├── .idea
├── CNN-RNN2016.iml
├── codeStyles
│ ├── Project.xml
│ └── codeStyleConfig.xml
├── dictionaries
│ └── ying.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── README.md
├── dataprocess
├── __init__.py
├── sampler.py
├── seqtransform.py
└── video_loader.py
├── dataset
├── __init__.py
└── prid2011.py
├── eval
├── __init__.py
├── eva_functions.py
└── evaluator.py
├── individualImage.png
├── models
├── __init__.py
└── cnnrnn.py
├── prid_data.py
├── splits_prid2011.json
├── train.py
└── utils.py
/.idea/CNN-RNN2016.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/codeStyles/Project.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/codeStyles/codeStyleConfig.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/.idea/dictionaries/ying.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | dirname
5 | dirnames
6 | identites
7 | imgs
8 | infostruct
9 | prid
10 | tracklet
11 | tracklets
12 | ying
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CNN-RNN2016
2 | Reimplementation the paper of Recurrent Convolutional Network for Video-based Person Re-Identification in Pytorch
3 | # Preparation
4 | Python 3.6
5 | Pytorch >= 0.4.0
6 | # Result on Prid2011
7 |
8 | |版本| map| rank1 | rank5 | rank10 | rank20 |
9 | | :---: | :---: |:-------:|:---: |:------:| -------:|
10 | | 复现 | 58.8%| 49.4% | 68.5% | 83.1% | 89.9% |
11 | | 原文 | --| 70% | 90% | 95% | 97% |
12 |
13 |
14 | # Problems
15 | P.1 use the official split to form dataset, the dataset is too small.
16 |
17 | train identites: 89, test identites: 89
18 |
19 | => PRID-2011 loaded
20 |
21 | |subset |# ids| # tracklets |
22 | | :---: | :---: |:-------:|
23 | | train | 89| 178 |
24 | | query | 89| 89 |
25 | | gallery | 89| 89 |
26 | | total | 178| 356 |
27 |
28 | P.2 Data Augmentation
29 |
30 | 1.mirror is not implemented
31 | 2.resize the image size from (128, 64) to (256, 128),not same as (64,48) in offical code
32 |
33 | P.3 Training
34 |
35 | 1.If batch-size set to 1, the net will not be convergent.
36 | 2.The dataset is too small, we can change the dataset generation way to extend
37 | the dataset. Maybe like the paper 'Video Person Re-identification with
38 | Competitive Snippet-similarity Aggregation and Co-attentive Snippet Embedding'.
39 |
40 | # Reference
41 | [Recurrent-Convolutional-Video-ReID](https://github.com/niallmcl/Recurrent-Convolutional-Video-ReID)
42 |
43 | [Spatial-Temporal-Pooling-Networks-ReID](https://github.com/YuanLeung/Spatial-Temporal-Pooling-Networks-ReID)
44 |
--------------------------------------------------------------------------------
/dataprocess/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # @author = yuci
4 | # Date: 19-3-20 下午8:10
5 |
6 |
--------------------------------------------------------------------------------
/dataprocess/sampler.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # @author = yuci
4 | # Date: 19-3-25 上午9:42
5 | import numpy as np
6 | import torch
7 |
8 | from torch.utils.data.sampler import (Sampler, SequentialSampler)
9 | from collections import defaultdict
10 |
11 |
12 | # 构成正序列对时,从另一个摄像头中选择对应的行人数据
13 | def No_index(a, b):
14 | assert isinstance(a, list)
15 | return [i for i, j in enumerate(a) if j != b]
16 |
17 |
18 | class RandomPairSampler(Sampler):
19 |
20 | def __init__(self, data_source):
21 | self.data_source = data_source # data_source的结构是一个元组(img_path, pid, cam_id)
22 | self.index_pid = defaultdict(int) # 索引对应的pid :index ---pid
23 | self.pid_cam = defaultdict(list) # pid对应的cam,是个列表
24 | self.pid_index = defaultdict(list) # pid 对应的索引
25 | self.num_samples = len(self.data_source) # 数据集的长度即为采样的数目 178
26 |
27 | for index, (_, pid, cam) in enumerate(self.data_source):
28 | self.index_pid[index] = pid
29 | self.pid_cam[pid].append(cam)
30 | self.pid_index[pid].append(index)
31 |
32 | def __len__(self):
33 | return self.num_samples * 2 # 采样后的数据是原数据集长度的两倍
34 |
35 | def __iter__(self): # 返回正负序列对
36 | indices = torch.randperm(self.num_samples)
37 | ret = [] # 返回(seqA, seqB, target)
38 | for i in range(2*self.num_samples):
39 |
40 | if i % 2 == 0: # positive pair
41 | j = i // 2
42 | j = int(indices[j]) # 确定序列对的第一个序列j
43 | _, j_pid, j_cam = self.data_source[j]
44 | pid_j = self.index_pid[j]
45 | cams = self.pid_cam[pid_j]
46 | index = self.pid_index[pid_j]
47 | select_cams = No_index(cams, j_cam) # 从另一个cam中选择第二个序列
48 | try:
49 | select_camind = np.random.choice(select_cams)
50 | except ValueError:
51 | print(cams)
52 | print(pid_j)
53 | select_ind = index[select_camind] # 选择第二个序列
54 | target = [1, pid_j, pid_j] # 标签信息
55 | ret.append((j, select_ind, target))
56 | else: # negative pair
57 | p_rand_id = torch.randperm(self.num_samples)
58 | a = int(p_rand_id[0]) # 随机选择
59 | pid_a = self.index_pid[a]
60 |
61 | b = int(p_rand_id[1])
62 | _, b_pid, b_cam = self.data_source[b]
63 | pid_b = self.index_pid[b]
64 | cams = self.pid_cam[pid_b]
65 | index = self.pid_index[pid_b]
66 | select_cams = No_index(cams, b_cam)
67 | try:
68 | select_camind = np.random.choice(select_cams)
69 | except ValueError:
70 | print(cams)
71 | print(pid_b)
72 | select_ind = index[select_camind]
73 |
74 | target = [-1, pid_a, pid_b]
75 | ret.append((a, select_ind, target))
76 | return iter(ret)
77 |
--------------------------------------------------------------------------------
/dataprocess/seqtransform.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # @author = yuci
4 | # Date: 19-3-26 上午9:53
5 | import torch
6 | import math
7 | import random
8 | from PIL import Image, ImageOps
9 | import numpy as np
10 | import cv2 as cv
11 | from utils import to_numpy
12 |
13 |
14 | class Compose(object):
15 | """Composes several transforms together.
16 |
17 | Args:
18 | transforms (List[Transform]): list of transforms to compose.
19 |
20 | Example:
21 | >>> transforms.Compose([
22 | >>> transforms.CenterCrop(10),
23 | >>> transforms.ToTensor(),
24 | >>> ])
25 | """
26 |
27 | def __init__(self, transforms):
28 | self.transforms = transforms
29 |
30 | def __call__(self, seqs):
31 | for t in self.transforms:
32 | seqs = t(seqs)
33 | return seqs
34 |
35 |
36 | class RectScale(object):
37 | def __init__(self, height, width, interpolation=Image.BILINEAR):
38 | self.height = height
39 | self.width = width
40 | self.interpolation = interpolation
41 |
42 | def __call__(self, seqs): # seqs = list[[image0,...,image8]]
43 | modallen = len(seqs) # 1个list
44 | framelen = len(seqs[0]) # 8
45 | new_seqs = [[[] for _ in range(framelen)] for _ in range(modallen)] # : [[[], [], [], [], [], [], [], []]]
46 |
47 | for modal_ind, modal in enumerate(seqs): # 遍历modal,总共有1个
48 | for frame_ind, frame in enumerate(modal): # 遍历每一帧图片
49 | w, h = frame.size # w:128,h:256
50 | if h == self.height and w == self.width:
51 | new_seqs[modal_ind][frame_ind] = frame
52 | else:
53 | new_seqs[modal_ind][frame_ind] = frame.resize((self.width, self.height), self.interpolation)
54 |
55 | return new_seqs
56 |
57 |
58 | class RandomSizedRectCrop(object):
59 | def __init__(self, height, width, interpolation=Image.BILINEAR):
60 | self.height = height
61 | self.width = width
62 | self.interpolation = interpolation
63 |
64 | def __call__(self, seqs):
65 | sample_img = seqs[0][0]
66 | for attempt in range(10):
67 | area = sample_img.size[0] * sample_img.size[1]
68 | target_area = random.uniform(0.64, 1.0) * area
69 | aspect_ratio = random.uniform(2, 3)
70 |
71 | h = int(round(math.sqrt(target_area * aspect_ratio)))
72 | w = int(round(math.sqrt(target_area / aspect_ratio)))
73 |
74 | if w <= sample_img.size[0] and h <= sample_img.size[1]:
75 | x1 = random.randint(0, sample_img.size[0] - w)
76 | y1 = random.randint(0, sample_img.size[1] - h)
77 |
78 | sample_img = sample_img.crop((x1, y1, x1 + w, y1 + h))
79 | assert (sample_img.size == (w, h))
80 | modallen = len(seqs)
81 | framelen = len(seqs[0])
82 | new_seqs = [[[] for _ in range(framelen)] for _ in range(modallen)]
83 |
84 | for modal_ind, modal in enumerate(seqs):
85 | for frame_ind, frame in enumerate(modal):
86 |
87 | frame = frame.crop((x1, y1, x1 + w, y1 + h))
88 | new_seqs[modal_ind][frame_ind] = frame.resize((self.width, self.height), self.interpolation)
89 |
90 | return new_seqs
91 |
92 | # Fallback
93 | scale = RectScale(self.height, self.width,
94 | interpolation=self.interpolation)
95 | return scale(seqs)
96 |
97 |
98 | class RandomSizedEarser(object):
99 |
100 | def __init__(self, sl=0.02, sh=0.2, asratio=0.3, p=0.5):
101 | self.sl = sl
102 | self.sh = sh
103 | self.asratio = asratio
104 | self.p = p
105 |
106 | def __call__(self, seqs):
107 | modallen = len(seqs)
108 | framelen = len(seqs[0])
109 | new_seqs = [[[] for _ in range(framelen)] for _ in range(modallen)]
110 | for modal_ind, modal in enumerate(seqs):
111 | for frame_ind, frame in enumerate(modal):
112 | p1 = random.uniform(0.0, 1.0)
113 | W = frame.size[0]
114 | H = frame.size[1]
115 | area = H * W
116 |
117 | if p1 > self.p:
118 | new_seqs[modal_ind][frame_ind] = frame
119 | else:
120 | gen = True
121 | while gen:
122 | Se = random.uniform(self.sl, self.sh) * area
123 | re = random.uniform(self.asratio, 1 / self.asratio)
124 | He = np.sqrt(Se * re)
125 | We = np.sqrt(Se / re)
126 | xe = random.uniform(0, W - We)
127 | ye = random.uniform(0, H - He)
128 | if xe + We <= W and ye + He <= H and xe > 0 and ye > 0:
129 | x1 = int(np.ceil(xe))
130 | y1 = int(np.ceil(ye))
131 | x2 = int(np.floor(x1 + We))
132 | y2 = int(np.floor(y1 + He))
133 | part1 = frame.crop((x1, y1, x2, y2))
134 | Rc = random.randint(0, 255)
135 | Gc = random.randint(0, 255)
136 | Bc = random.randint(0, 255)
137 | I = Image.new('RGB', part1.size, (Rc, Gc, Bc))
138 | frame.paste(I, part1.size)
139 | break
140 |
141 | new_seqs[modal_ind][frame_ind] = frame
142 |
143 | return new_seqs
144 |
145 |
146 | class RandomHorizontalFlip(object):
147 | """Randomly horizontally flips the given PIL.Image Sequence with a probability of 0.5
148 | """
149 | def __call__(self, seqs):
150 | if random.random() < 0.5:
151 | modallen = len(seqs)
152 | framelen = len(seqs[0])
153 | new_seqs = [[[] for _ in range(framelen)] for _ in range(modallen)]
154 | for modal_ind, modal in enumerate(seqs):
155 | for frame_ind, frame in enumerate(modal):
156 | new_seqs[modal_ind][frame_ind] = frame.transpose(Image.FLIP_LEFT_RIGHT)
157 | return new_seqs
158 | return seqs
159 |
160 |
161 | class ToTensor(object):
162 |
163 | def __call__(self, seqs):
164 | modallen = len(seqs)
165 | framelen = len(seqs[0])
166 | new_seqs = [[[] for _ in range(framelen)] for _ in range(modallen)]
167 | pic = seqs[0][0]
168 |
169 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
170 | if pic.mode == 'YCbCr':
171 | nchannel = 3
172 | elif pic.mode == 'I;16':
173 | nchannel = 1
174 | else:
175 | nchannel = len(pic.mode)
176 |
177 | if pic.mode == 'I':
178 | for modal_ind, modal in enumerate(seqs):
179 | for frame_ind, frame in enumerate(modal):
180 | img = torch.from_numpy(np.array(frame, np.int32, copy=False))
181 | img = img.view(pic.size[1], pic.size[0], nchannel)
182 | new_seqs[modal_ind][frame_ind] = img.transpose(0, 1).transpose(0, 2).contiguous()
183 |
184 | elif pic.mode == 'I;16':
185 | for modal_ind, modal in enumerate(seqs):
186 | for frame_ind, frame in enumerate(modal):
187 | img = torch.from_numpy(np.array(frame, np.int16, copy=False))
188 | img = img.view(pic.size[1], pic.size[0], nchannel)
189 | new_seqs[modal_ind][frame_ind] = img.transpose(0, 1).transpose(0, 2).contiguous()
190 | else:
191 | for modal_ind, modal in enumerate(seqs):
192 | for frame_ind, frame in enumerate(modal):
193 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(frame.tobytes()))
194 | img = img.view(pic.size[1], pic.size[0], nchannel) # torch.Size([128, 64, 3])
195 | img = img.transpose(0, 1).transpose(0, 2).contiguous() # torch.Size([3, 128, 64])
196 | new_seqs[modal_ind][frame_ind] = img.float().div(255)
197 |
198 | return new_seqs
199 |
200 |
201 | class Normalize(object):
202 | """Given mean: (R, G, B) and std: (R, G, B),
203 | will normalize each channel of the torch.*Tensor, i.e.
204 | channel = (channel - mean) / std
205 | """
206 | def __init__(self, mean, std):
207 | self.mean = mean
208 | self.std = std
209 |
210 | def __call__(self, seqs):
211 | # TODO: make efficient
212 | modallen = len(seqs)
213 | framelen = len(seqs[0])
214 | new_seqs = [[[] for _ in range(framelen)] for _ in range(modallen)]
215 |
216 | for modal_ind, modal in enumerate(seqs):
217 | for frame_ind, frame in enumerate(modal):
218 | for t, m, s in zip(frame, self.mean, self.std):
219 | t.sub_(m).div_(s)
220 | new_seqs[modal_ind][frame_ind] = frame
221 |
222 | return new_seqs
223 |
224 |
225 | class ToYUV(object):
226 |
227 | def __call__(self, seqs):
228 | modallen = len(seqs) # 1个list中有2个列表
229 | framelen = len(seqs[0])
230 | imagePixelData = torch.zeros((framelen, 5, 256, 128)) # torch.Size([16, 5, 128, 64])
231 |
232 | for i in range(framelen):
233 | fileRGB = seqs[0][i]
234 | fileOF = seqs[1][i]
235 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(fileRGB.tobytes())) # h,w,c
236 | img = img.view(fileRGB.size[1], fileRGB.size[0], 3) # torch.Size([128, 64, 3])
237 | # img = img.transpose(0, 1).transpose(0, 2).contiguous() # torch.Size([3, 128, 64])
238 | img = to_numpy(img.float()) # : (3, 256, 128)
239 | # img = cv.cvtColor(img, cv.COLOR_RGB2BGR)
240 | img = cv.cvtColor(img, cv.COLOR_BGR2YUV)
241 |
242 | imgof = torch.ByteTensor(torch.ByteStorage.from_buffer(fileOF.tobytes()))
243 | imgof = imgof.view(fileOF.size[1], fileOF.size[0], 3) # torch.Size([128, 64, 3])
244 | # imgof = imgof.transpose(0, 1).transpose(0, 2).contiguous().astype(np.float32) # torch.Size([3, 128, 64])
245 | imgof = to_numpy(imgof.float())
246 | img_tensor = torch.from_numpy(img) # torch.Size([256, 128, 3])
247 | imgof_tensor = torch.from_numpy(imgof) # torch.Size([256, 128, 3])
248 | for c in range(3):
249 | v = torch.sqrt(torch.var(img_tensor[:, :, c]))
250 | m = torch.mean(img_tensor[:, :, c])
251 | img_tensor[:, :, c] = img_tensor[:, :, c] - m
252 | img_tensor[:, :, c] = img_tensor[:, :, c] / torch.sqrt(v)
253 | imagePixelData[i, c] = img_tensor[:, :, c]
254 |
255 | for j in range(2):
256 | c = j + 1
257 | v = torch.sqrt(torch.var(imgof_tensor[:, :, c]))
258 | m = torch.mean(imgof_tensor[:, :, c])
259 | imgof_tensor[:, :, c] = imgof_tensor[:, :, c] - m
260 | imgof_tensor[:, :, c] = imgof_tensor[:, :, c] / torch.sqrt(v)
261 | imagePixelData[i, j + 3] = imgof_tensor[:, :, c]
262 |
263 | # for c in range(2):
264 | #
265 | # v = torch.sqrt(torch.var(imgof_tensor[:, :, c]))
266 | # m = torch.mean(imgof_tensor[:, :, c])
267 | # imgof_tensor[:, :, c] = imgof_tensor[:, :, c] - m
268 | # imgof_tensor[:, :, c] = imgof_tensor[:, :, c] / torch.sqrt(v)
269 | # imagePixelData[i, c + 3] = imgof_tensor[:, :, c]
270 |
271 | return imagePixelData
272 |
--------------------------------------------------------------------------------
/dataprocess/video_loader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # @author = yuci
4 | # Date: 19-3-20 下午8:38
5 | """ 第四步"""
6 | from PIL import Image
7 | import numpy as np
8 | import random
9 |
10 | import torch
11 | from torch.utils.data import Dataset
12 | from utils import to_torch
13 | import os.path as osp
14 | import cv2 as cv
15 |
16 |
17 | class VideoDataset(Dataset):
18 | """形成一个批次的数据 (批次大小,序列长度,channel, height,width)"""
19 | sample_methods = ['random', 'dense']
20 |
21 | def __init__(self, dataset, seq_len=16, sample='random', transform=None):
22 | self.dataset = dataset # 数据集
23 | self.seq_len = seq_len # 批次数据中的序列长度
24 | self.sample = sample # 采样方法,随机采样16个就结束,还是密集采样所有
25 | self.transform = transform # 对数据集中的图片进行数据增强
26 | self.useOpticalFlow = 1
27 |
28 | def __len__(self): # 获得数据集的长度
29 | return len(self.dataset)
30 |
31 | def getOfPath(self, rgb_path): # (一次1张图片)将图像转换为YUV,并且加入光流信息
32 | root_of = '/home/ying/Desktop/video_reid_mars/data/prid2011sequence/raw/prid2011flow/prid2011flow'
33 | fname_list = rgb_path.split('/')
34 | of_path = osp.join(root_of, fname_list[-3], fname_list[-2], fname_list[-1]) # 光流路径
35 |
36 | return of_path
37 |
38 | def __getitem__(self, item): # item = tuple(103, 102, [1, 51, 51])
39 | if self.sample == 'random':
40 | """
41 | Randomly sample seq_len consecutive frames from num frames,
42 | if num is smaller than seq_len, then replicate items.
43 | This sampling strategy is used in training phase.
44 | """
45 | item0, item1, target = item
46 | img0_paths, pid0, camid0 = self.dataset[item0]
47 |
48 | img1_paths, pid1, camid1 = self.dataset[item1] # 从数据集中获得某个id相应的信息
49 | num0 = len(img0_paths) # 获得这个id下,有多少张图片 73
50 | num1 = len(img1_paths) # 119
51 | # seq0
52 | frame_indices0 = list(range(num0)) # 将这些图片建立索引 [0, 1,...,72]
53 | rand_end0 = max(0, len(frame_indices0) - self.seq_len - 1) # 随机选择的索引起点最大值,56
54 | begin_index0 = random.randint(0, rand_end0) # 序列起点索引 0
55 | end_index0 = min(begin_index0 + self.seq_len, len(frame_indices0)) # 序列终点索引 16
56 | indices0 = frame_indices0[begin_index0:end_index0] # 根据随机选择的索引,确定序列的索引
57 |
58 | for index0 in indices0: # 遍历序列索引,
59 | if len(indices0) >= self.seq_len:
60 | break
61 | indices0.append(index0) # 当序列长度小于采样的长度时,复制序列,直到等于seq_len
62 |
63 | indices0 = np.array(indices0) # 列表转换成数组
64 | imgseq0 = [] # 存放图像序列
65 | flowseq0 = []
66 |
67 | for index0 in indices0: # [0, 1,...,15]
68 | index = int(index0) # 0
69 | img_paths0 = img0_paths[index] # 获得对应索引index下的图像绝对路径
70 | of_paths0 = self.getOfPath(img_paths0) # 获得对应的光流图片
71 | imgrgb0 = Image.open(img_paths0).convert('RGB')
72 | ofrgb0 = Image.open(of_paths0).convert('RGB')
73 | imgseq0.append(imgrgb0)
74 | flowseq0.append(ofrgb0)
75 | seq0 = [imgseq0, flowseq0] # [['.png','.png']]
76 | if self.transform is not None: # 以序列为单位进行数据增强,同时在transform中将图像转换为YUV格式
77 | seq0 = self.transform(seq0) # 进行数据增强,在transform中将光流数据加上
78 | img_tensor0 = seq0 # torch.Size([16, 5, 128, 64])
79 |
80 | # seq1 todo: 参照上面的修改
81 | frame_indices1 = list(range(num1))
82 | rand_end1 = max(0, len(frame_indices1) - self.seq_len - 1)
83 | begin_index1 = random.randint(0, rand_end1)
84 | end_index1 = min(begin_index1 + self.seq_len, len(frame_indices1))
85 | indices1 = frame_indices1[begin_index1:end_index1]
86 |
87 | for index1 in indices1:
88 | if len(indices1) >= self.seq_len:
89 | break
90 | indices1.append(index1)
91 |
92 | indices1 = np.array(indices1)
93 | imgseq1 = []
94 | flowseq1 = []
95 |
96 | for index1 in indices1:
97 | index = int(index1)
98 | img_paths1 = img1_paths[index]
99 | of_paths1 = self.getOfPath(img_paths1)
100 | imgrgb1 = Image.open(img_paths1)
101 | ofrgb1 = Image.open(of_paths1)
102 | imgseq1.append(imgrgb1)
103 | flowseq1.append(ofrgb1)
104 | seq1 = [imgseq1, flowseq1]
105 | if self.transform is not None:
106 | seq1 = self.transform(seq1)
107 | img_tensor1 = seq1
108 |
109 | return img_tensor0, img_tensor1, target
110 | elif self.sample == 'dense':
111 | """
112 | Sample all frames in a video into a list of clips,
113 | each clip contains seq_len frames, batch_size needs to be set to 1.
114 | This sampling strategy is used in test phase.
115 | """
116 | img_paths, pid, camid = self.dataset[item]
117 | num = len(img_paths) # 27
118 | cur_index = 0 # 密集采样,起始索引为0
119 | frame_indices = list(range(num)) # 图像帧索引列表
120 | indices_list = []
121 |
122 | while num - cur_index > self.seq_len: # 当序列总长度-当前索引 > 采样长度,则更新当前索引,一直遍历这个序列
123 | indices_list.append(frame_indices[cur_index:cur_index+self.seq_len])
124 | cur_index += self.seq_len # 更新当前索引
125 | last_seq = frame_indices[cur_index:] # 最后一个索引不满足采样长度,补足最后一个
126 | for index in last_seq:
127 | if len(last_seq) > self.seq_len:
128 | break
129 | last_seq.append(index)
130 | indices_list.append(last_seq) # 加上最后一个采样长度
131 |
132 | imgs_list = []
133 | for indices in indices_list: # 遍历每一个采样序列长度
134 | imgseq = [] # 用于存放一个采样序列长度的图片
135 | for index in indices: # 遍历每个采样序列中的每一张图片
136 | index = int(index)
137 | img_path = img_paths[index]
138 | img = Image.open(img_path).convert('RGB')
139 | if self.transform is not None:
140 | img = self.transform(img)
141 | img = img.unsqueeze(0) # [1, 3, 224, 112]
142 | imgseq.append(img)
143 | imgseq = to_torch(imgseq)
144 | imgseq = torch.cat(imgseq, dim=0) # [seq_len, 3, 224, 112]获得一个采样序列的图像
145 | imgs_list.append(imgseq)
146 | imgs_list = tuple(imgs_list)
147 | imgs_array = torch.stack(imgs_list) # 获得整个行人序列的密集采样序列
148 |
149 | return imgs_array, pid, camid
150 | else:
151 | raise KeyError("unknown sample method: {}.".format(self.sample))
152 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # @author = yuci
4 | # Date: 19-3-18 上午11:01
5 |
6 |
7 | from .prid2011 import PRID
8 |
9 |
10 | def get_sequence(name, *args, **kwargs):
11 | __factory = {
12 | 'prid2011': PRID,
13 | }
14 |
15 | if name not in __factory:
16 | raise KeyError("Unknown dataset", name)
17 | return __factory[name](*args, **kwargs)
18 |
--------------------------------------------------------------------------------
/dataset/prid2011.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # @author = yuci
4 | # Date: 19-3-18 下午9:28
5 | import glob
6 | import os.path as osp
7 | import numpy as np
8 | from utils import read_json
9 |
10 | """第三步 数据集形成"""
11 |
12 |
13 | # 用于构建数据集信息
14 | class infostruct(object):
15 | pass
16 |
17 |
18 | # 1.datasequence,数据集的总体情况
19 | class PRID(object):
20 | """
21 | code mainly from https://github.com/KaiyangZhou/deep-person-reid
22 | """
23 | root = '/home/ying/Desktop/video_reid_mars/data/prid2011sequence/raw/prid_2011'
24 | split_path = osp.join(root, 'splits_prid2011.json')
25 | cam_a_path = osp.join(root, 'prid_2011', 'multi_shot', 'cam_a')
26 | cam_b_path = osp.join(root, 'prid_2011', 'multi_shot', 'cam_b')
27 |
28 | def __init__(self, split_id=0, min_seq_len=0):
29 | self._check_before_run() # 一、检查数据集是否存在
30 |
31 | splits = read_json(self.split_path) # 二、读取split文件,包含数据集分割信息train/test
32 | if split_id >= len(splits): # split是个元组?查询的split_id 不能超过整个分割的长度
33 | raise ValueError("split id exceeds range, received {}, but expected between 0 and {}".format(split_id, len(self.split_path)))
34 | split = splits[split_id] # 根据split_id从splits文件中选出对应的分割split
35 | train_split, test_split = split['train'], split['test']
36 | print("# train identites: {}, # test identites: {}".format(len(train_split), len(test_split)))
37 |
38 | # 三、根据split信息,处理原始数据集。返回数据集,训练集中tracklet的数量,行人id的数量,每个tracklet中图片的数量
39 | train, num_train_tracklets, num_train_pids, num_imgs_train = \
40 | self._process_data(train_split, cam1=True, cam2=True)
41 | query, num_query_tracklets, num_query_pids, num_imgs_query, query_pid, query_camid = \
42 | self._process_data2(test_split, cam1=True, cam2=False)
43 | gallery, num_gallery_tracklets, num_gallery_pids, num_imgs_gallery, gallery_pid, gallery_camid = \
44 | self._process_data2(test_split, cam1=False, cam2=True)
45 |
46 | # 四、统计下每个视频序列的长度信息
47 | num_imgs_per_tracklet = num_imgs_train + num_imgs_query + num_imgs_gallery # 列表
48 | min_num = np.min(num_imgs_per_tracklet)
49 | max_num = np.max(num_imgs_per_tracklet)
50 | avg_num = np.mean(num_imgs_per_tracklet)
51 |
52 | # 五、统计行人id信息
53 | num_total_pids = num_train_pids + num_query_pids + num_gallery_pids # 一个数
54 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets
55 |
56 | # 六、封装数据集?
57 | self.train = train
58 | self.query = query
59 | self.gallery = gallery
60 | self.num_train_pids = num_train_pids
61 | self.num_query_pids = num_query_pids
62 | self.num_gallery_pids = num_gallery_pids
63 |
64 | self.queryinfo = infostruct()
65 | self.queryinfo.pid = query_pid
66 | self.queryinfo.camid = query_camid
67 | self.queryinfo.tranum = num_imgs_query
68 |
69 | self.galleryinfo = infostruct()
70 | self.galleryinfo.pid = gallery_pid
71 | self.galleryinfo.camid = gallery_camid
72 | self.galleryinfo.tranum = num_imgs_gallery
73 |
74 | # 七、打印数据集的一些基本信息
75 | print("=> PRID-2011 loaded")
76 | print("Dataset statistics:")
77 | print(" ------------------------------")
78 | print(" subset | # ids | # tracklets")
79 | print(" ------------------------------")
80 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets))
81 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets))
82 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets))
83 | print(" ------------------------------")
84 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets))
85 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num))
86 | print(" ------------------------------")
87 |
88 | # check if all files are available
89 | def _check_before_run(self):
90 | if not osp.exists(self.root):
91 | raise RuntimeError("'{}' is not available".format(self.root))
92 |
93 | # 根据split文件,处理原始数据集,形成对应的训练集,测试集
94 | def _process_data(self, dirnames, cam1=True, cam2=True):
95 | tracklets = [] # 列表,存放每个id的视频序列, 行人id,cam id
96 | num_imgs_per_tracklet = [] # 统计每个tracklet中的图片数目
97 | dirname2pid = {dirname: i for i, dirname in enumerate(dirnames)} # 根据数据集特点,将文件名字转换为连续的id
98 |
99 | for dirname in dirnames:
100 | if cam1: # cam_a 摄像头中数据
101 | person_dir = osp.join(self.cam_a_path, dirname) # 获得对应行人id的文件夹路径
102 | img_names = glob.glob(osp.join(person_dir, '*.png')) # 使用glob函数,获取文件夹中所有图片的路径
103 | assert len(img_names) > 0
104 | img_names = tuple(img_names) # 将所有图片路径存在一个元组中
105 | pid = dirname2pid[dirname] # 根据数据集特点,将文件夹名字转换为行人id
106 | tracklets.append((img_names, pid, 0))
107 | num_imgs_per_tracklet.append(len(img_names))
108 | if cam2:
109 | person_dir = osp.join(self.cam_b_path, dirname) # 获得对应行人id的文件夹路径
110 | img_names = glob.glob(osp.join(person_dir, '*.png')) # 使用glob函数,获取文件夹中所有图片的路径
111 | assert len(img_names) > 0
112 | img_names = tuple(img_names) # 将所有图片路径存在一个元组中
113 | pid = dirname2pid[dirname] # 根据数据集特点,将文件夹名字转换为行人id
114 | tracklets.append((img_names, pid, 1))
115 | num_imgs_per_tracklet.append(len(img_names))
116 |
117 | num_tracklets = len(tracklets) # 最后统计下视频序列的个数
118 | num_pid = len(dirnames) # 统计行人的id数
119 |
120 | return tracklets, num_tracklets, num_pid, num_imgs_per_tracklet
121 |
122 | def _process_data2(self, dirnames, cam1=True, cam2=True):
123 | tracklets = [] # 列表,存放每个id的视频序列, 行人id,cam id
124 | num_imgs_per_tracklet = [] # 统计每个tracklet中的图片数目
125 | dirname2pid = {dirname: i for i, dirname in enumerate(dirnames)} # 根据数据集特点,将文件名字转换为连续的id
126 | pid_list = []
127 | camid_list = []
128 |
129 | for dirname in dirnames:
130 | if cam1: # cam_a 摄像头中数据
131 | person_dir = osp.join(self.cam_a_path, dirname) # 获得对应行人id的文件夹路径
132 | img_names = glob.glob(osp.join(person_dir, '*.png')) # 使用glob函数,获取文件夹中所有图片的路径
133 | assert len(img_names) > 0
134 | img_names = tuple(img_names) # 将所有图片路径存在一个元组中
135 | pid = dirname2pid[dirname] # 根据数据集特点,将文件夹名字转换为行人id
136 | tracklets.append((img_names, pid, 0))
137 | pid_list.append(pid)
138 | camid_list.append(0)
139 | num_imgs_per_tracklet.append(len(img_names))
140 | if cam2:
141 | person_dir = osp.join(self.cam_b_path, dirname) # 获得对应行人id的文件夹路径
142 | img_names = glob.glob(osp.join(person_dir, '*.png')) # 使用glob函数,获取文件夹中所有图片的路径
143 | assert len(img_names) > 0
144 | img_names = tuple(img_names) # 将所有图片路径存在一个元组中
145 | pid = dirname2pid[dirname] # 根据数据集特点,将文件夹名字转换为行人id
146 | tracklets.append((img_names, pid, 1))
147 | pid_list.append(pid)
148 | camid_list.append(1)
149 | num_imgs_per_tracklet.append(len(img_names))
150 |
151 | num_tracklets = len(tracklets) # 最后统计下视频序列的个数
152 | num_pid = len(dirnames) # 统计行人的id数
153 |
154 | return tracklets, num_tracklets, num_pid, num_imgs_per_tracklet, pid_list, camid_list
155 |
--------------------------------------------------------------------------------
/eval/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # @author = yuci
4 | # Date: 19-3-24 下午9:20
5 | from .eva_functions import evaluate
6 | from .evaluator import Evaluator
7 |
--------------------------------------------------------------------------------
/eval/eva_functions.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # @author = yuci
4 | # Date: 19-3-24 下午9:20
5 | import numpy as np
6 | from utils import to_torch
7 |
8 |
9 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
10 | num_q, num_g = distmat.shape
11 | if num_g < max_rank:
12 | max_rank = num_g
13 | print("Note: number of gallery samples is quite small, got {}".format(num_g))
14 | indices = np.argsort(distmat, axis=1)
15 |
16 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
17 |
18 | # compute cmc curve for each query
19 | all_cmc = []
20 | all_AP = []
21 | num_valid_q = 0.
22 | for q_idx in range(num_q):
23 | # get query pid and camid
24 | q_pid = q_pids[q_idx]
25 | q_camid = q_camids[q_idx]
26 | # remove gallery samples that have the same pid and camid with query
27 | order = indices[q_idx]
28 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
29 | keep = np.invert(remove)
30 |
31 | # compute cmc curve
32 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
33 |
34 | if not np.any(orig_cmc):
35 | # this condition is true when query identity does not appear in gallery
36 | continue
37 |
38 | cmc = orig_cmc.cumsum()
39 | cmc[cmc > 1] = 1
40 | all_cmc.append(cmc[:max_rank])
41 | num_valid_q += 1.
42 |
43 | # compute average precision
44 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
45 | num_rel = orig_cmc.sum()
46 | tmp_cmc = orig_cmc.cumsum()
47 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
48 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
49 | AP = tmp_cmc.sum() / num_rel
50 |
51 | all_AP.append(AP)
52 |
53 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
54 |
55 | all_cmc = np.asarray(all_cmc).astype(np.float32)
56 | all_cmc = all_cmc.sum(0) / num_valid_q
57 | mAP = np.mean(all_AP)
58 |
59 | return all_cmc, mAP
60 |
61 |
62 | def accuracy(output, target, topk=(1,)):
63 | output, target = to_torch(output), to_torch(target)
64 | maxk = max(topk)
65 | batch_size = target.size(0)
66 |
67 | _, pred = output.topk(maxk, 1, True, True)
68 | pred = pred.t()
69 | correct = pred.eq(target.view(1, -1).expand_as(pred))
70 |
71 | ret = []
72 | for k in topk:
73 | correct_k = correct[:k].view(-1).float().sum(0)
74 | ret.append(correct_k.mul_(1. / batch_size))
75 | return ret
76 |
--------------------------------------------------------------------------------
/eval/evaluator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils import to_torch
3 | from .eva_functions import evaluate
4 | import numpy as np
5 | from torch import nn
6 |
7 |
8 | def evaluate_seq(distmat, query_pids, query_camids, gallery_pids, gallery_camids, cmc_topk=[1, 5, 10, 20]):
9 | query_ids = np.array(query_pids)
10 | gallery_ids = np.array(gallery_pids)
11 | query_cams = np.array(query_camids)
12 | gallery_cams = np.array(gallery_camids)
13 |
14 | cmc_scores, mAP = evaluate(distmat, query_ids, gallery_ids, query_cams, gallery_cams)
15 | print('Mean AP: {:4.1%}'.format(mAP))
16 |
17 | for r in cmc_topk:
18 | print("Rank-{:<3}: {:.1%}".format(r, cmc_scores[r-1]))
19 | print("------------------")
20 |
21 | # Use the allshots cmc top-1 score for validation criterion
22 | return cmc_scores[0]
23 |
24 |
25 | def pairwise_distance_tensor(query_x, gallery_x):
26 |
27 | m, n = query_x.size(0), gallery_x.size(0)
28 | x = query_x.view(m, -1)
29 | y = gallery_x.view(n, -1)
30 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) +\
31 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
32 | dist.addmm_(1, -2, x, y.t())
33 |
34 | return dist
35 |
36 |
37 | class Evaluator(object):
38 |
39 | def __init__(self, cnn_model):
40 | super(Evaluator, self).__init__()
41 | self.cnn_model = cnn_model
42 | self.softmax = nn.Softmax(dim=-1)
43 |
44 | def extract_feature(self, data_loader): # 2
45 | # print_freq = 50
46 | self.cnn_model.eval()
47 |
48 | qf = []
49 | # qf_raw = []
50 |
51 | for i, inputs in enumerate(data_loader):
52 | imgs, _, _ = inputs
53 | b, n, s, c, h, w = imgs.size()
54 | imgs = imgs.view(b*n, s, c, h, w)
55 | imgs = to_torch(imgs) # torch.Size([8, 8, 3, 256, 128])
56 | # flows = to_torch(flows) # torch.Size([8, 8, 3, 256, 128])
57 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
58 | imgs = imgs.to(device)
59 | # flows = flows.to(device)
60 | with torch.no_grad():
61 | out_feat, out_raw = self.cnn_model(imgs)
62 | allfeatures = out_feat.view(n, -1) # torch.Size([8, 128])
63 | # allfeatures_raw = out_raw.view(n, -1) # torch.Size([8, 128])
64 | allfeatures = torch.mean(allfeatures, 0).data.cpu() # 汇总一个序列特征,取平均
65 | # allfeatures_raw = torch.mean(allfeatures_raw, 0).data
66 | qf.append(allfeatures)
67 | # qf_raw.append(allfeatures_raw)
68 | qf = torch.stack(qf)
69 | # qf_raw = torch.stack(allfeatures_raw)
70 |
71 | print("Extracted features for query/gallery set, obtained {}-by-{} matrix"
72 | .format(qf.size(0), qf.size(1)))
73 | return qf
74 |
75 | def evaluate(self, query_loader, gallery_loader, queryinfo, galleryinfo):
76 | # 1
77 | self.cnn_model.eval()
78 |
79 | querypid = queryinfo.pid # 100 ge id : [74, 20, 90, 151, 1, 69, 84, 149, 5, 111, -1, 154, ...]
80 | querycamid = queryinfo.camid # 00000000000
81 |
82 | gallerypid = galleryinfo.pid # : [74, 20, 90, 151, 1, 69, 84, 149, 5, 111, -1, 154, ...]
83 | gallerycamid = galleryinfo.camid # 1111111111
84 |
85 | pooled_probe = self.extract_feature(query_loader) # 1980 * 128
86 | pooled_gallery = self.extract_feature(gallery_loader)
87 | print("Computing distance matrix")
88 | distmat = pairwise_distance_tensor(pooled_probe, pooled_gallery)
89 |
90 | return evaluate_seq(distmat, querypid, querycamid, gallerypid, gallerycamid)
91 |
--------------------------------------------------------------------------------
/individualImage.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AsuradaYuci/CNN-RNN2016/875a54a7bfa91a05d737e8155258079d5b26a701/individualImage.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # @author = yuci
4 | # Date: 19-3-18 上午11:03
5 | from .cnnrnn import Net
6 |
7 |
8 | __factory = {
9 | 'cnn-rnn': Net,
10 | }
11 |
12 |
13 | def names():
14 | return sorted(__factory.keys())
15 |
16 |
17 | def creat(name, *args, **kwargs):
18 | if name not in __factory:
19 | raise KeyError("unknown model:", name)
20 | return __factory[name](*args, **kwargs)
21 |
--------------------------------------------------------------------------------
/models/cnnrnn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # @author = yuci
4 | # Date: 19-3-21 下午7:29
5 | """CNN-RNN 网络架构 三层CNN
6 | 1.第一层网络:16个卷积核,尺寸为5*5,步长为2;2*2最大池化;tanh激活函数
7 | 2.第二层网络:64个卷积核,尺寸为5*5,步长为2;2*2最大池化;tanh激活函数
8 | 3.第三层网络:64个卷积核,尺寸为5*5,步长为2;2*2最大池化;tanh激活函数
9 | 4.0.5的dropout
10 | 5.128个元素的FC全连接层
11 |
12 | """
13 | import torch
14 | import torch.nn as nn
15 | import numpy as np
16 | from torch.nn import functional as F
17 |
18 |
19 | class Net(nn.Module):
20 | def __init__(self, nFilter1, nFilter2, nFilter3, num_person_train, dropout=0.5, num_features=0, seq_len=0, batch=0):
21 | super(Net, self).__init__()
22 | self.batch = batch
23 | self.seq_len = seq_len
24 | self.num_person_train = num_person_train
25 | self.dropout = dropout # 随机失活的概率,0-1
26 | self.num_features = num_features # 输出的特征维度 128
27 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
28 |
29 | self.nFilters = [nFilter1, nFilter2, nFilter3] # 初始化每一层的卷积核个数
30 | self.filter_size = [5, 5, 5] # 卷积核尺寸
31 |
32 | self.poolsize = [2, 2, 2] # 最大池化的尺寸
33 | self.stepsize = [2, 2, 2] # 池化步长
34 | self.padDim = 4 # 零填充
35 | self.input_channel = 5 # 3img + 2optical flow
36 |
37 | # 构建卷积层,nn。Conv2d(输入通道,卷积核的个数,卷积核尺寸,步长,零填充)
38 | self.conv1 = nn.Conv2d(self.input_channel, self.nFilters[0], self.filter_size[0], stride=1, padding=self.padDim)
39 | self.conv2 = nn.Conv2d(self.nFilters[0], self.nFilters[1], self.filter_size[1], stride=1, padding=self.padDim)
40 | self.conv3 = nn.Conv2d(self.nFilters[1], self.nFilters[2], self.filter_size[2], stride=1, padding=self.padDim)
41 |
42 | # 构建最大池化层
43 | self.pooling1 = nn.MaxPool2d(self.poolsize[0], self.stepsize[0])
44 | self.pooling2 = nn.MaxPool2d(self.poolsize[1], self.stepsize[1])
45 | self.pooling3 = nn.MaxPool2d(self.poolsize[2], self.stepsize[2])
46 |
47 | # tanh激活函数
48 | self.tanh = nn.Tanh()
49 |
50 | # FC层
51 | n_fully_connected = 21280 # 根据图片尺寸修改
52 |
53 | self.seq2 = nn.Sequential(
54 | nn.Dropout(self.dropout),
55 | nn.Linear(n_fully_connected, self.num_features)
56 | )
57 |
58 | # rnn层
59 | self.rnn = nn.RNN(self.num_features, self.num_features)
60 | self.hid_weight = nn.Parameter(
61 | nn.init.xavier_uniform_(
62 | torch.Tensor(1, self.seq_len * self.batch, self.num_features).to(self.device), gain=np.sqrt(2.0)
63 | ), requires_grad=True,
64 | )
65 |
66 | # final full connectlayer
67 | self.final_FC = nn.Linear(self.num_features, self.num_person_train)
68 |
69 | def build_net(self, input1, input2):
70 | seq1 = nn.Sequential(
71 | self.conv1, self.tanh, self.pooling1,
72 | self.conv2, self.tanh, self.pooling2,
73 | self.conv3, self.tanh, self.pooling3,
74 | )
75 | b = input1.size(0) # batch的大小
76 | n = input1.size(1) # 1个batch中图片的数目
77 | input1 = input1.view(b*n, input1.size(2), input1.size(3), input1.size(4)) #
78 | input2 = input2.view(b*n, input2.size(2), input2.size(3), input2.size(4))
79 | inp1_seq1_out = seq1(input1).view(input1.size(0), -1) # torch.Size([16, 32, 35, 19])
80 | inp2_seq1_out = seq1(input2).view(input2.size(0), -1) # 经过卷积层后的输出 torch.Size([16, 32, 35, 19])
81 | inp1_seq2_out = self.seq2(inp1_seq1_out).unsqueeze_(0)
82 | inp2_seq2_out = self.seq2(inp2_seq1_out).unsqueeze_(0) # 经过fc层的输出
83 |
84 | inp1_rnn_out, hn1 = self.rnn(inp1_seq2_out, self.hid_weight)
85 | inp2_rnn_out, hn2 = self.rnn(inp2_seq2_out, self.hid_weight) # todo:should debug here
86 | inp1_rnn_out = inp1_rnn_out.view(b, n, -1) # torch.Size([8, 16, 128])
87 | inp2_rnn_out = inp2_rnn_out.view(b, n, -1)
88 | inp1_rnn_out = inp1_rnn_out.permute(0, 2, 1)
89 | inp2_rnn_out = inp2_rnn_out.permute(0, 2, 1) # 8,128,16
90 |
91 | # 平均池化/最大池化
92 | feature_p = F.max_pool1d(inp1_rnn_out, inp1_rnn_out.size(2)) # 序列特征 8, 128, 1
93 | feature_g = F.max_pool1d(inp2_rnn_out, inp2_rnn_out.size(2))
94 |
95 | feature_p = feature_p.view(b, self.num_features) # 8,128
96 | feature_g = feature_g.view(b, self.num_features)
97 |
98 | # 分类
99 | identity_p = self.final_FC(feature_p) # 身份特征 torch.Size([8, 89])
100 | identity_g = self.final_FC(feature_g)
101 | return feature_p, feature_g, identity_p, identity_g
102 |
103 | def forward(self, input1, input2):
104 | feature_p, feature_g, identity_p, identity_g = self.build_net(input1, input2)
105 | return feature_p, feature_g, identity_p, identity_g
106 |
107 |
108 | class Criterion(nn.Module):
109 | def __init__(self, hinge_margin=2):
110 | super(Criterion, self).__init__()
111 | self.hinge_margin = hinge_margin
112 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
113 |
114 | def forward(self, feature_p, feature_g, identity_p, identity_g, target):
115 | dist = nn.PairwiseDistance(p=2)
116 | pair_dist = dist(feature_p, feature_g) # 欧几里得距离
117 |
118 | # 1.折页损失
119 | hing = nn.HingeEmbeddingLoss(margin=self.hinge_margin, reduce=False)
120 | label0 = target[0].to(self.device)
121 | hing_loss = hing(pair_dist, label0)
122 |
123 | # 2.交叉熵损失
124 | nll = nn.CrossEntropyLoss()
125 | label1 = target[1].to(self.device)
126 | label2 = target[2].to(self.device)
127 | loss_p = nll(identity_p, label1)
128 | loss_g = nll(identity_g, label2)
129 |
130 | # 3.损失求和
131 | total_loss = hing_loss + loss_p + loss_g
132 | mean_loss = torch.mean(total_loss)
133 |
134 | return mean_loss
135 |
136 |
137 |
138 |
--------------------------------------------------------------------------------
/prid_data.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # @author = yuci
4 | # Date: 19-3-18 上午11:04
5 | """第一步 main文件"""
6 |
7 | import torch
8 | import numpy as np
9 | import sys
10 | import argparse
11 | import os.path as osp
12 | import torch.backends.cudnn as cudnn
13 | # import torchvision.transforms as T
14 | from torch.utils.data import DataLoader
15 | import torch.optim as optim
16 |
17 | from utils import Logger, load_checkpoint, save_checkpoint
18 | from dataset import get_sequence
19 | from dataprocess.sampler import RandomPairSampler
20 | from dataprocess.video_loader import VideoDataset
21 | from dataprocess import seqtransform as T
22 | import models
23 | from models.cnnrnn import Criterion
24 | from eval import Evaluator
25 | from train import SEQTrainer
26 |
27 |
28 | # get dataset 数据集准备
29 | def getdata(dataset_name, split_id, batch_size, seq_len, seq_srd, workers):
30 | # todo:将光流信息加上,路径问题,数据增强
31 | # root_rgb = '/home/ying/Desktop/video_reid_mars/data/prid2011sequence/raw/prid_2011/prid_2011/multi_shot'
32 | # root_of = '/home/ying/Desktop/video_reid_mars/data/prid2011sequence/raw/prid2011flow'
33 | dataset = get_sequence(dataset_name, split_id)
34 | train_set = dataset.train
35 | num_classes = dataset.num_train_pids # 89
36 | # normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
37 | # 采用的数据增强:随机裁剪,水平翻转
38 | transform_train = T.Compose([
39 | T.RectScale(256, 128),
40 | T.RandomSizedEarser(),
41 | T.RandomHorizontalFlip(),
42 | T.ToYUV(),
43 | ])
44 | transform_query = T.Compose([
45 | T.RectScale(256, 128),
46 | T.ToYUV(),
47 | ])
48 | transform_gallery = T.Compose([
49 | T.RectScale(256, 128),
50 | T.ToYUV(),
51 | ])
52 | # 对数据集进行处理,封装进Dataloader ==> 类的初始化
53 | train_processor = VideoDataset(dataset.train, seq_len=seq_len, sample='random', transform=transform_train)
54 | query_processor = VideoDataset(dataset.query, seq_len=seq_len, sample='dense', transform=transform_query)
55 | gallery_processor = VideoDataset(dataset.gallery, seq_len=seq_len, sample='dense', transform=transform_gallery)
56 |
57 | train_loader = DataLoader(train_processor, batch_size=batch_size, num_workers=workers, sampler=RandomPairSampler(train_set), pin_memory=True, drop_last=True)
58 | query_loader = DataLoader(query_processor, batch_size=1, num_workers=workers, shuffle=False, pin_memory=True, drop_last=False)
59 | gallery_loader = DataLoader(gallery_processor, batch_size=1, num_workers=workers, shuffle=False, pin_memory=True, drop_last=False)
60 |
61 | return dataset, num_classes, train_loader, query_loader, gallery_loader
62 |
63 |
64 | def main(args):
65 | # 1.初始化设置
66 | np.random.seed(args.seed)
67 | torch.manual_seed(args.seed)
68 | torch.cuda.manual_seed_all(args.seed) # 为GPU设置随机数种子
69 | cudnn.benchmark = True # 在程序刚开始加这条语句可以提升一点训练速度,没什么额外开销
70 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
71 |
72 | # 2.日志文件 log
73 | if args.evaluate == 1:
74 | sys.stdout = Logger(osp.join(args.logs_dir, 'log_test.txt'))
75 | else:
76 | sys.stdout = Logger(osp.join(args.logs_dir, 'log_train.txt'))
77 |
78 | # 3.数据集 todo:重建这里的数据集
79 | dataset, numclasses, train_loader, query_loader, gallery_loader = getdata(args.dataset, args.split, args.batch_size, args.seq_len, args.seq_srd, args.workers)
80 |
81 | # 4.建立网络
82 | cnn_rnn_model = models.creat(args.a1, 16, 32, 32, numclasses, num_features=args.features, seq_len=args.seq_len, batch=args.batch_size).to(device)
83 | criterion = Criterion(args.hingeMargin).to(device)
84 | optimizer = optim.SGD(cnn_rnn_model.parameters(), lr=args.lr1, momentum=args.momentum, weight_decay=args.weight_decay)
85 |
86 | # 5.trainer实例化
87 | trainer = SEQTrainer(cnn_rnn_model, criterion)
88 |
89 | # 6.evaluate类的实例化
90 | evaluator = Evaluator(cnn_rnn_model)
91 | best_top1 = 0
92 |
93 | # 6.进入训练/测试模式
94 | if args.evaluate == 1:
95 | checkpoint = load_checkpoint(osp.join(args.logs_dir, 'cnn_rnn_best.pth.tar'))
96 | cnn_rnn_model.load_state_dict(checkpoint['state_dict'])
97 | rank1 = evaluator.evaluate(query_loader, gallery_loader, dataset.queryinfo, dataset.galleryinfo)
98 | else:
99 | cnn_rnn_model.train()
100 | for epoch in range(args.start_epoch, args.epochs):
101 | trainer.train(epoch, train_loader, optimizer)
102 |
103 | # 每隔几个epoch测试一次
104 | if (epoch + 1) % 400 == 0 or (epoch + 1) == args.epochs:
105 |
106 | top1 = evaluator.evaluate(query_loader, gallery_loader, dataset.queryinfo, dataset.galleryinfo)
107 | is_best = top1 > best_top1
108 | if is_best:
109 | best_top1 = top1
110 |
111 | save_checkpoint({
112 | 'state_dict': cnn_rnn_model.state_dict(),
113 | 'epoch': epoch + 1,
114 | 'best_top1': best_top1,
115 | }, is_best, fpath=osp.join(args.logs_dir, 'cnn_checkpoint.pth.tar'))
116 |
117 |
118 | if __name__ == '__main__':
119 | parser = argparse.ArgumentParser(description='hh')
120 | parser.add_argument('--seed', type=int, default=1)
121 | # DATA
122 | parser.add_argument('--dataset', type=str, default='prid2011', choices=['ilds', 'prid2011', 'mars'])
123 | parser.add_argument('--batch-size', type=int, default=8, help='depend on your device')
124 | parser.add_argument('--workers', type=int, default=4)
125 | parser.add_argument('--seq_len', type=int, default=16, help='the number of images in a sequence')
126 | parser.add_argument('--seq_srd', type=int, default=16, help='采样间隔步长')
127 | parser.add_argument('--split', type=int, default=0, help='total 10')
128 | # Model
129 | parser.add_argument('--a1', type=str, default='cnn-rnn')
130 | parser.add_argument('--nConvFilters', type=int, default=32)
131 | parser.add_argument('--features', type=int, default=128, help='features dimension')
132 | parser.add_argument('--hingeMargin', type=int, default=2)
133 | # parser.add_argument('--dropout', type=float, default=0.0)
134 | # Optimizer
135 | parser.add_argument('--lr1', type=float, default=0.001)
136 | parser.add_argument('--lrstep', type=int, default=20)
137 | parser.add_argument('--momentum', type=float, default=0.9)
138 | parser.add_argument('--weight-decay', type=float, default=5e-4)
139 | # Train
140 | parser.add_argument('--start-epoch', type=int, default=0)
141 | parser.add_argument('--epochs', type=int, default=400)
142 | parser.add_argument('--evaluate', type=int, default=0, help='0 => train; 1 =>test')
143 | # Path
144 | working_dir = osp.dirname(osp.abspath(__file__))
145 | parser.add_argument('--dataset-dir', type=str, metavar='PATH', default=osp.join(working_dir, '../video_reid_mars/data'))
146 | parser.add_argument('--logs-dir', type=str, metavar='PATH', default=osp.join(working_dir, 'log/yuci'))
147 | args = parser.parse_args()
148 | # main func
149 | main(args)
150 |
--------------------------------------------------------------------------------
/splits_prid2011.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "test": [
4 | "person_0001",
5 | "person_0002",
6 | "person_0003",
7 | "person_0007",
8 | "person_0012",
9 | "person_0013",
10 | "person_0014",
11 | "person_0016",
12 | "person_0017",
13 | "person_0023",
14 | "person_0027",
15 | "person_0030",
16 | "person_0032",
17 | "person_0033",
18 | "person_0035",
19 | "person_0036",
20 | "person_0037",
21 | "person_0038",
22 | "person_0040",
23 | "person_0043",
24 | "person_0044",
25 | "person_0045",
26 | "person_0046",
27 | "person_0047",
28 | "person_0048",
29 | "person_0049",
30 | "person_0054",
31 | "person_0058",
32 | "person_0060",
33 | "person_0061",
34 | "person_0063",
35 | "person_0067",
36 | "person_0069",
37 | "person_0070",
38 | "person_0073",
39 | "person_0074",
40 | "person_0075",
41 | "person_0078",
42 | "person_0081",
43 | "person_0082",
44 | "person_0083",
45 | "person_0085",
46 | "person_0087",
47 | "person_0090",
48 | "person_0096",
49 | "person_0097",
50 | "person_0099",
51 | "person_0101",
52 | "person_0102",
53 | "person_0107",
54 | "person_0108",
55 | "person_0109",
56 | "person_0111",
57 | "person_0116",
58 | "person_0117",
59 | "person_0122",
60 | "person_0124",
61 | "person_0126",
62 | "person_0127",
63 | "person_0130",
64 | "person_0131",
65 | "person_0132",
66 | "person_0141",
67 | "person_0148",
68 | "person_0149",
69 | "person_0151",
70 | "person_0152",
71 | "person_0153",
72 | "person_0156",
73 | "person_0169",
74 | "person_0170",
75 | "person_0171",
76 | "person_0172",
77 | "person_0173",
78 | "person_0174",
79 | "person_0175",
80 | "person_0177",
81 | "person_0178",
82 | "person_0181",
83 | "person_0182",
84 | "person_0184",
85 | "person_0186",
86 | "person_0188",
87 | "person_0189",
88 | "person_0191",
89 | "person_0192",
90 | "person_0193",
91 | "person_0195",
92 | "person_0197"
93 | ],
94 | "train": [
95 | "person_0004",
96 | "person_0005",
97 | "person_0008",
98 | "person_0009",
99 | "person_0010",
100 | "person_0011",
101 | "person_0015",
102 | "person_0018",
103 | "person_0020",
104 | "person_0021",
105 | "person_0022",
106 | "person_0024",
107 | "person_0025",
108 | "person_0026",
109 | "person_0028",
110 | "person_0029",
111 | "person_0034",
112 | "person_0039",
113 | "person_0041",
114 | "person_0042",
115 | "person_0050",
116 | "person_0051",
117 | "person_0053",
118 | "person_0055",
119 | "person_0056",
120 | "person_0059",
121 | "person_0064",
122 | "person_0065",
123 | "person_0066",
124 | "person_0068",
125 | "person_0072",
126 | "person_0076",
127 | "person_0077",
128 | "person_0079",
129 | "person_0080",
130 | "person_0084",
131 | "person_0086",
132 | "person_0088",
133 | "person_0089",
134 | "person_0091",
135 | "person_0092",
136 | "person_0093",
137 | "person_0095",
138 | "person_0098",
139 | "person_0100",
140 | "person_0103",
141 | "person_0104",
142 | "person_0105",
143 | "person_0110",
144 | "person_0112",
145 | "person_0113",
146 | "person_0114",
147 | "person_0115",
148 | "person_0118",
149 | "person_0119",
150 | "person_0121",
151 | "person_0123",
152 | "person_0125",
153 | "person_0128",
154 | "person_0129",
155 | "person_0133",
156 | "person_0134",
157 | "person_0135",
158 | "person_0137",
159 | "person_0140",
160 | "person_0143",
161 | "person_0144",
162 | "person_0146",
163 | "person_0147",
164 | "person_0150",
165 | "person_0154",
166 | "person_0155",
167 | "person_0157",
168 | "person_0160",
169 | "person_0161",
170 | "person_0162",
171 | "person_0163",
172 | "person_0164",
173 | "person_0165",
174 | "person_0167",
175 | "person_0168",
176 | "person_0176",
177 | "person_0179",
178 | "person_0180",
179 | "person_0185",
180 | "person_0187",
181 | "person_0194",
182 | "person_0196",
183 | "person_0198"
184 | ]
185 | },
186 | {
187 | "test": [
188 | "person_0001",
189 | "person_0004",
190 | "person_0005",
191 | "person_0007",
192 | "person_0008",
193 | "person_0009",
194 | "person_0010",
195 | "person_0011",
196 | "person_0020",
197 | "person_0021",
198 | "person_0022",
199 | "person_0024",
200 | "person_0025",
201 | "person_0026",
202 | "person_0030",
203 | "person_0032",
204 | "person_0034",
205 | "person_0035",
206 | "person_0041",
207 | "person_0043",
208 | "person_0046",
209 | "person_0050",
210 | "person_0051",
211 | "person_0055",
212 | "person_0056",
213 | "person_0060",
214 | "person_0061",
215 | "person_0066",
216 | "person_0067",
217 | "person_0068",
218 | "person_0069",
219 | "person_0073",
220 | "person_0074",
221 | "person_0076",
222 | "person_0077",
223 | "person_0078",
224 | "person_0081",
225 | "person_0083",
226 | "person_0085",
227 | "person_0087",
228 | "person_0088",
229 | "person_0089",
230 | "person_0090",
231 | "person_0091",
232 | "person_0092",
233 | "person_0095",
234 | "person_0096",
235 | "person_0099",
236 | "person_0100",
237 | "person_0107",
238 | "person_0108",
239 | "person_0110",
240 | "person_0112",
241 | "person_0113",
242 | "person_0115",
243 | "person_0116",
244 | "person_0117",
245 | "person_0119",
246 | "person_0121",
247 | "person_0123",
248 | "person_0137",
249 | "person_0141",
250 | "person_0146",
251 | "person_0148",
252 | "person_0149",
253 | "person_0153",
254 | "person_0154",
255 | "person_0155",
256 | "person_0165",
257 | "person_0167",
258 | "person_0169",
259 | "person_0172",
260 | "person_0173",
261 | "person_0174",
262 | "person_0175",
263 | "person_0179",
264 | "person_0180",
265 | "person_0181",
266 | "person_0185",
267 | "person_0186",
268 | "person_0187",
269 | "person_0189",
270 | "person_0191",
271 | "person_0192",
272 | "person_0194",
273 | "person_0195",
274 | "person_0196",
275 | "person_0197",
276 | "person_0198"
277 | ],
278 | "train": [
279 | "person_0002",
280 | "person_0003",
281 | "person_0012",
282 | "person_0013",
283 | "person_0014",
284 | "person_0015",
285 | "person_0016",
286 | "person_0017",
287 | "person_0018",
288 | "person_0023",
289 | "person_0027",
290 | "person_0028",
291 | "person_0029",
292 | "person_0033",
293 | "person_0036",
294 | "person_0037",
295 | "person_0038",
296 | "person_0039",
297 | "person_0040",
298 | "person_0042",
299 | "person_0044",
300 | "person_0045",
301 | "person_0047",
302 | "person_0048",
303 | "person_0049",
304 | "person_0053",
305 | "person_0054",
306 | "person_0058",
307 | "person_0059",
308 | "person_0063",
309 | "person_0064",
310 | "person_0065",
311 | "person_0070",
312 | "person_0072",
313 | "person_0075",
314 | "person_0079",
315 | "person_0080",
316 | "person_0082",
317 | "person_0084",
318 | "person_0086",
319 | "person_0093",
320 | "person_0097",
321 | "person_0098",
322 | "person_0101",
323 | "person_0102",
324 | "person_0103",
325 | "person_0104",
326 | "person_0105",
327 | "person_0109",
328 | "person_0111",
329 | "person_0114",
330 | "person_0118",
331 | "person_0122",
332 | "person_0124",
333 | "person_0125",
334 | "person_0126",
335 | "person_0127",
336 | "person_0128",
337 | "person_0129",
338 | "person_0130",
339 | "person_0131",
340 | "person_0132",
341 | "person_0133",
342 | "person_0134",
343 | "person_0135",
344 | "person_0140",
345 | "person_0143",
346 | "person_0144",
347 | "person_0147",
348 | "person_0150",
349 | "person_0151",
350 | "person_0152",
351 | "person_0156",
352 | "person_0157",
353 | "person_0160",
354 | "person_0161",
355 | "person_0162",
356 | "person_0163",
357 | "person_0164",
358 | "person_0168",
359 | "person_0170",
360 | "person_0171",
361 | "person_0176",
362 | "person_0177",
363 | "person_0178",
364 | "person_0182",
365 | "person_0184",
366 | "person_0188",
367 | "person_0193"
368 | ]
369 | },
370 | {
371 | "test": [
372 | "person_0003",
373 | "person_0007",
374 | "person_0010",
375 | "person_0012",
376 | "person_0015",
377 | "person_0016",
378 | "person_0017",
379 | "person_0018",
380 | "person_0023",
381 | "person_0024",
382 | "person_0025",
383 | "person_0027",
384 | "person_0028",
385 | "person_0029",
386 | "person_0033",
387 | "person_0036",
388 | "person_0037",
389 | "person_0039",
390 | "person_0040",
391 | "person_0041",
392 | "person_0044",
393 | "person_0045",
394 | "person_0047",
395 | "person_0048",
396 | "person_0049",
397 | "person_0051",
398 | "person_0053",
399 | "person_0055",
400 | "person_0056",
401 | "person_0058",
402 | "person_0059",
403 | "person_0060",
404 | "person_0061",
405 | "person_0063",
406 | "person_0068",
407 | "person_0074",
408 | "person_0078",
409 | "person_0079",
410 | "person_0080",
411 | "person_0083",
412 | "person_0087",
413 | "person_0088",
414 | "person_0095",
415 | "person_0096",
416 | "person_0100",
417 | "person_0102",
418 | "person_0104",
419 | "person_0107",
420 | "person_0109",
421 | "person_0110",
422 | "person_0112",
423 | "person_0113",
424 | "person_0114",
425 | "person_0121",
426 | "person_0122",
427 | "person_0123",
428 | "person_0129",
429 | "person_0133",
430 | "person_0134",
431 | "person_0137",
432 | "person_0140",
433 | "person_0144",
434 | "person_0146",
435 | "person_0147",
436 | "person_0148",
437 | "person_0149",
438 | "person_0150",
439 | "person_0151",
440 | "person_0152",
441 | "person_0153",
442 | "person_0155",
443 | "person_0156",
444 | "person_0165",
445 | "person_0169",
446 | "person_0171",
447 | "person_0174",
448 | "person_0176",
449 | "person_0177",
450 | "person_0179",
451 | "person_0180",
452 | "person_0185",
453 | "person_0187",
454 | "person_0189",
455 | "person_0193",
456 | "person_0194",
457 | "person_0195",
458 | "person_0196",
459 | "person_0197",
460 | "person_0198"
461 | ],
462 | "train": [
463 | "person_0001",
464 | "person_0002",
465 | "person_0004",
466 | "person_0005",
467 | "person_0008",
468 | "person_0009",
469 | "person_0011",
470 | "person_0013",
471 | "person_0014",
472 | "person_0020",
473 | "person_0021",
474 | "person_0022",
475 | "person_0026",
476 | "person_0030",
477 | "person_0032",
478 | "person_0034",
479 | "person_0035",
480 | "person_0038",
481 | "person_0042",
482 | "person_0043",
483 | "person_0046",
484 | "person_0050",
485 | "person_0054",
486 | "person_0064",
487 | "person_0065",
488 | "person_0066",
489 | "person_0067",
490 | "person_0069",
491 | "person_0070",
492 | "person_0072",
493 | "person_0073",
494 | "person_0075",
495 | "person_0076",
496 | "person_0077",
497 | "person_0081",
498 | "person_0082",
499 | "person_0084",
500 | "person_0085",
501 | "person_0086",
502 | "person_0089",
503 | "person_0090",
504 | "person_0091",
505 | "person_0092",
506 | "person_0093",
507 | "person_0097",
508 | "person_0098",
509 | "person_0099",
510 | "person_0101",
511 | "person_0103",
512 | "person_0105",
513 | "person_0108",
514 | "person_0111",
515 | "person_0115",
516 | "person_0116",
517 | "person_0117",
518 | "person_0118",
519 | "person_0119",
520 | "person_0124",
521 | "person_0125",
522 | "person_0126",
523 | "person_0127",
524 | "person_0128",
525 | "person_0130",
526 | "person_0131",
527 | "person_0132",
528 | "person_0135",
529 | "person_0141",
530 | "person_0143",
531 | "person_0154",
532 | "person_0157",
533 | "person_0160",
534 | "person_0161",
535 | "person_0162",
536 | "person_0163",
537 | "person_0164",
538 | "person_0167",
539 | "person_0168",
540 | "person_0170",
541 | "person_0172",
542 | "person_0173",
543 | "person_0175",
544 | "person_0178",
545 | "person_0181",
546 | "person_0182",
547 | "person_0184",
548 | "person_0186",
549 | "person_0188",
550 | "person_0191",
551 | "person_0192"
552 | ]
553 | },
554 | {
555 | "test": [
556 | "person_0002",
557 | "person_0004",
558 | "person_0007",
559 | "person_0008",
560 | "person_0009",
561 | "person_0016",
562 | "person_0021",
563 | "person_0022",
564 | "person_0023",
565 | "person_0024",
566 | "person_0026",
567 | "person_0027",
568 | "person_0030",
569 | "person_0035",
570 | "person_0036",
571 | "person_0038",
572 | "person_0041",
573 | "person_0042",
574 | "person_0043",
575 | "person_0046",
576 | "person_0049",
577 | "person_0050",
578 | "person_0056",
579 | "person_0058",
580 | "person_0060",
581 | "person_0064",
582 | "person_0065",
583 | "person_0069",
584 | "person_0070",
585 | "person_0072",
586 | "person_0076",
587 | "person_0077",
588 | "person_0078",
589 | "person_0081",
590 | "person_0082",
591 | "person_0083",
592 | "person_0087",
593 | "person_0088",
594 | "person_0090",
595 | "person_0092",
596 | "person_0095",
597 | "person_0100",
598 | "person_0102",
599 | "person_0103",
600 | "person_0105",
601 | "person_0107",
602 | "person_0108",
603 | "person_0109",
604 | "person_0117",
605 | "person_0118",
606 | "person_0119",
607 | "person_0122",
608 | "person_0124",
609 | "person_0125",
610 | "person_0127",
611 | "person_0128",
612 | "person_0131",
613 | "person_0133",
614 | "person_0134",
615 | "person_0135",
616 | "person_0137",
617 | "person_0140",
618 | "person_0141",
619 | "person_0146",
620 | "person_0148",
621 | "person_0149",
622 | "person_0150",
623 | "person_0152",
624 | "person_0153",
625 | "person_0154",
626 | "person_0156",
627 | "person_0157",
628 | "person_0161",
629 | "person_0162",
630 | "person_0165",
631 | "person_0167",
632 | "person_0168",
633 | "person_0174",
634 | "person_0176",
635 | "person_0177",
636 | "person_0178",
637 | "person_0179",
638 | "person_0180",
639 | "person_0181",
640 | "person_0184",
641 | "person_0192",
642 | "person_0193",
643 | "person_0194",
644 | "person_0198"
645 | ],
646 | "train": [
647 | "person_0001",
648 | "person_0003",
649 | "person_0005",
650 | "person_0010",
651 | "person_0011",
652 | "person_0012",
653 | "person_0013",
654 | "person_0014",
655 | "person_0015",
656 | "person_0017",
657 | "person_0018",
658 | "person_0020",
659 | "person_0025",
660 | "person_0028",
661 | "person_0029",
662 | "person_0032",
663 | "person_0033",
664 | "person_0034",
665 | "person_0037",
666 | "person_0039",
667 | "person_0040",
668 | "person_0044",
669 | "person_0045",
670 | "person_0047",
671 | "person_0048",
672 | "person_0051",
673 | "person_0053",
674 | "person_0054",
675 | "person_0055",
676 | "person_0059",
677 | "person_0061",
678 | "person_0063",
679 | "person_0066",
680 | "person_0067",
681 | "person_0068",
682 | "person_0073",
683 | "person_0074",
684 | "person_0075",
685 | "person_0079",
686 | "person_0080",
687 | "person_0084",
688 | "person_0085",
689 | "person_0086",
690 | "person_0089",
691 | "person_0091",
692 | "person_0093",
693 | "person_0096",
694 | "person_0097",
695 | "person_0098",
696 | "person_0099",
697 | "person_0101",
698 | "person_0104",
699 | "person_0110",
700 | "person_0111",
701 | "person_0112",
702 | "person_0113",
703 | "person_0114",
704 | "person_0115",
705 | "person_0116",
706 | "person_0121",
707 | "person_0123",
708 | "person_0126",
709 | "person_0129",
710 | "person_0130",
711 | "person_0132",
712 | "person_0143",
713 | "person_0144",
714 | "person_0147",
715 | "person_0151",
716 | "person_0155",
717 | "person_0160",
718 | "person_0163",
719 | "person_0164",
720 | "person_0169",
721 | "person_0170",
722 | "person_0171",
723 | "person_0172",
724 | "person_0173",
725 | "person_0175",
726 | "person_0182",
727 | "person_0185",
728 | "person_0186",
729 | "person_0187",
730 | "person_0188",
731 | "person_0189",
732 | "person_0191",
733 | "person_0195",
734 | "person_0196",
735 | "person_0197"
736 | ]
737 | },
738 | {
739 | "test": [
740 | "person_0002",
741 | "person_0003",
742 | "person_0005",
743 | "person_0007",
744 | "person_0009",
745 | "person_0021",
746 | "person_0023",
747 | "person_0024",
748 | "person_0025",
749 | "person_0027",
750 | "person_0029",
751 | "person_0032",
752 | "person_0036",
753 | "person_0037",
754 | "person_0038",
755 | "person_0039",
756 | "person_0044",
757 | "person_0045",
758 | "person_0047",
759 | "person_0049",
760 | "person_0050",
761 | "person_0051",
762 | "person_0053",
763 | "person_0056",
764 | "person_0060",
765 | "person_0063",
766 | "person_0067",
767 | "person_0072",
768 | "person_0075",
769 | "person_0076",
770 | "person_0078",
771 | "person_0079",
772 | "person_0081",
773 | "person_0083",
774 | "person_0086",
775 | "person_0087",
776 | "person_0088",
777 | "person_0090",
778 | "person_0095",
779 | "person_0096",
780 | "person_0097",
781 | "person_0099",
782 | "person_0100",
783 | "person_0101",
784 | "person_0103",
785 | "person_0104",
786 | "person_0105",
787 | "person_0110",
788 | "person_0111",
789 | "person_0115",
790 | "person_0118",
791 | "person_0119",
792 | "person_0122",
793 | "person_0123",
794 | "person_0126",
795 | "person_0127",
796 | "person_0129",
797 | "person_0131",
798 | "person_0135",
799 | "person_0140",
800 | "person_0141",
801 | "person_0152",
802 | "person_0153",
803 | "person_0154",
804 | "person_0155",
805 | "person_0156",
806 | "person_0157",
807 | "person_0161",
808 | "person_0164",
809 | "person_0169",
810 | "person_0170",
811 | "person_0171",
812 | "person_0172",
813 | "person_0173",
814 | "person_0174",
815 | "person_0175",
816 | "person_0178",
817 | "person_0179",
818 | "person_0180",
819 | "person_0181",
820 | "person_0185",
821 | "person_0186",
822 | "person_0187",
823 | "person_0188",
824 | "person_0189",
825 | "person_0191",
826 | "person_0194",
827 | "person_0195",
828 | "person_0197"
829 | ],
830 | "train": [
831 | "person_0001",
832 | "person_0004",
833 | "person_0008",
834 | "person_0010",
835 | "person_0011",
836 | "person_0012",
837 | "person_0013",
838 | "person_0014",
839 | "person_0015",
840 | "person_0016",
841 | "person_0017",
842 | "person_0018",
843 | "person_0020",
844 | "person_0022",
845 | "person_0026",
846 | "person_0028",
847 | "person_0030",
848 | "person_0033",
849 | "person_0034",
850 | "person_0035",
851 | "person_0040",
852 | "person_0041",
853 | "person_0042",
854 | "person_0043",
855 | "person_0046",
856 | "person_0048",
857 | "person_0054",
858 | "person_0055",
859 | "person_0058",
860 | "person_0059",
861 | "person_0061",
862 | "person_0064",
863 | "person_0065",
864 | "person_0066",
865 | "person_0068",
866 | "person_0069",
867 | "person_0070",
868 | "person_0073",
869 | "person_0074",
870 | "person_0077",
871 | "person_0080",
872 | "person_0082",
873 | "person_0084",
874 | "person_0085",
875 | "person_0089",
876 | "person_0091",
877 | "person_0092",
878 | "person_0093",
879 | "person_0098",
880 | "person_0102",
881 | "person_0107",
882 | "person_0108",
883 | "person_0109",
884 | "person_0112",
885 | "person_0113",
886 | "person_0114",
887 | "person_0116",
888 | "person_0117",
889 | "person_0121",
890 | "person_0124",
891 | "person_0125",
892 | "person_0128",
893 | "person_0130",
894 | "person_0132",
895 | "person_0133",
896 | "person_0134",
897 | "person_0137",
898 | "person_0143",
899 | "person_0144",
900 | "person_0146",
901 | "person_0147",
902 | "person_0148",
903 | "person_0149",
904 | "person_0150",
905 | "person_0151",
906 | "person_0160",
907 | "person_0162",
908 | "person_0163",
909 | "person_0165",
910 | "person_0167",
911 | "person_0168",
912 | "person_0176",
913 | "person_0177",
914 | "person_0182",
915 | "person_0184",
916 | "person_0192",
917 | "person_0193",
918 | "person_0196",
919 | "person_0198"
920 | ]
921 | },
922 | {
923 | "test": [
924 | "person_0001",
925 | "person_0002",
926 | "person_0004",
927 | "person_0005",
928 | "person_0008",
929 | "person_0009",
930 | "person_0010",
931 | "person_0012",
932 | "person_0013",
933 | "person_0014",
934 | "person_0016",
935 | "person_0018",
936 | "person_0020",
937 | "person_0021",
938 | "person_0024",
939 | "person_0026",
940 | "person_0029",
941 | "person_0034",
942 | "person_0037",
943 | "person_0040",
944 | "person_0041",
945 | "person_0042",
946 | "person_0044",
947 | "person_0045",
948 | "person_0046",
949 | "person_0048",
950 | "person_0050",
951 | "person_0054",
952 | "person_0056",
953 | "person_0059",
954 | "person_0061",
955 | "person_0063",
956 | "person_0065",
957 | "person_0069",
958 | "person_0072",
959 | "person_0077",
960 | "person_0078",
961 | "person_0079",
962 | "person_0082",
963 | "person_0083",
964 | "person_0084",
965 | "person_0086",
966 | "person_0092",
967 | "person_0096",
968 | "person_0102",
969 | "person_0104",
970 | "person_0108",
971 | "person_0109",
972 | "person_0110",
973 | "person_0111",
974 | "person_0114",
975 | "person_0117",
976 | "person_0118",
977 | "person_0119",
978 | "person_0121",
979 | "person_0122",
980 | "person_0123",
981 | "person_0124",
982 | "person_0125",
983 | "person_0128",
984 | "person_0129",
985 | "person_0130",
986 | "person_0131",
987 | "person_0132",
988 | "person_0140",
989 | "person_0141",
990 | "person_0146",
991 | "person_0147",
992 | "person_0148",
993 | "person_0152",
994 | "person_0153",
995 | "person_0160",
996 | "person_0161",
997 | "person_0163",
998 | "person_0164",
999 | "person_0165",
1000 | "person_0169",
1001 | "person_0170",
1002 | "person_0178",
1003 | "person_0180",
1004 | "person_0184",
1005 | "person_0185",
1006 | "person_0187",
1007 | "person_0188",
1008 | "person_0189",
1009 | "person_0192",
1010 | "person_0194",
1011 | "person_0196",
1012 | "person_0198"
1013 | ],
1014 | "train": [
1015 | "person_0003",
1016 | "person_0007",
1017 | "person_0011",
1018 | "person_0015",
1019 | "person_0017",
1020 | "person_0022",
1021 | "person_0023",
1022 | "person_0025",
1023 | "person_0027",
1024 | "person_0028",
1025 | "person_0030",
1026 | "person_0032",
1027 | "person_0033",
1028 | "person_0035",
1029 | "person_0036",
1030 | "person_0038",
1031 | "person_0039",
1032 | "person_0043",
1033 | "person_0047",
1034 | "person_0049",
1035 | "person_0051",
1036 | "person_0053",
1037 | "person_0055",
1038 | "person_0058",
1039 | "person_0060",
1040 | "person_0064",
1041 | "person_0066",
1042 | "person_0067",
1043 | "person_0068",
1044 | "person_0070",
1045 | "person_0073",
1046 | "person_0074",
1047 | "person_0075",
1048 | "person_0076",
1049 | "person_0080",
1050 | "person_0081",
1051 | "person_0085",
1052 | "person_0087",
1053 | "person_0088",
1054 | "person_0089",
1055 | "person_0090",
1056 | "person_0091",
1057 | "person_0093",
1058 | "person_0095",
1059 | "person_0097",
1060 | "person_0098",
1061 | "person_0099",
1062 | "person_0100",
1063 | "person_0101",
1064 | "person_0103",
1065 | "person_0105",
1066 | "person_0107",
1067 | "person_0112",
1068 | "person_0113",
1069 | "person_0115",
1070 | "person_0116",
1071 | "person_0126",
1072 | "person_0127",
1073 | "person_0133",
1074 | "person_0134",
1075 | "person_0135",
1076 | "person_0137",
1077 | "person_0143",
1078 | "person_0144",
1079 | "person_0149",
1080 | "person_0150",
1081 | "person_0151",
1082 | "person_0154",
1083 | "person_0155",
1084 | "person_0156",
1085 | "person_0157",
1086 | "person_0162",
1087 | "person_0167",
1088 | "person_0168",
1089 | "person_0171",
1090 | "person_0172",
1091 | "person_0173",
1092 | "person_0174",
1093 | "person_0175",
1094 | "person_0176",
1095 | "person_0177",
1096 | "person_0179",
1097 | "person_0181",
1098 | "person_0182",
1099 | "person_0186",
1100 | "person_0191",
1101 | "person_0193",
1102 | "person_0195",
1103 | "person_0197"
1104 | ]
1105 | },
1106 | {
1107 | "test": [
1108 | "person_0004",
1109 | "person_0005",
1110 | "person_0008",
1111 | "person_0009",
1112 | "person_0012",
1113 | "person_0015",
1114 | "person_0016",
1115 | "person_0021",
1116 | "person_0023",
1117 | "person_0025",
1118 | "person_0028",
1119 | "person_0029",
1120 | "person_0032",
1121 | "person_0033",
1122 | "person_0035",
1123 | "person_0036",
1124 | "person_0040",
1125 | "person_0041",
1126 | "person_0044",
1127 | "person_0045",
1128 | "person_0048",
1129 | "person_0050",
1130 | "person_0051",
1131 | "person_0053",
1132 | "person_0054",
1133 | "person_0056",
1134 | "person_0059",
1135 | "person_0065",
1136 | "person_0067",
1137 | "person_0068",
1138 | "person_0072",
1139 | "person_0078",
1140 | "person_0080",
1141 | "person_0081",
1142 | "person_0082",
1143 | "person_0083",
1144 | "person_0084",
1145 | "person_0085",
1146 | "person_0086",
1147 | "person_0088",
1148 | "person_0090",
1149 | "person_0091",
1150 | "person_0092",
1151 | "person_0095",
1152 | "person_0097",
1153 | "person_0099",
1154 | "person_0107",
1155 | "person_0108",
1156 | "person_0109",
1157 | "person_0111",
1158 | "person_0113",
1159 | "person_0116",
1160 | "person_0118",
1161 | "person_0119",
1162 | "person_0121",
1163 | "person_0122",
1164 | "person_0124",
1165 | "person_0126",
1166 | "person_0127",
1167 | "person_0129",
1168 | "person_0130",
1169 | "person_0134",
1170 | "person_0135",
1171 | "person_0137",
1172 | "person_0143",
1173 | "person_0146",
1174 | "person_0149",
1175 | "person_0150",
1176 | "person_0153",
1177 | "person_0154",
1178 | "person_0157",
1179 | "person_0165",
1180 | "person_0168",
1181 | "person_0170",
1182 | "person_0172",
1183 | "person_0173",
1184 | "person_0174",
1185 | "person_0175",
1186 | "person_0176",
1187 | "person_0177",
1188 | "person_0178",
1189 | "person_0179",
1190 | "person_0180",
1191 | "person_0182",
1192 | "person_0186",
1193 | "person_0189",
1194 | "person_0192",
1195 | "person_0197",
1196 | "person_0198"
1197 | ],
1198 | "train": [
1199 | "person_0001",
1200 | "person_0002",
1201 | "person_0003",
1202 | "person_0007",
1203 | "person_0010",
1204 | "person_0011",
1205 | "person_0013",
1206 | "person_0014",
1207 | "person_0017",
1208 | "person_0018",
1209 | "person_0020",
1210 | "person_0022",
1211 | "person_0024",
1212 | "person_0026",
1213 | "person_0027",
1214 | "person_0030",
1215 | "person_0034",
1216 | "person_0037",
1217 | "person_0038",
1218 | "person_0039",
1219 | "person_0042",
1220 | "person_0043",
1221 | "person_0046",
1222 | "person_0047",
1223 | "person_0049",
1224 | "person_0055",
1225 | "person_0058",
1226 | "person_0060",
1227 | "person_0061",
1228 | "person_0063",
1229 | "person_0064",
1230 | "person_0066",
1231 | "person_0069",
1232 | "person_0070",
1233 | "person_0073",
1234 | "person_0074",
1235 | "person_0075",
1236 | "person_0076",
1237 | "person_0077",
1238 | "person_0079",
1239 | "person_0087",
1240 | "person_0089",
1241 | "person_0093",
1242 | "person_0096",
1243 | "person_0098",
1244 | "person_0100",
1245 | "person_0101",
1246 | "person_0102",
1247 | "person_0103",
1248 | "person_0104",
1249 | "person_0105",
1250 | "person_0110",
1251 | "person_0112",
1252 | "person_0114",
1253 | "person_0115",
1254 | "person_0117",
1255 | "person_0123",
1256 | "person_0125",
1257 | "person_0128",
1258 | "person_0131",
1259 | "person_0132",
1260 | "person_0133",
1261 | "person_0140",
1262 | "person_0141",
1263 | "person_0144",
1264 | "person_0147",
1265 | "person_0148",
1266 | "person_0151",
1267 | "person_0152",
1268 | "person_0155",
1269 | "person_0156",
1270 | "person_0160",
1271 | "person_0161",
1272 | "person_0162",
1273 | "person_0163",
1274 | "person_0164",
1275 | "person_0167",
1276 | "person_0169",
1277 | "person_0171",
1278 | "person_0181",
1279 | "person_0184",
1280 | "person_0185",
1281 | "person_0187",
1282 | "person_0188",
1283 | "person_0191",
1284 | "person_0193",
1285 | "person_0194",
1286 | "person_0195",
1287 | "person_0196"
1288 | ]
1289 | },
1290 | {
1291 | "test": [
1292 | "person_0002",
1293 | "person_0003",
1294 | "person_0009",
1295 | "person_0010",
1296 | "person_0012",
1297 | "person_0013",
1298 | "person_0015",
1299 | "person_0025",
1300 | "person_0028",
1301 | "person_0030",
1302 | "person_0033",
1303 | "person_0035",
1304 | "person_0036",
1305 | "person_0041",
1306 | "person_0047",
1307 | "person_0048",
1308 | "person_0049",
1309 | "person_0051",
1310 | "person_0053",
1311 | "person_0056",
1312 | "person_0058",
1313 | "person_0059",
1314 | "person_0066",
1315 | "person_0075",
1316 | "person_0078",
1317 | "person_0079",
1318 | "person_0080",
1319 | "person_0082",
1320 | "person_0086",
1321 | "person_0087",
1322 | "person_0091",
1323 | "person_0096",
1324 | "person_0097",
1325 | "person_0098",
1326 | "person_0099",
1327 | "person_0101",
1328 | "person_0102",
1329 | "person_0103",
1330 | "person_0104",
1331 | "person_0105",
1332 | "person_0107",
1333 | "person_0108",
1334 | "person_0109",
1335 | "person_0112",
1336 | "person_0115",
1337 | "person_0118",
1338 | "person_0121",
1339 | "person_0122",
1340 | "person_0126",
1341 | "person_0127",
1342 | "person_0128",
1343 | "person_0129",
1344 | "person_0130",
1345 | "person_0132",
1346 | "person_0134",
1347 | "person_0135",
1348 | "person_0137",
1349 | "person_0141",
1350 | "person_0143",
1351 | "person_0144",
1352 | "person_0146",
1353 | "person_0148",
1354 | "person_0149",
1355 | "person_0151",
1356 | "person_0152",
1357 | "person_0153",
1358 | "person_0154",
1359 | "person_0156",
1360 | "person_0157",
1361 | "person_0162",
1362 | "person_0164",
1363 | "person_0167",
1364 | "person_0168",
1365 | "person_0170",
1366 | "person_0171",
1367 | "person_0172",
1368 | "person_0173",
1369 | "person_0174",
1370 | "person_0175",
1371 | "person_0178",
1372 | "person_0179",
1373 | "person_0181",
1374 | "person_0184",
1375 | "person_0188",
1376 | "person_0191",
1377 | "person_0192",
1378 | "person_0193",
1379 | "person_0194",
1380 | "person_0198"
1381 | ],
1382 | "train": [
1383 | "person_0001",
1384 | "person_0004",
1385 | "person_0005",
1386 | "person_0007",
1387 | "person_0008",
1388 | "person_0011",
1389 | "person_0014",
1390 | "person_0016",
1391 | "person_0017",
1392 | "person_0018",
1393 | "person_0020",
1394 | "person_0021",
1395 | "person_0022",
1396 | "person_0023",
1397 | "person_0024",
1398 | "person_0026",
1399 | "person_0027",
1400 | "person_0029",
1401 | "person_0032",
1402 | "person_0034",
1403 | "person_0037",
1404 | "person_0038",
1405 | "person_0039",
1406 | "person_0040",
1407 | "person_0042",
1408 | "person_0043",
1409 | "person_0044",
1410 | "person_0045",
1411 | "person_0046",
1412 | "person_0050",
1413 | "person_0054",
1414 | "person_0055",
1415 | "person_0060",
1416 | "person_0061",
1417 | "person_0063",
1418 | "person_0064",
1419 | "person_0065",
1420 | "person_0067",
1421 | "person_0068",
1422 | "person_0069",
1423 | "person_0070",
1424 | "person_0072",
1425 | "person_0073",
1426 | "person_0074",
1427 | "person_0076",
1428 | "person_0077",
1429 | "person_0081",
1430 | "person_0083",
1431 | "person_0084",
1432 | "person_0085",
1433 | "person_0088",
1434 | "person_0089",
1435 | "person_0090",
1436 | "person_0092",
1437 | "person_0093",
1438 | "person_0095",
1439 | "person_0100",
1440 | "person_0110",
1441 | "person_0111",
1442 | "person_0113",
1443 | "person_0114",
1444 | "person_0116",
1445 | "person_0117",
1446 | "person_0119",
1447 | "person_0123",
1448 | "person_0124",
1449 | "person_0125",
1450 | "person_0131",
1451 | "person_0133",
1452 | "person_0140",
1453 | "person_0147",
1454 | "person_0150",
1455 | "person_0155",
1456 | "person_0160",
1457 | "person_0161",
1458 | "person_0163",
1459 | "person_0165",
1460 | "person_0169",
1461 | "person_0176",
1462 | "person_0177",
1463 | "person_0180",
1464 | "person_0182",
1465 | "person_0185",
1466 | "person_0186",
1467 | "person_0187",
1468 | "person_0189",
1469 | "person_0195",
1470 | "person_0196",
1471 | "person_0197"
1472 | ]
1473 | },
1474 | {
1475 | "test": [
1476 | "person_0003",
1477 | "person_0009",
1478 | "person_0010",
1479 | "person_0011",
1480 | "person_0012",
1481 | "person_0015",
1482 | "person_0017",
1483 | "person_0022",
1484 | "person_0023",
1485 | "person_0026",
1486 | "person_0027",
1487 | "person_0034",
1488 | "person_0035",
1489 | "person_0036",
1490 | "person_0038",
1491 | "person_0041",
1492 | "person_0043",
1493 | "person_0044",
1494 | "person_0047",
1495 | "person_0048",
1496 | "person_0050",
1497 | "person_0051",
1498 | "person_0053",
1499 | "person_0055",
1500 | "person_0056",
1501 | "person_0064",
1502 | "person_0066",
1503 | "person_0069",
1504 | "person_0070",
1505 | "person_0072",
1506 | "person_0076",
1507 | "person_0078",
1508 | "person_0083",
1509 | "person_0084",
1510 | "person_0086",
1511 | "person_0088",
1512 | "person_0089",
1513 | "person_0092",
1514 | "person_0093",
1515 | "person_0095",
1516 | "person_0097",
1517 | "person_0098",
1518 | "person_0099",
1519 | "person_0101",
1520 | "person_0102",
1521 | "person_0103",
1522 | "person_0107",
1523 | "person_0108",
1524 | "person_0109",
1525 | "person_0115",
1526 | "person_0117",
1527 | "person_0119",
1528 | "person_0122",
1529 | "person_0123",
1530 | "person_0124",
1531 | "person_0126",
1532 | "person_0128",
1533 | "person_0129",
1534 | "person_0130",
1535 | "person_0131",
1536 | "person_0133",
1537 | "person_0141",
1538 | "person_0144",
1539 | "person_0146",
1540 | "person_0148",
1541 | "person_0151",
1542 | "person_0154",
1543 | "person_0155",
1544 | "person_0157",
1545 | "person_0160",
1546 | "person_0161",
1547 | "person_0162",
1548 | "person_0163",
1549 | "person_0168",
1550 | "person_0172",
1551 | "person_0174",
1552 | "person_0175",
1553 | "person_0179",
1554 | "person_0184",
1555 | "person_0185",
1556 | "person_0186",
1557 | "person_0187",
1558 | "person_0188",
1559 | "person_0189",
1560 | "person_0191",
1561 | "person_0193",
1562 | "person_0194",
1563 | "person_0195",
1564 | "person_0196"
1565 | ],
1566 | "train": [
1567 | "person_0001",
1568 | "person_0002",
1569 | "person_0004",
1570 | "person_0005",
1571 | "person_0007",
1572 | "person_0008",
1573 | "person_0013",
1574 | "person_0014",
1575 | "person_0016",
1576 | "person_0018",
1577 | "person_0020",
1578 | "person_0021",
1579 | "person_0024",
1580 | "person_0025",
1581 | "person_0028",
1582 | "person_0029",
1583 | "person_0030",
1584 | "person_0032",
1585 | "person_0033",
1586 | "person_0037",
1587 | "person_0039",
1588 | "person_0040",
1589 | "person_0042",
1590 | "person_0045",
1591 | "person_0046",
1592 | "person_0049",
1593 | "person_0054",
1594 | "person_0058",
1595 | "person_0059",
1596 | "person_0060",
1597 | "person_0061",
1598 | "person_0063",
1599 | "person_0065",
1600 | "person_0067",
1601 | "person_0068",
1602 | "person_0073",
1603 | "person_0074",
1604 | "person_0075",
1605 | "person_0077",
1606 | "person_0079",
1607 | "person_0080",
1608 | "person_0081",
1609 | "person_0082",
1610 | "person_0085",
1611 | "person_0087",
1612 | "person_0090",
1613 | "person_0091",
1614 | "person_0096",
1615 | "person_0100",
1616 | "person_0104",
1617 | "person_0105",
1618 | "person_0110",
1619 | "person_0111",
1620 | "person_0112",
1621 | "person_0113",
1622 | "person_0114",
1623 | "person_0116",
1624 | "person_0118",
1625 | "person_0121",
1626 | "person_0125",
1627 | "person_0127",
1628 | "person_0132",
1629 | "person_0134",
1630 | "person_0135",
1631 | "person_0137",
1632 | "person_0140",
1633 | "person_0143",
1634 | "person_0147",
1635 | "person_0149",
1636 | "person_0150",
1637 | "person_0152",
1638 | "person_0153",
1639 | "person_0156",
1640 | "person_0164",
1641 | "person_0165",
1642 | "person_0167",
1643 | "person_0169",
1644 | "person_0170",
1645 | "person_0171",
1646 | "person_0173",
1647 | "person_0176",
1648 | "person_0177",
1649 | "person_0178",
1650 | "person_0180",
1651 | "person_0181",
1652 | "person_0182",
1653 | "person_0192",
1654 | "person_0197",
1655 | "person_0198"
1656 | ]
1657 | },
1658 | {
1659 | "test": [
1660 | "person_0001",
1661 | "person_0004",
1662 | "person_0009",
1663 | "person_0010",
1664 | "person_0011",
1665 | "person_0014",
1666 | "person_0015",
1667 | "person_0016",
1668 | "person_0021",
1669 | "person_0024",
1670 | "person_0025",
1671 | "person_0030",
1672 | "person_0032",
1673 | "person_0034",
1674 | "person_0035",
1675 | "person_0041",
1676 | "person_0042",
1677 | "person_0043",
1678 | "person_0044",
1679 | "person_0045",
1680 | "person_0047",
1681 | "person_0049",
1682 | "person_0053",
1683 | "person_0054",
1684 | "person_0056",
1685 | "person_0058",
1686 | "person_0059",
1687 | "person_0060",
1688 | "person_0064",
1689 | "person_0066",
1690 | "person_0067",
1691 | "person_0070",
1692 | "person_0073",
1693 | "person_0074",
1694 | "person_0075",
1695 | "person_0076",
1696 | "person_0077",
1697 | "person_0080",
1698 | "person_0081",
1699 | "person_0082",
1700 | "person_0083",
1701 | "person_0085",
1702 | "person_0086",
1703 | "person_0087",
1704 | "person_0088",
1705 | "person_0095",
1706 | "person_0096",
1707 | "person_0097",
1708 | "person_0100",
1709 | "person_0104",
1710 | "person_0108",
1711 | "person_0110",
1712 | "person_0111",
1713 | "person_0113",
1714 | "person_0114",
1715 | "person_0115",
1716 | "person_0117",
1717 | "person_0119",
1718 | "person_0122",
1719 | "person_0126",
1720 | "person_0127",
1721 | "person_0129",
1722 | "person_0134",
1723 | "person_0140",
1724 | "person_0143",
1725 | "person_0146",
1726 | "person_0147",
1727 | "person_0150",
1728 | "person_0151",
1729 | "person_0153",
1730 | "person_0156",
1731 | "person_0157",
1732 | "person_0160",
1733 | "person_0162",
1734 | "person_0168",
1735 | "person_0174",
1736 | "person_0176",
1737 | "person_0179",
1738 | "person_0180",
1739 | "person_0181",
1740 | "person_0182",
1741 | "person_0184",
1742 | "person_0185",
1743 | "person_0186",
1744 | "person_0189",
1745 | "person_0192",
1746 | "person_0193",
1747 | "person_0196",
1748 | "person_0198"
1749 | ],
1750 | "train": [
1751 | "person_0002",
1752 | "person_0003",
1753 | "person_0005",
1754 | "person_0007",
1755 | "person_0008",
1756 | "person_0012",
1757 | "person_0013",
1758 | "person_0017",
1759 | "person_0018",
1760 | "person_0020",
1761 | "person_0022",
1762 | "person_0023",
1763 | "person_0026",
1764 | "person_0027",
1765 | "person_0028",
1766 | "person_0029",
1767 | "person_0033",
1768 | "person_0036",
1769 | "person_0037",
1770 | "person_0038",
1771 | "person_0039",
1772 | "person_0040",
1773 | "person_0046",
1774 | "person_0048",
1775 | "person_0050",
1776 | "person_0051",
1777 | "person_0055",
1778 | "person_0061",
1779 | "person_0063",
1780 | "person_0065",
1781 | "person_0068",
1782 | "person_0069",
1783 | "person_0072",
1784 | "person_0078",
1785 | "person_0079",
1786 | "person_0084",
1787 | "person_0089",
1788 | "person_0090",
1789 | "person_0091",
1790 | "person_0092",
1791 | "person_0093",
1792 | "person_0098",
1793 | "person_0099",
1794 | "person_0101",
1795 | "person_0102",
1796 | "person_0103",
1797 | "person_0105",
1798 | "person_0107",
1799 | "person_0109",
1800 | "person_0112",
1801 | "person_0116",
1802 | "person_0118",
1803 | "person_0121",
1804 | "person_0123",
1805 | "person_0124",
1806 | "person_0125",
1807 | "person_0128",
1808 | "person_0130",
1809 | "person_0131",
1810 | "person_0132",
1811 | "person_0133",
1812 | "person_0135",
1813 | "person_0137",
1814 | "person_0141",
1815 | "person_0144",
1816 | "person_0148",
1817 | "person_0149",
1818 | "person_0152",
1819 | "person_0154",
1820 | "person_0155",
1821 | "person_0161",
1822 | "person_0163",
1823 | "person_0164",
1824 | "person_0165",
1825 | "person_0167",
1826 | "person_0169",
1827 | "person_0170",
1828 | "person_0171",
1829 | "person_0172",
1830 | "person_0173",
1831 | "person_0175",
1832 | "person_0177",
1833 | "person_0178",
1834 | "person_0187",
1835 | "person_0188",
1836 | "person_0191",
1837 | "person_0194",
1838 | "person_0195",
1839 | "person_0197"
1840 | ]
1841 | }
1842 | ]
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # @author = yuci
4 | # Date: 19-3-25 下午8:13
5 |
6 |
7 | from __future__ import print_function, absolute_import
8 | import time
9 | import torch
10 | from torch import nn
11 | from eval.eva_functions import accuracy
12 | from utils import AverageMeter
13 | import torch.nn.functional as F
14 | from utils import to_numpy
15 | # from tensorboardX import SummaryWriter
16 | # writer = SummaryWriter('/media/ying/0BDD17830BDD1783/video_reid/logs/generate_q_by_lstm_100epoch')
17 | # mode decide how to train the model
18 | """ 分割损失回传
19 | optimizer1.zero_grad()
20 | loss1.backward(retain_graph=True)
21 | optimizer1.step()
22 |
23 | optimizer2.zero_grad()
24 | loss2.backward()
25 | optimizer2.step()
26 | """
27 | """ 累积梯度回传
28 | # if (i + 1) % accumulation_steps == 0:
29 | # optimizer1.step()
30 | # optimizer2.step()
31 | # optimizer1.zero_grad()
32 | # optimizer2.zero_grad()
33 | # # loss.backward()
34 | """
35 |
36 |
37 | class BaseTrainer(object):
38 |
39 | def __init__(self, model, criterion):
40 | super(BaseTrainer, self).__init__()
41 | self.model = model
42 | self.criterion = criterion
43 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
44 |
45 | def train(self, epoch, data_loader, optimizer):
46 | self.model.train()
47 |
48 | batch_time = AverageMeter()
49 | data_time = AverageMeter()
50 | losses = AverageMeter()
51 | precisions = AverageMeter()
52 | precisions1 = AverageMeter()
53 | accumulation_steps = 8
54 | # total_loss = 0
55 |
56 | end = time.time()
57 | for i, inputs in enumerate(data_loader):
58 | data_time.update(time.time() - end)
59 |
60 | netinputs0, netinputs1, targets = self._parse_data(inputs)
61 |
62 | loss = self._forward(netinputs0, netinputs1, targets) # 1.前向传播
63 | # loss = loss1 + loss2
64 | # total_loss += loss.item()
65 | losses.update(loss.item(), len(targets[0]))
66 |
67 | optimizer.zero_grad()
68 | loss.backward()
69 | optimizer.step()
70 |
71 | batch_time.update(time.time() - end)
72 | end = time.time()
73 | print_freq = 10
74 | if (i + 1) % print_freq == 0:
75 | print('Epoch: [{}][{}/{}]\t'
76 | 'Loss {:.3f} ({:.3f})\t'
77 | .format(epoch, i + 1, len(data_loader),
78 | losses.val, losses.avg))
79 |
80 | def _parse_data(self, inputs):
81 | raise NotImplementedError
82 |
83 | def _forward(self, netinputs0, netinputs1, targets):
84 | raise NotImplementedError
85 |
86 |
87 | class SEQTrainer(BaseTrainer):
88 |
89 | def __init__(self, model, criterion):
90 | super(SEQTrainer, self).__init__(model, criterion)
91 | self.criterion = criterion
92 | # self.rate = rate
93 | # self.softmax = nn.Softmax()
94 |
95 | def _parse_data(self, inputs):
96 | seq0, seq1, targets = inputs
97 | seq0 = seq0.to(self.device)
98 | seq1 = seq1.to(self.device)
99 |
100 | return seq0, seq1, targets
101 |
102 | def _forward(self, netinputs0, netinputs1, targets):
103 | # todo: target有问题?
104 | feature_p, feature_g, identity_p, identity_g = self.model(netinputs0, netinputs1)
105 | loss = self.criterion(feature_p, feature_g, identity_p, identity_g, targets)
106 |
107 | return loss
108 |
109 | def train(self, epoch, data_loader, optimizer):
110 | # self.rate = rate
111 | super(SEQTrainer, self).train(epoch, data_loader, optimizer)
112 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # @author = yuci
4 | # Date: 19-3-18 上午11:11
5 | """第二步 工具集合
6 | define some utils in this file.
7 | Code imported from https://github.com/Cysu/open-reid/.
8 | """
9 | import torch
10 | import os
11 | import os.path as osp
12 | import sys
13 | import errno
14 | import json
15 | import shutil
16 |
17 |
18 | # 0.to_numpy/to_torch 转换数据类型
19 | def to_numpy(tensor):
20 | if torch.is_tensor(tensor):
21 | return tensor.cpu().numpy()
22 | elif type(tensor).__module__ != 'numpy':
23 | raise ValueError("Cannot convert {} to numpy array".format(type(tensor)))
24 | return tensor
25 |
26 |
27 | def to_torch(ndarray):
28 | if type(ndarray).__module__ == 'numpy':
29 | return torch.from_numpy(ndarray)
30 | elif not torch.is_tensor(ndarray):
31 | raise ValueError("Cannot convert {} to torch tensor".format(type(ndarray)))
32 | return ndarray
33 |
34 |
35 | # 1.makefile/dir如果对应路径中的文件不存在,则新建这个文件
36 | def mkdir_if_missing(dir_path):
37 | try:
38 | os.makedirs(dir_path)
39 | except OSError as e:
40 | if e.errno != errno.EEXIST:
41 | raise
42 |
43 |
44 | # 2.Computes and stores the average and current value 计算和存储平均值、当前值,输出损失时用到
45 | class AverageMeter(object):
46 | def __init__(self):
47 | self.val = 0
48 | self.avg = 0
49 | self.sum = 0
50 | self.count = 0
51 |
52 | def reset(self):
53 | self.val = 0
54 | self.avg = 0
55 | self.sum = 0
56 | self.count = 0
57 |
58 | def update(self, val, n=1):
59 | self.val = val
60 | self.sum += val * n
61 | self.count += n
62 | self.avg = self.sum / self.count
63 |
64 |
65 | # 3. logger 日志管理 Write console output to external text file.
66 | class Logger(object):
67 | def __init__(self, fpath=None):
68 | self.console = sys.stdout
69 | self.file = None
70 | if fpath is not None:
71 | mkdir_if_missing(os.path.dirname(fpath))
72 | self.file = open(fpath, 'w')
73 |
74 | def __del__(self):
75 | self.close()
76 |
77 | def __enter__(self):
78 | pass
79 |
80 | def __exit__(self):
81 | self.close()
82 |
83 | def write(self, msg):
84 | self.console.write(msg)
85 | if self.file is not None:
86 | self.file.write(msg)
87 |
88 | def flush(self):
89 | self.console.flush()
90 | if self.file is not None:
91 | self.file.flush()
92 | os.fsync(self.file.fileno())
93 |
94 | def close(self):
95 | self.console.close()
96 | if self.file is not None:
97 | self.file.close()
98 |
99 |
100 | # 4.serialization
101 | def read_json(fpath):
102 | with open(fpath, 'r') as f:
103 | obj = json.load(f) # 解码:把Json格式字符串解码转换成Python对象
104 | return obj
105 |
106 |
107 | def write_json(obj, fpath):
108 | mkdir_if_missing(osp.dirname(fpath)) # 对应的文件如果不存在,则新建它
109 | with open(fpath, 'w') as f:
110 | json.dump(obj, f, indent=4, separators=(',', ':')) # 这表示dictionary内keys之间用“,”隔开,而KEY和value之间用“:”隔开
111 | # 使用json.dump()将数据obj写入文件f,会换行且按照indent的数量显示前面的空白,
112 |
113 |
114 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): # 保存checkpoint文件
115 | mkdir_if_missing(osp.dirname(fpath))
116 | torch.save(state, fpath)
117 | if is_best:
118 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar'))
119 |
120 |
121 | def load_checkpoint(fpath): # 加载checkpoint文件
122 | if osp.isfile(fpath):
123 | checkpoint = torch.load(fpath)
124 | print("=> Loaded checkpoint '{}'".format(fpath))
125 | return checkpoint
126 | else:
127 | raise ValueError("=> No checkpoint found at '{}'".format(fpath))
128 |
--------------------------------------------------------------------------------