├── App2 ├── copyfile.py ├── dataset.py ├── global_set.py ├── mot_model.py ├── munkres.py ├── test.py ├── test_dataset.py └── train.py ├── GN ├── copyfile.py ├── global_set.py ├── m_mot_model.py ├── m_test_dataset.py ├── mot_model.py ├── munkres.py ├── test.py └── test_dataset.py ├── Motion1 ├── copyfile.py ├── dataset.py ├── global_set.py ├── m_mot_model.py ├── munkres.py ├── test.py ├── test_dataset.py └── train.py ├── Pic ├── Pipeline.png └── mot.gif ├── README.md └── requirements.txt /App2/copyfile.py: -------------------------------------------------------------------------------- 1 | import shutil, os 2 | from global_set import mot_dataset_dir 3 | 4 | name = 'motmetrics' 5 | types = ['POI'] 6 | 7 | seqs = [2, 4, 5, 9, 10, 11, 13] # the set of sequences 8 | lengths = [600, 1050, 837, 525, 654, 900, 750] # the length of the sequence 9 | 10 | test_seqs = [1, 3, 6, 7, 8, 12, 14] 11 | test_lengths = [450, 1500, 1194, 500, 625, 900, 750] 12 | 13 | # copy the results for testing sets 14 | for type in types: 15 | for i in range(len(seqs)): 16 | src_dir = 'results/%02d/%d/%s_%s/res.txt' % (test_seqs[i], test_lengths[i], name, type) 17 | 18 | t = type 19 | if type == 'DPM0' or type == 'POI': 20 | t = 'DPM' 21 | 22 | des_d = 'mot16/' 23 | if not os.path.exists(des_d): 24 | os.mkdir(des_d) 25 | des_dir = des_d + 'MOT16-%02d.txt' % (test_seqs[i]) 26 | 27 | print(src_dir) 28 | print(des_dir) 29 | shutil.copyfile(src_dir, des_dir) 30 | 31 | types = ['POI'] 32 | for type in types: 33 | for i in range(len(seqs)): 34 | src_dir = mot_dataset_dir + 'MOT17/train/MOT17-%02d-%s/gt/gt.txt' % (seqs[i], type) 35 | 36 | des_dir = des_d + 'MOT16-%02d.txt' % (seqs[i]) 37 | 38 | print(src_dir) 39 | print(des_dir) 40 | shutil.copyfile(src_dir, des_dir) 41 | -------------------------------------------------------------------------------- /App2/dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import cv2, random, torch, shutil, os 3 | import numpy as np 4 | from PIL import Image 5 | import torch.nn.functional as F 6 | from mot_model import appearance 7 | from global_set import edge_initial, app_fine_tune, fine_tune_dir, overlap 8 | from torchvision.transforms import ToTensor 9 | 10 | 11 | def load_img(filepath): 12 | img = Image.open(filepath).convert('RGB') 13 | return img 14 | 15 | 16 | class DatasetFromFolder(data.Dataset): 17 | def __init__(self, part, cuda=True, show=0): 18 | super(DatasetFromFolder, self).__init__() 19 | self.dir = part 20 | self.cleanPath(part) 21 | self.img_dir = part + '/img1/' 22 | self.gt_dir = part + '/gt/' 23 | self.det_dir = part + '/det/' 24 | self.device = torch.device("cuda" if cuda else "cpu") 25 | self.show = show 26 | 27 | self.loadAModel() 28 | self.getSeqL() 29 | self.readBBx() 30 | self.initBuffer() 31 | 32 | def cleanPath(self, part): 33 | if os.path.exists(part + '/gts/'): 34 | shutil.rmtree(part + '/gts/') 35 | if os.path.exists(part + '/dets/'): 36 | shutil.rmtree(part + '/dets/') 37 | 38 | def loadAModel(self): 39 | if app_fine_tune: 40 | self.Appearance = torch.load(fine_tune_dir) 41 | else: 42 | self.Appearance = appearance() 43 | self.Appearance.to(self.device) 44 | self.Appearance.eval() # fixing the BatchN layer 45 | 46 | def getSeqL(self): 47 | # get the length of the sequence 48 | info = self.dir + '/seqinfo.ini' 49 | f = open(info, 'r') 50 | f.readline() 51 | for line in f.readlines(): 52 | line = line.strip().split('=') 53 | if line[0] == 'seqLength': 54 | self.seqL = int(line[1]) 55 | f.close() 56 | print(' The length of the sequence:', self.seqL) 57 | 58 | def generator(self, bbx): 59 | if random.randint(0, 1): 60 | x, y, w, h = bbx 61 | x, y, w = float(x), float(y), float(w), 62 | tmp = overlap * 2 / (1 + overlap) 63 | n_w = random.uniform(tmp * w, w) 64 | n_h = tmp * w * float(h) / n_w 65 | 66 | direction = random.randint(1, 4) 67 | if direction == 1: 68 | x = x + n_w - w 69 | y = y + n_h - h 70 | elif direction == 2: 71 | x = x - n_w + w 72 | y = y + n_h - h 73 | elif direction == 3: 74 | x = x + n_w - w 75 | y = y - n_h + h 76 | else: 77 | x = x - n_w + w 78 | y = y - n_h + h 79 | ans = [int(x), int(y), int(w), h] 80 | return ans 81 | return bbx 82 | 83 | def readBBx(self): 84 | # get the gt 85 | self.bbx = [[] for i in range(self.seqL + 1)] 86 | gt = self.gt_dir + 'gt.txt' 87 | f = open(gt, 'r') 88 | pre = -1 89 | for line in f.readlines(): 90 | line = line.strip().split(',') 91 | if line[7] == '1': 92 | index = int(line[0]) 93 | id = int(line[1]) 94 | x, y = int(line[2]), int(line[3]) 95 | w, h = int(line[4]), int(line[5]) 96 | conf_score, l, vr = float(line[6]), int(line[7]), float(line[8]) 97 | 98 | # sweep the invisible head-bbx from the training data 99 | if pre != id and vr == 0: 100 | continue 101 | 102 | pre = id 103 | x, y, w, h = self.generator([x, y, w, h]) 104 | self.bbx[index].append([x, y, w, h, id, vr]) 105 | f.close() 106 | 107 | def initBuffer(self): 108 | if self.show: 109 | cv2.namedWindow('view', flags=0) 110 | cv2.namedWindow('crop', flags=0) 111 | self.f_step = 1 # the index of next frame in the process 112 | self.cur = 0 # the index of current frame in the detections 113 | self.nxt = 1 # the index of next frame in the detections 114 | self.detections = [None, None] # the buffer to storing images: current & next frame 115 | self.feature(1) 116 | 117 | def setBuffer(self, f): 118 | self.f_step = f 119 | self.feature(1) 120 | 121 | def fixBB(self, x, y, w, h, size): 122 | width, height = size 123 | w = min(w + x, width) 124 | h = min(h + y, height) 125 | x = max(x, 0) 126 | y = max(y, 0) 127 | w -= x 128 | h -= y 129 | return x, y, w, h 130 | 131 | def IOU(self, Reframe, GTframe): 132 | """ 133 | Compute the Intersection of Union 134 | :param Reframe: x, y, w, h 135 | :param GTframe: x, y, w, h 136 | :return: Ratio 137 | """ 138 | if edge_initial == 1: 139 | return random.random() 140 | elif edge_initial == 3: 141 | return 0.5 142 | x1 = Reframe[0] 143 | y1 = Reframe[1] 144 | width1 = Reframe[2] 145 | height1 = Reframe[3] 146 | 147 | x2 = GTframe[0] 148 | y2 = GTframe[1] 149 | width2 = GTframe[2] 150 | height2 = GTframe[3] 151 | 152 | endx = max(x1 + width1, x2 + width2) 153 | startx = min(x1, x2) 154 | width = width1 + width2 - (endx - startx) 155 | 156 | endy = max(y1 + height1, y2 + height2) 157 | starty = min(y1, y2) 158 | height = height1 + height2 - (endy - starty) 159 | 160 | if width <= 0 or height <= 0: 161 | ratio = 0 162 | else: 163 | Area = width * height 164 | Area1 = width1 * height1 165 | Area2 = width2 * height2 166 | ratio = Area * 1. / (Area1 + Area2 - Area) 167 | return ratio 168 | 169 | def getMN(self, m, n): 170 | ans = [[None for i in range(n)] for i in range(m)] 171 | for i in range(m): 172 | Reframe = self.bbx[self.f_step - 1][i] 173 | for j in range(n): 174 | GTframe = self.bbx[self.f_step][j] 175 | p = self.IOU(Reframe, GTframe) 176 | # 1 - match, 0 - mismatch 177 | ans[i][j] = torch.FloatTensor([(1 - p) / 100.0, p / 100.0]) 178 | return ans 179 | 180 | def aggregate(self, set): 181 | if len(set): 182 | rho = sum(set) 183 | return rho / len(set) 184 | print(' The set is empty!') 185 | return None 186 | 187 | def getApp(self, tag, index): 188 | cur = self.cur if tag else self.nxt 189 | if torch.is_tensor(index): 190 | n = index.numel() 191 | if n < 0: 192 | print('The tensor is empyt!') 193 | return None 194 | if n == 1: 195 | return self.detections[cur][index[0]][0] 196 | ans = torch.cat((self.detections[cur][index[0]][0], self.detections[cur][index[1]][0]), dim=0) 197 | for i in range(2, n): 198 | ans = torch.cat((ans, self.detections[cur][index[i]][0]), dim=0) 199 | return ans 200 | return self.detections[cur][index][0] 201 | 202 | def swapFC(self): 203 | self.cur = self.cur ^ self.nxt 204 | self.nxt = self.cur ^ self.nxt 205 | self.cur = self.cur ^ self.nxt 206 | 207 | def resnet34(self, img): 208 | bbx = ToTensor()(img) 209 | bbx = bbx.to(self.device) 210 | bbx = bbx.view(-1, bbx.size(0), bbx.size(1), bbx.size(2)) 211 | ret = self.Appearance(bbx) 212 | ret = ret.view(1, -1) 213 | return ret 214 | 215 | def feature(self, tag=0): 216 | ''' 217 | Getting the appearance of the detections in current frame 218 | :param tag: 1 - initiating 219 | :param show: 1 - show the cropped & src image 220 | :return: None 221 | ''' 222 | apps = [] 223 | with torch.no_grad(): 224 | bbx_container = [] 225 | img = load_img(self.img_dir + '%06d.jpg' % self.f_step) # initial with loading the first frame 226 | for bbx in self.bbx[self.f_step]: 227 | """ 228 | Condition needed be taken into consideration: 229 | x, y < 0 and x+w > W, y+h > H 230 | """ 231 | x, y, w, h, id, vr = bbx 232 | x, y, w, h = self.fixBB(x, y, w, h, img.size) 233 | bbx_container.append([x, y, w, h, id, vr]) 234 | crop = img.crop([x, y, x + w, y + h]) 235 | bbx = crop.resize((224, 224), Image.ANTIALIAS) 236 | ret = self.resnet34(bbx) 237 | app = ret.data 238 | apps.append([app, id]) 239 | 240 | if self.show: 241 | img = np.asarray(img) 242 | crop = np.asarray(crop) 243 | print('%06d' % self.f_step, id, vr, '***', end=' ') 244 | print(w, h, '-', end=' ') 245 | print(len(crop[0]), len(crop)) 246 | cv2.imshow('crop', crop) 247 | cv2.imshow('view', img) 248 | cv2.waitKey(34) 249 | input('Continue?') 250 | # cv2.waitKey(34) 251 | self.bbx[self.f_step] = bbx_container 252 | if tag: 253 | self.detections[self.cur] = apps 254 | else: 255 | self.detections[self.nxt] = apps 256 | 257 | def initEC(self): 258 | self.m = len(self.detections[self.cur]) 259 | self.n = len(self.detections[self.nxt]) 260 | self.candidates = [] 261 | self.edges = self.getMN(self.m, self.n) 262 | self.gts = [[None for j in range(self.n)] for i in range(self.m)] 263 | self.step_gt = 0.0 264 | for i in range(self.m): 265 | for j in range(self.n): 266 | tag = int(self.detections[self.cur][i][1] == self.detections[self.nxt][j][1]) 267 | self.gts[i][j] = torch.LongTensor([tag]) 268 | self.step_gt += tag * 1.0 269 | 270 | for i in range(self.m): 271 | for j in range(self.n): 272 | e = self.edges[i][j] 273 | gt = self.gts[i][j] 274 | self.candidates.append([e, gt, i, j]) 275 | 276 | def loadNext(self): 277 | self.f_step += 1 278 | self.feature() 279 | self.initEC() 280 | # print ' The index of the next frame', self.f_step 281 | # print self.detections[self.cur] 282 | # print self.detections[self.nxt] 283 | 284 | def __getitem__(self, index): 285 | return self.candidates[index] 286 | 287 | def __len__(self): 288 | return len(self.candidates) 289 | -------------------------------------------------------------------------------- /App2/global_set.py: -------------------------------------------------------------------------------- 1 | #mot_dataset_dir = '/media/codinglee/DATA/Ubuntu16.04/Desktop/MOT/' 2 | mot_dataset_dir = '../MOT/' 3 | 4 | model_dir = 'model/' 5 | 6 | u_initial = 1 # 1 - random, 0 - 0 7 | 8 | edge_initial = 0 # 1 - random, 0 - IoU 9 | 10 | criterion_s = 0 # 1 - MSELoss, 0 - CrossEntropyLoss 11 | 12 | test_gt_det = 0 # 1 - detections of gt, 0 - detections of det 13 | 14 | u_update = 1 # 1 - update when testing, 0 - without updating 15 | if u_update: 16 | u_dir = '_uupdate' 17 | else: 18 | u_dir = '' 19 | 20 | app_fine_tune = 0 # 1 - fine-tunedthe appearance model, 0 - pre-trained appearance model 21 | if app_fine_tune: 22 | fine_tune_dir = mot_dataset_dir + 'Fine-tune_GPU_5_3_60_aug/appearance_19.pth' 23 | app_dir = 'Finetuned' 24 | else: 25 | fine_tune_dir = '' 26 | app_dir = 'Pretrained' 27 | 28 | decay = 1.3 29 | decay_dir = '_decay' 30 | 31 | f_gap = 5 32 | if f_gap: 33 | recover_dir = '_Recover' 34 | else: 35 | recover_dir = '_NoRecover' 36 | 37 | tau_threshold = 1.0 # The threshold of matching cost 38 | tau_dis = 2.0 # The times of the current bbx's scale 39 | gap = 25 # max frame number for side connection 40 | 41 | show_recovering = 0 # 1 - 11, 0 - 10 42 | 43 | overlap = 0.85 # the IoU 44 | -------------------------------------------------------------------------------- /App2/mot_model.py: -------------------------------------------------------------------------------- 1 | import torch, torchvision 2 | import torch.nn as nn 3 | from global_set import criterion_s 4 | 5 | v_num = 512 # Only take the appearance into consideration, and add velocity when basic model works 6 | u_num = 100 7 | e_num = 1 if criterion_s else 2 8 | 9 | 10 | class appearance(nn.Module): 11 | def __init__(self): 12 | super(appearance, self).__init__() 13 | features = list(torchvision.models.resnet34(pretrained=True).children())[:-1] 14 | # print features 15 | self.features = nn.Sequential(*features) 16 | 17 | def forward(self, x): 18 | return self.features(x) 19 | 20 | 21 | class uphi(nn.Module): 22 | def __init__(self): 23 | super(uphi, self).__init__() 24 | self.features = nn.Sequential( 25 | nn.Linear(u_num + v_num + e_num, 256), 26 | nn.LeakyReLU(inplace=True), 27 | nn.Linear(256, u_num), 28 | ) 29 | 30 | def forward(self, e, v, u): 31 | """ 32 | The network which updates the global variable u 33 | :param e: the aggregation of the probability 34 | :param v: the aggregation of the vertice 35 | :param u: global variable 36 | """ 37 | # print 'U:', e.size(), v.size(), u.size() 38 | bs = e.size()[0] 39 | if bs == 1: 40 | tmp = u 41 | else: 42 | tmp = torch.cat((u, u), dim=0) 43 | for i in range(2, bs): 44 | tmp = torch.cat((tmp, u), dim=0) 45 | x = torch.cat((e, v), dim=1) 46 | x = torch.cat((x, tmp), dim=1) 47 | return self.features(x) 48 | 49 | 50 | class ephi(nn.Module): 51 | def __init__(self): 52 | super(ephi, self).__init__() 53 | self.features = nn.Sequential( 54 | nn.Linear(u_num + v_num * 2 + e_num, 256), 55 | nn.LeakyReLU(inplace=True), 56 | nn.Linear(256, e_num), 57 | ) 58 | 59 | def forward(self, e, v1, v2, u): 60 | """ 61 | The network which updates the probability e 62 | :param e: the probability between two detections 63 | :param v1: the sender 64 | :param v2: the receiver 65 | :param u: global variable 66 | """ 67 | # print 'E:', e.size(), v1.size(), v2.size(), u.size() 68 | bs = e.size()[0] 69 | if bs == 1: 70 | tmp = u 71 | else: 72 | tmp = torch.cat((u, u), dim=0) 73 | for i in range(2, bs): 74 | tmp = torch.cat((tmp, u), dim=0) 75 | x = torch.cat((e, v1), dim=1) 76 | x = torch.cat((x, v2), dim=1) 77 | x = torch.cat((x, tmp), dim=1) 78 | return self.features(x) 79 | 80 | 81 | class vphi(nn.Module): 82 | def __init__(self): 83 | super(vphi, self).__init__() 84 | self.features = nn.Sequential( 85 | nn.Linear(u_num + v_num * 2 + e_num, 256), 86 | nn.LeakyReLU(inplace=True), 87 | nn.Linear(256, v_num), 88 | ) 89 | 90 | def forward(self, e, v1, v2, u): 91 | """ 92 | The network which updates the probability e 93 | :param e: the probability between two detections 94 | :param v1: the sender 95 | :param v2: the receiver 96 | :param u: global variable 97 | """ 98 | # print 'E:', e.size(), v1.size(), v2.size(), u.size() 99 | bs = e.size()[0] 100 | if bs == 1: 101 | tmp = u 102 | else: 103 | tmp = torch.cat((u, u), dim=0) 104 | for i in range(2, bs): 105 | tmp = torch.cat((tmp, u), dim=0) 106 | x = torch.cat((e, v1), dim=1) 107 | x = torch.cat((x, v2), dim=1) 108 | x = torch.cat((x, tmp), dim=1) 109 | return self.features(x) 110 | -------------------------------------------------------------------------------- /App2/munkres.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: iso-8859-1 -*- 3 | 4 | # Documentation is intended to be processed by Epydoc. 5 | 6 | """ 7 | Introduction 8 | ============ 9 | 10 | The Munkres module provides an implementation of the Munkres algorithm 11 | (also called the Hungarian algorithm or the Kuhn-Munkres algorithm), 12 | useful for solving the Assignment Problem. 13 | 14 | Assignment Problem 15 | ================== 16 | 17 | Let *C* be an *n* by *n* matrix representing the costs of each of *n* workers 18 | to perform any of *n* jobs. The assignment problem is to assign jobs to 19 | workers in a way that minimizes the total cost. Since each worker can perform 20 | only one job and each job can be assigned to only one worker the assignments 21 | represent an independent set of the matrix *C*. 22 | 23 | One way to generate the optimal set is to create all permutations of 24 | the indexes necessary to traverse the matrix so that no row and column 25 | are used more than once. For instance, given this matrix (expressed in 26 | Python): 27 | 28 | matrix = [[5, 9, 1], 29 | [10, 3, 2], 30 | [8, 7, 4]] 31 | 32 | You could use this code to generate the traversal indexes: 33 | 34 | def permute(a, results): 35 | if len(a) == 1: 36 | results.insert(len(results), a) 37 | 38 | else: 39 | for i in range(0, len(a)): 40 | element = a[i] 41 | a_copy = [a[j] for j in range(0, len(a)) if j != i] 42 | subresults = [] 43 | permute(a_copy, subresults) 44 | for subresult in subresults: 45 | result = [element] + subresult 46 | results.insert(len(results), result) 47 | 48 | results = [] 49 | permute(range(len(matrix)), results) # [0, 1, 2] for a 3x3 matrix 50 | 51 | After the call to permute(), the results matrix would look like this: 52 | 53 | [[0, 1, 2], 54 | [0, 2, 1], 55 | [1, 0, 2], 56 | [1, 2, 0], 57 | [2, 0, 1], 58 | [2, 1, 0]] 59 | 60 | You could then use that index matrix to loop over the original cost matrix 61 | and calculate the smallest cost of the combinations: 62 | 63 | minval = sys.maxsize 64 | for indexes in results: 65 | cost = 0 66 | for row, col in enumerate(indexes): 67 | cost += matrix[row][col] 68 | minval = min(cost, minval) 69 | 70 | print minval 71 | 72 | While this approach works fine for small matrices, it does not scale. It 73 | executes in O(*n*!) time: Calculating the permutations for an *n*\ x\ *n* 74 | matrix requires *n*! operations. For a 12x12 matrix, that's 479,001,600 75 | traversals. Even if you could manage to perform each traversal in just one 76 | millisecond, it would still take more than 133 hours to perform the entire 77 | traversal. A 20x20 matrix would take 2,432,902,008,176,640,000 operations. At 78 | an optimistic millisecond per operation, that's more than 77 million years. 79 | 80 | The Munkres algorithm runs in O(*n*\ ^3) time, rather than O(*n*!). This 81 | package provides an implementation of that algorithm. 82 | 83 | This version is based on 84 | http://csclab.murraystate.edu/~bob.pilgrim/445/munkres.html 85 | 86 | This version was written for Python by Brian Clapper from the algorithm 87 | at the above web site. (The ``Algorithm:Munkres`` Perl version, in CPAN, was 88 | clearly adapted from the same web site.) 89 | 90 | Usage 91 | ===== 92 | 93 | Construct a Munkres object: 94 | 95 | from munkres import Munkres 96 | 97 | m = Munkres() 98 | 99 | Then use it to compute the lowest cost assignment from a cost matrix. Here's 100 | a sample program: 101 | 102 | from munkres import Munkres, print_matrix 103 | 104 | matrix = [[5, 9, 1], 105 | [10, 3, 2], 106 | [8, 7, 4]] 107 | m = Munkres() 108 | indexes = m.compute(matrix) 109 | print_matrix(matrix, msg='Lowest cost through this matrix:') 110 | total = 0 111 | for row, column in indexes: 112 | value = matrix[row][column] 113 | total += value 114 | print '(%d, %d) -> %d' % (row, column, value) 115 | print 'total cost: %d' % total 116 | 117 | Running that program produces: 118 | 119 | Lowest cost through this matrix: 120 | [5, 9, 1] 121 | [10, 3, 2] 122 | [8, 7, 4] 123 | (0, 0) -> 5 124 | (1, 1) -> 3 125 | (2, 2) -> 4 126 | total cost=12 127 | 128 | The instantiated Munkres object can be used multiple times on different 129 | matrices. 130 | 131 | Non-square Cost Matrices 132 | ======================== 133 | 134 | The Munkres algorithm assumes that the cost matrix is square. However, it's 135 | possible to use a rectangular matrix if you first pad it with 0 values to make 136 | it square. This module automatically pads rectangular cost matrices to make 137 | them square. 138 | 139 | Notes: 140 | 141 | - The module operates on a *copy* of the caller's matrix, so any padding will 142 | not be seen by the caller. 143 | - The cost matrix must be rectangular or square. An irregular matrix will 144 | *not* work. 145 | 146 | Calculating Profit, Rather than Cost 147 | ==================================== 148 | 149 | The cost matrix is just that: A cost matrix. The Munkres algorithm finds 150 | the combination of elements (one from each row and column) that results in 151 | the smallest cost. It's also possible to use the algorithm to maximize 152 | profit. To do that, however, you have to convert your profit matrix to a 153 | cost matrix. The simplest way to do that is to subtract all elements from a 154 | large value. For example: 155 | 156 | from munkres import Munkres, print_matrix 157 | 158 | matrix = [[5, 9, 1], 159 | [10, 3, 2], 160 | [8, 7, 4]] 161 | cost_matrix = [] 162 | for row in matrix: 163 | cost_row = [] 164 | for col in row: 165 | cost_row += [sys.maxsize - col] 166 | cost_matrix += [cost_row] 167 | 168 | m = Munkres() 169 | indexes = m.compute(cost_matrix) 170 | print_matrix(matrix, msg='Highest profit through this matrix:') 171 | total = 0 172 | for row, column in indexes: 173 | value = matrix[row][column] 174 | total += value 175 | print '(%d, %d) -> %d' % (row, column, value) 176 | 177 | print 'total profit=%d' % total 178 | 179 | Running that program produces: 180 | 181 | Highest profit through this matrix: 182 | [5, 9, 1] 183 | [10, 3, 2] 184 | [8, 7, 4] 185 | (0, 1) -> 9 186 | (1, 0) -> 10 187 | (2, 2) -> 4 188 | total profit=23 189 | 190 | The ``munkres`` module provides a convenience method for creating a cost 191 | matrix from a profit matrix. By default, it calculates the maximum profit 192 | and subtracts every profit from it to obtain a cost. If, however, you 193 | need a more general function, you can provide the 194 | conversion function; but the convenience method takes care of the actual 195 | creation of the matrix: 196 | 197 | import munkres 198 | 199 | cost_matrix = munkres.make_cost_matrix( 200 | matrix, 201 | lambda profit: 1000.0 - math.sqrt(profit)) 202 | 203 | So, the above profit-calculation program can be recast as: 204 | 205 | from munkres import Munkres, print_matrix, make_cost_matrix 206 | 207 | matrix = [[5, 9, 1], 208 | [10, 3, 2], 209 | [8, 7, 4]] 210 | cost_matrix = make_cost_matrix(matrix) 211 | # cost_matrix == [[5, 1, 9], 212 | # [0, 7, 8], 213 | # [2, 3, 6]] 214 | m = Munkres() 215 | indexes = m.compute(cost_matrix) 216 | print_matrix(matrix, msg='Highest profits through this matrix:') 217 | total = 0 218 | for row, column in indexes: 219 | value = matrix[row][column] 220 | total += value 221 | print '(%d, %d) -> %d' % (row, column, value) 222 | print 'total profit=%d' % total 223 | 224 | Disallowed Assignments 225 | ====================== 226 | 227 | You can also mark assignments in your cost or profit matrix as disallowed. 228 | Simply use the munkres.DISALLOWED constant. 229 | 230 | from munkres import Munkres, print_matrix, make_cost_matrix, DISALLOWED 231 | 232 | matrix = [[5, 9, DISALLOWED], 233 | [10, DISALLOWED, 2], 234 | [8, 7, 4]] 235 | cost_matrix = make_cost_matrix(matrix, lambda cost: (sys.maxsize - cost) if 236 | (cost != DISALLOWED) else DISALLOWED) 237 | m = Munkres() 238 | indexes = m.compute(cost_matrix) 239 | print_matrix(matrix, msg='Highest profit through this matrix:') 240 | total = 0 241 | for row, column in indexes: 242 | value = matrix[row][column] 243 | total += value 244 | print '(%d, %d) -> %d' % (row, column, value) 245 | print 'total profit=%d' % total 246 | 247 | Running this program produces: 248 | 249 | Lowest cost through this matrix: 250 | [ 5, 9, D] 251 | [10, D, 2] 252 | [ 8, 7, 4] 253 | (0, 1) -> 9 254 | (1, 0) -> 10 255 | (2, 2) -> 4 256 | total profit=23 257 | 258 | References 259 | ========== 260 | 261 | 1. http://www.public.iastate.edu/~ddoty/HungarianAlgorithm.html 262 | 263 | 2. Harold W. Kuhn. The Hungarian Method for the assignment problem. 264 | *Naval Research Logistics Quarterly*, 2:83-97, 1955. 265 | 266 | 3. Harold W. Kuhn. Variants of the Hungarian method for assignment 267 | problems. *Naval Research Logistics Quarterly*, 3: 253-258, 1956. 268 | 269 | 4. Munkres, J. Algorithms for the Assignment and Transportation Problems. 270 | *Journal of the Society of Industrial and Applied Mathematics*, 271 | 5(1):32-38, March, 1957. 272 | 273 | 5. http://en.wikipedia.org/wiki/Hungarian_algorithm 274 | 275 | Copyright and License 276 | ===================== 277 | 278 | Copyright 2008-2016 Brian M. Clapper 279 | 280 | Licensed under the Apache License, Version 2.0 (the "License"); 281 | you may not use this file except in compliance with the License. 282 | You may obtain a copy of the License at 283 | 284 | http://www.apache.org/licenses/LICENSE-2.0 285 | 286 | Unless required by applicable law or agreed to in writing, software 287 | distributed under the License is distributed on an "AS IS" BASIS, 288 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 289 | See the License for the specific language governing permissions and 290 | limitations under the License. 291 | """ 292 | 293 | __docformat__ = 'restructuredtext' 294 | 295 | # --------------------------------------------------------------------------- 296 | # Imports 297 | # --------------------------------------------------------------------------- 298 | 299 | import sys 300 | import copy 301 | 302 | # --------------------------------------------------------------------------- 303 | # Exports 304 | # --------------------------------------------------------------------------- 305 | 306 | __all__ = ['Munkres', 'make_cost_matrix', 'DISALLOWED'] 307 | 308 | # --------------------------------------------------------------------------- 309 | # Globals 310 | # --------------------------------------------------------------------------- 311 | 312 | # Info about the module 313 | __version__ = "1.0.12" 314 | __author__ = "Brian Clapper, bmc@clapper.org" 315 | __url__ = "http://software.clapper.org/munkres/" 316 | __copyright__ = "(c) 2008-2017 Brian M. Clapper" 317 | __license__ = "Apache Software License" 318 | 319 | # Constants 320 | class DISALLOWED_OBJ(object): 321 | pass 322 | DISALLOWED = DISALLOWED_OBJ() 323 | DISALLOWED_PRINTVAL = "D" 324 | 325 | # --------------------------------------------------------------------------- 326 | # Exceptions 327 | # --------------------------------------------------------------------------- 328 | 329 | class UnsolvableMatrix(Exception): 330 | """ 331 | Exception raised for unsolvable matrices 332 | """ 333 | pass 334 | 335 | # --------------------------------------------------------------------------- 336 | # Classes 337 | # --------------------------------------------------------------------------- 338 | 339 | class Munkres: 340 | """ 341 | Calculate the Munkres solution to the classical assignment problem. 342 | See the module documentation for usage. 343 | """ 344 | 345 | def __init__(self): 346 | """Create a new instance""" 347 | self.C = None 348 | self.row_covered = [] 349 | self.col_covered = [] 350 | self.n = 0 351 | self.Z0_r = 0 352 | self.Z0_c = 0 353 | self.marked = None 354 | self.path = None 355 | 356 | def make_cost_matrix(profit_matrix, inversion_function): 357 | """ 358 | **DEPRECATED** 359 | 360 | Please use the module function ``make_cost_matrix()``. 361 | """ 362 | import munkres 363 | return munkres.make_cost_matrix(profit_matrix, inversion_function) 364 | 365 | make_cost_matrix = staticmethod(make_cost_matrix) 366 | 367 | def pad_matrix(self, matrix, pad_value=0): 368 | """ 369 | Pad a possibly non-square matrix to make it square. 370 | 371 | :Parameters: 372 | matrix : list of lists 373 | matrix to pad 374 | 375 | pad_value : int 376 | value to use to pad the matrix 377 | 378 | :rtype: list of lists 379 | :return: a new, possibly padded, matrix 380 | """ 381 | max_columns = 0 382 | total_rows = len(matrix) 383 | 384 | for row in matrix: 385 | max_columns = max(max_columns, len(row)) 386 | 387 | total_rows = max(max_columns, total_rows) 388 | 389 | new_matrix = [] 390 | for row in matrix: 391 | row_len = len(row) 392 | new_row = row[:] 393 | if total_rows > row_len: 394 | # Row too short. Pad it. 395 | new_row += [pad_value] * (total_rows - row_len) 396 | new_matrix += [new_row] 397 | 398 | while len(new_matrix) < total_rows: 399 | new_matrix += [[pad_value] * total_rows] 400 | 401 | return new_matrix 402 | 403 | def compute(self, cost_matrix): 404 | """ 405 | Compute the indexes for the lowest-cost pairings between rows and 406 | columns in the database. Returns a list of (row, column) tuples 407 | that can be used to traverse the matrix. 408 | 409 | :Parameters: 410 | cost_matrix : list of lists 411 | The cost matrix. If this cost matrix is not square, it 412 | will be padded with zeros, via a call to ``pad_matrix()``. 413 | (This method does *not* modify the caller's matrix. It 414 | operates on a copy of the matrix.) 415 | 416 | **WARNING**: This code handles square and rectangular 417 | matrices. It does *not* handle irregular matrices. 418 | 419 | :rtype: list 420 | :return: A list of ``(row, column)`` tuples that describe the lowest 421 | cost path through the matrix 422 | 423 | """ 424 | self.C = self.pad_matrix(cost_matrix) 425 | self.n = len(self.C) 426 | self.original_length = len(cost_matrix) 427 | self.original_width = len(cost_matrix[0]) 428 | self.row_covered = [False for i in range(self.n)] 429 | self.col_covered = [False for i in range(self.n)] 430 | self.Z0_r = 0 431 | self.Z0_c = 0 432 | self.path = self.__make_matrix(self.n * 2, 0) 433 | self.marked = self.__make_matrix(self.n, 0) 434 | 435 | done = False 436 | step = 1 437 | 438 | steps = { 1 : self.__step1, 439 | 2 : self.__step2, 440 | 3 : self.__step3, 441 | 4 : self.__step4, 442 | 5 : self.__step5, 443 | 6 : self.__step6 } 444 | 445 | while not done: 446 | try: 447 | func = steps[step] 448 | step = func() 449 | except KeyError: 450 | done = True 451 | 452 | # Look for the starred columns 453 | results = [] 454 | for i in range(self.original_length): 455 | for j in range(self.original_width): 456 | if self.marked[i][j] == 1: 457 | results += [(i, j)] 458 | 459 | return results 460 | 461 | def __copy_matrix(self, matrix): 462 | """Return an exact copy of the supplied matrix""" 463 | return copy.deepcopy(matrix) 464 | 465 | def __make_matrix(self, n, val): 466 | """Create an *n*x*n* matrix, populating it with the specific value.""" 467 | matrix = [] 468 | for i in range(n): 469 | matrix += [[val for j in range(n)]] 470 | return matrix 471 | 472 | def __step1(self): 473 | """ 474 | For each row of the matrix, find the smallest element and 475 | subtract it from every element in its row. Go to Step 2. 476 | """ 477 | C = self.C 478 | n = self.n 479 | for i in range(n): 480 | vals = [x for x in self.C[i] if x is not DISALLOWED] 481 | if len(vals) == 0: 482 | # All values in this row are DISALLOWED. This matrix is 483 | # unsolvable. 484 | raise UnsolvableMatrix( 485 | "Row {0} is entirely DISALLOWED.".format(i) 486 | ) 487 | minval = min(vals) 488 | # Find the minimum value for this row and subtract that minimum 489 | # from every element in the row. 490 | for j in range(n): 491 | if self.C[i][j] is not DISALLOWED: 492 | self.C[i][j] -= minval 493 | return 2 494 | 495 | def __step2(self): 496 | """ 497 | Find a zero (Z) in the resulting matrix. If there is no starred 498 | zero in its row or column, star Z. Repeat for each element in the 499 | matrix. Go to Step 3. 500 | """ 501 | n = self.n 502 | for i in range(n): 503 | for j in range(n): 504 | if (self.C[i][j] == 0) and \ 505 | (not self.col_covered[j]) and \ 506 | (not self.row_covered[i]): 507 | self.marked[i][j] = 1 508 | self.col_covered[j] = True 509 | self.row_covered[i] = True 510 | break 511 | 512 | self.__clear_covers() 513 | return 3 514 | 515 | def __step3(self): 516 | """ 517 | Cover each column containing a starred zero. If K columns are 518 | covered, the starred zeros describe a complete set of unique 519 | assignments. In this case, Go to DONE, otherwise, Go to Step 4. 520 | """ 521 | n = self.n 522 | count = 0 523 | for i in range(n): 524 | for j in range(n): 525 | if self.marked[i][j] == 1 and not self.col_covered[j]: 526 | self.col_covered[j] = True 527 | count += 1 528 | 529 | if count >= n: 530 | step = 7 # done 531 | else: 532 | step = 4 533 | 534 | return step 535 | 536 | def __step4(self): 537 | """ 538 | Find a noncovered zero and prime it. If there is no starred zero 539 | in the row containing this primed zero, Go to Step 5. Otherwise, 540 | cover this row and uncover the column containing the starred 541 | zero. Continue in this manner until there are no uncovered zeros 542 | left. Save the smallest uncovered value and Go to Step 6. 543 | """ 544 | step = 0 545 | done = False 546 | row = 0 547 | col = 0 548 | star_col = -1 549 | while not done: 550 | (row, col) = self.__find_a_zero(row, col) 551 | if row < 0: 552 | done = True 553 | step = 6 554 | else: 555 | self.marked[row][col] = 2 556 | star_col = self.__find_star_in_row(row) 557 | if star_col >= 0: 558 | col = star_col 559 | self.row_covered[row] = True 560 | self.col_covered[col] = False 561 | else: 562 | done = True 563 | self.Z0_r = row 564 | self.Z0_c = col 565 | step = 5 566 | 567 | return step 568 | 569 | def __step5(self): 570 | """ 571 | Construct a series of alternating primed and starred zeros as 572 | follows. Let Z0 represent the uncovered primed zero found in Step 4. 573 | Let Z1 denote the starred zero in the column of Z0 (if any). 574 | Let Z2 denote the primed zero in the row of Z1 (there will always 575 | be one). Continue until the series terminates at a primed zero 576 | that has no starred zero in its column. Unstar each starred zero 577 | of the series, star each primed zero of the series, erase all 578 | primes and uncover every line in the matrix. Return to Step 3 579 | """ 580 | count = 0 581 | path = self.path 582 | path[count][0] = self.Z0_r 583 | path[count][1] = self.Z0_c 584 | done = False 585 | while not done: 586 | row = self.__find_star_in_col(path[count][1]) 587 | if row >= 0: 588 | count += 1 589 | path[count][0] = row 590 | path[count][1] = path[count-1][1] 591 | else: 592 | done = True 593 | 594 | if not done: 595 | col = self.__find_prime_in_row(path[count][0]) 596 | count += 1 597 | path[count][0] = path[count-1][0] 598 | path[count][1] = col 599 | 600 | self.__convert_path(path, count) 601 | self.__clear_covers() 602 | self.__erase_primes() 603 | return 3 604 | 605 | def __step6(self): 606 | """ 607 | Add the value found in Step 4 to every element of each covered 608 | row, and subtract it from every element of each uncovered column. 609 | Return to Step 4 without altering any stars, primes, or covered 610 | lines. 611 | """ 612 | minval = self.__find_smallest() 613 | events = 0 # track actual changes to matrix 614 | for i in range(self.n): 615 | for j in range(self.n): 616 | if self.C[i][j] is DISALLOWED: 617 | continue 618 | if self.row_covered[i]: 619 | self.C[i][j] += minval 620 | events += 1 621 | if not self.col_covered[j]: 622 | self.C[i][j] -= minval 623 | events += 1 624 | if self.row_covered[i] and not self.col_covered[j]: 625 | events -= 2 # change reversed, no real difference 626 | if (events == 0): 627 | raise UnsolvableMatrix("Matrix cannot be solved!") 628 | return 4 629 | 630 | def __find_smallest(self): 631 | """Find the smallest uncovered value in the matrix.""" 632 | minval = sys.maxsize 633 | for i in range(self.n): 634 | for j in range(self.n): 635 | if (not self.row_covered[i]) and (not self.col_covered[j]): 636 | if self.C[i][j] is not DISALLOWED and minval > self.C[i][j]: 637 | minval = self.C[i][j] 638 | return minval 639 | 640 | 641 | def __find_a_zero(self, i0=0, j0=0): 642 | """Find the first uncovered element with value 0""" 643 | row = -1 644 | col = -1 645 | i = i0 646 | n = self.n 647 | done = False 648 | 649 | while not done: 650 | j = j0 651 | while True: 652 | if (self.C[i][j] == 0) and \ 653 | (not self.row_covered[i]) and \ 654 | (not self.col_covered[j]): 655 | row = i 656 | col = j 657 | done = True 658 | j = (j + 1) % n 659 | if j == j0: 660 | break 661 | i = (i + 1) % n 662 | if i == i0: 663 | done = True 664 | 665 | return (row, col) 666 | 667 | def __find_star_in_row(self, row): 668 | """ 669 | Find the first starred element in the specified row. Returns 670 | the column index, or -1 if no starred element was found. 671 | """ 672 | col = -1 673 | for j in range(self.n): 674 | if self.marked[row][j] == 1: 675 | col = j 676 | break 677 | 678 | return col 679 | 680 | def __find_star_in_col(self, col): 681 | """ 682 | Find the first starred element in the specified row. Returns 683 | the row index, or -1 if no starred element was found. 684 | """ 685 | row = -1 686 | for i in range(self.n): 687 | if self.marked[i][col] == 1: 688 | row = i 689 | break 690 | 691 | return row 692 | 693 | def __find_prime_in_row(self, row): 694 | """ 695 | Find the first prime element in the specified row. Returns 696 | the column index, or -1 if no starred element was found. 697 | """ 698 | col = -1 699 | for j in range(self.n): 700 | if self.marked[row][j] == 2: 701 | col = j 702 | break 703 | 704 | return col 705 | 706 | def __convert_path(self, path, count): 707 | for i in range(count+1): 708 | if self.marked[path[i][0]][path[i][1]] == 1: 709 | self.marked[path[i][0]][path[i][1]] = 0 710 | else: 711 | self.marked[path[i][0]][path[i][1]] = 1 712 | 713 | def __clear_covers(self): 714 | """Clear all covered matrix cells""" 715 | for i in range(self.n): 716 | self.row_covered[i] = False 717 | self.col_covered[i] = False 718 | 719 | def __erase_primes(self): 720 | """Erase all prime markings""" 721 | for i in range(self.n): 722 | for j in range(self.n): 723 | if self.marked[i][j] == 2: 724 | self.marked[i][j] = 0 725 | 726 | # --------------------------------------------------------------------------- 727 | # Functions 728 | # --------------------------------------------------------------------------- 729 | 730 | def make_cost_matrix(profit_matrix, inversion_function=None): 731 | """ 732 | Create a cost matrix from a profit matrix by calling 733 | 'inversion_function' to invert each value. The inversion 734 | function must take one numeric argument (of any type) and return 735 | another numeric argument which is presumed to be the cost inverse 736 | of the original profit. In case the inversion function is not provided, 737 | calculate it as max(matrix) - matrix. 738 | 739 | This is a static method. Call it like this: 740 | 741 | .. python: 742 | 743 | cost_matrix = Munkres.make_cost_matrix(matrix, inversion_func) 744 | 745 | For example: 746 | 747 | .. python: 748 | 749 | cost_matrix = Munkres.make_cost_matrix(matrix, lambda x : sys.maxsize - x) 750 | 751 | :Parameters: 752 | profit_matrix : list of lists 753 | The matrix to convert from a profit to a cost matrix 754 | 755 | inversion_function : function 756 | The function to use to invert each entry in the profit matrix. 757 | In case it is not provided, calculate it as max(matrix) - matrix. 758 | 759 | :rtype: list of lists 760 | :return: The converted matrix 761 | """ 762 | if not inversion_function: 763 | maximum = max(max(row) for row in profit_matrix) 764 | inversion_function = lambda x: maximum - x 765 | 766 | cost_matrix = [] 767 | for row in profit_matrix: 768 | cost_matrix.append([inversion_function(value) for value in row]) 769 | return cost_matrix 770 | 771 | def print_matrix(matrix, msg=None): 772 | """ 773 | Convenience function: Displays the contents of a matrix of integers. 774 | 775 | :Parameters: 776 | matrix : list of lists 777 | Matrix to print 778 | 779 | msg : str 780 | Optional message to print before displaying the matrix 781 | """ 782 | import math 783 | 784 | if msg is not None: 785 | print(msg) 786 | 787 | # Calculate the appropriate format width. 788 | width = 0 789 | for row in matrix: 790 | for val in row: 791 | if val is DISALLOWED: 792 | val = DISALLOWED_PRINTVAL 793 | width = max(width, len(str(val))) 794 | 795 | # Make the format string 796 | format = ('%%%d' % width) 797 | 798 | # Print the matrix 799 | for row in matrix: 800 | sep = '[' 801 | for val in row: 802 | if val is DISALLOWED: 803 | formatted = ((format + 's') % DISALLOWED_PRINTVAL) 804 | else: formatted = ((format + 'd') % val) 805 | sys.stdout.write(sep + formatted) 806 | sep = ', ' 807 | sys.stdout.write(']\n') 808 | 809 | # --------------------------------------------------------------------------- 810 | # Main 811 | # --------------------------------------------------------------------------- 812 | 813 | def samples(): 814 | 815 | matrices = [ 816 | # Square 817 | ([[400, 150, 400], 818 | [400, 450, 600], 819 | [300, 225, 300]], 820 | 850), # expected cost 821 | 822 | # Rectangular variant 823 | ([[400, 150, 400, 1], 824 | [400, 450, 600, 2], 825 | [300, 225, 300, 3]], 826 | 452), # expected cost 827 | 828 | 829 | # Square 830 | ([[10, 10, 8], 831 | [9, 8, 1], 832 | [9, 7, 4]], 833 | 18), 834 | 835 | # Rectangular variant 836 | ([[10, 10, 8, 11], 837 | [9, 8, 1, 1], 838 | [9, 7, 4, 10]], 839 | 15), 840 | 841 | # Rectangular with DISALLOWED 842 | ([[4, 5, 6, DISALLOWED], 843 | [1, 9, 12, 11], 844 | [DISALLOWED, 5, 4, DISALLOWED], 845 | [12, 12, 12, 10]], 846 | 20), 847 | 848 | # DISALLOWED to force pairings 849 | ([[1, DISALLOWED, DISALLOWED, DISALLOWED], 850 | [DISALLOWED, 2, DISALLOWED, DISALLOWED], 851 | [DISALLOWED, DISALLOWED, 3, DISALLOWED], 852 | [DISALLOWED, DISALLOWED, DISALLOWED, 4]], 853 | 10)] 854 | 855 | m = Munkres() 856 | for cost_matrix, expected_total in matrices: 857 | print_matrix(cost_matrix, msg='cost matrix') 858 | indexes = m.compute(cost_matrix) 859 | total_cost = 0 860 | for r, c in indexes: 861 | x = cost_matrix[r][c] 862 | total_cost += x 863 | print('(%d, %d) -> %d' % (r, c, x)) 864 | print('lowest cost=%d' % total_cost) 865 | assert expected_total == total_cost 866 | 867 | # samples() -------------------------------------------------------------------------------- /App2/test.py: -------------------------------------------------------------------------------- 1 | # from __future__ import print_function 2 | import numpy as np 3 | from munkres import Munkres 4 | import torch.nn.functional as F 5 | import time, os, shutil 6 | from global_set import test_gt_det, \ 7 | tau_threshold, gap, f_gap, show_recovering, decay, u_update, mot_dataset_dir, model_dir 8 | from test_dataset import DatasetFromFolder 9 | from mot_model import * 10 | 11 | torch.manual_seed(123) 12 | np.random.seed(123) 13 | 14 | 15 | def deleteDir(del_dir): 16 | shutil.rmtree(del_dir) 17 | 18 | 19 | year = 17 20 | 21 | type = '' 22 | t_dir = '' # the dir of tracking results 23 | sequence_dir = '' # the dir of the testing video sequence 24 | 25 | seqs = [2, 4, 5, 9, 10, 11, 13] # the set of sequences 26 | lengths = [600, 1050, 837, 525, 654, 900, 750] # the length of the sequence 27 | 28 | test_seqs = [1, 3, 6, 7, 8, 12, 14] 29 | test_lengths = [450, 1500, 1194, 500, 625, 900, 750] 30 | 31 | tt_tag = 1 # 1 - test, 0 - train 32 | 33 | tau_conf_score = 0.0 34 | 35 | 36 | class GN(): 37 | def __init__(self, seq_index, seq_len, cuda=True): 38 | ''' 39 | Evaluating with the MotMetrics 40 | :param seq_index: the number of the sequence 41 | :param seq_len: the length of the sequence 42 | :param cuda: True - GPU, False - CPU 43 | ''' 44 | self.seq_index = seq_index 45 | self.hungarian = Munkres() 46 | self.device = torch.device("cuda" if cuda else "cpu") 47 | self.seq_len = seq_len 48 | self.missingCounter = 0 49 | self.sideConnection = 0 50 | 51 | print(' Loading the model...') 52 | self.loadModel() 53 | 54 | self.out_dir = t_dir + 'motmetrics_%s/' % (type) 55 | print(' The dir of the output file:', self.out_dir) 56 | 57 | if not os.path.exists(self.out_dir): 58 | os.mkdir(self.out_dir) 59 | else: 60 | deleteDir(self.out_dir) 61 | os.mkdir(self.out_dir) 62 | self.initOut() 63 | 64 | def initOut(self): 65 | print(' Loading Data...') 66 | self.train_set = DatasetFromFolder(sequence_dir, mot_dataset_dir + 'MOT16/test/MOT16-%02d' % self.seq_index, 67 | tau_conf_score) 68 | 69 | detection_dir = self.out_dir + 'res_det.txt' 70 | res_training = self.out_dir + 'res.txt' # the tracking results 71 | self.createTxt(detection_dir) 72 | self.createTxt(res_training) 73 | self.copyLines(self.seq_index, 1, detection_dir, self.seq_len, 1) 74 | 75 | self.evaluation(1, self.seq_len, detection_dir, res_training) 76 | 77 | def getSeqL(self, info): 78 | # get the length of the sequence 79 | f = open(info, 'r') 80 | f.readline() 81 | for line in f.readlines(): 82 | line = line.strip().split('=') 83 | if line[0] == 'seqLength': 84 | seqL = int(line[1]) 85 | f.close() 86 | return seqL 87 | 88 | def copyLines(self, seq, head, gt_seq, tail=-1, tag=0): 89 | ''' 90 | Copy the groun truth within [head, head+num] 91 | :param seq: the number of the sequence 92 | :param head: the head frame number 93 | :param tail: the number the clipped sequence 94 | :param gt_seq: the dir of the output file 95 | :return: None 96 | ''' 97 | if tt_tag: 98 | basic_dir = mot_dataset_dir + 'MOT%d/test/MOT%d-%02d-%s/' % (year, year, seq, type) 99 | else: 100 | basic_dir = mot_dataset_dir + 'MOT%d/train/MOT%d-%02d-%s/' % (year, year, seq, type) 101 | print(' Testing on', basic_dir, 'Length:', self.seq_len) 102 | seqL = tail if tail != -1 else self.getSeqL(basic_dir + 'seqinfo.ini') 103 | 104 | det_dir = 'gt/gt_det.txt' if test_gt_det else 'det/det.txt' 105 | seq_dir = basic_dir + ('gt/gt.txt' if tag == 0 else det_dir) 106 | inStream = open(seq_dir, 'r') 107 | 108 | outStream = open(gt_seq, 'w') 109 | for line in inStream.readlines(): 110 | line = line.strip() 111 | attrs = line.split(',') 112 | f_num = int(attrs[0]) 113 | if f_num >= head and f_num <= seqL: 114 | outStream.write(line + '\n') 115 | outStream.close() 116 | 117 | inStream.close() 118 | return seqL 119 | 120 | def createTxt(self, out_file): 121 | f = open(out_file, 'w') 122 | f.close() 123 | 124 | def loadModel(self): 125 | tail = 13 126 | self.Uphi = torch.load(model_dir + 'uphi_%02d.pth' % tail).to(self.device) 127 | self.Vphi = torch.load(model_dir + 'vphi_%02d.pth' % tail).to(self.device) 128 | self.Ephi1 = torch.load(model_dir + 'ephi1_%02d.pth' % tail).to(self.device) 129 | self.Ephi2 = torch.load(model_dir + 'ephi2_%02d.pth' % tail).to(self.device) 130 | self.u = torch.load(model_dir + 'u_%02d.pth' % tail).to(self.device) 131 | 132 | def swapFC(self): 133 | self.cur = self.cur ^ self.nxt 134 | self.nxt = self.cur ^ self.nxt 135 | self.cur = self.cur ^ self.nxt 136 | 137 | def linearModel(self, out, attr1, attr2): 138 | # print 'I got you! *.*' 139 | t = attr1[-1] 140 | self.sideConnection += 1 141 | if t > f_gap: 142 | return 143 | frame = int(attr1[0]) 144 | x1, y1, w1, h1 = float(attr1[2]), float(attr1[3]), float(attr1[4]), float(attr1[5]) 145 | x2, y2, w2, h2 = float(attr2[2]), float(attr2[3]), float(attr2[4]), float(attr2[5]) 146 | 147 | x_delta = (x2 - x1) / t 148 | y_delta = (y2 - y1) / t 149 | w_delta = (w2 - w1) / t 150 | h_delta = (h2 - h1) / t 151 | 152 | for i in range(1, t): 153 | frame += 1 154 | x1 += x_delta 155 | y1 += y_delta 156 | w1 += w_delta 157 | h1 += h_delta 158 | attr1[0] = str(frame) 159 | attr1[2] = str(x1) 160 | attr1[3] = str(y1) 161 | attr1[4] = str(w1) 162 | attr1[5] = str(h1) 163 | line = '' 164 | for attr in attr1[:-1]: 165 | line += attr + ',' 166 | if show_recovering: 167 | line += '1' 168 | else: 169 | line = line[:-1] 170 | out.write(line + '\n') 171 | self.missingCounter += t - 1 172 | 173 | def evaluation(self, head, tail, gtFile, outFile): 174 | ''' 175 | Evaluation on dets 176 | :param head: the head frame number 177 | :param tail: the tail frame number 178 | :param gtFile: the ground truth file name 179 | :param outFile: the name of output file 180 | :return: None 181 | ''' 182 | with torch.no_grad(): 183 | gtIn = open(gtFile, 'r') 184 | self.cur, self.nxt = 0, 1 185 | line_con = [[], []] 186 | id_con = [[], []] 187 | id_step = 1 188 | 189 | step = head + self.train_set.setBuffer(head) 190 | while step < tail: 191 | # print ' ', step 192 | print(step, end=' ') 193 | if step % 100 == 0: 194 | print('') 195 | t_gap = self.train_set.loadNext() 196 | step += t_gap 197 | # print 'F', 't_gap=%d'%t_gap, 198 | # print 'Fo', 199 | 200 | m = self.train_set.m 201 | n = self.train_set.n 202 | # print m, n, 203 | if n == 0: 204 | print('There is no detection in the rest of sequence!') 205 | break 206 | 207 | if id_step == 1: 208 | out = open(outFile, 'a') 209 | i = 0 210 | while i < m: 211 | attrs = gtIn.readline().strip().split(',') 212 | if float(attrs[6]) >= tau_conf_score: 213 | attrs.append(1) 214 | attrs[1] = str(id_step) 215 | line = '' 216 | for attr in attrs[:-1]: 217 | line += attr + ',' 218 | if show_recovering: 219 | line += '0' 220 | else: 221 | line = line[:-1] 222 | out.write(line + '\n') 223 | line_con[self.cur].append(attrs) 224 | id_con[self.cur].append(id_step) 225 | id_step += 1 226 | i += 1 227 | out.close() 228 | # print '0.0', 229 | i = 0 230 | while i < n: 231 | attrs = gtIn.readline().strip().split(',') 232 | if float(attrs[6]) >= tau_conf_score: 233 | attrs.append(1) 234 | line_con[self.nxt].append(attrs) 235 | id_con[self.nxt].append(-1) 236 | i += 1 237 | 238 | # update the edges 239 | # print 'T', 240 | candidates = [] 241 | E_CON, V_CON = [], [] 242 | ret = self.train_set.getRet() 243 | 244 | decay_tag = [0 for i in range(m)] 245 | for i in range(m): 246 | for j in range(n): 247 | if ret[i][j] == 0: 248 | decay_tag[i] += 1 249 | 250 | for edge in self.train_set.candidates: 251 | e, vs_index, vr_index = edge 252 | e = e.view(1, -1).to(self.device) 253 | vs = self.train_set.getApp(1, vs_index) 254 | vr = self.train_set.getApp(0, vr_index) 255 | 256 | e1 = self.Ephi1(e, vs, vr, self.u) 257 | vr1 = self.Vphi(e1, vs, vr, self.u) 258 | candidates.append((e1, vs, vr1, vs_index, vr_index)) 259 | E_CON.append(e1) 260 | V_CON.append(vs) 261 | V_CON.append(vr1) 262 | 263 | E = self.train_set.aggregate(E_CON).view(1, -1) 264 | V = self.train_set.aggregate(V_CON).view(1, -1) 265 | u1 = self.Uphi(E, V, self.u) 266 | 267 | for iteration in candidates: 268 | e1, vs, vr1, vs_index, vr_index = iteration 269 | if ret[vs_index][vr_index] == tau_threshold: 270 | continue 271 | 272 | e2 = self.Ephi2(e1, vs, vr1, u1) 273 | self.train_set.edges[vs_index][vr_index] = e1.data.view(-1) 274 | 275 | tmp = F.softmax(e2) 276 | tmp = tmp.cpu().data.numpy()[0] 277 | 278 | t = line_con[self.cur][vs_index][-1] 279 | # ret[vs_index][vr_index] = float(tmp[0]) * pow(decay, t-1) 280 | if decay_tag[vs_index] > 0: 281 | ret[vs_index][vr_index] = min(float(tmp[0]) * pow(decay, t - 1), 1.0) 282 | else: 283 | ret[vs_index][vr_index] = float(tmp[0]) 284 | 285 | # for j in ret: 286 | # print j 287 | # print ret 288 | results = self.hungarian.compute(ret) 289 | 290 | out = open(outFile, 'a') 291 | nxt = self.train_set.nxt 292 | for (i, j) in results: 293 | # print (i,j) 294 | if ret[i][j] >= tau_threshold: 295 | continue 296 | e1 = self.train_set.edges[i][j].view(1, -1).to(self.device) 297 | vs = self.train_set.getApp(1, i) 298 | vr = self.train_set.getApp(0, j) 299 | 300 | vr1 = self.Vphi(e1, vs, vr, self.u) 301 | self.train_set.detections[nxt][j][0] = vr1.data 302 | e2 = self.Ephi2(e1, vs, vr1, u1) 303 | self.train_set.edges[i][j] = e2.data.view(-1) 304 | 305 | id = id_con[self.cur][i] 306 | id_con[self.nxt][j] = id 307 | attr1 = line_con[self.cur][i] 308 | attr2 = line_con[self.nxt][j] 309 | # print attrs 310 | attr2[1] = str(id) 311 | if attr1[-1] > 1: 312 | # for the missing detections 313 | self.linearModel(out, attr1, attr2) 314 | line = '' 315 | for attr in attr2[:-1]: 316 | line += attr + ',' 317 | if show_recovering: 318 | line += '0' 319 | else: 320 | line = line[:-1] 321 | out.write(line + '\n') 322 | 323 | if u_update: 324 | self.u = u1.data 325 | 326 | for i in range(n): 327 | if id_con[self.nxt][i] == -1: 328 | id_con[self.nxt][i] = id_step 329 | attrs = line_con[self.nxt][i] 330 | attrs[1] = str(id_step) 331 | line = '' 332 | for attr in attrs[:-1]: 333 | line += attr + ',' 334 | if show_recovering: 335 | line += '0' 336 | else: 337 | line = line[:-1] 338 | out.write(line + '\n') 339 | id_step += 1 340 | out.close() 341 | 342 | # For missing & Occlusion 343 | index = 0 344 | for (i, j) in results: 345 | while i != index: 346 | attrs = line_con[self.cur][index] 347 | # print '*', attrs, '*' 348 | if attrs[-1] + t_gap <= gap: 349 | attrs[-1] += t_gap 350 | line_con[self.nxt].append(attrs) 351 | id_con[self.nxt].append(id_con[self.cur][index]) 352 | self.train_set.moveApp(index) 353 | index += 1 354 | 355 | if ret[i][j] >= tau_threshold: 356 | attrs = line_con[self.cur][index] 357 | # print '*', attrs, '*' 358 | if attrs[-1] + t_gap <= gap: 359 | attrs[-1] += t_gap 360 | line_con[self.nxt].append(attrs) 361 | id_con[self.nxt].append(id_con[self.cur][index]) 362 | self.train_set.moveApp(index) 363 | index += 1 364 | while index < m: 365 | attrs = line_con[self.cur][index] 366 | # print '*', attrs, '*' 367 | if attrs[-1] + t_gap <= gap: 368 | attrs[-1] += t_gap 369 | line_con[self.nxt].append(attrs) 370 | id_con[self.nxt].append(id_con[self.cur][index]) 371 | self.train_set.moveApp(index) 372 | index += 1 373 | 374 | line_con[self.cur] = [] 375 | id_con[self.cur] = [] 376 | # print 'Results', 377 | self.train_set.swapFC() 378 | self.swapFC() 379 | # print '^.^|', step, tail 380 | gtIn.close() 381 | 382 | 383 | if __name__ == '__main__': 384 | try: 385 | types = [['POI', 0.7]] 386 | # types = [['DPM', -0.6], ['SDP', 0.5], ['FRCNN', 0.5]] 387 | for t in types: 388 | type, tau_conf_score = t 389 | head = time.time() 390 | 391 | f_dir = 'results/' 392 | if not os.path.exists(f_dir): 393 | os.mkdir(f_dir) 394 | 395 | for i in range(len(seqs)): 396 | if tt_tag: 397 | seq_index = test_seqs[i] 398 | seq_len = test_lengths[i] 399 | else: 400 | seq_index = seqs[i] 401 | seq_len = lengths[i] 402 | 403 | print('The sequence:', seq_index, '- The length of the training data:', seq_len) 404 | 405 | s_dir = f_dir + '%02d/' % seq_index 406 | if not os.path.exists(s_dir): 407 | os.mkdir(s_dir) 408 | print(s_dir, 'does not exist!') 409 | 410 | t_dir = s_dir + '%d/' % seq_len 411 | if not os.path.exists(t_dir): 412 | os.mkdir(t_dir) 413 | print(t_dir, 'does not exist!') 414 | 415 | if tt_tag: 416 | seq_dir = 'MOT%d-%02d-%s' % (year, test_seqs[i], type) 417 | sequence_dir = mot_dataset_dir + 'MOT%d/test/' % year + seq_dir 418 | print(' ', sequence_dir) 419 | 420 | start = time.time() 421 | print(' Evaluating Graph Network...') 422 | gn = GN(test_seqs[i], test_lengths[i]) 423 | else: 424 | seq_dir = 'MOT%d-%02d-%s' % (year, seqs[i], type) 425 | sequence_dir = mot_dataset_dir + 'MOT%d/train/' % year + seq_dir 426 | print(' ', sequence_dir) 427 | 428 | start = time.time() 429 | print(' Evaluating Graph Network...') 430 | gn = GN(seqs[i], lengths[i]) 431 | print(' Recover the number missing detections:', gn.missingCounter) 432 | print(' The number of sideConnections:', gn.sideConnection) 433 | print('Time consuming:', (time.time() - start) / 60.0) 434 | print('Time consuming:', (time.time() - head) / 60.0) 435 | except KeyboardInterrupt: 436 | print('Time consuming:', time.time() - start) 437 | print('') 438 | print('-' * 90) 439 | print('Existing from training early.') 440 | -------------------------------------------------------------------------------- /App2/test_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import cv2, random, torch, shutil, os 3 | import numpy as np 4 | from math import * 5 | from PIL import Image 6 | import torch.nn.functional as F 7 | from mot_model import appearance 8 | from global_set import edge_initial, test_gt_det, tau_dis, app_fine_tune, fine_tune_dir, tau_threshold 9 | from torchvision.transforms import ToTensor 10 | 11 | 12 | def load_img(filepath): 13 | img = Image.open(filepath).convert('RGB') 14 | return img 15 | 16 | 17 | class DatasetFromFolder(data.Dataset): 18 | def __init__(self, part, part_I, tau, cuda=True, show=0): 19 | super(DatasetFromFolder, self).__init__() 20 | self.dir = part 21 | self.cleanPath(part) 22 | self.img_dir = part_I + '/img1/' 23 | self.gt_dir = part + '/gt/' 24 | self.det_dir = part + '/det/' 25 | self.device = torch.device("cuda" if cuda else "cpu") 26 | self.tau_conf_score = tau 27 | self.show = show 28 | 29 | self.loadAModel() 30 | self.getSeqL() 31 | if test_gt_det: 32 | self.readBBx_gt() 33 | else: 34 | self.readBBx_det() 35 | self.initBuffer() 36 | 37 | def cleanPath(self, part): 38 | if os.path.exists(part + '/gts/'): 39 | shutil.rmtree(part + '/gts/') 40 | if os.path.exists(part + '/dets/'): 41 | shutil.rmtree(part + '/dets/') 42 | 43 | def loadAModel(self): 44 | if app_fine_tune: 45 | self.Appearance = torch.load(fine_tune_dir) 46 | else: 47 | self.Appearance = appearance() 48 | self.Appearance.to(self.device) 49 | self.Appearance.eval() # fixing the BatchN layer 50 | 51 | def getSeqL(self): 52 | # get the length of the sequence 53 | info = self.dir + '/seqinfo.ini' 54 | f = open(info, 'r') 55 | f.readline() 56 | for line in f.readlines(): 57 | line = line.strip().split('=') 58 | if line[0] == 'seqLength': 59 | self.seqL = int(line[1]) 60 | f.close() 61 | # print 'The length of the sequence:', self.seqL 62 | 63 | def readBBx_gt(self): 64 | # get the gt 65 | self.bbx = [[] for i in range(self.seqL + 1)] 66 | gt = self.gt_dir + 'gt.txt' 67 | f = open(gt, 'r') 68 | pre = -1 69 | for line in f.readlines(): 70 | line = line.strip().split(',') 71 | if line[7] == '1': 72 | index = int(line[0]) 73 | id = int(line[1]) 74 | x, y = int(line[2]), int(line[3]) 75 | w, h = int(line[4]), int(line[5]) 76 | conf_score, l, vr = float(line[6]), int(line[7]), float(line[8]) 77 | 78 | # sweep the invisible head-bbx from the training data 79 | if pre != id and vr == 0: 80 | continue 81 | 82 | pre = id 83 | self.bbx[index].append([x, y, w, h, id, conf_score, vr]) 84 | f.close() 85 | 86 | gt_out = open(self.gt_dir + 'gt_det.txt', 'w') 87 | for index in range(1, self.seqL + 1): 88 | for bbx in self.bbx[index]: 89 | x, y, w, h, id, conf_score, vr = bbx 90 | print >> gt_out, '%d,-1,%d,%d,%d,%d,%f,-1,-1,-1' % (index, x, y, w, h, conf_score) 91 | gt_out.close() 92 | 93 | def readBBx_det(self): 94 | # get the gt 95 | self.bbx = [[] for i in range(self.seqL + 1)] 96 | det = self.det_dir + 'det.txt' 97 | f = open(det, 'r') 98 | for line in f.readlines(): 99 | line = line.strip().split(',') 100 | index = int(line[0]) 101 | id = int(line[1]) 102 | x, y = int(float(line[2])), int(float(line[3])) 103 | w, h = int(float(line[4])), int(float(line[5])) 104 | conf_score = float(line[6]) 105 | if conf_score >= self.tau_conf_score: 106 | self.bbx[index].append([x, y, w, h, conf_score]) 107 | f.close() 108 | 109 | def initBuffer(self): 110 | if self.show: 111 | cv2.namedWindow('view', flags=0) 112 | cv2.namedWindow('crop', flags=0) 113 | self.f_step = 1 # the index of next frame in the process 114 | self.cur = 0 # the index of current frame in the detections 115 | self.nxt = 1 # the index of next frame in the detections 116 | self.detections = [None, None] # the buffer to storing images: current & next frame 117 | self.feature(1) 118 | 119 | def setBuffer(self, f): 120 | self.m = 0 121 | counter = -1 122 | while self.m == 0: 123 | counter += 1 124 | self.f_step = f + counter 125 | self.feature(1) 126 | self.m = len(self.detections[self.cur]) 127 | if counter > 0: 128 | print(' Empty in setBuffer:', counter) 129 | return counter 130 | 131 | def fixBB(self, x, y, w, h, size): 132 | width, height = size 133 | w = min(w + x, width) 134 | h = min(h + y, height) 135 | x = max(x, 0) 136 | y = max(y, 0) 137 | w -= x 138 | h -= y 139 | return x, y, w, h 140 | 141 | def IOU(self, Reframe, GTframe): 142 | """ 143 | Compute the Intersection of Union 144 | :param Reframe: x, y, w, h 145 | :param GTframe: x, y, w, h 146 | :return: Ratio 147 | """ 148 | if edge_initial == 1: 149 | return random.random() 150 | elif edge_initial == 3: 151 | return 0.5 152 | x1 = Reframe[0] 153 | y1 = Reframe[1] 154 | width1 = Reframe[2] 155 | height1 = Reframe[3] 156 | 157 | x2 = GTframe[0] 158 | y2 = GTframe[1] 159 | width2 = GTframe[2] 160 | height2 = GTframe[3] 161 | 162 | endx = max(x1 + width1, x2 + width2) 163 | startx = min(x1, x2) 164 | width = width1 + width2 - (endx - startx) 165 | 166 | endy = max(y1 + height1, y2 + height2) 167 | starty = min(y1, y2) 168 | height = height1 + height2 - (endy - starty) 169 | 170 | if width <= 0 or height <= 0: 171 | ratio = 0 172 | else: 173 | Area = width * height 174 | Area1 = width1 * height1 175 | Area2 = width2 * height2 176 | ratio = Area * 1. / (Area1 + Area2 - Area) 177 | return ratio 178 | 179 | def getMN(self, m, n): 180 | ans = [[None for i in range(n)] for i in range(m)] 181 | for i in range(m): 182 | Reframe = self.bbx[self.f_step - self.gap][i] 183 | for j in range(n): 184 | GTframe = self.bbx[self.f_step][j] 185 | p = self.IOU(Reframe, GTframe) 186 | # 1 - match, 0 - mismatch 187 | ans[i][j] = torch.FloatTensor([(1 - p) / 100.0, p / 100.0]).to(self.device) 188 | return ans 189 | 190 | def aggregate(self, set): 191 | if len(set): 192 | rho = sum(set) 193 | return rho / len(set) 194 | print(' The set is empty!') 195 | return None 196 | 197 | def distance(self, a_bbx, b_bbx): 198 | w1 = float(a_bbx[2]) * tau_dis 199 | w2 = float(b_bbx[2]) * tau_dis 200 | dx = float(a_bbx[0] + a_bbx[2] / 2) - float(b_bbx[0] + b_bbx[2] / 2) 201 | dy = float(a_bbx[1] + a_bbx[3] / 2) - float(b_bbx[1] + b_bbx[3] / 2) 202 | d = sqrt(dx * dx + dy * dy) 203 | if d <= w1 and d <= w2: 204 | return 0.0 205 | return tau_threshold 206 | 207 | def getRet(self): 208 | cur = self.f_step - self.gap 209 | ret = [[0.0 for i in range(self.n)] for j in range(self.m)] 210 | for i in range(self.m): 211 | bbx1 = self.bbx[cur][i] 212 | for j in range(self.n): 213 | ret[i][j] = self.distance(bbx1, self.bbx[self.f_step][j]) 214 | return ret 215 | 216 | def getApp(self, tag, index): 217 | cur = self.cur if tag else self.nxt 218 | if torch.is_tensor(index): 219 | n = index.numel() 220 | if n < 0: 221 | print('The tensor is empyt!') 222 | return None 223 | if n == 1: 224 | return self.detections[cur][index[0]][0] 225 | ans = torch.cat((self.detections[cur][index[0]][0], self.detections[cur][index[1]][0]), dim=0) 226 | for i in range(2, n): 227 | ans = torch.cat((ans, self.detections[cur][index[i]][0]), dim=0) 228 | return ans 229 | return self.detections[cur][index][0] 230 | 231 | def moveApp(self, index): 232 | self.bbx[self.f_step].append(self.bbx[self.f_step - self.gap][index]) # add the bbx 233 | self.detections[self.nxt].append(self.detections[self.cur][index]) # add the appearance 234 | 235 | def swapFC(self): 236 | self.cur = self.cur ^ self.nxt 237 | self.nxt = self.cur ^ self.nxt 238 | self.cur = self.cur ^ self.nxt 239 | 240 | def resnet34(self, img): 241 | bbx = ToTensor()(img) 242 | bbx = bbx.to(self.device) 243 | bbx = bbx.view(-1, bbx.size(0), bbx.size(1), bbx.size(2)) 244 | ret = self.Appearance(bbx) 245 | ret = ret.view(1, -1) 246 | return ret 247 | 248 | def feature(self, tag=0): 249 | ''' 250 | Getting the appearance of the detections in current frame 251 | :param tag: 1 - initiating 252 | :param show: 1 - show the cropped & src image 253 | :return: None 254 | ''' 255 | apps = [] 256 | with torch.no_grad(): 257 | bbx_container = [] 258 | for bbx in self.bbx[self.f_step]: 259 | """ 260 | Bellow Conditions needed be taken into consideration: 261 | x, y < 0 and x+w > W, y+h > H 262 | """ 263 | img = load_img(self.img_dir + '%06d.jpg' % self.f_step) # initial with loading the first frame 264 | if test_gt_det: 265 | x, y, w, h, id, conf_score, vr = bbx 266 | else: 267 | x, y, w, h, conf_score = bbx 268 | x, y, w, h = self.fixBB(x, y, w, h, img.size) 269 | if test_gt_det: 270 | bbx_container.append([x, y, w, h, id, conf_score, vr]) 271 | else: 272 | bbx_container.append([x, y, w, h, conf_score]) 273 | crop = img.crop([x, y, x + w, y + h]) 274 | bbx = crop.resize((224, 224), Image.ANTIALIAS) 275 | ret = self.resnet34(bbx) 276 | app = ret.data 277 | apps.append([app, conf_score]) 278 | 279 | if self.show: 280 | img = np.asarray(img) 281 | crop = np.asarray(crop) 282 | if test_gt_det: 283 | print('%06d' % self.f_step, id, vr, '***', ) 284 | else: 285 | print('%06d' % self.f_step, conf_score, vr, '***', ) 286 | print(w, h, '-', ) 287 | print(len(crop[0]), len(crop)) 288 | cv2.imshow('crop', crop) 289 | cv2.imshow('view', img) 290 | cv2.waitKey(34) 291 | input('Continue?') 292 | # cv2.waitKey(34) 293 | self.bbx[self.f_step] = bbx_container 294 | if tag: 295 | self.detections[self.cur] = apps 296 | else: 297 | self.detections[self.nxt] = apps 298 | 299 | def loadNext(self): 300 | self.m = len(self.detections[self.cur]) 301 | 302 | self.gap = 0 303 | self.n = 0 304 | while self.n == 0: 305 | self.f_step += 1 306 | self.feature() 307 | self.n = len(self.detections[self.nxt]) 308 | self.gap += 1 309 | 310 | if self.gap > 1: 311 | print(' Empty in loadNext:', self.f_step - self.gap + 1, '-', self.gap - 1) 312 | 313 | self.candidates = [] 314 | self.edges = self.getMN(self.m, self.n) 315 | 316 | for i in range(self.m): 317 | for j in range(self.n): 318 | e = self.edges[i][j] 319 | self.candidates.append([e, i, j]) 320 | 321 | # print ' The index of the next frame', self.f_step, len(self.bbx) 322 | return self.gap 323 | 324 | def __getitem__(self, index): 325 | return self.candidates[index] 326 | 327 | def __len__(self): 328 | return len(self.candidates) 329 | -------------------------------------------------------------------------------- /App2/train.py: -------------------------------------------------------------------------------- 1 | # from __future__ import print_function 2 | import torch.optim as optim 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from torch.utils.data import DataLoader 6 | from dataset import DatasetFromFolder 7 | import time, random, os, shutil 8 | from munkres import Munkres 9 | from global_set import u_initial, mot_dataset_dir, model_dir 10 | from mot_model import * 11 | from tensorboardX import SummaryWriter 12 | 13 | torch.manual_seed(123) 14 | np.random.seed(123) 15 | 16 | 17 | def deleteDir(del_dir): 18 | shutil.rmtree(del_dir) 19 | 20 | 21 | class GN(): 22 | def __init__(self, lr=1e-5, batchs=8, cuda=True): 23 | ''' 24 | :param tt: train_test 25 | :param tag: 1 - evaluation on testing data, 0 - without evaluation on testing data 26 | :param lr: 27 | :param batchs: 28 | :param cuda: 29 | ''' 30 | # all the tensor should set the 'volatile' as True, and False when update the network 31 | self.hungarian = Munkres() 32 | self.device = torch.device("cuda" if cuda else "cpu") 33 | self.nEpochs = 999 34 | self.lr = lr 35 | self.batchsize = batchs 36 | self.numWorker = 4 37 | 38 | self.show_process = 0 # interaction 39 | self.step_input = 1 40 | 41 | print(' Preparing the model...') 42 | self.resetU() 43 | 44 | self.Uphi = uphi().to(self.device) 45 | self.Vphi = vphi().to(self.device) 46 | self.Ephi1 = ephi().to(self.device) 47 | self.Ephi2 = ephi().to(self.device) 48 | 49 | self.criterion = nn.MSELoss() if criterion_s else nn.CrossEntropyLoss() 50 | self.criterion = self.criterion.to(self.device) 51 | 52 | self.criterion_v = nn.MSELoss().to(self.device) 53 | 54 | self.optimizer1 = optim.Adam([ 55 | {'params': self.Ephi1.parameters()}], 56 | lr=lr) 57 | self.optimizer2 = optim.Adam([ 58 | {'params': self.Uphi.parameters()}, 59 | {'params': self.Vphi.parameters()}, 60 | {'params': self.Ephi2.parameters()}], 61 | lr=lr) 62 | 63 | self.writer = SummaryWriter() 64 | 65 | seqs = [2, 4, 5, 9, 10, 11, 13] 66 | lengths = [600, 1050, 837, 525, 654, 900, 750] 67 | 68 | for i in range(len(seqs)): 69 | # print ' Loading Data...' 70 | seq = seqs[i] 71 | self.seq_index = seq 72 | start = time.time() 73 | sequence_dir = mot_dataset_dir + 'MOT16/train/MOT16-%02d' % seq 74 | 75 | self.train_set = DatasetFromFolder(sequence_dir) 76 | 77 | self.train_test = lengths[i] 78 | # self.train_test = int(self.train_test * 0.8) # For training the model without the validation set 79 | 80 | self.loss_threhold = 0.03 81 | self.update('seq/%02d' % seq) 82 | 83 | def getEdges(self): # the statistic data of the graph among two frames' detections 84 | self.train_set.setBuffer(1) 85 | step = 1 86 | edge_counter = 0.0 87 | for head in range(1, self.train_test): 88 | self.train_set.loadNext() # Get the next frame 89 | edge_counter += self.train_set.m * self.train_set.n 90 | step += 1 91 | self.train_set.swapFC() 92 | 93 | def resetU(self): 94 | if u_initial: 95 | self.u = torch.FloatTensor([random.random() for i in range(u_num)]).view(1, -1) 96 | else: 97 | self.u = torch.FloatTensor([0.0 for i in range(u_num)]).view(1, -1) 98 | self.u = self.u.to(self.device) 99 | 100 | def updateNetwork(self, seqName): 101 | self.train_set.setBuffer(1) 102 | step = 1 103 | loss_step = 0 104 | edge_counter = 0.0 105 | for head in range(1, self.train_test): 106 | self.train_set.loadNext() # Get the next frame 107 | edge_counter += self.train_set.m * self.train_set.n 108 | # print(' Step -', step) 109 | data_loader = DataLoader(dataset=self.train_set, num_workers=self.numWorker, batch_size=self.batchsize, 110 | shuffle=True) 111 | 112 | for epoch_i in range(1, self.nEpochs): 113 | num = 0 114 | epoch_loss_i = 0.0 115 | for iteration in enumerate(data_loader, 1): 116 | index, (e, gt, vs_index, vr_index) = iteration 117 | e = e.to(self.device) 118 | gt = gt.to(self.device) 119 | vs = self.train_set.getApp(1, vs_index) 120 | vr = self.train_set.getApp(0, vr_index) 121 | 122 | self.optimizer1.zero_grad() 123 | e1 = self.Ephi1(e, vs, vr, self.u) 124 | # update the Ephi1 125 | loss = self.criterion(e1, gt.squeeze(1)) 126 | loss.backward() 127 | self.optimizer1.step() 128 | num += self.batchsize 129 | if epoch_loss_i / num < self.loss_threhold: 130 | break 131 | # print(' Updating the Ephi1: %d times.' % epoch_i) 132 | 133 | for epoch in range(1, self.nEpochs): 134 | num = 0 135 | epoch_loss = 0.0 136 | v_loss = 0.0 137 | arpha_loss = 0.0 138 | 139 | candidates = [] 140 | E_CON, V_CON = [], [] 141 | for iteration in enumerate(data_loader, 1): 142 | index, (e, gt, vs_index, vr_index) = iteration 143 | e = e.to(self.device) 144 | gt = gt.to(self.device) 145 | vs = self.train_set.getApp(1, vs_index) 146 | vr = self.train_set.getApp(0, vr_index) 147 | 148 | e1 = self.Ephi1(e, vs, vr, self.u) 149 | 150 | e1 = e1.data 151 | vr1 = self.Vphi(e1, vs, vr, self.u) 152 | candidates.append((e1, gt, vs, vr, vr1)) 153 | E_CON.append(torch.mean(e1, 0)) 154 | V_CON.append(torch.mean(vs, 0)) 155 | V_CON.append(torch.mean(vr1.data, 0)) 156 | 157 | E = self.train_set.aggregate(E_CON).view(1, -1) # This part is the aggregation for Edge 158 | V = self.train_set.aggregate(V_CON).view(1, -1) # This part is the aggregation for vertex 159 | # print E.view(1, -1) 160 | 161 | n = len(candidates) 162 | for i in range(n): 163 | e1, gt, vs, vr, vr1 = candidates[i] 164 | tmp_gt = 1 - torch.FloatTensor(gt.cpu().numpy()).to(self.device) 165 | 166 | self.optimizer2.zero_grad() 167 | 168 | u1 = self.Uphi(E, V, self.u) 169 | e2 = self.Ephi2(e1, vs, vr1, u1) 170 | 171 | # Penalize the u to let its value not too big 172 | arpha = torch.mean(torch.abs(u1)) 173 | arpha_loss += arpha.item() 174 | arpha.backward(retain_graph=True) 175 | 176 | v_l = self.criterion_v(tmp_gt * vr, tmp_gt * vr1) 177 | v_loss += v_l.item() 178 | v_l.backward(retain_graph=True) 179 | 180 | # The regular loss 181 | loss = self.criterion(e2, gt.squeeze(1)) 182 | epoch_loss += loss.item() 183 | loss.backward() 184 | 185 | # update the network: Uphi and Ephi 186 | self.optimizer2.step() 187 | 188 | num += e1.size()[0] 189 | 190 | if self.show_process and self.step_input: 191 | a = input('Continue(0-step, 1-run, 2-run with showing)?') 192 | if a == '1': 193 | self.show_process = 0 194 | elif a == '2': 195 | self.step_input = 0 196 | 197 | epoch_loss /= num 198 | 199 | # print(' Loss of epoch {}: {}.'.format(epoch, epoch_loss)) 200 | self.writer.add_scalar(seqName, epoch_loss, loss_step) 201 | loss_step += 1 202 | if epoch_loss < self.loss_threhold: 203 | break 204 | 205 | self.updateUVE() 206 | step += 1 207 | self.train_set.swapFC() 208 | 209 | def saveModel(self): 210 | print('Saving the Uphi model...') 211 | torch.save(self.Uphi, model_dir + 'uphi_%02d.pth' % self.seq_index) 212 | print('Saving the Vphi model...') 213 | torch.save(self.Vphi, model_dir + 'vphi_%02d.pth' % self.seq_index) 214 | print('Saving the Ephi1 model...') 215 | torch.save(self.Ephi1, model_dir + 'ephi1_%02d.pth' % self.seq_index) 216 | print('Saving the Ephi model...') 217 | torch.save(self.Ephi2, model_dir + 'ephi2_%02d.pth' % self.seq_index) 218 | print('Saving the global variable u...') 219 | torch.save(self.u, model_dir + 'u_%02d.pth' % self.seq_index) 220 | print('Done!') 221 | 222 | def updateUVE(self): 223 | with torch.no_grad(): 224 | candidates = [] 225 | E_CON, V_CON = [], [] 226 | for edge in self.train_set: 227 | e, gt, vs_index, vr_index = edge 228 | e = e.view(1, -1).to(self.device) 229 | vs = self.train_set.getApp(1, vs_index) 230 | vr = self.train_set.getApp(0, vr_index) 231 | 232 | e1 = self.Ephi1(e, vs, vr, self.u) 233 | vr1 = self.Vphi(e1, vs, vr, self.u) 234 | candidates.append((e1, gt, vs, vr1, vs_index, vr_index)) 235 | E_CON.append(e1) 236 | V_CON.append(vs) 237 | V_CON.append(vr1) 238 | 239 | E = self.train_set.aggregate(E_CON).view(1, -1) # This part is the aggregation for Edge 240 | V = self.train_set.aggregate(V_CON).view(1, -1) # This part is the aggregation for vertex 241 | u1 = self.Uphi(E, V, self.u) 242 | self.u = u1.data 243 | 244 | nxt = self.train_set.nxt 245 | for iteration in candidates: 246 | e1, gt, vs, vr1, vs_index, vr_index = iteration 247 | e2 = self.Ephi2(e1, vs, vr1, u1) 248 | if gt.item(): 249 | self.train_set.detections[nxt][vr_index][0] = vr1.data 250 | self.train_set.edges[vs_index][vr_index] = e2.data.view(-1) 251 | 252 | def update(self, seqName): 253 | print(' Train the model with the sequence: ', seqName) 254 | self.updateNetwork(seqName) 255 | self.saveModel() 256 | 257 | 258 | if __name__ == '__main__': 259 | try: 260 | deleteDir(model_dir) 261 | if not os.path.exists(model_dir): 262 | os.mkdir(model_dir) 263 | start = time.time() 264 | print(' Starting Graph Network...') 265 | gn = GN() 266 | print('Time consuming:', time.time() - start) 267 | else: 268 | # deleteDir(model_dir) 269 | # os.mkdir(model_dir) 270 | print('The model has been here!') 271 | except KeyboardInterrupt: 272 | print('Time consuming:', time.time() - start) 273 | print('') 274 | print('-' * 90) 275 | print('Existing from training early.') 276 | -------------------------------------------------------------------------------- /GN/copyfile.py: -------------------------------------------------------------------------------- 1 | import shutil, os 2 | from global_set import mot_dataset_dir 3 | 4 | name = 'motmetrics' 5 | types = ['POI'] 6 | 7 | seqs = [2, 4, 5, 9, 10, 11, 13] # the set of sequences 8 | lengths = [600, 1050, 837, 525, 654, 900, 750] # the length of the sequence 9 | 10 | test_seqs = [1, 3, 6, 7, 8, 12, 14] 11 | test_lengths = [450, 1500, 1194, 500, 625, 900, 750] 12 | 13 | # copy the results for testing sets 14 | for type in types: 15 | for i in range(len(seqs)): 16 | src_dir = 'results/%02d/%d/%s_%s/res.txt' % (test_seqs[i], test_lengths[i], name, type) 17 | 18 | t = type 19 | if type == 'DPM0' or type == 'POI': 20 | t = 'DPM' 21 | 22 | des_d = 'mot16/' 23 | if not os.path.exists(des_d): 24 | os.mkdir(des_d) 25 | des_dir = des_d + 'MOT16-%02d.txt' % (test_seqs[i]) 26 | 27 | print(src_dir) 28 | print(des_dir) 29 | shutil.copyfile(src_dir, des_dir) 30 | 31 | # Copy the ground truth of training sets 32 | # types = ['POI'] 33 | # for type in types: 34 | # for i in range(len(seqs)): 35 | # src_dir = mot_dataset_dir + 'MOT17/train/MOT17-%02d-%s/gt/gt.txt' % (seqs[i], type) 36 | # 37 | # des_dir = des_d + 'MOT16-%02d.txt' % (seqs[i]) 38 | # 39 | # print(src_dir) 40 | # print(des_dir) 41 | # shutil.copyfile(src_dir, des_dir) 42 | -------------------------------------------------------------------------------- /GN/global_set.py: -------------------------------------------------------------------------------- 1 | #mot_dataset_dir = '/media/codinglee/DATA/Ubuntu16.04/Desktop/MOT/' 2 | mot_dataset_dir = '../MOT/' 3 | 4 | model_dir = 'model/' 5 | 6 | u_initial = 1 # 1 - random, 0 - 0 7 | 8 | edge_initial = 0 # 1 - random, 0 - IoU 9 | 10 | criterion_s = 0 # 1 - MSELoss, 0 - CrossEntropyLoss 11 | 12 | test_gt_det = 0 # 1 - detections of gt, 0 - detections of det 13 | 14 | u_update = 1 # 1 - update when testing, 0 - without updating 15 | if u_update: 16 | u_dir = '_uupdate' 17 | else: 18 | u_dir = '' 19 | 20 | app_fine_tune = 0 # 1 - fine-tunedthe appearance model, 0 - pre-trained appearance model 21 | if app_fine_tune: 22 | fine_tune_dir = '../MOT/Fine-tune_GPU_5_3_60_aug/appearance_19.pth' 23 | app_dir = 'Finetuned' 24 | else: 25 | fine_tune_dir = '' 26 | app_dir = 'Pretrained' 27 | 28 | decay = 1.3 29 | decay_dir = '_decay' 30 | 31 | f_gap = 5 32 | if f_gap: 33 | recover_dir = '_Recover' 34 | else: 35 | recover_dir = '_NoRecover' 36 | 37 | tau_threshold = 1.0 # The threshold of matching cost 38 | tau_dis = 2.0 # The times of the current bbx's scale 39 | gap = 25 # max frame number for side connection 40 | 41 | show_recovering = 0 # 1 - 11, 0 - 10 42 | 43 | overlap = 0.85 # the IoU 44 | -------------------------------------------------------------------------------- /GN/m_mot_model.py: -------------------------------------------------------------------------------- 1 | import torch, math 2 | import torch.nn as nn 3 | from global_set import criterion_s 4 | 5 | v_num = 6 # Only take the appearance into consideration, and add velocity when basic model works 6 | u_num = 100 7 | e_num = 1 if criterion_s else 2 8 | 9 | uphi_n = 256 10 | ephi_n = 256 11 | 12 | 13 | class uphi(nn.Module): 14 | def __init__(self): 15 | super(uphi, self).__init__() 16 | self.features = nn.Sequential( 17 | nn.Linear(u_num + v_num + e_num, uphi_n), 18 | nn.LeakyReLU(inplace=True), 19 | nn.Linear(uphi_n, u_num), 20 | ) 21 | 22 | def forward(self, e, v, u): 23 | """ 24 | The network which updates the global variable u 25 | :param e: the aggregation of the probability 26 | :param v: the aggregation of the node 27 | :param u: global variable 28 | """ 29 | # print 'U:', e.size(), v.size(), u.size() 30 | bs = e.size()[0] 31 | if u.size()[0] == 1: 32 | if bs == 1: 33 | tmp = u 34 | else: 35 | tmp = torch.cat((u, u), dim=0) 36 | for i in range(2, bs): 37 | tmp = torch.cat((tmp, u), dim=0) 38 | else: 39 | tmp = u 40 | x = torch.cat((e, v), dim=1) 41 | x = torch.cat((x, tmp), dim=1) 42 | return self.features(x) 43 | 44 | 45 | class ephi(nn.Module): 46 | def __init__(self): 47 | super(ephi, self).__init__() 48 | self.features = nn.Sequential( 49 | nn.Linear(u_num + v_num * 2 + e_num, ephi_n), 50 | nn.LeakyReLU(inplace=True), 51 | nn.Linear(ephi_n, e_num), 52 | ) 53 | 54 | def forward(self, e, v1, v2, u): 55 | """ 56 | The network which updates the probability e 57 | :param e: the probability between two detections 58 | :param v1: the sender 59 | :param v2: the receiver 60 | :param u: global variable 61 | """ 62 | # print 'E:', e.size(), v1.size(), v2.size(), u.size() 63 | bs = e.size()[0] 64 | if u.size()[0] == 1: 65 | if bs == 1: 66 | tmp = u 67 | else: 68 | tmp = torch.cat((u, u), dim=0) 69 | for i in range(2, bs): 70 | tmp = torch.cat((tmp, u), dim=0) 71 | else: 72 | tmp = u 73 | x = torch.cat((e, v1), dim=1) 74 | x = torch.cat((x, v2), dim=1) 75 | x = torch.cat((x, tmp), dim=1) 76 | return self.features(x) 77 | -------------------------------------------------------------------------------- /GN/m_test_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import random, torch, shutil, os, gc 3 | from math import * 4 | from PIL import Image 5 | import torch.nn.functional as F 6 | from global_set import edge_initial, test_gt_det, tau_dis, tau_threshold # , tau_conf_score 7 | 8 | 9 | def load_img(filepath): 10 | img = Image.open(filepath).convert('RGB') 11 | return img 12 | 13 | 14 | class MDatasetFromFolder(data.Dataset): 15 | def __init__(self, part, part_I, tau, cuda=True): 16 | super(MDatasetFromFolder, self).__init__() 17 | self.dir = part 18 | self.cleanPath(part) 19 | self.img_dir = part_I + '/img1/' 20 | self.gt_dir = part + '/gt/' 21 | self.det_dir = part + '/det/' 22 | self.device = torch.device("cuda" if cuda else "cpu") 23 | self.tau_conf_score = tau 24 | 25 | self.getSeqL() 26 | if test_gt_det: 27 | self.readBBx_gt() 28 | else: 29 | self.readBBx_det() 30 | self.initBuffer() 31 | 32 | def cleanPath(self, part): 33 | if os.path.exists(part + '/gts/'): 34 | shutil.rmtree(part + '/gts/') 35 | if os.path.exists(part + '/dets/'): 36 | shutil.rmtree(part + '/dets/') 37 | 38 | def getSeqL(self): 39 | # get the length of the sequence 40 | info = self.dir + '/seqinfo.ini' 41 | f = open(info, 'r') 42 | f.readline() 43 | for line in f.readlines(): 44 | line = line.strip().split('=') 45 | if line[0] == 'seqLength': 46 | self.seqL = int(line[1]) 47 | f.close() 48 | # print 'The length of the sequence:', self.seqL 49 | 50 | def fixBB(self, x, y, w, h, size): 51 | width, height = size 52 | w = min(w + x, width) 53 | h = min(h + y, height) 54 | x = max(x, 0) 55 | y = max(y, 0) 56 | w -= x 57 | h -= y 58 | return x, y, w, h 59 | 60 | def readBBx_gt(self): 61 | # get the gt 62 | self.bbx = [[] for i in range(self.seqL + 1)] 63 | bbxs = [[] for i in range(self.seqL + 1)] 64 | imgs = [None for i in range(self.seqL + 1)] 65 | for i in range(1, self.seqL + 1): 66 | imgs[i] = load_img(self.img_dir + '%06d.jpg' % i) 67 | gt = self.gt_dir + 'gt.txt' 68 | f = open(gt, 'r') 69 | pre = -1 70 | for line in f.readlines(): 71 | line = line.strip().split(',') 72 | if line[7] == '1': 73 | index = int(line[0]) 74 | id = int(line[1]) 75 | x, y = float(line[2]), float(line[3]) 76 | w, h = float(line[4]), float(line[5]) 77 | conf_score, l, vr = float(line[6]), int(line[7]), float(line[8]) 78 | 79 | # sweep the invisible head-bbx from the training data 80 | if pre != id and vr == 0: 81 | continue 82 | 83 | pre = id 84 | img = imgs[index] 85 | x, y, w, h = self.fixBB(x, y, w, h, img.size) 86 | width, height = float(img.size[0]), float(img.size[1]) 87 | self.bbx[index].append([x / width, y / height, w / width, h / height, id, conf_score, vr]) 88 | bbxs[index].append([x, y, w, h, id, conf_score, vr]) 89 | f.close() 90 | 91 | gt_out = open(self.gt_dir + 'gt_det.txt', 'w') 92 | for index in range(1, self.seqL + 1): 93 | for bbx in bbxs[index]: 94 | x, y, w, h, id, conf_score, vr = bbx 95 | print >> gt_out, '%d,-1,%d,%d,%d,%d,%f,-1,-1,-1' % (index, x, y, w, h, conf_score) 96 | gt_out.close() 97 | 98 | def readBBx_det(self): 99 | # get the gt 100 | self.bbx = [[] for i in range(self.seqL + 1)] 101 | imgs = [None for i in range(self.seqL + 1)] 102 | for i in range(1, self.seqL + 1): 103 | imgs[i] = load_img(self.img_dir + '%06d.jpg' % i) 104 | det = self.det_dir + 'det.txt' 105 | f = open(det, 'r') 106 | for line in f.readlines(): 107 | line = line.strip().split(',') 108 | index = int(line[0]) 109 | id = int(line[1]) 110 | x, y = float(line[2]), float(line[3]) 111 | w, h = float(line[4]), float(line[5]) 112 | conf_score = float(line[6]) 113 | if conf_score >= self.tau_conf_score: 114 | img = imgs[i] 115 | x, y, w, h = self.fixBB(x, y, w, h, img.size) 116 | width, height = float(img.size[0]), float(img.size[1]) 117 | self.bbx[index].append([x / width, y / height, w / width, h / height, id, conf_score]) 118 | f.close() 119 | 120 | def initBuffer(self): 121 | self.f_step = 1 # the index of next frame in the process 122 | self.cur = 0 # the index of current frame in the detections 123 | self.nxt = 1 # the index of next frame in the detections 124 | self.detections = [None, None] # the buffer to storing images: current & next frame 125 | self.feature(1) 126 | 127 | def setBuffer(self, f): 128 | self.m = 0 129 | counter = -1 130 | while self.m == 0: 131 | counter += 1 132 | self.f_step = f + counter 133 | self.feature(1) 134 | self.m = len(self.detections[self.cur]) 135 | if counter > 0: 136 | print(' Empty in setBuffer:', counter) 137 | return counter 138 | 139 | def IOU(self, Reframe, GTframe): 140 | """ 141 | Compute the Intersection of Union 142 | :param Reframe: x, y, w, h 143 | :param GTframe: x, y, w, h 144 | :return: Ratio 145 | """ 146 | if edge_initial == 1: 147 | return random.random() 148 | elif edge_initial == 3: 149 | return 0.5 150 | x1 = Reframe[0] 151 | y1 = Reframe[1] 152 | width1 = Reframe[2] 153 | height1 = Reframe[3] 154 | 155 | x2 = GTframe[0] 156 | y2 = GTframe[1] 157 | width2 = GTframe[2] 158 | height2 = GTframe[3] 159 | 160 | endx = max(x1 + width1, x2 + width2) 161 | startx = min(x1, x2) 162 | width = width1 + width2 - (endx - startx) 163 | 164 | endy = max(y1 + height1, y2 + height2) 165 | starty = min(y1, y2) 166 | height = height1 + height2 - (endy - starty) 167 | 168 | if width <= 0 or height <= 0: 169 | ratio = 0 170 | else: 171 | Area = width * height 172 | Area1 = width1 * height1 173 | Area2 = width2 * height2 174 | ratio = Area * 1. / (Area1 + Area2 - Area) 175 | return ratio 176 | 177 | def aggregate(self, set): 178 | if len(set): 179 | rho = sum(set) 180 | return rho / len(set) 181 | print(' The set is empty!') 182 | return None 183 | 184 | def distance(self, a_bbx, b_bbx): 185 | w = min(float(a_bbx[2]) * tau_dis, float(b_bbx[2]) * tau_dis) 186 | dx = float(a_bbx[0] + a_bbx[2] / 2) - float(b_bbx[0] + b_bbx[2] / 2) 187 | dy = float(a_bbx[1] + a_bbx[3] / 2) - float(b_bbx[1] + b_bbx[3] / 2) 188 | d = sqrt(dx * dx + dy * dy) 189 | if d <= w: 190 | return 0.0 191 | return tau_threshold 192 | 193 | def getRet(self): 194 | cur = self.f_step - self.gap 195 | ret = [[0.0 for i in range(self.n)] for j in range(self.m)] 196 | for i in range(self.m): 197 | bbx1 = self.bbx[cur][i] 198 | for j in range(self.n): 199 | ret[i][j] = self.distance(bbx1, self.bbx[self.f_step][j]) 200 | return ret 201 | 202 | def getMotion(self, tag, index, pre_index=None, t=None): 203 | cur = self.cur if tag else self.nxt 204 | if tag == 0: 205 | self.updateVelocity(pre_index, index, t) 206 | return self.detections[cur][index][0][pre_index] 207 | return self.detections[cur][index][0][0] 208 | 209 | def moveMotion(self, index): 210 | self.bbx[self.f_step].append(self.bbx[self.f_step - self.gap][index]) # add the bbx: x, y, w, h, id, conf_score 211 | self.detections[self.nxt].append( 212 | self.detections[self.cur][index]) # add the motion: [[x, y, w, h, v_x, v_y], id] 213 | 214 | def cleanEdge(self): 215 | con = [] 216 | index = 0 217 | for det in self.detections[self.nxt]: 218 | motion, id = det 219 | x = motion[0][0].item() + motion[0][4].item() 220 | y = motion[0][1].item() + motion[0][5].item() 221 | if (x < 0.0 or x > 1.0) or (y < 0.0 or y > 1.0): 222 | con.append(index) 223 | index += 1 224 | 225 | for i in range(len(con) - 1, -1, -1): 226 | index = con[i] 227 | del self.bbx[self.f_step][index] 228 | del self.detections[self.nxt][index] 229 | return con 230 | 231 | def swapFC(self): 232 | self.cur = self.cur ^ self.nxt 233 | self.nxt = self.cur ^ self.nxt 234 | self.cur = self.cur ^ self.nxt 235 | 236 | def updateVelocity(self, i, j, t=None, tag=True): 237 | v_x = 0.0 238 | v_y = 0.0 239 | if i != -1: 240 | if test_gt_det: 241 | x1, y1, w1, h1, id1, conf_score1, vr1 = self.bbx[self.f_step - self.gap][i] 242 | x2, y2, w2, h2, id2, conf_score2, vr2 = self.bbx[self.f_step][j] 243 | else: 244 | x1, y1, w1, h1, id1, conf_score1 = self.bbx[self.f_step - self.gap][i] 245 | x2, y2, w2, h2, id2, conf_score2 = self.bbx[self.f_step][j] 246 | v_x = (x2 + w2 / 2 - (x1 + w1 / 2)) / t 247 | v_y = (y2 + h2 / 2 - (y1 + h1 / 2)) / t 248 | if tag: 249 | # print 'm=%d,n=%d; i=%d, j=%d'%(len(self.detections[self.cur]), len(self.detections[self.nxt]), i, j) 250 | self.detections[self.nxt][j][0][i][0][4] = v_x 251 | self.detections[self.nxt][j][0][i][0][5] = v_y 252 | else: 253 | cur_m = self.detections[self.nxt][j][0][0] 254 | cur_m[0][4] = v_x 255 | cur_m[0][5] = v_y 256 | self.detections[self.nxt][j][0] = [cur_m] 257 | 258 | def getMN(self, m, n): 259 | cur = self.f_step - self.gap 260 | ans = [[None for i in range(n)] for i in range(m)] 261 | for i in range(m): 262 | Reframe = self.bbx[cur][i] 263 | for j in range(n): 264 | GTframe = self.bbx[self.f_step][j] 265 | p = self.IOU(Reframe, GTframe) 266 | # 1 - match, 0 - mismatch 267 | ans[i][j] = torch.FloatTensor([(1 - p) / 100.0, p / 100.0]) 268 | return ans 269 | 270 | def feature(self, tag=0): 271 | ''' 272 | Getting the appearance of the detections in current frame 273 | :param tag: 1 - initiating 274 | :param show: 1 - show the cropped & src image 275 | :return: None 276 | ''' 277 | motions = [] 278 | with torch.no_grad(): 279 | m = 1 if tag else self.m 280 | for bbx in self.bbx[self.f_step]: 281 | """ 282 | Bellow Conditions needed be taken into consideration: 283 | x, y < 0 and x+w > W, y+h > H 284 | """ 285 | if test_gt_det: 286 | x, y, w, h, id, conf_score, vr = bbx 287 | else: 288 | x, y, w, h, id, conf_score = bbx 289 | cur_m = [] 290 | for i in range(m): 291 | cur_m.append(torch.FloatTensor([[x, y, w, h, 0.0, 0.0]]).to(self.device)) 292 | motions.append([cur_m, id]) 293 | if tag: 294 | self.detections[self.cur] = motions 295 | else: 296 | self.detections[self.nxt] = motions 297 | 298 | def loadNext(self): 299 | self.m = len(self.detections[self.cur]) 300 | 301 | self.gap = 0 302 | self.n = 0 303 | while self.n == 0: 304 | self.f_step += 1 305 | self.feature() 306 | self.n = len(self.detections[self.nxt]) 307 | self.gap += 1 308 | 309 | if self.gap > 1: 310 | print(' Empty in loadNext:', self.f_step - self.gap + 1, '-', self.gap - 1) 311 | 312 | self.candidates = [] 313 | self.edges = self.getMN(self.m, self.n) 314 | 315 | es = [] 316 | # vs_index = 0 317 | for i in range(self.m): 318 | # vr_index = self.m 319 | for j in range(self.n): 320 | e = self.edges[i][j] 321 | es.append(e) 322 | self.candidates.append([e, i, j]) 323 | # vr_index += 1 324 | # vs_index += 1 325 | 326 | vs = [] 327 | for i in range(2): 328 | n = len(self.detections[i]) 329 | for j in range(n): 330 | v = self.detections[i][j][0][0] 331 | vs.append(v) 332 | 333 | self.E = self.aggregate(es).to(self.device).view(1, -1) 334 | self.V = self.aggregate(vs).to(self.device) 335 | 336 | # print ' The index of the next frame', self.f_step, len(self.bbx) 337 | return self.gap 338 | 339 | def __getitem__(self, index): 340 | return self.candidates[index] 341 | 342 | def __len__(self): 343 | return len(self.candidates) 344 | -------------------------------------------------------------------------------- /GN/mot_model.py: -------------------------------------------------------------------------------- 1 | import torch, torchvision 2 | import torch.nn as nn 3 | from global_set import criterion_s 4 | 5 | v_num = 512 # Only take the appearance into consideration, and add velocity when basic model works 6 | u_num = 100 7 | e_num = 1 if criterion_s else 2 8 | 9 | 10 | class appearance(nn.Module): 11 | def __init__(self): 12 | super(appearance, self).__init__() 13 | features = list(torchvision.models.resnet34(pretrained=True).children())[:-1] 14 | # print features 15 | self.features = nn.Sequential(*features) 16 | 17 | def forward(self, x): 18 | return self.features(x) 19 | 20 | 21 | class uphi(nn.Module): 22 | def __init__(self): 23 | super(uphi, self).__init__() 24 | self.features = nn.Sequential( 25 | nn.Linear(u_num + v_num + e_num, 256), 26 | nn.LeakyReLU(inplace=True), 27 | nn.Linear(256, u_num), 28 | ) 29 | 30 | def forward(self, e, v, u): 31 | """ 32 | The network which updates the global variable u 33 | :param e: the aggregation of the probability 34 | :param v: the aggregation of the vertice 35 | :param u: global variable 36 | """ 37 | # print 'U:', e.size(), v.size(), u.size() 38 | bs = e.size()[0] 39 | if bs == 1: 40 | tmp = u 41 | else: 42 | tmp = torch.cat((u, u), dim=0) 43 | for i in range(2, bs): 44 | tmp = torch.cat((tmp, u), dim=0) 45 | x = torch.cat((e, v), dim=1) 46 | x = torch.cat((x, tmp), dim=1) 47 | return self.features(x) 48 | 49 | 50 | class ephi(nn.Module): 51 | def __init__(self): 52 | super(ephi, self).__init__() 53 | self.features = nn.Sequential( 54 | nn.Linear(u_num + v_num * 2 + e_num, 256), 55 | nn.LeakyReLU(inplace=True), 56 | nn.Linear(256, e_num), 57 | ) 58 | 59 | def forward(self, e, v1, v2, u): 60 | """ 61 | The network which updates the probability e 62 | :param e: the probability between two detections 63 | :param v1: the sender 64 | :param v2: the receiver 65 | :param u: global variable 66 | """ 67 | # print 'E:', e.size(), v1.size(), v2.size(), u.size() 68 | bs = e.size()[0] 69 | if bs == 1: 70 | tmp = u 71 | else: 72 | tmp = torch.cat((u, u), dim=0) 73 | for i in range(2, bs): 74 | tmp = torch.cat((tmp, u), dim=0) 75 | x = torch.cat((e, v1), dim=1) 76 | x = torch.cat((x, v2), dim=1) 77 | x = torch.cat((x, tmp), dim=1) 78 | return self.features(x) 79 | 80 | 81 | class vphi(nn.Module): 82 | def __init__(self): 83 | super(vphi, self).__init__() 84 | self.features = nn.Sequential( 85 | nn.Linear(u_num + v_num * 2 + e_num, 256), 86 | nn.LeakyReLU(inplace=True), 87 | nn.Linear(256, v_num), 88 | ) 89 | 90 | def forward(self, e, v1, v2, u): 91 | """ 92 | The network which updates the probability e 93 | :param e: the probability between two detections 94 | :param v1: the sender 95 | :param v2: the receiver 96 | :param u: global variable 97 | """ 98 | # print 'E:', e.size(), v1.size(), v2.size(), u.size() 99 | bs = e.size()[0] 100 | if bs == 1: 101 | tmp = u 102 | else: 103 | tmp = torch.cat((u, u), dim=0) 104 | for i in range(2, bs): 105 | tmp = torch.cat((tmp, u), dim=0) 106 | x = torch.cat((e, v1), dim=1) 107 | x = torch.cat((x, v2), dim=1) 108 | x = torch.cat((x, tmp), dim=1) 109 | return self.features(x) 110 | -------------------------------------------------------------------------------- /GN/test.py: -------------------------------------------------------------------------------- 1 | # from __future__ import print_function 2 | import numpy as np 3 | from m_mot_model import * 4 | from munkres import Munkres 5 | import torch.nn.functional as F 6 | import time, os, shutil, random 7 | from global_set import test_gt_det, \ 8 | tau_threshold, gap, f_gap, show_recovering, decay, u_update, mot_dataset_dir, model_dir 9 | from mot_model import appearance 10 | from test_dataset import ADatasetFromFolder 11 | from m_test_dataset import MDatasetFromFolder 12 | 13 | torch.manual_seed(123) 14 | np.random.seed(123) 15 | 16 | 17 | def deleteDir(del_dir): 18 | shutil.rmtree(del_dir) 19 | 20 | 21 | year = 17 22 | 23 | type = '' # detection method 24 | t_dir = '' # the dir of tracking results 25 | sequence_dir = '' # the dir of the testing video sequence 26 | 27 | seqs = [2, 4, 5, 9, 10, 11, 13] # the set of sequences 28 | lengths = [600, 1050, 837, 525, 654, 900, 750] # the length of the sequence 29 | 30 | test_seqs = [1, 3, 6, 7, 8, 12, 14] 31 | test_lengths = [450, 1500, 1194, 500, 625, 900, 750] 32 | 33 | tt_tag = 1 # 1 - test, 0 - train 34 | 35 | tau_conf_score = 0.0 36 | 37 | 38 | class GN(): 39 | def __init__(self, seq_index, seq_len, cuda=True): 40 | ''' 41 | Evaluating with the MotMetrics 42 | :param seq_index: the number of the sequence 43 | :param seq_len: the length of the sequence 44 | :param cuda: True - GPU, False - CPU 45 | ''' 46 | self.bbx_counter = 0 47 | self.seq_index = seq_index 48 | self.hungarian = Munkres() 49 | self.device = torch.device("cuda" if cuda else "cpu") 50 | self.seq_len = seq_len 51 | self.alpha = 0.3 52 | self.missingCounter = 0 53 | self.sideConnection = 0 54 | 55 | print(' Loading the model...') 56 | self.loadAModel() 57 | self.loadMModel() 58 | 59 | self.out_dir = t_dir + 'motmetrics_%s/' % (type) 60 | 61 | print(self.out_dir) 62 | if not os.path.exists(self.out_dir): 63 | os.mkdir(self.out_dir) 64 | else: 65 | deleteDir(self.out_dir) 66 | os.mkdir(self.out_dir) 67 | self.initOut() 68 | 69 | def initOut(self): 70 | print(' Loading Data...') 71 | self.a_train_set = ADatasetFromFolder(sequence_dir, mot_dataset_dir + 'MOT16/test/MOT16-%02d' % self.seq_index, 72 | tau_conf_score) 73 | self.m_train_set = MDatasetFromFolder(sequence_dir, mot_dataset_dir + 'MOT16/test/MOT16-%02d' % self.seq_index, 74 | tau_conf_score) 75 | 76 | detection_dir = self.out_dir + 'res_det.txt' 77 | res_training = self.out_dir + 'res.txt' # the tracking results 78 | self.createTxt(detection_dir) 79 | self.createTxt(res_training) 80 | self.copyLines(self.seq_index, 1, detection_dir, self.seq_len, 1) 81 | 82 | self.evaluation(1, self.seq_len, detection_dir, res_training) 83 | 84 | def getSeqL(self, info): 85 | # get the length of the sequence 86 | f = open(info, 'r') 87 | f.readline() 88 | for line in f.readlines(): 89 | line = line.strip().split('=') 90 | if line[0] == 'seqLength': 91 | seqL = int(line[1]) 92 | f.close() 93 | return seqL 94 | 95 | def copyLines(self, seq, head, gt_seq, tail=-1, tag=0): 96 | ''' 97 | Copy the groun truth within [head, head+num] 98 | :param seq: the number of the sequence 99 | :param head: the head frame number 100 | :param tail: the number the clipped sequence 101 | :param gt_seq: the dir of the output file 102 | :return: None 103 | ''' 104 | if tt_tag: 105 | basic_dir = mot_dataset_dir + 'MOT%d/test/MOT%d-%02d-%s/' % (year, year, seq, type) 106 | else: 107 | basic_dir = mot_dataset_dir + 'MOT%d/train/MOT%d-%02d-%s/' % (year, year, seq, type) 108 | print(' Testing on', basic_dir, 'Length:', self.seq_len) 109 | seqL = tail if tail != -1 else self.getSeqL(basic_dir + 'seqinfo.ini') 110 | 111 | det_dir = 'gt/gt_det.txt' if test_gt_det else 'det/det.txt' 112 | seq_dir = basic_dir + ('gt/gt.txt' if tag == 0 else det_dir) 113 | inStream = open(seq_dir, 'r') 114 | 115 | outStream = open(gt_seq, 'w') 116 | for line in inStream.readlines(): 117 | line = line.strip() 118 | attrs = line.split(',') 119 | f_num = int(attrs[0]) 120 | if f_num >= head and f_num <= seqL: 121 | outStream.write(line + '\n') 122 | outStream.close() 123 | 124 | inStream.close() 125 | return seqL 126 | 127 | def createTxt(self, out_file): 128 | f = open(out_file, 'w') 129 | f.close() 130 | 131 | def loadAModel(self): 132 | from mot_model import uphi, ephi, vphi 133 | tail = 13 134 | self.AUphi = torch.load('../App2/' + model_dir + 'uphi_%02d.pth' % tail).to(self.device) 135 | self.AVphi = torch.load('../App2/' + model_dir + 'vphi_%02d.pth' % tail).to(self.device) 136 | self.AEphi1 = torch.load('../App2/' + model_dir + 'ephi1_%02d.pth' % tail).to(self.device) 137 | self.AEphi2 = torch.load('../App2/' + model_dir + 'ephi2_%02d.pth' % tail).to(self.device) 138 | self.Au = torch.load('../App2/' + model_dir + 'u_%02d.pth' % tail).to(self.device) 139 | 140 | def loadMModel(self): 141 | from m_mot_model import uphi, ephi 142 | tail = 13 143 | self.MUphi = torch.load('../Motion1/' + model_dir + 'uphi_%d.pth' % tail).to(self.device) 144 | self.MEphi = torch.load('../Motion1/' + model_dir + 'ephi_%d.pth' % tail).to(self.device) 145 | self.Mu = torch.load('../Motion1/' + model_dir + 'u_%d.pth' % tail).to(self.device) 146 | 147 | def swapFC(self): 148 | self.cur = self.cur ^ self.nxt 149 | self.nxt = self.cur ^ self.nxt 150 | self.cur = self.cur ^ self.nxt 151 | 152 | def linearModel(self, out, attr1, attr2): 153 | # print 'I got you! *.*' 154 | t = attr1[-1] 155 | self.sideConnection += 1 156 | if t > f_gap: 157 | return 158 | frame = int(attr1[0]) 159 | x1, y1, w1, h1 = float(attr1[2]), float(attr1[3]), float(attr1[4]), float(attr1[5]) 160 | x2, y2, w2, h2 = float(attr2[2]), float(attr2[3]), float(attr2[4]), float(attr2[5]) 161 | 162 | x_delta = (x2 - x1) / t 163 | y_delta = (y2 - y1) / t 164 | w_delta = (w2 - w1) / t 165 | h_delta = (h2 - h1) / t 166 | 167 | for i in range(1, t): 168 | frame += 1 169 | x1 += x_delta 170 | y1 += y_delta 171 | w1 += w_delta 172 | h1 += h_delta 173 | attr1[0] = str(frame) 174 | attr1[2] = str(x1) 175 | attr1[3] = str(y1) 176 | attr1[4] = str(w1) 177 | attr1[5] = str(h1) 178 | line = '' 179 | for attr in attr1[:-1]: 180 | line += attr + ',' 181 | if show_recovering: 182 | line += '1' 183 | else: 184 | line = line[:-1] 185 | out.write(line + '\n') 186 | self.bbx_counter += 1 187 | self.missingCounter += t - 1 188 | 189 | def evaluation(self, head, tail, gtFile, outFile): 190 | ''' 191 | Evaluation on dets 192 | :param head: the head frame number 193 | :param tail: the tail frame number 194 | :param gtFile: the ground truth file name 195 | :param outFile: the name of output file 196 | :return: None 197 | ''' 198 | gtIn = open(gtFile, 'r') 199 | self.cur, self.nxt = 0, 1 200 | line_con = [[], []] 201 | id_con = [[], []] 202 | id_step = 1 203 | 204 | a_step = head + self.a_train_set.setBuffer(head) 205 | m_step = head + self.m_train_set.setBuffer(head) 206 | if a_step != m_step: 207 | print('Something is wrong!') 208 | print('a_step =', a_step, ', m_step =', m_step) 209 | input('Continue?') 210 | 211 | while a_step < tail: 212 | # print '*********************************' 213 | a_t_gap = self.a_train_set.loadNext() 214 | m_t_gap = self.m_train_set.loadNext() 215 | if a_t_gap != m_t_gap: 216 | print('Something is wrong!') 217 | print('a_t_gap =', a_t_gap, ', m_t_gap =', m_t_gap) 218 | input('Continue?') 219 | a_step += a_t_gap 220 | m_step += m_step 221 | print(a_step, end=' ') 222 | if a_step % 100 == 0: 223 | print('') 224 | 225 | m_u_ = self.MUphi(self.m_train_set.E, self.m_train_set.V, self.Mu) 226 | 227 | # print 'Fo' 228 | a_m = self.a_train_set.m 229 | a_n = self.a_train_set.n 230 | m_m = self.m_train_set.m 231 | m_n = self.m_train_set.n 232 | 233 | if a_m != m_m or a_n != m_n: 234 | print('Something is wrong!') 235 | print('a_m = %d, m_m = %d' % (a_m, m_m), ', a_n = %d, m_n = %d' % (a_n, m_n)) 236 | input('Continue?') 237 | # print 'm = %d, n = %d'%(m, n) 238 | if a_n == 0: 239 | print('There is no detection in the rest of sequence!') 240 | break 241 | 242 | if id_step == 1: 243 | out = open(outFile, 'a') 244 | i = 0 245 | while i < a_m: 246 | attrs = gtIn.readline().strip().split(',') 247 | if float(attrs[6]) >= tau_conf_score: 248 | attrs.append(1) 249 | attrs[1] = str(id_step) 250 | line = '' 251 | for attr in attrs[:-1]: 252 | line += attr + ',' 253 | if show_recovering: 254 | line += '0' 255 | else: 256 | line = line[:-1] 257 | out.write(line + '\n') 258 | self.bbx_counter += 1 259 | line_con[self.cur].append(attrs) 260 | id_con[self.cur].append(id_step) 261 | id_step += 1 262 | i += 1 263 | out.close() 264 | 265 | i = 0 266 | while i < a_n: 267 | attrs = gtIn.readline().strip().split(',') 268 | if float(attrs[6]) >= tau_conf_score: 269 | attrs.append(1) 270 | line_con[self.nxt].append(attrs) 271 | id_con[self.nxt].append(-1) 272 | i += 1 273 | 274 | # update the edges 275 | # print 'T', 276 | candidates = [] 277 | E_CON, V_CON = [], [] 278 | for edge in self.a_train_set.candidates: 279 | e, vs_index, vr_index = edge 280 | e = e.view(1, -1).to(self.device) 281 | vs = self.a_train_set.getApp(1, vs_index) 282 | vr = self.a_train_set.getApp(0, vr_index) 283 | 284 | e1 = self.AEphi1(e, vs, vr, self.Au) 285 | vr1 = self.AVphi(e1, vs, vr, self.Au) 286 | candidates.append((e1, vs, vr1, vs_index, vr_index)) 287 | E_CON.append(e1) 288 | V_CON.append(vs) 289 | V_CON.append(vr1) 290 | 291 | E = self.a_train_set.aggregate(E_CON).view(1, -1) 292 | V = self.a_train_set.aggregate(V_CON).view(1, -1) 293 | u1 = self.AUphi(E, V, self.Au) 294 | 295 | ret = self.a_train_set.getRet() 296 | decay_tag = [0 for i in range(a_m)] 297 | for i in range(a_m): 298 | for j in range(a_n): 299 | if ret[i][j] == 0: 300 | decay_tag[i] += 1 301 | 302 | for i in range(len(self.a_train_set.candidates)): 303 | e1, vs, vr1, a_vs_index, a_vr_index = candidates[i] 304 | m_e, m_vs_index, m_vr_index = self.m_train_set.candidates[i] 305 | if a_vs_index != m_vs_index or a_vr_index != m_vr_index: 306 | print('Something is wrong!') 307 | print('a_vs_index = %d, m_vs_index = %d' % (a_vs_index, m_vs_index)) 308 | print('a_vr_index = %d, m_vr_index = %d' % (a_vr_index, m_vr_index)) 309 | input('Continue?') 310 | if ret[a_vs_index][a_vr_index] == tau_threshold: 311 | continue 312 | 313 | e2 = self.AEphi2(e1, vs, vr1, u1) 314 | self.a_train_set.edges[a_vs_index][a_vr_index] = e1.data.view(-1) 315 | 316 | a_tmp = F.softmax(e2) 317 | a_tmp = a_tmp.cpu().data.numpy()[0] 318 | 319 | m_e = m_e.to(self.device).view(1, -1) 320 | m_v1 = self.m_train_set.getMotion(1, m_vs_index) 321 | m_v2 = self.m_train_set.getMotion(0, m_vr_index, m_vs_index, 322 | line_con[self.cur][m_vs_index][-1] + a_t_gap - 1) 323 | m_e_ = self.MEphi(m_e, m_v1, m_v2, m_u_) 324 | self.m_train_set.edges[m_vs_index][m_vr_index] = m_e_.data.view(-1) 325 | m_tmp = F.softmax(m_e_) 326 | m_tmp = m_tmp.cpu().data.numpy()[0] 327 | 328 | t = line_con[self.cur][a_vs_index][-1] 329 | if decay_tag[a_vs_index] > 0: 330 | A = min(float(a_tmp[0]) * pow(decay, t - 1), 1.0) 331 | M = min(float(m_tmp[0]) * pow(decay, t - 1), 1.0) 332 | else: 333 | A = float(a_tmp[0]) 334 | M = float(m_tmp[0]) 335 | ret[a_vs_index][a_vr_index] = A * self.alpha + M * (1 - self.alpha) 336 | 337 | # for j in ret: 338 | # print j 339 | results = self.hungarian.compute(ret) 340 | 341 | out = open(outFile, 'a') 342 | look_up = set(j for j in range(a_n)) 343 | nxt = self.a_train_set.nxt 344 | for (i, j) in results: 345 | # print (i,j) 346 | if ret[i][j] >= tau_threshold: 347 | continue 348 | e1 = self.a_train_set.edges[i][j].view(1, -1).to(self.device) 349 | vs = self.a_train_set.getApp(1, i) 350 | vr = self.a_train_set.getApp(0, j) 351 | 352 | vr1 = self.AVphi(e1, vs, vr, self.Au) 353 | self.a_train_set.detections[nxt][j][0] = vr1.data 354 | 355 | look_up.remove(j) 356 | self.m_train_set.updateVelocity(i, j, line_con[self.cur][i][-1], False) 357 | 358 | id = id_con[self.cur][i] 359 | id_con[self.nxt][j] = id 360 | attr1 = line_con[self.cur][i] 361 | attr2 = line_con[self.nxt][j] 362 | # print attrs 363 | attr2[1] = str(id) 364 | if attr1[-1] + a_t_gap - 1 > 1: 365 | # for the missing detections 366 | self.linearModel(out, attr1, attr2) 367 | line = '' 368 | for attr in attr2[:-1]: 369 | line += attr + ',' 370 | if show_recovering: 371 | line += '0' 372 | else: 373 | line = line[:-1] 374 | out.write(line + '\n') 375 | self.bbx_counter += 1 376 | 377 | if u_update: 378 | self.Mu = m_u_.data 379 | self.Au = u1.data 380 | 381 | for j in look_up: 382 | self.m_train_set.updateVelocity(-1, j, tag=False) 383 | 384 | for i in range(a_n): 385 | if id_con[self.nxt][i] == -1: 386 | id_con[self.nxt][i] = id_step 387 | attrs = line_con[self.nxt][i] 388 | attrs[1] = str(id_step) 389 | line = '' 390 | for attr in attrs[:-1]: 391 | line += attr + ',' 392 | if show_recovering: 393 | line += '0' 394 | else: 395 | line = line[:-1] 396 | out.write(line + '\n') 397 | self.bbx_counter += 1 398 | id_step += 1 399 | out.close() 400 | 401 | # For missing & Occlusion 402 | index = 0 403 | for (i, j) in results: 404 | while i != index: 405 | attrs = line_con[self.cur][index] 406 | # print '*', attrs, '*' 407 | if attrs[-1] + a_t_gap <= gap: 408 | attrs[-1] += a_t_gap 409 | line_con[self.nxt].append(attrs) 410 | id_con[self.nxt].append(id_con[self.cur][index]) 411 | self.a_train_set.moveApp(index) 412 | self.m_train_set.moveMotion(index) 413 | index += 1 414 | if ret[i][j] >= tau_threshold: 415 | attrs = line_con[self.cur][index] 416 | # print '*', attrs, '*' 417 | if attrs[-1] + a_t_gap <= gap: 418 | attrs[-1] += a_t_gap 419 | line_con[self.nxt].append(attrs) 420 | id_con[self.nxt].append(id_con[self.cur][index]) 421 | self.a_train_set.moveApp(index) 422 | self.m_train_set.moveMotion(index) 423 | index += 1 424 | while index < a_m: 425 | attrs = line_con[self.cur][index] 426 | # print '*', attrs, '*' 427 | if attrs[-1] + a_t_gap <= gap: 428 | attrs[-1] += a_t_gap 429 | line_con[self.nxt].append(attrs) 430 | id_con[self.nxt].append(id_con[self.cur][index]) 431 | self.a_train_set.moveApp(index) 432 | self.m_train_set.moveMotion(index) 433 | index += 1 434 | 435 | # con = self.m_train_set.cleanEdge() 436 | # for i in range(len(con)-1, -1, -1): 437 | # index = con[i] 438 | # del line_con[self.nxt][index] 439 | # del id_con[self.nxt][index] 440 | 441 | line_con[self.cur] = [] 442 | id_con[self.cur] = [] 443 | # print head+step, results 444 | self.a_train_set.swapFC() 445 | self.m_train_set.swapFC() 446 | self.swapFC() 447 | gtIn.close() 448 | print(' The results:', id_step, self.bbx_counter) 449 | 450 | 451 | if __name__ == '__main__': 452 | try: 453 | types = [['POI', 0.7]] 454 | # types = [['DPM', -0.6], ['SDP', 0.5], ['FRCNN', 0.5]] 455 | 456 | for t in types: 457 | type, tau_conf_score = t 458 | head = time.time() 459 | 460 | f_dir = 'results/' 461 | if not os.path.exists(f_dir): 462 | os.mkdir(f_dir) 463 | 464 | for i in range(len(seqs)): 465 | if tt_tag: 466 | seq_index = test_seqs[i] 467 | seq_len = test_lengths[i] 468 | else: 469 | seq_index = seqs[i] 470 | seq_len = lengths[i] 471 | 472 | print('The sequence:', seq_index, '- The length of the training data:', seq_len) 473 | 474 | s_dir = f_dir + '%02d/' % seq_index 475 | if not os.path.exists(s_dir): 476 | os.mkdir(s_dir) 477 | print(s_dir, 'does not exist!') 478 | 479 | t_dir = s_dir + '%d/' % seq_len 480 | if not os.path.exists(t_dir): 481 | os.mkdir(t_dir) 482 | print(t_dir, 'does not exist!') 483 | 484 | if tt_tag: 485 | seq_dir = 'MOT%d-%02d-%s' % (year, test_seqs[i], type) 486 | sequence_dir = mot_dataset_dir + 'MOT%d/test/' % year + seq_dir 487 | print(' ', sequence_dir) 488 | 489 | start = time.time() 490 | print(' Evaluating Graph Network...') 491 | gn = GN(test_seqs[i], test_lengths[i]) 492 | else: 493 | seq_dir = 'MOT%d-%02d-%s' % (year, seqs[i], type) 494 | sequence_dir = mot_dataset_dir + 'MOT%d/train/' % year + seq_dir 495 | print(' ', sequence_dir) 496 | 497 | start = time.time() 498 | print(' Evaluating Graph Network...') 499 | gn = GN(seqs[i], lengths[i]) 500 | print(' Recover the number missing detections:', gn.missingCounter) 501 | print(' The number of sideConnections:', gn.sideConnection) 502 | print('Time consuming:', (time.time() - start) / 60.0) 503 | print('Time consuming:', (time.time() - head) / 60.0) 504 | except KeyboardInterrupt: 505 | print('Time consuming:', time.time() - start) 506 | print('') 507 | print('-' * 90) 508 | print('Existing from training early.') 509 | -------------------------------------------------------------------------------- /GN/test_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import cv2, random, torch, shutil, os 3 | import numpy as np 4 | from math import * 5 | from PIL import Image 6 | import torch.nn.functional as F 7 | from mot_model import appearance 8 | from global_set import edge_initial, test_gt_det, tau_dis, app_fine_tune, fine_tune_dir, \ 9 | tau_threshold # , tau_conf_score 10 | from torchvision.transforms import ToTensor 11 | 12 | 13 | def load_img(filepath): 14 | img = Image.open(filepath).convert('RGB') 15 | return img 16 | 17 | 18 | class ADatasetFromFolder(data.Dataset): 19 | def __init__(self, part, part_I, tau, cuda=True, show=0): 20 | super(ADatasetFromFolder, self).__init__() 21 | self.dir = part 22 | self.cleanPath(part) 23 | self.img_dir = part_I + '/img1/' 24 | self.gt_dir = part + '/gt/' 25 | self.det_dir = part + '/det/' 26 | self.device = torch.device("cuda" if cuda else "cpu") 27 | self.tau_conf_score = tau 28 | self.show = show 29 | 30 | self.loadAModel() 31 | self.getSeqL() 32 | if test_gt_det: 33 | self.readBBx_gt() 34 | else: 35 | self.readBBx_det() 36 | self.initBuffer() 37 | 38 | def cleanPath(self, part): 39 | if os.path.exists(part + '/gts/'): 40 | shutil.rmtree(part + '/gts/') 41 | if os.path.exists(part + '/dets/'): 42 | shutil.rmtree(part + '/dets/') 43 | 44 | def loadAModel(self): 45 | if app_fine_tune: 46 | self.Appearance = torch.load(fine_tune_dir) 47 | else: 48 | self.Appearance = appearance() 49 | self.Appearance.to(self.device) 50 | self.Appearance.eval() # fixing the BatchN layer 51 | 52 | def getSeqL(self): 53 | # get the length of the sequence 54 | info = self.dir + '/seqinfo.ini' 55 | f = open(info, 'r') 56 | f.readline() 57 | for line in f.readlines(): 58 | line = line.strip().split('=') 59 | if line[0] == 'seqLength': 60 | self.seqL = int(line[1]) 61 | f.close() 62 | # print 'The length of the sequence:', self.seqL 63 | 64 | def readBBx_gt(self): 65 | # get the gt 66 | self.bbx = [[] for i in range(self.seqL + 1)] 67 | gt = self.gt_dir + 'gt.txt' 68 | f = open(gt, 'r') 69 | pre = -1 70 | for line in f.readlines(): 71 | line = line.strip().split(',') 72 | if line[7] == '1': 73 | index = int(line[0]) 74 | id = int(line[1]) 75 | x, y = int(line[2]), int(line[3]) 76 | w, h = int(line[4]), int(line[5]) 77 | conf_score, l, vr = float(line[6]), int(line[7]), float(line[8]) 78 | 79 | # sweep the invisible head-bbx from the training data 80 | if pre != id and vr == 0: 81 | continue 82 | 83 | pre = id 84 | self.bbx[index].append([x, y, w, h, id, conf_score, vr]) 85 | f.close() 86 | 87 | gt_out = open(self.gt_dir + 'gt_det.txt', 'w') 88 | for index in range(1, self.seqL + 1): 89 | for bbx in self.bbx[index]: 90 | x, y, w, h, id, conf_score, vr = bbx 91 | print >> gt_out, '%d,-1,%d,%d,%d,%d,%f,-1,-1,-1' % (index, x, y, w, h, conf_score) 92 | gt_out.close() 93 | 94 | def readBBx_det(self): 95 | # get the gt 96 | self.bbx = [[] for i in range(self.seqL + 1)] 97 | det = self.det_dir + 'det.txt' 98 | f = open(det, 'r') 99 | for line in f.readlines(): 100 | line = line.strip().split(',') 101 | index = int(line[0]) 102 | id = int(line[1]) 103 | x, y = int(float(line[2])), int(float(line[3])) 104 | w, h = int(float(line[4])), int(float(line[5])) 105 | conf_score = float(line[6]) 106 | if conf_score >= self.tau_conf_score: 107 | self.bbx[index].append([x, y, w, h, conf_score]) 108 | f.close() 109 | 110 | def initBuffer(self): 111 | if self.show: 112 | cv2.namedWindow('view', flags=0) 113 | cv2.namedWindow('crop', flags=0) 114 | self.f_step = 1 # the index of next frame in the process 115 | self.cur = 0 # the index of current frame in the detections 116 | self.nxt = 1 # the index of next frame in the detections 117 | self.detections = [None, None] # the buffer to storing images: current & next frame 118 | self.feature(1) 119 | 120 | def setBuffer(self, f): 121 | self.m = 0 122 | counter = -1 123 | while self.m == 0: 124 | counter += 1 125 | self.f_step = f + counter 126 | self.feature(1) 127 | self.m = len(self.detections[self.cur]) 128 | if counter > 0: 129 | print(' Empty in setBuffer:', counter) 130 | return counter 131 | 132 | def fixBB(self, x, y, w, h, size): 133 | width, height = size 134 | w = min(w + x, width) 135 | h = min(h + y, height) 136 | x = max(x, 0) 137 | y = max(y, 0) 138 | w -= x 139 | h -= y 140 | return x, y, w, h 141 | 142 | def IOU(self, Reframe, GTframe): 143 | """ 144 | Compute the Intersection of Union 145 | :param Reframe: x, y, w, h 146 | :param GTframe: x, y, w, h 147 | :return: Ratio 148 | """ 149 | if edge_initial == 1: 150 | return random.random() 151 | elif edge_initial == 3: 152 | return 0.5 153 | x1 = Reframe[0] 154 | y1 = Reframe[1] 155 | width1 = Reframe[2] 156 | height1 = Reframe[3] 157 | 158 | x2 = GTframe[0] 159 | y2 = GTframe[1] 160 | width2 = GTframe[2] 161 | height2 = GTframe[3] 162 | 163 | endx = max(x1 + width1, x2 + width2) 164 | startx = min(x1, x2) 165 | width = width1 + width2 - (endx - startx) 166 | 167 | endy = max(y1 + height1, y2 + height2) 168 | starty = min(y1, y2) 169 | height = height1 + height2 - (endy - starty) 170 | 171 | if width <= 0 or height <= 0: 172 | ratio = 0 173 | else: 174 | Area = width * height 175 | Area1 = width1 * height1 176 | Area2 = width2 * height2 177 | ratio = Area * 1. / (Area1 + Area2 - Area) 178 | return ratio 179 | 180 | def getMN(self, m, n): 181 | ans = [[None for i in range(n)] for i in range(m)] 182 | for i in range(m): 183 | Reframe = self.bbx[self.f_step - self.gap][i] 184 | for j in range(n): 185 | GTframe = self.bbx[self.f_step][j] 186 | p = self.IOU(Reframe, GTframe) 187 | # 1 - match, 0 - mismatch 188 | ans[i][j] = torch.FloatTensor([(1 - p) / 100.0, p / 100.0]).to(self.device) 189 | return ans 190 | 191 | def aggregate(self, set): 192 | if len(set): 193 | rho = sum(set) 194 | return rho / len(set) 195 | print(' The set is empty!') 196 | return None 197 | 198 | def distance(self, a_bbx, b_bbx): 199 | w = min(float(a_bbx[2]) * tau_dis, float(b_bbx[2]) * tau_dis) 200 | dx = float(a_bbx[0] + a_bbx[2] / 2) - float(b_bbx[0] + b_bbx[2] / 2) 201 | dy = float(a_bbx[1] + a_bbx[3] / 2) - float(b_bbx[1] + b_bbx[3] / 2) 202 | d = sqrt(dx * dx + dy * dy) 203 | if d <= w: 204 | return 0.0 205 | return tau_threshold 206 | 207 | def getRet(self): 208 | cur = self.f_step - self.gap 209 | ret = [[0.0 for i in range(self.n)] for j in range(self.m)] 210 | for i in range(self.m): 211 | bbx1 = self.bbx[cur][i] 212 | for j in range(self.n): 213 | ret[i][j] = self.distance(bbx1, self.bbx[self.f_step][j]) 214 | return ret 215 | 216 | def getApp(self, tag, index): 217 | cur = self.cur if tag else self.nxt 218 | if torch.is_tensor(index): 219 | n = index.numel() 220 | if n < 0: 221 | print('The tensor is empyt!') 222 | return None 223 | if n == 1: 224 | return self.detections[cur][index[0]][0] 225 | ans = torch.cat((self.detections[cur][index[0]][0], self.detections[cur][index[1]][0]), dim=0) 226 | for i in range(2, n): 227 | ans = torch.cat((ans, self.detections[cur][index[i]][0]), dim=0) 228 | return ans 229 | return self.detections[cur][index][0] 230 | 231 | def moveApp(self, index): 232 | self.bbx[self.f_step].append(self.bbx[self.f_step - self.gap][index]) # add the bbx 233 | self.detections[self.nxt].append(self.detections[self.cur][index]) # add the appearance 234 | 235 | def swapFC(self): 236 | self.cur = self.cur ^ self.nxt 237 | self.nxt = self.cur ^ self.nxt 238 | self.cur = self.cur ^ self.nxt 239 | 240 | def resnet34(self, img): 241 | bbx = ToTensor()(img) 242 | bbx = bbx.to(self.device) 243 | bbx = bbx.view(-1, bbx.size(0), bbx.size(1), bbx.size(2)) 244 | ret = self.Appearance(bbx) 245 | ret = ret.view(1, -1) 246 | return ret 247 | 248 | def feature(self, tag=0): 249 | ''' 250 | Getting the appearance of the detections in current frame 251 | :param tag: 1 - initiating 252 | :param show: 1 - show the cropped & src image 253 | :return: None 254 | ''' 255 | apps = [] 256 | with torch.no_grad(): 257 | bbx_container = [] 258 | for bbx in self.bbx[self.f_step]: 259 | """ 260 | Bellow Conditions needed be taken into consideration: 261 | x, y < 0 and x+w > W, y+h > H 262 | """ 263 | img = load_img(self.img_dir + '%06d.jpg' % self.f_step) # initial with loading the first frame 264 | if test_gt_det: 265 | x, y, w, h, id, conf_score, vr = bbx 266 | else: 267 | x, y, w, h, conf_score = bbx 268 | x, y, w, h = self.fixBB(x, y, w, h, img.size) 269 | if test_gt_det: 270 | bbx_container.append([x, y, w, h, id, conf_score, vr]) 271 | else: 272 | bbx_container.append([x, y, w, h, conf_score]) 273 | crop = img.crop([x, y, x + w, y + h]) 274 | bbx = crop.resize((224, 224), Image.ANTIALIAS) 275 | ret = self.resnet34(bbx) 276 | app = ret.data 277 | apps.append([app, conf_score]) 278 | 279 | if self.show: 280 | img = np.asarray(img) 281 | crop = np.asarray(crop) 282 | if test_gt_det: 283 | print('%06d' % self.f_step, id, vr, '***', ) 284 | else: 285 | print('%06d' % self.f_step, conf_score, vr, '***', ) 286 | print(w, h, '-', ) 287 | print(len(crop[0]), len(crop)) 288 | cv2.imshow('crop', crop) 289 | cv2.imshow('view', img) 290 | cv2.waitKey(34) 291 | input('Continue?') 292 | # cv2.waitKey(34) 293 | self.bbx[self.f_step] = bbx_container 294 | if tag: 295 | self.detections[self.cur] = apps 296 | else: 297 | self.detections[self.nxt] = apps 298 | 299 | def loadNext(self): 300 | self.m = len(self.detections[self.cur]) 301 | 302 | self.gap = 0 303 | self.n = 0 304 | while self.n == 0: 305 | self.f_step += 1 306 | self.feature() 307 | self.n = len(self.detections[self.nxt]) 308 | self.gap += 1 309 | 310 | if self.gap > 1: 311 | print(' Empty in loadNext:', self.f_step - self.gap + 1, '-', self.gap - 1) 312 | 313 | self.candidates = [] 314 | self.edges = self.getMN(self.m, self.n) 315 | 316 | for i in range(self.m): 317 | for j in range(self.n): 318 | e = self.edges[i][j] 319 | self.candidates.append([e, i, j]) 320 | 321 | # print ' The index of the next frame', self.f_step, len(self.bbx) 322 | return self.gap 323 | 324 | def __getitem__(self, index): 325 | return self.candidates[index] 326 | 327 | def __len__(self): 328 | return len(self.candidates) 329 | -------------------------------------------------------------------------------- /Motion1/copyfile.py: -------------------------------------------------------------------------------- 1 | import shutil, os 2 | from global_set import mot_dataset_dir 3 | 4 | name = 'motmetrics' 5 | types = ['POI'] 6 | 7 | seqs = [2, 4, 5, 9, 10, 11, 13] # the set of sequences 8 | lengths = [600, 1050, 837, 525, 654, 900, 750] # the length of the sequence 9 | 10 | test_seqs = [1, 3, 6, 7, 8, 12, 14] 11 | test_lengths = [450, 1500, 1194, 500, 625, 900, 750] 12 | 13 | # copy the results for testing sets 14 | for type in types: 15 | for i in range(len(seqs)): 16 | src_dir = 'results/%02d/%d/%s_%s/res.txt' % (test_seqs[i], test_lengths[i], name, type) 17 | 18 | t = type 19 | if type == 'DPM0' or type == 'POI': 20 | t = 'DPM' 21 | 22 | des_d = 'mot16/' 23 | if not os.path.exists(des_d): 24 | os.mkdir(des_d) 25 | des_dir = des_d + 'MOT16-%02d.txt' % (test_seqs[i]) 26 | 27 | print(src_dir) 28 | print(des_dir) 29 | shutil.copyfile(src_dir, des_dir) 30 | 31 | types = ['POI'] 32 | for type in types: 33 | for i in range(len(seqs)): 34 | src_dir = mot_dataset_dir + 'MOT17/train/MOT17-%02d-%s/gt/gt.txt' % (seqs[i], type) 35 | 36 | des_dir = des_d + 'MOT16-%02d.txt' % (seqs[i]) 37 | 38 | print(src_dir) 39 | print(des_dir) 40 | shutil.copyfile(src_dir, des_dir) 41 | -------------------------------------------------------------------------------- /Motion1/dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import random, torch, shutil, os 3 | from PIL import Image 4 | import torch.nn.functional as F 5 | from global_set import edge_initial, overlap 6 | 7 | 8 | def load_img(filepath): 9 | img = Image.open(filepath).convert('RGB') 10 | return img 11 | 12 | 13 | class DatasetFromFolder(data.Dataset): 14 | def __init__(self, part, cuda=True): 15 | super(DatasetFromFolder, self).__init__() 16 | self.dir = part 17 | self.cleanPath(part) 18 | self.img_dir = part + '/img1/' 19 | self.gt_dir = part + '/gt/' 20 | self.det_dir = part + '/det/' 21 | 22 | self.device = torch.device("cuda" if cuda else "cpu") 23 | 24 | self.getSeqL() 25 | self.readBBx() 26 | self.initBuffer() 27 | print(' Data loader is already!') 28 | 29 | def cleanPath(self, part): 30 | if os.path.exists(part + '/gts/'): 31 | shutil.rmtree(part + '/gts/') 32 | if os.path.exists(part + '/dets/'): 33 | shutil.rmtree(part + '/dets/') 34 | 35 | def getSeqL(self): 36 | # get the length of the sequence 37 | info = self.dir + '/seqinfo.ini' 38 | f = open(info, 'r') 39 | f.readline() 40 | for line in f.readlines(): 41 | line = line.strip().split('=') 42 | if line[0] == 'seqLength': 43 | self.seqL = int(line[1]) 44 | f.close() 45 | print(' The length of the sequence:', self.seqL) 46 | 47 | def fixBB(self, x, y, w, h, size): 48 | width, height = size 49 | w = min(w + x, width) 50 | h = min(h + y, height) 51 | x = max(x, 0) 52 | y = max(y, 0) 53 | w -= x 54 | h -= y 55 | return x, y, w, h 56 | 57 | def generator(self, bbx): 58 | if random.randint(0, 1): 59 | x, y, w, h = bbx 60 | tmp = overlap * 2 / (1 + overlap) 61 | n_w = random.uniform(tmp * w, w) 62 | n_h = tmp * w * h / n_w 63 | 64 | direction = random.randint(1, 4) 65 | if direction == 1: 66 | x = x + n_w - w 67 | y = y + n_h - h 68 | elif direction == 2: 69 | x = x - n_w + w 70 | y = y + n_h - h 71 | elif direction == 3: 72 | x = x + n_w - w 73 | y = y - n_h + h 74 | else: 75 | x = x - n_w + w 76 | y = y - n_h + h 77 | ans = [x, y, w, h] 78 | return ans 79 | return bbx 80 | 81 | def readBBx(self): 82 | # get the gt 83 | self.bbx = [[] for i in range(self.seqL + 1)] 84 | imgs = [None for i in range(self.seqL + 1)] 85 | for i in range(1, self.seqL + 1): 86 | imgs[i] = load_img(self.img_dir + '%06d.jpg' % i) 87 | gt = self.gt_dir + 'gt.txt' 88 | f = open(gt, 'r') 89 | for line in f.readlines(): 90 | line = line.strip().split(',') 91 | if line[7] == '1': 92 | index = int(line[0]) 93 | id = int(line[1]) 94 | x, y = float(line[2]), float(line[3]) 95 | w, h = float(line[4]), float(line[5]) 96 | conf_score, l, vr = float(line[6]), int(line[7]), float(line[8]) 97 | 98 | img = imgs[index] 99 | x, y, w, h = self.generator([x, y, w, h]) 100 | x, y, w, h = self.fixBB(x, y, w, h, img.size) 101 | width, height = float(img.size[0]), float(img.size[1]) 102 | # self.bbx[index].append([x, y, w, h, id, vr]) 103 | self.bbx[index].append([x / width, y / height, w / width, h / height, id, vr]) 104 | f.close() 105 | 106 | def initBuffer(self): 107 | self.f_step = 1 # the index of next frame in the process 108 | self.cur = 0 # the index of current frame in the detections 109 | self.nxt = 1 # the index of next frame in the detections 110 | self.detections = [None, None] # the buffer to storing images: current & next frame 111 | self.feature(1) 112 | 113 | def IOU(self, Reframe, GTframe): 114 | """ 115 | Compute the Intersection of Union 116 | :param Reframe: x, y, w, h 117 | :param GTframe: x, y, w, h 118 | :return: Ratio 119 | """ 120 | if edge_initial == 1: 121 | return random.random() 122 | elif edge_initial == 3: 123 | return 0.5 124 | x1 = Reframe[0] 125 | y1 = Reframe[1] 126 | width1 = Reframe[2] 127 | height1 = Reframe[3] 128 | 129 | x2 = GTframe[0] 130 | y2 = GTframe[1] 131 | width2 = GTframe[2] 132 | height2 = GTframe[3] 133 | 134 | endx = max(x1 + width1, x2 + width2) 135 | startx = min(x1, x2) 136 | width = width1 + width2 - (endx - startx) 137 | 138 | endy = max(y1 + height1, y2 + height2) 139 | starty = min(y1, y2) 140 | height = height1 + height2 - (endy - starty) 141 | 142 | if width <= 0 or height <= 0: 143 | ratio = 0.0 144 | else: 145 | Area = width * height 146 | Area1 = width1 * height1 147 | Area2 = width2 * height2 148 | ratio = Area * 1. / (Area1 + Area2 - Area) 149 | return ratio 150 | 151 | def swapFC(self): 152 | self.getVelocity() 153 | self.cur = self.cur ^ self.nxt 154 | self.nxt = self.cur ^ self.nxt 155 | self.cur = self.cur ^ self.nxt 156 | 157 | def getMotion(self, tag, index, pre_index=None): 158 | cur = self.cur if tag else self.nxt 159 | if torch.is_tensor(index): 160 | n = index.numel() 161 | if n < 0: 162 | print('The tensor is empyt!') 163 | return None 164 | if tag == 0: 165 | for k in range(n): 166 | i, j = pre_index[k].item(), index[k].item() 167 | self.updateVelocity(i, j) 168 | if n == 1: 169 | return self.detections[cur][index[0]][0][pre_index[0]] 170 | ans = torch.cat((self.detections[cur][index[0]][0][pre_index[0]], 171 | self.detections[cur][index[1]][0][pre_index[1]]), dim=0) 172 | for i in range(2, n): 173 | ans = torch.cat((ans, self.detections[cur][index[i]][0][pre_index[i]]), dim=0) 174 | return ans 175 | if n == 1: 176 | return self.detections[cur][index[0]][0][0] 177 | ans = torch.cat((self.detections[cur][index[0]][0][0], 178 | self.detections[cur][index[1]][0][0]), dim=0) 179 | for i in range(2, n): 180 | ans = torch.cat((ans, self.detections[cur][index[i]][0][0]), dim=0) 181 | return ans 182 | if tag == 0: 183 | self.updateVelocity(pre_index, index) 184 | return self.detections[cur][index][0][pre_index] 185 | return self.detections[cur][index][0][0] 186 | 187 | def updateVelocity(self, i, j, tag=True): 188 | ''' 189 | :param i: cur_index, -1 - birth 190 | :param j: nxt_index 191 | :param tag: True - update the velocity in the next frame, False - Write down the final velocity 192 | :return: 193 | ''' 194 | v_x, v_y = 0.0, 0.0 195 | if i >= 0: 196 | x1, y1, w1, h1, id1, vr1 = self.bbx[self.f_step - 1][i] 197 | x2, y2, w2, h2, id2, vr2 = self.bbx[self.f_step][j] 198 | v_x = x2 + w2 / 2 - (x1 + w1 / 2) 199 | v_y = y2 + h2 / 2 - (y1 + h1 / 2) 200 | if tag: 201 | self.detections[self.nxt][j][0][i][0][4] = v_x 202 | self.detections[self.nxt][j][0][i][0][5] = v_y 203 | else: 204 | cur_m = self.detections[self.nxt][j][0][0] 205 | cur_m[0][4] = v_x 206 | cur_m[0][5] = v_y 207 | self.detections[self.nxt][j][0] = [cur_m] 208 | 209 | def getVelocity(self): 210 | remaining = set(j for j in range(self.n)) 211 | # For the connection between two detections 212 | for (i, j) in self.matches: 213 | remaining.remove(j) 214 | self.updateVelocity(i, j, False) 215 | 216 | # For the birth of objects 217 | for j in remaining: 218 | self.updateVelocity(-1, j, False) 219 | 220 | def getMN(self): 221 | cur = self.f_step - 1 222 | ans = [[None for j in range(self.n)] for i in range(self.m)] 223 | for i in range(self.m): 224 | Reframe = self.bbx[cur][i] 225 | for j in range(self.n): 226 | GTframe = self.bbx[self.f_step][j] 227 | p = self.IOU(Reframe, GTframe) 228 | # 1 - match, 0 - mismatch 229 | ans[i][j] = torch.FloatTensor([(1 - p) / 100.0, p / 100.0]) 230 | return ans 231 | 232 | def aggregate(self, sets): 233 | n = len(sets) 234 | if n: 235 | rho = sum(sets) 236 | return rho / n 237 | print(' The set is empty!') 238 | return None 239 | 240 | def initEC(self): 241 | self.m = len(self.detections[self.cur]) 242 | self.n = len(self.detections[self.nxt]) 243 | self.edges = self.getMN() 244 | self.candidates = [] 245 | self.matches = [] 246 | self.gts = [[None for j in range(self.n)] for i in range(self.m)] 247 | self.step_gt = 0.0 248 | for i in range(self.m): 249 | for j in range(self.n): 250 | tag = int(self.detections[self.cur][i][1] == self.detections[self.nxt][j][1]) 251 | if tag: 252 | self.matches.append((i, j)) 253 | self.gts[i][j] = torch.LongTensor([tag]) 254 | self.step_gt += tag * 1.0 255 | 256 | es = [] 257 | # vs_index = 0 258 | for i in range(self.m): 259 | # vr_index = self.m 260 | for j in range(self.n): 261 | e = self.edges[i][j] 262 | gt = self.gts[i][j] 263 | es.append(e) 264 | self.candidates.append([e, gt, i, j]) 265 | # vr_index += 1 266 | # vs_index += 1 267 | 268 | vs = [] 269 | for i in range(2): 270 | n = len(self.detections[i]) 271 | for j in range(n): 272 | v = self.detections[i][j][0][0] 273 | vs.append(v) 274 | 275 | self.E = self.aggregate(es).to(self.device).view(1, -1) 276 | self.V = self.aggregate(vs) 277 | 278 | def setBuffer(self, f): 279 | self.f_step = f 280 | self.feature(1) 281 | 282 | def feature(self, tag=0): 283 | ''' 284 | Getting the motion of the detections in current frame 285 | :param tag: 1 - initiating 286 | :return: None 287 | ''' 288 | motions = [] 289 | with torch.no_grad(): 290 | m = 1 if tag else len(self.bbx[self.f_step - 1]) 291 | for bbx in self.bbx[self.f_step]: 292 | """ 293 | Bellow Conditions need to be taken into consideration: 294 | x, y < 0 and x+w > W, y+h > H 295 | """ 296 | x, y, w, h, id, vr = bbx 297 | x += w / 2 298 | y += h / 2 299 | cur_m = [] 300 | for i in range(m): 301 | cur_m.append(torch.FloatTensor([[x, y, w, h, 0.0, 0.0]]).to(self.device)) 302 | motions.append([cur_m, id]) 303 | if tag: 304 | self.detections[self.cur] = motions 305 | else: 306 | self.detections[self.nxt] = motions 307 | 308 | def loadNext(self): 309 | self.f_step += 1 310 | self.feature() 311 | self.initEC() 312 | # print ' The index of the next frame', self.f_step 313 | # print self.detections[self.cur] 314 | # print self.detections[self.nxt] 315 | 316 | def __getitem__(self, index): 317 | return self.candidates[index] 318 | 319 | def __len__(self): 320 | return len(self.candidates) 321 | -------------------------------------------------------------------------------- /Motion1/global_set.py: -------------------------------------------------------------------------------- 1 | #mot_dataset_dir = '/media/codinglee/DATA/Ubuntu16.04/Desktop/MOT/' 2 | mot_dataset_dir = '../MOT/' 3 | 4 | model_dir = 'model/' 5 | 6 | u_initial = 1 # 1 - random, 0 - 0 7 | 8 | edge_initial = 0 # 1 - random, 0 - IoU 9 | 10 | criterion_s = 0 # 1 - MSELoss, 0 - CrossEntropyLoss 11 | 12 | u_evaluation = 0 # 1 - initiate randomly, 0 - initiate with the u learned 13 | 14 | test_gt_det = 0 # 1 - detections of gt, 0 - detections of det 15 | 16 | u_update = 1 # 1 - update when testing, 0 - without updating 17 | if u_update: 18 | u_dir = '_uupdate' 19 | else: 20 | u_dir = '' 21 | 22 | decay = 1.3 23 | decay_dir = '_decay' 24 | 25 | f_gap = 5 26 | if f_gap: 27 | recover_dir = '_Recover' 28 | else: 29 | recover_dir = '_NoRecover' 30 | 31 | tau_threshold = 1.0 # The threshold of matching cost 32 | tau_dis = 2.0 # The times of the current bbx's scale 33 | gap = 25 # max frame number for side connection 34 | 35 | show_recovering = 0 # 1 - 11, 0 - 10 36 | 37 | overlap = 0.85 38 | 39 | debug = 1 # 1 - debugging 40 | -------------------------------------------------------------------------------- /Motion1/m_mot_model.py: -------------------------------------------------------------------------------- 1 | import torch, math 2 | import torch.nn as nn 3 | from global_set import criterion_s 4 | 5 | v_num = 6 # Only take the appearance into consideration, and add velocity when basic model works 6 | u_num = 100 7 | e_num = 1 if criterion_s else 2 8 | 9 | uphi_n = 256 10 | ephi_n = 256 11 | 12 | 13 | class uphi(nn.Module): 14 | def __init__(self): 15 | super(uphi, self).__init__() 16 | self.features = nn.Sequential( 17 | nn.Linear(u_num + v_num + e_num, uphi_n), 18 | nn.LeakyReLU(inplace=True), 19 | nn.Linear(uphi_n, u_num), 20 | ) 21 | 22 | def forward(self, e, v, u): 23 | """ 24 | The network which updates the global variable u 25 | :param e: the aggregation of the probability 26 | :param v: the aggregation of the node 27 | :param u: global variable 28 | """ 29 | # print 'U:', e.size(), v.size(), u.size() 30 | bs = e.size()[0] 31 | if u.size()[0] == 1: 32 | if bs == 1: 33 | tmp = u 34 | else: 35 | tmp = torch.cat((u, u), dim=0) 36 | for i in range(2, bs): 37 | tmp = torch.cat((tmp, u), dim=0) 38 | else: 39 | tmp = u 40 | x = torch.cat((e, v), dim=1) 41 | x = torch.cat((x, tmp), dim=1) 42 | return self.features(x) 43 | 44 | 45 | class ephi(nn.Module): 46 | def __init__(self): 47 | super(ephi, self).__init__() 48 | self.features = nn.Sequential( 49 | nn.Linear(u_num + v_num * 2 + e_num, ephi_n), 50 | nn.LeakyReLU(inplace=True), 51 | nn.Linear(ephi_n, e_num), 52 | ) 53 | 54 | def forward(self, e, v1, v2, u): 55 | """ 56 | The network which updates the probability e 57 | :param e: the probability between two detections 58 | :param v1: the sender 59 | :param v2: the receiver 60 | :param u: global variable 61 | """ 62 | # print 'E:', e.size(), v1.size(), v2.size(), u.size() 63 | bs = e.size()[0] 64 | if u.size()[0] == 1: 65 | if bs == 1: 66 | tmp = u 67 | else: 68 | tmp = torch.cat((u, u), dim=0) 69 | for i in range(2, bs): 70 | tmp = torch.cat((tmp, u), dim=0) 71 | else: 72 | tmp = u 73 | x = torch.cat((e, v1), dim=1) 74 | x = torch.cat((x, v2), dim=1) 75 | x = torch.cat((x, tmp), dim=1) 76 | return self.features(x) 77 | -------------------------------------------------------------------------------- /Motion1/test.py: -------------------------------------------------------------------------------- 1 | # from __future__ import print_function 2 | import numpy as np 3 | from munkres import Munkres 4 | import torch.nn.functional as F 5 | import time, os, shutil 6 | from global_set import edge_initial, test_gt_det, \ 7 | tau_threshold, gap, f_gap, show_recovering, decay, u_update, model_dir, mot_dataset_dir 8 | from test_dataset import DatasetFromFolder 9 | from m_mot_model import * 10 | 11 | torch.manual_seed(123) 12 | np.random.seed(123) 13 | 14 | 15 | def deleteDir(del_dir): 16 | shutil.rmtree(del_dir) 17 | 18 | 19 | year = 17 20 | 21 | type = '' 22 | t_dir = '' # the dir of tracking results 23 | sequence_dir = '' # the dir of the testing video sequence 24 | 25 | seqs = [2, 4, 5, 9, 10, 11, 13] # the set of sequences 26 | lengths = [600, 1050, 837, 525, 654, 900, 750] # the length of the sequence 27 | 28 | test_seqs = [1, 3, 6, 7, 8, 12, 14] 29 | test_lengths = [450, 1500, 1194, 500, 625, 900, 750] 30 | 31 | tt_tag = 1 # 1 - test, 0 - train 32 | 33 | tau_conf_score = 0.0 34 | 35 | 36 | class GN(): 37 | def __init__(self, seq_index, seq_len, cuda=True): 38 | ''' 39 | Evaluating with the MotMetrics 40 | :param seq_index: the number of the sequence 41 | :param seq_len: the length of the sequence 42 | :param cuda: True - GPU, False - CPU 43 | ''' 44 | self.bbx_counter = 0 45 | self.seq_index = seq_index 46 | self.hungarian = Munkres() 47 | self.device = torch.device("cuda" if cuda else "cpu") 48 | self.seq_len = seq_len 49 | self.missingCounter = 0 50 | self.sideConnection = 0 51 | 52 | print(' Loading the model...') 53 | self.loadModel() 54 | 55 | self.out_dir = t_dir + 'motmetrics_%s/' % (type) 56 | print(self.out_dir) 57 | 58 | if not os.path.exists(self.out_dir): 59 | os.mkdir(self.out_dir) 60 | else: 61 | deleteDir(self.out_dir) 62 | os.mkdir(self.out_dir) 63 | self.initOut() 64 | 65 | def initOut(self): 66 | print(' Loading Data...') 67 | self.train_set = DatasetFromFolder(sequence_dir, mot_dataset_dir + 'MOT16/test/MOT16-%02d' % self.seq_index, 68 | tau_conf_score) 69 | 70 | detection_dir = self.out_dir + 'res_det.txt' 71 | res_training = self.out_dir + 'res.txt' # the tracking results 72 | self.createTxt(detection_dir) 73 | self.createTxt(res_training) 74 | self.copyLines(self.seq_index, 1, detection_dir, self.seq_len, 1) 75 | 76 | self.evaluation(1, self.seq_len, detection_dir, res_training) 77 | 78 | def getSeqL(self, info): 79 | # get the length of the sequence 80 | f = open(info, 'r') 81 | f.readline() 82 | for line in f.readlines(): 83 | line = line.strip().split('=') 84 | if line[0] == 'seqLength': 85 | seqL = int(line[1]) 86 | f.close() 87 | return seqL 88 | 89 | def copyLines(self, seq, head, gt_seq, tail=-1, tag=0): 90 | ''' 91 | Copy the groun truth within [head, head+num] 92 | :param seq: the number of the sequence 93 | :param head: the head frame number 94 | :param tail: the number the clipped sequence 95 | :param gt_seq: the dir of the output file 96 | :return: None 97 | ''' 98 | if tt_tag: 99 | basic_dir = mot_dataset_dir + 'MOT%d/test/MOT%d-%02d-%s/' % (year, year, seq, type) 100 | else: 101 | basic_dir = mot_dataset_dir + 'MOT%d/train/MOT%d-%02d-%s/' % (year, year, seq, type) 102 | print(' Testing on', basic_dir, 'Length:', self.seq_len) 103 | seqL = tail if tail != -1 else self.getSeqL(basic_dir + 'seqinfo.ini') 104 | 105 | det_dir = 'gt/gt_det.txt' if test_gt_det else 'det/det.txt' 106 | seq_dir = basic_dir + ('gt/gt.txt' if tag == 0 else det_dir) 107 | inStream = open(seq_dir, 'r') 108 | 109 | outStream = open(gt_seq, 'w') 110 | for line in inStream.readlines(): 111 | line = line.strip() 112 | attrs = line.split(',') 113 | f_num = int(attrs[0]) 114 | if f_num >= head and f_num <= seqL: 115 | outStream.write(line + '\n') 116 | outStream.close() 117 | 118 | inStream.close() 119 | return seqL 120 | 121 | def createTxt(self, out_file): 122 | f = open(out_file, 'w') 123 | f.close() 124 | 125 | def loadModel(self): 126 | tail = 13 127 | self.Uphi = torch.load(model_dir + 'uphi_%d.pth' % tail).to(self.device) 128 | self.Ephi = torch.load(model_dir + 'ephi_%d.pth' % tail).to(self.device) 129 | self.u = torch.load(model_dir + 'u_%d.pth' % tail).to(self.device) 130 | 131 | def swapFC(self): 132 | self.cur = self.cur ^ self.nxt 133 | self.nxt = self.cur ^ self.nxt 134 | self.cur = self.cur ^ self.nxt 135 | 136 | def linearModel(self, out, attr1, attr2): 137 | # print 'I got you! *.*' 138 | t = attr1[-1] 139 | self.sideConnection += 1 140 | if t > f_gap: 141 | return 142 | frame = int(attr1[0]) 143 | x1, y1, w1, h1 = float(attr1[2]), float(attr1[3]), float(attr1[4]), float(attr1[5]) 144 | x2, y2, w2, h2 = float(attr2[2]), float(attr2[3]), float(attr2[4]), float(attr2[5]) 145 | 146 | x_delta = (x2 - x1) / t 147 | y_delta = (y2 - y1) / t 148 | w_delta = (w2 - w1) / t 149 | h_delta = (h2 - h1) / t 150 | 151 | for i in range(1, t): 152 | frame += 1 153 | x1 += x_delta 154 | y1 += y_delta 155 | w1 += w_delta 156 | h1 += h_delta 157 | attr1[0] = str(frame) 158 | attr1[2] = str(x1) 159 | attr1[3] = str(y1) 160 | attr1[4] = str(w1) 161 | attr1[5] = str(h1) 162 | line = '' 163 | for attr in attr1[:-1]: 164 | line += attr + ',' 165 | if show_recovering: 166 | line += '1' 167 | else: 168 | line = line[:-1] 169 | out.write(line + '\n') 170 | self.bbx_counter += 1 171 | self.missingCounter += t - 1 172 | 173 | def evaluation(self, head, tail, gtFile, outFile): 174 | ''' 175 | Evaluation on detections 176 | :param head: the head frame number 177 | :param tail: the tail frame number 178 | :param gtFile: the ground truth file name 179 | :param outFile: the name of output file 180 | :return: None 181 | ''' 182 | gtIn = open(gtFile, 'r') 183 | self.cur, self.nxt = 0, 1 184 | line_con = [[], []] 185 | id_con = [[], []] 186 | id_step = 1 187 | 188 | step = head + self.train_set.setBuffer(head) 189 | while step < tail: 190 | # print '*********************************' 191 | t_gap = self.train_set.loadNext() 192 | step += t_gap 193 | # print head+step, 'F', 194 | print(step, end=' ') 195 | if step % 100 == 0: 196 | print('') 197 | 198 | # print 'Fo' 199 | m = self.train_set.m 200 | n = self.train_set.n 201 | # print 'm = %d, n = %d'%(m, n) 202 | if n == 0: 203 | print('There is no detection in the rest of sequence!') 204 | break 205 | 206 | if id_step == 1: 207 | out = open(outFile, 'a') 208 | i = 0 209 | while i < m: 210 | attrs = gtIn.readline().strip().split(',') 211 | if float(attrs[6]) >= tau_conf_score: 212 | attrs.append(1) 213 | attrs[1] = str(id_step) 214 | line = '' 215 | for attr in attrs[:-1]: 216 | line += attr + ',' 217 | if show_recovering: 218 | line += '0' 219 | else: 220 | line = line[:-1] 221 | out.write(line + '\n') 222 | self.bbx_counter += 1 223 | line_con[self.cur].append(attrs) 224 | id_con[self.cur].append(id_step) 225 | id_step += 1 226 | i += 1 227 | out.close() 228 | 229 | i = 0 230 | while i < n: 231 | attrs = gtIn.readline().strip().split(',') 232 | if float(attrs[6]) >= tau_conf_score: 233 | attrs.append(1) 234 | line_con[self.nxt].append(attrs) 235 | id_con[self.nxt].append(-1) 236 | i += 1 237 | 238 | # update the edges 239 | # print 'T', 240 | u_ = self.Uphi(self.train_set.E, self.train_set.V, self.u) 241 | if u_update: 242 | self.u = u_.data 243 | 244 | ret = self.train_set.getRet() 245 | decay_tag = [0 for i in range(m)] 246 | for i in range(m): 247 | for j in range(n): 248 | if ret[i][j] == 0: 249 | decay_tag[i] += 1 250 | 251 | for edge in self.train_set.candidates: 252 | e, vs_index, vr_index = edge 253 | if ret[vs_index][vr_index] == tau_threshold: 254 | continue 255 | e = e.to(self.device).view(1, -1) 256 | v1 = self.train_set.getMotion(1, vs_index) 257 | v2 = self.train_set.getMotion(0, vr_index, vs_index, line_con[self.cur][vs_index][-1]) 258 | e_ = self.Ephi(e, v1, v2, u_) 259 | tmp = F.softmax(e_) 260 | tmp = tmp.cpu().data.numpy()[0] 261 | 262 | t = line_con[self.cur][vs_index][-1] 263 | # ret[vs_index][vr_index] = float(tmp[0])*pow(decay, t-1) 264 | if decay_tag[vs_index] > 0: 265 | ret[vs_index][vr_index] = min(float(tmp[0]) * pow(decay, t - 1), 1.0) 266 | else: 267 | ret[vs_index][vr_index] = float(tmp[0]) 268 | 269 | # for j in ret: 270 | # print j 271 | results = self.hungarian.compute(ret) 272 | 273 | out = open(outFile, 'a') 274 | look_up = set(j for j in range(n)) 275 | for (i, j) in results: 276 | # print (i,j) 277 | if ret[i][j] >= tau_threshold: 278 | continue 279 | look_up.remove(j) 280 | self.train_set.updateVelocity(i, j, line_con[self.cur][i][-1], False) 281 | 282 | id = id_con[self.cur][i] 283 | id_con[self.nxt][j] = id 284 | attr1 = line_con[self.cur][i] 285 | attr2 = line_con[self.nxt][j] 286 | # print attrs 287 | attr2[1] = str(id) 288 | if attr1[-1] > 1: 289 | # for the missing detections 290 | self.linearModel(out, attr1, attr2) 291 | line = '' 292 | for attr in attr2[:-1]: 293 | line += attr + ',' 294 | if show_recovering: 295 | line += '0' 296 | else: 297 | line = line[:-1] 298 | out.write(line + '\n') 299 | self.bbx_counter += 1 300 | 301 | for j in look_up: 302 | self.train_set.updateVelocity(-1, j, tag=False) 303 | 304 | for i in range(n): 305 | if id_con[self.nxt][i] == -1: 306 | id_con[self.nxt][i] = id_step 307 | attrs = line_con[self.nxt][i] 308 | attrs[1] = str(id_step) 309 | line = '' 310 | for attr in attrs[:-1]: 311 | line += attr + ',' 312 | if show_recovering: 313 | line += '0' 314 | else: 315 | line = line[:-1] 316 | out.write(line + '\n') 317 | self.bbx_counter += 1 318 | id_step += 1 319 | out.close() 320 | 321 | # For missing & Occlusion 322 | index = 0 323 | for (i, j) in results: 324 | while i != index: 325 | attrs = line_con[self.cur][index] 326 | # print '*', attrs, '*' 327 | if attrs[-1] + t_gap <= gap: 328 | attrs[-1] += t_gap 329 | line_con[self.nxt].append(attrs) 330 | id_con[self.nxt].append(id_con[self.cur][index]) 331 | self.train_set.moveMotion(index) 332 | index += 1 333 | 334 | if ret[i][j] >= tau_threshold: 335 | attrs = line_con[self.cur][index] 336 | # print '*', attrs, '*' 337 | if attrs[-1] + t_gap <= gap: 338 | attrs[-1] += t_gap 339 | line_con[self.nxt].append(attrs) 340 | id_con[self.nxt].append(id_con[self.cur][index]) 341 | self.train_set.moveMotion(index) 342 | index += 1 343 | while index < m: 344 | attrs = line_con[self.cur][index] 345 | # print '*', attrs, '*' 346 | if attrs[-1] + t_gap <= gap: 347 | attrs[-1] += t_gap 348 | line_con[self.nxt].append(attrs) 349 | id_con[self.nxt].append(id_con[self.cur][index]) 350 | self.train_set.moveMotion(index) 351 | index += 1 352 | 353 | # con = self.train_set.cleanEdge() 354 | # for i in range(len(con)-1, -1, -1): 355 | # index = con[i] 356 | # del line_con[self.nxt][index] 357 | # del id_con[self.nxt][index] 358 | 359 | line_con[self.cur] = [] 360 | id_con[self.cur] = [] 361 | # print head+step, results 362 | self.train_set.swapFC() 363 | self.swapFC() 364 | gtIn.close() 365 | print(' The results:', id_step, self.bbx_counter) 366 | 367 | 368 | if __name__ == '__main__': 369 | try: 370 | types = [['POI', 0.7]] 371 | # types = [['DPM', -0.6], ['SDP', 0.5], ['FRCNN', 0.5]] 372 | for t in types: 373 | type, tau_conf_score = t 374 | head = time.time() 375 | 376 | f_dir = 'results/' 377 | if not os.path.exists(f_dir): 378 | os.mkdir(f_dir) 379 | print(f_dir, 'does not exist!') 380 | 381 | for i in range(len(seqs)): 382 | if tt_tag: 383 | seq_index = test_seqs[i] 384 | seq_len = test_lengths[i] 385 | else: 386 | seq_index = seqs[i] 387 | seq_len = lengths[i] 388 | 389 | print('The sequence:', seq_index, '- The length of the training data:', seq_len) 390 | 391 | s_dir = f_dir + '%02d/' % seq_index 392 | if not os.path.exists(s_dir): 393 | os.mkdir(s_dir) 394 | print(s_dir, 'does not exist!') 395 | 396 | t_dir = s_dir + '%d/' % seq_len 397 | if not os.path.exists(t_dir): 398 | os.mkdir(t_dir) 399 | print(t_dir, 'does not exist!') 400 | 401 | if tt_tag: 402 | seq_dir = 'MOT%d-%02d-%s' % (year, test_seqs[i], type) 403 | sequence_dir = mot_dataset_dir + 'MOT%d/test/' % year + seq_dir 404 | print(' ', sequence_dir) 405 | 406 | start = time.time() 407 | print(' Evaluating Graph Network...') 408 | gn = GN(test_seqs[i], test_lengths[i]) 409 | else: 410 | seq_dir = 'MOT%d-%02d-%s' % (year, seqs[i], type) 411 | sequence_dir = mot_dataset_dir + 'MOT%d/train/' % year + seq_dir 412 | print(' ', sequence_dir) 413 | 414 | start = time.time() 415 | print(' Evaluating Graph Network...') 416 | gn = GN(seqs[i], lengths[i]) 417 | print(' Recover the number missing detections:', gn.missingCounter) 418 | print(' The number of sideConnections:', gn.sideConnection) 419 | print('Time consuming:', (time.time() - start) / 60.0) 420 | print('Time consuming:', (time.time() - head) / 60.0) 421 | except KeyboardInterrupt: 422 | print('Time consuming:', (time.time() - start) / 60.0) 423 | print('') 424 | print('-' * 90) 425 | print('Existing from testing early.') 426 | -------------------------------------------------------------------------------- /Motion1/test_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import random, torch, shutil, os, gc 3 | from math import * 4 | from PIL import Image 5 | import torch.nn.functional as F 6 | from global_set import edge_initial, test_gt_det, tau_dis, tau_threshold 7 | 8 | 9 | def load_img(filepath): 10 | img = Image.open(filepath).convert('RGB') 11 | return img 12 | 13 | 14 | class DatasetFromFolder(data.Dataset): 15 | def __init__(self, part, part_I, tau, cuda=True): 16 | super(DatasetFromFolder, self).__init__() 17 | self.dir = part 18 | self.cleanPath(part) 19 | self.img_dir = part_I + '/img1/' 20 | self.gt_dir = part + '/gt/' 21 | self.det_dir = part + '/det/' 22 | self.device = torch.device("cuda" if cuda else "cpu") 23 | self.tau_conf_score = tau 24 | 25 | self.getSeqL() 26 | if test_gt_det: 27 | self.readBBx_gt() 28 | else: 29 | self.readBBx_det() 30 | self.initBuffer() 31 | 32 | def cleanPath(self, part): 33 | if os.path.exists(part + '/gts/'): 34 | shutil.rmtree(part + '/gts/') 35 | if os.path.exists(part + '/dets/'): 36 | shutil.rmtree(part + '/dets/') 37 | 38 | def getSeqL(self): 39 | # get the length of the sequence 40 | info = self.dir + '/seqinfo.ini' 41 | f = open(info, 'r') 42 | f.readline() 43 | for line in f.readlines(): 44 | line = line.strip().split('=') 45 | if line[0] == 'seqLength': 46 | self.seqL = int(line[1]) 47 | f.close() 48 | # print 'The length of the sequence:', self.seqL 49 | 50 | def fixBB(self, x, y, w, h, size): 51 | width, height = size 52 | w = min(w + x, width) 53 | h = min(h + y, height) 54 | x = max(x, 0) 55 | y = max(y, 0) 56 | w -= x 57 | h -= y 58 | return x, y, w, h 59 | 60 | def readBBx_gt(self): 61 | # get the gt 62 | self.bbx = [[] for i in range(self.seqL + 1)] 63 | bbxs = [[] for i in range(self.seqL + 1)] 64 | imgs = [None for i in range(self.seqL + 1)] 65 | for i in range(1, self.seqL + 1): 66 | imgs[i] = load_img(self.img_dir + '%06d.jpg' % i) 67 | gt = self.gt_dir + 'gt.txt' 68 | f = open(gt, 'r') 69 | pre = -1 70 | for line in f.readlines(): 71 | line = line.strip().split(',') 72 | if line[7] == '1': 73 | index = int(line[0]) 74 | id = int(line[1]) 75 | x, y = float(line[2]), float(line[3]) 76 | w, h = float(line[4]), float(line[5]) 77 | conf_score, l, vr = float(line[6]), int(line[7]), float(line[8]) 78 | 79 | # sweep the invisible head-bbx from the training data 80 | if pre != id and vr == 0: 81 | continue 82 | 83 | pre = id 84 | img = imgs[index] 85 | x, y, w, h = self.fixBB(x, y, w, h, img.size) 86 | width, height = float(img.size[0]), float(img.size[1]) 87 | self.bbx[index].append([x / width, y / height, w / width, h / height, id, conf_score, vr]) 88 | bbxs[index].append([x, y, w, h, id, conf_score, vr]) 89 | f.close() 90 | 91 | gt_out = open(self.gt_dir + 'gt_det.txt', 'w') 92 | for index in range(1, self.seqL + 1): 93 | for bbx in bbxs[index]: 94 | x, y, w, h, id, conf_score, vr = bbx 95 | print >> gt_out, '%d,-1,%d,%d,%d,%d,%f,-1,-1,-1' % (index, x, y, w, h, conf_score) 96 | gt_out.close() 97 | 98 | def readBBx_det(self): 99 | # get the gt 100 | self.bbx = [[] for i in range(self.seqL + 1)] 101 | imgs = [None for i in range(self.seqL + 1)] 102 | for i in range(1, self.seqL + 1): 103 | imgs[i] = load_img(self.img_dir + '%06d.jpg' % i) 104 | det = self.det_dir + 'det.txt' 105 | f = open(det, 'r') 106 | for line in f.readlines(): 107 | line = line.strip().split(',') 108 | index = int(line[0]) 109 | id = int(line[1]) 110 | x, y = float(line[2]), float(line[3]) 111 | w, h = float(line[4]), float(line[5]) 112 | conf_score = float(line[6]) 113 | if conf_score >= self.tau_conf_score: 114 | img = imgs[i] 115 | x, y, w, h = self.fixBB(x, y, w, h, img.size) 116 | 117 | width, height = float(img.size[0]), float(img.size[1]) 118 | self.bbx[index].append([x / width, y / height, w / width, h / height, id, conf_score]) 119 | f.close() 120 | 121 | def initBuffer(self): 122 | self.f_step = 1 # the index of next frame in the process 123 | self.cur = 0 # the index of current frame in the detections 124 | self.nxt = 1 # the index of next frame in the detections 125 | self.detections = [None, None] # the buffer to storing images: current & next frame 126 | self.feature(1) 127 | 128 | def setBuffer(self, f): 129 | self.m = 0 130 | counter = -1 131 | while self.m == 0: 132 | counter += 1 133 | self.f_step = f + counter 134 | self.feature(1) 135 | self.m = len(self.detections[self.cur]) 136 | if counter > 0: 137 | print(' Empty in setBuffer:', counter) 138 | return counter 139 | 140 | def IOU(self, Reframe, GTframe): 141 | """ 142 | Compute the Intersection of Union 143 | :param Reframe: x, y, w, h 144 | :param GTframe: x, y, w, h 145 | :return: Ratio 146 | """ 147 | if edge_initial == 1: 148 | return random.random() 149 | elif edge_initial == 3: 150 | return 0.5 151 | x1 = Reframe[0] 152 | y1 = Reframe[1] 153 | width1 = Reframe[2] 154 | height1 = Reframe[3] 155 | 156 | x2 = GTframe[0] 157 | y2 = GTframe[1] 158 | width2 = GTframe[2] 159 | height2 = GTframe[3] 160 | 161 | endx = max(x1 + width1, x2 + width2) 162 | startx = min(x1, x2) 163 | width = width1 + width2 - (endx - startx) 164 | 165 | endy = max(y1 + height1, y2 + height2) 166 | starty = min(y1, y2) 167 | height = height1 + height2 - (endy - starty) 168 | 169 | if width <= 0 or height <= 0: 170 | ratio = 0 171 | else: 172 | Area = width * height 173 | Area1 = width1 * height1 174 | Area2 = width2 * height2 175 | ratio = Area * 1. / (Area1 + Area2 - Area) 176 | return ratio 177 | 178 | def aggregate(self, set): 179 | if len(set): 180 | rho = sum(set) 181 | return rho / len(set) 182 | print(' The set is empty!') 183 | return None 184 | 185 | def distance(self, a_bbx, b_bbx): 186 | w = min(float(a_bbx[2]) * tau_dis, float(b_bbx[2]) * tau_dis) 187 | dx = float(a_bbx[0] + a_bbx[2] / 2) - float(b_bbx[0] + b_bbx[2] / 2) 188 | dy = float(a_bbx[1] + a_bbx[3] / 2) - float(b_bbx[1] + b_bbx[3] / 2) 189 | d = sqrt(dx * dx + dy * dy) 190 | if d <= w: 191 | return 0.0 192 | return tau_threshold 193 | 194 | def getRet(self): 195 | cur = self.f_step - self.gap 196 | ret = [[0.0 for i in range(self.n)] for j in range(self.m)] 197 | for i in range(self.m): 198 | bbx1 = self.bbx[cur][i] 199 | for j in range(self.n): 200 | ret[i][j] = self.distance(bbx1, self.bbx[self.f_step][j]) 201 | return ret 202 | 203 | def getMotion(self, tag, index, pre_index=None, t=None): 204 | cur = self.cur if tag else self.nxt 205 | if tag == 0: 206 | self.updateVelocity(pre_index, index, t) 207 | return self.detections[cur][index][0][pre_index] 208 | return self.detections[cur][index][0][0] 209 | 210 | def moveMotion(self, index): 211 | self.bbx[self.f_step].append(self.bbx[self.f_step - self.gap][index]) # add the bbx: x, y, w, h, id, conf_score 212 | self.detections[self.nxt].append( 213 | self.detections[self.cur][index]) # add the motion: [[x, y, w, h, v_x, v_y], id] 214 | 215 | def cleanEdge(self): 216 | con = [] 217 | index = 0 218 | for det in self.detections[self.nxt]: 219 | motion, id = det 220 | x = motion[0][0].item() + motion[0][4].item() 221 | y = motion[0][1].item() + motion[0][5].item() 222 | if (x < 0.0 or x > 1.0) or (y < 0.0 or y > 1.0): 223 | con.append(index) 224 | index += 1 225 | 226 | for i in range(len(con) - 1, -1, -1): 227 | index = con[i] 228 | del self.bbx[self.f_step][index] 229 | del self.detections[self.nxt][index] 230 | return con 231 | 232 | def swapFC(self): 233 | self.cur = self.cur ^ self.nxt 234 | self.nxt = self.cur ^ self.nxt 235 | self.cur = self.cur ^ self.nxt 236 | 237 | def updateVelocity(self, i, j, t=None, tag=True): 238 | v_x = 0.0 239 | v_y = 0.0 240 | if i != -1: 241 | if test_gt_det: 242 | x1, y1, w1, h1, id1, conf_score1, vr1 = self.bbx[self.f_step - self.gap][i] 243 | x2, y2, w2, h2, id2, conf_score2, vr2 = self.bbx[self.f_step][j] 244 | else: 245 | x1, y1, w1, h1, id1, conf_score1 = self.bbx[self.f_step - self.gap][i] 246 | x2, y2, w2, h2, id2, conf_score2 = self.bbx[self.f_step][j] 247 | v_x = (x2 + w2 / 2 - (x1 + w1 / 2)) / t 248 | v_y = (y2 + h2 / 2 - (y1 + h1 / 2)) / t 249 | 250 | if tag: 251 | # print 'm=%d,n=%d; i=%d, j=%d'%(len(self.detections[self.cur]), len(self.detections[self.nxt]), i, j) 252 | self.detections[self.nxt][j][0][i][0][4] = v_x 253 | self.detections[self.nxt][j][0][i][0][5] = v_y 254 | else: 255 | cur_m = self.detections[self.nxt][j][0][0] 256 | cur_m[0][4] = v_x 257 | cur_m[0][5] = v_y 258 | self.detections[self.nxt][j][0] = [cur_m] 259 | 260 | def getMN(self, m, n): 261 | cur = self.f_step - self.gap 262 | ans = [[None for i in range(n)] for i in range(m)] 263 | for i in range(m): 264 | Reframe = self.bbx[cur][i] 265 | for j in range(n): 266 | GTframe = self.bbx[self.f_step][j] 267 | p = self.IOU(Reframe, GTframe) 268 | # 1 - match, 0 - mismatch 269 | ans[i][j] = torch.FloatTensor([(1 - p) / 100.0, p / 100.0]) 270 | return ans 271 | 272 | def feature(self, tag=0): 273 | ''' 274 | Getting the appearance of the detections in current frame 275 | :param tag: 1 - initiating 276 | :param show: 1 - show the cropped & src image 277 | :return: None 278 | ''' 279 | motions = [] 280 | with torch.no_grad(): 281 | m = 1 if tag else self.m 282 | for bbx in self.bbx[self.f_step]: 283 | """ 284 | Bellow Conditions need to be taken into consideration: 285 | x, y < 0 and x+w > W, y+h > H 286 | """ 287 | if test_gt_det: 288 | x, y, w, h, id, conf_score, vr = bbx 289 | else: 290 | x, y, w, h, id, conf_score = bbx 291 | cur_m = [] 292 | for i in range(m): 293 | cur_m.append(torch.FloatTensor([[x, y, w, h, 0.0, 0.0]]).to(self.device)) 294 | motions.append([cur_m, id]) 295 | if tag: 296 | self.detections[self.cur] = motions 297 | else: 298 | self.detections[self.nxt] = motions 299 | 300 | def loadNext(self): 301 | self.m = len(self.detections[self.cur]) 302 | 303 | self.gap = 0 304 | self.n = 0 305 | while self.n == 0: 306 | self.f_step += 1 307 | self.feature() 308 | self.n = len(self.detections[self.nxt]) 309 | self.gap += 1 310 | 311 | if self.gap > 1: 312 | print(' Empty in loadNext:', self.f_step - self.gap + 1, '-', self.gap - 1) 313 | 314 | self.candidates = [] 315 | self.edges = self.getMN(self.m, self.n) 316 | 317 | es = [] 318 | # vs_index = 0 319 | for i in range(self.m): 320 | # vr_index = self.m 321 | for j in range(self.n): 322 | e = self.edges[i][j] 323 | es.append(e) 324 | self.candidates.append([e, i, j]) 325 | # vr_index += 1 326 | # vs_index += 1 327 | 328 | vs = [] 329 | for i in range(2): 330 | n = len(self.detections[i]) 331 | for j in range(n): 332 | v = self.detections[i][j][0][0] 333 | vs.append(v) 334 | 335 | self.E = self.aggregate(es).to(self.device).view(1, -1) 336 | self.V = self.aggregate(vs).to(self.device) 337 | 338 | # print ' The index of the next frame', self.f_step, len(self.bbx) 339 | return self.gap 340 | 341 | def __getitem__(self, index): 342 | return self.candidates[index] 343 | 344 | def __len__(self): 345 | return len(self.candidates) 346 | -------------------------------------------------------------------------------- /Motion1/train.py: -------------------------------------------------------------------------------- 1 | # from __future__ import print_function 2 | import torch.optim as optim 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from torch.utils.data import DataLoader 6 | from dataset import DatasetFromFolder 7 | import time, random, os, shutil 8 | from munkres import Munkres 9 | from global_set import u_initial, mot_dataset_dir, model_dir 10 | from m_mot_model import * 11 | from tensorboardX import SummaryWriter 12 | 13 | torch.manual_seed(123) 14 | np.random.seed(123) 15 | 16 | 17 | def deleteDir(del_dir): 18 | shutil.rmtree(del_dir) 19 | 20 | 21 | class GN(): 22 | def __init__(self, lr=5e-4, batchs=8, cuda=True): 23 | ''' 24 | :param tt: train_test 25 | :param tag: 1 - evaluation on testing data, 0 - without evaluation on testing data 26 | :param lr: 27 | :param batchs: 28 | :param cuda: 29 | ''' 30 | # all the tensor should set the 'volatile' as True, and False when update the network 31 | self.hungarian = Munkres() 32 | self.device = torch.device("cuda" if cuda else "cpu") 33 | self.nEpochs = 999 34 | self.lr = lr 35 | self.batchsize = batchs 36 | self.numWorker = 4 37 | 38 | self.show_process = 0 # interaction 39 | self.step_input = 1 40 | 41 | print(' Preparing the model...') 42 | self.resetU() 43 | 44 | self.Uphi = uphi().to(self.device) 45 | self.Ephi = ephi().to(self.device) 46 | 47 | self.criterion = nn.MSELoss() if criterion_s else nn.CrossEntropyLoss() 48 | self.criterion = self.criterion.to(self.device) 49 | 50 | self.optimizer = optim.Adam([ 51 | {'params': self.Uphi.parameters()}, 52 | {'params': self.Ephi.parameters()}], 53 | lr=lr) 54 | 55 | self.writer = SummaryWriter() 56 | 57 | seqs = [2, 4, 5, 9, 10, 11, 13] 58 | lengths = [600, 1050, 837, 525, 654, 900, 750] 59 | 60 | for i in range(len(seqs)): 61 | # print ' Loading Data...' 62 | seq = seqs[i] 63 | self.seq_index = seq 64 | start = time.time() 65 | sequence_dir = mot_dataset_dir + 'MOT16/train/MOT16-%02d' % seq 66 | 67 | self.train_set = DatasetFromFolder(sequence_dir) 68 | 69 | self.train_test = lengths[i] 70 | # self.train_test = int(self.train_test * 0.8) # For training the model without the validation set 71 | 72 | self.loss_threhold = 0.03 73 | self.update('seq/%02d' % seq) 74 | 75 | def resetU(self): 76 | if u_initial: 77 | self.u = torch.FloatTensor([random.random() for i in range(u_num)]).view(1, -1) 78 | else: 79 | self.u = torch.FloatTensor([0.0 for i in range(u_num)]).view(1, -1) 80 | self.u = self.u.to(self.device) 81 | 82 | def updateNetwork(self, seqName): 83 | self.train_set.setBuffer(1) 84 | step = 1 85 | loss_step = 0 86 | average_epoch = 0 87 | edge_counter = 0.0 88 | for head in range(1, self.train_test): 89 | self.train_set.loadNext() # Get the next frame 90 | edge_counter += self.train_set.m * self.train_set.n 91 | start = time.time() 92 | # print(' Step -', step) 93 | data_loader = DataLoader(dataset=self.train_set, num_workers=self.numWorker, batch_size=self.batchsize, 94 | shuffle=True) 95 | for epoch in range(1, self.nEpochs): 96 | num = 0 97 | epoch_loss = 0.0 98 | arpha_loss = 0.0 99 | for iteration in enumerate(data_loader, 1): 100 | index, (e, gt, vs_index, vr_index) = iteration 101 | e = e.to(self.device) 102 | gt = gt.to(self.device) 103 | 104 | self.optimizer.zero_grad() 105 | 106 | u_ = self.Uphi(self.train_set.E, self.train_set.V, self.u) 107 | v1 = self.train_set.getMotion(1, vs_index) 108 | v2 = self.train_set.getMotion(0, vr_index, vs_index) 109 | e_ = self.Ephi(e, v1, v2, u_) 110 | 111 | if self.show_process: 112 | print('-' * 66) 113 | print(vs_index, vr_index) 114 | print('e:', e.cpu().data.numpy()[0][0], end=' ') 115 | print('e_:', e_.cpu().data.numpy()[0][0], end=' ') 116 | if criterion_s: 117 | print('GT:', gt.cpu().data.numpy()[0][0]) 118 | else: 119 | print('GT:', gt.cpu().data.numpy()[0]) 120 | 121 | # Penalize the u to let its value not too big 122 | arpha = torch.mean(torch.abs(u_)) 123 | arpha_loss += arpha.item() 124 | arpha.backward(retain_graph=True) 125 | 126 | # The regular loss 127 | # print e_.size(), e_ 128 | # print gt.size(), gt 129 | loss = self.criterion(e_, gt.squeeze(1)) 130 | epoch_loss += loss.item() 131 | loss.backward() 132 | 133 | # update the network: Uphi and Ephi 134 | self.optimizer.step() 135 | 136 | # Show the parameters of the Uphi and Ephi to check the process of optimiser 137 | # print self.Uphi.features[0].weight.data 138 | # print self.Ephi.features[0].weight.data 139 | # raw_input('continue?') 140 | 141 | num += e.size()[0] 142 | 143 | if self.show_process and self.step_input: 144 | a = input('Continue(0-step, 1-run, 2-run with showing)?') 145 | if a == '1': 146 | self.show_process = 0 147 | elif a == '2': 148 | self.step_input = 0 149 | 150 | epoch_loss /= num 151 | 152 | self.writer.add_scalar(seqName, epoch_loss, loss_step) # Show the training loss 153 | loss_step += 1 154 | # print(' Loss of epoch {}: {}.'.format(epoch, epoch_loss)) 155 | if epoch_loss < self.loss_threhold: 156 | break 157 | 158 | # print(' Time consuming:{}\n\n'.format(time.time() - start)) 159 | self.updateUE() 160 | average_epoch += epoch 161 | step += 1 162 | self.train_set.swapFC() 163 | 164 | def saveModel(self): 165 | print('Saving the Uphi model...') 166 | torch.save(self.Uphi, model_dir + 'uphi_%02d.pth' % self.seq_index) 167 | print('Saving the Ephi model...') 168 | torch.save(self.Ephi, model_dir + 'ephi_%02d.pth' % self.seq_index) 169 | print('Saving the global variable u...') 170 | torch.save(self.u, model_dir + 'u_%02d.pth' % self.seq_index) 171 | print('Done!') 172 | 173 | def updateUE(self): 174 | u_ = self.Uphi(self.train_set.E, self.train_set.V, self.u) 175 | 176 | self.u = u_.data 177 | 178 | # update the edges 179 | for edge in self.train_set: 180 | e, gt, vs_index, vr_index = edge 181 | e = e.to(self.device).view(1, -1) 182 | v1 = self.train_set.getMotion(1, vs_index) 183 | v2 = self.train_set.getMotion(0, vr_index, vs_index) 184 | e_ = self.Ephi(e, v1, v2, u_) 185 | self.train_set.edges[vs_index][vr_index] = e_.data.view(-1) 186 | 187 | def update(self, seqName): 188 | print(' Train the model with the sequence: ', seqName) 189 | self.updateNetwork(seqName) 190 | self.saveModel() 191 | 192 | 193 | if __name__ == '__main__': 194 | try: 195 | # deleteDir(model_dir) 196 | if not os.path.exists(model_dir): 197 | os.mkdir(model_dir) 198 | start = time.time() 199 | print(' Starting Graph Network...') 200 | gn = GN() 201 | print('Time consuming:', time.time() - start) 202 | else: 203 | # deleteDir(model_dir) 204 | # os.mkdir(model_dir) 205 | print('The model has been here!') 206 | 207 | except KeyboardInterrupt: 208 | print('Time consuming:', time.time() - start) 209 | print('') 210 | print('-' * 90) 211 | print('Existing from training early.') 212 | -------------------------------------------------------------------------------- /Pic/Pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yinizhizhu/GNMOT/f827d850e61385d833270c8e204d7304df44e2b2/Pic/Pipeline.png -------------------------------------------------------------------------------- /Pic/mot.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yinizhizhu/GNMOT/f827d850e61385d833270c8e204d7304df44e2b2/Pic/mot.gif -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Networks for Multiple Object Tracking (WACV 2020) 2 | 3 | ## Introduction 4 | This is the official code of '**Graph Networks for Multiple object Tracking**'. 5 | Multiple object tracking (MOT) task requires reasoning the states of all targets and associating these targets in a global way. However, existing MOT methods mostly focus on the local relationship among objects and ignore the global relationship. Some methods formulate the MOT problem as a graph optimization problem. However, these methods are based on static graphs, which are seldom updated. To solve these problems, we design a new near-online MOT method with an end-to-end graph network. Specifically, we design an appearance graph network and a motion graph network to capture the appearance and the motion similarity separately. The updating mechanism is carefully designed in our graph network, which means that nodes, edges and the global variable in the graph can be updated. The global variable can capture the global relationship to help tracking. Finally, a strategy to handle missing detections is proposed to remedy the defect of the detectors. Our method is evaluated on both the MOT16 and the MOT17 benchmarks, and experimental results show the encouraging performance of our method. 6 | 7 |
8 |
9 |
20 |
21 |