├── .gitignore ├── README.md ├── pairwise.py ├── siamfc.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .* 2 | *.pyc 3 | data 4 | data/ 5 | __pycache__/ 6 | results*/ 7 | reports*/ 8 | cache*/ 9 | pretrained*/ 10 | !.gitignore 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SiamFC 2 | 3 | A minimal example showing how to train a tracker on GOT-10k/VID and evaluate its performance on GOT-10k as well as 6 other tracking datasets ([OTB (2013/2015)](http://cvlab.hanyang.ac.kr/tracker_benchmark/index.html), [VOT (2013~2018)](http://votchallenge.net), [DTB70](https://github.com/flyers/drone-tracking), [TColor128](http://www.dabi.temple.edu/~hbling/data/TColor-128/TColor-128.html), [NfS](http://ci2cv.net/nfs/index.html) and [UAV123](https://ivul.kaust.edu.sa/Pages/pub-benchmark-simulator-uav.aspx)), using the [GOT-10k toolkit](https://github.com/got-10k/toolkit). 4 | 5 | ## Performance 6 | 7 | ### GOT-10k 8 | 9 | | Dataset | AO | SR0.50 | SR0.75 | 10 | |:------- |:-----:|:-----------------:|:-----------------:| 11 | | GOT-10k | 0.355 | 0.390 | 0.118 | 12 | 13 | The scores are comparable with state-of-the-art results on [GOT-10k leaderboard](http://got-10k.aitestunion.com/leaderboard). 14 | 15 | ### OTB / UAV123 / DTB70 / TColor128 / NfS 16 | 17 | | Dataset | Success Score | Precision Score | 18 | |:----------- |:----------------:|:----------------:| 19 | | OTB2013 | 0.589 | 0.781 | 20 | | OTB2015 | 0.578 | 0.765 | 21 | | UAV123 | 0.523 | 0.731 | 22 | | UAV20L | 0.423 | 0.572 | 23 | | DTB70 | 0.493 | 0.731 | 24 | | TColor128 | 0.510 | 0.691 | 25 | | NfS (30 fps) | - | - | 26 | | NfS (240 fps) | 0.520 | 0.624 | 27 | 28 | ### VOT2018 29 | 30 | | Dataset | Accuracy | Robustness (unnormalized) | 31 | |:----------- |:-----------:|:-------------------------:| 32 | | VOT2018 | 0.502 | 37.25 | 33 | 34 | ## Dependencies 35 | 36 | Install PyTorch, opencv-python and GOT-10k toolkit: 37 | 38 | ```bash 39 | pip install torch opencv-python got10k 40 | ``` 41 | 42 | ## Running the tracker 43 | 44 | In the root directory of `siamfc`: 45 | 46 | 1. Download pretrained `model.pth` from [Baidu Yun](https://pan.baidu.com/s/1TT7ebFho63Lw2D7CXLqwjQ) or [Google Drive](https://drive.google.com/open?id=1Qu5K8bQhRAiexKdnwzs39lOko3uWxEKm), and put the file under `pretrained/siamfc`. 47 | 48 | 2. Create a symbolic link `data` to your datasets folder (e.g., `data/OTB`, `data/UAV123`, `data/GOT-10k`): 49 | 50 | ``` 51 | ln -s ./data /path/to/your/data/folder 52 | ``` 53 | 54 | 3. Run: 55 | 56 | ``` 57 | python test.py 58 | ``` 59 | 60 | By default, the tracking experiments will be executed and evaluated over all 7 datasets. Comment lines in `run_tracker.py` as you wish if you need to skip some experiments. 61 | 62 | ## Training the tracker 63 | 64 | 1. Assume the GOT-10k dataset is located at `data/GOT-10K`. 65 | 66 | 2. Run: 67 | 68 | ``` 69 | python train.py 70 | ``` 71 | -------------------------------------------------------------------------------- /pairwise.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | 3 | import numpy as np 4 | from collections import namedtuple 5 | from torch.utils.data import Dataset 6 | from torchvision.transforms import Compose, CenterCrop, RandomCrop, ToTensor 7 | from PIL import Image, ImageStat, ImageOps 8 | 9 | 10 | class RandomStretch(object): 11 | 12 | def __init__(self, max_stretch=0.05, interpolation='bilinear'): 13 | assert interpolation in ['bilinear', 'bicubic'] 14 | self.max_stretch = max_stretch 15 | self.interpolation = interpolation 16 | 17 | def __call__(self, img): 18 | scale = 1.0 + np.random.uniform( 19 | -self.max_stretch, self.max_stretch) 20 | size = np.round(np.array(img.size, float) * scale).astype(int) 21 | if self.interpolation == 'bilinear': 22 | method = Image.BILINEAR 23 | elif self.interpolation == 'bicubic': 24 | method = Image.BICUBIC 25 | return img.resize(tuple(size), method) 26 | 27 | 28 | class Pairwise(Dataset): 29 | 30 | def __init__(self, seq_dataset, **kargs): 31 | super(Pairwise, self).__init__() 32 | self.cfg = self.parse_args(**kargs) 33 | 34 | self.seq_dataset = seq_dataset 35 | self.indices = np.random.permutation(len(seq_dataset)) 36 | # augmentation for exemplar and instance images 37 | self.transform_z = Compose([ 38 | RandomStretch(max_stretch=0.05), 39 | CenterCrop(self.cfg.instance_sz - 8), 40 | RandomCrop(self.cfg.instance_sz - 2 * 8), 41 | CenterCrop(self.cfg.exemplar_sz), 42 | ToTensor()]) 43 | self.transform_x = Compose([ 44 | RandomStretch(max_stretch=0.05), 45 | CenterCrop(self.cfg.instance_sz - 8), 46 | RandomCrop(self.cfg.instance_sz - 2 * 8), 47 | ToTensor()]) 48 | 49 | def parse_args(self, **kargs): 50 | # default parameters 51 | cfg = { 52 | 'pairs_per_seq': 10, 53 | 'max_dist': 100, 54 | 'exemplar_sz': 127, 55 | 'instance_sz': 255, 56 | 'context': 0.5} 57 | 58 | for key, val in kargs.items(): 59 | if key in cfg: 60 | cfg.update({key: val}) 61 | return namedtuple('GenericDict', cfg.keys())(**cfg) 62 | 63 | def __getitem__(self, index): 64 | index = self.indices[index % len(self.seq_dataset)] 65 | img_files, anno = self.seq_dataset[index] 66 | 67 | # remove too small objects 68 | valid = anno[:, 2:].prod(axis=1) >= 10 69 | img_files = np.array(img_files)[valid] 70 | anno = anno[valid, :] 71 | 72 | rand_z, rand_x = self._sample_pair(len(img_files)) 73 | 74 | exemplar_image = Image.open(img_files[rand_z]) 75 | instance_image = Image.open(img_files[rand_x]) 76 | exemplar_image = self._crop_and_resize(exemplar_image, anno[rand_z]) 77 | instance_image = self._crop_and_resize(instance_image, anno[rand_x]) 78 | exemplar_image = 255.0 * self.transform_z(exemplar_image) 79 | instance_image = 255.0 * self.transform_x(instance_image) 80 | 81 | return exemplar_image, instance_image 82 | 83 | def __len__(self): 84 | return self.cfg.pairs_per_seq * len(self.seq_dataset) 85 | 86 | def _sample_pair(self, n): 87 | assert n > 0 88 | if n == 1: 89 | return 0, 0 90 | elif n == 2: 91 | return 0, 1 92 | else: 93 | max_dist = min(n - 1, self.cfg.max_dist) 94 | rand_dist = np.random.choice(max_dist) + 1 95 | rand_z = np.random.choice(n - rand_dist) 96 | rand_x = rand_z + rand_dist 97 | 98 | return rand_z, rand_x 99 | 100 | def _crop_and_resize(self, image, box): 101 | # convert box to 0-indexed and center based 102 | box = np.array([ 103 | box[0] - 1 + (box[2] - 1) / 2, 104 | box[1] - 1 + (box[3] - 1) / 2, 105 | box[2], box[3]], dtype=np.float32) 106 | center, target_sz = box[:2], box[2:] 107 | 108 | # exemplar and search sizes 109 | context = self.cfg.context * np.sum(target_sz) 110 | z_sz = np.sqrt(np.prod(target_sz + context)) 111 | x_sz = z_sz * self.cfg.instance_sz / self.cfg.exemplar_sz 112 | 113 | # convert box to corners (0-indexed) 114 | size = round(x_sz) 115 | corners = np.concatenate(( 116 | np.round(center - (size - 1) / 2), 117 | np.round(center - (size - 1) / 2) + size)) 118 | corners = np.round(corners).astype(int) 119 | 120 | # pad image if necessary 121 | pads = np.concatenate(( 122 | -corners[:2], corners[2:] - image.size)) 123 | npad = max(0, int(pads.max())) 124 | if npad > 0: 125 | avg_color = ImageStat.Stat(image).mean 126 | # PIL doesn't support float RGB image 127 | avg_color = tuple(int(round(c)) for c in avg_color) 128 | image = ImageOps.expand(image, border=npad, fill=avg_color) 129 | 130 | # crop image patch 131 | corners = tuple((corners + npad).astype(int)) 132 | patch = image.crop(corners) 133 | 134 | # resize to instance_sz 135 | out_size = (self.cfg.instance_sz, self.cfg.instance_sz) 136 | patch = patch.resize(out_size, Image.BILINEAR) 137 | 138 | return patch 139 | -------------------------------------------------------------------------------- /siamfc.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | import numpy as np 9 | import cv2 10 | from collections import namedtuple 11 | from torch.optim.lr_scheduler import ExponentialLR 12 | 13 | from got10k.trackers import Tracker 14 | 15 | 16 | class SiamFC(nn.Module): 17 | 18 | def __init__(self): 19 | super(SiamFC, self).__init__() 20 | self.feature = nn.Sequential( 21 | # conv1 22 | nn.Conv2d(3, 96, 11, 2), 23 | nn.BatchNorm2d(96, eps=1e-6, momentum=0.05), 24 | nn.ReLU(inplace=True), 25 | nn.MaxPool2d(3, 2), 26 | # conv2 27 | nn.Conv2d(96, 256, 5, 1, groups=2), 28 | nn.BatchNorm2d(256, eps=1e-6, momentum=0.05), 29 | nn.ReLU(inplace=True), 30 | nn.MaxPool2d(3, 2), 31 | # conv3 32 | nn.Conv2d(256, 384, 3, 1), 33 | nn.BatchNorm2d(384, eps=1e-6, momentum=0.05), 34 | nn.ReLU(inplace=True), 35 | # conv4 36 | nn.Conv2d(384, 384, 3, 1, groups=2), 37 | nn.BatchNorm2d(384, eps=1e-6, momentum=0.05), 38 | nn.ReLU(inplace=True), 39 | # conv5 40 | nn.Conv2d(384, 256, 3, 1, groups=2)) 41 | self._initialize_weights() 42 | 43 | def forward(self, z, x): 44 | z = self.feature(z) 45 | x = self.feature(x) 46 | 47 | # fast cross correlation 48 | n, c, h, w = x.size() 49 | x = x.view(1, n * c, h, w) 50 | out = F.conv2d(x, z, groups=n) 51 | out = out.view(n, 1, out.size(-2), out.size(-1)) 52 | 53 | # adjust the scale of responses 54 | out = 0.001 * out + 0.0 55 | 56 | return out 57 | 58 | def _initialize_weights(self): 59 | for m in self.modules(): 60 | if isinstance(m, nn.Conv2d): 61 | init.kaiming_normal_(m.weight.data, mode='fan_out', 62 | nonlinearity='relu') 63 | m.bias.data.fill_(0) 64 | elif isinstance(m, nn.BatchNorm2d): 65 | m.weight.data.fill_(1) 66 | m.bias.data.zero_() 67 | 68 | 69 | class TrackerSiamFC(Tracker): 70 | 71 | def __init__(self, net_path=None, **kargs): 72 | super(TrackerSiamFC, self).__init__( 73 | name='SiamFC', is_deterministic=True) 74 | self.cfg = self.parse_args(**kargs) 75 | 76 | # setup GPU device if available 77 | self.cuda = torch.cuda.is_available() 78 | self.device = torch.device('cuda:0' if self.cuda else 'cpu') 79 | 80 | # setup model 81 | self.net = SiamFC() 82 | if net_path is not None: 83 | self.net.load_state_dict(torch.load( 84 | net_path, map_location=lambda storage, loc: storage)) 85 | self.net = self.net.to(self.device) 86 | 87 | # setup optimizer 88 | self.optimizer = optim.SGD( 89 | self.net.parameters(), 90 | lr=self.cfg.initial_lr, 91 | weight_decay=self.cfg.weight_decay, 92 | momentum=self.cfg.momentum) 93 | 94 | # setup lr scheduler 95 | self.lr_scheduler = ExponentialLR( 96 | self.optimizer, gamma=self.cfg.lr_decay) 97 | 98 | def parse_args(self, **kargs): 99 | # default parameters 100 | cfg = { 101 | # inference parameters 102 | 'exemplar_sz': 127, 103 | 'instance_sz': 255, 104 | 'context': 0.5, 105 | 'scale_num': 3, 106 | 'scale_step': 1.0375, 107 | 'scale_lr': 0.59, 108 | 'scale_penalty': 0.9745, 109 | 'window_influence': 0.176, 110 | 'response_sz': 17, 111 | 'response_up': 16, 112 | 'total_stride': 8, 113 | 'adjust_scale': 0.001, 114 | # train parameters 115 | 'initial_lr': 0.01, 116 | 'lr_decay': 0.8685113737513527, 117 | 'weight_decay': 5e-4, 118 | 'momentum': 0.9, 119 | 'r_pos': 16, 120 | 'r_neg': 0} 121 | 122 | for key, val in kargs.items(): 123 | if key in cfg: 124 | cfg.update({key: val}) 125 | return namedtuple('GenericDict', cfg.keys())(**cfg) 126 | 127 | def init(self, image, box): 128 | image = np.asarray(image) 129 | 130 | # convert box to 0-indexed and center based [y, x, h, w] 131 | box = np.array([ 132 | box[1] - 1 + (box[3] - 1) / 2, 133 | box[0] - 1 + (box[2] - 1) / 2, 134 | box[3], box[2]], dtype=np.float32) 135 | self.center, self.target_sz = box[:2], box[2:] 136 | 137 | # create hanning window 138 | self.upscale_sz = self.cfg.response_up * self.cfg.response_sz 139 | self.hann_window = np.outer( 140 | np.hanning(self.upscale_sz), 141 | np.hanning(self.upscale_sz)) 142 | self.hann_window /= self.hann_window.sum() 143 | 144 | # search scale factors 145 | self.scale_factors = self.cfg.scale_step ** np.linspace( 146 | -(self.cfg.scale_num // 2), 147 | self.cfg.scale_num // 2, self.cfg.scale_num) 148 | 149 | # exemplar and search sizes 150 | context = self.cfg.context * np.sum(self.target_sz) 151 | self.z_sz = np.sqrt(np.prod(self.target_sz + context)) 152 | self.x_sz = self.z_sz * \ 153 | self.cfg.instance_sz / self.cfg.exemplar_sz 154 | 155 | # exemplar image 156 | self.avg_color = np.mean(image, axis=(0, 1)) 157 | exemplar_image = self._crop_and_resize( 158 | image, self.center, self.z_sz, 159 | out_size=self.cfg.exemplar_sz, 160 | pad_color=self.avg_color) 161 | 162 | # exemplar features 163 | exemplar_image = torch.from_numpy(exemplar_image).to( 164 | self.device).permute([2, 0, 1]).unsqueeze(0).float() 165 | with torch.set_grad_enabled(False): 166 | self.net.eval() 167 | self.kernel = self.net.feature(exemplar_image) 168 | 169 | def update(self, image): 170 | image = np.asarray(image) 171 | 172 | # search images 173 | instance_images = [self._crop_and_resize( 174 | image, self.center, self.x_sz * f, 175 | out_size=self.cfg.instance_sz, 176 | pad_color=self.avg_color) for f in self.scale_factors] 177 | instance_images = np.stack(instance_images, axis=0) 178 | instance_images = torch.from_numpy(instance_images).to( 179 | self.device).permute([0, 3, 1, 2]).float() 180 | 181 | # responses 182 | with torch.set_grad_enabled(False): 183 | self.net.eval() 184 | instances = self.net.feature(instance_images) 185 | responses = F.conv2d(instances, self.kernel) * 0.001 186 | responses = responses.squeeze(1).cpu().numpy() 187 | 188 | # upsample responses and penalize scale changes 189 | responses = np.stack([cv2.resize( 190 | t, (self.upscale_sz, self.upscale_sz), 191 | interpolation=cv2.INTER_CUBIC) for t in responses], axis=0) 192 | responses[:self.cfg.scale_num // 2] *= self.cfg.scale_penalty 193 | responses[self.cfg.scale_num // 2 + 1:] *= self.cfg.scale_penalty 194 | 195 | # peak scale 196 | scale_id = np.argmax(np.amax(responses, axis=(1, 2))) 197 | 198 | # peak location 199 | response = responses[scale_id] 200 | response -= response.min() 201 | response /= response.sum() + 1e-16 202 | response = (1 - self.cfg.window_influence) * response + \ 203 | self.cfg.window_influence * self.hann_window 204 | loc = np.unravel_index(response.argmax(), response.shape) 205 | 206 | # locate target center 207 | disp_in_response = np.array(loc) - self.upscale_sz // 2 208 | disp_in_instance = disp_in_response * \ 209 | self.cfg.total_stride / self.cfg.response_up 210 | disp_in_image = disp_in_instance * self.x_sz * \ 211 | self.scale_factors[scale_id] / self.cfg.instance_sz 212 | self.center += disp_in_image 213 | 214 | # update target size 215 | scale = (1 - self.cfg.scale_lr) * 1.0 + \ 216 | self.cfg.scale_lr * self.scale_factors[scale_id] 217 | self.target_sz *= scale 218 | self.z_sz *= scale 219 | self.x_sz *= scale 220 | 221 | # return 1-indexed and left-top based bounding box 222 | box = np.array([ 223 | self.center[1] + 1 - (self.target_sz[1] - 1) / 2, 224 | self.center[0] + 1 - (self.target_sz[0] - 1) / 2, 225 | self.target_sz[1], self.target_sz[0]]) 226 | 227 | return box 228 | 229 | def step(self, batch, backward=True, update_lr=False): 230 | if backward: 231 | self.net.train() 232 | if update_lr: 233 | self.lr_scheduler.step() 234 | else: 235 | self.net.eval() 236 | 237 | z = batch[0].to(self.device) 238 | x = batch[1].to(self.device) 239 | 240 | with torch.set_grad_enabled(backward): 241 | responses = self.net(z, x) 242 | labels, weights = self._create_labels(responses.size()) 243 | loss = F.binary_cross_entropy_with_logits( 244 | responses, labels, weight=weights, size_average=True) 245 | 246 | if backward: 247 | self.optimizer.zero_grad() 248 | loss.backward() 249 | self.optimizer.step() 250 | 251 | return loss.item() 252 | 253 | def _crop_and_resize(self, image, center, size, out_size, pad_color): 254 | # convert box to corners (0-indexed) 255 | size = round(size) 256 | corners = np.concatenate(( 257 | np.round(center - (size - 1) / 2), 258 | np.round(center - (size - 1) / 2) + size)) 259 | corners = np.round(corners).astype(int) 260 | 261 | # pad image if necessary 262 | pads = np.concatenate(( 263 | -corners[:2], corners[2:] - image.shape[:2])) 264 | npad = max(0, int(pads.max())) 265 | if npad > 0: 266 | image = cv2.copyMakeBorder( 267 | image, npad, npad, npad, npad, 268 | cv2.BORDER_CONSTANT, value=pad_color) 269 | 270 | # crop image patch 271 | corners = (corners + npad).astype(int) 272 | patch = image[corners[0]:corners[2], corners[1]:corners[3]] 273 | 274 | # resize to out_size 275 | patch = cv2.resize(patch, (out_size, out_size)) 276 | 277 | return patch 278 | 279 | def _create_labels(self, size): 280 | # skip if same sized labels already created 281 | if hasattr(self, 'labels') and self.labels.size() == size: 282 | return self.labels, self.weights 283 | 284 | def logistic_labels(x, y, r_pos, r_neg): 285 | dist = np.abs(x) + np.abs(y) # block distance 286 | labels = np.where(dist <= r_pos, 287 | np.ones_like(x), 288 | np.where(dist < r_neg, 289 | np.ones_like(x) * 0.5, 290 | np.zeros_like(x))) 291 | return labels 292 | 293 | # distances along x- and y-axis 294 | n, c, h, w = size 295 | x = np.arange(w) - w // 2 296 | y = np.arange(h) - h // 2 297 | x, y = np.meshgrid(x, y) 298 | 299 | # create logistic labels 300 | r_pos = self.cfg.r_pos / self.cfg.total_stride 301 | r_neg = self.cfg.r_neg / self.cfg.total_stride 302 | labels = logistic_labels(x, y, r_pos, r_neg) 303 | 304 | # pos/neg weights 305 | pos_num = np.sum(labels == 1) 306 | neg_num = np.sum(labels == 0) 307 | weights = np.zeros_like(labels) 308 | weights[labels == 1] = 0.5 / pos_num 309 | weights[labels == 0] = 0.5 / neg_num 310 | weights *= pos_num + neg_num 311 | 312 | # repeat to size 313 | labels = labels.reshape((1, 1, h, w)) 314 | weights = weights.reshape((1, 1, h, w)) 315 | labels = np.tile(labels, (n, c, 1, 1)) 316 | weights = np.tile(weights, [n, c, 1, 1]) 317 | 318 | # convert to tensors 319 | self.labels = torch.from_numpy(labels).to(self.device).float() 320 | self.weights = torch.from_numpy(weights).to(self.device).float() 321 | 322 | return self.labels, self.weights 323 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from got10k.experiments import * 4 | 5 | from siamfc import TrackerSiamFC 6 | 7 | 8 | if __name__ == '__main__': 9 | # setup tracker 10 | net_path = 'pretrained/siamfc/model.pth' 11 | tracker = TrackerSiamFC(net_path=net_path) 12 | 13 | # setup experiments 14 | experiments = [ 15 | ExperimentGOT10k('data/GOT-10k', subset='test'), 16 | ExperimentOTB('data/OTB', version=2013), 17 | ExperimentOTB('data/OTB', version=2015), 18 | ExperimentVOT('data/vot2018', version=2018), 19 | ExperimentDTB70('data/DTB70'), 20 | ExperimentTColor128('data/Temple-color-128'), 21 | ExperimentUAV123('data/UAV123', version='UAV123'), 22 | ExperimentUAV123('data/UAV123', version='UAV20L'), 23 | ExperimentNfS('data/nfs', fps=30), 24 | ExperimentNfS('data/nfs', fps=240) 25 | ] 26 | 27 | # run tracking experiments and report performance 28 | for e in experiments: 29 | e.run(tracker, visualize=True) 30 | e.report([tracker.name]) 31 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function 2 | 3 | import os 4 | import sys 5 | import torch 6 | from torch.utils.data import DataLoader 7 | 8 | from got10k.datasets import ImageNetVID, GOT10k 9 | from pairwise import Pairwise 10 | from siamfc import TrackerSiamFC 11 | 12 | 13 | if __name__ == '__main__': 14 | # setup dataset 15 | name = 'GOT-10k' 16 | assert name in ['VID', 'GOT-10k'] 17 | if name == 'GOT-10k': 18 | root_dir = 'data/GOT-10k' 19 | seq_dataset = GOT10k(root_dir, subset='train') 20 | elif name == 'VID': 21 | root_dir = 'data/ILSVRC' 22 | seq_dataset = ImageNetVID(root_dir, subset=('train', 'val')) 23 | pair_dataset = Pairwise(seq_dataset) 24 | 25 | # setup data loader 26 | cuda = torch.cuda.is_available() 27 | loader = DataLoader( 28 | pair_dataset, batch_size=8, shuffle=True, 29 | pin_memory=cuda, drop_last=True, num_workers=4) 30 | 31 | # setup tracker 32 | tracker = TrackerSiamFC() 33 | 34 | # path for saving checkpoints 35 | net_dir = 'pretrained/siamfc_new' 36 | if not os.path.exists(net_dir): 37 | os.makedirs(net_dir) 38 | 39 | # training loop 40 | epoch_num = 50 41 | for epoch in range(epoch_num): 42 | for step, batch in enumerate(loader): 43 | loss = tracker.step( 44 | batch, backward=True, update_lr=(step == 0)) 45 | if step % 20 == 0: 46 | print('Epoch [{}][{}/{}]: Loss: {:.3f}'.format( 47 | epoch + 1, step + 1, len(loader), loss)) 48 | sys.stdout.flush() 49 | 50 | # save checkpoint 51 | net_path = os.path.join(net_dir, 'model_e%d.pth' % (epoch + 1)) 52 | torch.save(tracker.net.state_dict(), net_path) 53 | --------------------------------------------------------------------------------