├── .DS_Store ├── LICENSE ├── README.md ├── dataset ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── davis_test_dataset.cpython-36.pyc │ ├── range_transform.cpython-36.pyc │ ├── reseed.cpython-36.pyc │ ├── static_dataset.cpython-36.pyc │ ├── tps.cpython-36.pyc │ ├── util.cpython-36.pyc │ ├── vos_dataset.cpython-36.pyc │ └── yv_test_dataset.cpython-36.pyc ├── davis_test_dataset.py ├── generic_test_dataset.py ├── range_transform.py ├── reseed.py ├── static_dataset.py ├── tps.py ├── util.py ├── vos_dataset.py └── yv_test_dataset.py ├── download_bl30k.py ├── download_datasets.py ├── environment.yaml ├── inference_core.py ├── model ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── aggregate.cpython-36.pyc │ ├── cbam.cpython-36.pyc │ ├── eval_network.cpython-36.pyc │ ├── losses.cpython-36.pyc │ ├── mod_resnet.cpython-36.pyc │ ├── model.cpython-36.pyc │ ├── model.cpython-36.pyc.140569625763192 │ ├── model.cpython-36.pyc.140666179276152 │ ├── model_temp.cpython-36.pyc │ ├── model_temp_temp.cpython-36.pyc │ ├── models_vit.cpython-36.pyc │ ├── models_vit.cpython-36.pyc.139742317369424 │ ├── models_vit.cpython-36.pyc.139824390907984 │ ├── models_vit.cpython-36.pyc.139838520842320 │ ├── models_vit.cpython-36.pyc.140109349144656 │ ├── models_vit.cpython-36.pyc.140142142203984 │ ├── models_vit.cpython-36.pyc.140277441316944 │ ├── modules.cpython-36.pyc │ └── network.cpython-36.pyc ├── aggregate.py ├── cbam.py ├── eval_network.py ├── losses.py ├── mod_resnet.py ├── model.py ├── model_temp.py ├── model_temp_temp.py ├── models_vit.py ├── modules.py └── network.py ├── pretrained_models └── .DS_Store ├── scripts ├── 00180.jpg ├── 00180_original.jpg ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── resize_youtube.cpython-36.pyc │ └── resize_youtube.cpython-38.pyc ├── resize_length.py └── resize_youtube.py ├── submit_eval_davis_ours_all.py ├── train_simvos.py └── util ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── hyper_para.cpython-36.pyc ├── image_saver.cpython-36.pyc ├── load_subset.cpython-36.pyc ├── log_integrator.cpython-36.pyc ├── logger.cpython-36.pyc └── tensor_util.cpython-36.pyc ├── davis_subset.txt ├── hyper_para.py ├── image_saver.py ├── load_subset.py ├── log_integrator.py ├── logger.py ├── tensor_util.py └── yv_subset.txt /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SimVOS 2 | 3 | The codes for ICCV 2023 paper 'Scalable Video Object Segmentation with Simplified Framework' 4 | 5 | ## :sunny: Highlights 6 | 7 | #### * Our Goal: providing a simple and scalable VOS baseline to explore the effect of self-supervised pre-training. 8 | 9 | #### * Our SimVOS only relies on video sequence for one-stage training and achieves favorable performance on DAVIS and YouTube datasets. 10 | 11 | #### * Our project is built upon the [_STCN_](https://github.com/hkchengrex/STCN) library. Thanks for their contribution. 12 | 13 | ## Install the environment 14 | We use the Anaconda to create the Python environment, which mainly follows the installation in [_STCN_](https://github.com/hkchengrex/STCN). The cuda environment we use for result reproduction is python3.6-cuda11.0-cudnn8.1. 15 | One installation packages can be found in `environment.yaml`. 16 | 17 | 18 | ## Data preparation 19 | We follow the same data preparation steps used in [_STCN_](https://github.com/hkchengrex/STCN). Download both DAVIS and YouTube-19 datasets: 20 | ```bash 21 | ├── DAVIS 22 | │ ├── 2016 23 | │ │ ├── Annotations 24 | │ │ └── ... 25 | │ └── 2017 26 | │ ├── test-dev 27 | │ │ ├── Annotations 28 | │ │ └── ... 29 | │ └── trainval 30 | │ ├── Annotations 31 | │ └── ... 32 | ├── YouTube 33 | │ ├── all_frames 34 | │ │ └── valid_all_frames 35 | │ ├── train 36 | │ ├── train_480p 37 | │ └── valid 38 | ``` 39 | 40 | ### Pre-trained model download 41 | Please download the pre-trained weights (e.g., MAE: ViT-Base or ViT-Large) and put them in `./pretrained_models` folder. 42 | 43 | ### Training command 44 | To train a SimVOS model (ViT-Base with MAE Init.) w/ token refinement (e.g., the default seeting with 384/384 foreground/background prototypes and `layer_index=4` for prptotype generation): 45 | ``` 46 | python -m torch.distributed.launch --master_port 9842 --nproc_per_node=8 train_simvos.py --id retrain_s03 --stage 3 47 | ``` 48 | If you want to train a SimVOS-B model w/o token refinement: 49 | ``` 50 | python -m torch.distributed.launch --master_port 9842 --nproc_per_node=8 train_simvos.py --id retrain_s03 --stage 3 --layer_index 0 --use_token_learner False 51 | ``` 52 | or SimVOS-L model: 53 | ``` 54 | python -m torch.distributed.launch --master_port 9842 --nproc_per_node=8 train_simvos.py --id retrain_s03 --stage 3 --layer_index 0 --use_token_learner False --backbone_type vit_large 55 | ``` 56 | 57 | ### Evaluation command 58 | Download the SimVOS models [SimVOS-BS(384/384-layer_index=4-vitbase)](https://drive.google.com/file/d/1v1FdDc5oFFUOBZ_Oc2yhPxDZpbHpTYsY/view?usp=drive_link), [SimVOS-B(vitbase)](https://drive.google.com/file/d/1v1FdDc5oFFUOBZ_Oc2yhPxDZpbHpTYsY/view?usp=drive_link](https://drive.google.com/file/d/1uSobYg2JQzpR-Lwb81YsUoyjEr1jaTwJ/view?usp=drive_link)), and [SimVOS-L(vitlarge)](https://drive.google.com/file/d/1bh2FyaoRlTdupvCHRiJc9O9vnRhSkcE8/view?usp=drive_link). Put the models in the `test_checkpoints` folder. After taht, run the evaluation w/ the following commands. All evaluations are done in the 480p resolution. 59 | ``` 60 | #SimVOS-BS 61 | python submit_eval_davis_ours_all.py --model_path ./test_checkpoints --davis_path ./Data/DAVIS/2017 --output ./results --split val --layer_index 4 --use_token_learner --backbone_type vit_base 62 | #SimVOS-B 63 | python submit_eval_davis_ours_all.py --model_path ./test_checkpoints --davis_path ./Data/DAVIS/2017 --output ./results --split val --layer_index 0 --backbone_type vit_base 64 | #SimVOS-L 65 | python submit_eval_davis_ours_all.py --model_path ./test_checkpoints --davis_path ./Data/DAVIS/2017 --output ./results --split val --layer_index 0 --backbone_type vit_large 66 | ``` 67 | 68 | After running the above evaluation, you could get the qualitative results saved in the root project directory. You could use the offline evaluation toolikit (https://github.com/davisvideochallenge/davis2017-evaluation) to get the validation performance on DAVIS-16/17. For `test-dev` on DAVIS-17, using the online evaluation server instead. 69 | 70 | ------ 71 | 72 | If you find our work useful in your research, please consider citing: 73 | 74 | ``` 75 | @inproceedings{wu2023, 76 | title={Scalable Video Object Segmentation with Simplified Framework}, 77 | author={Qiangqiang Wu and Tianyu Yang and Wei Wu and Antoni B. Chan}, 78 | booktitle={ICCV}, 79 | year={2023} 80 | } 81 | ``` 82 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/dataset/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/davis_test_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/dataset/__pycache__/davis_test_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/range_transform.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/dataset/__pycache__/range_transform.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/reseed.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/dataset/__pycache__/reseed.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/static_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/dataset/__pycache__/static_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/tps.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/dataset/__pycache__/tps.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/dataset/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/vos_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/dataset/__pycache__/vos_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/yv_test_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/dataset/__pycache__/yv_test_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/davis_test_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/seoungwugoh/STM/blob/master/dataset.py 3 | """ 4 | 5 | import os 6 | from os import path 7 | import numpy as np 8 | from PIL import Image 9 | 10 | import torch 11 | from torchvision import transforms 12 | from torchvision.transforms import InterpolationMode 13 | from torch.utils.data.dataset import Dataset 14 | from dataset.range_transform import im_normalization 15 | from dataset.util import all_to_onehot 16 | 17 | 18 | class DAVISTestDataset(Dataset): 19 | def __init__(self, root, imset='2017/val.txt', resolution=480, single_object=False, target_name=None): 20 | self.root = root 21 | if resolution == 480: 22 | res_tag = '480p' 23 | else: 24 | res_tag = 'Full-Resolution' 25 | self.mask_dir = path.join(root, 'Annotations', res_tag) 26 | self.mask480_dir = path.join(root, 'Annotations', '480p') 27 | self.image_dir = path.join(root, 'JPEGImages', res_tag) 28 | self.resolution = resolution 29 | _imset_dir = path.join(root, 'ImageSets') 30 | _imset_f = path.join(_imset_dir, imset) 31 | 32 | self.videos = [] 33 | self.num_frames = {} 34 | self.num_objects = {} 35 | self.shape = {} 36 | self.size_480p = {} 37 | with open(path.join(_imset_f), "r") as lines: 38 | for line in lines: 39 | _video = line.rstrip('\n') 40 | if target_name is not None and target_name != _video: 41 | continue 42 | self.videos.append(_video) 43 | self.num_frames[_video] = len(os.listdir(path.join(self.image_dir, _video))) 44 | _mask = np.array(Image.open(path.join(self.mask_dir, _video, '00000.png')).convert("P")) 45 | self.num_objects[_video] = np.max(_mask) 46 | self.shape[_video] = np.shape(_mask) 47 | _mask480 = np.array(Image.open(path.join(self.mask480_dir, _video, '00000.png')).convert("P")) 48 | self.size_480p[_video] = np.shape(_mask480) 49 | 50 | self.single_object = single_object 51 | 52 | if resolution == 480: 53 | self.im_transform = transforms.Compose([ 54 | transforms.ToTensor(), 55 | im_normalization, 56 | ]) 57 | else: 58 | self.im_transform = transforms.Compose([ 59 | transforms.ToTensor(), 60 | im_normalization, 61 | transforms.Resize(resolution, interpolation=InterpolationMode.BICUBIC), 62 | ]) 63 | self.mask_transform = transforms.Compose([ 64 | transforms.Resize(resolution, interpolation=InterpolationMode.NEAREST), 65 | ]) 66 | 67 | def __len__(self): 68 | return len(self.videos) 69 | 70 | def __getitem__(self, index): 71 | video = self.videos[index] 72 | info = {} 73 | info['name'] = video 74 | info['frames'] = [] 75 | info['num_frames'] = self.num_frames[video] 76 | info['size_480p'] = self.size_480p[video] 77 | 78 | images = [] 79 | masks = [] 80 | for f in range(self.num_frames[video]): 81 | img_file = path.join(self.image_dir, video, '{:05d}.jpg'.format(f)) 82 | images.append(self.im_transform(Image.open(img_file).convert('RGB'))) 83 | info['frames'].append('{:05d}.jpg'.format(f)) 84 | 85 | mask_file = path.join(self.mask_dir, video, '{:05d}.png'.format(f)) 86 | if path.exists(mask_file): 87 | masks.append(np.array(Image.open(mask_file).convert('P'), dtype=np.uint8)) 88 | else: 89 | # Test-set maybe? 90 | masks.append(np.zeros_like(masks[0])) 91 | 92 | images = torch.stack(images, 0) 93 | masks = np.stack(masks, 0) 94 | 95 | if self.single_object: 96 | labels = [1] 97 | masks = (masks > 0.5).astype(np.uint8) 98 | masks = torch.from_numpy(all_to_onehot(masks, labels)).float() 99 | else: 100 | labels = np.unique(masks[0]) 101 | labels = labels[labels!=0] 102 | masks = torch.from_numpy(all_to_onehot(masks, labels)).float() 103 | 104 | if self.resolution != 480: 105 | masks = self.mask_transform(masks) 106 | masks = masks.unsqueeze(2) 107 | 108 | info['labels'] = labels 109 | 110 | data = { 111 | 'rgb': images, 112 | 'gt': masks, 113 | 'info': info, 114 | } 115 | 116 | return data 117 | 118 | -------------------------------------------------------------------------------- /dataset/generic_test_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | 4 | import torch 5 | from torch.utils.data.dataset import Dataset 6 | from torchvision import transforms 7 | from torchvision.transforms import InterpolationMode 8 | from PIL import Image 9 | import numpy as np 10 | 11 | from dataset.range_transform import im_normalization 12 | from dataset.util import all_to_onehot 13 | 14 | 15 | class GenericTestDataset(Dataset): 16 | def __init__(self, data_root, res=480): 17 | self.image_dir = path.join(data_root, 'JPEGImages') 18 | self.mask_dir = path.join(data_root, 'Annotations') 19 | 20 | self.videos = [] 21 | self.shape = {} 22 | self.frames = {} 23 | 24 | vid_list = sorted(os.listdir(self.image_dir)) 25 | # Pre-reading 26 | for vid in vid_list: 27 | frames = sorted(os.listdir(os.path.join(self.image_dir, vid))) 28 | self.frames[vid] = frames 29 | 30 | self.videos.append(vid) 31 | first_mask = os.listdir(path.join(self.mask_dir, vid))[0] 32 | _mask = np.array(Image.open(path.join(self.mask_dir, vid, first_mask)).convert("P")) 33 | self.shape[vid] = np.shape(_mask) 34 | 35 | if res != -1: 36 | self.im_transform = transforms.Compose([ 37 | transforms.ToTensor(), 38 | im_normalization, 39 | transforms.Resize(res, interpolation=InterpolationMode.BICUBIC), 40 | ]) 41 | 42 | self.mask_transform = transforms.Compose([ 43 | transforms.Resize(res, interpolation=InterpolationMode.NEAREST), 44 | ]) 45 | else: 46 | self.im_transform = transforms.Compose([ 47 | transforms.ToTensor(), 48 | im_normalization, 49 | ]) 50 | 51 | self.mask_transform = transforms.Compose([ 52 | ]) 53 | 54 | def __getitem__(self, idx): 55 | video = self.videos[idx] 56 | info = {} 57 | info['name'] = video 58 | info['frames'] = self.frames[video] 59 | info['size'] = self.shape[video] # Real sizes 60 | info['gt_obj'] = {} # Frames with labelled objects 61 | 62 | vid_im_path = path.join(self.image_dir, video) 63 | vid_gt_path = path.join(self.mask_dir, video) 64 | 65 | frames = self.frames[video] 66 | 67 | images = [] 68 | masks = [] 69 | for i, f in enumerate(frames): 70 | img = Image.open(path.join(vid_im_path, f)).convert('RGB') 71 | images.append(self.im_transform(img)) 72 | 73 | mask_file = path.join(vid_gt_path, f.replace('.jpg','.png')) 74 | if path.exists(mask_file): 75 | mask = Image.open(mask_file).convert('P') 76 | palette = mask.getpalette() 77 | masks.append(np.array(mask, dtype=np.uint8)) 78 | this_labels = np.unique(masks[-1]) 79 | this_labels = this_labels[this_labels!=0] 80 | info['gt_obj'][i] = this_labels 81 | else: 82 | # Mask not exists -> nothing in it 83 | masks.append(np.zeros(self.shape[video])) 84 | 85 | images = torch.stack(images, 0) 86 | masks = np.stack(masks, 0) 87 | 88 | # Construct the forward and backward mapping table for labels 89 | # this is because YouTubeVOS's labels are sometimes not continuous 90 | # while we want continuous ones (for one-hot) 91 | # so we need to maintain a backward mapping table 92 | labels = np.unique(masks).astype(np.uint8) 93 | labels = labels[labels!=0] 94 | info['label_convert'] = {} 95 | info['label_backward'] = {} 96 | idx = 1 97 | for l in labels: 98 | info['label_convert'][l] = idx 99 | info['label_backward'][idx] = l 100 | idx += 1 101 | masks = torch.from_numpy(all_to_onehot(masks, labels)).float() 102 | 103 | # Resize to 480p 104 | masks = self.mask_transform(masks) 105 | masks = masks.unsqueeze(2) 106 | 107 | info['labels'] = labels 108 | 109 | data = { 110 | 'rgb': images, 111 | 'gt': masks, 112 | 'info': info, 113 | 'palette': np.array(palette), 114 | } 115 | 116 | return data 117 | 118 | def __len__(self): 119 | return len(self.videos) -------------------------------------------------------------------------------- /dataset/range_transform.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | 3 | im_mean = (124, 116, 104) 4 | 5 | im_normalization = transforms.Normalize( 6 | mean=[0.485, 0.456, 0.406], 7 | std=[0.229, 0.224, 0.225] 8 | ) 9 | 10 | inv_im_trans = transforms.Normalize( 11 | mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], 12 | std=[1/0.229, 1/0.224, 1/0.225]) 13 | -------------------------------------------------------------------------------- /dataset/reseed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | 4 | def reseed(seed): 5 | random.seed(seed) 6 | torch.manual_seed(seed) -------------------------------------------------------------------------------- /dataset/static_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | 4 | import torch 5 | from torch.utils.data.dataset import Dataset 6 | from torchvision import transforms 7 | from torchvision.transforms import InterpolationMode 8 | from PIL import Image 9 | import numpy as np 10 | 11 | from dataset.range_transform import im_normalization, im_mean 12 | from dataset.tps import random_tps_warp 13 | from dataset.reseed import reseed 14 | 15 | 16 | class StaticTransformDataset(Dataset): 17 | """ 18 | Generate pseudo VOS data by applying random transforms on static images. 19 | Single-object only. 20 | 21 | Method 0 - FSS style (class/1.jpg class/1.png) 22 | Method 1 - Others style (XXX.jpg XXX.png) 23 | """ 24 | def __init__(self, root, method=0): 25 | self.root = root 26 | self.method = method 27 | 28 | if method == 0: 29 | # Get images 30 | self.im_list = [] 31 | classes = os.listdir(self.root) 32 | for c in classes: 33 | imgs = os.listdir(path.join(root, c)) 34 | jpg_list = [im for im in imgs if 'jpg' in im[-3:].lower()] 35 | 36 | joint_list = [path.join(root, c, im) for im in jpg_list] 37 | self.im_list.extend(joint_list) 38 | 39 | elif method == 1: 40 | self.im_list = [path.join(self.root, im) for im in os.listdir(self.root) if '.jpg' in im] 41 | 42 | print('%d images found in %s' % (len(self.im_list), root)) 43 | 44 | # These set of transform is the same for im/gt pairs, but different among the 3 sampled frames 45 | self.pair_im_lone_transform = transforms.Compose([ 46 | transforms.ColorJitter(0.1, 0.05, 0.05, 0), # No hue change here as that's not realistic 47 | ]) 48 | 49 | self.pair_im_dual_transform = transforms.Compose([ 50 | transforms.RandomAffine(degrees=20, scale=(0.9,1.1), shear=10, interpolation=InterpolationMode.BICUBIC, fill=im_mean), 51 | transforms.Resize(384, InterpolationMode.BICUBIC), 52 | transforms.RandomCrop((384, 384), pad_if_needed=True, fill=im_mean), 53 | ]) 54 | 55 | self.pair_gt_dual_transform = transforms.Compose([ 56 | transforms.RandomAffine(degrees=20, scale=(0.9,1.1), shear=10, interpolation=InterpolationMode.BICUBIC, fill=0), 57 | transforms.Resize(384, InterpolationMode.NEAREST), 58 | transforms.RandomCrop((384, 384), pad_if_needed=True, fill=0), 59 | ]) 60 | 61 | # These transform are the same for all pairs in the sampled sequence 62 | self.all_im_lone_transform = transforms.Compose([ 63 | transforms.ColorJitter(0.1, 0.05, 0.05, 0.05), 64 | transforms.RandomGrayscale(0.05), 65 | ]) 66 | 67 | self.all_im_dual_transform = transforms.Compose([ 68 | transforms.RandomAffine(degrees=0, scale=(0.8, 1.5), fill=im_mean), 69 | transforms.RandomHorizontalFlip(), 70 | ]) 71 | 72 | self.all_gt_dual_transform = transforms.Compose([ 73 | transforms.RandomAffine(degrees=0, scale=(0.8, 1.5), fill=0), 74 | transforms.RandomHorizontalFlip(), 75 | ]) 76 | 77 | # Final transform without randomness 78 | self.final_im_transform = transforms.Compose([ 79 | transforms.ToTensor(), 80 | im_normalization, 81 | ]) 82 | 83 | self.final_gt_transform = transforms.Compose([ 84 | transforms.ToTensor(), 85 | ]) 86 | 87 | def __getitem__(self, idx): 88 | im = Image.open(self.im_list[idx]).convert('RGB') 89 | 90 | if self.method == 0: 91 | gt = Image.open(self.im_list[idx][:-3]+'png').convert('L') 92 | else: 93 | gt = Image.open(self.im_list[idx].replace('.jpg','.png')).convert('L') 94 | 95 | sequence_seed = np.random.randint(2147483647) 96 | 97 | images = [] 98 | masks = [] 99 | for _ in range(3): 100 | reseed(sequence_seed) 101 | this_im = self.all_im_dual_transform(im) 102 | this_im = self.all_im_lone_transform(this_im) 103 | reseed(sequence_seed) 104 | this_gt = self.all_gt_dual_transform(gt) 105 | 106 | pairwise_seed = np.random.randint(2147483647) 107 | reseed(pairwise_seed) 108 | this_im = self.pair_im_dual_transform(this_im) 109 | this_im = self.pair_im_lone_transform(this_im) 110 | reseed(pairwise_seed) 111 | this_gt = self.pair_gt_dual_transform(this_gt) 112 | 113 | # Use TPS only some of the times 114 | # Not because TPS is bad -- just that it is too slow and I need to speed up data loading 115 | if np.random.rand() < 0.33: 116 | this_im, this_gt = random_tps_warp(this_im, this_gt, scale=0.02) 117 | 118 | this_im = self.final_im_transform(this_im) 119 | this_gt = self.final_gt_transform(this_gt) 120 | 121 | images.append(this_im) 122 | masks.append(this_gt) 123 | 124 | images = torch.stack(images, 0) 125 | masks = torch.stack(masks, 0) 126 | 127 | info = {} 128 | info['name'] = self.im_list[idx] 129 | 130 | cls_gt = np.zeros((3, 384, 384), dtype=np.int) 131 | cls_gt[masks[:,0] > 0.5] = 1 132 | 133 | data = { 134 | 'rgb': images, 135 | 'gt': masks, 136 | 'cls_gt': cls_gt, 137 | 'info': info 138 | } 139 | 140 | return data 141 | 142 | 143 | def __len__(self): 144 | return len(self.im_list) 145 | -------------------------------------------------------------------------------- /dataset/tps.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import cv2 4 | import thinplate as tps 5 | 6 | cv2.setNumThreads(0) 7 | 8 | def pick_random_points(h, w, n_samples): 9 | y_idx = np.random.choice(np.arange(h), size=n_samples, replace=False) 10 | x_idx = np.random.choice(np.arange(w), size=n_samples, replace=False) 11 | return y_idx/h, x_idx/w 12 | 13 | 14 | def warp_dual_cv(img, mask, c_src, c_dst): 15 | dshape = img.shape 16 | theta = tps.tps_theta_from_points(c_src, c_dst, reduced=True) 17 | grid = tps.tps_grid(theta, c_dst, dshape) 18 | mapx, mapy = tps.tps_grid_to_remap(grid, img.shape) 19 | return cv2.remap(img, mapx, mapy, cv2.INTER_LINEAR), cv2.remap(mask, mapx, mapy, cv2.INTER_NEAREST) 20 | 21 | 22 | def random_tps_warp(img, mask, scale, n_ctrl_pts=12): 23 | """ 24 | Apply a random TPS warp of the input image and mask 25 | Uses randomness from numpy 26 | """ 27 | img = np.asarray(img) 28 | mask = np.asarray(mask) 29 | 30 | h, w = mask.shape 31 | points = pick_random_points(h, w, n_ctrl_pts) 32 | c_src = np.stack(points, 1) 33 | c_dst = c_src + np.random.normal(scale=scale, size=c_src.shape) 34 | warp_im, warp_gt = warp_dual_cv(img, mask, c_src, c_dst) 35 | 36 | return Image.fromarray(warp_im), Image.fromarray(warp_gt) 37 | 38 | -------------------------------------------------------------------------------- /dataset/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def all_to_onehot(masks, labels): 5 | if len(masks.shape) == 3: 6 | Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8) 7 | else: 8 | Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8) 9 | 10 | for k, l in enumerate(labels): 11 | Ms[k] = (masks == l).astype(np.uint8) 12 | 13 | return Ms 14 | -------------------------------------------------------------------------------- /dataset/yv_test_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | 4 | import torch 5 | from torch.utils.data.dataset import Dataset 6 | from torchvision import transforms 7 | from torchvision.transforms import InterpolationMode 8 | from PIL import Image 9 | import numpy as np 10 | 11 | from dataset.range_transform import im_normalization 12 | from dataset.util import all_to_onehot 13 | 14 | 15 | class YouTubeVOSTestDataset(Dataset): 16 | def __init__(self, data_root, split, res=480): 17 | self.image_dir = path.join(data_root, 'all_frames', split+'_all_frames', 'JPEGImages') 18 | self.mask_dir = path.join(data_root, split, 'Annotations') 19 | 20 | self.videos = [] 21 | self.shape = {} 22 | self.frames = {} 23 | 24 | vid_list = sorted(os.listdir(self.image_dir)) 25 | # Pre-reading 26 | for vid in vid_list: 27 | frames = sorted(os.listdir(os.path.join(self.image_dir, vid))) 28 | self.frames[vid] = frames 29 | 30 | self.videos.append(vid) 31 | first_mask = os.listdir(path.join(self.mask_dir, vid))[0] 32 | _mask = np.array(Image.open(path.join(self.mask_dir, vid, first_mask)).convert("P")) 33 | self.shape[vid] = np.shape(_mask) 34 | 35 | if res != -1: 36 | self.im_transform = transforms.Compose([ 37 | transforms.ToTensor(), 38 | im_normalization, 39 | transforms.Resize(res, interpolation=InterpolationMode.BICUBIC), 40 | ]) 41 | 42 | self.mask_transform = transforms.Compose([ 43 | transforms.Resize(res, interpolation=InterpolationMode.NEAREST), 44 | ]) 45 | else: 46 | self.im_transform = transforms.Compose([ 47 | transforms.ToTensor(), 48 | im_normalization, 49 | ]) 50 | 51 | self.mask_transform = transforms.Compose([ 52 | ]) 53 | 54 | def __getitem__(self, idx): 55 | video = self.videos[idx] 56 | info = {} 57 | info['name'] = video 58 | info['frames'] = self.frames[video] 59 | info['size'] = self.shape[video] # Real sizes 60 | info['gt_obj'] = {} # Frames with labelled objects 61 | 62 | vid_im_path = path.join(self.image_dir, video) 63 | vid_gt_path = path.join(self.mask_dir, video) 64 | 65 | frames = self.frames[video] 66 | 67 | images = [] 68 | masks = [] 69 | for i, f in enumerate(frames): 70 | img = Image.open(path.join(vid_im_path, f)).convert('RGB') 71 | images.append(self.im_transform(img)) 72 | 73 | mask_file = path.join(vid_gt_path, f.replace('.jpg','.png')) 74 | if path.exists(mask_file): 75 | masks.append(np.array(Image.open(mask_file).convert('P'), dtype=np.uint8)) 76 | this_labels = np.unique(masks[-1]) 77 | this_labels = this_labels[this_labels!=0] 78 | info['gt_obj'][i] = this_labels 79 | else: 80 | # Mask not exists -> nothing in it 81 | masks.append(np.zeros(self.shape[video])) 82 | 83 | images = torch.stack(images, 0) 84 | masks = np.stack(masks, 0) 85 | 86 | # Construct the forward and backward mapping table for labels 87 | # this is because YouTubeVOS's labels are sometimes not continuous 88 | # while we want continuous ones (for one-hot) 89 | # so we need to maintain a backward mapping table 90 | labels = np.unique(masks).astype(np.uint8) 91 | labels = labels[labels!=0] 92 | info['label_convert'] = {} 93 | info['label_backward'] = {} 94 | idx = 1 95 | for l in labels: 96 | info['label_convert'][l] = idx 97 | info['label_backward'][idx] = l 98 | idx += 1 99 | masks = torch.from_numpy(all_to_onehot(masks, labels)).float() 100 | 101 | # Resize to 480p 102 | masks = self.mask_transform(masks) 103 | masks = masks.unsqueeze(2) 104 | 105 | info['labels'] = labels 106 | 107 | data = { 108 | 'rgb': images, 109 | 'gt': masks, 110 | 'info': info, 111 | } 112 | 113 | return data 114 | 115 | def __len__(self): 116 | return len(self.videos) -------------------------------------------------------------------------------- /download_bl30k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gdown 3 | import tarfile 4 | 5 | 6 | LICENSE = """ 7 | This dataset is a derivative of ShapeNet. 8 | Please read and respect their licenses and terms before use. 9 | Textures and skybox image are obtained from Google image search with the "non-commercial reuse" flag. 10 | Do not use this dataset for commercial purposes. 11 | You should cite both ShapeNet and our paper if you use this dataset. 12 | """ 13 | 14 | print(LICENSE) 15 | print('Datasets will be downloaded and extracted to ../BL30K') 16 | print('The script will download and extract the segment one by one') 17 | print('You are going to need ~1TB of free disk space') 18 | reply = input('[y] to confirm, others to exit: ') 19 | if reply != 'y': 20 | exit() 21 | 22 | links = [ 23 | 'https://drive.google.com/uc?id=1z9V5zxLOJLNt1Uj7RFqaP2FZWKzyXvVc', 24 | 'https://drive.google.com/uc?id=11-IzgNwEAPxgagb67FSrBdzZR7OKAEdJ', 25 | 'https://drive.google.com/uc?id=1ZfIv6GTo-OGpXpoKen1fUvDQ0A_WoQ-Q', 26 | 'https://drive.google.com/uc?id=1G4eXgYS2kL7_Cc0x3N1g1x7Zl8D_aU_-', 27 | 'https://drive.google.com/uc?id=1Y8q0V_oBwJIY27W_6-8CD1dRqV2gNTdE', 28 | 'https://drive.google.com/uc?id=1nawBAazf_unMv46qGBHhWcQ4JXZ5883r', 29 | ] 30 | 31 | names = [ 32 | 'BL30K_a.tar', 33 | 'BL30K_b.tar', 34 | 'BL30K_c.tar', 35 | 'BL30K_d.tar', 36 | 'BL30K_e.tar', 37 | 'BL30K_f.tar', 38 | ] 39 | 40 | for i, link in enumerate(links): 41 | print('Downloading segment %d/%d ...' % (i, len(links))) 42 | gdown.download(link, output='../%s' % names[i], quiet=False) 43 | print('Extracting...') 44 | with tarfile.open('../%s' % names[i], 'r') as tar_file: 45 | tar_file.extractall('../%s' % names[i]) 46 | print('Cleaning up...') 47 | os.remove('../%s' % names[i]) 48 | 49 | 50 | print('Done.') -------------------------------------------------------------------------------- /download_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gdown 3 | import zipfile 4 | from scripts import resize_youtube 5 | 6 | 7 | LICENSE = """ 8 | These are either re-distribution of the original datasets or derivatives (through simple processing) of the original datasets. 9 | Please read and respect their licenses and terms before use. 10 | You should cite the original papers if you use any of the datasets. 11 | 12 | For BL30K, see download_bl30k.py 13 | 14 | Links: 15 | DUTS: http://saliencydetection.net/duts 16 | HRSOD: https://github.com/yi94code/HRSOD 17 | FSS: https://github.com/HKUSTCV/FSS-1000 18 | ECSSD: https://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/dataset.html 19 | BIG: https://github.com/hkchengrex/CascadePSP 20 | 21 | YouTubeVOS: https://youtube-vos.org 22 | DAVIS: https://davischallenge.org/ 23 | BL30K: https://github.com/hkchengrex/MiVOS 24 | """ 25 | 26 | print(LICENSE) 27 | print('Datasets will be downloaded and extracted to /home/user/HDD_Data/YouTube, /home/user/HDD_Data/YouTube2018, /home/user/HDD_Data/static, /home/user/HDD_Data/DAVIS') 28 | reply = input('[y] to confirm, others to exit: ') 29 | if reply != 'y': 30 | exit() 31 | 32 | 33 | # Static data 34 | os.makedirs('/home/user/HDD_Data/static', exist_ok=True) 35 | print('Downloading static datasets...') 36 | gdown.download('https://drive.google.com/uc?id=1wUJq3HcLdN-z1t4CsUhjeZ9BVDb9YKLd', output='/home/user/HDD_Data/static_data.zip', quiet=False) 37 | print('Extracting static datasets...') 38 | with zipfile.ZipFile('/home/user/HDD_Data/static/static_data.zip', 'r') as zip_file: 39 | zip_file.extractall('/home/user/HDD_Data/static/') 40 | print('Cleaning up static datasets...') 41 | os.remove('/home/user/HDD_Data/static/static_data.zip') 42 | 43 | 44 | # # DAVIS 45 | # # Google drive mirror: https://drive.google.com/drive/folders/1hEczGHw7qcMScbCJukZsoOW4Q9byx16A?usp=sharing 46 | # os.makedirs('../DAVIS/2017', exist_ok=True) 47 | 48 | # print('Downloading DAVIS 2016...') 49 | # gdown.download('https://drive.google.com/uc?id=198aRlh5CpAoFz0hfRgYbiNenn_K8DxWD', output='../DAVIS/DAVIS-data.zip', quiet=False) 50 | 51 | # print('Downloading DAVIS 2017 trainval...') 52 | # gdown.download('https://drive.google.com/uc?id=1kiaxrX_4GuW6NmiVuKGSGVoKGWjOdp6d', output='../DAVIS/2017/DAVIS-2017-trainval-480p.zip', quiet=False) 53 | 54 | # print('Downloading DAVIS 2017 testdev...') 55 | # gdown.download('https://drive.google.com/uc?id=1fmkxU2v9cQwyb62Tj1xFDdh2p4kDsUzD', output='../DAVIS/2017/DAVIS-2017-test-dev-480p.zip', quiet=False) 56 | 57 | # print('Downloading DAVIS 2017 scribbles...') 58 | # gdown.download('https://drive.google.com/uc?id=1JzIQSu36h7dVM8q0VoE4oZJwBXvrZlkl', output='../DAVIS/2017/DAVIS-2017-scribbles-trainval.zip', quiet=False) 59 | 60 | # print('Extracting DAVIS datasets...') 61 | # with zipfile.ZipFile('../DAVIS/DAVIS-data.zip', 'r') as zip_file: 62 | # zip_file.extractall('../DAVIS/') 63 | # os.rename('../DAVIS/DAVIS', '../DAVIS/2016') 64 | 65 | # with zipfile.ZipFile('../DAVIS/2017/DAVIS-2017-trainval-480p.zip', 'r') as zip_file: 66 | # zip_file.extractall('../DAVIS/2017/') 67 | # with zipfile.ZipFile('../DAVIS/2017/DAVIS-2017-scribbles-trainval.zip', 'r') as zip_file: 68 | # zip_file.extractall('../DAVIS/2017/') 69 | # os.rename('../DAVIS/2017/DAVIS', '../DAVIS/2017/trainval') 70 | 71 | # with zipfile.ZipFile('../DAVIS/2017/DAVIS-2017-test-dev-480p.zip', 'r') as zip_file: 72 | # zip_file.extractall('../DAVIS/2017/') 73 | # os.rename('../DAVIS/2017/DAVIS', '../DAVIS/2017/test-dev') 74 | 75 | # print('Cleaning up DAVIS datasets...') 76 | # os.remove('../DAVIS/2017/DAVIS-2017-trainval-480p.zip') 77 | # os.remove('../DAVIS/2017/DAVIS-2017-test-dev-480p.zip') 78 | # os.remove('../DAVIS/2017/DAVIS-2017-scribbles-trainval.zip') 79 | # os.remove('../DAVIS/DAVIS-data.zip') 80 | 81 | 82 | # YouTubeVOS 83 | # os.makedirs('/home/user/HDD_Data/YouTube19', exist_ok=True) 84 | # os.makedirs('/home/user/HDD_Data/YouTube/all_frames', exist_ok=True) 85 | 86 | # print('Downloading YouTubeVOS train...') 87 | # gdown.download('https://drive.google.com/uc?id=13Eqw0gVK-AO5B-cqvJ203mZ2vzWck9s4', output='/home/user/HDD_Data/YouTube/train.zip', quiet=False) 88 | # print('Downloading YouTubeVOS val...') 89 | # gdown.download('https://drive.google.com/uc?id=1o586Wjya-f2ohxYf9C1RlRH-gkrzGS8t', output='/home/user/HDD_Data/YouTube/valid.zip', quiet=False) 90 | # print('Downloading YouTubeVOS all frames valid...') 91 | # gdown.download('https://drive.google.com/uc?id=1rWQzZcMskgpEQOZdJPJ7eTmLCBEIIpEN', output='/home/user/HDD_Data/YouTube/all_frames/valid.zip', quiet=False) 92 | 93 | # print('Extracting YouTube datasets...') 94 | # with zipfile.ZipFile('/home/user/HDD_Data/YouTube19/train.zip', 'r') as zip_file: 95 | # zip_file.extractall('/home/user/HDD_Data/YouTube19/') 96 | # with zipfile.ZipFile('/home/user/HDD_Data/YouTube19/valid.zip', 'r') as zip_file: 97 | # zip_file.extractall('/home/user/HDD_Data/YouTube19/') 98 | # with zipfile.ZipFile('/home/user/HDD_Data/YouTube19/all_frames/valid_all_frames.zip', 'r') as zip_file: 99 | # zip_file.extractall('/home/user/HDD_Data/YouTube19/all_frames') 100 | 101 | # print('Cleaning up YouTubeVOS datasets...') 102 | # os.remove('/home/user/HDD_Data/YouTube/train.zip') 103 | # os.remove('/home/user/HDD_Data/YouTube/valid.zip') 104 | # os.remove('/home/user/HDD_Data/YouTube/all_frames/valid.zip') 105 | 106 | # print('Resizing YouTubeVOS to 480p...') 107 | # resize_youtube.resize_all('/home/user/HDD_Data/YouTube19/train', '/home/user/HDD_Data/YouTube19/train_480p') 108 | 109 | # # YouTubeVOS 2018 110 | # os.makedirs('/home/user/HDD_Data/YouTube2018', exist_ok=True) 111 | # os.makedirs('/home/user/HDD_Data/YouTube2018/all_frames', exist_ok=True) 112 | 113 | # print('Downloading YouTubeVOS2018 val...') 114 | # gdown.download('https://drive.google.com/uc?id=1-QrceIl5sUNTKz7Iq0UsWC6NLZq7girr', output='/home/user/HDD_Data/YouTube2018/valid.zip', quiet=False) 115 | # print('Downloading YouTubeVOS2018 all frames valid...') 116 | # gdown.download('https://drive.google.com/uc?id=1yVoHM6zgdcL348cFpolFcEl4IC1gorbV', output='/home/user/HDD_Data/YouTube2018/all_frames/valid.zip', quiet=False) 117 | 118 | # print('Extracting YouTube2018 datasets...') 119 | # with zipfile.ZipFile('/home/user/HDD_Data/YouTube2018/valid.zip', 'r') as zip_file: 120 | # zip_file.extractall('/home/user/HDD_Data/YouTube2018/') 121 | # with zipfile.ZipFile('/home/user/HDD_Data/YouTube2018/all_frames/valid.zip', 'r') as zip_file: 122 | # zip_file.extractall('/home/user/HDD_Data/YouTube2018/all_frames') 123 | 124 | # print('Cleaning up YouTubeVOS2018 datasets...') 125 | # os.remove('/home/user/HDD_Data/YouTube2018/valid.zip') 126 | # os.remove('/home/user/HDD_Data/YouTube2018/all_frames/valid.zip') 127 | 128 | # print('Done.') -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: mae 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=4.5=1_gnu 7 | - astroid=2.6.6=py36h06a4308_0 8 | - ca-certificates=2023.01.10=h06a4308_0 9 | - certifi=2021.5.30=py36h06a4308_0 10 | - colorama=0.4.4=pyhd3eb1b0_0 11 | - isort=5.9.3=pyhd3eb1b0_0 12 | - lazy-object-proxy=1.6.0=py36h27cfd23_0 13 | - ld_impl_linux-64=2.35.1=h7274673_9 14 | - libffi=3.3=he6710b0_2 15 | - libgcc-ng=9.3.0=h5101ec6_17 16 | - libgomp=9.3.0=h5101ec6_17 17 | - libstdcxx-ng=9.3.0=hd4cf53a_17 18 | - mccabe=0.6.1=py36_1 19 | - ncurses=6.3=h7f8727e_2 20 | - openssl=1.1.1t=h7f8727e_0 21 | - pip=21.2.2=py36h06a4308_0 22 | - pylint=2.9.6=py36h06a4308_1 23 | - python=3.6.13=h12debd9_1 24 | - readline=8.1.2=h7f8727e_1 25 | - setuptools=58.0.4=py36h06a4308_0 26 | - sqlite=3.38.2=hc218d9a_0 27 | - tk=8.6.11=h1ccaba5_0 28 | - toml=0.10.2=pyhd3eb1b0_0 29 | - tqdm=4.63.0=pyhd3eb1b0_0 30 | - typing-extensions=4.1.1=hd3eb1b0_0 31 | - typing_extensions=4.1.1=pyh06a4308_0 32 | - wheel=0.37.1=pyhd3eb1b0_0 33 | - wrapt=1.12.1=py36h7b6447c_1 34 | - xz=5.2.5=h7b6447c_0 35 | - zlib=1.2.12=h7f8727e_1 36 | - pip: 37 | - absl-py==1.0.0 38 | - antlr4-python3-runtime==4.9.3 39 | - appdirs==1.4.4 40 | - attributee==0.1.7 41 | - attrs==21.4.0 42 | - beautifulsoup4==4.11.1 43 | - bidict==0.21.4 44 | - black==21.4b2 45 | - cachetools==4.2.4 46 | - cffi==1.15.0 47 | - charset-normalizer==2.0.12 48 | - click==8.0.4 49 | - cloudpickle==2.0.0 50 | - cycler==0.11.0 51 | - cython==0.29.28 52 | - dataclasses==0.8 53 | - davis==0.3 54 | - decorator==4.4.2 55 | - decord==0.6.0 56 | - docker-pycreds==0.4.0 57 | - dominate==2.6.0 58 | - easydict==1.9 59 | - filelock==3.4.1 60 | - future==0.18.2 61 | - fvcore==0.1.5.post20220512 62 | - gdown==4.4.0 63 | - gitdb==4.0.9 64 | - gitpython==3.1.20 65 | - google-auth==2.6.6 66 | - google-auth-oauthlib==0.4.6 67 | - grpcio==1.44.0 68 | - hydra-core==1.2.0 69 | - idna==3.3 70 | - imageio==2.15.0 71 | - imgaug==0.4.0 72 | - importlib-metadata==4.8.3 73 | - importlib-resources==5.4.0 74 | - iopath==0.1.9 75 | - jpeg4py==0.1.4 76 | - jsonpatch==1.32 77 | - jsonpointer==2.3 78 | - jsonschema==3.2.0 79 | - kiwisolver==1.3.1 80 | - llvmlite==0.36.0 81 | - lmdb==1.3.0 82 | - markdown==3.3.6 83 | - matplotlib==3.3.4 84 | - mypy-extensions==0.4.3 85 | - networkx==2.5.1 86 | - nose==1.3.7 87 | - numba==0.53.1 88 | - numpy==1.19.5 89 | - oauthlib==3.2.0 90 | - omegaconf==2.2.2 91 | - onnx==1.11.0 92 | - onnxruntime-gpu==1.6.0 93 | - opencv-python==4.5.5.64 94 | - ordered-set==4.0.2 95 | - packaging==21.3 96 | - pandas==1.1.5 97 | - pathspec==0.9.0 98 | - pathtools==0.1.2 99 | - phx-class-registry==3.0.5 100 | - pillow==8.4.0 101 | - pillow-simd==7.0.0.post3 102 | - portalocker==2.5.1 103 | - prettytable==2.5.0 104 | - progressbar==2.5 105 | - progressbar2==3.55.0 106 | - promise==2.3 107 | - protobuf==3.19.4 108 | - psutil==5.9.2 109 | - pyasn1==0.4.8 110 | - pyasn1-modules==0.2.8 111 | - pycocotools==2.0.4 112 | - pycparser==2.21 113 | - pydot==1.4.2 114 | - pylatex==1.4.1 115 | - pyparsing==3.0.8 116 | - pyrsistent==0.18.0 117 | - pysocks==1.7.1 118 | - python-dateutil==2.8.2 119 | - python-utils==3.1.0 120 | - pytz==2022.1 121 | - pywavelets==1.1.1 122 | - pyyaml==6.0 123 | - pyzmq==22.3.0 124 | - regex==2022.7.9 125 | - requests==2.27.1 126 | - requests-oauthlib==1.3.1 127 | - resnest==0.0.5 128 | - rsa==4.8 129 | - scikit-image==0.17.2 130 | - scipy==1.5.4 131 | - sentry-sdk==1.9.9 132 | - setproctitle==1.2.3 133 | - shapely==1.8.1.post1 134 | - shortuuid==1.0.9 135 | - six==1.16.0 136 | - smmap==5.0.0 137 | - soupsieve==2.3.2.post1 138 | - spatial-correlation-sampler==0.4.0 139 | - submitit==1.4.2 140 | - tabulate==0.8.10 141 | - tb-nightly==2.9.0a20220421 142 | - tensorboard==2.9.1 143 | - tensorboard-data-server==0.6.1 144 | - tensorboard-plugin-wit==1.8.1 145 | - tensorboardx==2.5 146 | - termcolor==1.1.0 147 | - thinplate==1.0.0 148 | - thop==0.0.5-2204221051 149 | - tifffile==2020.9.3 150 | - tikzplotlib==0.9.12 151 | - timm==0.3.2 152 | - torch==1.8.1+cu111 153 | - torch-scatter==2.0.6 154 | - torchaudio==0.8.1 155 | - torchfile==0.1.0 156 | - torchvision==0.9.1+cu111 157 | - tornado==6.1 158 | - typed-ast==1.5.4 159 | - urllib3==1.26.12 160 | - visdom==0.1.8.9 161 | - vot-toolkit==0.5.3 162 | - vot-trax==3.0.3 163 | - wandb==0.13.3 164 | - wcwidth==0.2.5 165 | - websocket-client==1.3.1 166 | - werkzeug==2.0.3 167 | - wget==3.2 168 | - yacs==0.1.8 169 | - zipp==3.6.0 170 | prefix: /home/user/anaconda3/envs/mae 171 | -------------------------------------------------------------------------------- /inference_core.py: -------------------------------------------------------------------------------- 1 | from os import name 2 | import torch 3 | 4 | from model.eval_network import STCN 5 | from model.aggregate import aggregate 6 | 7 | from util.tensor_util import pad_divide_by 8 | from torch.nn import functional as F 9 | from matplotlib import pyplot as plt 10 | import cv2 11 | import numpy as np 12 | 13 | import time 14 | import os 15 | 16 | from model import models_vit 17 | import timm 18 | assert timm.__version__ == "0.3.2" # version check 19 | 20 | 21 | def show_cam_on_image(img: np.ndarray, 22 | mask: np.ndarray, 23 | use_rgb: bool = False, 24 | colormap: int = cv2.COLORMAP_JET) -> np.ndarray: 25 | """ This function overlays the cam mask on the image as an heatmap. 26 | By default the heatmap is in BGR format. 27 | :param img: The base image in RGB or BGR format. 28 | :param mask: The cam mask. 29 | :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. 30 | :param colormap: The OpenCV colormap to be used. 31 | :returns: The default image with the cam overlay. 32 | """ 33 | # print('max') 34 | # print(colormap.max()) 35 | heatmap = cv2.applyColorMap(np.uint8(255 * (mask/np.max(mask))), colormap) 36 | img = img/np.max(img) 37 | if use_rgb: 38 | heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) 39 | heatmap = np.float32(heatmap) / 255 40 | 41 | if np.max(img) > 1: 42 | raise Exception( 43 | "The input image should np.float32 in the range [0, 1]") 44 | 45 | cam = heatmap + img 46 | cam = cam / np.max(cam) 47 | return np.uint8(255 * cam) 48 | 49 | 50 | def interpolate_pos_embed_2D(pos_embed, kh, kw): 51 | 52 | num_extra_tokens = 1 53 | model_pos_tokens = pos_embed[:, num_extra_tokens:, :] 54 | model_token_size = int((model_pos_tokens.shape[1])**0.5) 55 | # pos_embed = net.pos_embed 56 | model_pos_tokens = pos_embed[:, num_extra_tokens:(model_token_size*model_token_size + 1), :] # bs, N, C 57 | extra_pos_tokens = pos_embed[:, :num_extra_tokens] 58 | 59 | embedding_size = extra_pos_tokens.shape[-1] 60 | 61 | if kh != model_token_size or kw != model_token_size: # do interpolation 62 | model_pos_tokens_temp = model_pos_tokens.reshape(-1, model_token_size, model_token_size, embedding_size).contiguous().permute(0, 3, 1, 2) # bs, c, h, w 63 | search_pos_tokens = torch.nn.functional.interpolate( 64 | model_pos_tokens_temp, size=(kh, kw), mode='bicubic', align_corners=False) 65 | search_pos_tokens = search_pos_tokens.permute(0, 2, 3, 1).contiguous().flatten(1, 2) 66 | else: 67 | search_pos_tokens = model_pos_tokens 68 | new_pos_embed = torch.cat((extra_pos_tokens, search_pos_tokens), dim=1) 69 | return new_pos_embed 70 | 71 | class InferenceCore_ViT: 72 | def __init__(self, prop_net:models_vit, images, num_objects, pos_embed_new, video_name=None): 73 | self.prop_net = prop_net 74 | 75 | # True dimensions 76 | t = images.shape[1] 77 | h, w = images.shape[-2:] 78 | 79 | # Pad each side to multiple of 16 80 | images, self.pad = pad_divide_by(images, 16) 81 | # Padded dimensions 82 | nh, nw = images.shape[-2:] 83 | 84 | self.images = images 85 | self.device = 'cuda' 86 | 87 | self.k = num_objects 88 | 89 | # Background included, not always consistent (i.e. sum up to 1) 90 | self.prob = torch.zeros((self.k+1, t, 1, nh, nw), dtype=torch.float32, device=self.device) 91 | self.prob[0] = 1e-7 # for the background 92 | 93 | self.t, self.h, self.w = t, h, w 94 | self.nh, self.nw = nh, nw 95 | self.kh = self.nh//16 96 | self.kw = self.nw//16 97 | 98 | pos_embed_new_temp = interpolate_pos_embed_2D(pos_embed_new, self.kh, self.kw) 99 | self.prop_net.pos_embed_new = torch.nn.Parameter(pos_embed_new_temp) 100 | 101 | print('after interpolation:') 102 | print(self.prop_net.pos_embed_new.shape) 103 | 104 | # self.mem_bank = MemoryBank(k=self.k, top_k=top_k) 105 | self.video_name = video_name 106 | print('init inference_core') 107 | 108 | def visualize(self, att_weights, video_name, mode, frame_index): 109 | base_path = '/apdcephfs/share_1290939/qiangqwu/VOS/DAVIS/2017/trainval/JPEGImages/480p/'+video_name+'/{:05d}.jpg'.format(frame_index) 110 | image = cv2.imread(base_path) 111 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 112 | # W, H = 320, 320 113 | # image = cv2.resize(image, (W, H)) 114 | H, W, C = image.shape 115 | # here we create the canvas 116 | fig = plt.figure(constrained_layout=True, figsize=(25 * 0.7, 25 * 0.7)) 117 | # and we add one plot per reference point 118 | grid_len = 8 119 | # gs = fig.add_gridspec(grid_len+1, grid_len) 120 | gs = fig.add_gridspec(1, 1) 121 | axs = [] 122 | axs.append(fig.add_subplot(gs[0, 0])) 123 | # axs.append(fig.add_subplot(gs[0, 1])) 124 | # for i in range(1, grid_len): 125 | # for j in range(grid_len): 126 | # axs.append(fig.add_subplot(gs[i, j])) 127 | # axs[0].imshow(image) 128 | # axs[0].axis('off') 129 | temp_att = cv2.resize(np.mean(att_weights, axis=0), (W, H)) 130 | template_attention_map = show_cam_on_image(image, temp_att, use_rgb=True) 131 | axs[0].imshow(template_attention_map, cmap='cividis', interpolation='nearest') 132 | axs[0].axis('off') 133 | 134 | # axs = axs[2:] 135 | # token_index = 0 136 | # for ax in axs: 137 | # temp_att = cv2.resize(att_weights[token_index], (W, H)) 138 | # template_attention_map = show_cam_on_image(image, temp_att, use_rgb=True) 139 | # ax.imshow(template_attention_map, cmap='cividis', interpolation='nearest') 140 | # ax.axis('off') 141 | # # show_cam_on_image(cv2.resize(start_frame_img_rgb, (320, 320)), cv2.resize(dec_attn_weights_template, (320, 320)), use_rgb=True) 142 | # # ax.imshow(cv2.resize(z_patch, (320, 320))) 143 | # # ax.axis('off') 144 | # token_index += 1 145 | if not os.path.exists("./visualization/"+self.video_name): 146 | os.makedirs("./visualization/"+self.video_name) 147 | plt.savefig("./visualization/"+self.video_name+"/"+mode+'_{:05d}.png'.format(frame_index)) 148 | 149 | def do_pass(self, target_initial_mask, idx, end_idx, layer_index, use_local_window, use_token_learner): 150 | ''' 151 | target_initial_mask: N, 1, H, W 152 | ''' 153 | # self.mem_bank.add_memory(key_k, key_v) 154 | print('layer index: %d' %(layer_index)) 155 | closest_ti = end_idx 156 | 157 | # Note that we never reach closest_ti, just the frame before it 158 | this_range = range(idx+1, closest_ti) 159 | end = closest_ti - 1 160 | 161 | target_models = [] 162 | 163 | if layer_index > 0: 164 | # # tokenlearner w/ weights 165 | f1_v = self.prop_net(frames=self.images[:,idx].repeat(self.k, 1, 1, 1), mode='extract_feat_w_mask', layer_index=layer_index, mask_frames=target_initial_mask, use_window_partition=use_local_window, use_token_learner=use_token_learner) # N, dim, h, w 166 | if use_token_learner: 167 | bg_mask = torch.ones_like(target_initial_mask[0].unsqueeze(0)) - torch.sum(target_initial_mask.permute(1, 0, 2, 3), dim=1, keepdim=True) 168 | bg_mask[bg_mask != 1] = 0 169 | init_mask = torch.cat((bg_mask, target_initial_mask.permute(1, 0, 2, 3)), dim=1) # B, N+1, H, W 170 | h, w = f1_v.shape[-2:] 171 | props = F.interpolate(init_mask, size=(h, w), mode='bilinear', align_corners=False) # B, N+1, H, W 172 | hard_probs = torch.zeros_like(props) 173 | max_indices = torch.max(props, dim=1, keepdim=True)[1] 174 | hard_probs.scatter_(dim=1, index=max_indices, value=1.0) 175 | hard_probs = hard_probs[:, 1:] # B, N, H, W; The binary segmentation mask 176 | masks = torch.stack([props[:, 1:] * hard_probs, (1 - props[:, 1:]) * (1 - hard_probs)], dim=2) # B, N, 2, H, W; 177 | for i in range(self.k): 178 | if use_token_learner: 179 | # for visualizaiton 180 | # rf_tokens, for_weight, back_weight, h_mask, w_mask = self.prop_net(qk16=f1_v[i].unsqueeze(0), mode='tokenlearner_w_masks', mask = masks[:,i]) 181 | # target_models.append(rf_tokens) 182 | 183 | # for_weight = for_weight.view(1, -1, h_mask, w_mask) 184 | # back_weight = back_weight.view(1, -1, h_mask, w_mask) 185 | 186 | # self.visualize(for_weight.squeeze(0).cpu().data.numpy(), self.video_name, 'foreground_'+str(i), idx) 187 | # self.visualize(back_weight.squeeze(0).cpu().data.numpy(), self.video_name, 'background_'+str(i), idx) 188 | 189 | target_models.append(self.prop_net(qk16=f1_v[i].unsqueeze(0), mode='tokenlearner_w_masks', mask = masks[:,i])[0]) # B, num_token, c 190 | 191 | else: 192 | # print(f1_v[i].shape) 193 | target_models.append(f1_v[i].unsqueeze(0)) # 1, C, H, W 194 | 195 | for ti in this_range: 196 | mask_list = [] 197 | # s1 = time.time() 198 | 199 | if layer_index > 0: 200 | m16_f_ti = self.prop_net(frames=self.images[:,ti], mode='extract_feat_wo_mask', layer_index=layer_index, use_window_partition=use_local_window) # 1, hw, c 201 | _, L, _ = m16_f_ti.shape 202 | 203 | for obj_index in range(self.k): # for different objects. here we can also use previous frames 204 | # tokenlearner w/ weights; updating 205 | if ti == 1: #target_models[obj_index] 206 | m16_f2_v_index, m8_f2_v_index, m4_f2_v_index = self.prop_net(template=target_models[obj_index], search=m16_f_ti, mode='forward_together', layer_index=layer_index, H=self.images[:,ti].shape[-2], W=self.images[:,ti].shape[-1], L=L) 207 | out_mask = self.prop_net(m16=m16_f2_v_index, m8 = m8_f2_v_index, m4 = m4_f2_v_index, mode='segmentation_single_onject') 208 | mask_list.append(out_mask) 209 | 210 | else: 211 | m16_f2_v_index, m8_f2_v_index, m4_f2_v_index = self.prop_net(template=torch.cat((target_models[obj_index], target_models_latest[obj_index]), dim=1), search=m16_f_ti, mode='forward_together', layer_index=layer_index, H=self.images[:,ti].shape[-2], W=self.images[:,ti].shape[-1], L=L) 212 | out_mask = self.prop_net(m16=m16_f2_v_index, m8 = m8_f2_v_index, m4 = m4_f2_v_index, mode='segmentation_single_onject') 213 | mask_list.append(out_mask) 214 | else: 215 | for obj_index in range(self.k): # for different objects. here we can also use previous frames 216 | target_mask = target_initial_mask[obj_index].unsqueeze(0).unsqueeze(0) # 1,1,1, H, W 217 | 218 | if ti == 1: 219 | m16_f1_v1, m8_f1_v1, m4_f1_v1 = self.prop_net(memory_frames=self.images[:,idx].unsqueeze(1), mask_frames=target_mask, query_frame=self.images[:,ti], mode='backbone_full') 220 | out_mask = self.prop_net(m16=m16_f1_v1, m8 = m8_f1_v1, m4 = m4_f1_v1, mode='segmentation_single_onject') 221 | mask_list.append(out_mask) 222 | else: 223 | m16_f1_v1, m8_f1_v1, m4_f1_v1 = self.prop_net(memory_frames=torch.cat((self.images[:,idx].unsqueeze(1), self.images[:,ti-1].unsqueeze(1)), dim=1), mask_frames=torch.cat((target_mask, self.prob[:,ti-1][obj_index+1].unsqueeze(0).unsqueeze(0)), dim=1), query_frame=self.images[:,ti], mode='backbone_full') 224 | out_mask = self.prop_net(m16=m16_f1_v1, m8 = m8_f1_v1, m4 = m4_f1_v1, mode='segmentation_single_onject') 225 | mask_list.append(out_mask) 226 | # m16_f1_v1, m8_f1_v1, m4_f1_v1 = self.prop_net(memory_frames=torch.cat((self.images[:,idx].unsqueeze(1), self.images[:,ti-1].unsqueeze(1)), dim=1), mask_frames=torch.cat((target_mask, (torch.argmax(self.prob[:,ti-1], dim=0)==(obj_index+1)).float().unsqueeze(0).unsqueeze(0)), dim=1), query_frame=self.images[:,ti], mode='backbone') 227 | 228 | out_mask = torch.stack(mask_list, dim=0).flatten(0, 1) #num, 1, 1, h, w 229 | 230 | out_mask = aggregate(out_mask, keep_bg=True) 231 | self.prob[:,ti] = out_mask # N+1, 1, H, W 232 | 233 | 234 | if layer_index > 0: 235 | 236 | # tokenlearner w/ latest frame 237 | if ti < end: 238 | f_ti_v = self.prop_net(frames=self.images[:,ti].repeat(self.k, 1, 1, 1), mode='extract_feat_w_mask', layer_index=layer_index, mask_frames=out_mask[1:], use_window_partition=use_local_window, use_token_learner=use_token_learner) # N, dim, h, w 239 | target_models_latest = [] 240 | 241 | if use_token_learner: 242 | bg_mask = torch.ones_like(target_initial_mask[0].unsqueeze(0)) - torch.sum(out_mask[1:].permute(1, 0, 2, 3), dim=1, keepdim=True) 243 | bg_mask[bg_mask != 1] = 0 244 | init_mask = torch.cat((bg_mask, out_mask[1:].permute(1, 0, 2, 3)), dim=1) # B, N+1, H, W 245 | h, w = f_ti_v.shape[-2:] 246 | props = F.interpolate(init_mask, size=(h, w), mode='bilinear', align_corners=False) # B, N+1, H, W 247 | hard_probs = torch.zeros_like(props) 248 | max_indices = torch.max(props, dim=1, keepdim=True)[1] 249 | hard_probs.scatter_(dim=1, index=max_indices, value=1.0) 250 | hard_probs = hard_probs[:, 1:] # B, N, H, W; The binary segmentation mask 251 | masks = torch.stack([props[:, 1:] * hard_probs, (1 - props[:, 1:]) * (1 - hard_probs)], dim=2) # B, N, 2, H, W; 252 | for i in range(self.k): 253 | if use_token_learner: 254 | # rf_tokens, for_weight, back_weight, h_mask, w_mask = self.prop_net(qk16=f_ti_v[i].unsqueeze(0), mode='tokenlearner_w_masks', mask = masks[:,i]) 255 | # target_models_latest.append(rf_tokens) 256 | 257 | # for_weight = for_weight.view(1, -1, h_mask, w_mask) 258 | # back_weight = back_weight.view(1, -1, h_mask, w_mask) 259 | 260 | # self.visualize(for_weight.squeeze(0).cpu().data.numpy(), self.video_name, 'foreground_'+str(i), ti) 261 | # self.visualize(back_weight.squeeze(0).cpu().data.numpy(), self.video_name, 'background_'+str(i), ti) 262 | 263 | target_models_latest.append(self.prop_net(qk16=f_ti_v[i].unsqueeze(0), mode='tokenlearner_w_masks', mask = masks[:,i])[0]) # B, num_token, c 264 | else: 265 | target_models_latest.append(f_ti_v[i].unsqueeze(0)) #.permute(0, 2, 3, 1).view(1, -1, 768) 266 | 267 | return closest_ti 268 | 269 | def interact(self, mask, frame_idx, end_idx, layer_index, use_local_window, use_token_learner): 270 | mask, _ = pad_divide_by(mask.cuda(), 16) # 2, 1, 480, 912 271 | 272 | self.prob[:, frame_idx] = aggregate(mask, keep_bg=True) # the 1st frame 273 | 274 | # Propagate 275 | self.do_pass(mask, frame_idx, end_idx, layer_index, use_local_window, use_token_learner) 276 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/model/__init__.py -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/aggregate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/model/__pycache__/aggregate.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/cbam.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/model/__pycache__/cbam.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/eval_network.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/model/__pycache__/eval_network.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/losses.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/model/__pycache__/losses.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/mod_resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/model/__pycache__/mod_resnet.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/model/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/model.cpython-36.pyc.140569625763192: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/model/__pycache__/model.cpython-36.pyc.140569625763192 -------------------------------------------------------------------------------- /model/__pycache__/model.cpython-36.pyc.140666179276152: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/model/__pycache__/model.cpython-36.pyc.140666179276152 -------------------------------------------------------------------------------- /model/__pycache__/model_temp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/model/__pycache__/model_temp.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/model_temp_temp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/model/__pycache__/model_temp_temp.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/models_vit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/model/__pycache__/models_vit.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/models_vit.cpython-36.pyc.139742317369424: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/model/__pycache__/models_vit.cpython-36.pyc.139742317369424 -------------------------------------------------------------------------------- /model/__pycache__/models_vit.cpython-36.pyc.139824390907984: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/model/__pycache__/models_vit.cpython-36.pyc.139824390907984 -------------------------------------------------------------------------------- /model/__pycache__/models_vit.cpython-36.pyc.139838520842320: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/model/__pycache__/models_vit.cpython-36.pyc.139838520842320 -------------------------------------------------------------------------------- /model/__pycache__/models_vit.cpython-36.pyc.140109349144656: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/model/__pycache__/models_vit.cpython-36.pyc.140109349144656 -------------------------------------------------------------------------------- /model/__pycache__/models_vit.cpython-36.pyc.140142142203984: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/model/__pycache__/models_vit.cpython-36.pyc.140142142203984 -------------------------------------------------------------------------------- /model/__pycache__/models_vit.cpython-36.pyc.140277441316944: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/model/__pycache__/models_vit.cpython-36.pyc.140277441316944 -------------------------------------------------------------------------------- /model/__pycache__/modules.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/model/__pycache__/modules.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/network.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/model/__pycache__/network.cpython-36.pyc -------------------------------------------------------------------------------- /model/aggregate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | # Soft aggregation from STM 6 | def aggregate(prob, keep_bg=False): 7 | k = prob.shape 8 | new_prob = torch.cat([ 9 | torch.prod(1-prob, dim=0, keepdim=True), 10 | prob 11 | ], 0).clamp(1e-7, 1-1e-7) 12 | logits = torch.log((new_prob /(1-new_prob))) 13 | 14 | if keep_bg: 15 | return F.softmax(logits, dim=0) 16 | else: 17 | return F.softmax(logits, dim=0)[1:] -------------------------------------------------------------------------------- /model/cbam.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class BasicConv(nn.Module): 8 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): 9 | super(BasicConv, self).__init__() 10 | self.out_channels = out_planes 11 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 12 | 13 | def forward(self, x): 14 | x = self.conv(x) 15 | return x 16 | 17 | class Flatten(nn.Module): 18 | def forward(self, x): 19 | return x.view(x.size(0), -1) 20 | 21 | class ChannelGate(nn.Module): 22 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): 23 | super(ChannelGate, self).__init__() 24 | self.gate_channels = gate_channels 25 | self.mlp = nn.Sequential( 26 | Flatten(), 27 | nn.Linear(gate_channels, gate_channels // reduction_ratio), 28 | nn.ReLU(), 29 | nn.Linear(gate_channels // reduction_ratio, gate_channels) 30 | ) 31 | self.pool_types = pool_types 32 | def forward(self, x): 33 | channel_att_sum = None 34 | for pool_type in self.pool_types: 35 | if pool_type=='avg': 36 | avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 37 | channel_att_raw = self.mlp( avg_pool ) 38 | elif pool_type=='max': 39 | max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 40 | channel_att_raw = self.mlp( max_pool ) 41 | 42 | if channel_att_sum is None: 43 | channel_att_sum = channel_att_raw 44 | else: 45 | channel_att_sum = channel_att_sum + channel_att_raw 46 | 47 | scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) 48 | return x * scale 49 | 50 | class ChannelPool(nn.Module): 51 | def forward(self, x): 52 | return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) 53 | 54 | class SpatialGate(nn.Module): 55 | def __init__(self): 56 | super(SpatialGate, self).__init__() 57 | kernel_size = 7 58 | self.compress = ChannelPool() 59 | self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2) 60 | def forward(self, x): 61 | x_compress = self.compress(x) 62 | x_out = self.spatial(x_compress) 63 | scale = torch.sigmoid(x_out) # broadcasting 64 | return x * scale 65 | 66 | class CBAM(nn.Module): 67 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): 68 | super(CBAM, self).__init__() 69 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) 70 | self.no_spatial=no_spatial 71 | if not no_spatial: 72 | self.SpatialGate = SpatialGate() 73 | def forward(self, x): 74 | x_out = self.ChannelGate(x) 75 | if not self.no_spatial: 76 | x_out = self.SpatialGate(x_out) 77 | return x_out 78 | -------------------------------------------------------------------------------- /model/eval_network.py: -------------------------------------------------------------------------------- 1 | """ 2 | eval_network.py - Evaluation version of the network 3 | The logic is basically the same 4 | but with top-k and some implementation optimization 5 | 6 | The trailing number of a variable usually denote the stride 7 | e.g. f16 -> encoded features with stride 16 8 | """ 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from model.modules import * 15 | from model.network import Decoder 16 | 17 | from model import models_vit 18 | import timm 19 | assert timm.__version__ == "0.3.2" # version check 20 | 21 | 22 | class STCN(nn.Module): 23 | def __init__(self): 24 | super().__init__() 25 | self.key_encoder = KeyEncoder() 26 | self.value_encoder = ValueEncoder() 27 | 28 | # Projection from f16 feature space to key space 29 | self.key_proj = KeyProjection(1024, keydim=64) 30 | 31 | # Compress f16 a bit to use in decoding later on 32 | self.key_comp = nn.Conv2d(1024, 512, kernel_size=3, padding=1) 33 | 34 | self.decoder = Decoder() 35 | 36 | def encode_value(self, frame, kf16, masks): 37 | k, _, h, w = masks.shape 38 | 39 | # Extract memory key/value for a frame with multiple masks 40 | frame = frame.view(1, 3, h, w).repeat(k, 1, 1, 1) 41 | # Compute the "others" mask 42 | if k != 1: 43 | others = torch.cat([ 44 | torch.sum( 45 | masks[[j for j in range(k) if i!=j]] 46 | , dim=0, keepdim=True) 47 | for i in range(k)], 0) 48 | else: 49 | others = torch.zeros_like(masks) 50 | 51 | f16 = self.value_encoder(frame, kf16.repeat(k,1,1,1), masks, others) # 2, 512, 30, 57 52 | 53 | return f16.unsqueeze(2) # 2, 512, 1, 30, 57 54 | 55 | def encode_key(self, frame): 56 | f16, f8, f4 = self.key_encoder(frame) 57 | k16 = self.key_proj(f16) 58 | f16_thin = self.key_comp(f16) 59 | 60 | return k16, f16_thin, f16, f8, f4 61 | 62 | def segment_with_query(self, mem_bank, qf8, qf4, qk16, qv16): 63 | k = mem_bank.num_objects 64 | 65 | readout_mem = mem_bank.match_memory(qk16) 66 | qv16 = qv16.expand(k, -1, -1, -1) 67 | qv16 = torch.cat([readout_mem, qv16], 1) 68 | 69 | return torch.sigmoid(self.decoder(qv16, qf8, qf4)) 70 | 71 | 72 | 73 | class ViT_STCN(nn.Module): 74 | def __init__(self): 75 | super().__init__() 76 | vit_model = models_vit.__dict__['vit_base_patch16']( 77 | num_classes=1000, 78 | drop_path_rate=0.1, 79 | global_pool=True, 80 | single_object = False) 81 | 82 | -------------------------------------------------------------------------------- /model/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from util.tensor_util import compute_tensor_iu 5 | 6 | from collections import defaultdict 7 | 8 | 9 | def get_iou_hook(values): 10 | return 'iou/iou', (values['hide_iou/i']+1)/(values['hide_iou/u']+1) 11 | 12 | def get_sec_iou_hook(values): 13 | return 'iou/sec_iou', (values['hide_iou/sec_i']+1)/(values['hide_iou/sec_u']+1) 14 | 15 | iou_hooks_so = [ 16 | get_iou_hook, 17 | ] 18 | 19 | iou_hooks_mo = [ 20 | get_iou_hook, 21 | get_sec_iou_hook, 22 | ] 23 | 24 | 25 | # https://stackoverflow.com/questions/63735255/how-do-i-compute-bootstrapped-cross-entropy-loss-in-pytorch 26 | class BootstrappedCE(nn.Module): 27 | def __init__(self, start_warm=20000, end_warm=70000, top_p=0.15): 28 | # def __init__(self, start_warm=40000, end_warm=70000, top_p=0.15): 29 | super().__init__() 30 | 31 | self.start_warm = start_warm 32 | self.end_warm = end_warm 33 | self.top_p = top_p 34 | 35 | def forward(self, input, target, it): 36 | if it < self.start_warm: 37 | return F.cross_entropy(input, target), 1.0 38 | 39 | raw_loss = F.cross_entropy(input, target, reduction='none').view(-1) 40 | num_pixels = raw_loss.numel() 41 | 42 | if it > self.end_warm: 43 | this_p = self.top_p 44 | else: 45 | this_p = self.top_p + (1-self.top_p)*((self.end_warm-it)/(self.end_warm-self.start_warm)) 46 | loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False) 47 | return loss.mean(), this_p 48 | # return F.cross_entropy(input, target), 1.0 49 | 50 | 51 | class LossComputer: 52 | def __init__(self, para): 53 | super().__init__() 54 | self.para = para 55 | self.bce = BootstrappedCE() 56 | 57 | def compute(self, data, it): 58 | losses = defaultdict(int) 59 | 60 | b, s, _, _, _ = data['gt'].shape 61 | selector = data.get('selector', None) 62 | 63 | # we only use two frames 64 | for i in range(1, s): 65 | # Have to do it in a for-loop like this since not every entry has the second object 66 | # Well it's not a lot of iterations anyway 67 | # data['cls_gt']: 4, 3, 384, 384; 3: three frames 68 | for j in range(b): 69 | if selector is not None and selector[j][1] > 0.5: # we have the 2nd target 70 | loss, p = self.bce(data['logits_%d'%i][j:j+1], data['cls_gt'][j:j+1,i], it) # loss on the 1 (0,1,2) frame 71 | else: 72 | loss, p = self.bce(data['logits_%d'%i][j:j+1,:2], data['cls_gt'][j:j+1,i], it) 73 | 74 | losses['loss_%d'%i] += loss / b 75 | losses['p'] += p / b / (s-1) 76 | 77 | losses['total_loss'] += losses['loss_%d'%i] 78 | 79 | new_total_i, new_total_u = compute_tensor_iu(data['mask_%d'%i]>0.5, data['gt'][:,i]>0.5) 80 | losses['hide_iou/i'] += new_total_i 81 | losses['hide_iou/u'] += new_total_u 82 | 83 | if selector is not None: 84 | new_total_i, new_total_u = compute_tensor_iu(data['sec_mask_%d'%i]>0.5, data['sec_gt'][:,i]>0.5) 85 | losses['hide_iou/sec_i'] += new_total_i 86 | losses['hide_iou/sec_u'] += new_total_u 87 | 88 | return losses 89 | -------------------------------------------------------------------------------- /model/mod_resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | mod_resnet.py - A modified ResNet structure 3 | We append extra channels to the first conv by some network surgery 4 | """ 5 | 6 | from collections import OrderedDict 7 | import math 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.utils import model_zoo 12 | 13 | 14 | def load_weights_sequential(target, source_state, extra_chan=1): 15 | 16 | new_dict = OrderedDict() 17 | 18 | for k1, v1 in target.state_dict().items(): 19 | if not 'num_batches_tracked' in k1: 20 | if k1 in source_state: 21 | tar_v = source_state[k1] 22 | 23 | if v1.shape != tar_v.shape: 24 | # Init the new segmentation channel with zeros 25 | # print(v1.shape, tar_v.shape) 26 | c, _, w, h = v1.shape 27 | pads = torch.zeros((c,extra_chan,w,h), device=tar_v.device) 28 | nn.init.orthogonal_(pads) 29 | tar_v = torch.cat([tar_v, pads], 1) 30 | 31 | new_dict[k1] = tar_v 32 | 33 | target.load_state_dict(new_dict, strict=False) 34 | 35 | 36 | model_urls = { 37 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 38 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 39 | } 40 | 41 | 42 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 43 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 44 | padding=dilation, dilation=dilation) 45 | 46 | 47 | class BasicBlock(nn.Module): 48 | expansion = 1 49 | 50 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 51 | super(BasicBlock, self).__init__() 52 | self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | self.relu = nn.ReLU(inplace=True) 55 | self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation) 56 | self.bn2 = nn.BatchNorm2d(planes) 57 | self.downsample = downsample 58 | self.stride = stride 59 | 60 | def forward(self, x): 61 | residual = x 62 | 63 | out = self.conv1(x) 64 | out = self.bn1(out) 65 | out = self.relu(out) 66 | 67 | out = self.conv2(out) 68 | out = self.bn2(out) 69 | 70 | if self.downsample is not None: 71 | residual = self.downsample(x) 72 | 73 | out += residual 74 | out = self.relu(out) 75 | 76 | return out 77 | 78 | 79 | class Bottleneck(nn.Module): 80 | expansion = 4 81 | 82 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 83 | super(Bottleneck, self).__init__() 84 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1) 85 | self.bn1 = nn.BatchNorm2d(planes) 86 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation, 87 | padding=dilation) 88 | self.bn2 = nn.BatchNorm2d(planes) 89 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1) 90 | self.bn3 = nn.BatchNorm2d(planes * 4) 91 | self.relu = nn.ReLU(inplace=True) 92 | self.downsample = downsample 93 | self.stride = stride 94 | 95 | def forward(self, x): 96 | residual = x 97 | 98 | out = self.conv1(x) 99 | out = self.bn1(out) 100 | out = self.relu(out) 101 | 102 | out = self.conv2(out) 103 | out = self.bn2(out) 104 | out = self.relu(out) 105 | 106 | out = self.conv3(out) 107 | out = self.bn3(out) 108 | 109 | if self.downsample is not None: 110 | residual = self.downsample(x) 111 | 112 | out += residual 113 | out = self.relu(out) 114 | 115 | return out 116 | 117 | 118 | class ResNet(nn.Module): 119 | def __init__(self, block, layers=(3, 4, 23, 3), extra_chan=1): 120 | self.inplanes = 64 121 | super(ResNet, self).__init__() 122 | self.conv1 = nn.Conv2d(3+extra_chan, 64, kernel_size=7, stride=2, padding=3) 123 | self.bn1 = nn.BatchNorm2d(64) 124 | self.relu = nn.ReLU(inplace=True) 125 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 126 | self.layer1 = self._make_layer(block, 64, layers[0]) 127 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 128 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 129 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 130 | 131 | for m in self.modules(): 132 | if isinstance(m, nn.Conv2d): 133 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 134 | m.weight.data.normal_(0, math.sqrt(2. / n)) 135 | m.bias.data.zero_() 136 | elif isinstance(m, nn.BatchNorm2d): 137 | m.weight.data.fill_(1) 138 | m.bias.data.zero_() 139 | 140 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 141 | downsample = None 142 | if stride != 1 or self.inplanes != planes * block.expansion: 143 | downsample = nn.Sequential( 144 | nn.Conv2d(self.inplanes, planes * block.expansion, 145 | kernel_size=1, stride=stride), 146 | nn.BatchNorm2d(planes * block.expansion), 147 | ) 148 | 149 | layers = [block(self.inplanes, planes, stride, downsample)] 150 | self.inplanes = planes * block.expansion 151 | for i in range(1, blocks): 152 | layers.append(block(self.inplanes, planes, dilation=dilation)) 153 | 154 | return nn.Sequential(*layers) 155 | 156 | def resnet18(pretrained=True, extra_chan=0): 157 | model = ResNet(BasicBlock, [2, 2, 2, 2], extra_chan) 158 | if pretrained: 159 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet18']), extra_chan) 160 | return model 161 | 162 | def resnet50(pretrained=True, extra_chan=0): 163 | model = ResNet(Bottleneck, [3, 4, 6, 3], extra_chan) 164 | if pretrained: 165 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet50']), extra_chan) 166 | return model 167 | 168 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | model.py - warpper and utility functions for network training 3 | Compute loss, back-prop, update parameters, logging, etc. 4 | """ 5 | 6 | 7 | import os 8 | import time 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | 13 | from model.network import STCN 14 | from model.losses import LossComputer, iou_hooks_mo, iou_hooks_so 15 | from util.log_integrator import Integrator 16 | from util.image_saver import pool_pairs 17 | 18 | from torch.nn import functional as F 19 | 20 | from model import models_vit 21 | import timm 22 | assert timm.__version__ == "0.3.2" # version check 23 | 24 | 25 | def interpolate_pos_embed(pos_embed, search_size): 26 | 27 | num_extra_tokens = 1 28 | # pos_embed = net.pos_embed 29 | model_pos_tokens = pos_embed[:, num_extra_tokens:, :] # bs, N, C 30 | model_token_size = int(model_pos_tokens.shape[1]**0.5) 31 | extra_pos_tokens = pos_embed[:, :num_extra_tokens] 32 | 33 | embedding_size = extra_pos_tokens.shape[-1] 34 | 35 | if search_size != model_token_size: # do interpolation 36 | model_pos_tokens_temp = model_pos_tokens.reshape(-1, model_token_size, model_token_size, embedding_size).contiguous().permute(0, 3, 1, 2) # bs, c, h, w 37 | search_pos_tokens = torch.nn.functional.interpolate( 38 | model_pos_tokens_temp, size=(search_size, search_size), mode='bicubic', align_corners=False) 39 | search_pos_tokens = search_pos_tokens.permute(0, 2, 3, 1).contiguous().flatten(1, 2) 40 | else: 41 | search_pos_tokens = model_pos_tokens 42 | new_pos_embed = torch.cat((extra_pos_tokens, search_pos_tokens), dim=1) 43 | return new_pos_embed 44 | 45 | 46 | class ViTSTCNModel: 47 | def __init__(self, para, logger=None, save_path=None, local_rank=0, world_size=1): 48 | # we use the ViT model to perform the joint feature exctraction and interaction 49 | self.para = para 50 | self.single_object = para['single_object'] # default: False, multiple objects per frame; 51 | print('drop_path_rate: %f' %(para['drop_path_rate'])) 52 | if para['backbone_type'] == 'vit_base': 53 | # the same parameters w/ the mae fine-tuning 54 | vit_model = models_vit.__dict__['vit_base_patch16']( 55 | num_classes=1000, 56 | drop_path_rate=para['drop_path_rate'], 57 | global_pool=True, 58 | single_object = self.single_object, 59 | num_bases_foreground = para['num_bases_foreground'], 60 | num_bases_background = para['num_bases_background'], 61 | img_size=para['img_size'], 62 | vit_dim=768) 63 | checkpoint = torch.load('./pretrained_models/mae_pretrain_vit_base.pth', map_location='cpu') 64 | print("Load pre-trained checkpoint from") 65 | checkpoint_model = checkpoint['model'] 66 | state_dict = vit_model.state_dict() 67 | for k in ['head.weight', 'head.bias']: 68 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 69 | print(f"Removing key {k} from pretrained checkpoint") 70 | del checkpoint_model[k] 71 | elif para['backbone_type'] == 'vit_large': 72 | vit_model = models_vit.__dict__['vit_large_patch16']( 73 | num_classes=1000, 74 | drop_path_rate=para['drop_path_rate'], 75 | global_pool=True, 76 | single_object = self.single_object, 77 | num_bases_foreground = para['num_bases_foreground'], 78 | num_bases_background = para['num_bases_background'], 79 | img_size=para['img_size'], 80 | vit_dim=1024) 81 | checkpoint = torch.load('./pretrained_models/mae_pretrain_vit_large.pth', map_location='cpu') 82 | print("Load pre-trained checkpoint from") 83 | checkpoint_model = checkpoint['model'] 84 | state_dict = vit_model.state_dict() 85 | for k in ['head.weight', 'head.bias']: 86 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 87 | print(f"Removing key {k} from pretrained checkpoint") 88 | del checkpoint_model[k] 89 | 90 | # load pre-trained model 91 | msg = vit_model.load_state_dict(checkpoint_model, strict=False) 92 | print(msg) 93 | # interpolate position embedding 94 | pos_embed_new = interpolate_pos_embed(vit_model.pos_embed, int(para['img_size']//16)) 95 | vit_model.pos_embed_new = torch.nn.Parameter(pos_embed_new, requires_grad=False) 96 | 97 | ################### Original Setting used in STCN ############################## 98 | 99 | self.local_rank = local_rank 100 | 101 | self.STCN = nn.parallel.DistributedDataParallel( 102 | vit_model.cuda(), 103 | device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False, find_unused_parameters=True) 104 | 105 | # Setup logger when local_rank=0 106 | self.logger = logger 107 | self.save_path = save_path 108 | if logger is not None: 109 | self.last_time = time.time() 110 | self.train_integrator = Integrator(self.logger, distributed=True, local_rank=local_rank, world_size=world_size) 111 | if self.single_object: 112 | self.train_integrator.add_hook(iou_hooks_so) 113 | else: 114 | self.train_integrator.add_hook(iou_hooks_mo) 115 | self.loss_computer = LossComputer(para) 116 | 117 | self.train() 118 | self.optimizer = optim.Adam(filter( 119 | lambda p: p.requires_grad, self.STCN.parameters()), lr=para['lr'], weight_decay=1e-7) 120 | self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, para['steps'], para['gamma']) 121 | 122 | 123 | if para['amp']: 124 | self.scaler = torch.cuda.amp.GradScaler() 125 | 126 | # Logging info 127 | self.report_interval = 100 128 | self.save_im_interval = 2000 #800 129 | self.save_model_interval = 10000 #50000 130 | if para['debug']: 131 | self.report_interval = self.save_im_interval = 1 132 | self.num_bases_foreground = para['num_bases_foreground'] 133 | self.num_bases_background = para['num_bases_background'] 134 | self.img_size = para['img_size'] 135 | 136 | 137 | def do_pass(self, data, it=0): 138 | # No need to store the gradient outside training 139 | torch.set_grad_enabled(self._is_train) 140 | 141 | for k, v in data.items(): 142 | if type(v) != list and type(v) != dict and type(v) != int: 143 | data[k] = v.cuda(non_blocking=True) 144 | 145 | out = {} 146 | Fs = data['rgb'] # bs, 3, 3, 384, 384 147 | Ms = data['gt'] # bs, 3, 1, 384, 384 148 | 149 | with torch.cuda.amp.autocast(enabled=self.para['amp']): 150 | 151 | if not self.single_object: 152 | 153 | sec_Ms = data['sec_gt'] # 4, 3, 1, 384, 384 154 | selector = data['selector'] # 4, 2 155 | 156 | if self.para['layer_index'] > 0: 157 | 158 | f1_v1 = self.STCN(frames=Fs[:,0], mode='extract_feat_w_mask', layer_index=self.para['layer_index'], mask_frames=Ms[:,0], use_window_partition=self.para['use_window_partition'], use_token_learner=self.para['use_token_learner']) #bs, dim, h, w 159 | f1_v2 = self.STCN(frames=Fs[:,0], mode='extract_feat_w_mask', layer_index=self.para['layer_index'], mask_frames=sec_Ms[:,0], use_window_partition=self.para['use_window_partition'], use_token_learner=self.para['use_token_learner']) #bs, dim, h, w 160 | 161 | if self.para['use_token_learner']: 162 | first_target_mask = Ms[:, 0] # 1st target in the first frame # B, 1, H, W 163 | second_target_mask = sec_Ms[:, 0] # 2nd target in the second frame # B, 1, H, W 164 | bg_mask = torch.ones_like(first_target_mask) - torch.sum(torch.cat((first_target_mask, second_target_mask), dim=1), dim=1, keepdim=True) 165 | bg_mask[bg_mask != 1] = 0 166 | init_mask = torch.cat((bg_mask, first_target_mask, second_target_mask), dim=1) # B, N+1, H, W 167 | h, w = f1_v1.shape[-2:] 168 | props = F.interpolate(init_mask, size=(h, w), mode='bilinear', align_corners=False) # B, N+1, H, W 169 | hard_probs = torch.zeros_like(props) 170 | max_indices = torch.max(props, dim=1, keepdim=True)[1] 171 | hard_probs.scatter_(dim=1, index=max_indices, value=1.0) 172 | hard_probs = hard_probs[:, 1:] # B, N, H, W; The binary segmentation mask 173 | masks = torch.stack([props[:, 1:] * hard_probs, (1 - props[:, 1:]) * (1 - hard_probs)], dim=2) # B, N, 2, H, W; 174 | bases_fixed_v1, _, _, _, _ = self.STCN(qk16=f1_v1, mode='tokenlearner_w_masks', mask = masks[:,0]) # B, num_token, c 175 | bases_fixed_v2, _, _, _, _ = self.STCN(qk16=f1_v2, mode='tokenlearner_w_masks', mask = masks[:,1]) # B, num_token, c 176 | 177 | # _, L, _ = bases_fixed_v1.shape 178 | m16_f2 = self.STCN(frames=Fs[:,1], mode='extract_feat_wo_mask', layer_index=self.para['layer_index'], use_window_partition=self.para['use_window_partition']) #bs, hw, dim 179 | 180 | _, L, _ = m16_f2.shape 181 | if self.para['use_token_learner']: 182 | m16_f2_v1, m8_f2_v1, m4_f2_v1 = self.STCN(template=bases_fixed_v1, mode='forward_together', search=m16_f2, layer_index=self.para['layer_index'], H=Fs[:,1].shape[-2], W=Fs[:,1].shape[-1], L=L) 183 | m16_f2_v2, m8_f2_v2, m4_f2_v2 = self.STCN(template=bases_fixed_v2, mode='forward_together', search=m16_f2, layer_index=self.para['layer_index'], H=Fs[:,1].shape[-2], W=Fs[:,1].shape[-1], L=L) 184 | else: 185 | m16_f2_v1, m8_f2_v1, m4_f2_v1 = self.STCN(template=f1_v1, mode='forward_together', search=m16_f2, layer_index=self.para['layer_index'], H=Fs[:,1].shape[-2], W=Fs[:,1].shape[-1], L=L) 186 | m16_f2_v2, m8_f2_v2, m4_f2_v2 = self.STCN(template=f1_v2, mode='forward_together', search=m16_f2, layer_index=self.para['layer_index'], H=Fs[:,1].shape[-2], W=Fs[:,1].shape[-1], L=L) 187 | 188 | 189 | # Segment frame 1 with frame 0, prev_mask_all: B, N+1, H, W 190 | prev_logits, prev_mask, prev_mask_all = self.STCN(m16=torch.cat((m16_f2_v1.unsqueeze(1), m16_f2_v2.unsqueeze(1)), dim=1), 191 | m8=torch.cat((m8_f2_v1.unsqueeze(1), m8_f2_v2.unsqueeze(1)), dim=1), 192 | m4=torch.cat((m4_f2_v1.unsqueeze(1), m4_f2_v2.unsqueeze(1)), dim=1), selector=selector, mode='segmentation') 193 | else: 194 | m16_f2_v1, m8_f2_v1, m4_f2_v1 = self.STCN(memory_frames=Fs[:,0].unsqueeze(1), mask_frames=Ms[:,0].unsqueeze(1), query_frame=Fs[:,1], mode='backbone_full') 195 | m16_f2_v2, m8_f2_v2, m4_f2_v2 = self.STCN(memory_frames=Fs[:,0].unsqueeze(1), mask_frames=sec_Ms[:,0].unsqueeze(1), query_frame=Fs[:,1], mode='backbone_full') 196 | prev_logits, prev_mask, prev_mask_all = self.STCN(m16=torch.cat((m16_f2_v1.unsqueeze(1), m16_f2_v2.unsqueeze(1)), dim=1), 197 | m8=torch.cat((m8_f2_v1.unsqueeze(1), m8_f2_v2.unsqueeze(1)), dim=1), 198 | m4=torch.cat((m4_f2_v1.unsqueeze(1), m4_f2_v2.unsqueeze(1)), dim=1), selector=selector, mode='segmentation') 199 | 200 | 201 | out['mask_1'] = prev_mask[:,0:1] # frame-1-taregt-1 202 | out['sec_mask_1'] = prev_mask[:,1:2] # frame-1-target-2 203 | 204 | out['logits_1'] = prev_logits # frame-1 205 | 206 | if self._do_log or self._is_train: 207 | losses = self.loss_computer.compute({**data, **out}, it) 208 | 209 | # Logging 210 | if self._do_log: 211 | self.integrator.add_dict(losses) 212 | if self._is_train: 213 | if it % self.save_im_interval == 0 and it != 0: #save_im_interval = 800 214 | if self.logger is not None: 215 | images = {**data, **out} 216 | size = (384, 384) 217 | self.logger.log_cv2('train/pairs', pool_pairs(images, size, self.single_object), it) 218 | 219 | if self._is_train: 220 | if (it) % self.report_interval == 0 and it != 0: 221 | if self.logger is not None: 222 | self.logger.log_scalar('train/lr', self.scheduler.get_last_lr()[0], it) 223 | self.logger.log_metrics('train', 'time', (time.time()-self.last_time)/self.report_interval, it) 224 | self.last_time = time.time() 225 | self.train_integrator.finalize('train', it) 226 | self.train_integrator.reset_except_hooks() 227 | 228 | if it % self.save_model_interval == 0 and it != 0: # self.save_model_interval = 50000 229 | if self.logger is not None: 230 | self.save(it) 231 | 232 | # Backward pass 233 | # This should be done outside autocast 234 | # but I trained it like this and it worked fine 235 | # so I am keeping it this way for reference 236 | self.optimizer.zero_grad(set_to_none=True) 237 | if self.para['amp']: 238 | self.scaler.scale(losses['total_loss']).backward() 239 | self.scaler.step(self.optimizer) 240 | self.scaler.update() 241 | else: 242 | losses['total_loss'].backward() 243 | self.optimizer.step() 244 | self.scheduler.step() 245 | 246 | def save(self, it): 247 | if self.save_path is None: 248 | print('Saving has been disabled.') 249 | return 250 | 251 | os.makedirs(os.path.dirname(self.save_path), exist_ok=True) 252 | model_path = self.save_path + ('_%s.pth' % it) 253 | torch.save(self.STCN.module.state_dict(), model_path) 254 | print('Model saved to %s.' % model_path) 255 | 256 | self.save_checkpoint(it) 257 | 258 | def save_checkpoint(self, it): 259 | if self.save_path is None: 260 | print('Saving has been disabled.') 261 | return 262 | 263 | os.makedirs(os.path.dirname(self.save_path), exist_ok=True) 264 | checkpoint_path = self.save_path + '_checkpoint.pth' 265 | checkpoint = { 266 | 'it': it, 267 | 'network': self.STCN.module.state_dict(), 268 | 'optimizer': self.optimizer.state_dict(), 269 | 'scheduler': self.scheduler.state_dict()} 270 | torch.save(checkpoint, checkpoint_path) 271 | 272 | print('Checkpoint saved to %s.' % checkpoint_path) 273 | 274 | def load_model(self, path): 275 | # This method loads everything and should be used to resume training 276 | map_location = 'cuda:%d' % self.local_rank 277 | checkpoint = torch.load(path, map_location={'cuda:0': map_location}) 278 | 279 | it = checkpoint['it'] 280 | network = checkpoint['network'] 281 | optimizer = checkpoint['optimizer'] 282 | scheduler = checkpoint['scheduler'] 283 | 284 | map_location = 'cuda:%d' % self.local_rank 285 | self.STCN.module.load_state_dict(network) 286 | self.optimizer.load_state_dict(optimizer) 287 | self.scheduler.load_state_dict(scheduler) 288 | 289 | print('Model loaded.') 290 | 291 | return it 292 | 293 | def load_network(self, path): 294 | # This method loads only the network weight and should be used to load a pretrained model 295 | map_location = 'cuda:%d' % self.local_rank 296 | src_dict = torch.load(path, map_location={'cuda:0': map_location}) 297 | 298 | # Maps SO weight (without other_mask) to MO weight (with other_mask) 299 | for k in list(src_dict.keys()): 300 | if k == 'value_encoder.conv1.weight': 301 | if src_dict[k].shape[1] == 4: 302 | pads = torch.zeros((64,1,7,7), device=src_dict[k].device) 303 | nn.init.orthogonal_(pads) 304 | src_dict[k] = torch.cat([src_dict[k], pads], 1) 305 | 306 | self.STCN.module.load_state_dict(src_dict) 307 | print('Network weight loaded:', path) 308 | 309 | def train(self): 310 | self._is_train = True 311 | self._do_log = True 312 | self.integrator = self.train_integrator 313 | # Shall be in eval() mode to freeze BN parameters # do we need to freeze all BN weights for MAE? 314 | self.STCN.train() 315 | return self 316 | 317 | def val(self): 318 | self._is_train = False 319 | self._do_log = True 320 | self.STCN.eval() 321 | return self 322 | 323 | def test(self): 324 | self._is_train = False 325 | self._do_log = False 326 | self.STCN.eval() 327 | return self 328 | 329 | -------------------------------------------------------------------------------- /model/model_temp_temp.py: -------------------------------------------------------------------------------- 1 | """ 2 | model.py - warpper and utility functions for network training 3 | Compute loss, back-prop, update parameters, logging, etc. 4 | """ 5 | 6 | 7 | import os 8 | import time 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | 13 | from model.network import STCN 14 | from model.losses import LossComputer, iou_hooks_mo, iou_hooks_so 15 | from util.log_integrator import Integrator 16 | from util.image_saver import pool_pairs 17 | 18 | from model import models_vit 19 | import timm 20 | assert timm.__version__ == "0.3.2" # version check 21 | 22 | 23 | class STCNModel: 24 | def __init__(self, para, logger=None, save_path=None, local_rank=0, world_size=1): 25 | self.para = para 26 | self.single_object = para['single_object'] # default: False, multiple objects per frame; 27 | self.local_rank = local_rank 28 | 29 | self.STCN = nn.parallel.DistributedDataParallel( 30 | STCN(self.single_object).cuda(), 31 | device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False) 32 | 33 | # Setup logger when local_rank=0 34 | self.logger = logger 35 | self.save_path = save_path 36 | if logger is not None: 37 | self.last_time = time.time() 38 | self.train_integrator = Integrator(self.logger, distributed=True, local_rank=local_rank, world_size=world_size) 39 | if self.single_object: 40 | self.train_integrator.add_hook(iou_hooks_so) 41 | else: 42 | self.train_integrator.add_hook(iou_hooks_mo) 43 | self.loss_computer = LossComputer(para) 44 | 45 | self.train() 46 | self.optimizer = optim.Adam(filter( 47 | lambda p: p.requires_grad, self.STCN.parameters()), lr=para['lr'], weight_decay=1e-7) 48 | self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, para['steps'], para['gamma']) 49 | if para['amp']: 50 | self.scaler = torch.cuda.amp.GradScaler() 51 | 52 | # Logging info 53 | self.report_interval = 100 54 | self.save_im_interval = 800 55 | self.save_model_interval = 50000 56 | if para['debug']: 57 | self.report_interval = self.save_im_interval = 1 58 | 59 | def do_pass(self, data, it=0): 60 | # No need to store the gradient outside training 61 | torch.set_grad_enabled(self._is_train) 62 | 63 | for k, v in data.items(): 64 | if type(v) != list and type(v) != dict and type(v) != int: 65 | data[k] = v.cuda(non_blocking=True) 66 | 67 | out = {} 68 | Fs = data['rgb'] # bs, 3, 3, 384, 384 69 | Ms = data['gt'] # bs, 3, 1, 384, 384 70 | 71 | with torch.cuda.amp.autocast(enabled=self.para['amp']): 72 | # key features never change, compute once 73 | # k16: key: 64-D, kf16: key:512-D; kf16, kf8, kf4, which are all output from the backbone 74 | k16, kf16_thin, kf16, kf8, kf4 = self.STCN('encode_key', Fs) 75 | 76 | if self.single_object: 77 | ref_v = self.STCN('encode_value', Fs[:,0], kf16[:,0], Ms[:,0]) 78 | 79 | # Segment frame 1 with frame 0 80 | prev_logits, prev_mask = self.STCN('segment', 81 | k16[:,:,1], kf16_thin[:,1], kf8[:,1], kf4[:,1], 82 | k16[:,:,0:1], ref_v) 83 | prev_v = self.STCN('encode_value', Fs[:,1], kf16[:,1], prev_mask) 84 | 85 | values = torch.cat([ref_v, prev_v], 2) 86 | 87 | del ref_v 88 | 89 | # Segment frame 2 with frame 0 and 1 90 | this_logits, this_mask = self.STCN('segment', 91 | k16[:,:,2], kf16_thin[:,2], kf8[:,2], kf4[:,2], 92 | k16[:,:,0:2], values) 93 | 94 | out['mask_1'] = prev_mask 95 | out['mask_2'] = this_mask 96 | out['logits_1'] = prev_logits 97 | out['logits_2'] = this_logits 98 | else: 99 | sec_Ms = data['sec_gt'] # 4, 3, 1, 384, 384 100 | selector = data['selector'] # 4, 2 101 | 102 | 103 | # Why did you only consider up to 2 objects during training since one image might contain more than 5 objects? 104 | # Also, in the 'encode_value', other foreground objects mask is also engaged to generate the mask feature for one object. Do you think it could help the model to filter out background information and contribute to the models' performance? 105 | # Probably. But that just follows from STM and we did not specifically control it. 106 | 107 | # two targets in the frame 0 108 | ref_v1 = self.STCN('encode_value', Fs[:,0], kf16[:,0], Ms[:,0], sec_Ms[:,0]) #Fs: 4, 3, 3, 384, 384; ref_v1: 4, 512, 1, 24, 24 109 | ref_v2 = self.STCN('encode_value', Fs[:,0], kf16[:,0], sec_Ms[:,0], Ms[:,0]) #ref_v2: 4, 512, 1, 24, 24 110 | ref_v = torch.stack([ref_v1, ref_v2], 1) # 4, 2, 512, 1, 24, 24 111 | 112 | # Segment frame 1 with frame 0 113 | prev_logits, prev_mask = self.STCN('segment', 114 | k16[:,:,1], kf16_thin[:,1], kf8[:,1], kf4[:,1], 115 | k16[:,:,0:1], ref_v, selector) 116 | 117 | # two targets in the frame 1 118 | prev_v1 = self.STCN('encode_value', Fs[:,1], kf16[:,1], prev_mask[:,0:1], prev_mask[:,1:2]) 119 | prev_v2 = self.STCN('encode_value', Fs[:,1], kf16[:,1], prev_mask[:,1:2], prev_mask[:,0:1]) 120 | prev_v = torch.stack([prev_v1, prev_v2], 1) 121 | values = torch.cat([ref_v, prev_v], 3) 122 | 123 | del ref_v 124 | 125 | # Segment frame 2 with frame 0 and 1 126 | this_logits, this_mask = self.STCN('segment', 127 | k16[:,:,2], kf16_thin[:,2], kf8[:,2], kf4[:,2], 128 | k16[:,:,0:2], values, selector) 129 | 130 | out['mask_1'] = prev_mask[:,0:1] # frame-1-taregt-1 shape: 4, 1, 384, 384 131 | out['mask_2'] = this_mask[:,0:1] # frame-2-target-1 132 | out['sec_mask_1'] = prev_mask[:,1:2] # frame-1-target-2 133 | out['sec_mask_2'] = this_mask[:,1:2] # frame-2-target-2 134 | 135 | out['logits_1'] = prev_logits # frame-1 prev_logits: 4, 3, 384, 384: background, fir_target, sec_target 136 | out['logits_2'] = this_logits # frame-2: 4, 3, 384, 384: background, fir_target, sec_target 137 | 138 | if self._do_log or self._is_train: 139 | losses = self.loss_computer.compute({**data, **out}, it) 140 | 141 | # Logging 142 | if self._do_log: 143 | self.integrator.add_dict(losses) 144 | if self._is_train: 145 | if it % self.save_im_interval == 0 and it != 0: 146 | if self.logger is not None: 147 | images = {**data, **out} 148 | size = (384, 384) 149 | self.logger.log_cv2('train/pairs', pool_pairs(images, size, self.single_object), it) 150 | 151 | if self._is_train: 152 | if (it) % self.report_interval == 0 and it != 0: 153 | if self.logger is not None: 154 | self.logger.log_scalar('train/lr', self.scheduler.get_last_lr()[0], it) 155 | self.logger.log_metrics('train', 'time', (time.time()-self.last_time)/self.report_interval, it) 156 | self.last_time = time.time() 157 | self.train_integrator.finalize('train', it) 158 | self.train_integrator.reset_except_hooks() 159 | 160 | if it % self.save_model_interval == 0 and it != 0: 161 | if self.logger is not None: 162 | self.save(it) 163 | 164 | # Backward pass 165 | # This should be done outside autocast 166 | # but I trained it like this and it worked fine 167 | # so I am keeping it this way for reference 168 | self.optimizer.zero_grad(set_to_none=True) 169 | if self.para['amp']: 170 | self.scaler.scale(losses['total_loss']).backward() 171 | self.scaler.step(self.optimizer) 172 | self.scaler.update() 173 | else: 174 | losses['total_loss'].backward() 175 | self.optimizer.step() 176 | self.scheduler.step() 177 | 178 | def save(self, it): 179 | if self.save_path is None: 180 | print('Saving has been disabled.') 181 | return 182 | 183 | os.makedirs(os.path.dirname(self.save_path), exist_ok=True) 184 | model_path = self.save_path + ('_%s.pth' % it) 185 | torch.save(self.STCN.module.state_dict(), model_path) 186 | print('Model saved to %s.' % model_path) 187 | 188 | self.save_checkpoint(it) 189 | 190 | def save_checkpoint(self, it): 191 | if self.save_path is None: 192 | print('Saving has been disabled.') 193 | return 194 | 195 | os.makedirs(os.path.dirname(self.save_path), exist_ok=True) 196 | checkpoint_path = self.save_path + '_checkpoint.pth' 197 | checkpoint = { 198 | 'it': it, 199 | 'network': self.STCN.module.state_dict(), 200 | 'optimizer': self.optimizer.state_dict(), 201 | 'scheduler': self.scheduler.state_dict()} 202 | torch.save(checkpoint, checkpoint_path) 203 | 204 | print('Checkpoint saved to %s.' % checkpoint_path) 205 | 206 | def load_model(self, path): 207 | # This method loads everything and should be used to resume training 208 | map_location = 'cuda:%d' % self.local_rank 209 | checkpoint = torch.load(path, map_location={'cuda:0': map_location}) 210 | 211 | it = checkpoint['it'] 212 | network = checkpoint['network'] 213 | optimizer = checkpoint['optimizer'] 214 | scheduler = checkpoint['scheduler'] 215 | 216 | map_location = 'cuda:%d' % self.local_rank 217 | self.STCN.module.load_state_dict(network) 218 | self.optimizer.load_state_dict(optimizer) 219 | self.scheduler.load_state_dict(scheduler) 220 | 221 | print('Model loaded.') 222 | 223 | return it 224 | 225 | def load_network(self, path): 226 | # This method loads only the network weight and should be used to load a pretrained model 227 | map_location = 'cuda:%d' % self.local_rank 228 | src_dict = torch.load(path, map_location={'cuda:0': map_location}) 229 | 230 | # Maps SO weight (without other_mask) to MO weight (with other_mask) 231 | for k in list(src_dict.keys()): 232 | if k == 'value_encoder.conv1.weight': 233 | if src_dict[k].shape[1] == 4: 234 | pads = torch.zeros((64,1,7,7), device=src_dict[k].device) 235 | nn.init.orthogonal_(pads) 236 | src_dict[k] = torch.cat([src_dict[k], pads], 1) 237 | 238 | self.STCN.module.load_state_dict(src_dict) 239 | print('Network weight loaded:', path) 240 | 241 | def train(self): 242 | self._is_train = True 243 | self._do_log = True 244 | self.integrator = self.train_integrator 245 | # Shall be in eval() mode to freeze BN parameters 246 | self.STCN.eval() 247 | return self 248 | 249 | def val(self): 250 | self._is_train = False 251 | self._do_log = True 252 | self.STCN.eval() 253 | return self 254 | 255 | def test(self): 256 | self._is_train = False 257 | self._do_log = False 258 | self.STCN.eval() 259 | return self 260 | 261 | 262 | def interpolate_pos_embed(pos_embed, search_size): 263 | 264 | num_extra_tokens = 1 265 | # pos_embed = net.pos_embed 266 | model_pos_tokens = pos_embed[:, num_extra_tokens:, :] # bs, N, C 267 | model_token_size = int(model_pos_tokens.shape[1]**0.5) 268 | extra_pos_tokens = pos_embed[:, :num_extra_tokens] 269 | 270 | embedding_size = extra_pos_tokens.shape[-1] 271 | 272 | if search_size != model_token_size: # do interpolation 273 | model_pos_tokens_temp = model_pos_tokens.reshape(-1, model_token_size, model_token_size, embedding_size).contiguous().permute(0, 3, 1, 2) # bs, c, h, w 274 | search_pos_tokens = torch.nn.functional.interpolate( 275 | model_pos_tokens_temp, size=(search_size, search_size), mode='bicubic', align_corners=False) 276 | search_pos_tokens = search_pos_tokens.permute(0, 2, 3, 1).contiguous().flatten(1, 2) 277 | else: 278 | search_pos_tokens = model_pos_tokens 279 | new_pos_embed = torch.cat((extra_pos_tokens, search_pos_tokens), dim=1) 280 | return new_pos_embed 281 | 282 | 283 | import numpy as np 284 | from typing import Tuple 285 | # patch embeding for CLIP 286 | class PatchEmbed2D(nn.Module): 287 | 288 | def __init__( 289 | self, 290 | patch_size: Tuple[int, int] = (16, 16), 291 | in_channels: int = 3, 292 | embed_dim: int = 768, 293 | ): 294 | super().__init__() 295 | 296 | self.patch_size = patch_size 297 | self.in_channels = in_channels 298 | 299 | self.proj = nn.Linear(np.prod(patch_size) * in_channels, embed_dim) 300 | 301 | 302 | def _initialize_weights(self, x): 303 | nn.init.kaiming_normal_(self.proj.weight, 0.) 304 | nn.init.constant_(self.proj.bias, 0.) 305 | 306 | 307 | def forward(self, x: torch.Tensor, is_template=True): 308 | B, C, H, W = x.size() 309 | pH, pW = self.patch_size 310 | 311 | assert C == self.in_channels and H % pH == 0 and W % pW == 0 312 | 313 | x = x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 1, 3, 5).flatten(3).flatten(1, 2) 314 | x = self.proj(x) 315 | 316 | return x 317 | 318 | 319 | 320 | class ViTSTCNModel_triplet: 321 | def __init__(self, para, logger=None, save_path=None, local_rank=0, world_size=1): 322 | # we use the ViT model to perform the joint feature exctraction and interaction 323 | self.para = para 324 | print(self.para) 325 | self.single_object = para['single_object'] # default: False, multiple objects per frame; 326 | vit_model = models_vit.__dict__['vit_base_patch16']( 327 | num_classes=1000, 328 | drop_path_rate=para['droppath_rate'], #0.1 329 | global_pool=True, 330 | single_object = self.single_object, 331 | deep_low_map = para['deep_low_map'], 332 | use_tape = para['use_tape'], 333 | use_pos_emd = para['use_pos_emd'], 334 | valdim = para['valdim'], 335 | num_iters = para['num_iters'], 336 | num_bases = para['num_bases'], 337 | tau_value = para['tau'], 338 | num_bases_foreground = para['num_bases_foreground'], 339 | num_bases_background = para['num_bases_background'], 340 | img_size=para['img_size']) 341 | if para['pretrain_weights'] == 'mae': 342 | print('load mae weights!!!') 343 | checkpoint = torch.load('/apdcephfs/private_qiangqwu/Projects/STCN/pretrain_models/mae_pretrain_vit_base.pth', map_location='cpu') 344 | checkpoint_model = checkpoint['model'] 345 | elif para['pretrain_weights'] == 'clip': 346 | print('CLIP weights: need to be implemented.') 347 | # For CLIP 348 | checkpoint_model = torch.load('/apdcephfs/private_qiangqwu/Projects/STCN/ft_local/ViT-B-16-CLIP.pth', map_location='cpu') 349 | vit_model.patch_embed = PatchEmbed2D() 350 | 351 | # print('load imagenet weights!!!') 352 | # checkpoint = torch.load('/apdcephfs/private_qiangqwu/Projects/vit_ostrack/pretrain_models/jx_vit_base_p16_224-80ecf9dd.pth', map_location='cpu') 353 | # print("Load pre-trained checkpoint from") 354 | # checkpoint_model = checkpoint #['model'] 355 | 356 | state_dict = vit_model.state_dict() 357 | for k in ['head.weight', 'head.bias']: 358 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 359 | print(f"Removing key {k} from pretrained checkpoint") 360 | del checkpoint_model[k] 361 | 362 | # interpolate position embedding 363 | # load pre-trained model 364 | msg = vit_model.load_state_dict(checkpoint_model, strict=False) 365 | print(msg) 366 | # interpolate the position embedding 367 | print('interpolation after the pretrained weight loaded!!!!!!') 368 | pos_embed_new = interpolate_pos_embed(vit_model.pos_embed, int(para['img_size']//16)) 369 | vit_model.pos_embed_new = torch.nn.Parameter(pos_embed_new, requires_grad=False) 370 | print(vit_model.pos_embed_new) 371 | ################### Original Setting used in STCN ############################## 372 | 373 | self.local_rank = local_rank 374 | 375 | self.STCN = nn.parallel.DistributedDataParallel( 376 | vit_model.cuda(), 377 | device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False, find_unused_parameters=True) 378 | 379 | # Setup logger when local_rank=0 380 | self.logger = logger 381 | self.save_path = save_path 382 | if logger is not None: 383 | self.last_time = time.time() 384 | self.train_integrator = Integrator(self.logger, distributed=True, local_rank=local_rank, world_size=world_size) 385 | if self.single_object: 386 | self.train_integrator.add_hook(iou_hooks_so) 387 | else: 388 | self.train_integrator.add_hook(iou_hooks_mo) 389 | self.loss_computer = LossComputer(para) 390 | 391 | self.train() 392 | self.optimizer = optim.Adam(filter( 393 | lambda p: p.requires_grad, self.STCN.parameters()), lr=para['lr'], weight_decay=para['weight_decay']) #original: 1e-7 394 | self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, para['steps'], para['gamma']) 395 | 396 | # para['amp'] = False 397 | 398 | if para['amp']: 399 | self.scaler = torch.cuda.amp.GradScaler() 400 | 401 | # Logging info 402 | self.report_interval = 100 403 | self.save_im_interval = 2000 404 | self.save_model_interval = 10000 #50000 405 | self.img_size = para['img_size'] 406 | if para['debug']: 407 | self.report_interval = self.save_im_interval = 1 408 | 409 | 410 | def do_pass(self, data, it=0): 411 | # No need to store the gradient outside training 412 | torch.set_grad_enabled(self._is_train) 413 | 414 | for k, v in data.items(): 415 | if type(v) != list and type(v) != dict and type(v) != int: 416 | data[k] = v.cuda(non_blocking=True) 417 | 418 | out = {} 419 | Fs = data['rgb'] # bs, 3, 3, 384, 384 420 | Ms = data['gt'] # bs, 3, 1, 384, 384 421 | 422 | with torch.cuda.amp.autocast(enabled=self.para['amp']): 423 | # key features never change, compute once 424 | # k16: key: 64-D, kf16: key:512-D; kf16, kf8, kf4, which are all output from the backbone 425 | 426 | # k16, kf16_thin, kf16, kf8, kf4 = self.STCN('encode_key', Fs) 427 | # k16 = None 428 | # kf16_thin = None 429 | # kf16 = None 430 | # kf8 = None 431 | # kf4 = None 432 | 433 | if not self.single_object: 434 | sec_Ms = data['sec_gt'] # 4, 3, 1, 384, 384 435 | selector = data['selector'] # 4, 2 436 | 437 | f1_v1 = self.STCN(frames=Fs[:,0], mode='extract_feat_w_mask', layer_index=self.para['layer_index'], mask_frames=Ms[:,0]) #bs, dim, h, w 438 | f1_v2 = self.STCN(frames=Fs[:,0], mode='extract_feat_w_mask', layer_index=self.para['layer_index'], mask_frames=sec_Ms[:,0]) #bs, dim, h, w 439 | 440 | m16_f2 = self.STCN(frames=Fs[:,1], mode='extract_feat_wo_mask', layer_index=self.para['layer_index']) #bs, hw, dim 441 | _, L, _ = m16_f2.shape 442 | bs, c, h, w = f1_v1.shape 443 | m16_f2_v1, m8_f2_v1, m4_f2_v1 = self.STCN(template=f1_v1.permute(0, 2, 3, 1).view(bs, -1, c), mode='forward_together', search=m16_f2, layer_index=self.para['layer_index'], H=Fs[:,1].shape[-2], W=Fs[:,1].shape[-1], L=L) 444 | m16_f2_v2, m8_f2_v2, m4_f2_v2 = self.STCN(template=f1_v2.permute(0, 2, 3, 1).view(bs, -1, c), mode='forward_together', search=m16_f2, layer_index=self.para['layer_index'], H=Fs[:,1].shape[-2], W=Fs[:,1].shape[-1], L=L) 445 | 446 | # Segment frame 1 with frame 0 447 | # prev_mask: bs, num_obj, h, w 448 | prev_logits, prev_mask, prev_mask_all = self.STCN(m16=torch.cat((m16_f2_v1.unsqueeze(1), m16_f2_v2.unsqueeze(1)), dim=1), 449 | m8=torch.cat((m8_f2_v1.unsqueeze(1), m8_f2_v2.unsqueeze(1)), dim=1), 450 | m4=torch.cat((m4_f2_v1.unsqueeze(1), m4_f2_v2.unsqueeze(1)), dim=1), selector=selector, mode='segmentation') 451 | 452 | # del m16_f1_v1 453 | # del m8_f1_v1 454 | # del m4_f1_v1 455 | # del m16_f1_v2 456 | # del m8_f1_v2 457 | # del m4_f1_v2 458 | 459 | f2_v1 = self.STCN(frames=Fs[:,1], mode='extract_feat_w_mask', layer_index=self.para['layer_index'], mask_frames=prev_mask.detach()[:,0:1]) #bs, dim, h, w 460 | f2_v2 = self.STCN(frames=Fs[:,1], mode='extract_feat_w_mask', layer_index=self.para['layer_index'], mask_frames=prev_mask.detach()[:,1:2]) #bs, dim, h, w 461 | 462 | m16_f3 = self.STCN(frames=Fs[:,2], mode='extract_feat_wo_mask', layer_index=self.para['layer_index']) #bs, hw, dim 463 | _, L, _ = m16_f3.shape 464 | m16_f3_v1, m8_f3_v1, m4_f3_v1 = self.STCN(template=torch.cat((f1_v1.permute(0, 2, 3, 1).view(bs, -1, c), f2_v1.permute(0, 2, 3, 1).view(bs, -1, c)), dim=1), mode='forward_together', search=m16_f3, layer_index=self.para['layer_index'], H=Fs[:,1].shape[-2], W=Fs[:,1].shape[-1], L=L) 465 | m16_f3_v2, m8_f3_v2, m4_f3_v2 = self.STCN(template=torch.cat((f1_v2.permute(0, 2, 3, 1).view(bs, -1, c), f2_v2.permute(0, 2, 3, 1).view(bs, -1, c)), dim=1), mode='forward_together', search=m16_f3, layer_index=self.para['layer_index'], H=Fs[:,1].shape[-2], W=Fs[:,1].shape[-1], L=L) 466 | 467 | # # # Segment frame 2 with frames 0 and 1 468 | this_logits, this_mask, _ = self.STCN(m16=torch.cat((m16_f3_v1.unsqueeze(1), m16_f3_v2.unsqueeze(1)), dim=1), 469 | m8=torch.cat((m8_f3_v1.unsqueeze(1), m8_f3_v2.unsqueeze(1)), dim=1), 470 | m4=torch.cat((m4_f3_v1.unsqueeze(1), m4_f3_v2.unsqueeze(1)), dim=1), selector=selector, mode='segmentation') 471 | 472 | out['mask_1'] = prev_mask[:,0:1] # frame-1-taregt-1 473 | out['mask_2'] = this_mask[:,0:1] # frame-2-target-1 474 | out['sec_mask_1'] = prev_mask[:,1:2] # frame-1-target-2 475 | out['sec_mask_2'] = this_mask[:,1:2] # frame-2-target-2 476 | 477 | out['logits_1'] = prev_logits # frame-1 478 | out['logits_2'] = this_logits # frame-2 479 | 480 | if self._do_log or self._is_train: 481 | losses = self.loss_computer.compute({**data, **out}, it) 482 | 483 | # Logging 484 | if self._do_log: 485 | self.integrator.add_dict(losses) 486 | if self._is_train: 487 | if it % self.save_im_interval == 0 and it != 0: #save_im_interval = 800 488 | if self.logger is not None: 489 | images = {**data, **out} 490 | size = (self.img_size, self.img_size) 491 | self.logger.log_cv2('train/pairs', pool_pairs(images, size, self.single_object), it) 492 | 493 | if self._is_train: 494 | if (it) % self.report_interval == 0 and it != 0: 495 | if self.logger is not None: 496 | self.logger.log_scalar('train/lr', self.scheduler.get_last_lr()[0], it) 497 | self.logger.log_metrics('train', 'time', (time.time()-self.last_time)/self.report_interval, it) 498 | self.last_time = time.time() 499 | self.train_integrator.finalize('train', it) 500 | self.train_integrator.reset_except_hooks() 501 | 502 | if it % self.save_model_interval == 0 and it != 0: # self.save_model_interval = 50000 503 | if self.logger is not None: 504 | self.save(it) 505 | 506 | # Backward pass 507 | # This should be done outside autocast 508 | # but I trained it like this and it worked fine 509 | # so I am keeping it this way for reference 510 | self.optimizer.zero_grad(set_to_none=True) 511 | if self.para['amp']: 512 | self.scaler.scale(losses['total_loss']).backward() 513 | self.scaler.step(self.optimizer) 514 | self.scaler.update() 515 | else: 516 | losses['total_loss'].backward() 517 | self.optimizer.step() 518 | self.scheduler.step() 519 | 520 | def save(self, it): 521 | if self.save_path is None: 522 | print('Saving has been disabled.') 523 | return 524 | 525 | os.makedirs(os.path.dirname(self.save_path), exist_ok=True) 526 | model_path = self.save_path + ('_%s.pth' % it) 527 | torch.save(self.STCN.module.state_dict(), model_path) 528 | print('Model saved to %s.' % model_path) 529 | 530 | self.save_checkpoint(it) 531 | 532 | def save_checkpoint(self, it): 533 | if self.save_path is None: 534 | print('Saving has been disabled.') 535 | return 536 | 537 | os.makedirs(os.path.dirname(self.save_path), exist_ok=True) 538 | checkpoint_path = self.save_path + '_checkpoint.pth' 539 | checkpoint = { 540 | 'it': it, 541 | 'network': self.STCN.module.state_dict(), 542 | 'optimizer': self.optimizer.state_dict(), 543 | 'scheduler': self.scheduler.state_dict()} 544 | torch.save(checkpoint, checkpoint_path) 545 | 546 | print('Checkpoint saved to %s.' % checkpoint_path) 547 | 548 | def load_model(self, path): 549 | # This method loads everything and should be used to resume training 550 | map_location = 'cuda:%d' % self.local_rank 551 | checkpoint = torch.load(path, map_location={'cuda:0': map_location}) 552 | 553 | it = checkpoint['it'] 554 | network = checkpoint['network'] 555 | optimizer = checkpoint['optimizer'] 556 | scheduler = checkpoint['scheduler'] 557 | 558 | map_location = 'cuda:%d' % self.local_rank 559 | self.STCN.module.load_state_dict(network) 560 | self.optimizer.load_state_dict(optimizer) 561 | self.scheduler.load_state_dict(scheduler) 562 | 563 | print('Model loaded.') 564 | 565 | return it 566 | 567 | def load_network(self, path, backbone_only): 568 | # This method loads only the network weight and should be used to load a pretrained model 569 | map_location = 'cuda:%d' % self.local_rank 570 | src_dict = torch.load(path, map_location={'cuda:0': map_location}) 571 | 572 | key_names = list(src_dict.keys()) 573 | 574 | # here we only use the backbone, do not include the mask head and the multi-scale fpn 575 | if backbone_only: 576 | for k in key_names: 577 | if 'fpn' in k or 'stcn_decoder' in k: 578 | print('del ' + k) 579 | del src_dict[k] 580 | # here we use the transformer backbone with the multi-scale fpn 581 | else: 582 | for k in key_names: 583 | if 'stcn_decoder' in k: 584 | print('del ' + k) 585 | del src_dict[k] 586 | 587 | 588 | 589 | 590 | # # Maps SO weight (without other_mask) to MO weight (with other_mask) 591 | # for k in list(src_dict.keys()): 592 | # if k == 'value_encoder.conv1.weight': 593 | # if src_dict[k].shape[1] == 4: 594 | # pads = torch.zeros((64,1,7,7), device=src_dict[k].device) 595 | # nn.init.orthogonal_(pads) 596 | # src_dict[k] = torch.cat([src_dict[k], pads], 1) 597 | 598 | msc = self.STCN.module.load_state_dict(src_dict, strict=False) 599 | print(msc) 600 | print('Network weight loaded:', path) 601 | 602 | print('check grads for pos embed:') 603 | print(self.STCN.module.pos_embed_two_frame.requires_grad) 604 | print(self.STCN.module.pos_embed_three_frame.requires_grad) 605 | 606 | 607 | 608 | def train(self): 609 | self._is_train = True 610 | self._do_log = True 611 | self.integrator = self.train_integrator 612 | # Shall be in eval() mode to freeze BN parameters 613 | self.STCN.train() 614 | return self 615 | 616 | def val(self): 617 | self._is_train = False 618 | self._do_log = True 619 | self.STCN.eval() 620 | return self 621 | 622 | def test(self): 623 | self._is_train = False 624 | self._do_log = False 625 | self.STCN.eval() 626 | return self -------------------------------------------------------------------------------- /model/models_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | from tkinter.messagebox import NO 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | import timm.models.vision_transformer 19 | from timm.models.layers.helpers import * 20 | from timm.models.vision_transformer import PatchEmbed 21 | from model.network import Decoder 22 | import torch.nn.functional as F 23 | import time 24 | import math 25 | import torch.nn.functional as NF 26 | 27 | 28 | def l2norm(inp, dim): 29 | norm = torch.linalg.norm(inp, dim=dim, keepdim=True) + 1e-6 30 | return inp/norm 31 | 32 | 33 | def window_partition(x, window_size): 34 | """ 35 | Partition into non-overlapping windows with padding if needed. 36 | Args: 37 | x (tensor): input tokens with [B, H, W, C]. 38 | window_size (int): window size. 39 | 40 | Returns: 41 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 42 | (Hp, Wp): padded height and width before partition 43 | """ 44 | B, H, W, C = x.shape 45 | 46 | pad_h = (window_size - H % window_size) % window_size 47 | pad_w = (window_size - W % window_size) % window_size 48 | if pad_h > 0 or pad_w > 0: 49 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 50 | Hp, Wp = H + pad_h, W + pad_w 51 | 52 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 53 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 54 | return windows, (Hp, Wp) 55 | 56 | 57 | def window_unpartition(windows, window_size, pad_hw, hw): 58 | """ 59 | Window unpartition into original sequences and removing padding. 60 | Args: 61 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 62 | window_size (int): window size. 63 | pad_hw (Tuple): padded height and width (Hp, Wp). 64 | hw (Tuple): original height and width (H, W) before padding. 65 | 66 | Returns: 67 | x: unpartitioned sequences with [B, H, W, C]. 68 | """ 69 | Hp, Wp = pad_hw 70 | H, W = hw 71 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 72 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 73 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 74 | 75 | if Hp > H or Wp > W: 76 | x = x[:, :H, :W, :].contiguous() 77 | return x 78 | 79 | 80 | 81 | # class TokenLearnerModuleV11(nn.Module): 82 | # """TokenLearner module Version 1.1, using slightly different conv. layers. 83 | # Instead of using 4 conv. layers with small channels to implement spatial 84 | # attention, this version uses 2 grouped conv. layers with more channels. It 85 | # also uses softmax instead of sigmoid. We confirmed that this version works 86 | # better when having limited training data, such as training with ImageNet1K 87 | # from scratch. 88 | # Attributes: 89 | # num_tokens: Number of tokens. 90 | # dropout_rate: Dropout rate. 91 | # """ 92 | 93 | # def __init__(self, in_channels, num_tokens): 94 | # """Applies learnable tokenization to the 2D inputs. 95 | # Args: 96 | # inputs: Inputs of shape `[bs, h, w, c]`. 97 | # Returns: 98 | # Output of shape `[bs, n_token, c]`. 99 | # """ 100 | # super(TokenLearnerModuleV11, self).__init__() 101 | # self.in_channels = in_channels 102 | # self.num_tokens = num_tokens 103 | # self.norm = nn.LayerNorm(self.in_channels) # Operates on the last axis (c) of the input data. 104 | 105 | # # use a patch-to-cluster architecture 106 | # self.patch_to_cluster_atten = nn.Sequential( 107 | # nn.Conv2d(self.in_channels, self.in_channels // 4, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False), 108 | # nn.GELU() 109 | # ) 110 | # self.patch_to_cluster_linear = nn.Linear(in_features=self.in_channels // 4, out_features=self.num_tokens) 111 | 112 | 113 | # # def forward(self, inputs, mask, pos, offset_input=None): 114 | # def forward(self, inputs, weights): 115 | # ''' 116 | # inputs: bs, h, w, c 117 | # weights: bs, 2, h, w 118 | # ''' 119 | 120 | # feature_shape = inputs.shape # Shape: [bs, h, w, c] 121 | 122 | # # use the appearance features for attention map generation 123 | # selected = inputs 124 | # selected = selected.permute(0, 3, 1, 2) # Shape: [bs, c, h, w] 125 | # selected = self.patch_to_cluster_atten(selected) # Shape: [bs, c_dim, h, w]. 126 | # selected = selected.permute(0, 2, 3, 1) # Shape: [bs, h, w, c_dim]. 127 | # selected = selected.contiguous().view(feature_shape[0], feature_shape[1] * feature_shape[2], -1) # Shape: [bs, h*w, c_dim]. 128 | # selected = self.patch_to_cluster_linear(selected) # Shape: [bs, h*w, n_token]. 129 | 130 | # selected = selected.permute(0, 2, 1) # Shape: [bs, n_token, h*w]. 131 | # # selected = F.softmax(selected, dim=-1) # bs, n_token, hw 132 | 133 | # num_for_tokens = self.num_tokens // 2 134 | # num_back_tokens = self.num_tokens // 2 135 | 136 | # for_weights = weights.flatten(2, 3)[:, 0].unsqueeze(1) # bs, 1, hw 137 | # back_weights = weights.flatten(2, 3)[:, 1].unsqueeze(1) # bs, 1, hw 138 | 139 | # combined_weights = torch.cat((for_weights.repeat(1, num_for_tokens, 1), back_weights.repeat(1, num_back_tokens, 1)), dim=1) # bs, n_token, hw 140 | # selected[combined_weights == 0] = -float('inf') 141 | # selected = F.softmax(selected, dim=-1) # bs, n_token, hw 142 | # selected = torch.nan_to_num(selected) # replace nan to 0.0, especially for selector = [1, 0], i.e., only has one object, sec_obj is empty 143 | 144 | # # selected = selected * torch.cat((for_weights.repeat(1, num_for_tokens, 1), back_weights.repeat(1, num_back_tokens, 1)), dim=1) # bs, n_token, hw 145 | # # sum_weights = torch.sum(selected, dim=-1).unsqueeze(-1) # bs, n_token, 1 146 | 147 | # feat = inputs #bs, h, w, c 148 | # feat = feat.contiguous().view(feature_shape[0], feature_shape[1] * feature_shape[2], -1) # Shape: [bs, h*w, c]. 149 | 150 | # # Produced the attended inputs. 151 | # outputs = torch.einsum("...si,...id->...sd", selected, feat) # (B, n_token, c) 152 | # # outputs = outputs / (sum_weights + 1e-20) 153 | # outputs = self.norm(outputs) 154 | 155 | # return outputs 156 | 157 | 158 | 159 | class TokenLearnerModuleV11_w_Mask(nn.Module): 160 | """TokenLearner module Version 1.1, using slightly different conv. layers. 161 | Instead of using 4 conv. layers with small channels to implement spatial 162 | attention, this version uses 2 grouped conv. layers with more channels. It 163 | also uses softmax instead of sigmoid. We confirmed that this version works 164 | better when having limited training data, such as training with ImageNet1K 165 | from scratch. 166 | Attributes: 167 | num_tokens: Number of tokens. 168 | dropout_rate: Dropout rate. 169 | """ 170 | 171 | def __init__(self, in_channels, num_tokens): 172 | """Applies learnable tokenization to the 2D inputs. 173 | Args: 174 | inputs: Inputs of shape `[bs, h, w, c]`. 175 | Returns: 176 | Output of shape `[bs, n_token, c]`. 177 | """ 178 | super(TokenLearnerModuleV11_w_Mask, self).__init__() 179 | self.in_channels = in_channels 180 | self.num_tokens = num_tokens 181 | self.norm = nn.LayerNorm(self.in_channels) # Operates on the last axis (c) of the input data. 182 | self.input_norm = nn.LayerNorm(self.in_channels + 1) 183 | 184 | # Input: appearance features + 1-channel mask 185 | self.patch_to_cluster_atten = nn.Sequential( 186 | nn.Conv2d(self.in_channels + 1, self.in_channels // 4, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False), 187 | nn.GELU() 188 | ) 189 | self.patch_to_cluster_linear = nn.Linear(in_features=self.in_channels // 4, out_features=self.num_tokens) 190 | 191 | 192 | def forward(self, inputs, weights): 193 | ''' 194 | inputs: bs, h, w, c 195 | weights: bs, 1, h, w 196 | ''' 197 | 198 | feature_shape = inputs.shape # Shape: [bs, h, w, c] 199 | 200 | # use the appearance features for attention map generation 201 | selected = torch.cat((inputs, weights.permute(0, 2, 3, 1)), dim=-1) # bs, h, w, c+1 202 | selected = self.input_norm(selected) # need this norm here? shape: bs, h, w, c+1 203 | selected = selected.permute(0, 3, 1, 2) # Shape: [bs, c+1, h, w] 204 | selected = self.patch_to_cluster_atten(selected) # Shape: [bs, c_dim, h, w]. 205 | selected = selected.permute(0, 2, 3, 1) # Shape: [bs, h, w, c_dim]. 206 | selected = selected.contiguous().view(feature_shape[0], feature_shape[1] * feature_shape[2], -1) # Shape: [bs, h*w, c_dim]. 207 | selected = self.patch_to_cluster_linear(selected) # Shape: [bs, h*w, n_token//2]. 208 | 209 | selected = selected.permute(0, 2, 1) # Shape: [bs, n_token//2, h*w]. 210 | # selected = F.softmax(selected, dim=-1) # bs, n_token, hw 211 | 212 | num_for_tokens = self.num_tokens 213 | 214 | for_weights = weights.flatten(2, 3) # bs, 1, hw 215 | 216 | # # mask map guided suppresion 217 | combined_weights = for_weights.repeat(1, num_for_tokens, 1) # bs, n_token//2, hw 218 | selected[combined_weights == 0] = -float('inf') 219 | 220 | selected = F.softmax(selected, dim=-1) # bs, n_token, hw 221 | selected = torch.nan_to_num(selected) # replace nan to 0.0, especially for selector = [1, 0], i.e., only has one object, sec_obj is empty 222 | 223 | # selected = selected * torch.cat((for_weights.repeat(1, num_for_tokens, 1), back_weights.repeat(1, num_back_tokens, 1)), dim=1) # bs, n_token, hw 224 | # sum_weights = torch.sum(selected, dim=-1).unsqueeze(-1) # bs, n_token, 1 225 | 226 | feat = inputs #bs, h, w, c 227 | feat = feat.contiguous().view(feature_shape[0], feature_shape[1] * feature_shape[2], -1) # Shape: [bs, h*w, c]. 228 | 229 | # Produced the attended inputs. 230 | outputs = torch.einsum("...si,...id->...sd", selected, feat) # (B, n_token//2, c) 231 | outputs = self.norm(outputs) 232 | 233 | return outputs, selected 234 | 235 | # borrowed from https://github.com/ViTAE-Transformer/ViTDet/blob/main/mmdet/models/backbones/vitae.py 236 | class Norm2d(nn.Module): 237 | def __init__(self, embed_dim): 238 | super().__init__() 239 | self.ln = nn.LayerNorm(embed_dim, eps=1e-6) 240 | def forward(self, x): 241 | x = x.permute(0, 2, 3, 1) 242 | x = self.ln(x) 243 | x = x.permute(0, 3, 1, 2).contiguous() 244 | return x 245 | 246 | 247 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 248 | """ Vision Transformer with support for global average pooling 249 | """ 250 | def __init__(self, global_pool=False, single_object=False, num_bases_foreground=None, num_bases_background=None, img_size=None, vit_dim=None, **kwargs): 251 | super(VisionTransformer, self).__init__(**kwargs) 252 | 253 | self.global_pool = global_pool 254 | if self.global_pool: 255 | norm_layer = kwargs['norm_layer'] 256 | embed_dim = kwargs['embed_dim'] 257 | self.fc_norm = norm_layer(embed_dim) 258 | self.single_object = single_object 259 | 260 | self.vit_dim = vit_dim 261 | 262 | 263 | self.mask_patch_embed = PatchEmbed( 264 | img_size=img_size, patch_size=16, in_chans=1, embed_dim=vit_dim) # !!! to check whether it has grads 265 | 266 | # borrowed from https://github.com/ViTAE-Transformer/ViTDet/blob/main/mmdet/models/backbones/vitae.py 267 | self.fpn1 = nn.Sequential( # 1/4 268 | nn.ConvTranspose2d(vit_dim, vit_dim, kernel_size=2, stride=2), 269 | Norm2d(vit_dim), 270 | nn.GELU(), 271 | nn.ConvTranspose2d(vit_dim, 256, kernel_size=2, stride=2), 272 | ) 273 | 274 | self.fpn2 = nn.Sequential( # 1/8 275 | nn.ConvTranspose2d(vit_dim, 512, kernel_size=2, stride=2), 276 | ) 277 | 278 | self.stcn_decoder = Decoder(vit_dim=vit_dim) 279 | 280 | self.num_bases_foreground = num_bases_foreground 281 | self.num_bases_background = num_bases_background 282 | 283 | print('num_bases_foreground: %d' %(self.num_bases_foreground)) 284 | print('num_bases_background: %d' %(self.num_bases_background)) 285 | 286 | self.tlearner = TokenLearnerModuleV11_w_Mask(in_channels=vit_dim, num_tokens=self.num_bases_foreground) 287 | self.tlearner_back = TokenLearnerModuleV11_w_Mask(in_channels=vit_dim, num_tokens=self.num_bases_background) 288 | 289 | 290 | def aggregate(self, prob): 291 | new_prob = torch.cat([ 292 | torch.prod(1-prob, dim=1, keepdim=True), # get the background region 293 | prob 294 | ], 1).clamp(1e-7, 1-1e-7) 295 | logits = torch.log((new_prob /(1-new_prob))) # bs, 3, 384, 384 296 | return logits 297 | 298 | 299 | def forward(self, mode=None, **kwargs): #memory_frames=None, mask_frames=None, query_frame=None, mode=None, selector=None): 300 | ''' 301 | memory_frames: bs, T, 3, 384, 384 302 | mask_frames: bs, T, 3, 384, 384 303 | query_frame: bs, 3, 384, 384 304 | ''' 305 | if mode == 'extract_feat_w_mask': 306 | frames = kwargs['frames'] #B, C=3, H, W 307 | mask_frames = kwargs['mask_frames'] #B, 1, H, W 308 | layer_index = kwargs['layer_index'] 309 | use_window_partition = kwargs['use_window_partition'] 310 | use_token_learner = kwargs['use_token_learner'] 311 | # local attention for feature extraction 312 | x = self.patch_embed(frames) # bs*T, (H//16 * W//16), 768 313 | B, _, H, W = frames.shape 314 | _, _, C = x.shape 315 | mask_tokens = self.mask_patch_embed(mask_frames) 316 | x = x + self.pos_embed_new[:, 1:, :] + mask_tokens 317 | x = self.pos_drop(x) #bs, (T+1)*hw, C 318 | 319 | bs = frames.shape[0] 320 | 321 | # if use_window_partition, perform the window_partition 322 | if use_window_partition: 323 | # local-in-local attention 324 | H = frames.shape[-2] // 16 325 | W = frames.shape[-1] // 16 326 | dim = x.shape[-1] 327 | x = x.view(bs, H, W, dim) 328 | window_size = H // 2 329 | x, pad_hw = window_partition(x, window_size) # x: bs*N, window_size, window_size, c 330 | x = x.view(-1, window_size*window_size, dim) 331 | # token interaction in early layers of ViT 332 | for blk in self.blocks[0:layer_index]: 333 | x = blk(x) 334 | 335 | # if use window_partition 336 | if use_window_partition: 337 | # local-in-local attention recover 338 | x = window_unpartition(x, window_size, pad_hw, (H, W)) #x: bs, H, W, c 339 | if use_token_learner: 340 | return x.permute(0, 3, 1, 2) #.view(B, C, int(H//16), int(W//16)) 341 | else: 342 | return x.view(bs, -1, dim) #(b, hw, c) 343 | else: 344 | if use_token_learner: 345 | return x.view(bs, H//16, W//16, x.shape[-1]).permute(0, 3, 1, 2) # (b, c, h, w) 346 | else: 347 | return x # (b, hw, c) 348 | elif mode == 'backbone_full': 349 | memory_frames = kwargs['memory_frames'] 350 | mask_frames = kwargs['mask_frames'] 351 | query_frame = kwargs['query_frame'] 352 | B, T, C, H, W = memory_frames.shape 353 | 354 | memory_frames = memory_frames.flatten(0, 1) 355 | mask_frames = mask_frames.flatten(0, 1) 356 | memory_tokens = self.patch_embed(memory_frames) # bs*T, (H//16 * W//16), 768 357 | mask_tokens = self.mask_patch_embed(mask_frames) # bs*T, (H//16 * W//16), 768 358 | # add the target-aware positional encoding 359 | memory_tokens = memory_tokens + mask_tokens 360 | query_tokens = self.patch_embed(query_frame) # bs, (H//16 * W//16), 768 361 | 362 | if T > 1: # multiple memory frames 363 | memory_tokens = memory_tokens.view(B, T, -1, memory_tokens.size()[-1]).contiguous() #bs ,T, num, C 364 | # use all the memory frames 365 | memory_tokens = memory_tokens.flatten(1, 2) # bs ,total_num, C 366 | 367 | x = torch.cat((memory_tokens, query_tokens), dim=1) 368 | if T > 1: 369 | single_size = int((self.pos_embed_new[:, 1:, :].shape[1])) 370 | x = x + self.pos_embed_new[:, 1:(single_size+1), :].repeat(1, T+1, 1) 371 | else: 372 | x = x + self.pos_embed_new[:, 1:, :].repeat(1, 2, 1) # 2 frames 373 | x = self.pos_drop(x) 374 | for blk in self.blocks: 375 | x = blk(x) 376 | 377 | # maybe we need the norm(x), improves the results! 378 | x = self.norm(x) 379 | 380 | num_query_tokens = query_tokens.shape[1] 381 | updated_query_tokens = x[:, -num_query_tokens:, :] 382 | updated_query_tokens = updated_query_tokens.permute(0, 2, 1).contiguous().view(B, self.vit_dim, int(H//16), int(W//16)) 383 | m16 = updated_query_tokens # bs, 768, 24, 24 384 | m8 = self.fpn2(updated_query_tokens) # bs, 512, 48, 48 385 | m4 = self.fpn1(updated_query_tokens) # bs, 256, 96, 96 386 | 387 | return m16, m8, m4 388 | elif mode == 'extract_feat_wo_mask': 389 | frames = kwargs['frames'] 390 | layer_index = kwargs['layer_index'] 391 | use_window_partition = kwargs['use_window_partition'] 392 | x = self.patch_embed(frames) 393 | x = x + self.pos_embed_new[:, 1:, :] 394 | x = self.pos_drop(x) 395 | 396 | if use_window_partition: 397 | # local-in-local attention 398 | H = frames.shape[-2] // 16 399 | W = frames.shape[-1] // 16 400 | bs = frames.shape[0] 401 | dim = x.shape[-1] 402 | x = x.view(bs, H, W, dim) 403 | window_size = H // 2 404 | x, pad_hw = window_partition(x, window_size) # x: bs*N, window_size, window_size, c 405 | x = x.view(-1, window_size*window_size, dim) 406 | for blk in self.blocks[0:layer_index]: 407 | x = blk(x) 408 | 409 | # local-in-local attention recover 410 | if use_window_partition: 411 | x = window_unpartition(x, window_size, pad_hw, (H, W)) #x: bs, H, W, c 412 | x = x.view(bs, -1, dim) 413 | return x # bs, hw, c 414 | else: 415 | return x 416 | elif mode == 'forward_together': 417 | template = kwargs['template'] 418 | search = kwargs['search'] 419 | layer_index = kwargs['layer_index'] 420 | H = kwargs['H'] 421 | W = kwargs['W'] 422 | L = kwargs['L'] 423 | x = torch.cat((template, search), dim=1) 424 | bs, _, _ = x.shape 425 | for blk in self.blocks[layer_index:]: 426 | x = blk(x) 427 | # do the normalization for the output 428 | x = self.norm(x) 429 | updated_query_tokens = x[:, (-L):, :] 430 | updated_query_tokens = updated_query_tokens.permute(0, 2, 1).contiguous().view(bs, self.vit_dim, int(H//16), int(W//16)) 431 | m16 = updated_query_tokens # bs, 768, 24, 24 432 | m8 = self.fpn2(updated_query_tokens) # bs, 512, 48, 48 433 | m4 = self.fpn1(updated_query_tokens) # bs, 256, 96, 96 434 | return m16, m8, m4 #, att_list 435 | elif mode == 'extract_feat_in_later_layer': 436 | x = kwargs['x'] 437 | H = kwargs['H'] 438 | W = kwargs['W'] 439 | L = kwargs['L'] 440 | bs, _, _ = x.shape 441 | 442 | iden_embed = torch.cat((self.pos_iden.repeat(1, self.num_bases_foreground, 1), self.neg_iden.repeat(1, self.num_bases_background, 1)), dim=1).repeat(bs, 1, 1) 443 | x[:, 0:L, :] = x[:, 0:L, :] + iden_embed 444 | layer_index = kwargs['layer_index'] 445 | for blk in self.blocks[layer_index:]: 446 | x = blk(x) 447 | 448 | updated_query_tokens = x[:, L:, :] 449 | updated_query_tokens = updated_query_tokens.permute(0, 2, 1).contiguous().view(bs, self.vit_dim, int(H//16), int(W//16)) 450 | m16 = updated_query_tokens # bs, 768, 24, 24 451 | m8 = self.fpn2(updated_query_tokens) # bs, 512, 48, 48 452 | m4 = self.fpn1(updated_query_tokens) # bs, 256, 96, 96 453 | return m16, m8, m4 #, att_list 454 | elif mode == 'extract_feat_in_later_layer_test': # for inference 455 | x = kwargs['x'] 456 | H = kwargs['H'] 457 | W = kwargs['W'] 458 | L = kwargs['L'] 459 | # att_list = [] 460 | bs, _, _ = x.shape 461 | 462 | iden_embed = torch.cat((self.pos_iden.repeat(1, self.num_bases_foreground, 1), self.neg_iden.repeat(1, self.num_bases_background, 1)), dim=1).repeat(bs, 1, 1) 463 | x[:, 0:L, :] = x[:, 0:L, :] + iden_embed 464 | x[:, L:(2*L), :] = x[:, L:(2*L), :] + iden_embed 465 | layer_index = kwargs['layer_index'] 466 | for blk in self.blocks[layer_index:]: 467 | x, attn = blk(x) 468 | # att_list.append(attn) 469 | updated_query_tokens = x[:, (2*L):, :] 470 | updated_query_tokens = updated_query_tokens.permute(0, 2, 1).contiguous().view(bs, self.vit_dim, int(H//16), int(W//16)) 471 | m16 = updated_query_tokens # bs, 768, 24, 24 472 | m8 = self.fpn2(updated_query_tokens) # bs, 512, 48, 48 473 | m4 = self.fpn1(updated_query_tokens) # bs, 256, 96, 96 474 | return m16, m8, m4 #, att_list 475 | elif mode == 'extract_feat_in_later_layers_w_memory_bank': 476 | x1 = kwargs['x1'] # first frame 1, 512, c 477 | x2 = kwargs['x2'] # pos mem bank 1, 2048, c 478 | x3 = kwargs['x3'] # neg mem bank 1, 2048, c 479 | x4 = kwargs['x4'] # context features 480 | H = kwargs['H'] 481 | W = kwargs['W'] 482 | L = kwargs['L'] 483 | att_list = [] 484 | bs, num_pos_mem, _ = x2.shape 485 | bs, num_neg_mem, _ = x3.shape 486 | 487 | x2 = x2 + self.pos_iden.repeat(1, num_pos_mem, 1) 488 | x3 = x3 + self.neg_iden.repeat(1, num_neg_mem, 1) 489 | if x1 is not None: 490 | x1 = x1 + torch.cat((self.pos_iden.repeat(1, self.num_bases_foreground, 1), self.neg_iden.repeat(1, self.num_bases_background, 1)), dim=1) 491 | x = torch.cat((x1, x2, x3, x4), dim=1) 492 | else: 493 | x = torch.cat((x2, x3, x4), dim=1) 494 | layer_index = kwargs['layer_index'] 495 | for blk in self.blocks[layer_index:]: 496 | x, attn = blk(x) 497 | if x1 is not None: 498 | att_list.append(torch.max(torch.mean(attn,dim=1)[:,-(x4.shape[-2]):][:,:,x1.shape[1]:(x1.shape[1]+num_pos_mem+num_neg_mem)],dim=1).values) 499 | else: 500 | att_list.append(torch.max(torch.mean(attn,dim=1)[:,-(x4.shape[-2]):][:,:,0:(num_pos_mem+num_neg_mem)],dim=1).values) 501 | updated_query_tokens = x[:, -(x4.shape[-2]):, :] 502 | updated_query_tokens = updated_query_tokens.permute(0, 2, 1).contiguous().view(bs, self.vit_dim, int(H//16), int(W//16)) 503 | m16 = updated_query_tokens # bs, 768, 24, 24 504 | m8 = self.fpn2(updated_query_tokens) # bs, 512, 48, 48 505 | m4 = self.fpn1(updated_query_tokens) # bs, 256, 96, 96 506 | return m16, m8, m4, att_list 507 | elif mode == 'tokenlearner_w_masks': 508 | # qk16 = kwargs['qk16'] #bs, c, h, w 509 | # qk16 = qk16.permute(0, 2, 3, 1) #bs, h, w, c 510 | # mask = kwargs['mask'] #bs, 2, h, w 511 | # qk16 = qk16.permute(0, 2, 3, 1) 512 | 513 | qk16 = kwargs['qk16'] #bs, c, h, w 514 | mask = kwargs['mask'] #bs, 2, h, w 515 | # qk16 = torch.cat((qk16, mask), dim=1) #bs, c+2, h, w 516 | qk16 = qk16.permute(0, 2, 3, 1) #bs, h, w, c 517 | 518 | # rf_tokens = self.tlearner(qk16, mask) 519 | fore_tokens, fore_att = self.tlearner(qk16, mask[:, 0].unsqueeze(1)) 520 | back_tokens, back_att = self.tlearner_back(qk16, mask[:, 1].unsqueeze(1)) 521 | # fore_tokens = self.tlearner(qk16, mask[:, 0].unsqueeze(1)) 522 | # back_tokens = self.tlearner_back(qk16, mask[:, 1].unsqueeze(1)) 523 | rf_tokens = torch.cat((fore_tokens, back_tokens), dim=1) #bs, num_token, c 524 | return rf_tokens, fore_att, back_att, mask.shape[-2], mask.shape[-1] 525 | elif mode == 'segmentation': 526 | # print('decoder for segmentation') 527 | m16 = kwargs['m16'] 528 | m8 = kwargs['m8'] 529 | m4 = kwargs['m4'] 530 | selector = kwargs['selector'] 531 | 532 | # m16=m16, m8=m8, m4=m4, selector=selector 533 | if self.single_object: 534 | logits = self.decoder(m16, m8, m4) 535 | prob = torch.sigmoid(logits) 536 | else: 537 | 538 | #self.memory.readout(affinity, mv16[:,0], qv16): 4, 1024, 24, 24 539 | # qf8: 4, 512, 48, 48; qf4: 4, 256, 96, 96; 540 | logits = torch.cat([ 541 | self.stcn_decoder(m16[:,0], m8[:,0], m4[:,0]), 542 | self.stcn_decoder(m16[:,1], m8[:,1], m4[:,1]), 543 | ], 1) # 4, 2, 384, 384 544 | 545 | prob = torch.sigmoid(logits) # 4, 2, 384, 384; 2: two targets 546 | prob = prob * selector.unsqueeze(2).unsqueeze(2) # 4, 2, 384, 384 547 | 548 | logits = self.aggregate(prob) 549 | prob = F.softmax(logits, dim=1)[:, 1:] 550 | return logits, prob, F.softmax(logits, dim=1) # for memorize 551 | elif mode == 'segmentation_single_onject': 552 | m16 = kwargs['m16'] 553 | m8 = kwargs['m8'] 554 | m4 = kwargs['m4'] 555 | logits = self.stcn_decoder(m16, m8, m4) 556 | return torch.sigmoid(logits) 557 | 558 | def forward_patch_embedding(self, x): 559 | return self.patch_embed(x) 560 | 561 | def forward_features_testing(self, x, z): 562 | B = x.shape[0] 563 | # cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 564 | # x = torch.cat((cls_tokens, x, z), dim=1) 565 | x = torch.cat((x, z), dim=1) 566 | x = x + self.pos_embed 567 | x = self.pos_drop(x) 568 | 569 | for blk in self.blocks: 570 | x = blk(x) 571 | 572 | # if self.global_pool: # mae use the global_pool instead of cls token for classficaition 573 | # x = x[:, 1:, :].mean(dim=1) # global pool without cls token 574 | # outcome = self.fc_norm(x) 575 | # else: 576 | # x = self.norm(x) 577 | # outcome = x[:, 0] 578 | 579 | search_tokens = z.shape[1] 580 | 581 | return x[:, -search_tokens:, :] 582 | 583 | 584 | def vit_base_patch16(**kwargs): 585 | model = VisionTransformer( 586 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 587 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 588 | return model 589 | 590 | 591 | def vit_large_patch16(**kwargs): 592 | model = VisionTransformer( 593 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 594 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 595 | return model 596 | 597 | 598 | def vit_huge_patch14(**kwargs): 599 | model = VisionTransformer( 600 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 601 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 602 | return model 603 | 604 | 605 | -------------------------------------------------------------------------------- /model/modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | modules.py - This file stores the rathering boring network blocks. 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torchvision import models 9 | 10 | from model import mod_resnet 11 | from model import cbam 12 | 13 | 14 | class ResBlock(nn.Module): 15 | def __init__(self, indim, outdim=None): 16 | super(ResBlock, self).__init__() 17 | if outdim == None: 18 | outdim = indim 19 | if indim == outdim: 20 | self.downsample = None 21 | else: 22 | self.downsample = nn.Conv2d(indim, outdim, kernel_size=3, padding=1) 23 | 24 | self.conv1 = nn.Conv2d(indim, outdim, kernel_size=3, padding=1) 25 | self.conv2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1) 26 | 27 | def forward(self, x): 28 | r = self.conv1(F.relu(x)) 29 | r = self.conv2(F.relu(r)) 30 | 31 | if self.downsample is not None: 32 | x = self.downsample(x) 33 | 34 | return x + r 35 | 36 | 37 | class FeatureFusionBlock(nn.Module): 38 | def __init__(self, indim, outdim): 39 | super().__init__() 40 | 41 | self.block1 = ResBlock(indim, outdim) 42 | self.attention = cbam.CBAM(outdim) 43 | self.block2 = ResBlock(outdim, outdim) 44 | 45 | def forward(self, x, f16): 46 | x = torch.cat([x, f16], 1) 47 | x = self.block1(x) 48 | r = self.attention(x) 49 | x = self.block2(x + r) 50 | 51 | return x 52 | 53 | 54 | # Single object version, used only in static image pretraining 55 | # This will be loaded and modified into the multiple objects version later (in stage 1/2/3) 56 | # See model.py (load_network) for the modification procedure 57 | class ValueEncoderSO(nn.Module): 58 | def __init__(self): 59 | super().__init__() 60 | 61 | resnet = mod_resnet.resnet18(pretrained=True, extra_chan=1) 62 | self.conv1 = resnet.conv1 63 | self.bn1 = resnet.bn1 64 | self.relu = resnet.relu # 1/2, 64 65 | self.maxpool = resnet.maxpool 66 | 67 | self.layer1 = resnet.layer1 # 1/4, 64 68 | self.layer2 = resnet.layer2 # 1/8, 128 69 | self.layer3 = resnet.layer3 # 1/16, 256 70 | 71 | self.fuser = FeatureFusionBlock(1024 + 256, 512) 72 | 73 | def forward(self, image, key_f16, mask): 74 | # key_f16 is the feature from the key encoder 75 | 76 | f = torch.cat([image, mask], 1) 77 | 78 | x = self.conv1(f) 79 | x = self.bn1(x) 80 | x = self.relu(x) # 1/2, 64 81 | x = self.maxpool(x) # 1/4, 64 82 | x = self.layer1(x) # 1/4, 64 83 | x = self.layer2(x) # 1/8, 128 84 | x = self.layer3(x) # 1/16, 256 85 | 86 | x = self.fuser(x, key_f16) 87 | 88 | return x 89 | 90 | 91 | # Multiple objects version, used in other times 92 | class ValueEncoder(nn.Module): 93 | def __init__(self): 94 | super().__init__() 95 | 96 | resnet = mod_resnet.resnet18(pretrained=True, extra_chan=2) 97 | self.conv1 = resnet.conv1 98 | self.bn1 = resnet.bn1 99 | self.relu = resnet.relu # 1/2, 64 100 | self.maxpool = resnet.maxpool 101 | 102 | self.layer1 = resnet.layer1 # 1/4, 64 103 | self.layer2 = resnet.layer2 # 1/8, 128 104 | self.layer3 = resnet.layer3 # 1/16, 256 105 | 106 | self.fuser = FeatureFusionBlock(1024 + 256, 512) 107 | 108 | def forward(self, image, key_f16, mask, other_masks): 109 | # key_f16 is the feature from the key encoder 110 | 111 | f = torch.cat([image, mask, other_masks], 1) 112 | 113 | x = self.conv1(f) 114 | x = self.bn1(x) 115 | x = self.relu(x) # 1/2, 64 116 | x = self.maxpool(x) # 1/4, 64 117 | x = self.layer1(x) # 1/4, 64 118 | x = self.layer2(x) # 1/8, 128 119 | x = self.layer3(x) # 1/16, 256 120 | 121 | x = self.fuser(x, key_f16) 122 | 123 | return x 124 | 125 | 126 | class KeyEncoder(nn.Module): 127 | def __init__(self): 128 | super().__init__() 129 | resnet = models.resnet50(pretrained=True) 130 | self.conv1 = resnet.conv1 131 | self.bn1 = resnet.bn1 132 | self.relu = resnet.relu # 1/2, 64 133 | self.maxpool = resnet.maxpool 134 | 135 | self.res2 = resnet.layer1 # 1/4, 256 136 | self.layer2 = resnet.layer2 # 1/8, 512 137 | self.layer3 = resnet.layer3 # 1/16, 1024 138 | 139 | def forward(self, f): 140 | x = self.conv1(f) 141 | x = self.bn1(x) 142 | x = self.relu(x) # 1/2, 64 143 | x = self.maxpool(x) # 1/4, 64 144 | f4 = self.res2(x) # 1/4, 256 145 | f8 = self.layer2(f4) # 1/8, 512 146 | f16 = self.layer3(f8) # 1/16, 1024 147 | 148 | return f16, f8, f4 149 | 150 | 151 | class UpsampleBlock(nn.Module): 152 | def __init__(self, skip_c, up_c, out_c, scale_factor=2): 153 | super().__init__() 154 | self.skip_conv = nn.Conv2d(skip_c, up_c, kernel_size=3, padding=1) 155 | self.out_conv = ResBlock(up_c, out_c) 156 | self.scale_factor = scale_factor 157 | 158 | def forward(self, skip_f, up_f): 159 | x = self.skip_conv(skip_f) 160 | x = x + F.interpolate(up_f, scale_factor=self.scale_factor, mode='bilinear', align_corners=False) 161 | x = self.out_conv(x) 162 | return x 163 | 164 | 165 | class KeyProjection(nn.Module): 166 | def __init__(self, indim, keydim): 167 | super().__init__() 168 | self.key_proj = nn.Conv2d(indim, keydim, kernel_size=3, padding=1) 169 | 170 | nn.init.orthogonal_(self.key_proj.weight.data) 171 | nn.init.zeros_(self.key_proj.bias.data) 172 | 173 | def forward(self, x): 174 | return self.key_proj(x) 175 | -------------------------------------------------------------------------------- /model/network.py: -------------------------------------------------------------------------------- 1 | """ 2 | network.py - The core of the neural network 3 | Defines the structure and memory operations 4 | Modifed from STM: https://github.com/seoungwugoh/STM 5 | 6 | The trailing number of a variable usually denote the stride 7 | e.g. f16 -> encoded features with stride 16 8 | """ 9 | 10 | import math 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | from model.modules import * 17 | 18 | 19 | # from model import models_vit 20 | # import timm 21 | # assert timm.__version__ == "0.3.2" # version check 22 | 23 | 24 | class Decoder(nn.Module): 25 | def __init__(self, vit_dim): 26 | super().__init__() 27 | # self.compress = ResBlock(1024, 512) 28 | self.compress = ResBlock(vit_dim, 512) 29 | self.up_16_8 = UpsampleBlock(512, 512, 256) # 1/16 -> 1/8 30 | self.up_8_4 = UpsampleBlock(256, 256, 256) # 1/8 -> 1/4 31 | 32 | self.pred = nn.Conv2d(256, 1, kernel_size=(3,3), padding=(1,1), stride=1) 33 | 34 | def forward(self, f16, f8, f4): 35 | x = self.compress(f16) 36 | x = self.up_16_8(f8, x) 37 | x = self.up_8_4(f4, x) 38 | 39 | x = self.pred(F.relu(x)) 40 | 41 | x = F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=False) 42 | return x 43 | 44 | 45 | class MemoryReader(nn.Module): 46 | def __init__(self): 47 | super().__init__() 48 | 49 | def get_affinity(self, mk, qk): 50 | B, CK, T, H, W = mk.shape 51 | mk = mk.flatten(start_dim=2) 52 | qk = qk.flatten(start_dim=2) 53 | 54 | # See supplementary material 55 | a_sq = mk.pow(2).sum(1).unsqueeze(2) 56 | ab = mk.transpose(1, 2) @ qk 57 | 58 | affinity = (2*ab-a_sq) / math.sqrt(CK) # B, THW, HW 59 | 60 | # softmax operation; aligned the evaluation style 61 | maxes = torch.max(affinity, dim=1, keepdim=True)[0] 62 | x_exp = torch.exp(affinity - maxes) 63 | x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True) 64 | affinity = x_exp / x_exp_sum 65 | 66 | return affinity 67 | 68 | def readout(self, affinity, mv, qv): 69 | B, CV, T, H, W = mv.shape 70 | 71 | mo = mv.view(B, CV, T*H*W) 72 | mem = torch.bmm(mo, affinity) # Weighted-sum B, CV, HW 73 | mem = mem.view(B, CV, H, W) 74 | 75 | mem_out = torch.cat([mem, qv], dim=1) 76 | 77 | return mem_out 78 | 79 | 80 | class STCN(nn.Module): 81 | def __init__(self, single_object): 82 | super().__init__() 83 | self.single_object = single_object 84 | 85 | self.key_encoder = KeyEncoder() # ResNet50--layer3 86 | if single_object: 87 | self.value_encoder = ValueEncoderSO() 88 | else: 89 | self.value_encoder = ValueEncoder() # ResNet18--layer3 90 | 91 | # Projection from f16 feature space to key space 92 | self.key_proj = KeyProjection(1024, keydim=64) # one conv 93 | 94 | # Compress f16 a bit to use in decoding later on 95 | self.key_comp = nn.Conv2d(1024, 512, kernel_size=3, padding=1) 96 | 97 | self.memory = MemoryReader() 98 | self.decoder = Decoder() 99 | 100 | def aggregate(self, prob): 101 | new_prob = torch.cat([ 102 | torch.prod(1-prob, dim=1, keepdim=True), # get the background region 103 | prob 104 | ], 1).clamp(1e-7, 1-1e-7) 105 | logits = torch.log((new_prob /(1-new_prob))) # bs, 3, 384, 384 106 | return logits 107 | 108 | def encode_key(self, frame): 109 | # input: b*t*c*h*w 110 | b, t = frame.shape[:2] 111 | 112 | f16, f8, f4 = self.key_encoder(frame.flatten(start_dim=0, end_dim=1)) 113 | k16 = self.key_proj(f16) 114 | f16_thin = self.key_comp(f16) 115 | 116 | # B*C*T*H*W 117 | k16 = k16.view(b, t, *k16.shape[-3:]).transpose(1, 2).contiguous() 118 | 119 | # B*T*C*H*W 120 | f16_thin = f16_thin.view(b, t, *f16_thin.shape[-3:]) 121 | f16 = f16.view(b, t, *f16.shape[-3:]) 122 | f8 = f8.view(b, t, *f8.shape[-3:]) 123 | f4 = f4.view(b, t, *f4.shape[-3:]) 124 | 125 | return k16, f16_thin, f16, f8, f4 126 | 127 | def encode_value(self, frame, kf16, mask, other_mask=None): 128 | # Extract memory key/value for a frame 129 | if self.single_object: 130 | f16 = self.value_encoder(frame, kf16, mask) 131 | else: 132 | f16 = self.value_encoder(frame, kf16, mask, other_mask) 133 | return f16.unsqueeze(2) # B*512*T*H*W 134 | 135 | def segment(self, qk16, qv16, qf8, qf4, mk16, mv16, selector=None): 136 | # q - query, m - memory 137 | # qv16 is f16_thin above 138 | # segment a specific frame 139 | affinity = self.memory.get_affinity(mk16, qk16) # mk16: bs, 64, T, 24, 24; qk16: bs, 64, 24, 24; affinity: bs, THW, THW 140 | 141 | if self.single_object: 142 | logits = self.decoder(self.memory.readout(affinity, mv16, qv16), qf8, qf4) 143 | prob = torch.sigmoid(logits) 144 | else: 145 | #self.memory.readout(affinity, mv16[:,0], qv16): 4, 1024, 24, 24 146 | # qf8: 4, 512, 48, 48; qf4: 4, 256, 96, 96; 147 | logits = torch.cat([ 148 | self.decoder(self.memory.readout(affinity, mv16[:,0], qv16), qf8, qf4), 149 | self.decoder(self.memory.readout(affinity, mv16[:,1], qv16), qf8, qf4), 150 | ], 1) # 4, 2, 384, 384 151 | 152 | prob = torch.sigmoid(logits) # 4, 2, 384, 384; 2: two targets 153 | prob = prob * selector.unsqueeze(2).unsqueeze(2) # 4, 2, 384, 384 154 | 155 | logits = self.aggregate(prob) 156 | prob = F.softmax(logits, dim=1)[:, 1:] 157 | 158 | return logits, prob 159 | 160 | def forward(self, mode, *args, **kwargs): 161 | if mode == 'encode_key': 162 | return self.encode_key(*args, **kwargs) 163 | elif mode == 'encode_value': 164 | return self.encode_value(*args, **kwargs) 165 | elif mode == 'segment': 166 | return self.segment(*args, **kwargs) 167 | else: 168 | raise NotImplementedError 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | -------------------------------------------------------------------------------- /pretrained_models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/pretrained_models/.DS_Store -------------------------------------------------------------------------------- /scripts/00180.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/scripts/00180.jpg -------------------------------------------------------------------------------- /scripts/00180_original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/scripts/00180_original.jpg -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/scripts/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/scripts/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/resize_youtube.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/scripts/__pycache__/resize_youtube.cpython-36.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/resize_youtube.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/scripts/__pycache__/resize_youtube.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/resize_length.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import cv2 4 | 5 | from progressbar import progressbar 6 | 7 | input_dir = sys.argv[1] 8 | output_dir = sys.argv[2] 9 | 10 | # max_length = 500 11 | min_length = 384 12 | 13 | def process_fun(): 14 | 15 | for f in progressbar(os.listdir(input_dir)): 16 | img = cv2.imread(os.path.join(input_dir, f)) 17 | h, w, _ = img.shape 18 | 19 | # scale = max(h, w) / max_length 20 | scale = min(h, w) / min_length 21 | 22 | img = cv2.resize(img, (int(w/scale), int(h/scale)), interpolation=cv2.INTER_AREA) 23 | cv2.imwrite(os.path.join(output_dir, os.path.basename(f)), img) 24 | 25 | if __name__ == '__main__': 26 | 27 | os.makedirs(output_dir, exist_ok=True) 28 | process_fun() 29 | 30 | print('All done.') -------------------------------------------------------------------------------- /scripts/resize_youtube.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from os import path 4 | 5 | from PIL import Image 6 | import numpy as np 7 | from progressbar import progressbar 8 | from multiprocessing import Pool 9 | 10 | from PIL import ImageFile 11 | ImageFile.LOAD_TRUNCATED_IMAGES = True 12 | 13 | new_min_size = 480 14 | 15 | def resize_vid_jpeg(inputs): 16 | vid_name, folder_path, out_path = inputs 17 | 18 | vid_path = path.join(folder_path, vid_name) 19 | vid_out_path = path.join(out_path, 'JPEGImages', vid_name) 20 | os.makedirs(vid_out_path, exist_ok=True) 21 | 22 | for im_name in os.listdir(vid_path): 23 | hr_im = Image.open(path.join(vid_path, im_name)) 24 | w, h = hr_im.size 25 | 26 | ratio = new_min_size / min(w, h) 27 | 28 | # print(os.path.join(vid_path, im_name)) 29 | # try: 30 | # lr_im = hr_im.resize((int(w*ratio), int(h*ratio)), Image.BICUBIC) 31 | # lr_im.save(path.join(vid_out_path, im_name)) 32 | # except Exception as e: 33 | # print('exception!!!') 34 | # print(os.path.join(vid_path, im_name)) 35 | 36 | lr_im = hr_im.resize((int(w*ratio), int(h*ratio)), Image.BICUBIC) 37 | lr_im.save(path.join(vid_out_path, im_name)) 38 | 39 | 40 | def resize_vid_anno(inputs): 41 | vid_name, folder_path, out_path = inputs 42 | 43 | vid_path = path.join(folder_path, vid_name) 44 | vid_out_path = path.join(out_path, 'Annotations', vid_name) 45 | os.makedirs(vid_out_path, exist_ok=True) 46 | 47 | for im_name in os.listdir(vid_path): 48 | hr_im = Image.open(path.join(vid_path, im_name)).convert('P') 49 | w, h = hr_im.size 50 | 51 | ratio = new_min_size / min(w, h) 52 | 53 | lr_im = hr_im.resize((int(w*ratio), int(h*ratio)), Image.NEAREST) 54 | lr_im.save(path.join(vid_out_path, im_name)) 55 | 56 | 57 | def resize_all(in_path, out_path): 58 | for folder in os.listdir(in_path): 59 | 60 | if folder not in ['JPEGImages', 'Annotations']: 61 | continue 62 | folder_path = path.join(in_path, folder) 63 | videos = os.listdir(folder_path) 64 | 65 | videos = [(v, folder_path, out_path) for v in videos] 66 | 67 | if folder == 'JPEGImages': 68 | print('Processing images') 69 | os.makedirs(path.join(out_path, 'JPEGImages'), exist_ok=True) 70 | 71 | pool = Pool(processes=8) 72 | # pool = Pool(processes=1) 73 | for _ in progressbar(pool.imap_unordered(resize_vid_jpeg, videos), max_value=len(videos)): 74 | pass 75 | else: 76 | print('Processing annotations') 77 | os.makedirs(path.join(out_path, 'Annotations'), exist_ok=True) 78 | 79 | pool = Pool(processes=8) 80 | for _ in progressbar(pool.imap_unordered(resize_vid_anno, videos), max_value=len(videos)): 81 | pass 82 | 83 | 84 | if __name__ == '__main__': 85 | in_path = sys.argv[1] 86 | out_path = sys.argv[2] 87 | 88 | resize_all(in_path, out_path) 89 | 90 | print('Done.') -------------------------------------------------------------------------------- /submit_eval_davis_ours_all.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | import time 4 | from argparse import ArgumentParser 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.utils.data import DataLoader 9 | import numpy as np 10 | from PIL import Image 11 | 12 | from model.eval_network import STCN, ViT_STCN 13 | from dataset.davis_test_dataset import DAVISTestDataset 14 | from util.tensor_util import unpad 15 | from inference_core import InferenceCore_ViT 16 | 17 | from progressbar import progressbar 18 | import cv2 19 | 20 | 21 | from model import models_vit 22 | import timm 23 | assert timm.__version__ == "0.3.2" # version check 24 | 25 | 26 | def interpolate_pos_embed(pos_embed, search_size): 27 | 28 | num_extra_tokens = 1 29 | # pos_embed = net.pos_embed 30 | model_pos_tokens = pos_embed[:, num_extra_tokens:, :] # bs, N, C 31 | model_token_size = int(model_pos_tokens.shape[1]**0.5) 32 | extra_pos_tokens = pos_embed[:, :num_extra_tokens] 33 | 34 | embedding_size = extra_pos_tokens.shape[-1] 35 | 36 | if search_size != model_token_size: # do interpolation 37 | model_pos_tokens_temp = model_pos_tokens.reshape(-1, model_token_size, model_token_size, embedding_size).contiguous().permute(0, 3, 1, 2) # bs, c, h, w 38 | search_pos_tokens = torch.nn.functional.interpolate( 39 | model_pos_tokens_temp, size=(search_size, search_size), mode='bicubic', align_corners=False) 40 | search_pos_tokens = search_pos_tokens.permute(0, 2, 3, 1).contiguous().flatten(1, 2) 41 | else: 42 | search_pos_tokens = model_pos_tokens 43 | new_pos_embed = torch.cat((extra_pos_tokens, search_pos_tokens), dim=1) 44 | return new_pos_embed 45 | 46 | 47 | """ 48 | Arguments loading 49 | """ 50 | parser = ArgumentParser() 51 | parser.add_argument('--model_path', default='./') 52 | parser.add_argument('--davis_path', default='/apdcephfs/share_1290939/qiangqwu/VOS/DAVIS/2017') 53 | parser.add_argument('--output', default='/apdcephfs/share_1290939/qiangqwu/STCN_clustering_iccv19_evaluation') 54 | parser.add_argument('--split', help='val/testdev', default='val') 55 | parser.add_argument('--top', type=int, default=20) 56 | parser.add_argument('--amp', action='store_true') 57 | parser.add_argument('--num_bases_foreground', type=int, default=384) 58 | parser.add_argument('--num_bases_background', type=int, default=384) 59 | parser.add_argument('--layer_index', type=int, default=4) 60 | parser.add_argument('--use_local_window', action='store_true') 61 | parser.add_argument('--use_token_learner', action='store_true') 62 | parser.add_argument('--backbone_type', default='vit_base') 63 | 64 | # python submit_eval_davis_ours_all.py --model_path /home/user/Project/STCN_updating_clustering/test_checkpoints --davis_path /home/user/Data/DAVIS/2017 --output ./results --split val --layer_index 4 --use_token_learner --use_local_window 65 | 66 | args = parser.parse_args() 67 | 68 | print(args) 69 | 70 | def overlay_davis(image,mask,colors=[255,0,0],cscale=2,alpha=0.4): 71 | """ Overlay segmentation on top of RGB image. from davis official""" 72 | # import skimage 73 | from scipy.ndimage.morphology import binary_erosion, binary_dilation 74 | 75 | colors = np.reshape(colors, (-1, 3)) 76 | colors = np.atleast_2d(colors) * cscale 77 | 78 | im_overlay = image.copy() 79 | object_ids = np.unique(mask) 80 | 81 | for object_id in object_ids[1:]: 82 | # Overlay color on binary mask 83 | foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id]) 84 | binary_mask = mask == object_id 85 | 86 | # Compose image 87 | im_overlay[binary_mask] = foreground[binary_mask] 88 | 89 | # countours = skimage.morphology.binary.binary_dilation(binary_mask) - binary_mask 90 | countours = binary_dilation(binary_mask) ^ binary_mask 91 | # countours = cv2.dilate(binary_mask, cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))) - binary_mask 92 | im_overlay[countours,:] = 0 93 | 94 | return im_overlay.astype(image.dtype) 95 | 96 | 97 | checkpoint_list = os.listdir(args.model_path) 98 | for model_name in checkpoint_list: 99 | out_path = os.path.join(args.output, model_name[0:-4]) 100 | davis_path = args.davis_path 101 | 102 | # Simple setup 103 | os.makedirs(out_path, exist_ok=True) 104 | palette = Image.open(path.expanduser(davis_path + '/trainval/Annotations/480p/blackswan/00000.png')).getpalette() 105 | 106 | torch.autograd.set_grad_enabled(False) 107 | 108 | # Setup Dataset 109 | if args.split == 'val': 110 | test_dataset = DAVISTestDataset(davis_path+'/trainval', imset='2017/val.txt') 111 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4) 112 | elif args.split == 'testdev': 113 | test_dataset = DAVISTestDataset(davis_path+'/test-dev', imset='2017/test-dev.txt') 114 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4) 115 | else: 116 | raise NotImplementedError 117 | 118 | # Load our checkpoint 119 | top_k = args.top 120 | 121 | 122 | print('layer index: %d' %(args.layer_index)) 123 | 124 | 125 | if backbone_type == 'vit_base': 126 | vit_model = models_vit.__dict__['vit_base_patch16']( 127 | num_classes=1000, 128 | drop_path_rate=0.0, 129 | global_pool=True, 130 | single_object = False, 131 | num_bases_foreground = args.num_bases_foreground, 132 | num_bases_background = args.num_bases_background, 133 | img_size = 384, 134 | vit_dim=768) 135 | pos_embed_new = interpolate_pos_embed(vit_model.pos_embed, int(384//16)) 136 | vit_model.pos_embed_new = torch.nn.Parameter(pos_embed_new, requires_grad=False) 137 | checkpoint = torch.load(os.path.join(args.model_path, model_name), map_location='cpu') 138 | print("Load well-trained checkpoint from %s" %(model_name)) 139 | state_dict = checkpoint 140 | elif backbone_type == 'vit_large': 141 | vit_model = models_vit.__dict__['vit_large_patch16']( 142 | num_classes=1000, 143 | drop_path_rate=0.0, 144 | global_pool=True, 145 | single_object = False, 146 | num_bases_foreground = args.num_bases_foreground, 147 | num_bases_background = args.num_bases_background, 148 | img_size = 384, 149 | vit_dim=1024) 150 | pos_embed_new = interpolate_pos_embed(vit_model.pos_embed, int(384//16)) 151 | vit_model.pos_embed_new = torch.nn.Parameter(pos_embed_new, requires_grad=False) 152 | checkpoint = torch.load(os.path.join(args.model_path, model_name), map_location='cpu') 153 | print("Load well-trained checkpoint from %s" %(checkpoint)) 154 | state_dict = checkpoint 155 | 156 | 157 | # interpolate position embedding 158 | # load pre-trained model 159 | msg = vit_model.load_state_dict(state_dict, strict=False) 160 | print(msg) 161 | vit_model = vit_model.cuda().eval() 162 | 163 | total_process_time = 0 164 | total_frames = 0 165 | 166 | pos_embed_new = vit_model.pos_embed_new.detach().clone() 167 | 168 | # Start eval 169 | for data in progressbar(test_loader, max_value=len(test_loader), redirect_stdout=True): 170 | 171 | with torch.cuda.amp.autocast(enabled=args.amp): # per sequence here 172 | rgb = data['rgb'].cuda() # 1, 69, 3, 480, 910 173 | msk = data['gt'][0].cuda() # 2, 69, 1, 480, 910 174 | info = data['info'] 175 | name = info['name'][0] 176 | k = len(info['labels'][0]) # num. of objects 177 | size = info['size_480p'] 178 | 179 | torch.cuda.synchronize() 180 | process_begin = time.time() 181 | 182 | print('before_interpolation:') 183 | print(pos_embed_new.shape) 184 | 185 | processor = InferenceCore_ViT(vit_model, rgb, k, pos_embed_new, video_name=name) 186 | processor.interact(msk[:,0], 0, rgb.shape[1], args.layer_index, args.use_local_window, args.use_token_learner) 187 | 188 | # Do unpad -> upsample to original size 189 | out_masks = torch.zeros((processor.t, 1, *size), dtype=torch.uint8, device='cuda') 190 | for ti in range(processor.t): 191 | prob = unpad(processor.prob[:,ti], processor.pad) 192 | prob = F.interpolate(prob, size, mode='bilinear', align_corners=False) 193 | out_masks[ti] = torch.argmax(prob, dim=0) 194 | 195 | out_masks = (out_masks.detach().cpu().numpy()[:,0]).astype(np.uint8) 196 | 197 | torch.cuda.synchronize() 198 | total_process_time += time.time() - process_begin 199 | total_frames += out_masks.shape[0] 200 | 201 | # Save the results 202 | this_out_path = path.join(out_path, name) 203 | os.makedirs(this_out_path, exist_ok=True) 204 | for f in range(out_masks.shape[0]): 205 | img_E = Image.fromarray(out_masks[f]) 206 | img_E.putpalette(palette) 207 | img_E.save(os.path.join(this_out_path, '{:05d}.png'.format(f))) 208 | 209 | # # Optional: save the overlay images 210 | # video_name = data['info']['name'][0] 211 | # base_path = os.path.join('./qualitative_results_final_results', video_name) 212 | # os.makedirs(base_path, exist_ok=True) 213 | # for ti in range(processor.t): 214 | # pF = cv2.imread(os.path.join('/apdcephfs/share_1290939/qiangqwu/VOS/DAVIS/2017/trainval/JPEGImages/480p', video_name, data['info']['frames'][ti][0])) 215 | # pF = cv2.cvtColor(pF, cv2.COLOR_BGR2RGB) 216 | # canvas = overlay_davis(pF, out_masks[ti], palette) 217 | # canvas = Image.fromarray(canvas) 218 | # canvas.save(os.path.join(base_path, data['info']['frames'][ti][0])) 219 | 220 | del rgb 221 | del msk 222 | del processor 223 | 224 | print('Total processing time: ', total_process_time) 225 | print('Total processed frames: ', total_frames) 226 | print('FPS: ', total_frames / total_process_time) -------------------------------------------------------------------------------- /train_simvos.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from os import path 3 | import math 4 | 5 | import random 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import DataLoader, ConcatDataset 9 | import torch.distributed as distributed 10 | 11 | from model.model import ViTSTCNModel 12 | from dataset.static_dataset import StaticTransformDataset 13 | from dataset.vos_dataset import VOSDataset 14 | 15 | from util.logger import TensorboardLogger 16 | from util.hyper_para import HyperParameters 17 | from util.load_subset import load_sub_davis, load_sub_yv 18 | 19 | 20 | """ 21 | Initial setup 22 | """ 23 | # Init distributed environment 24 | distributed.init_process_group(backend="nccl") 25 | # Set seed to ensure the same initialization 26 | torch.manual_seed(14159265) 27 | np.random.seed(14159265) 28 | random.seed(14159265) 29 | 30 | print('CUDA Device count: ', torch.cuda.device_count()) 31 | 32 | # Parse command line arguments 33 | para = HyperParameters() 34 | para.parse() 35 | 36 | if para['benchmark']: 37 | torch.backends.cudnn.benchmark = True 38 | 39 | local_rank = torch.distributed.get_rank() 40 | world_size = torch.distributed.get_world_size() 41 | torch.cuda.set_device(local_rank) 42 | 43 | print('I am rank %d in this world of size %d!' % (local_rank, world_size)) 44 | 45 | """ 46 | Model related 47 | """ 48 | if local_rank == 0: 49 | # Logging 50 | if para['id'].lower() != 'null': 51 | print('I will take the role of logging!') 52 | long_id = '%s_%s' % (datetime.datetime.now().strftime('%b%d_%H.%M.%S'), para['id']) 53 | else: 54 | long_id = None 55 | logger = TensorboardLogger(para['id'], long_id) 56 | logger.log_string('hyperpara', str(para)) 57 | 58 | # Construct the rank 0 model 59 | model = ViTSTCNModel(para, logger=logger, 60 | save_path=path.join('./saved_checkpoints', long_id, long_id) if long_id is not None else None, 61 | local_rank=local_rank, world_size=world_size).train() 62 | else: 63 | # Construct model for other ranks 64 | model = ViTSTCNModel(para, local_rank=local_rank, world_size=world_size).train() 65 | 66 | # init iter 67 | total_iter = 0 68 | 69 | """ 70 | Dataloader related 71 | """ 72 | 73 | # To re-seed the randomness everytime we start a worker 74 | def worker_init_fn(worker_id): 75 | return np.random.seed(torch.initial_seed()%(2**31) + worker_id + local_rank*100) 76 | 77 | def construct_loader(dataset): 78 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, rank=local_rank, shuffle=True) 79 | train_loader = DataLoader(dataset, para['batch_size'], sampler=train_sampler, num_workers=para['num_workers'], 80 | worker_init_fn=worker_init_fn, drop_last=True, pin_memory=True) 81 | return train_sampler, train_loader 82 | 83 | def renew_vos_loader(max_skip): 84 | # //5 because we only have annotation for every five frames 85 | yv_dataset = VOSDataset(path.join(yv_root, 'JPEGImages'), 86 | path.join(yv_root, 'Annotations'), max_skip//5, is_bl=False, subset=load_sub_yv(), img_size=para['img_size']) 87 | davis_dataset = VOSDataset(path.join(davis_root, 'JPEGImages', '480p'), 88 | path.join(davis_root, 'Annotations', '480p'), max_skip, is_bl=False, subset=load_sub_davis(), img_size=para['img_size']) 89 | train_dataset = ConcatDataset([davis_dataset]*5 + [yv_dataset]) 90 | 91 | print('YouTube dataset size: ', len(yv_dataset)) 92 | print('DAVIS dataset size: ', len(davis_dataset)) 93 | print('Concat dataset size: ', len(train_dataset)) 94 | print('Renewed with skip: ', max_skip) 95 | 96 | return construct_loader(train_dataset) 97 | 98 | def renew_bl_loader(max_skip): 99 | train_dataset = VOSDataset(path.join(bl_root, 'JPEGImages'), 100 | path.join(bl_root, 'Annotations'), max_skip, is_bl=True) 101 | 102 | print('Blender dataset size: ', len(train_dataset)) 103 | print('Renewed with skip: ', max_skip) 104 | 105 | return construct_loader(train_dataset) 106 | 107 | """ 108 | Dataset related 109 | """ 110 | 111 | """ 112 | These define the training schedule of the distance between frames 113 | We will switch to skip_values[i] once we pass the percentage specified by increase_skip_fraction[i] 114 | Not effective for stage 0 training 115 | """ 116 | 117 | max_interval = 10 118 | print('max_interval: %d' %(max_interval)) 119 | print(para) 120 | 121 | # stage 0 is not used in our work for simplicity 122 | if para['stage'] == 0: 123 | static_root = path.expanduser(para['static_root']) 124 | fss_dataset = StaticTransformDataset(path.join(static_root, 'fss'), method=0) 125 | duts_tr_dataset = StaticTransformDataset(path.join(static_root, 'DUTS-TR'), method=1) 126 | duts_te_dataset = StaticTransformDataset(path.join(static_root, 'DUTS-TE'), method=1) 127 | ecssd_dataset = StaticTransformDataset(path.join(static_root, 'ecssd'), method=1) 128 | 129 | big_dataset = StaticTransformDataset(path.join(static_root, 'BIG_small'), method=1) 130 | hrsod_dataset = StaticTransformDataset(path.join(static_root, 'HRSOD_small'), method=1) 131 | 132 | # BIG and HRSOD have higher quality, use them more 133 | train_dataset = ConcatDataset([fss_dataset, duts_tr_dataset, duts_te_dataset, ecssd_dataset] 134 | + [big_dataset, hrsod_dataset]*5) 135 | train_sampler, train_loader = construct_loader(train_dataset) 136 | 137 | print('Static dataset size: ', len(train_dataset)) 138 | elif para['stage'] == 1: 139 | increase_skip_fraction = [0.1, 0.2, 0.3, 0.4, 0.8, 1.0] 140 | bl_root = path.join(path.expanduser(para['bl_root'])) 141 | 142 | train_sampler, train_loader = renew_bl_loader(5) 143 | renew_loader = renew_bl_loader 144 | else: 145 | # stage 2 or 3 146 | # VOS dataset, 480p is used for both datasets 147 | yv_root = path.join(path.expanduser(para['yv_root']), 'train_480p') 148 | davis_root = path.join(path.expanduser(para['davis_root']), '2017', 'trainval') 149 | 150 | train_sampler, train_loader = renew_vos_loader(max_interval) 151 | renew_loader = renew_vos_loader 152 | 153 | 154 | """ 155 | Determine current/max epoch 156 | """ 157 | total_epoch = math.ceil(para['iterations']/len(train_loader)) 158 | current_epoch = total_iter // len(train_loader) 159 | print('Number of training epochs (the last epoch might not complete): ', total_epoch) 160 | 161 | """ 162 | Starts training 163 | """ 164 | # Need this to select random bases in different workers 165 | np.random.seed(np.random.randint(2**30-1) + local_rank*100) 166 | try: 167 | for e in range(current_epoch, total_epoch): 168 | print('Epoch %d/%d' % (e, total_epoch)) 169 | train_sampler.set_epoch(e) 170 | 171 | # Train loop 172 | model.train() 173 | for data in train_loader: 174 | model.do_pass(data, total_iter) # 4, 3, 3, 384, 384 175 | total_iter += 1 176 | 177 | if total_iter >= para['iterations']: 178 | break 179 | 180 | finally: 181 | if not para['debug'] and model.logger is not None and total_iter>90000: 182 | model.save(total_iter) 183 | # Clean up 184 | distributed.destroy_process_group() 185 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/util/__init__.py -------------------------------------------------------------------------------- /util/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/util/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/hyper_para.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/util/__pycache__/hyper_para.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/image_saver.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/util/__pycache__/image_saver.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/load_subset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/util/__pycache__/load_subset.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/log_integrator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/util/__pycache__/log_integrator.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/util/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/tensor_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jimmy-dq/SimVOS/ef94709b9e8a4bda43276751be67f98ae1fc12e8/util/__pycache__/tensor_util.cpython-36.pyc -------------------------------------------------------------------------------- /util/davis_subset.txt: -------------------------------------------------------------------------------- 1 | bear 2 | bmx-bumps 3 | boat 4 | boxing-fisheye 5 | breakdance-flare 6 | bus 7 | car-turn 8 | cat-girl 9 | classic-car 10 | color-run 11 | crossing 12 | dance-jump 13 | dancing 14 | disc-jockey 15 | dog-agility 16 | dog-gooses 17 | dogs-scale 18 | drift-turn 19 | drone 20 | elephant 21 | flamingo 22 | hike 23 | hockey 24 | horsejump-low 25 | kid-football 26 | kite-walk 27 | koala 28 | lady-running 29 | lindy-hop 30 | longboard 31 | lucia 32 | mallard-fly 33 | mallard-water 34 | miami-surf 35 | motocross-bumps 36 | motorbike 37 | night-race 38 | paragliding 39 | planes-water 40 | rallye 41 | rhino 42 | rollerblade 43 | schoolgirls 44 | scooter-board 45 | scooter-gray 46 | sheep 47 | skate-park 48 | snowboard 49 | soccerball 50 | stroller 51 | stunt 52 | surf 53 | swing 54 | tennis 55 | tractor-sand 56 | train 57 | tuk-tuk 58 | upside-down 59 | varanus-cage 60 | walking -------------------------------------------------------------------------------- /util/hyper_para.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | def none_or_default(x, default): 5 | return x if x is not None else default 6 | 7 | class HyperParameters(): 8 | def parse(self, unknown_arg_ok=False): 9 | parser = ArgumentParser() 10 | 11 | # Enable torch.backends.cudnn.benchmark -- Faster in some cases, test in your own environment 12 | parser.add_argument('--benchmark', action='store_true') 13 | parser.add_argument('--no_amp', action='store_true') 14 | 15 | parser.add_argument('--use_pos_emd', default=True) 16 | parser.add_argument('--layer_index', help='the selected layer for global attention', type=int, default=4) 17 | # parameters for SWEM 18 | parser.add_argument('--valdim', help='feat. dim for ViT', type=int, default=768) 19 | parser.add_argument('--num_iters', help='iterations of SWEM', type=int, default=4) 20 | parser.add_argument('--num_bases', help='num of bases', type=int, default=512) # 256 before 21 | parser.add_argument('--tau', help='tau used in softmax', type=float, default=5) #0.05 22 | parser.add_argument('--use_window_partition', default=False) 23 | parser.add_argument('--use_token_learner', default=True) 24 | 25 | parser.add_argument('--backbone_type', help='backbone_type', default='vit_base') 26 | parser.add_argument('--drop_path_rate', help='drop_path_rate', default=0.1, type=float) 27 | 28 | # parameters for tokenlearner 29 | parser.add_argument('--num_bases_foreground', help='num of bases', type=int, default=384) #256 30 | parser.add_argument('--num_bases_background', help='num of bases', type=int, default=384) #256 31 | parser.add_argument('--img_size', help='image size', type=int, default=384) 32 | 33 | # Data parameters 34 | parser.add_argument('--static_root', help='Static training data root', default='../static') 35 | parser.add_argument('--bl_root', help='Blender training data root', default='../BL30K') 36 | parser.add_argument('--yv_root', help='YouTubeVOS data root', default='/home/user/HDD_Data/YouTube19') 37 | parser.add_argument('--davis_root', help='DAVIS data root', default='/home/user/Data/DAVIS') 38 | 39 | parser.add_argument('--stage', help='Training stage (0-static images, 1-Blender dataset, 2-DAVIS+YouTubeVOS (300K), 3-DAVIS+YouTubeVOS (150K))', type=int, default=0) 40 | parser.add_argument('--num_workers', help='Number of datalaoder workers per process', type=int, default=8) #8 41 | 42 | # Generic learning parameters 43 | parser.add_argument('-b', '--batch_size', help='Default is dependent on the training stage, see below', default=None, type=int) 44 | parser.add_argument('-i', '--iterations', help='Default is dependent on the training stage, see below', default=None, type=int) 45 | parser.add_argument('--steps', help='Default is dependent on the training stage, see below', nargs="*", default=None, type=int) 46 | 47 | parser.add_argument('--lr', help='Initial learning rate', type=float) 48 | parser.add_argument('--gamma', help='LR := LR*gamma at every decay step', default=0.1, type=float) 49 | 50 | # Loading 51 | parser.add_argument('--load_network', help='Path to pretrained network weight only') 52 | parser.add_argument('--load_model', help='Path to the model file, including network, optimizer and such') 53 | 54 | # Logging information 55 | parser.add_argument('--id', help='Experiment UNIQUE id, use NULL to disable logging to tensorboard', default='NULL') 56 | parser.add_argument('--debug', help='Debug mode which logs information more often', action='store_true') 57 | 58 | # Multiprocessing parameters, not set by users 59 | parser.add_argument('--local_rank', default=0, type=int, help='Local rank of this process') 60 | 61 | if unknown_arg_ok: 62 | args, _ = parser.parse_known_args() 63 | self.args = vars(args) 64 | else: 65 | self.args = vars(parser.parse_args()) 66 | 67 | self.args['amp'] = not self.args['no_amp'] 68 | 69 | # Stage-dependent hyperparameters 70 | # Assign default if not given 71 | if self.args['stage'] == 0: 72 | # Static image pretraining 73 | self.args['lr'] = none_or_default(self.args['lr'], 1e-5) 74 | self.args['batch_size'] = none_or_default(self.args['batch_size'], 8) 75 | self.args['iterations'] = none_or_default(self.args['iterations'], 300000) 76 | self.args['steps'] = none_or_default(self.args['steps'], [150000]) 77 | self.args['single_object'] = True 78 | elif self.args['stage'] == 1: 79 | # BL30K pretraining 80 | self.args['lr'] = none_or_default(self.args['lr'], 1e-5) 81 | self.args['batch_size'] = none_or_default(self.args['batch_size'], 4) 82 | self.args['iterations'] = none_or_default(self.args['iterations'], 500000) 83 | self.args['steps'] = none_or_default(self.args['steps'], [400000]) 84 | self.args['single_object'] = False 85 | elif self.args['stage'] == 2: 86 | # 300K main training for after BL30K 87 | self.args['lr'] = none_or_default(self.args['lr'], 1e-5) 88 | self.args['batch_size'] = none_or_default(self.args['batch_size'], 4) 89 | self.args['iterations'] = none_or_default(self.args['iterations'], 300000) 90 | self.args['steps'] = none_or_default(self.args['steps'], [250000]) 91 | self.args['single_object'] = False 92 | elif self.args['stage'] == 3: 93 | # 150K main training for after static image pretraining 94 | self.args['lr'] = none_or_default(self.args['lr'], 2e-5) 95 | self.args['batch_size'] = none_or_default(self.args['batch_size'], 4) 96 | self.args['iterations'] = none_or_default(self.args['iterations'], 210000) 97 | self.args['steps'] = none_or_default(self.args['steps'], [125000]) 98 | self.args['single_object'] = False 99 | 100 | 101 | else: 102 | raise NotImplementedError 103 | 104 | def __getitem__(self, key): 105 | return self.args[key] 106 | 107 | def __setitem__(self, key, value): 108 | self.args[key] = value 109 | 110 | def __str__(self): 111 | return str(self.args) 112 | 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /util/image_saver.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | import torch 5 | from dataset.range_transform import inv_im_trans 6 | from collections import defaultdict 7 | 8 | def tensor_to_numpy(image): 9 | image_np = (image.numpy() * 255).astype('uint8') 10 | return image_np 11 | 12 | def tensor_to_np_float(image): 13 | image_np = image.numpy().astype('float32') 14 | return image_np 15 | 16 | def detach_to_cpu(x): 17 | return x.detach().cpu() 18 | 19 | def transpose_np(x): 20 | return np.transpose(x, [1,2,0]) 21 | 22 | def tensor_to_gray_im(x): 23 | x = detach_to_cpu(x) 24 | x = tensor_to_numpy(x) 25 | x = transpose_np(x) 26 | return x 27 | 28 | def tensor_to_im(x): 29 | x = detach_to_cpu(x) 30 | x = inv_im_trans(x).clamp(0, 1) 31 | x = tensor_to_numpy(x) 32 | x = transpose_np(x) 33 | return x 34 | 35 | def tensor_to_seg(x): 36 | x = detach_to_cpu(x) 37 | x = inv_seg_trans(x).clamp(0, 1) 38 | x = tensor_to_numpy(x) 39 | x = transpose_np(x) 40 | return x 41 | 42 | # Predefined key <-> caption dict 43 | key_captions = { 44 | 'im': 'Image', 45 | 'gt': 'GT', 46 | } 47 | 48 | """ 49 | Return an image array with captions 50 | keys in dictionary will be used as caption if not provided 51 | values should contain lists of cv2 images 52 | """ 53 | def get_image_array(images, grid_shape, captions={}): 54 | h, w = grid_shape 55 | cate_counts = len(images) 56 | rows_counts = len(next(iter(images.values()))) 57 | 58 | font = cv2.FONT_HERSHEY_SIMPLEX 59 | 60 | output_image = np.zeros([w*cate_counts, h*(rows_counts+1), 3], dtype=np.uint8) 61 | col_cnt = 0 62 | for k, v in images.items(): 63 | 64 | # Default as key value itself 65 | caption = captions.get(k, k) 66 | 67 | # Handles new line character 68 | dy = 40 69 | for i, line in enumerate(caption.split('\n')): 70 | cv2.putText(output_image, line, (10, col_cnt*w+100+i*dy), 71 | font, 0.8, (255,255,255), 2, cv2.LINE_AA) 72 | 73 | # Put images 74 | for row_cnt, img in enumerate(v): 75 | im_shape = img.shape 76 | if len(im_shape) == 2: 77 | img = img[..., np.newaxis] 78 | 79 | img = (img * 255).astype('uint8') 80 | 81 | output_image[(col_cnt+0)*w:(col_cnt+1)*w, 82 | (row_cnt+1)*h:(row_cnt+2)*h, :] = img 83 | 84 | col_cnt += 1 85 | 86 | return output_image 87 | 88 | def base_transform(im, size): 89 | im = tensor_to_np_float(im) 90 | if len(im.shape) == 3: 91 | im = im.transpose((1, 2, 0)) 92 | else: 93 | im = im[:, :, None] 94 | 95 | # Resize 96 | if im.shape[1] != size: 97 | im = cv2.resize(im, size, interpolation=cv2.INTER_NEAREST) 98 | 99 | return im.clip(0, 1) 100 | 101 | def im_transform(im, size): 102 | return base_transform(inv_im_trans(detach_to_cpu(im)), size=size) 103 | 104 | def mask_transform(mask, size): 105 | return base_transform(detach_to_cpu(mask), size=size) 106 | 107 | def out_transform(mask, size): 108 | return base_transform(detach_to_cpu(torch.sigmoid(mask)), size=size) 109 | 110 | def pool_pairs(images, size, so): 111 | req_images = defaultdict(list) 112 | 113 | b, s, _, _, _ = images['gt'].shape 114 | 115 | # limit number of images to save disk space 116 | if b >= 2: 117 | b = max(2, b) # assume that batch size is larger than 2 118 | else: 119 | b = 1 120 | 121 | GT_name = 'GT' 122 | for b_idx in range(b): 123 | GT_name += ' %s\n' % images['info']['name'][b_idx] 124 | 125 | for b_idx in range(b): 126 | for s_idx in range(s): 127 | req_images['RGB'].append(im_transform(images['rgb'][b_idx,s_idx], size)) 128 | if s_idx == 0: 129 | req_images['Mask'].append(np.zeros((size[1], size[0], 3))) 130 | if not so: 131 | req_images['Mask 2'].append(np.zeros((size[1], size[0], 3))) 132 | else: 133 | req_images['Mask'].append(mask_transform(images['mask_%d'%s_idx][b_idx], size)) 134 | if not so: 135 | req_images['Mask 2'].append(mask_transform(images['sec_mask_%d'%s_idx][b_idx], size)) 136 | req_images[GT_name].append(mask_transform(images['gt'][b_idx,s_idx], size)) 137 | if not so: 138 | req_images[GT_name + '_2'].append(mask_transform(images['sec_gt'][b_idx,s_idx], size)) 139 | 140 | return get_image_array(req_images, size, key_captions) -------------------------------------------------------------------------------- /util/load_subset.py: -------------------------------------------------------------------------------- 1 | """ 2 | load_subset.py - Presents a subset of data 3 | DAVIS - only the training set 4 | YouTubeVOS - I manually filtered some erroneous ones out but I haven't checked all 5 | """ 6 | 7 | 8 | def load_sub_davis(path='util/davis_subset.txt'): 9 | with open(path, mode='r') as f: 10 | subset = set(f.read().splitlines()) 11 | return subset 12 | 13 | def load_sub_yv(path='util/yv_subset.txt'): 14 | with open(path, mode='r') as f: 15 | subset = set(f.read().splitlines()) 16 | return subset 17 | -------------------------------------------------------------------------------- /util/log_integrator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Integrate numerical values for some iterations 3 | Typically used for loss computation / logging to tensorboard 4 | Call finalize and create a new Integrator when you want to display/log 5 | """ 6 | 7 | import torch 8 | 9 | 10 | class Integrator: 11 | def __init__(self, logger, distributed=True, local_rank=0, world_size=1): 12 | self.values = {} 13 | self.counts = {} 14 | self.hooks = [] # List is used here to maintain insertion order 15 | 16 | self.logger = logger 17 | 18 | self.distributed = distributed 19 | self.local_rank = local_rank 20 | self.world_size = world_size 21 | 22 | def add_tensor(self, key, tensor): 23 | if key not in self.values: 24 | self.counts[key] = 1 25 | if type(tensor) == float or type(tensor) == int: 26 | self.values[key] = tensor 27 | else: 28 | self.values[key] = tensor.mean().item() 29 | else: 30 | self.counts[key] += 1 31 | if type(tensor) == float or type(tensor) == int: 32 | self.values[key] += tensor 33 | else: 34 | self.values[key] += tensor.mean().item() 35 | 36 | def add_dict(self, tensor_dict): 37 | for k, v in tensor_dict.items(): 38 | self.add_tensor(k, v) 39 | 40 | def add_hook(self, hook): 41 | """ 42 | Adds a custom hook, i.e. compute new metrics using values in the dict 43 | The hook takes the dict as argument, and returns a (k, v) tuple 44 | e.g. for computing IoU 45 | """ 46 | if type(hook) == list: 47 | self.hooks.extend(hook) 48 | else: 49 | self.hooks.append(hook) 50 | 51 | def reset_except_hooks(self): 52 | self.values = {} 53 | self.counts = {} 54 | 55 | # Average and output the metrics 56 | def finalize(self, prefix, it, f=None): 57 | 58 | for hook in self.hooks: 59 | k, v = hook(self.values) 60 | self.add_tensor(k, v) 61 | 62 | for k, v in self.values.items(): 63 | 64 | if k[:4] == 'hide': 65 | continue 66 | 67 | avg = v / self.counts[k] 68 | 69 | if self.distributed: 70 | # Inplace operation 71 | avg = torch.tensor(avg).cuda() 72 | torch.distributed.reduce(avg, dst=0) 73 | 74 | if self.local_rank == 0: 75 | avg = (avg/self.world_size).cpu().item() 76 | self.logger.log_metrics(prefix, k, avg, it, f) 77 | else: 78 | # Simple does it 79 | self.logger.log_metrics(prefix, k, avg, it, f) 80 | 81 | -------------------------------------------------------------------------------- /util/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dumps things to tensorboard and console 3 | """ 4 | 5 | import os 6 | import warnings 7 | import git 8 | 9 | import torchvision.transforms as transforms 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | 13 | def tensor_to_numpy(image): 14 | image_np = (image.numpy() * 255).astype('uint8') 15 | return image_np 16 | 17 | def detach_to_cpu(x): 18 | return x.detach().cpu() 19 | 20 | def fix_width_trunc(x): 21 | return ('{:.9s}'.format('{:0.9f}'.format(x))) 22 | 23 | class TensorboardLogger: 24 | def __init__(self, short_id, id): 25 | self.short_id = short_id 26 | if self.short_id == 'NULL': 27 | self.short_id = 'DEBUG' 28 | 29 | if id is None: 30 | self.no_log = True 31 | warnings.warn('Logging has been disbaled.') 32 | else: 33 | self.no_log = False 34 | 35 | self.inv_im_trans = transforms.Normalize( 36 | mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], 37 | std=[1/0.229, 1/0.224, 1/0.225]) 38 | 39 | self.inv_seg_trans = transforms.Normalize( 40 | mean=[-0.5/0.5], 41 | std=[1/0.5]) 42 | 43 | log_path = os.path.join('.', 'log', '%s' % id) 44 | self.logger = SummaryWriter(log_path) 45 | 46 | repo = git.Repo(".") 47 | self.log_string('git', str(repo.active_branch) + ' ' + str(repo.head.commit.hexsha)) 48 | 49 | def log_scalar(self, tag, x, step): 50 | if self.no_log: 51 | warnings.warn('Logging has been disabled.') 52 | return 53 | self.logger.add_scalar(tag, x, step) 54 | 55 | def log_metrics(self, l1_tag, l2_tag, val, step, f=None): 56 | tag = l1_tag + '/' + l2_tag 57 | text = '{:s} - It {:6d} [{:5s}] [{:13}]: {:s}'.format(self.short_id, step, l1_tag.upper(), l2_tag, fix_width_trunc(val)) 58 | print(text) 59 | if f is not None: 60 | f.write(text + '\n') 61 | f.flush() 62 | self.log_scalar(tag, val, step) 63 | 64 | def log_im(self, tag, x, step): 65 | if self.no_log: 66 | warnings.warn('Logging has been disabled.') 67 | return 68 | x = detach_to_cpu(x) 69 | x = self.inv_im_trans(x) 70 | x = tensor_to_numpy(x) 71 | self.logger.add_image(tag, x, step) 72 | 73 | def log_cv2(self, tag, x, step): 74 | if self.no_log: 75 | warnings.warn('Logging has been disabled.') 76 | return 77 | x = x.transpose((2, 0, 1)) 78 | self.logger.add_image(tag, x, step) 79 | 80 | def log_seg(self, tag, x, step): 81 | if self.no_log: 82 | warnings.warn('Logging has been disabled.') 83 | return 84 | x = detach_to_cpu(x) 85 | x = self.inv_seg_trans(x) 86 | x = tensor_to_numpy(x) 87 | self.logger.add_image(tag, x, step) 88 | 89 | def log_gray(self, tag, x, step): 90 | if self.no_log: 91 | warnings.warn('Logging has been disabled.') 92 | return 93 | x = detach_to_cpu(x) 94 | x = tensor_to_numpy(x) 95 | self.logger.add_image(tag, x, step) 96 | 97 | def log_string(self, tag, x): 98 | print(tag, x) 99 | if self.no_log: 100 | warnings.warn('Logging has been disabled.') 101 | return 102 | self.logger.add_text(tag, x) 103 | -------------------------------------------------------------------------------- /util/tensor_util.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | def compute_tensor_iu(seg, gt): 4 | intersection = (seg & gt).float().sum() 5 | union = (seg | gt).float().sum() 6 | 7 | return intersection, union 8 | 9 | def compute_tensor_iou(seg, gt): 10 | intersection, union = compute_tensor_iu(seg, gt) 11 | iou = (intersection + 1e-6) / (union + 1e-6) 12 | 13 | return iou 14 | 15 | # STM 16 | def pad_divide_by(in_img, d, in_size=None): 17 | if in_size is None: 18 | h, w = in_img.shape[-2:] 19 | else: 20 | h, w = in_size 21 | 22 | if h % d > 0: 23 | new_h = h + d - h % d 24 | else: 25 | new_h = h 26 | if w % d > 0: 27 | new_w = w + d - w % d 28 | else: 29 | new_w = w 30 | lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2) 31 | lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2) 32 | pad_array = (int(lw), int(uw), int(lh), int(uh)) 33 | out = F.pad(in_img, pad_array) 34 | return out, pad_array 35 | 36 | def unpad(img, pad): 37 | if pad[2]+pad[3] > 0: 38 | img = img[:,:,pad[2]:-pad[3],:] 39 | if pad[0]+pad[1] > 0: 40 | img = img[:,:,:,pad[0]:-pad[1]] 41 | return img --------------------------------------------------------------------------------