├── .gitignore ├── LICENSE ├── README.md ├── dataset ├── __init__.py ├── bl_test_dataset.py ├── davis_test_dataset.py ├── fusion_dataset.py ├── onehot_util.py ├── range_transform.py ├── reseed.py └── yv_test_dataset.py ├── davis_processor.py ├── docs ├── css │ ├── materialize.css │ ├── materialize.min.css │ └── style.css ├── dataset.html ├── index.html ├── js │ ├── init.js │ ├── materialize.js │ └── materialize.min.js ├── supplementary.pdf └── video.html ├── download_bl30k.py ├── download_datasets.py ├── download_model.py ├── eval_interactive_davis.py ├── example └── example.mp4 ├── fbrs ├── LICENSE ├── __init__.py ├── controller.py ├── inference │ ├── __init__.py │ ├── clicker.py │ ├── evaluation.py │ ├── predictors │ │ ├── __init__.py │ │ ├── base.py │ │ ├── brs.py │ │ ├── brs_functors.py │ │ └── brs_losses.py │ ├── transforms │ │ ├── __init__.py │ │ ├── base.py │ │ ├── crops.py │ │ ├── flip.py │ │ ├── limit_longest_side.py │ │ └── zoom_in.py │ └── utils.py ├── model │ ├── __init__.py │ ├── initializer.py │ ├── is_deeplab_model.py │ ├── is_hrnet_model.py │ ├── losses.py │ ├── metrics.py │ ├── modeling │ │ ├── __init__.py │ │ ├── basic_blocks.py │ │ ├── deeplab_v3.py │ │ ├── hrnet_ocr.py │ │ ├── ocr.py │ │ ├── resnet.py │ │ └── resnetv1b.py │ ├── ops.py │ └── syncbn │ │ ├── LICENSE │ │ ├── README.md │ │ ├── __init__.py │ │ └── modules │ │ ├── __init__.py │ │ ├── functional │ │ ├── __init__.py │ │ ├── _csrc.py │ │ ├── csrc │ │ │ ├── bn.h │ │ │ ├── cuda │ │ │ │ ├── bn_cuda.cu │ │ │ │ ├── common.h │ │ │ │ └── ext_lib.h │ │ │ └── ext_lib.cpp │ │ └── syncbn.py │ │ └── nn │ │ ├── __init__.py │ │ └── syncbn.py └── utils │ ├── __init__.py │ ├── cython │ ├── __init__.py │ ├── _get_dist_maps.pyx │ ├── _get_dist_maps.pyxbld │ └── dist_maps.py │ ├── misc.py │ └── vis.py ├── generate_fusion.py ├── generation ├── __init__.py ├── blender │ ├── __init__.py │ ├── clean_data.py │ ├── gen_utils.py │ ├── generate_yaml.py │ ├── resize_texture.py │ ├── skybox.json │ ├── texture.json │ └── texture_2.json ├── fusion_generator.py └── test │ ├── image.png │ ├── test_spline.py │ └── test_spline_continuous.py ├── imgs └── framework.jpg ├── inference_core.py ├── interact ├── __init__.py ├── fbrs_controller.py ├── interaction.py ├── interactive_utils.py ├── s2m_controller.py └── timer.py ├── interactive_gui.py ├── model ├── __init__.py ├── aggregate.py ├── attn_network.py ├── fusion_model.py ├── fusion_net.py ├── losses.py ├── propagation │ ├── __init__.py │ ├── mod_resnet.py │ ├── modules.py │ └── prop_net.py └── s2m │ ├── __init__.py │ ├── _deeplab.py │ ├── s2m_network.py │ ├── s2m_resnet.py │ └── utils.py ├── scripts ├── __init__.py ├── resize_length.py └── resize_youtube.py ├── train.py └── util ├── __init__.py ├── cv2palette.py ├── davis_subset.txt ├── hyper_para.py ├── image_saver.py ├── load_subset.py ├── log_integrator.py ├── logger.py ├── palette.py └── tensor_util.py /.gitignore: -------------------------------------------------------------------------------- 1 | saves/ 2 | log/ 3 | output/ 4 | initmodel/ 5 | vis/ 6 | run 7 | .vscode/ 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Ho Kei Cheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/MiVOS/f2600a6eea8709c7b9f1a7575adc725def680b81/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/bl_test_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/seoungwugoh/STM/blob/master/dataset.py 3 | """ 4 | 5 | import os 6 | from os import path 7 | import numpy as np 8 | from PIL import Image 9 | 10 | import torch 11 | from torchvision import transforms 12 | from torch.utils.data.dataset import Dataset 13 | from dataset.range_transform import im_normalization 14 | from dataset.onehot_util import all_to_onehot 15 | 16 | 17 | class BLTestDataset(Dataset): 18 | def __init__(self, root, subset=None, start=None, end=None): 19 | self.root = root 20 | self.mask_dir = path.join(root, 'Annotations') 21 | self.image_dir = path.join(root, 'JPEGImages') 22 | 23 | self.videos = [] 24 | self.num_frames = {} 25 | for _video in os.listdir(self.image_dir): 26 | if subset is not None and _video not in subset: 27 | continue 28 | 29 | self.videos.append(_video) 30 | self.num_frames[_video] = len(os.listdir(path.join(self.image_dir, _video))) 31 | 32 | self.im_transform = transforms.Compose([ 33 | transforms.ToTensor(), 34 | im_normalization, 35 | ]) 36 | 37 | self.videos = sorted(self.videos) 38 | print('Total amount of videos: ', len(self.videos)) 39 | if (start is not None) and (end is not None): 40 | print('Taking crop from %d to %d. ' % (start, end)) 41 | self.videos = self.videos[start:end+1] 42 | print('New size: ', len(self.videos)) 43 | 44 | def __len__(self): 45 | return len(self.videos) 46 | 47 | def __getitem__(self, index): 48 | video = self.videos[index] 49 | info = {} 50 | info['name'] = video 51 | info['num_frames'] = self.num_frames[video] 52 | 53 | images = [] 54 | masks = [] 55 | for f in range(self.num_frames[video]): 56 | img_file = path.join(self.image_dir, video, '{:05d}.jpg'.format(f)) 57 | images.append(self.im_transform(Image.open(img_file).convert('RGB'))) 58 | 59 | mask_file = path.join(self.mask_dir, video, '{:05d}.png'.format(f)) 60 | masks.append(np.array(Image.open(mask_file).convert('P'), dtype=np.uint8)) 61 | 62 | images = torch.stack(images, 0) 63 | masks = np.stack(masks, 0) 64 | 65 | labels = np.unique(masks) 66 | labels = labels[labels!=0] 67 | masks = torch.from_numpy(all_to_onehot(masks, labels)).float() 68 | 69 | masks = masks.unsqueeze(2) 70 | 71 | info['labels'] = labels 72 | 73 | data = { 74 | 'rgb': images, 75 | 'gt': masks, 76 | 'info': info, 77 | } 78 | 79 | return data 80 | 81 | -------------------------------------------------------------------------------- /dataset/davis_test_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/seoungwugoh/STM/blob/master/dataset.py 3 | """ 4 | 5 | import os 6 | from os import path 7 | import numpy as np 8 | from PIL import Image 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from torchvision import transforms 13 | from torch.utils.data.dataset import Dataset 14 | from dataset.range_transform import im_normalization 15 | from dataset.onehot_util import all_to_onehot 16 | 17 | 18 | class DAVISTestDataset(Dataset): 19 | def __init__(self, root, imset='2017/val.txt', resolution='480p', single_object=False, target_name=None): 20 | self.root = root 21 | self.mask_dir = path.join(root, 'Annotations', resolution) 22 | self.mask480_dir = path.join(root, 'Annotations', '480p') 23 | self.image_dir = path.join(root, 'JPEGImages', resolution) 24 | self.resolution = resolution 25 | _imset_dir = path.join(root, 'ImageSets') 26 | _imset_f = path.join(_imset_dir, imset) 27 | 28 | self.videos = [] 29 | self.num_frames = {} 30 | self.num_objects = {} 31 | self.shape = {} 32 | self.size_480p = {} 33 | with open(path.join(_imset_f), "r") as lines: 34 | for line in lines: 35 | _video = line.rstrip('\n') 36 | if target_name is not None and target_name != _video: 37 | continue 38 | self.videos.append(_video) 39 | self.num_frames[_video] = len(os.listdir(path.join(self.image_dir, _video))) 40 | _mask = np.array(Image.open(path.join(self.mask_dir, _video, '00000.png')).convert("P")) 41 | self.num_objects[_video] = np.max(_mask) 42 | self.shape[_video] = np.shape(_mask) 43 | _mask480 = np.array(Image.open(path.join(self.mask480_dir, _video, '00000.png')).convert("P")) 44 | self.size_480p[_video] = np.shape(_mask480) 45 | 46 | self.single_object = single_object 47 | 48 | if resolution == '480p': 49 | self.im_transform = transforms.Compose([ 50 | transforms.ToTensor(), 51 | im_normalization, 52 | ]) 53 | else: 54 | self.im_transform = transforms.Compose([ 55 | transforms.ToTensor(), 56 | im_normalization, 57 | transforms.Resize(600, interpolation=Image.BICUBIC), 58 | ]) 59 | self.mask_transform = transforms.Compose([ 60 | transforms.Resize(600, interpolation=Image.NEAREST), 61 | ]) 62 | 63 | def __len__(self): 64 | return len(self.videos) 65 | 66 | def __getitem__(self, index): 67 | video = self.videos[index] 68 | info = {} 69 | info['name'] = video 70 | info['num_frames'] = self.num_frames[video] 71 | info['size_480p'] = self.size_480p[video] 72 | 73 | images = [] 74 | masks = [] 75 | for f in range(self.num_frames[video]): 76 | img_file = path.join(self.image_dir, video, '{:05d}.jpg'.format(f)) 77 | images.append(self.im_transform(Image.open(img_file).convert('RGB'))) 78 | 79 | mask_file = path.join(self.mask_dir, video, '{:05d}.png'.format(f)) 80 | if path.exists(mask_file): 81 | masks.append(np.array(Image.open(mask_file).convert('P'), dtype=np.uint8)) 82 | else: 83 | # Test-set maybe? 84 | masks.append(np.zeros_like(masks[0])) 85 | 86 | images = torch.stack(images, 0) 87 | masks = np.stack(masks, 0) 88 | 89 | if self.single_object: 90 | labels = [1] 91 | masks = (masks > 0.5).astype(np.uint8) 92 | masks = torch.from_numpy(all_to_onehot(masks, labels)).float() 93 | else: 94 | labels = np.unique(masks[0]) 95 | labels = labels[labels!=0] 96 | masks = torch.from_numpy(all_to_onehot(masks, labels)).float() 97 | 98 | if self.resolution != '480p': 99 | masks = self.mask_transform(masks) 100 | masks = masks.unsqueeze(2) 101 | 102 | info['labels'] = labels 103 | 104 | data = { 105 | 'rgb': images, 106 | 'gt': masks, 107 | 'info': info, 108 | } 109 | 110 | return data 111 | 112 | -------------------------------------------------------------------------------- /dataset/onehot_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def all_to_onehot(masks, labels): 5 | Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8) 6 | for k, l in enumerate(labels): 7 | Ms[k] = (masks == l).astype(np.uint8) 8 | return Ms 9 | -------------------------------------------------------------------------------- /dataset/range_transform.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | 3 | im_mean = (124, 116, 104) 4 | 5 | im_normalization = transforms.Normalize( 6 | mean=[0.485, 0.456, 0.406], 7 | std=[0.229, 0.224, 0.225] 8 | ) 9 | 10 | inv_im_trans = transforms.Normalize( 11 | mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], 12 | std=[1/0.229, 1/0.224, 1/0.225]) 13 | -------------------------------------------------------------------------------- /dataset/reseed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | 4 | def reseed(seed): 5 | random.seed(seed) 6 | torch.manual_seed(seed) -------------------------------------------------------------------------------- /dataset/yv_test_dataset.py: -------------------------------------------------------------------------------- 1 | # Partially taken from STM's dataloader 2 | 3 | import os 4 | from os import path 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.utils.data.dataset import Dataset 9 | from torchvision import transforms 10 | from PIL import Image 11 | import numpy as np 12 | import random 13 | 14 | from dataset.range_transform import im_normalization 15 | 16 | class YouTubeVOSTestDataset(Dataset): 17 | def __init__(self, data_root, split): 18 | self.image_dir = path.join(data_root, 'vos', 'all_frames', split, 'JPEGImages') 19 | self.mask_dir = path.join(data_root, 'vos', split, 'Annotations') 20 | 21 | self.videos = [] 22 | self.shape = {} 23 | self.frames = {} 24 | 25 | vid_list = sorted(os.listdir(self.image_dir)) 26 | # Pre-reading 27 | for vid in vid_list: 28 | frames = sorted(os.listdir(os.path.join(self.image_dir, vid))) 29 | self.frames[vid] = frames 30 | 31 | self.videos.append(vid) 32 | first_mask = os.listdir(path.join(self.mask_dir, vid))[0] 33 | _mask = np.array(Image.open(path.join(self.mask_dir, vid, first_mask)).convert("P")) 34 | self.shape[vid] = np.shape(_mask) 35 | 36 | self.im_transform = transforms.Compose([ 37 | transforms.ToTensor(), 38 | im_normalization, 39 | ]) 40 | 41 | # From STM's code 42 | def To_onehot(self, mask, labels): 43 | M = np.zeros((len(labels), mask.shape[0], mask.shape[1]), dtype=np.uint8) 44 | for k, l in enumerate(labels): 45 | M[k] = (mask == l).astype(np.uint8) 46 | return M 47 | 48 | def All_to_onehot(self, masks, labels): 49 | Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8) 50 | for n in range(masks.shape[0]): 51 | Ms[:,n] = self.To_onehot(masks[n], labels) 52 | return Ms 53 | 54 | def __getitem__(self, idx): 55 | video = self.videos[idx] 56 | info = {} 57 | info['name'] = video 58 | info['num_objects'] = 0 59 | info['frames'] = self.frames[video] 60 | info['size'] = self.shape[video] # Real sizes 61 | info['gt_obj'] = {} # Frames with labelled objects 62 | 63 | vid_im_path = path.join(self.image_dir, video) 64 | vid_gt_path = path.join(self.mask_dir, video) 65 | 66 | frames = self.frames[video] 67 | 68 | images = [] 69 | masks = [] 70 | for i, f in enumerate(frames): 71 | img = Image.open(path.join(vid_im_path, f)).convert('RGB') 72 | images.append(self.im_transform(img)) 73 | 74 | mask_file = path.join(vid_gt_path, f.replace('.jpg','.png')) 75 | if path.exists(mask_file): 76 | masks.append(np.array(Image.open(mask_file).convert('P'), dtype=np.uint8)) 77 | this_labels = np.unique(masks[-1]) 78 | this_labels = this_labels[this_labels!=0] 79 | info['gt_obj'][i] = this_labels 80 | else: 81 | # Mask not exists -> nothing in it 82 | masks.append(np.zeros(self.shape[video])) 83 | 84 | images = torch.stack(images, 0) 85 | masks = np.stack(masks, 0) 86 | 87 | # Construct the forward and backward mapping table for labels 88 | labels = np.unique(masks).astype(np.uint8) 89 | labels = labels[labels!=0] 90 | info['label_convert'] = {} 91 | info['label_backward'] = {} 92 | idx = 1 93 | for l in labels: 94 | info['label_convert'][l] = idx 95 | info['label_backward'][idx] = l 96 | idx += 1 97 | masks = torch.from_numpy(self.All_to_onehot(masks, labels)).float() 98 | 99 | # images = images.unsqueeze(0) 100 | masks = masks.unsqueeze(2) 101 | 102 | # Resize to 480p 103 | h, w = masks.shape[-2:] 104 | if h > w: 105 | new_size = (h*480//w, 480) 106 | else: 107 | new_size = (480, w*480//h) 108 | images = F.interpolate(images, size=new_size, mode='bicubic', align_corners=False) 109 | masks = F.interpolate(masks, size=(1, *new_size), mode='nearest') 110 | 111 | info['labels'] = labels 112 | 113 | data = { 114 | 'rgb': images, 115 | 'gt': masks, 116 | 'info': info, 117 | } 118 | 119 | return data 120 | 121 | def __len__(self): 122 | return len(self.videos) -------------------------------------------------------------------------------- /davis_processor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | 5 | from davisinteractive.utils.scribbles import scribbles2mask 6 | from inference_core import InferenceCore 7 | 8 | from model.s2m.s2m_network import deeplabv3plus_resnet50 as S2M 9 | from util.tensor_util import pad_divide_by, compute_tensor_iou 10 | from model.aggregate import aggregate_sbg, aggregate_wbg 11 | 12 | class DAVISProcessor: 13 | """ 14 | Acts as the junction between DAVIS interactive track and our inference_core 15 | """ 16 | def __init__(self, prop_net, fuse_net, s2m_net, images, num_objects, device='cuda:0'): 17 | self.s2m_net = s2m_net.to(device, non_blocking=True) 18 | 19 | images, self.pad = pad_divide_by(images, 16, images.shape[-2:]) 20 | self.device = device 21 | 22 | # Padded dimensions 23 | nh, nw = images.shape[-2:] 24 | self.nh, self.nw = nh, nw 25 | 26 | # True dimensions 27 | t = images.shape[1] 28 | h, w = images.shape[-2:] 29 | 30 | self.k = num_objects 31 | self.t, self.h, self.w = t, h, w 32 | 33 | self.interacted_count = 0 34 | self.davis_schedule = [2, 5, 7] 35 | 36 | self.processor = InferenceCore(prop_net, fuse_net, images, num_objects, mem_profile=0, device=device) 37 | 38 | def to_mask(self, scribble): 39 | # First we select the only frame with scribble 40 | all_scr = scribble['scribbles'] 41 | # all_scr is a list. len(all_scr) == total number of frames 42 | for idx, s in enumerate(all_scr): 43 | # The only non-empty element in all_scr is the frame that has been interacted with 44 | if len(s) != 0: 45 | scribble['scribbles'] = [s] 46 | # since we break here, idx will remain at the interacted frame and can be used below 47 | break 48 | 49 | # Pass to DAVIS to change the path to an array 50 | scr_mask = scribbles2mask(scribble, (self.h, self.w))[0] 51 | 52 | # Run our S2M 53 | kernel = np.ones((3,3), np.uint8) 54 | mask = torch.zeros((self.k, 1, self.nh, self.nw), dtype=torch.float32, device=self.device) 55 | for ki in range(1, self.k+1): 56 | p_srb = (scr_mask==ki).astype(np.uint8) 57 | p_srb = cv2.dilate(p_srb, kernel).astype(np.bool) 58 | 59 | n_srb = ((scr_mask!=ki) * (scr_mask!=-1)).astype(np.uint8) 60 | n_srb = cv2.dilate(n_srb, kernel).astype(np.bool) 61 | 62 | Rs = torch.from_numpy(np.stack([p_srb, n_srb], 0)).unsqueeze(0).float().to(self.device) 63 | Rs, _ = pad_divide_by(Rs, 16, Rs.shape[-2:]) 64 | 65 | # Use hard mask because we train S2M with such 66 | inputs = torch.cat([self.processor.get_image_buffered(idx), 67 | (self.processor.masks[idx]==ki).to(self.device).float().unsqueeze(0), Rs], 1) 68 | mask[ki-1] = torch.sigmoid(self.s2m_net(inputs)) 69 | mask = aggregate_wbg(mask, keep_bg=True, hard=True) 70 | return mask, idx 71 | 72 | def interact(self, scribble): 73 | mask, idx = self.to_mask(scribble) 74 | 75 | if self.interacted_count == self.davis_schedule[0]: 76 | # Finish the instant interaction loop for this frame 77 | self.davis_schedule = self.davis_schedule[1:] 78 | next_interact = None 79 | out_masks = self.processor.interact(mask, idx) 80 | else: 81 | next_interact = [idx] 82 | out_masks = self.processor.update_mask_only(mask, idx) 83 | 84 | self.interacted_count += 1 85 | 86 | # Trim paddings 87 | if self.pad[2]+self.pad[3] > 0: 88 | out_masks = out_masks[:,self.pad[2]:-self.pad[3],:] 89 | if self.pad[0]+self.pad[1] > 0: 90 | out_masks = out_masks[:,:,self.pad[0]:-self.pad[1]] 91 | 92 | return out_masks, next_interact, idx 93 | -------------------------------------------------------------------------------- /docs/css/style.css: -------------------------------------------------------------------------------- 1 | /* Custom Stylesheet */ 2 | /** 3 | * Use this file to override Materialize files so you can update 4 | * the core Materialize files in the future 5 | * 6 | * Made By MaterializeCSS.com 7 | */ 8 | 9 | .icon-block { 10 | padding: 0 15px; 11 | } 12 | .icon-block .material-icons { 13 | font-size: inherit; 14 | } 15 | 16 | .larger-text { 17 | font-size: x-large; 18 | } -------------------------------------------------------------------------------- /docs/js/init.js: -------------------------------------------------------------------------------- 1 | (function($){ 2 | $(function(){ 3 | 4 | $('.sidenav').sidenav(); 5 | 6 | }); // end of document ready 7 | })(jQuery); // end of jQuery name space 8 | -------------------------------------------------------------------------------- /docs/supplementary.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/MiVOS/f2600a6eea8709c7b9f1a7575adc725def680b81/docs/supplementary.pdf -------------------------------------------------------------------------------- /download_bl30k.py: -------------------------------------------------------------------------------- 1 | # NO LONGER AVAILABLE 2 | # Use https://doi.org/10.13012/B2IDB-1702934_V1 3 | 4 | # import os 5 | # import gdown 6 | # import tarfile 7 | 8 | 9 | # LICENSE = """ 10 | # This dataset is a derivative of ShapeNet. 11 | # Please read and respect their licenses and terms before use. 12 | # Textures and skybox image are obtained from Google image search with the "non-commercial reuse" flag. 13 | # Do not use this dataset for commercial purposes. 14 | # You should cite both ShapeNet and our paper if you use this dataset. 15 | # """ 16 | 17 | # print(LICENSE) 18 | # print('Datasets will be downloaded and extracted to ../BL30K') 19 | # print('The script will download and extract the segment one by one') 20 | # print('You are going to need ~1TB of free disk space') 21 | # reply = input('[y] to confirm, others to exit: ') 22 | # if reply != 'y': 23 | # exit() 24 | 25 | # links = [ 26 | # 'https://drive.google.com/uc?id=1z9V5zxLOJLNt1Uj7RFqaP2FZWKzyXvVc', 27 | # 'https://drive.google.com/uc?id=11-IzgNwEAPxgagb67FSrBdzZR7OKAEdJ', 28 | # 'https://drive.google.com/uc?id=1ZfIv6GTo-OGpXpoKen1fUvDQ0A_WoQ-Q', 29 | # 'https://drive.google.com/uc?id=1G4eXgYS2kL7_Cc0x3N1g1x7Zl8D_aU_-', 30 | # 'https://drive.google.com/uc?id=1Y8q0V_oBwJIY27W_6-8CD1dRqV2gNTdE', 31 | # 'https://drive.google.com/uc?id=1nawBAazf_unMv46qGBHhWcQ4JXZ5883r', 32 | # ] 33 | 34 | # names = [ 35 | # 'BL30K_a.tar', 36 | # 'BL30K_b.tar', 37 | # 'BL30K_c.tar', 38 | # 'BL30K_d.tar', 39 | # 'BL30K_e.tar', 40 | # 'BL30K_f.tar', 41 | # ] 42 | 43 | # for i, link in enumerate(links): 44 | # print('Downloading segment %d/%d ...' % (i, len(links))) 45 | # gdown.download(link, output='../%s' % names[i], quiet=False) 46 | # print('Extracting...') 47 | # with tarfile.open('../%s' % names[i], 'r') as tar_file: 48 | # tar_file.extractall('../%s' % names[i]) 49 | # print('Cleaning up...') 50 | # os.remove('../%s' % names[i]) 51 | 52 | 53 | # print('Done.') 54 | -------------------------------------------------------------------------------- /download_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gdown 3 | import zipfile 4 | from scripts import resize_youtube 5 | 6 | 7 | LICENSE = """ 8 | These are either re-distribution of the original datasets or derivatives (through simple processing) of the original datasets. 9 | Please read and respect their licenses and terms before use. 10 | You should cite the original papers if you use any of the datasets. 11 | 12 | For BL30K, see download_bl30k.py 13 | 14 | Links: 15 | DUTS: http://saliencydetection.net/duts 16 | HRSOD: https://github.com/yi94code/HRSOD 17 | FSS: https://github.com/HKUSTCV/FSS-1000 18 | ECSSD: https://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/dataset.html 19 | BIG: https://github.com/hkchengrex/CascadePSP 20 | 21 | YouTubeVOS: https://youtube-vos.org 22 | DAVIS: https://davischallenge.org/ 23 | BL30K: https://github.com/hkchengrex/MiVOS 24 | """ 25 | 26 | print(LICENSE) 27 | print('Datasets will be downloaded and extracted to ../YouTube, ../static, ../DAVIS') 28 | reply = input('[y] to confirm, others to exit: ') 29 | if reply != 'y': 30 | exit() 31 | 32 | 33 | # Static data 34 | os.makedirs('../static', exist_ok=True) 35 | print('Downloading static datasets...') 36 | gdown.download('https://drive.google.com/uc?id=1wUJq3HcLdN-z1t4CsUhjeZ9BVDb9YKLd', output='../static/static_data.zip', quiet=False) 37 | print('Extracting static datasets...') 38 | with zipfile.ZipFile('../static/static_data.zip', 'r') as zip_file: 39 | zip_file.extractall('../static/') 40 | print('Cleaning up static datasets...') 41 | os.remove('../static/static_data.zip') 42 | 43 | 44 | # DAVIS 45 | # Google drive mirror: https://drive.google.com/drive/folders/1hEczGHw7qcMScbCJukZsoOW4Q9byx16A?usp=sharing 46 | os.makedirs('../DAVIS/2017', exist_ok=True) 47 | 48 | print('Downloading DAVIS 2016...') 49 | gdown.download('https://drive.google.com/uc?id=198aRlh5CpAoFz0hfRgYbiNenn_K8DxWD', output='../DAVIS/DAVIS-data.zip', quiet=False) 50 | 51 | print('Downloading DAVIS 2017 trainval...') 52 | gdown.download('https://drive.google.com/uc?id=1kiaxrX_4GuW6NmiVuKGSGVoKGWjOdp6d', output='../DAVIS/2017/DAVIS-2017-trainval-480p.zip', quiet=False) 53 | 54 | print('Downloading DAVIS 2017 testdev...') 55 | gdown.download('https://drive.google.com/uc?id=1fmkxU2v9cQwyb62Tj1xFDdh2p4kDsUzD', output='../DAVIS/2017/DAVIS-2017-test-dev-480p.zip', quiet=False) 56 | 57 | print('Downloading DAVIS 2017 scribbles...') 58 | gdown.download('https://drive.google.com/uc?id=1JzIQSu36h7dVM8q0VoE4oZJwBXvrZlkl', output='../DAVIS/2017/DAVIS-2017-scribbles-trainval.zip', quiet=False) 59 | 60 | print('Extracting DAVIS datasets...') 61 | with zipfile.ZipFile('../DAVIS/DAVIS-data.zip', 'r') as zip_file: 62 | zip_file.extractall('../DAVIS/') 63 | os.rename('../DAVIS/DAVIS', '../DAVIS/2016') 64 | 65 | with zipfile.ZipFile('../DAVIS/2017/DAVIS-2017-trainval-480p.zip', 'r') as zip_file: 66 | zip_file.extractall('../DAVIS/2017/') 67 | with zipfile.ZipFile('../DAVIS/2017/DAVIS-2017-scribbles-trainval.zip', 'r') as zip_file: 68 | zip_file.extractall('../DAVIS/2017/') 69 | os.rename('../DAVIS/2017/DAVIS', '../DAVIS/2017/trainval') 70 | 71 | with zipfile.ZipFile('../DAVIS/2017/DAVIS-2017-test-dev-480p.zip', 'r') as zip_file: 72 | zip_file.extractall('../DAVIS/2017/') 73 | os.rename('../DAVIS/2017/DAVIS', '../DAVIS/2017/test-dev') 74 | 75 | print('Cleaning up DAVIS datasets...') 76 | os.remove('../DAVIS/2017/DAVIS-2017-trainval-480p.zip') 77 | os.remove('../DAVIS/2017/DAVIS-2017-test-dev-480p.zip') 78 | os.remove('../DAVIS/2017/DAVIS-2017-scribbles-trainval.zip') 79 | os.remove('../DAVIS/DAVIS-data.zip') 80 | 81 | 82 | # YouTubeVOS 83 | os.makedirs('../YouTube', exist_ok=True) 84 | os.makedirs('../YouTube/all_frames', exist_ok=True) 85 | print('Downloading YouTubeVOS train...') 86 | gdown.download('https://drive.google.com/uc?id=13Eqw0gVK-AO5B-cqvJ203mZ2vzWck9s4', output='../YouTube/train.zip', quiet=False) 87 | print('Downloading YouTubeVOS val...') 88 | gdown.download('https://drive.google.com/uc?id=1o586Wjya-f2ohxYf9C1RlRH-gkrzGS8t', output='../YouTube/valid.zip', quiet=False) 89 | print('Downloading YouTubeVOS all frames valid...') 90 | gdown.download('https://drive.google.com/uc?id=1rWQzZcMskgpEQOZdJPJ7eTmLCBEIIpEN', output='../YouTube/all_frames/valid.zip', quiet=False) 91 | 92 | print('Extracting YouTube datasets...') 93 | with zipfile.ZipFile('../YouTube/train.zip', 'r') as zip_file: 94 | zip_file.extractall('../YouTube/') 95 | with zipfile.ZipFile('../YouTube/valid.zip', 'r') as zip_file: 96 | zip_file.extractall('../YouTube/') 97 | with zipfile.ZipFile('../YouTube/all_frames/valid.zip', 'r') as zip_file: 98 | zip_file.extractall('../YouTube/all_frames') 99 | 100 | print('Cleaning up YouTubeVOS datasets...') 101 | os.remove('../YouTube/train.zip') 102 | os.remove('../YouTube/valid.zip') 103 | os.remove('../YouTube/all_frames/valid.zip') 104 | 105 | print('Resizing YouTubeVOS to 480p...') 106 | resize_youtube.resize_all('../YouTube/train', '../YouTube/train_480p') 107 | 108 | print('Done.') -------------------------------------------------------------------------------- /download_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gdown 3 | import urllib.request 4 | 5 | 6 | os.makedirs('saves', exist_ok=True) 7 | print('Downloading propagation model...') 8 | gdown.download('https://drive.google.com/uc?id=19dfbVDndFkboGLHESi8DGtuxF1B21Nm8', output='saves/propagation_model.pth', quiet=False) 9 | 10 | print('Downloading fusion model...') 11 | gdown.download('https://drive.google.com/uc?id=1Lc1lI5-ix4WsCRdipACXgvS3G-o0lMoz', output='saves/fusion.pth', quiet=False) 12 | 13 | print('Downloading s2m model...') 14 | gdown.download('https://drive.google.com/uc?id=1HKwklVey3P2jmmdmrACFlkXtcvNxbKMM', output='saves/s2m.pth', quiet=False) 15 | 16 | print('Downloading fbrs model...') 17 | urllib.request.urlretrieve('https://github.com/saic-vul/fbrs_interactive_segmentation/releases/download/v1.0/resnet50_dh128_lvis.pth', 'saves/fbrs.pth') 18 | 19 | print('Done.') -------------------------------------------------------------------------------- /eval_interactive_davis.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | from argparse import ArgumentParser 4 | 5 | import torch 6 | from torch.utils.data import Dataset, DataLoader 7 | import numpy as np 8 | from PIL import Image 9 | import cv2 10 | 11 | from model.propagation.prop_net import PropagationNetwork 12 | from model.fusion_net import FusionNet 13 | from model.s2m.s2m_network import deeplabv3plus_resnet50 as S2M 14 | from dataset.davis_test_dataset import DAVISTestDataset 15 | from davis_processor import DAVISProcessor 16 | 17 | from davisinteractive.session.session import DavisInteractiveSession 18 | 19 | """ 20 | Arguments loading 21 | """ 22 | parser = ArgumentParser() 23 | parser.add_argument('--prop_model', default='saves/propagation_model.pth') 24 | parser.add_argument('--fusion_model', default='saves/fusion.pth') 25 | parser.add_argument('--s2m_model', default='saves/s2m.pth') 26 | parser.add_argument('--davis', default='../DAVIS/2017') 27 | parser.add_argument('--output') 28 | parser.add_argument('--save_mask', action='store_true') 29 | 30 | args = parser.parse_args() 31 | 32 | davis_path = args.davis 33 | out_path = args.output 34 | save_mask = args.save_mask 35 | 36 | # Simple setup 37 | os.makedirs(out_path, exist_ok=True) 38 | palette = Image.open(path.expanduser(davis_path + '/trainval/Annotations/480p/blackswan/00000.png')).getpalette() 39 | 40 | torch.autograd.set_grad_enabled(False) 41 | 42 | # Setup Dataset 43 | test_dataset = DAVISTestDataset(davis_path+'/trainval', imset='2017/val.txt') 44 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2) 45 | 46 | images = {} 47 | num_objects = {} 48 | # Loads all the images 49 | for data in test_loader: 50 | rgb = data['rgb'] 51 | k = len(data['info']['labels'][0]) 52 | name = data['info']['name'][0] 53 | images[name] = rgb 54 | num_objects[name] = k 55 | print('Finished loading %d sequences.' % len(images)) 56 | 57 | # Load our checkpoint 58 | prop_saved = torch.load(args.prop_model) 59 | prop_model = PropagationNetwork().cuda().eval() 60 | prop_model.load_state_dict(prop_saved) 61 | 62 | fusion_saved = torch.load(args.fusion_model) 63 | fusion_model = FusionNet().cuda().eval() 64 | fusion_model.load_state_dict(fusion_saved) 65 | 66 | s2m_saved = torch.load(args.s2m_model) 67 | s2m_model = S2M().cuda().eval() 68 | s2m_model.load_state_dict(s2m_saved) 69 | 70 | total_iter = 0 71 | user_iter = 0 72 | last_seq = None 73 | pred_masks = None 74 | with DavisInteractiveSession(davis_root=davis_path+'/trainval', report_save_dir='../output', max_nb_interactions=8, max_time=8*30) as sess: 75 | while sess.next(): 76 | sequence, scribbles, new_seq = sess.get_scribbles(only_last=True) 77 | 78 | if new_seq: 79 | if 'processor' in locals(): 80 | # Note that ALL pre-computed features are flushed in this step 81 | # We are not using pre-computed features for the same sequence with different user-id 82 | del processor # Should release some juicy mem 83 | processor = DAVISProcessor(prop_model, fusion_model, s2m_model, images[sequence], num_objects[sequence]) 84 | print(sequence) 85 | 86 | # Save last time 87 | if save_mask: 88 | if pred_masks is not None: 89 | seq_path = path.join(out_path, str(user_iter), last_seq) 90 | os.makedirs(seq_path, exist_ok=True) 91 | for i in range(len(pred_masks)): 92 | img_E = Image.fromarray(pred_masks[i]) 93 | img_E.putpalette(palette) 94 | img_E.save(os.path.join(seq_path, '{:05d}.png'.format(i))) 95 | 96 | if (last_seq is None) or (sequence != last_seq): 97 | last_seq = sequence 98 | user_iter = 0 99 | else: 100 | user_iter += 1 101 | 102 | pred_masks, next_masks, this_idx = processor.interact(scribbles) 103 | sess.submit_masks(pred_masks, next_masks) 104 | 105 | total_iter += 1 106 | 107 | report = sess.get_report() 108 | summary = sess.get_global_summary(save_file=path.join(out_path, 'summary.json')) 109 | -------------------------------------------------------------------------------- /example/example.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/MiVOS/f2600a6eea8709c7b9f1a7575adc725def680b81/example/example.mp4 -------------------------------------------------------------------------------- /fbrs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/MiVOS/f2600a6eea8709c7b9f1a7575adc725def680b81/fbrs/__init__.py -------------------------------------------------------------------------------- /fbrs/controller.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torchvision import transforms 4 | 5 | from fbrs.inference import clicker 6 | from fbrs.inference.predictors import get_predictor 7 | from fbrs.utils.vis import draw_with_blend_and_clicks 8 | 9 | 10 | class InteractiveController: 11 | def __init__(self, net, device, predictor_params, prob_thresh=0.5): 12 | self.net = net.to(device) 13 | self.prob_thresh = prob_thresh 14 | self.clicker = clicker.Clicker() 15 | self.states = [] 16 | self.probs_history = [] 17 | self.object_count = 0 18 | self._result_mask = None 19 | 20 | self.image = None 21 | self.predictor = None 22 | self.device = device 23 | self.predictor_params = predictor_params 24 | self.reset_predictor() 25 | 26 | def set_image(self, image): 27 | self.image = image 28 | self._result_mask = torch.zeros(image.shape[-2:], dtype=torch.uint8) 29 | self.object_count = 0 30 | self.reset_last_object() 31 | 32 | def add_click(self, x, y, is_positive): 33 | self.states.append({ 34 | 'clicker': self.clicker.get_state(), 35 | 'predictor': self.predictor.get_states() 36 | }) 37 | 38 | click = clicker.Click(is_positive=is_positive, coords=(y, x)) 39 | self.clicker.add_click(click) 40 | pred = self.predictor.get_prediction(self.clicker) 41 | torch.cuda.empty_cache() 42 | 43 | if self.probs_history: 44 | self.probs_history.append((self.probs_history[-1][0], pred)) 45 | else: 46 | self.probs_history.append((torch.zeros_like(pred), pred)) 47 | 48 | def undo_click(self): 49 | if not self.states: 50 | return 51 | 52 | prev_state = self.states.pop() 53 | self.clicker.set_state(prev_state['clicker']) 54 | self.predictor.set_states(prev_state['predictor']) 55 | self.probs_history.pop() 56 | 57 | def partially_finish_object(self): 58 | object_prob = self.current_object_prob 59 | if object_prob is None: 60 | return 61 | 62 | self.probs_history.append((object_prob, torch.zeros_like(object_prob))) 63 | self.states.append(self.states[-1]) 64 | 65 | self.clicker.reset_clicks() 66 | self.reset_predictor() 67 | 68 | def finish_object(self): 69 | object_prob = self.current_object_prob 70 | if object_prob is None: 71 | return 72 | 73 | self.object_count += 1 74 | object_mask = object_prob > self.prob_thresh 75 | self._result_mask[object_mask] = self.object_count 76 | self.reset_last_object() 77 | 78 | def reset_last_object(self): 79 | self.states = [] 80 | self.probs_history = [] 81 | self.clicker.reset_clicks() 82 | self.reset_predictor() 83 | 84 | def reset_predictor(self, predictor_params=None): 85 | if predictor_params is not None: 86 | self.predictor_params = predictor_params 87 | self.predictor = get_predictor(self.net, device=self.device, 88 | **self.predictor_params) 89 | if self.image is not None: 90 | self.predictor.set_input_image(self.image) 91 | 92 | @property 93 | def current_object_prob(self): 94 | if self.probs_history: 95 | current_prob_total, current_prob_additive = self.probs_history[-1] 96 | return torch.maximum(current_prob_total, current_prob_additive) 97 | else: 98 | return None 99 | 100 | @property 101 | def is_incomplete_mask(self): 102 | return len(self.probs_history) > 0 103 | 104 | @property 105 | def result_mask(self): 106 | return self._result_mask.clone() 107 | -------------------------------------------------------------------------------- /fbrs/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/MiVOS/f2600a6eea8709c7b9f1a7575adc725def680b81/fbrs/inference/__init__.py -------------------------------------------------------------------------------- /fbrs/inference/clicker.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import numpy as np 4 | from copy import deepcopy 5 | from scipy.ndimage import distance_transform_edt 6 | 7 | Click = namedtuple('Click', ['is_positive', 'coords']) 8 | 9 | 10 | class Clicker(object): 11 | def __init__(self, gt_mask=None, init_clicks=None, ignore_label=-1): 12 | if gt_mask is not None: 13 | self.gt_mask = gt_mask == 1 14 | self.not_ignore_mask = gt_mask != ignore_label 15 | else: 16 | self.gt_mask = None 17 | 18 | self.reset_clicks() 19 | 20 | if init_clicks is not None: 21 | for click in init_clicks: 22 | self.add_click(click) 23 | 24 | def make_next_click(self, pred_mask): 25 | assert self.gt_mask is not None 26 | click = self._get_click(pred_mask) 27 | self.add_click(click) 28 | 29 | def get_clicks(self, clicks_limit=None): 30 | return self.clicks_list[:clicks_limit] 31 | 32 | def _get_click(self, pred_mask, padding=True): 33 | fn_mask = np.logical_and(np.logical_and(self.gt_mask, np.logical_not(pred_mask)), self.not_ignore_mask) 34 | fp_mask = np.logical_and(np.logical_and(np.logical_not(self.gt_mask), pred_mask), self.not_ignore_mask) 35 | 36 | if padding: 37 | fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), 'constant') 38 | fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), 'constant') 39 | 40 | fn_mask_dt = distance_transform_edt(fn_mask) 41 | fp_mask_dt = distance_transform_edt(fp_mask) 42 | 43 | if padding: 44 | fn_mask_dt = fn_mask_dt[1:-1, 1:-1] 45 | fp_mask_dt = fp_mask_dt[1:-1, 1:-1] 46 | 47 | fn_mask_dt = fn_mask_dt * self.not_clicked_map 48 | fp_mask_dt = fp_mask_dt * self.not_clicked_map 49 | 50 | fn_max_dist = np.max(fn_mask_dt) 51 | fp_max_dist = np.max(fp_mask_dt) 52 | 53 | is_positive = fn_max_dist > fp_max_dist 54 | if is_positive: 55 | coords_y, coords_x = np.where(fn_mask_dt == fn_max_dist) # coords is [y, x] 56 | else: 57 | coords_y, coords_x = np.where(fp_mask_dt == fp_max_dist) # coords is [y, x] 58 | 59 | return Click(is_positive=is_positive, coords=(coords_y[0], coords_x[0])) 60 | 61 | def add_click(self, click): 62 | coords = click.coords 63 | 64 | if click.is_positive: 65 | self.num_pos_clicks += 1 66 | else: 67 | self.num_neg_clicks += 1 68 | 69 | self.clicks_list.append(click) 70 | if self.gt_mask is not None: 71 | self.not_clicked_map[coords[0], coords[1]] = False 72 | 73 | def _remove_last_click(self): 74 | click = self.clicks_list.pop() 75 | coords = click.coords 76 | 77 | if click.is_positive: 78 | self.num_pos_clicks -= 1 79 | else: 80 | self.num_neg_clicks -= 1 81 | 82 | if self.gt_mask is not None: 83 | self.not_clicked_map[coords[0], coords[1]] = True 84 | 85 | def reset_clicks(self): 86 | if self.gt_mask is not None: 87 | self.not_clicked_map = np.ones_like(self.gt_mask, dtype=np.bool) 88 | 89 | self.num_pos_clicks = 0 90 | self.num_neg_clicks = 0 91 | 92 | self.clicks_list = [] 93 | 94 | def get_state(self): 95 | return deepcopy(self.clicks_list) 96 | 97 | def set_state(self, state): 98 | self.reset_clicks() 99 | for click in state: 100 | self.add_click(click) 101 | 102 | def __len__(self): 103 | return len(self.clicks_list) 104 | -------------------------------------------------------------------------------- /fbrs/inference/evaluation.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from fbrs.inference import utils 7 | from fbrs.inference.clicker import Clicker 8 | 9 | try: 10 | get_ipython() 11 | from tqdm import tqdm_notebook as tqdm 12 | except NameError: 13 | from tqdm import tqdm 14 | 15 | 16 | def evaluate_dataset(dataset, predictor, oracle_eval=False, **kwargs): 17 | all_ious = [] 18 | 19 | start_time = time() 20 | for index in tqdm(range(len(dataset)), leave=False): 21 | sample = dataset.get_sample(index) 22 | item = dataset[index] 23 | 24 | if oracle_eval: 25 | gt_mask = torch.tensor(sample['instances_mask'], dtype=torch.float32) 26 | gt_mask = gt_mask.unsqueeze(0).unsqueeze(0) 27 | predictor.opt_functor.mask_loss.set_gt_mask(gt_mask) 28 | _, sample_ious, _ = evaluate_sample(item['images'], sample['instances_mask'], predictor, **kwargs) 29 | all_ious.append(sample_ious) 30 | end_time = time() 31 | elapsed_time = end_time - start_time 32 | 33 | return all_ious, elapsed_time 34 | 35 | 36 | def evaluate_sample(image_nd, instances_mask, predictor, max_iou_thr, 37 | pred_thr=0.49, max_clicks=20): 38 | clicker = Clicker(gt_mask=instances_mask) 39 | pred_mask = np.zeros_like(instances_mask) 40 | ious_list = [] 41 | 42 | with torch.no_grad(): 43 | predictor.set_input_image(image_nd) 44 | 45 | for click_number in range(max_clicks): 46 | clicker.make_next_click(pred_mask) 47 | pred_probs = predictor.get_prediction(clicker) 48 | pred_mask = pred_probs > pred_thr 49 | 50 | iou = utils.get_iou(instances_mask, pred_mask) 51 | ious_list.append(iou) 52 | 53 | if iou >= max_iou_thr: 54 | break 55 | 56 | return clicker.clicks_list, np.array(ious_list, dtype=np.float32), pred_probs 57 | -------------------------------------------------------------------------------- /fbrs/inference/predictors/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BasePredictor 2 | from .brs import InputBRSPredictor, FeatureBRSPredictor, HRNetFeatureBRSPredictor 3 | from .brs_functors import InputOptimizer, ScaleBiasOptimizer 4 | from fbrs.inference.transforms import ZoomIn 5 | from fbrs.model.is_hrnet_model import DistMapsHRNetModel 6 | 7 | 8 | def get_predictor(net, brs_mode, device, 9 | prob_thresh=0.49, 10 | with_flip=True, 11 | zoom_in_params=dict(), 12 | predictor_params=None, 13 | brs_opt_func_params=None, 14 | lbfgs_params=None): 15 | lbfgs_params_ = { 16 | 'm': 20, 17 | 'factr': 0, 18 | 'pgtol': 1e-8, 19 | 'maxfun': 20, 20 | } 21 | 22 | predictor_params_ = { 23 | 'optimize_after_n_clicks': 1 24 | } 25 | 26 | if zoom_in_params is not None: 27 | zoom_in = ZoomIn(**zoom_in_params) 28 | else: 29 | zoom_in = None 30 | 31 | if lbfgs_params is not None: 32 | lbfgs_params_.update(lbfgs_params) 33 | lbfgs_params_['maxiter'] = 2 * lbfgs_params_['maxfun'] 34 | 35 | if brs_opt_func_params is None: 36 | brs_opt_func_params = dict() 37 | 38 | if brs_mode == 'NoBRS': 39 | if predictor_params is not None: 40 | predictor_params_.update(predictor_params) 41 | predictor = BasePredictor(net, device, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_) 42 | elif brs_mode.startswith('f-BRS'): 43 | predictor_params_.update({ 44 | 'net_clicks_limit': 8, 45 | }) 46 | if predictor_params is not None: 47 | predictor_params_.update(predictor_params) 48 | 49 | insertion_mode = { 50 | 'f-BRS-A': 'after_c4', 51 | 'f-BRS-B': 'after_aspp', 52 | 'f-BRS-C': 'after_deeplab' 53 | }[brs_mode] 54 | 55 | opt_functor = ScaleBiasOptimizer(prob_thresh=prob_thresh, 56 | with_flip=with_flip, 57 | optimizer_params=lbfgs_params_, 58 | **brs_opt_func_params) 59 | 60 | if isinstance(net, DistMapsHRNetModel): 61 | FeaturePredictor = HRNetFeatureBRSPredictor 62 | insertion_mode = {'after_c4': 'A', 'after_aspp': 'A', 'after_deeplab': 'C'}[insertion_mode] 63 | else: 64 | FeaturePredictor = FeatureBRSPredictor 65 | 66 | predictor = FeaturePredictor(net, device, 67 | opt_functor=opt_functor, 68 | with_flip=with_flip, 69 | insertion_mode=insertion_mode, 70 | zoom_in=zoom_in, 71 | **predictor_params_) 72 | elif brs_mode == 'RGB-BRS' or brs_mode == 'DistMap-BRS': 73 | use_dmaps = brs_mode == 'DistMap-BRS' 74 | 75 | predictor_params_.update({ 76 | 'net_clicks_limit': 5, 77 | }) 78 | if predictor_params is not None: 79 | predictor_params_.update(predictor_params) 80 | 81 | opt_functor = InputOptimizer(prob_thresh=prob_thresh, 82 | with_flip=with_flip, 83 | optimizer_params=lbfgs_params_, 84 | **brs_opt_func_params) 85 | 86 | predictor = InputBRSPredictor(net, device, 87 | optimize_target='dmaps' if use_dmaps else 'rgb', 88 | opt_functor=opt_functor, 89 | with_flip=with_flip, 90 | zoom_in=zoom_in, 91 | **predictor_params_) 92 | else: 93 | raise NotImplementedError 94 | 95 | return predictor 96 | -------------------------------------------------------------------------------- /fbrs/inference/predictors/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from fbrs.inference.transforms import AddHorizontalFlip, SigmoidForPred, LimitLongestSide 5 | 6 | 7 | class BasePredictor(object): 8 | def __init__(self, net, device, 9 | net_clicks_limit=None, 10 | with_flip=False, 11 | zoom_in=None, 12 | max_size=None, 13 | **kwargs): 14 | self.net = net 15 | self.with_flip = with_flip 16 | self.net_clicks_limit = net_clicks_limit 17 | self.original_image = None 18 | self.device = device 19 | self.zoom_in = zoom_in 20 | 21 | self.transforms = [zoom_in] if zoom_in is not None else [] 22 | if max_size is not None: 23 | self.transforms.append(LimitLongestSide(max_size=max_size)) 24 | self.transforms.append(SigmoidForPred()) 25 | if with_flip: 26 | self.transforms.append(AddHorizontalFlip()) 27 | 28 | def set_input_image(self, image_nd): 29 | for transform in self.transforms: 30 | transform.reset() 31 | self.original_image = image_nd.to(self.device) 32 | if len(self.original_image.shape) == 3: 33 | self.original_image = self.original_image.unsqueeze(0) 34 | 35 | def get_prediction(self, clicker): 36 | clicks_list = clicker.get_clicks() 37 | 38 | image_nd, clicks_lists, is_image_changed = self.apply_transforms( 39 | self.original_image, [clicks_list] 40 | ) 41 | 42 | pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed) 43 | prediction = F.interpolate(pred_logits, mode='bilinear', align_corners=True, 44 | size=image_nd.size()[2:]) 45 | 46 | for t in reversed(self.transforms): 47 | prediction = t.inv_transform(prediction) 48 | 49 | if self.zoom_in is not None and self.zoom_in.check_possible_recalculation(): 50 | print('zooming') 51 | return self.get_prediction(clicker) 52 | 53 | # return prediction.cpu().numpy()[0, 0] 54 | return prediction 55 | 56 | def _get_prediction(self, image_nd, clicks_lists, is_image_changed): 57 | points_nd = self.get_points_nd(clicks_lists) 58 | return self.net(image_nd, points_nd)['instances'] 59 | 60 | def _get_transform_states(self): 61 | return [x.get_state() for x in self.transforms] 62 | 63 | def _set_transform_states(self, states): 64 | assert len(states) == len(self.transforms) 65 | for state, transform in zip(states, self.transforms): 66 | transform.set_state(state) 67 | 68 | def apply_transforms(self, image_nd, clicks_lists): 69 | is_image_changed = False 70 | for t in self.transforms: 71 | image_nd, clicks_lists = t.transform(image_nd, clicks_lists) 72 | is_image_changed |= t.image_changed 73 | 74 | return image_nd, clicks_lists, is_image_changed 75 | 76 | def get_points_nd(self, clicks_lists): 77 | total_clicks = [] 78 | num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists] 79 | num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)] 80 | num_max_points = max(num_pos_clicks + num_neg_clicks) 81 | if self.net_clicks_limit is not None: 82 | num_max_points = min(self.net_clicks_limit, num_max_points) 83 | num_max_points = max(1, num_max_points) 84 | 85 | for clicks_list in clicks_lists: 86 | clicks_list = clicks_list[:self.net_clicks_limit] 87 | pos_clicks = [click.coords for click in clicks_list if click.is_positive] 88 | pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1)] 89 | 90 | neg_clicks = [click.coords for click in clicks_list if not click.is_positive] 91 | neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1)] 92 | total_clicks.append(pos_clicks + neg_clicks) 93 | 94 | return torch.tensor(total_clicks, device=self.device) 95 | 96 | def get_states(self): 97 | return {'transform_states': self._get_transform_states()} 98 | 99 | def set_states(self, states): 100 | self._set_transform_states(states['transform_states']) 101 | -------------------------------------------------------------------------------- /fbrs/inference/predictors/brs_functors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from fbrs.model.metrics import _compute_iou 5 | from .brs_losses import BRSMaskLoss 6 | 7 | 8 | class BaseOptimizer: 9 | def __init__(self, optimizer_params, 10 | prob_thresh=0.49, 11 | reg_weight=1e-3, 12 | min_iou_diff=0.01, 13 | brs_loss=BRSMaskLoss(), 14 | with_flip=False, 15 | flip_average=False, 16 | **kwargs): 17 | self.brs_loss = brs_loss 18 | self.optimizer_params = optimizer_params 19 | self.prob_thresh = prob_thresh 20 | self.reg_weight = reg_weight 21 | self.min_iou_diff = min_iou_diff 22 | self.with_flip = with_flip 23 | self.flip_average = flip_average 24 | 25 | self.best_prediction = None 26 | self._get_prediction_logits = None 27 | self._opt_shape = None 28 | self._best_loss = None 29 | self._click_masks = None 30 | self._last_mask = None 31 | self.device = None 32 | 33 | def init_click(self, get_prediction_logits, pos_mask, neg_mask, device, shape=None): 34 | self.best_prediction = None 35 | self._get_prediction_logits = get_prediction_logits 36 | self._click_masks = (pos_mask, neg_mask) 37 | self._opt_shape = shape 38 | self._last_mask = None 39 | self.device = device 40 | 41 | def __call__(self, x): 42 | opt_params = torch.from_numpy(x).float().to(self.device) 43 | opt_params.requires_grad_(True) 44 | 45 | with torch.enable_grad(): 46 | opt_vars, reg_loss = self.unpack_opt_params(opt_params) 47 | result_before_sigmoid = self._get_prediction_logits(*opt_vars) 48 | result = torch.sigmoid(result_before_sigmoid) 49 | 50 | pos_mask, neg_mask = self._click_masks 51 | if self.with_flip and self.flip_average: 52 | result, result_flipped = torch.chunk(result, 2, dim=0) 53 | result = 0.5 * (result + torch.flip(result_flipped, dims=[3])) 54 | pos_mask, neg_mask = pos_mask[:result.shape[0]], neg_mask[:result.shape[0]] 55 | 56 | loss, f_max_pos, f_max_neg = self.brs_loss(result, pos_mask, neg_mask) 57 | loss = loss + reg_loss 58 | 59 | f_val = loss.detach().cpu().numpy() 60 | if self.best_prediction is None or f_val < self._best_loss: 61 | self.best_prediction = result_before_sigmoid.detach() 62 | self._best_loss = f_val 63 | 64 | if f_max_pos < (1 - self.prob_thresh) and f_max_neg < self.prob_thresh: 65 | return [f_val, np.zeros_like(x)] 66 | 67 | current_mask = result > self.prob_thresh 68 | if self._last_mask is not None and self.min_iou_diff > 0: 69 | diff_iou = _compute_iou(current_mask, self._last_mask) 70 | if len(diff_iou) > 0 and diff_iou.mean() > 1 - self.min_iou_diff: 71 | return [f_val, np.zeros_like(x)] 72 | self._last_mask = current_mask 73 | 74 | loss.backward() 75 | f_grad = opt_params.grad.cpu().numpy().ravel().astype(np.float) 76 | 77 | return [f_val, f_grad] 78 | 79 | def unpack_opt_params(self, opt_params): 80 | raise NotImplementedError 81 | 82 | 83 | class InputOptimizer(BaseOptimizer): 84 | def unpack_opt_params(self, opt_params): 85 | opt_params = opt_params.view(self._opt_shape) 86 | if self.with_flip: 87 | opt_params_flipped = torch.flip(opt_params, dims=[3]) 88 | opt_params = torch.cat([opt_params, opt_params_flipped], dim=0) 89 | reg_loss = self.reg_weight * torch.sum(opt_params**2) 90 | 91 | return (opt_params,), reg_loss 92 | 93 | 94 | class ScaleBiasOptimizer(BaseOptimizer): 95 | def __init__(self, *args, scale_act=None, reg_bias_weight=10.0, **kwargs): 96 | super().__init__(*args, **kwargs) 97 | self.scale_act = scale_act 98 | self.reg_bias_weight = reg_bias_weight 99 | 100 | def unpack_opt_params(self, opt_params): 101 | scale, bias = torch.chunk(opt_params, 2, dim=0) 102 | reg_loss = self.reg_weight * (torch.sum(scale**2) + self.reg_bias_weight * torch.sum(bias**2)) 103 | 104 | if self.scale_act == 'tanh': 105 | scale = torch.tanh(scale) 106 | elif self.scale_act == 'sin': 107 | scale = torch.sin(scale) 108 | 109 | return (1 + scale, bias), reg_loss 110 | -------------------------------------------------------------------------------- /fbrs/inference/predictors/brs_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from fbrs.model.losses import SigmoidBinaryCrossEntropyLoss 4 | 5 | 6 | class BRSMaskLoss(torch.nn.Module): 7 | def __init__(self, eps=1e-5): 8 | super().__init__() 9 | self._eps = eps 10 | 11 | def forward(self, result, pos_mask, neg_mask): 12 | pos_diff = (1 - result) * pos_mask 13 | pos_target = torch.sum(pos_diff ** 2) 14 | pos_target = pos_target / (torch.sum(pos_mask) + self._eps) 15 | 16 | neg_diff = result * neg_mask 17 | neg_target = torch.sum(neg_diff ** 2) 18 | neg_target = neg_target / (torch.sum(neg_mask) + self._eps) 19 | 20 | loss = pos_target + neg_target 21 | 22 | with torch.no_grad(): 23 | f_max_pos = torch.max(torch.abs(pos_diff)).item() 24 | f_max_neg = torch.max(torch.abs(neg_diff)).item() 25 | 26 | return loss, f_max_pos, f_max_neg 27 | 28 | 29 | class OracleMaskLoss(torch.nn.Module): 30 | def __init__(self): 31 | super().__init__() 32 | self.gt_mask = None 33 | self.loss = SigmoidBinaryCrossEntropyLoss(from_sigmoid=True) 34 | self.predictor = None 35 | self.history = [] 36 | 37 | def set_gt_mask(self, gt_mask): 38 | self.gt_mask = gt_mask 39 | self.history = [] 40 | 41 | def forward(self, result, pos_mask, neg_mask): 42 | gt_mask = self.gt_mask.to(result.device) 43 | if self.predictor.object_roi is not None: 44 | r1, r2, c1, c2 = self.predictor.object_roi[:4] 45 | gt_mask = gt_mask[:, :, r1:r2 + 1, c1:c2 + 1] 46 | gt_mask = torch.nn.functional.interpolate(gt_mask, result.size()[2:], mode='bilinear', align_corners=True) 47 | 48 | if result.shape[0] == 2: 49 | gt_mask_flipped = torch.flip(gt_mask, dims=[3]) 50 | gt_mask = torch.cat([gt_mask, gt_mask_flipped], dim=0) 51 | 52 | loss = self.loss(result, gt_mask) 53 | self.history.append(loss.detach().cpu().numpy()[0]) 54 | 55 | if len(self.history) > 5 and abs(self.history[-5] - self.history[-1]) < 1e-5: 56 | return 0, 0, 0 57 | 58 | return loss, 1.0, 1.0 59 | -------------------------------------------------------------------------------- /fbrs/inference/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import SigmoidForPred 2 | from .flip import AddHorizontalFlip 3 | from .zoom_in import ZoomIn 4 | from .limit_longest_side import LimitLongestSide 5 | from .crops import Crops 6 | -------------------------------------------------------------------------------- /fbrs/inference/transforms/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseTransform(object): 5 | def __init__(self): 6 | self.image_changed = False 7 | 8 | def transform(self, image_nd, clicks_lists): 9 | raise NotImplementedError 10 | 11 | def inv_transform(self, prob_map): 12 | raise NotImplementedError 13 | 14 | def reset(self): 15 | raise NotImplementedError 16 | 17 | def get_state(self): 18 | raise NotImplementedError 19 | 20 | def set_state(self, state): 21 | raise NotImplementedError 22 | 23 | 24 | class SigmoidForPred(BaseTransform): 25 | def transform(self, image_nd, clicks_lists): 26 | return image_nd, clicks_lists 27 | 28 | def inv_transform(self, prob_map): 29 | return torch.sigmoid(prob_map) 30 | 31 | def reset(self): 32 | pass 33 | 34 | def get_state(self): 35 | return None 36 | 37 | def set_state(self, state): 38 | pass 39 | -------------------------------------------------------------------------------- /fbrs/inference/transforms/crops.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from fbrs.inference.clicker import Click 7 | from .base import BaseTransform 8 | 9 | 10 | class Crops(BaseTransform): 11 | def __init__(self, crop_size=(320, 480), min_overlap=0.2): 12 | super().__init__() 13 | self.crop_height, self.crop_width = crop_size 14 | self.min_overlap = min_overlap 15 | 16 | self.x_offsets = None 17 | self.y_offsets = None 18 | self._counts = None 19 | 20 | def transform(self, image_nd, clicks_lists): 21 | assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 22 | image_height, image_width = image_nd.shape[2:4] 23 | self._counts = None 24 | 25 | if image_height < self.crop_height or image_width < self.crop_width: 26 | return image_nd, clicks_lists 27 | 28 | self.x_offsets = get_offsets(image_width, self.crop_width, self.min_overlap) 29 | self.y_offsets = get_offsets(image_height, self.crop_height, self.min_overlap) 30 | self._counts = np.zeros((image_height, image_width)) 31 | 32 | image_crops = [] 33 | for dy in self.y_offsets: 34 | for dx in self.x_offsets: 35 | self._counts[dy:dy + self.crop_height, dx:dx + self.crop_width] += 1 36 | image_crop = image_nd[:, :, dy:dy + self.crop_height, dx:dx + self.crop_width] 37 | image_crops.append(image_crop) 38 | image_crops = torch.cat(image_crops, dim=0) 39 | self._counts = torch.tensor(self._counts, device=image_nd.device, dtype=torch.float32) 40 | 41 | clicks_list = clicks_lists[0] 42 | clicks_lists = [] 43 | for dy in self.y_offsets: 44 | for dx in self.x_offsets: 45 | crop_clicks = [Click(is_positive=x.is_positive, coords=(x.coords[0] - dy, x.coords[1] - dx)) 46 | for x in clicks_list] 47 | clicks_lists.append(crop_clicks) 48 | 49 | return image_crops, clicks_lists 50 | 51 | def inv_transform(self, prob_map): 52 | if self._counts is None: 53 | return prob_map 54 | 55 | new_prob_map = torch.zeros((1, 1, *self._counts.shape), 56 | dtype=prob_map.dtype, device=prob_map.device) 57 | 58 | crop_indx = 0 59 | for dy in self.y_offsets: 60 | for dx in self.x_offsets: 61 | new_prob_map[0, 0, dy:dy + self.crop_height, dx:dx + self.crop_width] += prob_map[crop_indx, 0] 62 | crop_indx += 1 63 | new_prob_map = torch.div(new_prob_map, self._counts) 64 | 65 | return new_prob_map 66 | 67 | def get_state(self): 68 | return self.x_offsets, self.y_offsets, self._counts 69 | 70 | def set_state(self, state): 71 | self.x_offsets, self.y_offsets, self._counts = state 72 | 73 | def reset(self): 74 | self.x_offsets = None 75 | self.y_offsets = None 76 | self._counts = None 77 | 78 | 79 | def get_offsets(length, crop_size, min_overlap_ratio=0.2): 80 | if length == crop_size: 81 | return [0] 82 | 83 | N = (length / crop_size - min_overlap_ratio) / (1 - min_overlap_ratio) 84 | N = math.ceil(N) 85 | 86 | overlap_ratio = (N - length / crop_size) / (N - 1) 87 | overlap_width = int(crop_size * overlap_ratio) 88 | 89 | offsets = [0] 90 | for i in range(1, N): 91 | new_offset = offsets[-1] + crop_size - overlap_width 92 | if new_offset + crop_size > length: 93 | new_offset = length - crop_size 94 | 95 | offsets.append(new_offset) 96 | 97 | return offsets 98 | -------------------------------------------------------------------------------- /fbrs/inference/transforms/flip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from fbrs.inference.clicker import Click 4 | from .base import BaseTransform 5 | 6 | 7 | class AddHorizontalFlip(BaseTransform): 8 | def transform(self, image_nd, clicks_lists): 9 | assert len(image_nd.shape) == 4 10 | image_nd = torch.cat([image_nd, torch.flip(image_nd, dims=[3])], dim=0) 11 | 12 | image_width = image_nd.shape[3] 13 | clicks_lists_flipped = [] 14 | for clicks_list in clicks_lists: 15 | clicks_list_flipped = [Click(is_positive=click.is_positive, 16 | coords=(click.coords[0], image_width - click.coords[1] - 1)) 17 | for click in clicks_list] 18 | clicks_lists_flipped.append(clicks_list_flipped) 19 | clicks_lists = clicks_lists + clicks_lists_flipped 20 | 21 | return image_nd, clicks_lists 22 | 23 | def inv_transform(self, prob_map): 24 | assert len(prob_map.shape) == 4 and prob_map.shape[0] % 2 == 0 25 | num_maps = prob_map.shape[0] // 2 26 | prob_map, prob_map_flipped = prob_map[:num_maps], prob_map[num_maps:] 27 | 28 | return 0.5 * (prob_map + torch.flip(prob_map_flipped, dims=[3])) 29 | 30 | def get_state(self): 31 | return None 32 | 33 | def set_state(self, state): 34 | pass 35 | 36 | def reset(self): 37 | pass 38 | -------------------------------------------------------------------------------- /fbrs/inference/transforms/limit_longest_side.py: -------------------------------------------------------------------------------- 1 | from .zoom_in import ZoomIn, get_roi_image_nd 2 | 3 | 4 | class LimitLongestSide(ZoomIn): 5 | def __init__(self, max_size=800): 6 | super().__init__(target_size=max_size, skip_clicks=0) 7 | 8 | def transform(self, image_nd, clicks_lists): 9 | assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 10 | image_max_size = max(image_nd.shape[2:4]) 11 | self.image_changed = False 12 | 13 | if image_max_size <= self.target_size: 14 | return image_nd, clicks_lists 15 | self._input_image = image_nd 16 | 17 | self._object_roi = (0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1) 18 | self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size) 19 | self.image_changed = True 20 | 21 | tclicks_lists = [self._transform_clicks(clicks_lists[0])] 22 | return self._roi_image, tclicks_lists 23 | -------------------------------------------------------------------------------- /fbrs/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/MiVOS/f2600a6eea8709c7b9f1a7575adc725def680b81/fbrs/model/__init__.py -------------------------------------------------------------------------------- /fbrs/model/initializer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class Initializer(object): 7 | def __init__(self, local_init=True, gamma=None): 8 | self.local_init = local_init 9 | self.gamma = gamma 10 | 11 | def __call__(self, m): 12 | if getattr(m, '__initialized', False): 13 | return 14 | 15 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, 16 | nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, 17 | nn.GroupNorm, nn.SyncBatchNorm)) or 'BatchNorm' in m.__class__.__name__: 18 | if m.weight is not None: 19 | self._init_gamma(m.weight.data) 20 | if m.bias is not None: 21 | self._init_beta(m.bias.data) 22 | else: 23 | if getattr(m, 'weight', None) is not None: 24 | self._init_weight(m.weight.data) 25 | if getattr(m, 'bias', None) is not None: 26 | self._init_bias(m.bias.data) 27 | 28 | if self.local_init: 29 | object.__setattr__(m, '__initialized', True) 30 | 31 | def _init_weight(self, data): 32 | nn.init.uniform_(data, -0.07, 0.07) 33 | 34 | def _init_bias(self, data): 35 | nn.init.constant_(data, 0) 36 | 37 | def _init_gamma(self, data): 38 | if self.gamma is None: 39 | nn.init.constant_(data, 1.0) 40 | else: 41 | nn.init.normal_(data, 1.0, self.gamma) 42 | 43 | def _init_beta(self, data): 44 | nn.init.constant_(data, 0) 45 | 46 | 47 | class Bilinear(Initializer): 48 | def __init__(self, scale, groups, in_channels, **kwargs): 49 | super().__init__(**kwargs) 50 | self.scale = scale 51 | self.groups = groups 52 | self.in_channels = in_channels 53 | 54 | def _init_weight(self, data): 55 | """Reset the weight and bias.""" 56 | bilinear_kernel = self.get_bilinear_kernel(self.scale) 57 | weight = torch.zeros_like(data) 58 | for i in range(self.in_channels): 59 | if self.groups == 1: 60 | j = i 61 | else: 62 | j = 0 63 | weight[i, j] = bilinear_kernel 64 | data[:] = weight 65 | 66 | @staticmethod 67 | def get_bilinear_kernel(scale): 68 | """Generate a bilinear upsampling kernel.""" 69 | kernel_size = 2 * scale - scale % 2 70 | scale = (kernel_size + 1) // 2 71 | center = scale - 0.5 * (1 + kernel_size % 2) 72 | 73 | og = np.ogrid[:kernel_size, :kernel_size] 74 | kernel = (1 - np.abs(og[0] - center) / scale) * (1 - np.abs(og[1] - center) / scale) 75 | 76 | return torch.tensor(kernel, dtype=torch.float32) 77 | 78 | 79 | class XavierGluon(Initializer): 80 | def __init__(self, rnd_type='uniform', factor_type='avg', magnitude=3, **kwargs): 81 | super().__init__(**kwargs) 82 | 83 | self.rnd_type = rnd_type 84 | self.factor_type = factor_type 85 | self.magnitude = float(magnitude) 86 | 87 | def _init_weight(self, arr): 88 | fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr) 89 | 90 | if self.factor_type == 'avg': 91 | factor = (fan_in + fan_out) / 2.0 92 | elif self.factor_type == 'in': 93 | factor = fan_in 94 | elif self.factor_type == 'out': 95 | factor = fan_out 96 | else: 97 | raise ValueError('Incorrect factor type') 98 | scale = np.sqrt(self.magnitude / factor) 99 | 100 | if self.rnd_type == 'uniform': 101 | nn.init.uniform_(arr, -scale, scale) 102 | elif self.rnd_type == 'gaussian': 103 | nn.init.normal_(arr, 0, scale) 104 | else: 105 | raise ValueError('Unknown random type') 106 | -------------------------------------------------------------------------------- /fbrs/model/is_deeplab_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from fbrs.model.ops import DistMaps 5 | from .modeling.deeplab_v3 import DeepLabV3Plus 6 | from .modeling.basic_blocks import SepConvHead 7 | 8 | 9 | def get_deeplab_model(backbone='resnet50', deeplab_ch=256, aspp_dropout=0.5, 10 | norm_layer=nn.BatchNorm2d, backbone_norm_layer=None, 11 | use_rgb_conv=True, cpu_dist_maps=False, 12 | norm_radius=260): 13 | model = DistMapsModel( 14 | feature_extractor=DeepLabV3Plus(backbone=backbone, 15 | ch=deeplab_ch, 16 | project_dropout=aspp_dropout, 17 | norm_layer=norm_layer, 18 | backbone_norm_layer=backbone_norm_layer), 19 | head=SepConvHead(1, in_channels=deeplab_ch, mid_channels=deeplab_ch // 2, 20 | num_layers=2, norm_layer=norm_layer), 21 | use_rgb_conv=use_rgb_conv, 22 | norm_layer=norm_layer, 23 | norm_radius=norm_radius, 24 | cpu_dist_maps=cpu_dist_maps 25 | ) 26 | 27 | return model 28 | 29 | 30 | class DistMapsModel(nn.Module): 31 | def __init__(self, feature_extractor, head, norm_layer=nn.BatchNorm2d, use_rgb_conv=True, 32 | cpu_dist_maps=False, norm_radius=260): 33 | super(DistMapsModel, self).__init__() 34 | 35 | if use_rgb_conv: 36 | self.rgb_conv = nn.Sequential( 37 | nn.Conv2d(in_channels=5, out_channels=8, kernel_size=1), 38 | nn.LeakyReLU(negative_slope=0.2), 39 | norm_layer(8), 40 | nn.Conv2d(in_channels=8, out_channels=3, kernel_size=1), 41 | ) 42 | else: 43 | self.rgb_conv = None 44 | 45 | self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0, 46 | cpu_mode=cpu_dist_maps) 47 | self.feature_extractor = feature_extractor 48 | self.head = head 49 | 50 | def forward(self, image, points): 51 | coord_features = self.dist_maps(image, points) 52 | 53 | if self.rgb_conv is not None: 54 | x = self.rgb_conv(torch.cat((image, coord_features), dim=1)) 55 | else: 56 | c1, c2 = torch.chunk(coord_features, 2, dim=1) 57 | c3 = torch.ones_like(c1) 58 | coord_features = torch.cat((c1, c2, c3), dim=1) 59 | x = 0.8 * image * coord_features + 0.2 * image 60 | 61 | backbone_features = self.feature_extractor(x) 62 | instance_out = self.head(backbone_features[0]) 63 | instance_out = nn.functional.interpolate(instance_out, size=image.size()[2:], 64 | mode='bilinear', align_corners=True) 65 | 66 | return {'instances': instance_out} 67 | 68 | def load_weights(self, path_to_weights): 69 | current_state_dict = self.state_dict() 70 | new_state_dict = torch.load(path_to_weights, map_location='cpu') 71 | current_state_dict.update(new_state_dict) 72 | self.load_state_dict(current_state_dict) 73 | 74 | def get_trainable_params(self): 75 | backbone_params = nn.ParameterList() 76 | other_params = nn.ParameterList() 77 | 78 | for name, param in self.named_parameters(): 79 | if param.requires_grad: 80 | if 'backbone' in name: 81 | backbone_params.append(param) 82 | else: 83 | other_params.append(param) 84 | return backbone_params, other_params 85 | 86 | 87 | -------------------------------------------------------------------------------- /fbrs/model/is_hrnet_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from fbrs.model.ops import DistMaps 5 | from .modeling.hrnet_ocr import HighResolutionNet 6 | 7 | 8 | def get_hrnet_model(width=48, ocr_width=256, small=False, norm_radius=260, 9 | use_rgb_conv=True, with_aux_output=False, cpu_dist_maps=False, 10 | norm_layer=nn.BatchNorm2d): 11 | model = DistMapsHRNetModel( 12 | feature_extractor=HighResolutionNet(width=width, ocr_width=ocr_width, small=small, 13 | num_classes=1, norm_layer=norm_layer), 14 | use_rgb_conv=use_rgb_conv, 15 | with_aux_output=with_aux_output, 16 | norm_layer=norm_layer, 17 | norm_radius=norm_radius, 18 | cpu_dist_maps=cpu_dist_maps 19 | ) 20 | 21 | return model 22 | 23 | 24 | class DistMapsHRNetModel(nn.Module): 25 | def __init__(self, feature_extractor, use_rgb_conv=True, with_aux_output=False, 26 | norm_layer=nn.BatchNorm2d, norm_radius=260, cpu_dist_maps=False): 27 | super(DistMapsHRNetModel, self).__init__() 28 | self.with_aux_output = with_aux_output 29 | 30 | if use_rgb_conv: 31 | self.rgb_conv = nn.Sequential( 32 | nn.Conv2d(in_channels=5, out_channels=8, kernel_size=1), 33 | nn.LeakyReLU(negative_slope=0.2), 34 | norm_layer(8), 35 | nn.Conv2d(in_channels=8, out_channels=3, kernel_size=1), 36 | ) 37 | else: 38 | self.rgb_conv = None 39 | 40 | self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0, cpu_mode=cpu_dist_maps) 41 | self.feature_extractor = feature_extractor 42 | 43 | def forward(self, image, points): 44 | coord_features = self.dist_maps(image, points) 45 | 46 | if self.rgb_conv is not None: 47 | x = self.rgb_conv(torch.cat((image, coord_features), dim=1)) 48 | else: 49 | c1, c2 = torch.chunk(coord_features, 2, dim=1) 50 | c3 = torch.ones_like(c1) 51 | coord_features = torch.cat((c1, c2, c3), dim=1) 52 | x = 0.8 * image * coord_features + 0.2 * image 53 | 54 | feature_extractor_out = self.feature_extractor(x) 55 | instance_out = feature_extractor_out[0] 56 | instance_out = nn.functional.interpolate(instance_out, size=image.size()[2:], 57 | mode='bilinear', align_corners=True) 58 | outputs = {'instances': instance_out} 59 | if self.with_aux_output: 60 | instance_aux_out = feature_extractor_out[1] 61 | instance_aux_out = nn.functional.interpolate(instance_aux_out, size=image.size()[2:], 62 | mode='bilinear', align_corners=True) 63 | outputs['instances_aux'] = instance_aux_out 64 | 65 | return outputs 66 | 67 | def load_weights(self, path_to_weights): 68 | current_state_dict = self.state_dict() 69 | new_state_dict = torch.load(path_to_weights) 70 | current_state_dict.update(new_state_dict) 71 | self.load_state_dict(current_state_dict) 72 | 73 | def get_trainable_params(self): 74 | backbone_params = nn.ParameterList() 75 | other_params = nn.ParameterList() 76 | other_params_keys = [] 77 | nonbackbone_keywords = ['rgb_conv', 'aux_head', 'cls_head', 'conv3x3_ocr', 'ocr_distri_head'] 78 | 79 | for name, param in self.named_parameters(): 80 | if param.requires_grad: 81 | if any(x in name for x in nonbackbone_keywords): 82 | other_params.append(param) 83 | other_params_keys.append(name) 84 | else: 85 | backbone_params.append(param) 86 | print('Nonbackbone params:', sorted(other_params_keys)) 87 | return backbone_params, other_params 88 | -------------------------------------------------------------------------------- /fbrs/model/losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from fbrs.utils import misc 7 | 8 | 9 | class NormalizedFocalLossSigmoid(nn.Module): 10 | def __init__(self, axis=-1, alpha=0.25, gamma=2, 11 | from_logits=False, batch_axis=0, 12 | weight=None, size_average=True, detach_delimeter=True, 13 | eps=1e-12, scale=1.0, 14 | ignore_label=-1): 15 | super(NormalizedFocalLossSigmoid, self).__init__() 16 | self._axis = axis 17 | self._alpha = alpha 18 | self._gamma = gamma 19 | self._ignore_label = ignore_label 20 | self._weight = weight if weight is not None else 1.0 21 | self._batch_axis = batch_axis 22 | 23 | self._scale = scale 24 | self._from_logits = from_logits 25 | self._eps = eps 26 | self._size_average = size_average 27 | self._detach_delimeter = detach_delimeter 28 | self._k_sum = 0 29 | 30 | def forward(self, pred, label, sample_weight=None): 31 | one_hot = label > 0 32 | sample_weight = label != self._ignore_label 33 | 34 | if not self._from_logits: 35 | pred = torch.sigmoid(pred) 36 | 37 | alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight) 38 | pt = torch.where(one_hot, pred, 1 - pred) 39 | pt = torch.where(sample_weight, pt, torch.ones_like(pt)) 40 | 41 | beta = (1 - pt) ** self._gamma 42 | 43 | sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True) 44 | beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True) 45 | mult = sw_sum / (beta_sum + self._eps) 46 | if self._detach_delimeter: 47 | mult = mult.detach() 48 | beta = beta * mult 49 | 50 | ignore_area = torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim()))).cpu().numpy() 51 | sample_mult = torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy() 52 | if np.any(ignore_area == 0): 53 | self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean() 54 | 55 | loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device))) 56 | loss = self._weight * (loss * sample_weight) 57 | 58 | if self._size_average: 59 | bsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(sample_weight.dim(), self._batch_axis)) 60 | loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (bsum + self._eps) 61 | else: 62 | loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) 63 | 64 | return self._scale * loss 65 | 66 | def log_states(self, sw, name, global_step): 67 | sw.add_scalar(tag=name + '_k', value=self._k_sum, global_step=global_step) 68 | 69 | 70 | class FocalLoss(nn.Module): 71 | def __init__(self, axis=-1, alpha=0.25, gamma=2, 72 | from_logits=False, batch_axis=0, 73 | weight=None, num_class=None, 74 | eps=1e-9, size_average=True, scale=1.0): 75 | super(FocalLoss, self).__init__() 76 | self._axis = axis 77 | self._alpha = alpha 78 | self._gamma = gamma 79 | self._weight = weight if weight is not None else 1.0 80 | self._batch_axis = batch_axis 81 | 82 | self._scale = scale 83 | self._num_class = num_class 84 | self._from_logits = from_logits 85 | self._eps = eps 86 | self._size_average = size_average 87 | 88 | def forward(self, pred, label, sample_weight=None): 89 | if not self._from_logits: 90 | pred = F.sigmoid(pred) 91 | 92 | one_hot = label > 0 93 | pt = torch.where(one_hot, pred, 1 - pred) 94 | 95 | t = label != -1 96 | alpha = torch.where(one_hot, self._alpha * t, (1 - self._alpha) * t) 97 | beta = (1 - pt) ** self._gamma 98 | 99 | loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device))) 100 | sample_weight = label != -1 101 | 102 | loss = self._weight * (loss * sample_weight) 103 | 104 | if self._size_average: 105 | tsum = torch.sum(label == 1, dim=misc.get_dims_with_exclusion(label.dim(), self._batch_axis)) 106 | loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (tsum + self._eps) 107 | else: 108 | loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) 109 | 110 | return self._scale * loss 111 | 112 | 113 | class SigmoidBinaryCrossEntropyLoss(nn.Module): 114 | def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, ignore_label=-1): 115 | super(SigmoidBinaryCrossEntropyLoss, self).__init__() 116 | self._from_sigmoid = from_sigmoid 117 | self._ignore_label = ignore_label 118 | self._weight = weight if weight is not None else 1.0 119 | self._batch_axis = batch_axis 120 | 121 | def forward(self, pred, label): 122 | label = label.view(pred.size()) 123 | sample_weight = label != self._ignore_label 124 | label = torch.where(sample_weight, label, torch.zeros_like(label)) 125 | 126 | if not self._from_sigmoid: 127 | loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred)) 128 | else: 129 | eps = 1e-12 130 | loss = -(torch.log(pred + eps) * label 131 | + torch.log(1. - pred + eps) * (1. - label)) 132 | 133 | loss = self._weight * (loss * sample_weight) 134 | return torch.mean(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) 135 | -------------------------------------------------------------------------------- /fbrs/model/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from fbrs.utils import misc 5 | 6 | 7 | class TrainMetric(object): 8 | def __init__(self, pred_outputs, gt_outputs): 9 | self.pred_outputs = pred_outputs 10 | self.gt_outputs = gt_outputs 11 | 12 | def update(self, *args, **kwargs): 13 | raise NotImplementedError 14 | 15 | def get_epoch_value(self): 16 | raise NotImplementedError 17 | 18 | def reset_epoch_stats(self): 19 | raise NotImplementedError 20 | 21 | def log_states(self, sw, tag_prefix, global_step): 22 | pass 23 | 24 | @property 25 | def name(self): 26 | return type(self).__name__ 27 | 28 | 29 | class AdaptiveIoU(TrainMetric): 30 | def __init__(self, init_thresh=0.4, thresh_step=0.025, thresh_beta=0.99, iou_beta=0.9, 31 | ignore_label=-1, from_logits=True, 32 | pred_output='instances', gt_output='instances'): 33 | super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,)) 34 | self._ignore_label = ignore_label 35 | self._from_logits = from_logits 36 | self._iou_thresh = init_thresh 37 | self._thresh_step = thresh_step 38 | self._thresh_beta = thresh_beta 39 | self._iou_beta = iou_beta 40 | self._ema_iou = 0.0 41 | self._epoch_iou_sum = 0.0 42 | self._epoch_batch_count = 0 43 | 44 | def update(self, pred, gt): 45 | gt_mask = gt > 0 46 | if self._from_logits: 47 | pred = torch.sigmoid(pred) 48 | 49 | gt_mask_area = torch.sum(gt_mask, dim=(1, 2)).detach().cpu().numpy() 50 | if np.all(gt_mask_area == 0): 51 | return 52 | 53 | ignore_mask = gt == self._ignore_label 54 | max_iou = _compute_iou(pred > self._iou_thresh, gt_mask, ignore_mask).mean() 55 | best_thresh = self._iou_thresh 56 | for t in [best_thresh - self._thresh_step, best_thresh + self._thresh_step]: 57 | temp_iou = _compute_iou(pred > t, gt_mask, ignore_mask).mean() 58 | if temp_iou > max_iou: 59 | max_iou = temp_iou 60 | best_thresh = t 61 | 62 | self._iou_thresh = self._thresh_beta * self._iou_thresh + (1 - self._thresh_beta) * best_thresh 63 | self._ema_iou = self._iou_beta * self._ema_iou + (1 - self._iou_beta) * max_iou 64 | self._epoch_iou_sum += max_iou 65 | self._epoch_batch_count += 1 66 | 67 | def get_epoch_value(self): 68 | if self._epoch_batch_count > 0: 69 | return self._epoch_iou_sum / self._epoch_batch_count 70 | else: 71 | return 0.0 72 | 73 | def reset_epoch_stats(self): 74 | self._epoch_iou_sum = 0.0 75 | self._epoch_batch_count = 0 76 | 77 | def log_states(self, sw, tag_prefix, global_step): 78 | sw.add_scalar(tag=tag_prefix + '_ema_iou', value=self._ema_iou, global_step=global_step) 79 | sw.add_scalar(tag=tag_prefix + '_iou_thresh', value=self._iou_thresh, global_step=global_step) 80 | 81 | @property 82 | def iou_thresh(self): 83 | return self._iou_thresh 84 | 85 | 86 | def _compute_iou(pred_mask, gt_mask, ignore_mask=None, keep_ignore=False): 87 | if ignore_mask is not None: 88 | pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask) 89 | 90 | reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0) 91 | union = torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims).detach().cpu().numpy() 92 | intersection = torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims).detach().cpu().numpy() 93 | nonzero = union > 0 94 | 95 | iou = intersection[nonzero] / union[nonzero] 96 | if not keep_ignore: 97 | return iou 98 | else: 99 | result = np.full_like(intersection, -1) 100 | result[nonzero] = iou 101 | return result 102 | -------------------------------------------------------------------------------- /fbrs/model/modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/MiVOS/f2600a6eea8709c7b9f1a7575adc725def680b81/fbrs/model/modeling/__init__.py -------------------------------------------------------------------------------- /fbrs/model/modeling/basic_blocks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from fbrs.model import ops 4 | 5 | 6 | class ConvHead(nn.Module): 7 | def __init__(self, out_channels, in_channels=32, num_layers=1, 8 | kernel_size=3, padding=1, 9 | norm_layer=nn.BatchNorm2d): 10 | super(ConvHead, self).__init__() 11 | convhead = [] 12 | 13 | for i in range(num_layers): 14 | convhead.extend([ 15 | nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding), 16 | nn.ReLU(), 17 | norm_layer(in_channels) if norm_layer is not None else nn.Identity() 18 | ]) 19 | convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0)) 20 | 21 | self.convhead = nn.Sequential(*convhead) 22 | 23 | def forward(self, *inputs): 24 | return self.convhead(inputs[0]) 25 | 26 | 27 | class SepConvHead(nn.Module): 28 | def __init__(self, num_outputs, in_channels, mid_channels, num_layers=1, 29 | kernel_size=3, padding=1, dropout_ratio=0.0, dropout_indx=0, 30 | norm_layer=nn.BatchNorm2d): 31 | super(SepConvHead, self).__init__() 32 | 33 | sepconvhead = [] 34 | 35 | for i in range(num_layers): 36 | sepconvhead.append( 37 | SeparableConv2d(in_channels=in_channels if i == 0 else mid_channels, 38 | out_channels=mid_channels, 39 | dw_kernel=kernel_size, dw_padding=padding, 40 | norm_layer=norm_layer, activation='relu') 41 | ) 42 | if dropout_ratio > 0 and dropout_indx == i: 43 | sepconvhead.append(nn.Dropout(dropout_ratio)) 44 | 45 | sepconvhead.append( 46 | nn.Conv2d(in_channels=mid_channels, out_channels=num_outputs, kernel_size=1, padding=0) 47 | ) 48 | 49 | self.layers = nn.Sequential(*sepconvhead) 50 | 51 | def forward(self, *inputs): 52 | x = inputs[0] 53 | 54 | return self.layers(x) 55 | 56 | 57 | class SeparableConv2d(nn.Module): 58 | def __init__(self, in_channels, out_channels, dw_kernel, dw_padding, dw_stride=1, 59 | activation=None, use_bias=False, norm_layer=None): 60 | super(SeparableConv2d, self).__init__() 61 | _activation = ops.select_activation_function(activation) 62 | self.body = nn.Sequential( 63 | nn.Conv2d(in_channels, in_channels, kernel_size=dw_kernel, stride=dw_stride, 64 | padding=dw_padding, bias=use_bias, groups=in_channels), 65 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias), 66 | norm_layer(out_channels) if norm_layer is not None else nn.Identity(), 67 | _activation() 68 | ) 69 | 70 | def forward(self, x): 71 | return self.body(x) 72 | -------------------------------------------------------------------------------- /fbrs/model/modeling/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s 3 | 4 | 5 | class ResNetBackbone(torch.nn.Module): 6 | def __init__(self, backbone='resnet50', pretrained_base=True, dilated=True, **kwargs): 7 | super(ResNetBackbone, self).__init__() 8 | 9 | if backbone == 'resnet34': 10 | pretrained = resnet34_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs) 11 | elif backbone == 'resnet50': 12 | pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) 13 | elif backbone == 'resnet101': 14 | pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) 15 | elif backbone == 'resnet152': 16 | pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) 17 | else: 18 | raise RuntimeError(f'unknown backbone: {backbone}') 19 | 20 | self.conv1 = pretrained.conv1 21 | self.bn1 = pretrained.bn1 22 | self.relu = pretrained.relu 23 | self.maxpool = pretrained.maxpool 24 | self.layer1 = pretrained.layer1 25 | self.layer2 = pretrained.layer2 26 | self.layer3 = pretrained.layer3 27 | self.layer4 = pretrained.layer4 28 | 29 | def forward(self, x): 30 | x = self.conv1(x) 31 | x = self.bn1(x) 32 | x = self.relu(x) 33 | x = self.maxpool(x) 34 | c1 = self.layer1(x) 35 | c2 = self.layer2(c1) 36 | c3 = self.layer3(c2) 37 | c4 = self.layer4(c3) 38 | 39 | return c1, c2, c3, c4 40 | -------------------------------------------------------------------------------- /fbrs/model/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | import numpy as np 4 | 5 | import fbrs.model.initializer as initializer 6 | from fbrs.utils.cython import get_dist_maps 7 | 8 | 9 | def select_activation_function(activation): 10 | if isinstance(activation, str): 11 | if activation.lower() == 'relu': 12 | return nn.ReLU 13 | elif activation.lower() == 'softplus': 14 | return nn.Softplus 15 | else: 16 | raise ValueError(f"Unknown activation type {activation}") 17 | elif isinstance(activation, nn.Module): 18 | return activation 19 | else: 20 | raise ValueError(f"Unknown activation type {activation}") 21 | 22 | 23 | class BilinearConvTranspose2d(nn.ConvTranspose2d): 24 | def __init__(self, in_channels, out_channels, scale, groups=1): 25 | kernel_size = 2 * scale - scale % 2 26 | self.scale = scale 27 | 28 | super().__init__( 29 | in_channels, out_channels, 30 | kernel_size=kernel_size, 31 | stride=scale, 32 | padding=1, 33 | groups=groups, 34 | bias=False) 35 | 36 | self.apply(initializer.Bilinear(scale=scale, in_channels=in_channels, groups=groups)) 37 | 38 | 39 | class DistMaps(nn.Module): 40 | def __init__(self, norm_radius, spatial_scale=1.0, cpu_mode=False): 41 | super(DistMaps, self).__init__() 42 | self.spatial_scale = spatial_scale 43 | self.norm_radius = norm_radius 44 | self.cpu_mode = cpu_mode 45 | 46 | def get_coord_features(self, points, batchsize, rows, cols): 47 | if self.cpu_mode: 48 | coords = [] 49 | for i in range(batchsize): 50 | norm_delimeter = self.spatial_scale * self.norm_radius 51 | coords.append(get_dist_maps(points[i].cpu().float().numpy(), rows, cols, 52 | norm_delimeter)) 53 | coords = torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float() 54 | else: 55 | num_points = points.shape[1] // 2 56 | points = points.view(-1, 2) 57 | invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0 58 | row_array = torch.arange(start=0, end=rows, step=1, dtype=torch.float32, device=points.device) 59 | col_array = torch.arange(start=0, end=cols, step=1, dtype=torch.float32, device=points.device) 60 | 61 | coord_rows, coord_cols = torch.meshgrid(row_array, col_array) 62 | coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1) 63 | 64 | add_xy = (points * self.spatial_scale).view(points.size(0), points.size(1), 1, 1) 65 | coords.add_(-add_xy) 66 | coords.div_(self.norm_radius * self.spatial_scale) 67 | coords.mul_(coords) 68 | 69 | coords[:, 0] += coords[:, 1] 70 | coords = coords[:, :1] 71 | 72 | coords[invalid_points, :, :, :] = 1e6 73 | 74 | coords = coords.view(-1, num_points, 1, rows, cols) 75 | coords = coords.min(dim=1)[0] # -> (bs * num_masks * 2) x 1 x h x w 76 | coords = coords.view(-1, 2, rows, cols) 77 | 78 | coords.sqrt_().mul_(2).tanh_() 79 | 80 | return coords 81 | 82 | def forward(self, x, coords): 83 | return self.get_coord_features(coords, x.shape[0], x.shape[2], x.shape[3]) 84 | -------------------------------------------------------------------------------- /fbrs/model/syncbn/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Tamaki Kojima 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /fbrs/model/syncbn/README.md: -------------------------------------------------------------------------------- 1 | # pytorch-syncbn 2 | 3 | Tamaki Kojima(tamakoji@gmail.com) 4 | 5 | ## Announcement 6 | 7 | **Pytorch 1.0 support** 8 | 9 | ## Overview 10 | This is alternative implementation of "Synchronized Multi-GPU Batch Normalization" which computes global stats across gpus instead of locally computed. SyncBN are getting important for those input image is large, and must use multi-gpu to increase the minibatch-size for the training. 11 | 12 | The code was inspired by [Pytorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding) and [Inplace-ABN](https://github.com/mapillary/inplace_abn) 13 | 14 | ## Remarks 15 | - Unlike [Pytorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding), you don't need custom `nn.DataParallel` 16 | - Unlike [Inplace-ABN](https://github.com/mapillary/inplace_abn), you can just replace your `nn.BatchNorm2d` to this module implementation, since it will not mark for inplace operation 17 | - You can plug into arbitrary module written in PyTorch to enable Synchronized BatchNorm 18 | - Backward computation is rewritten and tested against behavior of `nn.BatchNorm2d` 19 | 20 | ## Requirements 21 | For PyTorch, please refer to https://pytorch.org/ 22 | 23 | NOTE : The code is tested only with PyTorch v1.0.0, CUDA10/CuDNN7.4.2 on ubuntu18.04 24 | 25 | It utilize Pytorch JIT mechanism to compile seamlessly, using ninja. Please install ninja-build before use. 26 | 27 | ``` 28 | sudo apt-get install ninja-build 29 | ``` 30 | 31 | Also install all dependencies for python. For pip, run: 32 | 33 | 34 | ``` 35 | pip install -U -r requirements.txt 36 | ``` 37 | 38 | ## Build 39 | 40 | There is no need to build. just run and JIT will take care. 41 | JIT and cpp extensions are supported after PyTorch0.4, however it is highly recommended to use PyTorch > 1.0 due to huge design changes. 42 | 43 | ## Usage 44 | 45 | Please refer to [`test.py`](./test.py) for testing the difference between `nn.BatchNorm2d` and `modules.nn.BatchNorm2d` 46 | 47 | ``` 48 | import torch 49 | from modules import nn as NN 50 | num_gpu = torch.cuda.device_count() 51 | model = nn.Sequential( 52 | nn.Conv2d(3, 3, 1, 1, bias=False), 53 | NN.BatchNorm2d(3), 54 | nn.ReLU(inplace=True), 55 | nn.Conv2d(3, 3, 1, 1, bias=False), 56 | NN.BatchNorm2d(3), 57 | ).cuda() 58 | model = nn.DataParallel(model, device_ids=range(num_gpu)) 59 | x = torch.rand(num_gpu, 3, 2, 2).cuda() 60 | z = model(x) 61 | ``` 62 | 63 | ## Math 64 | 65 | ### Forward 66 | 1. compute in each gpu 67 | 2. gather all from workers to master and compute where 68 | 69 | 70 | 71 | and 72 | 73 | 74 | 75 | and then above global stats to be shared to all gpus, update running_mean and running_var by moving average using global stats. 76 | 77 | 3. forward batchnorm using global stats by 78 | 79 | 80 | 81 | and then 82 | 83 | 84 | 85 | where is weight parameter and is bias parameter. 86 | 87 | 4. save for backward 88 | 89 | ### Backward 90 | 91 | 1. Restore saved 92 | 93 | 2. Compute below sums on each gpu 94 | 95 | 96 | 97 | and 98 | 99 | 100 | 101 | where 102 | 103 | then gather them at master node to sum up global, and normalize with N where N is total number of elements for each channels. Global sums are then shared among all gpus. 104 | 105 | 3. compute gradients using global stats 106 | 107 | 108 | 109 | where 110 | 111 | 112 | 113 | and 114 | 115 | 116 | 117 | and finally, 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | Note that in the implementation, normalization with N is performed at step (2) and above equation and implementation is not exactly the same, but mathematically is same. 126 | 127 | You can go deeper on above explanation at [Kevin Zakka's Blog](https://kevinzakka.github.io/2016/09/14/batch_normalization/) -------------------------------------------------------------------------------- /fbrs/model/syncbn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/MiVOS/f2600a6eea8709c7b9f1a7575adc725def680b81/fbrs/model/syncbn/__init__.py -------------------------------------------------------------------------------- /fbrs/model/syncbn/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/MiVOS/f2600a6eea8709c7b9f1a7575adc725def680b81/fbrs/model/syncbn/modules/__init__.py -------------------------------------------------------------------------------- /fbrs/model/syncbn/modules/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .syncbn import batchnorm2d_sync 2 | -------------------------------------------------------------------------------- /fbrs/model/syncbn/modules/functional/_csrc.py: -------------------------------------------------------------------------------- 1 | """ 2 | /*****************************************************************************/ 3 | 4 | Extension module loader 5 | 6 | code referenced from : https://github.com/facebookresearch/maskrcnn-benchmark 7 | 8 | /*****************************************************************************/ 9 | """ 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import glob 15 | import os.path 16 | 17 | import torch 18 | 19 | try: 20 | from torch.utils.cpp_extension import load 21 | from torch.utils.cpp_extension import CUDA_HOME 22 | except ImportError: 23 | raise ImportError( 24 | "The cpp layer extensions requires PyTorch 0.4 or higher") 25 | 26 | 27 | def _load_C_extensions(): 28 | this_dir = os.path.dirname(os.path.abspath(__file__)) 29 | this_dir = os.path.join(this_dir, "csrc") 30 | 31 | main_file = glob.glob(os.path.join(this_dir, "*.cpp")) 32 | sources_cpu = glob.glob(os.path.join(this_dir, "cpu", "*.cpp")) 33 | sources_cuda = glob.glob(os.path.join(this_dir, "cuda", "*.cu")) 34 | 35 | sources = main_file + sources_cpu 36 | 37 | extra_cflags = [] 38 | extra_cuda_cflags = [] 39 | if torch.cuda.is_available() and CUDA_HOME is not None: 40 | sources.extend(sources_cuda) 41 | extra_cflags = ["-O3", "-DWITH_CUDA"] 42 | extra_cuda_cflags = ["--expt-extended-lambda"] 43 | sources = [os.path.join(this_dir, s) for s in sources] 44 | extra_include_paths = [this_dir] 45 | return load( 46 | name="ext_lib", 47 | sources=sources, 48 | extra_cflags=extra_cflags, 49 | extra_include_paths=extra_include_paths, 50 | extra_cuda_cflags=extra_cuda_cflags, 51 | ) 52 | 53 | 54 | _backend = _load_C_extensions() 55 | -------------------------------------------------------------------------------- /fbrs/model/syncbn/modules/functional/csrc/bn.h: -------------------------------------------------------------------------------- 1 | /***************************************************************************** 2 | 3 | SyncBN 4 | 5 | *****************************************************************************/ 6 | #pragma once 7 | 8 | #ifdef WITH_CUDA 9 | #include "cuda/ext_lib.h" 10 | #endif 11 | 12 | /// SyncBN 13 | 14 | std::vector syncbn_sum_sqsum(const at::Tensor& x) { 15 | if (x.is_cuda()) { 16 | #ifdef WITH_CUDA 17 | return syncbn_sum_sqsum_cuda(x); 18 | #else 19 | AT_ERROR("Not compiled with GPU support"); 20 | #endif 21 | } else { 22 | AT_ERROR("CPU implementation not supported"); 23 | } 24 | } 25 | 26 | at::Tensor syncbn_forward(const at::Tensor& x, const at::Tensor& weight, 27 | const at::Tensor& bias, const at::Tensor& mean, 28 | const at::Tensor& var, bool affine, float eps) { 29 | if (x.is_cuda()) { 30 | #ifdef WITH_CUDA 31 | return syncbn_forward_cuda(x, weight, bias, mean, var, affine, eps); 32 | #else 33 | AT_ERROR("Not compiled with GPU support"); 34 | #endif 35 | } else { 36 | AT_ERROR("CPU implementation not supported"); 37 | } 38 | } 39 | 40 | std::vector syncbn_backward_xhat(const at::Tensor& dz, 41 | const at::Tensor& x, 42 | const at::Tensor& mean, 43 | const at::Tensor& var, float eps) { 44 | if (dz.is_cuda()) { 45 | #ifdef WITH_CUDA 46 | return syncbn_backward_xhat_cuda(dz, x, mean, var, eps); 47 | #else 48 | AT_ERROR("Not compiled with GPU support"); 49 | #endif 50 | } else { 51 | AT_ERROR("CPU implementation not supported"); 52 | } 53 | } 54 | 55 | std::vector syncbn_backward( 56 | const at::Tensor& dz, const at::Tensor& x, const at::Tensor& weight, 57 | const at::Tensor& bias, const at::Tensor& mean, const at::Tensor& var, 58 | const at::Tensor& sum_dz, const at::Tensor& sum_dz_xhat, bool affine, 59 | float eps) { 60 | if (dz.is_cuda()) { 61 | #ifdef WITH_CUDA 62 | return syncbn_backward_cuda(dz, x, weight, bias, mean, var, sum_dz, 63 | sum_dz_xhat, affine, eps); 64 | #else 65 | AT_ERROR("Not compiled with GPU support"); 66 | #endif 67 | } else { 68 | AT_ERROR("CPU implementation not supported"); 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /fbrs/model/syncbn/modules/functional/csrc/cuda/common.h: -------------------------------------------------------------------------------- 1 | /***************************************************************************** 2 | 3 | CUDA utility funcs 4 | 5 | code referenced from : https://github.com/mapillary/inplace_abn 6 | 7 | *****************************************************************************/ 8 | #pragma once 9 | 10 | #include 11 | 12 | // Checks 13 | #ifndef AT_CHECK 14 | #define AT_CHECK AT_ASSERT 15 | #endif 16 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 17 | #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous") 18 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 19 | 20 | /* 21 | * General settings 22 | */ 23 | const int WARP_SIZE = 32; 24 | const int MAX_BLOCK_SIZE = 512; 25 | 26 | template 27 | struct Pair { 28 | T v1, v2; 29 | __device__ Pair() {} 30 | __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {} 31 | __device__ Pair(T v) : v1(v), v2(v) {} 32 | __device__ Pair(int v) : v1(v), v2(v) {} 33 | __device__ Pair &operator+=(const Pair &a) { 34 | v1 += a.v1; 35 | v2 += a.v2; 36 | return *this; 37 | } 38 | }; 39 | 40 | /* 41 | * Utility functions 42 | */ 43 | template 44 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, 45 | int width = warpSize, 46 | unsigned int mask = 0xffffffff) { 47 | #if CUDART_VERSION >= 9000 48 | return __shfl_xor_sync(mask, value, laneMask, width); 49 | #else 50 | return __shfl_xor(value, laneMask, width); 51 | #endif 52 | } 53 | 54 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); } 55 | 56 | static int getNumThreads(int nElem) { 57 | int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE}; 58 | for (int i = 0; i != 5; ++i) { 59 | if (nElem <= threadSizes[i]) { 60 | return threadSizes[i]; 61 | } 62 | } 63 | return MAX_BLOCK_SIZE; 64 | } 65 | 66 | template 67 | static __device__ __forceinline__ T warpSum(T val) { 68 | #if __CUDA_ARCH__ >= 300 69 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) { 70 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); 71 | } 72 | #else 73 | __shared__ T values[MAX_BLOCK_SIZE]; 74 | values[threadIdx.x] = val; 75 | __threadfence_block(); 76 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE; 77 | for (int i = 1; i < WARP_SIZE; i++) { 78 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)]; 79 | } 80 | #endif 81 | return val; 82 | } 83 | 84 | template 85 | static __device__ __forceinline__ Pair warpSum(Pair value) { 86 | value.v1 = warpSum(value.v1); 87 | value.v2 = warpSum(value.v2); 88 | return value; 89 | } 90 | 91 | template 92 | __device__ T reduce(Op op, int plane, int N, int C, int S) { 93 | T sum = (T)0; 94 | for (int batch = 0; batch < N; ++batch) { 95 | for (int x = threadIdx.x; x < S; x += blockDim.x) { 96 | sum += op(batch, plane, x); 97 | } 98 | } 99 | 100 | // sum over NumThreads within a warp 101 | sum = warpSum(sum); 102 | 103 | // 'transpose', and reduce within warp again 104 | __shared__ T shared[32]; 105 | __syncthreads(); 106 | if (threadIdx.x % WARP_SIZE == 0) { 107 | shared[threadIdx.x / WARP_SIZE] = sum; 108 | } 109 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { 110 | // zero out the other entries in shared 111 | shared[threadIdx.x] = (T)0; 112 | } 113 | __syncthreads(); 114 | if (threadIdx.x / WARP_SIZE == 0) { 115 | sum = warpSum(shared[threadIdx.x]); 116 | if (threadIdx.x == 0) { 117 | shared[0] = sum; 118 | } 119 | } 120 | __syncthreads(); 121 | 122 | // Everyone picks it up, should be broadcast into the whole gradInput 123 | return shared[0]; 124 | } -------------------------------------------------------------------------------- /fbrs/model/syncbn/modules/functional/csrc/cuda/ext_lib.h: -------------------------------------------------------------------------------- 1 | /***************************************************************************** 2 | 3 | CUDA SyncBN code 4 | 5 | *****************************************************************************/ 6 | #pragma once 7 | #include 8 | #include 9 | 10 | /// Sync-BN 11 | std::vector syncbn_sum_sqsum_cuda(const at::Tensor& x); 12 | at::Tensor syncbn_forward_cuda(const at::Tensor& x, const at::Tensor& weight, 13 | const at::Tensor& bias, const at::Tensor& mean, 14 | const at::Tensor& var, bool affine, float eps); 15 | std::vector syncbn_backward_xhat_cuda(const at::Tensor& dz, 16 | const at::Tensor& x, 17 | const at::Tensor& mean, 18 | const at::Tensor& var, 19 | float eps); 20 | std::vector syncbn_backward_cuda( 21 | const at::Tensor& dz, const at::Tensor& x, const at::Tensor& weight, 22 | const at::Tensor& bias, const at::Tensor& mean, const at::Tensor& var, 23 | const at::Tensor& sum_dz, const at::Tensor& sum_dz_xhat, bool affine, 24 | float eps); 25 | -------------------------------------------------------------------------------- /fbrs/model/syncbn/modules/functional/csrc/ext_lib.cpp: -------------------------------------------------------------------------------- 1 | #include "bn.h" 2 | 3 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 4 | m.def("syncbn_sum_sqsum", &syncbn_sum_sqsum, "Sum and Sum^2 computation"); 5 | m.def("syncbn_forward", &syncbn_forward, "SyncBN forward computation"); 6 | m.def("syncbn_backward_xhat", &syncbn_backward_xhat, 7 | "First part of SyncBN backward computation"); 8 | m.def("syncbn_backward", &syncbn_backward, 9 | "Second part of SyncBN backward computation"); 10 | } -------------------------------------------------------------------------------- /fbrs/model/syncbn/modules/functional/syncbn.py: -------------------------------------------------------------------------------- 1 | """ 2 | /*****************************************************************************/ 3 | 4 | BatchNorm2dSync with multi-gpu 5 | 6 | code referenced from : https://github.com/mapillary/inplace_abn 7 | 8 | /*****************************************************************************/ 9 | """ 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import torch.cuda.comm as comm 15 | from torch.autograd import Function 16 | from torch.autograd.function import once_differentiable 17 | from ._csrc import _backend 18 | 19 | 20 | def _count_samples(x): 21 | count = 1 22 | for i, s in enumerate(x.size()): 23 | if i != 1: 24 | count *= s 25 | return count 26 | 27 | 28 | class BatchNorm2dSyncFunc(Function): 29 | 30 | @staticmethod 31 | def forward(ctx, x, weight, bias, running_mean, running_var, 32 | extra, compute_stats=True, momentum=0.1, eps=1e-05): 33 | def _parse_extra(ctx, extra): 34 | ctx.is_master = extra["is_master"] 35 | if ctx.is_master: 36 | ctx.master_queue = extra["master_queue"] 37 | ctx.worker_queues = extra["worker_queues"] 38 | ctx.worker_ids = extra["worker_ids"] 39 | else: 40 | ctx.master_queue = extra["master_queue"] 41 | ctx.worker_queue = extra["worker_queue"] 42 | # Save context 43 | if extra is not None: 44 | _parse_extra(ctx, extra) 45 | ctx.compute_stats = compute_stats 46 | ctx.momentum = momentum 47 | ctx.eps = eps 48 | ctx.affine = weight is not None and bias is not None 49 | if ctx.compute_stats: 50 | N = _count_samples(x) * (ctx.master_queue.maxsize + 1) 51 | assert N > 1 52 | # 1. compute sum(x) and sum(x^2) 53 | xsum, xsqsum = _backend.syncbn_sum_sqsum(x.detach()) 54 | if ctx.is_master: 55 | xsums, xsqsums = [xsum], [xsqsum] 56 | # master : gatther all sum(x) and sum(x^2) from slaves 57 | for _ in range(ctx.master_queue.maxsize): 58 | xsum_w, xsqsum_w = ctx.master_queue.get() 59 | ctx.master_queue.task_done() 60 | xsums.append(xsum_w) 61 | xsqsums.append(xsqsum_w) 62 | xsum = comm.reduce_add(xsums) 63 | xsqsum = comm.reduce_add(xsqsums) 64 | mean = xsum / N 65 | sumvar = xsqsum - xsum * mean 66 | var = sumvar / N 67 | uvar = sumvar / (N - 1) 68 | # master : broadcast global mean, variance to all slaves 69 | tensors = comm.broadcast_coalesced( 70 | (mean, uvar, var), [mean.get_device()] + ctx.worker_ids) 71 | for ts, queue in zip(tensors[1:], ctx.worker_queues): 72 | queue.put(ts) 73 | else: 74 | # slave : send sum(x) and sum(x^2) to master 75 | ctx.master_queue.put((xsum, xsqsum)) 76 | # slave : get global mean and variance 77 | mean, uvar, var = ctx.worker_queue.get() 78 | ctx.worker_queue.task_done() 79 | 80 | # Update running stats 81 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) 82 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * uvar) 83 | ctx.N = N 84 | ctx.save_for_backward(x, weight, bias, mean, var) 85 | else: 86 | mean, var = running_mean, running_var 87 | 88 | # do batch norm forward 89 | z = _backend.syncbn_forward(x, weight, bias, mean, var, 90 | ctx.affine, ctx.eps) 91 | return z 92 | 93 | @staticmethod 94 | @once_differentiable 95 | def backward(ctx, dz): 96 | x, weight, bias, mean, var = ctx.saved_tensors 97 | dz = dz.contiguous() 98 | 99 | # 1. compute \sum(\frac{dJ}{dy_i}) and \sum(\frac{dJ}{dy_i}*\hat{x_i}) 100 | sum_dz, sum_dz_xhat = _backend.syncbn_backward_xhat( 101 | dz, x, mean, var, ctx.eps) 102 | if ctx.is_master: 103 | sum_dzs, sum_dz_xhats = [sum_dz], [sum_dz_xhat] 104 | # master : gatther from slaves 105 | for _ in range(ctx.master_queue.maxsize): 106 | sum_dz_w, sum_dz_xhat_w = ctx.master_queue.get() 107 | ctx.master_queue.task_done() 108 | sum_dzs.append(sum_dz_w) 109 | sum_dz_xhats.append(sum_dz_xhat_w) 110 | # master : compute global stats 111 | sum_dz = comm.reduce_add(sum_dzs) 112 | sum_dz_xhat = comm.reduce_add(sum_dz_xhats) 113 | sum_dz /= ctx.N 114 | sum_dz_xhat /= ctx.N 115 | # master : broadcast global stats 116 | tensors = comm.broadcast_coalesced( 117 | (sum_dz, sum_dz_xhat), [mean.get_device()] + ctx.worker_ids) 118 | for ts, queue in zip(tensors[1:], ctx.worker_queues): 119 | queue.put(ts) 120 | else: 121 | # slave : send to master 122 | ctx.master_queue.put((sum_dz, sum_dz_xhat)) 123 | # slave : get global stats 124 | sum_dz, sum_dz_xhat = ctx.worker_queue.get() 125 | ctx.worker_queue.task_done() 126 | 127 | # do batch norm backward 128 | dx, dweight, dbias = _backend.syncbn_backward( 129 | dz, x, weight, bias, mean, var, sum_dz, sum_dz_xhat, 130 | ctx.affine, ctx.eps) 131 | 132 | return dx, dweight, dbias, \ 133 | None, None, None, None, None, None 134 | 135 | batchnorm2d_sync = BatchNorm2dSyncFunc.apply 136 | 137 | __all__ = ["batchnorm2d_sync"] 138 | -------------------------------------------------------------------------------- /fbrs/model/syncbn/modules/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .syncbn import * 2 | -------------------------------------------------------------------------------- /fbrs/model/syncbn/modules/nn/syncbn.py: -------------------------------------------------------------------------------- 1 | """ 2 | /*****************************************************************************/ 3 | 4 | BatchNorm2dSync with multi-gpu 5 | 6 | /*****************************************************************************/ 7 | """ 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | try: 13 | # python 3 14 | from queue import Queue 15 | except ImportError: 16 | # python 2 17 | from Queue import Queue 18 | 19 | import torch 20 | import torch.nn as nn 21 | from torch.nn import functional as F 22 | from torch.nn.parameter import Parameter 23 | from isegm.model.syncbn.modules.functional import batchnorm2d_sync 24 | 25 | 26 | class _BatchNorm(nn.Module): 27 | """ 28 | Customized BatchNorm from nn.BatchNorm 29 | >> added freeze attribute to enable bn freeze. 30 | """ 31 | 32 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 33 | track_running_stats=True): 34 | super(_BatchNorm, self).__init__() 35 | self.num_features = num_features 36 | self.eps = eps 37 | self.momentum = momentum 38 | self.affine = affine 39 | self.track_running_stats = track_running_stats 40 | self.freezed = False 41 | if self.affine: 42 | self.weight = Parameter(torch.Tensor(num_features)) 43 | self.bias = Parameter(torch.Tensor(num_features)) 44 | else: 45 | self.register_parameter('weight', None) 46 | self.register_parameter('bias', None) 47 | if self.track_running_stats: 48 | self.register_buffer('running_mean', torch.zeros(num_features)) 49 | self.register_buffer('running_var', torch.ones(num_features)) 50 | else: 51 | self.register_parameter('running_mean', None) 52 | self.register_parameter('running_var', None) 53 | self.reset_parameters() 54 | 55 | def reset_parameters(self): 56 | if self.track_running_stats: 57 | self.running_mean.zero_() 58 | self.running_var.fill_(1) 59 | if self.affine: 60 | self.weight.data.uniform_() 61 | self.bias.data.zero_() 62 | 63 | def _check_input_dim(self, input): 64 | return NotImplemented 65 | 66 | def forward(self, input): 67 | self._check_input_dim(input) 68 | 69 | compute_stats = not self.freezed and \ 70 | self.training and self.track_running_stats 71 | 72 | ret = F.batch_norm(input, self.running_mean, self.running_var, 73 | self.weight, self.bias, compute_stats, 74 | self.momentum, self.eps) 75 | return ret 76 | 77 | def extra_repr(self): 78 | return '{num_features}, eps={eps}, momentum={momentum}, '\ 79 | 'affine={affine}, ' \ 80 | 'track_running_stats={track_running_stats}'.format( 81 | **self.__dict__) 82 | 83 | 84 | class BatchNorm2dNoSync(_BatchNorm): 85 | """ 86 | Equivalent to nn.BatchNorm2d 87 | """ 88 | 89 | def _check_input_dim(self, input): 90 | if input.dim() != 4: 91 | raise ValueError('expected 4D input (got {}D input)' 92 | .format(input.dim())) 93 | 94 | 95 | class BatchNorm2dSync(BatchNorm2dNoSync): 96 | """ 97 | BatchNorm2d with automatic multi-GPU Sync 98 | """ 99 | 100 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 101 | track_running_stats=True): 102 | super(BatchNorm2dSync, self).__init__( 103 | num_features, eps=eps, momentum=momentum, affine=affine, 104 | track_running_stats=track_running_stats) 105 | self.sync_enabled = True 106 | self.devices = list(range(torch.cuda.device_count())) 107 | if len(self.devices) > 1: 108 | # Initialize queues 109 | self.worker_ids = self.devices[1:] 110 | self.master_queue = Queue(len(self.worker_ids)) 111 | self.worker_queues = [Queue(1) for _ in self.worker_ids] 112 | 113 | def forward(self, x): 114 | compute_stats = not self.freezed and \ 115 | self.training and self.track_running_stats 116 | if self.sync_enabled and compute_stats and len(self.devices) > 1: 117 | if x.get_device() == self.devices[0]: 118 | # Master mode 119 | extra = { 120 | "is_master": True, 121 | "master_queue": self.master_queue, 122 | "worker_queues": self.worker_queues, 123 | "worker_ids": self.worker_ids 124 | } 125 | else: 126 | # Worker mode 127 | extra = { 128 | "is_master": False, 129 | "master_queue": self.master_queue, 130 | "worker_queue": self.worker_queues[ 131 | self.worker_ids.index(x.get_device())] 132 | } 133 | return batchnorm2d_sync(x, self.weight, self.bias, 134 | self.running_mean, self.running_var, 135 | extra, compute_stats, self.momentum, 136 | self.eps) 137 | return super(BatchNorm2dSync, self).forward(x) 138 | 139 | def __repr__(self): 140 | """repr""" 141 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ 142 | 'affine={affine}, ' \ 143 | 'track_running_stats={track_running_stats},' \ 144 | 'devices={devices})' 145 | return rep.format(name=self.__class__.__name__, **self.__dict__) 146 | 147 | #BatchNorm2d = BatchNorm2dNoSync 148 | BatchNorm2d = BatchNorm2dSync 149 | -------------------------------------------------------------------------------- /fbrs/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/MiVOS/f2600a6eea8709c7b9f1a7575adc725def680b81/fbrs/utils/__init__.py -------------------------------------------------------------------------------- /fbrs/utils/cython/__init__.py: -------------------------------------------------------------------------------- 1 | # noinspection PyUnresolvedReferences 2 | from .dist_maps import get_dist_maps -------------------------------------------------------------------------------- /fbrs/utils/cython/_get_dist_maps.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | cimport cython 3 | cimport numpy as np 4 | from libc.stdlib cimport malloc, free 5 | 6 | ctypedef struct qnode: 7 | int row 8 | int col 9 | int layer 10 | int orig_row 11 | int orig_col 12 | 13 | @cython.infer_types(True) 14 | @cython.boundscheck(False) 15 | @cython.wraparound(False) 16 | @cython.nonecheck(False) 17 | def get_dist_maps(np.ndarray[np.float32_t, ndim=2, mode="c"] points, 18 | int height, int width, float norm_delimeter): 19 | cdef np.ndarray[np.float32_t, ndim=3, mode="c"] dist_maps = \ 20 | np.full((2, height, width), 1e6, dtype=np.float32, order="C") 21 | 22 | cdef int *dxy = [-1, 0, 0, -1, 0, 1, 1, 0] 23 | cdef int i, j, x, y, dx, dy 24 | cdef qnode v 25 | cdef qnode *q = malloc((4 * height * width + 1) * sizeof(qnode)) 26 | cdef int qhead = 0, qtail = -1 27 | cdef float ndist 28 | 29 | for i in range(points.shape[0]): 30 | x, y = round(points[i, 0]), round(points[i, 1]) 31 | if x >= 0: 32 | qtail += 1 33 | q[qtail].row = x 34 | q[qtail].col = y 35 | q[qtail].orig_row = x 36 | q[qtail].orig_col = y 37 | if i >= points.shape[0] / 2: 38 | q[qtail].layer = 1 39 | else: 40 | q[qtail].layer = 0 41 | dist_maps[q[qtail].layer, x, y] = 0 42 | 43 | while qtail - qhead + 1 > 0: 44 | v = q[qhead] 45 | qhead += 1 46 | 47 | for k in range(4): 48 | x = v.row + dxy[2 * k] 49 | y = v.col + dxy[2 * k + 1] 50 | 51 | ndist = ((x - v.orig_row)/norm_delimeter) ** 2 + ((y - v.orig_col)/norm_delimeter) ** 2 52 | if (x >= 0 and y >= 0 and x < height and y < width and 53 | dist_maps[v.layer, x, y] > ndist): 54 | qtail += 1 55 | q[qtail].orig_col = v.orig_col 56 | q[qtail].orig_row = v.orig_row 57 | q[qtail].layer = v.layer 58 | q[qtail].row = x 59 | q[qtail].col = y 60 | dist_maps[v.layer, x, y] = ndist 61 | 62 | free(q) 63 | return dist_maps 64 | -------------------------------------------------------------------------------- /fbrs/utils/cython/_get_dist_maps.pyxbld: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | def make_ext(modname, pyxfilename): 4 | from distutils.extension import Extension 5 | return Extension(modname, [pyxfilename], 6 | include_dirs=[numpy.get_include()], 7 | extra_compile_args=['-O3'], language='c++') 8 | -------------------------------------------------------------------------------- /fbrs/utils/cython/dist_maps.py: -------------------------------------------------------------------------------- 1 | import pyximport; pyximport.install(pyximport=True, language_level=3) 2 | # noinspection PyUnresolvedReferences 3 | from ._get_dist_maps import get_dist_maps -------------------------------------------------------------------------------- /fbrs/utils/misc.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def get_dims_with_exclusion(dim, exclude=None): 8 | dims = list(range(dim)) 9 | if exclude is not None: 10 | dims.remove(exclude) 11 | 12 | return dims 13 | 14 | 15 | def get_unique_labels(mask): 16 | return np.nonzero(np.bincount(mask.flatten() + 1))[0] - 1 17 | 18 | 19 | def get_bbox_from_mask(mask): 20 | rows = np.any(mask, axis=1) 21 | cols = np.any(mask, axis=0) 22 | rmin, rmax = np.where(rows)[0][[0, -1]] 23 | cmin, cmax = np.where(cols)[0][[0, -1]] 24 | 25 | return rmin, rmax, cmin, cmax 26 | 27 | 28 | def expand_bbox(bbox, expand_ratio, min_crop_size=None): 29 | rmin, rmax, cmin, cmax = bbox 30 | rcenter = 0.5 * (rmin + rmax) 31 | ccenter = 0.5 * (cmin + cmax) 32 | height = expand_ratio * (rmax - rmin + 1) 33 | width = expand_ratio * (cmax - cmin + 1) 34 | if min_crop_size is not None: 35 | height = max(height, min_crop_size) 36 | width = max(width, min_crop_size) 37 | 38 | rmin = int(round(rcenter - 0.5 * height)) 39 | rmax = int(round(rcenter + 0.5 * height)) 40 | cmin = int(round(ccenter - 0.5 * width)) 41 | cmax = int(round(ccenter + 0.5 * width)) 42 | 43 | return rmin, rmax, cmin, cmax 44 | 45 | 46 | def clamp_bbox(bbox, rmin, rmax, cmin, cmax): 47 | return (max(rmin, bbox[0]), min(rmax, bbox[1]), 48 | max(cmin, bbox[2]), min(cmax, bbox[3])) 49 | 50 | 51 | def get_bbox_iou(b1, b2): 52 | h_iou = get_segments_iou(b1[:2], b2[:2]) 53 | w_iou = get_segments_iou(b1[2:4], b2[2:4]) 54 | return h_iou * w_iou 55 | 56 | 57 | def get_segments_iou(s1, s2): 58 | a, b = s1 59 | c, d = s2 60 | intersection = max(0, min(b, d) - max(a, c) + 1) 61 | union = max(1e-6, max(b, d) - min(a, c) + 1) 62 | return intersection / union 63 | -------------------------------------------------------------------------------- /fbrs/utils/vis.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | 7 | def visualize_instances(imask, bg_color=255, 8 | boundaries_color=None, boundaries_width=1, boundaries_alpha=0.8): 9 | num_objects = imask.max() + 1 10 | palette = get_palette(num_objects) 11 | if bg_color is not None: 12 | palette[0] = bg_color 13 | 14 | result = palette[imask].astype(np.uint8) 15 | if boundaries_color is not None: 16 | boundaries_mask = get_boundaries(imask, boundaries_width=boundaries_width) 17 | tresult = result.astype(np.float32) 18 | tresult[boundaries_mask] = boundaries_color 19 | tresult = tresult * boundaries_alpha + (1 - boundaries_alpha) * result 20 | result = tresult.astype(np.uint8) 21 | 22 | return result 23 | 24 | 25 | @lru_cache(maxsize=16) 26 | def get_palette(num_cls): 27 | palette = np.zeros(3 * num_cls, dtype=np.int32) 28 | 29 | for j in range(0, num_cls): 30 | lab = j 31 | i = 0 32 | 33 | while lab > 0: 34 | palette[j*3 + 0] |= (((lab >> 0) & 1) << (7-i)) 35 | palette[j*3 + 1] |= (((lab >> 1) & 1) << (7-i)) 36 | palette[j*3 + 2] |= (((lab >> 2) & 1) << (7-i)) 37 | i = i + 1 38 | lab >>= 3 39 | 40 | return palette.reshape((-1, 3)) 41 | 42 | 43 | def visualize_mask(mask, num_cls): 44 | palette = get_palette(num_cls) 45 | mask[mask == -1] = 0 46 | 47 | return palette[mask].astype(np.uint8) 48 | 49 | 50 | def visualize_proposals(proposals_info, point_color=(255, 0, 0), point_radius=1): 51 | proposal_map, colors, candidates = proposals_info 52 | 53 | proposal_map = draw_probmap(proposal_map) 54 | for x, y in candidates: 55 | proposal_map = cv2.circle(proposal_map, (y, x), point_radius, point_color, -1) 56 | 57 | return proposal_map 58 | 59 | 60 | def draw_probmap(x): 61 | return cv2.applyColorMap((x * 255).astype(np.uint8), cv2.COLORMAP_HOT) 62 | 63 | 64 | def draw_points(image, points, color, radius=3): 65 | image = image.copy() 66 | for p in points: 67 | image = cv2.circle(image, (int(p[1]), int(p[0])), radius, color, -1) 68 | 69 | return image 70 | 71 | 72 | def draw_instance_map(x, palette=None): 73 | num_colors = x.max() + 1 74 | if palette is None: 75 | palette = get_palette(num_colors) 76 | 77 | return palette[x].astype(np.uint8) 78 | 79 | 80 | def blend_mask(image, mask, alpha=0.6): 81 | if mask.min() == -1: 82 | mask = mask.copy() + 1 83 | 84 | imap = draw_instance_map(mask) 85 | result = (image * (1 - alpha) + alpha * imap).astype(np.uint8) 86 | return result 87 | 88 | 89 | def get_boundaries(instances_masks, boundaries_width=1): 90 | boundaries = np.zeros((instances_masks.shape[0], instances_masks.shape[1]), dtype=np.bool) 91 | 92 | for obj_id in np.unique(instances_masks.flatten()): 93 | if obj_id == 0: 94 | continue 95 | 96 | obj_mask = instances_masks == obj_id 97 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) 98 | inner_mask = cv2.erode(obj_mask.astype(np.uint8), kernel, iterations=boundaries_width).astype(np.bool) 99 | 100 | obj_boundary = np.logical_xor(obj_mask, np.logical_and(inner_mask, obj_mask)) 101 | boundaries = np.logical_or(boundaries, obj_boundary) 102 | return boundaries 103 | 104 | 105 | def draw_with_blend_and_clicks(img, mask=None, alpha=0.6, clicks_list=None, pos_color=(0, 255, 0), 106 | neg_color=(255, 0, 0), radius=4): 107 | result = img.copy() 108 | 109 | if mask is not None: 110 | palette = get_palette(np.max(mask) + 1) 111 | rgb_mask = palette[mask.astype(np.uint8)] 112 | 113 | mask_region = (mask > 0).astype(np.uint8) 114 | result = result * (1 - mask_region[:, :, np.newaxis]) + \ 115 | (1 - alpha) * mask_region[:, :, np.newaxis] * result + \ 116 | alpha * rgb_mask 117 | result = result.astype(np.uint8) 118 | 119 | # result = (result * (1 - alpha) + alpha * rgb_mask).astype(np.uint8) 120 | 121 | if clicks_list is not None and len(clicks_list) > 0: 122 | pos_points = [click.coords for click in clicks_list if click.is_positive] 123 | neg_points = [click.coords for click in clicks_list if not click.is_positive] 124 | 125 | result = draw_points(result, pos_points, pos_color, radius=radius) 126 | result = draw_points(result, neg_points, neg_color, radius=radius) 127 | 128 | return result 129 | 130 | -------------------------------------------------------------------------------- /generate_fusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate fusion data for the DAVIS dataset. 3 | """ 4 | 5 | import os 6 | from os import path 7 | from argparse import ArgumentParser 8 | 9 | import torch 10 | from torch.utils.data import Dataset, DataLoader 11 | import numpy as np 12 | from PIL import Image 13 | import cv2 14 | 15 | from model.propagation.prop_net import PropagationNetwork 16 | from dataset.davis_test_dataset import DAVISTestDataset 17 | from dataset.bl_test_dataset import BLTestDataset 18 | from generation.fusion_generator import FusionGenerator 19 | 20 | from progressbar import progressbar 21 | 22 | 23 | """ 24 | Arguments loading 25 | """ 26 | parser = ArgumentParser() 27 | parser.add_argument('--model', default='saves/propagation_model.pth') 28 | parser.add_argument('--davis_root', default='../DAVIS/2017') 29 | parser.add_argument('--bl_root', default='../BL30K') 30 | parser.add_argument('--dataset', help='DAVIS/BL') 31 | parser.add_argument('--output') 32 | parser.add_argument('--separation', default=None, type=int) 33 | parser.add_argument('--range', default=None, type=int) 34 | parser.add_argument('--mem_freq', default=None, type=int) 35 | parser.add_argument('--start', default=None, type=int) 36 | parser.add_argument('--end', default=None, type=int) 37 | args = parser.parse_args() 38 | 39 | davis_path = args.davis_root 40 | bl_path = args.bl_root 41 | out_path = args.output 42 | dataset_option = args.dataset 43 | 44 | # Simple setup 45 | os.makedirs(out_path, exist_ok=True) 46 | palette = Image.open(path.expanduser(davis_path+'/trainval/Annotations/480p/blackswan/00000.png')).getpalette() 47 | 48 | torch.autograd.set_grad_enabled(False) 49 | 50 | # Setup Dataset 51 | if dataset_option == 'DAVIS': 52 | test_dataset = DAVISTestDataset(davis_path+'/trainval', imset='2017/train.txt') 53 | elif dataset_option == 'BL': 54 | test_dataset = BLTestDataset(bl_path, start=args.start, end=args.end) 55 | else: 56 | print('Use --dataset DAVIS or --dataset BL') 57 | raise NotImplementedError 58 | 59 | # test_dataset = BLTestDataset(args.bl, start=args.start, end=args.end, subset=load_sub_bl()) 60 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=False) 61 | 62 | # Load our checkpoint 63 | prop_saved = torch.load(args.model) 64 | prop_model = PropagationNetwork().cuda().eval() 65 | prop_model.load_state_dict(prop_saved) 66 | 67 | # Start evaluation 68 | for data in progressbar(test_loader, max_value=len(test_loader), redirect_stdout=True): 69 | 70 | rgb = data['rgb'].cuda() 71 | msk = data['gt'][0].cuda() 72 | info = data['info'] 73 | 74 | total_t = rgb.shape[1] 75 | processor = FusionGenerator(prop_model, rgb, args.mem_freq) 76 | 77 | for frame in range(0, total_t, args.separation): 78 | 79 | usable_keys = [] 80 | for k in range(msk.shape[0]): 81 | if (msk[k,frame] > 0.5).sum() > 10*10: 82 | usable_keys.append(k) 83 | if len(usable_keys) == 0: 84 | continue 85 | if len(usable_keys) > 5: 86 | # Memory limit 87 | usable_keys = usable_keys[:5] 88 | 89 | k = len(usable_keys) 90 | processor.reset(k) 91 | this_msk = msk[usable_keys] 92 | 93 | # Make this directory 94 | this_out_path = path.join(out_path, info['name'][0], '%05d'%frame) 95 | os.makedirs(this_out_path, exist_ok=True) 96 | 97 | # Propagate 98 | if dataset_option == 'DAVIS': 99 | left_limit = 0 100 | right_limit = total_t-1 101 | else: 102 | left_limit = max(0, frame-args.range) 103 | right_limit = min(total_t-1, frame+args.range) 104 | 105 | pred_range = range(left_limit, right_limit+1) 106 | out_probs = processor.interact_mask(this_msk[:,frame], frame, left_limit, right_limit) 107 | 108 | for kidx, obj_id in enumerate(usable_keys): 109 | obj_out_path = path.join(this_out_path, '%05d'%(obj_id+1)) 110 | os.makedirs(obj_out_path, exist_ok=True) 111 | prob_Es = (out_probs[kidx+1]*255).cpu().numpy().astype(np.uint8) 112 | 113 | for f in pred_range: 114 | img_E = Image.fromarray(prob_Es[f]) 115 | img_E.save(os.path.join(obj_out_path, '{:05d}.png'.format(f))) 116 | 117 | del out_probs 118 | 119 | print(info['name'][0]) 120 | torch.cuda.empty_cache() -------------------------------------------------------------------------------- /generation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/MiVOS/f2600a6eea8709c7b9f1a7575adc725def680b81/generation/__init__.py -------------------------------------------------------------------------------- /generation/blender/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/MiVOS/f2600a6eea8709c7b9f1a7575adc725def680b81/generation/blender/__init__.py -------------------------------------------------------------------------------- /generation/blender/clean_data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from os import path 4 | import shutil 5 | 6 | """ 7 | Look at a "rendered" folder, move the rendered to the output path 8 | Keep the empty folders in place (don't delete since it might still be rendered right) 9 | Also copy the corresponding yaml in there 10 | """ 11 | 12 | input_path = sys.argv[1] 13 | output_path = sys.argv[2] 14 | yaml_path = sys.argv[3] 15 | 16 | # This overrides the softlink 17 | # os.makedirs(output_path, exist_ok=True) 18 | 19 | renders = os.listdir(input_path) 20 | is_rendered = [len(os.listdir(path.join(input_path, r, 'segmentation')))==160 for r in renders] 21 | 22 | updated = 0 23 | for i, r in enumerate(renders): 24 | if is_rendered[i]: 25 | if not path.exists(path.join(output_path, r)): 26 | shutil.move(path.join(input_path, r), output_path) 27 | prefix = r[:3] 28 | 29 | shutil.copy2(path.join(yaml_path, 'yaml_%s'%prefix, '%s.yaml'%r), path.join(output_path, r)) 30 | updated += 1 31 | else: 32 | print('path exist') 33 | else: 34 | # Nothing for now. Can do something later 35 | pass 36 | 37 | print('Number of completed renders: ', len(os.listdir(output_path))) 38 | print('Number of updated renders: ', updated) -------------------------------------------------------------------------------- /generation/blender/gen_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.polynomial.polynomial as poly 3 | from scipy import optimize 4 | 5 | 6 | class Sampler: 7 | def __init__(self, data_list): 8 | self.data_list = data_list 9 | self.idx = 0 10 | self.permute() 11 | 12 | def permute(self): 13 | self.data_list = np.random.permutation(self.data_list) 14 | 15 | def next(self): 16 | if self.idx == len(self.data_list): 17 | self.permute() 18 | self.idx = 0 19 | data = self.data_list[self.idx] 20 | self.idx += 1 21 | return data 22 | 23 | def step_back(self): 24 | self.idx -= 1 25 | if self.idx == -1: 26 | self.idx = len(self.data_list) - 1 27 | 28 | def test_path(prev_paths, path, tol=0.75): 29 | min_dist = float('inf') 30 | path = np.array(path) 31 | for p in prev_paths: 32 | p = np.array(p) 33 | # Find min distance as a constrained optimization problem 34 | poly_vals = p - path 35 | f = lambda x: np.linalg.norm(poly.polyval(x, poly_vals)) 36 | optim_x = optimize.minimize_scalar(f, bounds=(0, 1), method='bounded') 37 | if optim_x.fun < tol: 38 | # print('Fail') 39 | return False 40 | # print('Success') 41 | return True 42 | 43 | def pick_rand(min_v, max_v, shape=None): 44 | if shape is not None: 45 | return np.random.rand(shape)*(max_v-min_v) + min_v 46 | else: 47 | return np.random.rand()*(max_v-min_v) + min_v 48 | 49 | def pick_normal_rand(mean, std, shape=None): 50 | return np.random.normal(mean, std, shape) 51 | 52 | def pick_randint(min_v, max_v): 53 | return np.random.randint(min_v, max_v+1) 54 | 55 | def normalize(a): 56 | return a / np.linalg.norm(a) 57 | 58 | def get_2side_rand(max_delta, shape=1): 59 | return np.random.rand(shape)*2*max_delta-max_delta 60 | 61 | def get_vector_in_frustum(min_base, max_base, min_into, max_into, cam_min_into): 62 | y = pick_rand(min_into, max_into) 63 | 64 | f_min_base = min_base * ((y - cam_min_into) / (min_into - cam_min_into)) 65 | f_max_base = max_base * ((y - cam_min_into) / (min_into - cam_min_into)) 66 | 67 | x = pick_rand(f_min_base, f_max_base) 68 | z = pick_rand(f_min_base, f_max_base) 69 | return np.array((x, y, z)) 70 | 71 | def get_vector_in_block(min_base, max_base, min_into, max_into): 72 | x = pick_rand(min_base, max_base) 73 | y = pick_rand(min_into, max_into) 74 | z = pick_rand(min_base, max_base) 75 | return np.array((x, y, z)) 76 | 77 | def get_vector_on_sphere(radius): 78 | x1 = np.random.normal(0, 1) 79 | x2 = np.random.normal(0, 1) 80 | x3 = np.random.normal(0, 1) 81 | norm = (x1*x1 + x2*x2 + x3*x3)**(1/2) 82 | pt = radius*np.array((x1,x2,x3))/norm 83 | return pt 84 | 85 | def get_next_vector_in_block(curr_vec, max_delta, min_base, max_base, min_into, max_into): 86 | new_point = get_vector_in_block(min_base, max_base, min_into, max_into) 87 | 88 | max_delta = np.abs(max_delta) 89 | dist_vec = (new_point - curr_vec) 90 | for i in range(3): 91 | if dist_vec[i] > max_delta[i]: 92 | dist_vec[i] = np.sign(dist_vec[i]) * max_delta[i] 93 | 94 | new_point = curr_vec + dist_vec 95 | return new_point 96 | 97 | def get_next_vector_in_frustum(curr_vec, max_delta, min_base, max_base, min_into, max_into, cam_min_into): 98 | new_point = get_vector_in_frustum(min_base, max_base, min_into, max_into, cam_min_into) 99 | 100 | max_delta = np.abs(max_delta) 101 | dist_vec = (new_point - curr_vec) 102 | for i in range(3): 103 | if dist_vec[i] > max_delta[i]: 104 | dist_vec[i] = np.sign(dist_vec[i]) * max_delta[i] 105 | 106 | new_point = curr_vec + dist_vec 107 | return new_point -------------------------------------------------------------------------------- /generation/blender/resize_texture.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import cv2 4 | 5 | from multiprocessing import Pool 6 | from progressbar import progressbar 7 | 8 | input_dir = sys.argv[1] 9 | output_dir = sys.argv[2] 10 | 11 | min_size = 512 12 | 13 | def process_fun(sub_dir): 14 | this_in_dir = os.path.join(input_dir, sub_dir) 15 | this_out_dir = os.path.join(output_dir, sub_dir) 16 | os.makedirs(this_out_dir, exist_ok=True) 17 | 18 | for f in os.listdir(this_in_dir): 19 | img = cv2.imread(os.path.join(this_in_dir, f)) 20 | if img is None: 21 | continue 22 | h, w, _ = img.shape 23 | 24 | scale = min(h, w) / min_size 25 | 26 | img = cv2.resize(img, (int(w/scale), int(h/scale)), interpolation=cv2.INTER_AREA) 27 | if len(img.shape) == 3: 28 | img = img[0:min_size, 0:min_size, :] 29 | else: 30 | img = img[0:min_size, 0:min_size] 31 | cv2.imwrite(os.path.join(this_out_dir, os.path.basename(f)), img) 32 | 33 | if __name__ == '__main__': 34 | pool = Pool() 35 | chunksize = 1 36 | 37 | os.makedirs(output_dir, exist_ok=True) 38 | for _ in progressbar(pool.map(process_fun, os.listdir(input_dir)), chunksize): 39 | pass 40 | 41 | print('All done.') -------------------------------------------------------------------------------- /generation/blender/texture.json: -------------------------------------------------------------------------------- 1 | { 2 | "Records": [ 3 | { 4 | "keywords": "wood texture", 5 | "limit": 300, 6 | "usage_rights": "labeled-for-nocommercial-reuse", 7 | "print_urls": true, 8 | "output_directory": "texture" 9 | }, 10 | { 11 | "keywords": "metal texture", 12 | "limit": 300, 13 | "usage_rights": "labeled-for-nocommercial-reuse", 14 | "print_urls": true, 15 | "output_directory": "texture" 16 | }, 17 | { 18 | "keywords": "tree texture", 19 | "limit": 300, 20 | "usage_rights": "labeled-for-nocommercial-reuse", 21 | "print_urls": true, 22 | "output_directory": "texture" 23 | }, 24 | { 25 | "keywords": "glass texture", 26 | "limit": 300, 27 | "usage_rights": "labeled-for-nocommercial-reuse", 28 | "print_urls": true, 29 | "output_directory": "texture" 30 | }, 31 | { 32 | "keywords": "plastic texture", 33 | "limit": 300, 34 | "usage_rights": "labeled-for-nocommercial-reuse", 35 | "print_urls": true, 36 | "output_directory": "texture" 37 | }, 38 | { 39 | "keywords": "brick texture", 40 | "limit": 300, 41 | "usage_rights": "labeled-for-nocommercial-reuse", 42 | "print_urls": true, 43 | "output_directory": "texture" 44 | }, 45 | { 46 | "keywords": "wall texture", 47 | "limit": 300, 48 | "usage_rights": "labeled-for-nocommercial-reuse", 49 | "print_urls": true, 50 | "output_directory": "texture" 51 | }, 52 | { 53 | "keywords": "concrete texture", 54 | "limit": 300, 55 | "usage_rights": "labeled-for-nocommercial-reuse", 56 | "print_urls": true, 57 | "output_directory": "texture" 58 | }, 59 | { 60 | "keywords": "dirt texture", 61 | "limit": 300, 62 | "usage_rights": "labeled-for-nocommercial-reuse", 63 | "print_urls": true, 64 | "output_directory": "texture" 65 | }, 66 | { 67 | "keywords": "paper texture", 68 | "limit": 300, 69 | "usage_rights": "labeled-for-nocommercial-reuse", 70 | "print_urls": true, 71 | "output_directory": "texture" 72 | }, 73 | { 74 | "keywords": "skin texture", 75 | "limit": 300, 76 | "usage_rights": "labeled-for-nocommercial-reuse", 77 | "print_urls": true, 78 | "output_directory": "texture" 79 | }, 80 | { 81 | "keywords": "meat texture", 82 | "limit": 300, 83 | "usage_rights": "labeled-for-nocommercial-reuse", 84 | "print_urls": true, 85 | "output_directory": "texture" 86 | }, 87 | { 88 | "keywords": "hair texture", 89 | "limit": 300, 90 | "usage_rights": "labeled-for-nocommercial-reuse", 91 | "print_urls": true, 92 | "output_directory": "texture" 93 | }, 94 | { 95 | "keywords": "electronics texture", 96 | "limit": 300, 97 | "usage_rights": "labeled-for-nocommercial-reuse", 98 | "print_urls": true, 99 | "output_directory": "texture" 100 | }, 101 | { 102 | "keywords": "fur texture", 103 | "limit": 300, 104 | "usage_rights": "labeled-for-nocommercial-reuse", 105 | "print_urls": true, 106 | "output_directory": "texture" 107 | } 108 | ] 109 | } -------------------------------------------------------------------------------- /generation/blender/texture_2.json: -------------------------------------------------------------------------------- 1 | { 2 | "Records": [ 3 | { 4 | "keywords": "stone texture", 5 | "limit": 300, 6 | "usage_rights": "labeled-for-nocommercial-reuse", 7 | "print_urls": true, 8 | "output_directory": "texture" 9 | }, 10 | { 11 | "keywords": "fabric texture", 12 | "limit": 300, 13 | "usage_rights": "labeled-for-nocommercial-reuse", 14 | "print_urls": true, 15 | "output_directory": "texture" 16 | }, 17 | { 18 | "keywords": "pebble texture", 19 | "limit": 300, 20 | "usage_rights": "labeled-for-nocommercial-reuse", 21 | "print_urls": true, 22 | "output_directory": "texture" 23 | }, 24 | { 25 | "keywords": "sand texture", 26 | "limit": 300, 27 | "usage_rights": "labeled-for-nocommercial-reuse", 28 | "print_urls": true, 29 | "output_directory": "texture" 30 | }, 31 | { 32 | "keywords": "beach texture", 33 | "limit": 300, 34 | "usage_rights": "labeled-for-nocommercial-reuse", 35 | "print_urls": true, 36 | "output_directory": "texture" 37 | } 38 | ] 39 | } -------------------------------------------------------------------------------- /generation/fusion_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | A helper class for generating data used to train the fusion module. 3 | This part is rather crude as it is used only once/twice. 4 | """ 5 | import torch 6 | import numpy as np 7 | 8 | from model.propagation.prop_net import PropagationNetwork 9 | from model.aggregate import aggregate_sbg, aggregate_wbg 10 | 11 | from util.tensor_util import pad_divide_by 12 | 13 | class FusionGenerator: 14 | def __init__(self, prop_net:PropagationNetwork, images, mem_freq): 15 | self.prop_net = prop_net 16 | self.mem_freq = mem_freq 17 | 18 | # True dimensions 19 | t = images.shape[1] 20 | h, w = images.shape[-2:] 21 | 22 | # Pad each side to multiple of 16 23 | images, self.pad = pad_divide_by(images, 16, images.shape[-2:]) 24 | # Padded dimensions 25 | nh, nw = images.shape[-2:] 26 | 27 | self.images = images 28 | self.device = self.images.device 29 | 30 | self.t, self.h, self.w = t, h, w 31 | self.nh, self.nw = nh, nw 32 | 33 | def reset(self, k): 34 | self.k = k 35 | self.prob = torch.zeros((self.k+1, self.t, 1, self.nh, self.nw), dtype=torch.float32, device=self.device) 36 | 37 | def get_im(self, idx): 38 | return self.images[:,idx] 39 | 40 | def get_query_buf(self, idx): 41 | query = self.prop_net.get_query_values(self.get_im(idx)) 42 | return query 43 | 44 | def do_pass(self, key_k, key_v, idx, left_limit, right_limit, forward=True): 45 | keys = key_k 46 | values = key_v 47 | prev_k = prev_v = None 48 | last_ti = idx 49 | 50 | Es = self.prob 51 | 52 | if forward: 53 | this_range = range(idx+1, right_limit+1) 54 | step = +1 55 | end = right_limit 56 | else: 57 | this_range = range(idx-1, left_limit-1, -1) 58 | step = -1 59 | end = left_limit 60 | 61 | for ti in this_range: 62 | if prev_k is not None: 63 | this_k = torch.cat([keys, prev_k], 2) 64 | this_v = torch.cat([values, prev_v], 2) 65 | else: 66 | this_k = keys 67 | this_v = values 68 | query = self.get_query_buf(ti) 69 | out_mask = self.prop_net.segment_with_query(this_k, this_v, *query) 70 | out_mask = aggregate_wbg(out_mask, keep_bg=True) 71 | 72 | Es[:,ti] = out_mask 73 | 74 | if ti != end: 75 | prev_k, prev_v = self.prop_net.memorize(self.get_im(ti), out_mask[1:]) 76 | if abs(ti-last_ti) >= self.mem_freq: 77 | last_ti = ti 78 | keys = torch.cat([keys, prev_k], 2) 79 | values = torch.cat([values, prev_v], 2) 80 | prev_k = prev_v = None 81 | 82 | def interact_mask(self, mask, idx, left_limit, right_limit): 83 | 84 | mask, _ = pad_divide_by(mask, 16, mask.shape[-2:]) 85 | mask = aggregate_wbg(mask, keep_bg=True) 86 | 87 | self.prob[:, idx] = mask 88 | key_k, key_v = self.prop_net.memorize(self.get_im(idx), mask[1:]) 89 | 90 | self.do_pass(key_k, key_v, idx, left_limit, right_limit, True) 91 | self.do_pass(key_k, key_v, idx, left_limit, right_limit, False) 92 | 93 | # Prepare output 94 | out_prob = self.prob[:,:,0,:,:] 95 | 96 | if self.pad[2]+self.pad[3] > 0: 97 | out_prob = out_prob[:,:,self.pad[2]:-self.pad[3],:] 98 | if self.pad[0]+self.pad[1] > 0: 99 | out_prob = out_prob[:,:,:,self.pad[0]:-self.pad[1]] 100 | 101 | return out_prob 102 | -------------------------------------------------------------------------------- /generation/test/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/MiVOS/f2600a6eea8709c7b9f1a7575adc725def680b81/generation/test/image.png -------------------------------------------------------------------------------- /generation/test/test_spline.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from PIL import Image 7 | import cv2 8 | import thinplate as tps 9 | 10 | 11 | def show_warped(img, warped): 12 | fig, axs = plt.subplots(1, 2, figsize=(16,8)) 13 | axs[0].axis('off') 14 | axs[1].axis('off') 15 | axs[0].imshow(img[...,::-1], origin='upper') 16 | axs[0].scatter(c_src[:, 0]*img.shape[1], c_src[:, 1]*img.shape[0], marker='+', color='black') 17 | axs[1].imshow(warped[...,::-1], origin='upper') 18 | axs[1].scatter(c_dst[:, 0]*warped.shape[1], c_dst[:, 1]*warped.shape[0], marker='+', color='black') 19 | plt.show() 20 | 21 | def warp_image_cv(img, c_src, c_dst, dshape=None): 22 | dshape = dshape or img.shape 23 | theta = tps.tps_theta_from_points(c_src, c_dst, reduced=True) 24 | grid = tps.tps_grid(theta, c_dst, dshape) 25 | mapx, mapy = tps.tps_grid_to_remap(grid, img.shape) 26 | return cv2.remap(img, mapx, mapy, cv2.INTER_CUBIC) 27 | 28 | img = cv2.imread('image.png') 29 | 30 | c_src = np.array([ 31 | [0.0, 0.0], 32 | [1., 0], 33 | [1, 1], 34 | [0, 1], 35 | [0.3, 0.3], 36 | [0.7, 0.7], 37 | ]) 38 | 39 | c_dst = np.array([ 40 | [-0.2, -0.2], 41 | [1., 0], 42 | [1, 1], 43 | [0, 1], 44 | [0.4, 0.4], 45 | [0.6, 0.6], 46 | ]) 47 | 48 | warped = warp_image_cv(img, c_src, c_dst) 49 | print(warped.shape) 50 | show_warped(img, warped) -------------------------------------------------------------------------------- /generation/test/test_spline_continuous.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from PIL import Image 7 | import cv2 8 | import thinplate as tps 9 | 10 | 11 | def warp_image_cv(img, c_src, c_dst, dshape=None): 12 | dshape = dshape or img.shape 13 | theta = tps.tps_theta_from_points(c_src, c_dst, reduced=True) 14 | grid = tps.tps_grid(theta, c_dst, dshape) 15 | mapx, mapy = tps.tps_grid_to_remap(grid, img.shape) 16 | return cv2.remap(img, mapx, mapy, cv2.INTER_CUBIC) 17 | 18 | img = cv2.imread('image.png') 19 | 20 | c_src = np.array([ 21 | [0.0, 0.0], 22 | [1., 0], 23 | [1, 1], 24 | [0, 1], 25 | [0.3, 0.3], 26 | [0.7, 0.7], 27 | ]) 28 | 29 | c_dst = np.array([ 30 | [-0.2, -0.2], 31 | [1., 0], 32 | [1, 1], 33 | [0, 1], 34 | [0.4, 0.4], 35 | [0.6, 0.6], 36 | ]) 37 | 38 | for i in range(5): 39 | c_this_dst = c_src*(5-i)/5 + c_dst*i/5 40 | warped = warp_image_cv(img, c_src, c_this_dst) 41 | cv2.imwrite('%d.png'%i, warped) -------------------------------------------------------------------------------- /imgs/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/MiVOS/f2600a6eea8709c7b9f1a7575adc725def680b81/imgs/framework.jpg -------------------------------------------------------------------------------- /interact/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/MiVOS/f2600a6eea8709c7b9f1a7575adc725def680b81/interact/__init__.py -------------------------------------------------------------------------------- /interact/fbrs_controller.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fbrs.controller import InteractiveController 3 | from fbrs.inference import utils 4 | 5 | 6 | class FBRSController: 7 | def __init__(self, checkpoint_path, device='cuda:0', max_size=800): 8 | model = utils.load_is_model(checkpoint_path, device, cpu_dist_maps=True, norm_radius=260) 9 | 10 | # Predictor params 11 | zoomin_params = { 12 | 'skip_clicks': 1, 13 | 'target_size': 480, 14 | 'expansion_ratio': 1.4, 15 | } 16 | 17 | predictor_params = { 18 | 'brs_mode': 'f-BRS-B', 19 | 'prob_thresh': 0.5, 20 | 'zoom_in_params': zoomin_params, 21 | 'predictor_params': { 22 | 'net_clicks_limit': 8, 23 | 'max_size': 800, 24 | }, 25 | 'brs_opt_func_params': {'min_iou_diff': 1e-3}, 26 | 'lbfgs_params': {'maxfun': 20} 27 | } 28 | 29 | self.controller = InteractiveController(model, device, predictor_params) 30 | self.anchored = False 31 | self.device = device 32 | 33 | def unanchor(self): 34 | self.anchored = False 35 | 36 | def interact(self, image, x, y, is_positive): 37 | image = image.to(self.device, non_blocking=True) 38 | if not self.anchored: 39 | self.controller.set_image(image) 40 | self.controller.reset_predictor() 41 | self.anchored = True 42 | 43 | self.controller.add_click(x, y, is_positive) 44 | # return self.controller.result_mask 45 | # return self.controller.probs_history[-1][1] 46 | return (self.controller.probs_history[-1][1]>0.5).float() 47 | 48 | def undo(self): 49 | self.controller.undo_click() 50 | if len(self.controller.probs_history) == 0: 51 | return None 52 | else: 53 | return (self.controller.probs_history[-1][1]>0.5).float() -------------------------------------------------------------------------------- /interact/interactive_utils.py: -------------------------------------------------------------------------------- 1 | # Modifed from https://github.com/seoungwugoh/ivs-demo 2 | 3 | import numpy as np 4 | import os 5 | import copy 6 | import cv2 7 | import glob 8 | 9 | import matplotlib.pyplot as plt 10 | from scipy.ndimage.morphology import binary_erosion, binary_dilation 11 | from PIL import Image, ImageDraw, ImageFont 12 | 13 | import torch 14 | from torchvision import models 15 | from dataset.range_transform import im_normalization 16 | 17 | 18 | def images_to_torch(frames, device): 19 | frames = torch.from_numpy(frames.transpose(0, 3, 1, 2)).float().unsqueeze(0)/255 20 | b, t, c, h, w = frames.shape 21 | for ti in range(t): 22 | frames[0, ti] = im_normalization(frames[0, ti]) 23 | return frames.to(device) 24 | 25 | def load_images(path, min_side=None): 26 | fnames = sorted(glob.glob(os.path.join(path, '*.jpg'))) 27 | if len(fnames) == 0: 28 | fnames = sorted(glob.glob(os.path.join(path, '*.png'))) 29 | frame_list = [] 30 | for i, fname in enumerate(fnames): 31 | if min_side: 32 | image = Image.open(fname).convert('RGB') 33 | w, h = image.size 34 | new_w = (w*min_side//min(w, h)) 35 | new_h = (h*min_side//min(w, h)) 36 | frame_list.append(np.array(image.resize((new_w, new_h), Image.BICUBIC), dtype=np.uint8)) 37 | else: 38 | frame_list.append(np.array(Image.open(fname).convert('RGB'), dtype=np.uint8)) 39 | frames = np.stack(frame_list, axis=0) 40 | return frames 41 | 42 | def load_masks(path, min_side=None): 43 | fnames = sorted(glob.glob(os.path.join(path, '*.png'))) 44 | frame_list = [] 45 | 46 | first_frame = np.array(Image.open(fnames[0])) 47 | binary_mask = (first_frame.max() == 255) 48 | 49 | for i, fname in enumerate(fnames): 50 | if min_side: 51 | image = Image.open(fname) 52 | w, h = image.size 53 | new_w = (w*min_side//min(w, h)) 54 | new_h = (h*min_side//min(w, h)) 55 | frame_list.append(np.array(image.resize((new_w, new_h), Image.NEAREST), dtype=np.uint8)) 56 | else: 57 | frame_list.append(np.array(Image.open(fname), dtype=np.uint8)) 58 | 59 | frames = np.stack(frame_list, axis=0) 60 | if binary_mask: 61 | frames = (frames > 128).astype(np.uint8) 62 | return frames 63 | 64 | def load_video(path, min_side=None): 65 | frame_list = [] 66 | cap = cv2.VideoCapture(path) 67 | while(cap.isOpened()): 68 | _, frame = cap.read() 69 | if frame is None: 70 | break 71 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 72 | if min_side: 73 | h, w = frame.shape[:2] 74 | new_w = (w*min_side//min(w, h)) 75 | new_h = (h*min_side//min(w, h)) 76 | frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_CUBIC) 77 | frame_list.append(frame) 78 | frames = np.stack(frame_list, axis=0) 79 | return frames 80 | 81 | def _pascal_color_map(N=256, normalized=False): 82 | """ 83 | Python implementation of the color map function for the PASCAL VOC data set. 84 | Official Matlab version can be found in the PASCAL VOC devkit 85 | http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit 86 | """ 87 | 88 | def bitget(byteval, idx): 89 | return (byteval & (1 << idx)) != 0 90 | 91 | dtype = 'float32' if normalized else 'uint8' 92 | cmap = np.zeros((N, 3), dtype=dtype) 93 | for i in range(N): 94 | r = g = b = 0 95 | c = i 96 | for j in range(8): 97 | r = r | (bitget(c, 0) << 7 - j) 98 | g = g | (bitget(c, 1) << 7 - j) 99 | b = b | (bitget(c, 2) << 7 - j) 100 | c = c >> 3 101 | 102 | cmap[i] = np.array([r, g, b]) 103 | 104 | cmap = cmap / 255 if normalized else cmap 105 | return cmap 106 | 107 | color_map = [ 108 | [0, 0, 0], 109 | [255, 50, 50], 110 | [50, 255, 50], 111 | [50, 50, 255], 112 | [255, 50, 255], 113 | [50, 255, 255], 114 | [255, 255, 50], 115 | ] 116 | 117 | color_map_np = np.array(color_map) 118 | 119 | def overlay_davis(image, mask, alpha=0.5): 120 | """ Overlay segmentation on top of RGB image. from davis official""" 121 | im_overlay = image.copy() 122 | 123 | colored_mask = color_map_np[mask] 124 | foreground = image*alpha + (1-alpha)*colored_mask 125 | binary_mask = (mask > 0) 126 | # Compose image 127 | im_overlay[binary_mask] = foreground[binary_mask] 128 | countours = binary_dilation(binary_mask) ^ binary_mask 129 | im_overlay[countours,:] = 0 130 | return im_overlay.astype(image.dtype) 131 | 132 | def overlay_davis_fade(image, mask, alpha=0.5): 133 | im_overlay = image.copy() 134 | 135 | colored_mask = color_map_np[mask] 136 | foreground = image*alpha + (1-alpha)*colored_mask 137 | binary_mask = (mask > 0) 138 | # Compose image 139 | im_overlay[binary_mask] = foreground[binary_mask] 140 | countours = binary_dilation(binary_mask) ^ binary_mask 141 | im_overlay[countours,:] = 0 142 | im_overlay[~binary_mask] = im_overlay[~binary_mask] * 0.6 143 | return im_overlay.astype(image.dtype) -------------------------------------------------------------------------------- /interact/s2m_controller.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from model.s2m.s2m_network import deeplabv3plus_resnet50 as S2M 4 | 5 | from util.tensor_util import pad_divide_by 6 | 7 | 8 | class S2MController: 9 | """ 10 | A controller for Scribble-to-Mask (for user interaction, not for DAVIS) 11 | Takes the image, previous mask, and scribbles to produce a new mask 12 | ignore_class is usually 255 13 | 0 is NOT the ignore class -- it is the label for the background 14 | """ 15 | def __init__(self, s2m_net:S2M, num_objects, ignore_class, device='cuda:0'): 16 | self.s2m_net = s2m_net 17 | self.num_objects = num_objects 18 | self.ignore_class = ignore_class 19 | self.device = device 20 | 21 | def interact(self, image, prev_mask, scr_mask): 22 | image = image.to(self.device, non_blocking=True) 23 | prev_mask = prev_mask.to(self.device, non_blocking=True) 24 | 25 | h, w = image.shape[-2:] 26 | unaggre_mask = torch.zeros((self.num_objects, 1, h, w), dtype=torch.float32, device=image.device) 27 | 28 | for ki in range(1, self.num_objects+1): 29 | p_srb = (scr_mask==ki).astype(np.uint8) 30 | n_srb = ((scr_mask!=ki) * (scr_mask!=self.ignore_class)).astype(np.uint8) 31 | 32 | Rs = torch.from_numpy(np.stack([p_srb, n_srb], 0)).unsqueeze(0).float().to(image.device) 33 | Rs, _ = pad_divide_by(Rs, 16, Rs.shape[-2:]) 34 | 35 | inputs = torch.cat([image, (prev_mask==ki).float().unsqueeze(0), Rs], 1) 36 | unaggre_mask[ki-1] = torch.sigmoid(self.s2m_net(inputs)) 37 | 38 | return unaggre_mask -------------------------------------------------------------------------------- /interact/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | class Timer: 4 | def __init__(self): 5 | self._acc_time = 0 6 | self._paused = True 7 | 8 | def start(self): 9 | if self._paused: 10 | self.last_time = time.time() 11 | self._paused = False 12 | return self 13 | 14 | def pause(self): 15 | self.count() 16 | self._paused = True 17 | return self 18 | 19 | def count(self): 20 | if self._paused: 21 | return self._acc_time 22 | t = time.time() 23 | self._acc_time += t - self.last_time 24 | self.last_time = t 25 | return self._acc_time 26 | 27 | def format(self): 28 | # count = int(self.count()*100) 29 | # return '%02d:%02d:%02d' % (count//6000, (count//100)%60, count%100) 30 | return '%03.2f' % self.count() 31 | 32 | def __str__(self): 33 | return self.format() -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/MiVOS/f2600a6eea8709c7b9f1a7575adc725def680b81/model/__init__.py -------------------------------------------------------------------------------- /model/aggregate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def aggregate_sbg(prob, keep_bg=False, hard=False): 5 | device = prob.device 6 | k, _, h, w = prob.shape 7 | ex_prob = torch.zeros((k+1, 1, h, w), device=device) 8 | ex_prob[0] = 0.5 9 | ex_prob[1:] = prob 10 | ex_prob = torch.clamp(ex_prob, 1e-7, 1-1e-7) 11 | logits = torch.log((ex_prob /(1-ex_prob))) 12 | 13 | if hard: 14 | # Very low temperature o((⊙﹏⊙))o 🥶 15 | logits *= 1000 16 | 17 | if keep_bg: 18 | return F.softmax(logits, dim=0) 19 | else: 20 | return F.softmax(logits, dim=0)[1:] 21 | 22 | def aggregate_wbg(prob, keep_bg=False, hard=False): 23 | k, _, h, w = prob.shape 24 | new_prob = torch.cat([ 25 | torch.prod(1-prob, dim=0, keepdim=True), 26 | prob 27 | ], 0).clamp(1e-7, 1-1e-7) 28 | logits = torch.log((new_prob /(1-new_prob))) 29 | 30 | if hard: 31 | # Very low temperature o((⊙﹏⊙))o 🥶 32 | logits *= 1000 33 | 34 | if keep_bg: 35 | return F.softmax(logits, dim=0) 36 | else: 37 | return F.softmax(logits, dim=0)[1:] 38 | 39 | def aggregate_wbg_channel(prob, keep_bg=False, hard=False): 40 | new_prob = torch.cat([ 41 | torch.prod(1-prob, dim=1, keepdim=True), 42 | prob 43 | ], 1).clamp(1e-7, 1-1e-7) 44 | logits = torch.log((new_prob /(1-new_prob))) 45 | 46 | if hard: 47 | # Very low temperature o((⊙﹏⊙))o 🥶 48 | logits *= 1000 49 | 50 | if keep_bg: 51 | return logits, F.softmax(logits, dim=1) 52 | else: 53 | return logits, F.softmax(logits, dim=1)[:, 1:] -------------------------------------------------------------------------------- /model/attn_network.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modifed from the original STM code https://github.com/seoungwugoh/STM 3 | """ 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from model.propagation.modules import MaskRGBEncoder, RGBEncoder, KeyValue 11 | 12 | class AttentionMemory(nn.Module): 13 | def __init__(self, k=50): 14 | super().__init__() 15 | self.k = k 16 | 17 | def forward(self, mk, qk): 18 | B, CK, H, W = mk.shape 19 | 20 | mk = mk.view(B, CK, H*W) 21 | mk = torch.transpose(mk, 1, 2) # B * HW * CK 22 | 23 | qk = qk.view(B, CK, H*W).expand(B, -1, -1) / math.sqrt(CK) # B * CK * HW 24 | 25 | affinity = torch.bmm(mk, qk) # B * HW * HW 26 | affinity = F.softmax(affinity, dim=1) 27 | 28 | return affinity 29 | 30 | class AttentionReadNetwork(nn.Module): 31 | def __init__(self): 32 | super().__init__() 33 | self.mask_rgb_encoder = MaskRGBEncoder() 34 | self.rgb_encoder = RGBEncoder() 35 | 36 | self.kv_m_f16 = KeyValue(1024, keydim=128, valdim=512) 37 | self.kv_q_f16 = KeyValue(1024, keydim=128, valdim=512) 38 | self.memory = AttentionMemory() 39 | 40 | for p in self.parameters(): 41 | p.requires_grad = False 42 | 43 | def get_segment(self, f16, qk): 44 | k16, _ = self.kv_m_f16(f16) 45 | p = self.memory(k16, qk) 46 | return p 47 | 48 | def forward(self, image, mask11, mask21, mask12, mask22, query_image): 49 | b, _, h, w = mask11.shape 50 | nh = h//16 51 | nw = w//16 52 | 53 | with torch.no_grad(): 54 | pos_mask1 = (mask21-mask11).clamp(0, 1) 55 | neg_mask1 = (mask11-mask21).clamp(0, 1) 56 | pos_mask2 = (mask22-mask12).clamp(0, 1) 57 | neg_mask2 = (mask12-mask22).clamp(0, 1) 58 | 59 | f16_1 = self.mask_rgb_encoder(image, mask21, mask22) 60 | f16_2 = self.mask_rgb_encoder(image, mask22, mask21) 61 | 62 | qf16, _, _ = self.rgb_encoder(query_image) 63 | qk16, _ = self.kv_q_f16(qf16) 64 | 65 | W1 = self.get_segment(f16_1, qk16) 66 | W2 = self.get_segment(f16_2, qk16) 67 | 68 | pos_map1 = (F.interpolate(pos_mask1, size=(nh,nw), mode='area').view(b, 1, nh*nw) @ W1) 69 | neg_map1 = (F.interpolate(neg_mask1, size=(nh,nw), mode='area').view(b, 1, nh*nw) @ W1) 70 | attn_map1 = torch.cat([pos_map1, neg_map1], 1) 71 | attn_map1 = attn_map1.reshape(b, 2, nh, nw) 72 | attn_map1 = F.interpolate(attn_map1, mode='bilinear', size=(h,w), align_corners=False) 73 | 74 | pos_map2 = (F.interpolate(pos_mask2, size=(nh,nw), mode='area').view(b, 1, nh*nw) @ W2) 75 | neg_map2 = (F.interpolate(neg_mask2, size=(nh,nw), mode='area').view(b, 1, nh*nw) @ W2) 76 | attn_map2 = torch.cat([pos_map2, neg_map2], 1) 77 | attn_map2 = attn_map2.reshape(b, 2, nh, nw) 78 | attn_map2 = F.interpolate(attn_map2, mode='bilinear', size=(h,w), align_corners=False) 79 | 80 | return attn_map1, attn_map2 81 | -------------------------------------------------------------------------------- /model/fusion_net.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class FusionNet(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | self.conv1 = nn.Sequential( 13 | nn.Conv2d(9, 32, kernel_size=3, padding=1, stride=1), 14 | nn.ReLU(), 15 | ) 16 | 17 | self.conv2 = nn.Sequential( 18 | nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1), 19 | nn.ReLU(), 20 | nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1), 21 | ) 22 | 23 | self.conv3 = nn.Sequential( 24 | nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1), 25 | nn.ReLU(), 26 | nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1), 27 | ) 28 | 29 | self.relu = nn.ReLU() 30 | self.final_conv = nn.Conv2d(32, 1, kernel_size=3, padding=1, stride=1) 31 | 32 | def forward(self, im, seg1, seg2, attn, time): 33 | h, w = im.shape[-2:] 34 | 35 | time = time.unsqueeze(2).unsqueeze(2) 36 | time = time.expand(-1, -1, h, w) 37 | 38 | x = torch.cat([im, seg1, seg2, attn, time], 1) 39 | 40 | x = self.conv1(x) 41 | 42 | r = self.conv2(x) 43 | x = self.relu(x + r) 44 | 45 | r = self.conv3(x) 46 | x = self.relu(x + r) 47 | 48 | x = self.final_conv(x) 49 | 50 | return x 51 | -------------------------------------------------------------------------------- /model/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from util.tensor_util import compute_tensor_iu 5 | from collections import defaultdict 6 | 7 | 8 | def get_iou_hook(values): 9 | return 'iou/iou', (values['hide_iou/i']+1)/(values['hide_iou/u']+1) 10 | 11 | def get_sec_iou_hook(values): 12 | return 'iou/sec_iou', (values['hide_iou/sec_i']+1)/(values['hide_iou/sec_u']+1) 13 | 14 | iou_hooks = [ 15 | get_iou_hook, 16 | get_sec_iou_hook, 17 | ] 18 | 19 | 20 | # https://stackoverflow.com/questions/63735255/how-do-i-compute-bootstrapped-cross-entropy-loss-in-pytorch 21 | class BootstrappedCE(nn.Module): 22 | def __init__(self, start_warm, end_warm, top_p=0.15): 23 | super().__init__() 24 | 25 | self.start_warm = start_warm 26 | self.end_warm = end_warm 27 | self.top_p = top_p 28 | 29 | def forward(self, input, target, it): 30 | if it < self.start_warm: 31 | return F.cross_entropy(input, target), 1.0 32 | 33 | raw_loss = F.cross_entropy(input, target, reduction='none').view(-1) 34 | num_pixels = raw_loss.numel() 35 | 36 | if it > self.end_warm: 37 | this_p = self.top_p 38 | else: 39 | this_p = self.top_p + (1-self.top_p)*((self.end_warm-it)/(self.end_warm-self.start_warm)) 40 | loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False) 41 | return loss.mean(), this_p 42 | 43 | 44 | class LossComputer: 45 | def __init__(self, para): 46 | super().__init__() 47 | self.para = para 48 | self.bce = BootstrappedCE(start_warm=int(para['iterations']*0.2), end_warm=int(para['iterations']*0.5)) 49 | 50 | def compute(self, data, it): 51 | losses = defaultdict(int) 52 | 53 | b = data['gt'].shape[0] 54 | selector = data.get('selector', None) 55 | selector = data['selector'] 56 | 57 | for j in range(b): 58 | if selector[j][1] > 0.5: 59 | loss, p = self.bce(data['logits'][j:j+1], data['cls_gt'][j:j+1], it) 60 | else: 61 | loss, p = self.bce(data['logits'][j:j+1,:2], data['cls_gt'][j:j+1], it) 62 | 63 | losses['total_loss'] += loss / b 64 | losses['p'] += p / b 65 | 66 | new_total_i, new_total_u = compute_tensor_iu(data['mask'][:,1:2]>0.5, data['gt']>0.5) 67 | losses['hide_iou/i'] += new_total_i 68 | losses['hide_iou/u'] += new_total_u 69 | 70 | if selector is not None: 71 | new_total_i, new_total_u = compute_tensor_iu(data['mask'][:,2:3]>0.5, data['gt2']>0.5) 72 | losses['hide_iou/sec_i'] += new_total_i 73 | losses['hide_iou/sec_u'] += new_total_u 74 | 75 | return losses -------------------------------------------------------------------------------- /model/propagation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/MiVOS/f2600a6eea8709c7b9f1a7575adc725def680b81/model/propagation/__init__.py -------------------------------------------------------------------------------- /model/propagation/mod_resnet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.utils import model_zoo 8 | 9 | def load_weights_sequential(target, source_state, extra_chan=1): 10 | 11 | new_dict = OrderedDict() 12 | 13 | for k1, v1 in target.state_dict().items(): 14 | if not 'num_batches_tracked' in k1: 15 | if k1 in source_state: 16 | tar_v = source_state[k1] 17 | 18 | if v1.shape != tar_v.shape: 19 | # Init the new segmentation channel with zeros 20 | # print(v1.shape, tar_v.shape) 21 | c, _, w, h = v1.shape 22 | tar_v = torch.cat([ 23 | tar_v, 24 | torch.zeros((c,extra_chan,w,h)), 25 | ], 1) 26 | 27 | new_dict[k1] = tar_v 28 | elif 'bias' not in k1: 29 | print('Not OK', k1) 30 | 31 | target.load_state_dict(new_dict, strict=False) 32 | 33 | 34 | model_urls = { 35 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 36 | } 37 | 38 | 39 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 40 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 41 | padding=dilation, dilation=dilation) 42 | 43 | 44 | class BasicBlock(nn.Module): 45 | expansion = 1 46 | 47 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 48 | super(BasicBlock, self).__init__() 49 | self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation) 50 | self.bn1 = nn.BatchNorm2d(planes) 51 | self.relu = nn.ReLU(inplace=True) 52 | self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation) 53 | self.bn2 = nn.BatchNorm2d(planes) 54 | self.downsample = downsample 55 | self.stride = stride 56 | 57 | def forward(self, x): 58 | residual = x 59 | 60 | out = self.conv1(x) 61 | out = self.bn1(out) 62 | out = self.relu(out) 63 | 64 | out = self.conv2(out) 65 | out = self.bn2(out) 66 | 67 | if self.downsample is not None: 68 | residual = self.downsample(x) 69 | 70 | out += residual 71 | out = self.relu(out) 72 | 73 | return out 74 | 75 | 76 | class Bottleneck(nn.Module): 77 | expansion = 4 78 | 79 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 80 | super(Bottleneck, self).__init__() 81 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1) 82 | self.bn1 = nn.BatchNorm2d(planes) 83 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation, 84 | padding=dilation) 85 | self.bn2 = nn.BatchNorm2d(planes) 86 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1) 87 | self.bn3 = nn.BatchNorm2d(planes * 4) 88 | self.relu = nn.ReLU(inplace=True) 89 | self.downsample = downsample 90 | self.stride = stride 91 | 92 | def forward(self, x): 93 | residual = x 94 | 95 | out = self.conv1(x) 96 | out = self.bn1(out) 97 | out = self.relu(out) 98 | 99 | out = self.conv2(out) 100 | out = self.bn2(out) 101 | out = self.relu(out) 102 | 103 | out = self.conv3(out) 104 | out = self.bn3(out) 105 | 106 | if self.downsample is not None: 107 | residual = self.downsample(x) 108 | 109 | out += residual 110 | out = self.relu(out) 111 | 112 | return out 113 | 114 | 115 | class ResNet(nn.Module): 116 | def __init__(self, block, layers=(3, 4, 23, 3), extra_chan=1): 117 | self.inplanes = 64 118 | super(ResNet, self).__init__() 119 | self.conv1 = nn.Conv2d(3+extra_chan, 64, kernel_size=7, stride=2, padding=3) 120 | self.bn1 = nn.BatchNorm2d(64) 121 | self.relu = nn.ReLU(inplace=True) 122 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 123 | self.layer1 = self._make_layer(block, 64, layers[0]) 124 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 125 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 126 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 127 | 128 | for m in self.modules(): 129 | if isinstance(m, nn.Conv2d): 130 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 131 | m.weight.data.normal_(0, math.sqrt(2. / n)) 132 | m.bias.data.zero_() 133 | elif isinstance(m, nn.BatchNorm2d): 134 | m.weight.data.fill_(1) 135 | m.bias.data.zero_() 136 | 137 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 138 | downsample = None 139 | if stride != 1 or self.inplanes != planes * block.expansion: 140 | downsample = nn.Sequential( 141 | nn.Conv2d(self.inplanes, planes * block.expansion, 142 | kernel_size=1, stride=stride), 143 | nn.BatchNorm2d(planes * block.expansion), 144 | ) 145 | 146 | layers = [block(self.inplanes, planes, stride, downsample)] 147 | self.inplanes = planes * block.expansion 148 | for i in range(1, blocks): 149 | layers.append(block(self.inplanes, planes, dilation=dilation)) 150 | 151 | return nn.Sequential(*layers) 152 | 153 | def resnet50(pretrained=True, extra_chan=0): 154 | model = ResNet(Bottleneck, [3, 4, 6, 3], extra_chan) 155 | if pretrained: 156 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet50']), extra_chan) 157 | return model 158 | 159 | -------------------------------------------------------------------------------- /model/propagation/modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modifed from the original STM code https://github.com/seoungwugoh/STM 3 | """ 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.model_zoo as model_zoo 10 | from torchvision import models 11 | 12 | from model.propagation import mod_resnet 13 | 14 | 15 | class ResBlock(nn.Module): 16 | def __init__(self, indim, outdim=None): 17 | super(ResBlock, self).__init__() 18 | if outdim == None: 19 | outdim = indim 20 | if indim == outdim: 21 | self.downsample = None 22 | else: 23 | self.downsample = nn.Conv2d(indim, outdim, kernel_size=3, padding=1) 24 | 25 | self.conv1 = nn.Conv2d(indim, outdim, kernel_size=3, padding=1) 26 | self.conv2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1) 27 | 28 | def forward(self, x): 29 | r = self.conv1(F.relu(x)) 30 | r = self.conv2(F.relu(r)) 31 | 32 | if self.downsample is not None: 33 | x = self.downsample(x) 34 | 35 | return x + r 36 | 37 | 38 | class MaskRGBEncoder(nn.Module): 39 | def __init__(self): 40 | super().__init__() 41 | 42 | resnet = mod_resnet.resnet50(pretrained=True, extra_chan=2) 43 | self.conv1 = resnet.conv1 44 | self.bn1 = resnet.bn1 45 | self.relu = resnet.relu # 1/2, 64 46 | self.maxpool = resnet.maxpool 47 | 48 | self.layer1 = resnet.layer1 # 1/4, 256 49 | self.layer2 = resnet.layer2 # 1/8, 512 50 | self.layer3 = resnet.layer3 # 1/16, 1024 51 | 52 | def forward(self, f, m, o): 53 | 54 | f = torch.cat([f, m, o], 1) 55 | 56 | x = self.conv1(f) 57 | x = self.bn1(x) 58 | x = self.relu(x) # 1/2, 64 59 | x = self.maxpool(x) # 1/4, 64 60 | x = self.layer1(x) # 1/4, 256 61 | x = self.layer2(x) # 1/8, 512 62 | x = self.layer3(x) # 1/16, 1024 63 | 64 | return x 65 | 66 | 67 | class RGBEncoder(nn.Module): 68 | def __init__(self): 69 | super().__init__() 70 | resnet = models.resnet50(pretrained=True) 71 | self.conv1 = resnet.conv1 72 | self.bn1 = resnet.bn1 73 | self.relu = resnet.relu # 1/2, 64 74 | self.maxpool = resnet.maxpool 75 | 76 | self.res2 = resnet.layer1 # 1/4, 256 77 | self.layer2 = resnet.layer2 # 1/8, 512 78 | self.layer3 = resnet.layer3 # 1/16, 1024 79 | 80 | def forward(self, f): 81 | x = self.conv1(f) 82 | x = self.bn1(x) 83 | x = self.relu(x) # 1/2, 64 84 | x = self.maxpool(x) # 1/4, 64 85 | f4 = self.res2(x) # 1/4, 256 86 | f8 = self.layer2(f4) # 1/8, 512 87 | f16 = self.layer3(f8) # 1/16, 1024 88 | 89 | return f16, f8, f4 90 | 91 | 92 | class UpsampleBlock(nn.Module): 93 | def __init__(self, skip_c, up_c, out_c, scale_factor=2): 94 | super().__init__() 95 | self.skip_conv1 = nn.Conv2d(skip_c, up_c, kernel_size=3, padding=1) 96 | self.skip_conv2 = ResBlock(up_c, up_c) 97 | self.out_conv = ResBlock(up_c, out_c) 98 | self.scale_factor = scale_factor 99 | 100 | def forward(self, skip_f, up_f): 101 | x = self.skip_conv2(self.skip_conv1(skip_f)) 102 | x = x + F.interpolate(up_f, scale_factor=self.scale_factor, mode='bilinear', align_corners=False) 103 | x = self.out_conv(x) 104 | return x 105 | 106 | 107 | class KeyValue(nn.Module): 108 | def __init__(self, indim, keydim, valdim): 109 | super().__init__() 110 | self.key_proj = nn.Conv2d(indim, keydim, kernel_size=3, padding=1) 111 | self.val_proj = nn.Conv2d(indim, valdim, kernel_size=3, padding=1) 112 | 113 | def forward(self, x): 114 | return self.key_proj(x), self.val_proj(x) 115 | -------------------------------------------------------------------------------- /model/s2m/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/MiVOS/f2600a6eea8709c7b9f1a7575adc725def680b81/model/s2m/__init__.py -------------------------------------------------------------------------------- /model/s2m/s2m_network.py: -------------------------------------------------------------------------------- 1 | # Credit: https://github.com/VainF/DeepLabV3Plus-Pytorch 2 | 3 | from model.s2m.utils import IntermediateLayerGetter 4 | from model.s2m._deeplab import DeepLabHead, DeepLabHeadV3Plus, DeepLabV3 5 | from model.s2m import s2m_resnet 6 | 7 | def _segm_resnet(name, backbone_name, num_classes, output_stride, pretrained_backbone): 8 | 9 | if output_stride==8: 10 | replace_stride_with_dilation=[False, True, True] 11 | aspp_dilate = [12, 24, 36] 12 | else: 13 | replace_stride_with_dilation=[False, False, True] 14 | aspp_dilate = [6, 12, 18] 15 | 16 | backbone = s2m_resnet.__dict__[backbone_name]( 17 | pretrained=pretrained_backbone, 18 | replace_stride_with_dilation=replace_stride_with_dilation) 19 | 20 | inplanes = 2048 21 | low_level_planes = 256 22 | 23 | if name=='deeplabv3plus': 24 | return_layers = {'layer4': 'out', 'layer1': 'low_level'} 25 | classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate) 26 | elif name=='deeplabv3': 27 | return_layers = {'layer4': 'out'} 28 | classifier = DeepLabHead(inplanes , num_classes, aspp_dilate) 29 | backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) 30 | 31 | model = DeepLabV3(backbone, classifier) 32 | return model 33 | 34 | def _load_model(arch_type, backbone, num_classes, output_stride, pretrained_backbone): 35 | 36 | if backbone.startswith('resnet'): 37 | model = _segm_resnet(arch_type, backbone, num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) 38 | else: 39 | raise NotImplementedError 40 | return model 41 | 42 | 43 | # Deeplab v3 44 | def deeplabv3_resnet50(num_classes=1, output_stride=16, pretrained_backbone=False): 45 | """Constructs a DeepLabV3 model with a ResNet-50 backbone. 46 | 47 | Args: 48 | num_classes (int): number of classes. 49 | output_stride (int): output stride for deeplab. 50 | pretrained_backbone (bool): If True, use the pretrained backbone. 51 | """ 52 | return _load_model('deeplabv3', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) 53 | 54 | 55 | # Deeplab v3+ 56 | def deeplabv3plus_resnet50(num_classes=1, output_stride=16, pretrained_backbone=False): 57 | """Constructs a DeepLabV3 model with a ResNet-50 backbone. 58 | 59 | Args: 60 | num_classes (int): number of classes. 61 | output_stride (int): output stride for deeplab. 62 | pretrained_backbone (bool): If True, use the pretrained backbone. 63 | """ 64 | return _load_model('deeplabv3plus', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) 65 | 66 | -------------------------------------------------------------------------------- /model/s2m/utils.py: -------------------------------------------------------------------------------- 1 | # Credit: https://github.com/VainF/DeepLabV3Plus-Pytorch 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import torch.nn.functional as F 7 | from collections import OrderedDict 8 | 9 | class _SimpleSegmentationModel(nn.Module): 10 | def __init__(self, backbone, classifier): 11 | super(_SimpleSegmentationModel, self).__init__() 12 | self.backbone = backbone 13 | self.classifier = classifier 14 | 15 | def forward(self, x): 16 | input_shape = x.shape[-2:] 17 | features = self.backbone(x) 18 | x = self.classifier(features) 19 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) 20 | return x 21 | 22 | 23 | class IntermediateLayerGetter(nn.ModuleDict): 24 | """ 25 | Module wrapper that returns intermediate layers from a model 26 | 27 | It has a strong assumption that the modules have been registered 28 | into the model in the same order as they are used. 29 | This means that one should **not** reuse the same nn.Module 30 | twice in the forward if you want this to work. 31 | 32 | Additionally, it is only able to query submodules that are directly 33 | assigned to the model. So if `model` is passed, `model.feature1` can 34 | be returned, but not `model.feature1.layer2`. 35 | 36 | Arguments: 37 | model (nn.Module): model on which we will extract the features 38 | return_layers (Dict[name, new_name]): a dict containing the names 39 | of the modules for which the activations will be returned as 40 | the key of the dict, and the value of the dict is the name 41 | of the returned activation (which the user can specify). 42 | 43 | Examples:: 44 | 45 | >>> m = torchvision.models.resnet18(pretrained=True) 46 | >>> # extract layer1 and layer3, giving as names `feat1` and feat2` 47 | >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, 48 | >>> {'layer1': 'feat1', 'layer3': 'feat2'}) 49 | >>> out = new_m(torch.rand(1, 3, 224, 224)) 50 | >>> print([(k, v.shape) for k, v in out.items()]) 51 | >>> [('feat1', torch.Size([1, 64, 56, 56])), 52 | >>> ('feat2', torch.Size([1, 256, 14, 14]))] 53 | """ 54 | def __init__(self, model, return_layers): 55 | if not set(return_layers).issubset([name for name, _ in model.named_children()]): 56 | raise ValueError("return_layers are not present in model") 57 | 58 | orig_return_layers = return_layers 59 | return_layers = {k: v for k, v in return_layers.items()} 60 | layers = OrderedDict() 61 | for name, module in model.named_children(): 62 | layers[name] = module 63 | if name in return_layers: 64 | del return_layers[name] 65 | if not return_layers: 66 | break 67 | 68 | super(IntermediateLayerGetter, self).__init__(layers) 69 | self.return_layers = orig_return_layers 70 | 71 | def forward(self, x): 72 | out = OrderedDict() 73 | for name, module in self.named_children(): 74 | x = module(x) 75 | if name in self.return_layers: 76 | out_name = self.return_layers[name] 77 | out[out_name] = x 78 | return out 79 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/MiVOS/f2600a6eea8709c7b9f1a7575adc725def680b81/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/resize_length.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import cv2 4 | 5 | from progressbar import progressbar 6 | 7 | input_dir = sys.argv[1] 8 | output_dir = sys.argv[2] 9 | 10 | # max_length = 500 11 | min_length = 384 12 | 13 | def process_fun(): 14 | 15 | for f in progressbar(os.listdir(input_dir)): 16 | img = cv2.imread(os.path.join(input_dir, f)) 17 | h, w, _ = img.shape 18 | 19 | # scale = max(h, w) / max_length 20 | scale = min(h, w) / min_length 21 | 22 | img = cv2.resize(img, (int(w/scale), int(h/scale)), interpolation=cv2.INTER_AREA) 23 | cv2.imwrite(os.path.join(output_dir, os.path.basename(f)), img) 24 | 25 | if __name__ == '__main__': 26 | 27 | os.makedirs(output_dir, exist_ok=True) 28 | process_fun() 29 | 30 | print('All done.') -------------------------------------------------------------------------------- /scripts/resize_youtube.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from os import path 4 | 5 | from PIL import Image 6 | import numpy as np 7 | from progressbar import progressbar 8 | from multiprocessing import Pool 9 | 10 | new_min_size = 480 11 | 12 | def resize_vid_jpeg(inputs): 13 | vid_name, folder_path, out_path = inputs 14 | 15 | vid_path = path.join(folder_path, vid_name) 16 | vid_out_path = path.join(out_path, 'JPEGImages', vid_name) 17 | os.makedirs(vid_out_path, exist_ok=True) 18 | 19 | for im_name in os.listdir(vid_path): 20 | hr_im = Image.open(path.join(vid_path, im_name)) 21 | w, h = hr_im.size 22 | 23 | ratio = new_min_size / min(w, h) 24 | 25 | lr_im = hr_im.resize((int(w*ratio), int(h*ratio)), Image.BICUBIC) 26 | lr_im.save(path.join(vid_out_path, im_name)) 27 | 28 | def resize_vid_anno(inputs): 29 | vid_name, folder_path, out_path = inputs 30 | 31 | vid_path = path.join(folder_path, vid_name) 32 | vid_out_path = path.join(out_path, 'Annotations', vid_name) 33 | os.makedirs(vid_out_path, exist_ok=True) 34 | 35 | for im_name in os.listdir(vid_path): 36 | hr_im = Image.open(path.join(vid_path, im_name)).convert('P') 37 | w, h = hr_im.size 38 | 39 | ratio = new_min_size / min(w, h) 40 | 41 | lr_im = hr_im.resize((int(w*ratio), int(h*ratio)), Image.NEAREST) 42 | lr_im.save(path.join(vid_out_path, im_name)) 43 | 44 | 45 | def resize_all(in_path, out_path): 46 | for folder in os.listdir(in_path): 47 | 48 | if folder not in ['JPEGImages', 'Annotations']: 49 | continue 50 | folder_path = path.join(in_path, folder) 51 | videos = os.listdir(folder_path) 52 | 53 | videos = [(v, folder_path, out_path) for v in videos] 54 | 55 | if folder == 'JPEGImages': 56 | print('Processing images') 57 | os.makedirs(path.join(out_path, 'JPEGImages'), exist_ok=True) 58 | 59 | pool = Pool(processes=8) 60 | for _ in progressbar(pool.imap_unordered(resize_vid_jpeg, videos), max_value=len(videos)): 61 | pass 62 | else: 63 | print('Processing annotations') 64 | os.makedirs(path.join(out_path, 'Annotations'), exist_ok=True) 65 | 66 | pool = Pool(processes=8) 67 | for _ in progressbar(pool.imap_unordered(resize_vid_anno, videos), max_value=len(videos)): 68 | pass 69 | 70 | 71 | if __name__ == '__main__': 72 | in_path = sys.argv[1] 73 | out_path = sys.argv[2] 74 | 75 | resize_all(in_path, out_path) 76 | 77 | print('Done.') -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from os import path 3 | import datetime 4 | import math 5 | 6 | import random 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import Dataset, DataLoader, ConcatDataset 10 | import torch.multiprocessing as mp 11 | import torch.distributed as distributed 12 | 13 | from model.fusion_model import FusionModel 14 | from dataset.fusion_dataset import FusionDataset 15 | 16 | from util.logger import TensorboardLogger 17 | from util.hyper_para import HyperParameters 18 | from util.load_subset import * 19 | 20 | 21 | """ 22 | Initial setup 23 | """ 24 | torch.backends.cudnn.benchmark = True 25 | 26 | # Init distributed environment 27 | distributed.init_process_group(backend="nccl") 28 | # Set seed to ensure the same initialization 29 | torch.manual_seed(14159265) 30 | np.random.seed(14159265) 31 | random.seed(14159265) 32 | 33 | print('CUDA Device count: ', torch.cuda.device_count()) 34 | 35 | # Parse command line arguments 36 | para = HyperParameters() 37 | para.parse() 38 | 39 | local_rank = torch.distributed.get_rank() 40 | world_size = torch.distributed.get_world_size() 41 | torch.cuda.set_device(local_rank) 42 | 43 | print('I am rank %d in the world of %d!' % (local_rank, world_size)) 44 | 45 | """ 46 | Model related 47 | """ 48 | if local_rank == 0: 49 | # Logging 50 | if para['id'].lower() != 'null': 51 | long_id = '%s_%s' % (datetime.datetime.now().strftime('%b%d_%H.%M.%S'), para['id']) 52 | else: 53 | long_id = None 54 | logger = TensorboardLogger(para['id'], long_id) 55 | logger.log_string('hyperpara', str(para)) 56 | 57 | # Construct rank 0 model 58 | model = FusionModel(para, logger=logger, 59 | save_path=path.join('saves', long_id, long_id) if long_id is not None else None, 60 | local_rank=local_rank, world_size=world_size).train() 61 | else: 62 | # Construct models of other ranks 63 | model = FusionModel(para, local_rank=local_rank, world_size=world_size).train() 64 | 65 | # Load pertrained model if needed 66 | if para['load_model'] is not None: 67 | total_iter = model.load_model(para['load_model']) 68 | else: 69 | total_iter = 0 70 | 71 | if para['load_network'] is not None: 72 | model.load_network(para['load_network']) 73 | 74 | """ 75 | Dataloader related 76 | """ 77 | if para['load_prop'] is None: 78 | print('Fusion module can only be trained with a pre-trained propagation module!') 79 | print('Use --load_prop [model_path]') 80 | raise NotImplementedError 81 | model.load_prop(para['load_prop']) 82 | torch.cuda.empty_cache() 83 | 84 | if para['stage'] == 0: 85 | data_root = path.join(path.expanduser(para['bl_root'])) 86 | train_dataset = FusionDataset(path.join(data_root, 'JPEGImages'), 87 | path.join(data_root, 'Annotations'), para['fusion_bl_root']) 88 | elif para['stage'] == 1: 89 | data_root = path.join(path.expanduser(para['davis_root']), '2017', 'trainval') 90 | train_dataset = FusionDataset(path.join(data_root, 'JPEGImages', '480p'), 91 | path.join(data_root, 'Annotations', '480p'), para['fusion_root']) 92 | 93 | 94 | def worker_init_fn(worker_id): 95 | return np.random.seed((torch.initial_seed()%2**31) + worker_id + local_rank*100) 96 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, rank=local_rank, shuffle=True) 97 | train_loader = DataLoader(train_dataset, para['batch_size'], sampler=train_sampler, num_workers=8, 98 | worker_init_fn=worker_init_fn, drop_last=True, pin_memory=True) 99 | 100 | """ 101 | Determine current/max epoch 102 | """ 103 | start_epoch = total_iter//len(train_loader) 104 | total_epoch = math.ceil(para['iterations']/len(train_loader)) 105 | print('Actual training epoch: ', total_epoch) 106 | 107 | """ 108 | Starts training 109 | """ 110 | # Need this to select random bases in different workers 111 | np.random.seed(np.random.randint(2**30-1) + local_rank*100) 112 | try: 113 | for e in range(start_epoch, total_epoch): 114 | # Crucial for randomness! 115 | train_sampler.set_epoch(e) 116 | 117 | # Train loop 118 | model.train() 119 | for data in train_loader: 120 | model.do_pass(data, total_iter) 121 | total_iter += 1 122 | 123 | if total_iter >= para['iterations']: 124 | break 125 | finally: 126 | if not para['debug'] and model.logger is not None and total_iter > 5000: 127 | model.save(total_iter) 128 | # Clean up 129 | distributed.destroy_process_group() 130 | 131 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/MiVOS/f2600a6eea8709c7b9f1a7575adc725def680b81/util/__init__.py -------------------------------------------------------------------------------- /util/cv2palette.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from PIL import Image 3 | import cv2 4 | import os 5 | 6 | 7 | # A stupid trick to force palette into cv2 8 | def cv2palette(image, palette): 9 | img_E = Image.fromarray(image) 10 | img_E.putpalette(palette) 11 | with tempfile.TemporaryDirectory() as tmppath: 12 | full_path = os.path.join(tmppath, 'temp.png') 13 | img_E.save(full_path) 14 | image = cv2.imread(full_path) 15 | return image -------------------------------------------------------------------------------- /util/davis_subset.txt: -------------------------------------------------------------------------------- 1 | bear 2 | bmx-bumps 3 | boat 4 | boxing-fisheye 5 | breakdance-flare 6 | bus 7 | car-turn 8 | cat-girl 9 | classic-car 10 | color-run 11 | crossing 12 | dance-jump 13 | dancing 14 | disc-jockey 15 | dog-agility 16 | dog-gooses 17 | dogs-scale 18 | drift-turn 19 | drone 20 | elephant 21 | flamingo 22 | hike 23 | hockey 24 | horsejump-low 25 | kid-football 26 | kite-walk 27 | koala 28 | lady-running 29 | lindy-hop 30 | longboard 31 | lucia 32 | mallard-fly 33 | mallard-water 34 | miami-surf 35 | motocross-bumps 36 | motorbike 37 | night-race 38 | paragliding 39 | planes-water 40 | rallye 41 | rhino 42 | rollerblade 43 | schoolgirls 44 | scooter-board 45 | scooter-gray 46 | sheep 47 | skate-park 48 | snowboard 49 | soccerball 50 | stroller 51 | stunt 52 | surf 53 | swing 54 | tennis 55 | tractor-sand 56 | train 57 | tuk-tuk 58 | upside-down 59 | varanus-cage 60 | walking -------------------------------------------------------------------------------- /util/hyper_para.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | def none_or_default(x, default): 5 | return x if x is not None else default 6 | 7 | class HyperParameters(): 8 | def parse(self, unknown_arg_ok=False): 9 | parser = ArgumentParser() 10 | 11 | # Data parameters 12 | parser.add_argument('--yv_root', help='YouTubeVOS data root', default='../YouTube') 13 | parser.add_argument('--davis_root', help='DAVIS data root', default='../DAVIS') 14 | parser.add_argument('--bl_root', default='../BL30K') 15 | 16 | parser.add_argument('--fusion_root', default='../fusion_data/davis') 17 | parser.add_argument('--fusion_bl_root', default='../fusion_data/bl') 18 | 19 | parser.add_argument('--stage', type=int, default=0) 20 | 21 | # Generic learning parameters 22 | parser.add_argument('-i', '--iterations', help='Number of training iterations', default=None, type=int) 23 | parser.add_argument('-b', '--batch_size', help='Batch size', default=12, type=int) 24 | parser.add_argument('--lr', help='Initial learning rate', default=1e-4, type=float) 25 | parser.add_argument('--steps', help='Iteration at which learning rate is decayed by gamma', default=None, type=int, nargs='*') 26 | parser.add_argument('--gamma', help='Gamma used in learning rate decay', default=0.1, type=float) 27 | 28 | # Loading 29 | parser.add_argument('--load_network', help='Path to pretrained network weight only') 30 | parser.add_argument('--load_prop', default='saves/propagation_model.pth') 31 | parser.add_argument('--load_model', help='Path to the model file, including network, optimizer and such') 32 | 33 | # Logging information 34 | parser.add_argument('--id', help='Experiment UNIQUE id, use NULL to disable logging to tensorboard', default='NULL') 35 | parser.add_argument('--debug', help='Debug mode which logs information more often', action='store_true') 36 | 37 | # Multiprocessing parameters 38 | parser.add_argument('--local_rank', default=0, type=int, help='Local rank of this process') 39 | 40 | if unknown_arg_ok: 41 | args, _ = parser.parse_known_args() 42 | self.args = vars(args) 43 | else: 44 | self.args = vars(parser.parse_args()) 45 | 46 | # Stage-dependent hyperparameters 47 | # Assign default if not given 48 | if self.args['stage'] == 0: 49 | self.args['iterations'] = none_or_default(self.args['iterations'], 30000) 50 | self.args['steps'] = none_or_default(self.args['steps'], [20000]) 51 | elif self.args['stage'] == 1: 52 | self.args['iterations'] = none_or_default(self.args['iterations'], 10000) 53 | self.args['steps'] = none_or_default(self.args['steps'], [7500]) 54 | 55 | def __getitem__(self, key): 56 | return self.args[key] 57 | 58 | def __str__(self): 59 | return str(self.args) 60 | -------------------------------------------------------------------------------- /util/image_saver.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | import torch 5 | import torchvision.transforms as transforms 6 | from dataset.range_transform import inv_im_trans 7 | from collections import defaultdict 8 | 9 | def tensor_to_numpy(image): 10 | image_np = (image.numpy() * 255).astype('uint8') 11 | return image_np 12 | 13 | def tensor_to_np_float(image): 14 | image_np = image.numpy().astype('float32') 15 | return image_np 16 | 17 | def detach_to_cpu(x): 18 | return x.detach().cpu() 19 | 20 | def transpose_np(x): 21 | return np.transpose(x, [1,2,0]) 22 | 23 | def tensor_to_gray_im(x): 24 | x = detach_to_cpu(x) 25 | x = tensor_to_numpy(x) 26 | x = transpose_np(x) 27 | return x 28 | 29 | def tensor_to_im(x): 30 | x = detach_to_cpu(x) 31 | x = inv_im_trans(x).clamp(0, 1) 32 | x = tensor_to_numpy(x) 33 | x = transpose_np(x) 34 | return x 35 | 36 | # Predefined key <-> caption dict 37 | key_captions = { 38 | 'im': 'Image', 39 | 'gt': 'GT', 40 | } 41 | 42 | """ 43 | Return an image array with captions 44 | keys in dictionary will be used as caption if not provided 45 | values should contain lists of cv2 images 46 | """ 47 | def get_image_array(images, grid_shape, captions={}): 48 | h, w = grid_shape 49 | cate_counts = len(images) 50 | rows_counts = len(next(iter(images.values()))) 51 | 52 | font = cv2.FONT_HERSHEY_SIMPLEX 53 | 54 | output_image = np.zeros([w*cate_counts, h*(rows_counts+1), 3], dtype=np.uint8) 55 | col_cnt = 0 56 | for k, v in images.items(): 57 | 58 | # Default as key value itself 59 | caption = captions.get(k, k) 60 | 61 | # Handles new line character 62 | dy = 40 63 | for i, line in enumerate(caption.split('\n')): 64 | if h > 200: 65 | cv2.putText(output_image, line, (10, col_cnt*w+100+i*dy), 66 | font, 0.8, (255,255,255), 2, cv2.LINE_AA) 67 | else: 68 | cv2.putText(output_image, line, (10, col_cnt*w+10+i*dy), 69 | font, 0.4, (255,255,255), 1, cv2.LINE_AA) 70 | 71 | # Put images 72 | for row_cnt, img in enumerate(v): 73 | im_shape = img.shape 74 | if len(im_shape) == 2: 75 | img = img[..., np.newaxis] 76 | 77 | img = (img * 255).astype('uint8') 78 | 79 | output_image[(col_cnt+0)*w:(col_cnt+1)*w, 80 | (row_cnt+1)*h:(row_cnt+2)*h, :] = img 81 | 82 | col_cnt += 1 83 | 84 | return output_image 85 | 86 | def base_transform(im, size): 87 | im = tensor_to_np_float(im) 88 | if len(im.shape) == 3: 89 | im = im.transpose((1, 2, 0)) 90 | else: 91 | im = im[:, :, None] 92 | 93 | # Resize 94 | if size is not None and im.shape[1] != size: 95 | im = cv2.resize(im, size, interpolation=cv2.INTER_NEAREST) 96 | 97 | return im.clip(0, 1) 98 | 99 | def im_transform(im, size): 100 | return base_transform(inv_im_trans(detach_to_cpu(im)), size=size) 101 | 102 | def mask_transform(mask, size): 103 | return base_transform(detach_to_cpu(mask), size=size) 104 | 105 | def out_transform(mask, size): 106 | return base_transform(detach_to_cpu(torch.sigmoid(mask)), size=size) 107 | 108 | def get_click_points(image, pos_map, neg_map): 109 | image[pos_map<0.02, :] = [0, 1, 0] 110 | image[neg_map<0.02, :] = [1, 0, 0] 111 | 112 | return image 113 | 114 | def get_clicked_torch(image, pos_map, neg_map): 115 | rgb = im_transform(image, None) 116 | pos_map = mask_transform(pos_map, None)[:, :, 0] 117 | neg_map = mask_transform(neg_map, None)[:, :, 0] 118 | 119 | rgb[pos_map<0.02, :] = [0, 1, 0] 120 | rgb[neg_map<0.02, :] = [1, 0, 0] 121 | 122 | return (rgb*255).astype(np.uint8) 123 | 124 | def pool_fusion(images, size): 125 | req_images = defaultdict(list) 126 | 127 | b = images['gt'].shape[0] 128 | 129 | # Save storage 130 | b = max(4, b) 131 | 132 | GT_name = 'GT' 133 | 134 | for b_idx in range(b): 135 | req_images['RGB'].append(im_transform(images['rgb'][b_idx], size)) 136 | req_images['S11'].append(mask_transform(images['seg1'][b_idx], size)) 137 | req_images['S21'].append(mask_transform(images['seg2'][b_idx], size)) 138 | req_images['S12'].append(mask_transform(images['seg12'][b_idx], size)) 139 | req_images['S22'].append(mask_transform(images['seg22'][b_idx], size)) 140 | req_images['Pos Attn1'].append(mask_transform(images['attn1'][b_idx,0:1], size)) 141 | req_images['Neg Attn1'].append(mask_transform(images['attn1'][b_idx,1:2], size)) 142 | req_images['Pos Attn2'].append(mask_transform(images['attn2'][b_idx,0:1], size)) 143 | req_images['Neg Attn2'].append(mask_transform(images['attn2'][b_idx,1:2], size)) 144 | 145 | req_images['MSK1'].append(mask_transform(images['mask'][b_idx,1:2], size)) 146 | req_images['MSK2'].append(mask_transform(images['mask'][b_idx,2:3], size)) 147 | 148 | req_images[GT_name+'1'].append(mask_transform(images['gt'][b_idx], size)) 149 | req_images[GT_name+'2'].append(mask_transform(images['gt2'][b_idx], size)) 150 | 151 | return get_image_array(req_images, size, key_captions) -------------------------------------------------------------------------------- /util/load_subset.py: -------------------------------------------------------------------------------- 1 | def load_sub_davis(path='util/davis_subset.txt'): 2 | with open(path, mode='r') as f: 3 | subset = set(f.read().splitlines()) 4 | return subset 5 | 6 | def load_sub_yv(path='util/yv_subset.txt'): 7 | with open(path, mode='r') as f: 8 | subset = set(f.read().splitlines()) 9 | return subset 10 | -------------------------------------------------------------------------------- /util/log_integrator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Integrate numerical values for some iterations 3 | Typically used for loss computation / logging to tensorboard 4 | Call finalize and create a new Integrator when you want to display/log 5 | """ 6 | 7 | import torch 8 | 9 | 10 | class Integrator: 11 | def __init__(self, logger, distributed=True, local_rank=0, world_size=1): 12 | self.values = {} 13 | self.counts = {} 14 | self.hooks = [] # List is used here to maintain insertion order 15 | 16 | self.logger = logger 17 | 18 | self.distributed = distributed 19 | self.local_rank = local_rank 20 | self.world_size = world_size 21 | 22 | def add_tensor(self, key, tensor): 23 | if key not in self.values: 24 | self.counts[key] = 1 25 | if type(tensor) == float or type(tensor) == int: 26 | self.values[key] = tensor 27 | else: 28 | self.values[key] = tensor.mean().item() 29 | else: 30 | self.counts[key] += 1 31 | if type(tensor) == float or type(tensor) == int: 32 | self.values[key] += tensor 33 | else: 34 | self.values[key] += tensor.mean().item() 35 | 36 | def add_dict(self, tensor_dict): 37 | for k, v in tensor_dict.items(): 38 | self.add_tensor(k, v) 39 | 40 | def add_hook(self, hook): 41 | """ 42 | Adds a custom hook, i.e. compute new metrics using values in the dict 43 | The hook takes the dict as argument, and returns a (k, v) tuple 44 | e.g. for computing IoU 45 | """ 46 | if type(hook) == list: 47 | self.hooks.extend(hook) 48 | else: 49 | self.hooks.append(hook) 50 | 51 | def reset_except_hooks(self): 52 | self.values = {} 53 | self.counts = {} 54 | 55 | # Average and output the metrics 56 | def finalize(self, prefix, it, f=None): 57 | 58 | for hook in self.hooks: 59 | k, v = hook(self.values) 60 | self.add_tensor(k, v) 61 | 62 | for k, v in self.values.items(): 63 | 64 | if k[:4] == 'hide': 65 | continue 66 | 67 | avg = v / self.counts[k] 68 | 69 | if self.distributed: 70 | # Inplace operation 71 | avg = torch.tensor(avg).cuda() 72 | torch.distributed.reduce(avg, dst=0) 73 | 74 | if self.local_rank == 0: 75 | avg = (avg/self.world_size).cpu().item() 76 | self.logger.log_metrics(prefix, k, avg, it, f) 77 | else: 78 | # Simple does it 79 | self.logger.log_metrics(prefix, k, avg, it, f) 80 | 81 | -------------------------------------------------------------------------------- /util/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dumps things to tensorboard and console 3 | """ 4 | 5 | import os 6 | import warnings 7 | import git 8 | 9 | import torchvision.transforms as transforms 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | 13 | def tensor_to_numpy(image): 14 | image_np = (image.numpy() * 255).astype('uint8') 15 | return image_np 16 | 17 | def detach_to_cpu(x): 18 | return x.detach().cpu() 19 | 20 | def fix_width_trunc(x): 21 | return ('{:.9s}'.format('{:0.9f}'.format(x))) 22 | 23 | class TensorboardLogger: 24 | def __init__(self, short_id, id): 25 | self.short_id = short_id 26 | if self.short_id == 'NULL': 27 | self.short_id = 'DEBUG' 28 | 29 | if id is None: 30 | self.no_log = True 31 | warnings.warn('Logging has been disbaled.') 32 | else: 33 | self.no_log = False 34 | 35 | self.inv_im_trans = transforms.Normalize( 36 | mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], 37 | std=[1/0.229, 1/0.224, 1/0.225]) 38 | 39 | self.inv_seg_trans = transforms.Normalize( 40 | mean=[-0.5/0.5], 41 | std=[1/0.5]) 42 | 43 | log_path = os.path.join('.', 'log', '%s' % id) 44 | self.logger = SummaryWriter(log_path) 45 | 46 | repo = git.Repo(".") 47 | self.log_string('git', str(repo.active_branch) + ' ' + str(repo.head.commit.hexsha)) 48 | 49 | def log_scalar(self, tag, x, step): 50 | if self.no_log: 51 | warnings.warn('Logging has been disabled.') 52 | return 53 | self.logger.add_scalar(tag, x, step) 54 | 55 | def log_metrics(self, l1_tag, l2_tag, val, step, f=None): 56 | tag = l1_tag + '/' + l2_tag 57 | text = '{:s} - It {:6d} [{:5s}] [{:13}]: {:s}'.format(self.short_id, step, l1_tag.upper(), l2_tag, fix_width_trunc(val)) 58 | print(text) 59 | if f is not None: 60 | f.write(text + '\n') 61 | f.flush() 62 | self.log_scalar(tag, val, step) 63 | 64 | def log_im(self, tag, x, step): 65 | if self.no_log: 66 | warnings.warn('Logging has been disabled.') 67 | return 68 | x = detach_to_cpu(x) 69 | x = self.inv_im_trans(x) 70 | x = tensor_to_numpy(x) 71 | self.logger.add_image(tag, x, step) 72 | 73 | def log_cv2(self, tag, x, step): 74 | if self.no_log: 75 | warnings.warn('Logging has been disabled.') 76 | return 77 | x = x.transpose((2, 0, 1)) 78 | self.logger.add_image(tag, x, step) 79 | 80 | def log_seg(self, tag, x, step): 81 | if self.no_log: 82 | warnings.warn('Logging has been disabled.') 83 | return 84 | x = detach_to_cpu(x) 85 | x = self.inv_seg_trans(x) 86 | x = tensor_to_numpy(x) 87 | self.logger.add_image(tag, x, step) 88 | 89 | def log_gray(self, tag, x, step): 90 | if self.no_log: 91 | warnings.warn('Logging has been disabled.') 92 | return 93 | x = detach_to_cpu(x) 94 | x = tensor_to_numpy(x) 95 | self.logger.add_image(tag, x, step) 96 | 97 | def log_string(self, tag, x): 98 | print(tag, x) 99 | if self.no_log: 100 | warnings.warn('Logging has been disabled.') 101 | return 102 | self.logger.add_text(tag, x) 103 | -------------------------------------------------------------------------------- /util/palette.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def get_color_map(N=256): 4 | def bitget(byteval, idx): 5 | return ((byteval & (1 << idx)) != 0) 6 | 7 | cmap = np.zeros((N, 3), dtype=np.uint8) 8 | for i in range(N): 9 | r = g = b = 0 10 | c = i 11 | for j in range(8): 12 | r = r | (bitget(c, 0) << 7-j) 13 | g = g | (bitget(c, 1) << 7-j) 14 | b = b | (bitget(c, 2) << 7-j) 15 | c = c >> 3 16 | 17 | cmap[i] = np.array([r, g, b]) 18 | 19 | return cmap 20 | 21 | color_map = get_color_map() 22 | def pal_color_map(): 23 | return color_map -------------------------------------------------------------------------------- /util/tensor_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | def compute_tensor_iu(seg, gt): 6 | intersection = (seg & gt).float().sum() 7 | union = (seg | gt).float().sum() 8 | 9 | return intersection, union 10 | 11 | def compute_np_iu(seg, gt): 12 | intersection = (seg & gt).astype(np.float32).sum() 13 | union = (seg | gt).astype(np.float32).sum() 14 | 15 | return intersection, union 16 | 17 | def compute_tensor_iou(seg, gt): 18 | intersection, union = compute_tensor_iu(seg, gt) 19 | iou = (intersection + 1e-6) / (union + 1e-6) 20 | 21 | return iou 22 | 23 | def compute_np_iou(seg, gt): 24 | intersection, union = compute_np_iu(seg, gt) 25 | iou = (intersection + 1e-6) / (union + 1e-6) 26 | 27 | return iou 28 | 29 | def compute_multi_class_iou(seg, gt): 30 | # seg -> k*h*w 31 | # gt -> k*1*h*w 32 | num_classes = gt.shape[0] 33 | pred_idx = torch.argmax(seg, dim=0) 34 | iou_sum = 0 35 | for ki in range(num_classes): 36 | # seg includes BG class 37 | iou_sum += compute_tensor_iou(pred_idx==(ki+1), gt[ki,0]>0.5) 38 | 39 | return (iou_sum+1e-6)/(num_classes+1e-6) 40 | 41 | def compute_multi_class_iou_idx(seg, gt): 42 | # seg -> h*w 43 | # gt -> k*h*w 44 | num_classes = gt.shape[0] 45 | iou_sum = 0 46 | for ki in range(num_classes): 47 | # seg includes BG class 48 | iou_sum += compute_np_iou(seg==(ki+1), gt[ki]>0.5) 49 | 50 | return (iou_sum+1e-6)/(num_classes+1e-6) 51 | 52 | def compute_multi_class_iou_both_idx(seg, gt): 53 | # seg -> h*w 54 | # gt -> h*w 55 | num_classes = gt.max() 56 | iou_sum = 0 57 | for ki in range(1, num_classes+1): 58 | iou_sum += compute_np_iou(seg==ki, gt==ki) 59 | return (iou_sum+1e-6)/(num_classes+1e-6) 60 | 61 | # STM 62 | def pad_divide_by(in_img, d, in_size=None): 63 | if in_size is None: 64 | h, w = in_img.shape[-2:] 65 | else: 66 | h, w = in_size 67 | 68 | if h % d > 0: 69 | new_h = h + d - h % d 70 | else: 71 | new_h = h 72 | if w % d > 0: 73 | new_w = w + d - w % d 74 | else: 75 | new_w = w 76 | lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2) 77 | lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2) 78 | pad_array = (int(lw), int(uw), int(lh), int(uh)) 79 | out = F.pad(in_img, pad_array) 80 | return out, pad_array 81 | 82 | def unpad(img, pad): 83 | if pad[2]+pad[3] > 0: 84 | img = img[:,:,pad[2]:-pad[3],:] 85 | if pad[0]+pad[1] > 0: 86 | img = img[:,:,:,pad[0]:-pad[1]] 87 | return img 88 | 89 | def unpad_3dim(img, pad): 90 | if pad[2]+pad[3] > 0: 91 | img = img[:,pad[2]:-pad[3],:] 92 | if pad[0]+pad[1] > 0: 93 | img = img[:,:,pad[0]:-pad[1]] 94 | return img --------------------------------------------------------------------------------