├── README.md ├── dataset.py ├── dataset_loaders ├── ECSSD_MSRA10K.py ├── __init__.py ├── coco.py ├── custom_transforms.py ├── custom_transforms_MTB.py ├── dataset_utils.py ├── davis17_v2.py ├── davis17_v2_all.py ├── davis17_v2_org.py ├── helpers.py ├── s ├── ytvos_v2.py └── ytvos_v2_org.py ├── eccv-framework.png ├── local_config.py ├── models ├── __init__.py ├── backbones │ ├── __init__.py │ ├── resnet.py │ └── s ├── graph_memory.py ├── helpers.py ├── s └── units │ ├── ConvGRU2.py │ ├── __init__.py │ └── s ├── requirements.txt ├── run_graph_memory_test.sh ├── runfiles ├── eval_DAVIS_graph_memory.py └── s └── utils ├── __init__.py ├── debugging.py ├── readsaveimage.py ├── s ├── stats.py └── tensor_utils.py /README.md: -------------------------------------------------------------------------------- 1 | ## GraphMemVOS 2 | Code for ECCV 2020 spotlight paper: Video Object Segmentation with Episodic Graph Memory Networks 3 | # 4 | ![](../master/eccv-framework.png) 5 | 6 | ## Testing 7 | 1. Install python (3.6.5), pytorch (version:1.0.1) and requirements in the requirements.txt files. Download the DAVIS-2017 dataset. 8 | 9 | 2. Download the pretrained model from [googledrive](https://drive.google.com/file/d/1HO6wlhAYSuBDx4Cnb8efQyLs357ycDz2/view?usp=sharing) and put it into the workspace_STM_alpha files. 10 | 11 | 3. Run 'run_graph_memory_test.sh' and change the davis dataset path, pretrainde model path and result path and the paths in local_config.py. 12 | 13 | The segmentation results can be download from [googledrive](https://drive.google.com/file/d/1CYDtlQNlq2ZEKI29LLOb8TZq4eSpiRPJ/view?usp=sharing). 14 | 15 | ## Results 16 | 1. **DAVIS** ( Val 2017): 17 | 18 | In the inference stage, we ran using the default size of DAVIS (480p). 19 | 20 | **Mean J&F** | **J score** | **F score** | 21 | ---------| :---------: | :---------: 22 | **82.8** | **80.0** | **85.2** | 23 | 24 | 2. **YouTube-VOS** (Val 2018): 25 | 26 | **J Seen** | **F Seen** | **J Unseen** | **F Unseen** | **Mean** | 27 | ---------| :---------: | :---------: | :---------: | :---------: 28 | **80.7** | **85.1** | **74.0** | **80.9** | **80.2** | 29 | 30 | 3. **DAVIS-2016**: 31 | 32 | **J score** | **F score** | **Mean T** | 33 | ---------| :---------: | :---------: 34 | **82.5** | **81.2** | **19.8** | 35 | 36 | 4. **Youtube-Objects**: 37 | 38 | **Airplane** | **Bird** | **Boat** | **Car** | **Cat** | **Cow** | **Dog** | **Horse** | **Motorbike** |**Train** |**Mean** | 39 | ---------| :---------: | :---------: |:---------: | :---------: |:---------: | :---------: |:---------: | :---------: | :---------: | :---------: 40 | **86.1** | **75.7** | **68.6** |**82.4** | **65.9** | **70.5** |**77.1** | **72.2** | **63.8** |**47.8** | **71.4** | 41 | 42 | ## Citation 43 | 44 | If you find the code and dataset useful in your research, please consider citing: 45 | ``` 46 | @inproceedings{lu2020video, 47 | title={Video Object Segmentation with Episodic Graph Memory Networks}, 48 | author={Lu, Xiankai and Wang, Wenguan and Martin, Danelljan and Zhou, Tianfei and Shen, Jianbing and Luc, Van Gool}, 49 | booktitle={ECCV}, 50 | year={2020} 51 | } 52 | ``` 53 | ## Other related projects/papers: 54 | 55 | 1. Zero-shot Video Object Segmentation via Attentive Graph Neural Networks, ICCV 2019 (https://github.com/carrierlxk/AGNN) 56 | 57 | ## Acknowledge 58 | 59 | 1. Video object segmentation using space-time memory networks, ICCV 2019 (https://github.com/seoungwugoh/STM) 60 | 2. A Generative Appearance Model for End-to-End Video Object Segmentation, CVPR2019 (https://github.com/joakimjohnander/agame-vos) 61 | 3. https://github.com/lyxok1/STM-Training 62 | 63 | Any comments, please email: carrierlxk@gmail.com 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | from PIL import Image 5 | 6 | import torch 7 | import torchvision 8 | from torch.utils import data 9 | 10 | import glob 11 | 12 | class DAVIS_MO_Test(data.Dataset): 13 | # for multi object, do shuffling 14 | 15 | def __init__(self, root, imset='2017/train.txt', resolution='480p', single_object=False): 16 | self.root = root 17 | self.mask_dir = os.path.join(root, 'Annotations', resolution) 18 | self.mask480_dir = os.path.join(root, 'Annotations', '480p') 19 | self.image_dir = os.path.join(root, 'JPEGImages', resolution) 20 | _imset_dir = os.path.join(root, 'ImageSets') 21 | _imset_f = os.path.join(_imset_dir, imset) 22 | 23 | self.videos = [] 24 | self.num_frames = {} 25 | self.num_objects = {} 26 | self.shape = {} 27 | self.size_480p = {} 28 | with open(os.path.join(_imset_f), "r") as lines: 29 | for line in lines: 30 | _video = line.rstrip('\n') 31 | self.videos.append(_video) 32 | self.num_frames[_video] = len(glob.glob(os.path.join(self.image_dir, _video, '*.jpg'))) 33 | _mask = np.array(Image.open(os.path.join(self.mask_dir, _video, '00000.png')).convert("P")) 34 | self.num_objects[_video] = np.max(_mask) 35 | self.shape[_video] = np.shape(_mask) 36 | _mask480 = np.array(Image.open(os.path.join(self.mask480_dir, _video, '00000.png')).convert("P")) 37 | self.size_480p[_video] = np.shape(_mask480) 38 | 39 | self.K = 11 40 | self.single_object = single_object 41 | 42 | def __len__(self): 43 | return len(self.videos) 44 | 45 | 46 | def To_onehot(self, mask): 47 | M = np.zeros((self.K, mask.shape[0], mask.shape[1]), dtype=np.uint8) 48 | for k in range(self.K): 49 | M[k] = (mask == k).astype(np.uint8) 50 | return M 51 | 52 | def All_to_onehot(self, masks): 53 | Ms = np.zeros((self.K, masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8) 54 | for n in range(masks.shape[0]): 55 | Ms[:,n] = self.To_onehot(masks[n]) 56 | return Ms 57 | 58 | def __getitem__(self, index): 59 | video = self.videos[index] 60 | info = {} 61 | info['name'] = video 62 | info['num_frames'] = self.num_frames[video] 63 | info['size_480p'] = self.size_480p[video] 64 | 65 | N_frames = np.empty((self.num_frames[video],)+self.shape[video]+(3,), dtype=np.float32) 66 | N_masks = np.empty((self.num_frames[video],)+self.shape[video], dtype=np.uint8) 67 | for f in range(self.num_frames[video]): 68 | img_file = os.path.join(self.image_dir, video, '{:05d}.jpg'.format(f)) 69 | N_frames[f] = np.array(Image.open(img_file).convert('RGB'))/255. 70 | try: 71 | mask_file = os.path.join(self.mask_dir, video, '{:05d}.png'.format(f)) 72 | N_masks[f] = np.array(Image.open(mask_file).convert('P'), dtype=np.uint8) 73 | except: 74 | # print('a') 75 | N_masks[f] = 255 76 | 77 | Fs = torch.from_numpy(np.transpose(N_frames.copy(), (3, 0, 1, 2)).copy()).float() 78 | if self.single_object: 79 | N_masks = (N_masks > 0.5).astype(np.uint8) * (N_masks < 255).astype(np.uint8) 80 | Ms = torch.from_numpy(self.All_to_onehot(N_masks).copy()).float() 81 | num_objects = torch.LongTensor([int(1)]) 82 | return Fs, Ms, num_objects, info 83 | else: 84 | Ms = torch.from_numpy(self.All_to_onehot(N_masks).copy()).float() 85 | num_objects = torch.LongTensor([int(self.num_objects[video])]) 86 | return Fs, Ms, num_objects, info 87 | 88 | 89 | 90 | if __name__ == '__main__': 91 | pass 92 | -------------------------------------------------------------------------------- /dataset_loaders/ECSSD_MSRA10K.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | 4 | import os, math, random 5 | from os.path import join 6 | import numpy as np 7 | 8 | import cv2 9 | from .custom_transforms_MTB import aug_batch 10 | from PIL import Image 11 | import matplotlib.pyplot as plt 12 | 13 | 14 | class ECSSD(data.Dataset): 15 | def __init__(self, root='', replicates=1, aug=False): 16 | self.replicates = replicates 17 | self.aug = aug 18 | 19 | image_root = join(root, 'images') 20 | gt_root = join(root, 'saliencymaps') 21 | 22 | self.image_list = [] 23 | self.gt_list = [] 24 | files = sorted(os.listdir(image_root)) 25 | for i in range(len(files)): 26 | img = join(image_root, files[i]) 27 | gt = join(gt_root, files[i][:-4] + '.png') 28 | self.image_list += [img] 29 | self.gt_list += [gt] 30 | 31 | self.size = len(self.image_list) 32 | self.frame_size = cv2.imread(self.image_list[0], cv2.IMREAD_COLOR).shape 33 | 34 | assert (len(self.image_list) == len(self.gt_list)) 35 | 36 | def __getitem__(self, index): 37 | 38 | index = index % self.size 39 | img = cv2.imread(self.image_list[index], cv2.IMREAD_COLOR) 40 | 41 | gt = np.array(cv2.imread(self.gt_list[index], cv2.IMREAD_GRAYSCALE))#np.expand_dims(, axis=2) Image.open(self.gt_list[index]) 42 | gt[gt == 255] = 1 43 | 44 | image_size = img.shape[:2] 45 | 46 | images = [] 47 | segannos = [] 48 | for i in range(3): 49 | img_copy, gt_copy = np.copy(img), np.copy(gt) 50 | img_auged, gt_auged = aug_batch(img_copy, gt_copy) 51 | img_auged = img_auged.transpose(2, 0, 1) 52 | gt_auged = gt_auged.transpose(2, 0, 1) 53 | img_auged = torch.from_numpy(img_auged.astype(np.float32)) 54 | gt_auged = torch.from_numpy(gt_auged.astype(np.float32)) 55 | images.append(img_auged) 56 | segannos.append(gt_auged) 57 | 58 | images = torch.stack(images, dim=0).float().clamp(0, 1) 59 | segannos = torch.stack(segannos, dim=0).float() 60 | 61 | num_objects = int(segannos.max()) 62 | # save sample for checking, todo: need to delete 63 | if False: 64 | path_sample = '/home/cgv841/gwb/Code/agame-vos-master/ECSSD_sample/{}'.format(index) 65 | if not os.path.exists(path_sample): 66 | os.makedirs(path_sample) 67 | palette = Image.open( 68 | '/home/cgv841/gwb/DataSets/davis-2017/data/DAVIS/Annotations/480p/blackswan/00000.png').getpalette() 69 | for i in range(images.shape[0]): 70 | img, gt = 255 * images[i], segannos[i] 71 | img, gt = img.numpy().transpose((1, 2, 0)).astype(np.uint8), gt.numpy().transpose((1, 2, 0)).astype( 72 | np.uint8).squeeze() 73 | img, gt = Image.fromarray(img), Image.fromarray(gt) 74 | gt.putpalette(palette) 75 | img.save(os.path.join(path_sample, '{:05d}.jpg'.format(i))) 76 | gt.save(os.path.join(path_sample, '{:05d}.png'.format(i))) 77 | 78 | return {'images':images, 'segannos':segannos, 'seqname':'unknow', 'num_objects':num_objects} 79 | 80 | def __len__(self): 81 | return self.size * self.replicates 82 | 83 | 84 | class MSRA10K(data.Dataset): 85 | def __init__(self, root='', replicates=1, aug=False): 86 | self.replicates = replicates 87 | self.aug = aug 88 | 89 | image_root = join(root, 'images') 90 | gt_root = join(root, 'saliencymaps') 91 | 92 | self.image_list = [] 93 | self.gt_list = [] 94 | files = sorted(os.listdir(image_root)) 95 | for i in range(len(files)): 96 | img = join(image_root, files[i]) 97 | gt = join(gt_root, files[i][:-4] + '.png') 98 | self.image_list += [img] 99 | self.gt_list += [gt] 100 | 101 | self.size = len(self.image_list) 102 | self.frame_size = cv2.imread(self.image_list[0], cv2.IMREAD_COLOR).shape 103 | 104 | assert (len(self.image_list) == len(self.gt_list)) 105 | 106 | def __getitem__(self, index): 107 | 108 | index = index % self.size 109 | img = cv2.imread(self.image_list[index], cv2.IMREAD_COLOR) 110 | 111 | gt = np.array(cv2.imread(self.gt_list[index], cv2.IMREAD_GRAYSCALE))#, cv2.IMREAD_GRAYSCALE np.expand_dims(, axis=2) Image.open(self.gt_list[index]) 112 | #print('gt size:',gt.shape) 113 | #np.array(Image.open(self.gt_list[index]))#np.expand_dims(, axis=2) 114 | gt[gt != 255] = 0 115 | gt[gt == 255] = 1 116 | 117 | image_size = img.shape[:2] 118 | 119 | images = [] 120 | segannos = [] 121 | for i in range(3): 122 | img_copy, gt_copy = np.copy(img), np.copy(gt) 123 | img_auged, gt_auged = aug_batch(img_copy, gt_copy) 124 | img_auged = img_auged.transpose(2, 0, 1) 125 | gt_auged = gt_auged.transpose(2, 0, 1) 126 | img_auged = torch.from_numpy(img_auged.astype(np.float32)) 127 | gt_auged = torch.from_numpy(gt_auged.astype(np.float32)) 128 | images.append(img_auged) 129 | segannos.append(gt_auged) 130 | 131 | images = torch.stack(images, dim=0).float().clamp(0, 1) 132 | segannos = torch.stack(segannos, dim=0).float() 133 | 134 | num_objects = int(segannos.max()) 135 | # save sample for checking, todo: need to delete 136 | if False: 137 | path_sample = '/home/ubuntu/xiankai/STM_train_v1/MSRA10K_sample/{}'.format(index) 138 | if not os.path.exists(path_sample): 139 | os.makedirs(path_sample) 140 | palette = Image.open( 141 | '/raid/DAVIS/DAVIS-2017/DAVIS-train-val/Annotations/480p/blackswan/00000.png').getpalette() 142 | for i in range(images.shape[0]): 143 | img, gt = 255 * images[i], segannos[i] 144 | img, gt = img.numpy().transpose((1, 2, 0)).astype(np.uint8), gt.numpy().transpose((1, 2, 0)).astype( 145 | np.uint8).squeeze() 146 | img, gt = Image.fromarray(img), Image.fromarray(gt) 147 | gt.putpalette(palette) 148 | img.save(os.path.join(path_sample, '{:05d}.jpg'.format(i))) 149 | gt.save(os.path.join(path_sample, '{:05d}.png'.format(i))) 150 | 151 | return {'images': images, 'segannos': segannos, 'seqname': 'unknow', 'num_objects': num_objects} 152 | 153 | def __len__(self): 154 | return self.size * self.replicates -------------------------------------------------------------------------------- /dataset_loaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .davis17_v2 import DAVIS17V2 2 | from .davis17_v2_all import DAVIS17V2_all 3 | from .ytvos_v2 import YTVOSV2 4 | from .dataset_utils import * 5 | from .coco import SimpleCoCoDataset, SimpleSBDDataset 6 | from .ECSSD_MSRA10K import ECSSD, MSRA10K -------------------------------------------------------------------------------- /dataset_loaders/coco.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch 3 | from pycocotools.coco import COCO 4 | import os 5 | import numpy as np 6 | import random 7 | import torch 8 | import scipy.io as scio 9 | from torchvision import transforms 10 | from.custom_transforms_MTB import aug_batch 11 | 12 | class SimpleCoCoDataset(torch.utils.data.Dataset): 13 | def __init__(self, rootdir, set_name='val2017', transform=None, max_num_objects=3): 14 | self.rootdir, self.set_name = rootdir, set_name 15 | self.transform = transform 16 | self.coco = COCO(os.path.join(self.rootdir, 'annotations', 'instances_' 17 | + self.set_name + '.json')) 18 | self.image_ids = self.coco.getImgIds() 19 | self.load_classes() 20 | self._max_num_objects = max_num_objects 21 | 22 | def load_classes(self): 23 | categories = self.coco.loadCats(self.coco.getCatIds()) 24 | categories.sort(key=lambda x: x['id']) 25 | 26 | # coco ids is not from 1, and not continue 27 | # make a new index from 0 to 79, continuely 28 | 29 | # classes: {names: new_index} 30 | # coco_labels: {new_index: coco_index} 31 | # coco_labels_inverse: {coco_index: new_index} 32 | self.classes, self.coco_labels, self.coco_labels_inverse = {}, {}, {} 33 | for c in categories: 34 | self.coco_labels[len(self.classes)] = c['id'] 35 | self.coco_labels_inverse[c['id']] = len(self.classes) 36 | self.classes[c['name']] = len(self.classes) 37 | 38 | # labels: {new_index: names} 39 | self.labels = {} 40 | for k, v in self.classes.items(): 41 | self.labels[v] = k 42 | 43 | def __len__(self): 44 | return len(self.image_ids) 45 | 46 | def __getitem__(self, index): 47 | img, gt = self.load_image_anns(index) 48 | 49 | sample = {'image': img, 'gt': gt} 50 | 51 | images = [] 52 | segannos = [] 53 | for i in range(3): 54 | sample_copy = {'image': np.copy(sample['image']), 'gt': np.copy(sample['gt'])} 55 | sample_transformed = self.transform(sample_copy) 56 | images.append(sample_transformed['crop_image']) 57 | segannos.append(sample_transformed['crop_gt']) 58 | 59 | images = torch.stack(images, dim=0).float().clamp(0, 1) 60 | segannos = torch.stack(segannos, dim=0).float() 61 | num_objects = int(segannos.max()) 62 | #print('coco size:',images.size(), segannos.size()) 63 | # save sample for checking, todo: need to delete 64 | if False: 65 | path_coco_sample = '/raid/STM_train_v1/coco_sample/{}'.format(index) 66 | if not os.path.exists(path_coco_sample): 67 | os.makedirs(path_coco_sample) 68 | palette = Image.open( 69 | '/raid/DAVIS/DAVIS-2017/DAVIS-train-val/Annotations/480p/blackswan/00000.png').getpalette() 70 | for i in range(images.shape[0]): 71 | img, gt = 255*images[i], segannos[i] 72 | img, gt = img.numpy().transpose((1, 2, 0)).astype(np.uint8), gt.numpy().transpose((1, 2, 0)).astype(np.uint8).squeeze() 73 | img, gt = Image.fromarray(img), Image.fromarray(gt) 74 | gt.putpalette(palette) 75 | img.save(os.path.join(path_coco_sample, '{:05d}.jpg'.format(i))) 76 | gt.save(os.path.join(path_coco_sample, '{:05d}.png'.format(i))) 77 | 78 | return {'images':images, 'segannos':segannos, 'seqname':'unknow', 'num_objects':num_objects} 79 | 80 | def load_image_anns(self, index): 81 | image_info = self.coco.loadImgs(self.image_ids[index])[0] 82 | imgpath = os.path.join(self.rootdir, self.set_name, 83 | image_info['file_name']) 84 | img = np.array(Image.open(imgpath).convert('RGB'))/255. 85 | 86 | annotation_ids = self.coco.getAnnIds(self.image_ids[index], iscrowd=False) 87 | coco_anns = self.coco.loadAnns(annotation_ids) 88 | target = np.zeros([img.shape[0], img.shape[1]]).astype(np.float32) 89 | if len(coco_anns) == 0: 90 | return img, target 91 | coco_anns_sample = random.sample(coco_anns, min(self._max_num_objects, len(coco_anns))) 92 | for i, a in enumerate(coco_anns_sample): 93 | target[self.coco.annToMask(a) == 1] = i+1 94 | target = target[:, :, np.newaxis] 95 | 96 | return img, target 97 | 98 | def image_aspect_ratio(self, index): 99 | image = self.coco.loadImgs(self.image_ids[index])[0] 100 | return float(image['width']) / float(image['height']) 101 | 102 | class SimpleSBDDataset(torch.utils.data.Dataset): 103 | def __init__(self, rootdir, transform=None, max_num_objects=3): 104 | self.rootdir = rootdir 105 | self.transform = transform 106 | self.image_list = [] 107 | self.gt_list = [] 108 | files = sorted(os.listdir(rootdir + '/img')) 109 | for i in range(len(files)): 110 | img = os.path.join(rootdir + '/img', files[i]) 111 | gt = os.path.join(rootdir + '/inst', files[i][:-4] + '.mat') 112 | self.image_list += [img] 113 | self.gt_list += [gt] 114 | 115 | self._max_num_objects = max_num_objects 116 | 117 | def __len__(self): 118 | return len(self.image_list) 119 | 120 | def __getitem__(self, index): 121 | img, gt = self.load_image_anns(index) 122 | 123 | sample = {'image': img, 'gt': gt} 124 | 125 | images = [] 126 | segannos = [] 127 | for i in range(3): 128 | sample_copy = {'image': np.copy(sample['image']), 'gt': np.copy(sample['gt'])} 129 | sample_transformed = self.transform(sample_copy) 130 | images.append(sample_transformed['crop_image']) 131 | segannos.append(sample_transformed['crop_gt']) 132 | 133 | images = torch.stack(images, dim=0).float().clamp(0, 1) 134 | segannos = torch.stack(segannos, dim=0).float() 135 | num_objects = int(segannos.max()) 136 | # save sample for checking, todo: need to delete 137 | if False: 138 | path_coco_sample = '/home/cgv841/gwb/Code/agame-vos-master/SBD_sample/{}'.format(index) 139 | if not os.path.exists(path_coco_sample): 140 | os.makedirs(path_coco_sample) 141 | palette = Image.open('/home/cgv841/gwb/DataSets/davis-2017/data/DAVIS/Annotations/480p/blackswan/00000.png').getpalette() 142 | for i in range(images.shape[0]): 143 | img, gt = 255*images[i], segannos[i] 144 | img, gt = img.numpy().transpose((1, 2, 0)).astype(np.uint8), gt.numpy().transpose((1, 2, 0)).astype(np.uint8).squeeze() 145 | img, gt = Image.fromarray(img), Image.fromarray(gt) 146 | gt.putpalette(palette) 147 | img.save(os.path.join(path_coco_sample, '{:05d}.jpg'.format(i))) 148 | gt.save(os.path.join(path_coco_sample, '{:05d}.png'.format(i))) 149 | 150 | return {'images':images, 'segannos':segannos, 'seqname':'unknow', 'num_objects':num_objects} 151 | 152 | def load_image_anns(self, index): 153 | imgpath = self.image_list[index] 154 | img = np.array(Image.open(imgpath).convert('RGB'))/255. 155 | 156 | gtpath = self.gt_list[index] 157 | gt = scio.loadmat(gtpath) 158 | gt = gt['GTinst']['Segmentation'][0][0] 159 | 160 | old_idx_list = random.sample(range(1, gt.max()+1), min(self._max_num_objects, gt.max())) 161 | target = np.zeros([img.shape[0], img.shape[1]]).astype(np.float32) 162 | for i, old_idx in enumerate(old_idx_list): 163 | target[gt == old_idx] = i+1 164 | 165 | target = target[:, :, np.newaxis] 166 | 167 | return img, target 168 | -------------------------------------------------------------------------------- /dataset_loaders/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import torch, cv2 2 | 3 | import numpy.random as random 4 | import numpy as np 5 | import dataset_loaders.helpers as helpers 6 | 7 | 8 | class ScaleNRotate(object): 9 | """Scale (zoom-in, zoom-out) and Rotate the image and the ground truth. 10 | Args: 11 | two possibilities: 12 | 1. rots (tuple): (minimum, maximum) rotation angle 13 | scales (tuple): (minimum, maximum) scale 14 | 2. rots [list]: list of fixed possible rotation angles 15 | scales [list]: list of fixed possible scales 16 | """ 17 | def __init__(self, rots=(-30, 30), scales=(.75, 1.25), semseg=False): 18 | assert (isinstance(rots, type(scales))) 19 | self.rots = rots 20 | self.scales = scales 21 | self.semseg = semseg 22 | 23 | def __call__(self, sample): 24 | 25 | if type(self.rots) == tuple: 26 | # Continuous range of scales and rotations 27 | rot = (self.rots[1] - self.rots[0]) * random.random() - \ 28 | (self.rots[1] - self.rots[0])/2 29 | 30 | sc = (self.scales[1] - self.scales[0]) * random.random() - \ 31 | (self.scales[1] - self.scales[0]) / 2 + 1 32 | elif type(self.rots) == list: 33 | # Fixed range of scales and rotations 34 | rot = self.rots[random.randint(0, len(self.rots))] 35 | sc = self.scales[random.randint(0, len(self.scales))] 36 | 37 | for elem in sample.keys(): 38 | if 'meta' in elem: 39 | continue 40 | 41 | tmp = sample[elem] 42 | 43 | h, w = tmp.shape[:2] 44 | center = (w / 2, h / 2) 45 | assert(center != 0) # Strange behaviour warpAffine 46 | M = cv2.getRotationMatrix2D(center, rot, sc) 47 | 48 | if ((tmp == 0) | (tmp == 1)).all(): 49 | flagval = cv2.INTER_NEAREST 50 | elif 'gt' in elem and self.semseg: 51 | flagval = cv2.INTER_NEAREST 52 | else: 53 | flagval = cv2.INTER_CUBIC 54 | tmp = cv2.warpAffine(tmp, M, (w, h), flags=flagval) 55 | 56 | sample[elem] = tmp 57 | 58 | return sample 59 | 60 | def __str__(self): 61 | return 'ScaleNRotate:(rot='+str(self.rots)+',scale='+str(self.scales)+')' 62 | 63 | 64 | class FixedResize(object): 65 | """Resize the image and the ground truth to specified resolution. 66 | Args: 67 | resolutions (dict): the list of resolutions 68 | """ 69 | def __init__(self, resolutions=None, flagvals=None): 70 | self.resolutions = resolutions 71 | self.flagvals = flagvals 72 | if self.flagvals is not None: 73 | assert(len(self.resolutions) == len(self.flagvals)) 74 | 75 | def __call__(self, sample): 76 | 77 | # Fixed range of scales 78 | if self.resolutions is None: 79 | return sample 80 | 81 | elems = list(sample.keys()) 82 | 83 | for elem in elems: 84 | 85 | if 'meta' in elem or 'bbox' in elem or ('extreme_points_coord' in elem and elem not in self.resolutions): 86 | continue 87 | if 'extreme_points_coord' in elem and elem in self.resolutions: 88 | bbox = sample['bbox'] 89 | crop_size = np.array([bbox[3]-bbox[1]+1, bbox[4]-bbox[2]+1]) 90 | res = np.array(self.resolutions[elem]).astype(np.float32) 91 | sample[elem] = np.round(sample[elem]*res/crop_size).astype(np.int) 92 | continue 93 | if elem in self.resolutions: 94 | if self.resolutions[elem] is None: 95 | continue 96 | if isinstance(sample[elem], list): 97 | if sample[elem][0].ndim == 3: 98 | output_size = np.append(self.resolutions[elem], [3, len(sample[elem])]) 99 | else: 100 | output_size = np.append(self.resolutions[elem], len(sample[elem])) 101 | tmp = sample[elem] 102 | sample[elem] = np.zeros(output_size, dtype=np.float32) 103 | for ii, crop in enumerate(tmp): 104 | if self.flagvals is None: 105 | sample[elem][..., ii] = helpers.fixed_resize(crop, self.resolutions[elem]) 106 | else: 107 | sample[elem][..., ii] = helpers.fixed_resize(crop, self.resolutions[elem], flagval=self.flagvals[elem]) 108 | else: 109 | if self.flagvals is None: 110 | sample[elem] = helpers.fixed_resize(sample[elem], self.resolutions[elem]) 111 | else: 112 | sample[elem] = helpers.fixed_resize(sample[elem], self.resolutions[elem], flagval=self.flagvals[elem]) 113 | else: 114 | del sample[elem] 115 | 116 | return sample 117 | 118 | def __str__(self): 119 | return 'FixedResize:'+str(self.resolutions) 120 | 121 | 122 | class RandomHorizontalFlip(object): 123 | """Horizontally flip the given image and ground truth randomly with a probability of 0.5.""" 124 | 125 | def __call__(self, sample): 126 | 127 | if random.random() < 0.5: 128 | for elem in sample.keys(): 129 | if 'meta' in elem: 130 | continue 131 | tmp = sample[elem] 132 | tmp = cv2.flip(tmp, flipCode=1) 133 | sample[elem] = tmp 134 | 135 | return sample 136 | 137 | def __str__(self): 138 | return 'RandomHorizontalFlip' 139 | 140 | 141 | class ExtremePoints(object): 142 | """ 143 | Returns the four extreme points (left, right, top, bottom) (with some random perturbation) in a given binary mask 144 | sigma: sigma of Gaussian to create a heatmap from a point 145 | pert: number of pixels fo the maximum perturbation 146 | elem: which element of the sample to choose as the binary mask 147 | """ 148 | def __init__(self, sigma=10, pert=0, elem='gt'): 149 | self.sigma = sigma 150 | self.pert = pert 151 | self.elem = elem 152 | 153 | def __call__(self, sample): 154 | if sample[self.elem].ndim == 3: 155 | raise ValueError('ExtremePoints not implemented for multiple object per image.') 156 | _target = sample[self.elem] 157 | if np.max(_target) == 0: 158 | sample['extreme_points'] = np.zeros(_target.shape, dtype=_target.dtype) # TODO: handle one_mask_per_point case 159 | else: 160 | _points = helpers.extreme_points(_target, self.pert) 161 | sample['extreme_points'] = helpers.make_gt(_target, _points, sigma=self.sigma, one_mask_per_point=False) 162 | 163 | return sample 164 | 165 | def __str__(self): 166 | return 'ExtremePoints:(sigma='+str(self.sigma)+', pert='+str(self.pert)+', elem='+str(self.elem)+')' 167 | 168 | 169 | class ConcatInputs(object): 170 | 171 | def __init__(self, elems=('image', 'point')): 172 | self.elems = elems 173 | 174 | def __call__(self, sample): 175 | 176 | res = sample[self.elems[0]] 177 | 178 | for elem in self.elems[1:]: 179 | assert(sample[self.elems[0]].shape[:2] == sample[elem].shape[:2]) 180 | 181 | # Check if third dimension is missing 182 | tmp = sample[elem] 183 | if tmp.ndim == 2: 184 | tmp = tmp[:, :, np.newaxis] 185 | 186 | res = np.concatenate((res, tmp), axis=2) 187 | 188 | sample['concat'] = res 189 | 190 | return sample 191 | 192 | def __str__(self): 193 | return 'ExtremePoints:'+str(self.elems) 194 | 195 | 196 | class CropFromMask(object): 197 | """ 198 | Returns image cropped in bounding box from a given mask 199 | """ 200 | def __init__(self, crop_elems=('image', 'gt'), 201 | mask_elem='gt', 202 | relax=0, 203 | zero_pad=False): 204 | 205 | self.crop_elems = crop_elems 206 | self.mask_elem = mask_elem 207 | self.relax = relax 208 | self.zero_pad = zero_pad 209 | 210 | def __call__(self, sample): 211 | _target = sample[self.mask_elem] 212 | if _target.ndim == 2: 213 | _target = np.expand_dims(_target, axis=-1) 214 | for elem in self.crop_elems: 215 | _img = sample[elem] 216 | _crop = [] 217 | if self.mask_elem == elem: 218 | if _img.ndim == 2: 219 | _img = np.expand_dims(_img, axis=-1) 220 | for k in range(0, _target.shape[-1]): 221 | _tmp_img = _img[..., k] 222 | _tmp_target = _target[..., k] 223 | if np.max(_target[..., k]) == 0: 224 | _crop.append(np.zeros(_tmp_img.shape, dtype=_img.dtype)) 225 | else: 226 | _crop.append(helpers.crop_from_mask(_tmp_img, _tmp_target, relax=self.relax, zero_pad=self.zero_pad)) 227 | else: 228 | for k in range(0, _target.shape[-1]): 229 | if np.max(_target[..., k]) == 0: 230 | _crop.append(np.zeros(_img.shape, dtype=_img.dtype)) 231 | else: 232 | _tmp_target = _target[..., k] 233 | _crop.append(helpers.crop_from_mask(_img, _tmp_target, relax=self.relax, zero_pad=self.zero_pad)) 234 | if len(_crop) == 1: 235 | sample['crop_' + elem] = _crop[0] 236 | else: 237 | sample['crop_' + elem] = _crop 238 | return sample 239 | 240 | def __str__(self): 241 | return 'CropFromMask:(crop_elems='+str(self.crop_elems)+', mask_elem='+str(self.mask_elem)+\ 242 | ', relax='+str(self.relax)+',zero_pad='+str(self.zero_pad)+')' 243 | 244 | 245 | class ToImage(object): 246 | """ 247 | Return the given elements between 0 and 255 248 | """ 249 | def __init__(self, norm_elem='image', custom_max=255.): 250 | self.norm_elem = norm_elem 251 | self.custom_max = custom_max 252 | 253 | def __call__(self, sample): 254 | if isinstance(self.norm_elem, tuple): 255 | for elem in self.norm_elem: 256 | tmp = sample[elem] 257 | sample[elem] = self.custom_max * (tmp - tmp.min()) / (tmp.max() - tmp.min() + 1e-10) 258 | else: 259 | tmp = sample[self.norm_elem] 260 | sample[self.norm_elem] = self.custom_max * (tmp - tmp.min()) / (tmp.max() - tmp.min() + 1e-10) 261 | return sample 262 | 263 | def __str__(self): 264 | return 'NormalizeImage' 265 | 266 | 267 | class ToTensor(object): 268 | """Convert ndarrays in sample to Tensors.""" 269 | 270 | def __call__(self, sample): 271 | 272 | for elem in sample.keys(): 273 | if 'meta' in elem: 274 | continue 275 | elif 'bbox' in elem: 276 | tmp = sample[elem] 277 | sample[elem] = torch.from_numpy(tmp) 278 | continue 279 | 280 | tmp = sample[elem] 281 | 282 | if tmp.ndim == 2: 283 | tmp = tmp[:, :, np.newaxis] 284 | 285 | # swap color axis because 286 | # numpy image: H x W x C 287 | # torch image: C X H X W 288 | tmp = tmp.transpose((2, 0, 1)) 289 | sample[elem] = torch.from_numpy(tmp) 290 | 291 | return sample 292 | 293 | def __str__(self): 294 | return 'ToTensor' -------------------------------------------------------------------------------- /dataset_loaders/custom_transforms_MTB.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import torch.nn as nn 4 | 5 | import random 6 | from os.path import join 7 | import numpy as np 8 | import numpy 9 | #numpy.random.bit_generator = numpy.random._bit_generator 10 | import cv2 11 | import imgaug as ia 12 | from imgaug import augmenters as iaa 13 | 14 | 15 | def outS(i): 16 | """Given shape of input image as i,i,3 in deeplab-resnet model, this function 17 | returns j such that the shape of output blob of is j,j,21 (21 in case of VOC)""" 18 | j = int(i) 19 | j = (j + 1) / 2 20 | j = int(np.ceil((j + 1) / 2.0)) 21 | j = (j + 1) / 2 22 | return int(j) 23 | 24 | 25 | def resize_label_batch(label, size): 26 | #print('label shape:',label.shape,np.max(label)) 27 | label_resized = np.zeros((size, size, 1, label.shape[2])) 28 | interp = nn.Upsample(size=(size, size), mode='bilinear') 29 | labelVar = torch.from_numpy(label.transpose(3, 2, 0, 1)) 30 | #print('shape:',labelVar.size(),label_resized.shape) 31 | label_resized[:, :, :, :] = interp(labelVar).data.numpy().transpose(2, 3, 0, 1) 32 | label_resized[label_resized > 0.3] = 1 33 | label_resized[label_resized != 0] = 1 34 | 35 | return label_resized 36 | 37 | 38 | def flip(I, flip_p): 39 | if flip_p > 0.5: 40 | return np.fliplr(I) 41 | else: 42 | return I 43 | 44 | 45 | def aug_batch(img, gt): 46 | sometimes = lambda aug: iaa.Sometimes(0.5, aug) 47 | sometimes2 = lambda aug: iaa.Sometimes(0.9, aug) 48 | 49 | seq = iaa.Sequential( 50 | [ 51 | sometimes(iaa.Affine( 52 | scale={"x": (0.8, 1.2), "y": (0.8, 1.2)}, 53 | translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)}, # translate by -20 to +20 percent (per axis) 54 | rotate=(-45, 45), # rotate by -45 to +45 degrees 55 | shear=(-16, 16), # shear by -16 to +16 degrees 56 | order=[0, 1], # use nearest neighbour or bilinear interpolation (fast) 57 | cval=(0, 255), # if mode is constant, use a cval between 0 and 255 58 | mode=ia.ALL # use any of scikit-image's warping modes (see 2nd image from the top for examples) 59 | )), 60 | iaa.Add((-10, 10), per_channel=0.5), 61 | sometimes(iaa.PerspectiveTransform(scale=(0.01, 0.1))) 62 | ], random_order=True 63 | ) 64 | 65 | seq2 = iaa.Sequential( 66 | [ 67 | sometimes2(iaa.Affine( 68 | scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, 69 | translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)}, # translate by -20 to +20 percent (per axis) 70 | rotate=(-10, 10), # rotate by -45 to +45 degrees 71 | shear=(-10, 10), # shear by -16 to +16 degrees 72 | order=0, # use nearest neighbour or bilinear interpolation (fast) 73 | cval=(0, 255), # if mode is constant, use a cval between 0 and 255 74 | mode=ia.ALL # use any of scikit-image's warping modes (see 2nd image from the top for examples) 75 | )), 76 | # sometimes2(iaa.CoarseDropout(0.2, size_percent=(0.1, 0.5) 77 | # )) 78 | ], random_order=True 79 | ) 80 | scale = random.uniform(0.5, 81 | 1.3) # random.uniform(0.5,1.5) does not fit in a Titan X with the present version of pytorch, so we random scaling in the range (0.5,1.3), different than caffe implementation in that caffe used only 4 fixed scales. Refer to read me 82 | scale = 1 83 | dim = int(scale * 384) 84 | 85 | flip_p = random.uniform(0, 1) 86 | 87 | img_temp = flip(img, flip_p) 88 | gt_temp = flip(gt, flip_p) 89 | 90 | seq_det = seq.to_deterministic() 91 | img_temp = seq_det.augment_image(img_temp) 92 | img_temp = cv2.cvtColor(img_temp, cv2.COLOR_BGR2RGB).astype(float) / 255. 93 | img_temp = cv2.resize(img_temp, (dim, dim)) 94 | 95 | gt_temp = ia.SegmentationMapOnImage(gt_temp, shape=gt_temp.shape, nb_classes=2) 96 | gt_temp_map = seq_det.augment_segmentation_maps([gt_temp])[0] 97 | gt_temp = gt_temp_map.get_arr_int().astype(float) 98 | mask = seq2.augment_segmentation_maps([gt_temp_map])[0].get_arr_int() 99 | mask = cv2.resize(mask, (dim, dim), interpolation=cv2.INTER_NEAREST).astype(float) 100 | 101 | kernel = np.ones((int(scale * 5), int(scale * 5)), np.uint8) 102 | 103 | #bb = cv2.boundingRect(gt_temp.astype('uint8')) 104 | #if bb[2] != 0 and bb[3] != 0: 105 | # fc = np.ones([dim, dim, 1]) * -100 / 255. 106 | # fc[bb[1]:bb[1]+bb[3], bb[0]:bb[0]+bb[2], 0] = 100 107 | # if flip_p <= 1.0: 108 | # aug_p = random.uniform(0, 1) 109 | # it = random.randint(1, 5) 110 | 111 | # aug = np.expand_dims(cv2.dilate(mask, kernel, iterations=it), 2) 112 | # fc[np.where(aug == 1)] = 100 / 255. 113 | #else: 114 | # fc = np.ones([dim, dim, 1]) * -100 / 255. 115 | # image = np.dstack([img_temp, fc]) 116 | gt_temp = np.expand_dims(gt_temp, 2) 117 | image = img_temp 118 | gt = np.expand_dims(gt_temp,3) 119 | label = resize_label_batch(gt, dim) 120 | #print('label size:',label.shape) 121 | label = label.squeeze(2) 122 | 123 | return image, label 124 | 125 | 126 | def aug_pair(img_template, img_search, gt_template, gt_search): 127 | sometimes = lambda aug: iaa.Sometimes(0.5, aug) 128 | sometimes2 = lambda aug: iaa.Sometimes(0.9, aug) 129 | 130 | seq = iaa.Sequential( 131 | [ 132 | sometimes(iaa.Affine( 133 | scale={"x": (2 ** (-1 / 8), 2 ** (1 / 8)), "y": (2 ** (-1 / 8), 2 ** (1 / 8))}, 134 | translate_px={"x": (-8, 8), "y": (-8, 8)}, # translate by -20 to +20 percent (per axis) 135 | cval=(0, 0), # if mode is constant, use a cval between 0 and 255 136 | mode='edge' # use any of scikit-image's warping modes (see 2nd image from the top for examples) 137 | )), 138 | iaa.Add((-10, 10), per_channel=0.5), 139 | ], random_order=True 140 | ) 141 | 142 | seq2 = iaa.Sequential( 143 | [ 144 | sometimes2(iaa.Affine( 145 | scale={"x": (0.98, 1.02), "y": (0.98, 1.02)}, 146 | translate_percent={"x": (-0.05, 0.05), "y": (-0.05, 0.05)}, 147 | # translate by -20 to +20 percent (per axis) 148 | rotate=(-5, 5), # rotate by -45 to +45 degrees 149 | shear=(-5, 5), # shear by -16 to +16 degrees 150 | order=0, # use nearest neighbour or bilinear interpolation (fast) 151 | cval=(0, 255), # if mode is constant, use a cval between 0 and 255 152 | mode=ia.ALL # use any of scikit-image's warping modes (see 2nd image from the top for examples) 153 | )), 154 | # sometimes2(iaa.CoarseDropout(0.2, size_percent=(0.1, 0.5) 155 | # )) 156 | ], random_order=True 157 | ) 158 | 159 | scale = 1 160 | dim = int(scale * 328) 161 | 162 | flip_p = random.uniform(0, 1) 163 | 164 | img_template = flip(img_template, flip_p) 165 | gt_template = flip(gt_template, flip_p) 166 | img_search = flip(img_search, flip_p) 167 | gt_search = flip(gt_search, flip_p) 168 | 169 | # process template 170 | bb = cv2.boundingRect(gt_template) 171 | if bb[2] != 0 and bb[3] != 0: 172 | template = crop_and_padding(img_template, gt_template, (dim, dim)) 173 | t_h, t_w, _ = img_template.shape 174 | # fg = np.ones([img_template.shape[1], img_template.shape[1], 1]) * 100 175 | template_mask = crop_and_padding(gt_template, gt_template, (dim, dim)) 176 | bb = cv2.boundingRect(template_mask) 177 | template_mask = np.zeros([dim, dim, 1]) 178 | if bb[2] != 0 and bb[3] != 0: 179 | template_mask[bb[1]:bb[1] + bb[3] + 1, bb[0]:bb[0] + bb[2] + 1, 0] = 1 180 | else: 181 | template = np.zeros([dim, dim, 3]).astype('uint8') 182 | template_mask = np.zeros([dim, dim, 1]) 183 | 184 | template = cv2.cvtColor(template, cv2.COLOR_BGR2RGB).astype(float) / 255. 185 | 186 | # Augment Search Image 187 | # img_search_o = crop_and_padding(img_search, gt_search, (dim, dim)).astype('uint8') 188 | # gt_search_o = crop_and_padding(gt_search, gt_search, (dim, dim)).astype('uint8') 189 | img_search_o = img_search.copy() 190 | gt_search_o = gt_search.copy() 191 | 192 | for i in range(10): 193 | seq_det = seq.to_deterministic() 194 | img_search = seq_det.augment_image(img_search_o) 195 | img_search = cv2.cvtColor(img_search, cv2.COLOR_BGR2RGB).astype(float) / 255. 196 | 197 | gt_search = ia.SegmentationMapOnImage(gt_search_o, shape=gt_search_o.shape, nb_classes=2) 198 | gt_search_map = seq_det.augment_segmentation_maps([gt_search])[0] 199 | gt_search = gt_search_map.get_arr_int() 200 | bb = cv2.boundingRect(gt_search.astype('uint8')) 201 | if bb[2] > 30 and bb[3] > 30: 202 | break 203 | else: 204 | img_search = img_search_o 205 | gt_search = gt_search_o 206 | 207 | img_search = cv2.resize(img_search, (dim, dim)) 208 | gt_search = cv2.resize(gt_search.astype('uint8'), (dim, dim), cv2.INTER_NEAREST) 209 | 210 | kernel = np.ones((int(scale * 5), int(scale * 5)), np.uint8) 211 | for i in range(10): 212 | mask = seq2.augment_segmentation_maps([gt_search_map])[0].get_arr_int().astype(float) 213 | bb = cv2.boundingRect(gt_search.astype('uint8')) 214 | if bb[2] > 10 and bb[3] > 10: 215 | break 216 | else: 217 | mask = gt_search.copy() 218 | 219 | fc = np.zeros([dim, dim, 1]) 220 | fc = mask.copy() 221 | if flip_p <= 0.8: 222 | aug_p = random.uniform(0, 1) 223 | it = random.randint(1, 3) 224 | 225 | aug = np.expand_dims(cv2.dilate(mask, kernel, iterations=it), 2) 226 | fc = aug 227 | # fc[np.where(aug==1)] = 1 228 | else: 229 | fc = np.expand_dims(fc, 2) 230 | fc = cv2.resize(fc, (dim, dim), cv2.INTER_NEAREST) 231 | fc = np.expand_dims(fc, 2) 232 | 233 | gt_search = np.expand_dims(gt_search, 2) 234 | gt = np.expand_dims(gt_search, 3) 235 | label = resize_label_batch(gt.astype(float), dim // 2) 236 | label = label.squeeze(3) 237 | 238 | return img_search, fc, template, template_mask, label 239 | 240 | 241 | def aug_mask_nodeform(img_template, img_search, gt_template, gt_search, p_mask): 242 | sometimes = lambda aug: iaa.Sometimes(0.8, aug) 243 | 244 | seq = iaa.Sequential( 245 | [ 246 | sometimes(iaa.Affine( 247 | scale={"x": (2 ** (-1 / 8), 2 ** (1 / 8)), "y": (2 ** (-1 / 8), 2 ** (1 / 8))}, 248 | translate_px={"x": (-8, 8), "y": (-8, 8)}, # translate by -20 to +20 percent (per axis) 249 | cval=(0, 0), # if mode is constant, use a cval between 0 and 255 250 | mode='edge' # use any of scikit-image's warping modes (see 2nd image from the top for examples) 251 | )), 252 | iaa.Add((-10, 10), per_channel=0.5), 253 | ], random_order=True 254 | ) 255 | 256 | scale = 1 257 | dim = int(scale * 328) 258 | 259 | # Create Template Image 260 | flip_p = random.uniform(0, 1) 261 | img_template = flip(img_template, flip_p) 262 | gt_template = flip(gt_template, flip_p) 263 | 264 | img_template = cv2.cvtColor(img_template, cv2.COLOR_BGR2RGB) 265 | target = crop_and_padding(img_template, gt_template, (dim, dim)) 266 | mask = crop_and_padding(gt_template, gt_template, (dim, dim)) 267 | 268 | seq_det = seq.to_deterministic() 269 | target = seq_det.augment_image(target).astype(float) / 255. 270 | mask_map = ia.SegmentationMapOnImage(mask, shape=mask.shape, nb_classes=2) 271 | mask = seq_det.augment_segmentation_maps([mask_map])[0].get_arr_int() 272 | bb = cv2.boundingRect(mask.astype('uint8')) 273 | template_mask = np.zeros(mask.shape) 274 | template_mask[bb[1]:bb[1] + bb[3], bb[0]:bb[0] + bb[2]] = 1 275 | 276 | # Create Search Image 277 | flip_p = random.uniform(0, 1) 278 | img_search = flip(img_search, flip_p) 279 | gt_search = flip(gt_search, flip_p) 280 | p_mask = flip(p_mask, flip_p) 281 | 282 | img_search = cv2.cvtColor(img_search, cv2.COLOR_BGR2RGB) 283 | img_search_o = crop_and_padding(img_search, p_mask, (dim, dim)) 284 | gt_search_o = crop_and_padding(gt_search, p_mask, (dim, dim)) 285 | p_mask_o = crop_and_padding(p_mask, p_mask, (dim, dim)) 286 | 287 | if len(np.unique(gt_search_o)) == 1: 288 | img_search_o = crop_and_padding(img_search, gt_search, (dim, dim)) 289 | gt_search_o = crop_and_padding(gt_search, gt_search, (dim, dim)) 290 | p_mask = gt_search 291 | for i in range(10): 292 | seq_det = seq.to_deterministic() 293 | img_search = seq_det.augment_image(img_search_o).astype(float) / 255. 294 | 295 | gt_searchmap = ia.SegmentationMapOnImage(gt_search_o, shape=gt_search.shape, nb_classes=2) 296 | gt_search = seq_det.augment_segmentation_maps([gt_searchmap])[0].get_arr_int().astype(float) 297 | mask_map = ia.SegmentationMapOnImage(p_mask_o, shape=p_mask.shape, nb_classes=2) 298 | p_mask = seq_det.augment_segmentation_maps([mask_map])[0].get_arr_int().astype(float) 299 | if bb[2] > 10 and bb[3] > 10: 300 | break 301 | 302 | kernel = np.ones((int(scale * 3), int(scale * 3)), np.uint8) 303 | it = random.randint(1, 3) 304 | p_mask = np.expand_dims(cv2.dilate(p_mask.astype(float), kernel, iterations=it), 2) 305 | template_mask = np.expand_dims(template_mask, 2) 306 | 307 | # image = np.dstack([img_search, p_mask]) 308 | gt_search = np.expand_dims(gt_search, 2) 309 | gt_search = np.expand_dims(gt_search, 3) 310 | label = resize_label_batch(gt_search, dim // 2) 311 | label = label.squeeze(3) 312 | 313 | return img_search, fc, template, template_mask, label -------------------------------------------------------------------------------- /dataset_loaders/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | random.seed() 5 | from itertools import accumulate 6 | import bisect 7 | from PIL import PILLOW_VERSION, Image 8 | import numpy as np 9 | 10 | env_path = os.path.join(os.path.dirname(__file__), '..') 11 | if env_path not in sys.path: 12 | sys.path.append(env_path) 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | import torch.utils.data 17 | import torchvision as tv 18 | 19 | import utils 20 | 21 | IMAGENET_MEAN = [.485,.456,.406] 22 | IMAGENET_STD = [.229,.224,.225] 23 | 24 | class LabelToLongTensor(object): 25 | """From Tiramisu github""" 26 | def __call__(self, pic): 27 | if isinstance(pic, np.ndarray): 28 | # handle numpy array 29 | label = torch.from_numpy(pic).long() 30 | elif pic.mode == '1': 31 | label = torch.from_numpy(np.array(pic, np.uint8, copy=False)).long().view(1, pic.size[1], pic.size[0]) 32 | else: 33 | label = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 34 | if pic.mode == 'LA': # Hack to remove alpha channel if it exists 35 | label = label.view(pic.size[1], pic.size[0], 2) 36 | label = label.transpose(0, 1).transpose(0, 2).contiguous().long()[0] 37 | label = label.view(1, label.size(0), label.size(1)) 38 | else: 39 | label = label.view(pic.size[1], pic.size[0], -1) 40 | label = label.transpose(0, 1).transpose(0, 2).contiguous().long() 41 | return label 42 | 43 | class LabelLongTensorToFloat(object): 44 | def __call__(self, label): 45 | return label.float() 46 | 47 | class PadToDivisible(object): 48 | def __init__(self, divisibility): 49 | self.div = divisibility 50 | 51 | def __call__(self, tensor): 52 | size = tensor.size() 53 | assert tensor.dim() == 4 54 | height, width = size[-2:] 55 | height_pad = (self.div - height % self.div) % self.div 56 | width_pad = (self.div - width % self.div) % self.div 57 | padding = [(width_pad+1)//2, width_pad//2, (height_pad+1)//2, height_pad//2] 58 | tensor = F.pad(tensor, padding, mode='reflect') 59 | return tensor, padding 60 | 61 | class JointCompose(object): 62 | def __init__(self, transforms): 63 | self.transforms = transforms 64 | 65 | def __call__(self, images, labels): 66 | for t in self.transforms: 67 | images, labels = t(images, labels) 68 | return images, labels 69 | 70 | def __repr__(self): 71 | format_string = self.__class__.__name__ + '(' 72 | for t in self.transforms: 73 | format_string += '\n' 74 | format_string += ' {0}'.format(t) 75 | format_string += '\n)' 76 | return format_string 77 | 78 | class JointRandomHorizontalFlip(object): 79 | def __call__(self, *args): 80 | if random.choice([True, False]): 81 | out = [] 82 | for tensor in args: 83 | idx = [i for i in range(tensor.size(-1)-1, -1, -1)] 84 | idx = torch.LongTensor(idx) 85 | tensor_flip = tensor.index_select(-1, idx) 86 | out.append(tensor_flip) 87 | return out 88 | else: 89 | return args 90 | 91 | 92 | def centercrop(tensor, cropsize): 93 | _, _, H, W = tensor.size() 94 | A, B = cropsize 95 | # print((H,W), (A,B), (H-A)//2, (H+A)//2 96 | return tensor[:,:,(H-A)//2:(H+A)//2,(W-B)//2:(W+B)//2] 97 | 98 | class JointRandomScale(object): 99 | def __call__(self, images, labels): 100 | L, _, H, W = images.size() 101 | scales = ((1.0 + (torch.rand(1) < .5).float()*torch.rand(1)*.1)*torch.ones(L)).cumprod(0).tolist() 102 | images = torch.cat([centercrop(F.interpolate(images[l:l+1,:,:,:], scale_factor=scales[l], mode='bilinear', align_corners=False), (H, W)) for l in range(L)], dim=0) 103 | labels = torch.cat([centercrop(F.interpolate(labels[l,:,:].view(1,1,H,W).float(), scale_factor=scales[l], mode='nearest').long(), (H,W)) for l in range(L)], dim=0).view(L,1,H,W) 104 | return images, labels 105 | 106 | def centercrop(tensor, cropsize): 107 | _, _, H, W = tensor.size() 108 | A, B = cropsize 109 | # print((H,W), (A,B), (H-A)//2, (H+A)//2 110 | return tensor[:,:,(H-A)//2:(H+A)//2,(W-B)//2:(W+B)//2] 111 | 112 | class JointRandomScale(object): 113 | def __call__(self, images, labels): 114 | L, _, H, W = images.size() 115 | scales = ((1.0 + (torch.rand(1) < .5).float()*torch.rand(1)*.1)*torch.ones(L)).cumprod(0).tolist() 116 | images = torch.cat([centercrop(F.interpolate(images[l:l+1,:,:,:], scale_factor=scales[l], mode='bilinear', align_corners=False), (H, W)) for l in range(L)], dim=0) 117 | labels = torch.cat([centercrop(F.interpolate(labels[l,:,:].view(1,1,H,W).float(), scale_factor=scales[l], mode='nearest').long(), (H,W)) for l in range(L)], dim=0).view(L,1,H,W) 118 | return images, labels -------------------------------------------------------------------------------- /dataset_loaders/davis17_v2.py: -------------------------------------------------------------------------------- 1 | import random 2 | import glob 3 | import os 4 | import json 5 | from collections import OrderedDict 6 | import numpy as np 7 | from PIL import Image 8 | import torch 9 | import torchvision as tv 10 | from dataset_loaders import custom_transforms as tr 11 | from dataset_loaders import dataset_utils 12 | import utils 13 | 14 | 15 | def get_sample_bernoulli(p): 16 | return (lambda lst: [elem for elem in lst if random.random() < p]) 17 | def get_sample_all(): 18 | return (lambda lst: lst) 19 | def get_sample_k_random(k): 20 | return (lambda lst: sorted(random.sample(lst, min(k,len(lst))))) 21 | 22 | def get_anno_ids(anno_path, pic_to_tensor_function, threshold): 23 | pic = Image.open(anno_path) 24 | tensor = pic_to_tensor_function(pic) 25 | values = (tensor.view(-1).bincount() > threshold).nonzero().view(-1).tolist() 26 | if 0 in values: values.remove(0) 27 | if 255 in values: values.remove(255) 28 | return values 29 | 30 | def get_default_image_read(size=(240,432)): 31 | def image_read(path): 32 | pic = Image.open(path) 33 | transform = tv.transforms.Compose( 34 | [tv.transforms.Resize(size, interpolation=Image.BILINEAR), 35 | tv.transforms.ToTensor(), 36 | tv.transforms.Normalize(mean=dataset_utils.IMAGENET_MEAN, std=dataset_utils.IMAGENET_STD)]) 37 | return transform(pic) 38 | return image_read 39 | def get_default_anno_read(size=(240,432)): 40 | def label_read(path): 41 | if os.path.exists(path): 42 | pic = Image.open(path) 43 | transform = tv.transforms.Compose( 44 | [tv.transforms.Resize(size, interpolation=Image.NEAREST), 45 | dataset_utils.LabelToLongTensor()]) 46 | label = transform(pic) 47 | else: 48 | label = torch.LongTensor(1,*size).fill_(255) # Put label that will be ignored 49 | return label 50 | return label_read 51 | 52 | class DAVIS17V2(torch.utils.data.Dataset): 53 | def __init__(self, root_path, version, image_set, image_read=get_default_image_read(), 54 | anno_read=get_default_anno_read(), image_label_read=None, 55 | joint_transform=None, samplelen=4, obj_selection=get_sample_all(), min_num_obj=1, start_frame='random', max_skip=None, load_all=False): 56 | self._min_num_objects = min_num_obj 57 | self._root_path = root_path 58 | self._version = version 59 | self._image_set = image_set 60 | self._image_read = image_read 61 | self._anno_read = anno_read 62 | self._image_anno_read = image_label_read 63 | self._joint_transform = joint_transform 64 | self._seqlen = samplelen 65 | self._obj_selection = obj_selection 66 | self._start_frame = start_frame 67 | self.transform = None 68 | assert version in ('2016', '2017') 69 | assert image_set in ('train', 'val', 'test-dev', 'test-challenge') 70 | # assert samplelen > 1, "samplelen must be at least 2" 71 | assert start_frame in ('random','first') 72 | self._init_data() 73 | self._max_skip = max_skip 74 | self._load_all = load_all 75 | 76 | def _init_data(self): 77 | """ Store some metadata that needs to be known during training. In order to sample, the viable sequences 78 | must be known. Sequences are viable if a snippet of given sample length can be selected, starting with 79 | an annotated frame and containing at least one more annotated frame. 80 | """ 81 | print("-- DAVIS17 dataset initialization started.") 82 | framework_path = os.path.join(os.path.dirname(__file__), '..') 83 | cache_path = os.path.join(framework_path, 'cache', 'davis17_v2_visible_objects_100px_threshold.json') 84 | 85 | # First find visible objects in all annotated frames 86 | if os.path.exists(cache_path): 87 | with open(cache_path, 'r') as f: 88 | self._visible_objects = json.load(f) 89 | self._visible_objects = {seqname: OrderedDict((int(idx), objlst) for idx, objlst in val.items()) 90 | for seqname, val in self._visible_objects.items()} 91 | print("Datafile {} loaded, describing {} sequences.".format(cache_path, len(self._visible_objects))) 92 | else: 93 | # Grab all sequences in dataset 94 | seqnames = os.listdir(os.path.join(self._root_path, 'JPEGImages/', '480p')) 95 | 96 | # Construct meta-info 97 | self._visible_objects = {} 98 | for seqname in seqnames: 99 | anno_paths = sorted(glob.glob(self._full_anno_path(seqname, '*.png'))) 100 | self._visible_objects[seqname] = OrderedDict( 101 | (self._frame_name_to_idx(os.path.basename(path)), 102 | get_anno_ids(path, dataset_utils.LabelToLongTensor(), 100)) 103 | for path in anno_paths) 104 | 105 | if not os.path.exists(os.path.dirname(cache_path)): 106 | os.makedirs(os.path.dirname(cache_path)) 107 | with open(cache_path, 'w') as f: 108 | json.dump(self._visible_objects, f) 109 | print("Datafile {} was not found, creating it with {} sequences.".format(cache_path, len(self._visible_objects))) 110 | 111 | # Find sequences in the requested image_set 112 | with open(os.path.join(self._root_path, 'ImageSets', self._version, self._image_set + '.txt'), 'r') as f: 113 | self._all_seqs = f.read().splitlines() 114 | print("{} sequences found in image set \"{}\"".format(len(self._all_seqs), self._image_set)) 115 | 116 | # Filter out sequences that are too short from first frame with object, to last annotation 117 | self._nonempty_frame_ids = {seq: [frame_idx for frame_idx, obj_ids in lst.items() if len(obj_ids) >= self._min_num_objects] 118 | for seq, lst in self._visible_objects.items()} 119 | self._viable_seqs = [seq for seq in self._all_seqs if 120 | len(self._nonempty_frame_ids[seq]) > 0 121 | and len(self.get_image_frame_ids(seq)[min(self._nonempty_frame_ids[seq]) : 122 | max(self._visible_objects[seq].keys()) + 1]) 123 | >= self._seqlen] 124 | print("{} sequences remaining after filtering on length (from first anno obj appearance to last anno frame.".format(len(self._viable_seqs))) 125 | 126 | def __len__(self): 127 | return len(self._viable_seqs) 128 | 129 | 130 | def _frame_idx_to_image_fname(self, idx): 131 | return "{:05d}.jpg".format(idx) 132 | 133 | def _frame_idx_to_anno_fname(self, idx): 134 | return "{:05d}.png".format(idx) 135 | 136 | def _frame_name_to_idx(self, fname): 137 | return int(os.path.splitext(fname)[0]) 138 | 139 | def get_viable_seqnames(self): 140 | return self._viable_seqs 141 | 142 | def get_all_seqnames(self): 143 | return self._all_seqs 144 | 145 | def get_anno_frame_names(self, seqname): 146 | return os.listdir(os.path.join(self._root_path, "Annotations", "480p", seqname)) 147 | 148 | def get_anno_frame_ids(self, seqname): 149 | return sorted([self._frame_name_to_idx(fname) for fname in self.get_anno_frame_names(seqname)]) 150 | 151 | def get_image_frame_names(self, seqname): 152 | return os.listdir(os.path.join(self._root_path, "JPEGImages", "480p", seqname)) 153 | 154 | def get_image_frame_ids(self, seqname): 155 | return sorted([self._frame_name_to_idx(fname) for fname in self.get_image_frame_names(seqname)]) 156 | 157 | def get_frame_ids(self, seqname): 158 | return sorted([self._frame_name_to_idx(fname) for fname in self.get_image_frame_names(seqname)]) 159 | 160 | def get_nonempty_frame_ids(self, seqname): 161 | return self._nonempty_frame_ids[seqname] 162 | 163 | def _full_image_path(self, seqname, image): 164 | if isinstance(image, int): 165 | image = self._frame_idx_to_image_fname(image) 166 | return os.path.join(self._root_path, 'JPEGImages', "480p", seqname, image) 167 | 168 | def _full_anno_path(self, seqname, anno): 169 | if isinstance(anno, int): 170 | anno = self._frame_idx_to_anno_fname(anno) 171 | return os.path.join(self._root_path, 'Annotations', "480p", seqname, anno) 172 | 173 | def _select_frame_ids(self, frame_ids, viable_starting_frame_ids): 174 | if self._start_frame == 'first': 175 | frame_idxidx = frame_ids.index(viable_starting_frame_ids[0]) 176 | elif self._start_frame == 'random': 177 | frame_idxidx = frame_ids.index(random.choice(viable_starting_frame_ids)) 178 | 179 | if self._load_all: 180 | return frame_ids 181 | 182 | if self._max_skip is None: 183 | return frame_ids[frame_idxidx : frame_idxidx + self._seqlen] 184 | else: 185 | frame_ids_select = [] 186 | skip = random.randint(1, self._max_skip) 187 | sum_skip = skip * (self._seqlen - 1) 188 | if (sum_skip + self._seqlen) > (len(frame_ids) - frame_idxidx): 189 | skip = int((len(frame_ids) - frame_idxidx - self._seqlen) / (self._seqlen - 1)) 190 | idx_offset = 0 191 | for i in range(self._seqlen): 192 | frame_ids_select.append(frame_ids[frame_idxidx + idx_offset]) 193 | idx_offset = idx_offset + skip + 1 194 | return frame_ids_select 195 | 196 | def _select_object_ids(self, labels): 197 | assert labels.min() > -1 and labels.max() < 256, "{}".format(utils.print_tensor_statistics(labels)) 198 | possible_obj_ids = (labels[0].view(-1).bincount() > 25).nonzero().view(-1).tolist() 199 | if 0 in possible_obj_ids: possible_obj_ids.remove(0) 200 | if 255 in possible_obj_ids: possible_obj_ids.remove(255) 201 | #assert len(possible_obj_ids) > 0 202 | 203 | obj_ids = self._obj_selection(possible_obj_ids) 204 | bg_ids = (labels.view(-1).bincount() > 0).nonzero().view(-1).tolist() 205 | if 0 in bg_ids: bg_ids.remove(0) 206 | if 255 in bg_ids: bg_ids.remove(255) 207 | for idx in obj_ids: 208 | bg_ids.remove(idx) 209 | 210 | for idx in bg_ids: 211 | labels[labels == idx] = 0 212 | obj_ids.sort() 213 | for new_idx, old_idx in zip(range(1,len(obj_ids)+1), obj_ids): 214 | labels[labels == old_idx] = new_idx 215 | return labels 216 | 217 | def __getitem__(self, idx): 218 | """ 219 | returns: 220 | dict (Tensors): contains 'images', 'given_segmentations', 'labels' 221 | """ 222 | # assert self._version == '2017', "Only the 2017 version is supported for training as of now" 223 | seqname = self.get_viable_seqnames()[idx] 224 | 225 | # We require to begin with a nonempty frame, and will consider all objects in that frame to be tracked. 226 | # A starting frame is valid if it is followed by seqlen-1 frames with corresp images 227 | frame_ids = self.get_frame_ids(seqname) 228 | viable_starting_frame_ids = [idx for idx in self.get_nonempty_frame_ids(seqname) 229 | if idx <= frame_ids[-self._seqlen]] 230 | 231 | frame_ids = self._select_frame_ids(frame_ids, viable_starting_frame_ids) 232 | if self.transform is not None: 233 | images = [] 234 | segannos = [] 235 | for idx in frame_ids: 236 | my_image = np.array(Image.open(self._full_image_path(seqname, idx)).convert('RGB'))/255. 237 | #self._image_read(self._full_image_path(seqname, idx)) 238 | my_gt = np.array(Image.open(self._full_anno_path(seqname, idx))) 239 | #print('davis range:', np.max(my_image), np.min(my_image), np.max(my_gt), np.min(my_gt)) 240 | #self._anno_read(self._full_anno_path(seqname, idx)) 241 | sample_copy = {'image': np.copy(my_image), 'gt': np.copy(my_gt)} 242 | sample_transformed = self.transform(sample_copy) 243 | images.append(sample_transformed['crop_image']) 244 | segannos.append(sample_transformed['crop_gt']) 245 | images = torch.stack(images, dim=0).float().clamp(0, 1) 246 | segannos = torch.stack(segannos, dim=0).long() 247 | 248 | if self._joint_transform is not None: 249 | 250 | #images = torch.stack([self._image_read(self._full_image_path(seqname, idx)) 251 | # for idx in frame_ids]) 252 | #print('davis range:', torch.max(images), torch.min(images)) 253 | #segannos = torch.stack([self._anno_read(self._full_anno_path(seqname, idx)) 254 | # for idx in frame_ids]) 255 | images = [] 256 | segannos = [] 257 | for idx in frame_ids: 258 | temp1,temp2 = self._image_anno_read(self._full_image_path(seqname, idx),self._full_anno_path(seqname, idx)) 259 | images.append(temp1) 260 | segannos.append(temp2) 261 | images = torch.stack(images, dim=0).float().clamp(0, 1) 262 | segannos = torch.stack(segannos, dim=0) 263 | 264 | if self._version == '2017': 265 | segannos = self._select_object_ids(segannos) 266 | elif self._version == '2016': 267 | segannos = (segannos > 0).long() 268 | else: 269 | raise ValueError("Version is not 2016 or 2017, got {}".format(self._version)) 270 | if self._joint_transform is not None: 271 | images, segannos = self._joint_transform(images, segannos) 272 | segannos[segannos == 255] = 0 273 | given_seganno = segannos[0] 274 | provides_seganno = torch.empty((self._seqlen), dtype=torch.uint8).fill_(True) 275 | num_objects = int(segannos.max()) 276 | #print('davis:',num_objects) 277 | 278 | return {'images':images, 'provides_seganno': provides_seganno, 'given_seganno':given_seganno, 279 | 'segannos':segannos, 'seqname':seqname, 'num_objects':num_objects} 280 | 281 | def _get_snippet(self, seqname, frame_ids): 282 | images = torch.stack([self._image_read(self._full_image_path(seqname, idx)) 283 | for idx in frame_ids]).unsqueeze(0) 284 | if self._image_set in ('test-dev', 'test-challenge'): 285 | segannos = None 286 | given_segannos = [self._anno_read(self._full_anno_path(seqname, idx)).unsqueeze(0) 287 | if idx in anno_frame_ids else None for idx in frame_ids] 288 | else: 289 | segannos = torch.stack([self._anno_read(self._full_anno_path(seqname, idx)) 290 | for idx in frame_ids]).squeeze().unsqueeze(0) 291 | if self._version == '2016': 292 | segannos = (segannos != 0).long() 293 | given_segannos = [self._anno_read(self._full_anno_path(seqname, idx)).unsqueeze(0) 294 | if idx == self.get_anno_frame_ids(seqname)[0] else None for idx in frame_ids] 295 | for i in range(len(given_segannos)): # Remove dont-care from given segannos 296 | if given_segannos[i] is not None: 297 | given_segannos[i][given_segannos[i] == 255] = 0 298 | if self._version == '2016': 299 | given_segannos[i] = (given_segannos[i] != 0).long() 300 | 301 | fnames = [self._frame_idx_to_anno_fname(idx) for idx in frame_ids] 302 | return {'images':images, 'given_segannos': given_segannos, 'segannos':segannos, 'fnames':fnames} 303 | 304 | def _get_video(self, seqname): 305 | seq_frame_ids = self.get_frame_ids(seqname) 306 | partitioned_frame_ids = [seq_frame_ids[start_idx : start_idx + self._seqlen] 307 | for start_idx in range(0, len(seq_frame_ids), self._seqlen)] 308 | for frame_ids in partitioned_frame_ids: 309 | yield self._get_snippet(seqname, frame_ids) 310 | 311 | def get_video_generator(self): 312 | for seqname in self.get_all_seqnames(): 313 | yield (seqname, self._get_video(seqname)) 314 | 315 | -------------------------------------------------------------------------------- /dataset_loaders/davis17_v2_all.py: -------------------------------------------------------------------------------- 1 | import random 2 | import glob 3 | import os 4 | import json 5 | from collections import OrderedDict 6 | 7 | from PIL import Image 8 | import torch 9 | import torchvision as tv 10 | 11 | from dataset_loaders import dataset_utils 12 | import utils 13 | 14 | 15 | def get_sample_bernoulli(p): 16 | return (lambda lst: [elem for elem in lst if random.random() < p]) 17 | def get_sample_all(): 18 | return (lambda lst: lst) 19 | def get_sample_k_random(k): 20 | return (lambda lst: sorted(random.sample(lst, min(k,len(lst))))) 21 | 22 | def get_anno_ids(anno_path, pic_to_tensor_function, threshold): 23 | pic = Image.open(anno_path) 24 | tensor = pic_to_tensor_function(pic) 25 | values = (tensor.view(-1).bincount() > threshold).nonzero().view(-1).tolist() 26 | if 0 in values: values.remove(0) 27 | if 255 in values: values.remove(255) 28 | return values 29 | 30 | def get_default_image_read(size=(240,432)): 31 | def image_read(path): 32 | pic = Image.open(path) 33 | transform = tv.transforms.Compose( 34 | [tv.transforms.Resize(size, interpolation=Image.BILINEAR), 35 | tv.transforms.ToTensor(), 36 | tv.transforms.Normalize(mean=dataset_utils.IMAGENET_MEAN, std=dataset_utils.IMAGENET_STD)]) 37 | return transform(pic) 38 | return image_read 39 | def get_default_anno_read(size=(240,432)): 40 | def label_read(path): 41 | if os.path.exists(path): 42 | pic = Image.open(path) 43 | transform = tv.transforms.Compose( 44 | [tv.transforms.Resize(size, interpolation=Image.NEAREST), 45 | dataset_utils.LabelToLongTensor()]) 46 | label = transform(pic) 47 | else: 48 | label = torch.LongTensor(1,*size).fill_(255) # Put label that will be ignored 49 | return label 50 | return label_read 51 | 52 | class DAVIS17V2_all(torch.utils.data.Dataset): 53 | def __init__(self, root_path, version, image_set, image_read=get_default_image_read(), 54 | anno_read=get_default_anno_read(), 55 | joint_transform=None, samplelen=4, obj_selection=get_sample_all(), min_num_obj=1, start_frame='random', max_skip=None, load_all=False): 56 | self._min_num_objects = min_num_obj 57 | self._root_path = root_path 58 | self._version = version 59 | self._image_set = image_set 60 | self._image_read = image_read 61 | self._anno_read = anno_read 62 | self._joint_transform = joint_transform 63 | self._seqlen = samplelen 64 | self._obj_selection = obj_selection 65 | self._start_frame = start_frame 66 | assert version in ('2016', '2017') 67 | assert image_set in ('train', 'val', 'test-dev', 'test-challenge') 68 | # assert samplelen > 1, "samplelen must be at least 2" 69 | assert start_frame in ('random','first') 70 | self._init_data() 71 | self._max_skip = max_skip 72 | self._load_all = load_all 73 | 74 | def _init_data(self): 75 | """ Store some metadata that needs to be known during training. In order to sample, the viable sequences 76 | must be known. Sequences are viable if a snippet of given sample length can be selected, starting with 77 | an annotated frame and containing at least one more annotated frame. 78 | """ 79 | print("-- DAVIS17 dataset initialization started.") 80 | framework_path = os.path.join(os.path.dirname(__file__), '..') 81 | cache_path = os.path.join(framework_path, 'cache', 'davis17_v2_visible_objects_100px_threshold.json') 82 | 83 | # First find visible objects in all annotated frames 84 | if os.path.exists(cache_path): 85 | with open(cache_path, 'r') as f: 86 | self._visible_objects = json.load(f) 87 | self._visible_objects = {seqname: OrderedDict((int(idx), objlst) for idx, objlst in val.items()) 88 | for seqname, val in self._visible_objects.items()} 89 | print("Datafile {} loaded, describing {} sequences.".format(cache_path, len(self._visible_objects))) 90 | else: 91 | # Grab all sequences in dataset 92 | seqnames = os.listdir(os.path.join(self._root_path, 'JPEGImages/', '480p')) 93 | 94 | # Construct meta-info 95 | self._visible_objects = {} 96 | for seqname in seqnames: 97 | anno_paths = sorted(glob.glob(self._full_anno_path(seqname, '*.png'))) 98 | self._visible_objects[seqname] = OrderedDict( 99 | (self._frame_name_to_idx(os.path.basename(path)), 100 | get_anno_ids(path, dataset_utils.LabelToLongTensor(), 100)) 101 | for path in anno_paths) 102 | 103 | if not os.path.exists(os.path.dirname(cache_path)): 104 | os.makedirs(os.path.dirname(cache_path)) 105 | with open(cache_path, 'w') as f: 106 | json.dump(self._visible_objects, f) 107 | print("Datafile {} was not found, creating it with {} sequences.".format(cache_path, len(self._visible_objects))) 108 | 109 | # Find sequences in the requested image_set 110 | with open(os.path.join(self._root_path, 'ImageSets', self._version, self._image_set + '.txt'), 'r') as f: 111 | self._all_seqs = f.read().splitlines() 112 | print("{} sequences found in image set \"{}\"".format(len(self._all_seqs), self._image_set)) 113 | 114 | # Filter out sequences that are too short from first frame with object, to last annotation 115 | self._nonempty_frame_ids = {seq: [frame_idx for frame_idx, obj_ids in lst.items() if len(obj_ids) >= self._min_num_objects] 116 | for seq, lst in self._visible_objects.items()} 117 | self._viable_seqs = [seq for seq in self._all_seqs if 118 | len(self._nonempty_frame_ids[seq]) > 0 119 | and len(self.get_image_frame_ids(seq)[min(self._nonempty_frame_ids[seq]) : 120 | max(self._visible_objects[seq].keys()) + 1]) 121 | >= self._seqlen] 122 | print("{} sequences remaining after filtering on length (from first anno obj appearance to last anno frame.".format(len(self._viable_seqs))) 123 | 124 | def __len__(self): 125 | return len(self._viable_seqs) 126 | 127 | def _frame_idx_to_image_fname(self, idx): 128 | return "{:05d}.jpg".format(idx) 129 | 130 | def _frame_idx_to_anno_fname(self, idx): 131 | return "{:05d}.png".format(idx) 132 | 133 | def _frame_name_to_idx(self, fname): 134 | return int(os.path.splitext(fname)[0]) 135 | 136 | def get_viable_seqnames(self): 137 | return self._viable_seqs 138 | 139 | def get_all_seqnames(self): 140 | return self._all_seqs 141 | 142 | def get_anno_frame_names(self, seqname): 143 | return os.listdir(os.path.join(self._root_path, "Annotations", "480p", seqname)) 144 | 145 | def get_anno_frame_ids(self, seqname): 146 | return sorted([self._frame_name_to_idx(fname) for fname in self.get_anno_frame_names(seqname)]) 147 | 148 | def get_image_frame_names(self, seqname): 149 | return os.listdir(os.path.join(self._root_path, "JPEGImages", "480p", seqname)) 150 | 151 | def get_image_frame_ids(self, seqname): 152 | return sorted([self._frame_name_to_idx(fname) for fname in self.get_image_frame_names(seqname)]) 153 | 154 | def get_frame_ids(self, seqname): 155 | return sorted([self._frame_name_to_idx(fname) for fname in self.get_image_frame_names(seqname)]) 156 | 157 | def get_nonempty_frame_ids(self, seqname): 158 | return self._nonempty_frame_ids[seqname] 159 | 160 | def _full_image_path(self, seqname, image): 161 | if isinstance(image, int): 162 | image = self._frame_idx_to_image_fname(image) 163 | return os.path.join(self._root_path, 'JPEGImages', "480p", seqname, image) 164 | 165 | def _full_anno_path(self, seqname, anno): 166 | if isinstance(anno, int): 167 | anno = self._frame_idx_to_anno_fname(anno) 168 | return os.path.join(self._root_path, 'Annotations', "480p", seqname, anno) 169 | 170 | def _select_frame_ids(self, frame_ids, viable_starting_frame_ids): 171 | if self._start_frame == 'first': 172 | frame_idxidx = frame_ids.index(viable_starting_frame_ids[0]) 173 | elif self._start_frame == 'random': 174 | frame_idxidx = frame_ids.index(random.choice(viable_starting_frame_ids)) 175 | 176 | if self._load_all: 177 | return frame_ids 178 | 179 | if self._max_skip is None: 180 | return frame_ids[frame_idxidx : frame_idxidx + self._seqlen] 181 | else: 182 | frame_ids_select = [] 183 | skip = random.randint(1, self._max_skip) 184 | sum_skip = skip * (self._seqlen - 1) 185 | if (sum_skip + self._seqlen) > (len(frame_ids) - frame_idxidx): 186 | skip = int((len(frame_ids) - frame_idxidx - self._seqlen) / (self._seqlen - 1)) 187 | idx_offset = 0 188 | for i in range(self._seqlen): 189 | frame_ids_select.append(frame_ids[frame_idxidx + idx_offset]) 190 | idx_offset = idx_offset + skip + 1 191 | return frame_ids_select 192 | 193 | def _select_object_ids(self, labels): 194 | assert labels.min() > -1 and labels.max() < 256, "{}".format(utils.print_tensor_statistics(labels)) 195 | possible_obj_ids = (labels[0].view(-1).bincount() > 25).nonzero().view(-1).tolist() 196 | if 0 in possible_obj_ids: possible_obj_ids.remove(0) 197 | if 255 in possible_obj_ids: possible_obj_ids.remove(255) 198 | assert len(possible_obj_ids) > 0 199 | 200 | obj_ids = self._obj_selection(possible_obj_ids) 201 | bg_ids = (labels.view(-1).bincount() > 0).nonzero().view(-1).tolist() 202 | if 0 in bg_ids: bg_ids.remove(0) 203 | if 255 in bg_ids: bg_ids.remove(255) 204 | for idx in obj_ids: 205 | bg_ids.remove(idx) 206 | 207 | for idx in bg_ids: 208 | labels[labels == idx] = 0 209 | obj_ids.sort() 210 | for new_idx, old_idx in zip(range(1,len(obj_ids)+1), obj_ids): 211 | labels[labels == old_idx] = new_idx 212 | return labels 213 | 214 | def __getitem__(self, idx): 215 | """ 216 | returns: 217 | dict (Tensors): contains 'images', 'given_segmentations', 'labels' 218 | """ 219 | # assert self._version == '2017', "Only the 2017 version is supported for training as of now" 220 | seqname = self.get_viable_seqnames()[idx] 221 | 222 | # We require to begin with a nonempty frame, and will consider all objects in that frame to be tracked. 223 | # A starting frame is valid if it is followed by seqlen-1 frames with corresp images 224 | frame_ids = self.get_frame_ids(seqname) 225 | #viable_starting_frame_ids = [idx for idx in self.get_nonempty_frame_ids(seqname) 226 | # if idx <= frame_ids[-self._seqlen]] 227 | 228 | #frame_ids = self._select_frame_ids(frame_ids, viable_starting_frame_ids) 229 | 230 | images = torch.stack([self._image_read(self._full_image_path(seqname, idx)) 231 | for idx in frame_ids]) 232 | #print('video name:', seqname, images.size()) 233 | segannos = torch.stack([self._anno_read(self._full_anno_path(seqname, idx)) 234 | for idx in frame_ids]) 235 | 236 | if self._version == '2017': 237 | segannos = self._select_object_ids(segannos) 238 | elif self._version == '2016': 239 | segannos = (segannos > 0).long() 240 | else: 241 | raise ValueError("Version is not 2016 or 2017, got {}".format(self._version)) 242 | if self._joint_transform is not None: 243 | images, segannos = self._joint_transform(images, segannos) 244 | segannos[segannos == 255] = 0 245 | given_seganno = segannos[0] 246 | provides_seganno = torch.empty((self._seqlen), dtype=torch.uint8).fill_(True) 247 | num_objects = int(segannos.max()) 248 | 249 | return {'images':images, 'provides_seganno': provides_seganno, 'given_seganno':given_seganno, 250 | 'segannos':segannos, 'seqname':seqname, 'num_objects':num_objects} 251 | 252 | def _get_snippet(self, seqname, frame_ids): 253 | images = torch.stack([self._image_read(self._full_image_path(seqname, idx)) 254 | for idx in frame_ids]).unsqueeze(0) 255 | if self._image_set in ('test-dev', 'test-challenge'): 256 | segannos = None 257 | given_segannos = [self._anno_read(self._full_anno_path(seqname, idx)).unsqueeze(0) 258 | if idx in anno_frame_ids else None for idx in frame_ids] 259 | else: 260 | segannos = torch.stack([self._anno_read(self._full_anno_path(seqname, idx)) 261 | for idx in frame_ids]).squeeze().unsqueeze(0) 262 | if self._version == '2016': 263 | segannos = (segannos != 0).long() 264 | given_segannos = [self._anno_read(self._full_anno_path(seqname, idx)).unsqueeze(0) 265 | if idx == self.get_anno_frame_ids(seqname)[0] else None for idx in frame_ids] 266 | for i in range(len(given_segannos)): # Remove dont-care from given segannos 267 | if given_segannos[i] is not None: 268 | given_segannos[i][given_segannos[i] == 255] = 0 269 | if self._version == '2016': 270 | given_segannos[i] = (given_segannos[i] != 0).long() 271 | 272 | fnames = [self._frame_idx_to_anno_fname(idx) for idx in frame_ids] 273 | return {'images':images, 'given_segannos': given_segannos, 'segannos':segannos, 'fnames':fnames} 274 | 275 | def _get_video(self, seqname): 276 | seq_frame_ids = self.get_frame_ids(seqname) 277 | partitioned_frame_ids = [seq_frame_ids[start_idx : start_idx + self._seqlen] 278 | for start_idx in range(0, len(seq_frame_ids), self._seqlen)] 279 | for frame_ids in partitioned_frame_ids: 280 | yield self._get_snippet(seqname, frame_ids) 281 | 282 | def get_video_generator(self): 283 | for seqname in self.get_all_seqnames(): 284 | yield (seqname, self._get_video(seqname)) 285 | 286 | -------------------------------------------------------------------------------- /dataset_loaders/davis17_v2_org.py: -------------------------------------------------------------------------------- 1 | import random 2 | import glob 3 | import os 4 | import json 5 | from collections import OrderedDict 6 | 7 | from PIL import Image 8 | import torch 9 | import torchvision as tv 10 | 11 | from dataset_loaders import dataset_utils 12 | import utils 13 | 14 | 15 | def get_sample_bernoulli(p): 16 | return (lambda lst: [elem for elem in lst if random.random() < p]) 17 | def get_sample_all(): 18 | return (lambda lst: lst) 19 | def get_sample_k_random(k): 20 | return (lambda lst: sorted(random.sample(lst, min(k,len(lst))))) 21 | 22 | def get_anno_ids(anno_path, pic_to_tensor_function, threshold): 23 | pic = Image.open(anno_path) 24 | tensor = pic_to_tensor_function(pic) 25 | values = (tensor.view(-1).bincount() > threshold).nonzero().view(-1).tolist() 26 | if 0 in values: values.remove(0) 27 | if 255 in values: values.remove(255) 28 | return values 29 | 30 | def get_default_image_read(size=(240,432)): 31 | def image_read(path): 32 | pic = Image.open(path) 33 | transform = tv.transforms.Compose( 34 | [tv.transforms.Resize(size, interpolation=Image.BILINEAR), 35 | tv.transforms.ToTensor(), 36 | tv.transforms.Normalize(mean=dataset_utils.IMAGENET_MEAN, std=dataset_utils.IMAGENET_STD)]) 37 | return transform(pic) 38 | return image_read 39 | def get_default_anno_read(size=(240,432)): 40 | def label_read(path): 41 | if os.path.exists(path): 42 | pic = Image.open(path) 43 | transform = tv.transforms.Compose( 44 | [tv.transforms.Resize(size, interpolation=Image.NEAREST), 45 | dataset_utils.LabelToLongTensor()]) 46 | label = transform(pic) 47 | else: 48 | label = torch.LongTensor(1,*size).fill_(255) # Put label that will be ignored 49 | return label 50 | return label_read 51 | 52 | class DAVIS17V2(torch.utils.data.Dataset): 53 | def __init__(self, root_path, version, image_set, image_read=get_default_image_read(), 54 | anno_read=get_default_anno_read(), 55 | joint_transform=None, samplelen=4, obj_selection=get_sample_all(), min_num_obj=1, start_frame='random', max_skip=None, load_all=False): 56 | self._min_num_objects = min_num_obj 57 | self._root_path = root_path 58 | self._version = version 59 | self._image_set = image_set 60 | self._image_read = image_read 61 | self._anno_read = anno_read 62 | self._joint_transform = joint_transform 63 | self._seqlen = samplelen 64 | self._obj_selection = obj_selection 65 | self._start_frame = start_frame 66 | assert version in ('2016', '2017') 67 | assert image_set in ('train', 'val', 'test-dev', 'test-challenge') 68 | # assert samplelen > 1, "samplelen must be at least 2" 69 | assert start_frame in ('random','first') 70 | self._init_data() 71 | self._max_skip = max_skip 72 | self._load_all = load_all 73 | 74 | def _init_data(self): 75 | """ Store some metadata that needs to be known during training. In order to sample, the viable sequences 76 | must be known. Sequences are viable if a snippet of given sample length can be selected, starting with 77 | an annotated frame and containing at least one more annotated frame. 78 | """ 79 | print("-- DAVIS17 dataset initialization started.") 80 | framework_path = os.path.join(os.path.dirname(__file__), '..') 81 | cache_path = os.path.join(framework_path, 'cache', 'davis17_v2_visible_objects_100px_threshold.json') 82 | 83 | # First find visible objects in all annotated frames 84 | if os.path.exists(cache_path): 85 | with open(cache_path, 'r') as f: 86 | self._visible_objects = json.load(f) 87 | self._visible_objects = {seqname: OrderedDict((int(idx), objlst) for idx, objlst in val.items()) 88 | for seqname, val in self._visible_objects.items()} 89 | print("Datafile {} loaded, describing {} sequences.".format(cache_path, len(self._visible_objects))) 90 | else: 91 | # Grab all sequences in dataset 92 | seqnames = os.listdir(os.path.join(self._root_path, 'JPEGImages/', '480p')) 93 | 94 | # Construct meta-info 95 | self._visible_objects = {} 96 | for seqname in seqnames: 97 | anno_paths = sorted(glob.glob(self._full_anno_path(seqname, '*.png'))) 98 | self._visible_objects[seqname] = OrderedDict( 99 | (self._frame_name_to_idx(os.path.basename(path)), 100 | get_anno_ids(path, dataset_utils.LabelToLongTensor(), 100)) 101 | for path in anno_paths) 102 | 103 | if not os.path.exists(os.path.dirname(cache_path)): 104 | os.makedirs(os.path.dirname(cache_path)) 105 | with open(cache_path, 'w') as f: 106 | json.dump(self._visible_objects, f) 107 | print("Datafile {} was not found, creating it with {} sequences.".format(cache_path, len(self._visible_objects))) 108 | 109 | # Find sequences in the requested image_set 110 | with open(os.path.join(self._root_path, 'ImageSets', self._version, self._image_set + '.txt'), 'r') as f: 111 | self._all_seqs = f.read().splitlines() 112 | print("{} sequences found in image set \"{}\"".format(len(self._all_seqs), self._image_set)) 113 | 114 | # Filter out sequences that are too short from first frame with object, to last annotation 115 | self._nonempty_frame_ids = {seq: [frame_idx for frame_idx, obj_ids in lst.items() if len(obj_ids) >= self._min_num_objects] 116 | for seq, lst in self._visible_objects.items()} 117 | self._viable_seqs = [seq for seq in self._all_seqs if 118 | len(self._nonempty_frame_ids[seq]) > 0 119 | and len(self.get_image_frame_ids(seq)[min(self._nonempty_frame_ids[seq]) : 120 | max(self._visible_objects[seq].keys()) + 1]) 121 | >= self._seqlen] 122 | print("{} sequences remaining after filtering on length (from first anno obj appearance to last anno frame.".format(len(self._viable_seqs))) 123 | 124 | def __len__(self): 125 | return len(self._viable_seqs) 126 | 127 | def _frame_idx_to_image_fname(self, idx): 128 | return "{:05d}.jpg".format(idx) 129 | 130 | def _frame_idx_to_anno_fname(self, idx): 131 | return "{:05d}.png".format(idx) 132 | 133 | def _frame_name_to_idx(self, fname): 134 | return int(os.path.splitext(fname)[0]) 135 | 136 | def get_viable_seqnames(self): 137 | return self._viable_seqs 138 | 139 | def get_all_seqnames(self): 140 | return self._all_seqs 141 | 142 | def get_anno_frame_names(self, seqname): 143 | return os.listdir(os.path.join(self._root_path, "Annotations", "480p", seqname)) 144 | 145 | def get_anno_frame_ids(self, seqname): 146 | return sorted([self._frame_name_to_idx(fname) for fname in self.get_anno_frame_names(seqname)]) 147 | 148 | def get_image_frame_names(self, seqname): 149 | return os.listdir(os.path.join(self._root_path, "JPEGImages", "480p", seqname)) 150 | 151 | def get_image_frame_ids(self, seqname): 152 | return sorted([self._frame_name_to_idx(fname) for fname in self.get_image_frame_names(seqname)]) 153 | 154 | def get_frame_ids(self, seqname): 155 | return sorted([self._frame_name_to_idx(fname) for fname in self.get_image_frame_names(seqname)]) 156 | 157 | def get_nonempty_frame_ids(self, seqname): 158 | return self._nonempty_frame_ids[seqname] 159 | 160 | def _full_image_path(self, seqname, image): 161 | if isinstance(image, int): 162 | image = self._frame_idx_to_image_fname(image) 163 | return os.path.join(self._root_path, 'JPEGImages', "480p", seqname, image) 164 | 165 | def _full_anno_path(self, seqname, anno): 166 | if isinstance(anno, int): 167 | anno = self._frame_idx_to_anno_fname(anno) 168 | return os.path.join(self._root_path, 'Annotations', "480p", seqname, anno) 169 | 170 | def _select_frame_ids(self, frame_ids, viable_starting_frame_ids): 171 | if self._start_frame == 'first': 172 | frame_idxidx = frame_ids.index(viable_starting_frame_ids[0]) 173 | elif self._start_frame == 'random': 174 | frame_idxidx = frame_ids.index(random.choice(viable_starting_frame_ids)) 175 | 176 | if self._load_all: 177 | return frame_ids 178 | 179 | if self._max_skip is None: 180 | return frame_ids[frame_idxidx : frame_idxidx + self._seqlen] 181 | else: 182 | frame_ids_select = [] 183 | skip = random.randint(1, self._max_skip) 184 | sum_skip = skip * (self._seqlen - 1) 185 | if (sum_skip + self._seqlen) > (len(frame_ids) - frame_idxidx): 186 | skip = int((len(frame_ids) - frame_idxidx - self._seqlen) / (self._seqlen - 1)) 187 | idx_offset = 0 188 | for i in range(self._seqlen): 189 | frame_ids_select.append(frame_ids[frame_idxidx + idx_offset]) 190 | idx_offset = idx_offset + skip + 1 191 | return frame_ids_select 192 | 193 | def _select_object_ids(self, labels): 194 | assert labels.min() > -1 and labels.max() < 256, "{}".format(utils.print_tensor_statistics(labels)) 195 | possible_obj_ids = (labels[0].view(-1).bincount() > 25).nonzero().view(-1).tolist() 196 | if 0 in possible_obj_ids: possible_obj_ids.remove(0) 197 | if 255 in possible_obj_ids: possible_obj_ids.remove(255) 198 | assert len(possible_obj_ids) > 0 199 | 200 | obj_ids = self._obj_selection(possible_obj_ids) 201 | bg_ids = (labels.view(-1).bincount() > 0).nonzero().view(-1).tolist() 202 | if 0 in bg_ids: bg_ids.remove(0) 203 | if 255 in bg_ids: bg_ids.remove(255) 204 | for idx in obj_ids: 205 | bg_ids.remove(idx) 206 | 207 | for idx in bg_ids: 208 | labels[labels == idx] = 0 209 | for new_idx, old_idx in zip(range(1,len(obj_ids)+1), obj_ids): 210 | labels[labels == old_idx] = new_idx 211 | return labels 212 | 213 | def __getitem__(self, idx): 214 | """ 215 | returns: 216 | dict (Tensors): contains 'images', 'given_segmentations', 'labels' 217 | """ 218 | # assert self._version == '2017', "Only the 2017 version is supported for training as of now" 219 | seqname = self.get_viable_seqnames()[idx] 220 | 221 | # We require to begin with a nonempty frame, and will consider all objects in that frame to be tracked. 222 | # A starting frame is valid if it is followed by seqlen-1 frames with corresp images 223 | frame_ids = self.get_frame_ids(seqname) 224 | viable_starting_frame_ids = [idx for idx in self.get_nonempty_frame_ids(seqname) 225 | if idx <= frame_ids[-self._seqlen]] 226 | 227 | frame_ids = self._select_frame_ids(frame_ids, viable_starting_frame_ids) 228 | 229 | images = torch.stack([self._image_read(self._full_image_path(seqname, idx)) 230 | for idx in frame_ids]) 231 | segannos = torch.stack([self._anno_read(self._full_anno_path(seqname, idx)) 232 | for idx in frame_ids]) 233 | 234 | if self._version == '2017': 235 | segannos = self._select_object_ids(segannos) 236 | elif self._version == '2016': 237 | segannos = (segannos > 0).long() 238 | else: 239 | raise ValueError("Version is not 2016 or 2017, got {}".format(self._version)) 240 | if self._joint_transform is not None: 241 | images, segannos = self._joint_transform(images, segannos) 242 | segannos[segannos == 255] = 0 243 | given_seganno = segannos[0] 244 | provides_seganno = torch.empty((self._seqlen), dtype=torch.uint8).fill_(True) 245 | num_objects = int(segannos.max()) 246 | 247 | return {'images':images, 'provides_seganno': provides_seganno, 'given_seganno':given_seganno, 248 | 'segannos':segannos, 'seqname':seqname, 'num_objects':num_objects} 249 | 250 | def _get_snippet(self, seqname, frame_ids): 251 | images = torch.stack([self._image_read(self._full_image_path(seqname, idx)) 252 | for idx in frame_ids]).unsqueeze(0) 253 | if self._image_set in ('test-dev', 'test-challenge'): 254 | segannos = None 255 | given_segannos = [self._anno_read(self._full_anno_path(seqname, idx)).unsqueeze(0) 256 | if idx in anno_frame_ids else None for idx in frame_ids] 257 | else: 258 | segannos = torch.stack([self._anno_read(self._full_anno_path(seqname, idx)) 259 | for idx in frame_ids]).squeeze().unsqueeze(0) 260 | if self._version == '2016': 261 | segannos = (segannos != 0).long() 262 | given_segannos = [self._anno_read(self._full_anno_path(seqname, idx)).unsqueeze(0) 263 | if idx == self.get_anno_frame_ids(seqname)[0] else None for idx in frame_ids] 264 | for i in range(len(given_segannos)): # Remove dont-care from given segannos 265 | if given_segannos[i] is not None: 266 | given_segannos[i][given_segannos[i] == 255] = 0 267 | if self._version == '2016': 268 | given_segannos[i] = (given_segannos[i] != 0).long() 269 | 270 | fnames = [self._frame_idx_to_anno_fname(idx) for idx in frame_ids] 271 | return {'images':images, 'given_segannos': given_segannos, 'segannos':segannos, 'fnames':fnames} 272 | 273 | def _get_video(self, seqname): 274 | seq_frame_ids = self.get_frame_ids(seqname) 275 | partitioned_frame_ids = [seq_frame_ids[start_idx : start_idx + self._seqlen] 276 | for start_idx in range(0, len(seq_frame_ids), self._seqlen)] 277 | for frame_ids in partitioned_frame_ids: 278 | yield self._get_snippet(seqname, frame_ids) 279 | 280 | def get_video_generator(self): 281 | for seqname in self.get_all_seqnames(): 282 | yield (seqname, self._get_video(seqname)) 283 | 284 | -------------------------------------------------------------------------------- /dataset_loaders/helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch, cv2 4 | import random 5 | import numpy as np 6 | from PIL import Image 7 | 8 | 9 | def tens2image(im): 10 | if im.size()[0] == 1: 11 | tmp = np.squeeze(im.numpy(), axis=0) 12 | else: 13 | tmp = im.numpy() 14 | if tmp.ndim == 2: 15 | return tmp 16 | else: 17 | return tmp.transpose((1, 2, 0)) 18 | 19 | 20 | def crop2fullmask(crop_mask, bbox, im=None, im_size=None, zero_pad=False, relax=0, mask_relax=True, 21 | interpolation=cv2.INTER_CUBIC, scikit=False): 22 | if scikit: 23 | from skimage.transform import resize as sk_resize 24 | assert (not (im is None and im_size is None)), 'You have to provide an image or the image size' 25 | if im is None: 26 | im_si = im_size 27 | else: 28 | im_si = im.shape 29 | # Borers of image 30 | bounds = (0, 0, im_si[1] - 1, im_si[0] - 1) 31 | 32 | # Valid bounding box locations as (x_min, y_min, x_max, y_max) 33 | bbox_valid = (max(bbox[0], bounds[0]), 34 | max(bbox[1], bounds[1]), 35 | min(bbox[2], bounds[2]), 36 | min(bbox[3], bounds[3])) 37 | 38 | # Bounding box of initial mask 39 | bbox_init = (bbox[0] + relax, 40 | bbox[1] + relax, 41 | bbox[2] - relax, 42 | bbox[3] - relax) 43 | 44 | if zero_pad: 45 | # Offsets for x and y 46 | offsets = (-bbox[0], -bbox[1]) 47 | else: 48 | assert ((bbox == bbox_valid).all()) 49 | offsets = (-bbox_valid[0], -bbox_valid[1]) 50 | 51 | # Simple per element addition in the tuple 52 | inds = tuple(map(sum, zip(bbox_valid, offsets + offsets))) 53 | 54 | if scikit: 55 | crop_mask = sk_resize(crop_mask, (bbox[3] - bbox[1] + 1, bbox[2] - bbox[0] + 1), order=0, 56 | mode='constant').astype(crop_mask.dtype) 57 | else: 58 | crop_mask = cv2.resize(crop_mask, (bbox[2] - bbox[0] + 1, bbox[3] - bbox[1] + 1), interpolation=interpolation) 59 | result_ = np.zeros(im_si) 60 | result_[bbox_valid[1]:bbox_valid[3] + 1, bbox_valid[0]:bbox_valid[2] + 1] = \ 61 | crop_mask[inds[1]:inds[3] + 1, inds[0]:inds[2] + 1] 62 | 63 | result = np.zeros(im_si) 64 | if mask_relax: 65 | result[bbox_init[1]:bbox_init[3] + 1, bbox_init[0]:bbox_init[2] + 1] = \ 66 | result_[bbox_init[1]:bbox_init[3] + 1, bbox_init[0]:bbox_init[2] + 1] 67 | else: 68 | result = result_ 69 | 70 | return result 71 | 72 | 73 | def overlay_mask(im, ma, colors=None, alpha=0.5): 74 | assert np.max(im) <= 1.0 75 | if colors is None: 76 | colors = np.load(os.path.join(os.path.dirname(__file__), 'pascal_map.npy')) / 255. 77 | else: 78 | colors = np.append([[0., 0., 0.]], colors, axis=0); 79 | 80 | if ma.ndim == 3: 81 | assert len(colors) >= ma.shape[0], 'Not enough colors' 82 | ma = ma.astype(np.bool) 83 | im = im.astype(np.float32) 84 | 85 | if ma.ndim == 2: 86 | fg = im * alpha + np.ones(im.shape) * (1 - alpha) * colors[1, :3] # np.array([0,0,255])/255.0 87 | else: 88 | fg = [] 89 | for n in range(ma.ndim): 90 | fg.append(im * alpha + np.ones(im.shape) * (1 - alpha) * colors[1 + n, :3]) 91 | # Whiten background 92 | bg = im.copy() 93 | if ma.ndim == 2: 94 | bg[ma == 0] = im[ma == 0] 95 | bg[ma == 1] = fg[ma == 1] 96 | total_ma = ma 97 | else: 98 | total_ma = np.zeros([ma.shape[1], ma.shape[2]]) 99 | for n in range(ma.shape[0]): 100 | tmp_ma = ma[n, :, :] 101 | total_ma = np.logical_or(tmp_ma, total_ma) 102 | tmp_fg = fg[n] 103 | bg[tmp_ma == 1] = tmp_fg[tmp_ma == 1] 104 | bg[total_ma == 0] = im[total_ma == 0] 105 | 106 | # [-2:] is s trick to be compatible both with opencv 2 and 3 107 | contours = cv2.findContours(total_ma.copy().astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:] 108 | cv2.drawContours(bg, contours[0], -1, (0.0, 0.0, 0.0), 1) 109 | 110 | return bg 111 | 112 | 113 | def overlay_masks(im, masks, alpha=0.5): 114 | colors = np.load(os.path.join(os.path.dirname(__file__), 'pascal_map.npy')) / 255. 115 | 116 | if isinstance(masks, np.ndarray): 117 | masks = [masks] 118 | 119 | assert len(colors) >= len(masks), 'Not enough colors' 120 | 121 | ov = im.copy() 122 | im = im.astype(np.float32) 123 | total_ma = np.zeros([im.shape[0], im.shape[1]]) 124 | i = 1 125 | for ma in masks: 126 | ma = ma.astype(np.bool) 127 | fg = im * alpha + np.ones(im.shape) * (1 - alpha) * colors[i, :3] # np.array([0,0,255])/255.0 128 | i = i + 1 129 | ov[ma == 1] = fg[ma == 1] 130 | total_ma += ma 131 | 132 | # [-2:] is s trick to be compatible both with opencv 2 and 3 133 | contours = cv2.findContours(ma.copy().astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:] 134 | cv2.drawContours(ov, contours[0], -1, (0.0, 0.0, 0.0), 1) 135 | ov[total_ma == 0] = im[total_ma == 0] 136 | 137 | return ov 138 | 139 | 140 | def extreme_points(mask, pert): 141 | def find_point(id_x, id_y, ids): 142 | sel_id = ids[0][random.randint(0, len(ids[0]) - 1)] 143 | return [id_x[sel_id], id_y[sel_id]] 144 | 145 | # List of coordinates of the mask 146 | inds_y, inds_x = np.where(mask > 0.5) 147 | 148 | # Find extreme points 149 | return np.array([find_point(inds_x, inds_y, np.where(inds_x <= np.min(inds_x) + pert)), # left 150 | find_point(inds_x, inds_y, np.where(inds_x >= np.max(inds_x) - pert)), # right 151 | find_point(inds_x, inds_y, np.where(inds_y <= np.min(inds_y) + pert)), # top 152 | find_point(inds_x, inds_y, np.where(inds_y >= np.max(inds_y) - pert)) # bottom 153 | ]) 154 | 155 | 156 | def get_bbox(mask, points=None, pad=0, zero_pad=False): 157 | if points is not None: 158 | inds = np.flip(points.transpose(), axis=0) 159 | else: 160 | inds = np.where(mask > 0) 161 | 162 | if inds[0].shape[0] == 0: 163 | return None 164 | 165 | if zero_pad: 166 | x_min_bound = -np.inf 167 | y_min_bound = -np.inf 168 | x_max_bound = np.inf 169 | y_max_bound = np.inf 170 | else: 171 | x_min_bound = 0 172 | y_min_bound = 0 173 | x_max_bound = mask.shape[1] - 1 174 | y_max_bound = mask.shape[0] - 1 175 | 176 | x_min = max(inds[1].min() - pad, x_min_bound) 177 | y_min = max(inds[0].min() - pad, y_min_bound) 178 | x_max = min(inds[1].max() + pad, x_max_bound) 179 | y_max = min(inds[0].max() + pad, y_max_bound) 180 | 181 | return x_min, y_min, x_max, y_max 182 | 183 | 184 | def crop_from_bbox(img, bbox, zero_pad=False): 185 | # Borders of image 186 | bounds = (0, 0, img.shape[1] - 1, img.shape[0] - 1) 187 | 188 | # Valid bounding box locations as (x_min, y_min, x_max, y_max) 189 | bbox_valid = (max(bbox[0], bounds[0]), 190 | max(bbox[1], bounds[1]), 191 | min(bbox[2], bounds[2]), 192 | min(bbox[3], bounds[3])) 193 | 194 | if zero_pad: 195 | # Initialize crop size (first 2 dimensions) 196 | crop = np.zeros((bbox[3] - bbox[1] + 1, bbox[2] - bbox[0] + 1), dtype=img.dtype) 197 | 198 | # Offsets for x and y 199 | offsets = (-bbox[0], -bbox[1]) 200 | 201 | else: 202 | assert (bbox == bbox_valid) 203 | crop = np.zeros((bbox_valid[3] - bbox_valid[1] + 1, bbox_valid[2] - bbox_valid[0] + 1), dtype=img.dtype) 204 | offsets = (-bbox_valid[0], -bbox_valid[1]) 205 | 206 | # Simple per element addition in the tuple 207 | inds = tuple(map(sum, zip(bbox_valid, offsets + offsets))) 208 | 209 | img = np.squeeze(img) 210 | if img.ndim == 2: 211 | crop[inds[1]:inds[3] + 1, inds[0]:inds[2] + 1] = \ 212 | img[bbox_valid[1]:bbox_valid[3] + 1, bbox_valid[0]:bbox_valid[2] + 1] 213 | else: 214 | crop = np.tile(crop[:, :, np.newaxis], [1, 1, 3]) # Add 3 RGB Channels 215 | crop[inds[1]:inds[3] + 1, inds[0]:inds[2] + 1, :] = \ 216 | img[bbox_valid[1]:bbox_valid[3] + 1, bbox_valid[0]:bbox_valid[2] + 1, :] 217 | 218 | return crop 219 | 220 | 221 | def fixed_resize(sample, resolution, flagval=None): 222 | if flagval is None: 223 | if ((sample == 0) | (sample == 1)).all(): 224 | flagval = cv2.INTER_NEAREST 225 | else: 226 | flagval = cv2.INTER_CUBIC 227 | 228 | if isinstance(resolution, int): 229 | tmp = [resolution, resolution] 230 | tmp[np.argmax(sample.shape[:2])] = int( 231 | round(float(resolution) / np.min(sample.shape[:2]) * np.max(sample.shape[:2]))) 232 | resolution = tuple(tmp) 233 | 234 | if sample.ndim == 2 or (sample.ndim == 3 and sample.shape[2] == 3): 235 | sample = cv2.resize(sample, resolution[::-1], interpolation=flagval) 236 | else: 237 | tmp = sample 238 | sample = np.zeros(np.append(resolution, tmp.shape[2]), dtype=np.float32) 239 | for ii in range(sample.shape[2]): 240 | sample[:, :, ii] = cv2.resize(tmp[:, :, ii], resolution[::-1], interpolation=flagval) 241 | return sample 242 | 243 | 244 | def crop_from_mask(img, mask, relax=0, zero_pad=False): 245 | if mask.shape[:2] != img.shape[:2]: 246 | mask = cv2.resize(mask, dsize=tuple(reversed(img.shape[:2])), interpolation=cv2.INTER_NEAREST) 247 | 248 | assert (mask.shape[:2] == img.shape[:2]) 249 | 250 | bbox = get_bbox(mask, pad=relax, zero_pad=zero_pad) 251 | 252 | if bbox is None: 253 | return None 254 | 255 | crop = crop_from_bbox(img, bbox, zero_pad) 256 | 257 | return crop 258 | 259 | 260 | def make_gaussian(size, sigma=10, center=None, d_type=np.float64): 261 | """ Make a square gaussian kernel. 262 | size: is the dimensions of the output gaussian 263 | sigma: is full-width-half-maximum, which 264 | can be thought of as an effective radius. 265 | """ 266 | 267 | x = np.arange(0, size[1], 1, float) 268 | y = np.arange(0, size[0], 1, float) 269 | y = y[:, np.newaxis] 270 | 271 | if center is None: 272 | x0 = y0 = size[0] // 2 273 | else: 274 | x0 = center[0] 275 | y0 = center[1] 276 | 277 | return np.exp(-4 * np.log(2) * ((x - x0) ** 2 + (y - y0) ** 2) / sigma ** 2).astype(d_type) 278 | 279 | 280 | def make_gt(img, labels, sigma=10, one_mask_per_point=False): 281 | """ Make the ground-truth for landmark. 282 | img: the original color image 283 | labels: label with the Gaussian center(s) [[x0, y0],[x1, y1],...] 284 | sigma: sigma of the Gaussian. 285 | one_mask_per_point: masks for each point in different channels? 286 | """ 287 | h, w = img.shape[:2] 288 | if labels is None: 289 | gt = make_gaussian((h, w), center=(h // 2, w // 2), sigma=sigma) 290 | else: 291 | labels = np.array(labels) 292 | if labels.ndim == 1: 293 | labels = labels[np.newaxis] 294 | if one_mask_per_point: 295 | gt = np.zeros(shape=(h, w, labels.shape[0])) 296 | for ii in range(labels.shape[0]): 297 | gt[:, :, ii] = make_gaussian((h, w), center=labels[ii, :], sigma=sigma) 298 | else: 299 | gt = np.zeros(shape=(h, w), dtype=np.float64) 300 | for ii in range(labels.shape[0]): 301 | gt = np.maximum(gt, make_gaussian((h, w), center=labels[ii, :], sigma=sigma)) 302 | 303 | gt = gt.astype(dtype=img.dtype) 304 | 305 | return gt 306 | 307 | 308 | def cstm_normalize(im, max_value): 309 | """ 310 | Normalize image to range 0 - max_value 311 | """ 312 | imn = max_value * (im - im.min()) / max((im.max() - im.min()), 1e-8) 313 | return imn 314 | 315 | 316 | def generate_param_report(logfile, param): 317 | log_file = open(logfile, 'w') 318 | for key, val in param.items(): 319 | log_file.write(key + ':' + str(val) + '\n') 320 | log_file.close() 321 | 322 | 323 | def color_map(N=256, normalized=False): 324 | def bitget(byteval, idx): 325 | return ((byteval & (1 << idx)) != 0) 326 | 327 | dtype = 'float32' if normalized else 'uint8' 328 | cmap = np.zeros((N, 3), dtype=dtype) 329 | for i in range(N): 330 | r = g = b = 0 331 | c = i 332 | for j in range(8): 333 | r = r | (bitget(c, 0) << 7 - j) 334 | g = g | (bitget(c, 1) << 7 - j) 335 | b = b | (bitget(c, 2) << 7 - j) 336 | c = c >> 3 337 | 338 | cmap[i] = np.array([r, g, b]) 339 | 340 | cmap = cmap / 255 if normalized else cmap 341 | return cmap 342 | 343 | 344 | def save_mask(results, mask_path): 345 | mask = np.zeros(results[0].shape) 346 | for ii, r in enumerate(results): 347 | mask[r] = ii + 1 348 | result = Image.fromarray(mask.astype(np.uint8)) 349 | result.putpalette(list(color_map(80).flatten())) 350 | result.save(mask_path) 351 | -------------------------------------------------------------------------------- /dataset_loaders/s: -------------------------------------------------------------------------------- 1 | s 2 | -------------------------------------------------------------------------------- /dataset_loaders/ytvos_v2.py: -------------------------------------------------------------------------------- 1 | import random 2 | import glob 3 | import os 4 | import json 5 | from collections import OrderedDict 6 | import tqdm 7 | import numpy as np 8 | from PIL import Image 9 | import torch 10 | import torchvision as tv 11 | from dataset_loaders import custom_transforms as tr 12 | from dataset_loaders import dataset_utils 13 | import utils 14 | 15 | 16 | def get_sample_bernoulli(p): 17 | return (lambda lst: [elem for elem in lst if random.random() < p]) 18 | def get_sample_all(): 19 | return (lambda lst: lst) 20 | def get_sample_k_random(k): 21 | return (lambda lst: sorted(random.sample(lst, min(k,len(lst))))) 22 | 23 | def get_anno_ids(anno_path, pic_to_tensor_function, threshold): 24 | pic = Image.open(anno_path) 25 | tensor = pic_to_tensor_function(pic) 26 | values = (tensor.view(-1).bincount() > threshold).nonzero().view(-1).tolist() 27 | if 0 in values: values.remove(0) 28 | if 255 in values: values.remove(255) 29 | return values 30 | 31 | def get_default_image_read(size=(240,432)): 32 | def image_read(path): 33 | pic = Image.open(path) 34 | transform = tv.transforms.Compose( 35 | [tv.transforms.Resize(size, interpolation=Image.BILINEAR), 36 | tv.transforms.ToTensor(), 37 | tv.transforms.Normalize(mean=dataset_utils.IMAGENET_MEAN, std=dataset_utils.IMAGENET_STD)]) 38 | return transform(pic) 39 | return image_read 40 | def get_default_anno_read(size=(240,432)): 41 | def label_read(path): 42 | if os.path.exists(path): 43 | pic = Image.open(path) 44 | transform = tv.transforms.Compose( 45 | [tv.transforms.Resize(size, interpolation=Image.NEAREST), 46 | dataset_utils.LabelToLongTensor()]) 47 | label = transform(pic) 48 | else: 49 | label = torch.LongTensor(1,*size).fill_(255) # Put label that will be ignored 50 | return label 51 | return label_read 52 | 53 | class YTVOSV2(torch.utils.data.Dataset): 54 | def __init__(self, root_path, split, image_set, impath='JPEGImages', image_read=get_default_image_read(), 55 | anno_read=get_default_anno_read(),image_label_read=None, joint_transform=None, 56 | samplelen=4, obj_selection=get_sample_all(), min_num_obj=1, start_frame='random', max_skip=None): 57 | self._min_num_objects = min_num_obj 58 | self._root_path = root_path 59 | self._split = split 60 | self._image_set = image_set 61 | self._impath = impath 62 | self._image_read = image_read 63 | self._anno_read = anno_read 64 | self._image_anno_read = image_label_read 65 | self._joint_transform = joint_transform 66 | self.transform = None 67 | self._seqlen = samplelen 68 | self._obj_selection = obj_selection 69 | self._start_frame = start_frame 70 | assert ((image_set in ('train', 'train_joakim', 'val_joakim') and split == 'train') 71 | or (image_set is None and split == 'valid')) 72 | assert impath in ('JPEGImages', 'JPEGImages_all_frames') 73 | # assert samplelen > 1, "samplelen must be at least 2" 74 | assert start_frame in ('random','first') 75 | 76 | self._init_data() 77 | self._max_skip = max_skip 78 | 79 | def _init_data(self): 80 | """ Store some metadata that needs to be known during training. In order to sample, the viable sequences 81 | must be known. Sequences are viable if a snippet of given sample length can be selected, starting with 82 | an annotated frame and containing at least one more annotated frame. 83 | """ 84 | print("-- YTVOS dataset initialization started.") 85 | framework_path = os.path.join(os.path.dirname(__file__), '..') 86 | cache_path = os.path.join(framework_path, 'cache', 'ytvos_v2_{}_100px_threshold.json'.format(self._split)) 87 | 88 | # First find visible objects in all annotated frames 89 | if os.path.exists(cache_path): 90 | with open(cache_path, 'r') as f: 91 | self._visible_objects, self._resolutions = json.load(f) 92 | self._visible_objects = {seqname: OrderedDict((int(idx), objlst) for idx, objlst in val.items()) 93 | for seqname, val in self._visible_objects.items()} 94 | assert len(self._visible_objects) == len(self._resolutions) 95 | print("Datafile {} loaded, describing {} sequences.".format(cache_path, len(self._visible_objects))) 96 | else: 97 | # Grab all sequences in dataset 98 | seqnames = os.listdir(os.path.join(self._root_path, self._split, self._impath)) 99 | 100 | # Construct meta-info 101 | self._visible_objects = {} 102 | self._resolutions = {} 103 | for seqname in tqdm.tqdm(seqnames): 104 | anno_paths = sorted(glob.glob(self._full_anno_path(seqname, "*.png"))) 105 | self._visible_objects[seqname] = OrderedDict( 106 | (self._frame_name_to_idx(os.path.basename(path)), 107 | get_anno_ids(path, dataset_utils.LabelToLongTensor(), 100)) 108 | for path in anno_paths) 109 | self._resolutions[seqname] = Image.open(anno_paths[0]).size[::-1] 110 | 111 | # Save meta-info 112 | if not os.path.exists(os.path.dirname(cache_path)): 113 | os.makedirs(os.path.dirname(cache_path)) 114 | with open(cache_path, 'w') as f: 115 | json.dump((self._visible_objects, self._resolutions), f) 116 | print("Datafile {} was not found, creating it with {} sequences.".format( 117 | cache_path, len(self._visible_objects))) 118 | 119 | # Find sequences in the requested image_set 120 | if self._split == 'train': 121 | with open(os.path.join(self._root_path, "ImageSets", self._image_set + '.txt'), 'r') as f: 122 | self._all_seqs = f.read().splitlines() 123 | print("{} sequences found in image set \"{}\"".format(len(self._all_seqs), self._image_set)) 124 | else: 125 | self._all_seqs = os.listdir(os.path.join(self._root_path, self._split, "Annotations")) 126 | print("{} sequences found in the Annotations directory.".format(len(self._all_seqs))) 127 | 128 | # Filter out sequences that are too short from first frame with object, to last annotation 129 | self._nonempty_frame_ids = {seq: [frame_idx for frame_idx, obj_ids in lst.items() if len(obj_ids) >= self._min_num_objects] 130 | for seq, lst in self._visible_objects.items()} 131 | #self._viable_seqs = [seq for seq in self._all_seqs if 132 | # len(self._nonempty_frame_ids[seq]) > 0 133 | # and len(self.get_image_frame_ids(seq)[min(self._nonempty_frame_ids[seq]) : 134 | # max(self._visible_objects[seq].keys()) + 1]) 135 | # >= self._seqlen] 136 | self._viable_seqs = [seq for seq in self._all_seqs if 137 | len(self._nonempty_frame_ids[seq]) > 0 138 | and len(self.get_image_frame_ids(seq)[ 139 | self.get_image_frame_ids(seq).index(min(self._nonempty_frame_ids[seq])): 140 | self.get_image_frame_ids(seq).index(max(self._visible_objects[seq].keys())) + 1]) 141 | >= self._seqlen] 142 | print("{} sequences remaining after filtering on length (from first anno obj appearance to last anno frame.".format(len(self._viable_seqs))) 143 | 144 | # Filter out sequences with wrong resolution 145 | self._viable_seqs = [seq for seq in self._viable_seqs if tuple(self._resolutions[seq]) == (720,1280)] 146 | print("{} sequences remaining after filtering out sequences that are not in 720p.".format( 147 | len(self._viable_seqs))) 148 | 149 | def __len__(self): 150 | return len(self._viable_seqs) 151 | 152 | def _frame_idx_to_image_fname(self, idx): 153 | return "{:05d}.jpg".format(idx) 154 | 155 | def _frame_idx_to_anno_fname(self, idx): 156 | return "{:05d}.png".format(idx) 157 | 158 | def _frame_name_to_idx(self, fname): 159 | return int(os.path.splitext(fname)[0]) 160 | 161 | def get_viable_seqnames(self): 162 | return self._viable_seqs 163 | 164 | def get_all_seqnames(self): 165 | return self._all_seqs 166 | 167 | def get_anno_frame_names(self, seqname): 168 | return os.listdir(os.path.join(self._root_path, self._split, "Annotations", seqname)) 169 | 170 | def get_anno_frame_ids(self, seqname): 171 | return sorted([self._frame_name_to_idx(fname) for fname in self.get_anno_frame_names(seqname)]) 172 | 173 | def get_image_frame_names(self, seqname): 174 | return os.listdir(os.path.join(self._root_path, self._split, self._impath, seqname)) 175 | 176 | def get_image_frame_ids(self, seqname): 177 | return sorted([self._frame_name_to_idx(fname) for fname in self.get_image_frame_names(seqname)]) 178 | 179 | def get_frame_ids(self, seqname): 180 | """ Returns ids of all images that have idx higher than or equal to the first annotated frame""" 181 | all_frame_ids = sorted([self._frame_name_to_idx(fname) for fname in self.get_image_frame_names(seqname)]) 182 | min_anno_idx = min(self.get_anno_frame_ids(seqname)) 183 | frame_ids = [idx for idx in all_frame_ids if idx >= min_anno_idx] 184 | return frame_ids 185 | 186 | def get_nonempty_frame_ids(self, seqname): 187 | return self._nonempty_frame_ids[seqname] 188 | 189 | def _full_image_path(self, seqname, image): 190 | if isinstance(image, int): 191 | image = self._frame_idx_to_image_fname(image) 192 | return os.path.join(self._root_path, self._split, self._impath, seqname, image) 193 | 194 | def _full_anno_path(self, seqname, anno): 195 | if isinstance(anno, int): 196 | anno = self._frame_idx_to_anno_fname(anno) 197 | return os.path.join(self._root_path, self._split, "Annotations", seqname, anno) 198 | 199 | def _select_frame_ids(self, frame_ids, viable_starting_frame_ids): 200 | if self._start_frame == 'first': 201 | frame_idxidx = frame_ids.index(viable_starting_frame_ids[0]) 202 | elif self._start_frame == 'random': 203 | frame_idxidx = frame_ids.index(random.choice(viable_starting_frame_ids)) 204 | 205 | if self._max_skip is None: 206 | return frame_ids[frame_idxidx: frame_idxidx + self._seqlen] 207 | else: 208 | frame_ids_select = [] 209 | skip = random.randint(1, self._max_skip) 210 | sum_skip = skip * (self._seqlen - 1) 211 | if (sum_skip + self._seqlen) > (len(frame_ids) - frame_idxidx): 212 | skip = int((len(frame_ids) - frame_idxidx - self._seqlen) / (self._seqlen - 1)) 213 | idx_offset = 0 214 | for i in range(self._seqlen): 215 | frame_ids_select.append(frame_ids[frame_idxidx + idx_offset]) 216 | idx_offset = idx_offset + skip + 1 217 | return frame_ids_select 218 | 219 | def _select_object_ids(self, labels): 220 | assert labels.min() > -1 and labels.max() < 256, "{}".format(utils.print_tensor_statistics(labels)) 221 | possible_obj_ids = (labels[0].view(-1).bincount() > 10).nonzero().view(-1).tolist() 222 | if 0 in possible_obj_ids: possible_obj_ids.remove(0) 223 | if 255 in possible_obj_ids: possible_obj_ids.remove(255) 224 | #assert len(possible_obj_ids) > 0, "{}".format(labels[0].view(-1).bincount()) 225 | 226 | obj_ids = self._obj_selection(possible_obj_ids) 227 | bg_ids = (labels.view(-1).bincount() > 0).nonzero().view(-1).tolist() 228 | if 0 in bg_ids: bg_ids.remove(0) 229 | if 255 in bg_ids: bg_ids.remove(255) 230 | for idx in obj_ids: 231 | bg_ids.remove(idx) 232 | 233 | obj_ids.sort() 234 | for idx in bg_ids: 235 | labels[labels == idx] = 0 236 | for new_idx, old_idx in zip(range(1,len(obj_ids)+1), obj_ids): 237 | labels[labels == old_idx] = new_idx 238 | 239 | return labels 240 | 241 | def __getitem__(self, idx): 242 | """ 243 | returns: 244 | dict (Tensors): contains 'images', 'given_segmentations', 'labels' 245 | """ 246 | seqname = self.get_viable_seqnames()[idx] 247 | 248 | # We require to begin with a nonempty frame, and will consider all objects in that frame to be tracked. 249 | # A starting frame is valid if it is followed by seqlen-1 frames with corresp images 250 | frame_ids = self.get_frame_ids(seqname) 251 | viable_starting_frame_ids = [idx for idx in self.get_nonempty_frame_ids(seqname) 252 | if idx <= frame_ids[-self._seqlen]] 253 | 254 | frame_ids = self._select_frame_ids(frame_ids, viable_starting_frame_ids) 255 | if self.transform is not None: 256 | images = [] 257 | segannos = [] 258 | 259 | for idx in frame_ids: 260 | my_image = np.array(Image.open(self._full_image_path(seqname, idx)).convert('RGB')) / 255. 261 | # self._image_read(self._full_image_path(seqname, idx)) 262 | my_gt = np.array(Image.open(self._full_anno_path(seqname, idx))) 263 | #print('youtube range:', np.max(my_image), np.min(my_image), np.max(my_gt), np.min(my_gt)) 264 | #my_image = self._image_read(self._full_image_path(seqname, idx)) 265 | #my_gt = self._anno_read(self._full_anno_path(seqname, idx)) 266 | sample_copy = {'image': np.copy(my_image), 'gt': np.copy(my_gt)} 267 | sample_transformed = self.transform(sample_copy) 268 | images.append(sample_transformed['crop_image']) 269 | segannos.append(sample_transformed['crop_gt']) 270 | images = torch.stack(images, dim=0).float().clamp(0, 1) 271 | segannos = torch.stack(segannos, dim=0).long() 272 | 273 | if self._joint_transform is not None: 274 | #images = torch.stack([self._image_read(self._full_image_path(seqname, idx)) 275 | # for idx in frame_ids]) 276 | #segannos = torch.stack([self._anno_read(self._full_anno_path(seqname, idx)) 277 | # for idx in frame_ids]) 278 | images = [] 279 | segannos = [] 280 | #print('training samples:', seqname, frame_ids) 281 | for idx in frame_ids: 282 | temp1, temp2 = self._image_anno_read(self._full_image_path(seqname, idx), 283 | self._full_anno_path(seqname, idx)) 284 | images.append(temp1) 285 | segannos.append(temp2) 286 | images = torch.stack(images, dim=0).float().clamp(0, 1) 287 | segannos = torch.stack(segannos, dim=0) 288 | 289 | try: 290 | segannos = self._select_object_ids(segannos) 291 | except: 292 | print(seqname) 293 | print("frame ids ", self.get_frame_ids(seqname)) 294 | print("frame ids post filtering ", frame_ids) 295 | print("viable starting frame ids", viable_starting_frame_ids) 296 | print("visible objects", self._visible_objects[seqname]) 297 | raise 298 | if self._joint_transform is not None: 299 | images, segannos = self._joint_transform(images, segannos) 300 | segannos[segannos == 255] = 0 301 | given_seganno = segannos[0] 302 | provides_seganno = torch.empty((self._seqlen),dtype=torch.uint8).fill_(True) 303 | num_objects = int(segannos.max()) 304 | #print('youtube:', num_objects) 305 | 306 | return {'images':images, 'provides_seganno': provides_seganno, 'given_seganno':given_seganno, 'segannos':segannos, 'seqname':seqname, 'num_objects':num_objects} 307 | 308 | def _get_snippet(self, seqname, frame_ids): 309 | images = torch.stack( 310 | [self._image_read(self._full_image_path(seqname, idx)) for idx in frame_ids]).unsqueeze(0) 311 | if self._split == 'valid': 312 | segannos = None 313 | anno_frame_ids = self.get_anno_frame_ids(seqname) 314 | given_segannos = [self._anno_read(self._full_anno_path(seqname, idx)).unsqueeze(0) 315 | if idx in anno_frame_ids else None for idx in frame_ids] 316 | else: 317 | segannos = torch.stack([self._anno_read(self._full_anno_path(seqname, idx)) 318 | for idx in frame_ids]).squeeze().unsqueeze(0) 319 | given_segannos = [self._anno_read(self._full_anno_path(seqname, idx)).unsqueeze(0) 320 | if idx == self.get_anno_frame_ids(seqname)[0] else None for idx in frame_ids] 321 | for i in range(len(given_segannos)): # Remove dont-care from given segannos 322 | if given_segannos[i] is not None: 323 | given_segannos[i][given_segannos[i] == 255] = 0 324 | 325 | fnames = [self._frame_idx_to_anno_fname(idx) for idx in frame_ids] 326 | return {'images':images, 'given_segannos': given_segannos, 'segannos':segannos, 'fnames':fnames} 327 | 328 | def _get_video(self, seqname): 329 | seq_frame_ids = self.get_frame_ids(seqname) 330 | partitioned_frame_ids = [seq_frame_ids[start_idx : start_idx + self._seqlen] 331 | for start_idx in range(0, len(seq_frame_ids), self._seqlen)] 332 | for frame_ids in partitioned_frame_ids: 333 | yield self._get_snippet(seqname, frame_ids) 334 | 335 | def get_video_generator(self, low=0, high=2**31): 336 | """Returns a video generator. The video generator is used to obtain parts of a sequence. Some assumptions are made, depending on whether the train or valid splits are used. For the train split, the first annotated frame is given. No other annotation is used. For the validation split, each annotation found is given. 337 | """ 338 | sequences = self.get_all_seqnames()[low:high] 339 | # NO LONGER NEEDED, now only frame ids coming after an annotated frame are utilized 340 | if self._split == 'train': # These sequences are Empty in the first frame 341 | sequences.remove('d6917db4be') 342 | sequences.remove('d0c65e9e95') 343 | sequences.remove('c130c3fc0c') 344 | for seqname in sequences: 345 | yield (seqname, self._get_video(seqname)) 346 | 347 | 348 | -------------------------------------------------------------------------------- /dataset_loaders/ytvos_v2_org.py: -------------------------------------------------------------------------------- 1 | import random 2 | import glob 3 | import os 4 | import json 5 | from collections import OrderedDict 6 | import tqdm 7 | 8 | from PIL import Image 9 | import torch 10 | import torchvision as tv 11 | 12 | from dataset_loaders import dataset_utils 13 | import utils 14 | 15 | 16 | def get_sample_bernoulli(p): 17 | return (lambda lst: [elem for elem in lst if random.random() < p]) 18 | def get_sample_all(): 19 | return (lambda lst: lst) 20 | def get_sample_k_random(k): 21 | return (lambda lst: sorted(random.sample(lst, min(k,len(lst))))) 22 | 23 | def get_anno_ids(anno_path, pic_to_tensor_function, threshold): 24 | pic = Image.open(anno_path) 25 | tensor = pic_to_tensor_function(pic) 26 | values = (tensor.view(-1).bincount() > threshold).nonzero().view(-1).tolist() 27 | if 0 in values: values.remove(0) 28 | if 255 in values: values.remove(255) 29 | return values 30 | 31 | def get_default_image_read(size=(240,432)): 32 | def image_read(path): 33 | pic = Image.open(path) 34 | transform = tv.transforms.Compose( 35 | [tv.transforms.Resize(size, interpolation=Image.BILINEAR), 36 | tv.transforms.ToTensor(), 37 | tv.transforms.Normalize(mean=dataset_utils.IMAGENET_MEAN, std=dataset_utils.IMAGENET_STD)]) 38 | return transform(pic) 39 | return image_read 40 | def get_default_anno_read(size=(240,432)): 41 | def label_read(path): 42 | if os.path.exists(path): 43 | pic = Image.open(path) 44 | transform = tv.transforms.Compose( 45 | [tv.transforms.Resize(size, interpolation=Image.NEAREST), 46 | dataset_utils.LabelToLongTensor()]) 47 | label = transform(pic) 48 | else: 49 | label = torch.LongTensor(1,*size).fill_(255) # Put label that will be ignored 50 | return label 51 | return label_read 52 | 53 | class YTVOSV2(torch.utils.data.Dataset): 54 | def __init__(self, root_path, split, image_set, impath='JPEGImages', image_read=get_default_image_read(), 55 | anno_read=get_default_anno_read(), joint_transform=None, 56 | samplelen=4, obj_selection=get_sample_all(), min_num_obj=1, start_frame='random', max_skip=None): 57 | self._min_num_objects = min_num_obj 58 | self._root_path = root_path 59 | self._split = split 60 | self._image_set = image_set 61 | self._impath = impath 62 | self._image_read = image_read 63 | self._anno_read = anno_read 64 | self._joint_transform = joint_transform 65 | self._seqlen = samplelen 66 | self._obj_selection = obj_selection 67 | self._start_frame = start_frame 68 | assert ((image_set in ('train', 'train_joakim', 'val_joakim') and split == 'train') 69 | or (image_set is None and split == 'valid')) 70 | assert impath in ('JPEGImages', 'JPEGImages_all_frames') 71 | # assert samplelen > 1, "samplelen must be at least 2" 72 | assert start_frame in ('random','first') 73 | 74 | self._init_data() 75 | self._max_skip = max_skip 76 | 77 | def _init_data(self): 78 | """ Store some metadata that needs to be known during training. In order to sample, the viable sequences 79 | must be known. Sequences are viable if a snippet of given sample length can be selected, starting with 80 | an annotated frame and containing at least one more annotated frame. 81 | """ 82 | print("-- YTVOS dataset initialization started.") 83 | framework_path = os.path.join(os.path.dirname(__file__), '..') 84 | cache_path = os.path.join(framework_path, 'cache', 'ytvos_v2_{}_100px_threshold.json'.format(self._split)) 85 | 86 | # First find visible objects in all annotated frames 87 | if os.path.exists(cache_path): 88 | with open(cache_path, 'r') as f: 89 | self._visible_objects, self._resolutions = json.load(f) 90 | self._visible_objects = {seqname: OrderedDict((int(idx), objlst) for idx, objlst in val.items()) 91 | for seqname, val in self._visible_objects.items()} 92 | assert len(self._visible_objects) == len(self._resolutions) 93 | print("Datafile {} loaded, describing {} sequences.".format(cache_path, len(self._visible_objects))) 94 | else: 95 | # Grab all sequences in dataset 96 | seqnames = os.listdir(os.path.join(self._root_path, self._split, self._impath)) 97 | 98 | # Construct meta-info 99 | self._visible_objects = {} 100 | self._resolutions = {} 101 | for seqname in tqdm.tqdm(seqnames): 102 | anno_paths = sorted(glob.glob(self._full_anno_path(seqname, "*.png"))) 103 | self._visible_objects[seqname] = OrderedDict( 104 | (self._frame_name_to_idx(os.path.basename(path)), 105 | get_anno_ids(path, dataset_utils.LabelToLongTensor(), 100)) 106 | for path in anno_paths) 107 | self._resolutions[seqname] = Image.open(anno_paths[0]).size[::-1] 108 | 109 | # Save meta-info 110 | if not os.path.exists(os.path.dirname(cache_path)): 111 | os.makedirs(os.path.dirname(cache_path)) 112 | with open(cache_path, 'w') as f: 113 | json.dump((self._visible_objects, self._resolutions), f) 114 | print("Datafile {} was not found, creating it with {} sequences.".format( 115 | cache_path, len(self._visible_objects))) 116 | 117 | # Find sequences in the requested image_set 118 | if self._split == 'train': 119 | with open(os.path.join(self._root_path, "ImageSets", self._image_set + '.txt'), 'r') as f: 120 | self._all_seqs = f.read().splitlines() 121 | print("{} sequences found in image set \"{}\"".format(len(self._all_seqs), self._image_set)) 122 | else: 123 | self._all_seqs = os.listdir(os.path.join(self._root_path, self._split, "Annotations")) 124 | print("{} sequences found in the Annotations directory.".format(len(self._all_seqs))) 125 | 126 | # Filter out sequences that are too short from first frame with object, to last annotation 127 | self._nonempty_frame_ids = {seq: [frame_idx for frame_idx, obj_ids in lst.items() if len(obj_ids) >= self._min_num_objects] 128 | for seq, lst in self._visible_objects.items()} 129 | self._viable_seqs = [seq for seq in self._all_seqs if 130 | len(self._nonempty_frame_ids[seq]) > 0 131 | and len(self.get_image_frame_ids(seq)[min(self._nonempty_frame_ids[seq]) : 132 | max(self._visible_objects[seq].keys()) + 1]) 133 | >= self._seqlen] 134 | print("{} sequences remaining after filtering on length (from first anno obj appearance to last anno frame.".format(len(self._viable_seqs))) 135 | 136 | # Filter out sequences with wrong resolution 137 | self._viable_seqs = [seq for seq in self._viable_seqs if tuple(self._resolutions[seq]) == (720,1280)] 138 | print("{} sequences remaining after filtering out sequences that are not in 720p.".format( 139 | len(self._viable_seqs))) 140 | 141 | def __len__(self): 142 | return len(self._viable_seqs) 143 | 144 | def _frame_idx_to_image_fname(self, idx): 145 | return "{:05d}.jpg".format(idx) 146 | 147 | def _frame_idx_to_anno_fname(self, idx): 148 | return "{:05d}.png".format(idx) 149 | 150 | def _frame_name_to_idx(self, fname): 151 | return int(os.path.splitext(fname)[0]) 152 | 153 | def get_viable_seqnames(self): 154 | return self._viable_seqs 155 | 156 | def get_all_seqnames(self): 157 | return self._all_seqs 158 | 159 | def get_anno_frame_names(self, seqname): 160 | return os.listdir(os.path.join(self._root_path, self._split, "Annotations", seqname)) 161 | 162 | def get_anno_frame_ids(self, seqname): 163 | return sorted([self._frame_name_to_idx(fname) for fname in self.get_anno_frame_names(seqname)]) 164 | 165 | def get_image_frame_names(self, seqname): 166 | return os.listdir(os.path.join(self._root_path, self._split, self._impath, seqname)) 167 | 168 | def get_image_frame_ids(self, seqname): 169 | return sorted([self._frame_name_to_idx(fname) for fname in self.get_image_frame_names(seqname)]) 170 | 171 | def get_frame_ids(self, seqname): 172 | """ Returns ids of all images that have idx higher than or equal to the first annotated frame""" 173 | all_frame_ids = sorted([self._frame_name_to_idx(fname) for fname in self.get_image_frame_names(seqname)]) 174 | min_anno_idx = min(self.get_anno_frame_ids(seqname)) 175 | frame_ids = [idx for idx in all_frame_ids if idx >= min_anno_idx] 176 | return frame_ids 177 | 178 | def get_nonempty_frame_ids(self, seqname): 179 | return self._nonempty_frame_ids[seqname] 180 | 181 | def _full_image_path(self, seqname, image): 182 | if isinstance(image, int): 183 | image = self._frame_idx_to_image_fname(image) 184 | return os.path.join(self._root_path, self._split, self._impath, seqname, image) 185 | 186 | def _full_anno_path(self, seqname, anno): 187 | if isinstance(anno, int): 188 | anno = self._frame_idx_to_anno_fname(anno) 189 | return os.path.join(self._root_path, self._split, "Annotations", seqname, anno) 190 | 191 | def _select_frame_ids(self, frame_ids, viable_starting_frame_ids): 192 | if self._start_frame == 'first': 193 | frame_idxidx = frame_ids.index(viable_starting_frame_ids[0]) 194 | elif self._start_frame == 'random': 195 | frame_idxidx = frame_ids.index(random.choice(viable_starting_frame_ids)) 196 | 197 | if self._max_skip is None: 198 | return frame_ids[frame_idxidx: frame_idxidx + self._seqlen] 199 | else: 200 | frame_ids_select = [] 201 | skip = random.randint(1, self._max_skip) 202 | sum_skip = skip * (self._seqlen - 1) 203 | if (sum_skip + self._seqlen) > (len(frame_ids) - frame_idxidx): 204 | skip = int((len(frame_ids) - frame_idxidx - self._seqlen) / (self._seqlen - 1)) 205 | idx_offset = 0 206 | for i in range(self._seqlen): 207 | frame_ids_select.append(frame_ids[frame_idxidx + idx_offset]) 208 | idx_offset = idx_offset + skip + 1 209 | return frame_ids_select 210 | 211 | def _select_object_ids(self, labels): 212 | assert labels.min() > -1 and labels.max() < 256, "{}".format(utils.print_tensor_statistics(labels)) 213 | possible_obj_ids = (labels[0].view(-1).bincount() > 10).nonzero().view(-1).tolist() 214 | if 0 in possible_obj_ids: possible_obj_ids.remove(0) 215 | if 255 in possible_obj_ids: possible_obj_ids.remove(255) 216 | assert len(possible_obj_ids) > 0, "{}".format(labels[0].view(-1).bincount()) 217 | 218 | obj_ids = self._obj_selection(possible_obj_ids) 219 | bg_ids = (labels.view(-1).bincount() > 0).nonzero().view(-1).tolist() 220 | if 0 in bg_ids: bg_ids.remove(0) 221 | if 255 in bg_ids: bg_ids.remove(255) 222 | for idx in obj_ids: 223 | bg_ids.remove(idx) 224 | 225 | for idx in bg_ids: 226 | labels[labels == idx] = 0 227 | for new_idx, old_idx in zip(range(1,len(obj_ids)+1), obj_ids): 228 | labels[labels == old_idx] = new_idx 229 | 230 | return labels 231 | 232 | def __getitem__(self, idx): 233 | """ 234 | returns: 235 | dict (Tensors): contains 'images', 'given_segmentations', 'labels' 236 | """ 237 | seqname = self.get_viable_seqnames()[idx] 238 | 239 | # We require to begin with a nonempty frame, and will consider all objects in that frame to be tracked. 240 | # A starting frame is valid if it is followed by seqlen-1 frames with corresp images 241 | frame_ids = self.get_frame_ids(seqname) 242 | viable_starting_frame_ids = [idx for idx in self.get_nonempty_frame_ids(seqname) 243 | if idx <= frame_ids[-self._seqlen]] 244 | 245 | frame_ids = self._select_frame_ids(frame_ids, viable_starting_frame_ids) 246 | 247 | images = torch.stack([self._image_read(self._full_image_path(seqname, idx)) 248 | for idx in frame_ids]) 249 | segannos = torch.stack([self._anno_read(self._full_anno_path(seqname, idx)) 250 | for idx in frame_ids]) 251 | 252 | try: 253 | segannos = self._select_object_ids(segannos) 254 | except: 255 | print(seqname) 256 | print("frame ids ", self.get_frame_ids(seqname)) 257 | print("frame ids post filtering ", frame_ids) 258 | print("viable starting frame ids", viable_starting_frame_ids) 259 | print("visible objects", self._visible_objects[seqname]) 260 | raise 261 | if self._joint_transform is not None: 262 | images, segannos = self._joint_transform(images, segannos) 263 | segannos[segannos == 255] = 0 264 | given_seganno = segannos[0] 265 | provides_seganno = torch.empty((self._seqlen),dtype=torch.uint8).fill_(True) 266 | num_objects = int(segannos.max()) 267 | 268 | return {'images':images, 'provides_seganno': provides_seganno, 'given_seganno':given_seganno, 'segannos':segannos, 'seqname':seqname, 'num_objects':num_objects} 269 | 270 | def _get_snippet(self, seqname, frame_ids): 271 | images = torch.stack( 272 | [self._image_read(self._full_image_path(seqname, idx)) for idx in frame_ids]).unsqueeze(0) 273 | if self._split == 'valid': 274 | segannos = None 275 | anno_frame_ids = self.get_anno_frame_ids(seqname) 276 | given_segannos = [self._anno_read(self._full_anno_path(seqname, idx)).unsqueeze(0) 277 | if idx in anno_frame_ids else None for idx in frame_ids] 278 | else: 279 | segannos = torch.stack([self._anno_read(self._full_anno_path(seqname, idx)) 280 | for idx in frame_ids]).squeeze().unsqueeze(0) 281 | given_segannos = [self._anno_read(self._full_anno_path(seqname, idx)).unsqueeze(0) 282 | if idx == self.get_anno_frame_ids(seqname)[0] else None for idx in frame_ids] 283 | for i in range(len(given_segannos)): # Remove dont-care from given segannos 284 | if given_segannos[i] is not None: 285 | given_segannos[i][given_segannos[i] == 255] = 0 286 | 287 | fnames = [self._frame_idx_to_anno_fname(idx) for idx in frame_ids] 288 | return {'images':images, 'given_segannos': given_segannos, 'segannos':segannos, 'fnames':fnames} 289 | 290 | def _get_video(self, seqname): 291 | seq_frame_ids = self.get_frame_ids(seqname) 292 | partitioned_frame_ids = [seq_frame_ids[start_idx : start_idx + self._seqlen] 293 | for start_idx in range(0, len(seq_frame_ids), self._seqlen)] 294 | for frame_ids in partitioned_frame_ids: 295 | yield self._get_snippet(seqname, frame_ids) 296 | 297 | def get_video_generator(self, low=0, high=2**31): 298 | """Returns a video generator. The video generator is used to obtain parts of a sequence. Some assumptions are made, depending on whether the train or valid splits are used. For the train split, the first annotated frame is given. No other annotation is used. For the validation split, each annotation found is given. 299 | """ 300 | sequences = self.get_all_seqnames()[low:high] 301 | # NO LONGER NEEDED, now only frame ids coming after an annotated frame are utilized 302 | if self._split == 'train': # These sequences are Empty in the first frame 303 | sequences.remove('d6917db4be') 304 | sequences.remove('d0c65e9e95') 305 | sequences.remove('c130c3fc0c') 306 | for seqname in sequences: 307 | yield (seqname, self._get_video(seqname)) 308 | 309 | 310 | -------------------------------------------------------------------------------- /eccv-framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carrierlxk/GraphMemVOS/a3487bcf67ce479c3774890403362170f275095f/eccv-framework.png -------------------------------------------------------------------------------- /local_config.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Put these paths to point to your datasets 4 | config = { 5 | 'workspace_path' : "/home/ubuntu/xiankai/STM_train_V1/", 6 | 'davis_path' : "/home/ubuntu/xiankai/dataset/DAVIS-2016/", 7 | 'davis16_path' : "/fastdata/davis/2016/", 8 | 'davis17_path' : "/raid/DAVIS/DAVIS-2017/DAVIS-train-val/", 9 | 'ytvos_path' : "/raid/Youtube-VOS-2019/", 10 | 'output_path' : "/raid/STM_train_V1/output/", 11 | 'nn_weights_path' : "/home/cgv841/gwb/Code/agame-vos-master/" 12 | } 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbones import * 2 | from .graph_memory import graph_memory 3 | 4 | -------------------------------------------------------------------------------- /models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | -------------------------------------------------------------------------------- /models/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | from local_config import config 8 | 9 | 10 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 11 | 'resnet101_block34', 'resnet101_block14', 'resnet101s16', 'resnet50s16', 'resnet101s16v2'] 12 | 13 | 14 | model_urls = { 15 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 16 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 17 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 18 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 19 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 20 | } 21 | 22 | 23 | def conv3x3(in_planes, out_planes, stride=1): 24 | """3x3 convolution with padding""" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 26 | padding=1, bias=False) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None): 33 | super(BasicBlock, self).__init__() 34 | self.conv1 = conv3x3(inplanes, planes, stride) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None): 65 | super(Bottleneck, self).__init__() 66 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 67 | self.bn1 = nn.BatchNorm2d(planes) 68 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 69 | padding=1, bias=False) 70 | self.bn2 = nn.BatchNorm2d(planes) 71 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 72 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 73 | self.relu = nn.ReLU(inplace=True) 74 | self.downsample = downsample 75 | self.stride = stride 76 | 77 | def forward(self, x): 78 | residual = x 79 | 80 | out = self.conv1(x) 81 | out = self.bn1(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv2(out) 85 | out = self.bn2(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv3(out) 89 | out = self.bn3(out) 90 | 91 | if self.downsample is not None: 92 | residual = self.downsample(x) 93 | 94 | out += residual 95 | out = self.relu(out) 96 | 97 | return out 98 | 99 | 100 | class ResNet(nn.Module): 101 | 102 | def __init__(self, block, layers, num_classes=1000): 103 | self.inplanes = 64 104 | super(ResNet, self).__init__() 105 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 106 | bias=False) 107 | self.bn1 = nn.BatchNorm2d(64) 108 | self.relu = nn.ReLU(inplace=True) 109 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 110 | self.layer1 = self._make_layer(block, 64, layers[0]) 111 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 112 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 113 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 114 | self.avgpool = nn.AvgPool2d(7, stride=1) 115 | self.fc = nn.Linear(512 * block.expansion, num_classes) 116 | 117 | for m in self.modules(): 118 | if isinstance(m, nn.Conv2d): 119 | # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 120 | nn.init.kaiming_normal(m.weight, mode='fan_out') 121 | elif isinstance(m, nn.BatchNorm2d): 122 | # nn.init.constant_(m.weight, 1) 123 | # nn.init.constant_(m.bias, 0) 124 | nn.init.constant(m.weight, 1) 125 | nn.init.constant(m.bias, 0) 126 | 127 | def _make_layer(self, block, planes, blocks, stride=1): 128 | downsample = None 129 | if stride != 1 or self.inplanes != planes * block.expansion: 130 | downsample = nn.Sequential( 131 | nn.Conv2d(self.inplanes, planes * block.expansion, 132 | kernel_size=1, stride=stride, bias=False), 133 | nn.BatchNorm2d(planes * block.expansion), 134 | ) 135 | 136 | layers = [] 137 | layers.append(block(self.inplanes, planes, stride, downsample)) 138 | self.inplanes = planes * block.expansion 139 | for i in range(1, blocks): 140 | layers.append(block(self.inplanes, planes)) 141 | 142 | return nn.Sequential(*layers) 143 | 144 | def get_features(self, x): 145 | feats = [] 146 | x = self.conv1(x) 147 | x = self.bn1(x) 148 | x = self.relu(x) 149 | feats += [x] 150 | x = self.maxpool(x) 151 | 152 | x = self.layer1(x) 153 | feats += [x] 154 | x = self.layer2(x) 155 | feats += [x] 156 | x = self.layer3(x) 157 | feats += [x] 158 | x = self.layer4(x) 159 | feats += [x] 160 | 161 | return feats 162 | 163 | 164 | def forward(self, x): 165 | x = self.conv1(x) 166 | x = self.bn1(x) 167 | x = self.relu(x) 168 | x = self.maxpool(x) 169 | 170 | x = self.layer1(x) 171 | x = self.layer2(x) 172 | x = self.layer3(x) 173 | x = self.layer4(x) 174 | 175 | x = self.avgpool(x) 176 | x = x.view(x.size(0), -1) 177 | x = self.fc(x) 178 | 179 | return x 180 | 181 | class ResNet101Block34(ResNet): 182 | def get_features(self, x): 183 | feats = [] 184 | x = self.conv1(x) 185 | x = self.bn1(x) 186 | x = self.relu(x) 187 | feats.append(x) 188 | x = self.maxpool(x) 189 | 190 | x = self.layer1(x) 191 | feats.append(x) 192 | x = self.layer2(x) 193 | feats.append(x) 194 | x = self.layer3(x) 195 | feats.append(x) 196 | x = self.layer4(x) 197 | feats.append(torch.cat([x, feats[-1]], dim=-3)) 198 | 199 | return feats 200 | 201 | class ResNet101Block14(ResNet): 202 | def get_features(self, x): 203 | feats = [] 204 | x = self.conv1(x) 205 | x = self.bn1(x) 206 | x = self.relu(x) 207 | feats.append(x) 208 | x = self.maxpool(x) 209 | 210 | x = self.layer1(x) 211 | feats.append(x) 212 | x = self.layer2(x) 213 | feats.append(x) 214 | x = self.layer3(x) 215 | feats.append(x) 216 | x = self.layer4(x) 217 | feats.append(torch.cat([x, F.avg_pool2d(feats[1], 4)], dim=-3)) 218 | 219 | return feats 220 | 221 | class ResNetS16(ResNet): 222 | def __init__(self, finetune_layers, s16_feats, s8_feats, s4_feats, block, layers, num_classes=1000): 223 | super().__init__(block, layers, num_classes) 224 | self.finetune_layers = finetune_layers 225 | self.s16_feats = s16_feats 226 | self.s8_feats = s8_feats 227 | self.s4_feats = s4_feats 228 | 229 | # Set strided convolutions, deeplab-style 230 | self.layer4[0].downsample[0].stride = (1,1) 231 | self.layer4[0].conv2.stride = (1,1) 232 | for layer in self.layer4[1:]: 233 | layer.conv2.dilation = (2,2) 234 | layer.conv2.padding = (2,2) 235 | 236 | # Make only part of the feature extractor trainable 237 | self.requires_grad = False 238 | for param in self.parameters(): 239 | param.requires_grad = False 240 | for module_name in finetune_layers: 241 | getattr(self, module_name).train(True) 242 | getattr(self, module_name).requires_grad = True 243 | for param in getattr(self, module_name).parameters(): 244 | param.requires_grad = True 245 | 246 | def get_return_values(self, feats): 247 | return {'s16': torch.cat([feats[name] for name in self.s16_feats], dim=-1), 248 | 's8': torch.cat([feats[name] for name in self.s8_feats], dim=-1), 249 | 's4': torch.cat([feats[name] for name in self.s4_feats], dim=-1)} 250 | 251 | def get_features(self, x): 252 | feats = {} 253 | x = self.conv1(x) 254 | x = self.bn1(x) 255 | x = self.relu(x) 256 | feats['conv1'] = x 257 | x = self.maxpool(x) 258 | x = self.layer1(x) 259 | feats['layer1'] = x 260 | x = self.layer2(x) 261 | feats['layer2'] = x 262 | x = self.layer3(x) 263 | feats['layer3'] = x 264 | x = self.layer4(x) 265 | feats['layer4'] = x 266 | return self.get_return_values(feats) 267 | 268 | def train(self, mode): 269 | for name, module in self.named_children(): 270 | if name in self.finetune_layers: 271 | module.train(mode) 272 | else: 273 | module.train(False) # Frozen layers are never to be trained 274 | def eval(self): 275 | self.train(False) 276 | 277 | class ResNetS16V2(ResNetS16): 278 | def get_return_values(self, feats): 279 | return {'s16': feats['layer4'], 'layer3': feats['layer3'], 'layer2': feats['layer2'], 'layer1': feats['layer1'], 'conv1': feats['conv1']} 280 | 281 | def resnet18(pretrained=False, **kwargs): 282 | """Constructs a ResNet-18 model. 283 | 284 | Args: 285 | pretrained (bool): If True, returns a model pre-trained on ImageNet 286 | """ 287 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 288 | if pretrained: 289 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'], model_dir=config['nn_weights_path'])) 290 | return model 291 | 292 | 293 | def resnet34(pretrained=False, **kwargs): 294 | """Constructs a ResNet-34 model. 295 | 296 | Args: 297 | pretrained (bool): If True, returns a model pre-trained on ImageNet 298 | """ 299 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 300 | if pretrained: 301 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'], model_dir=config['nn_weights_path'])) 302 | return model 303 | 304 | 305 | def resnet50(pretrained=False, **kwargs): 306 | """Constructs a ResNet-50 model. 307 | 308 | Args: 309 | pretrained (bool): If True, returns a model pre-trained on ImageNet 310 | """ 311 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 312 | if pretrained: 313 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'], model_dir=config['nn_weights_path'])) 314 | return model 315 | 316 | 317 | def resnet101(pretrained=False, **kwargs): 318 | """Constructs a ResNet-101 model. 319 | 320 | Args: 321 | pretrained (bool): If True, returns a model pre-trained on ImageNet 322 | """ 323 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 324 | if pretrained: 325 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'], model_dir=config['nn_weights_path'])) 326 | return model 327 | 328 | 329 | def resnet101_block34(pretrained=False, **kwargs): 330 | """Constructs a ResNet-101 model. 331 | 332 | Args: 333 | pretrained (bool): If True, returns a model pre-trained on ImageNet 334 | """ 335 | model = ResNet101Block34(Bottleneck, [3, 4, 23, 3], **kwargs) 336 | if pretrained: 337 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'], model_dir=config['nn_weights_path'])) 338 | return model 339 | 340 | 341 | def resnet101_block14(pretrained=False, **kwargs): 342 | """Constructs a ResNet-101 model. 343 | 344 | Args: 345 | pretrained (bool): If True, returns a model pre-trained on ImageNet 346 | """ 347 | model = ResNet101Block14(Bottleneck, [3, 4, 23, 3], **kwargs) 348 | if pretrained: 349 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'], model_dir=config['nn_weights_path'])) 350 | return model 351 | 352 | def resnet101s16(pretrained=False, finetune_layers=(), s16_feats=('layer4',), s8_feats=('layer2',), 353 | s4_feats=('layer1',), **kwargs): 354 | """Constructs a ResNet-101 model. 355 | 356 | Args: 357 | pretrained (bool): If True, returns a model pre-trained on ImageNet 358 | """ 359 | model = ResNetS16(finetune_layers, s16_feats, s8_feats, s4_feats, Bottleneck, [3, 4, 23, 3], **kwargs) 360 | if pretrained: 361 | model.load_state_dict(torch.load('/home/cgv841/gwb/Models/resnet101-5d3b4d8f.pth')) 362 | return model 363 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'], model_dir=config['nn_weights_path'])) 364 | return model 365 | 366 | def resnet101s16v2(pretrained=False, finetune_layers=(), s16_feats=('layer4',), s8_feats=('layer2',), 367 | s4_feats=('layer1',), **kwargs): 368 | """Constructs a ResNet-101 model. 369 | 370 | Args: 371 | pretrained (bool): If True, returns a model pre-trained on ImageNet 372 | """ 373 | model = ResNetS16V2(finetune_layers, s16_feats, s8_feats, s4_feats, Bottleneck, [3, 4, 23, 3], **kwargs) 374 | if pretrained: 375 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'], model_dir=config['nn_weights_path'])) 376 | return model 377 | 378 | def resnet50s16(pretrained=False, finetune_layers=(), s16_feats=('layer4',), s8_feats=('layer2',), 379 | s4_feats=('layer1',), **kwargs): 380 | """Constructs a ResNet-101 model. 381 | 382 | Args: 383 | pretrained (bool): If True, returns a model pre-trained on ImageNet 384 | """ 385 | model = ResNetS16(finetune_layers, s16_feats, s8_feats, s4_feats, Bottleneck, [3, 4, 6, 3], **kwargs) 386 | if pretrained: 387 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'], model_dir=config['nn_weights_path'])) 388 | return model 389 | 390 | def resnet152(pretrained=False, **kwargs): 391 | """Constructs a ResNet-152 model. 392 | 393 | Args: 394 | pretrained (bool): If True, returns a model pre-trained on ImageNet 395 | """ 396 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 397 | if pretrained: 398 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'], model_dir=config['nn_weights_path'])) 399 | return model 400 | -------------------------------------------------------------------------------- /models/backbones/s: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/graph_memory.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | import torch.utils.model_zoo as model_zoo 7 | from torchvision import models 8 | 9 | # general libs 10 | import cv2 11 | import matplotlib.pyplot as plt 12 | from PIL import Image 13 | import numpy as np 14 | import math 15 | import time 16 | import tqdm 17 | import os 18 | import argparse 19 | import copy 20 | import sys 21 | 22 | from .helpers import * 23 | 24 | print('Space-time Memory Networks: initialized.') 25 | import sys 26 | sys.path.append('/home/ubuntu/xiankai/meta_VOS/models') 27 | import units.ConvGRU2 as ConvGRU 28 | 29 | class ASPP(nn.Module): 30 | def __init__(self, dilation_series, padding_series, depth): 31 | super(ASPP, self).__init__() 32 | self.mean = nn.AdaptiveAvgPool2d((1, 1)) 33 | self.conv = nn.Conv2d(1024, depth, 1, 1) 34 | self.bn_x = nn.BatchNorm2d(depth) 35 | self.conv2d_0 = nn.Conv2d(1024, depth, kernel_size=1, stride=1) 36 | self.bn_0 = nn.BatchNorm2d(depth) 37 | self.conv2d_1 = nn.Conv2d(1024, depth, kernel_size=3, stride=1, padding=padding_series[0], 38 | dilation=dilation_series[0]) 39 | self.bn_1 = nn.BatchNorm2d(depth) 40 | self.conv2d_2 = nn.Conv2d(1024, depth, kernel_size=3, stride=1, padding=padding_series[1], 41 | dilation=dilation_series[1]) 42 | self.bn_2 = nn.BatchNorm2d(depth) 43 | self.conv2d_3 = nn.Conv2d(1024, depth, kernel_size=3, stride=1, padding=padding_series[2], 44 | dilation=dilation_series[2]) 45 | self.bn_3 = nn.BatchNorm2d(depth) 46 | self.relu = nn.ReLU(inplace=True) 47 | self.bottleneck = nn.Conv2d(depth * 5, depth, kernel_size=3, padding=1) # 512 1x1Conv 48 | #self.bn = nn.BatchNorm2d(depth) 49 | #self.prelu = nn.PReLU() 50 | # for m in self.conv2d_list: 51 | # m.weight.data.normal_(0, 0.01) 52 | for m in self.modules(): 53 | if isinstance(m, nn.Conv2d): 54 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 55 | m.weight.data.normal_(0, 0.01) 56 | elif isinstance(m, nn.BatchNorm2d): 57 | m.weight.data.fill_(1) 58 | m.bias.data.zero_() 59 | 60 | def _make_stage_(self, dilation1, padding1): 61 | Conv = nn.Conv2d(2048, 256, kernel_size=3, stride=1, padding=padding1, dilation=dilation1, bias=True) # classes 62 | Bn = nn.BatchNorm2d(256) 63 | Relu = nn.ReLU(inplace=True) 64 | return nn.Sequential(Conv, Bn, Relu) 65 | 66 | def forward(self, x): 67 | # out = self.conv2d_list[0](x) 68 | # mulBranches = [conv2d_l(x) for conv2d_l in self.conv2d_list] 69 | size = x.shape[2:] 70 | image_features = self.mean(x) 71 | image_features = self.conv(image_features) 72 | image_features = self.bn_x(image_features) 73 | image_features = self.relu(image_features) 74 | image_features = F.upsample(image_features, size=size, mode='bilinear', align_corners=True) 75 | out_0 = self.conv2d_0(x) 76 | out_0 = self.bn_0(out_0) 77 | out_0 = self.relu(out_0) 78 | out_1 = self.conv2d_1(x) 79 | out_1 = self.bn_1(out_1) 80 | out_1 = self.relu(out_1) 81 | out_2 = self.conv2d_2(x) 82 | out_2 = self.bn_2(out_2) 83 | out_2 = self.relu(out_2) 84 | out_3 = self.conv2d_3(x) 85 | out_3 = self.bn_3(out_3) 86 | out_3 = self.relu(out_3) 87 | out = torch.cat([image_features, out_0, out_1, out_2, out_3], 1) 88 | out = self.bottleneck(out) 89 | # out = self.bn(out) 90 | # out = self.prelu(out) 91 | # for i in range(len(self.conv2d_list) - 1): 92 | # out += self.conv2d_list[i + 1](x) 93 | 94 | return out 95 | 96 | class ResBlock(nn.Module): 97 | def __init__(self, indim, outdim=None, stride=1): 98 | super(ResBlock, self).__init__() 99 | if outdim == None: 100 | outdim = indim 101 | if indim == outdim and stride==1: 102 | self.downsample = None 103 | else: 104 | self.downsample = nn.Conv2d(indim, outdim, kernel_size=3, padding=1, stride=stride) 105 | 106 | self.conv1 = nn.Conv2d(indim, outdim, kernel_size=3, padding=1, stride=stride) 107 | self.conv2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1) 108 | 109 | 110 | def forward(self, x): 111 | r = self.conv1(F.relu(x)) 112 | r = self.conv2(F.relu(r)) 113 | 114 | if self.downsample is not None: 115 | x = self.downsample(x) 116 | 117 | return x + r 118 | 119 | class Encoder_M(nn.Module): 120 | def __init__(self): 121 | super(Encoder_M, self).__init__() 122 | self.conv1_m = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 123 | self.conv1_o = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 124 | 125 | resnet = models.resnet50(pretrained=True) 126 | self.conv1 = resnet.conv1 127 | self.bn1 = resnet.bn1 128 | self.relu = resnet.relu # 1/2, 64 129 | self.maxpool = resnet.maxpool 130 | 131 | self.res2 = resnet.layer1 # 1/4, 256 132 | self.res3 = resnet.layer2 # 1/8, 512 133 | self.res4 = resnet.layer3 # 1/8, 1024 134 | 135 | self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1)) 136 | self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1)) 137 | 138 | def forward(self, in_f, in_m, in_o): 139 | #print('type:',type(self.mean),type(in_f)) 140 | 141 | f = (in_f - self.mean) / self.std 142 | m = torch.unsqueeze(in_m, dim=1).float() # add channel dim 143 | o = torch.unsqueeze(in_o, dim=1).float() # add channel dim 144 | 145 | x = self.conv1(f) + self.conv1_m(m) + self.conv1_o(o) 146 | x = self.bn1(x) 147 | c1 = self.relu(x) # 1/2, 64 148 | x = self.maxpool(c1) # 1/4, 64 149 | r2 = self.res2(x) # 1/4, 256 150 | r3 = self.res3(r2) # 1/8, 512 151 | r4 = self.res4(r3) # 1/8, 1024 152 | return r4, r3, r2, c1, f 153 | 154 | class Encoder_Q(nn.Module): 155 | def __init__(self): 156 | super(Encoder_Q, self).__init__() 157 | resnet = models.resnet50(pretrained=True) 158 | self.conv1 = resnet.conv1 159 | self.bn1 = resnet.bn1 160 | self.relu = resnet.relu # 1/2, 64 161 | self.maxpool = resnet.maxpool 162 | 163 | self.res2 = resnet.layer1 # 1/4, 256 164 | self.res3 = resnet.layer2 # 1/8, 512 165 | self.res4 = resnet.layer3 # 1/8, 1024 166 | 167 | self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1)) 168 | self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1)) 169 | 170 | def forward(self, in_f): 171 | f = (in_f - self.mean) / self.std 172 | 173 | x = self.conv1(f) 174 | x = self.bn1(x) 175 | c1 = self.relu(x) # 1/2, 64 176 | x = self.maxpool(c1) # 1/4, 64 177 | r2 = self.res2(x) # 1/4, 256 178 | r3 = self.res3(r2) # 1/8, 512 179 | r4 = self.res4(r3) # 1/8, 1024 180 | return r4, r3, r2, c1, f 181 | 182 | 183 | class Refine(nn.Module): 184 | def __init__(self, inplanes, planes, scale_factor=2): 185 | super(Refine, self).__init__() 186 | self.convFS = nn.Conv2d(inplanes, planes, kernel_size=(3,3), padding=(1,1), stride=1) 187 | self.ResFS = ResBlock(planes, planes) 188 | self.ResMM = ResBlock(planes, planes) 189 | self.scale_factor = scale_factor 190 | 191 | def forward(self, f, pm): 192 | s = self.ResFS(self.convFS(f)) 193 | #print('fea size:',s.size(),pm.size()) 194 | m = s + F.interpolate(pm, scale_factor=self.scale_factor, mode='bilinear', align_corners=False) 195 | m = self.ResMM(m) 196 | return m 197 | 198 | class Decoder(nn.Module): 199 | def __init__(self, mdim): 200 | super(Decoder, self).__init__() 201 | self.layer5 = self._make_pred_layer(ASPP, [2, 4, 8], [2, 4, 8], mdim) 202 | #self.convFM = nn.Conv2d(1024, mdim, kernel_size=(3,3), padding=(1,1), stride=1) 203 | self.ResMM = ResBlock(mdim, mdim) 204 | self.RF3 = Refine(512, mdim) # 1/8 -> 1/4 205 | self.RF2 = Refine(256, mdim) # 1/4 -> 1 206 | 207 | self.pred2 = nn.Conv2d(mdim, 2, kernel_size=(3,3), padding=(1,1), stride=1) 208 | 209 | def forward(self, r4, r3, r2): 210 | m4 = self.ResMM(self.layer5(r4)) 211 | m3 = self.RF3(r3, m4) # out: 1/8, 256 212 | m2 = self.RF2(r2, m3) # out: 1/4, 256 213 | 214 | p2 = self.pred2(F.relu(m2)) 215 | 216 | p = F.interpolate(p2, scale_factor=4, mode='bilinear', align_corners=False) 217 | return p #, p2, p3, p4 218 | 219 | def _make_pred_layer(self, block, dilation_series, padding_series, num_classes): 220 | return block(dilation_series, padding_series, num_classes) 221 | 222 | 223 | class Memory(nn.Module): 224 | def __init__(self): 225 | super(Memory, self).__init__() 226 | self.propagate_layers = 5 227 | self.conv_fusion = nn.Conv2d(512, 512, kernel_size=1, bias=True) 228 | self.ConvGRU_h = ConvGRU.ConvGRUCell(512, 512, all_dim=128, kernel_size=1) 229 | self.ConvGRU_m = ConvGRU.ConvGRUCell(512, 512, all_dim=128, kernel_size=1) 230 | self.linear_e = nn.Linear(128, 128, bias=False) 231 | 232 | def forward(self, m_in, m_out, q_in, q_out): # m_in: o,c,t,h,w 233 | B, D_e, T, H, W = m_in.size() # T is the memory size 234 | _, D_o, _, _, _ = m_out.size() 235 | q_out0 = q_out.clone() 236 | for kk in range(0, self.propagate_layers): 237 | B, D_e, T, H, W = m_in.size() 238 | _, D_o, _, _, _ = m_out.size() 239 | # print('fea size:',m_in.size(),q_in.size()) 240 | mi = m_in.view(B, D_e, T * H * W) 241 | mi = torch.transpose(mi, 1, 2) # b, THW, emb 242 | 243 | qi = q_in.view(B, D_e, H * W) # b, emb, HW 244 | 245 | p = torch.bmm(mi, qi) # b, THW, HW 246 | p = p / math.sqrt(D_e) 247 | p = F.softmax(p, dim=1) # b, THW, HW 248 | mo = m_out.view(B, D_o, T * H * W) 249 | mem = torch.bmm(mo, p) # Weighted-sum B, D_o, HW 250 | mem_mean = mem.view(B, D_o, H, W) 251 | 252 | if T <2: 253 | 254 | mem_out = torch.cat([mem_mean, q_out0], dim=1) 255 | return mem_out,10 256 | if T>1: 257 | m_in_all = torch.cat((m_in, q_in.unsqueeze(2)),dim=2).contiguous() # B, D_e, T+1, H, W 258 | m_out_all = torch.cat((m_out, q_out.unsqueeze(2)),dim=2).contiguous() # B, D_o, T+1, H, W 259 | #print('memory size:',m_in_all.size(),m_out_all.size()) 260 | edge_featuress = [] 261 | for x in range(0,T+1): #for each node 262 | edge_set = [] #compute edge feature with other nodes 263 | for y in range(0,T+1): 264 | edge_feature = self.generate_edge(m_in_all[:,:,y,:,:].clone(), m_out_all[:,:,y,:,:].clone(), m_in_all[:,:,x,:,:].clone()) 265 | edge_set.append(edge_feature) 266 | edge_set.pop(x) #remove self connection 267 | edge_features = self.conv_fusion(torch.sum(torch.stack(edge_set,dim=1),dim=1))#self.conv_fusion(torch.cat(edge_set,dim=1)) 268 | edge_featuress.append(edge_features) 269 | for x in range(0,T): 270 | #only update memory node 271 | #print('feature dim:',torch.cat(edge_set,dim=1).size(),m_out_all[:,:,x,:,:].size()) 272 | hiden_state = self.ConvGRU_m(edge_featuress[x], m_out_all[:,:,x,:,:].clone()) 273 | #hiden_state = self.batch_norm_m(hiden_state) 274 | m_out_all[:, :, x, :, :] = hiden_state 275 | 276 | q_out_h = self.ConvGRU_h(q_out, mem_mean) 277 | #q_out_h = self.batch_norm_h(q_out_h) 278 | q_out = q_out_h.clone() 279 | 280 | mem_out = torch.cat([q_out, q_out0], dim=1) 281 | return mem_out,10 282 | 283 | def generate_edge(self, m_in, m_out, q_in): 284 | B, D_e, H, W = m_in.size() # during training T is 1 or 2 285 | _, D_o, _, _ = m_out.size() 286 | mi = m_in.view(B, D_e, H * W) 287 | mi = torch.transpose(mi, 1, 2) # b, THW, emb 288 | 289 | qi = q_in.view(B, D_e, H * W) # b, emb, HW 290 | mi = self.linear_e(mi) 291 | p = torch.bmm(mi, qi) # b, THW, HW 292 | #p = p / math.sqrt(D_e) 293 | p = F.softmax(p, dim=1) # b, THW, HW 294 | 295 | mo = m_out.view(B, D_o, H * W) 296 | mem = torch.bmm(mo, p) # Weighted-sum B, D_o, HW 297 | mem = mem.view(B, D_o, H, W) 298 | return mem 299 | 300 | class KeyValue(nn.Module): 301 | # Not using location 302 | def __init__(self, indim, keydim, valdim): 303 | super(KeyValue, self).__init__() 304 | self.Key = nn.Conv2d(indim, keydim, kernel_size=(3,3), padding=(1,1), stride=1) 305 | self.Value = nn.Conv2d(indim, valdim, kernel_size=(3,3), padding=(1,1), stride=1) 306 | 307 | def forward(self, x): 308 | return self.Key(x), self.Value(x) 309 | 310 | 311 | 312 | 313 | class graph_memory(nn.Module): 314 | def __init__(self): 315 | super(graph_memory, self).__init__() 316 | self.Encoder_M = Encoder_M() 317 | self.Encoder_Q = Encoder_Q() 318 | 319 | self.KV_M_r4 = KeyValue(1024, keydim=128, valdim=512) 320 | self.KV_Q_r4 = KeyValue(1024, keydim=128, valdim=512) 321 | 322 | self.Memory = Memory() 323 | self.Decoder = Decoder(256) 324 | 325 | def Pad_memory(self, mems, num_objects, K): 326 | pad_mems = [] 327 | for mem in mems: 328 | pad_mem = ToCuda(torch.zeros(1, K, mem.size()[1], 1, mem.size()[2], mem.size()[3])) 329 | pad_mem[0,1:num_objects+1,:,0] = mem 330 | pad_mems.append(pad_mem) 331 | return pad_mems 332 | 333 | def memorize(self, frame, masks, num_objects): 334 | B_f_batch, B_m_batch, B_o_batch = [], [], [] 335 | frame_batch, masks_batch, num_objects_batch = frame, masks, num_objects 336 | for i in range(num_objects_batch.shape[0]): 337 | # memorize a frame 338 | frame, masks = frame_batch[i].unsqueeze(0), masks_batch[i].unsqueeze(0) 339 | num_objects = num_objects_batch[i] 340 | _, K, H, W = masks.shape # B = 1 341 | 342 | (frame, masks), pad = pad_divide_by([frame, masks], 16, (frame.size()[2], frame.size()[3])) 343 | 344 | # make batch arg list 345 | B_list = {'f':[], 'm':[], 'o':[]} 346 | for o in range(1, num_objects+1): # 1 - no 347 | B_list['f'].append(frame) 348 | B_list['m'].append(masks[:,o]) 349 | B_list['o'].append( (torch.sum(masks[:,1:o], dim=1) + \ 350 | torch.sum(masks[:,o+1:num_objects+1], dim=1)).clamp(0,1) ) 351 | 352 | # make Batch 353 | B_ = {} 354 | for arg in B_list.keys(): 355 | B_[arg] = torch.cat(B_list[arg], dim=0) 356 | B_f_batch.append(B_['f']), B_m_batch.append(B_['m']), B_o_batch.append(B_['o']) 357 | 358 | B_f, B_m, B_o = torch.cat(B_f_batch, dim=0), torch.cat(B_m_batch, dim=0), torch.cat(B_o_batch, dim=0) 359 | r4, _, _, _, _ = self.Encoder_M(B_f, B_m, B_o) 360 | k4, v4 = self.KV_M_r4(r4) # num_objects, 128 and 512, H/16, W/16 361 | k4, v4 = self.Pad_memory([k4, v4], num_objects=torch.sum(num_objects_batch), K=K*num_objects_batch.shape[0]) 362 | return k4, v4 363 | 364 | def Soft_aggregation(self, ps, K): 365 | num_objects, H, W = ps.shape 366 | em = ToCuda(torch.zeros(1, K, H, W)) 367 | em[0,0] = torch.prod(1-ps, dim=0) # bg prob 368 | em[0,1:num_objects+1] = ps # obj prob 369 | em = torch.clamp(em, 1e-7, 1-1e-7) 370 | logit = torch.log((em /(1-em))) 371 | return logit 372 | 373 | def segment(self, frame, keys, values, num_objects): 374 | k4e_batch, v4e_batch = [], [] 375 | r3e_batch, r2e_batch = [], [] 376 | frame_batch, num_objects_batch = frame, num_objects 377 | _, K, keydim, T, H, W = keys.shape # B = 1 378 | # pad 379 | 380 | 381 | for i in range(num_objects_batch.shape[0]): 382 | frame = frame_batch[i].unsqueeze(0) 383 | [frame], pad = pad_divide_by([frame], 16, (frame.size()[2], frame.size()[3])) 384 | num_objects = num_objects_batch[i] 385 | r4, r3, r2, _, _ = self.Encoder_Q(frame) 386 | k4, v4 = self.KV_Q_r4(r4) # 1, dim, H/16, W/16 387 | 388 | # expand to --- no, c, h, w 389 | k4e, v4e = k4.expand(num_objects,-1,-1,-1), v4.expand(num_objects,-1,-1,-1) 390 | r3e, r2e = r3.expand(num_objects,-1,-1,-1), r2.expand(num_objects,-1,-1,-1) 391 | k4e_batch.append(k4e), v4e_batch.append(v4e) 392 | r3e_batch.append(r3e), r2e_batch.append(r2e) 393 | 394 | k4e, v4e = torch.cat(k4e_batch, dim=0), torch.cat(v4e_batch, dim=0) 395 | r3e, r2e = torch.cat(r3e_batch, dim=0), torch.cat(r2e_batch, dim=0) 396 | # memory select kv:(1, K, C, T, H, W) 397 | m4, viz = self.Memory(keys[0,1:torch.sum(num_objects_batch)+1], values[0,1:torch.sum(num_objects_batch)+1], k4e, v4e) 398 | logits_batch = self.Decoder(m4, r3e, r2e) 399 | logits_batch_out = [] 400 | begin = 0 401 | for i in range(num_objects_batch.shape[0]): 402 | ps = F.softmax(logits_batch[begin:begin + num_objects_batch[i]], dim=1)[:,1] # no, h, w 403 | #ps = indipendant possibility to belong to each object 404 | 405 | logit = self.Soft_aggregation(ps, int(K/num_objects_batch.shape[0])) # 1, K, H, W 406 | logits_batch_out.append(logit) 407 | begin += num_objects_batch[i] 408 | 409 | logit = torch.cat(logits_batch_out, dim=0) 410 | if pad[2]+pad[3] > 0: 411 | logit = logit[:,:,pad[2]:-pad[3],:] 412 | if pad[0]+pad[1] > 0: 413 | logit = logit[:,:,:,pad[0]:-pad[1]] 414 | 415 | return logit 416 | 417 | def forward(self, *args, **kwargs): 418 | if args[1].dim() > 4: # keys 419 | return self.segment(*args, **kwargs) 420 | else: 421 | return self.memorize(*args, **kwargs) 422 | 423 | 424 | -------------------------------------------------------------------------------- /models/helpers.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | #torch 3 | import torch 4 | from torch.autograd import Variable 5 | from torch.utils import data 6 | 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.nn.init as init 10 | import torch.utils.model_zoo as model_zoo 11 | from torchvision import models 12 | 13 | # general libs 14 | import cv2 15 | import matplotlib.pyplot as plt 16 | from PIL import Image 17 | import numpy as np 18 | import time 19 | import os 20 | import copy 21 | 22 | 23 | def ToCuda(xs): 24 | if torch.cuda.is_available(): 25 | if isinstance(xs, list) or isinstance(xs, tuple): 26 | return [x.cuda() for x in xs] 27 | else: 28 | return xs.cuda() 29 | else: 30 | return xs 31 | 32 | 33 | def pad_divide_by(in_list, d, in_size): 34 | out_list = [] 35 | h, w = in_size 36 | if h % d > 0: 37 | new_h = h + d - h % d 38 | else: 39 | new_h = h 40 | if w % d > 0: 41 | new_w = w + d - w % d 42 | else: 43 | new_w = w 44 | lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2) 45 | lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2) 46 | pad_array = (int(lw), int(uw), int(lh), int(uh)) 47 | for inp in in_list: 48 | out_list.append(F.pad(inp, pad_array)) 49 | return out_list, pad_array 50 | 51 | 52 | 53 | def overlay_davis(image,mask,colors=[255,0,0],cscale=2,alpha=0.4): 54 | """ Overlay segmentation on top of RGB image. from davis official""" 55 | # import skimage 56 | from scipy.ndimage.morphology import binary_erosion, binary_dilation 57 | 58 | colors = np.reshape(colors, (-1, 3)) 59 | colors = np.atleast_2d(colors) * cscale 60 | 61 | im_overlay = image.copy() 62 | object_ids = np.unique(mask) 63 | 64 | for object_id in object_ids[1:]: 65 | # Overlay color on binary mask 66 | foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id]) 67 | binary_mask = mask == object_id 68 | 69 | # Compose image 70 | im_overlay[binary_mask] = foreground[binary_mask] 71 | 72 | # countours = skimage.morphology.binary.binary_dilation(binary_mask) - binary_mask 73 | countours = binary_dilation(binary_mask) ^ binary_mask 74 | # countours = cv2.dilate(binary_mask, cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))) - binary_mask 75 | im_overlay[countours,:] = 0 76 | 77 | return im_overlay.astype(image.dtype) 78 | 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /models/s: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/units/ConvGRU2.py: -------------------------------------------------------------------------------- 1 | ################################################### 2 | # Nicolo Savioli, 2017 -- Conv-GRU pytorch v 1.0 # 3 | ################################################### 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | from torch.nn import init 9 | class ConvGRUCell(nn.Module): 10 | 11 | def __init__(self, input_size, hidden_size, all_dim, kernel_size): 12 | super(ConvGRUCell,self).__init__() 13 | self.input_size = input_size 14 | self.cuda_flag = True 15 | self.input_size = input_size 16 | self.hidden_size = hidden_size 17 | self.padding = int((kernel_size - 1) / 2) 18 | self.reset_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=self.padding) 19 | self.update_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=self.padding) 20 | self.out_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=self.padding) 21 | 22 | init.orthogonal(self.reset_gate.weight) 23 | init.orthogonal(self.update_gate.weight) 24 | init.orthogonal(self.out_gate.weight) 25 | init.constant(self.reset_gate.bias, 0.) 26 | init.constant(self.update_gate.bias, 0.) 27 | init.constant(self.out_gate.bias, 0.) 28 | 29 | def forward(self, input_, prev_state): 30 | 31 | # get batch and spatial sizes 32 | batch_size = input_.data.size()[0] 33 | spatial_size = input_.data.size()[2:] 34 | 35 | # generate empty prev_state, if None is provided 36 | if prev_state is None: 37 | state_size = [batch_size, self.hidden_size] + list(spatial_size) 38 | if torch.cuda.is_available(): 39 | prev_state = Variable(torch.zeros(state_size)).cuda() 40 | else: 41 | prev_state = Variable(torch.zeros(state_size)) 42 | 43 | # data size is [batch, channel, height, width] 44 | stacked_inputs = torch.cat([input_, prev_state], dim=1) 45 | update = F.sigmoid(self.update_gate(stacked_inputs)) 46 | reset = F.sigmoid(self.reset_gate(stacked_inputs)) 47 | out_inputs = F.tanh(self.out_gate(torch.cat([input_, prev_state * reset], dim=1))) 48 | new_state = prev_state * (1 - update) + out_inputs * update 49 | 50 | return new_state#, out_inputs 51 | -------------------------------------------------------------------------------- /models/units/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/media/xiankai/Data/segmentation/OSVOS/ECCV_graph_mem/models/units') 3 | #from LinkFunction import LinkFunction 4 | from MessageFunction import MessageFunction 5 | from UpdateFunction import UpdateFunction 6 | #from ReadoutFunction import ReadoutFunction 7 | #'LinkFunction', , 'ReadoutFunction' 8 | __all__ = ('MessageFunction', 'UpdateFunction') 9 | -------------------------------------------------------------------------------- /models/units/s: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | torchvision==0.2.2.post3 3 | tqdm 4 | tensorboardX 5 | scipy 6 | matplotlib 7 | imgaug 8 | scikit-image 9 | -------------------------------------------------------------------------------- /run_graph_memory_test.sh: -------------------------------------------------------------------------------- 1 | #python eval_DAVIS_graph_memory.py -g '4' -s val -y 17 -D /raid/DAVIS/DAVIS-2017/DAVIS-train-val 2 | #python eval_DAVIS.py -g '0' -s val -y 16 -D /media/xiankai/Data/segmentation/DAVIS-2016 3 | python runfiles/eval_DAVIS_graph_memory.py -c './workspace_STM_alpha/main_runfile_graph_memory.pth.tar' 4 | -------------------------------------------------------------------------------- /runfiles/eval_DAVIS_graph_memory.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import sys 3 | sys.path.append('/media/xiankai/Data/segmentation/OSVOS/ECCV_graph_memory') 4 | import models 5 | import torch 6 | from torch.autograd import Variable 7 | from torch.utils import data 8 | 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.nn.init as init 12 | import torch.utils.model_zoo as model_zoo 13 | #from torchvision import models 14 | from numpy.random import randint 15 | # general libs 16 | import cv2 17 | import matplotlib.pyplot as plt 18 | from PIL import Image 19 | import numpy as np 20 | import math 21 | import time 22 | import tqdm 23 | import os 24 | import argparse 25 | import copy 26 | import csv 27 | 28 | ### My libs 29 | 30 | from dataset import DAVIS_MO_Test 31 | 32 | 33 | torch.set_grad_enabled(False) # Volatile 34 | def pad_divide_by(in_list, d, in_size): 35 | out_list = [] 36 | h, w = in_size #input size 37 | if h % d > 0: 38 | new_h = h + d - h % d 39 | else: 40 | new_h = h 41 | if w % d > 0: 42 | new_w = w + d - w % d 43 | else: 44 | new_w = w 45 | lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2) 46 | lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2) 47 | pad_array = (int(lw), int(uw), int(lh), int(uh)) 48 | for inp in in_list: 49 | out_list.append(F.pad(inp, pad_array)) 50 | return out_list, pad_array 51 | 52 | def get_arguments(): 53 | parser = argparse.ArgumentParser(description="SST") 54 | parser.add_argument("-g", type=str, help="0; 0,1; 0,3; etc", default='0') 55 | parser.add_argument("-c", type=str, help="checkpoint", default=' ') 56 | parser.add_argument("-s", type=str, help="set", default="val") 57 | parser.add_argument("-y", type=int, help="year", default="17") 58 | parser.add_argument("-viz", help="Save visualization", action="store_true") 59 | parser.add_argument("-D", type=str, help="path to data", default='/media/xiankai/Data/segmentation/DAVIS-2017/DAVIS-train-val') 60 | return parser.parse_args() 61 | 62 | args = get_arguments() 63 | 64 | GPU = args.g 65 | YEAR = args.y 66 | SET = args.s 67 | VIZ = args.viz 68 | DATA_ROOT = args.D 69 | 70 | # Model and version 71 | MODEL = 'Graph-memory' 72 | print(MODEL, ': Testing on DAVIS') 73 | 74 | os.environ['CUDA_VISIBLE_DEVICES'] = GPU 75 | if torch.cuda.is_available(): 76 | print('using Cuda devices, num:', torch.cuda.device_count()) 77 | 78 | if VIZ: 79 | print('--- Produce mask overaid video outputs. Evaluation will run slow.') 80 | print('--- Require FFMPEG for encoding, Check folder ./viz') 81 | 82 | palette = Image.open('/media/xiankai/Data/segmentation/DAVIS-2017/DAVIS-train-val/Annotations/480p/bear/00000.png').getpalette() 83 | 84 | 85 | class VideoRecord(object): 86 | pass 87 | 88 | 89 | def _sample_pair_indices(record): 90 | """ 91 | :param record: VideoRecord 92 | :return: list 93 | """ 94 | new_length = 1 95 | 96 | average_duration = (record.num_frames - new_length + 1) // record.num_segment 97 | if average_duration > 0: 98 | # offsets = np.multiply(list(range(record.num_segment)), average_duration) + randint(average_duration,size=record.num_segment) 99 | offsets = np.multiply(list(range(record.num_segment)), average_duration) + [average_duration//2]*record.num_segment # no random 100 | elif record.num_frames > record.num_segment: 101 | offsets = randint(record.num_frames - 102 | new_length + 1, size=record.num_segment) 103 | else: 104 | offsets = np.zeros((record.num_segment,)) 105 | return offsets 106 | 107 | 108 | def Run_video(Fs, Ms, num_frames, num_objects, Mem_every=None, Mem_number=None): 109 | # initialize storage tensors 110 | num_first_memory = 1 111 | if Mem_every: 112 | to_memorize = [int(i) for i in np.arange(0, num_frames, step=Mem_every)] 113 | elif Mem_number: 114 | to_memorize = [int(round(i)) for i in 115 | np.linspace(0, num_frames, num=Mem_number + 2)[:-1]] # [0, 5, 10, 15, 20, 25] 116 | else: 117 | raise NotImplementedError 118 | 119 | # print('memory size:', len(to_memorize)) 120 | Es = torch.zeros_like(Ms) # mask 121 | Es[:, :, 0] = Ms[:, :, 0] 122 | record = VideoRecord() 123 | 124 | record.num_segment = 4 125 | for t in tqdm.tqdm(range(1, num_frames)): 126 | # memorize 127 | with torch.no_grad(): 128 | prev_key, prev_value = model(Fs[:, :, t - 1], Es[:, :, t - 1], torch.tensor([num_objects])) 129 | 130 | if t - 1 == 0: # 131 | this_keys, this_values = prev_key, prev_value # only prev memory 132 | elif t <= record.num_segment: 133 | this_keys = torch.cat([keys, prev_key], dim=3) 134 | this_values = torch.cat([values, prev_value], dim=3) 135 | # segment 136 | with torch.no_grad(): 137 | # print('input size1:', this_keys.size(), this_values.size())# torch.Size([1, 11, 128, 1, 30, 57]) torch.Size([1, 11, 512, 1, 30, 57]) # one hot label vector with length 11 138 | record.num_frames = t 139 | select_keys = [] 140 | select_values = [] 141 | # print('key size:',t,this_keys.size(), this_values.size())#[1, 11, 128, 4, 30, 57] 142 | if t > record.num_segment: 143 | Index = _sample_pair_indices(record) if record.num_segment else [] 144 | 145 | #Index[-1]=t-1 146 | # print('index', t, Index, type(Index)) 147 | for add_0 in range(num_first_memory): 148 | select_keys.append(this_keys[:, :, :, 0, :, :].unsqueeze(dim=3)) 149 | select_values.append(this_values[:, :, :, 0, :, :].unsqueeze(dim=3)) 150 | # print('index0:', this_keys[:, :, :, 0, :, :].size(),this_values[:, :, :, 0, :, :].unsqueeze(dim=3).size()) 151 | for ii in Index: 152 | prev_key1, prev_value1 = model(Fs[:, :, ii], Es[:, :, ii], torch.tensor([num_objects])) 153 | 154 | # print('index1:', prev_key.size()) #1, 11, 128, 1, 30, 57 155 | select_keys.append(prev_key1) # (this_keys[:, :, :, t-1, :, :]) 156 | select_values.append(prev_value1) # (this_values[:, :, :, t-1, :, :]) 157 | select_keys.append(prev_key) # (this_keys[:, :, :, t-1, :, :]) 158 | select_values.append(prev_value) 159 | # print('index2:', select_keys[0].size()) 160 | select_keys = torch.cat(select_keys, dim=3) 161 | select_values = torch.cat(select_values, dim=3) 162 | # print('key size:', prev_key.size(), select_keys.size(), select_values.size()) 163 | else: 164 | select_keys = this_keys 165 | select_values = this_values 166 | logit = model(Fs[:, :, t], select_keys, select_values, torch.tensor([num_objects])) 167 | 168 | Es[:, :, t] = F.softmax(logit, dim=1) 169 | # print('output size:', torch.max(Es[:,:,t]), torch.min(Es[:,:,t])) 170 | # update 171 | if t - 1 in to_memorize: 172 | keys, values = this_keys, this_values 173 | 174 | pred = np.argmax(Es[0].cpu().numpy(), axis=0).astype(np.uint8) 175 | return pred, Es 176 | 177 | 178 | Testset = DAVIS_MO_Test(DATA_ROOT, resolution='480p', imset='20{}/{}.txt'.format(YEAR, SET), single_object=(YEAR == 16)) 179 | Testloader = data.DataLoader(Testset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True) 180 | 181 | model = nn.DataParallel(models.graph_memory()) 182 | for param in model.parameters(): 183 | param.requires_grad = False 184 | if torch.cuda.is_available(): 185 | model.cuda() 186 | model.eval() # turn-off BN 187 | 188 | pth_path = args.c 189 | print('Loading weights:', pth_path) 190 | 191 | checkpoint = torch.load(pth_path) 192 | model.module.load_state_dict(checkpoint['net']) 193 | 194 | try: 195 | print('epoch:', checkpoint['epoch']) 196 | except: 197 | print('dont know epoch') 198 | 199 | for seq, V in enumerate(Testloader): 200 | Fs, Ms, num_objects, info = V 201 | seq_name = info['name'][0] 202 | num_frames = info['num_frames'][0].item() 203 | print('[{}]: num_frames: {}, num_objects: {}'.format(seq_name, num_frames, num_objects[0][0])) 204 | 205 | B, K, N, H, W = Fs.shape 206 | H_1, W_1 = 480, int(480.0 * W / H) 207 | # resize_sizes = [(int(0.75 * H_1), int(0.75 * W_1)), (H_1, W_1), (int(1.25 * H_1), int(1.25 * W_1))] 208 | resize_sizes = [(H_1, W_1)] 209 | use_flip = True 210 | resize_Fs = [] 211 | resize_Ms = [] 212 | # ms 213 | for size in resize_sizes: 214 | resize_Fs.append(F.interpolate(input=Fs.squeeze(0).permute(1, 0, 2, 3), size=size, mode='bilinear', 215 | align_corners=True).permute(1, 0, 2, 3).unsqueeze(0)) 216 | resize_Ms.append(F.interpolate(input=Ms.squeeze(0).permute(1, 0, 2, 3), size=size, mode='nearest' 217 | ).permute(1, 0, 2, 3).unsqueeze(0)) 218 | # flip 219 | if use_flip: 220 | for i in range(len(resize_Fs)): 221 | resize_Fs.append(torch.flip(resize_Fs[i], [-1])) 222 | resize_Ms.append(torch.flip(resize_Ms[i], [-1])) 223 | 224 | Es_list = [] 225 | for i in range(len(resize_Fs)): 226 | pred, Es = Run_video(resize_Fs[i], resize_Ms[i], num_frames, num_objects, Mem_every=5, Mem_number=None) 227 | Es = F.interpolate(input=Es.squeeze(0).permute(1, 0, 2, 3), size=(H, W), mode='bilinear', 228 | align_corners=True).permute(1, 0, 2, 3).unsqueeze(0) 229 | if use_flip: 230 | if i >= (len(resize_Fs) / 2): 231 | Es = torch.flip(Es, [-1]) 232 | Es_list.append(Es) 233 | Es = torch.stack(Es_list).mean(dim=0) 234 | pred = np.argmax(Es[0].numpy(), axis=0).astype(np.uint8) # different than ytvos here 235 | # pred, Es = Run_video(Fs, Ms, num_frames, num_objects, Mem_every=1, Mem_number=None, 236 | # num_first_memory=num_first_memory, num_middle_memory=num_middle_memory) 237 | 238 | # Save results for quantitative eval ###################### 239 | test_path = os.path.join('./test', code_name, seq_name) 240 | if not os.path.exists(test_path): 241 | os.makedirs(test_path) 242 | for f in range(num_frames): 243 | img_E = Image.fromarray(pred[f]) 244 | # print('image size:',type(YEAR),np.max(pred[f]),np.min(pred[f])) 245 | if YEAR == 16: 246 | # print('ok!') 247 | img_E = (pred[f].squeeze() * 255).astype(np.uint8) 248 | img_E = Image.fromarray(img_E) 249 | img_E = img_E.convert('RGB') 250 | else: 251 | img_E.putpalette(palette) 252 | img_E.save(os.path.join(test_path, '{:05d}.png'.format(f))) 253 | 254 | if VIZ: 255 | from tools.helpers import overlay_davis 256 | 257 | # visualize results ####################### 258 | viz_path = os.path.join('./viz/', code_name, seq_name) 259 | if not os.path.exists(viz_path): 260 | os.makedirs(viz_path) 261 | 262 | for f in range(num_frames): 263 | pF = (Fs[0, :, f].permute(1, 2, 0).numpy() * 255.).astype(np.uint8) 264 | pE = pred[f] 265 | canvas = overlay_davis(pF, pE, palette) 266 | canvas = Image.fromarray(canvas) 267 | canvas.save(os.path.join(viz_path, 'f{}.jpg'.format(f))) 268 | 269 | vid_path = os.path.join('./viz/', code_name, '{}.mp4'.format(seq_name)) 270 | frame_path = os.path.join('./viz/', code_name, seq_name, 'f%d.jpg') 271 | 272 | 273 | -------------------------------------------------------------------------------- /runfiles/s: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .debugging import * 2 | from .stats import * 3 | from .tensor_utils import * 4 | from .readsaveimage import * 5 | -------------------------------------------------------------------------------- /utils/debugging.py: -------------------------------------------------------------------------------- 1 | 2 | def get_tensor_statistics_str(tensor, name="", formatting="standard"): 3 | """ Returns string of formatted tensor statistics, contains min, max, mean, and std""" 4 | if isinstance(tensor, (torch.FloatTensor, torch.cuda.FloatTensor)): 5 | if formatting == "standard": 6 | string = "elem in [{:6.3f}, {:6.3f}] mean: {:6.3f} std: {:6.3f} size: {}".format(tensor.min().item(), tensor.max().item(), tensor.mean().item(), tensor.std().item(), tuple(tensor.size())) 7 | elif formatting == "short": 8 | string = "[{:6.3f}, {:6.3f}] mu: {:6.3f} std: {:6.3f} {!s:17} {: 6.1f}MB".format(tensor.min().item(), tensor.max().item(), tensor.mean().item(), tensor.std().item(), tuple(tensor.size()), 4e-6 * prod(tensor.size())) 9 | elif isinstance(tensor, (torch.LongTensor, torch.ByteTensor, torch.cuda.LongTensor, torch.cuda.ByteTensor)): 10 | tensor = tensor.to('cpu') 11 | string = "elem in [{:6.3f}, {:6.3f}] size: {} HIST BELOW:\n{}".format(tensor.min().item(), tensor.max().item(), tuple(tensor.size()), torch.stack([torch.arange(0, tensor.max()+1), tensor.view(-1).bincount()], dim=0)) 12 | else: 13 | raise NotImplementedError("A type of tensor not yet supported was input. Expected torch.FloatTensor or torch.LongTensor, got: {}".format(tensor.type())) 14 | string = string + " " + name 15 | return string 16 | 17 | def print_tensor_statistics(tensor, name="", formatting="standard"): 18 | print(get_tensor_statistics_str(tensor, name, formatting)) 19 | 20 | def get_weight_statistics_str(layer, name="", formatting="standard"): 21 | return get_tensor_statistics_str(layer.weight, name, formatting) 22 | 23 | def get_model_size_str(model): 24 | nelem = 0 25 | for module in model.modules(): 26 | if hasattr(module, 'weight'): 27 | nelem += module.weight.numel() 28 | if hasattr(module, 'bias'): 29 | nelem += module.weight.numel() 30 | size_str = "{:.2f} MB".format(nelem * 4 * 1e-6) 31 | return size_str 32 | -------------------------------------------------------------------------------- /utils/readsaveimage.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import queue 3 | import time 4 | import os 5 | import png 6 | import numpy 7 | import math 8 | 9 | # Palette based on YTVOS color palette 10 | BASE_PALETTE_4BIT = [[ 0, 0, 0], 11 | [236, 94, 102], 12 | [249, 144, 87], 13 | [250, 199, 98], 14 | [153, 199, 148], 15 | [ 97, 179, 177], 16 | [102, 153, 204], 17 | [196, 148, 196], 18 | [171, 120, 102], 19 | [255, 255, 255], 20 | [101, 115, 125], 21 | [ 10, 10, 10], 22 | [ 12, 12, 12], 23 | [ 13, 13, 13], 24 | [ 13, 13, 13], 25 | [ 14, 14, 14]] 26 | 27 | DAVIS_PALETTE_4BIT = [[ 0, 0, 0], 28 | [128, 0, 0], 29 | [ 0, 128, 0], 30 | [128, 128, 0], 31 | [ 0, 0, 128], 32 | [128, 0, 128], 33 | [ 0, 128, 128], 34 | [128, 128, 128], 35 | [ 64, 0, 0], 36 | [191, 0, 0], 37 | [ 64, 128, 0], 38 | [191, 128, 0], 39 | [ 64, 0, 128], 40 | [191, 0, 128], 41 | [ 64, 128, 128], 42 | [191, 128, 128]] 43 | 44 | # Implementations of Save Methods 45 | class ReadSaveImage(object): 46 | def __init__(self): 47 | super(ReadSaveImage, self).__init__() 48 | 49 | def check_path(self, fullpath): 50 | path, filename = os.path.split(fullpath) 51 | if not os.path.exists(path): 52 | os.makedirs(path) 53 | 54 | class PngMono(ReadSaveImage): 55 | def __init__(self): 56 | super(PngMono, self).__init__() 57 | 58 | def save(self, image, path, bitdepth): 59 | self.check_path(path) 60 | 61 | # Expect a numpy array 62 | height, width = image.shape 63 | file = open(path, 'wb') 64 | writer = png.Writer(width, height, greyscale=True, bitdepth=bitdepth) 65 | writer.write(file,image) 66 | 67 | class ReadSaveYTVOSChallengeLabels(ReadSaveImage): 68 | def __init__(self, bpalette=BASE_PALETTE_4BIT, palette=None): 69 | super(ReadSaveYTVOSChallengeLabels, self).__init__() 70 | self._palette = palette 71 | self._bpalette = bpalette 72 | @property 73 | def palette(self): 74 | return self._palette 75 | 76 | def save(self, image, path): 77 | self.check_path(path) 78 | 79 | # Set palette 80 | if self._palette is None: 81 | palette = self._bpalette 82 | else: 83 | palette = self._palette 84 | 85 | bitdepth = int(math.log(len(palette))/math.log(2)) 86 | 87 | # Expect a numpy array 88 | height, width = image.shape 89 | file = open(path, 'wb') 90 | writer = png.Writer(width, height, palette=palette, bitdepth=bitdepth) 91 | writer.write(file,image) 92 | 93 | def read(self, path): 94 | reader = png.Reader(path) 95 | width, height, data, meta = reader.read() 96 | if self._palette is None: 97 | self._palette = meta['palette'] 98 | image = numpy.vstack(data) 99 | return image 100 | 101 | class ReadSaveDAVISChallengeLabels(ReadSaveImage): 102 | def __init__(self, bpalette=DAVIS_PALETTE_4BIT, palette=None): 103 | super(ReadSaveDAVISChallengeLabels, self).__init__() 104 | self._palette = palette 105 | self._bpalette = bpalette 106 | self._width = 0 107 | self._height = 0 108 | 109 | @property 110 | def palette(self): 111 | return self._palette 112 | 113 | def save(self, image, path): 114 | self.check_path(path) 115 | 116 | # Set palette 117 | if self._palette is None: 118 | palette = self._bpalette 119 | else: 120 | palette = self._palette 121 | 122 | bitdepth = int(math.log(len(palette))/math.log(2)) 123 | 124 | # Expect a numpy array 125 | height, width = image.shape 126 | file = open(path, 'wb') 127 | writer = png.Writer(width, height, palette=palette, bitdepth=bitdepth) 128 | writer.write(file,image) 129 | 130 | def read(self, path): 131 | try: 132 | reader = png.Reader(path) 133 | width, height, data, meta = reader.read() 134 | if self._palette is None: 135 | self._palette = meta['palette'] 136 | image = numpy.vstack(data) 137 | self._height, self._width = image.shape 138 | except png.FormatError: 139 | image = numpy.zeros((self._height, self._width)) 140 | self.save(image,path) 141 | 142 | return image 143 | 144 | 145 | class ImageSaveHelper(threading.Thread): 146 | """ImageSaveHelper. Expects that (numpy array, path, method) is queued """ 147 | 148 | def __init__(self, queueSize=100000): 149 | super(ImageSaveHelper, self).__init__() 150 | self._alive = True 151 | self._queue = queue.Queue(queueSize) 152 | self.start() 153 | 154 | @property 155 | def alive(self): 156 | return self._alive 157 | 158 | @alive.setter 159 | def alive(self, alive): 160 | self._alive = alive 161 | 162 | @property 163 | def queue(self): 164 | return self._queue 165 | 166 | def kill(self): 167 | self._alive = False 168 | 169 | def enqueue(self, datatuple): 170 | ret = True 171 | try: 172 | self._queue.put(datatuple, block=False) 173 | except queue.Full: 174 | print("ImageSaveHelper - enqueue full") 175 | ret = False 176 | return ret 177 | 178 | def run(self): 179 | while True: 180 | while not self._queue.empty(): 181 | # Get an image from the queue 182 | args, method = self._queue.get(block=False, timeout=2) 183 | #print(fullpath) 184 | 185 | # Save image 186 | method.save(*args) 187 | 188 | self._queue.task_done() 189 | 190 | if not self._alive and self._queue.empty(): 191 | break 192 | 193 | time.sleep(0.001) 194 | -------------------------------------------------------------------------------- /utils/s: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/stats.py: -------------------------------------------------------------------------------- 1 | 2 | class AverageMeter(object): 3 | """Computes and stores the average and current value""" 4 | def __init__(self): 5 | self.clear() 6 | 7 | def reset(self): 8 | self.avg = 0 9 | self.val = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def clear(self): 14 | self.reset() 15 | self.history = [] 16 | 17 | def update(self, val, n=1): 18 | self.val = val 19 | self.sum += val * n 20 | self.count += n 21 | if self.count > 0: 22 | self.avg = self.sum / self.count 23 | else: 24 | self.avg = 'nan' 25 | 26 | def new_epoch(self): 27 | self.history.append(self.avg) 28 | self.reset() 29 | 30 | -------------------------------------------------------------------------------- /utils/tensor_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def get_intersection_over_union(predictions, gt): 5 | """ Calculates the class intersections over unions of two tensors 6 | Args: 7 | predictions (Tensor): of size (nsamples,nclasses,H,W) 8 | gt (Tensor): Ground truth segmentations, of size 9 | (nsamples,H,W) 10 | Returns: 11 | Tensor: of size (nsamples,nclasses) with error for each class 12 | """ 13 | nsamples,nclasses,height,width = predictions.size() 14 | assert gt.size(0) == nsamples, "gt size: {}, predictions size: {}".format(gt.size(), predictions.size()) 15 | assert gt.size(1) == height, "gt size: {}, predictions size: {}".format(gt.size(), predictions.size()) 16 | assert gt.size(2) == width, "gt size: {}, predictions size: {}".format(gt.size(), predictions.size()) 17 | prediction_max, prediction_argmax = predictions.max(-3) 18 | prediction_argmax = prediction_argmax.long() 19 | classes = gt.new_tensor([c for c in range(nclasses)]).view(1, nclasses, 1, 1) # [1,K,1,1] 20 | pred_bin = (prediction_argmax.view(nsamples, 1, height, width) == classes) # [N,K,H,W] 21 | gt_bin = (gt.view(nsamples, 1, height, width) == classes) # [N,K,H,W] 22 | intersection = (pred_bin * gt_bin).float().sum(dim=-2).sum(dim=-1) # [N,K] 23 | union = ((pred_bin + gt_bin) > 0).float().sum(dim=-2).sum(dim=-1) # [N,K] 24 | assert (intersection > union).sum() == 0 25 | return (intersection + 1e-8) / (union + 1e-8) # [N,K] 26 | 27 | 28 | 29 | def db_eval_boundary(foreground_mask,gt_mask,bound_th=0.008): 30 | """ 31 | Compute mean,recall and decay from per-frame evaluation. 32 | Calculates precision/recall for boundaries between foreground_mask and 33 | gt_mask using morphological operators to speed it up. 34 | 35 | Arguments: 36 | foreground_mask (ndarray): binary segmentation image. 37 | gt_mask (ndarray): binary annotated image. 38 | 39 | Returns: 40 | F (float): boundaries F-measure 41 | P (float): boundaries precision 42 | R (float): boundaries recall 43 | """ 44 | assert np.atleast_3d(foreground_mask).shape[2] == 1 45 | 46 | bound_pix = bound_th if bound_th >= 1 else \ 47 | np.ceil(bound_th*np.linalg.norm(foreground_mask.shape)) 48 | 49 | # Get the pixel boundaries of both masks 50 | fg_boundary = seg2bmap(foreground_mask); 51 | gt_boundary = seg2bmap(gt_mask); 52 | 53 | from skimage.morphology import binary_dilation,disk 54 | 55 | fg_dil = binary_dilation(fg_boundary,disk(bound_pix)) 56 | gt_dil = binary_dilation(gt_boundary,disk(bound_pix)) 57 | 58 | # Get the intersection 59 | gt_match = gt_boundary * fg_dil 60 | fg_match = fg_boundary * gt_dil 61 | 62 | # Area of the intersection 63 | n_fg = np.sum(fg_boundary) 64 | n_gt = np.sum(gt_boundary) 65 | 66 | #% Compute precision and recall 67 | if n_fg == 0 and n_gt > 0: 68 | precision = 1 69 | recall = 0 70 | elif n_fg > 0 and n_gt == 0: 71 | precision = 0 72 | recall = 1 73 | elif n_fg == 0 and n_gt == 0: 74 | precision = 1 75 | recall = 1 76 | else: 77 | precision = np.sum(fg_match)/float(n_fg) 78 | recall = np.sum(gt_match)/float(n_gt) 79 | 80 | # Compute F measure 81 | if precision + recall == 0: 82 | F = 0 83 | else: 84 | F = 2*precision*recall/(precision+recall); 85 | 86 | return F 87 | 88 | def seg2bmap(seg,width=None,height=None): 89 | """ 90 | From a segmentation, compute a binary boundary map with 1 pixel wide 91 | boundaries. The boundary pixels are offset by 1/2 pixel towards the 92 | origin from the actual segment boundary. 93 | 94 | Arguments: 95 | seg : Segments labeled from 1..k. 96 | width : Width of desired bmap <= seg.shape[1] 97 | height : Height of desired bmap <= seg.shape[0] 98 | 99 | Returns: 100 | bmap (ndarray): Binary boundary map. 101 | 102 | David Martin 103 | January 2003 104 | """ 105 | 106 | seg = seg.astype(np.bool) 107 | seg[seg>0] = 1 108 | 109 | assert np.atleast_3d(seg).shape[2] == 1 110 | 111 | width = seg.shape[1] if width is None else width 112 | height = seg.shape[0] if height is None else height 113 | 114 | h,w = seg.shape[:2] 115 | 116 | ar1 = float(width) / float(height) 117 | ar2 = float(w) / float(h) 118 | 119 | assert not (width>w | height>h | abs(ar1-ar2)>0.01),\ 120 | 'Can''t convert %dx%d seg to %dx%d bmap.'%(w,h,width,height) 121 | 122 | e = np.zeros_like(seg) 123 | s = np.zeros_like(seg) 124 | se = np.zeros_like(seg) 125 | 126 | e[:,:-1] = seg[:,1:] 127 | s[:-1,:] = seg[1:,:] 128 | se[:-1,:-1] = seg[1:,1:] 129 | 130 | b = seg^e | seg^s | seg^se 131 | b[-1,:] = seg[-1,:]^e[-1,:] 132 | b[:,-1] = seg[:,-1]^s[:,-1] 133 | b[-1,-1] = 0 134 | 135 | if w == width and h == height: 136 | bmap = b 137 | else: 138 | bmap = np.zeros((height,width)) 139 | for x in range(w): 140 | for y in range(h): 141 | if b[y,x]: 142 | j = 1+floor((y-1)+height / h) 143 | i = 1+floor((x-1)+width / h) 144 | bmap[j,i] = 1; 145 | 146 | return bmap 147 | 148 | def db_eval_iou(annotation,segmentation): 149 | 150 | """ Compute region similarity as the Jaccard Index. 151 | Arguments: 152 | annotation (ndarray): binary annotation map. 153 | segmentation (ndarray): binary segmentation map. 154 | Return: 155 | jaccard (float): region similarity 156 | """ 157 | 158 | annotation = annotation.astype(np.bool) 159 | segmentation = segmentation.astype(np.bool) 160 | 161 | if np.isclose(np.sum(annotation),0) and np.isclose(np.sum(segmentation),0): 162 | return 1 163 | else: 164 | return np.sum((annotation & segmentation)) / \ 165 | np.sum((annotation | segmentation),dtype=np.float32) --------------------------------------------------------------------------------