├── ILSVRC15-curation ├── gen_image_crops_VID.py └── gen_imdb_VID.py ├── LICENSE ├── README.md ├── Tracking ├── Config.py ├── SiamNet.py ├── Tracking_Utils.py └── run_SiamFC.py ├── Train ├── Config.py ├── DataAugmentation.py ├── SiamNet.py ├── Utils.py ├── VIDDataset.py ├── model │ └── SiamFC_50_model.pth └── run_Train_SiamFC.py └── imgs └── result.PNG /ILSVRC15-curation/gen_image_crops_VID.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Written by Heng Fan 3 | ''' 4 | import numpy as np 5 | import os 6 | import glob 7 | import xml.etree.ElementTree as ET 8 | import cv2 9 | import datetime 10 | 11 | 12 | ''' 13 | # default setting for cropping 14 | ''' 15 | examplar_size = 127.0 16 | instance_size = 255.0 17 | context_amount = 0.5 18 | 19 | 20 | def get_subwindow_avg(im, pos, model_sz, original_sz): 21 | ''' 22 | # obtain image patch, padding with avg channel if area goes outside of border 23 | ''' 24 | avg_chans = [np.mean(im[:, :, 0]), np.mean(im[:, :, 1]), np.mean(im[:, :, 2])] 25 | 26 | if original_sz is None: 27 | original_sz = model_sz 28 | 29 | sz = original_sz 30 | im_sz = im.shape 31 | # make sure the size is not too small 32 | assert (im_sz[0] > 2) & (im_sz[1] > 2), "The size of image is too small!" 33 | c = (sz + 1) / 2 34 | 35 | # check out-of-bounds coordinates, and set them to black 36 | context_xmin = round(pos[1] - c) # floor(pos(2) - sz(2) / 2); 37 | context_xmax = context_xmin + sz - 1 38 | context_ymin = round(pos[0] - c) # floor(pos(1) - sz(1) / 2); 39 | context_ymax = context_ymin + sz - 1 40 | left_pad = max(0, 1 - context_xmin) # in python, index starts from 0 41 | top_pad = max(0, 1 - context_ymin) 42 | right_pad = max(0, context_xmax - im_sz[1]) 43 | bottom_pad = max(0, context_ymax - im_sz[0]) 44 | 45 | context_xmin = context_xmin + left_pad 46 | context_xmax = context_xmax + left_pad 47 | context_ymin = context_ymin + top_pad 48 | context_ymax = context_ymax + top_pad 49 | 50 | im_R = im[:, :, 0] 51 | im_G = im[:, :, 1] 52 | im_B = im[:, :, 2] 53 | 54 | # padding 55 | if (top_pad != 0) | (bottom_pad != 0) | (left_pad != 0) | (right_pad != 0): 56 | im_R = np.pad(im_R, ((int(top_pad), int(bottom_pad)), (int(left_pad), int(right_pad))), 'constant', 57 | constant_values=avg_chans[0]) 58 | im_G = np.pad(im_G, ((int(top_pad), int(bottom_pad)), (int(left_pad), int(right_pad))), 'constant', 59 | constant_values=avg_chans[1]) 60 | im_B = np.pad(im_B, ((int(top_pad), int(bottom_pad)), (int(left_pad), int(right_pad))), 'constant', 61 | constant_values=avg_chans[2]) 62 | 63 | im = np.stack((im_R, im_G, im_B), axis=2) 64 | 65 | im_patch_original = im[int(context_ymin) - 1:int(context_ymax), int(context_xmin) - 1:int(context_xmax), :] 66 | 67 | if model_sz != original_sz: 68 | im_patch = cv2.resize(im_patch_original, (int(model_sz), int(model_sz)), interpolation=cv2.INTER_CUBIC) 69 | else: 70 | im_patch = im_patch_original 71 | 72 | return im_patch 73 | 74 | 75 | def get_crops(img, bbox, size_z, size_x, context_amount): 76 | ''' 77 | # get examplar and search region crops 78 | ''' 79 | cx = bbox[0] + bbox[2]/2 80 | cy = bbox[1] + bbox[3]/2 81 | w = bbox[2] 82 | h = bbox[3] 83 | 84 | # for examplar 85 | wc_z = w + context_amount * (w + h) 86 | hc_z = h + context_amount * (w + h) 87 | s_z = np.sqrt(wc_z * hc_z) 88 | scale_z = size_z / s_z 89 | im_crop_z = get_subwindow_avg(img, np.array([cy, cx]), size_z, round(s_z)) 90 | 91 | # for search region 92 | d_search = (size_x - size_z) / 2 93 | pad = d_search / scale_z 94 | s_x = s_z + 2 * pad 95 | scale_x = size_x / s_x 96 | im_crop_x = get_subwindow_avg(img, np.array([cy, cx]), size_x, round(s_x)) 97 | 98 | return im_crop_z, im_crop_x 99 | 100 | 101 | def generate_image_crops(vid_root_path, vid_curated_path): 102 | ''' 103 | # save image crops to the vid_curated_path 104 | ''' 105 | anno_str = "Annotations/VID/train/" 106 | data_str = "Data/VID/train/" 107 | vid_anno_path = os.path.join(vid_root_path, anno_str) 108 | vid_data_path = os.path.join(vid_root_path, data_str) 109 | 110 | cur_procesed_fraem = 0 111 | start_time = datetime.datetime.now() 112 | total_time = 0 113 | 114 | # dirs of level1: e.g., a/, b/, ... 115 | all_dirs_level1 = os.listdir(vid_anno_path) 116 | for i in range(len(all_dirs_level1)): 117 | all_dirs_level2 = os.listdir(os.path.join(vid_anno_path, all_dirs_level1[i])) 118 | 119 | # dirs of level2: e.g., a/ILSVRC2015_train_00000000/, a/ILSVRC2015_train_00001000/, ... 120 | for j in range(len(all_dirs_level2)): 121 | frame_list = glob.glob(os.path.join(vid_anno_path, all_dirs_level1[i], all_dirs_level2[j], "*.xml")) 122 | frame_list.sort() 123 | 124 | # level3: frame level 125 | for k in range(len(frame_list)): 126 | frame_xml_name = os.path.join(vid_anno_path, all_dirs_level1[i], all_dirs_level2[j], frame_list[k]) 127 | frame_xml_tree = ET.parse(frame_xml_name) 128 | frame_xml_root = frame_xml_tree.getroot() 129 | 130 | # image file path 131 | frame_img_name = (frame_list[k].replace(".xml", ".JPEG")).replace(vid_anno_path, vid_data_path) 132 | img = cv2.imread(frame_img_name) 133 | if img is None: 134 | print("Cannot find %s!"%frame_img_name) 135 | exit(0) 136 | 137 | # image file name 138 | frame_filename = frame_xml_root.find('filename').text 139 | 140 | # process (all objects in) each frame 141 | for object in frame_xml_root.iter("object"): 142 | # get trackid 143 | id = object.find("trackid").text 144 | 145 | # get bounding box 146 | bbox_node = object.find("bndbox") 147 | xmax = float(bbox_node.find('xmax').text) 148 | xmin = float(bbox_node.find('xmin').text) 149 | ymax = float(bbox_node.find('ymax').text) 150 | ymin = float(bbox_node.find('ymin').text) 151 | width = xmax - xmin + 1 152 | height = ymax - ymin + 1 153 | bbox = np.array([xmin, ymin, width, height]) 154 | 155 | # print("processing %s, %s, %s, %s ..." % (all_dirs_level1[i], all_dirs_level2[j], frame_filename+".JPEG", id)) 156 | 157 | # get crops 158 | im_crop_z, im_crop_x = get_crops(img, bbox, examplar_size, instance_size, context_amount) 159 | 160 | # save crops 161 | save_path = os.path.join(vid_curated_path, data_str, all_dirs_level1[i], all_dirs_level2[j]) 162 | if not os.path.exists(save_path): 163 | os.makedirs(save_path) 164 | 165 | savename_crop_z = os.path.join(save_path, '{}.{:02d}.crop.z.jpg'.format(frame_filename, int(id))) 166 | savename_crop_x = os.path.join(save_path, '{}.{:02d}.crop.x.jpg'.format(frame_filename, int(id))) 167 | 168 | cv2.imwrite(savename_crop_z, im_crop_z, [int(cv2.IMWRITE_JPEG_QUALITY), 90]) 169 | cv2.imwrite(savename_crop_x, im_crop_x, [int(cv2.IMWRITE_JPEG_QUALITY), 90]) 170 | 171 | cur_procesed_fraem = cur_procesed_fraem + 1 172 | 173 | if cur_procesed_fraem % 1000 == 0: 174 | end_time = datetime.datetime.now() 175 | total_time = total_time + int((end_time-start_time).seconds) 176 | print("finished processing %d frames in %d seconds (FPS: %d ) ..." % (cur_procesed_fraem, total_time, int(1000/(end_time-start_time).seconds))) 177 | start_time = datetime.datetime.now() 178 | 179 | 180 | if __name__ == "__main__": 181 | # path to your VID dataset 182 | vid_root_path = "/home/hfan/Dataset/ILSVRC2015" 183 | vid_curated_path = "/home/hfan/Dataset/ILSVRC2015_crops" 184 | if not os.path.exists(vid_curated_path): 185 | os.mkdir(vid_curated_path) 186 | generate_image_crops(vid_root_path, vid_curated_path) 187 | -------------------------------------------------------------------------------- /ILSVRC15-curation/gen_imdb_VID.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Written by Heng Fan 3 | ''' 4 | import numpy as np 5 | import os 6 | import glob 7 | import xml.etree.ElementTree as ET 8 | import json 9 | 10 | 11 | validation_ratio = 0.1 12 | 13 | 14 | def generate_image_imdb(vid_root_path, vid_curated_path): 15 | ''' 16 | # save image crops to the vid_curated_path 17 | ''' 18 | anno_str = "Annotations/VID/train/" 19 | data_str = "Data/VID/train/" 20 | vid_anno_path = os.path.join(vid_root_path, anno_str) 21 | vid_data_path = os.path.join(vid_root_path, data_str) 22 | 23 | num_videos = 0 24 | 25 | # dirs of level1: e.g., a/, b/, ... 26 | all_dirs_level1 = os.listdir(vid_anno_path) 27 | 28 | for i in range(len(all_dirs_level1)): 29 | all_dirs_level2 = os.listdir(os.path.join(vid_anno_path, all_dirs_level1[i])) 30 | num_videos = num_videos + len(all_dirs_level2) 31 | 32 | train_video_num = round(num_videos * (1-validation_ratio)) 33 | val_video_num = num_videos - train_video_num 34 | 35 | imdb_video_train = dict() 36 | imdb_video_train['num_videos'] = train_video_num 37 | imdb_video_train['data_str'] = data_str 38 | 39 | imdb_video_val = dict() 40 | imdb_video_val['num_videos'] = val_video_num 41 | imdb_video_val['data_str'] = data_str 42 | 43 | videos_train = dict() 44 | videos_val = dict() 45 | 46 | vid_idx = 0 47 | 48 | for i in range(len(all_dirs_level1)): 49 | all_dirs_level2 = os.listdir(os.path.join(vid_anno_path, all_dirs_level1[i])) 50 | 51 | # dirs of level2: e.g., a/ILSVRC2015_train_00000000/, a/ILSVRC2015_train_00001000/, ... 52 | for j in range(len(all_dirs_level2)): 53 | 54 | if vid_idx < train_video_num: 55 | if not videos_train.has_key(all_dirs_level2[j]): 56 | videos_train[all_dirs_level2[j]] = [] 57 | else: 58 | if not videos_val.has_key(all_dirs_level2[j]): 59 | videos_val[all_dirs_level2[j]] = [] 60 | 61 | 62 | 63 | frame_list = glob.glob(os.path.join(vid_anno_path, all_dirs_level1[i], all_dirs_level2[j], "*.xml")) 64 | frame_list.sort() 65 | 66 | video_ids = dict() # store frame level information 67 | 68 | # level3: frame level 69 | for k in range(len(frame_list)): 70 | # read xml file 71 | frame_id = k 72 | frame_xml_name = os.path.join(vid_anno_path, all_dirs_level1[i], all_dirs_level2[j], frame_list[k]) 73 | frame_xml_tree = ET.parse(frame_xml_name) 74 | frame_xml_root = frame_xml_tree.getroot() 75 | 76 | # crop_path = os.path.join(vid_curated_path, data_str, all_dirs_level1[i], all_dirs_level2[j]) 77 | crop_path = os.path.join(all_dirs_level1[i], all_dirs_level2[j]) 78 | frame_filename = frame_xml_root.find('filename').text 79 | 80 | print ("processing: %s, %s, %s ..." % (all_dirs_level1[i], all_dirs_level2[j], frame_filename)) 81 | 82 | for object in frame_xml_root.iter("object"): 83 | # get trackid 84 | id = object.find("trackid").text 85 | if not video_ids.has_key(id): 86 | video_ids[id] = [] 87 | # get bounding box 88 | bbox_node = object.find("bndbox") 89 | xmax = float(bbox_node.find('xmax').text) 90 | xmin = float(bbox_node.find('xmin').text) 91 | ymax = float(bbox_node.find('ymax').text) 92 | ymin = float(bbox_node.find('ymin').text) 93 | width = xmax - xmin + 1 94 | height = ymax - ymin + 1 95 | bbox = np.array([xmin, ymin, width, height]) 96 | 97 | tmp_instance = dict() 98 | tmp_instance['instance_path'] = os.path.join(all_dirs_level1[i], all_dirs_level2[j], '{}.{:02d}.crop.x.jpg'.format(frame_filename, int(id))) 99 | tmp_instance['bbox'] =bbox.tolist() 100 | 101 | video_ids[id].append(tmp_instance) 102 | 103 | # delete the object_id with less than 1 frame 104 | tmp_keys = video_ids.keys() 105 | for ki in range(len(tmp_keys)): 106 | if len(video_ids[tmp_keys[ki]]) < 2: 107 | del video_ids[tmp_keys[ki]] 108 | 109 | tmp_keys = video_ids.keys() 110 | 111 | if len(tmp_keys) > 0: 112 | 113 | if vid_idx < train_video_num: 114 | videos_train[all_dirs_level2[j]].append(video_ids) 115 | else: 116 | videos_val[all_dirs_level2[j]].append(video_ids) 117 | 118 | vid_idx = vid_idx + 1 119 | 120 | imdb_video_train['videos'] = videos_train 121 | imdb_video_val['videos'] = videos_val 122 | 123 | # save imdb information 124 | json.dump(imdb_video_train, open('imdb_video_train.json', 'w'), indent=2) 125 | json.dump(imdb_video_val, open('imdb_video_val.json', 'w'), indent=2) 126 | 127 | 128 | if __name__ == "__main__": 129 | vid_root_path = "/home/hfan/Dataset/ILSVRC2015" 130 | vid_curated_path = "/home/hfan/Dataset/ILSVRC2015_crops" 131 | generate_image_imdb(vid_root_path, vid_curated_path) 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## SiamFC-PyTorch 2 | This is the PyTorch (0.40) implementation of SiamFC tracker [1], which was originally implemented in Matlab using MatConvNet [2]. In our implementation, we obtain better perforamnce than the original one. 3 | 4 | ## Goal 5 | 6 | * A more compact implementation of SiamFC [1] 7 | * Reproduce the results of SiamFC [1], including data generation, training and tracking 8 | 9 | ## Requirements 10 | 11 | * Python 2.7 (I use Anaconda 2.* here. If you use Python3, you may get the very different results!) 12 | * Python-opencv 13 | * PyTorch 0.40 14 | * other common packages such as `numpy`, etc 15 | 16 | ## Data curation 17 | 18 | * Download ILSVRC15, and unzip it (let's assume that `$ILSVRC2015_Root` is the path to your ILSVRC2015) 19 | * Move `$ILSVRC2015_Root/Data/VID/val` into `$ILSVRC2015_Root/Data/VID/train/`, so we have five sub-folders in `$ILSVRC2015_Root/Data/VID/train/` 20 | * It is a good idea to change the names of five sub-folders in `$ILSVRC2015_Root/Data/VID/train/` to `a`, `b`, `c`, `d`, and `e` 21 | * Move `$ILSVRC2015_Root/Annotations/VID/val` into `$ILSVRC2015_Root/Annotations/VID/train/`, so we have five sub-folders in `$ILSVRC2015_Root/Annotations/VID/train/` 22 | * Change the names of five sub-folders in `$ILSVRC2015_Root/Annotations/VID/train/` to `a`, `b`, `c`, `d` and `e`, respectively 23 | 24 | * Generate image crops 25 | * cd `$SiamFC-PyTorch/ILSVRC15-curation/` (Assume you've downloaded the rep and its path is `$SiamFC-PyTorch`) 26 | * change `vid_curated_path` in `gen_image_crops_VID.py` to save your crops 27 | * run `$python gen_image_crops_VID.py` (I run it in PyCharm), then you can check the cropped images in your saving path (i.e., `vid_curated_path`) 28 | 29 | * Generate imdb for training and validation 30 | * cd `$SiamFC-PyTorch/ILSVRC15-curation/` 31 | * change `vid_root_path` and `vid_curated_path` to your custom path in `gen_imdb_VID.py` 32 | * run `$python gen_imdb_VID.py`, then you will get two json files `imdb_video_train.json` (~ 430MB) and `imdb_video_val.json` (~ 28MB) in current folder, which are used for training and validation 33 | 34 | ## Train 35 | 36 | * cd `$SiamFC-PyTorch/Train/` 37 | * Change `data_dir`, `train_imdb` and `val_imdb` to your custom cropping path, training and validation json files 38 | * run `$python run_Train_SiamFC.py` 39 | * some notes in training 40 | * the parameters for training are in `Config.py` 41 | * by default, I use GPU in training, and you can check the details in the function `train(data_dir, train_imdb, val_imdb, model_save_path="./model/", use_gpu=True)` 42 | * by default, the trained models will be saved to `$SiamFC-PyTorch/Train/model/` 43 | * each epoch (50 in total) may take 7-8 minuts (Nvidia 1080 GPU), and you can use parallelling utilities in PyTorch for speeding up 44 | * I tried to use fixed random seeds to get the same results, but it doesn't work ):, so results for each training may be slightly different (still better than the original) 45 | * only color images are used for training, and better performance is expected if using color+gray as in original paper 46 | 47 | ## Test (Tracking) 48 | 49 | * cd `$SiamFC-PyTorch/Tracking/` 50 | * Firstly, you should take a look at `Config.py`, which contains all parameters for tracking 51 | * Change `self.net_base_path` to the path saving your trained models 52 | * Change `self.seq_base_path` to the path storing your test sequences (OTB format, otherwise you need to revise the function `load_sequence()` in `Tracking_Utils.py` 53 | * Change `self.net` to indicate whcih model you want for evaluation (by default, use the last one), and I've uploaded a trained model `SiamFC_50_model.pth` in this rep (located in $SiamFC-PyTorch/Train/model/) 54 | * Change other parameters as your willing :) 55 | * Now, let's run `$python run_Train_SiamFC.py` 56 | * some notes in tracking 57 | * two evaluation types are provided: single video demo and evaluation on the whole (OTB-100) benchmark 58 | * you can also change whihc net for evaluation in `run_Train_SiamFC.py` 59 | 60 | ## Results 61 | I tested the trained model on OTB-100 using a Nvidia 1080 GPU. The results and comparisons to the original implementation are shown in the below image. The running speed of our implementation is 82 fps. Note that, both models are trained from stratch. 62 | 63 | ![image](/imgs/result.PNG) 64 | 65 | ## References 66 | 67 | [1] L. Bertinetto, J. Valmadre, J. F. Henriques, A. Vedaldi, and P. H. Torr. Fully-convolutional siamese networks for object tracking. In ECCV Workshop, 2016. 68 | 69 | [2] A. Vedaldi and K. Lenc. Matconvnet – convolutional neural networks for matlab. In ACM MM, 2015. 70 | 71 | ## Contact 72 | 73 | Any question are welcomed to hengfan@temple.edu. 74 | -------------------------------------------------------------------------------- /Tracking/Config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration for training SiamFC and tracking evaluation 3 | Written by Heng Fan 4 | """ 5 | 6 | 7 | class Config: 8 | def __init__(self): 9 | # parameters for training 10 | self.pos_pair_range = 100 11 | self.num_pairs = 5.32e4 12 | self.val_ratio = 0.1 13 | self.num_epoch = 50 14 | self.batch_size = 8 15 | self.examplar_size = 127 16 | self.instance_size = 255 17 | self.sub_mean = 0 18 | self.train_num_workers = 12 # number of threads to load data when training 19 | self.val_num_workers = 8 20 | self.stride = 8 21 | self.rPos = 16 22 | self.rNeg = 0 23 | self.label_weight_method = "balanced" 24 | 25 | self.lr = 1e-2 # learning rate of SGD 26 | self.momentum = 0.9 # momentum of SGD 27 | self.weight_decay = 5e-4 # weight decay of optimizator 28 | self.step_size = 1 # step size of LR_Schedular 29 | self.gamma = 0.8685 # decay rate of LR_Schedular 30 | 31 | # parameters for tracking (SiamFC-3s by default) 32 | self.num_scale = 3 33 | self.scale_step = 1.0375 34 | self.scale_penalty = 0.9745 35 | self.scale_LR = 0.59 36 | self.response_UP = 16 37 | self.windowing = "cosine" 38 | self.w_influence = 0.176 39 | 40 | self.video = "Lemming" 41 | self.visualization = 0 42 | self.bbox_output = True 43 | self.bbox_output_path = "./tracking_result/" 44 | 45 | self.context_amount = 0.5 46 | self.scale_min = 0.2 47 | self.scale_max = 5 48 | self.score_size = 17 49 | 50 | # path to your trained model 51 | self.net_base_path = "/home/hfan/Desktop/PyTorch-SiamFC/Train/model/" 52 | # path to your sequences (sequence should be in OTB format) 53 | self.seq_base_path = "/home/hfan/Desktop/demo-sequences/" 54 | # which model to use 55 | self.net = "SiamFC_50_model.pth" -------------------------------------------------------------------------------- /Tracking/SiamNet.py: -------------------------------------------------------------------------------- 1 | """ 2 | The architecture of SiamFC 3 | Written by Heng Fan 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from Config import * 10 | 11 | 12 | class SiamNet(nn.Module): 13 | 14 | def __init__(self): 15 | super(SiamNet, self).__init__() 16 | 17 | # architecture (AlexNet like) 18 | self.feat_extraction = nn.Sequential( 19 | nn.Conv2d(3, 96, 11, 2), # conv1 20 | nn.BatchNorm2d(96), 21 | nn.ReLU(inplace=True), 22 | nn.MaxPool2d(3, 2), 23 | nn.Conv2d(96, 256, 5, 1, groups=2), # conv2, group convolution 24 | nn.BatchNorm2d(256), 25 | nn.ReLU(inplace=True), 26 | nn.MaxPool2d(3, 2), 27 | nn.Conv2d(256, 384, 3, 1), # conv3 28 | nn.BatchNorm2d(384), 29 | nn.ReLU(inplace=True), 30 | nn.Conv2d(384, 384, 3, 1, groups=2), # conv4 31 | nn.BatchNorm2d(384), 32 | nn.ReLU(inplace=True), 33 | nn.Conv2d(384, 256, 3, 1, groups=2) # conv5 34 | ) 35 | 36 | # adjust layer as in the original SiamFC in matconvnet 37 | self.adjust = nn.Conv2d(1, 1, 1, 1) 38 | 39 | # initialize weights 40 | self._initialize_weight() 41 | 42 | self.config = Config() 43 | 44 | def forward(self, z, x): 45 | """ 46 | forward pass 47 | z: examplare, BxCxHxW 48 | x: search region, BxCxH1xW1 49 | """ 50 | # get features for z and x 51 | z_feat = self.feat_extraction(z) 52 | x_feat = self.feat_extraction(x) 53 | 54 | # correlation of z and x 55 | xcorr_out = self.xcorr(z_feat, x_feat) 56 | 57 | score = self.adjust(xcorr_out) 58 | 59 | return score 60 | 61 | def xcorr(self, z, x): 62 | """ 63 | correlation layer as in the original SiamFC (convolution process in fact) 64 | """ 65 | batch_size_x, channel_x, w_x, h_x = x.shape 66 | x = torch.reshape(x, (1, batch_size_x * channel_x, w_x, h_x)) 67 | 68 | # group convolution 69 | out = F.conv2d(x, z, groups = batch_size_x) 70 | 71 | batch_size_out, channel_out, w_out, h_out = out.shape 72 | xcorr_out = torch.reshape(out, (channel_out, batch_size_out, w_out, h_out)) 73 | 74 | return xcorr_out 75 | 76 | def _initialize_weight(self): 77 | """ 78 | initialize network parameters 79 | """ 80 | tmp_layer_idx = 0 81 | for m in self.modules(): 82 | if isinstance(m, nn.Conv2d): 83 | tmp_layer_idx = tmp_layer_idx + 1 84 | if tmp_layer_idx < 6: 85 | # kaiming initialization 86 | nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu') 87 | else: 88 | # initialization for adjust layer as in the original paper 89 | m.weight.data.fill_(1e-3) 90 | m.bias.data.zero_() 91 | elif isinstance(m, nn.BatchNorm2d): 92 | m.weight.data.fill_(1) 93 | m.bias.data.zero_() 94 | 95 | def weight_loss(self, prediction, label, weight): 96 | """ 97 | weighted cross entropy loss 98 | """ 99 | return F.binary_cross_entropy_with_logits(prediction, 100 | label, 101 | weight, 102 | size_average=False) / self.config.batch_size 103 | -------------------------------------------------------------------------------- /Tracking/Tracking_Utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tool functiond for tracking evaluation 3 | Written by Heng Fan 4 | """ 5 | 6 | import os 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import matplotlib.patches as patches 10 | import cv2 11 | import glob 12 | 13 | 14 | def cat_img(image_cat1, image_cat2, image_cat3): 15 | """ 16 | concatenate three 1-channel images to one 3-channel image 17 | """ 18 | image = np.zeros(shape = (image_cat1.shape[0], image_cat1.shape[1], 3), dtype=np.double) 19 | image[:, :, 0] = image_cat1 20 | image[:, :, 1] = image_cat2 21 | image[:, :, 2] = image_cat3 22 | 23 | return image 24 | 25 | 26 | def load_sequence(seq_root_path, seq_name): 27 | """ 28 | load sequences; 29 | sequences should be in OTB format, or you can custom this function by yourself 30 | """ 31 | img_dir = os.path.join(seq_root_path, seq_name, 'img/') 32 | gt_path = os.path.join(seq_root_path, seq_name, 'groundtruth_rect.txt') 33 | 34 | img_list = glob.glob(img_dir + "*.jpg") 35 | img_list.sort() 36 | img_list = [os.path.join(img_dir, x) for x in img_list] 37 | 38 | gt = np.loadtxt(gt_path, delimiter=',') 39 | 40 | init_bbox = gt[0] 41 | if seq_name == "Tiger1": 42 | init_bbox = gt[5] 43 | 44 | init_x = init_bbox[0] 45 | init_y = init_bbox[1] 46 | init_w = init_bbox[2] 47 | init_h = init_bbox[3] 48 | 49 | target_position = np.array([init_y + init_h/2, init_x + init_w/2], dtype = np.double) 50 | target_sz = np.array([init_h, init_w], dtype = np.double) 51 | 52 | if seq_name == "David": 53 | img_list = img_list[299:] 54 | if seq_name == "Tiger1": 55 | img_list = img_list[5:] 56 | if seq_name == "Football1": 57 | img_list = img_list[0:74] 58 | 59 | return img_list, target_position, target_sz 60 | 61 | 62 | def visualize_tracking_result(img, bbox, fig_n): 63 | """ 64 | visualize tracking result 65 | """ 66 | fig = plt.figure(fig_n) 67 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 68 | ax.set_axis_off() 69 | fig.add_axes(ax) 70 | r = patches.Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3], linewidth = 3, edgecolor = "#00ff00", zorder = 1, fill = False) 71 | ax.imshow(img) 72 | ax.add_patch(r) 73 | plt.ion() 74 | plt.show() 75 | plt.pause(0.00001) 76 | plt.clf() 77 | 78 | 79 | def get_subwindow_tracking(im, pos, model_sz, original_sz, avg_chans): 80 | """ 81 | extract image crop 82 | """ 83 | if original_sz is None: 84 | original_sz = model_sz 85 | 86 | sz = original_sz 87 | im_sz = im.shape 88 | # make sure the size is not too small 89 | assert (im_sz[0] > 2) & (im_sz[1] > 2), "The size of image is too small!" 90 | c = (sz+1) / 2 91 | 92 | # check out-of-bounds coordinates, and set them to black 93 | context_xmin = round(pos[1] - c) # floor(pos(2) - sz(2) / 2); 94 | context_xmax = context_xmin + sz - 1 95 | context_ymin = round(pos[0] - c) # floor(pos(1) - sz(1) / 2); 96 | context_ymax = context_ymin + sz - 1 97 | left_pad = max(0, 1-context_xmin) # in python, index starts from 0 98 | top_pad = max(0, 1-context_ymin) 99 | right_pad = max(0, context_xmax - im_sz[1]) 100 | bottom_pad = max(0, context_ymax - im_sz[0]) 101 | 102 | context_xmin = context_xmin + left_pad 103 | context_xmax = context_xmax + left_pad 104 | context_ymin = context_ymin + top_pad 105 | context_ymax = context_ymax + top_pad 106 | 107 | im_R = im[:, :, 0] 108 | im_G = im[:, :, 1] 109 | im_B = im[:, :, 2] 110 | 111 | # padding 112 | if (top_pad !=0) | (bottom_pad !=0) | (left_pad !=0) | (right_pad !=0): 113 | im_R = np.pad(im_R, ((int(top_pad), int(bottom_pad)), (int(left_pad), int(right_pad))), 'constant', constant_values = avg_chans[0]) 114 | im_G = np.pad(im_G, ((int(top_pad), int(bottom_pad)), (int(left_pad), int(right_pad))), 'constant', constant_values = avg_chans[1]) 115 | im_B = np.pad(im_B, ((int(top_pad), int(bottom_pad)), (int(left_pad), int(right_pad))), 'constant', constant_values = avg_chans[2]) 116 | 117 | im = cat_img(im_R, im_G, im_B) 118 | 119 | im_patch_original = im[int(context_ymin)-1:int(context_ymax), int(context_xmin)-1:int(context_xmax), :] 120 | 121 | if model_sz != original_sz: 122 | im_patch = cv2.resize(im_patch_original, (int(model_sz), int(model_sz)), interpolation = cv2.INTER_CUBIC) 123 | else: 124 | im_patch = im_patch_original 125 | 126 | return im_patch 127 | 128 | 129 | def make_scale_pyramid(im, target_position, in_side_scaled, out_side, avg_chans, p): 130 | """ 131 | extract multi-scale image crops 132 | """ 133 | in_side_scaled = np.round(in_side_scaled) 134 | pyramid = np.zeros((out_side, out_side, 3, p.num_scale), dtype = np.double) 135 | max_target_side = in_side_scaled[in_side_scaled.size-1] 136 | min_target_side = in_side_scaled[0] 137 | beta = out_side / min_target_side 138 | # size_in_search_area = beta * size_in_image 139 | # e.g. out_side = beta * min_target_side 140 | search_side = round(beta * max_target_side) 141 | 142 | search_region = get_subwindow_tracking(im, target_position, search_side, max_target_side, avg_chans) 143 | 144 | assert (round(beta * min_target_side) == out_side), "Error!" 145 | 146 | for s in range(p.num_scale): 147 | target_side = round(beta * in_side_scaled[s]) 148 | search_target_position = np.array([1 + search_side/2, 1 + search_side/2], dtype = np.double) 149 | pyramid[:, :, :, s] = get_subwindow_tracking(search_region, search_target_position, out_side, 150 | target_side, avg_chans) 151 | 152 | return pyramid 153 | 154 | 155 | def tracker_eval(net, s_x, z_features, x_features, target_position, window, p): 156 | """ 157 | do evaluation (i.e., a forward pass for search region) 158 | (This part is implemented as in the original Matlab version) 159 | """ 160 | # compute scores search regions of different scales 161 | scores = net.xcorr(z_features, x_features) 162 | scores = scores.to("cpu") 163 | 164 | response_maps = scores.squeeze().permute(1, 2, 0).data.numpy() 165 | # for this one, the opencv resize function works fine 166 | response_maps_up = cv2.resize(response_maps, (response_maps.shape[0]*p.response_UP, response_maps.shape[0]*p.response_UP), interpolation=cv2.INTER_CUBIC) 167 | 168 | # choose the scale whose response map has the highest peak 169 | if p.num_scale > 1: 170 | current_scale_id =np.ceil(p.num_scale/2) 171 | best_scale = current_scale_id 172 | best_peak = float("-inf") 173 | for s in range(p.num_scale): 174 | this_response = response_maps_up[:, :, s] 175 | # penalize change of scale 176 | if s != current_scale_id: 177 | this_response = this_response * p.scale_penalty 178 | this_peak = np.max(this_response) 179 | if this_peak > best_peak: 180 | best_peak = this_peak 181 | best_scale = s 182 | response_map = response_maps_up[:, :, int(best_scale)] 183 | else: 184 | response_map = response_maps_up 185 | best_scale = 1 186 | # make the response map sum to 1 187 | response_map = response_map - np.min(response_map) 188 | response_map = response_map / sum(sum(response_map)) 189 | 190 | # apply windowing 191 | response_map = (1 - p.w_influence) * response_map + p.w_influence * window 192 | p_corr = np.asarray(np.unravel_index(np.argmax(response_map), np.shape(response_map))) 193 | 194 | # avoid empty 195 | if p_corr[0] is None: 196 | p_corr[0] = np.ceil(p.score_size/2) 197 | if p_corr[1] is None: 198 | p_corr[1] = np.ceil(p.score_size/2) 199 | 200 | # Convert to crop-relative coordinates to frame coordinates 201 | # displacement from the center in instance final representation ... 202 | disp_instance_final = p_corr - np.ceil(p.score_size * p.response_UP / 2) 203 | # ... in instance input ... 204 | disp_instance_input = disp_instance_final * p.stride / p.response_UP 205 | # ... in instance original crop (in frame coordinates) 206 | disp_instance_frame = disp_instance_input * s_x / p.instance_size 207 | # position within frame in frame coordinates 208 | new_target_position = target_position + disp_instance_frame 209 | 210 | return new_target_position, best_scale -------------------------------------------------------------------------------- /Tracking/run_SiamFC.py: -------------------------------------------------------------------------------- 1 | from Config import * 2 | from Tracking_Utils import * 3 | from SiamNet import * 4 | import os 5 | import numpy as np 6 | import torchvision.transforms.functional as F 7 | import cv2 8 | import datetime 9 | import torch 10 | from torch.autograd import Variable 11 | 12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | 14 | 15 | # entry to evaluation of SiamFC 16 | def run_tracker(p): 17 | """ 18 | run tracker, return bounding result and speed 19 | """ 20 | # load model 21 | net = torch.load(os.path.join(p.net_base_path, p.net)) 22 | net = net.to(device) 23 | 24 | # evaluation mode 25 | net.eval() 26 | 27 | # load sequence 28 | img_list, target_position, target_size = load_sequence(p.seq_base_path, p.video) 29 | 30 | # first frame 31 | img_uint8 = cv2.imread(img_list[0]) 32 | img_uint8 = cv2.cvtColor(img_uint8, cv2.COLOR_BGR2RGB) 33 | img_double = np.double(img_uint8) # uint8 to float 34 | 35 | # compute avg for padding 36 | avg_chans = np.mean(img_double, axis=(0, 1)) 37 | 38 | wc_z = target_size[1] + p.context_amount * sum(target_size) 39 | hc_z = target_size[0] + p.context_amount * sum(target_size) 40 | s_z = np.sqrt(wc_z * hc_z) 41 | scale_z = p.examplar_size / s_z 42 | 43 | # crop examplar z in the first frame 44 | z_crop = get_subwindow_tracking(img_double, target_position, p.examplar_size, round(s_z), avg_chans) 45 | 46 | z_crop = np.uint8(z_crop) # you need to convert it to uint8 47 | # convert image to tensor 48 | z_crop_tensor = 255.0 * F.to_tensor(z_crop).unsqueeze(0) 49 | 50 | d_search = (p.instance_size - p.examplar_size) / 2 51 | pad = d_search / scale_z 52 | s_x = s_z + 2 * pad 53 | # arbitrary scale saturation 54 | min_s_x = p.scale_min * s_x 55 | max_s_x = p.scale_max * s_x 56 | 57 | # generate cosine window 58 | if p.windowing == 'cosine': 59 | window = np.outer(np.hanning(p.score_size * p.response_UP), np.hanning(p.score_size * p.response_UP)) 60 | elif p.windowing == 'uniform': 61 | window = np.ones((p.score_size * p.response_UP, p.score_size * p.response_UP)) 62 | window = window / sum(sum(window)) 63 | 64 | # pyramid scale search 65 | scales = p.scale_step**np.linspace(-np.ceil(p.num_scale/2), np.ceil(p.num_scale/2), p.num_scale) 66 | 67 | # extract feature for examplar z 68 | z_features = net.feat_extraction(Variable(z_crop_tensor).to(device)) 69 | z_features = z_features.repeat(p.num_scale, 1, 1, 1) 70 | 71 | # do tracking 72 | bboxes = np.zeros((len(img_list), 4), dtype=np.double) # save tracking result 73 | start_time = datetime.datetime.now() 74 | for i in range(0, len(img_list)): 75 | if i > 0: 76 | # do detection 77 | # currently, we only consider RGB images for tracking 78 | img_uint8 = cv2.imread(img_list[i]) 79 | img_uint8 = cv2.cvtColor(img_uint8, cv2.COLOR_BGR2RGB) 80 | img_double = np.double(img_uint8) # uint8 to float 81 | 82 | scaled_instance = s_x * scales 83 | scaled_target = np.zeros((2, scales.size), dtype = np.double) 84 | scaled_target[0, :] = target_size[0] * scales 85 | scaled_target[1, :] = target_size[1] * scales 86 | 87 | # extract scaled crops for search region x at previous target position 88 | x_crops = make_scale_pyramid(img_double, target_position, scaled_instance, p.instance_size, avg_chans, p) 89 | 90 | # get features of search regions 91 | x_crops_tensor = torch.FloatTensor(x_crops.shape[3], x_crops.shape[2], x_crops.shape[1], x_crops.shape[0]) 92 | # response_map = SiameseNet.get_response_map(z_features, x_crops) 93 | for k in range(x_crops.shape[3]): 94 | tmp_x_crop = x_crops[:, :, :, k] 95 | tmp_x_crop = np.uint8(tmp_x_crop) 96 | # numpy array to tensor 97 | x_crops_tensor[k, :, :, :] = 255.0 * F.to_tensor(tmp_x_crop).unsqueeze(0) 98 | 99 | # get features of search regions 100 | x_features = net.feat_extraction(Variable(x_crops_tensor).to(device)) 101 | 102 | # evaluate the offline-trained network for exemplar x features 103 | target_position, new_scale = tracker_eval(net, round(s_x), z_features, x_features, target_position, window, p) 104 | 105 | # scale damping and saturation 106 | s_x = max(min_s_x, min(max_s_x, (1 - p.scale_LR) * s_x + p.scale_LR * scaled_instance[int(new_scale)])) 107 | target_size = (1 - p.scale_LR) * target_size + p.scale_LR * np.array([scaled_target[0, int(new_scale)], scaled_target[1, int(new_scale)]]) 108 | 109 | rect_position = np.array([target_position[1]-target_size[1]/2, target_position[0]-target_size[0]/2, target_size[1], target_size[0]]) 110 | 111 | if p.visualization: 112 | visualize_tracking_result(img_uint8, rect_position, 1) 113 | 114 | # output bbox in the original frame coordinates 115 | o_target_position = target_position 116 | o_target_size = target_size 117 | bboxes[i,:] = np.array([o_target_position[1]-o_target_size[1]/2, o_target_position[0]-o_target_size[0]/2, o_target_size[1], o_target_size[0]]) 118 | 119 | end_time = datetime.datetime.now() 120 | fps = len(img_list)/max(1.0, (end_time-start_time).seconds) 121 | 122 | return bboxes, fps 123 | 124 | 125 | if __name__ == "__main__": 126 | 127 | # get the default parameters 128 | p = Config() 129 | 130 | # choose which model to run 131 | p.net = "SiamFC_50_model.pth" 132 | 133 | # choose demo type, single or all 134 | demp_type = "all" 135 | 136 | if demp_type == "single": # single video demo 137 | video = "Lemming" 138 | p.video = video 139 | print("Processing %s ... " % p.video) 140 | bbox_result, fps = run_tracker(p) 141 | print("FPS: %d " % fps) 142 | else: # evaluation the whole OTB benchmark 143 | # load all videos 144 | all_videos = os.listdir(p.seq_base_path) 145 | 146 | if p.bbox_output: 147 | if not os.path.exists(p.bbox_output_path): 148 | os.makedirs(p.bbox_output_path) 149 | 150 | fps_all = .0 151 | 152 | for video in all_videos: 153 | p.video = video 154 | print("Processing %s ... " % p.video) 155 | bbox_result, fps = run_tracker(p) 156 | # fps for this video 157 | print("FPS: %d " % fps) 158 | # saving tracking results 159 | if p.bbox_output: 160 | np.savetxt(p.bbox_output_path + p.video.lower() + '_SiamFC.txt', bbox_result, fmt='%.3f') 161 | fps_all = fps_all + fps 162 | 163 | avg_fps = fps_all / len(all_videos) 164 | print("Average FPS: %f" % avg_fps) 165 | -------------------------------------------------------------------------------- /Train/Config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration for training SiamFC and tracking evaluation 3 | Written by Heng Fan 4 | """ 5 | 6 | 7 | class Config: 8 | def __init__(self): 9 | # parameters for training 10 | self.pos_pair_range = 100 11 | self.num_pairs = 5.32e4 12 | self.val_ratio = 0.1 13 | self.num_epoch = 50 14 | self.batch_size = 8 15 | self.examplar_size = 127 16 | self.instance_size = 255 17 | self.sub_mean = 0 18 | self.train_num_workers = 12 # number of threads to load data when training 19 | self.val_num_workers = 8 20 | self.stride = 8 21 | self.rPos = 16 22 | self.rNeg = 0 23 | self.label_weight_method = "balanced" 24 | 25 | self.lr = 1e-2 # learning rate of SGD 26 | self.momentum = 0.9 # momentum of SGD 27 | self.weight_decay = 5e-4 # weight decay of optimizator 28 | self.step_size = 1 # step size of LR_Schedular 29 | self.gamma = 0.8685 # decay rate of LR_Schedular 30 | 31 | # parameters for tracking (SiamFC-3s by default) 32 | self.num_scale = 3 33 | self.scale_step = 1.0375 34 | self.scale_penalty = 0.9745 35 | self.scale_LR = 0.59 36 | self.response_UP = 16 37 | self.windowing = "cosine" 38 | self.w_influence = 0.176 39 | 40 | self.video = "Lemming" 41 | self.visualization = 1 42 | self.bbox_output = True 43 | self.bbox_output_path = "./tracking_result/" 44 | 45 | self.context_amount = 0.5 46 | self.scale_min = 0.2 47 | self.scale_max = 5 48 | self.score_size = 17 49 | 50 | # path to your trained model 51 | self.net_base_path = "/home/hfan/Desktop/PyTorch-SiamFC/Train/model/" 52 | # path to your sequences (sequence should be in OTB format) 53 | self.seq_base_path = "/home/hfan/Desktop/demo-sequences/" 54 | # which model to use 55 | self.net = "SiamFC_50_model.pth" -------------------------------------------------------------------------------- /Train/DataAugmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | 5 | 6 | class RandomStretch(object): 7 | def __init__(self, max_stretch=0.05): 8 | self.max_stretch = max_stretch 9 | 10 | def __call__(self, sample): 11 | scale_h = 1.0 + np.random.uniform(-self.max_stretch, self.max_stretch) 12 | scale_w = 1.0 + np.random.uniform(-self.max_stretch, self.max_stretch) 13 | h, w = sample.shape[:2] 14 | shape = (int(h * scale_h), int(w * scale_w)) 15 | return cv2.resize(sample, shape, cv2.INTER_CUBIC) 16 | 17 | 18 | class CenterCrop(object): 19 | ''' 20 | center crop for examplar z 21 | ''' 22 | def __init__(self, size): 23 | self.size = size 24 | 25 | def __call__(self, sample): 26 | shape = sample.shape[:2] 27 | cy, cx = shape[0] // 2, shape[1] // 2 28 | ymin, xmin = cy - self.size[0]//2, cx - self.size[1] // 2 29 | ymax, xmax = cy + self.size[0]//2 + 1, cx + self.size[1] // 2 + 1 30 | left = right = top = bottom = 0 31 | im_h, im_w = shape 32 | if xmin < 0: 33 | left = int(abs(xmin)) 34 | if xmax > im_w: 35 | right = int(xmax - im_w) 36 | if ymin < 0: 37 | top = int(abs(ymin)) 38 | if ymax > im_h: 39 | bottom = int(ymax - im_h) 40 | 41 | xmin = int(max(0, xmin)) 42 | xmax = int(min(im_w, xmax)) 43 | ymin = int(max(0, ymin)) 44 | ymax = int(min(im_h, ymax)) 45 | im_patch = sample[ymin:ymax, xmin:xmax] 46 | if left != 0 or right !=0 or top!=0 or bottom!=0: 47 | im_patch = cv2.copyMakeBorder(im_patch, top, bottom, left, right, 48 | cv2.BORDER_CONSTANT, value=0) 49 | return im_patch 50 | 51 | 52 | class RandomCrop(object): 53 | def __init__(self, size): 54 | self.size = size 55 | 56 | def __call__(self, sample): 57 | shape = sample.shape[:2] 58 | y1 = np.random.randint(0, shape[0] - self.size[0]) 59 | x1 = np.random.randint(0, shape[1] - self.size[1]) 60 | y2 = y1 + self.size[0] 61 | x2 = x1 + self.size[1] 62 | img_patch = sample[y1:y2, x1:x2] 63 | return img_patch 64 | 65 | 66 | class Normalize(object): 67 | def __init__(self): 68 | self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) 69 | self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32) 70 | 71 | def __call__(self, sample): 72 | return (sample / 255. - self.mean) / self.std 73 | 74 | 75 | class ToTensor(object): 76 | def __call__(self, sample): 77 | sample = sample.transpose(2, 0, 1) 78 | return torch.from_numpy(sample.astype(np.float32)) 79 | -------------------------------------------------------------------------------- /Train/SiamNet.py: -------------------------------------------------------------------------------- 1 | """ 2 | The architecture of SiamFC 3 | Written by Heng Fan 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from Config import * 10 | 11 | 12 | class SiamNet(nn.Module): 13 | 14 | def __init__(self): 15 | super(SiamNet, self).__init__() 16 | 17 | # architecture (AlexNet like) 18 | self.feat_extraction = nn.Sequential( 19 | nn.Conv2d(3, 96, 11, 2), # conv1 20 | nn.BatchNorm2d(96), 21 | nn.ReLU(inplace=True), 22 | nn.MaxPool2d(3, 2), 23 | nn.Conv2d(96, 256, 5, 1, groups=2), # conv2, group convolution 24 | nn.BatchNorm2d(256), 25 | nn.ReLU(inplace=True), 26 | nn.MaxPool2d(3, 2), 27 | nn.Conv2d(256, 384, 3, 1), # conv3 28 | nn.BatchNorm2d(384), 29 | nn.ReLU(inplace=True), 30 | nn.Conv2d(384, 384, 3, 1, groups=2), # conv4 31 | nn.BatchNorm2d(384), 32 | nn.ReLU(inplace=True), 33 | nn.Conv2d(384, 256, 3, 1, groups=2) # conv5 34 | ) 35 | 36 | # adjust layer as in the original SiamFC in matconvnet 37 | self.adjust = nn.Conv2d(1, 1, 1, 1) 38 | 39 | # initialize weights 40 | self._initialize_weight() 41 | 42 | self.config = Config() 43 | 44 | def forward(self, z, x): 45 | """ 46 | forward pass 47 | z: examplare, BxCxHxW 48 | x: search region, BxCxH1xW1 49 | """ 50 | # get features for z and x 51 | z_feat = self.feat_extraction(z) 52 | x_feat = self.feat_extraction(x) 53 | 54 | # correlation of z and z 55 | xcorr_out = self.xcorr(z_feat, x_feat) 56 | 57 | score = self.adjust(xcorr_out) 58 | 59 | return score 60 | 61 | def xcorr(self, z, x): 62 | """ 63 | correlation layer as in the original SiamFC (convolution process in fact) 64 | """ 65 | batch_size_x, channel_x, w_x, h_x = x.shape 66 | x = torch.reshape(x, (1, batch_size_x * channel_x, w_x, h_x)) 67 | 68 | # group convolution 69 | out = F.conv2d(x, z, groups = batch_size_x) 70 | 71 | batch_size_out, channel_out, w_out, h_out = out.shape 72 | xcorr_out = torch.reshape(out, (channel_out, batch_size_out, w_out, h_out)) 73 | 74 | return xcorr_out 75 | 76 | def _initialize_weight(self): 77 | """ 78 | initialize network parameters 79 | """ 80 | tmp_layer_idx = 0 81 | for m in self.modules(): 82 | if isinstance(m, nn.Conv2d): 83 | tmp_layer_idx = tmp_layer_idx + 1 84 | if tmp_layer_idx < 6: 85 | # kaiming initialization 86 | nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu') 87 | else: 88 | # initialization for adjust layer as in the original paper 89 | m.weight.data.fill_(1e-3) 90 | m.bias.data.zero_() 91 | elif isinstance(m, nn.BatchNorm2d): 92 | m.weight.data.fill_(1) 93 | m.bias.data.zero_() 94 | 95 | def weight_loss(self, prediction, label, weight): 96 | """ 97 | weighted cross entropy loss 98 | """ 99 | return F.binary_cross_entropy_with_logits(prediction, 100 | label, 101 | weight, 102 | size_average=False) / self.config.batch_size 103 | -------------------------------------------------------------------------------- /Train/Utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Define some tool functions 3 | Written by Heng Fan 4 | """ 5 | 6 | import torch 7 | import numpy as np 8 | import cv2 9 | 10 | 11 | def create_logisticloss_label(label_size, rPos, rNeg): 12 | """ 13 | construct label for logistic loss (same for all pairs) 14 | """ 15 | label_side = int(label_size[0]) 16 | logloss_label = torch.zeros(label_side, label_side) 17 | label_origin = np.array([np.ceil(label_side / 2), np.ceil(label_side / 2)]) 18 | for i in range(label_side): 19 | for j in range(label_side): 20 | dist_from_origin = np.sqrt((i - label_origin[0]) ** 2 + (j - label_origin[1]) ** 2) 21 | if dist_from_origin <= rPos: 22 | logloss_label[i, j] = +1 23 | else: 24 | if dist_from_origin <= rNeg: 25 | logloss_label[i, j] = 0 26 | 27 | return logloss_label 28 | 29 | 30 | def create_label(fixed_label_size, config, use_gpu): 31 | """ 32 | create label with weight 33 | """ 34 | rPos = config.rPos / config.stride 35 | rNeg = config.rNeg / config.stride 36 | 37 | half = int(np.floor(fixed_label_size[0] / 2) + 1) 38 | 39 | if config.label_weight_method == "balanced": 40 | fixed_label = create_logisticloss_label(fixed_label_size, rPos, rNeg) 41 | # plt.imshow(fixed_label) 42 | # plt.colorbar() 43 | # plt.show() 44 | instance_weight = torch.ones(fixed_label.shape[0], fixed_label.shape[1]) 45 | tmp_idx_P = np.where(fixed_label == 1) 46 | sumP = tmp_idx_P[0].size 47 | tmp_idx_N = np.where(fixed_label == 0) 48 | sumN = tmp_idx_N[0].size 49 | instance_weight[tmp_idx_P] = 0.5 * instance_weight[tmp_idx_P] / sumP 50 | instance_weight[tmp_idx_N] = 0.5 * instance_weight[tmp_idx_N] / sumN 51 | # plt.imshow(instance_weight) 52 | # plt.colorbar() 53 | # plt.show() 54 | 55 | # reshape label 56 | fixed_label = torch.reshape(fixed_label, (1, 1, fixed_label.shape[0], fixed_label.shape[1])) 57 | # copy label to match batchsize 58 | fixed_label = fixed_label.repeat(config.batch_size, 1, 1, 1) 59 | 60 | # reshape weight 61 | instance_weight = torch.reshape(instance_weight, (1, instance_weight.shape[0], instance_weight.shape[1])) 62 | 63 | if use_gpu: 64 | return fixed_label.cuda(), instance_weight.cuda() 65 | else: 66 | return fixed_label, instance_weight 67 | 68 | 69 | def cv2_brg2rgb(bgr_img): 70 | """ 71 | convert brg image to rgb 72 | """ 73 | b, g, r = cv2.split(bgr_img) 74 | rgb_img = cv2.merge([r, g, b]) 75 | 76 | return rgb_img 77 | 78 | 79 | def float32_to_uint8(img): 80 | """ 81 | convert float32 array to uint8 82 | """ 83 | beyong_255 = np.where(img > 255) 84 | img[beyong_255] = 255 85 | less_0 = np.where(img < 0) 86 | img[less_0] = 0 87 | img = np.round(img) 88 | 89 | return img.astype(np.uint8) 90 | -------------------------------------------------------------------------------- /Train/VIDDataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset for VID 3 | Written by Heng Fan 4 | """ 5 | 6 | from torch.utils.data.dataset import Dataset 7 | import json 8 | from Utils import * 9 | import os 10 | 11 | 12 | class VIDDataset(Dataset): 13 | 14 | def __init__(self, imdb, data_dir, config, z_transforms, x_transforms, mode = "Train"): 15 | imdb_video = json.load(open(imdb, 'r')) 16 | self.videos = imdb_video['videos'] 17 | self.data_dir = data_dir 18 | self.config = config 19 | self.num_videos = int(imdb_video['num_videos']) 20 | 21 | self.z_transforms = z_transforms 22 | self.x_transforms = x_transforms 23 | 24 | if mode == "Train": 25 | self.num = self.config.num_pairs 26 | else: 27 | self.num = self.num_videos 28 | 29 | def __getitem__(self, rand_vid): 30 | ''' 31 | read a pair of images z and x 32 | ''' 33 | # randomly decide the id of video to get z and x 34 | rand_vid = rand_vid % self.num_videos 35 | 36 | video_keys = self.videos.keys() 37 | video = self.videos[video_keys[rand_vid]] 38 | 39 | # get ids of this video 40 | video_ids = video[0] 41 | # how many ids in this video 42 | video_id_keys = video_ids.keys() 43 | 44 | # randomly pick an id for z 45 | rand_trackid_z = np.random.choice(list(range(len(video_id_keys)))) 46 | # get the video for this id 47 | video_id_z = video_ids[video_id_keys[rand_trackid_z]] 48 | 49 | # pick a valid examplar z in the video 50 | rand_z = np.random.choice(range(len(video_id_z))) 51 | 52 | # pick a valid instance within frame_range frames from the examplar, excluding the examplar itself 53 | possible_x_pos = range(len(video_id_z)) 54 | rand_x = np.random.choice(possible_x_pos[max(rand_z - self.config.pos_pair_range, 0):rand_z] + possible_x_pos[(rand_z + 1):min(rand_z + self.config.pos_pair_range, len(video_id_z))]) 55 | 56 | z = video_id_z[rand_z].copy() # use copy() here to avoid changing dictionary 57 | x = video_id_z[rand_x].copy() 58 | 59 | # read z and x 60 | img_z = cv2.imread(os.path.join(self.data_dir, z['instance_path'])) 61 | img_z = cv2.cvtColor(img_z, cv2.COLOR_BGR2RGB) 62 | 63 | img_x = cv2.imread(os.path.join(self.data_dir, x['instance_path'])) 64 | img_x = cv2.cvtColor(img_x, cv2.COLOR_BGR2RGB) 65 | 66 | # do data augmentation; 67 | # note that we have done center crop for z in the data augmentation 68 | img_z = self.z_transforms(img_z) 69 | img_x = self.x_transforms(img_x) 70 | 71 | return img_z, img_x 72 | 73 | def __len__(self): 74 | return self.num 75 | -------------------------------------------------------------------------------- /Train/model/SiamFC_50_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HengLan/SiamFC-PyTorch/6cb921f5ac58b612ebface176d9a84d5e033150a/Train/model/SiamFC_50_model.pth -------------------------------------------------------------------------------- /Train/run_Train_SiamFC.py: -------------------------------------------------------------------------------- 1 | """ 2 | PyTorch implementation of SiamFC (Luca Bertinetto, et al., ECCVW, 2016) 3 | Written by Heng Fan 4 | """ 5 | 6 | from SiamNet import * 7 | from VIDDataset import * 8 | from torch.autograd import Variable 9 | from torch.optim.lr_scheduler import StepLR 10 | from Config import * 11 | from Utils import * 12 | import torchvision.transforms as transforms 13 | from DataAugmentation import * 14 | import os 15 | from tqdm import tqdm 16 | from torch.utils.data import DataLoader 17 | 18 | # fix random seed 19 | np.random.seed(1357) 20 | torch.manual_seed(1234) 21 | 22 | 23 | def train(data_dir, train_imdb, val_imdb, model_save_path="./model/", use_gpu=True): 24 | 25 | # initialize training configuration 26 | config = Config() 27 | 28 | # do data augmentation in PyTorch; 29 | # you can also do complex data augmentation as in the original paper 30 | center_crop_size = config.instance_size - config.stride 31 | random_crop_size = config.instance_size - 2 * config.stride 32 | 33 | train_z_transforms = transforms.Compose([ 34 | RandomStretch(), 35 | CenterCrop((config.examplar_size, config.examplar_size)), 36 | ToTensor() 37 | ]) 38 | train_x_transforms = transforms.Compose([ 39 | RandomStretch(), 40 | CenterCrop((center_crop_size, center_crop_size)), 41 | RandomCrop((random_crop_size, random_crop_size)), 42 | ToTensor() 43 | ]) 44 | valid_z_transforms = transforms.Compose([ 45 | CenterCrop((config.examplar_size, config.examplar_size)), 46 | ToTensor(), 47 | ]) 48 | valid_x_transforms = transforms.Compose([ 49 | ToTensor() 50 | ]) 51 | 52 | # load data (see details in VIDDataset.py) 53 | train_dataset = VIDDataset(train_imdb, data_dir, config, train_z_transforms, train_x_transforms) 54 | val_dataset = VIDDataset(val_imdb, data_dir, config, valid_z_transforms, valid_x_transforms, "Validation") 55 | 56 | # create dataloader 57 | train_loader = DataLoader(train_dataset, batch_size=config.batch_size, 58 | shuffle=True, num_workers=config.train_num_workers, drop_last=True) 59 | val_loader = DataLoader(val_dataset, batch_size=config.batch_size, 60 | shuffle=True, num_workers=config.val_num_workers, drop_last=True) 61 | 62 | # create SiamFC network architecture (see details in SiamNet.py) 63 | net = SiamNet() 64 | # move network to GPU if using GPU 65 | if use_gpu: 66 | net.cuda() 67 | 68 | # define training strategy; 69 | # the learning rate of adjust layer (i.e., a conv layer) 70 | # is set to 0 as in the original paper 71 | optimizer = torch.optim.SGD([ 72 | {'params': net.feat_extraction.parameters()}, 73 | {'params': net.adjust.bias}, 74 | {'params': net.adjust.weight, 'lr': 0}, 75 | ], config.lr, config.momentum, config.weight_decay) 76 | 77 | # adjusting learning in each epoch 78 | scheduler = StepLR(optimizer, config.step_size, config.gamma) 79 | 80 | # used to control generating label for training; 81 | # once generated, they are fixed since the labels for each 82 | # pair of images (examplar z and search region x) are the same 83 | train_response_flag = False 84 | valid_response_flag = False 85 | 86 | # ------------------------ training & validation process ------------------------ 87 | for i in range(config.num_epoch): 88 | 89 | # adjusting learning rate 90 | scheduler.step() 91 | 92 | # ------------------------------ training ------------------------------ 93 | # indicating training (very important for batch normalization) 94 | net.train() 95 | 96 | # used to collect loss 97 | train_loss = [] 98 | 99 | for j, data in enumerate(tqdm(train_loader)): 100 | 101 | # fetch data, i.e., B x C x W x H (batchsize x channel x wdith x heigh) 102 | exemplar_imgs, instance_imgs = data 103 | 104 | # forward pass 105 | if use_gpu: 106 | exemplar_imgs = exemplar_imgs.cuda() 107 | instance_imgs = instance_imgs.cuda() 108 | output = net.forward(Variable(exemplar_imgs), Variable(instance_imgs)) 109 | 110 | # create label for training (only do it one time) 111 | if not train_response_flag: 112 | # change control flag 113 | train_response_flag = True 114 | # get shape of output (i.e., response map) 115 | response_size = output.shape[2:4] 116 | # generate label and weight 117 | train_eltwise_label, train_instance_weight = create_label(response_size, config, use_gpu) 118 | 119 | # clear the gradient 120 | optimizer.zero_grad() 121 | 122 | # loss 123 | loss = net.weight_loss(output, train_eltwise_label, train_instance_weight) 124 | 125 | # backward 126 | loss.backward() 127 | 128 | # update parameter 129 | optimizer.step() 130 | 131 | # collect training loss 132 | train_loss.append(loss.data) 133 | 134 | # ------------------------------ saving model ------------------------------ 135 | if not os.path.exists(model_save_path): 136 | os.makedirs(model_save_path) 137 | torch.save(net, model_save_path + "SiamFC_" + str(i + 1) + "_model.pth") 138 | 139 | # ------------------------------ validation ------------------------------ 140 | # indicate validation 141 | net.eval() 142 | 143 | # used to collect validation loss 144 | val_loss = [] 145 | 146 | for j, data in enumerate(tqdm(val_loader)): 147 | 148 | exemplar_imgs, instance_imgs = data 149 | 150 | # forward pass 151 | if use_gpu: 152 | exemplar_imgs = exemplar_imgs.cuda() 153 | instance_imgs = instance_imgs.cuda() 154 | output = net.forward(Variable(exemplar_imgs), Variable(instance_imgs)) 155 | 156 | # create label for validation (only do it one time) 157 | if not valid_response_flag: 158 | valid_response_flag = True 159 | response_size = output.shape[2:4] 160 | valid_eltwise_label, valid_instance_weight = create_label(response_size, config, use_gpu) 161 | 162 | # loss 163 | loss = net.weight_loss(output, valid_eltwise_label, valid_instance_weight) 164 | 165 | # collect validation loss 166 | val_loss.append(loss.data) 167 | 168 | print ("Epoch %d training loss: %f, validation loss: %f" % (i+1, np.mean(train_loss), np.mean(val_loss))) 169 | 170 | 171 | if __name__ == "__main__": 172 | 173 | data_dir = "/home/hfan/Dataset/ILSVRC2015_crops/Data/VID/train" 174 | train_imdb = "/home/hfan/Desktop/PyTorch-SiamFC/ILSVRC15-curation/imdb_video_train.json" 175 | val_imdb = "/home/hfan/Desktop/PyTorch-SiamFC/ILSVRC15-curation/imdb_video_val.json" 176 | 177 | # training SiamFC network, using GPU by default 178 | train(data_dir, train_imdb, val_imdb) 179 | -------------------------------------------------------------------------------- /imgs/result.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HengLan/SiamFC-PyTorch/6cb921f5ac58b612ebface176d9a84d5e033150a/imgs/result.PNG --------------------------------------------------------------------------------