├── .gitignore
├── README.md
├── dataset.py
├── eval_DAVIS.py
├── helpers.py
├── img
└── main.jpg
└── model.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # Visual Studio Code
132 | .vscode/*
133 | #!.vscode/settings.json
134 | #!.vscode/tasks.json
135 | #!.vscode/launch.json
136 | #!.vscode/extensions.json
137 | *.code-workspace
138 |
139 | # Local History for Visual Studio Code
140 | .history/
141 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Hierarchical Memory Matching Network for Video Object Segmentation
2 | ### Hongje Seong, Seoung Wug Oh, Joon-Young Lee, Seongwon Lee, Suhyeon Lee, Euntai Kim
3 | ### ICCV 2021
4 |
5 |
6 |
7 | This is the implementation of HMMN.
8 | This code is based on STM (ICCV 2019): [[link](https://github.com/seoungwugoh/STM)].
9 | Please see our paper for the details: [[paper]](https://arxiv.org/abs/2109.11404)
10 |
11 | [](https://www.youtube.com/watch?v=zSofRzPImQY "Hierarchical Memory Matching Network for Video Object Segmentation (ICCV 2021)")
12 |
13 | ## Dependencies
14 | * Python 3.8
15 | * PyTorch 1.8.1
16 | * numpy, opencv, pillow
17 |
18 | ## Trained model
19 | * Download pre-trained weights into the same folder with demo scripts
Link: [[weights](https://drive.google.com/file/d/16vOMm7hIdmC6yL4FlBO4p_miTeUoskt7/view?usp=sharing)]
20 |
21 |
22 | ## Code
23 | * DAVIS-2016 validation set (Single-object)
24 | ```bash
25 | python eval_DAVIS.py -g '0' -s val -y 16 -D [path/to/DAVIS]
26 | ```
27 | * DAVIS-2017 validation set (Multi-object)
28 | ```bash
29 | python eval_DAVIS.py -g '0' -s val -y 17 -D [path/to/DAVIS]
30 | ```
31 |
32 | ## Pre-computed Results
33 | We also provide pre-computed results for benchmark sets.
34 | * [[DAVIS-16-val]](https://drive.google.com/file/d/1SqzBktU0DrSd5_vC7TVmPtXJBn3bqAfG/view?usp=sharing)
35 | * [[DAVIS-17-val]](https://drive.google.com/file/d/1uDx8rPo91qEnoE_nBCYF-A-G8noRq6G8/view?usp=sharing)
36 | * [[DAVIS-17-test-dev]](https://drive.google.com/file/d/18-p2ihxfHZisOghWiSpvlfbhpy1j1oSl/view?usp=sharing)
37 | * [[YouTube-VOS-18-valid]](https://drive.google.com/file/d/1cE9rtqdafXGm7V3rP2ZSXq56atoXVUg3/view?usp=sharing)
38 | * [[YouTube-VOS-19-valid]](https://drive.google.com/file/d/1bA2iv2KhjYGlw5i25dLPf24byt-5On2r/view?usp=sharing)
39 |
40 |
41 | ## Bibtex
42 | ```
43 | @inproceedings{seong2021hierarchical,
44 | title={Hierarchical Memory Matching Network for Video Object Segmentation},
45 | author={Seong, Hongje and Oh, Seoung Wug and Lee, Joon-Young and Lee, Seongwon and Lee, Suhyeon and Kim, Euntai},
46 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
47 | year={2021}
48 | }
49 | ```
50 |
51 |
52 | ## Terms of Use
53 | This software is for non-commercial use only.
54 | The source code is released under the Attribution-NonCommercial-ShareAlike (CC BY-NC-SA) Licence
55 | (see [this](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) for details)
56 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import numpy as np
4 | from PIL import Image
5 |
6 | import torch
7 | import torchvision
8 | from torch.utils import data
9 |
10 | import glob
11 |
12 | class DAVIS_MO_Test(data.Dataset):
13 | # for multi object, do shuffling
14 |
15 | def __init__(self, root, imset='2017/train.txt', resolution='480p', single_object=False):
16 | self.root = root
17 | self.mask_dir = os.path.join(root, 'Annotations', resolution)
18 | self.mask480_dir = os.path.join(root, 'Annotations', '480p')
19 | self.image_dir = os.path.join(root, 'JPEGImages', resolution)
20 | _imset_dir = os.path.join(root, 'ImageSets')
21 | _imset_f = os.path.join(_imset_dir, imset)
22 |
23 | self.videos = []
24 | self.num_frames = {}
25 | self.num_objects = {}
26 | self.shape = {}
27 | self.size_480p = {}
28 | with open(os.path.join(_imset_f), "r") as lines:
29 | for line in lines:
30 | _video = line.rstrip('\n')
31 | self.videos.append(_video)
32 | self.num_frames[_video] = len(glob.glob(os.path.join(self.image_dir, _video, '*.jpg')))
33 | _mask = np.array(Image.open(os.path.join(self.mask_dir, _video, '00000.png')).convert("P"))
34 | self.num_objects[_video] = np.max(_mask)
35 | self.shape[_video] = np.shape(_mask)
36 | _mask480 = np.array(Image.open(os.path.join(self.mask480_dir, _video, '00000.png')).convert("P"))
37 | self.size_480p[_video] = np.shape(_mask480)
38 |
39 | self.K = -1
40 | self.single_object = single_object
41 |
42 | def __len__(self):
43 | return len(self.videos)
44 |
45 |
46 | def To_onehot(self, mask):
47 | M = np.zeros((self.K, mask.shape[0], mask.shape[1]), dtype=np.uint8)
48 | for k in range(self.K):
49 | M[k] = (mask == k).astype(np.uint8)
50 | return M
51 |
52 | def All_to_onehot(self, masks):
53 | Ms = np.zeros((self.K, masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8)
54 | for n in range(masks.shape[0]):
55 | Ms[:,n] = self.To_onehot(masks[n])
56 | return Ms
57 |
58 | def __getitem__(self, index):
59 | video = self.videos[index]
60 | info = {}
61 | info['name'] = video
62 | info['num_frames'] = self.num_frames[video]
63 | info['size_480p'] = self.size_480p[video]
64 |
65 | N_frames = np.empty((self.num_frames[video],)+self.shape[video]+(3,), dtype=np.float32)
66 | N_masks = np.empty((self.num_frames[video],)+self.shape[video], dtype=np.uint8)
67 | for f in range(self.num_frames[video]):
68 | img_file = os.path.join(self.image_dir, video, '{:05d}.jpg'.format(f))
69 | N_frames[f] = np.array(Image.open(img_file).convert('RGB'))/255.
70 | try:
71 | mask_file = os.path.join(self.mask_dir, video, '{:05d}.png'.format(f))
72 | N_masks[f] = np.array(Image.open(mask_file).convert('P'), dtype=np.uint8)
73 | except:
74 | # print('a')
75 | N_masks[f] = 255
76 |
77 | Fs = torch.from_numpy(np.transpose(N_frames.copy(), (3, 0, 1, 2)).copy()).float()
78 | if self.single_object:
79 | N_masks = (N_masks > 0.5).astype(np.uint8) * (N_masks < 255).astype(np.uint8)
80 | self.K = int(1+1)
81 | Ms = torch.from_numpy(self.All_to_onehot(N_masks).copy()).float()
82 | num_objects = torch.LongTensor([int(1)])
83 | return Fs, Ms, num_objects, info
84 | else:
85 | self.K = int(self.num_objects[video]+1)
86 | Ms = torch.from_numpy(self.All_to_onehot(N_masks).copy()).float()
87 | num_objects = torch.LongTensor([int(self.num_objects[video])])
88 | return Fs, Ms, num_objects, info
89 |
90 |
91 |
92 | if __name__ == '__main__':
93 | pass
94 |
--------------------------------------------------------------------------------
/eval_DAVIS.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 | from torch.autograd import Variable
4 | from torch.utils import data
5 |
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torch.nn.init as init
9 | import torch.utils.model_zoo as model_zoo
10 | from torchvision import models
11 |
12 | # general libs
13 | import cv2
14 | import matplotlib.pyplot as plt
15 | from PIL import Image
16 | import numpy as np
17 | import math
18 | import time
19 | import tqdm
20 | import os
21 | import argparse
22 | import copy
23 |
24 |
25 | ### My libs
26 | from dataset import DAVIS_MO_Test
27 | from model import HMMN
28 |
29 |
30 | torch.set_grad_enabled(False) # Volatile
31 |
32 | def get_arguments():
33 | parser = argparse.ArgumentParser(description="SST")
34 | parser.add_argument("-g", type=str, help="0; 0,1; 0,3; etc", required=True)
35 | parser.add_argument("-s", type=str, help="set", required=True)
36 | parser.add_argument("-y", type=int, help="year", required=True)
37 | parser.add_argument("-viz", help="Save visualization", action="store_true")
38 | parser.add_argument("-D", type=str, help="path to data",default='/local/DATA')
39 | return parser.parse_args()
40 |
41 | args = get_arguments()
42 |
43 | GPU = args.g
44 | YEAR = args.y
45 | SET = args.s
46 | VIZ = args.viz
47 | DATA_ROOT = args.D
48 |
49 | # Model and version
50 | MODEL = 'HMMN'
51 | print(MODEL, ': Testing on DAVIS')
52 |
53 | os.environ['CUDA_VISIBLE_DEVICES'] = GPU
54 | if torch.cuda.is_available():
55 | print('using Cuda devices, num:', torch.cuda.device_count())
56 |
57 | if VIZ:
58 | print('--- Produce mask overaid video outputs. Evaluation will run slow.')
59 | print('--- Require FFMPEG for encoding, Check folder ./viz')
60 |
61 |
62 | palette = Image.open(DATA_ROOT + '/Annotations/480p/blackswan/00000.png').getpalette()
63 |
64 | def Run_video(Fs, Ms, num_frames, num_objects, Mem_every=None, Mem_number=None):
65 | # initialize storage tensors
66 | if Mem_every:
67 | to_memorize = [int(i) for i in np.arange(0, num_frames, step=Mem_every)]
68 | elif Mem_number:
69 | to_memorize = [int(round(i)) for i in np.linspace(0, num_frames, num=Mem_number+2)[:-1]]
70 | else:
71 | raise NotImplementedError
72 |
73 | Es = torch.zeros_like(Ms)
74 | Es[:,:,0] = Ms[:,:,0]
75 |
76 | for t in tqdm.tqdm(range(1, num_frames)):
77 | # memorize
78 | with torch.no_grad():
79 | prev_key4, prev_value4, prev_key3, prev_value3, prev_key2, prev_value2 = model(Fs[:,:,t-1], Es[:,:,t-1], torch.tensor([num_objects]))
80 |
81 | if t-1 == 0: # only prev memory
82 | this_keys4, this_values4 = prev_key4, prev_value4
83 | this_keys3, this_values3 = prev_key3, prev_value3
84 | this_keys2, this_values2 = prev_key2, prev_value2
85 | else:
86 | this_keys4 = torch.cat([keys4, prev_key4], dim=2)
87 | this_values4 = torch.cat([values4, prev_value4], dim=2)
88 | this_keys3 = torch.cat([keys3, prev_key3], dim=2)
89 | this_values3 = torch.cat([values3, prev_value3], dim=2)
90 | this_keys2 = torch.cat([keys2, prev_key2], dim=2)
91 | this_values2 = torch.cat([values2, prev_value2], dim=2)
92 |
93 | # segment
94 | with torch.no_grad():
95 | logit = model(Fs[:,:,t], this_keys4, this_values4, this_keys3, this_values3, this_keys2, this_values2, torch.tensor([num_objects]))
96 | Es[:,:,t] = F.softmax(logit, dim=1)
97 |
98 | # update
99 | if t-1 in to_memorize:
100 | keys4, values4 = this_keys4, this_values4
101 | if t-1 == 0: # update only the first frame memory
102 | keys3, values3 = this_keys3, this_values3
103 | keys2, values2 = this_keys2, this_values2
104 |
105 | pred = np.argmax(Es[0].cpu().numpy(), axis=0).astype(np.uint8)
106 | return pred, Es
107 |
108 |
109 |
110 | Testset = DAVIS_MO_Test(DATA_ROOT, resolution='480p', imset='20{}/{}.txt'.format(YEAR,SET), single_object=(YEAR==16))
111 | Testloader = data.DataLoader(Testset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
112 |
113 | model = nn.DataParallel(HMMN())
114 | if torch.cuda.is_available():
115 | model.cuda()
116 | model.eval() # turn-off BN
117 |
118 | pth_path = 'HMMN_weights.pth'
119 | print('Loading weights:', pth_path)
120 | model.load_state_dict(torch.load(pth_path, map_location='cpu'))
121 |
122 | code_name = '{}_DAVIS_{}{}'.format(MODEL,YEAR,SET)
123 | print('Start Testing:', code_name)
124 |
125 |
126 | for seq, V in enumerate(Testloader):
127 | Fs, Ms, num_objects, info = V
128 | seq_name = info['name'][0]
129 | num_frames = info['num_frames'][0].item()
130 | print('[{}]: num_frames: {}, num_objects: {}'.format(seq_name, num_frames, num_objects[0][0]))
131 |
132 | pred, Es = Run_video(Fs, Ms, num_frames, num_objects, Mem_every=5, Mem_number=None)
133 |
134 | # Save results for quantitative eval ######################
135 | test_path = os.path.join('./test', code_name, seq_name)
136 | if not os.path.exists(test_path):
137 | os.makedirs(test_path)
138 | for f in range(num_frames):
139 | img_E = Image.fromarray(pred[f])
140 | img_E.putpalette(palette)
141 | img_E.save(os.path.join(test_path, '{:05d}.png'.format(f)))
142 |
143 | if VIZ:
144 | from helpers import overlay_davis
145 | # visualize results #######################
146 | viz_path = os.path.join('./viz/', code_name, seq_name)
147 | if not os.path.exists(viz_path):
148 | os.makedirs(viz_path)
149 |
150 | for f in range(num_frames):
151 | pF = (Fs[0,:,f].permute(1,2,0).numpy() * 255.).astype(np.uint8)
152 | pE = pred[f]
153 | canvas = overlay_davis(pF, pE, palette)
154 | canvas = Image.fromarray(canvas)
155 | canvas.save(os.path.join(viz_path, 'f{}.jpg'.format(f)))
156 |
157 | vid_path = os.path.join('./viz/', code_name, '{}.mp4'.format(seq_name))
158 | frame_path = os.path.join('./viz/', code_name, seq_name, 'f%d.jpg')
159 | os.system('ffmpeg -framerate 10 -i {} {} -vcodec libx264 -crf 10 -pix_fmt yuv420p -nostats -loglevel 0 -y'.format(frame_path, vid_path))
160 |
161 |
162 |
163 |
--------------------------------------------------------------------------------
/helpers.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | #torch
3 | import torch
4 | from torch.autograd import Variable
5 | from torch.utils import data
6 |
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torch.nn.init as init
10 | import torch.utils.model_zoo as model_zoo
11 | from torchvision import models
12 |
13 | # general libs
14 | import cv2
15 | import matplotlib.pyplot as plt
16 | from PIL import Image
17 | import numpy as np
18 | import time
19 | import os
20 | import copy
21 |
22 |
23 | def ToCuda(xs):
24 | if torch.cuda.is_available():
25 | if isinstance(xs, list) or isinstance(xs, tuple):
26 | return [x.cuda(non_blocking=True) for x in xs]
27 | else:
28 | return xs.cuda(non_blocking=True)
29 | else:
30 | return xs
31 |
32 |
33 | def pad_divide_by(in_list, d, in_size):
34 | out_list = []
35 | h, w = in_size
36 | if h % d > 0:
37 | new_h = h + d - h % d
38 | else:
39 | new_h = h
40 | if w % d > 0:
41 | new_w = w + d - w % d
42 | else:
43 | new_w = w
44 | lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2)
45 | lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2)
46 | pad_array = (int(lw), int(uw), int(lh), int(uh))
47 | for inp in in_list:
48 | out_list.append(F.pad(inp, pad_array))
49 | return out_list, pad_array
50 |
51 |
52 |
53 | def overlay_davis(image,mask,colors=[255,0,0],cscale=2,alpha=0.4):
54 | """ Overlay segmentation on top of RGB image. from davis official"""
55 | # import skimage
56 | from scipy.ndimage.morphology import binary_erosion, binary_dilation
57 |
58 | colors = np.reshape(colors, (-1, 3))
59 | colors = np.atleast_2d(colors) * cscale
60 |
61 | im_overlay = image.copy()
62 | object_ids = np.unique(mask)
63 |
64 | for object_id in object_ids[1:]:
65 | # Overlay color on binary mask
66 | foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id])
67 | binary_mask = mask == object_id
68 |
69 | # Compose image
70 | im_overlay[binary_mask] = foreground[binary_mask]
71 |
72 | # countours = skimage.morphology.binary.binary_dilation(binary_mask) - binary_mask
73 | countours = binary_dilation(binary_mask) ^ binary_mask
74 | # countours = cv2.dilate(binary_mask, cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))) - binary_mask
75 | im_overlay[countours,:] = 0
76 |
77 | return im_overlay.astype(image.dtype)
78 |
79 |
80 |
81 |
82 |
--------------------------------------------------------------------------------
/img/main.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/HMMN/a138cd805c0e30effb910f1a282e27c969e4fcac/img/main.jpg
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.nn.init as init
6 | import torch.utils.model_zoo as model_zoo
7 | from torchvision import models
8 |
9 | # general libs
10 | import cv2
11 | import matplotlib.pyplot as plt
12 | from PIL import Image
13 | import numpy as np
14 | import math
15 | import time
16 | import tqdm
17 | import os
18 | import argparse
19 | import copy
20 | import sys
21 |
22 | from helpers import *
23 |
24 | print('Hierarchical Memory Matching Network: initialized.')
25 |
26 | class ResBlock(nn.Module):
27 | def __init__(self, indim, outdim=None, stride=1):
28 | super(ResBlock, self).__init__()
29 | if outdim == None:
30 | outdim = indim
31 | if indim == outdim and stride==1:
32 | self.downsample = None
33 | else:
34 | self.downsample = nn.Conv2d(indim, outdim, kernel_size=3, padding=1, stride=stride)
35 |
36 | self.conv1 = nn.Conv2d(indim, outdim, kernel_size=3, padding=1, stride=stride)
37 | self.conv2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1)
38 |
39 |
40 | def forward(self, x):
41 | r = self.conv1(F.relu(x))
42 | r = self.conv2(F.relu(r))
43 |
44 | if self.downsample is not None:
45 | x = self.downsample(x)
46 |
47 | return x + r
48 |
49 | class Encoder_M(nn.Module):
50 | def __init__(self):
51 | super(Encoder_M, self).__init__()
52 | self.conv1_m = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
53 | self.conv1_o = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
54 |
55 | resnet = models.resnet50(pretrained=True)
56 | self.conv1 = resnet.conv1
57 | self.bn1 = resnet.bn1
58 | self.relu = resnet.relu # 1/2, 64
59 | self.maxpool = resnet.maxpool
60 |
61 | self.res2 = resnet.layer1 # 1/4, 256
62 | self.res3 = resnet.layer2 # 1/8, 512
63 | self.res4 = resnet.layer3 # 1/16, 1024
64 |
65 | self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1))
66 | self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1))
67 |
68 | def forward(self, in_f, in_m, in_o):
69 | f = (in_f - self.mean) / self.std
70 | m = torch.unsqueeze(in_m, dim=1).float() # add channel dim
71 | o = torch.unsqueeze(in_o, dim=1).float() # add channel dim
72 |
73 | x = self.conv1(f) + self.conv1_m(m) + self.conv1_o(o)
74 | x = self.bn1(x)
75 | c1 = self.relu(x) # 1/2, 64
76 | x = self.maxpool(c1) # 1/4, 64
77 | r2 = self.res2(x) # 1/4, 256
78 | r3 = self.res3(r2) # 1/8, 512
79 | r4 = self.res4(r3) # 1/16, 1024
80 | return r4, r3, r2, c1, f
81 |
82 | class Encoder_Q(nn.Module):
83 | def __init__(self):
84 | super(Encoder_Q, self).__init__()
85 | resnet = models.resnet50(pretrained=True)
86 | self.conv1 = resnet.conv1
87 | self.bn1 = resnet.bn1
88 | self.relu = resnet.relu # 1/2, 64
89 | self.maxpool = resnet.maxpool
90 |
91 | self.res2 = resnet.layer1 # 1/4, 256
92 | self.res3 = resnet.layer2 # 1/8, 512
93 | self.res4 = resnet.layer3 # 1/16, 1024
94 |
95 | self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1))
96 | self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1))
97 |
98 | def forward(self, in_f):
99 | f = (in_f - self.mean) / self.std
100 |
101 | x = self.conv1(f)
102 | x = self.bn1(x)
103 | c1 = self.relu(x) # 1/2, 64
104 | x = self.maxpool(c1) # 1/4, 64
105 | r2 = self.res2(x) # 1/4, 256
106 | r3 = self.res3(r2) # 1/8, 512
107 | r4 = self.res4(r3) # 1/16, 1024
108 | return r4, r3, r2, c1, f
109 |
110 |
111 | class Refine(nn.Module):
112 | def __init__(self, inplanes, planes, scale_factor=2):
113 | super(Refine, self).__init__()
114 | self.convFS = nn.Conv2d(inplanes, planes, kernel_size=(3,3), padding=(1,1), stride=1)
115 | self.convMemory = nn.Conv2d(inplanes//2, planes, kernel_size=(3,3), padding=(1,1), stride=1, bias=False)
116 | self.ResFS = ResBlock(planes, planes)
117 | self.ResMM = ResBlock(planes, planes)
118 | self.scale_factor = scale_factor
119 |
120 | def forward(self, f, pm, mem):
121 | s = self.convFS(f) + self.convMemory(mem)
122 | s = self.ResFS(s)
123 | m = s + F.interpolate(pm, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
124 | m = self.ResMM(m)
125 | return m
126 |
127 | class Decoder(nn.Module):
128 | def __init__(self, mdim):
129 | super(Decoder, self).__init__()
130 | self.convFM = nn.Conv2d(1024, mdim, kernel_size=(3,3), padding=(1,1), stride=1)
131 | self.ResMM = ResBlock(mdim, mdim)
132 | self.RF3 = Refine(512, mdim) # 1/8 -> 1/4
133 | self.RF2 = Refine(256, mdim) # 1/4 -> 1
134 |
135 | self.pred2 = nn.Conv2d(mdim, 2, kernel_size=(3,3), padding=(1,1), stride=1)
136 |
137 | def forward(self, r4, r3, r2, mem3, mem2):
138 | m4 = self.ResMM(self.convFM(r4))
139 | m3 = self.RF3(r3, m4, mem3) # out: 1/8, 256
140 | m2 = self.RF2(r2, m3, mem2) # out: 1/4, 256
141 |
142 | p2 = self.pred2(F.relu(m2))
143 |
144 | p = F.interpolate(p2, scale_factor=4, mode='bilinear', align_corners=False)
145 | return p
146 |
147 |
148 |
149 | class Memory(nn.Module):
150 | def __init__(self, gaussian_kernel, gaussian_kernel_flow_window):
151 | super(Memory, self).__init__()
152 | self.gaussian_kernel = gaussian_kernel
153 | self.gaussian_kernel_flow_window = gaussian_kernel_flow_window
154 | if self.gaussian_kernel != -1:
155 | self.feature_H = -1
156 | self.feature_W = -1
157 | if self.gaussian_kernel_flow_window != -1:
158 | self.H_flow = -1
159 | self.W_flow = -1
160 | self.T_flow = 1e+7
161 | self.B_flow = -1
162 |
163 | def apply_gaussian_kernel(self, corr, h, w, sigma_factor=1.):
164 | b, hwt, hw = corr.size()
165 |
166 | idx = corr.max(dim=2)[1] # b x h2 x w2
167 | idx_y = (idx // w).view(b, hwt, 1, 1).float()
168 | idx_x = (idx % w).view(b, hwt, 1, 1).float()
169 |
170 | if h != self.feature_H:
171 | self.feature_H = h
172 | y_tmp = np.linspace(0,h-1,h)
173 | self.y = ToCuda(torch.FloatTensor(y_tmp))
174 | y = self.y.view(1,1,h,1).expand(b, hwt, h, 1)
175 |
176 | if w != self.feature_W:
177 | self.feature_W = w
178 | x_tmp = np.linspace(0,w-1,w)
179 | self.x = ToCuda(torch.FloatTensor(x_tmp))
180 | x = self.x.view(1,1,1,w).expand(b, hwt, 1, w)
181 |
182 | gauss_kernel = torch.exp(-((x-idx_x)**2 + (y-idx_y)**2) / (2 * (self.gaussian_kernel*sigma_factor)**2))
183 | gauss_kernel = gauss_kernel.view(b, hwt, hw)
184 |
185 | return gauss_kernel, idx
186 |
187 | def forward(self, m_in, m_out, q_in, q_out): # m_in: o,c,t,h,w
188 | B, D_e, T, H, W = m_in.size()
189 | _, D_o, _, _, _ = m_out.size()
190 |
191 | mi = m_in.view(B, D_e, T*H*W)
192 | mi = torch.transpose(mi, 1, 2) # b, THW, emb
193 |
194 | qi = q_in.view(B, D_e, H*W) # b, emb, HW
195 |
196 | p = torch.bmm(mi, qi) # b, THW, HW
197 | p = p / math.sqrt(D_e)
198 |
199 | if self.gaussian_kernel != -1:
200 | if self.gaussian_kernel_flow_window != -1:
201 | p_tmp = p[:,int(-H*W):].clone()
202 | if (self.B_flow != B) or (self.T_flow != T) or (self.H_flow != H) or (self.W_flow != W):
203 | hide_non_local_qk_map_tmp = torch.ones(B,1,H,W,H,W).bool()
204 | window_size_half = (self.gaussian_kernel_flow_window-1) // 2
205 | for h_idx1 in range(H):
206 | for w_idx1 in range(W):
207 | h_left = max(h_idx1-window_size_half, 0)
208 | h_right = h_idx1+window_size_half+1
209 | w_left = max(w_idx1-window_size_half, 0)
210 | w_right = w_idx1+window_size_half+1
211 | hide_non_local_qk_map_tmp[:,0,h_idx1,w_idx1,h_left:h_right,w_left:w_right] = False
212 | hide_non_local_qk_map_tmp = hide_non_local_qk_map_tmp.view(B,H*W,H*W)
213 | self.hide_non_local_qk_map_flow = ToCuda(hide_non_local_qk_map_tmp)
214 | if (self.B_flow != B) or (self.T_flow > T) or (T==1) or (self.H_flow != H) or (self.W_flow != W):
215 | self.max_idx_stacked = None
216 | p_tmp.masked_fill_(self.hide_non_local_qk_map_flow, float('-inf'))
217 | gauss_kernel_map, max_idx = self.apply_gaussian_kernel(p_tmp, h=H, w=W)
218 | if self.max_idx_stacked is None:
219 | self.max_idx_stacked = max_idx
220 | else:
221 | if self.T_flow == T:
222 | self.max_idx_stacked = self.max_idx_stacked[:,:int(-H*W)]
223 | self.max_idx_stacked = torch.gather(max_idx, dim=1, index=self.max_idx_stacked)
224 | for t_ in range(1, T):
225 | gauss_kernel_map_tmp, _ = self.apply_gaussian_kernel(p_tmp, h=H, w=W, sigma_factor=(t_*0.5)+1)
226 | gauss_kernel_map_tmp = torch.gather(gauss_kernel_map_tmp, dim=1, index=self.max_idx_stacked[:,int((T-t_-1)*H*W):int((T-t_)*H*W)].unsqueeze(-1).expand(-1,-1,int(H*W)))
227 | gauss_kernel_map = torch.cat((gauss_kernel_map_tmp, gauss_kernel_map), dim=1)
228 | self.max_idx_stacked = torch.cat((self.max_idx_stacked, max_idx), dim=1)
229 | self.T_flow = T
230 | self.H_flow = H
231 | self.W_flow = W
232 | self.B_flow = B
233 | else:
234 | gauss_kernel_map, _ = self.apply_gaussian_kernel(p, h=H, w=W)
235 |
236 | p = F.softmax(p, dim=1) # b, THW, HW
237 |
238 | if self.gaussian_kernel != -1:
239 | p.mul_(gauss_kernel_map)
240 | p.div_(p.sum(dim=1, keepdim=True))
241 |
242 | mo = m_out.view(B, D_o, T*H*W)
243 | mem = torch.bmm(mo, p) # Weighted-sum B, D_o, HW
244 | mem = mem.view(B, D_o, H, W)
245 |
246 | mem_out = torch.cat([mem, q_out], dim=1)
247 |
248 | return mem_out, p
249 |
250 |
251 | class Memory_topk(nn.Module):
252 | def __init__(self, topk_guided_num):
253 | super(Memory_topk, self).__init__()
254 | self.topk_guided_num = topk_guided_num
255 |
256 | def forward(self, m_in, m_out, q_in, qk_ref, qk_ref_topk_indices=None, qk_ref_topk_val=None, mem_dropout=None): # m_in: o,c,t,h,w
257 | B_ori, D_e, T, H, W = m_in.size()
258 | _, D_o, _, _, _ = m_out.size()
259 |
260 | _, THW_ref, HW_ref = qk_ref.size()
261 | resolution_ref = int(math.sqrt((H*W) // HW_ref))
262 | H_ref = H // resolution_ref
263 | W_ref = W // resolution_ref
264 |
265 | size = resolution_ref
266 |
267 | if qk_ref_topk_indices is None:
268 | qk_ref_topk_val, qk_ref_topk_indices = torch.topk(qk_ref.transpose(1,2), k=self.topk_guided_num, dim=2, sorted=True)
269 | topk_guided_num = self.topk_guided_num
270 | else:
271 | topk_guided_num = qk_ref_topk_indices.shape[2]
272 |
273 | B = B_ori
274 | qk_ref_selected = qk_ref
275 | qk_ref_topk_indices_selected = qk_ref_topk_indices
276 | m_in_selected = m_in
277 | m_out_selected = m_out
278 | q_in_selected = q_in
279 |
280 | ref = torch.zeros_like(qk_ref_selected.transpose(1,2))
281 | ref.scatter_(2, qk_ref_topk_indices_selected, 1.)
282 | ref = ref.view(B, H_ref, W_ref, T, H_ref, W_ref)
283 |
284 | idx_all = torch.nonzero(ref)
285 | idx = idx_all[:, 0], idx_all[:, 1], idx_all[:, 2], idx_all[:, 3], idx_all[:, 4], idx_all[:, 5]
286 | m_in_selected = m_in_selected.view(B,D_e,T,H_ref,size,W_ref,size).permute(0,2,3,5,4,6,1)[idx[0], idx[3], idx[4], idx[5]] # B*H/2*W/2*k, 2, 2, Cin
287 | m_in_selected = m_in_selected.reshape(B, H_ref, W_ref, topk_guided_num*size*size, D_e) # B, H/2, W/2, k*size*size, Cin
288 | q_in_selected = q_in_selected.view(B,D_e,H_ref,size,W_ref,size) # B, Cin, H/2, 2, W/2, 2
289 | q_in_selected = q_in_selected.permute(0,2,4,1,3,5) # B, H/2, W/2, Cin, 2, 2
290 | m_out_selected = m_out_selected.view(B,D_o,T,H_ref,size,W_ref,size).permute(0,2,3,5,4,6,1)[idx[0], idx[3], idx[4], idx[5]] # B*H/2*W/2*k, 2, 2, Cout
291 | m_out_selected = m_out_selected.reshape(B, H_ref, W_ref, topk_guided_num*size*size, D_o)
292 |
293 | p = torch.einsum('bhwnc,bhwcij->bhwijn', m_in_selected, q_in_selected)
294 | p = p / math.sqrt(D_e)
295 | p = F.softmax(p, dim=-1)
296 |
297 | mem_out = torch.einsum('bhwnc,bhwijn->bchiwj', m_out_selected, p)
298 | mem_out = mem_out.reshape(B, D_o, H, W)
299 |
300 | mem_out_pad = mem_out
301 |
302 | return mem_out_pad, qk_ref_topk_indices[:,:,:max(topk_guided_num//4,1)], qk_ref_topk_val[:,:,:max(topk_guided_num//4,1)]
303 |
304 |
305 | class KeyValue(nn.Module):
306 | def __init__(self, indim, keydim, valdim, only_key=False):
307 | super(KeyValue, self).__init__()
308 | self.Key = nn.Conv2d(indim, keydim, kernel_size=(3,3), padding=(1,1), stride=1)
309 | self.only_key = only_key
310 | if not self.only_key:
311 | self.Value = nn.Conv2d(indim, valdim, kernel_size=(3,3), padding=(1,1), stride=1)
312 |
313 | def forward(self, x):
314 | k = self.Key(x)
315 | v = self.Value(x) if not self.only_key else None
316 | return k, v
317 |
318 |
319 |
320 |
321 | class HMMN(nn.Module):
322 | def __init__(self):
323 | super(HMMN, self).__init__()
324 | self.Encoder_M = Encoder_M()
325 | self.Encoder_Q = Encoder_Q()
326 |
327 | self.KV_M_r4 = KeyValue(1024, keydim=128, valdim=512)
328 | self.KV_Q_r4 = KeyValue(1024, keydim=128, valdim=512)
329 | self.KV_M_r3 = KeyValue(512, keydim=128, valdim=256)
330 | self.KV_Q_r3 = KeyValue(512, keydim=128, valdim=-1, only_key=True)
331 | self.KV_M_r2 = KeyValue(256, keydim=64, valdim=128)
332 | self.KV_Q_r2 = KeyValue(256, keydim=64, valdim=-1, only_key=True)
333 |
334 | self.Memory = Memory(gaussian_kernel=3, gaussian_kernel_flow_window=7)
335 | self.Memory_topk3 = Memory_topk(topk_guided_num=32)
336 | self.Memory_topk2 = Memory_topk(topk_guided_num=32//4)
337 |
338 | self.Decoder = Decoder(256)
339 |
340 | def Pad_memory(self, mems, num_objects, K):
341 | pad_mems = []
342 | for mem in mems:
343 | # pad_mem = ToCuda(torch.zeros(1, K, mem.size()[1], 1, mem.size()[2], mem.size()[3]))
344 | # pad_mem[0,1:num_objects+1,:,0] = mem
345 | pad_mem = mem.unsqueeze(2)
346 | pad_mems.append(pad_mem)
347 | return pad_mems
348 |
349 | def memorize(self, frame, masks, num_objects):
350 | # memorize a frame
351 | num_objects = num_objects[0].item()
352 | _, K, H, W = masks.shape # B = 1
353 |
354 | (frame, masks), pad = pad_divide_by([frame, masks], 16, (frame.size()[2], frame.size()[3]))
355 |
356 | # make batch arg list
357 | B_list = {'f':[], 'm':[], 'o':[]}
358 | for o in range(1, num_objects+1): # 1 - no
359 | B_list['f'].append(frame)
360 | B_list['m'].append(masks[:,o])
361 | B_list['o'].append( (torch.sum(masks[:,1:o], dim=1) + \
362 | torch.sum(masks[:,o+1:num_objects+1], dim=1)).clamp(0,1) )
363 |
364 | # make Batch
365 | B_ = {}
366 | for arg in B_list.keys():
367 | B_[arg] = torch.cat(B_list[arg], dim=0)
368 |
369 | r4, r3, r2, _, _ = self.Encoder_M(B_['f'], B_['m'], B_['o'])
370 | k4, v4 = self.KV_M_r4(r4) # num_objects, 128 and 512, H/16, W/16
371 | k3, v3 = self.KV_M_r3(r3)
372 | k2, v2 = self.KV_M_r2(r2)
373 | k4, v4 = self.Pad_memory([k4, v4], num_objects=num_objects, K=K)
374 | k3, v3 = self.Pad_memory([k3, v3], num_objects=num_objects, K=K)
375 | k2, v2 = self.Pad_memory([k2, v2], num_objects=num_objects, K=K)
376 | return k4, v4, k3, v3, k2, v2
377 |
378 | def Soft_aggregation(self, ps, K):
379 | num_objects, H, W = ps.shape
380 | em = ToCuda(torch.zeros(1, num_objects+1, H, W))
381 | em[0,0] = torch.prod(1-ps, dim=0) # bg prob
382 | em[0,1:num_objects+1] = ps # obj prob
383 | em = torch.clamp(em, 1e-7, 1-1e-7)
384 | logit = torch.log((em /(1-em)))
385 | return logit
386 |
387 | def segment(self, frame, keys4, values4, keys3, values3, keys2, values2, num_objects):
388 | num_objects = num_objects[0].item()
389 | K, keydim, T, H, W = keys4.shape # B = 1
390 | # pad
391 | [frame], pad = pad_divide_by([frame], 16, (frame.size()[2], frame.size()[3]))
392 |
393 | r4, r3, r2, _, _ = self.Encoder_Q(frame)
394 | k4, v4 = self.KV_Q_r4(r4) # 1, dim, H/16, W/16
395 | k3, _ = self.KV_Q_r3(r3)
396 | k2, _ = self.KV_Q_r2(r2)
397 |
398 | # expand to --- no, c, h, w
399 | k4e, v4e = k4.expand(num_objects,-1,-1,-1), v4.expand(num_objects,-1,-1,-1)
400 | k3e = k3.expand(num_objects,-1,-1,-1)
401 | k2e = k2.expand(num_objects,-1,-1,-1)
402 | r3e, r2e = r3.expand(num_objects,-1,-1,-1), r2.expand(num_objects,-1,-1,-1)
403 |
404 | # memory select kv:(1, K, C, T, H, W)
405 | # m4, pm4 = self.Memory(keys4[0,1:num_objects+1], values4[0,1:num_objects+1], k4e, v4e)
406 | m4, pm4 = self.Memory(keys4, values4, k4e, v4e)
407 | B, THW_ref, HW_ref = pm4.size()
408 | if THW_ref > (HW_ref):
409 | pm4_for_topk = torch.cat((pm4[:,:HW_ref], pm4[:,-HW_ref:]), dim=1) # First and Prev
410 | else:
411 | pm4_for_topk = pm4
412 |
413 | # m3, next_topk_indices, next_topk_val = self.Memory_topk3(keys3[0,1:num_objects+1], values3[0,1:num_objects+1], k3e, pm4_for_topk)
414 | # m2, _, _ = self.Memory_topk2(keys2[0,1:num_objects+1], values2[0,1:num_objects+1], k2e, pm4_for_topk, next_topk_indices, next_topk_val)
415 | m3, next_topk_indices, next_topk_val = self.Memory_topk3(keys3, values3, k3e, pm4_for_topk)
416 | m2, _, _ = self.Memory_topk2(keys2, values2, k2e, pm4_for_topk, next_topk_indices, next_topk_val)
417 |
418 | logits = self.Decoder(m4, r3e, r2e, m3, m2)
419 | ps = F.softmax(logits, dim=1)[:,1] # no, h, w
420 | #ps = indipendant possibility to belong to each object
421 |
422 | logit = self.Soft_aggregation(ps, K) # 1, K, H, W
423 |
424 | if pad[2]+pad[3] > 0:
425 | logit = logit[:,:,pad[2]:-pad[3],:]
426 | if pad[0]+pad[1] > 0:
427 | logit = logit[:,:,:,pad[0]:-pad[1]]
428 |
429 | return logit
430 |
431 | def forward(self, *args, **kwargs):
432 | if args[1].dim() > 4: # keys
433 | return self.segment(*args, **kwargs)
434 | else:
435 | return self.memorize(*args, **kwargs)
436 |
437 |
438 |
--------------------------------------------------------------------------------