├── .gitignore ├── README.md ├── dataset.py ├── eval_DAVIS.py ├── helpers.py ├── img └── main.jpg └── model.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Visual Studio Code 132 | .vscode/* 133 | #!.vscode/settings.json 134 | #!.vscode/tasks.json 135 | #!.vscode/launch.json 136 | #!.vscode/extensions.json 137 | *.code-workspace 138 | 139 | # Local History for Visual Studio Code 140 | .history/ 141 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Hierarchical Memory Matching Network for Video Object Segmentation 2 | ### Hongje Seong, Seoung Wug Oh, Joon-Young Lee, Seongwon Lee, Suhyeon Lee, Euntai Kim 3 | ### ICCV 2021 4 | 5 | no_image 6 | 7 | This is the implementation of HMMN. 8 | This code is based on STM (ICCV 2019): [[link](https://github.com/seoungwugoh/STM)]. 9 | Please see our paper for the details: [[paper]](https://arxiv.org/abs/2109.11404) 10 | 11 | [![Hierarchical Memory Matching Network for Video Object Segmentation (ICCV 2021)](https://img.youtube.com/vi/zSofRzPImQY/0.jpg)](https://www.youtube.com/watch?v=zSofRzPImQY "Hierarchical Memory Matching Network for Video Object Segmentation (ICCV 2021)") 12 | 13 | ## Dependencies 14 | * Python 3.8 15 | * PyTorch 1.8.1 16 | * numpy, opencv, pillow 17 | 18 | ## Trained model 19 | * Download pre-trained weights into the same folder with demo scripts
Link: [[weights](https://drive.google.com/file/d/16vOMm7hIdmC6yL4FlBO4p_miTeUoskt7/view?usp=sharing)] 20 | 21 | 22 | ## Code 23 | * DAVIS-2016 validation set (Single-object) 24 | ```bash 25 | python eval_DAVIS.py -g '0' -s val -y 16 -D [path/to/DAVIS] 26 | ``` 27 | * DAVIS-2017 validation set (Multi-object) 28 | ```bash 29 | python eval_DAVIS.py -g '0' -s val -y 17 -D [path/to/DAVIS] 30 | ``` 31 | 32 | ## Pre-computed Results 33 | We also provide pre-computed results for benchmark sets. 34 | * [[DAVIS-16-val]](https://drive.google.com/file/d/1SqzBktU0DrSd5_vC7TVmPtXJBn3bqAfG/view?usp=sharing) 35 | * [[DAVIS-17-val]](https://drive.google.com/file/d/1uDx8rPo91qEnoE_nBCYF-A-G8noRq6G8/view?usp=sharing) 36 | * [[DAVIS-17-test-dev]](https://drive.google.com/file/d/18-p2ihxfHZisOghWiSpvlfbhpy1j1oSl/view?usp=sharing) 37 | * [[YouTube-VOS-18-valid]](https://drive.google.com/file/d/1cE9rtqdafXGm7V3rP2ZSXq56atoXVUg3/view?usp=sharing) 38 | * [[YouTube-VOS-19-valid]](https://drive.google.com/file/d/1bA2iv2KhjYGlw5i25dLPf24byt-5On2r/view?usp=sharing) 39 | 40 | 41 | ## Bibtex 42 | ``` 43 | @inproceedings{seong2021hierarchical, 44 | title={Hierarchical Memory Matching Network for Video Object Segmentation}, 45 | author={Seong, Hongje and Oh, Seoung Wug and Lee, Joon-Young and Lee, Seongwon and Lee, Suhyeon and Kim, Euntai}, 46 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 47 | year={2021} 48 | } 49 | ``` 50 | 51 | 52 | ## Terms of Use 53 | This software is for non-commercial use only. 54 | The source code is released under the Attribution-NonCommercial-ShareAlike (CC BY-NC-SA) Licence 55 | (see [this](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) for details) 56 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | from PIL import Image 5 | 6 | import torch 7 | import torchvision 8 | from torch.utils import data 9 | 10 | import glob 11 | 12 | class DAVIS_MO_Test(data.Dataset): 13 | # for multi object, do shuffling 14 | 15 | def __init__(self, root, imset='2017/train.txt', resolution='480p', single_object=False): 16 | self.root = root 17 | self.mask_dir = os.path.join(root, 'Annotations', resolution) 18 | self.mask480_dir = os.path.join(root, 'Annotations', '480p') 19 | self.image_dir = os.path.join(root, 'JPEGImages', resolution) 20 | _imset_dir = os.path.join(root, 'ImageSets') 21 | _imset_f = os.path.join(_imset_dir, imset) 22 | 23 | self.videos = [] 24 | self.num_frames = {} 25 | self.num_objects = {} 26 | self.shape = {} 27 | self.size_480p = {} 28 | with open(os.path.join(_imset_f), "r") as lines: 29 | for line in lines: 30 | _video = line.rstrip('\n') 31 | self.videos.append(_video) 32 | self.num_frames[_video] = len(glob.glob(os.path.join(self.image_dir, _video, '*.jpg'))) 33 | _mask = np.array(Image.open(os.path.join(self.mask_dir, _video, '00000.png')).convert("P")) 34 | self.num_objects[_video] = np.max(_mask) 35 | self.shape[_video] = np.shape(_mask) 36 | _mask480 = np.array(Image.open(os.path.join(self.mask480_dir, _video, '00000.png')).convert("P")) 37 | self.size_480p[_video] = np.shape(_mask480) 38 | 39 | self.K = -1 40 | self.single_object = single_object 41 | 42 | def __len__(self): 43 | return len(self.videos) 44 | 45 | 46 | def To_onehot(self, mask): 47 | M = np.zeros((self.K, mask.shape[0], mask.shape[1]), dtype=np.uint8) 48 | for k in range(self.K): 49 | M[k] = (mask == k).astype(np.uint8) 50 | return M 51 | 52 | def All_to_onehot(self, masks): 53 | Ms = np.zeros((self.K, masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8) 54 | for n in range(masks.shape[0]): 55 | Ms[:,n] = self.To_onehot(masks[n]) 56 | return Ms 57 | 58 | def __getitem__(self, index): 59 | video = self.videos[index] 60 | info = {} 61 | info['name'] = video 62 | info['num_frames'] = self.num_frames[video] 63 | info['size_480p'] = self.size_480p[video] 64 | 65 | N_frames = np.empty((self.num_frames[video],)+self.shape[video]+(3,), dtype=np.float32) 66 | N_masks = np.empty((self.num_frames[video],)+self.shape[video], dtype=np.uint8) 67 | for f in range(self.num_frames[video]): 68 | img_file = os.path.join(self.image_dir, video, '{:05d}.jpg'.format(f)) 69 | N_frames[f] = np.array(Image.open(img_file).convert('RGB'))/255. 70 | try: 71 | mask_file = os.path.join(self.mask_dir, video, '{:05d}.png'.format(f)) 72 | N_masks[f] = np.array(Image.open(mask_file).convert('P'), dtype=np.uint8) 73 | except: 74 | # print('a') 75 | N_masks[f] = 255 76 | 77 | Fs = torch.from_numpy(np.transpose(N_frames.copy(), (3, 0, 1, 2)).copy()).float() 78 | if self.single_object: 79 | N_masks = (N_masks > 0.5).astype(np.uint8) * (N_masks < 255).astype(np.uint8) 80 | self.K = int(1+1) 81 | Ms = torch.from_numpy(self.All_to_onehot(N_masks).copy()).float() 82 | num_objects = torch.LongTensor([int(1)]) 83 | return Fs, Ms, num_objects, info 84 | else: 85 | self.K = int(self.num_objects[video]+1) 86 | Ms = torch.from_numpy(self.All_to_onehot(N_masks).copy()).float() 87 | num_objects = torch.LongTensor([int(self.num_objects[video])]) 88 | return Fs, Ms, num_objects, info 89 | 90 | 91 | 92 | if __name__ == '__main__': 93 | pass 94 | -------------------------------------------------------------------------------- /eval_DAVIS.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | from torch.autograd import Variable 4 | from torch.utils import data 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.nn.init as init 9 | import torch.utils.model_zoo as model_zoo 10 | from torchvision import models 11 | 12 | # general libs 13 | import cv2 14 | import matplotlib.pyplot as plt 15 | from PIL import Image 16 | import numpy as np 17 | import math 18 | import time 19 | import tqdm 20 | import os 21 | import argparse 22 | import copy 23 | 24 | 25 | ### My libs 26 | from dataset import DAVIS_MO_Test 27 | from model import HMMN 28 | 29 | 30 | torch.set_grad_enabled(False) # Volatile 31 | 32 | def get_arguments(): 33 | parser = argparse.ArgumentParser(description="SST") 34 | parser.add_argument("-g", type=str, help="0; 0,1; 0,3; etc", required=True) 35 | parser.add_argument("-s", type=str, help="set", required=True) 36 | parser.add_argument("-y", type=int, help="year", required=True) 37 | parser.add_argument("-viz", help="Save visualization", action="store_true") 38 | parser.add_argument("-D", type=str, help="path to data",default='/local/DATA') 39 | return parser.parse_args() 40 | 41 | args = get_arguments() 42 | 43 | GPU = args.g 44 | YEAR = args.y 45 | SET = args.s 46 | VIZ = args.viz 47 | DATA_ROOT = args.D 48 | 49 | # Model and version 50 | MODEL = 'HMMN' 51 | print(MODEL, ': Testing on DAVIS') 52 | 53 | os.environ['CUDA_VISIBLE_DEVICES'] = GPU 54 | if torch.cuda.is_available(): 55 | print('using Cuda devices, num:', torch.cuda.device_count()) 56 | 57 | if VIZ: 58 | print('--- Produce mask overaid video outputs. Evaluation will run slow.') 59 | print('--- Require FFMPEG for encoding, Check folder ./viz') 60 | 61 | 62 | palette = Image.open(DATA_ROOT + '/Annotations/480p/blackswan/00000.png').getpalette() 63 | 64 | def Run_video(Fs, Ms, num_frames, num_objects, Mem_every=None, Mem_number=None): 65 | # initialize storage tensors 66 | if Mem_every: 67 | to_memorize = [int(i) for i in np.arange(0, num_frames, step=Mem_every)] 68 | elif Mem_number: 69 | to_memorize = [int(round(i)) for i in np.linspace(0, num_frames, num=Mem_number+2)[:-1]] 70 | else: 71 | raise NotImplementedError 72 | 73 | Es = torch.zeros_like(Ms) 74 | Es[:,:,0] = Ms[:,:,0] 75 | 76 | for t in tqdm.tqdm(range(1, num_frames)): 77 | # memorize 78 | with torch.no_grad(): 79 | prev_key4, prev_value4, prev_key3, prev_value3, prev_key2, prev_value2 = model(Fs[:,:,t-1], Es[:,:,t-1], torch.tensor([num_objects])) 80 | 81 | if t-1 == 0: # only prev memory 82 | this_keys4, this_values4 = prev_key4, prev_value4 83 | this_keys3, this_values3 = prev_key3, prev_value3 84 | this_keys2, this_values2 = prev_key2, prev_value2 85 | else: 86 | this_keys4 = torch.cat([keys4, prev_key4], dim=2) 87 | this_values4 = torch.cat([values4, prev_value4], dim=2) 88 | this_keys3 = torch.cat([keys3, prev_key3], dim=2) 89 | this_values3 = torch.cat([values3, prev_value3], dim=2) 90 | this_keys2 = torch.cat([keys2, prev_key2], dim=2) 91 | this_values2 = torch.cat([values2, prev_value2], dim=2) 92 | 93 | # segment 94 | with torch.no_grad(): 95 | logit = model(Fs[:,:,t], this_keys4, this_values4, this_keys3, this_values3, this_keys2, this_values2, torch.tensor([num_objects])) 96 | Es[:,:,t] = F.softmax(logit, dim=1) 97 | 98 | # update 99 | if t-1 in to_memorize: 100 | keys4, values4 = this_keys4, this_values4 101 | if t-1 == 0: # update only the first frame memory 102 | keys3, values3 = this_keys3, this_values3 103 | keys2, values2 = this_keys2, this_values2 104 | 105 | pred = np.argmax(Es[0].cpu().numpy(), axis=0).astype(np.uint8) 106 | return pred, Es 107 | 108 | 109 | 110 | Testset = DAVIS_MO_Test(DATA_ROOT, resolution='480p', imset='20{}/{}.txt'.format(YEAR,SET), single_object=(YEAR==16)) 111 | Testloader = data.DataLoader(Testset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True) 112 | 113 | model = nn.DataParallel(HMMN()) 114 | if torch.cuda.is_available(): 115 | model.cuda() 116 | model.eval() # turn-off BN 117 | 118 | pth_path = 'HMMN_weights.pth' 119 | print('Loading weights:', pth_path) 120 | model.load_state_dict(torch.load(pth_path, map_location='cpu')) 121 | 122 | code_name = '{}_DAVIS_{}{}'.format(MODEL,YEAR,SET) 123 | print('Start Testing:', code_name) 124 | 125 | 126 | for seq, V in enumerate(Testloader): 127 | Fs, Ms, num_objects, info = V 128 | seq_name = info['name'][0] 129 | num_frames = info['num_frames'][0].item() 130 | print('[{}]: num_frames: {}, num_objects: {}'.format(seq_name, num_frames, num_objects[0][0])) 131 | 132 | pred, Es = Run_video(Fs, Ms, num_frames, num_objects, Mem_every=5, Mem_number=None) 133 | 134 | # Save results for quantitative eval ###################### 135 | test_path = os.path.join('./test', code_name, seq_name) 136 | if not os.path.exists(test_path): 137 | os.makedirs(test_path) 138 | for f in range(num_frames): 139 | img_E = Image.fromarray(pred[f]) 140 | img_E.putpalette(palette) 141 | img_E.save(os.path.join(test_path, '{:05d}.png'.format(f))) 142 | 143 | if VIZ: 144 | from helpers import overlay_davis 145 | # visualize results ####################### 146 | viz_path = os.path.join('./viz/', code_name, seq_name) 147 | if not os.path.exists(viz_path): 148 | os.makedirs(viz_path) 149 | 150 | for f in range(num_frames): 151 | pF = (Fs[0,:,f].permute(1,2,0).numpy() * 255.).astype(np.uint8) 152 | pE = pred[f] 153 | canvas = overlay_davis(pF, pE, palette) 154 | canvas = Image.fromarray(canvas) 155 | canvas.save(os.path.join(viz_path, 'f{}.jpg'.format(f))) 156 | 157 | vid_path = os.path.join('./viz/', code_name, '{}.mp4'.format(seq_name)) 158 | frame_path = os.path.join('./viz/', code_name, seq_name, 'f%d.jpg') 159 | os.system('ffmpeg -framerate 10 -i {} {} -vcodec libx264 -crf 10 -pix_fmt yuv420p -nostats -loglevel 0 -y'.format(frame_path, vid_path)) 160 | 161 | 162 | 163 | -------------------------------------------------------------------------------- /helpers.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | #torch 3 | import torch 4 | from torch.autograd import Variable 5 | from torch.utils import data 6 | 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.nn.init as init 10 | import torch.utils.model_zoo as model_zoo 11 | from torchvision import models 12 | 13 | # general libs 14 | import cv2 15 | import matplotlib.pyplot as plt 16 | from PIL import Image 17 | import numpy as np 18 | import time 19 | import os 20 | import copy 21 | 22 | 23 | def ToCuda(xs): 24 | if torch.cuda.is_available(): 25 | if isinstance(xs, list) or isinstance(xs, tuple): 26 | return [x.cuda(non_blocking=True) for x in xs] 27 | else: 28 | return xs.cuda(non_blocking=True) 29 | else: 30 | return xs 31 | 32 | 33 | def pad_divide_by(in_list, d, in_size): 34 | out_list = [] 35 | h, w = in_size 36 | if h % d > 0: 37 | new_h = h + d - h % d 38 | else: 39 | new_h = h 40 | if w % d > 0: 41 | new_w = w + d - w % d 42 | else: 43 | new_w = w 44 | lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2) 45 | lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2) 46 | pad_array = (int(lw), int(uw), int(lh), int(uh)) 47 | for inp in in_list: 48 | out_list.append(F.pad(inp, pad_array)) 49 | return out_list, pad_array 50 | 51 | 52 | 53 | def overlay_davis(image,mask,colors=[255,0,0],cscale=2,alpha=0.4): 54 | """ Overlay segmentation on top of RGB image. from davis official""" 55 | # import skimage 56 | from scipy.ndimage.morphology import binary_erosion, binary_dilation 57 | 58 | colors = np.reshape(colors, (-1, 3)) 59 | colors = np.atleast_2d(colors) * cscale 60 | 61 | im_overlay = image.copy() 62 | object_ids = np.unique(mask) 63 | 64 | for object_id in object_ids[1:]: 65 | # Overlay color on binary mask 66 | foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id]) 67 | binary_mask = mask == object_id 68 | 69 | # Compose image 70 | im_overlay[binary_mask] = foreground[binary_mask] 71 | 72 | # countours = skimage.morphology.binary.binary_dilation(binary_mask) - binary_mask 73 | countours = binary_dilation(binary_mask) ^ binary_mask 74 | # countours = cv2.dilate(binary_mask, cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))) - binary_mask 75 | im_overlay[countours,:] = 0 76 | 77 | return im_overlay.astype(image.dtype) 78 | 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /img/main.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/HMMN/a138cd805c0e30effb910f1a282e27c969e4fcac/img/main.jpg -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | import torch.utils.model_zoo as model_zoo 7 | from torchvision import models 8 | 9 | # general libs 10 | import cv2 11 | import matplotlib.pyplot as plt 12 | from PIL import Image 13 | import numpy as np 14 | import math 15 | import time 16 | import tqdm 17 | import os 18 | import argparse 19 | import copy 20 | import sys 21 | 22 | from helpers import * 23 | 24 | print('Hierarchical Memory Matching Network: initialized.') 25 | 26 | class ResBlock(nn.Module): 27 | def __init__(self, indim, outdim=None, stride=1): 28 | super(ResBlock, self).__init__() 29 | if outdim == None: 30 | outdim = indim 31 | if indim == outdim and stride==1: 32 | self.downsample = None 33 | else: 34 | self.downsample = nn.Conv2d(indim, outdim, kernel_size=3, padding=1, stride=stride) 35 | 36 | self.conv1 = nn.Conv2d(indim, outdim, kernel_size=3, padding=1, stride=stride) 37 | self.conv2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1) 38 | 39 | 40 | def forward(self, x): 41 | r = self.conv1(F.relu(x)) 42 | r = self.conv2(F.relu(r)) 43 | 44 | if self.downsample is not None: 45 | x = self.downsample(x) 46 | 47 | return x + r 48 | 49 | class Encoder_M(nn.Module): 50 | def __init__(self): 51 | super(Encoder_M, self).__init__() 52 | self.conv1_m = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 53 | self.conv1_o = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 54 | 55 | resnet = models.resnet50(pretrained=True) 56 | self.conv1 = resnet.conv1 57 | self.bn1 = resnet.bn1 58 | self.relu = resnet.relu # 1/2, 64 59 | self.maxpool = resnet.maxpool 60 | 61 | self.res2 = resnet.layer1 # 1/4, 256 62 | self.res3 = resnet.layer2 # 1/8, 512 63 | self.res4 = resnet.layer3 # 1/16, 1024 64 | 65 | self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1)) 66 | self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1)) 67 | 68 | def forward(self, in_f, in_m, in_o): 69 | f = (in_f - self.mean) / self.std 70 | m = torch.unsqueeze(in_m, dim=1).float() # add channel dim 71 | o = torch.unsqueeze(in_o, dim=1).float() # add channel dim 72 | 73 | x = self.conv1(f) + self.conv1_m(m) + self.conv1_o(o) 74 | x = self.bn1(x) 75 | c1 = self.relu(x) # 1/2, 64 76 | x = self.maxpool(c1) # 1/4, 64 77 | r2 = self.res2(x) # 1/4, 256 78 | r3 = self.res3(r2) # 1/8, 512 79 | r4 = self.res4(r3) # 1/16, 1024 80 | return r4, r3, r2, c1, f 81 | 82 | class Encoder_Q(nn.Module): 83 | def __init__(self): 84 | super(Encoder_Q, self).__init__() 85 | resnet = models.resnet50(pretrained=True) 86 | self.conv1 = resnet.conv1 87 | self.bn1 = resnet.bn1 88 | self.relu = resnet.relu # 1/2, 64 89 | self.maxpool = resnet.maxpool 90 | 91 | self.res2 = resnet.layer1 # 1/4, 256 92 | self.res3 = resnet.layer2 # 1/8, 512 93 | self.res4 = resnet.layer3 # 1/16, 1024 94 | 95 | self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1)) 96 | self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1)) 97 | 98 | def forward(self, in_f): 99 | f = (in_f - self.mean) / self.std 100 | 101 | x = self.conv1(f) 102 | x = self.bn1(x) 103 | c1 = self.relu(x) # 1/2, 64 104 | x = self.maxpool(c1) # 1/4, 64 105 | r2 = self.res2(x) # 1/4, 256 106 | r3 = self.res3(r2) # 1/8, 512 107 | r4 = self.res4(r3) # 1/16, 1024 108 | return r4, r3, r2, c1, f 109 | 110 | 111 | class Refine(nn.Module): 112 | def __init__(self, inplanes, planes, scale_factor=2): 113 | super(Refine, self).__init__() 114 | self.convFS = nn.Conv2d(inplanes, planes, kernel_size=(3,3), padding=(1,1), stride=1) 115 | self.convMemory = nn.Conv2d(inplanes//2, planes, kernel_size=(3,3), padding=(1,1), stride=1, bias=False) 116 | self.ResFS = ResBlock(planes, planes) 117 | self.ResMM = ResBlock(planes, planes) 118 | self.scale_factor = scale_factor 119 | 120 | def forward(self, f, pm, mem): 121 | s = self.convFS(f) + self.convMemory(mem) 122 | s = self.ResFS(s) 123 | m = s + F.interpolate(pm, scale_factor=self.scale_factor, mode='bilinear', align_corners=False) 124 | m = self.ResMM(m) 125 | return m 126 | 127 | class Decoder(nn.Module): 128 | def __init__(self, mdim): 129 | super(Decoder, self).__init__() 130 | self.convFM = nn.Conv2d(1024, mdim, kernel_size=(3,3), padding=(1,1), stride=1) 131 | self.ResMM = ResBlock(mdim, mdim) 132 | self.RF3 = Refine(512, mdim) # 1/8 -> 1/4 133 | self.RF2 = Refine(256, mdim) # 1/4 -> 1 134 | 135 | self.pred2 = nn.Conv2d(mdim, 2, kernel_size=(3,3), padding=(1,1), stride=1) 136 | 137 | def forward(self, r4, r3, r2, mem3, mem2): 138 | m4 = self.ResMM(self.convFM(r4)) 139 | m3 = self.RF3(r3, m4, mem3) # out: 1/8, 256 140 | m2 = self.RF2(r2, m3, mem2) # out: 1/4, 256 141 | 142 | p2 = self.pred2(F.relu(m2)) 143 | 144 | p = F.interpolate(p2, scale_factor=4, mode='bilinear', align_corners=False) 145 | return p 146 | 147 | 148 | 149 | class Memory(nn.Module): 150 | def __init__(self, gaussian_kernel, gaussian_kernel_flow_window): 151 | super(Memory, self).__init__() 152 | self.gaussian_kernel = gaussian_kernel 153 | self.gaussian_kernel_flow_window = gaussian_kernel_flow_window 154 | if self.gaussian_kernel != -1: 155 | self.feature_H = -1 156 | self.feature_W = -1 157 | if self.gaussian_kernel_flow_window != -1: 158 | self.H_flow = -1 159 | self.W_flow = -1 160 | self.T_flow = 1e+7 161 | self.B_flow = -1 162 | 163 | def apply_gaussian_kernel(self, corr, h, w, sigma_factor=1.): 164 | b, hwt, hw = corr.size() 165 | 166 | idx = corr.max(dim=2)[1] # b x h2 x w2 167 | idx_y = (idx // w).view(b, hwt, 1, 1).float() 168 | idx_x = (idx % w).view(b, hwt, 1, 1).float() 169 | 170 | if h != self.feature_H: 171 | self.feature_H = h 172 | y_tmp = np.linspace(0,h-1,h) 173 | self.y = ToCuda(torch.FloatTensor(y_tmp)) 174 | y = self.y.view(1,1,h,1).expand(b, hwt, h, 1) 175 | 176 | if w != self.feature_W: 177 | self.feature_W = w 178 | x_tmp = np.linspace(0,w-1,w) 179 | self.x = ToCuda(torch.FloatTensor(x_tmp)) 180 | x = self.x.view(1,1,1,w).expand(b, hwt, 1, w) 181 | 182 | gauss_kernel = torch.exp(-((x-idx_x)**2 + (y-idx_y)**2) / (2 * (self.gaussian_kernel*sigma_factor)**2)) 183 | gauss_kernel = gauss_kernel.view(b, hwt, hw) 184 | 185 | return gauss_kernel, idx 186 | 187 | def forward(self, m_in, m_out, q_in, q_out): # m_in: o,c,t,h,w 188 | B, D_e, T, H, W = m_in.size() 189 | _, D_o, _, _, _ = m_out.size() 190 | 191 | mi = m_in.view(B, D_e, T*H*W) 192 | mi = torch.transpose(mi, 1, 2) # b, THW, emb 193 | 194 | qi = q_in.view(B, D_e, H*W) # b, emb, HW 195 | 196 | p = torch.bmm(mi, qi) # b, THW, HW 197 | p = p / math.sqrt(D_e) 198 | 199 | if self.gaussian_kernel != -1: 200 | if self.gaussian_kernel_flow_window != -1: 201 | p_tmp = p[:,int(-H*W):].clone() 202 | if (self.B_flow != B) or (self.T_flow != T) or (self.H_flow != H) or (self.W_flow != W): 203 | hide_non_local_qk_map_tmp = torch.ones(B,1,H,W,H,W).bool() 204 | window_size_half = (self.gaussian_kernel_flow_window-1) // 2 205 | for h_idx1 in range(H): 206 | for w_idx1 in range(W): 207 | h_left = max(h_idx1-window_size_half, 0) 208 | h_right = h_idx1+window_size_half+1 209 | w_left = max(w_idx1-window_size_half, 0) 210 | w_right = w_idx1+window_size_half+1 211 | hide_non_local_qk_map_tmp[:,0,h_idx1,w_idx1,h_left:h_right,w_left:w_right] = False 212 | hide_non_local_qk_map_tmp = hide_non_local_qk_map_tmp.view(B,H*W,H*W) 213 | self.hide_non_local_qk_map_flow = ToCuda(hide_non_local_qk_map_tmp) 214 | if (self.B_flow != B) or (self.T_flow > T) or (T==1) or (self.H_flow != H) or (self.W_flow != W): 215 | self.max_idx_stacked = None 216 | p_tmp.masked_fill_(self.hide_non_local_qk_map_flow, float('-inf')) 217 | gauss_kernel_map, max_idx = self.apply_gaussian_kernel(p_tmp, h=H, w=W) 218 | if self.max_idx_stacked is None: 219 | self.max_idx_stacked = max_idx 220 | else: 221 | if self.T_flow == T: 222 | self.max_idx_stacked = self.max_idx_stacked[:,:int(-H*W)] 223 | self.max_idx_stacked = torch.gather(max_idx, dim=1, index=self.max_idx_stacked) 224 | for t_ in range(1, T): 225 | gauss_kernel_map_tmp, _ = self.apply_gaussian_kernel(p_tmp, h=H, w=W, sigma_factor=(t_*0.5)+1) 226 | gauss_kernel_map_tmp = torch.gather(gauss_kernel_map_tmp, dim=1, index=self.max_idx_stacked[:,int((T-t_-1)*H*W):int((T-t_)*H*W)].unsqueeze(-1).expand(-1,-1,int(H*W))) 227 | gauss_kernel_map = torch.cat((gauss_kernel_map_tmp, gauss_kernel_map), dim=1) 228 | self.max_idx_stacked = torch.cat((self.max_idx_stacked, max_idx), dim=1) 229 | self.T_flow = T 230 | self.H_flow = H 231 | self.W_flow = W 232 | self.B_flow = B 233 | else: 234 | gauss_kernel_map, _ = self.apply_gaussian_kernel(p, h=H, w=W) 235 | 236 | p = F.softmax(p, dim=1) # b, THW, HW 237 | 238 | if self.gaussian_kernel != -1: 239 | p.mul_(gauss_kernel_map) 240 | p.div_(p.sum(dim=1, keepdim=True)) 241 | 242 | mo = m_out.view(B, D_o, T*H*W) 243 | mem = torch.bmm(mo, p) # Weighted-sum B, D_o, HW 244 | mem = mem.view(B, D_o, H, W) 245 | 246 | mem_out = torch.cat([mem, q_out], dim=1) 247 | 248 | return mem_out, p 249 | 250 | 251 | class Memory_topk(nn.Module): 252 | def __init__(self, topk_guided_num): 253 | super(Memory_topk, self).__init__() 254 | self.topk_guided_num = topk_guided_num 255 | 256 | def forward(self, m_in, m_out, q_in, qk_ref, qk_ref_topk_indices=None, qk_ref_topk_val=None, mem_dropout=None): # m_in: o,c,t,h,w 257 | B_ori, D_e, T, H, W = m_in.size() 258 | _, D_o, _, _, _ = m_out.size() 259 | 260 | _, THW_ref, HW_ref = qk_ref.size() 261 | resolution_ref = int(math.sqrt((H*W) // HW_ref)) 262 | H_ref = H // resolution_ref 263 | W_ref = W // resolution_ref 264 | 265 | size = resolution_ref 266 | 267 | if qk_ref_topk_indices is None: 268 | qk_ref_topk_val, qk_ref_topk_indices = torch.topk(qk_ref.transpose(1,2), k=self.topk_guided_num, dim=2, sorted=True) 269 | topk_guided_num = self.topk_guided_num 270 | else: 271 | topk_guided_num = qk_ref_topk_indices.shape[2] 272 | 273 | B = B_ori 274 | qk_ref_selected = qk_ref 275 | qk_ref_topk_indices_selected = qk_ref_topk_indices 276 | m_in_selected = m_in 277 | m_out_selected = m_out 278 | q_in_selected = q_in 279 | 280 | ref = torch.zeros_like(qk_ref_selected.transpose(1,2)) 281 | ref.scatter_(2, qk_ref_topk_indices_selected, 1.) 282 | ref = ref.view(B, H_ref, W_ref, T, H_ref, W_ref) 283 | 284 | idx_all = torch.nonzero(ref) 285 | idx = idx_all[:, 0], idx_all[:, 1], idx_all[:, 2], idx_all[:, 3], idx_all[:, 4], idx_all[:, 5] 286 | m_in_selected = m_in_selected.view(B,D_e,T,H_ref,size,W_ref,size).permute(0,2,3,5,4,6,1)[idx[0], idx[3], idx[4], idx[5]] # B*H/2*W/2*k, 2, 2, Cin 287 | m_in_selected = m_in_selected.reshape(B, H_ref, W_ref, topk_guided_num*size*size, D_e) # B, H/2, W/2, k*size*size, Cin 288 | q_in_selected = q_in_selected.view(B,D_e,H_ref,size,W_ref,size) # B, Cin, H/2, 2, W/2, 2 289 | q_in_selected = q_in_selected.permute(0,2,4,1,3,5) # B, H/2, W/2, Cin, 2, 2 290 | m_out_selected = m_out_selected.view(B,D_o,T,H_ref,size,W_ref,size).permute(0,2,3,5,4,6,1)[idx[0], idx[3], idx[4], idx[5]] # B*H/2*W/2*k, 2, 2, Cout 291 | m_out_selected = m_out_selected.reshape(B, H_ref, W_ref, topk_guided_num*size*size, D_o) 292 | 293 | p = torch.einsum('bhwnc,bhwcij->bhwijn', m_in_selected, q_in_selected) 294 | p = p / math.sqrt(D_e) 295 | p = F.softmax(p, dim=-1) 296 | 297 | mem_out = torch.einsum('bhwnc,bhwijn->bchiwj', m_out_selected, p) 298 | mem_out = mem_out.reshape(B, D_o, H, W) 299 | 300 | mem_out_pad = mem_out 301 | 302 | return mem_out_pad, qk_ref_topk_indices[:,:,:max(topk_guided_num//4,1)], qk_ref_topk_val[:,:,:max(topk_guided_num//4,1)] 303 | 304 | 305 | class KeyValue(nn.Module): 306 | def __init__(self, indim, keydim, valdim, only_key=False): 307 | super(KeyValue, self).__init__() 308 | self.Key = nn.Conv2d(indim, keydim, kernel_size=(3,3), padding=(1,1), stride=1) 309 | self.only_key = only_key 310 | if not self.only_key: 311 | self.Value = nn.Conv2d(indim, valdim, kernel_size=(3,3), padding=(1,1), stride=1) 312 | 313 | def forward(self, x): 314 | k = self.Key(x) 315 | v = self.Value(x) if not self.only_key else None 316 | return k, v 317 | 318 | 319 | 320 | 321 | class HMMN(nn.Module): 322 | def __init__(self): 323 | super(HMMN, self).__init__() 324 | self.Encoder_M = Encoder_M() 325 | self.Encoder_Q = Encoder_Q() 326 | 327 | self.KV_M_r4 = KeyValue(1024, keydim=128, valdim=512) 328 | self.KV_Q_r4 = KeyValue(1024, keydim=128, valdim=512) 329 | self.KV_M_r3 = KeyValue(512, keydim=128, valdim=256) 330 | self.KV_Q_r3 = KeyValue(512, keydim=128, valdim=-1, only_key=True) 331 | self.KV_M_r2 = KeyValue(256, keydim=64, valdim=128) 332 | self.KV_Q_r2 = KeyValue(256, keydim=64, valdim=-1, only_key=True) 333 | 334 | self.Memory = Memory(gaussian_kernel=3, gaussian_kernel_flow_window=7) 335 | self.Memory_topk3 = Memory_topk(topk_guided_num=32) 336 | self.Memory_topk2 = Memory_topk(topk_guided_num=32//4) 337 | 338 | self.Decoder = Decoder(256) 339 | 340 | def Pad_memory(self, mems, num_objects, K): 341 | pad_mems = [] 342 | for mem in mems: 343 | # pad_mem = ToCuda(torch.zeros(1, K, mem.size()[1], 1, mem.size()[2], mem.size()[3])) 344 | # pad_mem[0,1:num_objects+1,:,0] = mem 345 | pad_mem = mem.unsqueeze(2) 346 | pad_mems.append(pad_mem) 347 | return pad_mems 348 | 349 | def memorize(self, frame, masks, num_objects): 350 | # memorize a frame 351 | num_objects = num_objects[0].item() 352 | _, K, H, W = masks.shape # B = 1 353 | 354 | (frame, masks), pad = pad_divide_by([frame, masks], 16, (frame.size()[2], frame.size()[3])) 355 | 356 | # make batch arg list 357 | B_list = {'f':[], 'm':[], 'o':[]} 358 | for o in range(1, num_objects+1): # 1 - no 359 | B_list['f'].append(frame) 360 | B_list['m'].append(masks[:,o]) 361 | B_list['o'].append( (torch.sum(masks[:,1:o], dim=1) + \ 362 | torch.sum(masks[:,o+1:num_objects+1], dim=1)).clamp(0,1) ) 363 | 364 | # make Batch 365 | B_ = {} 366 | for arg in B_list.keys(): 367 | B_[arg] = torch.cat(B_list[arg], dim=0) 368 | 369 | r4, r3, r2, _, _ = self.Encoder_M(B_['f'], B_['m'], B_['o']) 370 | k4, v4 = self.KV_M_r4(r4) # num_objects, 128 and 512, H/16, W/16 371 | k3, v3 = self.KV_M_r3(r3) 372 | k2, v2 = self.KV_M_r2(r2) 373 | k4, v4 = self.Pad_memory([k4, v4], num_objects=num_objects, K=K) 374 | k3, v3 = self.Pad_memory([k3, v3], num_objects=num_objects, K=K) 375 | k2, v2 = self.Pad_memory([k2, v2], num_objects=num_objects, K=K) 376 | return k4, v4, k3, v3, k2, v2 377 | 378 | def Soft_aggregation(self, ps, K): 379 | num_objects, H, W = ps.shape 380 | em = ToCuda(torch.zeros(1, num_objects+1, H, W)) 381 | em[0,0] = torch.prod(1-ps, dim=0) # bg prob 382 | em[0,1:num_objects+1] = ps # obj prob 383 | em = torch.clamp(em, 1e-7, 1-1e-7) 384 | logit = torch.log((em /(1-em))) 385 | return logit 386 | 387 | def segment(self, frame, keys4, values4, keys3, values3, keys2, values2, num_objects): 388 | num_objects = num_objects[0].item() 389 | K, keydim, T, H, W = keys4.shape # B = 1 390 | # pad 391 | [frame], pad = pad_divide_by([frame], 16, (frame.size()[2], frame.size()[3])) 392 | 393 | r4, r3, r2, _, _ = self.Encoder_Q(frame) 394 | k4, v4 = self.KV_Q_r4(r4) # 1, dim, H/16, W/16 395 | k3, _ = self.KV_Q_r3(r3) 396 | k2, _ = self.KV_Q_r2(r2) 397 | 398 | # expand to --- no, c, h, w 399 | k4e, v4e = k4.expand(num_objects,-1,-1,-1), v4.expand(num_objects,-1,-1,-1) 400 | k3e = k3.expand(num_objects,-1,-1,-1) 401 | k2e = k2.expand(num_objects,-1,-1,-1) 402 | r3e, r2e = r3.expand(num_objects,-1,-1,-1), r2.expand(num_objects,-1,-1,-1) 403 | 404 | # memory select kv:(1, K, C, T, H, W) 405 | # m4, pm4 = self.Memory(keys4[0,1:num_objects+1], values4[0,1:num_objects+1], k4e, v4e) 406 | m4, pm4 = self.Memory(keys4, values4, k4e, v4e) 407 | B, THW_ref, HW_ref = pm4.size() 408 | if THW_ref > (HW_ref): 409 | pm4_for_topk = torch.cat((pm4[:,:HW_ref], pm4[:,-HW_ref:]), dim=1) # First and Prev 410 | else: 411 | pm4_for_topk = pm4 412 | 413 | # m3, next_topk_indices, next_topk_val = self.Memory_topk3(keys3[0,1:num_objects+1], values3[0,1:num_objects+1], k3e, pm4_for_topk) 414 | # m2, _, _ = self.Memory_topk2(keys2[0,1:num_objects+1], values2[0,1:num_objects+1], k2e, pm4_for_topk, next_topk_indices, next_topk_val) 415 | m3, next_topk_indices, next_topk_val = self.Memory_topk3(keys3, values3, k3e, pm4_for_topk) 416 | m2, _, _ = self.Memory_topk2(keys2, values2, k2e, pm4_for_topk, next_topk_indices, next_topk_val) 417 | 418 | logits = self.Decoder(m4, r3e, r2e, m3, m2) 419 | ps = F.softmax(logits, dim=1)[:,1] # no, h, w 420 | #ps = indipendant possibility to belong to each object 421 | 422 | logit = self.Soft_aggregation(ps, K) # 1, K, H, W 423 | 424 | if pad[2]+pad[3] > 0: 425 | logit = logit[:,:,pad[2]:-pad[3],:] 426 | if pad[0]+pad[1] > 0: 427 | logit = logit[:,:,:,pad[0]:-pad[1]] 428 | 429 | return logit 430 | 431 | def forward(self, *args, **kwargs): 432 | if args[1].dim() > 4: # keys 433 | return self.segment(*args, **kwargs) 434 | else: 435 | return self.memorize(*args, **kwargs) 436 | 437 | 438 | --------------------------------------------------------------------------------