├── .gitignore
├── README.md
├── pairwise.py
├── siamfc.py
├── test.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .*
2 | *.pyc
3 | data
4 | data/
5 | __pycache__/
6 | results*/
7 | reports*/
8 | cache*/
9 | pretrained*/
10 | !.gitignore
11 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SiamFC
2 |
3 | A minimal example showing how to train a tracker on GOT-10k/VID and evaluate its performance on GOT-10k as well as 6 other tracking datasets ([OTB (2013/2015)](http://cvlab.hanyang.ac.kr/tracker_benchmark/index.html), [VOT (2013~2018)](http://votchallenge.net), [DTB70](https://github.com/flyers/drone-tracking), [TColor128](http://www.dabi.temple.edu/~hbling/data/TColor-128/TColor-128.html), [NfS](http://ci2cv.net/nfs/index.html) and [UAV123](https://ivul.kaust.edu.sa/Pages/pub-benchmark-simulator-uav.aspx)), using the [GOT-10k toolkit](https://github.com/got-10k/toolkit).
4 |
5 | ## Performance
6 |
7 | ### GOT-10k
8 |
9 | | Dataset | AO | SR0.50 | SR0.75 |
10 | |:------- |:-----:|:-----------------:|:-----------------:|
11 | | GOT-10k | 0.355 | 0.390 | 0.118 |
12 |
13 | The scores are comparable with state-of-the-art results on [GOT-10k leaderboard](http://got-10k.aitestunion.com/leaderboard).
14 |
15 | ### OTB / UAV123 / DTB70 / TColor128 / NfS
16 |
17 | | Dataset | Success Score | Precision Score |
18 | |:----------- |:----------------:|:----------------:|
19 | | OTB2013 | 0.589 | 0.781 |
20 | | OTB2015 | 0.578 | 0.765 |
21 | | UAV123 | 0.523 | 0.731 |
22 | | UAV20L | 0.423 | 0.572 |
23 | | DTB70 | 0.493 | 0.731 |
24 | | TColor128 | 0.510 | 0.691 |
25 | | NfS (30 fps) | - | - |
26 | | NfS (240 fps) | 0.520 | 0.624 |
27 |
28 | ### VOT2018
29 |
30 | | Dataset | Accuracy | Robustness (unnormalized) |
31 | |:----------- |:-----------:|:-------------------------:|
32 | | VOT2018 | 0.502 | 37.25 |
33 |
34 | ## Dependencies
35 |
36 | Install PyTorch, opencv-python and GOT-10k toolkit:
37 |
38 | ```bash
39 | pip install torch opencv-python got10k
40 | ```
41 |
42 | ## Running the tracker
43 |
44 | In the root directory of `siamfc`:
45 |
46 | 1. Download pretrained `model.pth` from [Baidu Yun](https://pan.baidu.com/s/1TT7ebFho63Lw2D7CXLqwjQ) or [Google Drive](https://drive.google.com/open?id=1Qu5K8bQhRAiexKdnwzs39lOko3uWxEKm), and put the file under `pretrained/siamfc`.
47 |
48 | 2. Create a symbolic link `data` to your datasets folder (e.g., `data/OTB`, `data/UAV123`, `data/GOT-10k`):
49 |
50 | ```
51 | ln -s ./data /path/to/your/data/folder
52 | ```
53 |
54 | 3. Run:
55 |
56 | ```
57 | python test.py
58 | ```
59 |
60 | By default, the tracking experiments will be executed and evaluated over all 7 datasets. Comment lines in `run_tracker.py` as you wish if you need to skip some experiments.
61 |
62 | ## Training the tracker
63 |
64 | 1. Assume the GOT-10k dataset is located at `data/GOT-10K`.
65 |
66 | 2. Run:
67 |
68 | ```
69 | python train.py
70 | ```
71 |
--------------------------------------------------------------------------------
/pairwise.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division
2 |
3 | import numpy as np
4 | from collections import namedtuple
5 | from torch.utils.data import Dataset
6 | from torchvision.transforms import Compose, CenterCrop, RandomCrop, ToTensor
7 | from PIL import Image, ImageStat, ImageOps
8 |
9 |
10 | class RandomStretch(object):
11 |
12 | def __init__(self, max_stretch=0.05, interpolation='bilinear'):
13 | assert interpolation in ['bilinear', 'bicubic']
14 | self.max_stretch = max_stretch
15 | self.interpolation = interpolation
16 |
17 | def __call__(self, img):
18 | scale = 1.0 + np.random.uniform(
19 | -self.max_stretch, self.max_stretch)
20 | size = np.round(np.array(img.size, float) * scale).astype(int)
21 | if self.interpolation == 'bilinear':
22 | method = Image.BILINEAR
23 | elif self.interpolation == 'bicubic':
24 | method = Image.BICUBIC
25 | return img.resize(tuple(size), method)
26 |
27 |
28 | class Pairwise(Dataset):
29 |
30 | def __init__(self, seq_dataset, **kargs):
31 | super(Pairwise, self).__init__()
32 | self.cfg = self.parse_args(**kargs)
33 |
34 | self.seq_dataset = seq_dataset
35 | self.indices = np.random.permutation(len(seq_dataset))
36 | # augmentation for exemplar and instance images
37 | self.transform_z = Compose([
38 | RandomStretch(max_stretch=0.05),
39 | CenterCrop(self.cfg.instance_sz - 8),
40 | RandomCrop(self.cfg.instance_sz - 2 * 8),
41 | CenterCrop(self.cfg.exemplar_sz),
42 | ToTensor()])
43 | self.transform_x = Compose([
44 | RandomStretch(max_stretch=0.05),
45 | CenterCrop(self.cfg.instance_sz - 8),
46 | RandomCrop(self.cfg.instance_sz - 2 * 8),
47 | ToTensor()])
48 |
49 | def parse_args(self, **kargs):
50 | # default parameters
51 | cfg = {
52 | 'pairs_per_seq': 10,
53 | 'max_dist': 100,
54 | 'exemplar_sz': 127,
55 | 'instance_sz': 255,
56 | 'context': 0.5}
57 |
58 | for key, val in kargs.items():
59 | if key in cfg:
60 | cfg.update({key: val})
61 | return namedtuple('GenericDict', cfg.keys())(**cfg)
62 |
63 | def __getitem__(self, index):
64 | index = self.indices[index % len(self.seq_dataset)]
65 | img_files, anno = self.seq_dataset[index]
66 |
67 | # remove too small objects
68 | valid = anno[:, 2:].prod(axis=1) >= 10
69 | img_files = np.array(img_files)[valid]
70 | anno = anno[valid, :]
71 |
72 | rand_z, rand_x = self._sample_pair(len(img_files))
73 |
74 | exemplar_image = Image.open(img_files[rand_z])
75 | instance_image = Image.open(img_files[rand_x])
76 | exemplar_image = self._crop_and_resize(exemplar_image, anno[rand_z])
77 | instance_image = self._crop_and_resize(instance_image, anno[rand_x])
78 | exemplar_image = 255.0 * self.transform_z(exemplar_image)
79 | instance_image = 255.0 * self.transform_x(instance_image)
80 |
81 | return exemplar_image, instance_image
82 |
83 | def __len__(self):
84 | return self.cfg.pairs_per_seq * len(self.seq_dataset)
85 |
86 | def _sample_pair(self, n):
87 | assert n > 0
88 | if n == 1:
89 | return 0, 0
90 | elif n == 2:
91 | return 0, 1
92 | else:
93 | max_dist = min(n - 1, self.cfg.max_dist)
94 | rand_dist = np.random.choice(max_dist) + 1
95 | rand_z = np.random.choice(n - rand_dist)
96 | rand_x = rand_z + rand_dist
97 |
98 | return rand_z, rand_x
99 |
100 | def _crop_and_resize(self, image, box):
101 | # convert box to 0-indexed and center based
102 | box = np.array([
103 | box[0] - 1 + (box[2] - 1) / 2,
104 | box[1] - 1 + (box[3] - 1) / 2,
105 | box[2], box[3]], dtype=np.float32)
106 | center, target_sz = box[:2], box[2:]
107 |
108 | # exemplar and search sizes
109 | context = self.cfg.context * np.sum(target_sz)
110 | z_sz = np.sqrt(np.prod(target_sz + context))
111 | x_sz = z_sz * self.cfg.instance_sz / self.cfg.exemplar_sz
112 |
113 | # convert box to corners (0-indexed)
114 | size = round(x_sz)
115 | corners = np.concatenate((
116 | np.round(center - (size - 1) / 2),
117 | np.round(center - (size - 1) / 2) + size))
118 | corners = np.round(corners).astype(int)
119 |
120 | # pad image if necessary
121 | pads = np.concatenate((
122 | -corners[:2], corners[2:] - image.size))
123 | npad = max(0, int(pads.max()))
124 | if npad > 0:
125 | avg_color = ImageStat.Stat(image).mean
126 | # PIL doesn't support float RGB image
127 | avg_color = tuple(int(round(c)) for c in avg_color)
128 | image = ImageOps.expand(image, border=npad, fill=avg_color)
129 |
130 | # crop image patch
131 | corners = tuple((corners + npad).astype(int))
132 | patch = image.crop(corners)
133 |
134 | # resize to instance_sz
135 | out_size = (self.cfg.instance_sz, self.cfg.instance_sz)
136 | patch = patch.resize(out_size, Image.BILINEAR)
137 |
138 | return patch
139 |
--------------------------------------------------------------------------------
/siamfc.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.init as init
6 | import torch.nn.functional as F
7 | import torch.optim as optim
8 | import numpy as np
9 | import cv2
10 | from collections import namedtuple
11 | from torch.optim.lr_scheduler import ExponentialLR
12 |
13 | from got10k.trackers import Tracker
14 |
15 |
16 | class SiamFC(nn.Module):
17 |
18 | def __init__(self):
19 | super(SiamFC, self).__init__()
20 | self.feature = nn.Sequential(
21 | # conv1
22 | nn.Conv2d(3, 96, 11, 2),
23 | nn.BatchNorm2d(96, eps=1e-6, momentum=0.05),
24 | nn.ReLU(inplace=True),
25 | nn.MaxPool2d(3, 2),
26 | # conv2
27 | nn.Conv2d(96, 256, 5, 1, groups=2),
28 | nn.BatchNorm2d(256, eps=1e-6, momentum=0.05),
29 | nn.ReLU(inplace=True),
30 | nn.MaxPool2d(3, 2),
31 | # conv3
32 | nn.Conv2d(256, 384, 3, 1),
33 | nn.BatchNorm2d(384, eps=1e-6, momentum=0.05),
34 | nn.ReLU(inplace=True),
35 | # conv4
36 | nn.Conv2d(384, 384, 3, 1, groups=2),
37 | nn.BatchNorm2d(384, eps=1e-6, momentum=0.05),
38 | nn.ReLU(inplace=True),
39 | # conv5
40 | nn.Conv2d(384, 256, 3, 1, groups=2))
41 | self._initialize_weights()
42 |
43 | def forward(self, z, x):
44 | z = self.feature(z)
45 | x = self.feature(x)
46 |
47 | # fast cross correlation
48 | n, c, h, w = x.size()
49 | x = x.view(1, n * c, h, w)
50 | out = F.conv2d(x, z, groups=n)
51 | out = out.view(n, 1, out.size(-2), out.size(-1))
52 |
53 | # adjust the scale of responses
54 | out = 0.001 * out + 0.0
55 |
56 | return out
57 |
58 | def _initialize_weights(self):
59 | for m in self.modules():
60 | if isinstance(m, nn.Conv2d):
61 | init.kaiming_normal_(m.weight.data, mode='fan_out',
62 | nonlinearity='relu')
63 | m.bias.data.fill_(0)
64 | elif isinstance(m, nn.BatchNorm2d):
65 | m.weight.data.fill_(1)
66 | m.bias.data.zero_()
67 |
68 |
69 | class TrackerSiamFC(Tracker):
70 |
71 | def __init__(self, net_path=None, **kargs):
72 | super(TrackerSiamFC, self).__init__(
73 | name='SiamFC', is_deterministic=True)
74 | self.cfg = self.parse_args(**kargs)
75 |
76 | # setup GPU device if available
77 | self.cuda = torch.cuda.is_available()
78 | self.device = torch.device('cuda:0' if self.cuda else 'cpu')
79 |
80 | # setup model
81 | self.net = SiamFC()
82 | if net_path is not None:
83 | self.net.load_state_dict(torch.load(
84 | net_path, map_location=lambda storage, loc: storage))
85 | self.net = self.net.to(self.device)
86 |
87 | # setup optimizer
88 | self.optimizer = optim.SGD(
89 | self.net.parameters(),
90 | lr=self.cfg.initial_lr,
91 | weight_decay=self.cfg.weight_decay,
92 | momentum=self.cfg.momentum)
93 |
94 | # setup lr scheduler
95 | self.lr_scheduler = ExponentialLR(
96 | self.optimizer, gamma=self.cfg.lr_decay)
97 |
98 | def parse_args(self, **kargs):
99 | # default parameters
100 | cfg = {
101 | # inference parameters
102 | 'exemplar_sz': 127,
103 | 'instance_sz': 255,
104 | 'context': 0.5,
105 | 'scale_num': 3,
106 | 'scale_step': 1.0375,
107 | 'scale_lr': 0.59,
108 | 'scale_penalty': 0.9745,
109 | 'window_influence': 0.176,
110 | 'response_sz': 17,
111 | 'response_up': 16,
112 | 'total_stride': 8,
113 | 'adjust_scale': 0.001,
114 | # train parameters
115 | 'initial_lr': 0.01,
116 | 'lr_decay': 0.8685113737513527,
117 | 'weight_decay': 5e-4,
118 | 'momentum': 0.9,
119 | 'r_pos': 16,
120 | 'r_neg': 0}
121 |
122 | for key, val in kargs.items():
123 | if key in cfg:
124 | cfg.update({key: val})
125 | return namedtuple('GenericDict', cfg.keys())(**cfg)
126 |
127 | def init(self, image, box):
128 | image = np.asarray(image)
129 |
130 | # convert box to 0-indexed and center based [y, x, h, w]
131 | box = np.array([
132 | box[1] - 1 + (box[3] - 1) / 2,
133 | box[0] - 1 + (box[2] - 1) / 2,
134 | box[3], box[2]], dtype=np.float32)
135 | self.center, self.target_sz = box[:2], box[2:]
136 |
137 | # create hanning window
138 | self.upscale_sz = self.cfg.response_up * self.cfg.response_sz
139 | self.hann_window = np.outer(
140 | np.hanning(self.upscale_sz),
141 | np.hanning(self.upscale_sz))
142 | self.hann_window /= self.hann_window.sum()
143 |
144 | # search scale factors
145 | self.scale_factors = self.cfg.scale_step ** np.linspace(
146 | -(self.cfg.scale_num // 2),
147 | self.cfg.scale_num // 2, self.cfg.scale_num)
148 |
149 | # exemplar and search sizes
150 | context = self.cfg.context * np.sum(self.target_sz)
151 | self.z_sz = np.sqrt(np.prod(self.target_sz + context))
152 | self.x_sz = self.z_sz * \
153 | self.cfg.instance_sz / self.cfg.exemplar_sz
154 |
155 | # exemplar image
156 | self.avg_color = np.mean(image, axis=(0, 1))
157 | exemplar_image = self._crop_and_resize(
158 | image, self.center, self.z_sz,
159 | out_size=self.cfg.exemplar_sz,
160 | pad_color=self.avg_color)
161 |
162 | # exemplar features
163 | exemplar_image = torch.from_numpy(exemplar_image).to(
164 | self.device).permute([2, 0, 1]).unsqueeze(0).float()
165 | with torch.set_grad_enabled(False):
166 | self.net.eval()
167 | self.kernel = self.net.feature(exemplar_image)
168 |
169 | def update(self, image):
170 | image = np.asarray(image)
171 |
172 | # search images
173 | instance_images = [self._crop_and_resize(
174 | image, self.center, self.x_sz * f,
175 | out_size=self.cfg.instance_sz,
176 | pad_color=self.avg_color) for f in self.scale_factors]
177 | instance_images = np.stack(instance_images, axis=0)
178 | instance_images = torch.from_numpy(instance_images).to(
179 | self.device).permute([0, 3, 1, 2]).float()
180 |
181 | # responses
182 | with torch.set_grad_enabled(False):
183 | self.net.eval()
184 | instances = self.net.feature(instance_images)
185 | responses = F.conv2d(instances, self.kernel) * 0.001
186 | responses = responses.squeeze(1).cpu().numpy()
187 |
188 | # upsample responses and penalize scale changes
189 | responses = np.stack([cv2.resize(
190 | t, (self.upscale_sz, self.upscale_sz),
191 | interpolation=cv2.INTER_CUBIC) for t in responses], axis=0)
192 | responses[:self.cfg.scale_num // 2] *= self.cfg.scale_penalty
193 | responses[self.cfg.scale_num // 2 + 1:] *= self.cfg.scale_penalty
194 |
195 | # peak scale
196 | scale_id = np.argmax(np.amax(responses, axis=(1, 2)))
197 |
198 | # peak location
199 | response = responses[scale_id]
200 | response -= response.min()
201 | response /= response.sum() + 1e-16
202 | response = (1 - self.cfg.window_influence) * response + \
203 | self.cfg.window_influence * self.hann_window
204 | loc = np.unravel_index(response.argmax(), response.shape)
205 |
206 | # locate target center
207 | disp_in_response = np.array(loc) - self.upscale_sz // 2
208 | disp_in_instance = disp_in_response * \
209 | self.cfg.total_stride / self.cfg.response_up
210 | disp_in_image = disp_in_instance * self.x_sz * \
211 | self.scale_factors[scale_id] / self.cfg.instance_sz
212 | self.center += disp_in_image
213 |
214 | # update target size
215 | scale = (1 - self.cfg.scale_lr) * 1.0 + \
216 | self.cfg.scale_lr * self.scale_factors[scale_id]
217 | self.target_sz *= scale
218 | self.z_sz *= scale
219 | self.x_sz *= scale
220 |
221 | # return 1-indexed and left-top based bounding box
222 | box = np.array([
223 | self.center[1] + 1 - (self.target_sz[1] - 1) / 2,
224 | self.center[0] + 1 - (self.target_sz[0] - 1) / 2,
225 | self.target_sz[1], self.target_sz[0]])
226 |
227 | return box
228 |
229 | def step(self, batch, backward=True, update_lr=False):
230 | if backward:
231 | self.net.train()
232 | if update_lr:
233 | self.lr_scheduler.step()
234 | else:
235 | self.net.eval()
236 |
237 | z = batch[0].to(self.device)
238 | x = batch[1].to(self.device)
239 |
240 | with torch.set_grad_enabled(backward):
241 | responses = self.net(z, x)
242 | labels, weights = self._create_labels(responses.size())
243 | loss = F.binary_cross_entropy_with_logits(
244 | responses, labels, weight=weights, size_average=True)
245 |
246 | if backward:
247 | self.optimizer.zero_grad()
248 | loss.backward()
249 | self.optimizer.step()
250 |
251 | return loss.item()
252 |
253 | def _crop_and_resize(self, image, center, size, out_size, pad_color):
254 | # convert box to corners (0-indexed)
255 | size = round(size)
256 | corners = np.concatenate((
257 | np.round(center - (size - 1) / 2),
258 | np.round(center - (size - 1) / 2) + size))
259 | corners = np.round(corners).astype(int)
260 |
261 | # pad image if necessary
262 | pads = np.concatenate((
263 | -corners[:2], corners[2:] - image.shape[:2]))
264 | npad = max(0, int(pads.max()))
265 | if npad > 0:
266 | image = cv2.copyMakeBorder(
267 | image, npad, npad, npad, npad,
268 | cv2.BORDER_CONSTANT, value=pad_color)
269 |
270 | # crop image patch
271 | corners = (corners + npad).astype(int)
272 | patch = image[corners[0]:corners[2], corners[1]:corners[3]]
273 |
274 | # resize to out_size
275 | patch = cv2.resize(patch, (out_size, out_size))
276 |
277 | return patch
278 |
279 | def _create_labels(self, size):
280 | # skip if same sized labels already created
281 | if hasattr(self, 'labels') and self.labels.size() == size:
282 | return self.labels, self.weights
283 |
284 | def logistic_labels(x, y, r_pos, r_neg):
285 | dist = np.abs(x) + np.abs(y) # block distance
286 | labels = np.where(dist <= r_pos,
287 | np.ones_like(x),
288 | np.where(dist < r_neg,
289 | np.ones_like(x) * 0.5,
290 | np.zeros_like(x)))
291 | return labels
292 |
293 | # distances along x- and y-axis
294 | n, c, h, w = size
295 | x = np.arange(w) - w // 2
296 | y = np.arange(h) - h // 2
297 | x, y = np.meshgrid(x, y)
298 |
299 | # create logistic labels
300 | r_pos = self.cfg.r_pos / self.cfg.total_stride
301 | r_neg = self.cfg.r_neg / self.cfg.total_stride
302 | labels = logistic_labels(x, y, r_pos, r_neg)
303 |
304 | # pos/neg weights
305 | pos_num = np.sum(labels == 1)
306 | neg_num = np.sum(labels == 0)
307 | weights = np.zeros_like(labels)
308 | weights[labels == 1] = 0.5 / pos_num
309 | weights[labels == 0] = 0.5 / neg_num
310 | weights *= pos_num + neg_num
311 |
312 | # repeat to size
313 | labels = labels.reshape((1, 1, h, w))
314 | weights = weights.reshape((1, 1, h, w))
315 | labels = np.tile(labels, (n, c, 1, 1))
316 | weights = np.tile(weights, [n, c, 1, 1])
317 |
318 | # convert to tensors
319 | self.labels = torch.from_numpy(labels).to(self.device).float()
320 | self.weights = torch.from_numpy(weights).to(self.device).float()
321 |
322 | return self.labels, self.weights
323 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from got10k.experiments import *
4 |
5 | from siamfc import TrackerSiamFC
6 |
7 |
8 | if __name__ == '__main__':
9 | # setup tracker
10 | net_path = 'pretrained/siamfc/model.pth'
11 | tracker = TrackerSiamFC(net_path=net_path)
12 |
13 | # setup experiments
14 | experiments = [
15 | ExperimentGOT10k('data/GOT-10k', subset='test'),
16 | ExperimentOTB('data/OTB', version=2013),
17 | ExperimentOTB('data/OTB', version=2015),
18 | ExperimentVOT('data/vot2018', version=2018),
19 | ExperimentDTB70('data/DTB70'),
20 | ExperimentTColor128('data/Temple-color-128'),
21 | ExperimentUAV123('data/UAV123', version='UAV123'),
22 | ExperimentUAV123('data/UAV123', version='UAV20L'),
23 | ExperimentNfS('data/nfs', fps=30),
24 | ExperimentNfS('data/nfs', fps=240)
25 | ]
26 |
27 | # run tracking experiments and report performance
28 | for e in experiments:
29 | e.run(tracker, visualize=True)
30 | e.report([tracker.name])
31 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, print_function
2 |
3 | import os
4 | import sys
5 | import torch
6 | from torch.utils.data import DataLoader
7 |
8 | from got10k.datasets import ImageNetVID, GOT10k
9 | from pairwise import Pairwise
10 | from siamfc import TrackerSiamFC
11 |
12 |
13 | if __name__ == '__main__':
14 | # setup dataset
15 | name = 'GOT-10k'
16 | assert name in ['VID', 'GOT-10k']
17 | if name == 'GOT-10k':
18 | root_dir = 'data/GOT-10k'
19 | seq_dataset = GOT10k(root_dir, subset='train')
20 | elif name == 'VID':
21 | root_dir = 'data/ILSVRC'
22 | seq_dataset = ImageNetVID(root_dir, subset=('train', 'val'))
23 | pair_dataset = Pairwise(seq_dataset)
24 |
25 | # setup data loader
26 | cuda = torch.cuda.is_available()
27 | loader = DataLoader(
28 | pair_dataset, batch_size=8, shuffle=True,
29 | pin_memory=cuda, drop_last=True, num_workers=4)
30 |
31 | # setup tracker
32 | tracker = TrackerSiamFC()
33 |
34 | # path for saving checkpoints
35 | net_dir = 'pretrained/siamfc_new'
36 | if not os.path.exists(net_dir):
37 | os.makedirs(net_dir)
38 |
39 | # training loop
40 | epoch_num = 50
41 | for epoch in range(epoch_num):
42 | for step, batch in enumerate(loader):
43 | loss = tracker.step(
44 | batch, backward=True, update_lr=(step == 0))
45 | if step % 20 == 0:
46 | print('Epoch [{}][{}/{}]: Loss: {:.3f}'.format(
47 | epoch + 1, step + 1, len(loader), loss))
48 | sys.stdout.flush()
49 |
50 | # save checkpoint
51 | net_path = os.path.join(net_dir, 'model_e%d.pth' % (epoch + 1))
52 | torch.save(tracker.net.state_dict(), net_path)
53 |
--------------------------------------------------------------------------------