├── Data └── PUT_UNZIPPED_DATASETS_HERE ├── README.md ├── SeqLists ├── list.txt ├── otb-vot14.txt ├── otb-vot15.txt ├── vot13-otb.txt ├── vot14-otb.txt └── vot15-otb.txt └── pretrain ├── load.py ├── network.py └── visualize.py /Data/PUT_UNZIPPED_DATASETS_HERE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorpro/MDNet/6ba896d7f89c22c05e510f94e114c0664dbcba5f/Data/PUT_UNZIPPED_DATASETS_HERE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MDNet 2 | Multi-Domain Convolutional Neural Network Tracker 3 | Depends on Keras and Tensorflow. 4 | 5 | Currently, there is a basic implementation of pretraining. 6 | 7 | Running `python network.py` from the pretraining directory will pretrain it. 8 | A model will be saved in a Models directory. 9 | 10 | -------------------------------------------------------------------------------- /SeqLists/list.txt: -------------------------------------------------------------------------------- 1 | ball 2 | basketball 3 | bicycle 4 | bolt 5 | car 6 | david 7 | diving 8 | drunk 9 | fernando 10 | fish1 11 | fish2 12 | gymnastics 13 | hand1 14 | hand2 15 | jogging 16 | motocross 17 | polarbear 18 | skating 19 | sphere 20 | sunshade 21 | surfing 22 | torus 23 | trellis 24 | tunnel 25 | woman 26 | -------------------------------------------------------------------------------- /SeqLists/otb-vot14.txt: -------------------------------------------------------------------------------- 1 | Biker 2 | Bird1 3 | BlurBody 4 | BlurCar2 5 | BlurFace 6 | BlurOwl 7 | Box 8 | Car1 9 | Car4 10 | CarDark 11 | ClifBar 12 | Couple 13 | Crowds 14 | Deer 15 | DragonBaby 16 | Dudek 17 | Football 18 | Freeman4 19 | Girl 20 | Human3 21 | Human4 22 | Human6 23 | Human9 24 | Ironman 25 | Jump 26 | Jumping 27 | Liquor 28 | Matrix 29 | Panda 30 | RedTeam 31 | Shaking 32 | Singer2 33 | Skating2-1 34 | Skating2-2 35 | Skiing 36 | Soccer 37 | Surfer 38 | Sylvester 39 | Tiger2 40 | Walking 41 | Walking2 42 | Bird2 43 | BlurCar1 44 | BlurCar3 45 | BlurCar4 46 | Board 47 | Bolt2 48 | Boy 49 | Car2 50 | Car24 51 | Coke 52 | Coupon 53 | Crossing 54 | Dancer 55 | Dancer2 56 | David2 57 | David3 58 | Dog 59 | Dog1 60 | Doll 61 | FaceOcc1 62 | FaceOcc2 63 | Fish 64 | FleetFace 65 | Football1 66 | Freeman1 67 | Freeman3 68 | Girl2 69 | Gym 70 | Human2 71 | Human5 72 | Human7 73 | Human8 74 | KiteSurf 75 | Lemming 76 | Man 77 | Mhyang 78 | MountainBike 79 | Rubik 80 | Singer1 81 | Skater 82 | Skater2 83 | Subway 84 | Suv 85 | Tiger1 86 | Toy 87 | Trans 88 | Twinnings 89 | Vase 90 | 91 | -------------------------------------------------------------------------------- /SeqLists/otb-vot15.txt: -------------------------------------------------------------------------------- 1 | Biker 2 | Bird1 3 | BlurBody 4 | BlurCar2 5 | BlurFace 6 | BlurOwl 7 | Box 8 | Car1 9 | Car4 10 | CarScale 11 | ClifBar 12 | Crowds 13 | David 14 | Deer 15 | Diving 16 | DragonBaby 17 | Dudek 18 | Football 19 | Freeman4 20 | Girl 21 | Human4 22 | Human9 23 | Ironman 24 | Jump 25 | Jumping 26 | Liquor 27 | Panda 28 | RedTeam 29 | Skating1 30 | Skiing 31 | Surfer 32 | Sylvester 33 | Trellis 34 | Walking 35 | Walking2 36 | Woman 37 | Bird2 38 | BlurCar3 39 | BlurCar4 40 | Board 41 | Boy 42 | Car2 43 | Car24 44 | Coke 45 | Coupon 46 | Crossing 47 | Dancer 48 | Dancer2 49 | David2 50 | David3 51 | Dog 52 | Dog1 53 | Doll 54 | FaceOcc1 55 | FaceOcc2 56 | Fish 57 | FleetFace 58 | Football1 59 | Freeman1 60 | Freeman3 61 | Gym 62 | Human2 63 | Human7 64 | Human8 65 | Jogging-1 66 | Jogging-2 67 | KiteSurf 68 | Lemming 69 | Man 70 | Mhyang 71 | MountainBike 72 | Rubik 73 | Skater 74 | Skater2 75 | Subway 76 | Suv 77 | Toy 78 | Trans 79 | Twinnings 80 | Vase 81 | 82 | -------------------------------------------------------------------------------- /SeqLists/vot13-otb.txt: -------------------------------------------------------------------------------- 1 | cup 2 | iceskater 3 | juice -------------------------------------------------------------------------------- /SeqLists/vot14-otb.txt: -------------------------------------------------------------------------------- 1 | ball 2 | bicycle 3 | drunk 4 | fish1 5 | hand1 6 | polarbear 7 | sphere 8 | sunshade 9 | surfing 10 | torus 11 | tunnel -------------------------------------------------------------------------------- /SeqLists/vot15-otb.txt: -------------------------------------------------------------------------------- 1 | bag 2 | ball1 3 | ball2 4 | birds1 5 | birds2 6 | blanket 7 | bmx 8 | book 9 | butterfly 10 | crossing 11 | dinosaur 12 | fernando 13 | fish1 14 | fish2 15 | fish3 16 | fish4 17 | glove 18 | godfather 19 | graduate 20 | gymnastics1 21 | gymnastics2 22 | gymnastics3 23 | gymnastics4 24 | hand 25 | handball1 26 | handball2 27 | helicopter 28 | iceskater1 29 | leaves 30 | marching 31 | motocross2 32 | nature 33 | octopus 34 | rabbit 35 | racing 36 | road 37 | sheep 38 | singer3 39 | soccer2 40 | soldier 41 | sphere 42 | traffic 43 | tunnel 44 | wiper 45 | -------------------------------------------------------------------------------- /pretrain/load.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import wget 3 | import wgetter 4 | from zipfile import ZipFile 5 | from os.path import isfile, join 6 | from os import listdir 7 | from time import sleep 8 | from itertools import repeat, count, islice, cycle 9 | from random import random, choice 10 | import numpy as np 11 | import pandas as pd 12 | from skimage.io import imread, imshow, imshow_collection 13 | from skimage.transform import resize 14 | 15 | 16 | data_path = '../data/' 17 | green = [0, 255, 0] 18 | red = [255, 0, 0] 19 | blue = [0, 0, 255] 20 | 21 | 22 | def vot_url(year): 23 | """Gets the url to load the VOT dataset for a given year""" 24 | return r"http://box.vicos.si/vot/vot{0}.zip".format(year) 25 | 26 | 27 | def download_vot(years=[2013, 2014, 2015], out_dir="../Data"): 28 | """ 29 | Dowloads VOT for given years and places then into the 30 | datasets directory. 31 | """ 32 | 33 | for year in years: 34 | url = vot_url(year) 35 | print("Downloading VOT {0}".format(year)) 36 | wgetter.download(url, outdir=out_dir) 37 | 38 | 39 | def unzip_files(zip_dir): 40 | """ 41 | Unzips the files from a given directory. 42 | """ 43 | zip_names = [f for f in listdir(zip_dir) if '.zip' in f] 44 | for zn in zip_names: 45 | print('Extracting files for {}'.format(zn)) 46 | zf = ZipFile(join(zip_dir, zn)) 47 | name, extension = zn.split(".") 48 | zf.extractall(join(zip_dir, name)) 49 | 50 | 51 | def load_seq(dataset, name): 52 | """ 53 | Loads sequence called `name` from `dataset`. 54 | """ 55 | fullpath = join(data_path, dataset, name) 56 | bbox = pd.read_csv(join(fullpath, 'groundtruth.txt')) 57 | img_names = [join(fullpath, f) for f in listdir(fullpath) if '.jpg' in f] 58 | imgs = np.array([imread(f) for f in img_names]) 59 | return standardize_bbox(bbox.values), imgs 60 | 61 | 62 | def intbb(bb): 63 | """ 64 | Ensures bb has integer values. 65 | """ 66 | return np.round(bb).astype(int) 67 | 68 | 69 | def standardize_bbox(bbox): 70 | """ 71 | Standardize bbox representation across different years of VOT 72 | VOT2013 had 4 values, and vot 2014/2015 had 8. 73 | """ 74 | bbox = np.round(bbox).astype(int) 75 | if len(bbox[0]) == 4: 76 | return bbox 77 | elif len(bbox[0]) == 8: 78 | left, right, top, bottom = bbox[:, [0, 6, 3, 1]].T 79 | width = right - left 80 | height = bottom - top 81 | return np.array([left, top, width, height]).T 82 | else: 83 | raise ValueError("bbox must have `len` 4 or 8") 84 | 85 | 86 | def roi_sample(bb, image_size, valid, scale_factor=1.05, trans_range=.1, 87 | scale_range=3, overlap_range=[-np.inf, np.inf], max_tries=4): 88 | """ 89 | Single sample for a given bounding box and image_size 90 | """ 91 | h, w, _ = image_size 92 | left, top, width, height = np.round(bb) 93 | 94 | found = False 95 | min_r, max_r = overlap_range 96 | tries = 0 97 | while not found: 98 | sample = np.array([left + width / 2, top + height / 2, width, height]) 99 | sample[:2] += trans_range * bb[2:] * (np.random.random(2) * 2 - 1) 100 | to_mul = np.power(scale_factor, scale_range * np.random.random()) 101 | sample[2:] = np.multiply( 102 | sample[2:], [scale_factor**(scale_range * random() * 2 - 1)] * 2) 103 | sample[2] = max(5, min(w - 5, sample[2])) 104 | sample[3] = max(5, min(h - 5, sample[3])) 105 | sample[:2] -= sample[2:] / 2 106 | 107 | if valid: 108 | sample[0] = max(1, min(w - sample[2], sample[0])) 109 | sample[1] = max(1, min(h - sample[3], sample[1])) 110 | else: 111 | sample[0] = max(1 - sample[2] / 2, 112 | min(w - sample[2] / 2, sample[0])) 113 | sample[1] = max(1 - sample[3] / 2, 114 | min(h - sample[3] / 2, sample[1])) 115 | r = overlap_ratio(bb, sample) 116 | tries+=1 117 | allpos = np.all(sample>=0) 118 | found = min_r <= r <= max_r and allpos 119 | if tries >= max_tries and allpos: 120 | found = True 121 | return np.round(sample).astype(int) 122 | 123 | 124 | 125 | 126 | def crop(img, bb, outshape=[107, 107]): 127 | """ 128 | Crops the `bb` from `img`, and resizes to `outshape` 129 | """ 130 | left, top, width, height = bb 131 | crop=img[top:top + height, left:left + width] 132 | return resize(crop, outshape) 133 | 134 | 135 | def overlap(rect1, rect2): 136 | """ 137 | Returns overlapping area between r1, r2 138 | Both rectangles are represent x1,y1,width,height 139 | """ 140 | l1, t1, w1, h1 = rect1 141 | l2, t2, w2, h2 = rect2 142 | r1, r2 = l1 + w1, l2 + w2 143 | b1, b2 = t1 + h1, t2 + h2 144 | dx = min(r1, r2) - max(l1, l2) 145 | dy = min(b1, b2) - max(t1, t2) 146 | return max(dx * dy, 0) 147 | 148 | 149 | def overlap_ratio(r1, r2): 150 | """ 151 | Returns overlapping ratio between r1, r2 152 | Both rectangles are represent x1,y1,width,height 153 | """ 154 | intersect = overlap(r1, r2) 155 | w1, h1 = r1[2:] 156 | w2, h2 = r2[2:] 157 | return intersect / (w1 * h1 + w2 * h2 - intersect) 158 | 159 | 160 | def load_seqs(dataset, seqsfile, seqsdir='../SeqLists'): 161 | """ 162 | loads the sequences in seqsfile from a given dataset 163 | """ 164 | f = open(join(seqsdir, seqsfile), 'r') 165 | seqnames = f.read().strip().split('\n') 166 | seqs = [zip(*load_seq(dataset, sn)) for sn in seqnames] 167 | return seqs 168 | 169 | 170 | def create_roidb(toload): 171 | """ 172 | Creates RoiDB, where there is a list of (bounding box, img) tuples 173 | for each sequence 174 | """ 175 | roidb = [] 176 | for (dataset, seqsfile) in toload: 177 | roidb.extend(load_seqs(dataset, seqsfile)) 178 | return roidb 179 | 180 | 181 | def generator(roidb, batchsize=128, minisize=8, posprob=1 / 3, show=False): 182 | """ 183 | Helper generator for batch_generator. 184 | 185 | After yielding `batchsize` elements, it moves on to the next 186 | sequence in roidb 187 | 188 | Samples examples that are positive with probability `posprob` 189 | 190 | If `show` is True, will display bounding box 191 | """ 192 | for (D, seq) in cycle(enumerate(roidb)): 193 | for _ in range(int(batchsize / minisize)): 194 | bb, img = choice(seq) 195 | for _ in range(minisize): 196 | if random() < posprob: 197 | label = 1 198 | smp= roi_sample(bb, img.shape, False, trans_range=.1, 199 | scale_range=5, overlap_range=[.7, 1]) 200 | else: 201 | label=0 202 | smp =roi_sample(bb, img.shape, False, trans_range=2, 203 | scale_range=10, overlap_range=[1e-7, .5]) 204 | if show: 205 | show_bb(img, smp, color = green if label is 1 else red) 206 | val = [0,0] 207 | val[label]=1 208 | yield D, crop(img,smp), label 209 | 210 | 211 | def batch_generator(roidb, batchsize=128, minisize=8, posprob=1 / 3, show=False): 212 | "generates a batch of cropped images" 213 | gen = generator(roidb, batchsize, minisize, 1 / 3, show) 214 | for _ in count(): 215 | batch_vals = zip(*islice(gen, batchsize)) 216 | ad, imgs, labs = batch_vals 217 | yield np.array(imgs), np.array(labs) 218 | # yield [np.array(imgs), np.array(list(ad))], np.array(labs) 219 | 220 | 221 | to_load = [('vot2013','vot13-otb.txt'),('vot2014','vot14-otb.txt')] 222 | 223 | if __name__=='__main__': 224 | download_vot() 225 | unzip_files("../Data") 226 | -------------------------------------------------------------------------------- /pretrain/network.py: -------------------------------------------------------------------------------- 1 | from keras import backend as K 2 | import os 3 | from keras.applications.vgg16 import VGG16 4 | from keras.preprocessing import image 5 | from keras.applications.vgg16 import preprocess_input 6 | from keras.models import Model 7 | import numpy as np 8 | from keras.layers import Dense, GlobalAveragePooling2D, Flatten, Input, Lambda, Reshape 9 | from keras.layers.pooling import MaxPooling2D 10 | from random import randint 11 | from load import to_load, create_roidb 12 | from load import batch_generator 13 | 14 | ########################################################## 15 | # RoiDB setup 16 | print("Initializing RoiDB...") 17 | roidb = create_roidb(to_load) 18 | print("RoiDB initialization complete!") 19 | 20 | 21 | ########################################################## 22 | # MODEL SETUP 23 | 24 | # Setup input shape assuming Tensorflow backend 25 | print("Setting up models..") 26 | images = Input([107,107,3], name='Images') 27 | 28 | # Use VGG-16 As a base model 29 | base_model = VGG16(weights='imagenet', include_top=False, input_tensor=images) 30 | # for l in base_model.layers: 31 | # l.trainable=False # Disable for these convolutional layers 32 | 33 | # Add the shared fully connected layers 34 | shared = Flatten()(base_model.output) 35 | fc1 = Dense(512, activation='relu')(shared) 36 | fc2 = Dense(512, activation='relu')(fc1) 37 | 38 | D = len(roidb) # The number of domains 39 | logits = [Dense(1, activation='sigmoid')(fc2) for _ in range(D)] 40 | models = [Model(input=base_model.input, output=logit) for logit in logits] 41 | 42 | # Set the optimizer and the loss 43 | print("Compiling models") 44 | [m.compile(optimizer='adam', loss='binary_crossentropy') for m in models] 45 | print("Models are setup!") 46 | 47 | ########################################################## 48 | # TRAINING 49 | 50 | iters = 5 # The number of times each sequence should be trained on 51 | 52 | print("Starting Training") 53 | gen=batch_generator(roidb) 54 | for _ in range(iters): 55 | for (i,m) in enumerate(models): 56 | print("Training on dataset {}".format(i+1)) 57 | X,Y = gen.next() 58 | m.fit(X,Y, nb_epoch=1, verbose=0) 59 | # Will probably use fit_generator (https://keras.io/models/model/#methods) in the actual thing 60 | 61 | ########################################################## 62 | # Save model 63 | model_dir = '../Models' 64 | if not os.path.exists(model_dir): 65 | os.makedirs(model_dir) 66 | m.save(join(model_dir, 'mdnet.hd5')) 67 | 68 | # Saves the model for one of the branches 69 | 70 | print("Training complete!") 71 | ########################################################## 72 | -------------------------------------------------------------------------------- /pretrain/visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from load import crop, roi_sample, overlap_ratio, intbb 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | 7 | def show_bb(img, bb, color): 8 | """ 9 | Shows image with `bb` drawn on it. 10 | """ 11 | img = img.copy() 12 | plt.imshow(draw_bbox(bb,img,color)) 13 | plt.show() 14 | 15 | def demo_gen(roidb,n=4): 16 | """ 17 | Shows how the generator samples from roidb. 18 | """ 19 | gen = generator(roidb,4,2, show=True) 20 | for _ in range(n): 21 | (D,img), label = gen.next() 22 | plt.imshow(img) 23 | plt.show() 24 | 25 | 26 | def crop_demo(img, bb, numpos=5, numneg=10): 27 | """ 28 | Shows some example crops of positive and negative examples 29 | """ 30 | pos = roi_sample(bb, img.shape, False, trans_range=.1, 31 | scale_range=5, overlap_range=[.7, 1]) 32 | img = img.copy() 33 | for _ in range(numpos): 34 | smp = roi_sample(bb, img.shape, False, trans_range=.1, 35 | scale_range=5, overlap_range=[.7, 1]) 36 | plt.imshow(crop(img, smp)) 37 | plt.show() 38 | 39 | for _ in range(numneg): 40 | smp = roi_sample(bb, img.shape, False, trans_range=2, 41 | scale_range=10, overlap_range=[0, .5]) 42 | plt.imshow(crop(img, smp)) 43 | plt.show() 44 | 45 | def draw_bbox(bbox, img, color=[255, 0, 255]): 46 | """ 47 | Draws bbox on image. 48 | """ 49 | left, top, width, height = intbb(bbox) 50 | right = min(left + width, img.shape[1] - 1) 51 | bottom = min(top + height, img.shape[0] - 1) 52 | img[top, left:right] = color 53 | img[bottom, left:right] = color 54 | img[top:bottom, left] = color 55 | img[top:bottom, right] = color 56 | return img 57 | 58 | 59 | def track(bboxs, imgs): 60 | """ 61 | Shows bb across sequence. 62 | """ 63 | for bbox, img in zip(bboxs, imgs): 64 | plt.imshow(draw_bbox(bbox, img)) 65 | plt.show() 66 | 67 | 68 | def show_samps(img, bb, valid=True, n=5): 69 | """ 70 | Samples `n` times from img around `bb` 71 | and colors: 72 | 73 | positive examples: green 74 | negative examples: red 75 | discarded examples: black 76 | """ 77 | bb = np.round(bb).astype(int) 78 | plt.imshow(draw_bbox(bb, img)) 79 | img = draw_bbox(bb, img) 80 | plt.show() 81 | for _ in range(n): 82 | smp = roi_sample(bb, img.shape, valid, 83 | overlap_range=(.7, 1)).astype(int) 84 | r = (overlap_ratio(smp, bb)) 85 | if .7 <= r <= 1: 86 | color = green 87 | elif .1 <= r <= .5: 88 | color = red 89 | else: 90 | color = [0] * 3 91 | plt.imshow(draw_bbox(smp, img, color)) 92 | plt.show() 93 | 94 | 95 | def sample_demo(img, bb, numpos=5, numneg=10): 96 | """ 97 | Outputs an image showing the positive and negative 98 | examples 99 | """ 100 | img = img.copy() 101 | output_img = draw_bbox(bb, img) 102 | for _ in range(numpos): 103 | smp = roi_sample(bb, img.shape, False, trans_range=.1, 104 | scale_range=5, overlap_range=[.7, 1]) 105 | ouput_img = draw_bbox(smp, output_img, color=green) 106 | for _ in range(numneg): 107 | smp = roi_sample(bb, img.shape, False, trans_range=2, 108 | scale_range=10, overlap_range=[0, .5]) 109 | ouput_img = draw_bbox(smp, output_img, color=red) 110 | plt.imsave("samples", output_img) 111 | --------------------------------------------------------------------------------