├── .gitignore ├── LICENSE ├── README.md ├── build_tfrecords ├── VID_Info │ └── synsets.txt ├── build_data_vid.py ├── collect_vid_info.py ├── generate_vidb.py ├── process_data.sh └── process_xml.py ├── config.py ├── estimator.py ├── experiment.py ├── feature.py ├── input.py ├── memnet ├── access.py ├── addressing.py ├── memnet.py └── rnn.py ├── model.py └── tracking ├── demo.py └── tracker.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | output/ 3 | .DS_Store 4 | .idea 5 | .svn 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Tianyu Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Learning Dynamic Memory Networks for Object Tracking 2 | >*We extend our MemTrack with Distractor Template Canceling mechamism in our journal verison, please check our new method [MemDTC](https://tianyu-yang.com/resources/memdtc.pdf). Code is availabe at [MemDTC-code](https://github.com/skyoung/MemDTC)* 3 | ### Introduction 4 | This is the Tensorflow implementation of our [**MemTrack**](https://arxiv.org/pdf/1803.07268.pdf) tracker published in ECCV, 2018. Detailed comparision results can be found in the author's [webpage](https://tianyu-yang.com) 5 | 6 | ### Prerequisites 7 | 8 | * Python 3.5 or higher 9 | * Tensorflow 1.2.1 or higher 10 | * CUDA 8.0 11 | 12 | ### Path setting 13 | Set proper `home_path` in `config.py` accordingly in order to proceed the following step. Make sure that you place the tracking data properly according to your path setting. 14 | 15 | ### Tracking Demo 16 | You can use our pretrained model to test our tracker first. 17 | 1. Download the model from the link: [GoogleDrive](https://drive.google.com/open?id=1ybywFIzVbflj2-n3gDM9kqEP2-LZ4f3l) 18 | 2. Put the model into directory `./output/models` 19 | 3. Run `python3 demo.py` in directory `./tracking` 20 | 21 | ### Training 22 | 1. Download the ILSRVC data from the official website and extract it to proper place according to the path in `config.py`. 23 | 2. Then run the `sh process_data.sh` in `./build_tfrecords` directory to convert ILSVRC data to tfrecords. 24 | 3. Run `python3 experiment.py` to train the model. 25 | 26 | ### Citing MemTrack 27 | If you find the code is helpful, please cite 28 | ``` 29 | @inproceedings{Yang2018, 30 | author = {Yang, Tianyu and Chan, Antoni B.}, 31 | booktitle = {ECCV}, 32 | title = {{Learning Dynamic Memory Networks for Object 33 | Tracking}}, 34 | year = {2018} 35 | } 36 | ``` 37 | -------------------------------------------------------------------------------- /build_tfrecords/VID_Info/synsets.txt: -------------------------------------------------------------------------------- 1 | n02691156 1 airplane 2 | n02419796 2 antelope 3 | n02131653 3 bear 4 | n02834778 4 bicycle 5 | n01503061 5 bird 6 | n02924116 6 bus 7 | n02958343 7 car 8 | n02402425 8 cattle 9 | n02084071 9 dog 10 | n02121808 10 domestic_cat 11 | n02503517 11 elephant 12 | n02118333 12 fox 13 | n02510455 13 giant_panda 14 | n02342885 14 hamster 15 | n02374451 15 horse 16 | n02129165 16 lion 17 | n01674464 17 lizard 18 | n02484322 18 monkey 19 | n03790512 19 motorcycle 20 | n02324045 20 rabbit 21 | n02509815 21 red_panda 22 | n02411705 22 sheep 23 | n01726692 23 snake 24 | n02355227 24 squirrel 25 | n02129604 25 tiger 26 | n04468005 26 train 27 | n01662784 27 turtle 28 | n04530566 28 watercraft 29 | n02062744 29 whale 30 | n02391049 30 zebra -------------------------------------------------------------------------------- /build_tfrecords/build_data_vid.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Tensorflow implementation of 3 | # "Learning Dynamic Memory Networks for Object Tracking", ECCV,2018 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Tianyu Yang (tianyu-yang.com) 6 | # ------------------------------------------------------------------ 7 | 8 | import pickle 9 | from PIL import Image, ImageDraw 10 | import numpy as np 11 | import os 12 | import tensorflow as tf 13 | import threading 14 | from datetime import datetime 15 | import sys 16 | sys.path.append('../') 17 | import config 18 | import time 19 | 20 | class Vid(): 21 | pass 22 | 23 | class EncodeJpeg(): 24 | def __init__(self, sess): 25 | self.img = tf.placeholder(dtype=tf.uint8) 26 | self.img_buffer = tf.image.encode_jpeg(self.img) 27 | self.sess = sess 28 | def get_img_buffer(self, img): 29 | img_buffer = self.sess.run(self.img_buffer, {self.img: img}) 30 | return img_buffer 31 | 32 | def partition_vid(f_vid_list, num_threads): 33 | """To make it evenly distributed""" 34 | f_vid_list = sorted(f_vid_list, key=lambda x: x.n_frame) 35 | n_vid = len(f_vid_list) 36 | part_num = int(n_vid / num_threads) 37 | remain_num = n_vid - part_num * num_threads 38 | 39 | ranges_list = [[] for _ in range(num_threads)] 40 | for i in range(num_threads): 41 | for j in range(part_num): 42 | if j < int(part_num / 2): 43 | ranges_list[i].append(j * num_threads + i + remain_num) 44 | else: 45 | j_rev = part_num - (j - int(part_num / 2)) 46 | ranges_list[i].append(j_rev * num_threads - i - 1 + remain_num) 47 | for i in range(remain_num): 48 | ranges_list[i].append(i) 49 | 50 | # Launch a thread for each spacing. 51 | print('Launching %d threads' % (num_threads)) 52 | 53 | ranges_list_flat = sum(ranges_list, []) 54 | unique_value = np.unique(np.array(ranges_list_flat)) 55 | assert len(unique_value) == n_vid 56 | frames = [] 57 | for ranges in ranges_list: 58 | frame_sum = 0 59 | for x in ranges: 60 | frame_sum += f_vid_list[x].n_frame 61 | frames.append(frame_sum) 62 | 63 | print('Total frames for each threads\n', frames) 64 | return ranges_list 65 | 66 | def build_tfrecords(vidb_f, data_path, num_threads, encode_jpeg): 67 | vidb = pickle.load(open(vidb_f, 'rb')) 68 | 69 | n_vid = len(vidb) 70 | part_idexes = partition_vid(vidb, num_threads) 71 | 72 | sys.stdout.flush() 73 | coord = tf.train.Coordinator() 74 | 75 | threads = [] 76 | data_name = os.path.basename(data_path) 77 | if not os.path.exists(config.tfrecords_path): 78 | os.makedirs(config.tfrecords_path) 79 | 80 | for i, thread_idxes in enumerate(part_idexes): 81 | output_filename = '%s-%.3d-of-%.3d.tfrecords' % (data_name, i, num_threads) 82 | output_file = os.path.join(config.tfrecords_path, output_filename) 83 | writer = tf.python_io.TFRecordWriter(output_file) 84 | args = (thread_idxes, vidb, data_path, writer, encode_jpeg) 85 | t = threading.Thread(target=process_videos, args=args) 86 | t.start() 87 | threads.append(t) 88 | 89 | # Wait for all the threads to terminate. 90 | coord.join(threads) 91 | print('%s: Finished writing all %d videos in data set which contains %d valid objects in total.' % 92 | (datetime.now(), n_vid, n_valid_objs)) 93 | sys.stdout.flush() 94 | 95 | 96 | def process_videos_area(video_idxes, vidb, data_path, writer, encode_jpeg): 97 | for idx in video_idxes: 98 | one_video = vidb[idx] 99 | img_buffers_ids = [[] for _ in range(config.max_trackid)] 100 | bboxes_ids = [[] for _ in range(config.max_trackid)] 101 | occupy_area_ids = [[] for _ in range(config.max_trackid)] 102 | vid_name = one_video.dir 103 | for objs in one_video.objs: 104 | sum_area = 0 105 | len_area = 0 106 | for j, obj in enumerate(objs): 107 | if obj is not None: 108 | occupy_area = (obj.xmax - obj.xmin) * (obj.ymax - obj.ymin) / (obj.width * obj.height) 109 | sum_area += occupy_area 110 | len_area += 1 111 | roi_patch, bbox = process(obj, data_path) 112 | img_buffer = encode_jpeg.get_img_buffer(roi_patch) 113 | img_buffers_ids[j].append(img_buffer) 114 | bboxes_ids[j].append(bbox.tolist()) 115 | occupy_area_ids[j].append(occupy_area) 116 | 117 | for id, area_id in enumerate(occupy_area_ids): 118 | sum_area = 0 119 | len_area = 0 120 | for area in area_id: 121 | sum_area += area 122 | len_area += 1 123 | if len_area>0: 124 | avg_area = sum_area / len_area 125 | else: 126 | avg_area = 0 127 | 128 | if avg_area > 0.25: 129 | img_buffers_ids[id] = [] 130 | bboxes_ids[id] = [] 131 | 132 | valid_objs = save_to_tfrecords(img_buffers_ids, bboxes_ids, vid_name, writer) 133 | # if is_train: 134 | # valid_objs = save_to_tfrecords(img_buffers_ids, bboxes_ids, vid_name, writer, is_train) 135 | # else: 136 | # valid_objs = save_to_tfrecords_eval(img_buffers_ids, bboxes_ids, vid_name, writer) 137 | lock.acquire() 138 | global n_valid_objs 139 | n_valid_objs += valid_objs 140 | lock.release() 141 | print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 142 | 'Finish processing video: %s video id: %4d, record %2d valid objects' % ( 143 | vid_name, one_video.id, valid_objs)) 144 | 145 | def process_videos(video_idxes, vidb, data_path, writer, encode_jpeg): 146 | for idx in video_idxes: 147 | one_video = vidb[idx] 148 | img_buffers_ids = [[] for _ in range(config.max_trackid)] 149 | bboxes_ids = [[] for _ in range(config.max_trackid)] 150 | vid_name = one_video.dir 151 | for objs in one_video.objs: 152 | for j, obj in enumerate(objs): 153 | if obj is not None: 154 | roi_patch, bbox = process(obj, data_path) 155 | img_buffer = encode_jpeg.get_img_buffer(roi_patch) 156 | img_buffers_ids[j].append(img_buffer) 157 | bboxes_ids[j].append(bbox.tolist()) 158 | valid_objs = save_to_tfrecords(img_buffers_ids, bboxes_ids, vid_name, writer) 159 | # if is_train: 160 | # valid_objs = save_to_tfrecords(img_buffers_ids, bboxes_ids, vid_name, writer, is_train) 161 | # else: 162 | # valid_objs = save_to_tfrecords_eval(img_buffers_ids, bboxes_ids, vid_name, writer) 163 | lock.acquire() 164 | global n_valid_objs 165 | n_valid_objs += valid_objs 166 | lock.release() 167 | print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 168 | 'Finish processing video: %s video id: %4d, record %2d valid objects' %(vid_name, one_video.id, valid_objs)) 169 | 170 | def process(obj, data_path): 171 | img = Image.open(os.path.join(data_path, obj.img_path)) 172 | avg_color = tuple(np.mean(np.array(img), (0,1)).astype(int)) 173 | 174 | # calculate roi region 175 | bb_c = np.array([(obj.xmin + obj.xmax)/2, (obj.ymin + obj.ymax)/2]) 176 | bb_size = np.array([obj.xmax - obj.xmin + 1, obj.ymax - obj.ymin + 1]) 177 | if config.fix_aspect: 178 | extend_size = bb_size + config.context_amount * (bb_size[0] + bb_size[1]) 179 | z_size = np.sqrt(np.prod(extend_size)) 180 | else: 181 | z_size = bb_size * config.z_scale 182 | z_scale = config.z_exemplar_size / z_size 183 | delta_size = config.patch_size - config.z_exemplar_size 184 | x_size = delta_size / z_scale + z_size 185 | x_sclae = config.patch_size / x_size 186 | roi = np.floor(np.concatenate([bb_c - (x_size - 1) / 2, bb_c + (x_size - 1) / 2], 0)).astype(int) 187 | 188 | bbox = np.array([obj.xmin, obj.ymin, obj.xmax, obj.ymax], dtype=np.float32) 189 | 190 | # calculate the image padding 191 | img_size = np.array([img.width, img.height]) 192 | img_pad_xymin = np.maximum(0, -roi[0:2]) 193 | img_pad_xymax = np.maximum(0, roi[2:4]-img_size+1) 194 | 195 | # if need padding 196 | if np.any(img_pad_xymin) or np.any(img_pad_xymax): 197 | pad_img_size = img_pad_xymax+img_pad_xymin+img_size 198 | img_pad = Image.new(img.mode, tuple(pad_img_size), avg_color) 199 | img_pad.paste(img, tuple(img_pad_xymin)) 200 | 201 | # shift roi coordinate 202 | shift_xy = np.tile(img_pad_xymin, 2) 203 | roi += shift_xy 204 | bbox += shift_xy 205 | else: 206 | img_pad = img 207 | 208 | roi_patch = img_pad.crop(roi) 209 | roi_patch = roi_patch.resize([config.patch_size, config.patch_size]) 210 | 211 | # shift bbox relative to roi 212 | roi_patch_size = np.array([config.patch_size, config.patch_size]) 213 | bb_c_on_roi = (roi_patch_size - 1) / 2 214 | bb_size_on_roi = np.floor(bb_size * x_sclae) 215 | bbox = np.hstack([bb_c_on_roi - (bb_size_on_roi-1)/2, bb_c_on_roi + (bb_size_on_roi-1)/2]) 216 | 217 | # img_draw = ImageDraw.Draw(roi_patch) 218 | # img_draw.rectangle(bbox.tolist(), outline=(255,0,0)) 219 | # roi_patch.show() 220 | return np.array(roi_patch), bbox 221 | 222 | def save_to_tfrecords(img_buffers_ids, bboxes_ids, seq_name, writer): 223 | 224 | valid_objs = 0 225 | for id, (imgs_per_id, bboxes_per_id) in enumerate(zip(img_buffers_ids, bboxes_ids)): 226 | seq_len = len(imgs_per_id) 227 | 228 | if seq_len < config.min_frames: 229 | continue 230 | valid_objs += 1 231 | example = convert_to_example(imgs_per_id, bboxes_per_id, seq_name, seq_len, id) 232 | writer.write(example.SerializeToString()) 233 | return valid_objs 234 | 235 | def convert_to_example(img_buffers, bboxes, seq_name, seq_len, trackid): 236 | 237 | context = tf.train.Features(feature={ 238 | 'seq_name': _bytes_feature(seq_name.encode('utf-8')), 239 | 'seq_len': _int64_feature(seq_len), 240 | 'trackid': _int64_feature(trackid) 241 | }) 242 | feature_lists = tf.train.FeatureLists(feature_list={ 243 | 'images': _bytes_feature_list(img_buffers), 244 | 'bboxes': _float_feature_list(bboxes) 245 | }) 246 | example = tf.train.SequenceExample(context=context, feature_lists=feature_lists) 247 | 248 | return example 249 | 250 | def _int64_feature(value): 251 | if not isinstance(value, list): 252 | value = [value] 253 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 254 | 255 | def _float_feature(value): 256 | if not isinstance(value, list): 257 | value = [value] 258 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 259 | 260 | def _bytes_feature(value): 261 | if not isinstance(value, list): 262 | value = [value] 263 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 264 | 265 | def _float_feature_list(values): 266 | return tf.train.FeatureList(feature=[_float_feature(v) for v in values]) 267 | 268 | def _bytes_feature_list(values): 269 | return tf.train.FeatureList(feature=[_bytes_feature(v) for v in values]) 270 | 271 | 272 | if __name__=='__main__': 273 | 274 | lock = threading.Lock() 275 | 276 | config_proto = tf.ConfigProto() 277 | config_proto.gpu_options.allow_growth = True 278 | with tf.Graph().as_default(), tf.Session(config=config_proto) as sess: 279 | encode_jpeg = EncodeJpeg(sess) 280 | global n_valid_objs 281 | n_valid_objs = 0 282 | print('Start building training tfrecords.......') 283 | t_start = time.time() 284 | build_tfrecords(config.vidb_t, config.data_path_t, config.num_threads_t, encode_jpeg) 285 | t_end = time.time() 286 | print('The time for building training tfrecords is %f seconds:' % (t_end - t_start)) 287 | 288 | n_valid_objs = 0 289 | print('Start building validation tfrecords.......') 290 | t_start = time.time() 291 | build_tfrecords(config.vidb_v, config.data_path_v, config.num_threads_v, encode_jpeg) 292 | t_end = time.time() 293 | print('The time for building validation tfrecords is %f seconds:' % (t_end - t_start)) -------------------------------------------------------------------------------- /build_tfrecords/collect_vid_info.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Tensorflow implementation of 3 | # "Learning Dynamic Memory Networks for Object Tracking", ECCV,2018 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Tianyu Yang (tianyu-yang.com) 6 | # ------------------------------------------------------------------ 7 | 8 | import os 9 | import sys 10 | sys.path.append('../') 11 | import config 12 | import time 13 | 14 | 15 | def collect_video_info(data_path, save_path, is_train): 16 | dirs1 = sorted(os.listdir(data_path)) 17 | 18 | video_id = 0 19 | f_vid_info = open(save_path, 'w') 20 | if is_train: 21 | for dir1 in dirs1: 22 | dirs2 = sorted(os.listdir(os.path.join(data_path, dir1))) 23 | for dir2 in dirs2: 24 | files = os.listdir(os.path.join(data_path, dir1, dir2)) 25 | video_dir = os.path.join(dir1, dir2) 26 | video_id += 1 27 | n_frames = len(files) 28 | f_vid_info.write('%s %d %d\n' %(video_dir, video_id, n_frames)) 29 | else: 30 | for dir1 in dirs1: 31 | files = os.listdir(os.path.join(data_path, dir1)) 32 | video_id += 1 33 | n_frames = len(files) 34 | f_vid_info.write('%s %d %d\n' %(dir1, video_id, n_frames)) 35 | 36 | f_vid_info.close() 37 | 38 | 39 | if __name__=='__main__': 40 | t_start = time.time() 41 | collect_video_info(config.data_path_t, config.vid_info_t, True) 42 | t_end = time.time() 43 | print('The time for collecting training information is %f seconds:' % (t_end - t_start)) 44 | 45 | t_start = time.time() 46 | collect_video_info(config.data_path_v, config.vid_info_v, False) 47 | t_end = time.time() 48 | print('The time for collecting training information is %f seconds:' % (t_end - t_start)) -------------------------------------------------------------------------------- /build_tfrecords/generate_vidb.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Tensorflow implementation of 3 | # "Learning Dynamic Memory Networks for Object Tracking", ECCV,2018 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Tianyu Yang (tianyu-yang.com) 6 | # ------------------------------------------------------------------ 7 | 8 | import sys 9 | sys.path.append('../') 10 | import os 11 | from build_tfrecords.process_xml import process_xml 12 | import pickle 13 | import time 14 | from datetime import datetime 15 | import config 16 | 17 | class Vid(object): 18 | pass 19 | 20 | def generate_vidb(root_anno_path, vid_info_path, vidb_f): 21 | 22 | f_vid = open(vid_info_path) 23 | vid_info = f_vid.read().split('\n') 24 | if vid_info[-1] == '': 25 | vid_info.pop() 26 | 27 | vidb = [] 28 | object_num = 0 29 | frame_num = 0 30 | for line in vid_info: 31 | frame_info = line.split(' ') 32 | vid_dir = frame_info[0] 33 | vid_id = int(frame_info[1]) 34 | vid_n_frame = int(frame_info[2]) 35 | vid = Vid() 36 | anno_path = os.path.join(root_anno_path, vid_dir) 37 | xml_files = sorted(os.listdir(anno_path)) 38 | objs_one_video = [] 39 | for xml_file in xml_files: 40 | frame_num += 1 41 | bboxes = process_xml(os.path.join(anno_path, xml_file)) 42 | objs_one_frame = config.max_trackid * [None] 43 | for obj in bboxes: 44 | object_num += 1 45 | id = obj.trackid 46 | if id >= config.max_trackid: 47 | print(obj.img_path) 48 | objs_one_frame[id] = obj 49 | objs_one_video.append(objs_one_frame) 50 | vid.objs = objs_one_video 51 | vid.dir = vid_dir 52 | vid.id = vid_id 53 | vid.n_frame = vid_n_frame 54 | vidb.append(vid) 55 | print(datetime.now(), 'Finish video %d' %vid_id) 56 | print('Starting pickle the vidb into file') 57 | pickle.dump(vidb, open(vidb_f, 'wb')) 58 | 59 | return object_num, frame_num 60 | if __name__ == '__main__': 61 | 62 | t_start = time.time() 63 | object_num_t, frame_num_t = generate_vidb(config.anno_path_t, config.vid_info_t, config.vidb_t) 64 | t_end = time.time() 65 | print('The time for generating training imdb is %f seconds:' %(t_end - t_start)) 66 | 67 | t_start = time.time() 68 | object_num_v, frame_num_v = generate_vidb(config.anno_path_v, config.vid_info_v, config.vidb_v) 69 | t_end = time.time() 70 | print('The time for generating validation imdb is %f seconds:' %(t_end - t_start)) 71 | 72 | f = open('VID_Info/vid_summary.txt', 'w') 73 | f.write('Train\n object_num: %d frame_num: %d\n\nVal\n object_num: %d frame_num: %d' 74 | % (object_num_t, frame_num_t, object_num_v, frame_num_v)) 75 | -------------------------------------------------------------------------------- /build_tfrecords/process_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | VID_DIR='./VID_Info' 3 | VID_INFO_TRAIN='vid_info_train.txt' 4 | VID_INFO_VAL='vid_info_val.txt' 5 | VIDB_TRAIN='vidb_train.pk' 6 | VIDB_VAL='vidb_val.pk' 7 | 8 | if [ ! -d $VID_DIR ]; then 9 | mkdir $VID_DIR 10 | fi 11 | 12 | if [ ! -f $VID_DIR/$VID_INFO_TRAIN ] || [ ! -f $VID_DIR/$VID_INFO_VAL ]; then 13 | echo 'Start collecting video data information.....' 14 | python3 collect_vid_info.py 15 | fi 16 | 17 | if [ ! -f $VID_DIR/VIDB_TRAIN ] || [ ! -f $VID_DIR/$VIDB_VAL ]; then 18 | echo 'Start generating video data base.....' 19 | python3 generate_vidb.py 20 | fi 21 | 22 | echo 'Start building tfrecords.....' 23 | python3 build_data_vid.py -------------------------------------------------------------------------------- /build_tfrecords/process_xml.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Tensorflow implementation of 3 | # "Learning Dynamic Memory Networks for Object Tracking", ECCV,2018 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Tianyu Yang (tianyu-yang.com) 6 | # ------------------------------------------------------------------ 7 | 8 | import xml.etree.ElementTree as ET 9 | import sys 10 | import os 11 | 12 | 13 | class BoundingBox(object): 14 | pass 15 | 16 | def get_item(name, root, index=0): 17 | count = 0 18 | for item in root.iter(name): 19 | if count == index: 20 | return item.text 21 | count += 1 22 | # Failed to find "index" occurrence of item. 23 | return -1 24 | 25 | 26 | def get_int(name, root, index=0): 27 | return int(get_item(name, root, index)) 28 | 29 | 30 | def find_num_bb(root): 31 | index = 0 32 | while True: 33 | if get_int('xmin', root, index) == -1: 34 | break 35 | index += 1 36 | return index 37 | 38 | def process_xml(xml_file): 39 | """Process a single XML file containing a bounding box.""" 40 | try: 41 | tree = ET.parse(xml_file) 42 | except Exception: 43 | print('Failed to parse: ' + xml_file, file=sys.stderr) 44 | return None 45 | 46 | root = tree.getroot() 47 | num_boxes = find_num_bb(root) 48 | boxes = [] 49 | for index in range(num_boxes): 50 | box = BoundingBox() 51 | # Grab the 'index' annotation. 52 | box.xmin = get_int('xmin', root, index) 53 | box.ymin = get_int('ymin', root, index) 54 | box.xmax = get_int('xmax', root, index) 55 | box.ymax = get_int('ymax', root, index) 56 | box.trackid = get_int('trackid', root, index) 57 | 58 | file_name = get_item('filename', root) + '.JPEG' 59 | folder = get_item('folder', root) 60 | 61 | box.width = get_int('width', root) 62 | box.height = get_int('height', root) 63 | box.img_path = os.path.join(folder, file_name) 64 | box.label = get_item('name', root) 65 | boxes.append(box) 66 | 67 | return boxes -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Tensorflow implementation of 3 | # "Learning Dynamic Memory Networks for Object Tracking", ECCV,2018 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Tianyu Yang (tianyu-yang.com) 6 | # ------------------------------------------------------------------ 7 | 8 | import os 9 | import socket 10 | 11 | #================= data preprocessing ========================== 12 | home_path = '/home/tianyu' 13 | root_path = home_path+'/Data/ILSVRC' 14 | tfrecords_path = home_path+'/Data/ILSVRC-TF' 15 | otb_data_dir = home_path+'/Data/Benchmark/OTB' 16 | data_path_t = os.path.join(root_path, 'Data/VID/train') 17 | data_path_v = os.path.join(root_path, 'Data/VID/val') 18 | anno_path_t = os.path.join(root_path, 'Annotations/VID/train/') 19 | anno_path_v = os.path.join(root_path, 'Annotations/VID/val/') 20 | 21 | vid_info_t = './VID_Info/vid_info_train.txt' 22 | vid_info_v = './VID_Info/vid_info_val.txt' 23 | vidb_t = './VID_Info/vidb_train.pk' 24 | vidb_v = './VID_Info/vidb_val.pk' 25 | 26 | max_trackid = 50 27 | min_frames = 50 28 | 29 | num_threads_t = 16 30 | num_threads_v = 2 31 | 32 | patch_size = 255+2*8 33 | 34 | fix_aspect = True 35 | enlarge_patch = True 36 | if fix_aspect: 37 | context_amount = 0.5 38 | else: 39 | z_scale = 2 40 | 41 | #========================== data input ============================ 42 | min_queue_examples = 500 43 | num_readers = 2 44 | num_preprocess_threads = 8 45 | 46 | z_exemplar_size = 127 47 | x_instance_size = 255 48 | 49 | is_limit_search = False 50 | max_search_range = 200 51 | 52 | is_augment = True 53 | max_strech_x = 0.05 54 | max_translate_x = 4 55 | max_strech_z = 0.1 56 | max_translate_z = 8 57 | 58 | label_type= 0 # 0: overlap: 1 dist 59 | overlap_thres = 0.7 60 | dist_thre = 2 61 | 62 | #========================== Memnet =============================== 63 | hidden_size = 512 64 | memory_size = 8 65 | slot_size = [6, 6, 256] 66 | usage_decay = 0.99 67 | 68 | clip_gradients = 20.0 69 | keep_prob = 0.8 70 | weight_decay = 0.0001 71 | use_attention_read = False 72 | use_fc_key = False 73 | key_dim = 256 74 | 75 | 76 | #========================== train ================================= 77 | batch_size = 8 78 | time_step = 16 79 | 80 | decay_circles = 10000 81 | lr_decay = 0.8 82 | learning_rate = 0.0001 83 | use_focal_loss = False 84 | 85 | summaries_dir = 'output/summary/' 86 | checkpoint_dir = 'output/models/' 87 | 88 | summary_save_step = 500 89 | model_save_step = 5000 90 | validate_step = 5000 91 | max_iterations = 100000 92 | summary_display_step = 8 93 | 94 | #========================== evaluation ================================== 95 | batch_size_eval = 2 96 | time_step_eval = 48 97 | num_example_epoch_eval = 1073 98 | max_iterations_eval = num_example_epoch_eval//batch_size_eval 99 | 100 | #========================== tracking ==================================== 101 | num_scale = 3 102 | scale_multipler = 1.05 103 | scale_penalty = 0.97 104 | scale_damp = 0.6 105 | 106 | response_up = 16 107 | response_size = 17 108 | window = 'cosine' 109 | win_weights = 0.15 110 | stride = 8 111 | avg_num = 1 112 | 113 | is_save = False 114 | save_path = './tracking/snapshots' 115 | -------------------------------------------------------------------------------- /estimator.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Tensorflow implementation of 3 | # "Learning Dynamic Memory Networks for Object Tracking", ECCV,2018 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Tianyu Yang (tianyu-yang.com) 6 | # ------------------------------------------------------------------ 7 | 8 | import tensorflow as tf 9 | from model import ModeKeys 10 | import config 11 | import time 12 | import os 13 | 14 | class Estimator(): 15 | 16 | def __init__(self, train_input_fn, eval_input_fn, model_fn): 17 | 18 | self._train_input_fn = train_input_fn 19 | self._eval_input_fn = eval_input_fn 20 | self._model_fn = model_fn 21 | tf.set_random_seed(1234) 22 | self._max_patience = 10 * config.validate_step 23 | self._best_value = None 24 | self._best_step = None 25 | self.build_eval() 26 | 27 | def train(self): 28 | config_proto = tf.ConfigProto() 29 | config_proto.gpu_options.allow_growth = True 30 | with tf.Graph().as_default(), tf.Session(config=config_proto) as sess: 31 | features, labels = self._train_input_fn() 32 | train_spec = self._model_fn(features, labels, ModeKeys.TRAIN) 33 | summary_writer = tf.summary.FileWriter(config.summaries_dir + 'train', sess.graph) 34 | 35 | global_step = tf.train.get_or_create_global_step() 36 | sess.run(tf.global_variables_initializer()) 37 | coord = tf.train.Coordinator() 38 | enqueue_threads = tf.train.start_queue_runners(sess, coord=coord) 39 | 40 | idx = sess.run(global_step) + 1 41 | while not coord.should_stop() and idx <= config.max_iterations: 42 | start_time = time.time() 43 | 44 | if idx % config.summary_save_step == 0: 45 | run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) 46 | run_metadata = tf.RunMetadata() 47 | summary, dist_error, loss, _ = sess.run( 48 | [train_spec.summary, train_spec.dist_error, train_spec.loss, train_spec.train], 49 | options=run_options, 50 | run_metadata=run_metadata) 51 | 52 | summary_writer.add_run_metadata(run_metadata, 'step%03d' % idx) 53 | summary_writer.add_summary(summary, idx) 54 | print('Adding run metadata for', idx) 55 | else: 56 | dist_error, loss, _ = sess.run([train_spec.dist_error, train_spec.loss, train_spec.train]) 57 | 58 | print("Step: %d Loss: %f, Dist error: %f Speed: %.0f examples per second" % 59 | (idx, loss, dist_error, config.batch_size * config.time_step / (time.time() - start_time))) 60 | 61 | if idx % config.model_save_step == 0 or idx == config.max_iterations or idx % config.validate_step == 0: 62 | checkpoint_path = os.path.join(config.checkpoint_dir, 'model.ckpt') 63 | train_spec.saver.save(sess, checkpoint_path, global_step=idx, write_meta_graph=False) 64 | print('Save to checkpoint at step %d' % (idx)) 65 | 66 | if idx % config.validate_step == 0: 67 | if self.evaluate(idx, 'loss'): 68 | coord.request_stop() 69 | 70 | idx = sess.run(tf.train.get_or_create_global_step()) + 1 71 | 72 | summary_writer.close() 73 | coord.join(enqueue_threads) 74 | 75 | def build_eval(self): 76 | 77 | with tf.Graph().as_default() as graph: 78 | features, labels = self._eval_input_fn() 79 | self._eval_spec = self._model_fn(features, labels, ModeKeys.EVAL) 80 | self._eval_summary_writer = tf.summary.FileWriter(config.summaries_dir + 'eval', graph) 81 | self._eval_graph = graph 82 | 83 | def evaluate(self, global_step, stop_metric='loss'): 84 | 85 | config_proto = tf.ConfigProto() 86 | config_proto.gpu_options.allow_growth = True 87 | with self._eval_graph.as_default(), tf.Session(config=config_proto) as sess: 88 | ckpt = tf.train.get_checkpoint_state(config.checkpoint_dir) 89 | if ckpt and ckpt.model_checkpoint_path: 90 | self._eval_spec.saver.restore(sess, ckpt.model_checkpoint_path) 91 | print('Checkpoint restored from %s' % (config.checkpoint_dir)) 92 | 93 | coord = tf.train.Coordinator() 94 | enqueue_threads = tf.train.start_queue_runners(sess, coord=coord) 95 | 96 | totoal_dist_error = 0 97 | totoal_loss = 0 98 | i = 0 99 | print('Starting validate current network......') 100 | while i < config.max_iterations_eval: 101 | dist_error, loss = sess.run([self._eval_spec.dist_error, self._eval_spec.loss]) 102 | totoal_dist_error += dist_error 103 | totoal_loss += loss 104 | i += 1 105 | print('Examples %d dist error: %f loss: %f' % (i, dist_error, loss)) 106 | 107 | coord.request_stop() 108 | coord.join(enqueue_threads) 109 | avg_dist_error = totoal_dist_error / config.max_iterations_eval 110 | avg_loss = totoal_loss / config.max_iterations_eval 111 | print('val_dist_error: %f' % (avg_dist_error)) 112 | print('val_loss: %f' % (avg_loss)) 113 | 114 | summary = tf.Summary() 115 | # summary.ParseFromString(sess.run(self._eval_spec.summary)) 116 | summary.value.add(tag='dist_error', simple_value=avg_dist_error) 117 | summary.value.add(tag='loss', simple_value=avg_loss) 118 | self._eval_summary_writer.add_summary(summary, global_step) 119 | 120 | coord.request_stop() 121 | coord.join(enqueue_threads) 122 | 123 | if stop_metric == 'loss': 124 | value = avg_loss 125 | elif stop_metric == 'dist_error': 126 | value = avg_dist_error 127 | else: 128 | value = avg_dist_error 129 | 130 | if (self._best_value is None) or \ 131 | (value < self._best_value): 132 | self._best_value = value 133 | self._best_step = global_step 134 | 135 | should_stop = (global_step - self._best_step >= self._max_patience) 136 | if should_stop: 137 | print('Stopping... Best step: {} with {} = {}.' \ 138 | .format(self._best_step, stop_metric, self._best_value)) 139 | return should_stop 140 | 141 | -------------------------------------------------------------------------------- /experiment.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Tensorflow implementation of 3 | # "Learning Dynamic Memory Networks for Object Tracking", ECCV,2018 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Tianyu Yang (tianyu-yang.com) 6 | # ------------------------------------------------------------------ 7 | 8 | import tensorflow as tf 9 | from input import generate_input_fn 10 | from model import model_fn 11 | from estimator import Estimator 12 | import config 13 | 14 | 15 | def experiment(): 16 | 17 | train_input_fn = generate_input_fn( 18 | is_train=True, 19 | tfrecords_path=config.tfrecords_path, 20 | batch_size=config.batch_size, 21 | time_step=config.time_step) 22 | 23 | eval_input_fn = generate_input_fn( 24 | is_train=False, 25 | tfrecords_path=config.tfrecords_path, 26 | batch_size=config.batch_size_eval, 27 | time_step=config.time_step_eval) 28 | 29 | estimator = Estimator( 30 | train_input_fn=train_input_fn, 31 | eval_input_fn=eval_input_fn, 32 | model_fn=model_fn) 33 | 34 | estimator.train() 35 | 36 | 37 | if __name__ == '__main__': 38 | if tf.gfile.Exists(config.summaries_dir): 39 | tf.gfile.DeleteRecursively(config.summaries_dir) 40 | tf.gfile.MakeDirs(config.summaries_dir) 41 | 42 | experiment() -------------------------------------------------------------------------------- /feature.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Tensorflow implementation of 3 | # "Learning Dynamic Memory Networks for Object Tracking", ECCV,2018 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Tianyu Yang (tianyu-yang.com) 6 | # ------------------------------------------------------------------ 7 | 8 | import tensorflow as tf 9 | import config 10 | 11 | def extract_feature(is_train, img_patch): 12 | 13 | conv1 = conv2d_bn_relu(is_train, img_patch, 96, [11, 11], [2, 2], 'valid', name='conv1') 14 | pool1 = tf.layers.max_pooling2d(conv1, [3, 3], [2, 2], 'valid', name='pool1') 15 | conv2 = conv2d_bn_relu(is_train, pool1, 256, [5, 5], [1, 1], 'valid', name='conv2') 16 | pool2 = tf.layers.max_pooling2d(conv2, [3, 3], [2, 2], 'valid', name='pool2') 17 | conv3 = conv2d_bn_relu(is_train, pool2, 384, [3, 3], [1, 1], 'valid', name='conv3') 18 | conv4 = conv2d_bn_relu(is_train, conv3, 384, [3, 3], [1, 1], 'valid', name='conv4') 19 | conv5 = tf.layers.conv2d(conv4, 256, [3, 3], [1, 1], 'valid', name='conv5') 20 | 21 | return conv5 22 | 23 | def conv2d(input, filters, kernel_size, strides, padding, name, group=1): 24 | 25 | if group == 1: 26 | conv = tf.layers.conv2d(input, filters, kernel_size, strides, padding, name=name) 27 | else: 28 | input_group = tf.split(input, group, 3) 29 | conv_group = [tf.layers.conv2d(input, filters//group, kernel_size, strides, padding, name=name+'group_{}'.format(i)) 30 | for i, input in enumerate(input_group)] 31 | conv = tf.concat(conv_group, 3) 32 | return conv 33 | 34 | def conv2d_bn_relu(is_train, input, filters, kernel_size, strides, padding, name, group=1): 35 | 36 | conv = conv2d(input, filters, kernel_size, strides, padding, name, group) 37 | bn = tf.layers.batch_normalization(conv, training=is_train, name=name+'_bn') 38 | return tf.nn.relu(bn, name=name+'_relu') 39 | 40 | def bn_relu_conv2d(is_train, input, filters, kernel_size, strides, padding, name): 41 | 42 | bn = tf.layers.batch_normalization(input, training=is_train, name=name+'_bn') 43 | relu = tf.nn.relu(bn, name=name+'_relu') 44 | return tf.layers.conv2d(relu, filters, kernel_size, strides, padding, name=name) 45 | 46 | def get_key_feature(input, is_train, name): 47 | 48 | input_shape = input.get_shape().as_list() 49 | if len(input_shape) > 4: 50 | input = tf.reshape(input, [-1] + input_shape[2:]) 51 | 52 | if config.use_fc_key: 53 | contrloller_input = bn_relu_conv2d(is_train, input, config.key_dim, config.slot_size[0:2], [1, 1], 'valid', name=name) 54 | else: 55 | contrloller_input = tf.layers.average_pooling2d(input, config.slot_size[0:2], [1, 1], 'valid', name=name) 56 | 57 | if len(input_shape) > 4: 58 | c_shape = contrloller_input.get_shape().as_list() 59 | contrloller_input = tf.reshape(contrloller_input, input_shape[0:2]+c_shape[1:]) 60 | 61 | return contrloller_input 62 | -------------------------------------------------------------------------------- /input.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Tensorflow implementation of 3 | # "Learning Dynamic Memory Networks for Object Tracking", ECCV,2018 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Tianyu Yang (tianyu-yang.com) 6 | # ------------------------------------------------------------------ 7 | 8 | import glob 9 | import os 10 | import time 11 | import numpy as np 12 | import tensorflow as tf 13 | import config 14 | 15 | DEBUG = False 16 | 17 | def generate_input_fn(is_train, tfrecords_path, batch_size, time_step): 18 | "Return _input_fn for use with Experiment." 19 | def _input_fn(): 20 | 21 | with tf.device('/cpu:0'): 22 | query_patch, search_patch, bbox = _batch_input(is_train, tfrecords_path, batch_size, time_step) 23 | 24 | patches = { 25 | 'query': query_patch, 26 | 'search': search_patch, 27 | } 28 | return patches, bbox 29 | 30 | return _input_fn 31 | 32 | def _batch_input(is_train, tfrecords_path, batch_size, time_step): 33 | 34 | if is_train: 35 | tf_files = glob.glob(os.path.join(tfrecords_path, 'train-*.tfrecords')) 36 | filename_queue = tf.train.string_input_producer(tf_files, shuffle=True, capacity=16) 37 | 38 | min_queue_examples = config.min_queue_examples 39 | examples_queue = tf.RandomShuffleQueue( 40 | capacity=min_queue_examples + 3 * batch_size, 41 | min_after_dequeue=min_queue_examples, 42 | dtypes=[tf.string]) 43 | enqueue_ops = [] 44 | for _ in range(config.num_readers): 45 | _, value = tf.TFRecordReader().read(filename_queue) 46 | enqueue_ops.append(examples_queue.enqueue([value])) 47 | 48 | tf.train.add_queue_runner( 49 | tf.train.QueueRunner(examples_queue, enqueue_ops)) 50 | example_serialized = examples_queue.dequeue() 51 | else: 52 | tf_files = sorted(glob.glob(os.path.join(tfrecords_path, 'val-*.tfrecords'))) 53 | filename_queue = tf.train.string_input_producer(tf_files, shuffle=False, capacity=8) 54 | _, example_serialized = tf.TFRecordReader().read(filename_queue) 55 | # example_serialized = next(tf.python_io.tf_record_iterator(self._tf_files[0])) 56 | images_and_labels = [] 57 | for thread_id in range(config.num_preprocess_threads): 58 | sequence, context = _parse_example_proto(example_serialized) 59 | image_buffers = sequence['images'] 60 | bboxes = sequence['bboxes'] 61 | seq_len = tf.cast(context['seq_len'][0], tf.int32) 62 | z_exemplars, x_crops, y_crops = _process_images(image_buffers, bboxes, seq_len, thread_id, time_step, is_train) 63 | images_and_labels.append([z_exemplars, x_crops, y_crops]) 64 | 65 | batch_z, batch_x, batch_y = tf.train.batch_join(images_and_labels, 66 | batch_size=batch_size, 67 | capacity=2 * config.num_preprocess_threads * batch_size) 68 | if is_train: 69 | tf.summary.image('exemplars', batch_z[0], 5) 70 | tf.summary.image('crops', batch_x[0], 5) 71 | 72 | return batch_z, batch_x, batch_y 73 | 74 | def _process_images(image_buffers, bboxes, seq_len, thread_id, time_step, is_train): 75 | if config.is_limit_search: 76 | search_range = tf.minimum(config.max_search_range, seq_len - 1) 77 | else: 78 | search_range = seq_len-1 79 | rand_start_idx = tf.random_uniform([], 0, seq_len-search_range, dtype=tf.int32) 80 | selected_len = time_step + 1 81 | if is_train: 82 | frame_idxes = tf.range(rand_start_idx, rand_start_idx+search_range) 83 | shuffle_idxes = tf.random_shuffle(frame_idxes) 84 | selected_idxes = shuffle_idxes[0:selected_len] 85 | selected_idxes, _ = tf.nn.top_k(selected_idxes, selected_len) 86 | selected_idxes = selected_idxes[::-1] 87 | else: 88 | selected_idxes = tf.to_int32(tf.linspace(0.0, tf.to_float(seq_len - 1), selected_len)) 89 | # self.seq_len = seq_len 90 | # self.search_range = search_range 91 | # self.selected_idxes = selected_idxes 92 | z_exemplars, y_exemplars, x_crops, y_crops = [], [], [], [] 93 | shift = int((config.patch_size - config.z_exemplar_size) / 2) 94 | for i in range(selected_len): 95 | idx = selected_idxes[i] 96 | image_buffer = tf.gather(image_buffers, idx) 97 | image = tf.image.decode_jpeg(image_buffer, channels=3) 98 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 99 | image.set_shape([config.patch_size, config.patch_size, 3]) 100 | 101 | # # Randomly distort the colors. 102 | # if is_train: 103 | # image = _distort_color(image, thread_id) 104 | 105 | if i < time_step: 106 | # if self._is_train: 107 | exemplar = tf.image.crop_to_bounding_box(image, shift, shift, config.z_exemplar_size, 108 | config.z_exemplar_size) 109 | if config.is_augment and i > 0: 110 | exemplar = _translate_and_strech(image, 111 | [config.z_exemplar_size, config.z_exemplar_size], 112 | config.max_strech_z, config.max_translate_z) 113 | z_exemplars.append(exemplar) 114 | if i > 0: 115 | bbox = tf.gather(bboxes, idx) 116 | if config.is_augment: 117 | image, bbox = _translate_and_strech(image, [config.x_instance_size, config.x_instance_size], 118 | config.max_strech_x, config.max_translate_x, bbox) 119 | x_crops.append(image) 120 | y_crops.append(bbox) 121 | x_crops = tf.stack(x_crops, 0) 122 | y_crops = tf.stack(y_crops, 0) 123 | z_exemplars = tf.stack(z_exemplars, 0) 124 | return z_exemplars, x_crops, y_crops 125 | 126 | def _translate_and_strech(image, m_sz, max_strech, max_translate=None, bbox=None, rgb_variance=None): 127 | 128 | m_sz_f = tf.convert_to_tensor(m_sz, dtype=tf.float32) 129 | img_sz = tf.convert_to_tensor(image.get_shape().as_list()[0:2],dtype=tf.float32) 130 | scale = 1+max_strech*tf.random_uniform([2], -1, 1, dtype=tf.float32) 131 | scale_sz = tf.round(tf.minimum(scale*m_sz_f, img_sz)) 132 | 133 | if max_translate is None: 134 | shift_range = (img_sz - scale_sz) / 2 135 | else: 136 | shift_range = tf.minimum(float(max_translate), (img_sz-scale_sz)/2) 137 | 138 | start = (img_sz - scale_sz)/2 139 | shift_row = start[0] + tf.random_uniform([1], -shift_range[0], shift_range[0], dtype=tf.float32) 140 | shift_col = start[1] + tf.random_uniform([1], -shift_range[1], shift_range[1], dtype=tf.float32) 141 | 142 | x1 = shift_col/(img_sz[1]-1) 143 | y1 = shift_row/(img_sz[0]-1) 144 | x2 = (shift_col + scale_sz[1]-1)/(img_sz[1]-1) 145 | y2 = (shift_row + scale_sz[0]-1)/(img_sz[0]-1) 146 | crop_img = tf.image.crop_and_resize(tf.expand_dims(image,0), 147 | tf.expand_dims(tf.concat(axis=0, values=[y1, x1, y2, x2]), 0), 148 | [0], m_sz) 149 | crop_img = tf.squeeze(crop_img) 150 | if rgb_variance is not None: 151 | crop_img = crop_img + rgb_variance*tf.random_normal([1,1,3]) 152 | 153 | if bbox is not None: 154 | new_bbox = bbox - tf.concat(axis=0, values=[shift_col, shift_row, shift_col, shift_row]) 155 | scale_ratio = m_sz_f/tf.reverse(scale_sz, [0]) 156 | new_bbox = new_bbox*tf.tile(scale_ratio,[2]) 157 | return crop_img, new_bbox 158 | else: 159 | return crop_img 160 | 161 | def _distort_color(image, thread_id=0): 162 | """Distort the color of the image. 163 | """ 164 | color_ordering = thread_id % 2 165 | 166 | if color_ordering == 0: 167 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 168 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 169 | image = tf.image.random_hue(image, max_delta=0.2) 170 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 171 | elif color_ordering == 1: 172 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 173 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 174 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 175 | image = tf.image.random_hue(image, max_delta=0.2) 176 | 177 | # The random_* ops do not necessarily clamp. 178 | image = tf.clip_by_value(image, 0.0, 1.0) 179 | return image 180 | 181 | def _parse_example_proto(example_serialized): 182 | 183 | context_features = { 184 | 'seq_name': tf.FixedLenFeature([], dtype=tf.string), 185 | 'seq_len': tf.FixedLenFeature(1, dtype=tf.int64), 186 | 'trackid': tf.FixedLenFeature(1, dtype=tf.int64), 187 | } 188 | sequence_features = { 189 | 'images': tf.FixedLenSequenceFeature([],dtype=tf.string), 190 | 'bboxes': tf.FixedLenSequenceFeature([4],dtype=tf.float32) 191 | } 192 | context_parsed, sequence_parsed = tf.parse_single_sequence_example(example_serialized, context_features, sequence_features) 193 | 194 | return sequence_parsed, context_parsed 195 | 196 | 197 | def generate_labels_dist(batch_size, feat_size): 198 | dist = lambda i,j,orgin: np.linalg.norm(np.array([i,j])-orgin) 199 | labels = -np.ones(feat_size, dtype=np.int32) 200 | orgin = (np.array(feat_size) -1)/2 201 | for i in range(feat_size[0]): 202 | for j in range(feat_size[1]): 203 | distance = dist(i,j,orgin) 204 | if distance <= config.dist_thre: 205 | labels[i,j] = 1 206 | else: 207 | labels[i,j] = 0 208 | num_pos = np.count_nonzero(labels == 1) 209 | num_neg = np.count_nonzero(labels == 0) 210 | weights = np.zeros(feat_size, dtype=np.float32) 211 | weights[labels==1] = 0.5/num_pos 212 | weights[labels==0] = 0.5/num_neg 213 | batch_labels = np.tile(labels, [batch_size, 1, 1]) 214 | batch_weights = np.tile(weights, [batch_size, 1, 1]) 215 | return tf.convert_to_tensor(batch_labels, tf.float32), tf.convert_to_tensor(batch_weights) 216 | 217 | def generate_labels_overlap(feat_size, bboxes, neg_flag=0): 218 | bboxes = tf.reshape(bboxes, [-1, 4]) 219 | batch_labels, batch_weights = \ 220 | tf.py_func(_generate_labels_overlap_py, 221 | [feat_size, bboxes, (feat_size - 1)/2, neg_flag], 222 | [tf.float32, tf.float32]) 223 | bboxes_shape = bboxes.get_shape().as_list() 224 | batch_labels.set_shape([bboxes_shape[0]]+feat_size.tolist()) 225 | batch_weights.set_shape([bboxes_shape[0]]+feat_size.tolist()) 226 | return batch_labels, batch_weights 227 | 228 | def _generate_labels_overlap_py(feat_size, y_crops, orgin, neg_flag=0): 229 | orig_size = feat_size*config.stride 230 | x = np.arange(0, orig_size[0], config.stride)+config.stride/2 231 | y = np.arange(0, orig_size[1], config.stride)+config.stride/2 232 | x, y = np.meshgrid(x, y) 233 | orgin = orgin*config.stride + config.stride/2 234 | batch_labels, batch_weights, batch_keep = [], [], [] 235 | for gt_bb_cur in y_crops: 236 | gt_size_cur = gt_bb_cur[2:4] - gt_bb_cur[0:2] + 1 237 | gt_bb_cur_new = np.hstack([orgin - (gt_size_cur - 1) / 2, orgin + (gt_size_cur - 1) / 2]) 238 | sample_centers = np.vstack([x.ravel(), y.ravel(), x.ravel(), y.ravel()]).transpose() 239 | sample_bboxes = sample_centers + np.hstack([-(gt_size_cur-1)/2, (gt_size_cur-1)/2]) 240 | 241 | overlaps = _bbox_overlaps(sample_bboxes, gt_bb_cur_new) 242 | 243 | pos_idxes = overlaps > config.overlap_thres 244 | neg_idxes = overlaps < config.overlap_thres 245 | labels = -np.ones(np.prod(feat_size), dtype=np.float32) 246 | labels[pos_idxes] = 1 247 | labels[neg_idxes] = neg_flag 248 | labels = np.reshape(labels, feat_size) 249 | 250 | num_pos = np.count_nonzero(labels == 1) 251 | num_neg = np.count_nonzero(labels == neg_flag) 252 | 253 | if DEBUG: 254 | print(gt_bb_cur) 255 | print((gt_bb_cur[0:2]+gt_bb_cur[2:4])/2) 256 | print('Positive samples:', num_pos, 'Negative samples:', num_neg) 257 | from matplotlib import pyplot as plt 258 | plt.imshow(labels) 259 | # # plt.imshow(np.reshape(overlaps, feat_size)) 260 | plt.pause(1) 261 | 262 | weights = np.zeros(feat_size, dtype=np.float32) 263 | if num_pos != 0: 264 | weights[labels == 1] = 0.5 / num_pos 265 | if num_neg != 0: 266 | weights[labels == neg_flag] = 0.5 / num_neg 267 | batch_weights.append(np.expand_dims(weights, 0)) 268 | batch_labels.append(np.expand_dims(labels, 0)) 269 | 270 | batch_labels = np.concatenate(batch_labels, 0) 271 | batch_weights = np.concatenate(batch_weights, 0) 272 | return batch_labels, batch_weights 273 | 274 | def _bbox_overlaps(sample_bboxes, gt_bbox): 275 | lt = np.maximum(sample_bboxes[:, 0:2], gt_bbox[0:2]) 276 | rb = np.minimum(sample_bboxes[:, 2:4], gt_bbox[2:4]) 277 | inter_area = np.maximum(rb - lt + 1, 0) 278 | inter_area = np.prod(inter_area, 1) 279 | union_area = np.prod(sample_bboxes[:, 2:4] - sample_bboxes[:, 0:2] + 1, 1) + np.prod(gt_bbox[2:4]-gt_bbox[0:2]+1, 0) - inter_area 280 | return inter_area / union_area 281 | -------------------------------------------------------------------------------- /memnet/access.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Tensorflow implementation of 3 | # "Learning Dynamic Memory Networks for Object Tracking", ECCV,2018 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Tianyu Yang (tianyu-yang.com) 6 | # ------------------------------------------------------------------ 7 | 8 | import collections 9 | import tensorflow as tf 10 | from memnet.addressing import cosine_similarity, attention_read, update_usage, calc_allocation_weight 11 | from feature import get_key_feature 12 | import config 13 | 14 | AccessState = collections.namedtuple('AccessState', ( 15 | 'init_memory', 'memory', 'read_weight', 'write_weight', 'control_factors', 'write_decay', 'usage')) 16 | 17 | 18 | def _reset_and_write(memory, write_weight, write_decay, control_factors, values): 19 | 20 | weight_shape = write_weight.get_shape().as_list() 21 | write_weight = tf.reshape(write_weight, weight_shape+[1,1,1]) 22 | decay = write_decay*tf.expand_dims(control_factors[:, 1], 1) + tf.expand_dims(control_factors[:, 2], 1) 23 | decay_expand = tf.expand_dims(tf.expand_dims(tf.expand_dims(decay, 1), 2), 3) 24 | decay_weight = write_weight*decay_expand 25 | 26 | memory *= 1 - decay_weight 27 | values = tf.expand_dims(values, 1) 28 | memory += decay_weight * values 29 | 30 | return memory 31 | 32 | 33 | class MemoryAccess(tf.nn.rnn_cell.RNNCell): 34 | 35 | def __init__(self, memory_size, slot_size, is_train): 36 | super(MemoryAccess, self).__init__() 37 | self._memory_size = memory_size 38 | self._slot_size = slot_size 39 | self._is_train = is_train 40 | 41 | 42 | def __call__(self, inputs, prev_state, scope=None): 43 | 44 | memory_for_writing = inputs[0] 45 | controller_output = inputs[1] 46 | read_key, read_strength, control_factors, write_decay, residual_vector = self._transform_input(controller_output) 47 | 48 | # Write previous template to memory. 49 | memory = _reset_and_write(prev_state.memory, prev_state.write_weight, 50 | prev_state.write_decay, prev_state.control_factors, memory_for_writing) 51 | 52 | # Read from memory. 53 | read_weight = self._read_weights(read_key, read_strength, memory) 54 | read_weight_expand = tf.reshape(read_weight, [-1, self._memory_size, 1, 1, 1]) 55 | residual_vector = tf.reshape(residual_vector, [-1, 1, 1, 1, self._slot_size[2]]) 56 | read_memory = tf.reduce_sum(residual_vector*read_weight_expand*memory, [1]) 57 | 58 | # calculate the allocation weight 59 | allocation_weight = calc_allocation_weight(prev_state.usage, self._memory_size) 60 | 61 | # calculate the write weight for next frame writing 62 | write_weight = self._write_weights(control_factors, read_weight, allocation_weight) 63 | 64 | # update usage using read & write weights and previous usage 65 | usage = update_usage(write_weight, read_weight, prev_state.usage) 66 | 67 | # summary 68 | if int(scope) < config.summary_display_step: 69 | tf.summary.histogram('write_factor/{}'.format(scope), control_factors[:, 0]) 70 | tf.summary.histogram('read_factor/{}'.format(scope), control_factors[:, 1]) 71 | tf.summary.histogram('allocation_factor/{}'.format(scope), control_factors[:, 2]) 72 | tf.summary.histogram('residual_vector/{}'.format(scope), residual_vector) 73 | tf.summary.histogram('write_decay/{}'.format(scope), write_decay) 74 | tf.summary.histogram('read_key/{}'.format(scope), read_key) 75 | if not config.use_attention_read: 76 | tf.summary.histogram('read_strength/{}'.format(scope), read_strength) 77 | 78 | return read_memory+prev_state.init_memory, AccessState( 79 | init_memory=prev_state.init_memory, 80 | memory=memory, 81 | write_weight=write_weight, 82 | read_weight=read_weight, 83 | control_factors=control_factors, 84 | write_decay=write_decay, 85 | usage=usage) 86 | 87 | def _transform_input(self, input): 88 | 89 | control_factors = tf.nn.softmax(tf.layers.dense(input, 3, name='control_factors')) 90 | write_decay = tf.sigmoid(tf.layers.dense(input, 1, name='write_decay')) 91 | residual_vector = tf.sigmoid(tf.layers.dense(input, self._slot_size[2], name='add_vector')) 92 | 93 | read_key = tf.layers.dense(input, config.key_dim, name='read_key') 94 | if config.use_attention_read: 95 | read_strength = None 96 | else: 97 | read_strength = tf.layers.dense(input, 1, bias_initializer=tf.ones_initializer(), name='write_strengths') 98 | 99 | return read_key, read_strength, control_factors, write_decay, residual_vector 100 | 101 | def _write_weights(self, control_factors, read_weight, allocation_weight): 102 | 103 | return tf.expand_dims(control_factors[:, 1], 1) * read_weight + tf.expand_dims(control_factors[:, 2], 1) * allocation_weight 104 | 105 | def _read_weights(self, read_key, read_strength, memory): 106 | 107 | memory_key = tf.squeeze(get_key_feature(memory, self._is_train, 'memory_key'),[2,3]) 108 | if config.use_attention_read: 109 | return attention_read(read_key, memory_key) 110 | else: 111 | return cosine_similarity(memory_key, read_key, read_strength) 112 | 113 | 114 | @property 115 | def state_size(self): 116 | 117 | return AccessState(init_memory=tf.TensorShape([self._memory_size]+self._slot_size), 118 | memory=tf.TensorShape([self._memory_size]+self._slot_size), 119 | read_weight=tf.TensorShape([self._memory_size]), 120 | write_weight=tf.TensorShape([self._memory_size]), 121 | write_decay=tf.TensorShape([1]), 122 | control_factors=tf.TensorShape([3]), 123 | usage=tf.TensorShape([self._memory_size])) 124 | 125 | @property 126 | def output_size(self): 127 | 128 | return tf.TensorShape(self._slot_size) 129 | -------------------------------------------------------------------------------- /memnet/addressing.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Tensorflow implementation of 3 | # "Learning Dynamic Memory Networks for Object Tracking", ECCV,2018 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Tianyu Yang (tianyu-yang.com) 6 | # ------------------------------------------------------------------ 7 | 8 | import config 9 | import tensorflow as tf 10 | 11 | # Ensure values are greater than epsilon to avoid numerical instability. 12 | _EPSILON = 1e-6 13 | 14 | 15 | def _vector_norms(m): 16 | 17 | squared_norms = tf.reduce_sum(m * m, axis=2, keep_dims=True) 18 | return tf.sqrt(squared_norms + _EPSILON) 19 | 20 | def _weighted_softmax(activations, strengths, strengths_op): 21 | 22 | sharp_activations = activations * strengths_op(strengths) 23 | softmax_weights = tf.nn.softmax(sharp_activations) 24 | # softmax_weights = tf.nn.l2_normalize(sharp_activations, 1) 25 | return softmax_weights 26 | 27 | def cosine_similarity(memory, keys, strengths, strength_op=tf.nn.softplus): 28 | 29 | # Calculates the inner product between the query vector and words in memory. 30 | keys = tf.expand_dims(keys, 1) 31 | dot = tf.matmul(keys, memory, adjoint_b=True) 32 | 33 | # Outer product to compute denominator (euclidean norm of query and memory). 34 | memory_norms = _vector_norms(memory) 35 | key_norms = _vector_norms(keys) 36 | norm = tf.matmul(key_norms, memory_norms, adjoint_b=True) 37 | 38 | # Calculates cosine similarity between the query vector and words in memory. 39 | similarity = dot / (norm + _EPSILON) 40 | 41 | return _weighted_softmax(tf.squeeze(similarity, [1]), strengths, strength_op) 42 | 43 | def attention_read(read_key, memory_key): 44 | 45 | memory_key = tf.expand_dims(memory_key, 1) 46 | input_transform = tf.layers.conv2d(memory_key, 256, [1, 1], [1, 1], use_bias=False, name='memory_key_layer') 47 | query_transform = tf.layers.dense(read_key, 256, name='read_key_layer') 48 | query_transform = tf.expand_dims(tf.expand_dims(query_transform, 1), 1) 49 | addition = tf.nn.tanh(input_transform + query_transform, name='addition_layer') 50 | addition_transform = tf.layers.conv2d(addition, 1, [1, 1], [1, 1], use_bias=False, name='score_layer') 51 | addition_shape = addition_transform.get_shape().as_list() 52 | return tf.nn.softmax(tf.reshape(addition_transform, [addition_shape[0], -1])) 53 | 54 | def update_usage(write_weights, read_weights, prev_usage): 55 | 56 | # write_weights = tf.stop_gradient(write_weights) 57 | # read_weights = tf.stop_gradient(read_weights) 58 | usage = config.usage_decay*prev_usage + write_weights + read_weights 59 | return usage 60 | 61 | def calc_allocation_weight(usage, memory_size): 62 | 63 | usage = tf.stop_gradient(usage) 64 | nonusage = 1 - usage 65 | sorted_nonusage, indices = tf.nn.top_k(nonusage, k=1, name='sort') 66 | allocation_weights = tf.one_hot(tf.squeeze(indices, [1]), memory_size) 67 | 68 | return allocation_weights 69 | -------------------------------------------------------------------------------- /memnet/memnet.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Tensorflow implementation of 3 | # "Learning Dynamic Memory Networks for Object Tracking", ECCV,2018 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Tianyu Yang (tianyu-yang.com) 6 | # ------------------------------------------------------------------ 7 | 8 | import collections 9 | import tensorflow as tf 10 | from memnet.access import MemoryAccess, AccessState 11 | from feature import get_key_feature 12 | import config 13 | 14 | MemNetState = collections.namedtuple('MemNetState', ('controller_state', 'access_state')) 15 | 16 | def attention(input, query, scope=None): 17 | 18 | input_shape = input.get_shape().as_list() 19 | input_transform = tf.layers.conv2d(input, input_shape[-1], [1, 1], [1, 1], use_bias=False, name='input_layer') 20 | query_transform = tf.layers.dense(query, input_shape[-1], name='query_layer') 21 | query_transform = tf.expand_dims(tf.expand_dims(query_transform, 1), 1) 22 | addition = tf.nn.tanh(input_transform + query_transform, name='addition') 23 | addition_transform = tf.layers.conv2d(addition, 1, [1, 1], [1, 1], use_bias=False, name='score') 24 | addition_shape = addition_transform.get_shape().as_list() 25 | score = tf.nn.softmax(tf.reshape(addition_transform, [addition_shape[0], -1])) 26 | 27 | if int(scope) < config.summary_display_step: 28 | max_idxes = tf.argmax(score, 1) 29 | tf.summary.histogram('max_idxes_{}'.format(scope),max_idxes) 30 | max_value = tf.reduce_max(score, 1) 31 | tf.summary.histogram('max_value_{}'.format(scope), max_value) 32 | 33 | score = tf.reshape(score, addition_shape) 34 | return tf.reduce_sum(input*score, [1,2]), score 35 | 36 | class MemNet(tf.nn.rnn_cell.RNNCell): 37 | 38 | def __init__(self, hidden_size, memory_size, slot_size, is_train): 39 | super(MemNet, self).__init__() 40 | # self._controller = tf.nn.rnn_cell.BasicLSTMCell(hidden_size) 41 | # if is_train and config.keep_prob < 1: 42 | # self._controller = tf.nn.rnn_cell.DropoutWrapper(self._controller, 43 | # input_keep_prob=config.keep_prob, 44 | # output_keep_prob=config.keep_prob) 45 | keep_prob = config.keep_prob if is_train else 1.0 46 | self._controller = tf.contrib.rnn.LayerNormBasicLSTMCell(hidden_size, layer_norm=True, dropout_keep_prob=keep_prob) 47 | self._memory_access = MemoryAccess(memory_size, slot_size, is_train) 48 | self._hidden_size = hidden_size 49 | self._memory_size = memory_size 50 | self._slot_size = slot_size 51 | self._is_train = is_train 52 | 53 | def __call__(self, inputs, prev_state, scope=None): 54 | 55 | prev_controller_state = prev_state.controller_state 56 | prev_access_state = prev_state.access_state 57 | 58 | search_feature = inputs[0] 59 | memory_for_writing = inputs[1] 60 | 61 | # get lstm controller input 62 | controller_input = get_key_feature(search_feature, self._is_train, 'search_key') 63 | 64 | attention_input, self.att_score = attention(controller_input, prev_controller_state[1], scope) 65 | 66 | controller_output, controller_state = self._controller(attention_input, prev_controller_state, scope) 67 | 68 | access_inputs = (memory_for_writing, controller_output) 69 | access_output, access_state = self._memory_access(access_inputs, prev_access_state, scope) 70 | 71 | return access_output, MemNetState(access_state=access_state, controller_state=controller_state) 72 | 73 | def initial_state(self, init_feature): 74 | 75 | init_key = tf.squeeze(get_key_feature(init_feature, self._is_train, 'init_memory_key'), [1, 2]) 76 | c_state = tf.layers.dense(init_key, self._hidden_size, activation=tf.nn.tanh, name='c_state') 77 | h_state = tf.layers.dense(init_key, self._hidden_size, activation=tf.nn.tanh, name='h_state') 78 | batch_size = init_key.get_shape().as_list()[0] 79 | controller_state = tf.nn.rnn_cell.LSTMStateTuple(c_state, h_state) 80 | write_weights = tf.one_hot([0]*batch_size, self._memory_size, axis=-1, dtype=tf.float32) 81 | read_weight = tf.zeros([batch_size, self._memory_size], tf.float32) 82 | control_factors = tf.one_hot([2]*batch_size, 3, axis=-1, dtype=tf.float32) 83 | write_decay = tf.zeros([batch_size, 1], tf.float32) 84 | usage = tf.one_hot([0]*batch_size, self._memory_size, axis=-1, dtype=tf.float32) 85 | memory = tf.zeros([batch_size, self._memory_size]+self._slot_size, tf.float32) 86 | access_state = AccessState(init_memory=init_feature, 87 | memory=memory, 88 | read_weight=read_weight, 89 | write_weight=write_weights, 90 | control_factors=control_factors, 91 | write_decay = write_decay, 92 | usage=usage) 93 | return MemNetState(controller_state=controller_state, access_state=access_state) 94 | 95 | @property 96 | def state_size(self): 97 | return MemNetState(controller_state=self._controller.state_size, access_state=self._memory_access.state_size) 98 | 99 | @property 100 | def output_size(self): 101 | return tf.TensorShape(self._slot_size) -------------------------------------------------------------------------------- /memnet/rnn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Tensorflow implementation of 3 | # "Learning Dynamic Memory Networks for Object Tracking", ECCV,2018 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Tianyu Yang (tianyu-yang.com) 6 | # ------------------------------------------------------------------ 7 | 8 | import tensorflow as tf 9 | from tensorflow.python.util import nest 10 | import config 11 | 12 | 13 | def weights_summay(weight, name): 14 | weight_shape = weight.get_shape().as_list() 15 | 16 | for i in range(weight_shape[1]): 17 | tf.summary.histogram('Memory_{}/{}'.format(i, name), weight[:,i]) 18 | 19 | def rnn(cell, inputs, initial_state, scope=None): 20 | """Creates a recurrent neural network specified by RNNCell `cell`.""" 21 | 22 | if not isinstance(cell, tf.contrib.rnn.RNNCell): 23 | raise TypeError("cell must be an instance of RNNCell") 24 | if not nest.is_sequence(inputs): 25 | raise TypeError("inputs must be a sequence") 26 | if not inputs: 27 | raise ValueError("inputs must not be empty") 28 | 29 | input_shape = inputs[0].get_shape().as_list() 30 | input_list = [] 31 | for input in inputs: 32 | input_list.append([tf.squeeze(input_, [1]) 33 | for input_ in tf.split(axis=1, num_or_size_splits=input_shape[1], value=input)]) 34 | 35 | num_input = len(inputs) 36 | inputs = [] 37 | for i in range(input_shape[1]): 38 | inputs.append(tuple([input_list[j][i] for j in range(num_input)])) 39 | 40 | outputs = [] 41 | states = [] 42 | # Create a new scope in which the caching device is either 43 | # determined by the parent scope, or is set to place the cached 44 | # Variable using the same placement as for the rest of the RNN. 45 | with tf.variable_scope(scope or "RNN") as varscope: 46 | if varscope.caching_device is None: 47 | varscope.set_caching_device(lambda op: op.device) 48 | 49 | state = initial_state 50 | 51 | for time, input_ in enumerate(inputs): 52 | if time > 0: varscope.reuse_variables() 53 | call_cell = lambda: cell(input_, state, str(time)) 54 | output, state = call_cell() 55 | outputs.append(output) 56 | states.append(state) 57 | 58 | # summary for all these weights 59 | if len(inputs) >= config.summary_display_step: 60 | for i in range(config.summary_display_step): 61 | state = states[i] 62 | weights_summay(state.access_state.memory, 'memory_slot/{}'.format(i)) 63 | weights_summay(state.access_state.read_weight, 'read_weight/{}'.format(i)) 64 | weights_summay(state.access_state.write_weight, 'write_weight/{}'.format(i)) 65 | weights_summay(state.access_state.usage, 'usage/{}'.format(i)) 66 | output_shape = outputs[0].get_shape().as_list() 67 | outputs = tf.reshape(tf.concat(axis=1, values=outputs), [-1, input_shape[1]] + output_shape[1:]) 68 | return (outputs, state) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Tensorflow implementation of 3 | # "Learning Dynamic Memory Networks for Object Tracking", ECCV,2018 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Tianyu Yang (tianyu-yang.com) 6 | # ------------------------------------------------------------------ 7 | 8 | import config 9 | import tensorflow as tf 10 | import numpy as np 11 | from feature import extract_feature 12 | from input import generate_labels_overlap, generate_labels_dist 13 | from memnet.memnet import MemNet 14 | from memnet.rnn import rnn 15 | import collections 16 | 17 | class ModeKeys(): 18 | TRAIN = 'train' 19 | EVAL = 'eval' 20 | PREDICT = 'predict' 21 | 22 | EstimatorSpec = collections.namedtuple('EstimatorSpec', ['predictions', 'loss', 'dist_error', 'train', 'summary', 'saver']) 23 | 24 | def get_cnn_feature(input, reuse, mode): 25 | 26 | input_shape = input.get_shape().as_list() 27 | if len(input_shape) > 4: 28 | input = tf.reshape(input, [-1] + input_shape[2:]) 29 | 30 | is_train = True if mode == ModeKeys.TRAIN else False 31 | with tf.variable_scope('feature_extraction', reuse=reuse): 32 | cnn_feature = extract_feature(is_train, input) 33 | 34 | if len(input_shape) > 4: 35 | cnn_feature_shape = cnn_feature.get_shape().as_list() 36 | cnn_feature = tf.reshape(cnn_feature, input_shape[0:2]+cnn_feature_shape[1:]) 37 | 38 | return cnn_feature 39 | 40 | def batch_conv(A, B, mode): 41 | 42 | a_shape = A.get_shape().as_list() 43 | if len(a_shape) > 4: 44 | A = tf.reshape(A, [-1] + a_shape[2:]) 45 | b_shape = B.get_shape().as_list() 46 | if len(b_shape) > 4: 47 | B = tf.reshape(B, [-1] + b_shape[2:]) 48 | batch_size = A.get_shape().as_list()[0] 49 | 50 | output = tf.map_fn(lambda inputs: tf.nn.conv2d(tf.expand_dims(inputs[0], 0), tf.expand_dims(inputs[1], 3), [1,1,1,1], 'VALID'), 51 | elems=[A, B], 52 | dtype=tf.float32, 53 | parallel_iterations=batch_size) 54 | is_train = True if mode == ModeKeys.TRAIN else False 55 | output = tf.layers.batch_normalization(tf.squeeze(output, [1]), training=is_train, name='bn_response') 56 | return tf.squeeze(output, [3]) 57 | 58 | def get_predictions(query_feature, search_feature, mode): 59 | 60 | with tf.variable_scope('mann'): 61 | mann_cell = MemNet(config.hidden_size, config.memory_size, config.slot_size, True) 62 | 63 | initial_state = mann_cell.initial_state(query_feature[:, 0]) 64 | 65 | inputs = (search_feature, query_feature) 66 | outputs, final_state = rnn(cell=mann_cell, inputs=inputs, initial_state=initial_state) 67 | 68 | response = batch_conv(search_feature, outputs, mode) 69 | 70 | return response 71 | 72 | 73 | def focal_loss(labels, predictions, gamma=2, epsilon=1e-7, scope=None): 74 | 75 | with tf.name_scope(scope, "focal_loss", (predictions, labels)) as scope: 76 | predictions = tf.to_float(predictions) 77 | labels = tf.to_float(labels) 78 | predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 79 | preds = tf.where( 80 | tf.equal(labels, 1), predictions, 1. - predictions) 81 | losses = -(1. - preds) ** gamma * tf.log(preds + epsilon) 82 | return losses 83 | 84 | def get_loss(outputs, labels, mode): 85 | 86 | if mode == tf.estimator.ModeKeys.PREDICT: 87 | return None 88 | outputs_shape = outputs.get_shape().as_list() 89 | if config.label_type == 0: 90 | labels_response, weights = generate_labels_overlap(np.array(outputs_shape[1:3]), labels) 91 | else: 92 | labels_response, weights = generate_labels_dist(outputs_shape[0], np.array(outputs_shape[1:3])) 93 | if config.use_focal_loss: 94 | loss = tf.reduce_sum(weights * focal_loss(labels=labels_response, predictions=tf.nn.sigmoid(outputs))) / outputs_shape[0] 95 | else: 96 | loss = tf.reduce_sum(weights*tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_response, logits=outputs))/outputs_shape[0] 97 | tf.summary.scalar('loss', loss) 98 | return loss 99 | 100 | def get_dist_error(outputs, mode): 101 | 102 | if mode == tf.estimator.ModeKeys.PREDICT: 103 | return None 104 | outputs_shape = outputs.get_shape().as_list() 105 | outputs = tf.reshape(outputs, [outputs_shape[0], -1]) 106 | pred_loc_idx = tf.argmax(outputs, 1) 107 | loc_x = pred_loc_idx%outputs_shape[1] 108 | loc_y = pred_loc_idx//outputs_shape[1] 109 | pred_loc = tf.stack([loc_x, loc_y], 1) 110 | gt_loc = tf.tile(tf.expand_dims([outputs_shape[1]/2, outputs_shape[1]/2], 0), [outputs_shape[0], 1]) 111 | dist_error = tf.losses.mean_squared_error(predictions=pred_loc, labels=gt_loc) 112 | tf.summary.scalar('dist_error', dist_error) 113 | return dist_error 114 | 115 | def get_train_op(loss, mode): 116 | 117 | if mode != ModeKeys.TRAIN: 118 | return None 119 | 120 | global_step = tf.train.get_or_create_global_step() 121 | learning_rate = tf.train.exponential_decay(config.learning_rate, global_step, config.decay_circles, config.lr_decay, staircase=True) 122 | tf.summary.scalar('learning_rate', learning_rate) 123 | 124 | tvars = tf.trainable_variables() 125 | regularizer = tf.contrib.layers.l2_regularizer(config.weight_decay) 126 | regularizer_loss = tf.contrib.layers.apply_regularization(regularizer, tvars) 127 | loss += regularizer_loss 128 | grads, _ = tf.clip_by_global_norm(tf.gradients(loss, tvars), config.clip_gradients) 129 | # optimizer = tf.train.GradientDescentOptimizer(self.lr) 130 | optimizer = tf.train.AdamOptimizer(learning_rate) 131 | 132 | batchnorm_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 133 | with tf.control_dependencies(batchnorm_update_ops): 134 | train_op = optimizer.apply_gradients(zip(grads, tvars), global_step) 135 | 136 | return train_op 137 | 138 | def get_summary(mode): 139 | 140 | if mode == ModeKeys.PREDICT: 141 | return None 142 | return tf.summary.merge_all() 143 | 144 | def get_saver(): 145 | 146 | return tf.train.Saver(tf.global_variables(), max_to_keep=15) 147 | 148 | def model_fn(features, labels, mode): 149 | # get cnn feature for query and search 150 | query_feature = get_cnn_feature(features['query'], None, mode) 151 | search_feature = get_cnn_feature(features['search'], True, mode) 152 | 153 | predictions = get_predictions(query_feature, search_feature, mode) 154 | loss = get_loss(predictions, labels, mode) 155 | dist_error = get_dist_error(predictions, mode) 156 | train_op = get_train_op(loss, mode) 157 | summary = get_summary(mode) 158 | saver = get_saver() 159 | 160 | return EstimatorSpec(predictions, loss, dist_error, train_op, summary, saver) 161 | 162 | def build_initial_state(init_query, mem_cell, mode): 163 | 164 | query_feature = get_cnn_feature(init_query, None, mode) 165 | return mem_cell.initial_state(query_feature[:,0]) 166 | 167 | def build_model(query, search, mem_cell, initial_state, mode): 168 | # get cnn feature for query and search 169 | query_feature = get_cnn_feature(query, True, mode) 170 | search_feature = get_cnn_feature(search, True, mode) 171 | 172 | inputs = (search_feature, query_feature) 173 | outputs, final_state = rnn(cell=mem_cell, inputs=inputs, initial_state=initial_state) 174 | 175 | response = batch_conv(search_feature, outputs, mode) 176 | saver = get_saver() 177 | 178 | return response, saver, final_state 179 | 180 | 181 | if __name__=='__main__': 182 | query_patch = tf.placeholder(tf.float32, [10, 5, config.z_exemplar_size, config.z_exemplar_size, 3]) 183 | search_patch = tf.placeholder(tf.float32, [10, 5, config.x_instance_size, config.x_instance_size, 3]) 184 | features = { 185 | 'query': query_patch, 186 | 'search': search_patch 187 | } 188 | labels = tf.placeholder(tf.float32, [10, 5, 4]) 189 | mode = ModeKeys.TRAIN 190 | 191 | esti_spec = model_fn(features, labels, mode) 192 | pass 193 | -------------------------------------------------------------------------------- /tracking/demo.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Tensorflow implementation of 3 | # "Learning Dynamic Memory Networks for Object Tracking", ECCV,2018 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Tianyu Yang (tianyu-yang.com) 6 | # ------------------------------------------------------------------ 7 | 8 | import tensorflow as tf 9 | import time 10 | import sys 11 | sys.path.append('../') 12 | import os 13 | import config 14 | from tracking.tracker import Tracker, Model 15 | import cv2 16 | 17 | def load_seq_config(seq_name): 18 | 19 | src = os.path.join(config.otb_data_dir,seq_name,'groundtruth_rect.txt') 20 | gt_file = open(src) 21 | lines = gt_file.readlines() 22 | gt_rects = [] 23 | for gt_rect in lines: 24 | rect = [int(v) for v in gt_rect[:-1].split(',')] 25 | gt_rects.append(rect) 26 | 27 | init_rect= gt_rects[0] 28 | img_path = os.path.join(config.otb_data_dir,seq_name,'img') 29 | img_names = sorted(os.listdir(img_path)) 30 | s_frames = [os.path.join(img_path, img_name) for img_name in img_names] 31 | 32 | return init_rect, s_frames 33 | 34 | def display_result(image, pred_boxes, frame_idx, seq_name=None): 35 | if len(image.shape) == 3: 36 | r, g, b = cv2.split(image) 37 | image = cv2.merge([b, g, r]) 38 | pred_boxes = pred_boxes.astype(int) 39 | cv2.rectangle(image, tuple(pred_boxes[0:2]), tuple(pred_boxes[0:2] + pred_boxes[2:4]), (0, 0, 255), 2) 40 | 41 | cv2.putText(image, 'Frame: %d' % frame_idx, (20, 30), cv2.FONT_HERSHEY_DUPLEX, 0.8, (0, 255, 255)) 42 | cv2.imshow('tracker', image) 43 | if cv2.waitKey(1) & 0xFF == ord('q'): 44 | return True 45 | if config.is_save: 46 | cv2.imwrite(os.path.join(config.save_path, seq_name, '%04d.jpg' % frame_idx), image) 47 | 48 | def run_tracker(): 49 | config_proto = tf.ConfigProto() 50 | config_proto.gpu_options.allow_growth = True 51 | with tf.Graph().as_default(), tf.Session(config=config_proto) as sess: 52 | os.chdir('../') 53 | model = Model(sess) 54 | tracker = Tracker(model) 55 | init_rect, s_frames = load_seq_config('Basketball') 56 | bbox = init_rect 57 | res = [] 58 | res.append(bbox) 59 | start_time = time.time() 60 | tracker.initialize(s_frames[0], bbox) 61 | 62 | for idx in range(1, len(s_frames)): 63 | tracker.idx = idx 64 | bbox, cur_frame = tracker.track(s_frames[idx]) 65 | display_result(cur_frame, bbox, idx) 66 | res.append(bbox.tolist()) 67 | end_time = time.time() 68 | type = 'rect' 69 | fps = idx/(end_time-start_time) 70 | 71 | return res, type, fps 72 | 73 | if __name__ == '__main__': 74 | run_tracker() -------------------------------------------------------------------------------- /tracking/tracker.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Tensorflow implementation of 3 | # "Learning Dynamic Memory Networks for Object Tracking", ECCV,2018 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Tianyu Yang (tianyu-yang.com) 6 | # ------------------------------------------------------------------ 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | import sys 11 | import config 12 | from model import build_initial_state, build_model, ModeKeys 13 | from memnet.memnet import MemNet, AccessState, MemNetState 14 | import math 15 | sys.path.append('../') 16 | 17 | class Model(): 18 | def __init__(self, sess, checkpoint_dir=None): 19 | 20 | self.z_file_init = tf.placeholder(tf.string, [], name='z_filename_init') 21 | self.z_roi_init = tf.placeholder(tf.float32, [1, 4], name='z_roi_init') 22 | self.z_file = tf.placeholder(tf.string, [], name='z_filename') 23 | self.z_roi = tf.placeholder(tf.float32, [1, 4], name='z_roi') 24 | self.x_file = tf.placeholder(tf.string, [], name='x_filename') 25 | self.x_roi = tf.placeholder(tf.float32, [config.num_scale, 4], name='x_roi') 26 | 27 | init_z_exemplar,_ = self._read_and_crop_image(self.z_file_init, self.z_roi_init, [config.z_exemplar_size, config.z_exemplar_size]) 28 | init_z_exemplar = tf.reshape(init_z_exemplar, [1, 1, config.z_exemplar_size, config.z_exemplar_size, 3]) 29 | init_z_exemplar = tf.tile(init_z_exemplar, [config.num_scale, 1, 1, 1, 1]) 30 | z_exemplar,_ = self._read_and_crop_image(self.z_file, self.z_roi, [config.z_exemplar_size, config.z_exemplar_size]) 31 | z_exemplar = tf.reshape(z_exemplar, [1, 1, config.z_exemplar_size, config.z_exemplar_size, 3]) 32 | z_exemplar = tf.tile(z_exemplar, [config.num_scale, 1, 1, 1, 1]) 33 | self.x_instances, self.image = self._read_and_crop_image(self.x_file, self.x_roi, [config.x_instance_size, config.x_instance_size]) 34 | self.x_instances = tf.reshape(self.x_instances, [config.num_scale, 1, config.x_instance_size, config.x_instance_size, 3]) 35 | 36 | with tf.variable_scope('mann'): 37 | mem_cell = MemNet(config.hidden_size, config.memory_size, config.slot_size, False) 38 | 39 | self.initial_state = build_initial_state(init_z_exemplar, mem_cell, ModeKeys.PREDICT) 40 | self.response, saver, self.final_state = build_model(z_exemplar, self.x_instances, mem_cell, self.initial_state, ModeKeys.PREDICT) 41 | self.att_score = mem_cell.att_score 42 | 43 | up_response_size = config.response_size * config.response_up 44 | self.up_response = tf.squeeze(tf.image.resize_images(tf.expand_dims(self.response, -1), 45 | [up_response_size, up_response_size], 46 | method=tf.image.ResizeMethod.BICUBIC, 47 | align_corners=True), -1) 48 | if checkpoint_dir is not None: 49 | saver.restore(sess, checkpoint_dir) 50 | self._sess = sess 51 | else: 52 | ckpt = tf.train.get_checkpoint_state(config.checkpoint_dir) 53 | if ckpt and ckpt.model_checkpoint_path: 54 | saver.restore(sess, ckpt.model_checkpoint_path) 55 | self._sess = sess 56 | 57 | def _read_and_crop_image(self, filename, roi, model_sz): 58 | image_file = tf.read_file(filename) 59 | # Decode the image as a JPEG file, this will turn it into a Tensor 60 | image = tf.image.decode_jpeg(image_file, channels=3) 61 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 62 | frame_sz = tf.shape(image) 63 | # used to pad the crops 64 | avg_chan = tf.reduce_mean(image, axis=(0, 1), name='avg_chan') 65 | # pad with if necessary 66 | frame_padded, npad = self._pad_frame(image, frame_sz, roi, avg_chan) 67 | frame_padded = tf.cast(frame_padded, tf.float32) 68 | crop_patch = self._crop_image(frame_padded, npad, frame_sz, roi, model_sz) 69 | return crop_patch, image 70 | 71 | def _pad_frame(self, im, frame_sz, roi, avg_chan): 72 | pos_x = tf.reduce_max(roi[:, 0], axis=0) 73 | pos_y = tf.reduce_max(roi[:, 1], axis=0) 74 | patch_sz = tf.reduce_max(roi[:, 2:4], axis=0) 75 | c = patch_sz / 2 76 | xleft_pad = tf.maximum(0, -tf.cast(tf.round(pos_x - c[0]), tf.int32)) 77 | ytop_pad = tf.maximum(0, -tf.cast(tf.round(pos_y - c[1]), tf.int32)) 78 | xright_pad = tf.maximum(0, tf.cast(tf.round(pos_x + c[0]), tf.int32) - frame_sz[1]) 79 | ybottom_pad = tf.maximum(0, tf.cast(tf.round(pos_y + c[1]), tf.int32) - frame_sz[0]) 80 | npad = tf.reduce_max([xleft_pad, ytop_pad, xright_pad, ybottom_pad]) 81 | paddings = [[npad, npad], [npad, npad], [0, 0]] 82 | im_padded = im 83 | if avg_chan is not None: 84 | im_padded = im_padded - avg_chan 85 | im_padded = tf.pad(im_padded, paddings, mode='CONSTANT') 86 | if avg_chan is not None: 87 | im_padded = im_padded + avg_chan 88 | return im_padded, npad 89 | 90 | def _crop_image(self, im, npad, frame_sz, rois, model_sz): 91 | radius = (rois[:, 2:4]-1) / 2 92 | c_xy = rois[:, 0:2] 93 | self.pad_frame_sz = pad_frame_sz = tf.cast(tf.expand_dims(frame_sz[0:2]+2*npad,0), tf.float32) 94 | npad = tf.cast(npad, tf.float32) 95 | xy1 = (npad + c_xy - radius) 96 | xy2 = (npad + c_xy + radius) 97 | norm_rect = tf.stack([xy1[:,1], xy1[:,0], xy2[:,1], xy2[:,0]], axis=1)/tf.concat([pad_frame_sz, pad_frame_sz],1) 98 | crops = tf.image.crop_and_resize(tf.expand_dims(im, 0), norm_rect, tf.zeros([tf.shape(rois)[0]],tf.int32), model_sz, method='bilinear') 99 | 100 | return crops 101 | 102 | 103 | class Tracker(): 104 | 105 | def __init__(self, model): 106 | 107 | self._model = model 108 | self._sess = model._sess 109 | self.idx = 1 110 | 111 | # prepare constant things for tracking 112 | scale_steps = list(range(math.ceil(config.num_scale / 2) - config.num_scale, math.floor(config.num_scale / 2) + 1)) 113 | self.scales = np.power(config.scale_multipler, scale_steps) 114 | 115 | up_response_size = config.response_size * config.response_up 116 | if config.window == 'cosine': 117 | window = np.matmul(np.expand_dims(np.hanning(up_response_size), 1), 118 | np.expand_dims(np.hanning(up_response_size), 0)).astype(np.float32) 119 | else: 120 | window = np.ones([up_response_size, up_response_size], dtype=np.float32) 121 | self.window = window / np.sum(window) 122 | 123 | def estimate_bbox(self, responses, x_roi_size_origs, target_pos, target_size): 124 | 125 | up_response_size = config.response_size * config.response_up 126 | current_scale_idx = math.floor(config.num_scale / 2) 127 | best_scale_idx = current_scale_idx 128 | best_peak = -math.inf 129 | 130 | for s_idx in range(config.num_scale): 131 | this_response = responses[s_idx].copy() 132 | 133 | # penalize the change of scale 134 | if s_idx != current_scale_idx: 135 | this_response *= config.scale_penalty 136 | this_peak = np.max(this_response) 137 | if this_peak > best_peak: 138 | best_peak = this_peak 139 | best_scale_idx = s_idx 140 | response = responses[best_scale_idx] 141 | 142 | x_roi_size_orig = x_roi_size_origs[best_scale_idx] 143 | 144 | # make response sum to 1 145 | response -= np.min(response) 146 | response /= np.sum(response) 147 | 148 | self.norm_response = response 149 | # apply window 150 | self.win_response = response = (1 - config.win_weights) * response + config.win_weights * self.window 151 | 152 | max_idx = np.argsort(response.flatten()) 153 | max_idx = max_idx[-config.avg_num:] 154 | 155 | x = max_idx % up_response_size 156 | y = max_idx // up_response_size 157 | position = np.vstack([x, y]).transpose() 158 | 159 | shift_center = position - up_response_size / 2 160 | shift_center_instance = shift_center * config.stride / config.response_up 161 | shift_center_orig = shift_center_instance * np.expand_dims(x_roi_size_orig, 0) / config.x_instance_size 162 | target_pos = np.mean(target_pos + shift_center_orig, 0) 163 | 164 | target_size_new = target_size * self.scales[best_scale_idx] 165 | target_size = (1 - config.scale_damp) * target_size + config.scale_damp * target_size_new 166 | 167 | return target_pos, target_size, best_scale_idx 168 | 169 | def initialize(self, init_frame_file, init_box): 170 | bbox = np.array(init_box) 171 | self.target_pos = bbox[0:2] + bbox[2:4] / 2 172 | self.target_size = bbox[2:4] 173 | 174 | self.z_roi_size = calc_z_size(self.target_size) 175 | self.x_roi_size = calc_x_size(self.z_roi_size) 176 | z_roi = np.concatenate([self.target_pos, self.z_roi_size], 0) 177 | next_state = self._sess.run(self._model.initial_state, 178 | {self._model.z_file_init: init_frame_file, 179 | self._model.z_roi_init: [z_roi]}) 180 | self.next_state = next_state 181 | self.pre_frame_file = init_frame_file 182 | 183 | def track(self, cur_frame_file, display=False): 184 | # build pyramid of search images 185 | sx_roi_size = np.round(np.expand_dims(self.x_roi_size, 0) * np.expand_dims(self.scales, 1)) 186 | target_poses = np.tile(np.expand_dims(self.target_pos,axis=0), [config.num_scale,1]) 187 | x_rois = np.concatenate([target_poses, sx_roi_size], axis=1) 188 | z_roi = np.concatenate([self.target_pos, self.z_roi_size], 0) 189 | att_score, responses, cur_frame,\ 190 | x_instances, self.next_state = self._sess.run([self._model.att_score, 191 | self._model.up_response, 192 | self._model.image, 193 | self._model.x_instances, 194 | self._model.final_state], 195 | {self._model.x_file: cur_frame_file, 196 | self._model.x_roi: x_rois, 197 | self._model.z_file: self.pre_frame_file, 198 | self._model.z_roi: [z_roi], 199 | self._model.initial_state: self.next_state}) 200 | 201 | # estimate position and size 202 | self.target_pos, self.target_size, best_scale_idx = self.estimate_bbox(responses, sx_roi_size, self.target_pos, self.target_size) 203 | bbox = np.hstack([self.target_pos - self.target_size / 2, self.target_size]) 204 | 205 | self.next_state = get_new_state(self.next_state, best_scale_idx) 206 | # calculate new x and z roi size for next frame 207 | self.z_roi_size = calc_z_size(self.target_size) 208 | self.x_roi_size = calc_x_size(self.z_roi_size) 209 | self.pre_frame_file = cur_frame_file 210 | 211 | if display: 212 | return bbox, cur_frame, x_instances[best_scale_idx, 0], att_score[best_scale_idx], responses[best_scale_idx], self.next_state.access_state 213 | else: 214 | return bbox, cur_frame 215 | 216 | 217 | def calc_z_size(target_size): 218 | # calculate roi region 219 | if config.fix_aspect: 220 | extend_size = target_size + config.context_amount * (target_size[0] + target_size[1]) 221 | z_size = np.sqrt(np.prod(extend_size)) 222 | z_size = np.repeat(z_size, 2, 0) 223 | else: 224 | z_size = target_size * config.z_scale 225 | 226 | return z_size 227 | 228 | def calc_x_size(z_roi_size): 229 | # calculate roi region 230 | z_scale = config.z_exemplar_size / z_roi_size 231 | delta_size = config.x_instance_size - config.z_exemplar_size 232 | x_size = delta_size / z_scale + z_roi_size 233 | 234 | return x_size 235 | 236 | def get_new_state(state, best_scale): 237 | 238 | lstm_state = state[0] 239 | access_state = state[1] 240 | 241 | c_best = lstm_state[0][best_scale] 242 | h_best = lstm_state[1][best_scale] 243 | c = np.array([c_best]*config.num_scale) 244 | h = np.array([h_best]*config.num_scale) 245 | 246 | lstm_state = tf.nn.rnn_cell.LSTMStateTuple(c, h) 247 | 248 | s_list = [] 249 | for s in access_state: 250 | s_best = s[best_scale] 251 | s_list.append([s_best]*config.num_scale) 252 | access_state = AccessState(np.array(s_list[0]), np.array(s_list[1]), np.array(s_list[2]), 253 | np.array(s_list[3]), np.array(s_list[4]), np.array(s_list[5]), np.array(s_list[6])) 254 | return MemNetState(lstm_state, access_state) 255 | 256 | 257 | --------------------------------------------------------------------------------