├── .idea
├── MOT.iml
├── deployment.xml
├── dictionaries
│ └── lipeizhao.xml
├── inspectionProfiles
│ └── Project_Default.xml
├── misc.xml
├── modules.xml
└── other.xml
├── Generator.py
├── README.md
├── Test.py
├── TestGenerate.py
├── config.yml
├── main.py
├── model
├── EmbeddingNet.py
├── FuckUpNet.py
├── GCN.py
├── __init__.py
├── final.py
└── net_1024.py
├── requirements.txt
├── setting
├── ADL-Rundle-1_config.yml
├── ADL-Rundle-3_config.yml
├── AVG-TownCentre_config.yml
├── ETH-Crossing_config.yml
├── ETH-Jelmoli_config.yml
├── ETH-Linthescher_config.yml
├── KITTI-16_config.yml
├── KITTI-19_config.yml
├── PETS09-S2L2_config.yml
├── TUD-Crossing_config.yml
├── Venice-1_config.yml
├── seq01_config.yml
├── seq03_config.yml
├── seq06_config.yml
├── seq07_config.yml
├── seq08_config.yml
├── seq12_config.yml
└── seq14_config.yml
├── tracking-MOT15.py
├── tracking.py
├── tracking_utils.py
├── train
├── __init__.py
└── train_net_1024.py
└── utils.py
/.idea/MOT.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/dictionaries/lipeizhao.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | amatrix
5 | bcount
6 | cfile
7 | dcount
8 | dnow
9 | embeddingnet
10 | embnet
11 | fourcc
12 | framewise
13 | frcnn
14 | gmail
15 | imglist
16 | lstm
17 | manu
18 | ndarray
19 | peizhaoli
20 | predata
21 | rawdata
22 | savetxt
23 | seqdata
24 | seqres
25 | sprop
26 | stdv
27 | testcheck
28 | tracklet
29 |
30 |
31 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/Generator.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : Generator.py
3 | # @Author : Peizhao Li
4 | # @Contact : peizhaoli05@gmail.com
5 | # @Date : 2018/10/11
6 |
7 | import os, random
8 | import os.path as osp
9 | import numpy as np
10 | import torch
11 | from torchvision import transforms
12 | from PIL import Image, ImageDraw
13 |
14 |
15 | def LoadImg(img_path):
16 | path = os.listdir(img_path)
17 | path.sort()
18 | imglist = []
19 |
20 | for i in range(len(path)):
21 | img = Image.open(osp.join(img_path, path[i]))
22 | imglist.append(img.copy())
23 | img.close()
24 |
25 | return imglist
26 |
27 |
28 | def FindMatch(list_id, list1, list2):
29 | """
30 |
31 | :param list_id:
32 | :param list1:
33 | :param list2:
34 | :return:
35 | """
36 | index_pair = []
37 | for index, id in enumerate(list_id):
38 | index1 = list1.index(id)
39 | index2 = list2.index(id)
40 | index_pair.append(index1)
41 | index_pair.append(index2)
42 |
43 | return index_pair
44 |
45 |
46 | class VideoData(object):
47 |
48 | def __init__(self, seq_id):
49 | self.img = LoadImg("MOT17/MOT17/train/MOT17-{}-SDP/img1".format(seq_id))
50 | self.gt = np.loadtxt("MOT17/label/{}_gt.txt".format(seq_id))
51 |
52 | self.ImageWidth = self.img[0].size[0]
53 | self.ImageHeight = self.img[0].size[1]
54 |
55 | self.transforms = transforms.Compose([
56 | transforms.Resize((224, 224)),
57 | transforms.ToTensor()
58 | ])
59 |
60 | def CurData(self, frame):
61 | data = self.gt[self.gt[:, 0] == (frame + 1)]
62 |
63 | return data
64 |
65 | def PreData(self, frame):
66 | DataList = []
67 | for i in range(5):
68 | data = self.gt[self.gt[:, 0] == (frame + 1 - i)]
69 | DataList.append(data)
70 |
71 | return DataList
72 |
73 | def TotalFrame(self):
74 |
75 | return len(self.img)
76 |
77 | def CenterCoordinate(self, SingleLineData):
78 | x = (SingleLineData[2] + (SingleLineData[4] / 2)) / float(self.ImageWidth)
79 | y = (SingleLineData[3] + (SingleLineData[5] / 2)) / float(self.ImageHeight)
80 |
81 | return x, y
82 |
83 | def Appearance(self, data):
84 | """
85 |
86 | :param data:
87 | :return:
88 | """
89 | appearance = []
90 | img = self.img[int(data[0, 0]) - 1]
91 | for i in range(data.shape[0]):
92 | crop = img.crop((int(data[i, 2]), int(data[i, 3]), int(data[i, 2]) + int(data[i, 4]),
93 | int(data[i, 3]) + int(data[i, 5])))
94 | crop = self.transforms(crop)
95 | appearance.append(crop)
96 |
97 | return appearance
98 |
99 | def CurMotion(self, data):
100 | motion = []
101 | for i in range(data.shape[0]):
102 | coordinate = torch.zeros([2])
103 | coordinate[0], coordinate[1] = self.CenterCoordinate(data[i])
104 | motion.append(coordinate)
105 |
106 | return motion
107 |
108 | def PreMotion(self, DataTuple):
109 | """
110 |
111 | :param DataTuple:
112 | :return:
113 | """
114 | motion = []
115 | nameless = DataTuple[0]
116 | for i in range(nameless.shape[0]):
117 | coordinate = torch.zeros([5, 2])
118 | identity = nameless[i, 1]
119 | coordinate[4, 0], coordinate[4, 1] = self.CenterCoordinate(nameless[i])
120 | for j in range(1, 5):
121 | unknown = DataTuple[j]
122 | if identity in unknown[:, 1]:
123 | coordinate[4 - j, 0], coordinate[4 - j, 1] = self.CenterCoordinate(
124 | unknown[unknown[:, 1] == identity].squeeze())
125 | else:
126 | coordinate[4 - j, :] = coordinate[5 - j, :]
127 | motion.append(coordinate)
128 |
129 | return motion
130 |
131 | def GetID(self, data):
132 | id = []
133 | for i in range(data.shape[0]):
134 | id.append(data[i, 1])
135 |
136 | return id
137 |
138 | def __call__(self, frame):
139 | """
140 |
141 | :param frame:
142 | :return:
143 | """
144 | assert frame >= 5 and frame < self.TotalFrame()
145 | cur = self.CurData(frame)
146 | pre = self.PreData(frame - 1)
147 |
148 | cur_crop = self.Appearance(cur)
149 | pre_crop = self.Appearance(pre[0])
150 |
151 | cur_motion = self.CurMotion(cur)
152 | pre_motion = self.PreMotion(pre)
153 |
154 | cur_id = self.GetID(cur)
155 | pre_id = self.GetID(pre[0])
156 |
157 | list_id = [x for x in pre_id if x in cur_id]
158 | index_pair = FindMatch(list_id, pre_id, cur_id)
159 | gt_matrix = np.zeros([len(pre_id), len(cur_id)])
160 | for i in range(len(index_pair) / 2):
161 | gt_matrix[index_pair[2 * i], index_pair[2 * i + 1]] = 1
162 |
163 | return cur_crop, pre_crop, cur_motion, pre_motion, cur_id, pre_id, gt_matrix
164 |
165 |
166 | class Generator(object):
167 | def __init__(self, entirety=False):
168 | """
169 |
170 | :param entirety:
171 | """
172 | self.sequence = []
173 |
174 | if entirety == True:
175 | self.SequenceID = ["02", "04", "05", "09", "10", "11", "13"]
176 | else:
177 | self.SequenceID = ["09"]
178 |
179 | self.vis_save_path = "MOT17/visualize"
180 |
181 | print("\n-------------------------- initialization --------------------------")
182 | for id in self.SequenceID:
183 | print("initializing sequence {} ...".format(id))
184 | self.sequence.append(VideoData(id))
185 | print("initialize {} done".format(id))
186 | print("------------------------------ done --------------------------------\n")
187 |
188 | def visualize(self, seq_ID, frame, save_path=None):
189 | """
190 |
191 | :param seq_ID:
192 | :param frame:
193 | :param save_path:
194 | """
195 | if save_path is None:
196 | save_path = self.vis_save_path
197 |
198 | print("visualize sequence {}: frame {}".format(self.SequenceID[seq_ID], frame + 1))
199 | print("video solution: {} {}".format(self.sequence[seq_ID].ImageWidth, self.sequence[seq_ID].ImageHeight))
200 | cur_crop, pre_crop, cur_motion, pre_motion, cur_id, pre_id, gt_matrix = self.sequence[seq_ID](frame)
201 |
202 | for i in range(len(cur_crop)):
203 | img = cur_crop[i]
204 | img = transforms.functional.to_pil_image(img)
205 | img = transforms.functional.resize(img, (420, 160))
206 | draw = ImageDraw.Draw(img)
207 | # draw.text((0, 0), "id: {}\ncoord: {:3.2f}, {:3.2f}".format(int(cur_id[i]), cur_motion[i][0].item(),
208 | # cur_motion[i][1].item()), fill=(255, 0, 0))
209 | img.save(osp.join(save_path, "cur_crop_{}.png".format(str(i).zfill(2))))
210 |
211 | for i in range(len(pre_crop)):
212 | img = pre_crop[i]
213 | img = transforms.functional.to_pil_image(img)
214 | img = transforms.functional.resize(img, (420, 160))
215 | draw = ImageDraw.Draw(img)
216 | # draw.text((0, 0), "id: {}\ncoord: {:3.2f}, {:3.2f}".format(int(pre_id[i]), pre_motion[i][4, 0].item(),
217 | # pre_motion[i][4, 1].item()), fill=(255, 0, 0))
218 | img.save(osp.join(save_path, "pre_crop_{}.png".format(str(i).zfill(2))))
219 |
220 | np.savetxt(osp.join(save_path, "gt_matrix.txt"), gt_matrix, fmt="%d")
221 | np.savetxt(osp.join(save_path, "pre_id.txt"), np.array(pre_id).transpose(), fmt="%d")
222 | np.savetxt(osp.join(save_path, "cur_id.txt"), np.array(cur_id).transpose(), fmt="%d")
223 |
224 | def __call__(self):
225 | """
226 |
227 | :return:
228 | """
229 | seq = random.choice(self.sequence)
230 | frame = random.randint(5, seq.TotalFrame() - 1)
231 | cur_crop, pre_crop, cur_motion, pre_motion, cur_id, pre_id, gt_matrix = seq(frame)
232 |
233 | return cur_crop, pre_crop, cur_motion, pre_motion, gt_matrix
234 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Graph Neural Based End-to-end Data Association Framework for Online Multiple-Object Tracking
2 | A PyTorch implementation combines with Siamese Network and Graph Neural Network for Online Multiple-Object Tracking.
3 |
4 | Dataset available at [https://motchallenge.net/]
5 |
6 | According paper can be found at [https://arxiv.org/abs/1907.05315]
7 |
8 | ## How to run
9 | Use `python main.py` to train a model from scratch. Settings for training is in `config.yml`.
10 | Use `python tracking.py` to track a test video, meanwhile you need to provide the detected objects & tracking results for the first five frames. Setting for tracking is in `setting/`.
11 |
12 | ## Requirements
13 | - Python 2.7.12
14 | - numpy 1.11.0
15 | - scipy 1.1.0
16 | - torchvision 0.2.1
17 | - opencv_python 3.3.0.10
18 | - easydict 1.7
19 | - torch 0.4.1
20 | - Pillow 6.2.0
21 | - PyYAML 5.1
22 |
--------------------------------------------------------------------------------
/Test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : Test.py
3 | # @Author : Peizhao Li
4 | # @Contact : lipeizhao1997@gmail.com
5 | # @Date : 2018/10/6
6 |
7 | import os
8 | import os.path as osp
9 | import numpy as np
10 | import torch
11 | from torchvision import transforms
12 | from PIL import Image, ImageDraw
13 | from model import net_1024
14 |
15 |
16 | def LoadImg(img_path):
17 | path = os.listdir(img_path)
18 | path.sort()
19 | imglist = []
20 |
21 | for i in range(len(path)):
22 | img = Image.open(osp.join(img_path, path[i]))
23 | imglist.append(img.copy())
24 | img.close()
25 |
26 | return imglist
27 |
28 |
29 | def LoadModel(model, path):
30 | checkpoint = torch.load(path)
31 | model.load_state_dict(checkpoint["state_dict"])
32 | model.cuda().eval()
33 |
34 | return model
35 |
36 |
37 | class VideoData(object):
38 |
39 | def __init__(self, info, res_path):
40 | # MOT17
41 | # self.img = LoadImg("MOT17/MOT17/test/MOT17-{}-{}/img1".format(info[0], info[1]))
42 | # self.det = np.loadtxt("test/MOT17-{}-{}/det.txt".format(info[0], info[1]))
43 |
44 | # MOT15
45 | self.img = LoadImg("MOT15/test/{}/img1".format(info))
46 | self.det = np.loadtxt("test-MOT15/{}/det.txt".format(info))
47 |
48 | self.res_path = res_path
49 |
50 | self.ImageWidth = self.img[0].size[0]
51 | self.ImageHeight = self.img[0].size[1]
52 | self.transforms = transforms.Compose([
53 | transforms.Resize((84, 32)),
54 | transforms.ToTensor()
55 | ])
56 |
57 | def DetData(self, frame):
58 | data = self.det[self.det[:, 0] == (frame + 1)]
59 |
60 | return data
61 |
62 | def PreData(self, frame):
63 | res = np.loadtxt(self.res_path)
64 | DataList = []
65 | for i in range(5):
66 | data = res[res[:, 0] == (frame + 1 - i)]
67 | DataList.append(data)
68 |
69 | return DataList
70 |
71 | def TotalFrame(self):
72 | return len(self.img)
73 |
74 | def CenterCoordinate(self, SingleLineData):
75 | x = (SingleLineData[2] + (SingleLineData[4] / 2)) / float(self.ImageWidth)
76 | y = (SingleLineData[3] + (SingleLineData[5] / 2)) / float(self.ImageHeight)
77 |
78 | return x, y
79 |
80 | def Appearance(self, data):
81 | appearance = []
82 |
83 | img = self.img[int(data[0, 0]) - 1]
84 | for i in range(data.shape[0]):
85 | crop = img.crop((int(data[i, 2]), int(data[i, 3]), int(data[i, 2]) + int(data[i, 4]),
86 | int(data[i, 3]) + int(data[i, 5])))
87 | crop = self.transforms(crop)
88 | appearance.append(crop)
89 |
90 | return appearance
91 |
92 | def DetMotion(self, data):
93 | motion = []
94 | for i in range(data.shape[0]):
95 | coordinate = torch.zeros([2])
96 | coordinate[0], coordinate[1] = self.CenterCoordinate(data[i])
97 | motion.append(coordinate)
98 |
99 | return motion
100 |
101 | def PreMotion(self, DataTuple):
102 | motion = []
103 | nameless = DataTuple[0]
104 | for i in range(nameless.shape[0]):
105 | coordinate = torch.zeros([5, 2])
106 | identity = nameless[i, 1]
107 | coordinate[4, 0], coordinate[4, 1] = self.CenterCoordinate(nameless[i])
108 | # print(identity)
109 |
110 | for j in range(1, 5):
111 | unknown = DataTuple[j]
112 | if identity in unknown[:, 1]:
113 | coordinate[4 - j, 0], coordinate[4 - j, 1] = self.CenterCoordinate(
114 | unknown[unknown[:, 1] == identity].squeeze())
115 | else:
116 | coordinate[4 - j, :] = coordinate[5 - j, :]
117 |
118 | motion.append(coordinate)
119 |
120 | return motion
121 |
122 | def GetID(self, data):
123 | id = []
124 | for i in range(data.shape[0]):
125 | id.append(data[i, 1].copy())
126 |
127 | return id
128 |
129 | def __call__(self, frame):
130 | assert frame >= 5 and frame < self.TotalFrame()
131 | det = self.DetData(frame)
132 | pre = self.PreData(frame - 1)
133 | det_crop = self.Appearance(det)
134 | pre_crop = self.Appearance(pre[0])
135 | det_motion = self.DetMotion(det)
136 | pre_motion = self.PreMotion(pre)
137 | pre_id = self.GetID(pre[0])
138 |
139 | return det_crop, det_motion, pre_crop, pre_motion, pre_id
140 |
141 |
142 | class TestGenerator(object):
143 |
144 | def __init__(self, res_path, info):
145 | net = net_1024.net_1024()
146 | net_path = "SaveModel/net_1024_beta2.pth"
147 | print("-------> loading net_1024")
148 | self.net = LoadModel(net, net_path)
149 |
150 | self.sequence = []
151 |
152 | print("-------> initializing MOT17-{}-{} ...".format(info[0], info[1]))
153 | self.sequence.append(VideoData(info, res_path))
154 | print("-------> initialize MOT17-{}-{} done".format(info[0], info[1]))
155 |
156 | self.vis_save_path = "test/visualize"
157 |
158 | def visualize(self, SeqID, frame, save_path=None):
159 | """
160 |
161 | :param seq_ID:
162 | :param frame:
163 | :param save_path:
164 | """
165 | if save_path is None:
166 | save_path = self.vis_save_path
167 |
168 | print("visualize sequence {}: frame {}".format(self.SequenceID[SeqID], frame + 1))
169 | print("video solution: {} {}".format(self.sequence[SeqID].ImageWidth, self.sequence[SeqID].ImageHeight))
170 | det_crop, det_motion, pre_crop, pre_motion, pre_id = self.sequence[SeqID](frame)
171 |
172 | for i in range(len(det_crop)):
173 | img = det_crop[i]
174 | img = transforms.functional.to_pil_image(img)
175 | img = transforms.functional.resize(img, (420, 160))
176 | draw = ImageDraw.Draw(img)
177 | draw.text((0, 0), "num: {}\ncoord: {:3.2f}, {:3.2f}".format(int(i), det_motion[i][0].item(),
178 | det_motion[i][1].item()), fill=(255, 0, 0))
179 | img.save(osp.join(save_path, "det_crop_{}.png".format(str(i).zfill(2))))
180 |
181 | for i in range(len(pre_crop)):
182 | img = pre_crop[i]
183 | img = transforms.functional.to_pil_image(img)
184 | img = transforms.functional.resize(img, (420, 160))
185 | draw = ImageDraw.Draw(img)
186 | draw.text((0, 0), "num: {}\nid: {}\ncoord: {:3.2f}, {:3.2f}".format(int(i), int(pre_id[i]),
187 | pre_motion[i][4, 0].item(),
188 | pre_motion[i][4, 1].item()),
189 | fill=(255, 0, 0))
190 | img.save(osp.join(save_path, "pre_crop_{}.png".format(str(i).zfill(2))))
191 |
192 | np.savetxt(osp.join(save_path, "pre_id.txt"), np.array(pre_id).transpose(), fmt="%d")
193 |
194 | def __call__(self, SeqID, frame):
195 | # frame start with 5, exist frame start from 1
196 | sequence = self.sequence[SeqID]
197 | det_crop, det_motion, pre_crop, pre_motion, pre_id = sequence(frame)
198 | with torch.no_grad():
199 | s0, s1, s2, s3, adj1, adj = self.net(pre_crop, det_crop, pre_motion, det_motion)
200 |
201 | return adj
202 |
--------------------------------------------------------------------------------
/TestGenerate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : TestGenerate.py
3 | # @Author : Peizhao Li
4 | # @Contact : peizhaoli05gmail.com
5 | # @Date : 2018/10/30
6 |
7 | from Test import TestGenerator
8 | import scipy.io as sio
9 | from utils import *
10 | import os.path as osp
11 |
12 | np.set_printoptions(precision=2, suppress=True)
13 |
14 |
15 | def MakeCell(data):
16 | cell = []
17 | frame_last = data[-1, 0]
18 | for i in range(1, int(frame_last) + 1):
19 | data_ = data[data[:, 0] == i]
20 | cell.append(data_.copy())
21 |
22 | return cell
23 |
24 |
25 | def TrackletRotation(newFrame, oldTracklet):
26 | # This function do the tracklet rotation in the prev_frame, happen at each frame preparing the input to the model
27 | # ---> oldTracklet is a 17-dim tuple
28 | # ---> newFrame is a 7-dim tuple
29 |
30 | # First, rotate the oldTracklet, i.e. split the oldTracklet into 7-dim head and 10-dim tail
31 | # rotate the tail by poping the last two, and insert(0) in the X,Y just tracked, i.e. newFrame[2,3]
32 | # Second, put the newFrame X, Y into the odlTracklet 2 & 3dim
33 | head = list(oldTracklet[:7])
34 | tail = list(oldTracklet[7:])
35 | for i in range(2):
36 | tail.pop()
37 | tail.insert(0, newFrame[3 - i])
38 | head[5 - i] = newFrame[5 - i]
39 | head[3 - i] = newFrame[3 - i]
40 |
41 | output = head + tail
42 | return output
43 |
44 |
45 | def DistanceMeasure(prev_coord, current_coord, FrameW, FrameH):
46 | x_dist = abs(prev_coord[0] - current_coord[0])
47 | y_dist = abs(prev_coord[1] - current_coord[1])
48 |
49 | x_dist = float(x_dist) / float(FrameW)
50 | y_dist = float(y_dist) / float(FrameH)
51 |
52 | rst = x_dist + y_dist
53 |
54 | return rst
55 |
56 |
57 | def Tracker(Amatrix, Prev_frame, PrevIDs, CurrentDRs, BirthBuffer, DeathBuffer, IDnow, FrameW, FrameH, NIDnow):
58 | '''
59 | This Tracker function output the CurrentDRs tensor as the thing to be wrote into the txt;
60 | CurrentDRs input as all TrID = -1;
61 | Then:
62 | ---> 1-to-1, write the trajID into the CurrentDRs correspondingly;
63 | ---> Birth, intitialize a new temporal ID in the birth buffer for that DR; ---> Need a Birth Buffer: buffer is a list of tuples
64 | ---> Death, put the dead trajIDs into the death list; ---> Nedd a Death Buffer: buffer is a list of tuples
65 | '''
66 |
67 | '''
68 | ---> Trajectories: Len = 5, a list of 7 + 2 * 5 = 17-dim tuple: [FrID, TrID, X, Y, W, H, X1, Y1, X2, Y2, ...., X5, Y5], W & H remains; just trajectories in the prev_frame
69 | ---> PrevIDs: the IDs corresponds to the Trajectories, who index the IDs in the AMatrix rows
70 | ---> CurrentDRs: a list of 7-dim tuple: [FrID, -1, X, Y, W, H]
71 | ---> BirthBuffer: a list of 7-dim tuple: [FrID, 0, X, Y, W, H]
72 | ---> DeathBuffer: a list of 7-dim tuple: [FrID, 'D', X, Y, W, H] for flags, use 'DD', 'DDD, etc
73 |
74 | RETURN ---> output, which is the ID assigned CurrentDRs
75 | '''
76 |
77 | prev_num = Amatrix.shape[1] # Amatrix.shape[0] is the batch size
78 | next_num = Amatrix.shape[2]
79 |
80 | DROccupiedIndexes = []
81 | ConfirmedBirthRowID = []
82 | ToDo_DoomedTrajID = [] # Confirmed Death Trajs to be taken out of prev_frame after each Tracker() iteration
83 | Fail2ConfirmedBirthKill = []
84 | '''
85 | The New Reading Matrix Logic
86 | '''
87 | if 'New Reading Logic for the AMatrix':
88 | '''
89 | set a configurable threshold params, reading by iteratively localizing the largest item in the matrix,
90 | Then finish an association, set the associated row and column to a very small negative value,
91 | Then repeat, until the max value in the matrix is below Threshold, then the left un-associated left &
92 | column to reason death & birth
93 | '''
94 | Th = 0
95 | dist_Th = 0.05
96 | DeathRows = [i for i in range(prev_num)]
97 | BirthCols = [j for j in range(next_num)]
98 |
99 | # Start the main loop here, may never reach the upper bound prev_num * next_num
100 | for i in range(prev_num * next_num):
101 | # print '-------------------------- i = %d'%i
102 | # compute the row and column index for the max value in a matrix
103 | # As AMatrix is a list of tuples, then has to find the max value in each list (row)
104 | # Then find the max column index in that row
105 | for k in range(prev_num):
106 | # print '--------------------------- prev_num %d'%prev_num
107 | # print '--------------------------- next_num %d'%next_num
108 | # print '------------------------------ k = %d'%k
109 | row_maxValue = [] # the index of row_maxValue corresponds to the row index of AMatrix
110 | [row_maxValue.append(max(Amatrix[0, j])) for j in range(prev_num)]
111 | max_rowValue = max(row_maxValue)
112 |
113 | # 1-to-1 Associations
114 | if max_rowValue > Th:
115 | '''
116 | Cases of 1-to-1 Associations
117 | '''
118 | max_rowIndex = row_maxValue.index(max_rowValue)
119 | max_colIndex = list(Amatrix[0, max_rowIndex]).index(max_rowValue)
120 |
121 | # Mark these associated row and col out of the DeathRows and BirthCols
122 | DeathRows.remove(max_rowIndex)
123 | BirthCols.remove(max_colIndex)
124 |
125 | print("Cases of 1-to-1 Association, the selected max row and col index:")
126 | print [max_rowIndex, max_colIndex]
127 |
128 | print("DeathRows:")
129 | print DeathRows
130 |
131 | print("BirthCols:")
132 | print BirthCols
133 |
134 | '''
135 | # ----------------------------- Case 1: Normal 1-to-1 association
136 | '''
137 | if PrevIDs[max_rowIndex] > -10 and PrevIDs[max_rowIndex] != -1:
138 | associated_DrIndex = max_colIndex
139 |
140 | '''
141 | ------- ------- ------- Distance Threshold
142 | '''
143 | prev_wh = Prev_frame[max_rowIndex][2:4]
144 | current_wh = CurrentDRs[associated_DrIndex][2:4]
145 | dist1 = DistanceMeasure(prev_wh, current_wh, FrameW, FrameH)
146 |
147 | if dist1 <= dist_Th:
148 | CurrentDRs[associated_DrIndex][1] = Prev_frame[max_rowIndex][1]
149 | print(
150 | "Normal 1-to-1 Association at %d row, traj ID paased onto next frame %d DR_Index, is %d") % (
151 | max_rowIndex, associated_DrIndex, CurrentDRs[associated_DrIndex][1])
152 | # Check if this association revive anybody in the DeathBuffer. i.e. if the associated ID is someone in the DeathBuffer, then it is revived, removed from DeathBuffer.
153 | # the just associated trajID here is Prev_frame[i][1]
154 | allTrajIDInDeathBuffer = [DeathBuffer[i][0] for i in range(len(DeathBuffer))]
155 | if Prev_frame[k][1] in allTrajIDInDeathBuffer:
156 | reviveTrajIndex = allTrajIDInDeathBuffer.index(Prev_frame[k][1])
157 | revivedID = DeathBuffer[reviveTrajIndex][0]
158 | print 'Trajectory %d is about to be revived from the DeathBuffer' % revivedID
159 | DeathBuffer.remove(DeathBuffer[reviveTrajIndex])
160 |
161 | '''
162 | # --------------------------- Case 2: Birth confirmation 1-to-1 association
163 | '''
164 | if PrevIDs[max_rowIndex] <= -10:
165 | print 'Birth confirmation 1-to-1 Association at %d row' % max_rowIndex
166 | associated_DrIndex = max_colIndex
167 |
168 | '''
169 | ------- ------- ------- Distance Threshold
170 | '''
171 | prev_wh = Prev_frame[max_rowIndex][2:4]
172 | current_wh = CurrentDRs[associated_DrIndex][2:4]
173 | dist2 = DistanceMeasure(prev_wh, current_wh, FrameW, FrameH)
174 |
175 | if dist2 <= dist_Th:
176 | CurrentDRs[associated_DrIndex][1] = IDnow
177 | print '$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$ A New Target id as %d is generated $$$$$$$$$$$$$$$$$$$$$' % IDnow
178 | IDnow = IDnow + 1
179 |
180 | ConfirmedBirthRowID.append(PrevIDs[max_rowIndex])
181 |
182 | elif dist2 > dist_Th:
183 | Fail2ConfirmedBirthKill.append(PrevIDs[max_rowIndex])
184 | # Conservation Check here, once a column is taken, it cannot be taken again, so remove it from the AMatrix
185 | if associated_DrIndex in DROccupiedIndexes:
186 | print 'Conservation Constraint violated at row %d when doing Birth confirmation 1-to-1 Association' % i
187 | else:
188 | DROccupiedIndexes.append(associated_DrIndex)
189 |
190 | # change the row & col in this instance of 1-to-1 association to be all < Th
191 | Amatrix[0][max_rowIndex, :] = Th
192 | Amatrix[0][:, max_colIndex] = Th
193 | # print 'Hi'
194 |
195 | # B & D for cols and rows
196 | if max_rowValue <= Th:
197 | '''
198 | The Rest un-associated rows are death, columns are birth
199 | '''
200 | if not DeathRows:
201 | # print '---> Tracker: No Death at this association'
202 | pass
203 | if not BirthCols:
204 | # print '---> Tracker: No Birth at this association'
205 | pass
206 |
207 | '''
208 | *************************** Death handles
209 | # for a just dead target, if it is within the death window, just prolong it, do not report it to be death
210 | # within the death windows frames.
211 | '''
212 | if DeathRows:
213 | for t in range(len(DeathRows)):
214 | if DeathRows[t] != -10000:
215 | DeadTrajID = prev_frame[DeathRows[t]][1]
216 |
217 | '''
218 | If a Death is ID <= -10, then terminate it right away
219 | The Birth if not confirmed in the next frame, then it should be terminated
220 | This termination is handled by the death handle altogether
221 | '''
222 | if DeadTrajID <= -10:
223 | #DeathBuffer.append([DeadTrajID, 5])
224 | Fail2ConfirmedBirthKill.append(DeadTrajID)
225 | print '_+_+_+_+_+_+_+_+_+_+_+ In death handles, a temporal birth %d fail to confirm, about to be killed right away in Fail2ConfirmedBirthKill'%DeadTrajID
226 |
227 | '''
228 | Normal death, update the death flag, then invoke motion prediciton
229 | '''
230 | if DeadTrajID > -10:
231 | '''
232 | # Step 1: Check and update the death counter of this dead trajectory in the DeathBuffer
233 | '''
234 | allTrajIDInDeathBuffer = [DeathBuffer[i][0] for i in range(len(DeathBuffer))]
235 |
236 | if DeadTrajID == 1:
237 | print 'Gotcha'
238 |
239 | if DeadTrajID not in allTrajIDInDeathBuffer:
240 | DeathBuffer.append([DeadTrajID, 0])
241 | else:
242 | # find the index of the DeadID, update its death counter by + 1
243 | AllDeadIDs = [DeathBuffer[i][0] for i in range(len(DeathBuffer))]
244 | temp_index = AllDeadIDs.index(DeadTrajID)
245 | DeathBuffer[temp_index][1] += 1
246 |
247 | # Step 2: As this trajectory is dead, i.e. fail to associate to a DR to prolong itself in this frame, then use a Motion Prediction
248 | # to propagate into a new dummy bbox to add into the result CurrentDRs
249 | DeadTraj = Prev_frame[DeathRows[t]][:]
250 |
251 | DeathRows[t] = -10000
252 |
253 | # Do the motion prediction using the tracklets in DeadTraj to predict a [X, Y], together with the original W, H, form a new DR wit ID to add into the CurrentDRs
254 | '''
255 | -------- MOTION PREDICTION HERE ----------
256 | '''
257 | if 'Motion Prediction for Dummies':
258 | print '----------------------------------------------- A target %d is not associated in this frame Motion Prediction at play now.' % DeadTrajID
259 | temp_tracklet = list(DeadTraj[2:4]) + list(DeadTraj[
260 | 7:15]) # temp_tracklet is to be feed into the LSTM, a list of 5 pairs of (x, y)
261 |
262 | dummyWH = list(DeadTraj[4:6])
263 | # Establish a simple Linear model here, which is about to be substitued by a LSTM
264 | temp_x_seq = [temp_tracklet[i] for i in range(0, len(temp_tracklet), 2)]
265 | temp_y_seq = [temp_tracklet[j] for j in range(1, len(temp_tracklet), 2)]
266 |
267 | predicted_x = temp_x_seq[1] + temp_x_seq[1] - temp_x_seq[2]
268 | predicted_y = temp_y_seq[1] + temp_y_seq[1] - temp_y_seq[2]
269 |
270 | dummy = list(DeadTraj[:2]) + list([predicted_x, predicted_y]) + list(dummyWH) + [0]
271 |
272 | if predicted_x < 0 or predicted_y < 0:
273 | print 'Im the Storm'
274 |
275 |
276 |
277 | dummy[0] += 1
278 | dummy = np.array(dummy)
279 |
280 | CurrentDRs.append(dummy)
281 |
282 | else:
283 | pass
284 |
285 | '''
286 | *************************** Birth handles
287 | '''
288 | if BirthCols:
289 | for p in range(len(BirthCols)):
290 | if BirthCols[p] != -10000:
291 | Birth_DrIndex = BirthCols[p]
292 | BirthCols[p] = -10000
293 | CurrentDRs[Birth_DrIndex][1] = NIDnow # The new birth is temporally IDed as 0, then moved into the BirthBuffer
294 | print '_+_+_+_+_+_+_+_+_+_+_+_+_+_+ One temporal birth found, assigned NID is %d' %NIDnow
295 | NIDnow = NIDnow - 1
296 | # BirthBuffer deprecated
297 | # BirthBuffer.append(CurrentDRs[Birth_DrIndex][:])
298 |
299 | print 'Birth at %d column' % Birth_DrIndex
300 |
301 | else:
302 | pass
303 |
304 | else:
305 | break
306 |
307 | # -------------------------------------------------------- Auxiliary Per-Frame Operations ---------------------------------------------------- #
308 | '''
309 | -------------------------- Death Counter Check and dead Bbox termination ---------------------
310 | '''
311 | # Check Real Death (them IDs in the DeathBuffer that with death counter up to DeathWindow) for termination
312 | AllDeathCounter = [DeathBuffer[i][1] for i in range(len(DeathBuffer))]
313 | DoomedIDs = []
314 | DeathBufferRemoveOnSpot = []
315 | if DeathWindow in AllDeathCounter:
316 | for i in range(len(DeathBuffer)):
317 | if DeathBuffer[i][1] >= DeathWindow:
318 | print 'One Trajectory %d meets the DeathWindow = %d, and is about to be terminate' % (DeathBuffer[i][0], DeathWindow)
319 | DoomedIDs.append(DeathBuffer[i][0])
320 | #print ' ! ! ! ! ! ! ! ! ! ! ! ! ! Trajectory %d is being terminated:' % DeathBuffer[i][0]
321 | DeathBufferRemoveOnSpot.append([DeathBuffer[i][0], DeathWindow])
322 | #DeathBuffer.remove([DeathBuffer[i][0], DeathWindow])
323 | #print 'IM the Storm'
324 |
325 | # Remove those trajs that have met DeathWindow in this frame out of the DeathBuffer
326 | for i in range(len(DeathBufferRemoveOnSpot)):
327 | DeathBuffer.remove(DeathBufferRemoveOnSpot[i])
328 |
329 | #print 'line 312'
330 |
331 | # -------------------- How to terminate all the trajectories with the DoomedIDs: i.e. remove it from the prev_frame
332 | # DoomedTrajsIndex = []
333 | # for i in range(len(prev_frame)):
334 | # if prev_frame[i][1] in DoomedIDs:
335 | # DoomedTrajsIndex.append(i)
336 |
337 | # for i in range(len(DoomedTrajsIndex)):
338 | # print ' ! ! ! ! ! ! ! ! ! ! ! ! ! Trajectory %d is being terminated:' % prev_frame[DoomedTrajsIndex[i]][1]
339 | # ToDo_DoomedTrajID.append(DoomedTrajsIndex[i])
340 | # list(prev_frame).pop(DoomedTrajsIndex[i])
341 |
342 | # # -------------------- Conservation Check again
343 | # if DRavaliableIndexes != DRavaliableIndexes2:
344 | # print 'Conservation check violated by checking the birth by scenario and birth by removing all occupied DRs.'
345 |
346 | '''
347 | Post-Tracking processing, do the rotation to get a 17_dim output
348 | Now CurrentDRs hold 7-dim where the TrID has been updated
349 | Prolong & rotate the prev_frame with CurrentDRs by associating with TrID
350 | '''
351 | # ------------------------------- Case 1: 1-to-1 Association
352 | # --------------- Rotate & update the prev_frame
353 | for i in range(len(prev_frame)):
354 | for j in range(len(CurrentDRs)):
355 | if prev_frame[i][1] == CurrentDRs[j][1]:
356 | prev_frame[i] = TrackletRotation(CurrentDRs[j], prev_frame[i])
357 |
358 | # ------------------------------- Case 2: Birth Confirmation
359 | # ------------ Birth Confirmation now is in the CurrentDRs with a newly assigned ID, need to find them, and pad 0
360 | ConfirmedBirthProlonged = []
361 | allPrevIDs = [prev_frame[i][1] for i in range(len(prev_frame))]
362 |
363 | for k in range(len(CurrentDRs)):
364 | '''
365 | ---- The padding for the confirmed birth -----
366 | '''
367 | if CurrentDRs[k][1] not in allPrevIDs and CurrentDRs[k][1] > -10 and CurrentDRs[k][1] != -1:
368 | padding = np.zeros(2 * TrackletLen)
369 | for i in range(len(padding)):
370 | if i % 2 == 0:
371 | padding[i] = CurrentDRs[k][2]
372 | if i % 2 != 0:
373 | padding[i] = CurrentDRs[k][3]
374 |
375 | CurrentDRs[k] = np.concatenate((CurrentDRs[k], padding))
376 |
377 | ConfirmedBirthProlonged.append(CurrentDRs[k])
378 |
379 | '''
380 | ---- The padding for the newly birth --------
381 | '''
382 | if CurrentDRs[k][1] not in allPrevIDs and CurrentDRs[k][1] <= -10 and CurrentDRs[k][1] != -1:
383 | padding = np.zeros(2 * TrackletLen)
384 | CurrentDRs[k] = np.concatenate((CurrentDRs[k], padding))
385 |
386 | ConfirmedBirthProlonged.append(CurrentDRs[k])
387 |
388 | if ConfirmedBirthProlonged:
389 | # print 'ConfirmedBirthProlonged size:'
390 | # print len(ConfirmedBirthProlonged[0])
391 | output = np.concatenate((Prev_frame, ConfirmedBirthProlonged))
392 | else:
393 | output = Prev_frame
394 |
395 | return output, BirthBuffer, DeathBuffer, IDnow, ConfirmedBirthRowID, DoomedIDs, NIDnow, Fail2ConfirmedBirthKill
396 |
397 |
398 | if __name__ == "__main__":
399 |
400 | os.environ["CUDA_VISIBLE_DEVICES"] = "3"
401 |
402 | data_framewise = np.loadtxt("test/MOT17-01-SDP/det.txt")
403 | data_framewise = MakeCell(data_framewise)
404 | manual_init = np.loadtxt("test/MOT17-01-SDP/res.txt")
405 | manual_init = MakeCell(manual_init)
406 |
407 | batchSize = 1
408 | TrackletLen = 5
409 | DeathWindow = 3
410 | V_len = len(data_framewise)
411 | Bcount = 0
412 | Dcount = 0
413 | FrameWidth = 1920
414 | FrameHeight = 1080
415 | BirthBuffer = []
416 | DeathBuffer = []
417 | prev_frame = manual_init[4]
418 | IDnow = 20
419 | NIDnow = -10
420 |
421 | padding = np.zeros((prev_frame.shape[0], 2 * TrackletLen))
422 | prev_frame = np.concatenate((prev_frame, padding), axis=1)
423 |
424 | for i in range(prev_frame.shape[0]):
425 | identity = prev_frame[i, 1]
426 | for frame in range(5):
427 | data = manual_init[4 - frame]
428 | if identity in data[:, 1]:
429 | data_ = data[data[:, 1] == identity].squeeze()
430 | prev_frame[i, 7 + 2 * frame], prev_frame[i, 8 + 2 * frame] = data_[2], data_[3]
431 | else:
432 | prev_frame[i, 7 + 2 * frame], prev_frame[i, 8 + 2 * frame] = prev_frame[i, 5 + 2 * frame], prev_frame[
433 | i, 6 + 2 * frame]
434 |
435 | prev_frame = list(prev_frame)
436 | PrevIDs = [prev_frame[i][1] for i in range(len(prev_frame))]
437 |
438 | res_path = "test/MOT17-01-SDP/{}.txt".format(time_for_file())
439 | log_path = "test/MOT17-01-SDP/log.log"
440 | buffer_path = "test/MOT17-01-SDP/buffer.txt"
441 | if not osp.exists(res_path):
442 | os.mknod(res_path)
443 | res_init = np.loadtxt("test/MOT17-01-SDP/res.txt")
444 |
445 | generator = TestGenerator(res_path, entirety=False)
446 |
447 | with open(res_path, "a") as txt:
448 | for i in range(len(res_init)):
449 | temp = np.array(res_init[i]).reshape([1, -1])
450 | np.savetxt(txt, temp[:, 0:7], fmt='%12.3f')
451 |
452 | "---------------------- start tracking ----------------------"
453 | for v in range(4, V_len - 1):
454 | # v denotes the previous frame and start with 0
455 | prev = v # prev: 1~TotalFrame-1
456 | next = prev + 1
457 | next_frame = list(data_framewise[next])
458 |
459 | FrID4now = next_frame[0][0]
460 |
461 | print ("------------------------------------------> Tracking Frame %d" % FrID4now)
462 |
463 | CurrentDRs = next_frame
464 | PrevIDs = [prev_frame[i][1] for i in range(len(prev_frame))]
465 |
466 | print("---> MAIN: Input to Model --> CurrentDRs")
467 | temp2 = [CurrentDRs[i][1] for i in range(len(CurrentDRs))]
468 | print(temp2)
469 |
470 | print("---> MAIN: Input to Model --> prev_frame")
471 | temp1 = [prev_frame[i][1] for i in range(len(prev_frame))]
472 | print(temp1)
473 |
474 | Amatrix = generator(SeqID=0, frame=next)
475 | Amatrix = Amatrix.unsqueeze(dim=0).cpu().numpy()
476 |
477 |
478 | for i in range(len(CurrentDRs)):
479 | CurrentDRs[i][1] = -1
480 |
481 | '''
482 | Initialize the trajectory sequences by appending the trajectories with 0s in the back
483 | '''
484 | prev_FrID = prev # This FrID is only used to load in the frames from file
485 | next_FrID = prev_FrID + 1
486 |
487 | # Update the FrID for prev_frame
488 | for i in range(len(prev_frame)):
489 | prev_frame[i][0] = FrID4now
490 |
491 | ToPlot, Bbuffer, Dbuffer, IDnow_out, ConfirmedBirthRowID_out, DoomedIDs_out, NIDnow_out, Fail2ConfirmedBirthKill_out = Tracker(
492 | Amatrix=Amatrix, Prev_frame=prev_frame, PrevIDs=PrevIDs, CurrentDRs=CurrentDRs, BirthBuffer=BirthBuffer,
493 | DeathBuffer=DeathBuffer, IDnow=IDnow, FrameW=FrameWidth, FrameH=FrameHeight, NIDnow= NIDnow)
494 |
495 | '''
496 | Also terminate the Doomed trajs out of the DeathBuffer
497 | '''
498 | # for i in range(len(Dbuffer)):
499 | # if Dbuffer[i][0] in ToDo_DoomedTrajID_out:
500 | # del Dbuffer[i]
501 |
502 | '''
503 | When it is negative
504 | '''
505 | for i in range(len(ToPlot)):
506 | if ToPlot[i][2] < 0:
507 | print 'Now negative X, Y'
508 |
509 | '''
510 | ------------------------------------------------------------- Clean out the doomed ID that is killed
511 | '''
512 | DoomedTraj2CleanIndex = []
513 | for i in range(len(ToPlot)):
514 | if ToPlot[i][1] in DoomedIDs_out:
515 | DoomedTraj2CleanIndex.append(i)
516 |
517 | DoomedTraj2CleanIndex.sort(reverse=True)
518 | for i in range(len(DoomedTraj2CleanIndex)):
519 | print '---> MAIN: _+_+_+_+_+_+_+_ Popping DoomedTraj ID is %d'%ToPlot[DoomedTraj2CleanIndex[i]][1]
520 | ToPlot = np.delete(ToPlot, DoomedTraj2CleanIndex[i], 0)
521 |
522 | '''
523 | ----------------------------------------------------- Clean out the temporal birth that have been confirmed in this frame
524 | '''
525 | ConfirmedBirthRowIndex = []
526 | for i in range(len(ToPlot)):
527 | if ToPlot[i][1] in ConfirmedBirthRowID_out:
528 | ConfirmedBirthRowIndex.append(i)
529 |
530 | ConfirmedBirthRowIndex.sort(reverse= True)
531 | for i in range(len(ConfirmedBirthRowIndex)):
532 | print("---> MAIN: _+_+_+_+_+_+_+_ Popping the temporal birth out when a birth is confirmed, ID is %d")%ToPlot[ConfirmedBirthRowIndex[i],1]
533 | ToPlot = np.delete(ToPlot, ConfirmedBirthRowIndex[i], axis=0)
534 |
535 | '''
536 | ---------------------------------------------------- Fail2ConfirmedBirthKill_out kill right away
537 | '''
538 | Fail2ConfirmedBirthKill_out_index = []
539 | for i in range(len(ToPlot)):
540 | if ToPlot[i][1] in Fail2ConfirmedBirthKill_out:
541 | Fail2ConfirmedBirthKill_out_index.append(i)
542 |
543 | Fail2ConfirmedBirthKill_out_index.sort(reverse=True)
544 | for i in range(len(Fail2ConfirmedBirthKill_out_index)):
545 | print '---> MAIN: _+_+_+_+_+_+_+_ ToPlot delete Fail2ConfirmedBirthKill, ID is %d'%ToPlot[Fail2ConfirmedBirthKill_out_index[i],1]
546 | ToPlot = np.delete(ToPlot, Fail2ConfirmedBirthKill_out_index[i], axis = 0)
547 |
548 |
549 | print("---> MAIN: Tracking output--> ToPlot")
550 | temp3 = [ToPlot[i][1] for i in range(len(ToPlot))]
551 | print temp3
552 |
553 | temp6 = [ToPlot[i][2] for i in range(len(ToPlot))]
554 | for i in range(len(temp6)):
555 | if temp6[i] <= 0:
556 | print 'Im the storm'
557 |
558 | print("---> MAIN: Tracking output--> DeathBuffer")
559 | temp1 = [Dbuffer[i] for i in range(len(Dbuffer))]
560 | print temp1
561 |
562 | # print("---> MAIN: Tracking output--> BirthBuffer")
563 | # temp4 = [BirthBuffer[i][1] for i in range(len(BirthBuffer))]
564 | # print temp4
565 | #
566 | print("---> MAIN: Tracking output--> DeathBuffer")
567 | temp4 = [DeathBuffer[i][0] for i in range(len(DeathBuffer))]
568 | temp5 = [DeathBuffer[i][1] for i in range(len(DeathBuffer))]
569 | print temp4,temp5
570 |
571 |
572 | BirthBuffer = Bbuffer
573 | DeathBuffer = Dbuffer
574 | IDnow = IDnow_out
575 | NIDnow = NIDnow_out
576 | prev_frame = ToPlot
577 |
578 | print("Writing")
579 | with open(res_path, "a") as txt:
580 | for i in range(len(ToPlot)):
581 | temp = np.array(ToPlot[i]).reshape([1, -1])
582 | np.savetxt(txt, temp[:, 0:7], fmt='%12.3f')
583 |
584 | with open(log_path, "a") as log:
585 | for i in range(len(ToPlot)):
586 | temp = np.array(ToPlot[i]).reshape([1, -1])
587 | np.savetxt(log, temp, fmt='%12.3f')
588 |
589 | with open(buffer_path, "a") as DeathBufferTXT:
590 | for i in range(len(DeathBuffer)):
591 | temp = np.array(DeathBuffer[i]).reshape([1, -1])
592 | np.savetxt(DeathBufferTXT, temp, fmt="%12.3f", delimiter=",")
593 |
594 | print("Finish Tracking Frame %d" % v)
595 | print("Current Initialized ID is up to: %d" % IDnow)
596 |
--------------------------------------------------------------------------------
/config.yml:
--------------------------------------------------------------------------------
1 | # ------------------------- Options -------------------------
2 | mode : 0 # 0 for debug
3 | description : net_1024
4 | device : "3"
5 | epochs : 40000
6 | gammas : [0.1, 0.1, 0.1]
7 | schedule : [10000, 20000, 30000]
8 | learning_rate : 0.001
9 | optimizer : Adam
10 | entirety : True
11 | model : net_1024
12 | # ----------------------- God Save Me -----------------------
13 | save_model : True
14 | dampening : 0.9
15 | lr_patience : 10
16 | momentum : 0.9
17 | decay : 0.0005
18 | start_epoch : 0
19 | print_freq : 10
20 | checkpoint : 10
21 | n_threads : 2
22 | result : result
23 | # --------------------------- End ---------------------------
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : main.py
3 | # @Author : Peizhao Li
4 | # @Contact : peizhaoli05@gmail.com
5 | # @Date : 2018/9/27
6 |
7 | import os.path as osp
8 | from utils import *
9 | from train import train_EmbeddingNet
10 | from train import train_LSTM
11 | from train import train_FuckUpNet
12 | from train import train_net_1024
13 | import sys
14 | from Generator import Generator
15 | import numpy as np
16 | import random
17 |
18 | seed = 1
19 | torch.backends.cudnn.enabled = True
20 | torch.backends.cudnn.deterministic = True
21 | torch.backends.cudnn.benchmark = True
22 | torch.manual_seed(seed)
23 | torch.cuda.manual_seed(seed)
24 | np.random.seed(seed)
25 | random.seed(seed)
26 |
27 | config = osp.join(os.path.abspath(os.curdir), "config.yml")
28 | parser, settings_show = Config(config)
29 | os.environ["CUDA_VISIBLE_DEVICES"] = parser.device
30 |
31 | if parser.mode == 0:
32 | log_path = osp.join(parser.result, 'debug')
33 | else:
34 | log_path = osp.join(parser.result, '{}-{}'.format(time_for_file(), parser.description))
35 | if not osp.exists(log_path):
36 | os.mkdir(log_path)
37 | log = open(osp.join(log_path, 'log.log'), 'w')
38 |
39 | print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
40 | print_log("torch version : {}".format(torch.__version__), log)
41 | print_log("cudnn version : {}".format(torch.backends.cudnn.version()), log)
42 | for idx, data in enumerate(settings_show):
43 | print_log(data, log)
44 |
45 | generator = Generator(entirety=parser.entirety)
46 |
47 | if parser.model == "EmbeddingNet":
48 | train_EmbeddingNet.train(parser, generator, log, log_path)
49 | elif parser.model == "lstm":
50 | train_LSTM.train(parser, generator, log, log_path)
51 | elif parser.model == "FuckUpNet":
52 | train_FuckUpNet.train(parser, generator, log, log_path)
53 | elif parser.model == "net_1024":
54 | train_net_1024.train(parser, generator, log, log_path)
55 | else:
56 | raise NotImplementedError
57 |
58 | log.close()
59 |
--------------------------------------------------------------------------------
/model/EmbeddingNet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : EmbeddingNet.py
3 | # @Author : Peizhao Li
4 | # @Contact : lipeizhao1997@gmail.com
5 | # @Date : 2018/10/21
6 |
7 | import torch
8 | from torch import nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class LSTM(nn.Module):
13 |
14 | def __init__(self, hidden_size):
15 | super(LSTM, self).__init__()
16 | self.lstm = nn.LSTM(input_size=2, hidden_size=hidden_size, num_layers=1, batch_first=True)
17 |
18 | def forward(self, x):
19 | x, _ = self.lstm(x)
20 | x = x[:, -1, :].clone()
21 |
22 | return x
23 |
24 |
25 | class ANet(nn.Module):
26 |
27 | def __init__(self):
28 | super(ANet, self).__init__()
29 | self.ndf = 32
30 |
31 | self.conv1 = nn.Conv2d(3, self.ndf, kernel_size=3, stride=1, padding=1, bias=False)
32 | self.conv2 = nn.Conv2d(self.ndf, int(self.ndf * 1.5), kernel_size=3, stride=1, padding=1, bias=False)
33 | self.conv3 = nn.Conv2d(int(self.ndf * 1.5), self.ndf * 2, kernel_size=3, stride=1, padding=1, bias=False)
34 | self.conv4 = nn.Conv2d(self.ndf * 2, self.ndf * 4, kernel_size=3, stride=1, padding=1, bias=False)
35 |
36 | for m in self.modules():
37 | if isinstance(m, nn.Conv2d):
38 | nn.init.xavier_uniform_(m.weight)
39 |
40 | def forward(self, x):
41 | x = F.max_pool2d(self.conv1(x), 2)
42 | x = F.leaky_relu(x)
43 | x = F.max_pool2d(self.conv2(x), 2)
44 | x = F.leaky_relu(x)
45 | x = F.max_pool2d(self.conv3(x), 2)
46 | x = F.leaky_relu(x)
47 | x = F.max_pool2d(self.conv4(x), 2)
48 | x = F.leaky_relu(x)
49 |
50 | x = F.avg_pool2d(x, kernel_size=(5, 2))
51 | x = x.view(x.size(0), -1)
52 |
53 | return x
54 |
55 |
56 | class EmbeddingNet_legacy(nn.Module):
57 |
58 | def __init__(self):
59 | super(EmbeddingNet_legacy, self).__init__()
60 | self.outsize = 128 + 2
61 |
62 | self.ANet = ANet()
63 | self.lstm = LSTM(hidden_size=2)
64 |
65 | self.fc1 = nn.Linear(self.outsize, int(self.outsize / 2))
66 | self.fc2 = nn.Linear(int(self.outsize / 2), int(self.outsize / 4))
67 | self.fc3 = nn.Linear(int(self.outsize / 4), 2)
68 |
69 | for m in self.modules():
70 | if isinstance(m, nn.Linear):
71 | nn.init.normal_(m.weight, mean=0, std=0.01)
72 | nn.init.constant_(m.bias, 0.1)
73 |
74 | self.drop1 = nn.Dropout(p=0.3)
75 | self.drop2 = nn.Dropout(p=0.2)
76 |
77 | def forward(self, pre_crop, cur_crop, pre_coord, cur_coord):
78 | pre_crop = self.ANet(pre_crop)
79 | cur_crop = self.ANet(cur_crop)
80 |
81 | pre_coord = self.lstm(pre_coord)
82 |
83 | pre = torch.cat((pre_crop, pre_coord), dim=1)
84 | cur = torch.cat((cur_crop, cur_coord), dim=1)
85 |
86 | x = F.leaky_relu_(self.fc1(pre.add_(-cur)))
87 | x = self.drop1(x)
88 | x = F.leaky_relu_(self.fc2(x))
89 | x = self.drop2(x)
90 | x = self.fc3(x)
91 |
92 | return x, pre, cur
93 |
94 |
95 | class EmbeddingNet_train(nn.Module):
96 |
97 | def __init__(self):
98 | super(EmbeddingNet_train, self).__init__()
99 | self.ANet = ANet()
100 | self.lstm = LSTM(hidden_size=16)
101 | self.fc = nn.Linear(2, 16)
102 |
103 | for m in self.modules():
104 | if isinstance(m, nn.Linear):
105 | nn.init.normal_(m.weight, mean=0, std=0.01)
106 | nn.init.constant_(m.bias, 0.1)
107 |
108 | self.crop_fc1 = nn.Linear(128, 64)
109 | self.crop_fc2 = nn.Linear(64, 32)
110 | self.crop_fc3 = nn.Linear(32, 2)
111 |
112 | self.coord_fc1 = nn.Linear(16, 8)
113 | self.coord_fc2 = nn.Linear(8, 2)
114 |
115 | self.com = nn.Linear(4, 2)
116 |
117 | def forward(self, pre_crop, cur_crop, pre_coord, cur_coord):
118 | pre_crop = self.ANet(pre_crop)
119 | cur_crop = self.ANet(cur_crop)
120 | pre_coord = self.lstm(pre_coord)
121 | cur_coord = self.fc(cur_coord)
122 |
123 | crop = F.leaky_relu(self.crop_fc1(pre_crop.add_(-cur_crop)))
124 | crop = F.leaky_relu(self.crop_fc2(crop))
125 | crop = self.crop_fc3(crop)
126 |
127 | coord = F.leaky_relu(self.coord_fc1(pre_coord.add_(-cur_coord)))
128 | coord = self.coord_fc2(coord)
129 |
130 | com = torch.cat((crop, coord), dim=1)
131 | com = self.com(com)
132 |
133 | pre_feature = torch.cat((pre_crop, pre_coord), dim=1)
134 | cur_feature = torch.cat((cur_crop, cur_coord), dim=1)
135 |
136 | return com, pre_feature, cur_feature
137 |
138 |
139 | class EmbeddingNet(nn.Module):
140 |
141 | def __init__(self):
142 | super(EmbeddingNet, self).__init__()
143 | self.ANet = ANet()
144 | self.lstm = LSTM(hidden_size=16)
145 | self.fc = nn.Linear(2, 16)
146 |
147 | for m in self.modules():
148 | if isinstance(m, nn.Linear):
149 | nn.init.normal_(m.weight, mean=0, std=0.01)
150 | nn.init.constant_(m.bias, 0.1)
151 |
152 | self.crop_fc1 = nn.Linear(128, 64)
153 | self.crop_fc2 = nn.Linear(64, 32)
154 | self.crop_fc3 = nn.Linear(32, 2)
155 |
156 | self.coord_fc1 = nn.Linear(16, 8)
157 | self.coord_fc2 = nn.Linear(8, 2)
158 |
159 | self.com = nn.Linear(4, 2)
160 |
161 | def forward(self, pre_crop, cur_crop, pre_coord, cur_coord):
162 | pre_crop = self.ANet(pre_crop)
163 | cur_crop = self.ANet(cur_crop)
164 | pre_coord = self.lstm(pre_coord)
165 | cur_coord = self.fc(cur_coord)
166 |
167 | crop = F.leaky_relu(self.crop_fc1(pre_crop.add_(-cur_crop)))
168 | crop = F.leaky_relu(self.crop_fc2(crop))
169 | crop = self.crop_fc3(crop)
170 |
171 | coord = F.leaky_relu(self.coord_fc1(pre_coord.add_(-cur_coord)))
172 | coord = self.coord_fc2(coord)
173 |
174 | com = torch.cat((crop, coord), dim=1)
175 | com = self.com(com)
176 |
177 | pre_feature = torch.cat((pre_crop, pre_coord), dim=1)
178 | cur_feature = torch.cat((cur_crop, cur_coord), dim=1)
179 |
180 | return com, pre_feature, cur_feature
181 |
--------------------------------------------------------------------------------
/model/FuckUpNet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : FuckUpNet.py
3 | # @Author : Peizhao Li
4 | # @Contact : lipeizhao1997@gmail.com
5 | # @Date : 2018/10/13
6 |
7 | import torch
8 | from torch import nn
9 | import torch.nn.functional as F
10 | from EmbeddingNet import EmbeddingNet
11 | from EmbeddingNet import EmbeddingNet_train
12 | from GCN import GCN
13 | from GCN import fuckupGCN
14 |
15 |
16 | class FuckUpNet(nn.Module):
17 |
18 | def __init__(self):
19 | super(FuckUpNet, self).__init__()
20 | self.embeddingnet = EmbeddingNet()
21 | self.gc = GCN(planes=144)
22 |
23 | def forward(self, pre_crop, cur_crop, pre_motion, cur_motion):
24 | matrix = torch.zeros(len(pre_crop), len(cur_crop)).cuda()
25 | pre_feature = torch.zeros(len(pre_crop), 144).cuda()
26 | cur_feature = torch.zeros(len(cur_crop), 144).cuda()
27 |
28 | cur_num = len(cur_crop)
29 | pre_num = len(pre_crop)
30 |
31 | for i in range(pre_num):
32 | pre_crop_ = pre_crop[i].cuda().unsqueeze_(dim=0)
33 | pre_motion_ = pre_motion[i].cuda().unsqueeze_(dim=0)
34 | for j in range(cur_num):
35 | cur_crop_ = cur_crop[j].cuda().unsqueeze_(dim=0)
36 | cur_motion_ = cur_motion[j].cuda().unsqueeze_(dim=0)
37 | score, pre, cur = self.embeddingnet(pre_crop_, cur_crop_, pre_motion_, cur_motion_)
38 | matrix[i, j] = score[0, 1]
39 | pre_feature[i, :] = pre
40 | cur_feature[j, :] = cur
41 |
42 | # matrix = self.gc(pre_feature, cur_feature, matrix)
43 |
44 | return matrix
45 |
46 |
47 | class fuckupnet(nn.Module):
48 |
49 | def __init__(self):
50 | super(fuckupnet, self).__init__()
51 | self.embeddingnet = EmbeddingNet_train()
52 | self.gc = fuckupGCN(planes=144)
53 |
54 | def forward(self, pre_crop, cur_crop, pre_motion, cur_motion):
55 | matrix = torch.zeros(len(pre_crop), len(cur_crop)).cuda()
56 | score0 = torch.zeros(len(pre_crop) * len(cur_crop), 2).cuda()
57 | pre_feature = torch.zeros(len(pre_crop), 144).cuda()
58 | cur_feature = torch.zeros(len(cur_crop), 144).cuda()
59 |
60 | cur_num = len(cur_crop)
61 | pre_num = len(pre_crop)
62 |
63 | for i in range(pre_num):
64 | pre_crop_ = pre_crop[i].cuda().unsqueeze_(dim=0)
65 | pre_motion_ = pre_motion[i].cuda().unsqueeze_(dim=0)
66 | for j in range(cur_num):
67 | cur_crop_ = cur_crop[j].cuda().unsqueeze_(dim=0)
68 | cur_motion_ = cur_motion[j].cuda().unsqueeze_(dim=0)
69 | score, pre, cur = self.embeddingnet(pre_crop_, cur_crop_, pre_motion_, cur_motion_)
70 | score = score.squeeze_()
71 | score0[i * cur_num + j] = score
72 | matrix[i, j] = score[1]
73 | pre_feature[i, :] = pre
74 | cur_feature[j, :] = cur
75 |
76 | # score1, score2, matrix = self.gc(pre_feature, cur_feature, matrix)
77 |
78 | return score0, matrix
79 |
80 |
81 | class embnet(nn.Module):
82 |
83 | def __init__(self):
84 | super(embnet, self).__init__()
85 | self.embeddingnet = EmbeddingNet_train()
86 |
87 | def forward(self, pre_crop, cur_crop, pre_motion, cur_motion):
88 | matrix = torch.zeros(len(pre_crop), len(cur_crop)).cuda()
89 | pre_feature = torch.zeros(len(pre_crop), 144).cuda()
90 | cur_feature = torch.zeros(len(cur_crop), 144).cuda()
91 |
92 | cur_num = len(cur_crop)
93 | pre_num = len(pre_crop)
94 |
95 | for i in range(pre_num):
96 | pre_crop_ = pre_crop[i].cuda().unsqueeze_(dim=0)
97 | pre_motion_ = pre_motion[i].cuda().unsqueeze_(dim=0)
98 | for j in range(cur_num):
99 | cur_crop_ = cur_crop[j].cuda().unsqueeze_(dim=0)
100 | cur_motion_ = cur_motion[j].cuda().unsqueeze_(dim=0)
101 | score, pre, cur = self.embeddingnet(pre_crop_, cur_crop_, pre_motion_, cur_motion_)
102 | score = score.squeeze_()
103 | matrix[i, j] = score[1]
104 | pre_feature[i, :] = pre
105 | cur_feature[j, :] = cur
106 |
107 | return matrix, pre_feature, cur_feature
108 |
109 |
110 | class uninet(nn.Module):
111 |
112 | def __init__(self):
113 | super(uninet, self).__init__()
114 | self.embnet = embnet()
115 | self.gc = fuckupGCN(planes=144)
116 |
117 | def forward(self, pre_crop, cur_crop, pre_motion, cur_motion):
118 | matrix, pre, cur = self.embnet(pre_crop, cur_crop, pre_motion, cur_motion)
119 | score1, score2, matrix = self.gc(pre, cur, matrix)
120 |
121 | return matrix, score1, score2
122 |
--------------------------------------------------------------------------------
/model/GCN.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : GCN.py
3 | # @Author : Peizhao Li
4 | # @Contact : lipeizhao1997@gmail.com
5 | # @Date : 2018/10/22
6 |
7 | import math
8 | import torch
9 | from torch.nn.parameter import Parameter
10 | from torch.nn.modules.module import Module
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 |
15 | class GraphConvolution(Module):
16 |
17 | def __init__(self, in_features, out_features):
18 | super(GraphConvolution, self).__init__()
19 | self.in_features = in_features
20 | self.out_features = out_features
21 | self.weight = Parameter(torch.FloatTensor(in_features, out_features))
22 |
23 | self.reset_parameters()
24 |
25 | def reset_parameters(self):
26 | stdv = 1. / math.sqrt(self.weight.size(1))
27 | self.weight.data.uniform_(-stdv, stdv)
28 |
29 | def forward(self, pre, cur, adj):
30 | pre_ = torch.mm(pre, self.weight)
31 | cur_ = torch.mm(cur, self.weight)
32 |
33 | pre = torch.mm(adj, cur_)
34 | cur = torch.mm(adj.t(), pre_)
35 |
36 | pre = F.leaky_relu(pre)
37 | cur = F.leaky_relu(cur)
38 |
39 | return pre, cur
40 |
41 | def __repr__(self):
42 | return self.__class__.__name__ + ' (' \
43 | + str(self.in_features) + ' -> ' \
44 | + str(self.out_features) + ')'
45 |
46 |
47 | class GCN(nn.Module):
48 |
49 | def __init__(self, planes):
50 | super(GCN, self).__init__()
51 |
52 | self.gc1 = GraphConvolution(planes, planes)
53 | self.gc2 = GraphConvolution(planes, planes)
54 |
55 | def cos_sim(self, pre, cur, adj):
56 | adj_ = torch.zeros_like(adj).cuda()
57 | for i in range(pre.size(0)):
58 | for j in range(cur.size(0)):
59 | adj_[i, j] = F.cosine_similarity(pre[i:i + 1], cur[j:j + 1])
60 |
61 | return adj_
62 |
63 | def forward(self, pre, cur, adj):
64 | pre, cur = self.gc1(pre, cur, adj)
65 | adj = self.cos_sim(pre, cur, adj)
66 | pre, cur = self.gc2(pre, cur, adj)
67 | adj = self.cos_sim(pre, cur, adj)
68 |
69 | return adj
70 |
71 |
72 | class fuckupGCN(nn.Module):
73 |
74 | def __init__(self, planes):
75 | super(fuckupGCN, self).__init__()
76 |
77 | self.gc1 = GraphConvolution(planes, planes)
78 | self.gc2 = GraphConvolution(planes, planes)
79 |
80 | self.fc1 = nn.Sequential(nn.Linear(144, 72), nn.LeakyReLU(), nn.Linear(72, 36), nn.LeakyReLU(),
81 | nn.Linear(36, 2))
82 | self.fc2 = nn.Sequential(nn.Linear(144, 72), nn.LeakyReLU(), nn.Linear(72, 36), nn.LeakyReLU(),
83 | nn.Linear(36, 2))
84 |
85 | def MLP(self, fc, pre, cur, adj):
86 | score = torch.zeros(pre.size(0) * cur.size(0), 2).cuda()
87 | adj_ = torch.zeros_like(adj).cuda()
88 |
89 | for i in range(pre.size(0)):
90 | pre_ = pre[i].unsqueeze_(dim=0)
91 | for j in range(cur.size(0)):
92 | cur_ = cur[j].unsqueeze_(dim=0)
93 | tmp = pre_ - cur_
94 | score_ = fc(tmp)
95 | score_ = score_.squeeze_()
96 | score[i * cur.size(0) + j] = score_
97 | adj_[i, j] = score_[1]
98 |
99 | return score, adj_
100 |
101 | def forward(self, pre, cur, adj):
102 | pre, cur = self.gc1(pre, cur, adj)
103 | score1, adj = self.MLP(self.fc1, pre, cur, adj)
104 | pre, cur = self.gc2(pre, cur, adj)
105 | score2, adj = self.MLP(self.fc2, pre, cur, adj)
106 |
107 | return score1, score2, adj
108 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : __init__.py
3 | # @Author : Peizhao Li
4 | # @Contact : lipeizhao1997@gmail.com
5 | # @Date : 2018/10/26
6 |
--------------------------------------------------------------------------------
/model/final.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : final.py
3 | # @Author : Peizhao Li
4 | # @Contact : peizhaoli05gmail.com
5 | # @Date : 2018/11/10
6 |
7 | import math
8 | import torch
9 | from torch import nn
10 | import torch.nn.functional as F
11 | from torch.nn.parameter import Parameter
12 | from torch.nn.modules.module import Module
13 | from torchvision import models
14 |
15 |
16 | class ANet(nn.Module):
17 |
18 | def __init__(self):
19 | super(ANet, self).__init__()
20 |
21 | self.ANet = models.resnet18(pretrained=True)
22 | self.ANet.fc = nn.Linear(512, 256)
23 |
24 | for m in self.modules():
25 | if isinstance(m, nn.Linear):
26 | nn.init.normal_(m.weight, mean=0, std=0.01)
27 |
28 | def forward(self, pre_crop, cur_crop):
29 | pre_crop = self.ANet(pre_crop)
30 | cur_crop = self.ANet(cur_crop)
31 |
32 | crop = torch.tan(F.cosine_similarity(pre_crop, cur_crop))
33 |
34 | pre_feature = pre_crop
35 | cur_feature = cur_crop
36 |
37 | return crop, pre_feature, cur_feature
38 |
39 |
40 | class GraphConvolution(Module):
41 |
42 | def __init__(self, in_features, out_features):
43 | super(GraphConvolution, self).__init__()
44 | self.in_features = in_features
45 | self.out_features = out_features
46 | self.weight = Parameter(torch.FloatTensor(in_features, out_features))
47 |
48 | self.reset_parameters()
49 |
50 | def reset_parameters(self):
51 | stdv = 1. / math.sqrt(self.weight.size(1))
52 | self.weight.data.uniform_(-stdv, stdv)
53 |
54 | def adj_norm(self, adj):
55 | adj_norm = adj
56 | adj_t_norm = adj.t()
57 |
58 | return adj_norm, adj_t_norm
59 |
60 | def forward(self, pre, cur, adj):
61 | pre_ = torch.mm(pre, self.weight)
62 | cur_ = torch.mm(cur, self.weight)
63 |
64 | adj_norm, adj_t_norm = self.adj_norm(adj)
65 |
66 | pre = torch.mm(adj_norm, cur_)
67 | cur = torch.mm(adj_t_norm, pre_)
68 |
69 | return pre, cur
70 |
71 | def __repr__(self):
72 | return self.__class__.__name__ + ' (' \
73 | + str(self.in_features) + ' -> ' \
74 | + str(self.out_features) + ')'
75 |
76 |
77 | class GCN(nn.Module):
78 |
79 | def __init__(self, planes):
80 | super(GCN, self).__init__()
81 |
82 | self.gc = GraphConvolution(planes, planes)
83 |
84 | for m in self.modules():
85 | if isinstance(m, nn.Linear):
86 | nn.init.normal_(m.weight, mean=0, std=0.01)
87 |
88 | def edge_update(self, pre, cur, adj):
89 | score = torch.zeros(pre.size(0) * cur.size(0)).cuda()
90 | adj_ = torch.zeros_like(adj).cuda()
91 |
92 | for i in range(pre.size(0)):
93 | pre_ = pre[i].unsqueeze(dim=0)
94 | for j in range(cur.size(0)):
95 | cur_ = cur[j].unsqueeze(dim=0)
96 | score_ = torch.tan(F.cosine_similarity(pre_, cur_))
97 | score[i * cur.size(0) + j] = score_
98 | adj_[i, j] = score_
99 |
100 | return score, adj_
101 |
102 | def forward(self, pre, cur, adj):
103 | pre, cur = self.gc(pre, cur, adj)
104 | score, adj = self.edge_update(pre, cur, adj)
105 |
106 | return score, adj
107 |
108 |
109 | class final(nn.Module):
110 |
111 | def __init__(self):
112 | super(final, self).__init__()
113 | self.embnet = ANet()
114 | self.gc = GCN(planes=256)
115 |
116 | def forward(self, pre_crop, cur_crop):
117 | cur_num = len(cur_crop)
118 | pre_num = len(pre_crop)
119 |
120 | adj1 = torch.zeros(pre_num, cur_num).cuda()
121 | pre_feature = torch.zeros(pre_num, 256).cuda()
122 | cur_feature = torch.zeros(cur_num, 256).cuda()
123 | s0 = torch.zeros(pre_num * cur_num).cuda()
124 |
125 | for i in range(pre_num):
126 | pre_crop_ = pre_crop[i].cuda().unsqueeze(dim=0)
127 | for j in range(cur_num):
128 | cur_crop_ = cur_crop[j].cuda().unsqueeze(dim=0)
129 |
130 | score0_, pre, cur = self.embnet(pre_crop_, cur_crop_)
131 | adj1[i, j] = score0_
132 | s0[i * cur_num + j] = score0_
133 | pre_feature[i, :] = pre
134 | cur_feature[j, :] = cur
135 |
136 | s3, adj = self.gc(pre_feature, cur_feature, adj1)
137 |
138 | return s0, s3, adj
139 |
--------------------------------------------------------------------------------
/model/net_1024.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : net_1024.py
3 | # @Author : Peizhao Li
4 | # @Contact : lipeizhao1997@gmail.com
5 | # @Date : 2018/10/24
6 |
7 | import math
8 | import torch
9 | from torch import nn
10 | import torch.nn.functional as F
11 | from torch.nn.parameter import Parameter
12 | from torch.nn.modules.module import Module
13 |
14 |
15 | def conv3x3(in_planes, out_planes, stride=1):
16 | """3x3 convolution with padding"""
17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
18 | padding=1, bias=False)
19 |
20 |
21 | class convblock(nn.Module):
22 |
23 | def __init__(self, inplanes, planes, stride=1):
24 | super(convblock, self).__init__()
25 | self.conv1 = conv3x3(inplanes, planes, stride)
26 | self.relu = nn.ReLU(inplace=True)
27 | self.conv2 = conv3x3(planes, planes)
28 |
29 | def forward(self, x):
30 | out = self.conv1(x)
31 | out = self.relu(out)
32 | out = self.conv2(out)
33 | out = self.relu(out)
34 |
35 | return out
36 |
37 |
38 | class ANet_2(nn.Module):
39 |
40 | def __init__(self):
41 | super(ANet_2, self).__init__()
42 | self.ndf = 32
43 |
44 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3, bias=False)
45 | self.conv2 = convblock(32, 48)
46 | self.conv3 = convblock(48, 64)
47 | self.conv4 = convblock(64, 128)
48 |
49 | for m in self.modules():
50 | if isinstance(m, nn.Conv2d):
51 | nn.init.xavier_uniform_(m.weight)
52 |
53 | def forward(self, x):
54 | x = F.relu(self.conv1(x))
55 | x = F.max_pool2d(self.conv2(x), 2)
56 | x = F.max_pool2d(self.conv3(x), 2)
57 | x = F.max_pool2d(self.conv4(x), 2)
58 |
59 | x = F.avg_pool2d(x, kernel_size=(5, 2))
60 | x = x.view(x.size(0), -1)
61 |
62 | return x
63 |
64 |
65 | class ANet(nn.Module):
66 |
67 | def __init__(self):
68 | super(ANet, self).__init__()
69 | self.ndf = 32
70 |
71 | self.conv1 = nn.Conv2d(3, self.ndf, kernel_size=3, stride=1, padding=1, bias=False)
72 | self.conv2 = nn.Conv2d(self.ndf, int(self.ndf * 1.5), kernel_size=3, stride=1, padding=1, bias=False)
73 | self.conv3 = nn.Conv2d(int(self.ndf * 1.5), self.ndf * 2, kernel_size=3, stride=1, padding=1, bias=False)
74 | self.conv4 = nn.Conv2d(self.ndf * 2, self.ndf * 4, kernel_size=3, stride=1, padding=1, bias=False)
75 |
76 | for m in self.modules():
77 | if isinstance(m, nn.Conv2d):
78 | nn.init.xavier_uniform_(m.weight)
79 |
80 | def forward(self, x):
81 | x = F.max_pool2d(self.conv1(x), 2)
82 | x = F.relu(x)
83 | x = F.max_pool2d(self.conv2(x), 2)
84 | x = F.relu(x)
85 | x = F.max_pool2d(self.conv3(x), 2)
86 | x = F.relu(x)
87 | x = F.max_pool2d(self.conv4(x), 2)
88 | x = F.relu(x)
89 |
90 | x = F.avg_pool2d(x, kernel_size=(5, 2))
91 | x = x.view(x.size(0), -1)
92 |
93 | return x
94 |
95 |
96 | class LSTM(nn.Module):
97 |
98 | def __init__(self, hidden_size):
99 | super(LSTM, self).__init__()
100 | self.lstm = nn.LSTM(input_size=2, hidden_size=hidden_size, num_layers=1, batch_first=True)
101 |
102 | def forward(self, x):
103 | x, _ = self.lstm(x)
104 | x = x[:, -1, :].clone()
105 |
106 | return x
107 |
108 |
109 | class embnet(nn.Module):
110 |
111 | def __init__(self):
112 | super(embnet, self).__init__()
113 | self.ANet = ANet()
114 | self.lstm = LSTM(hidden_size=16)
115 | self.fc = nn.Linear(2, 16, bias=False)
116 |
117 | self.crop_fc1 = nn.Linear(128, 64, bias=False)
118 | self.crop_fc2 = nn.Linear(64, 32, bias=False)
119 | self.crop_fc3 = nn.Linear(32, 1, bias=False)
120 |
121 | self.coord_fc1 = nn.Linear(16, 8, bias=False)
122 | self.coord_fc2 = nn.Linear(8, 1, bias=False)
123 |
124 | self.com = nn.Linear(2, 1, bias=False)
125 |
126 | for m in self.modules():
127 | if isinstance(m, nn.Linear):
128 | nn.init.normal_(m.weight, mean=0, std=0.01)
129 |
130 | def forward(self, pre_crop, cur_crop, pre_coord, cur_coord):
131 | pre_crop = self.ANet(pre_crop)
132 | cur_crop = self.ANet(cur_crop)
133 | pre_coord = self.lstm(pre_coord)
134 | cur_coord = self.fc(cur_coord)
135 |
136 | temp_crop = pre_crop.sub(cur_crop)
137 | temp_coord = pre_coord.sub(cur_coord)
138 |
139 | crop = F.relu(self.crop_fc1(temp_crop))
140 | crop = F.relu(self.crop_fc2(crop))
141 | crop = self.crop_fc3(crop)
142 |
143 | coord = F.relu(self.coord_fc1(temp_coord))
144 | coord = self.coord_fc2(coord)
145 |
146 | com = torch.cat((crop, coord), dim=1)
147 | com = self.com(com)
148 |
149 | pre_feature = torch.cat((pre_crop, pre_coord), dim=1)
150 | cur_feature = torch.cat((cur_crop, cur_coord), dim=1)
151 |
152 | return com, crop, coord, pre_feature, cur_feature
153 |
154 |
155 | class GraphConvolution(Module):
156 |
157 | def __init__(self, in_features, out_features):
158 | super(GraphConvolution, self).__init__()
159 | self.in_features = in_features
160 | self.out_features = out_features
161 | self.weight = Parameter(torch.FloatTensor(in_features, out_features))
162 |
163 | self.reset_parameters()
164 |
165 | def reset_parameters(self):
166 | stdv = 1. / math.sqrt(self.weight.size(1))
167 | self.weight.data.uniform_(-stdv, stdv)
168 |
169 | def adj_norm(self, adj):
170 | adj_norm = F.softmax(adj, dim=1)
171 | adj_t_norm = F.softmax(adj.t(), dim=1)
172 |
173 | return adj_norm, adj_t_norm
174 |
175 | def forward(self, pre, cur, adj):
176 | pre_ = torch.mm(pre, self.weight)
177 | cur_ = torch.mm(cur, self.weight)
178 |
179 | adj_norm, adj_t_norm = self.adj_norm(adj)
180 |
181 | pre = torch.mm(adj_norm, cur_)
182 | cur = torch.mm(adj_t_norm, pre_)
183 |
184 | pre = F.relu_(pre)
185 | cur = F.relu_(cur)
186 |
187 | return pre, cur
188 |
189 | def __repr__(self):
190 | return self.__class__.__name__ + ' (' \
191 | + str(self.in_features) + ' -> ' \
192 | + str(self.out_features) + ')'
193 |
194 |
195 | class GCN(nn.Module):
196 |
197 | def __init__(self, planes):
198 | super(GCN, self).__init__()
199 |
200 | self.gc = GraphConvolution(planes, planes)
201 | # self.gc2 = GraphConvolution(planes, planes)
202 |
203 | self.fc1 = nn.Sequential(nn.Linear(144, 72, bias=False), nn.ReLU(), nn.Linear(72, 36, bias=False), nn.ReLU(),
204 | nn.Linear(36, 1, bias=False))
205 | # self.fc2 = nn.Sequential(nn.Linear(144, 72, bias=False), nn.ReLU(), nn.Linear(72, 36, bias=False), nn.ReLU(),
206 | # nn.Linear(36, 1, bias=False))
207 |
208 | for m in self.modules():
209 | if isinstance(m, nn.Linear):
210 | nn.init.normal_(m.weight, mean=0, std=0.01)
211 |
212 | def MLP(self, fc, pre, cur, adj):
213 | score = torch.zeros(pre.size(0) * cur.size(0)).cuda()
214 | adj_ = torch.zeros_like(adj).cuda()
215 |
216 | for i in range(pre.size(0)):
217 | pre_ = pre[i].unsqueeze(dim=0)
218 | for j in range(cur.size(0)):
219 | cur_ = cur[j].unsqueeze(dim=0)
220 | temp = pre_.sub(cur_)
221 | score_ = fc(temp)
222 | score_ = score_.squeeze()
223 | score[i * cur.size(0) + j] = score_
224 | adj_[i, j] = score_
225 |
226 | return score, adj_
227 |
228 | def forward(self, pre, cur, adj):
229 | pre, cur = self.gc(pre, cur, adj)
230 | score, adj = self.MLP(self.fc1, pre, cur, adj)
231 | # pre, cur = self.gc2(pre, cur, adj)
232 | # score2, adj = self.MLP(self.fc2, pre, cur, adj)
233 |
234 | return score, adj
235 |
236 |
237 | class net_1024(nn.Module):
238 |
239 | def __init__(self):
240 | super(net_1024, self).__init__()
241 | self.embnet = embnet()
242 | self.gc = GCN(planes=144)
243 |
244 | def forward(self, pre_crop, cur_crop, pre_motion, cur_motion):
245 | cur_num = len(cur_crop)
246 | pre_num = len(pre_crop)
247 |
248 | adj1 = torch.zeros(pre_num, cur_num).cuda()
249 | pre_feature = torch.zeros(pre_num, 144).cuda()
250 | cur_feature = torch.zeros(cur_num, 144).cuda()
251 | s0 = torch.zeros(pre_num * cur_num).cuda()
252 | s1 = torch.zeros(pre_num * cur_num).cuda()
253 | s2 = torch.zeros(pre_num * cur_num).cuda()
254 |
255 | for i in range(pre_num):
256 | pre_crop_ = pre_crop[i].cuda().unsqueeze(dim=0)
257 | pre_motion_ = pre_motion[i].cuda().unsqueeze(dim=0)
258 | for j in range(cur_num):
259 | cur_crop_ = cur_crop[j].cuda().unsqueeze(dim=0)
260 | cur_motion_ = cur_motion[j].cuda().unsqueeze(dim=0)
261 |
262 | score0_, score1_, score2_, pre, cur = self.embnet(pre_crop_, cur_crop_, pre_motion_, cur_motion_)
263 | adj1[i, j] = score0_
264 | s0[i * cur_num + j] = score0_
265 | s1[i * cur_num + j] = score1_
266 | s2[i * cur_num + j] = score2_
267 | pre_feature[i, :] = pre
268 | cur_feature[j, :] = cur
269 |
270 | s3, adj = self.gc(pre_feature, cur_feature, adj1)
271 |
272 | return s0, s1, s2, s3, adj1, adj
273 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.11.0
2 | scipy==1.1.0
3 | torchvision==0.2.1
4 | opencv_python==3.3.0.10
5 | easydict==1.7
6 | torch==0.4.1
7 | Pillow==6.0.0
8 | motmetrics==1.1.3
9 | PyYAML==5.1
10 |
--------------------------------------------------------------------------------
/setting/ADL-Rundle-1_config.yml:
--------------------------------------------------------------------------------
1 | # -------- Tracking Setting for Sequence ADL-Rundle-1 --------
2 | ID_assign_init : 30
3 | ID_birth_init : -10
4 | DeathBufferLength : 2000
5 | BirthBufferLength : 5000
6 | DeathCount : 10
7 | BirthCount : 5
8 | Threshold : -3
9 | Distance : 0.05
10 | BoxRation : 0.3
11 | FrameWidth : 1920
12 | FrameHeight : 1080
13 | fps : 30
14 | PredictThreshold : 0.003
15 | # ------------------------------------------------------------
--------------------------------------------------------------------------------
/setting/ADL-Rundle-3_config.yml:
--------------------------------------------------------------------------------
1 | # -------- Tracking Setting for Sequence ADL-Rundle-3 --------
2 | ID_assign_init : 30
3 | ID_birth_init : -10
4 | DeathBufferLength : 2000
5 | BirthBufferLength : 5000
6 | DeathCount : 8
7 | BirthCount : 4
8 | Threshold : -3
9 | Distance : 0.05
10 | BoxRation : 0.3
11 | FrameWidth : 1920
12 | FrameHeight : 1080
13 | fps : 15
14 | PredictThreshold : 0.005
15 | # ------------------------------------------------------------
--------------------------------------------------------------------------------
/setting/AVG-TownCentre_config.yml:
--------------------------------------------------------------------------------
1 | # ------- Tracking Setting for Sequence AVG-TownCentre -------
2 | ID_assign_init : 30
3 | ID_birth_init : -10
4 | DeathBufferLength : 2000
5 | BirthBufferLength : 5000
6 | DeathCount : 10
7 | BirthCount : 2
8 | Threshold : -3
9 | Distance : 0.075
10 | BoxRation : 0.3
11 | FrameWidth : 1920
12 | FrameHeight : 1080
13 | fps : 10
14 | PredictThreshold : 0.0075
15 | # ------------------------------------------------------------
--------------------------------------------------------------------------------
/setting/ETH-Crossing_config.yml:
--------------------------------------------------------------------------------
1 | # -------- Tracking Setting for Sequence ETH-Crossing --------
2 | ID_assign_init : 30
3 | ID_birth_init : -10
4 | DeathBufferLength : 2000
5 | BirthBufferLength : 5000
6 | DeathCount : 8
7 | BirthCount : 1
8 | Threshold : -3
9 | Distance : 0.075
10 | BoxRation : 0.3
11 | FrameWidth : 640
12 | FrameHeight : 480
13 | fps : 14
14 | PredictThreshold : 0.0075
15 | # ------------------------------------------------------------
--------------------------------------------------------------------------------
/setting/ETH-Jelmoli_config.yml:
--------------------------------------------------------------------------------
1 | # -------- Tracking Setting for Sequence ETH-Jelmoli ---------
2 | ID_assign_init : 30
3 | ID_birth_init : -10
4 | DeathBufferLength : 2000
5 | BirthBufferLength : 5000
6 | DeathCount : 5
7 | BirthCount : 3
8 | Threshold : -3
9 | Distance : 0.05
10 | BoxRation : 0.3
11 | FrameWidth : 640
12 | FrameHeight : 480
13 | fps : 14
14 | PredictThreshold : 0.005
15 | # ------------------------------------------------------------
--------------------------------------------------------------------------------
/setting/ETH-Linthescher_config.yml:
--------------------------------------------------------------------------------
1 | # ------ Tracking Setting for Sequence ETH-Linthescher -------
2 | ID_assign_init : 30
3 | ID_birth_init : -10
4 | DeathBufferLength : 2000
5 | BirthBufferLength : 5000
6 | DeathCount : 12
7 | BirthCount : 2
8 | Threshold : -3
9 | Distance : 0.05
10 | BoxRation : 0.3
11 | FrameWidth : 640
12 | FrameHeight : 480
13 | fps : 14
14 | PredictThreshold : 0.005
15 | # ------------------------------------------------------------
--------------------------------------------------------------------------------
/setting/KITTI-16_config.yml:
--------------------------------------------------------------------------------
1 | # ---------- Tracking Setting for Sequence KITTI-16 ----------
2 | ID_assign_init : 30
3 | ID_birth_init : -10
4 | DeathBufferLength : 2000
5 | BirthBufferLength : 5000
6 | DeathCount : 3
7 | BirthCount : 2
8 | Threshold : -3
9 | Distance : 0.075
10 | BoxRation : 0.3
11 | FrameWidth : 1224
12 | FrameHeight : 370
13 | fps : 10
14 | PredictThreshold : 0.015
15 | # ------------------------------------------------------------
--------------------------------------------------------------------------------
/setting/KITTI-19_config.yml:
--------------------------------------------------------------------------------
1 | # ---------- Tracking Setting for Sequence KITTI-19 ----------
2 | ID_assign_init : 30
3 | ID_birth_init : -10
4 | DeathBufferLength : 2000
5 | BirthBufferLength : 5000
6 | DeathCount : 8
7 | BirthCount : 2
8 | Threshold : -3
9 | Distance : 0.075
10 | BoxRation : 0.3
11 | FrameWidth : 1238
12 | FrameHeight : 374
13 | fps : 10
14 | PredictThreshold : 0.015
15 | # ------------------------------------------------------------
--------------------------------------------------------------------------------
/setting/PETS09-S2L2_config.yml:
--------------------------------------------------------------------------------
1 | # --------- Tracking Setting for Sequence PETS09-S2L2 --------
2 | ID_assign_init : 30
3 | ID_birth_init : -10
4 | DeathBufferLength : 2000
5 | BirthBufferLength : 5000
6 | DeathCount : 10
7 | BirthCount : 5
8 | Threshold : -3
9 | Distance : 0.05
10 | BoxRation : 0.3
11 | FrameWidth : 768
12 | FrameHeight : 576
13 | fps : 7
14 | PredictThreshold : 0.005
15 | # ------------------------------------------------------------
--------------------------------------------------------------------------------
/setting/TUD-Crossing_config.yml:
--------------------------------------------------------------------------------
1 | # -------- Tracking Setting for Sequence TUD-Crossing --------
2 | ID_assign_init : 30
3 | ID_birth_init : -10
4 | DeathBufferLength : 2000
5 | BirthBufferLength : 5000
6 | DeathCount : 12
7 | BirthCount : 2
8 | Threshold : -3
9 | Distance : 0.04
10 | BoxRation : 0.3
11 | FrameWidth : 640
12 | FrameHeight : 480
13 | fps : 25
14 | PredictThreshold : 0.005
15 | # ------------------------------------------------------------
--------------------------------------------------------------------------------
/setting/Venice-1_config.yml:
--------------------------------------------------------------------------------
1 | # ---------- Tracking Setting for Sequence Venicd-1 ----------
2 | ID_assign_init : 30
3 | ID_birth_init : -10
4 | DeathBufferLength : 2000
5 | BirthBufferLength : 5000
6 | DeathCount : 10
7 | BirthCount : 5
8 | Threshold : -3
9 | Distance : 0.05
10 | BoxRation : 0.3
11 | FrameWidth : 1920
12 | FrameHeight : 1080
13 | fps : 30
14 | PredictThreshold : 0.005
15 | # ------------------------------------------------------------
--------------------------------------------------------------------------------
/setting/seq01_config.yml:
--------------------------------------------------------------------------------
1 | # ------------- Tracking Setting for Sequence 01 -------------
2 | ID_assign_init : 20
3 | ID_birth_init : -10
4 | DeathBufferLength : 2000
5 | BirthBufferLength : 5000
6 | DeathCount : 10
7 | BirthCount : 5
8 | Threshold : -3
9 | Distance : 0.05
10 | BoxRation : 0.3
11 | FrameWidth : 1920
12 | FrameHeight : 1080
13 | fps : 15
14 | PredictThreshold : 0.003
15 | # ------------------------------------------------------------
--------------------------------------------------------------------------------
/setting/seq03_config.yml:
--------------------------------------------------------------------------------
1 | # ------------- Tracking Setting for Sequence 03 -------------
2 | ID_assign_init : 100
3 | ID_birth_init : -10
4 | DeathBufferLength : 4000
5 | BirthBufferLength : 8000
6 | DeathCount : 10
7 | BirthCount : 5
8 | Threshold : -3
9 | Distance : 0.04
10 | BoxRation : 0.3
11 | FrameWidth : 1920
12 | FrameHeight : 1080
13 | fps : 15
14 | PredictThreshold : 0.005
15 | # ------------------------------------------------------------
--------------------------------------------------------------------------------
/setting/seq06_config.yml:
--------------------------------------------------------------------------------
1 | # ------------- Tracking Setting for Sequence 06 -------------
2 | ID_assign_init : 20
3 | ID_birth_init : -10
4 | DeathBufferLength : 2000
5 | BirthBufferLength : 5000
6 | DeathCount : 6
7 | BirthCount : 3
8 | Threshold : -3
9 | Distance : 0.06
10 | BoxRation : 0.3
11 | FrameWidth : 640
12 | FrameHeight : 480
13 | fps : 7
14 | PredictThreshold : 0.006
15 | # ------------------------------------------------------------
--------------------------------------------------------------------------------
/setting/seq07_config.yml:
--------------------------------------------------------------------------------
1 | # ------------- Tracking Setting for Sequence 07 -------------
2 | ID_assign_init : 30
3 | ID_birth_init : -10
4 | DeathBufferLength : 2000
5 | BirthBufferLength : 5000
6 | DeathCount : 10
7 | BirthCount : 5
8 | Threshold : -3
9 | Distance : 0.05
10 | BoxRation : 0.3
11 | FrameWidth : 1920
12 | FrameHeight : 1080
13 | fps : 15
14 | PredictThreshold : 0.005
15 | # ------------------------------------------------------------
--------------------------------------------------------------------------------
/setting/seq08_config.yml:
--------------------------------------------------------------------------------
1 | # ------------- Tracking Setting for Sequence 08 -------------
2 | ID_assign_init : 20
3 | ID_birth_init : -10
4 | DeathBufferLength : 2000
5 | BirthBufferLength : 5000
6 | DeathCount : 10
7 | BirthCount : 5
8 | Threshold : -3
9 | Distance : 0.06
10 | BoxRation : 0.3
11 | FrameWidth : 1920
12 | FrameHeight : 1080
13 | fps : 15
14 | PredictThreshold : 0.005
15 | # ------------------------------------------------------------
--------------------------------------------------------------------------------
/setting/seq12_config.yml:
--------------------------------------------------------------------------------
1 | # ------------- Tracking Setting for Sequence 12 -------------
2 | ID_assign_init : 20
3 | ID_birth_init : -10
4 | DeathBufferLength : 2000
5 | BirthBufferLength : 5000
6 | DeathCount : 10
7 | BirthCount : 5
8 | Threshold : -3
9 | Distance : 0.05
10 | BoxRation : 0.3
11 | FrameWidth : 1920
12 | FrameHeight : 1080
13 | fps : 15
14 | PredictThreshold : 0.003
15 | # ------------------------------------------------------------
--------------------------------------------------------------------------------
/setting/seq14_config.yml:
--------------------------------------------------------------------------------
1 | # ------------- Tracking Setting for Sequence 14 -------------
2 | ID_assign_init : 20
3 | ID_birth_init : -10
4 | DeathBufferLength : 2000
5 | BirthBufferLength : 5000
6 | DeathCount : 8
7 | BirthCount : 4
8 | Threshold : -3
9 | Distance : 0.05
10 | BoxRation : 0.3
11 | FrameWidth : 1920
12 | FrameHeight : 1080
13 | fps : 12
14 | PredictThreshold : 0.003
15 | # ------------------------------------------------------------
--------------------------------------------------------------------------------
/tracking-MOT15.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : tracking-MOT15.py
3 | # @Author : Peizhao Li
4 | # @Contact : peizhaoli05@gmail.com
5 | # @Date : 2018/11/14
6 |
7 | from Test import TestGenerator
8 | from utils import *
9 | from tracking_utils import *
10 | from mkvideo_MOT15 import MakeVideo
11 |
12 | np.set_printoptions(precision=2, suppress=True)
13 |
14 |
15 | def main(info, timer):
16 | "-------------------------------- initialize --------------------------------"
17 | os.environ["CUDA_VISIBLE_DEVICES"] = "3"
18 | parser, settings_show = Config("setting/{}_config.yml".format(info))
19 | for idx, data in enumerate(settings_show):
20 | print(data)
21 | detection = MakeCell(np.loadtxt("test-MOT15/{}/det.txt".format(info)))
22 | manual_init = MakeCell(np.loadtxt("test-MOT15/{}/res.txt".format(info)))
23 | StartFrame = len(manual_init)
24 | TotalFrame = len(detection)
25 | tracker_ = tracker(ID_assign_init=parser.ID_assign_init, ID_birth_init=parser.ID_birth_init,
26 | DeathBufferLength=parser.DeathBufferLength, BirthBufferLength=parser.BirthBufferLength,
27 | DeathCount=parser.DeathCount, BirthCount=parser.BirthCount, Threshold=parser.Threshold,
28 | Distance=parser.Distance, BoxRation=parser.BoxRation, FrameWidth=parser.FrameWidth,
29 | FrameHeight=parser.FrameHeight, PredictThreshold=parser.PredictThreshold)
30 | PrevData = manual_init[-1]
31 | PPrevData = manual_init[-2]
32 | PPPrevData = manual_init[-3]
33 | PPPPrevData = manual_init[-4]
34 | PPPPPrevData = manual_init[-5]
35 |
36 | res_path = "test-MOT15/{}/{}.txt".format(info, time_for_file())
37 | res_init = np.loadtxt("test-MOT15/{}/res.txt".format(info))
38 | with open(res_path, "w") as res:
39 | np.savetxt(res, res_init, fmt='%12.3f')
40 | res.close()
41 |
42 | generator = TestGenerator(res_path, info)
43 | BirthLog = [[], [], []] # prev frame, negative ID, assign ID
44 | DeathLog = [[], []] # death frame, death ID
45 | "----------------------------------------------------------------------------"
46 |
47 | print("############### Start Tracking ###############")
48 | for frame in range(StartFrame, TotalFrame):
49 | print ("-----------------------> Start Tracking Frame %d" % (frame + 1))
50 | CurData = detection[frame]
51 |
52 | PrevIDs = PrevData[:, 1].copy()
53 | assert PrevIDs[PrevIDs == 0].size == 0
54 | CurIDs = CurData[:, 1].copy()
55 | assert CurIDs.max() == -1 and CurIDs.min() == -1
56 |
57 | tik = time.time()
58 | Amatrix = generator(SeqID=0, frame=frame)
59 | Amatrix = Amatrix.cpu().numpy()
60 | assert Amatrix.shape[0] == PrevIDs.shape[0] and Amatrix.shape[1] == CurIDs.shape[0]
61 |
62 | PrevData, PPrevData, PPPrevData, PPPPrevData, PPPPPrevData, BirthLog, DeathLog = tracker_(Amatrix=Amatrix,
63 | PrevIDs=PrevIDs,
64 | CurData=CurData,
65 | PrevData=PrevData,
66 | PPrevData=PPrevData,
67 | PPPrevData=PPPrevData,
68 | PPPPrevData=PPPPrevData,
69 | PPPPPrevData=PPPPPrevData,
70 | BirthLog=BirthLog,
71 | DeathLog=DeathLog)
72 | tok = time.time()
73 | timer.sum(tok - tik)
74 |
75 | with open(res_path, "a") as res:
76 | np.savetxt(res, PrevData, fmt="%12.3f")
77 | res.close()
78 |
79 | print ("-----------------------> Finish Tracking Frame %d\n" % (frame + 1))
80 | print("############### Finish Tracking ###############\n")
81 |
82 | assert len(BirthLog[0]) == len(BirthLog[1]) == len(BirthLog[2])
83 | assert len(DeathLog[0]) == len(DeathLog[1])
84 |
85 | res_data = np.loadtxt(res_path)
86 |
87 | print("cleaning birth...")
88 | for birth in range(len(BirthLog[0])):
89 | frame = BirthLog[0][birth]
90 | ID_index = np.where(res_data[:, 1] == BirthLog[1][birth])
91 | assign_ID = BirthLog[2][birth]
92 | for i in range(parser.BirthCount):
93 | frame_ = frame - i
94 | frame_index = np.where(res_data[:, 0] == frame_)
95 | index = np.intersect1d(frame_index, ID_index)
96 | res_data[index, 1] = assign_ID
97 |
98 | print("cleaning death...")
99 | for death in range(len(DeathLog[0])):
100 | frame = DeathLog[0][death]
101 | ID_index = np.where(res_data[:, 1] == DeathLog[1][death])
102 | for i in range(parser.DeathCount - 3):
103 | frame_ = frame - i
104 | frame_index = np.where(res_data[:, 0] == frame_)
105 | index = np.intersect1d(frame_index, ID_index)
106 | res_data[index, 1] = -1
107 |
108 | print("cleaning death sp...")
109 | DeathBuffer = tracker_.DeathBuffer
110 | death_sp_log = np.intersect1d(np.where(DeathBuffer > 3)[0], np.where(DeathBuffer < parser.DeathCount)[0])
111 | for death_sp in range(death_sp_log.shape[0]):
112 | frame = TotalFrame
113 | ID_index = np.where(res_data[:, 1] == death_sp_log[death_sp])
114 | for i in range(int(DeathBuffer[death_sp_log[death_sp]])):
115 | frame_ = frame - i
116 | frame_index = np.where(res_data[:, 0] == frame_)
117 | index = np.intersect1d(frame_index, ID_index)
118 | res_data[index, 1] = -1
119 |
120 | np.savetxt(res_path, res_data, fmt="%12.3f")
121 |
122 | MakeVideo(res_path, info, parser.fps, parser.FrameWidth, parser.FrameHeight)
123 |
124 |
125 | if __name__ == "__main__":
126 | # seq = ["ADL-Rundle-1", "ADL-Rundle-3", "AVG-TownCentre", "ETH-Crossing", "ETH-Jelmoli", "ETH-Linthescher", "KITTI-16",
127 | # "KITTI-19", "PETS09-S2L2", "TUD-Crossing", "Venice-1"]
128 | seq = ["PETS09-S2L2"]
129 |
130 | timer = timer()
131 | for s in range(len(seq)):
132 | main(seq[s], timer)
133 | print("total time: {} second".format(timer()))
134 |
--------------------------------------------------------------------------------
/tracking.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : tracking.py
3 | # @Author : Peizhao Li
4 | # @Contact : peizhaoli05gmail.com
5 | # @Date : 2018/11/2
6 |
7 | from Test import TestGenerator
8 | from tracking_utils import *
9 | from utils import *
10 |
11 | np.set_printoptions(precision=2, suppress=True)
12 |
13 |
14 | def main(info, timer):
15 | "-------------------------------- initialize --------------------------------"
16 | os.environ["CUDA_VISIBLE_DEVICES"] = "3"
17 | parser, settings_show = Config("setting/seq{}_config.yml".format(info[0]))
18 | for idx, data in enumerate(settings_show):
19 | print(data)
20 | detection = MakeCell(np.loadtxt("test/MOT17-{}-{}/det.txt".format(info[0], info[1])))
21 | manual_init = MakeCell(np.loadtxt("test/MOT17-{}-{}/res.txt".format(info[0], info[1])))
22 | StartFrame = len(manual_init)
23 | TotalFrame = len(detection)
24 | tracker_ = tracker(ID_assign_init=parser.ID_assign_init, ID_birth_init=parser.ID_birth_init,
25 | DeathBufferLength=parser.DeathBufferLength, BirthBufferLength=parser.BirthBufferLength,
26 | DeathCount=parser.DeathCount, BirthCount=parser.BirthCount, Threshold=parser.Threshold,
27 | Distance=parser.Distance, BoxRation=parser.BoxRation, FrameWidth=parser.FrameWidth,
28 | FrameHeight=parser.FrameHeight, PredictThreshold=parser.PredictThreshold)
29 | PrevData = manual_init[-1]
30 | PPrevData = manual_init[-2]
31 | PPPrevData = manual_init[-3]
32 | PPPPrevData = manual_init[-4]
33 | PPPPPrevData = manual_init[-5]
34 |
35 | res_path = "test/MOT17-{}-{}/{}.txt".format(info[0], info[1], time_for_file())
36 | res_init = np.loadtxt("test/MOT17-{}-{}/res.txt".format(info[0], info[1]))
37 | with open(res_path, "w") as res:
38 | np.savetxt(res, res_init, fmt='%12.3f')
39 | res.close()
40 |
41 | generator = TestGenerator(res_path, info)
42 | BirthLog = [[], [], []] # prev frame, negative ID, assign ID
43 | DeathLog = [[], []] # death frame, death ID
44 | "----------------------------------------------------------------------------"
45 |
46 | print("############### Start Tracking ###############")
47 | for frame in range(StartFrame, TotalFrame):
48 | print ("-----------------------> Start Tracking Frame %d" % (frame + 1))
49 | CurData = detection[frame]
50 |
51 | PrevIDs = PrevData[:, 1].copy()
52 | assert PrevIDs[PrevIDs == 0].size == 0
53 | CurIDs = CurData[:, 1].copy()
54 | assert CurIDs.max() == -1 and CurIDs.min() == -1
55 |
56 | tik = time.time()
57 | Amatrix = generator(SeqID=0, frame=frame)
58 | Amatrix = Amatrix.cpu().numpy()
59 | assert Amatrix.shape[0] == PrevIDs.shape[0] and Amatrix.shape[1] == CurIDs.shape[0]
60 |
61 | PrevData, PPrevData, PPPrevData, PPPPrevData, PPPPPrevData, BirthLog, DeathLog = tracker_(Amatrix=Amatrix,
62 | PrevIDs=PrevIDs,
63 | CurData=CurData,
64 | PrevData=PrevData,
65 | PPrevData=PPrevData,
66 | PPPrevData=PPPrevData,
67 | PPPPrevData=PPPPrevData,
68 | PPPPPrevData=PPPPPrevData,
69 | BirthLog=BirthLog,
70 | DeathLog=DeathLog)
71 | tok = time.time()
72 | timer.sum(tok - tik)
73 |
74 | with open(res_path, "a") as res:
75 | np.savetxt(res, PrevData, fmt="%12.3f")
76 | res.close()
77 |
78 | print ("-----------------------> Finish Tracking Frame %d\n" % (frame + 1))
79 | print("############### Finish Tracking ###############\n")
80 |
81 | assert len(BirthLog[0]) == len(BirthLog[1]) == len(BirthLog[2])
82 | assert len(DeathLog[0]) == len(DeathLog[1])
83 |
84 | res_data = np.loadtxt(res_path)
85 |
86 | # print("cleaning birth...")
87 | # for birth in range(len(BirthLog[0])):
88 | # frame = BirthLog[0][birth]
89 | # ID_index = np.where(res_data[:, 1] == BirthLog[1][birth])
90 | # assign_ID = BirthLog[2][birth]
91 | # for i in range(parser.BirthCount):
92 | # frame_ = frame - i
93 | # frame_index = np.where(res_data[:, 0] == frame_)
94 | # index = np.intersect1d(frame_index, ID_index)
95 | # res_data[index, 1] = assign_ID
96 |
97 | # print("cleaning death...")
98 | # for death in range(len(DeathLog[0])):
99 | # frame = DeathLog[0][death]
100 | # ID_index = np.where(res_data[:, 1] == DeathLog[1][death])
101 | # for i in range(parser.DeathCount - 2):
102 | # frame_ = frame - i
103 | # frame_index = np.where(res_data[:, 0] == frame_)
104 | # index = np.intersect1d(frame_index, ID_index)
105 | # res_data[index, 1] = -1
106 |
107 | # print("cleaning death sp...")
108 | # DeathBuffer = tracker_.DeathBuffer
109 | # death_sp_log = np.intersect1d(np.where(DeathBuffer > 3)[0], np.where(DeathBuffer < parser.DeathCount)[0])
110 | # for death_sp in range(death_sp_log.shape[0]):
111 | # frame = TotalFrame
112 | # ID_index = np.where(res_data[:, 1] == death_sp_log[death_sp])
113 | # for i in range(int(DeathBuffer[death_sp_log[death_sp]])):
114 | # frame_ = frame - i
115 | # frame_index = np.where(res_data[:, 0] == frame_)
116 | # index = np.intersect1d(frame_index, ID_index)
117 | # res_data[index, 1] = -1
118 |
119 | np.savetxt(res_path, res_data, fmt="%12.3f")
120 |
121 | # MakeVideo(res_path, info, parser.fps, parser.FrameWidth, parser.FrameHeight)
122 |
123 |
124 | if __name__ == "__main__":
125 | # seq = ["01", "06", "07", "08", "12", "14"]
126 | seq = ["03"]
127 | # detector = ["DPM", "FRCNN", "SDP"]
128 | detector = ["DPM"]
129 | # detector = ["FRCNN"]
130 | # detector = ["SDP"]
131 | timer = timer()
132 | for s in range(len(seq)):
133 | for d in range(len(detector)):
134 | main([seq[s], detector[d]], timer)
135 | print("total time: {} second".format(timer()))
136 |
--------------------------------------------------------------------------------
/tracking_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : tracking_utils.py
3 | # @Author : Peizhao Li
4 | # @Contact : peizhaoli05gmail.com
5 | # @Date : 2018/11/2
6 |
7 | import numpy as np
8 |
9 |
10 | def MakeCell(data):
11 | cell = []
12 | frame_last = data[-1, 0]
13 | for i in range(1, int(frame_last) + 1):
14 | data_ = data[data[:, 0] == i]
15 | cell.append(data_.copy())
16 |
17 | return cell
18 |
19 |
20 | class timer():
21 |
22 | def __init__(self):
23 | self.time = 0
24 |
25 | def sum(self, time):
26 | self.time += time
27 |
28 | def __call__(self):
29 | return int(self.time)
30 |
31 |
32 | class ID_assign():
33 |
34 | def __init__(self, ID_init):
35 | self.ID = ID_init - 1
36 |
37 | def curID(self):
38 | return self.ID
39 |
40 | def __call__(self):
41 | self.ID += 1
42 | return self.ID
43 |
44 |
45 | class ID_birth():
46 |
47 | def __init__(self, ID_init):
48 | self.ID = ID_init + 1
49 |
50 | def curID(self):
51 | return self.ID
52 |
53 | def __call__(self):
54 | self.ID -= 1
55 | return self.ID
56 |
57 |
58 | class tracker():
59 |
60 | def __init__(self, ID_assign_init, ID_birth_init, DeathBufferLength, BirthBufferLength, DeathCount, BirthCount,
61 | Threshold, Distance, BoxRation, FrameWidth, FrameHeight, PredictThreshold):
62 | self.ID_assign = ID_assign(ID_init=ID_assign_init)
63 | self.ID_birth = ID_birth(ID_init=ID_birth_init)
64 | self.DeathBuffer = np.zeros(DeathBufferLength)
65 | self.BirthBuffer = np.zeros(BirthBufferLength)
66 | self.DeathCount = DeathCount
67 | self.BirthCount = BirthCount
68 | self.Threshold = Threshold
69 | self.Distance = Distance
70 | self.BoxRation = BoxRation
71 | self.FrameWidth = float(FrameWidth)
72 | self.FrameHeight = float(FrameHeight)
73 | self.PredictThreshold = PredictThreshold
74 |
75 | def DistanceMeasure(self, PrevData_, CurData_):
76 | x_dis = (np.abs(PrevData_[2] - CurData_[2])) / self.FrameWidth
77 | y_dis = (np.abs(PrevData_[3] - CurData_[3])) / self.FrameHeight
78 | dis = x_dis + y_dis
79 |
80 | PrevBoxSize = PrevData_[4]
81 | CurBoxSize = CurData_[4]
82 |
83 | if dis <= self.Distance and CurBoxSize < PrevBoxSize * (1 + self.BoxRation) and CurBoxSize > PrevBoxSize * (
84 | 1 - self.BoxRation):
85 | return True
86 | else:
87 | return False
88 |
89 | def CoordPrediction_v1(self, PrevData_, PPrevData_):
90 | if PPrevData_ is not None:
91 | x = 2 * PrevData_[2] - PPrevData_[2]
92 | y = 2 * PrevData_[3] - PPrevData_[3]
93 | else:
94 | x = PrevData_[2]
95 | y = PrevData_[3]
96 |
97 | return x, y
98 |
99 | def CoordPrediction_v2(self, PrevData_, PPrevData_):
100 | if PPrevData_ is not None:
101 | w = (PrevData_[4] + PPrevData_[4]) / 2.0
102 | h = (PrevData_[5] + PPrevData_[5]) / 2.0
103 |
104 | x0 = PPrevData_[2] + (PPrevData_[4] / 2.0)
105 | y0 = PPrevData_[3] + (PPrevData_[5] / 2.0)
106 |
107 | x1 = PrevData_[2] + (PrevData_[4] / 2.0)
108 | y1 = PrevData_[3] + (PrevData_[5] / 2.0)
109 |
110 | x_move = x1 - x0
111 | x_move = min(x_move, self.PredictThreshold * self.FrameWidth)
112 | x_move = max(x_move, -self.PredictThreshold * self.FrameWidth)
113 |
114 | y_move = y1 - y0
115 | y_move = min(y_move, self.PredictThreshold * self.FrameHeight)
116 | y_move = max(y_move, -self.PredictThreshold * self.FrameHeight)
117 |
118 | x2 = x1 + x_move
119 | y2 = y1 + y_move
120 |
121 | x = x2 - (w / 2.0)
122 | y = y2 - (h / 2.0)
123 |
124 | else:
125 | x = PrevData_[2]
126 | y = PrevData_[3]
127 | w = PrevData_[4]
128 | h = PrevData_[5]
129 |
130 | return x, y, w, h
131 |
132 | def CoordPrediction_v3(self, PrevData_, PPrevData_, PPPrevData_):
133 | if PPrevData_ is not None and PPPrevData_ is not None:
134 | w = (PPrevData_[4] + PPPrevData_[4]) / 2.0
135 | h = (PPrevData_[5] + PPPrevData_[5]) / 2.0
136 |
137 | x0 = PPPrevData_[2] + (PPPrevData_[4] / 2.0)
138 | y0 = PPPrevData_[3] + (PPPrevData_[5] / 2.0)
139 |
140 | x1 = PPrevData_[2] + (PPrevData_[4] / 2.0)
141 | y1 = PPrevData_[3] + (PPrevData_[5] / 2.0)
142 |
143 | x2 = 3 * x1 - 2 * x0
144 | y2 = 3 * y1 - 2 * y0
145 |
146 | x = x2 - (w / 2.0)
147 | y = y2 - (h / 2.0)
148 |
149 | else:
150 | x = PrevData_[2]
151 | y = PrevData_[3]
152 | w = PrevData_[4]
153 | h = PrevData_[5]
154 |
155 | return x, y, w, h
156 |
157 | def CoordPrediction_v4(self, PrevData_, PPrevData_, PPPrevData_, PPPPrevData_, PPPPPrevData_):
158 | w = PrevData_[4]
159 | h = PrevData_[5]
160 |
161 | x0 = PPPPPrevData_[2] + (PPPPPrevData_[4] / 2.0)
162 | y0 = PPPPPrevData_[3] + (PPPPPrevData_[5] / 2.0)
163 |
164 | x1 = PPPPrevData_[2] + (PPPPrevData_[4] / 2.0)
165 | y1 = PPPPrevData_[3] + (PPPPrevData_[5] / 2.0)
166 |
167 | x2 = PPPrevData_[2] + (PPPrevData_[4] / 2.0)
168 | y2 = PPPrevData_[3] + (PPPrevData_[5] / 2.0)
169 |
170 | x3 = PPrevData_[2] + (PPrevData_[4] / 2.0)
171 | y3 = PPrevData_[3] + (PPrevData_[5] / 2.0)
172 |
173 | x4 = PrevData_[2] + (PrevData_[4] / 2.0)
174 | y4 = PrevData_[3] + (PrevData_[5] / 2.0)
175 |
176 | x_move = ((x1 - x0) + (x2 - x1) + (x3 - x2) + (x4 - x3)) / 4.0
177 | y_move = ((y1 - y0) + (y2 - y1) + (y3 - y2) + (y4 - y3)) / 4.0
178 |
179 | x_move = min(x_move, self.PredictThreshold * self.FrameWidth)
180 | x_move = max(x_move, -self.PredictThreshold * self.FrameWidth)
181 |
182 | y_move = min(y_move, self.PredictThreshold * self.FrameHeight)
183 | y_move = max(y_move, -self.PredictThreshold * self.FrameHeight)
184 |
185 | x5 = x4 + x_move
186 | y5 = y4 + y_move
187 |
188 | x = x5 - (w / 2.0)
189 | y = y5 - (h / 2.0)
190 |
191 | return x, y, w, h
192 |
193 | def __call__(self, Amatrix, PrevIDs, CurData, PrevData, PPrevData, PPPrevData, PPPPrevData, PPPPPrevData, BirthLog,
194 | DeathLog):
195 | PreRange = np.arange(Amatrix.shape[0])
196 | CurRange = np.arange(Amatrix.shape[1])
197 | PrevMatchIndex = []
198 | CurMatchIndex = []
199 |
200 | # step 1: match
201 | while Amatrix.max() > self.Threshold:
202 | PrevIndex, CurIndex = np.unravel_index(Amatrix.argmax(), Amatrix.shape)
203 | if self.DistanceMeasure(PrevData[PrevIndex], CurData[CurIndex]):
204 | PrevMatchIndex.append(PrevIndex.copy())
205 | CurMatchIndex.append(CurIndex.copy())
206 | prevID = int(PrevIDs[PrevIndex])
207 |
208 | # step 1.1: birth check
209 | if prevID < 0:
210 | self.BirthBuffer[prevID] += 1
211 | print("ID %d birth count %d" % (prevID, self.BirthBuffer[prevID]))
212 |
213 | if self.BirthBuffer[prevID] == self.BirthCount:
214 | CurData[CurIndex, 1] = self.ID_assign()
215 | BirthLog[0].append(PrevData[PrevIndex, 0])
216 | BirthLog[1].append(prevID)
217 | BirthLog[2].append(CurData[CurIndex, 1])
218 | print("---> New ID %d assigned to index %d" % (CurData[CurIndex, 1], CurIndex))
219 | else:
220 | CurData[CurIndex, 1] = prevID
221 |
222 | # step 1.2: match
223 | else:
224 | # step 1.2.1: buffer clean
225 | self.DeathBuffer[prevID] = 0
226 |
227 | # step 1.2.2: copy ID
228 | CurData[CurIndex, 1] = prevID
229 | print("ID %d passed from index %d to index %d" % (prevID, PrevIndex, CurIndex))
230 |
231 | Amatrix[PrevIndex, :] = self.Threshold
232 | Amatrix[:, CurIndex] = self.Threshold
233 | else:
234 | Amatrix[PrevIndex, CurIndex] = self.Threshold
235 |
236 | # step 2: find mismatch
237 | DeathIndex = np.setxor1d(np.array(PrevMatchIndex), PreRange).astype(int)
238 | BirthIndex = np.setxor1d(np.array(CurMatchIndex), CurRange).astype(int)
239 | print ("-----------------------> Birth and Death")
240 | print("DeathIndex: {}".format(DeathIndex))
241 | print("BirthIndex: {}".format(BirthIndex))
242 |
243 | # step 3: death process
244 | for i in range(len(DeathIndex)):
245 | deathID = int(PrevIDs[DeathIndex[i]])
246 | if deathID < 0:
247 | pass
248 | else:
249 | self.DeathBuffer[deathID] += 1
250 | print("ID %d death count %d" % (deathID, self.DeathBuffer[deathID]))
251 |
252 | # step 3.1: terminate check
253 | if self.DeathBuffer[deathID] == self.DeathCount:
254 | DeathLog[0].append(PrevData[DeathIndex[i], 0])
255 | DeathLog[1].append(deathID)
256 | print("terminate %d" % deathID)
257 |
258 | # step 3.2: death prediction
259 | else:
260 | PrevData_ = PrevData[PrevData[:, 1] == deathID].squeeze()
261 | if deathID in PPrevData[:, 1]:
262 | PPrevData_ = PPrevData[PPrevData[:, 1] == deathID].squeeze()
263 | else:
264 | PPrevData_ = PrevData_
265 | if deathID in PPPrevData[:, 1]:
266 | PPPrevData_ = PPPrevData[PPPrevData[:, 1] == deathID].squeeze()
267 | else:
268 | PPPrevData_ = PPrevData_
269 | if deathID in PPPPrevData[:, 1]:
270 | PPPPrevData_ = PPPPrevData[PPPPrevData[:, 1] == deathID].squeeze()
271 | else:
272 | PPPPrevData_ = PPPrevData_
273 | if deathID in PPPPPrevData[:, 1]:
274 | PPPPPrevData_ = PPPPPrevData[PPPPPrevData[:, 1] == deathID].squeeze()
275 | else:
276 | PPPPPrevData_ = PPPPrevData_
277 | DeathData = PrevData[DeathIndex[i]].copy()
278 |
279 | CoordPrediction = self.CoordPrediction_v4(PrevData_, PPrevData_, PPPrevData_, PPPPrevData_,
280 | PPPPPrevData_)
281 | DeathData[2], DeathData[3], DeathData[4], DeathData[5] = CoordPrediction[0], CoordPrediction[1], \
282 | CoordPrediction[2], CoordPrediction[3]
283 |
284 | DeathData[0] += 1 # frame update
285 | CurData = np.concatenate((CurData, DeathData.reshape(1, -1)))
286 | print("ID %d coordinates predicted" % deathID)
287 |
288 | # step 4: birth process:
289 | for j in range(len(BirthIndex)):
290 | CurData[BirthIndex[j], 1] = self.ID_birth()
291 | print("Pseudo ID %d assigned to index %d" % (CurData[BirthIndex[j], 1], BirthIndex[j]))
292 |
293 | assert self.DeathBuffer.max() <= self.DeathCount
294 | assert self.BirthBuffer.max() <= self.BirthCount
295 |
296 | print("-----------------------> ID info")
297 | print("ID up to %d" % self.ID_assign.curID())
298 | print("Pseudo ID up to %d" % self.ID_birth.curID())
299 |
300 | return CurData, PrevData, PPrevData, PPPrevData, PPPPrevData, BirthLog, DeathLog
301 |
--------------------------------------------------------------------------------
/train/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : __init__.py
3 | # @Author : Peizhao Li
4 | # @Contact : lipeizhao1997@gmail.com
5 | # @Date : 2018/10/26
6 |
--------------------------------------------------------------------------------
/train/train_net_1024.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : train_net_1024.py
3 | # @Author : Peizhao Li
4 | # @Contact : peizhaoli05@gmail.com
5 | # @Date : 2018/10/24
6 |
7 | import os.path as osp
8 |
9 | from model import net_1024
10 | from utils import *
11 |
12 |
13 | def train(parser, generator, log, log_path):
14 | # print("training net_1024\n")
15 | # model = net_1024.net_1024()
16 |
17 | print("training final\n")
18 | model = net_1024.net_1024()
19 |
20 | "----------------- pretrained model loading -----------------"
21 | # print("loading pretrained model")
22 | # checkpoint = torch.load("/home/lallazhao/MOT/result/Oct-25-at-02-17-net_1024/net_1024_88.4.pth")
23 | # model.load_state_dict(checkpoint["state_dict"])
24 | "------------------------------------------------------------"
25 |
26 | model = model.cuda()
27 | net_param_dict = model.parameters()
28 |
29 | weight = torch.Tensor([10])
30 | criterion_BCE = torch.nn.BCEWithLogitsLoss(pos_weight=weight).cuda()
31 | criterion_CE = torch.nn.CrossEntropyLoss().cuda()
32 | criterion_MSE = torch.nn.MSELoss().cuda()
33 |
34 | if parser.optimizer == "SGD":
35 | optimizer = torch.optim.SGD(net_param_dict, lr=parser.learning_rate,
36 | momentum=parser.momentum, weight_decay=parser.decay, nesterov=True)
37 | elif parser.optimizer == "Adam":
38 | optimizer = torch.optim.Adam(net_param_dict, lr=parser.learning_rate, weight_decay=parser.decay)
39 | elif parser.optimizer == "RMSprop":
40 | optimizer = torch.optim.RMSprop(net_param_dict, lr=parser.learning_rate, weight_decay=parser.decay,
41 | momentum=parser.momentum)
42 | else:
43 | raise NotImplementedError
44 |
45 | # Main Training and Evaluation Loop
46 | start_time, epoch_time = time.time(), AverageMeter()
47 |
48 | Batch_time = AverageMeter()
49 | Loss = AverageMeter()
50 | Acc = AverageMeter()
51 | Acc_pos = AverageMeter()
52 |
53 | for epoch in range(parser.start_epoch, parser.epochs):
54 | all_lrs = adjust_learning_rate(optimizer, epoch, parser.gammas, parser.schedule)
55 | need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (parser.epochs - epoch))
56 |
57 | # ----------------------------------- train for one epoch -----------------------------------
58 | batch_time, loss, acc, acc_pos = train_net_1024(model, generator, optimizer, criterion_BCE, criterion_CE,
59 | criterion_MSE)
60 |
61 | Batch_time.update(batch_time)
62 | Loss.update(loss.item())
63 | Acc.update(acc)
64 | Acc_pos.update(acc_pos)
65 |
66 | if epoch % parser.print_freq == 0 or epoch == parser.epochs - 1:
67 | print_log('Epoch: [{:03d}/{:03d}]\t'
68 | 'Time {batch_time.val:5.2f} ({batch_time.avg:5.2f})\t'
69 | 'Loss {loss.val:6.3f} ({loss.avg:6.3f})\t'
70 | "Acc {acc.val:6.3f} ({acc.avg:6.3f})\t"
71 | "Acc_pos {acc_pos.val:6.3f} ({acc_pos.avg:6.3f})\t".format(
72 | epoch, parser.epochs, batch_time=Batch_time, loss=Loss, acc=Acc, acc_pos=Acc_pos), log)
73 |
74 | Batch_time = AverageMeter()
75 | Loss = AverageMeter()
76 |
77 | if (epoch in parser.schedule):
78 | print_log("------------------- adjust learning rate -------------------", log)
79 | # -------------------------------------------------------------------------------------------
80 |
81 | # measure elapsed time
82 | epoch_time.update(time.time() - start_time)
83 | start_time = time.time()
84 |
85 | if parser.save_model:
86 | save_file_path = osp.join(log_path, "net_1024.pth")
87 | states = {
88 | "state_dict": model.state_dict(),
89 | }
90 | torch.save(states, save_file_path)
91 |
92 |
93 | def train_net_1024(model, generator, optimizer, criterion_BCE, criterion_CE, criterion_MSE):
94 | # switch to train mode
95 | model.train()
96 |
97 | cur_crop, pre_crop, cur_motion, pre_motion, gt_matrix = generator()
98 | assert len(cur_crop) == len(cur_motion)
99 | assert len(pre_crop) == len(pre_motion)
100 |
101 | target = torch.from_numpy(gt_matrix).cuda().float().view(-1)
102 |
103 | end = time.time()
104 |
105 | s0, s1, s2, s3, adj1, adj = model(pre_crop, cur_crop, pre_motion, cur_motion)
106 | loss = criterion_BCE(s0, target)
107 | loss += criterion_BCE(s1, target)
108 | loss += criterion_BCE(s2, target)
109 | loss += criterion_BCE(s3, target)
110 | loss += matrix_loss(adj1, gt_matrix, criterion_CE, criterion_MSE)
111 | loss += matrix_loss(adj, gt_matrix, criterion_CE, criterion_MSE)
112 |
113 | # s0, s3, adj = model(pre_crop, cur_crop)
114 | # loss = criterion_BCE(s0, target)
115 | # loss = criterion_BCE(s3, target)
116 | # loss += matrix_loss(adj1, gt_matrix, criterion_CE, criterion_MSE)
117 | # loss += matrix_loss(adj, gt_matrix, criterion_CE, criterion_MSE)
118 |
119 | acc, acc_pos = accuracy(s3.clone(), target.clone())
120 |
121 | optimizer.zero_grad()
122 | loss.backward()
123 | optimizer.step()
124 |
125 | batch_time = time.time() - end
126 |
127 | return batch_time, loss, acc, acc_pos
128 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : utils.py
3 | # @Author : Peizhao Li
4 | # @Contact : peizhaoli05@gmail.com
5 | # @Date : 2018/9/27
6 |
7 | import yaml, torch, time, os
8 | from easydict import EasyDict as edict
9 | import numpy as np
10 |
11 |
12 | def Config(filename):
13 | listfile1 = open(filename, 'r')
14 | listfile2 = open(filename, 'r')
15 | parser = edict(yaml.load(listfile1))
16 | settings_show = listfile2.read().splitlines()
17 | return parser, settings_show
18 |
19 |
20 | def adjust_learning_rate(optimizer, epoch, gammas, schedule):
21 | assert len(gammas) == len(schedule), "length of gammas and schedule should be equal"
22 | multiple = 1
23 | for (gamma, step) in zip(gammas, schedule):
24 | if (epoch == step):
25 | multiple = gamma
26 | break
27 | all_lrs = []
28 | for param_group in optimizer.param_groups:
29 | param_group['lr'] = multiple * param_group['lr']
30 | all_lrs.append(param_group['lr'])
31 | return set(all_lrs)
32 |
33 |
34 | def print_log(print_string, log, true_string=None):
35 | print("{}".format(print_string))
36 | if true_string is not None:
37 | print_string = true_string
38 | if log is not None:
39 | log.write('{}\n'.format(print_string))
40 | log.flush()
41 |
42 |
43 | def time_string():
44 | ISOTIMEFORMAT = '%Y-%m-%d %X'
45 | string = '[{}]'.format(time.strftime(ISOTIMEFORMAT, time.localtime(time.time())))
46 | return string
47 |
48 |
49 | def time_for_file():
50 | ISOTIMEFORMAT = '%h-%d-at-%H-%M'
51 | return '{}'.format(time.strftime(ISOTIMEFORMAT, time.localtime(time.time())))
52 |
53 |
54 | def convert_secs2time(epoch_time):
55 | need_hour = int(epoch_time / 3600)
56 | need_mins = int((epoch_time - 3600 * need_hour) / 60)
57 | need_secs = int(epoch_time - 3600 * need_hour - 60 * need_mins)
58 | return need_hour, need_mins, need_secs
59 |
60 |
61 | def extract_label(matrix):
62 | index = np.argwhere(matrix == 1)
63 | target = index[:, 1]
64 | target = torch.from_numpy(target).cuda()
65 |
66 | return target
67 |
68 |
69 | def matrix_loss(matrix, gt_matrix, criterion_CE, criterion_MSE):
70 | index_row_match = np.where([np.sum(gt_matrix, axis=1) == 1])[1]
71 | index_col_match = np.where([np.sum(gt_matrix, axis=0) == 1])[1]
72 | index_row_miss = np.where([np.sum(gt_matrix, axis=1) == 0])[1]
73 | index_col_miss = np.where([np.sum(gt_matrix, axis=0) == 0])[1]
74 |
75 | gt_matrix_row_match = np.take(gt_matrix, index_row_match, axis=0)
76 | gt_matrix_col_match = np.take(gt_matrix.transpose(), index_col_match, axis=0)
77 |
78 | index_row_match = torch.from_numpy(index_row_match).cuda()
79 | index_col_match = torch.from_numpy(index_col_match).cuda()
80 |
81 | matrix_row_match = torch.index_select(matrix, dim=0, index=index_row_match)
82 | matrix_col_match = torch.index_select(matrix.t(), dim=0, index=index_col_match)
83 |
84 | label_row_CE = extract_label(gt_matrix_row_match)
85 | label_col_CE = extract_label(gt_matrix_col_match)
86 |
87 | loss = criterion_CE(matrix_row_match, label_row_CE)
88 | loss += criterion_CE(matrix_col_match, label_col_CE)
89 |
90 | if index_row_miss.size != 0:
91 | index_row_miss = torch.from_numpy(index_row_miss).cuda()
92 | matrix_row_miss = torch.index_select(matrix, dim=0, index=index_row_miss)
93 | loss += criterion_MSE(torch.sigmoid(matrix_row_miss), torch.zeros_like(matrix_row_miss))
94 |
95 | if index_col_miss.size != 0:
96 | index_col_miss = torch.from_numpy(index_col_miss).cuda()
97 | matrix_col_miss = torch.index_select(matrix.t(), dim=0, index=index_col_miss)
98 | loss += criterion_MSE(torch.sigmoid(matrix_col_miss), torch.zeros_like(matrix_col_miss))
99 |
100 | return loss
101 |
102 |
103 | def accuracy(input, target):
104 | assert input.size() == target.size()
105 |
106 | input[input < 0] = 0
107 | input[input > 0] = 1
108 | batch_size = input.size(0)
109 | pos_size = torch.sum(target)
110 |
111 | dis = input.sub(target)
112 | wrong = torch.sum(torch.abs(dis))
113 | acc = (batch_size - wrong.item()) / batch_size
114 |
115 | index = torch.nonzero(target)
116 | input_pos = torch.sum(input[index])
117 | acc_pos = input_pos.item() / pos_size
118 |
119 | return acc, acc_pos
120 |
121 |
122 | class AverageMeter(object):
123 | """Computes and stores the average and current value"""
124 |
125 | def __init__(self):
126 | self.reset()
127 |
128 | def reset(self):
129 | self.val = 0
130 | self.avg = 0
131 | self.sum = 0
132 | self.count = 0
133 |
134 | def update(self, val, n=1):
135 | self.val = val
136 | self.sum += val * n
137 | self.count += n
138 | self.avg = self.sum / self.count
139 |
--------------------------------------------------------------------------------