├── .gitignore ├── bin ├── train_siamfc.py ├── run_SiamFC.py ├── create_lmdb.py ├── demo_siamfc.py ├── create_dataset.py └── convert_pretrained_model.py ├── siamfc ├── __init__.py ├── config.py ├── utils.py ├── dataset.py ├── alexnet.py ├── train.py ├── tracker.py └── custom_transforms.py ├── LICENSE └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | */.idea/* 2 | *.so 3 | .DS_Store 4 | *.swp 5 | */build/* 6 | */__pycache__/* 7 | __pycache__/ 8 | build/ 9 | *.pyc 10 | test/* 11 | *.pth 12 | models/* 13 | models/logs/* 14 | *.mat 15 | -------------------------------------------------------------------------------- /bin/train_siamfc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.getcwd()) 4 | from fire import Fire 5 | 6 | from siamfc import train 7 | 8 | if __name__ == '__main__': 9 | Fire(train) 10 | -------------------------------------------------------------------------------- /siamfc/__init__.py: -------------------------------------------------------------------------------- 1 | from .tracker import SiamFCTracker 2 | from .train import train 3 | from .config import config 4 | from .utils import get_instance_image 5 | from .dataset import ImagnetVIDDataset 6 | from .alexnet import SiameseAlexNet 7 | 8 | 9 | -------------------------------------------------------------------------------- /bin/run_SiamFC.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | from siamfc import SiamFCTracker, config 4 | import cv2 5 | import glob 6 | import os 7 | 8 | 9 | def run_SiamFC(seq, rp, saveimage): 10 | x = seq.init_rect[0] 11 | y = seq.init_rect[1] 12 | w = seq.init_rect[2] 13 | h = seq.init_rect[3] 14 | 15 | tic = time.clock() 16 | # starting tracking 17 | tracker = SiamFCTracker(config.model_path, config.gpu_id) 18 | res = [] 19 | for idx, frame in enumerate(seq.s_frames): 20 | frame = cv2.cvtColor(cv2.imread(frame), cv2.COLOR_BGR2RGB) 21 | if idx == 0: 22 | bbox = (x, y, w, h) 23 | tracker.init(frame, bbox) 24 | bbox = (bbox[0], bbox[1], bbox[0]+bbox[2], bbox[1]+bbox[3]) # 1-idx 25 | else: 26 | bbox = tracker.update(frame) 27 | res.append((bbox[0], bbox[1], bbox[2]-bbox[0], bbox[3]-bbox[1])) # 1-idx 28 | duration = time.clock() - tic 29 | result = {} 30 | result['res'] = res 31 | result['type'] = 'rect' 32 | result['fps'] = round(seq.len / duration, 3) 33 | return result 34 | 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 StrangerZhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /bin/create_lmdb.py: -------------------------------------------------------------------------------- 1 | import lmdb 2 | import cv2 3 | import numpy as np 4 | import os 5 | import hashlib 6 | import functools 7 | 8 | from glob import glob 9 | from fire import Fire 10 | from tqdm import tqdm 11 | from multiprocessing import Pool 12 | 13 | def worker(video_name): 14 | image_names = glob(video_name+'/*') 15 | kv = {} 16 | for image_name in image_names: 17 | img = cv2.imread(image_name) 18 | _, img_encode = cv2.imencode('.jpg', img) 19 | img_encode = img_encode.tobytes() 20 | kv[hashlib.md5(image_name.encode()).digest()] = img_encode 21 | return kv 22 | 23 | def create_lmdb(data_dir, output_dir, num_threads): 24 | video_names = glob(data_dir+'/*') 25 | video_names = [x for x in video_names if os.path.isdir(x)] 26 | db = lmdb.open(output_dir, map_size=int(50e9)) 27 | with Pool(processes=num_threads) as pool: 28 | for ret in tqdm(pool.imap_unordered( 29 | functools.partial(worker), video_names), total=len(video_names)): 30 | with db.begin(write=True) as txn: 31 | for k, v in ret.items(): 32 | txn.put(k, v) 33 | 34 | if __name__ == '__main__': 35 | Fire(create_lmdb) 36 | 37 | -------------------------------------------------------------------------------- /bin/demo_siamfc.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import pandas as pd 4 | import argparse 5 | import numpy as np 6 | import cv2 7 | import time 8 | import sys 9 | sys.path.append(os.getcwd()) 10 | 11 | from fire import Fire 12 | from tqdm import tqdm 13 | 14 | from siamfc import SiamFCTracker 15 | 16 | def main(video_dir, gpu_id, model_path): 17 | # load videos 18 | filenames = sorted(glob.glob(os.path.join(video_dir, "img/*.jpg")), 19 | key=lambda x: int(os.path.basename(x).split('.')[0])) 20 | frames = [cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2RGB) for filename in filenames] 21 | gt_bboxes = pd.read_csv(os.path.join(video_dir, "groundtruth_rect.txt"), sep='\t|,| ', 22 | header=None, names=['xmin', 'ymin', 'width', 'height'], 23 | engine='python') 24 | 25 | title = video_dir.split('/')[-1] 26 | # starting tracking 27 | tracker = SiamFCTracker(model_path, gpu_id) 28 | for idx, frame in enumerate(frames): 29 | if idx == 0: 30 | bbox = gt_bboxes.iloc[0].values 31 | tracker.init(frame, bbox) 32 | bbox = (bbox[0]-1, bbox[1]-1, 33 | bbox[0]+bbox[2]-1, bbox[1]+bbox[3]-1) 34 | else: 35 | bbox = tracker.update(frame) 36 | # bbox xmin ymin xmax ymax 37 | frame = cv2.rectangle(frame, 38 | (int(bbox[0]), int(bbox[1])), 39 | (int(bbox[2]), int(bbox[3])), 40 | (0, 255, 0), 41 | 2) 42 | gt_bbox = gt_bboxes.iloc[idx].values 43 | gt_bbox = (gt_bbox[0], gt_bbox[1], 44 | gt_bbox[0]+gt_bbox[2], gt_bbox[1]+gt_bbox[3]) 45 | frame = cv2.rectangle(frame, 46 | (int(gt_bbox[0]-1), int(gt_bbox[1]-1)), # 0-index 47 | (int(gt_bbox[2]-1), int(gt_bbox[3]-1)), 48 | (255, 0, 0), 49 | 1) 50 | if len(frame.shape) == 3: 51 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 52 | frame = cv2.putText(frame, str(idx), (5, 20), cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, (0, 255, 0), 1) 53 | cv2.imshow(title, frame) 54 | cv2.waitKey(30) 55 | 56 | if __name__ == "__main__": 57 | Fire(main) 58 | -------------------------------------------------------------------------------- /siamfc/config.py: -------------------------------------------------------------------------------- 1 | 2 | class Config: 3 | # dataset related 4 | exemplar_size = 127 # exemplar size 5 | instance_size = 255 # instance size 6 | context_amount = 0.5 # context amount 7 | 8 | # training related 9 | num_per_epoch = 53200 # num of samples per epoch 10 | train_ratio = 0.9 # training ratio of VID dataset 11 | frame_range = 100 # frame range of choosing the instance 12 | train_batch_size = 8 # training batch size 13 | valid_batch_size = 8 # validation batch size 14 | train_num_workers = 8 # number of workers of train dataloader 15 | valid_num_workers = 8 # number of workers of validation dataloader 16 | lr = 1e-2 # learning rate of SGD 17 | momentum = 0.0 # momentum of SGD 18 | weight_decay = 0.0 # weight decay of optimizator 19 | step_size = 25 # step size of LR_Schedular 20 | gamma = 0.1 # decay rate of LR_Schedular 21 | epoch = 30 # total epoch 22 | seed = 1234 # seed to sample training videos 23 | log_dir = './models/logs' # log dirs 24 | radius = 16 # radius of positive label 25 | response_scale = 1e-3 # normalize of response 26 | max_translate = 3 # max translation of random shift 27 | 28 | # tracking related 29 | scale_step = 1.0375 # scale step of instance image 30 | num_scale = 3 # number of scales 31 | scale_lr = 0.59 # scale learning rate 32 | response_up_stride = 16 # response upsample stride 33 | response_sz = 17 # response size 34 | train_response_sz = 15 # train response size 35 | window_influence = 0.176 # window influence 36 | scale_penalty = 0.9745 # scale penalty 37 | total_stride = 8 # total stride of backbone 38 | sample_type = 'uniform' 39 | gray_ratio = 0.25 40 | blur_ratio = 0.15 41 | 42 | config = Config() 43 | -------------------------------------------------------------------------------- /siamfc/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | def get_center(x): 5 | return (x - 1.) / 2. 6 | 7 | def xyxy2cxcywh(bbox): 8 | return get_center(bbox[0]+bbox[2]), \ 9 | get_center(bbox[1]+bbox[3]), \ 10 | (bbox[2]-bbox[0]), \ 11 | (bbox[3]-bbox[1]) 12 | 13 | def crop_and_pad(img, cx, cy, model_sz, original_sz, img_mean=None): 14 | xmin = cx - original_sz // 2 15 | xmax = cx + original_sz // 2 16 | ymin = cy - original_sz // 2 17 | ymax = cy + original_sz // 2 18 | im_h, im_w, _ = img.shape 19 | 20 | left = right = top = bottom = 0 21 | if xmin < 0: 22 | left = int(abs(xmin)) 23 | if xmax > im_w: 24 | right = int(xmax - im_w) 25 | if ymin < 0: 26 | top = int(abs(ymin)) 27 | if ymax > im_h: 28 | bottom = int(ymax - im_h) 29 | 30 | xmin = int(max(0, xmin)) 31 | xmax = int(min(im_w, xmax)) 32 | ymin = int(max(0, ymin)) 33 | ymax = int(min(im_h, ymax)) 34 | im_patch = img[ymin:ymax, xmin:xmax] 35 | if left != 0 or right !=0 or top!=0 or bottom!=0: 36 | if img_mean is None: 37 | img_mean = tuple(map(int, img.mean(axis=(0, 1)))) 38 | im_patch = cv2.copyMakeBorder(im_patch, top, bottom, left, right, 39 | cv2.BORDER_CONSTANT, value=img_mean) 40 | if model_sz != original_sz: 41 | im_patch = cv2.resize(im_patch, (model_sz, model_sz)) 42 | return im_patch 43 | 44 | def get_exemplar_image(img, bbox, size_z, context_amount, img_mean=None): 45 | cx, cy, w, h = xyxy2cxcywh(bbox) 46 | wc_z = w + context_amount * (w+h) 47 | hc_z = h + context_amount * (w+h) 48 | s_z = np.sqrt(wc_z * hc_z) 49 | scale_z = size_z / s_z 50 | exemplar_img = crop_and_pad(img, cx, cy, size_z, s_z, img_mean) 51 | return exemplar_img, scale_z, s_z 52 | 53 | def get_instance_image(img, bbox, size_z, size_x, context_amount, img_mean=None): 54 | cx, cy, w, h = xyxy2cxcywh(bbox) 55 | wc_z = w + context_amount * (w+h) 56 | hc_z = h + context_amount * (w+h) 57 | s_z = np.sqrt(wc_z * hc_z) 58 | scale_z = size_z / s_z 59 | d_search = (size_x - size_z) / 2 60 | pad = d_search / scale_z 61 | s_x = s_z + 2 * pad 62 | scale_x = size_x / s_x 63 | instance_img = crop_and_pad(img, cx, cy, size_x, s_x, img_mean) 64 | return instance_img, scale_x, s_x 65 | 66 | def get_pyramid_instance_image(img, center, size_x, size_x_scales, img_mean=None): 67 | if img_mean is None: 68 | img_mean = tuple(map(int, img.mean(axis=(0, 1)))) 69 | pyramid = [crop_and_pad(img, center[0], center[1], size_x, size_x_scale, img_mean) 70 | for size_x_scale in size_x_scales] 71 | return pyramid 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch implementation of SiamFC 2 | 3 | ## Run demo 4 | ```bash 5 | cd SiamFC-Pytorch 6 | 7 | mkdir models 8 | 9 | # for color model 10 | wget http://www.robots.ox.ac.uk/%7Eluca/stuff/siam-fc_nets/2016-08-17.net.mat -P models/ 11 | # for color+gray model 12 | wget http://www.robots.ox.ac.uk/%7Eluca/stuff/siam-fc_nets/2016-08-17_gray025.net.mat -P models/ 13 | 14 | python bin/convert_pretrained_model.py 15 | 16 | # video dir should conatin groundtruth_rect.txt which the same format like otb 17 | python bin/demo_siamfc --gpu-id [gpu_id] --video-dir path/to/video 18 | ``` 19 | 20 | ## Training 21 | Download ILSVRC2015-VID 22 | 23 | ```bash 24 | cd SiamFC-Pytorch 25 | 26 | mkdir models 27 | 28 | # using 12 threads should take an hour 29 | python bin/create_dataset.py --data-dir path/to/data/ILSVRC2015 \ 30 | --output-dir path/to/data/ILSVRC2015_VID_CURATION \ 31 | --num-threads 8 32 | 33 | # ILSVRC2015_VID_CURATION and ILSVRC2015_VID_CURATION.lmdb should be in the same directory 34 | # the ILSVRC2015_VID_CURATION.lmdb should be about 34G or so 35 | python bin/create_lmdb.py --data-dir path/to/data/ILSVRC2015_VID_CURATION \ 36 | --output-dir path/to/data/ILSVRC2015_VID_CURATION.lmdb \ 37 | --num-threads 8 38 | 39 | # training should take about 1.5~2hrs on a Titan Xp GPU with 30 epochs 40 | python bin/train_siamfc.py --gpu-id [gpu_id] --data-dir path/to/data/ILSVRC2015_VID_CURATION 41 | ``` 42 | ## Benchmark results 43 | #### OTB100 44 | 45 | | Tracker | AUC | 46 | | --------------------------------------------- | --------------- | 47 | | SiamFC-color(converted from matconvnet) | 0.5544 | 48 | | SiamFC-color+gray(converted from matconvnet) | 0.5818(vs 0.582)| 49 | | SiamFC(trained from scratch) | 0.5820(vs 0.582)| 50 | 51 | ## Note 52 | We use SGD without momentum, weight decay setting 0, detailed setting can be found in config.py 53 | Training is unstable, In order to reproduce the result, you should evaluate all epoches between 54 | 10 to 30 on OTB100, and choose the best one. 55 | below is one of my experiment result. 56 | ```bash 57 | Epoch 11 AUC: 0.5522 58 | Epoch 12 AUC: 0.5670 59 | Epoch 13 AUC: 0.5604 60 | Epoch 14 AUC: 0.5559 61 | Epoch 15 AUC: 0.5790 62 | Epoch 16 AUC: 0.5687 63 | Epoch 17 AUC: 0.5534 64 | Epoch 18 AUC: 0.5745 65 | Epoch 19 AUC: 0.5619 66 | Epoch 20 AUC: 0.5749 67 | Epoch 21 AUC: 0.5648 68 | Epoch 22 AUC: 0.5775 69 | Epoch 23 AUC: 0.5784 70 | Epoch 24 AUC: 0.5812 71 | Epoch 25 AUC: 0.5785 72 | Epoch 26 AUC: 0.5637 73 | Epoch 27 AUC: 0.5764 74 | Epoch 28 AUC: 0.5675 75 | Epoch 29 AUC: 0.5787 76 | Epoch 30 AUC: 0.5820 77 | ``` 78 | ## Reference 79 | [1] Bertinetto, Luca and Valmadre, Jack and Henriques, Joo F and Vedaldi, Andrea and Torr, Philip H S 80 | Fully-Convolutional Siamese Networks for Object Tracking 81 | In ECCV 2016 workshops 82 | -------------------------------------------------------------------------------- /bin/create_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import os 4 | import cv2 5 | import functools 6 | import xml.etree.ElementTree as ET 7 | import sys 8 | sys.path.append(os.getcwd()) 9 | 10 | from multiprocessing import Pool 11 | from fire import Fire 12 | from tqdm import tqdm 13 | from glob import glob 14 | 15 | from siamfc import config, get_instance_image 16 | 17 | def worker(output_dir, video_dir): 18 | image_names = glob(os.path.join(video_dir, '*.JPEG')) 19 | image_names = sorted(image_names, 20 | key=lambda x:int(x.split('/')[-1].split('.')[0])) 21 | video_name = video_dir.split('/')[-1] 22 | save_folder = os.path.join(output_dir, video_name) 23 | if not os.path.exists(save_folder): 24 | os.mkdir(save_folder) 25 | trajs = {} 26 | for image_name in image_names: 27 | img = cv2.imread(image_name) 28 | img_mean = tuple(map(int, img.mean(axis=(0, 1)))) 29 | anno_name = image_name.replace('Data', 'Annotations') 30 | anno_name = anno_name.replace('JPEG', 'xml') 31 | tree = ET.parse(anno_name) 32 | root = tree.getroot() 33 | bboxes = [] 34 | filename = root.find('filename').text 35 | for obj in root.iter('object'): 36 | bbox = obj.find('bndbox') 37 | bbox = list(map(int, [bbox.find('xmin').text, 38 | bbox.find('ymin').text, 39 | bbox.find('xmax').text, 40 | bbox.find('ymax').text])) 41 | trkid = int(obj.find('trackid').text) 42 | if trkid in trajs: 43 | trajs[trkid].append(filename) 44 | else: 45 | trajs[trkid] = [filename] 46 | instance_img, _, _ = get_instance_image(img, bbox, 47 | config.exemplar_size, config.instance_size, config.context_amount, img_mean) 48 | instance_img_name = os.path.join(save_folder, filename+".{:02d}.x.jpg".format(trkid)) 49 | cv2.imwrite(instance_img_name, instance_img) 50 | return video_name, trajs 51 | 52 | def processing(data_dir, output_dir, num_threads=32): 53 | # get all 4417 videos 54 | video_dir = os.path.join(data_dir, 'Data/VID') 55 | all_videos = glob(os.path.join(video_dir, 'train/ILSVRC2015_VID_train_0000/*')) + \ 56 | glob(os.path.join(video_dir, 'train/ILSVRC2015_VID_train_0001/*')) + \ 57 | glob(os.path.join(video_dir, 'train/ILSVRC2015_VID_train_0002/*')) + \ 58 | glob(os.path.join(video_dir, 'train/ILSVRC2015_VID_train_0003/*')) + \ 59 | glob(os.path.join(video_dir, 'val/*')) 60 | meta_data = [] 61 | if not os.path.exists(output_dir): 62 | os.makedirs(output_dir) 63 | with Pool(processes=num_threads) as pool: 64 | for ret in tqdm(pool.imap_unordered( 65 | functools.partial(worker, output_dir), all_videos), total=len(all_videos)): 66 | meta_data.append(ret) 67 | 68 | # save meta data 69 | pickle.dump(meta_data, open(os.path.join(output_dir, "meta_data.pkl"), 'wb')) 70 | 71 | 72 | if __name__ == '__main__': 73 | Fire(processing) 74 | 75 | -------------------------------------------------------------------------------- /bin/convert_pretrained_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import re 3 | import numpy as np 4 | import argparse 5 | 6 | from scipy import io as sio 7 | from tqdm import tqdm 8 | 9 | # code adapted from https://github.com/bilylee/SiamFC-TensorFlow/blob/master/utils/train_utils.py 10 | def convert(mat_path): 11 | """Get parameter from .mat file into parms(dict)""" 12 | 13 | def squeeze(vars_): 14 | # Matlab save some params with shape (*, 1) 15 | # However, we don't need the trailing dimension in TensorFlow. 16 | if isinstance(vars_, (list, tuple)): 17 | return [np.squeeze(v, 1) for v in vars_] 18 | else: 19 | return np.squeeze(vars_, 1) 20 | 21 | netparams = sio.loadmat(mat_path)["net"]["params"][0][0] 22 | params = dict() 23 | 24 | name_map = {(1, 'conv'): 0, (1, 'bn'): 1, 25 | (2, 'conv'): 4, (2, 'bn'): 5, 26 | (3, 'conv'): 8, (3, 'bn'): 9, 27 | (4, 'conv'): 11, (4, 'bn'): 12, 28 | (5, 'conv'): 14} 29 | for i in tqdm(range(netparams.size)): 30 | param = netparams[0][i] 31 | name = param["name"][0] 32 | value = param["value"] 33 | value_size = param["value"].shape[0] 34 | 35 | match = re.match(r"([a-z]+)([0-9]+)([a-z]+)", name, re.I) 36 | if match: 37 | items = match.groups() 38 | elif name == 'adjust_f': 39 | continue 40 | elif name == 'adjust_b': 41 | params['corr_bias'] = torch.from_numpy(squeeze(value)) 42 | continue 43 | 44 | 45 | op, layer, types = items 46 | layer = int(layer) 47 | if layer in [1, 2, 3, 4, 5]: 48 | idx = name_map[(layer, op)] 49 | if op == 'conv': # convolution 50 | if types == 'f': 51 | params['features.{}.weight'.format(idx)] = torch.from_numpy(value.transpose(3, 2, 0, 1)) 52 | elif types == 'b':# and layer == 5: 53 | value = squeeze(value) 54 | params['features.{}.bias'.format(idx)] = torch.from_numpy(value) 55 | elif op == 'bn': # batch normalization 56 | if types == 'x': 57 | m, v = squeeze(np.split(value, 2, 1)) 58 | params['features.{}.running_mean'.format(idx)] = torch.from_numpy(m) 59 | params['features.{}.running_var'.format(idx)] = torch.from_numpy(np.square(v)) 60 | # params['features.{}.num_batches_tracked'.format(idx)] = torch.zeros(0) 61 | elif types == 'm': 62 | value = squeeze(value) 63 | params['features.{}.weight'.format(idx)] = torch.from_numpy(value) 64 | elif types == 'b': 65 | value = squeeze(value) 66 | params['features.{}.bias'.format(idx)] = torch.from_numpy(value) 67 | else: 68 | raise Exception 69 | return params 70 | 71 | if __name__ == '__main__': 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument('--mat_path', type=str, default="./models/2016-08-17.net.mat") 74 | args = parser.parse_args() 75 | params = convert(args.mat_path) 76 | torch.save(params, "./models/siamfc_pretrained.pth") 77 | -------------------------------------------------------------------------------- /siamfc/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import os 4 | import numpy as np 5 | import pickle 6 | import lmdb 7 | import hashlib 8 | from torch.utils.data.dataset import Dataset 9 | 10 | from .config import config 11 | 12 | class ImagnetVIDDataset(Dataset): 13 | def __init__(self, db, video_names, data_dir, z_transforms, x_transforms, training=True): 14 | self.video_names = video_names 15 | self.data_dir = data_dir 16 | self.z_transforms = z_transforms 17 | self.x_transforms = x_transforms 18 | meta_data_path = os.path.join(data_dir, 'meta_data.pkl') 19 | self.meta_data = pickle.load(open(meta_data_path, 'rb')) 20 | self.meta_data = {x[0]:x[1] for x in self.meta_data} 21 | # filter traj len less than 2 22 | for key in self.meta_data.keys(): 23 | trajs = self.meta_data[key] 24 | for trkid in list(trajs.keys()): 25 | if len(trajs[trkid]) < 2: 26 | del trajs[trkid] 27 | 28 | self.txn = db.begin(write=False) 29 | self.num = len(self.video_names) if config.num_per_epoch is None or not training\ 30 | else config.num_per_epoch 31 | 32 | def imread(self, path): 33 | key = hashlib.md5(path.encode()).digest() 34 | img_buffer = self.txn.get(key) 35 | img_buffer = np.frombuffer(img_buffer, np.uint8) 36 | img = cv2.imdecode(img_buffer, cv2.IMREAD_COLOR) 37 | return img 38 | 39 | def _sample_weights(self, center, low_idx, high_idx, s_type='uniform'): 40 | weights = list(range(low_idx, high_idx)) 41 | weights.remove(center) 42 | weights = np.array(weights) 43 | if s_type == 'linear': 44 | weights = abs(weights - center) 45 | elif s_type == 'sqrt': 46 | weights = np.sqrt(abs(weights - center)) 47 | elif s_type == 'uniform': 48 | weights = np.ones_like(weights) 49 | return weights / sum(weights) 50 | 51 | def __getitem__(self, idx): 52 | idx = idx % len(self.video_names) 53 | video = self.video_names[idx] 54 | trajs = self.meta_data[video] 55 | # sample one trajs 56 | trkid = np.random.choice(list(trajs.keys())) 57 | traj = trajs[trkid] 58 | assert len(traj) > 1, "video_name: {}".format(video) 59 | # sample exemplar 60 | exemplar_idx = np.random.choice(list(range(len(traj)))) 61 | exemplar_name = os.path.join(self.data_dir, video, traj[exemplar_idx]+".{:02d}.x.jpg".format(trkid)) 62 | exemplar_img = self.imread(exemplar_name) 63 | exemplar_img = cv2.cvtColor(exemplar_img, cv2.COLOR_BGR2RGB) 64 | # sample instance 65 | low_idx = max(0, exemplar_idx - config.frame_range) 66 | up_idx = min(len(traj), exemplar_idx + config.frame_range) 67 | 68 | # create sample weight, if the sample are far away from center 69 | # the probability being choosen are high 70 | weights = self._sample_weights(exemplar_idx, low_idx, up_idx, config.sample_type) 71 | instance = np.random.choice(traj[low_idx:exemplar_idx] + traj[exemplar_idx+1:up_idx], p=weights) 72 | instance_name = os.path.join(self.data_dir, video, instance+".{:02d}.x.jpg".format(trkid)) 73 | instance_img = self.imread(instance_name) 74 | instance_img = cv2.cvtColor(instance_img, cv2.COLOR_BGR2RGB) 75 | if np.random.rand(1) < config.gray_ratio: 76 | exemplar_img = cv2.cvtColor(exemplar_img, cv2.COLOR_RGB2GRAY) 77 | exemplar_img = cv2.cvtColor(exemplar_img, cv2.COLOR_GRAY2RGB) 78 | instance_img = cv2.cvtColor(instance_img, cv2.COLOR_RGB2GRAY) 79 | instance_img = cv2.cvtColor(instance_img, cv2.COLOR_GRAY2RGB) 80 | exemplar_img = self.z_transforms(exemplar_img) 81 | instance_img = self.x_transforms(instance_img) 82 | return exemplar_img, instance_img 83 | 84 | def __len__(self): 85 | return self.num 86 | -------------------------------------------------------------------------------- /siamfc/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | import torchvision.transforms as transforms 5 | from .custom_transforms import ToTensor 6 | 7 | from torchvision.models import alexnet 8 | from torch.autograd import Variable 9 | from torch import nn 10 | 11 | from .config import config 12 | 13 | class SiameseAlexNet(nn.Module): 14 | def __init__(self, gpu_id, train=True): 15 | super(SiameseAlexNet, self).__init__() 16 | self.features = nn.Sequential( 17 | nn.Conv2d(3, 96, 11, 2), 18 | nn.BatchNorm2d(96), 19 | nn.ReLU(inplace=True), 20 | nn.MaxPool2d(3, 2), 21 | nn.Conv2d(96, 256, 5, 1, groups=2), 22 | nn.BatchNorm2d(256), 23 | nn.ReLU(inplace=True), 24 | nn.MaxPool2d(3, 2), 25 | nn.Conv2d(256, 384, 3, 1), 26 | nn.BatchNorm2d(384), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(384, 384, 3, 1, groups=2), 29 | nn.BatchNorm2d(384), 30 | nn.ReLU(inplace=True), 31 | nn.Conv2d(384, 256, 3, 1, groups=2) 32 | ) 33 | self.corr_bias = nn.Parameter(torch.zeros(1)) 34 | if train: 35 | gt, weight = self._create_gt_mask((config.train_response_sz, config.train_response_sz)) 36 | with torch.cuda.device(gpu_id): 37 | self.train_gt = torch.from_numpy(gt).cuda() 38 | self.train_weight = torch.from_numpy(weight).cuda() 39 | gt, weight = self._create_gt_mask((config.response_sz, config.response_sz)) 40 | with torch.cuda.device(gpu_id): 41 | self.valid_gt = torch.from_numpy(gt).cuda() 42 | self.valid_weight = torch.from_numpy(weight).cuda() 43 | self.exemplar = None 44 | self.gpu_id = gpu_id 45 | 46 | def init_weights(self): 47 | for m in self.modules(): 48 | if isinstance(m, nn.Conv2d): 49 | nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu') 50 | elif isinstance(m, nn.BatchNorm2d): 51 | m.weight.data.fill_(1) 52 | m.bias.data.zero_() 53 | 54 | def forward(self, x): 55 | exemplar, instance = x 56 | if exemplar is not None and instance is not None: 57 | batch_size = exemplar.shape[0] 58 | exemplar = self.features(exemplar) 59 | instance = self.features(instance) 60 | score_map = [] 61 | N, C, H, W = instance.shape 62 | instance = instance.view(1, -1, H, W) 63 | score = F.conv2d(instance, exemplar, groups=N) * config.response_scale \ 64 | + self.corr_bias 65 | return score.transpose(0, 1) 66 | elif exemplar is not None and instance is None: 67 | # inference used 68 | self.exemplar = self.features(exemplar) 69 | self.exemplar = torch.cat([self.exemplar for _ in range(3)], dim=0) 70 | else: 71 | # inference used we don't need to scale the reponse or add bias 72 | instance = self.features(instance) 73 | N, _, H, W = instance.shape 74 | instance = instance.view(1, -1, H, W) 75 | score = F.conv2d(instance, self.exemplar, groups=N) 76 | return score.transpose(0, 1) 77 | 78 | def loss(self, pred): 79 | return F.binary_cross_entropy_with_logits(pred, self.gt) 80 | 81 | def weighted_loss(self, pred): 82 | if self.training: 83 | return F.binary_cross_entropy_with_logits(pred, self.train_gt, 84 | self.train_weight, reduction='sum') / config.train_batch_size # normalize the batch_size 85 | else: 86 | return F.binary_cross_entropy_with_logits(pred, self.valid_gt, 87 | self.valid_weight, reduction='sum') / config.train_batch_size # normalize the batch_size 88 | 89 | def _create_gt_mask(self, shape): 90 | # same for all pairs 91 | h, w = shape 92 | y = np.arange(h, dtype=np.float32) - (h-1) / 2. 93 | x = np.arange(w, dtype=np.float32) - (w-1) / 2. 94 | y, x = np.meshgrid(y, x) 95 | dist = np.sqrt(x**2 + y**2) 96 | mask = np.zeros((h, w)) 97 | mask[dist <= config.radius / config.total_stride] = 1 98 | mask = mask[np.newaxis, :, :] 99 | weights = np.ones_like(mask) 100 | weights[mask == 1] = 0.5 / np.sum(mask == 1) 101 | weights[mask == 0] = 0.5 / np.sum(mask == 0) 102 | mask = np.repeat(mask, config.train_batch_size, axis=0)[:, np.newaxis, :, :] 103 | return mask.astype(np.float32), weights.astype(np.float32) 104 | -------------------------------------------------------------------------------- /siamfc/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.optim as optim 4 | import torchvision.transforms as transforms 5 | import torchvision 6 | import numpy as np 7 | import pandas as pd 8 | import os 9 | import cv2 10 | import pickle 11 | import lmdb 12 | 13 | from fire import Fire 14 | from torch.autograd import Variable 15 | from torch.optim.lr_scheduler import StepLR 16 | from torch.utils.data import DataLoader 17 | from glob import glob 18 | from tqdm import tqdm 19 | from sklearn.model_selection import train_test_split 20 | from tensorboardX import SummaryWriter 21 | 22 | from .config import config 23 | from .alexnet import SiameseAlexNet 24 | from .dataset import ImagnetVIDDataset 25 | from .custom_transforms import Normalize, ToTensor, RandomStretch, \ 26 | RandomCrop, CenterCrop, RandomBlur, ColorAug 27 | 28 | torch.manual_seed(1234) 29 | 30 | def train(gpu_id, data_dir): 31 | # loading meta data 32 | meta_data_path = os.path.join(data_dir, "meta_data.pkl") 33 | meta_data = pickle.load(open(meta_data_path,'rb')) 34 | all_videos = [x[0] for x in meta_data] 35 | 36 | # split train/valid dataset 37 | train_videos, valid_videos = train_test_split(all_videos, 38 | test_size=1-config.train_ratio, random_state=config.seed) 39 | 40 | # define transforms 41 | random_crop_size = config.instance_size - 2 * config.total_stride 42 | train_z_transforms = transforms.Compose([ 43 | RandomStretch(), 44 | CenterCrop((config.exemplar_size, config.exemplar_size)), 45 | ToTensor() 46 | ]) 47 | train_x_transforms = transforms.Compose([ 48 | RandomStretch(), 49 | RandomCrop((random_crop_size, random_crop_size), 50 | config.max_translate), 51 | ToTensor() 52 | ]) 53 | valid_z_transforms = transforms.Compose([ 54 | CenterCrop((config.exemplar_size, config.exemplar_size)), 55 | ToTensor() 56 | ]) 57 | valid_x_transforms = transforms.Compose([ 58 | ToTensor() 59 | ]) 60 | 61 | # open lmdb 62 | db = lmdb.open(data_dir+'.lmdb', readonly=True, map_size=int(50e9)) 63 | 64 | # create dataset 65 | train_dataset = ImagnetVIDDataset(db, train_videos, data_dir, 66 | train_z_transforms, train_x_transforms) 67 | valid_dataset = ImagnetVIDDataset(db, valid_videos, data_dir, 68 | valid_z_transforms, valid_x_transforms, training=False) 69 | 70 | # create dataloader 71 | trainloader = DataLoader(train_dataset, batch_size=config.train_batch_size, 72 | shuffle=True, pin_memory=True, num_workers=config.train_num_workers, drop_last=True) 73 | validloader = DataLoader(valid_dataset, batch_size=config.valid_batch_size, 74 | shuffle=False, pin_memory=True, num_workers=config.valid_num_workers, drop_last=True) 75 | 76 | # create summary writer 77 | if not os.path.exists(config.log_dir): 78 | os.mkdir(config.log_dir) 79 | summary_writer = SummaryWriter(config.log_dir) 80 | 81 | # start training 82 | with torch.cuda.device(gpu_id): 83 | model = SiameseAlexNet(gpu_id, train=True) 84 | model.init_weights() 85 | model = model.cuda() 86 | optimizer = torch.optim.SGD(model.parameters(), lr=config.lr, 87 | momentum=config.momentum, weight_decay=config.weight_decay) 88 | scheduler = StepLR(optimizer, step_size=config.step_size, 89 | gamma=config.gamma) 90 | 91 | for epoch in range(config.epoch): 92 | train_loss = [] 93 | model.train() 94 | for i, data in enumerate(tqdm(trainloader)): 95 | exemplar_imgs, instance_imgs = data 96 | exemplar_var, instance_var = Variable(exemplar_imgs.cuda()), \ 97 | Variable(instance_imgs.cuda()) 98 | optimizer.zero_grad() 99 | outputs = model((exemplar_var, instance_var)) 100 | loss = model.weighted_loss(outputs) 101 | loss.backward() 102 | optimizer.step() 103 | step = epoch * len(trainloader) + i 104 | summary_writer.add_scalar('train/loss', loss.data, step) 105 | train_loss.append(loss.data) 106 | train_loss = np.mean(train_loss) 107 | 108 | valid_loss = [] 109 | model.eval() 110 | for i, data in enumerate(tqdm(validloader)): 111 | exemplar_imgs, instance_imgs = data 112 | exemplar_var, instance_var = Variable(exemplar_imgs.cuda()),\ 113 | Variable(instance_imgs.cuda()) 114 | outputs = model((exemplar_var, instance_var)) 115 | loss = model.weighted_loss(outputs) 116 | valid_loss.append(loss.data) 117 | valid_loss = np.mean(valid_loss) 118 | print("EPOCH %d valid_loss: %.4f, train_loss: %.4f" % 119 | (epoch, valid_loss, train_loss)) 120 | summary_writer.add_scalar('valid/loss', 121 | valid_loss, (epoch+1)*len(trainloader)) 122 | torch.save(model.cpu().state_dict(), 123 | "./models/siamfc_{}.pth".format(epoch+1)) 124 | model.cuda() 125 | scheduler.step() 126 | -------------------------------------------------------------------------------- /siamfc/tracker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | import torch.nn.functional as F 5 | import time 6 | import warnings 7 | import torchvision.transforms as transforms 8 | 9 | from torch.autograd import Variable 10 | 11 | from .alexnet import SiameseAlexNet 12 | from .config import config 13 | from .custom_transforms import ToTensor 14 | from .utils import get_exemplar_image, get_pyramid_instance_image, get_instance_image 15 | 16 | torch.set_num_threads(1) # otherwise pytorch will take all cpus 17 | 18 | class SiamFCTracker: 19 | def __init__(self, model_path, gpu_id): 20 | self.gpu_id = gpu_id 21 | with torch.cuda.device(gpu_id): 22 | self.model = SiameseAlexNet(gpu_id, train=False) 23 | self.model.load_state_dict(torch.load(model_path)) 24 | self.model = self.model.cuda() 25 | self.model.eval() 26 | self.transforms = transforms.Compose([ 27 | ToTensor() 28 | ]) 29 | 30 | def _cosine_window(self, size): 31 | """ 32 | get the cosine window 33 | """ 34 | cos_window = np.hanning(int(size[0]))[:, np.newaxis].dot(np.hanning(int(size[1]))[np.newaxis, :]) 35 | cos_window = cos_window.astype(np.float32) 36 | cos_window /= np.sum(cos_window) 37 | return cos_window 38 | 39 | def init(self, frame, bbox): 40 | """ initialize siamfc tracker 41 | Args: 42 | frame: an RGB image 43 | bbox: one-based bounding box [x, y, width, height] 44 | """ 45 | self.bbox = (bbox[0]-1, bbox[1]-1, bbox[0]-1+bbox[2], bbox[1]-1+bbox[3]) # zero based 46 | self.pos = np.array([bbox[0]-1+(bbox[2]-1)/2, bbox[1]-1+(bbox[3]-1)/2]) # center x, center y, zero based 47 | self.target_sz = np.array([bbox[2], bbox[3]]) # width, height 48 | # get exemplar img 49 | self.img_mean = tuple(map(int, frame.mean(axis=(0, 1)))) 50 | exemplar_img, scale_z, s_z = get_exemplar_image(frame, self.bbox, 51 | config.exemplar_size, config.context_amount, self.img_mean) 52 | 53 | # get exemplar feature 54 | exemplar_img = self.transforms(exemplar_img)[None,:,:,:] 55 | with torch.cuda.device(self.gpu_id): 56 | exemplar_img_var = Variable(exemplar_img.cuda()) 57 | self.model((exemplar_img_var, None)) 58 | 59 | self.penalty = np.ones((config.num_scale)) * config.scale_penalty 60 | self.penalty[config.num_scale//2] = 1 61 | 62 | # create cosine window 63 | self.interp_response_sz = config.response_up_stride * config.response_sz 64 | self.cosine_window = self._cosine_window((self.interp_response_sz, self.interp_response_sz)) 65 | 66 | # create scalse 67 | self.scales = config.scale_step ** np.arange(np.ceil(config.num_scale/2)-config.num_scale, 68 | np.floor(config.num_scale/2)+1) 69 | 70 | # create s_x 71 | self.s_x = s_z + (config.instance_size-config.exemplar_size) / scale_z 72 | 73 | # arbitrary scale saturation 74 | self.min_s_x = 0.2 * self.s_x 75 | self.max_s_x = 5 * self.s_x 76 | 77 | def update(self, frame): 78 | """track object based on the previous frame 79 | Args: 80 | frame: an RGB image 81 | 82 | Returns: 83 | bbox: tuple of 1-based bounding box(xmin, ymin, xmax, ymax) 84 | """ 85 | size_x_scales = self.s_x * self.scales 86 | pyramid = get_pyramid_instance_image(frame, self.pos, config.instance_size, size_x_scales, self.img_mean) 87 | instance_imgs = torch.cat([self.transforms(x)[None,:,:,:] for x in pyramid], dim=0) 88 | with torch.cuda.device(self.gpu_id): 89 | instance_imgs_var = Variable(instance_imgs.cuda()) 90 | response_maps = self.model((None, instance_imgs_var)) 91 | response_maps = response_maps.data.cpu().numpy().squeeze() 92 | response_maps_up = [cv2.resize(x, (self.interp_response_sz, self.interp_response_sz), cv2.INTER_CUBIC) 93 | for x in response_maps] 94 | # get max score 95 | max_score = np.array([x.max() for x in response_maps_up]) * self.penalty 96 | 97 | # penalty scale change 98 | scale_idx = max_score.argmax() 99 | response_map = response_maps_up[scale_idx] 100 | response_map -= response_map.min() 101 | response_map /= response_map.sum() 102 | response_map = (1 - config.window_influence) * response_map + \ 103 | config.window_influence * self.cosine_window 104 | max_r, max_c = np.unravel_index(response_map.argmax(), response_map.shape) 105 | # displacement in interpolation response 106 | disp_response_interp = np.array([max_c, max_r]) - (self.interp_response_sz-1) / 2. 107 | # displacement in input 108 | disp_response_input = disp_response_interp * config.total_stride / config.response_up_stride 109 | # displacement in frame 110 | scale = self.scales[scale_idx] 111 | disp_response_frame = disp_response_input * (self.s_x * scale) / config.instance_size 112 | # position in frame coordinates 113 | self.pos += disp_response_frame 114 | # scale damping and saturation 115 | self.s_x *= ((1 - config.scale_lr) + config.scale_lr * scale) 116 | self.s_x = max(self.min_s_x, min(self.max_s_x, self.s_x)) 117 | self.target_sz = ((1 - config.scale_lr) + config.scale_lr * scale) * self.target_sz 118 | bbox = (self.pos[0] - self.target_sz[0]/2 + 1, # xmin convert to 1-based 119 | self.pos[1] - self.target_sz[1]/2 + 1, # ymin 120 | self.pos[0] + self.target_sz[0]/2 + 1, # xmax 121 | self.pos[1] + self.target_sz[1]/2 + 1) # ymax 122 | return bbox 123 | -------------------------------------------------------------------------------- /siamfc/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | 5 | class RandomStretch(object): 6 | def __init__(self, max_stretch=0.05): 7 | """Random resize image according to the stretch 8 | Args: 9 | max_stretch(float): 0 to 1 value 10 | """ 11 | self.max_stretch = max_stretch 12 | 13 | def __call__(self, sample): 14 | """ 15 | Args: 16 | sample(numpy array): 3 or 1 dim image 17 | """ 18 | scale_h = 1.0 + np.random.uniform(-self.max_stretch, self.max_stretch) 19 | scale_w = 1.0 + np.random.uniform(-self.max_stretch, self.max_stretch) 20 | h, w = sample.shape[:2] 21 | shape = (int(h * scale_h), int(w * scale_w)) 22 | return cv2.resize(sample, shape, cv2.INTER_LINEAR) 23 | 24 | class CenterCrop(object): 25 | def __init__(self, size): 26 | """Crop the image in the center according the given size 27 | if size greater than image size, zero padding will adpot 28 | Args: 29 | size (tuple): desired size 30 | """ 31 | self.size = size 32 | 33 | def __call__(self, sample): 34 | """ 35 | Args: 36 | sample(numpy array): 3 or 1 dim image 37 | """ 38 | shape = sample.shape[:2] 39 | cy, cx = (shape[0]-1) // 2, (shape[1]-1) // 2 40 | ymin, xmin = cy - self.size[0]//2, cx - self.size[1] // 2 41 | ymax, xmax = cy + self.size[0]//2 + self.size[0] % 2,\ 42 | cx + self.size[1]//2 + self.size[1] % 2 43 | left = right = top = bottom = 0 44 | im_h, im_w = shape 45 | if xmin < 0: 46 | left = int(abs(xmin)) 47 | if xmax > im_w: 48 | right = int(xmax - im_w) 49 | if ymin < 0: 50 | top = int(abs(ymin)) 51 | if ymax > im_h: 52 | bottom = int(ymax - im_h) 53 | 54 | xmin = int(max(0, xmin)) 55 | xmax = int(min(im_w, xmax)) 56 | ymin = int(max(0, ymin)) 57 | ymax = int(min(im_h, ymax)) 58 | im_patch = sample[ymin:ymax, xmin:xmax] 59 | if left != 0 or right !=0 or top!=0 or bottom!=0: 60 | im_patch = cv2.copyMakeBorder(im_patch, top, bottom, left, right, 61 | cv2.BORDER_CONSTANT, value=0) 62 | return im_patch 63 | 64 | class RandomCrop(object): 65 | def __init__(self, size, max_translate): 66 | """Crop the image in the center according the given size 67 | if size greater than image size, zero padding will adpot 68 | Args: 69 | size (tuple): desired size 70 | max_translate: max translate of random shift 71 | """ 72 | self.size = size 73 | self.max_translate = max_translate 74 | 75 | def __call__(self, sample): 76 | """ 77 | Args: 78 | sample(numpy array): 3 or 1 dim image 79 | """ 80 | shape = sample.shape[:2] 81 | cy_o = (shape[0] - 1) // 2 82 | cx_o = (shape[1] - 1) // 2 83 | cy = np.random.randint(cy_o - self.max_translate, 84 | cy_o + self.max_translate+1) 85 | cx = np.random.randint(cx_o - self.max_translate, 86 | cx_o + self.max_translate+1) 87 | assert abs(cy-cy_o) <= self.max_translate and \ 88 | abs(cx-cx_o) <= self.max_translate 89 | ymin = cy - self.size[0] // 2 90 | xmin = cx - self.size[1] // 2 91 | ymax = cy + self.size[0] // 2 + self.size[0] % 2 92 | xmax = cx + self.size[1] // 2 + self.size[1] % 2 93 | left = right = top = bottom = 0 94 | im_h, im_w = shape 95 | if xmin < 0: 96 | left = int(abs(xmin)) 97 | if xmax > im_w: 98 | right = int(xmax - im_w) 99 | if ymin < 0: 100 | top = int(abs(ymin)) 101 | if ymax > im_h: 102 | bottom = int(ymax - im_h) 103 | 104 | xmin = int(max(0, xmin)) 105 | xmax = int(min(im_w, xmax)) 106 | ymin = int(max(0, ymin)) 107 | ymax = int(min(im_h, ymax)) 108 | im_patch = sample[ymin:ymax, xmin:xmax] 109 | if left != 0 or right !=0 or top!=0 or bottom!=0: 110 | im_patch = cv2.copyMakeBorder(im_patch, top, bottom, left, right, 111 | cv2.BORDER_CONSTANT, value=0) 112 | return im_patch 113 | 114 | class ColorAug(object): 115 | def __init__(self, type_in='z'): 116 | if type_in == 'z': 117 | rgb_var = np.array([[3.2586416e+03,2.8992207e+03,2.6392236e+03], 118 | [2.8992207e+03,3.0958174e+03,2.9321748e+03], 119 | [2.6392236e+03,2.9321748e+03,3.4533721e+03]]) 120 | if type_in == 'x': 121 | rgb_var = np.array([[2.4847285e+03,2.1796064e+03,1.9766885e+03], 122 | [2.1796064e+03,2.3441289e+03,2.2357402e+03], 123 | [1.9766885e+03,2.2357402e+03,2.7369697e+03]]) 124 | self.v, _ = np.linalg.eig(rgb_var) 125 | self.v = np.sqrt(self.v) 126 | 127 | def __call__(self, sample): 128 | return sample + 0.1 * self.v * np.random.randn(3) 129 | 130 | 131 | class RandomBlur(object): 132 | def __init__(self, ratio): 133 | self.ratio = ratio 134 | 135 | def __call__(self, sample): 136 | if np.random.rand(1) < self.ratio: 137 | # random kernel size 138 | kernel_size = np.random.choice([3, 5, 7]) 139 | # random gaussian sigma 140 | sigma = np.random.rand() * 5 141 | return cv2.GaussianBlur(sample, (kernel_size,kernel_size), sigma) 142 | else: 143 | return sample 144 | 145 | class Normalize(object): 146 | def __init__(self): 147 | self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) 148 | self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32) 149 | 150 | def __call__(self, sample): 151 | return (sample / 255. - self.mean) / self.std 152 | 153 | class ToTensor(object): 154 | def __call__(self, sample): 155 | sample = sample.transpose(2, 0, 1) 156 | return torch.from_numpy(sample.astype(np.float32)) 157 | --------------------------------------------------------------------------------