├── .idea ├── MOT.iml ├── deployment.xml ├── dictionaries │ └── lipeizhao.xml ├── inspectionProfiles │ └── Project_Default.xml ├── misc.xml ├── modules.xml └── other.xml ├── Generator.py ├── README.md ├── Test.py ├── TestGenerate.py ├── config.yml ├── main.py ├── model ├── EmbeddingNet.py ├── FuckUpNet.py ├── GCN.py ├── __init__.py ├── final.py └── net_1024.py ├── requirements.txt ├── setting ├── ADL-Rundle-1_config.yml ├── ADL-Rundle-3_config.yml ├── AVG-TownCentre_config.yml ├── ETH-Crossing_config.yml ├── ETH-Jelmoli_config.yml ├── ETH-Linthescher_config.yml ├── KITTI-16_config.yml ├── KITTI-19_config.yml ├── PETS09-S2L2_config.yml ├── TUD-Crossing_config.yml ├── Venice-1_config.yml ├── seq01_config.yml ├── seq03_config.yml ├── seq06_config.yml ├── seq07_config.yml ├── seq08_config.yml ├── seq12_config.yml └── seq14_config.yml ├── tracking-MOT15.py ├── tracking.py ├── tracking_utils.py ├── train ├── __init__.py └── train_net_1024.py └── utils.py /.idea/MOT.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | 12 | 14 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /.idea/dictionaries/lipeizhao.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | amatrix 5 | bcount 6 | cfile 7 | dcount 8 | dnow 9 | embeddingnet 10 | embnet 11 | fourcc 12 | framewise 13 | frcnn 14 | gmail 15 | imglist 16 | lstm 17 | manu 18 | ndarray 19 | peizhaoli 20 | predata 21 | rawdata 22 | savetxt 23 | seqdata 24 | seqres 25 | sprop 26 | stdv 27 | testcheck 28 | tracklet 29 | 30 | 31 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 32 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | -------------------------------------------------------------------------------- /Generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : Generator.py 3 | # @Author : Peizhao Li 4 | # @Contact : peizhaoli05@gmail.com 5 | # @Date : 2018/10/11 6 | 7 | import os, random 8 | import os.path as osp 9 | import numpy as np 10 | import torch 11 | from torchvision import transforms 12 | from PIL import Image, ImageDraw 13 | 14 | 15 | def LoadImg(img_path): 16 | path = os.listdir(img_path) 17 | path.sort() 18 | imglist = [] 19 | 20 | for i in range(len(path)): 21 | img = Image.open(osp.join(img_path, path[i])) 22 | imglist.append(img.copy()) 23 | img.close() 24 | 25 | return imglist 26 | 27 | 28 | def FindMatch(list_id, list1, list2): 29 | """ 30 | 31 | :param list_id: 32 | :param list1: 33 | :param list2: 34 | :return: 35 | """ 36 | index_pair = [] 37 | for index, id in enumerate(list_id): 38 | index1 = list1.index(id) 39 | index2 = list2.index(id) 40 | index_pair.append(index1) 41 | index_pair.append(index2) 42 | 43 | return index_pair 44 | 45 | 46 | class VideoData(object): 47 | 48 | def __init__(self, seq_id): 49 | self.img = LoadImg("MOT17/MOT17/train/MOT17-{}-SDP/img1".format(seq_id)) 50 | self.gt = np.loadtxt("MOT17/label/{}_gt.txt".format(seq_id)) 51 | 52 | self.ImageWidth = self.img[0].size[0] 53 | self.ImageHeight = self.img[0].size[1] 54 | 55 | self.transforms = transforms.Compose([ 56 | transforms.Resize((224, 224)), 57 | transforms.ToTensor() 58 | ]) 59 | 60 | def CurData(self, frame): 61 | data = self.gt[self.gt[:, 0] == (frame + 1)] 62 | 63 | return data 64 | 65 | def PreData(self, frame): 66 | DataList = [] 67 | for i in range(5): 68 | data = self.gt[self.gt[:, 0] == (frame + 1 - i)] 69 | DataList.append(data) 70 | 71 | return DataList 72 | 73 | def TotalFrame(self): 74 | 75 | return len(self.img) 76 | 77 | def CenterCoordinate(self, SingleLineData): 78 | x = (SingleLineData[2] + (SingleLineData[4] / 2)) / float(self.ImageWidth) 79 | y = (SingleLineData[3] + (SingleLineData[5] / 2)) / float(self.ImageHeight) 80 | 81 | return x, y 82 | 83 | def Appearance(self, data): 84 | """ 85 | 86 | :param data: 87 | :return: 88 | """ 89 | appearance = [] 90 | img = self.img[int(data[0, 0]) - 1] 91 | for i in range(data.shape[0]): 92 | crop = img.crop((int(data[i, 2]), int(data[i, 3]), int(data[i, 2]) + int(data[i, 4]), 93 | int(data[i, 3]) + int(data[i, 5]))) 94 | crop = self.transforms(crop) 95 | appearance.append(crop) 96 | 97 | return appearance 98 | 99 | def CurMotion(self, data): 100 | motion = [] 101 | for i in range(data.shape[0]): 102 | coordinate = torch.zeros([2]) 103 | coordinate[0], coordinate[1] = self.CenterCoordinate(data[i]) 104 | motion.append(coordinate) 105 | 106 | return motion 107 | 108 | def PreMotion(self, DataTuple): 109 | """ 110 | 111 | :param DataTuple: 112 | :return: 113 | """ 114 | motion = [] 115 | nameless = DataTuple[0] 116 | for i in range(nameless.shape[0]): 117 | coordinate = torch.zeros([5, 2]) 118 | identity = nameless[i, 1] 119 | coordinate[4, 0], coordinate[4, 1] = self.CenterCoordinate(nameless[i]) 120 | for j in range(1, 5): 121 | unknown = DataTuple[j] 122 | if identity in unknown[:, 1]: 123 | coordinate[4 - j, 0], coordinate[4 - j, 1] = self.CenterCoordinate( 124 | unknown[unknown[:, 1] == identity].squeeze()) 125 | else: 126 | coordinate[4 - j, :] = coordinate[5 - j, :] 127 | motion.append(coordinate) 128 | 129 | return motion 130 | 131 | def GetID(self, data): 132 | id = [] 133 | for i in range(data.shape[0]): 134 | id.append(data[i, 1]) 135 | 136 | return id 137 | 138 | def __call__(self, frame): 139 | """ 140 | 141 | :param frame: 142 | :return: 143 | """ 144 | assert frame >= 5 and frame < self.TotalFrame() 145 | cur = self.CurData(frame) 146 | pre = self.PreData(frame - 1) 147 | 148 | cur_crop = self.Appearance(cur) 149 | pre_crop = self.Appearance(pre[0]) 150 | 151 | cur_motion = self.CurMotion(cur) 152 | pre_motion = self.PreMotion(pre) 153 | 154 | cur_id = self.GetID(cur) 155 | pre_id = self.GetID(pre[0]) 156 | 157 | list_id = [x for x in pre_id if x in cur_id] 158 | index_pair = FindMatch(list_id, pre_id, cur_id) 159 | gt_matrix = np.zeros([len(pre_id), len(cur_id)]) 160 | for i in range(len(index_pair) / 2): 161 | gt_matrix[index_pair[2 * i], index_pair[2 * i + 1]] = 1 162 | 163 | return cur_crop, pre_crop, cur_motion, pre_motion, cur_id, pre_id, gt_matrix 164 | 165 | 166 | class Generator(object): 167 | def __init__(self, entirety=False): 168 | """ 169 | 170 | :param entirety: 171 | """ 172 | self.sequence = [] 173 | 174 | if entirety == True: 175 | self.SequenceID = ["02", "04", "05", "09", "10", "11", "13"] 176 | else: 177 | self.SequenceID = ["09"] 178 | 179 | self.vis_save_path = "MOT17/visualize" 180 | 181 | print("\n-------------------------- initialization --------------------------") 182 | for id in self.SequenceID: 183 | print("initializing sequence {} ...".format(id)) 184 | self.sequence.append(VideoData(id)) 185 | print("initialize {} done".format(id)) 186 | print("------------------------------ done --------------------------------\n") 187 | 188 | def visualize(self, seq_ID, frame, save_path=None): 189 | """ 190 | 191 | :param seq_ID: 192 | :param frame: 193 | :param save_path: 194 | """ 195 | if save_path is None: 196 | save_path = self.vis_save_path 197 | 198 | print("visualize sequence {}: frame {}".format(self.SequenceID[seq_ID], frame + 1)) 199 | print("video solution: {} {}".format(self.sequence[seq_ID].ImageWidth, self.sequence[seq_ID].ImageHeight)) 200 | cur_crop, pre_crop, cur_motion, pre_motion, cur_id, pre_id, gt_matrix = self.sequence[seq_ID](frame) 201 | 202 | for i in range(len(cur_crop)): 203 | img = cur_crop[i] 204 | img = transforms.functional.to_pil_image(img) 205 | img = transforms.functional.resize(img, (420, 160)) 206 | draw = ImageDraw.Draw(img) 207 | # draw.text((0, 0), "id: {}\ncoord: {:3.2f}, {:3.2f}".format(int(cur_id[i]), cur_motion[i][0].item(), 208 | # cur_motion[i][1].item()), fill=(255, 0, 0)) 209 | img.save(osp.join(save_path, "cur_crop_{}.png".format(str(i).zfill(2)))) 210 | 211 | for i in range(len(pre_crop)): 212 | img = pre_crop[i] 213 | img = transforms.functional.to_pil_image(img) 214 | img = transforms.functional.resize(img, (420, 160)) 215 | draw = ImageDraw.Draw(img) 216 | # draw.text((0, 0), "id: {}\ncoord: {:3.2f}, {:3.2f}".format(int(pre_id[i]), pre_motion[i][4, 0].item(), 217 | # pre_motion[i][4, 1].item()), fill=(255, 0, 0)) 218 | img.save(osp.join(save_path, "pre_crop_{}.png".format(str(i).zfill(2)))) 219 | 220 | np.savetxt(osp.join(save_path, "gt_matrix.txt"), gt_matrix, fmt="%d") 221 | np.savetxt(osp.join(save_path, "pre_id.txt"), np.array(pre_id).transpose(), fmt="%d") 222 | np.savetxt(osp.join(save_path, "cur_id.txt"), np.array(cur_id).transpose(), fmt="%d") 223 | 224 | def __call__(self): 225 | """ 226 | 227 | :return: 228 | """ 229 | seq = random.choice(self.sequence) 230 | frame = random.randint(5, seq.TotalFrame() - 1) 231 | cur_crop, pre_crop, cur_motion, pre_motion, cur_id, pre_id, gt_matrix = seq(frame) 232 | 233 | return cur_crop, pre_crop, cur_motion, pre_motion, gt_matrix 234 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Neural Based End-to-end Data Association Framework for Online Multiple-Object Tracking 2 | A PyTorch implementation combines with Siamese Network and Graph Neural Network for Online Multiple-Object Tracking. 3 | 4 | Dataset available at [https://motchallenge.net/] 5 | 6 | According paper can be found at [https://arxiv.org/abs/1907.05315] 7 | 8 | ## How to run 9 | Use `python main.py` to train a model from scratch. Settings for training is in `config.yml`. 10 | Use `python tracking.py` to track a test video, meanwhile you need to provide the detected objects & tracking results for the first five frames. Setting for tracking is in `setting/`. 11 | 12 | ## Requirements 13 | - Python 2.7.12 14 | - numpy 1.11.0 15 | - scipy 1.1.0 16 | - torchvision 0.2.1 17 | - opencv_python 3.3.0.10 18 | - easydict 1.7 19 | - torch 0.4.1 20 | - Pillow 6.2.0 21 | - PyYAML 5.1 22 | -------------------------------------------------------------------------------- /Test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : Test.py 3 | # @Author : Peizhao Li 4 | # @Contact : lipeizhao1997@gmail.com 5 | # @Date : 2018/10/6 6 | 7 | import os 8 | import os.path as osp 9 | import numpy as np 10 | import torch 11 | from torchvision import transforms 12 | from PIL import Image, ImageDraw 13 | from model import net_1024 14 | 15 | 16 | def LoadImg(img_path): 17 | path = os.listdir(img_path) 18 | path.sort() 19 | imglist = [] 20 | 21 | for i in range(len(path)): 22 | img = Image.open(osp.join(img_path, path[i])) 23 | imglist.append(img.copy()) 24 | img.close() 25 | 26 | return imglist 27 | 28 | 29 | def LoadModel(model, path): 30 | checkpoint = torch.load(path) 31 | model.load_state_dict(checkpoint["state_dict"]) 32 | model.cuda().eval() 33 | 34 | return model 35 | 36 | 37 | class VideoData(object): 38 | 39 | def __init__(self, info, res_path): 40 | # MOT17 41 | # self.img = LoadImg("MOT17/MOT17/test/MOT17-{}-{}/img1".format(info[0], info[1])) 42 | # self.det = np.loadtxt("test/MOT17-{}-{}/det.txt".format(info[0], info[1])) 43 | 44 | # MOT15 45 | self.img = LoadImg("MOT15/test/{}/img1".format(info)) 46 | self.det = np.loadtxt("test-MOT15/{}/det.txt".format(info)) 47 | 48 | self.res_path = res_path 49 | 50 | self.ImageWidth = self.img[0].size[0] 51 | self.ImageHeight = self.img[0].size[1] 52 | self.transforms = transforms.Compose([ 53 | transforms.Resize((84, 32)), 54 | transforms.ToTensor() 55 | ]) 56 | 57 | def DetData(self, frame): 58 | data = self.det[self.det[:, 0] == (frame + 1)] 59 | 60 | return data 61 | 62 | def PreData(self, frame): 63 | res = np.loadtxt(self.res_path) 64 | DataList = [] 65 | for i in range(5): 66 | data = res[res[:, 0] == (frame + 1 - i)] 67 | DataList.append(data) 68 | 69 | return DataList 70 | 71 | def TotalFrame(self): 72 | return len(self.img) 73 | 74 | def CenterCoordinate(self, SingleLineData): 75 | x = (SingleLineData[2] + (SingleLineData[4] / 2)) / float(self.ImageWidth) 76 | y = (SingleLineData[3] + (SingleLineData[5] / 2)) / float(self.ImageHeight) 77 | 78 | return x, y 79 | 80 | def Appearance(self, data): 81 | appearance = [] 82 | 83 | img = self.img[int(data[0, 0]) - 1] 84 | for i in range(data.shape[0]): 85 | crop = img.crop((int(data[i, 2]), int(data[i, 3]), int(data[i, 2]) + int(data[i, 4]), 86 | int(data[i, 3]) + int(data[i, 5]))) 87 | crop = self.transforms(crop) 88 | appearance.append(crop) 89 | 90 | return appearance 91 | 92 | def DetMotion(self, data): 93 | motion = [] 94 | for i in range(data.shape[0]): 95 | coordinate = torch.zeros([2]) 96 | coordinate[0], coordinate[1] = self.CenterCoordinate(data[i]) 97 | motion.append(coordinate) 98 | 99 | return motion 100 | 101 | def PreMotion(self, DataTuple): 102 | motion = [] 103 | nameless = DataTuple[0] 104 | for i in range(nameless.shape[0]): 105 | coordinate = torch.zeros([5, 2]) 106 | identity = nameless[i, 1] 107 | coordinate[4, 0], coordinate[4, 1] = self.CenterCoordinate(nameless[i]) 108 | # print(identity) 109 | 110 | for j in range(1, 5): 111 | unknown = DataTuple[j] 112 | if identity in unknown[:, 1]: 113 | coordinate[4 - j, 0], coordinate[4 - j, 1] = self.CenterCoordinate( 114 | unknown[unknown[:, 1] == identity].squeeze()) 115 | else: 116 | coordinate[4 - j, :] = coordinate[5 - j, :] 117 | 118 | motion.append(coordinate) 119 | 120 | return motion 121 | 122 | def GetID(self, data): 123 | id = [] 124 | for i in range(data.shape[0]): 125 | id.append(data[i, 1].copy()) 126 | 127 | return id 128 | 129 | def __call__(self, frame): 130 | assert frame >= 5 and frame < self.TotalFrame() 131 | det = self.DetData(frame) 132 | pre = self.PreData(frame - 1) 133 | det_crop = self.Appearance(det) 134 | pre_crop = self.Appearance(pre[0]) 135 | det_motion = self.DetMotion(det) 136 | pre_motion = self.PreMotion(pre) 137 | pre_id = self.GetID(pre[0]) 138 | 139 | return det_crop, det_motion, pre_crop, pre_motion, pre_id 140 | 141 | 142 | class TestGenerator(object): 143 | 144 | def __init__(self, res_path, info): 145 | net = net_1024.net_1024() 146 | net_path = "SaveModel/net_1024_beta2.pth" 147 | print("-------> loading net_1024") 148 | self.net = LoadModel(net, net_path) 149 | 150 | self.sequence = [] 151 | 152 | print("-------> initializing MOT17-{}-{} ...".format(info[0], info[1])) 153 | self.sequence.append(VideoData(info, res_path)) 154 | print("-------> initialize MOT17-{}-{} done".format(info[0], info[1])) 155 | 156 | self.vis_save_path = "test/visualize" 157 | 158 | def visualize(self, SeqID, frame, save_path=None): 159 | """ 160 | 161 | :param seq_ID: 162 | :param frame: 163 | :param save_path: 164 | """ 165 | if save_path is None: 166 | save_path = self.vis_save_path 167 | 168 | print("visualize sequence {}: frame {}".format(self.SequenceID[SeqID], frame + 1)) 169 | print("video solution: {} {}".format(self.sequence[SeqID].ImageWidth, self.sequence[SeqID].ImageHeight)) 170 | det_crop, det_motion, pre_crop, pre_motion, pre_id = self.sequence[SeqID](frame) 171 | 172 | for i in range(len(det_crop)): 173 | img = det_crop[i] 174 | img = transforms.functional.to_pil_image(img) 175 | img = transforms.functional.resize(img, (420, 160)) 176 | draw = ImageDraw.Draw(img) 177 | draw.text((0, 0), "num: {}\ncoord: {:3.2f}, {:3.2f}".format(int(i), det_motion[i][0].item(), 178 | det_motion[i][1].item()), fill=(255, 0, 0)) 179 | img.save(osp.join(save_path, "det_crop_{}.png".format(str(i).zfill(2)))) 180 | 181 | for i in range(len(pre_crop)): 182 | img = pre_crop[i] 183 | img = transforms.functional.to_pil_image(img) 184 | img = transforms.functional.resize(img, (420, 160)) 185 | draw = ImageDraw.Draw(img) 186 | draw.text((0, 0), "num: {}\nid: {}\ncoord: {:3.2f}, {:3.2f}".format(int(i), int(pre_id[i]), 187 | pre_motion[i][4, 0].item(), 188 | pre_motion[i][4, 1].item()), 189 | fill=(255, 0, 0)) 190 | img.save(osp.join(save_path, "pre_crop_{}.png".format(str(i).zfill(2)))) 191 | 192 | np.savetxt(osp.join(save_path, "pre_id.txt"), np.array(pre_id).transpose(), fmt="%d") 193 | 194 | def __call__(self, SeqID, frame): 195 | # frame start with 5, exist frame start from 1 196 | sequence = self.sequence[SeqID] 197 | det_crop, det_motion, pre_crop, pre_motion, pre_id = sequence(frame) 198 | with torch.no_grad(): 199 | s0, s1, s2, s3, adj1, adj = self.net(pre_crop, det_crop, pre_motion, det_motion) 200 | 201 | return adj 202 | -------------------------------------------------------------------------------- /TestGenerate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : TestGenerate.py 3 | # @Author : Peizhao Li 4 | # @Contact : peizhaoli05gmail.com 5 | # @Date : 2018/10/30 6 | 7 | from Test import TestGenerator 8 | import scipy.io as sio 9 | from utils import * 10 | import os.path as osp 11 | 12 | np.set_printoptions(precision=2, suppress=True) 13 | 14 | 15 | def MakeCell(data): 16 | cell = [] 17 | frame_last = data[-1, 0] 18 | for i in range(1, int(frame_last) + 1): 19 | data_ = data[data[:, 0] == i] 20 | cell.append(data_.copy()) 21 | 22 | return cell 23 | 24 | 25 | def TrackletRotation(newFrame, oldTracklet): 26 | # This function do the tracklet rotation in the prev_frame, happen at each frame preparing the input to the model 27 | # ---> oldTracklet is a 17-dim tuple 28 | # ---> newFrame is a 7-dim tuple 29 | 30 | # First, rotate the oldTracklet, i.e. split the oldTracklet into 7-dim head and 10-dim tail 31 | # rotate the tail by poping the last two, and insert(0) in the X,Y just tracked, i.e. newFrame[2,3] 32 | # Second, put the newFrame X, Y into the odlTracklet 2 & 3dim 33 | head = list(oldTracklet[:7]) 34 | tail = list(oldTracklet[7:]) 35 | for i in range(2): 36 | tail.pop() 37 | tail.insert(0, newFrame[3 - i]) 38 | head[5 - i] = newFrame[5 - i] 39 | head[3 - i] = newFrame[3 - i] 40 | 41 | output = head + tail 42 | return output 43 | 44 | 45 | def DistanceMeasure(prev_coord, current_coord, FrameW, FrameH): 46 | x_dist = abs(prev_coord[0] - current_coord[0]) 47 | y_dist = abs(prev_coord[1] - current_coord[1]) 48 | 49 | x_dist = float(x_dist) / float(FrameW) 50 | y_dist = float(y_dist) / float(FrameH) 51 | 52 | rst = x_dist + y_dist 53 | 54 | return rst 55 | 56 | 57 | def Tracker(Amatrix, Prev_frame, PrevIDs, CurrentDRs, BirthBuffer, DeathBuffer, IDnow, FrameW, FrameH, NIDnow): 58 | ''' 59 | This Tracker function output the CurrentDRs tensor as the thing to be wrote into the txt; 60 | CurrentDRs input as all TrID = -1; 61 | Then: 62 | ---> 1-to-1, write the trajID into the CurrentDRs correspondingly; 63 | ---> Birth, intitialize a new temporal ID in the birth buffer for that DR; ---> Need a Birth Buffer: buffer is a list of tuples 64 | ---> Death, put the dead trajIDs into the death list; ---> Nedd a Death Buffer: buffer is a list of tuples 65 | ''' 66 | 67 | ''' 68 | ---> Trajectories: Len = 5, a list of 7 + 2 * 5 = 17-dim tuple: [FrID, TrID, X, Y, W, H, X1, Y1, X2, Y2, ...., X5, Y5], W & H remains; just trajectories in the prev_frame 69 | ---> PrevIDs: the IDs corresponds to the Trajectories, who index the IDs in the AMatrix rows 70 | ---> CurrentDRs: a list of 7-dim tuple: [FrID, -1, X, Y, W, H] 71 | ---> BirthBuffer: a list of 7-dim tuple: [FrID, 0, X, Y, W, H] 72 | ---> DeathBuffer: a list of 7-dim tuple: [FrID, 'D', X, Y, W, H] for flags, use 'DD', 'DDD, etc 73 | 74 | RETURN ---> output, which is the ID assigned CurrentDRs 75 | ''' 76 | 77 | prev_num = Amatrix.shape[1] # Amatrix.shape[0] is the batch size 78 | next_num = Amatrix.shape[2] 79 | 80 | DROccupiedIndexes = [] 81 | ConfirmedBirthRowID = [] 82 | ToDo_DoomedTrajID = [] # Confirmed Death Trajs to be taken out of prev_frame after each Tracker() iteration 83 | Fail2ConfirmedBirthKill = [] 84 | ''' 85 | The New Reading Matrix Logic 86 | ''' 87 | if 'New Reading Logic for the AMatrix': 88 | ''' 89 | set a configurable threshold params, reading by iteratively localizing the largest item in the matrix, 90 | Then finish an association, set the associated row and column to a very small negative value, 91 | Then repeat, until the max value in the matrix is below Threshold, then the left un-associated left & 92 | column to reason death & birth 93 | ''' 94 | Th = 0 95 | dist_Th = 0.05 96 | DeathRows = [i for i in range(prev_num)] 97 | BirthCols = [j for j in range(next_num)] 98 | 99 | # Start the main loop here, may never reach the upper bound prev_num * next_num 100 | for i in range(prev_num * next_num): 101 | # print '-------------------------- i = %d'%i 102 | # compute the row and column index for the max value in a matrix 103 | # As AMatrix is a list of tuples, then has to find the max value in each list (row) 104 | # Then find the max column index in that row 105 | for k in range(prev_num): 106 | # print '--------------------------- prev_num %d'%prev_num 107 | # print '--------------------------- next_num %d'%next_num 108 | # print '------------------------------ k = %d'%k 109 | row_maxValue = [] # the index of row_maxValue corresponds to the row index of AMatrix 110 | [row_maxValue.append(max(Amatrix[0, j])) for j in range(prev_num)] 111 | max_rowValue = max(row_maxValue) 112 | 113 | # 1-to-1 Associations 114 | if max_rowValue > Th: 115 | ''' 116 | Cases of 1-to-1 Associations 117 | ''' 118 | max_rowIndex = row_maxValue.index(max_rowValue) 119 | max_colIndex = list(Amatrix[0, max_rowIndex]).index(max_rowValue) 120 | 121 | # Mark these associated row and col out of the DeathRows and BirthCols 122 | DeathRows.remove(max_rowIndex) 123 | BirthCols.remove(max_colIndex) 124 | 125 | print("Cases of 1-to-1 Association, the selected max row and col index:") 126 | print [max_rowIndex, max_colIndex] 127 | 128 | print("DeathRows:") 129 | print DeathRows 130 | 131 | print("BirthCols:") 132 | print BirthCols 133 | 134 | ''' 135 | # ----------------------------- Case 1: Normal 1-to-1 association 136 | ''' 137 | if PrevIDs[max_rowIndex] > -10 and PrevIDs[max_rowIndex] != -1: 138 | associated_DrIndex = max_colIndex 139 | 140 | ''' 141 | ------- ------- ------- Distance Threshold 142 | ''' 143 | prev_wh = Prev_frame[max_rowIndex][2:4] 144 | current_wh = CurrentDRs[associated_DrIndex][2:4] 145 | dist1 = DistanceMeasure(prev_wh, current_wh, FrameW, FrameH) 146 | 147 | if dist1 <= dist_Th: 148 | CurrentDRs[associated_DrIndex][1] = Prev_frame[max_rowIndex][1] 149 | print( 150 | "Normal 1-to-1 Association at %d row, traj ID paased onto next frame %d DR_Index, is %d") % ( 151 | max_rowIndex, associated_DrIndex, CurrentDRs[associated_DrIndex][1]) 152 | # Check if this association revive anybody in the DeathBuffer. i.e. if the associated ID is someone in the DeathBuffer, then it is revived, removed from DeathBuffer. 153 | # the just associated trajID here is Prev_frame[i][1] 154 | allTrajIDInDeathBuffer = [DeathBuffer[i][0] for i in range(len(DeathBuffer))] 155 | if Prev_frame[k][1] in allTrajIDInDeathBuffer: 156 | reviveTrajIndex = allTrajIDInDeathBuffer.index(Prev_frame[k][1]) 157 | revivedID = DeathBuffer[reviveTrajIndex][0] 158 | print 'Trajectory %d is about to be revived from the DeathBuffer' % revivedID 159 | DeathBuffer.remove(DeathBuffer[reviveTrajIndex]) 160 | 161 | ''' 162 | # --------------------------- Case 2: Birth confirmation 1-to-1 association 163 | ''' 164 | if PrevIDs[max_rowIndex] <= -10: 165 | print 'Birth confirmation 1-to-1 Association at %d row' % max_rowIndex 166 | associated_DrIndex = max_colIndex 167 | 168 | ''' 169 | ------- ------- ------- Distance Threshold 170 | ''' 171 | prev_wh = Prev_frame[max_rowIndex][2:4] 172 | current_wh = CurrentDRs[associated_DrIndex][2:4] 173 | dist2 = DistanceMeasure(prev_wh, current_wh, FrameW, FrameH) 174 | 175 | if dist2 <= dist_Th: 176 | CurrentDRs[associated_DrIndex][1] = IDnow 177 | print '$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$ A New Target id as %d is generated $$$$$$$$$$$$$$$$$$$$$' % IDnow 178 | IDnow = IDnow + 1 179 | 180 | ConfirmedBirthRowID.append(PrevIDs[max_rowIndex]) 181 | 182 | elif dist2 > dist_Th: 183 | Fail2ConfirmedBirthKill.append(PrevIDs[max_rowIndex]) 184 | # Conservation Check here, once a column is taken, it cannot be taken again, so remove it from the AMatrix 185 | if associated_DrIndex in DROccupiedIndexes: 186 | print 'Conservation Constraint violated at row %d when doing Birth confirmation 1-to-1 Association' % i 187 | else: 188 | DROccupiedIndexes.append(associated_DrIndex) 189 | 190 | # change the row & col in this instance of 1-to-1 association to be all < Th 191 | Amatrix[0][max_rowIndex, :] = Th 192 | Amatrix[0][:, max_colIndex] = Th 193 | # print 'Hi' 194 | 195 | # B & D for cols and rows 196 | if max_rowValue <= Th: 197 | ''' 198 | The Rest un-associated rows are death, columns are birth 199 | ''' 200 | if not DeathRows: 201 | # print '---> Tracker: No Death at this association' 202 | pass 203 | if not BirthCols: 204 | # print '---> Tracker: No Birth at this association' 205 | pass 206 | 207 | ''' 208 | *************************** Death handles 209 | # for a just dead target, if it is within the death window, just prolong it, do not report it to be death 210 | # within the death windows frames. 211 | ''' 212 | if DeathRows: 213 | for t in range(len(DeathRows)): 214 | if DeathRows[t] != -10000: 215 | DeadTrajID = prev_frame[DeathRows[t]][1] 216 | 217 | ''' 218 | If a Death is ID <= -10, then terminate it right away 219 | The Birth if not confirmed in the next frame, then it should be terminated 220 | This termination is handled by the death handle altogether 221 | ''' 222 | if DeadTrajID <= -10: 223 | #DeathBuffer.append([DeadTrajID, 5]) 224 | Fail2ConfirmedBirthKill.append(DeadTrajID) 225 | print '_+_+_+_+_+_+_+_+_+_+_+ In death handles, a temporal birth %d fail to confirm, about to be killed right away in Fail2ConfirmedBirthKill'%DeadTrajID 226 | 227 | ''' 228 | Normal death, update the death flag, then invoke motion prediciton 229 | ''' 230 | if DeadTrajID > -10: 231 | ''' 232 | # Step 1: Check and update the death counter of this dead trajectory in the DeathBuffer 233 | ''' 234 | allTrajIDInDeathBuffer = [DeathBuffer[i][0] for i in range(len(DeathBuffer))] 235 | 236 | if DeadTrajID == 1: 237 | print 'Gotcha' 238 | 239 | if DeadTrajID not in allTrajIDInDeathBuffer: 240 | DeathBuffer.append([DeadTrajID, 0]) 241 | else: 242 | # find the index of the DeadID, update its death counter by + 1 243 | AllDeadIDs = [DeathBuffer[i][0] for i in range(len(DeathBuffer))] 244 | temp_index = AllDeadIDs.index(DeadTrajID) 245 | DeathBuffer[temp_index][1] += 1 246 | 247 | # Step 2: As this trajectory is dead, i.e. fail to associate to a DR to prolong itself in this frame, then use a Motion Prediction 248 | # to propagate into a new dummy bbox to add into the result CurrentDRs 249 | DeadTraj = Prev_frame[DeathRows[t]][:] 250 | 251 | DeathRows[t] = -10000 252 | 253 | # Do the motion prediction using the tracklets in DeadTraj to predict a [X, Y], together with the original W, H, form a new DR wit ID to add into the CurrentDRs 254 | ''' 255 | -------- MOTION PREDICTION HERE ---------- 256 | ''' 257 | if 'Motion Prediction for Dummies': 258 | print '----------------------------------------------- A target %d is not associated in this frame Motion Prediction at play now.' % DeadTrajID 259 | temp_tracklet = list(DeadTraj[2:4]) + list(DeadTraj[ 260 | 7:15]) # temp_tracklet is to be feed into the LSTM, a list of 5 pairs of (x, y) 261 | 262 | dummyWH = list(DeadTraj[4:6]) 263 | # Establish a simple Linear model here, which is about to be substitued by a LSTM 264 | temp_x_seq = [temp_tracklet[i] for i in range(0, len(temp_tracklet), 2)] 265 | temp_y_seq = [temp_tracklet[j] for j in range(1, len(temp_tracklet), 2)] 266 | 267 | predicted_x = temp_x_seq[1] + temp_x_seq[1] - temp_x_seq[2] 268 | predicted_y = temp_y_seq[1] + temp_y_seq[1] - temp_y_seq[2] 269 | 270 | dummy = list(DeadTraj[:2]) + list([predicted_x, predicted_y]) + list(dummyWH) + [0] 271 | 272 | if predicted_x < 0 or predicted_y < 0: 273 | print 'Im the Storm' 274 | 275 | 276 | 277 | dummy[0] += 1 278 | dummy = np.array(dummy) 279 | 280 | CurrentDRs.append(dummy) 281 | 282 | else: 283 | pass 284 | 285 | ''' 286 | *************************** Birth handles 287 | ''' 288 | if BirthCols: 289 | for p in range(len(BirthCols)): 290 | if BirthCols[p] != -10000: 291 | Birth_DrIndex = BirthCols[p] 292 | BirthCols[p] = -10000 293 | CurrentDRs[Birth_DrIndex][1] = NIDnow # The new birth is temporally IDed as 0, then moved into the BirthBuffer 294 | print '_+_+_+_+_+_+_+_+_+_+_+_+_+_+ One temporal birth found, assigned NID is %d' %NIDnow 295 | NIDnow = NIDnow - 1 296 | # BirthBuffer deprecated 297 | # BirthBuffer.append(CurrentDRs[Birth_DrIndex][:]) 298 | 299 | print 'Birth at %d column' % Birth_DrIndex 300 | 301 | else: 302 | pass 303 | 304 | else: 305 | break 306 | 307 | # -------------------------------------------------------- Auxiliary Per-Frame Operations ---------------------------------------------------- # 308 | ''' 309 | -------------------------- Death Counter Check and dead Bbox termination --------------------- 310 | ''' 311 | # Check Real Death (them IDs in the DeathBuffer that with death counter up to DeathWindow) for termination 312 | AllDeathCounter = [DeathBuffer[i][1] for i in range(len(DeathBuffer))] 313 | DoomedIDs = [] 314 | DeathBufferRemoveOnSpot = [] 315 | if DeathWindow in AllDeathCounter: 316 | for i in range(len(DeathBuffer)): 317 | if DeathBuffer[i][1] >= DeathWindow: 318 | print 'One Trajectory %d meets the DeathWindow = %d, and is about to be terminate' % (DeathBuffer[i][0], DeathWindow) 319 | DoomedIDs.append(DeathBuffer[i][0]) 320 | #print ' ! ! ! ! ! ! ! ! ! ! ! ! ! Trajectory %d is being terminated:' % DeathBuffer[i][0] 321 | DeathBufferRemoveOnSpot.append([DeathBuffer[i][0], DeathWindow]) 322 | #DeathBuffer.remove([DeathBuffer[i][0], DeathWindow]) 323 | #print 'IM the Storm' 324 | 325 | # Remove those trajs that have met DeathWindow in this frame out of the DeathBuffer 326 | for i in range(len(DeathBufferRemoveOnSpot)): 327 | DeathBuffer.remove(DeathBufferRemoveOnSpot[i]) 328 | 329 | #print 'line 312' 330 | 331 | # -------------------- How to terminate all the trajectories with the DoomedIDs: i.e. remove it from the prev_frame 332 | # DoomedTrajsIndex = [] 333 | # for i in range(len(prev_frame)): 334 | # if prev_frame[i][1] in DoomedIDs: 335 | # DoomedTrajsIndex.append(i) 336 | 337 | # for i in range(len(DoomedTrajsIndex)): 338 | # print ' ! ! ! ! ! ! ! ! ! ! ! ! ! Trajectory %d is being terminated:' % prev_frame[DoomedTrajsIndex[i]][1] 339 | # ToDo_DoomedTrajID.append(DoomedTrajsIndex[i]) 340 | # list(prev_frame).pop(DoomedTrajsIndex[i]) 341 | 342 | # # -------------------- Conservation Check again 343 | # if DRavaliableIndexes != DRavaliableIndexes2: 344 | # print 'Conservation check violated by checking the birth by scenario and birth by removing all occupied DRs.' 345 | 346 | ''' 347 | Post-Tracking processing, do the rotation to get a 17_dim output 348 | Now CurrentDRs hold 7-dim where the TrID has been updated 349 | Prolong & rotate the prev_frame with CurrentDRs by associating with TrID 350 | ''' 351 | # ------------------------------- Case 1: 1-to-1 Association 352 | # --------------- Rotate & update the prev_frame 353 | for i in range(len(prev_frame)): 354 | for j in range(len(CurrentDRs)): 355 | if prev_frame[i][1] == CurrentDRs[j][1]: 356 | prev_frame[i] = TrackletRotation(CurrentDRs[j], prev_frame[i]) 357 | 358 | # ------------------------------- Case 2: Birth Confirmation 359 | # ------------ Birth Confirmation now is in the CurrentDRs with a newly assigned ID, need to find them, and pad 0 360 | ConfirmedBirthProlonged = [] 361 | allPrevIDs = [prev_frame[i][1] for i in range(len(prev_frame))] 362 | 363 | for k in range(len(CurrentDRs)): 364 | ''' 365 | ---- The padding for the confirmed birth ----- 366 | ''' 367 | if CurrentDRs[k][1] not in allPrevIDs and CurrentDRs[k][1] > -10 and CurrentDRs[k][1] != -1: 368 | padding = np.zeros(2 * TrackletLen) 369 | for i in range(len(padding)): 370 | if i % 2 == 0: 371 | padding[i] = CurrentDRs[k][2] 372 | if i % 2 != 0: 373 | padding[i] = CurrentDRs[k][3] 374 | 375 | CurrentDRs[k] = np.concatenate((CurrentDRs[k], padding)) 376 | 377 | ConfirmedBirthProlonged.append(CurrentDRs[k]) 378 | 379 | ''' 380 | ---- The padding for the newly birth -------- 381 | ''' 382 | if CurrentDRs[k][1] not in allPrevIDs and CurrentDRs[k][1] <= -10 and CurrentDRs[k][1] != -1: 383 | padding = np.zeros(2 * TrackletLen) 384 | CurrentDRs[k] = np.concatenate((CurrentDRs[k], padding)) 385 | 386 | ConfirmedBirthProlonged.append(CurrentDRs[k]) 387 | 388 | if ConfirmedBirthProlonged: 389 | # print 'ConfirmedBirthProlonged size:' 390 | # print len(ConfirmedBirthProlonged[0]) 391 | output = np.concatenate((Prev_frame, ConfirmedBirthProlonged)) 392 | else: 393 | output = Prev_frame 394 | 395 | return output, BirthBuffer, DeathBuffer, IDnow, ConfirmedBirthRowID, DoomedIDs, NIDnow, Fail2ConfirmedBirthKill 396 | 397 | 398 | if __name__ == "__main__": 399 | 400 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 401 | 402 | data_framewise = np.loadtxt("test/MOT17-01-SDP/det.txt") 403 | data_framewise = MakeCell(data_framewise) 404 | manual_init = np.loadtxt("test/MOT17-01-SDP/res.txt") 405 | manual_init = MakeCell(manual_init) 406 | 407 | batchSize = 1 408 | TrackletLen = 5 409 | DeathWindow = 3 410 | V_len = len(data_framewise) 411 | Bcount = 0 412 | Dcount = 0 413 | FrameWidth = 1920 414 | FrameHeight = 1080 415 | BirthBuffer = [] 416 | DeathBuffer = [] 417 | prev_frame = manual_init[4] 418 | IDnow = 20 419 | NIDnow = -10 420 | 421 | padding = np.zeros((prev_frame.shape[0], 2 * TrackletLen)) 422 | prev_frame = np.concatenate((prev_frame, padding), axis=1) 423 | 424 | for i in range(prev_frame.shape[0]): 425 | identity = prev_frame[i, 1] 426 | for frame in range(5): 427 | data = manual_init[4 - frame] 428 | if identity in data[:, 1]: 429 | data_ = data[data[:, 1] == identity].squeeze() 430 | prev_frame[i, 7 + 2 * frame], prev_frame[i, 8 + 2 * frame] = data_[2], data_[3] 431 | else: 432 | prev_frame[i, 7 + 2 * frame], prev_frame[i, 8 + 2 * frame] = prev_frame[i, 5 + 2 * frame], prev_frame[ 433 | i, 6 + 2 * frame] 434 | 435 | prev_frame = list(prev_frame) 436 | PrevIDs = [prev_frame[i][1] for i in range(len(prev_frame))] 437 | 438 | res_path = "test/MOT17-01-SDP/{}.txt".format(time_for_file()) 439 | log_path = "test/MOT17-01-SDP/log.log" 440 | buffer_path = "test/MOT17-01-SDP/buffer.txt" 441 | if not osp.exists(res_path): 442 | os.mknod(res_path) 443 | res_init = np.loadtxt("test/MOT17-01-SDP/res.txt") 444 | 445 | generator = TestGenerator(res_path, entirety=False) 446 | 447 | with open(res_path, "a") as txt: 448 | for i in range(len(res_init)): 449 | temp = np.array(res_init[i]).reshape([1, -1]) 450 | np.savetxt(txt, temp[:, 0:7], fmt='%12.3f') 451 | 452 | "---------------------- start tracking ----------------------" 453 | for v in range(4, V_len - 1): 454 | # v denotes the previous frame and start with 0 455 | prev = v # prev: 1~TotalFrame-1 456 | next = prev + 1 457 | next_frame = list(data_framewise[next]) 458 | 459 | FrID4now = next_frame[0][0] 460 | 461 | print ("------------------------------------------> Tracking Frame %d" % FrID4now) 462 | 463 | CurrentDRs = next_frame 464 | PrevIDs = [prev_frame[i][1] for i in range(len(prev_frame))] 465 | 466 | print("---> MAIN: Input to Model --> CurrentDRs") 467 | temp2 = [CurrentDRs[i][1] for i in range(len(CurrentDRs))] 468 | print(temp2) 469 | 470 | print("---> MAIN: Input to Model --> prev_frame") 471 | temp1 = [prev_frame[i][1] for i in range(len(prev_frame))] 472 | print(temp1) 473 | 474 | Amatrix = generator(SeqID=0, frame=next) 475 | Amatrix = Amatrix.unsqueeze(dim=0).cpu().numpy() 476 | 477 | 478 | for i in range(len(CurrentDRs)): 479 | CurrentDRs[i][1] = -1 480 | 481 | ''' 482 | Initialize the trajectory sequences by appending the trajectories with 0s in the back 483 | ''' 484 | prev_FrID = prev # This FrID is only used to load in the frames from file 485 | next_FrID = prev_FrID + 1 486 | 487 | # Update the FrID for prev_frame 488 | for i in range(len(prev_frame)): 489 | prev_frame[i][0] = FrID4now 490 | 491 | ToPlot, Bbuffer, Dbuffer, IDnow_out, ConfirmedBirthRowID_out, DoomedIDs_out, NIDnow_out, Fail2ConfirmedBirthKill_out = Tracker( 492 | Amatrix=Amatrix, Prev_frame=prev_frame, PrevIDs=PrevIDs, CurrentDRs=CurrentDRs, BirthBuffer=BirthBuffer, 493 | DeathBuffer=DeathBuffer, IDnow=IDnow, FrameW=FrameWidth, FrameH=FrameHeight, NIDnow= NIDnow) 494 | 495 | ''' 496 | Also terminate the Doomed trajs out of the DeathBuffer 497 | ''' 498 | # for i in range(len(Dbuffer)): 499 | # if Dbuffer[i][0] in ToDo_DoomedTrajID_out: 500 | # del Dbuffer[i] 501 | 502 | ''' 503 | When it is negative 504 | ''' 505 | for i in range(len(ToPlot)): 506 | if ToPlot[i][2] < 0: 507 | print 'Now negative X, Y' 508 | 509 | ''' 510 | ------------------------------------------------------------- Clean out the doomed ID that is killed 511 | ''' 512 | DoomedTraj2CleanIndex = [] 513 | for i in range(len(ToPlot)): 514 | if ToPlot[i][1] in DoomedIDs_out: 515 | DoomedTraj2CleanIndex.append(i) 516 | 517 | DoomedTraj2CleanIndex.sort(reverse=True) 518 | for i in range(len(DoomedTraj2CleanIndex)): 519 | print '---> MAIN: _+_+_+_+_+_+_+_ Popping DoomedTraj ID is %d'%ToPlot[DoomedTraj2CleanIndex[i]][1] 520 | ToPlot = np.delete(ToPlot, DoomedTraj2CleanIndex[i], 0) 521 | 522 | ''' 523 | ----------------------------------------------------- Clean out the temporal birth that have been confirmed in this frame 524 | ''' 525 | ConfirmedBirthRowIndex = [] 526 | for i in range(len(ToPlot)): 527 | if ToPlot[i][1] in ConfirmedBirthRowID_out: 528 | ConfirmedBirthRowIndex.append(i) 529 | 530 | ConfirmedBirthRowIndex.sort(reverse= True) 531 | for i in range(len(ConfirmedBirthRowIndex)): 532 | print("---> MAIN: _+_+_+_+_+_+_+_ Popping the temporal birth out when a birth is confirmed, ID is %d")%ToPlot[ConfirmedBirthRowIndex[i],1] 533 | ToPlot = np.delete(ToPlot, ConfirmedBirthRowIndex[i], axis=0) 534 | 535 | ''' 536 | ---------------------------------------------------- Fail2ConfirmedBirthKill_out kill right away 537 | ''' 538 | Fail2ConfirmedBirthKill_out_index = [] 539 | for i in range(len(ToPlot)): 540 | if ToPlot[i][1] in Fail2ConfirmedBirthKill_out: 541 | Fail2ConfirmedBirthKill_out_index.append(i) 542 | 543 | Fail2ConfirmedBirthKill_out_index.sort(reverse=True) 544 | for i in range(len(Fail2ConfirmedBirthKill_out_index)): 545 | print '---> MAIN: _+_+_+_+_+_+_+_ ToPlot delete Fail2ConfirmedBirthKill, ID is %d'%ToPlot[Fail2ConfirmedBirthKill_out_index[i],1] 546 | ToPlot = np.delete(ToPlot, Fail2ConfirmedBirthKill_out_index[i], axis = 0) 547 | 548 | 549 | print("---> MAIN: Tracking output--> ToPlot") 550 | temp3 = [ToPlot[i][1] for i in range(len(ToPlot))] 551 | print temp3 552 | 553 | temp6 = [ToPlot[i][2] for i in range(len(ToPlot))] 554 | for i in range(len(temp6)): 555 | if temp6[i] <= 0: 556 | print 'Im the storm' 557 | 558 | print("---> MAIN: Tracking output--> DeathBuffer") 559 | temp1 = [Dbuffer[i] for i in range(len(Dbuffer))] 560 | print temp1 561 | 562 | # print("---> MAIN: Tracking output--> BirthBuffer") 563 | # temp4 = [BirthBuffer[i][1] for i in range(len(BirthBuffer))] 564 | # print temp4 565 | # 566 | print("---> MAIN: Tracking output--> DeathBuffer") 567 | temp4 = [DeathBuffer[i][0] for i in range(len(DeathBuffer))] 568 | temp5 = [DeathBuffer[i][1] for i in range(len(DeathBuffer))] 569 | print temp4,temp5 570 | 571 | 572 | BirthBuffer = Bbuffer 573 | DeathBuffer = Dbuffer 574 | IDnow = IDnow_out 575 | NIDnow = NIDnow_out 576 | prev_frame = ToPlot 577 | 578 | print("Writing") 579 | with open(res_path, "a") as txt: 580 | for i in range(len(ToPlot)): 581 | temp = np.array(ToPlot[i]).reshape([1, -1]) 582 | np.savetxt(txt, temp[:, 0:7], fmt='%12.3f') 583 | 584 | with open(log_path, "a") as log: 585 | for i in range(len(ToPlot)): 586 | temp = np.array(ToPlot[i]).reshape([1, -1]) 587 | np.savetxt(log, temp, fmt='%12.3f') 588 | 589 | with open(buffer_path, "a") as DeathBufferTXT: 590 | for i in range(len(DeathBuffer)): 591 | temp = np.array(DeathBuffer[i]).reshape([1, -1]) 592 | np.savetxt(DeathBufferTXT, temp, fmt="%12.3f", delimiter=",") 593 | 594 | print("Finish Tracking Frame %d" % v) 595 | print("Current Initialized ID is up to: %d" % IDnow) 596 | -------------------------------------------------------------------------------- /config.yml: -------------------------------------------------------------------------------- 1 | # ------------------------- Options ------------------------- 2 | mode : 0 # 0 for debug 3 | description : net_1024 4 | device : "3" 5 | epochs : 40000 6 | gammas : [0.1, 0.1, 0.1] 7 | schedule : [10000, 20000, 30000] 8 | learning_rate : 0.001 9 | optimizer : Adam 10 | entirety : True 11 | model : net_1024 12 | # ----------------------- God Save Me ----------------------- 13 | save_model : True 14 | dampening : 0.9 15 | lr_patience : 10 16 | momentum : 0.9 17 | decay : 0.0005 18 | start_epoch : 0 19 | print_freq : 10 20 | checkpoint : 10 21 | n_threads : 2 22 | result : result 23 | # --------------------------- End --------------------------- -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : main.py 3 | # @Author : Peizhao Li 4 | # @Contact : peizhaoli05@gmail.com 5 | # @Date : 2018/9/27 6 | 7 | import os.path as osp 8 | from utils import * 9 | from train import train_EmbeddingNet 10 | from train import train_LSTM 11 | from train import train_FuckUpNet 12 | from train import train_net_1024 13 | import sys 14 | from Generator import Generator 15 | import numpy as np 16 | import random 17 | 18 | seed = 1 19 | torch.backends.cudnn.enabled = True 20 | torch.backends.cudnn.deterministic = True 21 | torch.backends.cudnn.benchmark = True 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | np.random.seed(seed) 25 | random.seed(seed) 26 | 27 | config = osp.join(os.path.abspath(os.curdir), "config.yml") 28 | parser, settings_show = Config(config) 29 | os.environ["CUDA_VISIBLE_DEVICES"] = parser.device 30 | 31 | if parser.mode == 0: 32 | log_path = osp.join(parser.result, 'debug') 33 | else: 34 | log_path = osp.join(parser.result, '{}-{}'.format(time_for_file(), parser.description)) 35 | if not osp.exists(log_path): 36 | os.mkdir(log_path) 37 | log = open(osp.join(log_path, 'log.log'), 'w') 38 | 39 | print_log("python version : {}".format(sys.version.replace('\n', ' ')), log) 40 | print_log("torch version : {}".format(torch.__version__), log) 41 | print_log("cudnn version : {}".format(torch.backends.cudnn.version()), log) 42 | for idx, data in enumerate(settings_show): 43 | print_log(data, log) 44 | 45 | generator = Generator(entirety=parser.entirety) 46 | 47 | if parser.model == "EmbeddingNet": 48 | train_EmbeddingNet.train(parser, generator, log, log_path) 49 | elif parser.model == "lstm": 50 | train_LSTM.train(parser, generator, log, log_path) 51 | elif parser.model == "FuckUpNet": 52 | train_FuckUpNet.train(parser, generator, log, log_path) 53 | elif parser.model == "net_1024": 54 | train_net_1024.train(parser, generator, log, log_path) 55 | else: 56 | raise NotImplementedError 57 | 58 | log.close() 59 | -------------------------------------------------------------------------------- /model/EmbeddingNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : EmbeddingNet.py 3 | # @Author : Peizhao Li 4 | # @Contact : lipeizhao1997@gmail.com 5 | # @Date : 2018/10/21 6 | 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class LSTM(nn.Module): 13 | 14 | def __init__(self, hidden_size): 15 | super(LSTM, self).__init__() 16 | self.lstm = nn.LSTM(input_size=2, hidden_size=hidden_size, num_layers=1, batch_first=True) 17 | 18 | def forward(self, x): 19 | x, _ = self.lstm(x) 20 | x = x[:, -1, :].clone() 21 | 22 | return x 23 | 24 | 25 | class ANet(nn.Module): 26 | 27 | def __init__(self): 28 | super(ANet, self).__init__() 29 | self.ndf = 32 30 | 31 | self.conv1 = nn.Conv2d(3, self.ndf, kernel_size=3, stride=1, padding=1, bias=False) 32 | self.conv2 = nn.Conv2d(self.ndf, int(self.ndf * 1.5), kernel_size=3, stride=1, padding=1, bias=False) 33 | self.conv3 = nn.Conv2d(int(self.ndf * 1.5), self.ndf * 2, kernel_size=3, stride=1, padding=1, bias=False) 34 | self.conv4 = nn.Conv2d(self.ndf * 2, self.ndf * 4, kernel_size=3, stride=1, padding=1, bias=False) 35 | 36 | for m in self.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | nn.init.xavier_uniform_(m.weight) 39 | 40 | def forward(self, x): 41 | x = F.max_pool2d(self.conv1(x), 2) 42 | x = F.leaky_relu(x) 43 | x = F.max_pool2d(self.conv2(x), 2) 44 | x = F.leaky_relu(x) 45 | x = F.max_pool2d(self.conv3(x), 2) 46 | x = F.leaky_relu(x) 47 | x = F.max_pool2d(self.conv4(x), 2) 48 | x = F.leaky_relu(x) 49 | 50 | x = F.avg_pool2d(x, kernel_size=(5, 2)) 51 | x = x.view(x.size(0), -1) 52 | 53 | return x 54 | 55 | 56 | class EmbeddingNet_legacy(nn.Module): 57 | 58 | def __init__(self): 59 | super(EmbeddingNet_legacy, self).__init__() 60 | self.outsize = 128 + 2 61 | 62 | self.ANet = ANet() 63 | self.lstm = LSTM(hidden_size=2) 64 | 65 | self.fc1 = nn.Linear(self.outsize, int(self.outsize / 2)) 66 | self.fc2 = nn.Linear(int(self.outsize / 2), int(self.outsize / 4)) 67 | self.fc3 = nn.Linear(int(self.outsize / 4), 2) 68 | 69 | for m in self.modules(): 70 | if isinstance(m, nn.Linear): 71 | nn.init.normal_(m.weight, mean=0, std=0.01) 72 | nn.init.constant_(m.bias, 0.1) 73 | 74 | self.drop1 = nn.Dropout(p=0.3) 75 | self.drop2 = nn.Dropout(p=0.2) 76 | 77 | def forward(self, pre_crop, cur_crop, pre_coord, cur_coord): 78 | pre_crop = self.ANet(pre_crop) 79 | cur_crop = self.ANet(cur_crop) 80 | 81 | pre_coord = self.lstm(pre_coord) 82 | 83 | pre = torch.cat((pre_crop, pre_coord), dim=1) 84 | cur = torch.cat((cur_crop, cur_coord), dim=1) 85 | 86 | x = F.leaky_relu_(self.fc1(pre.add_(-cur))) 87 | x = self.drop1(x) 88 | x = F.leaky_relu_(self.fc2(x)) 89 | x = self.drop2(x) 90 | x = self.fc3(x) 91 | 92 | return x, pre, cur 93 | 94 | 95 | class EmbeddingNet_train(nn.Module): 96 | 97 | def __init__(self): 98 | super(EmbeddingNet_train, self).__init__() 99 | self.ANet = ANet() 100 | self.lstm = LSTM(hidden_size=16) 101 | self.fc = nn.Linear(2, 16) 102 | 103 | for m in self.modules(): 104 | if isinstance(m, nn.Linear): 105 | nn.init.normal_(m.weight, mean=0, std=0.01) 106 | nn.init.constant_(m.bias, 0.1) 107 | 108 | self.crop_fc1 = nn.Linear(128, 64) 109 | self.crop_fc2 = nn.Linear(64, 32) 110 | self.crop_fc3 = nn.Linear(32, 2) 111 | 112 | self.coord_fc1 = nn.Linear(16, 8) 113 | self.coord_fc2 = nn.Linear(8, 2) 114 | 115 | self.com = nn.Linear(4, 2) 116 | 117 | def forward(self, pre_crop, cur_crop, pre_coord, cur_coord): 118 | pre_crop = self.ANet(pre_crop) 119 | cur_crop = self.ANet(cur_crop) 120 | pre_coord = self.lstm(pre_coord) 121 | cur_coord = self.fc(cur_coord) 122 | 123 | crop = F.leaky_relu(self.crop_fc1(pre_crop.add_(-cur_crop))) 124 | crop = F.leaky_relu(self.crop_fc2(crop)) 125 | crop = self.crop_fc3(crop) 126 | 127 | coord = F.leaky_relu(self.coord_fc1(pre_coord.add_(-cur_coord))) 128 | coord = self.coord_fc2(coord) 129 | 130 | com = torch.cat((crop, coord), dim=1) 131 | com = self.com(com) 132 | 133 | pre_feature = torch.cat((pre_crop, pre_coord), dim=1) 134 | cur_feature = torch.cat((cur_crop, cur_coord), dim=1) 135 | 136 | return com, pre_feature, cur_feature 137 | 138 | 139 | class EmbeddingNet(nn.Module): 140 | 141 | def __init__(self): 142 | super(EmbeddingNet, self).__init__() 143 | self.ANet = ANet() 144 | self.lstm = LSTM(hidden_size=16) 145 | self.fc = nn.Linear(2, 16) 146 | 147 | for m in self.modules(): 148 | if isinstance(m, nn.Linear): 149 | nn.init.normal_(m.weight, mean=0, std=0.01) 150 | nn.init.constant_(m.bias, 0.1) 151 | 152 | self.crop_fc1 = nn.Linear(128, 64) 153 | self.crop_fc2 = nn.Linear(64, 32) 154 | self.crop_fc3 = nn.Linear(32, 2) 155 | 156 | self.coord_fc1 = nn.Linear(16, 8) 157 | self.coord_fc2 = nn.Linear(8, 2) 158 | 159 | self.com = nn.Linear(4, 2) 160 | 161 | def forward(self, pre_crop, cur_crop, pre_coord, cur_coord): 162 | pre_crop = self.ANet(pre_crop) 163 | cur_crop = self.ANet(cur_crop) 164 | pre_coord = self.lstm(pre_coord) 165 | cur_coord = self.fc(cur_coord) 166 | 167 | crop = F.leaky_relu(self.crop_fc1(pre_crop.add_(-cur_crop))) 168 | crop = F.leaky_relu(self.crop_fc2(crop)) 169 | crop = self.crop_fc3(crop) 170 | 171 | coord = F.leaky_relu(self.coord_fc1(pre_coord.add_(-cur_coord))) 172 | coord = self.coord_fc2(coord) 173 | 174 | com = torch.cat((crop, coord), dim=1) 175 | com = self.com(com) 176 | 177 | pre_feature = torch.cat((pre_crop, pre_coord), dim=1) 178 | cur_feature = torch.cat((cur_crop, cur_coord), dim=1) 179 | 180 | return com, pre_feature, cur_feature 181 | -------------------------------------------------------------------------------- /model/FuckUpNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : FuckUpNet.py 3 | # @Author : Peizhao Li 4 | # @Contact : lipeizhao1997@gmail.com 5 | # @Date : 2018/10/13 6 | 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | from EmbeddingNet import EmbeddingNet 11 | from EmbeddingNet import EmbeddingNet_train 12 | from GCN import GCN 13 | from GCN import fuckupGCN 14 | 15 | 16 | class FuckUpNet(nn.Module): 17 | 18 | def __init__(self): 19 | super(FuckUpNet, self).__init__() 20 | self.embeddingnet = EmbeddingNet() 21 | self.gc = GCN(planes=144) 22 | 23 | def forward(self, pre_crop, cur_crop, pre_motion, cur_motion): 24 | matrix = torch.zeros(len(pre_crop), len(cur_crop)).cuda() 25 | pre_feature = torch.zeros(len(pre_crop), 144).cuda() 26 | cur_feature = torch.zeros(len(cur_crop), 144).cuda() 27 | 28 | cur_num = len(cur_crop) 29 | pre_num = len(pre_crop) 30 | 31 | for i in range(pre_num): 32 | pre_crop_ = pre_crop[i].cuda().unsqueeze_(dim=0) 33 | pre_motion_ = pre_motion[i].cuda().unsqueeze_(dim=0) 34 | for j in range(cur_num): 35 | cur_crop_ = cur_crop[j].cuda().unsqueeze_(dim=0) 36 | cur_motion_ = cur_motion[j].cuda().unsqueeze_(dim=0) 37 | score, pre, cur = self.embeddingnet(pre_crop_, cur_crop_, pre_motion_, cur_motion_) 38 | matrix[i, j] = score[0, 1] 39 | pre_feature[i, :] = pre 40 | cur_feature[j, :] = cur 41 | 42 | # matrix = self.gc(pre_feature, cur_feature, matrix) 43 | 44 | return matrix 45 | 46 | 47 | class fuckupnet(nn.Module): 48 | 49 | def __init__(self): 50 | super(fuckupnet, self).__init__() 51 | self.embeddingnet = EmbeddingNet_train() 52 | self.gc = fuckupGCN(planes=144) 53 | 54 | def forward(self, pre_crop, cur_crop, pre_motion, cur_motion): 55 | matrix = torch.zeros(len(pre_crop), len(cur_crop)).cuda() 56 | score0 = torch.zeros(len(pre_crop) * len(cur_crop), 2).cuda() 57 | pre_feature = torch.zeros(len(pre_crop), 144).cuda() 58 | cur_feature = torch.zeros(len(cur_crop), 144).cuda() 59 | 60 | cur_num = len(cur_crop) 61 | pre_num = len(pre_crop) 62 | 63 | for i in range(pre_num): 64 | pre_crop_ = pre_crop[i].cuda().unsqueeze_(dim=0) 65 | pre_motion_ = pre_motion[i].cuda().unsqueeze_(dim=0) 66 | for j in range(cur_num): 67 | cur_crop_ = cur_crop[j].cuda().unsqueeze_(dim=0) 68 | cur_motion_ = cur_motion[j].cuda().unsqueeze_(dim=0) 69 | score, pre, cur = self.embeddingnet(pre_crop_, cur_crop_, pre_motion_, cur_motion_) 70 | score = score.squeeze_() 71 | score0[i * cur_num + j] = score 72 | matrix[i, j] = score[1] 73 | pre_feature[i, :] = pre 74 | cur_feature[j, :] = cur 75 | 76 | # score1, score2, matrix = self.gc(pre_feature, cur_feature, matrix) 77 | 78 | return score0, matrix 79 | 80 | 81 | class embnet(nn.Module): 82 | 83 | def __init__(self): 84 | super(embnet, self).__init__() 85 | self.embeddingnet = EmbeddingNet_train() 86 | 87 | def forward(self, pre_crop, cur_crop, pre_motion, cur_motion): 88 | matrix = torch.zeros(len(pre_crop), len(cur_crop)).cuda() 89 | pre_feature = torch.zeros(len(pre_crop), 144).cuda() 90 | cur_feature = torch.zeros(len(cur_crop), 144).cuda() 91 | 92 | cur_num = len(cur_crop) 93 | pre_num = len(pre_crop) 94 | 95 | for i in range(pre_num): 96 | pre_crop_ = pre_crop[i].cuda().unsqueeze_(dim=0) 97 | pre_motion_ = pre_motion[i].cuda().unsqueeze_(dim=0) 98 | for j in range(cur_num): 99 | cur_crop_ = cur_crop[j].cuda().unsqueeze_(dim=0) 100 | cur_motion_ = cur_motion[j].cuda().unsqueeze_(dim=0) 101 | score, pre, cur = self.embeddingnet(pre_crop_, cur_crop_, pre_motion_, cur_motion_) 102 | score = score.squeeze_() 103 | matrix[i, j] = score[1] 104 | pre_feature[i, :] = pre 105 | cur_feature[j, :] = cur 106 | 107 | return matrix, pre_feature, cur_feature 108 | 109 | 110 | class uninet(nn.Module): 111 | 112 | def __init__(self): 113 | super(uninet, self).__init__() 114 | self.embnet = embnet() 115 | self.gc = fuckupGCN(planes=144) 116 | 117 | def forward(self, pre_crop, cur_crop, pre_motion, cur_motion): 118 | matrix, pre, cur = self.embnet(pre_crop, cur_crop, pre_motion, cur_motion) 119 | score1, score2, matrix = self.gc(pre, cur, matrix) 120 | 121 | return matrix, score1, score2 122 | -------------------------------------------------------------------------------- /model/GCN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : GCN.py 3 | # @Author : Peizhao Li 4 | # @Contact : lipeizhao1997@gmail.com 5 | # @Date : 2018/10/22 6 | 7 | import math 8 | import torch 9 | from torch.nn.parameter import Parameter 10 | from torch.nn.modules.module import Module 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | class GraphConvolution(Module): 16 | 17 | def __init__(self, in_features, out_features): 18 | super(GraphConvolution, self).__init__() 19 | self.in_features = in_features 20 | self.out_features = out_features 21 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 22 | 23 | self.reset_parameters() 24 | 25 | def reset_parameters(self): 26 | stdv = 1. / math.sqrt(self.weight.size(1)) 27 | self.weight.data.uniform_(-stdv, stdv) 28 | 29 | def forward(self, pre, cur, adj): 30 | pre_ = torch.mm(pre, self.weight) 31 | cur_ = torch.mm(cur, self.weight) 32 | 33 | pre = torch.mm(adj, cur_) 34 | cur = torch.mm(adj.t(), pre_) 35 | 36 | pre = F.leaky_relu(pre) 37 | cur = F.leaky_relu(cur) 38 | 39 | return pre, cur 40 | 41 | def __repr__(self): 42 | return self.__class__.__name__ + ' (' \ 43 | + str(self.in_features) + ' -> ' \ 44 | + str(self.out_features) + ')' 45 | 46 | 47 | class GCN(nn.Module): 48 | 49 | def __init__(self, planes): 50 | super(GCN, self).__init__() 51 | 52 | self.gc1 = GraphConvolution(planes, planes) 53 | self.gc2 = GraphConvolution(planes, planes) 54 | 55 | def cos_sim(self, pre, cur, adj): 56 | adj_ = torch.zeros_like(adj).cuda() 57 | for i in range(pre.size(0)): 58 | for j in range(cur.size(0)): 59 | adj_[i, j] = F.cosine_similarity(pre[i:i + 1], cur[j:j + 1]) 60 | 61 | return adj_ 62 | 63 | def forward(self, pre, cur, adj): 64 | pre, cur = self.gc1(pre, cur, adj) 65 | adj = self.cos_sim(pre, cur, adj) 66 | pre, cur = self.gc2(pre, cur, adj) 67 | adj = self.cos_sim(pre, cur, adj) 68 | 69 | return adj 70 | 71 | 72 | class fuckupGCN(nn.Module): 73 | 74 | def __init__(self, planes): 75 | super(fuckupGCN, self).__init__() 76 | 77 | self.gc1 = GraphConvolution(planes, planes) 78 | self.gc2 = GraphConvolution(planes, planes) 79 | 80 | self.fc1 = nn.Sequential(nn.Linear(144, 72), nn.LeakyReLU(), nn.Linear(72, 36), nn.LeakyReLU(), 81 | nn.Linear(36, 2)) 82 | self.fc2 = nn.Sequential(nn.Linear(144, 72), nn.LeakyReLU(), nn.Linear(72, 36), nn.LeakyReLU(), 83 | nn.Linear(36, 2)) 84 | 85 | def MLP(self, fc, pre, cur, adj): 86 | score = torch.zeros(pre.size(0) * cur.size(0), 2).cuda() 87 | adj_ = torch.zeros_like(adj).cuda() 88 | 89 | for i in range(pre.size(0)): 90 | pre_ = pre[i].unsqueeze_(dim=0) 91 | for j in range(cur.size(0)): 92 | cur_ = cur[j].unsqueeze_(dim=0) 93 | tmp = pre_ - cur_ 94 | score_ = fc(tmp) 95 | score_ = score_.squeeze_() 96 | score[i * cur.size(0) + j] = score_ 97 | adj_[i, j] = score_[1] 98 | 99 | return score, adj_ 100 | 101 | def forward(self, pre, cur, adj): 102 | pre, cur = self.gc1(pre, cur, adj) 103 | score1, adj = self.MLP(self.fc1, pre, cur, adj) 104 | pre, cur = self.gc2(pre, cur, adj) 105 | score2, adj = self.MLP(self.fc2, pre, cur, adj) 106 | 107 | return score1, score2, adj 108 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : __init__.py 3 | # @Author : Peizhao Li 4 | # @Contact : lipeizhao1997@gmail.com 5 | # @Date : 2018/10/26 6 | -------------------------------------------------------------------------------- /model/final.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : final.py 3 | # @Author : Peizhao Li 4 | # @Contact : peizhaoli05gmail.com 5 | # @Date : 2018/11/10 6 | 7 | import math 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | from torch.nn.parameter import Parameter 12 | from torch.nn.modules.module import Module 13 | from torchvision import models 14 | 15 | 16 | class ANet(nn.Module): 17 | 18 | def __init__(self): 19 | super(ANet, self).__init__() 20 | 21 | self.ANet = models.resnet18(pretrained=True) 22 | self.ANet.fc = nn.Linear(512, 256) 23 | 24 | for m in self.modules(): 25 | if isinstance(m, nn.Linear): 26 | nn.init.normal_(m.weight, mean=0, std=0.01) 27 | 28 | def forward(self, pre_crop, cur_crop): 29 | pre_crop = self.ANet(pre_crop) 30 | cur_crop = self.ANet(cur_crop) 31 | 32 | crop = torch.tan(F.cosine_similarity(pre_crop, cur_crop)) 33 | 34 | pre_feature = pre_crop 35 | cur_feature = cur_crop 36 | 37 | return crop, pre_feature, cur_feature 38 | 39 | 40 | class GraphConvolution(Module): 41 | 42 | def __init__(self, in_features, out_features): 43 | super(GraphConvolution, self).__init__() 44 | self.in_features = in_features 45 | self.out_features = out_features 46 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 47 | 48 | self.reset_parameters() 49 | 50 | def reset_parameters(self): 51 | stdv = 1. / math.sqrt(self.weight.size(1)) 52 | self.weight.data.uniform_(-stdv, stdv) 53 | 54 | def adj_norm(self, adj): 55 | adj_norm = adj 56 | adj_t_norm = adj.t() 57 | 58 | return adj_norm, adj_t_norm 59 | 60 | def forward(self, pre, cur, adj): 61 | pre_ = torch.mm(pre, self.weight) 62 | cur_ = torch.mm(cur, self.weight) 63 | 64 | adj_norm, adj_t_norm = self.adj_norm(adj) 65 | 66 | pre = torch.mm(adj_norm, cur_) 67 | cur = torch.mm(adj_t_norm, pre_) 68 | 69 | return pre, cur 70 | 71 | def __repr__(self): 72 | return self.__class__.__name__ + ' (' \ 73 | + str(self.in_features) + ' -> ' \ 74 | + str(self.out_features) + ')' 75 | 76 | 77 | class GCN(nn.Module): 78 | 79 | def __init__(self, planes): 80 | super(GCN, self).__init__() 81 | 82 | self.gc = GraphConvolution(planes, planes) 83 | 84 | for m in self.modules(): 85 | if isinstance(m, nn.Linear): 86 | nn.init.normal_(m.weight, mean=0, std=0.01) 87 | 88 | def edge_update(self, pre, cur, adj): 89 | score = torch.zeros(pre.size(0) * cur.size(0)).cuda() 90 | adj_ = torch.zeros_like(adj).cuda() 91 | 92 | for i in range(pre.size(0)): 93 | pre_ = pre[i].unsqueeze(dim=0) 94 | for j in range(cur.size(0)): 95 | cur_ = cur[j].unsqueeze(dim=0) 96 | score_ = torch.tan(F.cosine_similarity(pre_, cur_)) 97 | score[i * cur.size(0) + j] = score_ 98 | adj_[i, j] = score_ 99 | 100 | return score, adj_ 101 | 102 | def forward(self, pre, cur, adj): 103 | pre, cur = self.gc(pre, cur, adj) 104 | score, adj = self.edge_update(pre, cur, adj) 105 | 106 | return score, adj 107 | 108 | 109 | class final(nn.Module): 110 | 111 | def __init__(self): 112 | super(final, self).__init__() 113 | self.embnet = ANet() 114 | self.gc = GCN(planes=256) 115 | 116 | def forward(self, pre_crop, cur_crop): 117 | cur_num = len(cur_crop) 118 | pre_num = len(pre_crop) 119 | 120 | adj1 = torch.zeros(pre_num, cur_num).cuda() 121 | pre_feature = torch.zeros(pre_num, 256).cuda() 122 | cur_feature = torch.zeros(cur_num, 256).cuda() 123 | s0 = torch.zeros(pre_num * cur_num).cuda() 124 | 125 | for i in range(pre_num): 126 | pre_crop_ = pre_crop[i].cuda().unsqueeze(dim=0) 127 | for j in range(cur_num): 128 | cur_crop_ = cur_crop[j].cuda().unsqueeze(dim=0) 129 | 130 | score0_, pre, cur = self.embnet(pre_crop_, cur_crop_) 131 | adj1[i, j] = score0_ 132 | s0[i * cur_num + j] = score0_ 133 | pre_feature[i, :] = pre 134 | cur_feature[j, :] = cur 135 | 136 | s3, adj = self.gc(pre_feature, cur_feature, adj1) 137 | 138 | return s0, s3, adj 139 | -------------------------------------------------------------------------------- /model/net_1024.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : net_1024.py 3 | # @Author : Peizhao Li 4 | # @Contact : lipeizhao1997@gmail.com 5 | # @Date : 2018/10/24 6 | 7 | import math 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | from torch.nn.parameter import Parameter 12 | from torch.nn.modules.module import Module 13 | 14 | 15 | def conv3x3(in_planes, out_planes, stride=1): 16 | """3x3 convolution with padding""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=1, bias=False) 19 | 20 | 21 | class convblock(nn.Module): 22 | 23 | def __init__(self, inplanes, planes, stride=1): 24 | super(convblock, self).__init__() 25 | self.conv1 = conv3x3(inplanes, planes, stride) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.conv2 = conv3x3(planes, planes) 28 | 29 | def forward(self, x): 30 | out = self.conv1(x) 31 | out = self.relu(out) 32 | out = self.conv2(out) 33 | out = self.relu(out) 34 | 35 | return out 36 | 37 | 38 | class ANet_2(nn.Module): 39 | 40 | def __init__(self): 41 | super(ANet_2, self).__init__() 42 | self.ndf = 32 43 | 44 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3, bias=False) 45 | self.conv2 = convblock(32, 48) 46 | self.conv3 = convblock(48, 64) 47 | self.conv4 = convblock(64, 128) 48 | 49 | for m in self.modules(): 50 | if isinstance(m, nn.Conv2d): 51 | nn.init.xavier_uniform_(m.weight) 52 | 53 | def forward(self, x): 54 | x = F.relu(self.conv1(x)) 55 | x = F.max_pool2d(self.conv2(x), 2) 56 | x = F.max_pool2d(self.conv3(x), 2) 57 | x = F.max_pool2d(self.conv4(x), 2) 58 | 59 | x = F.avg_pool2d(x, kernel_size=(5, 2)) 60 | x = x.view(x.size(0), -1) 61 | 62 | return x 63 | 64 | 65 | class ANet(nn.Module): 66 | 67 | def __init__(self): 68 | super(ANet, self).__init__() 69 | self.ndf = 32 70 | 71 | self.conv1 = nn.Conv2d(3, self.ndf, kernel_size=3, stride=1, padding=1, bias=False) 72 | self.conv2 = nn.Conv2d(self.ndf, int(self.ndf * 1.5), kernel_size=3, stride=1, padding=1, bias=False) 73 | self.conv3 = nn.Conv2d(int(self.ndf * 1.5), self.ndf * 2, kernel_size=3, stride=1, padding=1, bias=False) 74 | self.conv4 = nn.Conv2d(self.ndf * 2, self.ndf * 4, kernel_size=3, stride=1, padding=1, bias=False) 75 | 76 | for m in self.modules(): 77 | if isinstance(m, nn.Conv2d): 78 | nn.init.xavier_uniform_(m.weight) 79 | 80 | def forward(self, x): 81 | x = F.max_pool2d(self.conv1(x), 2) 82 | x = F.relu(x) 83 | x = F.max_pool2d(self.conv2(x), 2) 84 | x = F.relu(x) 85 | x = F.max_pool2d(self.conv3(x), 2) 86 | x = F.relu(x) 87 | x = F.max_pool2d(self.conv4(x), 2) 88 | x = F.relu(x) 89 | 90 | x = F.avg_pool2d(x, kernel_size=(5, 2)) 91 | x = x.view(x.size(0), -1) 92 | 93 | return x 94 | 95 | 96 | class LSTM(nn.Module): 97 | 98 | def __init__(self, hidden_size): 99 | super(LSTM, self).__init__() 100 | self.lstm = nn.LSTM(input_size=2, hidden_size=hidden_size, num_layers=1, batch_first=True) 101 | 102 | def forward(self, x): 103 | x, _ = self.lstm(x) 104 | x = x[:, -1, :].clone() 105 | 106 | return x 107 | 108 | 109 | class embnet(nn.Module): 110 | 111 | def __init__(self): 112 | super(embnet, self).__init__() 113 | self.ANet = ANet() 114 | self.lstm = LSTM(hidden_size=16) 115 | self.fc = nn.Linear(2, 16, bias=False) 116 | 117 | self.crop_fc1 = nn.Linear(128, 64, bias=False) 118 | self.crop_fc2 = nn.Linear(64, 32, bias=False) 119 | self.crop_fc3 = nn.Linear(32, 1, bias=False) 120 | 121 | self.coord_fc1 = nn.Linear(16, 8, bias=False) 122 | self.coord_fc2 = nn.Linear(8, 1, bias=False) 123 | 124 | self.com = nn.Linear(2, 1, bias=False) 125 | 126 | for m in self.modules(): 127 | if isinstance(m, nn.Linear): 128 | nn.init.normal_(m.weight, mean=0, std=0.01) 129 | 130 | def forward(self, pre_crop, cur_crop, pre_coord, cur_coord): 131 | pre_crop = self.ANet(pre_crop) 132 | cur_crop = self.ANet(cur_crop) 133 | pre_coord = self.lstm(pre_coord) 134 | cur_coord = self.fc(cur_coord) 135 | 136 | temp_crop = pre_crop.sub(cur_crop) 137 | temp_coord = pre_coord.sub(cur_coord) 138 | 139 | crop = F.relu(self.crop_fc1(temp_crop)) 140 | crop = F.relu(self.crop_fc2(crop)) 141 | crop = self.crop_fc3(crop) 142 | 143 | coord = F.relu(self.coord_fc1(temp_coord)) 144 | coord = self.coord_fc2(coord) 145 | 146 | com = torch.cat((crop, coord), dim=1) 147 | com = self.com(com) 148 | 149 | pre_feature = torch.cat((pre_crop, pre_coord), dim=1) 150 | cur_feature = torch.cat((cur_crop, cur_coord), dim=1) 151 | 152 | return com, crop, coord, pre_feature, cur_feature 153 | 154 | 155 | class GraphConvolution(Module): 156 | 157 | def __init__(self, in_features, out_features): 158 | super(GraphConvolution, self).__init__() 159 | self.in_features = in_features 160 | self.out_features = out_features 161 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 162 | 163 | self.reset_parameters() 164 | 165 | def reset_parameters(self): 166 | stdv = 1. / math.sqrt(self.weight.size(1)) 167 | self.weight.data.uniform_(-stdv, stdv) 168 | 169 | def adj_norm(self, adj): 170 | adj_norm = F.softmax(adj, dim=1) 171 | adj_t_norm = F.softmax(adj.t(), dim=1) 172 | 173 | return adj_norm, adj_t_norm 174 | 175 | def forward(self, pre, cur, adj): 176 | pre_ = torch.mm(pre, self.weight) 177 | cur_ = torch.mm(cur, self.weight) 178 | 179 | adj_norm, adj_t_norm = self.adj_norm(adj) 180 | 181 | pre = torch.mm(adj_norm, cur_) 182 | cur = torch.mm(adj_t_norm, pre_) 183 | 184 | pre = F.relu_(pre) 185 | cur = F.relu_(cur) 186 | 187 | return pre, cur 188 | 189 | def __repr__(self): 190 | return self.__class__.__name__ + ' (' \ 191 | + str(self.in_features) + ' -> ' \ 192 | + str(self.out_features) + ')' 193 | 194 | 195 | class GCN(nn.Module): 196 | 197 | def __init__(self, planes): 198 | super(GCN, self).__init__() 199 | 200 | self.gc = GraphConvolution(planes, planes) 201 | # self.gc2 = GraphConvolution(planes, planes) 202 | 203 | self.fc1 = nn.Sequential(nn.Linear(144, 72, bias=False), nn.ReLU(), nn.Linear(72, 36, bias=False), nn.ReLU(), 204 | nn.Linear(36, 1, bias=False)) 205 | # self.fc2 = nn.Sequential(nn.Linear(144, 72, bias=False), nn.ReLU(), nn.Linear(72, 36, bias=False), nn.ReLU(), 206 | # nn.Linear(36, 1, bias=False)) 207 | 208 | for m in self.modules(): 209 | if isinstance(m, nn.Linear): 210 | nn.init.normal_(m.weight, mean=0, std=0.01) 211 | 212 | def MLP(self, fc, pre, cur, adj): 213 | score = torch.zeros(pre.size(0) * cur.size(0)).cuda() 214 | adj_ = torch.zeros_like(adj).cuda() 215 | 216 | for i in range(pre.size(0)): 217 | pre_ = pre[i].unsqueeze(dim=0) 218 | for j in range(cur.size(0)): 219 | cur_ = cur[j].unsqueeze(dim=0) 220 | temp = pre_.sub(cur_) 221 | score_ = fc(temp) 222 | score_ = score_.squeeze() 223 | score[i * cur.size(0) + j] = score_ 224 | adj_[i, j] = score_ 225 | 226 | return score, adj_ 227 | 228 | def forward(self, pre, cur, adj): 229 | pre, cur = self.gc(pre, cur, adj) 230 | score, adj = self.MLP(self.fc1, pre, cur, adj) 231 | # pre, cur = self.gc2(pre, cur, adj) 232 | # score2, adj = self.MLP(self.fc2, pre, cur, adj) 233 | 234 | return score, adj 235 | 236 | 237 | class net_1024(nn.Module): 238 | 239 | def __init__(self): 240 | super(net_1024, self).__init__() 241 | self.embnet = embnet() 242 | self.gc = GCN(planes=144) 243 | 244 | def forward(self, pre_crop, cur_crop, pre_motion, cur_motion): 245 | cur_num = len(cur_crop) 246 | pre_num = len(pre_crop) 247 | 248 | adj1 = torch.zeros(pre_num, cur_num).cuda() 249 | pre_feature = torch.zeros(pre_num, 144).cuda() 250 | cur_feature = torch.zeros(cur_num, 144).cuda() 251 | s0 = torch.zeros(pre_num * cur_num).cuda() 252 | s1 = torch.zeros(pre_num * cur_num).cuda() 253 | s2 = torch.zeros(pre_num * cur_num).cuda() 254 | 255 | for i in range(pre_num): 256 | pre_crop_ = pre_crop[i].cuda().unsqueeze(dim=0) 257 | pre_motion_ = pre_motion[i].cuda().unsqueeze(dim=0) 258 | for j in range(cur_num): 259 | cur_crop_ = cur_crop[j].cuda().unsqueeze(dim=0) 260 | cur_motion_ = cur_motion[j].cuda().unsqueeze(dim=0) 261 | 262 | score0_, score1_, score2_, pre, cur = self.embnet(pre_crop_, cur_crop_, pre_motion_, cur_motion_) 263 | adj1[i, j] = score0_ 264 | s0[i * cur_num + j] = score0_ 265 | s1[i * cur_num + j] = score1_ 266 | s2[i * cur_num + j] = score2_ 267 | pre_feature[i, :] = pre 268 | cur_feature[j, :] = cur 269 | 270 | s3, adj = self.gc(pre_feature, cur_feature, adj1) 271 | 272 | return s0, s1, s2, s3, adj1, adj 273 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.11.0 2 | scipy==1.1.0 3 | torchvision==0.2.1 4 | opencv_python==3.3.0.10 5 | easydict==1.7 6 | torch==0.4.1 7 | Pillow==6.0.0 8 | motmetrics==1.1.3 9 | PyYAML==5.1 10 | -------------------------------------------------------------------------------- /setting/ADL-Rundle-1_config.yml: -------------------------------------------------------------------------------- 1 | # -------- Tracking Setting for Sequence ADL-Rundle-1 -------- 2 | ID_assign_init : 30 3 | ID_birth_init : -10 4 | DeathBufferLength : 2000 5 | BirthBufferLength : 5000 6 | DeathCount : 10 7 | BirthCount : 5 8 | Threshold : -3 9 | Distance : 0.05 10 | BoxRation : 0.3 11 | FrameWidth : 1920 12 | FrameHeight : 1080 13 | fps : 30 14 | PredictThreshold : 0.003 15 | # ------------------------------------------------------------ -------------------------------------------------------------------------------- /setting/ADL-Rundle-3_config.yml: -------------------------------------------------------------------------------- 1 | # -------- Tracking Setting for Sequence ADL-Rundle-3 -------- 2 | ID_assign_init : 30 3 | ID_birth_init : -10 4 | DeathBufferLength : 2000 5 | BirthBufferLength : 5000 6 | DeathCount : 8 7 | BirthCount : 4 8 | Threshold : -3 9 | Distance : 0.05 10 | BoxRation : 0.3 11 | FrameWidth : 1920 12 | FrameHeight : 1080 13 | fps : 15 14 | PredictThreshold : 0.005 15 | # ------------------------------------------------------------ -------------------------------------------------------------------------------- /setting/AVG-TownCentre_config.yml: -------------------------------------------------------------------------------- 1 | # ------- Tracking Setting for Sequence AVG-TownCentre ------- 2 | ID_assign_init : 30 3 | ID_birth_init : -10 4 | DeathBufferLength : 2000 5 | BirthBufferLength : 5000 6 | DeathCount : 10 7 | BirthCount : 2 8 | Threshold : -3 9 | Distance : 0.075 10 | BoxRation : 0.3 11 | FrameWidth : 1920 12 | FrameHeight : 1080 13 | fps : 10 14 | PredictThreshold : 0.0075 15 | # ------------------------------------------------------------ -------------------------------------------------------------------------------- /setting/ETH-Crossing_config.yml: -------------------------------------------------------------------------------- 1 | # -------- Tracking Setting for Sequence ETH-Crossing -------- 2 | ID_assign_init : 30 3 | ID_birth_init : -10 4 | DeathBufferLength : 2000 5 | BirthBufferLength : 5000 6 | DeathCount : 8 7 | BirthCount : 1 8 | Threshold : -3 9 | Distance : 0.075 10 | BoxRation : 0.3 11 | FrameWidth : 640 12 | FrameHeight : 480 13 | fps : 14 14 | PredictThreshold : 0.0075 15 | # ------------------------------------------------------------ -------------------------------------------------------------------------------- /setting/ETH-Jelmoli_config.yml: -------------------------------------------------------------------------------- 1 | # -------- Tracking Setting for Sequence ETH-Jelmoli --------- 2 | ID_assign_init : 30 3 | ID_birth_init : -10 4 | DeathBufferLength : 2000 5 | BirthBufferLength : 5000 6 | DeathCount : 5 7 | BirthCount : 3 8 | Threshold : -3 9 | Distance : 0.05 10 | BoxRation : 0.3 11 | FrameWidth : 640 12 | FrameHeight : 480 13 | fps : 14 14 | PredictThreshold : 0.005 15 | # ------------------------------------------------------------ -------------------------------------------------------------------------------- /setting/ETH-Linthescher_config.yml: -------------------------------------------------------------------------------- 1 | # ------ Tracking Setting for Sequence ETH-Linthescher ------- 2 | ID_assign_init : 30 3 | ID_birth_init : -10 4 | DeathBufferLength : 2000 5 | BirthBufferLength : 5000 6 | DeathCount : 12 7 | BirthCount : 2 8 | Threshold : -3 9 | Distance : 0.05 10 | BoxRation : 0.3 11 | FrameWidth : 640 12 | FrameHeight : 480 13 | fps : 14 14 | PredictThreshold : 0.005 15 | # ------------------------------------------------------------ -------------------------------------------------------------------------------- /setting/KITTI-16_config.yml: -------------------------------------------------------------------------------- 1 | # ---------- Tracking Setting for Sequence KITTI-16 ---------- 2 | ID_assign_init : 30 3 | ID_birth_init : -10 4 | DeathBufferLength : 2000 5 | BirthBufferLength : 5000 6 | DeathCount : 3 7 | BirthCount : 2 8 | Threshold : -3 9 | Distance : 0.075 10 | BoxRation : 0.3 11 | FrameWidth : 1224 12 | FrameHeight : 370 13 | fps : 10 14 | PredictThreshold : 0.015 15 | # ------------------------------------------------------------ -------------------------------------------------------------------------------- /setting/KITTI-19_config.yml: -------------------------------------------------------------------------------- 1 | # ---------- Tracking Setting for Sequence KITTI-19 ---------- 2 | ID_assign_init : 30 3 | ID_birth_init : -10 4 | DeathBufferLength : 2000 5 | BirthBufferLength : 5000 6 | DeathCount : 8 7 | BirthCount : 2 8 | Threshold : -3 9 | Distance : 0.075 10 | BoxRation : 0.3 11 | FrameWidth : 1238 12 | FrameHeight : 374 13 | fps : 10 14 | PredictThreshold : 0.015 15 | # ------------------------------------------------------------ -------------------------------------------------------------------------------- /setting/PETS09-S2L2_config.yml: -------------------------------------------------------------------------------- 1 | # --------- Tracking Setting for Sequence PETS09-S2L2 -------- 2 | ID_assign_init : 30 3 | ID_birth_init : -10 4 | DeathBufferLength : 2000 5 | BirthBufferLength : 5000 6 | DeathCount : 10 7 | BirthCount : 5 8 | Threshold : -3 9 | Distance : 0.05 10 | BoxRation : 0.3 11 | FrameWidth : 768 12 | FrameHeight : 576 13 | fps : 7 14 | PredictThreshold : 0.005 15 | # ------------------------------------------------------------ -------------------------------------------------------------------------------- /setting/TUD-Crossing_config.yml: -------------------------------------------------------------------------------- 1 | # -------- Tracking Setting for Sequence TUD-Crossing -------- 2 | ID_assign_init : 30 3 | ID_birth_init : -10 4 | DeathBufferLength : 2000 5 | BirthBufferLength : 5000 6 | DeathCount : 12 7 | BirthCount : 2 8 | Threshold : -3 9 | Distance : 0.04 10 | BoxRation : 0.3 11 | FrameWidth : 640 12 | FrameHeight : 480 13 | fps : 25 14 | PredictThreshold : 0.005 15 | # ------------------------------------------------------------ -------------------------------------------------------------------------------- /setting/Venice-1_config.yml: -------------------------------------------------------------------------------- 1 | # ---------- Tracking Setting for Sequence Venicd-1 ---------- 2 | ID_assign_init : 30 3 | ID_birth_init : -10 4 | DeathBufferLength : 2000 5 | BirthBufferLength : 5000 6 | DeathCount : 10 7 | BirthCount : 5 8 | Threshold : -3 9 | Distance : 0.05 10 | BoxRation : 0.3 11 | FrameWidth : 1920 12 | FrameHeight : 1080 13 | fps : 30 14 | PredictThreshold : 0.005 15 | # ------------------------------------------------------------ -------------------------------------------------------------------------------- /setting/seq01_config.yml: -------------------------------------------------------------------------------- 1 | # ------------- Tracking Setting for Sequence 01 ------------- 2 | ID_assign_init : 20 3 | ID_birth_init : -10 4 | DeathBufferLength : 2000 5 | BirthBufferLength : 5000 6 | DeathCount : 10 7 | BirthCount : 5 8 | Threshold : -3 9 | Distance : 0.05 10 | BoxRation : 0.3 11 | FrameWidth : 1920 12 | FrameHeight : 1080 13 | fps : 15 14 | PredictThreshold : 0.003 15 | # ------------------------------------------------------------ -------------------------------------------------------------------------------- /setting/seq03_config.yml: -------------------------------------------------------------------------------- 1 | # ------------- Tracking Setting for Sequence 03 ------------- 2 | ID_assign_init : 100 3 | ID_birth_init : -10 4 | DeathBufferLength : 4000 5 | BirthBufferLength : 8000 6 | DeathCount : 10 7 | BirthCount : 5 8 | Threshold : -3 9 | Distance : 0.04 10 | BoxRation : 0.3 11 | FrameWidth : 1920 12 | FrameHeight : 1080 13 | fps : 15 14 | PredictThreshold : 0.005 15 | # ------------------------------------------------------------ -------------------------------------------------------------------------------- /setting/seq06_config.yml: -------------------------------------------------------------------------------- 1 | # ------------- Tracking Setting for Sequence 06 ------------- 2 | ID_assign_init : 20 3 | ID_birth_init : -10 4 | DeathBufferLength : 2000 5 | BirthBufferLength : 5000 6 | DeathCount : 6 7 | BirthCount : 3 8 | Threshold : -3 9 | Distance : 0.06 10 | BoxRation : 0.3 11 | FrameWidth : 640 12 | FrameHeight : 480 13 | fps : 7 14 | PredictThreshold : 0.006 15 | # ------------------------------------------------------------ -------------------------------------------------------------------------------- /setting/seq07_config.yml: -------------------------------------------------------------------------------- 1 | # ------------- Tracking Setting for Sequence 07 ------------- 2 | ID_assign_init : 30 3 | ID_birth_init : -10 4 | DeathBufferLength : 2000 5 | BirthBufferLength : 5000 6 | DeathCount : 10 7 | BirthCount : 5 8 | Threshold : -3 9 | Distance : 0.05 10 | BoxRation : 0.3 11 | FrameWidth : 1920 12 | FrameHeight : 1080 13 | fps : 15 14 | PredictThreshold : 0.005 15 | # ------------------------------------------------------------ -------------------------------------------------------------------------------- /setting/seq08_config.yml: -------------------------------------------------------------------------------- 1 | # ------------- Tracking Setting for Sequence 08 ------------- 2 | ID_assign_init : 20 3 | ID_birth_init : -10 4 | DeathBufferLength : 2000 5 | BirthBufferLength : 5000 6 | DeathCount : 10 7 | BirthCount : 5 8 | Threshold : -3 9 | Distance : 0.06 10 | BoxRation : 0.3 11 | FrameWidth : 1920 12 | FrameHeight : 1080 13 | fps : 15 14 | PredictThreshold : 0.005 15 | # ------------------------------------------------------------ -------------------------------------------------------------------------------- /setting/seq12_config.yml: -------------------------------------------------------------------------------- 1 | # ------------- Tracking Setting for Sequence 12 ------------- 2 | ID_assign_init : 20 3 | ID_birth_init : -10 4 | DeathBufferLength : 2000 5 | BirthBufferLength : 5000 6 | DeathCount : 10 7 | BirthCount : 5 8 | Threshold : -3 9 | Distance : 0.05 10 | BoxRation : 0.3 11 | FrameWidth : 1920 12 | FrameHeight : 1080 13 | fps : 15 14 | PredictThreshold : 0.003 15 | # ------------------------------------------------------------ -------------------------------------------------------------------------------- /setting/seq14_config.yml: -------------------------------------------------------------------------------- 1 | # ------------- Tracking Setting for Sequence 14 ------------- 2 | ID_assign_init : 20 3 | ID_birth_init : -10 4 | DeathBufferLength : 2000 5 | BirthBufferLength : 5000 6 | DeathCount : 8 7 | BirthCount : 4 8 | Threshold : -3 9 | Distance : 0.05 10 | BoxRation : 0.3 11 | FrameWidth : 1920 12 | FrameHeight : 1080 13 | fps : 12 14 | PredictThreshold : 0.003 15 | # ------------------------------------------------------------ -------------------------------------------------------------------------------- /tracking-MOT15.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : tracking-MOT15.py 3 | # @Author : Peizhao Li 4 | # @Contact : peizhaoli05@gmail.com 5 | # @Date : 2018/11/14 6 | 7 | from Test import TestGenerator 8 | from utils import * 9 | from tracking_utils import * 10 | from mkvideo_MOT15 import MakeVideo 11 | 12 | np.set_printoptions(precision=2, suppress=True) 13 | 14 | 15 | def main(info, timer): 16 | "-------------------------------- initialize --------------------------------" 17 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 18 | parser, settings_show = Config("setting/{}_config.yml".format(info)) 19 | for idx, data in enumerate(settings_show): 20 | print(data) 21 | detection = MakeCell(np.loadtxt("test-MOT15/{}/det.txt".format(info))) 22 | manual_init = MakeCell(np.loadtxt("test-MOT15/{}/res.txt".format(info))) 23 | StartFrame = len(manual_init) 24 | TotalFrame = len(detection) 25 | tracker_ = tracker(ID_assign_init=parser.ID_assign_init, ID_birth_init=parser.ID_birth_init, 26 | DeathBufferLength=parser.DeathBufferLength, BirthBufferLength=parser.BirthBufferLength, 27 | DeathCount=parser.DeathCount, BirthCount=parser.BirthCount, Threshold=parser.Threshold, 28 | Distance=parser.Distance, BoxRation=parser.BoxRation, FrameWidth=parser.FrameWidth, 29 | FrameHeight=parser.FrameHeight, PredictThreshold=parser.PredictThreshold) 30 | PrevData = manual_init[-1] 31 | PPrevData = manual_init[-2] 32 | PPPrevData = manual_init[-3] 33 | PPPPrevData = manual_init[-4] 34 | PPPPPrevData = manual_init[-5] 35 | 36 | res_path = "test-MOT15/{}/{}.txt".format(info, time_for_file()) 37 | res_init = np.loadtxt("test-MOT15/{}/res.txt".format(info)) 38 | with open(res_path, "w") as res: 39 | np.savetxt(res, res_init, fmt='%12.3f') 40 | res.close() 41 | 42 | generator = TestGenerator(res_path, info) 43 | BirthLog = [[], [], []] # prev frame, negative ID, assign ID 44 | DeathLog = [[], []] # death frame, death ID 45 | "----------------------------------------------------------------------------" 46 | 47 | print("############### Start Tracking ###############") 48 | for frame in range(StartFrame, TotalFrame): 49 | print ("-----------------------> Start Tracking Frame %d" % (frame + 1)) 50 | CurData = detection[frame] 51 | 52 | PrevIDs = PrevData[:, 1].copy() 53 | assert PrevIDs[PrevIDs == 0].size == 0 54 | CurIDs = CurData[:, 1].copy() 55 | assert CurIDs.max() == -1 and CurIDs.min() == -1 56 | 57 | tik = time.time() 58 | Amatrix = generator(SeqID=0, frame=frame) 59 | Amatrix = Amatrix.cpu().numpy() 60 | assert Amatrix.shape[0] == PrevIDs.shape[0] and Amatrix.shape[1] == CurIDs.shape[0] 61 | 62 | PrevData, PPrevData, PPPrevData, PPPPrevData, PPPPPrevData, BirthLog, DeathLog = tracker_(Amatrix=Amatrix, 63 | PrevIDs=PrevIDs, 64 | CurData=CurData, 65 | PrevData=PrevData, 66 | PPrevData=PPrevData, 67 | PPPrevData=PPPrevData, 68 | PPPPrevData=PPPPrevData, 69 | PPPPPrevData=PPPPPrevData, 70 | BirthLog=BirthLog, 71 | DeathLog=DeathLog) 72 | tok = time.time() 73 | timer.sum(tok - tik) 74 | 75 | with open(res_path, "a") as res: 76 | np.savetxt(res, PrevData, fmt="%12.3f") 77 | res.close() 78 | 79 | print ("-----------------------> Finish Tracking Frame %d\n" % (frame + 1)) 80 | print("############### Finish Tracking ###############\n") 81 | 82 | assert len(BirthLog[0]) == len(BirthLog[1]) == len(BirthLog[2]) 83 | assert len(DeathLog[0]) == len(DeathLog[1]) 84 | 85 | res_data = np.loadtxt(res_path) 86 | 87 | print("cleaning birth...") 88 | for birth in range(len(BirthLog[0])): 89 | frame = BirthLog[0][birth] 90 | ID_index = np.where(res_data[:, 1] == BirthLog[1][birth]) 91 | assign_ID = BirthLog[2][birth] 92 | for i in range(parser.BirthCount): 93 | frame_ = frame - i 94 | frame_index = np.where(res_data[:, 0] == frame_) 95 | index = np.intersect1d(frame_index, ID_index) 96 | res_data[index, 1] = assign_ID 97 | 98 | print("cleaning death...") 99 | for death in range(len(DeathLog[0])): 100 | frame = DeathLog[0][death] 101 | ID_index = np.where(res_data[:, 1] == DeathLog[1][death]) 102 | for i in range(parser.DeathCount - 3): 103 | frame_ = frame - i 104 | frame_index = np.where(res_data[:, 0] == frame_) 105 | index = np.intersect1d(frame_index, ID_index) 106 | res_data[index, 1] = -1 107 | 108 | print("cleaning death sp...") 109 | DeathBuffer = tracker_.DeathBuffer 110 | death_sp_log = np.intersect1d(np.where(DeathBuffer > 3)[0], np.where(DeathBuffer < parser.DeathCount)[0]) 111 | for death_sp in range(death_sp_log.shape[0]): 112 | frame = TotalFrame 113 | ID_index = np.where(res_data[:, 1] == death_sp_log[death_sp]) 114 | for i in range(int(DeathBuffer[death_sp_log[death_sp]])): 115 | frame_ = frame - i 116 | frame_index = np.where(res_data[:, 0] == frame_) 117 | index = np.intersect1d(frame_index, ID_index) 118 | res_data[index, 1] = -1 119 | 120 | np.savetxt(res_path, res_data, fmt="%12.3f") 121 | 122 | MakeVideo(res_path, info, parser.fps, parser.FrameWidth, parser.FrameHeight) 123 | 124 | 125 | if __name__ == "__main__": 126 | # seq = ["ADL-Rundle-1", "ADL-Rundle-3", "AVG-TownCentre", "ETH-Crossing", "ETH-Jelmoli", "ETH-Linthescher", "KITTI-16", 127 | # "KITTI-19", "PETS09-S2L2", "TUD-Crossing", "Venice-1"] 128 | seq = ["PETS09-S2L2"] 129 | 130 | timer = timer() 131 | for s in range(len(seq)): 132 | main(seq[s], timer) 133 | print("total time: {} second".format(timer())) 134 | -------------------------------------------------------------------------------- /tracking.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : tracking.py 3 | # @Author : Peizhao Li 4 | # @Contact : peizhaoli05gmail.com 5 | # @Date : 2018/11/2 6 | 7 | from Test import TestGenerator 8 | from tracking_utils import * 9 | from utils import * 10 | 11 | np.set_printoptions(precision=2, suppress=True) 12 | 13 | 14 | def main(info, timer): 15 | "-------------------------------- initialize --------------------------------" 16 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 17 | parser, settings_show = Config("setting/seq{}_config.yml".format(info[0])) 18 | for idx, data in enumerate(settings_show): 19 | print(data) 20 | detection = MakeCell(np.loadtxt("test/MOT17-{}-{}/det.txt".format(info[0], info[1]))) 21 | manual_init = MakeCell(np.loadtxt("test/MOT17-{}-{}/res.txt".format(info[0], info[1]))) 22 | StartFrame = len(manual_init) 23 | TotalFrame = len(detection) 24 | tracker_ = tracker(ID_assign_init=parser.ID_assign_init, ID_birth_init=parser.ID_birth_init, 25 | DeathBufferLength=parser.DeathBufferLength, BirthBufferLength=parser.BirthBufferLength, 26 | DeathCount=parser.DeathCount, BirthCount=parser.BirthCount, Threshold=parser.Threshold, 27 | Distance=parser.Distance, BoxRation=parser.BoxRation, FrameWidth=parser.FrameWidth, 28 | FrameHeight=parser.FrameHeight, PredictThreshold=parser.PredictThreshold) 29 | PrevData = manual_init[-1] 30 | PPrevData = manual_init[-2] 31 | PPPrevData = manual_init[-3] 32 | PPPPrevData = manual_init[-4] 33 | PPPPPrevData = manual_init[-5] 34 | 35 | res_path = "test/MOT17-{}-{}/{}.txt".format(info[0], info[1], time_for_file()) 36 | res_init = np.loadtxt("test/MOT17-{}-{}/res.txt".format(info[0], info[1])) 37 | with open(res_path, "w") as res: 38 | np.savetxt(res, res_init, fmt='%12.3f') 39 | res.close() 40 | 41 | generator = TestGenerator(res_path, info) 42 | BirthLog = [[], [], []] # prev frame, negative ID, assign ID 43 | DeathLog = [[], []] # death frame, death ID 44 | "----------------------------------------------------------------------------" 45 | 46 | print("############### Start Tracking ###############") 47 | for frame in range(StartFrame, TotalFrame): 48 | print ("-----------------------> Start Tracking Frame %d" % (frame + 1)) 49 | CurData = detection[frame] 50 | 51 | PrevIDs = PrevData[:, 1].copy() 52 | assert PrevIDs[PrevIDs == 0].size == 0 53 | CurIDs = CurData[:, 1].copy() 54 | assert CurIDs.max() == -1 and CurIDs.min() == -1 55 | 56 | tik = time.time() 57 | Amatrix = generator(SeqID=0, frame=frame) 58 | Amatrix = Amatrix.cpu().numpy() 59 | assert Amatrix.shape[0] == PrevIDs.shape[0] and Amatrix.shape[1] == CurIDs.shape[0] 60 | 61 | PrevData, PPrevData, PPPrevData, PPPPrevData, PPPPPrevData, BirthLog, DeathLog = tracker_(Amatrix=Amatrix, 62 | PrevIDs=PrevIDs, 63 | CurData=CurData, 64 | PrevData=PrevData, 65 | PPrevData=PPrevData, 66 | PPPrevData=PPPrevData, 67 | PPPPrevData=PPPPrevData, 68 | PPPPPrevData=PPPPPrevData, 69 | BirthLog=BirthLog, 70 | DeathLog=DeathLog) 71 | tok = time.time() 72 | timer.sum(tok - tik) 73 | 74 | with open(res_path, "a") as res: 75 | np.savetxt(res, PrevData, fmt="%12.3f") 76 | res.close() 77 | 78 | print ("-----------------------> Finish Tracking Frame %d\n" % (frame + 1)) 79 | print("############### Finish Tracking ###############\n") 80 | 81 | assert len(BirthLog[0]) == len(BirthLog[1]) == len(BirthLog[2]) 82 | assert len(DeathLog[0]) == len(DeathLog[1]) 83 | 84 | res_data = np.loadtxt(res_path) 85 | 86 | # print("cleaning birth...") 87 | # for birth in range(len(BirthLog[0])): 88 | # frame = BirthLog[0][birth] 89 | # ID_index = np.where(res_data[:, 1] == BirthLog[1][birth]) 90 | # assign_ID = BirthLog[2][birth] 91 | # for i in range(parser.BirthCount): 92 | # frame_ = frame - i 93 | # frame_index = np.where(res_data[:, 0] == frame_) 94 | # index = np.intersect1d(frame_index, ID_index) 95 | # res_data[index, 1] = assign_ID 96 | 97 | # print("cleaning death...") 98 | # for death in range(len(DeathLog[0])): 99 | # frame = DeathLog[0][death] 100 | # ID_index = np.where(res_data[:, 1] == DeathLog[1][death]) 101 | # for i in range(parser.DeathCount - 2): 102 | # frame_ = frame - i 103 | # frame_index = np.where(res_data[:, 0] == frame_) 104 | # index = np.intersect1d(frame_index, ID_index) 105 | # res_data[index, 1] = -1 106 | 107 | # print("cleaning death sp...") 108 | # DeathBuffer = tracker_.DeathBuffer 109 | # death_sp_log = np.intersect1d(np.where(DeathBuffer > 3)[0], np.where(DeathBuffer < parser.DeathCount)[0]) 110 | # for death_sp in range(death_sp_log.shape[0]): 111 | # frame = TotalFrame 112 | # ID_index = np.where(res_data[:, 1] == death_sp_log[death_sp]) 113 | # for i in range(int(DeathBuffer[death_sp_log[death_sp]])): 114 | # frame_ = frame - i 115 | # frame_index = np.where(res_data[:, 0] == frame_) 116 | # index = np.intersect1d(frame_index, ID_index) 117 | # res_data[index, 1] = -1 118 | 119 | np.savetxt(res_path, res_data, fmt="%12.3f") 120 | 121 | # MakeVideo(res_path, info, parser.fps, parser.FrameWidth, parser.FrameHeight) 122 | 123 | 124 | if __name__ == "__main__": 125 | # seq = ["01", "06", "07", "08", "12", "14"] 126 | seq = ["03"] 127 | # detector = ["DPM", "FRCNN", "SDP"] 128 | detector = ["DPM"] 129 | # detector = ["FRCNN"] 130 | # detector = ["SDP"] 131 | timer = timer() 132 | for s in range(len(seq)): 133 | for d in range(len(detector)): 134 | main([seq[s], detector[d]], timer) 135 | print("total time: {} second".format(timer())) 136 | -------------------------------------------------------------------------------- /tracking_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : tracking_utils.py 3 | # @Author : Peizhao Li 4 | # @Contact : peizhaoli05gmail.com 5 | # @Date : 2018/11/2 6 | 7 | import numpy as np 8 | 9 | 10 | def MakeCell(data): 11 | cell = [] 12 | frame_last = data[-1, 0] 13 | for i in range(1, int(frame_last) + 1): 14 | data_ = data[data[:, 0] == i] 15 | cell.append(data_.copy()) 16 | 17 | return cell 18 | 19 | 20 | class timer(): 21 | 22 | def __init__(self): 23 | self.time = 0 24 | 25 | def sum(self, time): 26 | self.time += time 27 | 28 | def __call__(self): 29 | return int(self.time) 30 | 31 | 32 | class ID_assign(): 33 | 34 | def __init__(self, ID_init): 35 | self.ID = ID_init - 1 36 | 37 | def curID(self): 38 | return self.ID 39 | 40 | def __call__(self): 41 | self.ID += 1 42 | return self.ID 43 | 44 | 45 | class ID_birth(): 46 | 47 | def __init__(self, ID_init): 48 | self.ID = ID_init + 1 49 | 50 | def curID(self): 51 | return self.ID 52 | 53 | def __call__(self): 54 | self.ID -= 1 55 | return self.ID 56 | 57 | 58 | class tracker(): 59 | 60 | def __init__(self, ID_assign_init, ID_birth_init, DeathBufferLength, BirthBufferLength, DeathCount, BirthCount, 61 | Threshold, Distance, BoxRation, FrameWidth, FrameHeight, PredictThreshold): 62 | self.ID_assign = ID_assign(ID_init=ID_assign_init) 63 | self.ID_birth = ID_birth(ID_init=ID_birth_init) 64 | self.DeathBuffer = np.zeros(DeathBufferLength) 65 | self.BirthBuffer = np.zeros(BirthBufferLength) 66 | self.DeathCount = DeathCount 67 | self.BirthCount = BirthCount 68 | self.Threshold = Threshold 69 | self.Distance = Distance 70 | self.BoxRation = BoxRation 71 | self.FrameWidth = float(FrameWidth) 72 | self.FrameHeight = float(FrameHeight) 73 | self.PredictThreshold = PredictThreshold 74 | 75 | def DistanceMeasure(self, PrevData_, CurData_): 76 | x_dis = (np.abs(PrevData_[2] - CurData_[2])) / self.FrameWidth 77 | y_dis = (np.abs(PrevData_[3] - CurData_[3])) / self.FrameHeight 78 | dis = x_dis + y_dis 79 | 80 | PrevBoxSize = PrevData_[4] 81 | CurBoxSize = CurData_[4] 82 | 83 | if dis <= self.Distance and CurBoxSize < PrevBoxSize * (1 + self.BoxRation) and CurBoxSize > PrevBoxSize * ( 84 | 1 - self.BoxRation): 85 | return True 86 | else: 87 | return False 88 | 89 | def CoordPrediction_v1(self, PrevData_, PPrevData_): 90 | if PPrevData_ is not None: 91 | x = 2 * PrevData_[2] - PPrevData_[2] 92 | y = 2 * PrevData_[3] - PPrevData_[3] 93 | else: 94 | x = PrevData_[2] 95 | y = PrevData_[3] 96 | 97 | return x, y 98 | 99 | def CoordPrediction_v2(self, PrevData_, PPrevData_): 100 | if PPrevData_ is not None: 101 | w = (PrevData_[4] + PPrevData_[4]) / 2.0 102 | h = (PrevData_[5] + PPrevData_[5]) / 2.0 103 | 104 | x0 = PPrevData_[2] + (PPrevData_[4] / 2.0) 105 | y0 = PPrevData_[3] + (PPrevData_[5] / 2.0) 106 | 107 | x1 = PrevData_[2] + (PrevData_[4] / 2.0) 108 | y1 = PrevData_[3] + (PrevData_[5] / 2.0) 109 | 110 | x_move = x1 - x0 111 | x_move = min(x_move, self.PredictThreshold * self.FrameWidth) 112 | x_move = max(x_move, -self.PredictThreshold * self.FrameWidth) 113 | 114 | y_move = y1 - y0 115 | y_move = min(y_move, self.PredictThreshold * self.FrameHeight) 116 | y_move = max(y_move, -self.PredictThreshold * self.FrameHeight) 117 | 118 | x2 = x1 + x_move 119 | y2 = y1 + y_move 120 | 121 | x = x2 - (w / 2.0) 122 | y = y2 - (h / 2.0) 123 | 124 | else: 125 | x = PrevData_[2] 126 | y = PrevData_[3] 127 | w = PrevData_[4] 128 | h = PrevData_[5] 129 | 130 | return x, y, w, h 131 | 132 | def CoordPrediction_v3(self, PrevData_, PPrevData_, PPPrevData_): 133 | if PPrevData_ is not None and PPPrevData_ is not None: 134 | w = (PPrevData_[4] + PPPrevData_[4]) / 2.0 135 | h = (PPrevData_[5] + PPPrevData_[5]) / 2.0 136 | 137 | x0 = PPPrevData_[2] + (PPPrevData_[4] / 2.0) 138 | y0 = PPPrevData_[3] + (PPPrevData_[5] / 2.0) 139 | 140 | x1 = PPrevData_[2] + (PPrevData_[4] / 2.0) 141 | y1 = PPrevData_[3] + (PPrevData_[5] / 2.0) 142 | 143 | x2 = 3 * x1 - 2 * x0 144 | y2 = 3 * y1 - 2 * y0 145 | 146 | x = x2 - (w / 2.0) 147 | y = y2 - (h / 2.0) 148 | 149 | else: 150 | x = PrevData_[2] 151 | y = PrevData_[3] 152 | w = PrevData_[4] 153 | h = PrevData_[5] 154 | 155 | return x, y, w, h 156 | 157 | def CoordPrediction_v4(self, PrevData_, PPrevData_, PPPrevData_, PPPPrevData_, PPPPPrevData_): 158 | w = PrevData_[4] 159 | h = PrevData_[5] 160 | 161 | x0 = PPPPPrevData_[2] + (PPPPPrevData_[4] / 2.0) 162 | y0 = PPPPPrevData_[3] + (PPPPPrevData_[5] / 2.0) 163 | 164 | x1 = PPPPrevData_[2] + (PPPPrevData_[4] / 2.0) 165 | y1 = PPPPrevData_[3] + (PPPPrevData_[5] / 2.0) 166 | 167 | x2 = PPPrevData_[2] + (PPPrevData_[4] / 2.0) 168 | y2 = PPPrevData_[3] + (PPPrevData_[5] / 2.0) 169 | 170 | x3 = PPrevData_[2] + (PPrevData_[4] / 2.0) 171 | y3 = PPrevData_[3] + (PPrevData_[5] / 2.0) 172 | 173 | x4 = PrevData_[2] + (PrevData_[4] / 2.0) 174 | y4 = PrevData_[3] + (PrevData_[5] / 2.0) 175 | 176 | x_move = ((x1 - x0) + (x2 - x1) + (x3 - x2) + (x4 - x3)) / 4.0 177 | y_move = ((y1 - y0) + (y2 - y1) + (y3 - y2) + (y4 - y3)) / 4.0 178 | 179 | x_move = min(x_move, self.PredictThreshold * self.FrameWidth) 180 | x_move = max(x_move, -self.PredictThreshold * self.FrameWidth) 181 | 182 | y_move = min(y_move, self.PredictThreshold * self.FrameHeight) 183 | y_move = max(y_move, -self.PredictThreshold * self.FrameHeight) 184 | 185 | x5 = x4 + x_move 186 | y5 = y4 + y_move 187 | 188 | x = x5 - (w / 2.0) 189 | y = y5 - (h / 2.0) 190 | 191 | return x, y, w, h 192 | 193 | def __call__(self, Amatrix, PrevIDs, CurData, PrevData, PPrevData, PPPrevData, PPPPrevData, PPPPPrevData, BirthLog, 194 | DeathLog): 195 | PreRange = np.arange(Amatrix.shape[0]) 196 | CurRange = np.arange(Amatrix.shape[1]) 197 | PrevMatchIndex = [] 198 | CurMatchIndex = [] 199 | 200 | # step 1: match 201 | while Amatrix.max() > self.Threshold: 202 | PrevIndex, CurIndex = np.unravel_index(Amatrix.argmax(), Amatrix.shape) 203 | if self.DistanceMeasure(PrevData[PrevIndex], CurData[CurIndex]): 204 | PrevMatchIndex.append(PrevIndex.copy()) 205 | CurMatchIndex.append(CurIndex.copy()) 206 | prevID = int(PrevIDs[PrevIndex]) 207 | 208 | # step 1.1: birth check 209 | if prevID < 0: 210 | self.BirthBuffer[prevID] += 1 211 | print("ID %d birth count %d" % (prevID, self.BirthBuffer[prevID])) 212 | 213 | if self.BirthBuffer[prevID] == self.BirthCount: 214 | CurData[CurIndex, 1] = self.ID_assign() 215 | BirthLog[0].append(PrevData[PrevIndex, 0]) 216 | BirthLog[1].append(prevID) 217 | BirthLog[2].append(CurData[CurIndex, 1]) 218 | print("---> New ID %d assigned to index %d" % (CurData[CurIndex, 1], CurIndex)) 219 | else: 220 | CurData[CurIndex, 1] = prevID 221 | 222 | # step 1.2: match 223 | else: 224 | # step 1.2.1: buffer clean 225 | self.DeathBuffer[prevID] = 0 226 | 227 | # step 1.2.2: copy ID 228 | CurData[CurIndex, 1] = prevID 229 | print("ID %d passed from index %d to index %d" % (prevID, PrevIndex, CurIndex)) 230 | 231 | Amatrix[PrevIndex, :] = self.Threshold 232 | Amatrix[:, CurIndex] = self.Threshold 233 | else: 234 | Amatrix[PrevIndex, CurIndex] = self.Threshold 235 | 236 | # step 2: find mismatch 237 | DeathIndex = np.setxor1d(np.array(PrevMatchIndex), PreRange).astype(int) 238 | BirthIndex = np.setxor1d(np.array(CurMatchIndex), CurRange).astype(int) 239 | print ("-----------------------> Birth and Death") 240 | print("DeathIndex: {}".format(DeathIndex)) 241 | print("BirthIndex: {}".format(BirthIndex)) 242 | 243 | # step 3: death process 244 | for i in range(len(DeathIndex)): 245 | deathID = int(PrevIDs[DeathIndex[i]]) 246 | if deathID < 0: 247 | pass 248 | else: 249 | self.DeathBuffer[deathID] += 1 250 | print("ID %d death count %d" % (deathID, self.DeathBuffer[deathID])) 251 | 252 | # step 3.1: terminate check 253 | if self.DeathBuffer[deathID] == self.DeathCount: 254 | DeathLog[0].append(PrevData[DeathIndex[i], 0]) 255 | DeathLog[1].append(deathID) 256 | print("terminate %d" % deathID) 257 | 258 | # step 3.2: death prediction 259 | else: 260 | PrevData_ = PrevData[PrevData[:, 1] == deathID].squeeze() 261 | if deathID in PPrevData[:, 1]: 262 | PPrevData_ = PPrevData[PPrevData[:, 1] == deathID].squeeze() 263 | else: 264 | PPrevData_ = PrevData_ 265 | if deathID in PPPrevData[:, 1]: 266 | PPPrevData_ = PPPrevData[PPPrevData[:, 1] == deathID].squeeze() 267 | else: 268 | PPPrevData_ = PPrevData_ 269 | if deathID in PPPPrevData[:, 1]: 270 | PPPPrevData_ = PPPPrevData[PPPPrevData[:, 1] == deathID].squeeze() 271 | else: 272 | PPPPrevData_ = PPPrevData_ 273 | if deathID in PPPPPrevData[:, 1]: 274 | PPPPPrevData_ = PPPPPrevData[PPPPPrevData[:, 1] == deathID].squeeze() 275 | else: 276 | PPPPPrevData_ = PPPPrevData_ 277 | DeathData = PrevData[DeathIndex[i]].copy() 278 | 279 | CoordPrediction = self.CoordPrediction_v4(PrevData_, PPrevData_, PPPrevData_, PPPPrevData_, 280 | PPPPPrevData_) 281 | DeathData[2], DeathData[3], DeathData[4], DeathData[5] = CoordPrediction[0], CoordPrediction[1], \ 282 | CoordPrediction[2], CoordPrediction[3] 283 | 284 | DeathData[0] += 1 # frame update 285 | CurData = np.concatenate((CurData, DeathData.reshape(1, -1))) 286 | print("ID %d coordinates predicted" % deathID) 287 | 288 | # step 4: birth process: 289 | for j in range(len(BirthIndex)): 290 | CurData[BirthIndex[j], 1] = self.ID_birth() 291 | print("Pseudo ID %d assigned to index %d" % (CurData[BirthIndex[j], 1], BirthIndex[j])) 292 | 293 | assert self.DeathBuffer.max() <= self.DeathCount 294 | assert self.BirthBuffer.max() <= self.BirthCount 295 | 296 | print("-----------------------> ID info") 297 | print("ID up to %d" % self.ID_assign.curID()) 298 | print("Pseudo ID up to %d" % self.ID_birth.curID()) 299 | 300 | return CurData, PrevData, PPrevData, PPPrevData, PPPPrevData, BirthLog, DeathLog 301 | -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : __init__.py 3 | # @Author : Peizhao Li 4 | # @Contact : lipeizhao1997@gmail.com 5 | # @Date : 2018/10/26 6 | -------------------------------------------------------------------------------- /train/train_net_1024.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : train_net_1024.py 3 | # @Author : Peizhao Li 4 | # @Contact : peizhaoli05@gmail.com 5 | # @Date : 2018/10/24 6 | 7 | import os.path as osp 8 | 9 | from model import net_1024 10 | from utils import * 11 | 12 | 13 | def train(parser, generator, log, log_path): 14 | # print("training net_1024\n") 15 | # model = net_1024.net_1024() 16 | 17 | print("training final\n") 18 | model = net_1024.net_1024() 19 | 20 | "----------------- pretrained model loading -----------------" 21 | # print("loading pretrained model") 22 | # checkpoint = torch.load("/home/lallazhao/MOT/result/Oct-25-at-02-17-net_1024/net_1024_88.4.pth") 23 | # model.load_state_dict(checkpoint["state_dict"]) 24 | "------------------------------------------------------------" 25 | 26 | model = model.cuda() 27 | net_param_dict = model.parameters() 28 | 29 | weight = torch.Tensor([10]) 30 | criterion_BCE = torch.nn.BCEWithLogitsLoss(pos_weight=weight).cuda() 31 | criterion_CE = torch.nn.CrossEntropyLoss().cuda() 32 | criterion_MSE = torch.nn.MSELoss().cuda() 33 | 34 | if parser.optimizer == "SGD": 35 | optimizer = torch.optim.SGD(net_param_dict, lr=parser.learning_rate, 36 | momentum=parser.momentum, weight_decay=parser.decay, nesterov=True) 37 | elif parser.optimizer == "Adam": 38 | optimizer = torch.optim.Adam(net_param_dict, lr=parser.learning_rate, weight_decay=parser.decay) 39 | elif parser.optimizer == "RMSprop": 40 | optimizer = torch.optim.RMSprop(net_param_dict, lr=parser.learning_rate, weight_decay=parser.decay, 41 | momentum=parser.momentum) 42 | else: 43 | raise NotImplementedError 44 | 45 | # Main Training and Evaluation Loop 46 | start_time, epoch_time = time.time(), AverageMeter() 47 | 48 | Batch_time = AverageMeter() 49 | Loss = AverageMeter() 50 | Acc = AverageMeter() 51 | Acc_pos = AverageMeter() 52 | 53 | for epoch in range(parser.start_epoch, parser.epochs): 54 | all_lrs = adjust_learning_rate(optimizer, epoch, parser.gammas, parser.schedule) 55 | need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (parser.epochs - epoch)) 56 | 57 | # ----------------------------------- train for one epoch ----------------------------------- 58 | batch_time, loss, acc, acc_pos = train_net_1024(model, generator, optimizer, criterion_BCE, criterion_CE, 59 | criterion_MSE) 60 | 61 | Batch_time.update(batch_time) 62 | Loss.update(loss.item()) 63 | Acc.update(acc) 64 | Acc_pos.update(acc_pos) 65 | 66 | if epoch % parser.print_freq == 0 or epoch == parser.epochs - 1: 67 | print_log('Epoch: [{:03d}/{:03d}]\t' 68 | 'Time {batch_time.val:5.2f} ({batch_time.avg:5.2f})\t' 69 | 'Loss {loss.val:6.3f} ({loss.avg:6.3f})\t' 70 | "Acc {acc.val:6.3f} ({acc.avg:6.3f})\t" 71 | "Acc_pos {acc_pos.val:6.3f} ({acc_pos.avg:6.3f})\t".format( 72 | epoch, parser.epochs, batch_time=Batch_time, loss=Loss, acc=Acc, acc_pos=Acc_pos), log) 73 | 74 | Batch_time = AverageMeter() 75 | Loss = AverageMeter() 76 | 77 | if (epoch in parser.schedule): 78 | print_log("------------------- adjust learning rate -------------------", log) 79 | # ------------------------------------------------------------------------------------------- 80 | 81 | # measure elapsed time 82 | epoch_time.update(time.time() - start_time) 83 | start_time = time.time() 84 | 85 | if parser.save_model: 86 | save_file_path = osp.join(log_path, "net_1024.pth") 87 | states = { 88 | "state_dict": model.state_dict(), 89 | } 90 | torch.save(states, save_file_path) 91 | 92 | 93 | def train_net_1024(model, generator, optimizer, criterion_BCE, criterion_CE, criterion_MSE): 94 | # switch to train mode 95 | model.train() 96 | 97 | cur_crop, pre_crop, cur_motion, pre_motion, gt_matrix = generator() 98 | assert len(cur_crop) == len(cur_motion) 99 | assert len(pre_crop) == len(pre_motion) 100 | 101 | target = torch.from_numpy(gt_matrix).cuda().float().view(-1) 102 | 103 | end = time.time() 104 | 105 | s0, s1, s2, s3, adj1, adj = model(pre_crop, cur_crop, pre_motion, cur_motion) 106 | loss = criterion_BCE(s0, target) 107 | loss += criterion_BCE(s1, target) 108 | loss += criterion_BCE(s2, target) 109 | loss += criterion_BCE(s3, target) 110 | loss += matrix_loss(adj1, gt_matrix, criterion_CE, criterion_MSE) 111 | loss += matrix_loss(adj, gt_matrix, criterion_CE, criterion_MSE) 112 | 113 | # s0, s3, adj = model(pre_crop, cur_crop) 114 | # loss = criterion_BCE(s0, target) 115 | # loss = criterion_BCE(s3, target) 116 | # loss += matrix_loss(adj1, gt_matrix, criterion_CE, criterion_MSE) 117 | # loss += matrix_loss(adj, gt_matrix, criterion_CE, criterion_MSE) 118 | 119 | acc, acc_pos = accuracy(s3.clone(), target.clone()) 120 | 121 | optimizer.zero_grad() 122 | loss.backward() 123 | optimizer.step() 124 | 125 | batch_time = time.time() - end 126 | 127 | return batch_time, loss, acc, acc_pos 128 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : utils.py 3 | # @Author : Peizhao Li 4 | # @Contact : peizhaoli05@gmail.com 5 | # @Date : 2018/9/27 6 | 7 | import yaml, torch, time, os 8 | from easydict import EasyDict as edict 9 | import numpy as np 10 | 11 | 12 | def Config(filename): 13 | listfile1 = open(filename, 'r') 14 | listfile2 = open(filename, 'r') 15 | parser = edict(yaml.load(listfile1)) 16 | settings_show = listfile2.read().splitlines() 17 | return parser, settings_show 18 | 19 | 20 | def adjust_learning_rate(optimizer, epoch, gammas, schedule): 21 | assert len(gammas) == len(schedule), "length of gammas and schedule should be equal" 22 | multiple = 1 23 | for (gamma, step) in zip(gammas, schedule): 24 | if (epoch == step): 25 | multiple = gamma 26 | break 27 | all_lrs = [] 28 | for param_group in optimizer.param_groups: 29 | param_group['lr'] = multiple * param_group['lr'] 30 | all_lrs.append(param_group['lr']) 31 | return set(all_lrs) 32 | 33 | 34 | def print_log(print_string, log, true_string=None): 35 | print("{}".format(print_string)) 36 | if true_string is not None: 37 | print_string = true_string 38 | if log is not None: 39 | log.write('{}\n'.format(print_string)) 40 | log.flush() 41 | 42 | 43 | def time_string(): 44 | ISOTIMEFORMAT = '%Y-%m-%d %X' 45 | string = '[{}]'.format(time.strftime(ISOTIMEFORMAT, time.localtime(time.time()))) 46 | return string 47 | 48 | 49 | def time_for_file(): 50 | ISOTIMEFORMAT = '%h-%d-at-%H-%M' 51 | return '{}'.format(time.strftime(ISOTIMEFORMAT, time.localtime(time.time()))) 52 | 53 | 54 | def convert_secs2time(epoch_time): 55 | need_hour = int(epoch_time / 3600) 56 | need_mins = int((epoch_time - 3600 * need_hour) / 60) 57 | need_secs = int(epoch_time - 3600 * need_hour - 60 * need_mins) 58 | return need_hour, need_mins, need_secs 59 | 60 | 61 | def extract_label(matrix): 62 | index = np.argwhere(matrix == 1) 63 | target = index[:, 1] 64 | target = torch.from_numpy(target).cuda() 65 | 66 | return target 67 | 68 | 69 | def matrix_loss(matrix, gt_matrix, criterion_CE, criterion_MSE): 70 | index_row_match = np.where([np.sum(gt_matrix, axis=1) == 1])[1] 71 | index_col_match = np.where([np.sum(gt_matrix, axis=0) == 1])[1] 72 | index_row_miss = np.where([np.sum(gt_matrix, axis=1) == 0])[1] 73 | index_col_miss = np.where([np.sum(gt_matrix, axis=0) == 0])[1] 74 | 75 | gt_matrix_row_match = np.take(gt_matrix, index_row_match, axis=0) 76 | gt_matrix_col_match = np.take(gt_matrix.transpose(), index_col_match, axis=0) 77 | 78 | index_row_match = torch.from_numpy(index_row_match).cuda() 79 | index_col_match = torch.from_numpy(index_col_match).cuda() 80 | 81 | matrix_row_match = torch.index_select(matrix, dim=0, index=index_row_match) 82 | matrix_col_match = torch.index_select(matrix.t(), dim=0, index=index_col_match) 83 | 84 | label_row_CE = extract_label(gt_matrix_row_match) 85 | label_col_CE = extract_label(gt_matrix_col_match) 86 | 87 | loss = criterion_CE(matrix_row_match, label_row_CE) 88 | loss += criterion_CE(matrix_col_match, label_col_CE) 89 | 90 | if index_row_miss.size != 0: 91 | index_row_miss = torch.from_numpy(index_row_miss).cuda() 92 | matrix_row_miss = torch.index_select(matrix, dim=0, index=index_row_miss) 93 | loss += criterion_MSE(torch.sigmoid(matrix_row_miss), torch.zeros_like(matrix_row_miss)) 94 | 95 | if index_col_miss.size != 0: 96 | index_col_miss = torch.from_numpy(index_col_miss).cuda() 97 | matrix_col_miss = torch.index_select(matrix.t(), dim=0, index=index_col_miss) 98 | loss += criterion_MSE(torch.sigmoid(matrix_col_miss), torch.zeros_like(matrix_col_miss)) 99 | 100 | return loss 101 | 102 | 103 | def accuracy(input, target): 104 | assert input.size() == target.size() 105 | 106 | input[input < 0] = 0 107 | input[input > 0] = 1 108 | batch_size = input.size(0) 109 | pos_size = torch.sum(target) 110 | 111 | dis = input.sub(target) 112 | wrong = torch.sum(torch.abs(dis)) 113 | acc = (batch_size - wrong.item()) / batch_size 114 | 115 | index = torch.nonzero(target) 116 | input_pos = torch.sum(input[index]) 117 | acc_pos = input_pos.item() / pos_size 118 | 119 | return acc, acc_pos 120 | 121 | 122 | class AverageMeter(object): 123 | """Computes and stores the average and current value""" 124 | 125 | def __init__(self): 126 | self.reset() 127 | 128 | def reset(self): 129 | self.val = 0 130 | self.avg = 0 131 | self.sum = 0 132 | self.count = 0 133 | 134 | def update(self, val, n=1): 135 | self.val = val 136 | self.sum += val * n 137 | self.count += n 138 | self.avg = self.sum / self.count 139 | --------------------------------------------------------------------------------