├── model_ECCV2020 ├── networks │ ├── __init__.py │ ├── deeplab │ │ ├── __init__.py │ │ ├── sync_batchnorm │ │ │ ├── __init__.py │ │ │ ├── unittest.py │ │ │ ├── replicate.py │ │ │ ├── comm.py │ │ │ └── batchnorm.py │ │ ├── backbone │ │ │ ├── __init__.py │ │ │ ├── mobilenet.py │ │ │ ├── resnet.py │ │ │ └── drn.py │ │ ├── decoder.py │ │ ├── deeplab.py │ │ └── aspp.py │ ├── correlation_package.zip │ ├── ltm_transfer.py │ └── network.py └── model.py ├── etc ├── 00000.png ├── png_demo.png ├── fonts │ ├── myriad.ttf │ ├── helvetica.ttf │ ├── verdana.ttf │ └── times-new-roman.ttf └── explain_qwerty.png ├── model_CVPR2021 ├── networks │ └── deeplab │ │ ├── sync_batchnorm │ │ ├── __init__.py │ │ ├── unittest.py │ │ ├── replicate.py │ │ ├── comm.py │ │ └── batchnorm.py │ │ ├── backbone │ │ ├── __init__.py │ │ ├── mobilenet.py │ │ ├── resnet.py │ │ ├── xception.py │ │ └── drn.py │ │ ├── decoder.py │ │ ├── deeplab.py │ │ └── aspp.py └── model.py ├── IVOS_demo_customvideo.py ├── libs ├── helpers.py ├── utils_torch.py ├── davis_interactive_evaluator_mo.py └── utils_custom.py ├── LICENSE ├── eval_IVOS.py ├── eval_GIS_RS4.py ├── eval_GIS_RS1.py └── README.md /model_ECCV2020/networks/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /model_ECCV2020/networks/deeplab/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /etc/00000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuk6heo/GUI-IVOS/HEAD/etc/00000.png -------------------------------------------------------------------------------- /etc/png_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuk6heo/GUI-IVOS/HEAD/etc/png_demo.png -------------------------------------------------------------------------------- /etc/fonts/myriad.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuk6heo/GUI-IVOS/HEAD/etc/fonts/myriad.ttf -------------------------------------------------------------------------------- /etc/explain_qwerty.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuk6heo/GUI-IVOS/HEAD/etc/explain_qwerty.png -------------------------------------------------------------------------------- /etc/fonts/helvetica.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuk6heo/GUI-IVOS/HEAD/etc/fonts/helvetica.ttf -------------------------------------------------------------------------------- /etc/fonts/verdana.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuk6heo/GUI-IVOS/HEAD/etc/fonts/verdana.ttf -------------------------------------------------------------------------------- /etc/fonts/times-new-roman.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuk6heo/GUI-IVOS/HEAD/etc/fonts/times-new-roman.ttf -------------------------------------------------------------------------------- /model_ECCV2020/networks/correlation_package.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuk6heo/GUI-IVOS/HEAD/model_ECCV2020/networks/correlation_package.zip -------------------------------------------------------------------------------- /model_CVPR2021/networks/deeplab/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /model_ECCV2020/networks/deeplab/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /model_CVPR2021/networks/deeplab/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from ..backbone import resnet, xception, drn, mobilenet 2 | 3 | def build_backbone(backbone, output_stride, BatchNorm): 4 | if backbone == 'resnet': 5 | return resnet.ResNet101(output_stride, BatchNorm) 6 | elif backbone == 'xception': 7 | return xception.AlignedXception(output_stride, BatchNorm) 8 | elif backbone == 'drn': 9 | return drn.drn_d_54(BatchNorm) 10 | elif backbone == 'mobilenet': 11 | return mobilenet.MobileNetV2(output_stride, BatchNorm) 12 | else: 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /model_ECCV2020/networks/deeplab/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from ..backbone import resnet, xception, drn, mobilenet 2 | 3 | def build_backbone(backbone, output_stride, BatchNorm): 4 | if backbone == 'resnet': 5 | return resnet.ResNet101(output_stride, BatchNorm) 6 | elif backbone == 'xception': 7 | return xception.AlignedXception(output_stride, BatchNorm) 8 | elif backbone == 'drn': 9 | return drn.drn_d_54(BatchNorm) 10 | elif backbone == 'mobilenet': 11 | return mobilenet.MobileNetV2(output_stride, BatchNorm) 12 | else: 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /model_CVPR2021/networks/deeplab/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /model_ECCV2020/networks/deeplab/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /IVOS_demo_customvideo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from PyQt5.QtWidgets import QApplication 3 | 4 | from apps.multi_object_gui_eval_nogt import App 5 | 6 | import os 7 | 8 | 9 | class App_CVPR2021(App): 10 | def __init__(self, model, root, video_name, target_obj, save_imgs=True): 11 | super().__init__(model, root, video_name, target_obj, save_imgs=save_imgs) 12 | 13 | 14 | if __name__ == '__main__': 15 | 16 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 17 | 18 | ##################### Configs ######################## 19 | os.environ["CUDA_VISIBLE_DEVICES"] = str(1) 20 | from model_CVPR2021.model import model 21 | model = model() 22 | root = '/home/yuk/data_ssd/datasets/DAVIS/JPEGImages/480p/blackswan' 23 | video_name = 'blackswan' 24 | target_obj = 1 25 | save_imgs = True 26 | ##################### Configs ######################## 27 | 28 | 29 | 30 | 31 | app = QApplication(sys.argv) 32 | 33 | ex = App_CVPR2021(model, root, video_name, target_obj=target_obj, save_imgs=save_imgs) 34 | app.exec_() 35 | ex = None 36 | 37 | -------------------------------------------------------------------------------- /libs/helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def scrimg_postprocess(scr, dilation=7, nocare_area=21, blur = False, blursize=(5, 5), var = 6.0, custom_blur = None): 7 | 8 | # Compute foreground 9 | if scr.max() == 1: 10 | kernel_fg = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilation, dilation)) 11 | fg = cv2.dilate(scr.astype(np.uint8), kernel=kernel_fg).astype(scr.dtype) 12 | else: 13 | fg = scr 14 | 15 | # Compute nocare area 16 | if nocare_area is None: 17 | nocare = None 18 | else: 19 | kernel_nc = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (nocare_area, nocare_area)) 20 | nocare = cv2.dilate(fg, kernel=kernel_nc) - fg 21 | if blur: 22 | fg = cv2.GaussianBlur(fg,ksize=blursize,sigmaX=var) 23 | elif custom_blur: 24 | c_kernel = np.array([[1,2,3,2,1],[2,4,9,4,2],[3,9,64,9,3],[2,4,9,4,2],[1,2,3,2,1]]) 25 | c_kernel = c_kernel/np.sum(c_kernel) 26 | fg = cv2.filter2D(fg,ddepth=-1,kernel = c_kernel) 27 | 28 | return fg, nocare -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Yuk Heo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /libs/utils_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import os 6 | import skimage.color as color 7 | 8 | def combine_masks_with_batch(masks, n_obj, th=0.5): 9 | """ Combine mask for different objects. 10 | 11 | Different methods are the following: 12 | 13 | * `max_per_pixel`: Computes the final mask taking the pixel with the highest 14 | probability for every object. 15 | 16 | # Arguments 17 | masks: Tensor with shape[B, nobj, H, W]. H, W on batches must be same 18 | method: String. Method that specifies how the masks are fused. 19 | 20 | # Returns 21 | [B, 1, H, W] 22 | """ 23 | 24 | # masks : B, nobj, h, w 25 | # output : h,w 26 | marker = torch.argmax(masks, dim=1, keepdim=True) 27 | out_mask = torch.unsqueeze(torch.zeros_like(masks)[:,0],1) #[B, 1, H, W] 28 | for obj_id in range(n_obj): 29 | try :tmp_mask = (marker == obj_id) * (masks[:,obj_id].unsqueeze(1) > th) 30 | except: raise NotImplementedError 31 | out_mask[tmp_mask] = obj_id + 1 32 | 33 | return out_mask 34 | 35 | 36 | def cuda(xs): 37 | if torch.cuda.is_available(): 38 | if not isinstance(xs, (list, tuple)): 39 | return xs.cuda() 40 | else: 41 | return [x.cuda() for x in xs] 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /eval_IVOS.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from PyQt5.QtWidgets import QApplication 3 | 4 | from apps.multi_object_gui_eval import App 5 | from libs.davis_interactive_evaluator_mo import Davis_Interactive_Evaluator as DIE 6 | 7 | import os 8 | 9 | 10 | class App_CVPR2021(App): 11 | def __init__(self, DIE, model, root, video_indices, save_imgs=False): 12 | super().__init__(DIE, model, root, video_indices, save_imgs=save_imgs) 13 | 14 | 15 | class Davis_Interactive_Evaluator(DIE): 16 | def __init__(self, root, algorithm_name, user_name, imset='2017/val.txt', resolution='480p'): 17 | super().__init__(root, algorithm_name, user_name, imset=imset, resolution=resolution) 18 | 19 | 20 | if __name__ == '__main__': 21 | 22 | ##################### Configs ######################## 23 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 24 | os.environ["CUDA_VISIBLE_DEVICES"] = str(1) 25 | root = '/home/yuk/data_ssd/datasets/DAVIS' 26 | user_name = 'A' 27 | ##################### Configs ######################## 28 | 29 | from model_CVPR2021.model import model 30 | DIE = Davis_Interactive_Evaluator(root,algorithm_name='RAmap_IVOS',user_name=user_name) 31 | DIE.write_info() 32 | model = model() 33 | 34 | app = QApplication(sys.argv) 35 | 36 | for val_idx in range(0,30): 37 | ex = App_CVPR2021(DIE, model, root, video_indices=val_idx, save_imgs=True) 38 | app.exec_() 39 | ex = None 40 | 41 | -------------------------------------------------------------------------------- /eval_GIS_RS4.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from PyQt5.QtWidgets import QApplication 3 | 4 | from apps.multi_object_gui_eval_RS4 import App 5 | from libs.davis_interactive_evaluator_mo import Davis_Interactive_Evaluator as DIE 6 | 7 | import os 8 | 9 | 10 | class App_CVPR2021(App): 11 | def __init__(self, DIE, model, root, video_indices, save_imgs=False): 12 | super().__init__(DIE, model, root, video_indices, save_imgs=save_imgs) 13 | 14 | class Davis_Interactive_Evaluator(DIE): 15 | def __init__(self, root, algorithm_name, user_name, imset='2017/val.txt', resolution='480p'): 16 | super().__init__(root, algorithm_name, user_name, imset=imset, resolution=resolution) 17 | 18 | 19 | if __name__ == '__main__': 20 | 21 | ##################### Configs ######################## 22 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 23 | os.environ["CUDA_VISIBLE_DEVICES"] = str(1) 24 | root = '/home/yuk/data_ssd/datasets/DAVIS' 25 | user_name = 'A' 26 | ##################### Configs ######################## 27 | 28 | 29 | from model_CVPR2021.model import model as model 30 | DIE = Davis_Interactive_Evaluator(root,algorithm_name='RAmap_RS4',user_name=user_name) 31 | DIE.write_info() 32 | model = model() 33 | 34 | app = QApplication(sys.argv) 35 | 36 | for val_idx in range(0,30): 37 | ex = App_CVPR2021(DIE, model, root, video_indices=val_idx, save_imgs=False) 38 | app.exec_() 39 | ex = None -------------------------------------------------------------------------------- /model_CVPR2021/networks/deeplab/decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from ..deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class Decoder(nn.Module): 8 | def __init__(self, num_classes, backbone, BatchNorm): 9 | super(Decoder, self).__init__() 10 | if backbone == 'resnet' or backbone == 'drn': 11 | low_level_inplanes = 256 12 | elif backbone == 'xception': 13 | low_level_inplanes = 128 14 | elif backbone == 'mobilenet': 15 | low_level_inplanes = 24 16 | else: 17 | raise NotImplementedError 18 | 19 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 20 | self.bn1 = BatchNorm(48) 21 | self.relu = nn.ReLU() 22 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 23 | BatchNorm(256), 24 | nn.ReLU(), 25 | nn.Dropout(0.5), 26 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 27 | BatchNorm(256), 28 | nn.ReLU(), 29 | nn.Dropout(0.1), 30 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1)) 31 | self._init_weight() 32 | 33 | 34 | def forward(self, x, low_level_feat): 35 | low_level_feat = self.conv1(low_level_feat) 36 | low_level_feat = self.bn1(low_level_feat) 37 | low_level_feat = self.relu(low_level_feat) 38 | 39 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) 40 | x = torch.cat((x, low_level_feat), dim=1) 41 | x = self.last_conv(x) 42 | 43 | return x 44 | 45 | def _init_weight(self): 46 | for m in self.modules(): 47 | if isinstance(m, nn.Conv2d): 48 | torch.nn.init.kaiming_normal_(m.weight) 49 | elif isinstance(m, SynchronizedBatchNorm2d): 50 | m.weight.data.fill_(1) 51 | m.bias.data.zero_() 52 | elif isinstance(m, nn.BatchNorm2d): 53 | m.weight.data.fill_(1) 54 | m.bias.data.zero_() 55 | 56 | def build_decoder(num_classes, backbone, BatchNorm): 57 | return Decoder(num_classes, backbone, BatchNorm) -------------------------------------------------------------------------------- /model_ECCV2020/networks/deeplab/decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from ..deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class Decoder(nn.Module): 8 | def __init__(self, num_classes, backbone, BatchNorm): 9 | super(Decoder, self).__init__() 10 | if backbone == 'resnet' or backbone == 'drn': 11 | low_level_inplanes = 256 12 | elif backbone == 'xception': 13 | low_level_inplanes = 128 14 | elif backbone == 'mobilenet': 15 | low_level_inplanes = 24 16 | else: 17 | raise NotImplementedError 18 | 19 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 20 | self.bn1 = BatchNorm(48) 21 | self.relu = nn.ReLU() 22 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 23 | BatchNorm(256), 24 | nn.ReLU(), 25 | nn.Dropout(0.5), 26 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 27 | BatchNorm(256), 28 | nn.ReLU(), 29 | nn.Dropout(0.1), 30 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1)) 31 | self._init_weight() 32 | 33 | 34 | def forward(self, x, low_level_feat): 35 | low_level_feat = self.conv1(low_level_feat) 36 | low_level_feat = self.bn1(low_level_feat) 37 | low_level_feat = self.relu(low_level_feat) 38 | 39 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) 40 | x = torch.cat((x, low_level_feat), dim=1) 41 | x = self.last_conv(x) 42 | 43 | return x 44 | 45 | def _init_weight(self): 46 | for m in self.modules(): 47 | if isinstance(m, nn.Conv2d): 48 | torch.nn.init.kaiming_normal_(m.weight) 49 | elif isinstance(m, SynchronizedBatchNorm2d): 50 | m.weight.data.fill_(1) 51 | m.bias.data.zero_() 52 | elif isinstance(m, nn.BatchNorm2d): 53 | m.weight.data.fill_(1) 54 | m.bias.data.zero_() 55 | 56 | def build_decoder(num_classes, backbone, BatchNorm): 57 | return Decoder(num_classes, backbone, BatchNorm) -------------------------------------------------------------------------------- /libs/davis_interactive_evaluator_mo.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import time 4 | import os 5 | import csv 6 | import random 7 | from datetime import datetime 8 | 9 | from libs import utils_custom 10 | 11 | 12 | class Davis_Interactive_Evaluator(): 13 | def __init__(self, root, algorithm_name, user_name, imset='2017/val.txt', resolution='480p'): 14 | self.root = root 15 | self.mask_dir = os.path.join(root, 'Annotations', resolution) 16 | self.image_dir = os.path.join(root, 'JPEGImages', resolution) 17 | _imset_dir = os.path.join(root, 'ImageSets') 18 | _imset_f = os.path.join(_imset_dir, imset) 19 | 20 | self.videos = [] 21 | with open(os.path.join(_imset_f), "r") as lines: 22 | for line in lines: 23 | _video = line.rstrip('\n') 24 | self.videos.append(_video) 25 | 26 | self.videos = sorted(self.videos) 27 | 28 | self.current_time = datetime.now().strftime('%Y%m%d_%H%M%S') 29 | self.save_root = 'results/Alg[{}]_{}'.format(algorithm_name, self.current_time) 30 | self.algorithm_name = algorithm_name 31 | utils_custom.mkdir(self.save_root) 32 | 33 | self.savefname_csv = os.path.join(self.save_root+'/result_{}.csv'.format(user_name)) 34 | 35 | def write_info(self): 36 | with open(self.savefname_csv, mode='a') as csv_file: 37 | writer = csv.writer(csv_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) 38 | writer.writerow(['sequence', 'obj_id', 'N_rounds', 'final_J', 'final_F', 'scribble_time', 'operation_time', 'finding_time', 'total_time']) 39 | 40 | def write_in_csv(self,sequence, n_obj, final_J, final_F, scribble_timesteps, operate_timesteps, finding_timesteps): 41 | # write csv 42 | n_rounds = len(operate_timesteps) 43 | totaltime = finding_timesteps[-1] 44 | 45 | scribble_time = np.sum(np.array(scribble_timesteps) - np.array([0] + finding_timesteps[:-1])) 46 | operation_time = np.sum(np.array(operate_timesteps) - np.array(scribble_timesteps)) 47 | finding_time = np.sum(np.array(finding_timesteps) - np.array(operate_timesteps)) 48 | 49 | with open(self.savefname_csv, mode='a') as csv_file: 50 | writer = csv.writer(csv_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) 51 | for obj_id in range(1,n_obj+1): 52 | writer.writerow([sequence, obj_id, n_rounds, final_J[obj_id-1], final_F[obj_id-1], scribble_time, operation_time, finding_time, totaltime]) 53 | 54 | # scr oper find scr oper find scr oper find// 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /model_CVPR2021/networks/deeplab/deeplab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from ..deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | from ..deeplab.aspp import build_aspp 6 | from ..deeplab.decoder import build_decoder 7 | from ..deeplab.backbone import build_backbone 8 | 9 | class DeepLab(nn.Module): 10 | def __init__(self, backbone='resnet', output_stride=16, num_classes=21, 11 | sync_bn=True, freeze_bn=False): 12 | super(DeepLab, self).__init__() 13 | if backbone == 'drn': 14 | output_stride = 8 15 | 16 | if sync_bn == True: 17 | BatchNorm = SynchronizedBatchNorm2d 18 | else: 19 | BatchNorm = nn.BatchNorm2d 20 | 21 | self.backbone = build_backbone(backbone, output_stride, BatchNorm) 22 | self.aspp = build_aspp(backbone, output_stride, BatchNorm) 23 | self.decoder = build_decoder(num_classes, backbone, BatchNorm) 24 | 25 | if freeze_bn: 26 | self.freeze_bn() 27 | 28 | def forward(self, input): 29 | x, low_level_feat = self.backbone(input) 30 | x = self.aspp(x) 31 | x = self.decoder(x, low_level_feat) 32 | x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True) 33 | 34 | return x 35 | 36 | def freeze_bn(self): 37 | for m in self.modules(): 38 | if isinstance(m, SynchronizedBatchNorm2d): 39 | m.eval() 40 | elif isinstance(m, nn.BatchNorm2d): 41 | m.eval() 42 | 43 | def get_1x_lr_params(self): 44 | modules = [self.backbone] 45 | for i in range(len(modules)): 46 | for m in modules[i].named_modules(): 47 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 48 | or isinstance(m[1], nn.BatchNorm2d): 49 | for p in m[1].parameters(): 50 | if p.requires_grad: 51 | yield p 52 | 53 | def get_10x_lr_params(self): 54 | modules = [self.aspp, self.decoder] 55 | for i in range(len(modules)): 56 | for m in modules[i].named_modules(): 57 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 58 | or isinstance(m[1], nn.BatchNorm2d): 59 | for p in m[1].parameters(): 60 | if p.requires_grad: 61 | yield p 62 | 63 | 64 | if __name__ == "__main__": 65 | model = DeepLab(backbone='mobilenet', output_stride=16) 66 | model.eval() 67 | input = torch.rand(1, 3, 513, 513) 68 | output = model(input) 69 | print(output.size()) 70 | 71 | 72 | -------------------------------------------------------------------------------- /model_ECCV2020/networks/deeplab/deeplab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from ..deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | from ..deeplab.aspp import build_aspp 6 | from ..deeplab.decoder import build_decoder 7 | from ..deeplab.backbone import build_backbone 8 | 9 | class DeepLab(nn.Module): 10 | def __init__(self, backbone='resnet', output_stride=16, num_classes=21, 11 | sync_bn=True, freeze_bn=False): 12 | super(DeepLab, self).__init__() 13 | if backbone == 'drn': 14 | output_stride = 8 15 | 16 | if sync_bn == True: 17 | BatchNorm = SynchronizedBatchNorm2d 18 | else: 19 | BatchNorm = nn.BatchNorm2d 20 | 21 | self.backbone = build_backbone(backbone, output_stride, BatchNorm) 22 | self.aspp = build_aspp(backbone, output_stride, BatchNorm) 23 | self.decoder = build_decoder(num_classes, backbone, BatchNorm) 24 | 25 | if freeze_bn: 26 | self.freeze_bn() 27 | 28 | def forward(self, input): 29 | x, low_level_feat = self.backbone(input) 30 | x = self.aspp(x) 31 | x = self.decoder(x, low_level_feat) 32 | x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True) 33 | 34 | return x 35 | 36 | def freeze_bn(self): 37 | for m in self.modules(): 38 | if isinstance(m, SynchronizedBatchNorm2d): 39 | m.eval() 40 | elif isinstance(m, nn.BatchNorm2d): 41 | m.eval() 42 | 43 | def get_1x_lr_params(self): 44 | modules = [self.backbone] 45 | for i in range(len(modules)): 46 | for m in modules[i].named_modules(): 47 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 48 | or isinstance(m[1], nn.BatchNorm2d): 49 | for p in m[1].parameters(): 50 | if p.requires_grad: 51 | yield p 52 | 53 | def get_10x_lr_params(self): 54 | modules = [self.aspp, self.decoder] 55 | for i in range(len(modules)): 56 | for m in modules[i].named_modules(): 57 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 58 | or isinstance(m[1], nn.BatchNorm2d): 59 | for p in m[1].parameters(): 60 | if p.requires_grad: 61 | yield p 62 | 63 | 64 | if __name__ == "__main__": 65 | model = DeepLab(backbone='mobilenet', output_stride=16) 66 | model.eval() 67 | input = torch.rand(1, 3, 513, 513) 68 | output = model(input) 69 | print(output.size()) 70 | 71 | 72 | -------------------------------------------------------------------------------- /eval_GIS_RS1.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from PyQt5.QtWidgets import QApplication 3 | 4 | from apps.multi_object_gui_eval import App 5 | from libs.davis_interactive_evaluator_mo import Davis_Interactive_Evaluator as DIE 6 | import time 7 | 8 | import os 9 | import numpy as np 10 | from davisinteractive.utils.visualization import overlay_mask 11 | 12 | 13 | class App_CVPR2021(App): 14 | def __init__(self, DIE, model, root, video_indices, save_imgs=False): 15 | super().__init__(DIE, model, root, video_indices, save_imgs=save_imgs) 16 | 17 | def on_run(self): 18 | if len(self.scribbles['scribbles'][self.cursur])>=1: 19 | self.scribble_timesteps.append(time.time()-self.time_init) 20 | self.VOS_once_executed_bool = True 21 | self.model.Run_propagation(self.cursur) 22 | self.current_mask = self.model.Get_mask() 23 | 24 | self.current_round +=1 25 | 26 | print('[Overlaying segmentations...]') 27 | for fr in range(self.num_frames): 28 | self.vis_frames[fr] = overlay_mask(self.frames[fr], self.current_mask[fr], alpha=0.5, contour_thickness=2) 29 | print('[Overlaying Done.] \n') 30 | 31 | 32 | # clear scribble and reset 33 | self.cursur = np.argmin(self.model.scores_nf) 34 | self.show_current() 35 | self.clear_strokes() 36 | self.reset_scribbles() 37 | self.lcd2.setText('Currunt round : {:2d}'.format(self.current_round + 1)) 38 | 39 | self.operate_timesteps.append(time.time() - self.time_init) 40 | self.finding_timesteps.append(time.time() - self.time_init) 41 | self.slider.setDisabled(True) 42 | self.text_print += 'Providing scribble...\n' 43 | self.lcd3.setText(self.text_print) 44 | 45 | def on_select(self): 46 | a=1 47 | 48 | class Davis_Interactive_Evaluator(DIE): 49 | def __init__(self, root, algorithm_name, user_name, imset='2017/val.txt', resolution='480p'): 50 | super().__init__(root, algorithm_name, user_name, imset=imset, resolution=resolution) 51 | 52 | 53 | if __name__ == '__main__': 54 | 55 | 56 | ##################### Configs ######################## 57 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 58 | os.environ["CUDA_VISIBLE_DEVICES"] = str(2) 59 | 60 | root = '/home/yuk/data_ssd/datasets/DAVIS' 61 | user_name = 'A' 62 | ##################### Configs ######################## 63 | 64 | from model_CVPR2021.model import model as model 65 | DIE = Davis_Interactive_Evaluator(root,algorithm_name='RAmap_RS1',user_name=user_name) 66 | DIE.write_info() 67 | model = model() 68 | 69 | app = QApplication(sys.argv) 70 | 71 | for val_idx in range(0,30): 72 | ex = App_CVPR2021(DIE, model, root, video_indices=val_idx, save_imgs=False) 73 | app.exec_() 74 | ex = None 75 | 76 | -------------------------------------------------------------------------------- /model_CVPR2021/networks/deeplab/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /model_ECCV2020/networks/deeplab/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /model_ECCV2020/networks/ltm_transfer.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | class LTM_transfer(nn.Module): 9 | def __init__(self,md=4, stride=1): 10 | super(LTM_transfer, self).__init__() 11 | self.md = md #displacement (default = 4pixels) 12 | self.range = (md*2 + 1) ** 2 #(default = (4x2+1)**2 = 81) 13 | self.grid = None 14 | self.Channelwise_sum = None 15 | 16 | d_u = torch.linspace(-self.md * stride, self.md * stride, 2 * self.md + 1).view(1, -1).repeat((2 * self.md + 1, 1)).view(self.range, 1) # (25,1) 17 | d_v = torch.linspace(-self.md * stride, self.md * stride, 2 * self.md + 1).view(-1, 1).repeat((1, 2 * self.md + 1)).view(self.range, 1) # (25,1) 18 | self.d = torch.cat((d_u, d_v), dim=1).cuda() # (25,2) 19 | 20 | def L2normalize(self, x, d=1): 21 | eps = 1e-6 22 | norm = x ** 2 23 | norm = norm.sum(dim=d, keepdim=True) + eps 24 | norm = norm ** (0.5) 25 | return (x/norm) 26 | 27 | def UniformGrid(self, Input): 28 | ''' 29 | Make uniform grid 30 | :param Input: tensor(N,C,H,W) 31 | :return grid: (1,2,H,W) 32 | ''' 33 | # torchHorizontal = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(N, 1, H, W) 34 | # torchVertical = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(N, 1, H, W) 35 | # grid = torch.cat([torchHorizontal, torchVertical], 1).cuda() 36 | 37 | _, _, H, W = Input.size() 38 | # mesh grid 39 | xx = torch.arange(0, W).view(1, 1, 1, W).expand(1, 1, H, W) 40 | yy = torch.arange(0, H).view(1, 1, H, 1).expand(1, 1, H, W) 41 | 42 | grid = torch.cat((xx, yy), 1).float() 43 | 44 | if Input.is_cuda: 45 | grid = grid.cuda() 46 | 47 | return grid 48 | 49 | def warp(self, x, BM_d): 50 | vgrid = self.grid + BM_d # [N2HW] # [(2d+1)^2, 2, H, W] 51 | # scale grid to [-1,1] 52 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :] / max(x.size(3) - 1, 1) - 1.0 53 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :] / max(x.size(2) - 1, 1) - 1.0 54 | 55 | vgrid = vgrid.permute(0, 2, 3, 1) 56 | output = nn.functional.grid_sample(x, vgrid, mode='bilinear', padding_mode = 'border') #800MB memory occupied (d=2,C=64,H=256,W=256) 57 | mask = torch.autograd.Variable(torch.ones(x.size())).cuda() 58 | mask = nn.functional.grid_sample(mask, vgrid) #300MB memory occpied (d=2,C=64,H=256,W=256) 59 | 60 | mask = mask.masked_fill_(mask<0.999,0) 61 | mask = mask.masked_fill_(mask>0,1) 62 | 63 | return output * mask 64 | 65 | def forward(self,sim_feature, f_map, apply_softmax_on_simfeature = True): 66 | ''' 67 | Return bilateral cost volume(Set of bilateral correlation map) 68 | :param sim_feature: Correlation feature based on operating frame's HW (N,D2,H,W) 69 | :param f_map: Previous frame mask (N,1,H,W) 70 | :return Correlation Cost: (N,(2d+1)^2,H,W) 71 | ''' 72 | # feature1 = self.L2normalize(feature1) 73 | # feature2 = self.L2normalize(feature2) 74 | 75 | B_size,C_size,H_size,W_size = f_map.size() 76 | 77 | if self.grid is None: 78 | # Initialize first uniform grid 79 | self.grid = self.UniformGrid(f_map) 80 | 81 | if H_size != self.grid.size(2) or W_size != self.grid.size(3): 82 | # Update uniform grid to fit on input tensor shape 83 | self.grid = self.UniformGrid(f_map) 84 | 85 | 86 | # Displacement volume (N,(2d+1)^2,2,H,W) d = (i,j) , i in [-md,md] & j in [-md,md] 87 | D_vol = self.d.view(self.range, 2, 1, 1).expand(-1, -1, H_size, W_size) # [(2d+1)^2, 2, H, W] 88 | 89 | if apply_softmax_on_simfeature: 90 | sim_feature = F.softmax(sim_feature, dim=1) # B,D^2,H,W 91 | f_map = self.warp(f_map.transpose(0, 1).expand(self.range,-1,-1,-1), D_vol).transpose(0, 1) # B,D^2,H,W 92 | 93 | f_map = torch.sum(torch.mul(sim_feature, f_map),dim=1, keepdim=True) # B,1,H,W 94 | 95 | return f_map # B,1,H,W 96 | -------------------------------------------------------------------------------- /model_CVPR2021/networks/deeplab/aspp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from ..deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class _ASPPModule(nn.Module): 8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm, pretrained): 9 | super(_ASPPModule, self).__init__() 10 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 11 | stride=1, padding=padding, dilation=dilation, bias=False) 12 | self.bn = BatchNorm(planes) 13 | self.relu = nn.ReLU() 14 | 15 | self._init_weight(pretrained) 16 | 17 | def forward(self, x): 18 | x = self.atrous_conv(x) 19 | x = self.bn(x) 20 | 21 | return self.relu(x) 22 | 23 | def _init_weight(self,pretrained): 24 | for m in self.modules(): 25 | if pretrained: 26 | break 27 | else: 28 | 29 | if isinstance(m, nn.Conv2d): 30 | torch.nn.init.kaiming_normal_(m.weight) 31 | elif isinstance(m, SynchronizedBatchNorm2d): 32 | m.weight.data.fill_(1) 33 | m.bias.data.zero_() 34 | elif isinstance(m, nn.BatchNorm2d): 35 | m.weight.data.fill_(1) 36 | m.bias.data.zero_() 37 | 38 | class ASPP(nn.Module): 39 | def __init__(self, backbone, output_stride, BatchNorm, pretrained, inplanes=2048, outplanes = 256): 40 | super(ASPP, self).__init__() 41 | if output_stride == 16: 42 | dilations = [1, 6, 12, 18] 43 | elif output_stride == 8: 44 | dilations = [1, 12, 24, 36] 45 | else: 46 | raise NotImplementedError 47 | 48 | self.aspp1 = _ASPPModule(inplanes, outplanes, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm, pretrained=pretrained) 49 | self.aspp2 = _ASPPModule(inplanes, outplanes, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm, pretrained=pretrained) 50 | self.aspp3 = _ASPPModule(inplanes, outplanes, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm, pretrained=pretrained) 51 | self.aspp4 = _ASPPModule(inplanes, outplanes, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm, pretrained=pretrained) 52 | 53 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 54 | nn.Conv2d(inplanes, outplanes, 1, stride=1, bias=False), 55 | BatchNorm(outplanes), 56 | nn.ReLU()) 57 | self.conv1 = nn.Conv2d(outplanes*5, outplanes, 1, bias=False) 58 | self.bn1 = BatchNorm(outplanes) 59 | self.relu = nn.ReLU() 60 | self.dropout = nn.Dropout(0.5) 61 | self._init_weight(pretrained) 62 | 63 | def forward(self, x): 64 | x1 = self.aspp1(x) 65 | x2 = self.aspp2(x) 66 | x3 = self.aspp3(x) 67 | x4 = self.aspp4(x) 68 | x5 = self.global_avg_pool(x) 69 | # if type(x4.size()[2]) != int: 70 | # tmpsize = (x4.size()[2].item(),x4.size()[3].item()) 71 | # else: 72 | # tmpsize = (x4.size()[2],x4.size()[3]) 73 | # x5 = F.interpolate(x5, size=(14,14), mode='bilinear', align_corners=True) 74 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 75 | 76 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 77 | 78 | x = self.conv1(x) 79 | x = self.bn1(x) 80 | x = self.relu(x) 81 | 82 | return self.dropout(x) 83 | 84 | def _init_weight(self,pretrained): 85 | for m in self.modules(): 86 | if pretrained: 87 | break 88 | else: 89 | if isinstance(m, nn.Conv2d): 90 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 91 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 92 | torch.nn.init.kaiming_normal_(m.weight) 93 | elif isinstance(m, SynchronizedBatchNorm2d): 94 | m.weight.data.fill_(1) 95 | m.bias.data.zero_() 96 | elif isinstance(m, nn.BatchNorm2d): 97 | m.weight.data.fill_(1) 98 | m.bias.data.zero_() 99 | 100 | 101 | def build_aspp(backbone, output_stride, BatchNorm,pretrained): 102 | return ASPP(backbone, output_stride, BatchNorm, pretrained) -------------------------------------------------------------------------------- /model_ECCV2020/networks/deeplab/aspp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from ..networks.deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class _ASPPModule(nn.Module): 8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm, pretrained): 9 | super(_ASPPModule, self).__init__() 10 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 11 | stride=1, padding=padding, dilation=dilation, bias=False) 12 | self.bn = BatchNorm(planes) 13 | self.relu = nn.ReLU() 14 | 15 | self._init_weight(pretrained) 16 | 17 | def forward(self, x): 18 | x = self.atrous_conv(x) 19 | x = self.bn(x) 20 | 21 | return self.relu(x) 22 | 23 | def _init_weight(self,pretrained): 24 | for m in self.modules(): 25 | if pretrained: 26 | break 27 | else: 28 | 29 | if isinstance(m, nn.Conv2d): 30 | torch.nn.init.kaiming_normal_(m.weight) 31 | elif isinstance(m, SynchronizedBatchNorm2d): 32 | m.weight.data.fill_(1) 33 | m.bias.data.zero_() 34 | elif isinstance(m, nn.BatchNorm2d): 35 | m.weight.data.fill_(1) 36 | m.bias.data.zero_() 37 | 38 | class ASPP(nn.Module): 39 | def __init__(self, backbone, output_stride, BatchNorm, pretrained): 40 | super(ASPP, self).__init__() 41 | if backbone == 'drn': 42 | inplanes = 512 43 | elif backbone == 'mobilenet': 44 | inplanes = 320 45 | else: 46 | inplanes = 2048 47 | if output_stride == 16: 48 | dilations = [1, 6, 12, 18] 49 | elif output_stride == 8: 50 | dilations = [1, 12, 24, 36] 51 | else: 52 | raise NotImplementedError 53 | 54 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm, pretrained=pretrained) 55 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm, pretrained=pretrained) 56 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm, pretrained=pretrained) 57 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm, pretrained=pretrained) 58 | 59 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 60 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 61 | BatchNorm(256), 62 | nn.ReLU()) 63 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 64 | self.bn1 = BatchNorm(256) 65 | self.relu = nn.ReLU() 66 | self.dropout = nn.Dropout(0.5) 67 | self._init_weight(pretrained) 68 | 69 | def forward(self, x): 70 | x1 = self.aspp1(x) 71 | x2 = self.aspp2(x) 72 | x3 = self.aspp3(x) 73 | x4 = self.aspp4(x) 74 | x5 = self.global_avg_pool(x) 75 | # if type(x4.size()[2]) != int: 76 | # tmpsize = (x4.size()[2].item(),x4.size()[3].item()) 77 | # else: 78 | # tmpsize = (x4.size()[2],x4.size()[3]) 79 | # x5 = F.interpolate(x5, size=(14,14), mode='bilinear', align_corners=True) 80 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 81 | 82 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 83 | 84 | x = self.conv1(x) 85 | x = self.bn1(x) 86 | x = self.relu(x) 87 | 88 | return self.dropout(x) 89 | 90 | def _init_weight(self,pretrained): 91 | for m in self.modules(): 92 | if pretrained: 93 | break 94 | else: 95 | if isinstance(m, nn.Conv2d): 96 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 97 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 98 | torch.nn.init.kaiming_normal_(m.weight) 99 | elif isinstance(m, SynchronizedBatchNorm2d): 100 | m.weight.data.fill_(1) 101 | m.bias.data.zero_() 102 | elif isinstance(m, nn.BatchNorm2d): 103 | m.weight.data.fill_(1) 104 | m.bias.data.zero_() 105 | 106 | 107 | def build_aspp(backbone, output_stride, BatchNorm,pretrained): 108 | return ASPP(backbone, output_stride, BatchNorm, pretrained) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Python 3.6](https://img.shields.io/badge/python-3.6-green.svg) 2 | # GUI for IVOS(interactive VOS) and GIS (Guided IVOS) 3 | ![explain_qwerty](etc/png_demo.png) 4 | GUI Implementation of 5 | 6 | CVPR2021 paper "Guided Interactive Video Object Segmentation Using Reliability-Based Attention Maps" 7 | 8 | ECCV2020 paper "Interactive Video Object Segmentation Using Global and Local Transfer Modules" 9 | 10 | Githubs: 11 | [CVPR2021](https://github.com/yuk6heo/GIS-RAmap) / 12 | [ECCV2020](https://github.com/yuk6heo/IVOS-ATNet) 13 | 14 | Project Pages: 15 | [CVPR2021](http://mcl.korea.ac.kr/yukheo_cvpr2021/) / 16 | [ECCV2020](http://mcl.korea.ac.kr/yukheo_eccv2020/) 17 | 18 | Codes in this github: 19 | 20 | 1. Real-world GUI evaluation on DAVIS2017 based on the [DAVIS framework](https://interactive.davischallenge.org/) 21 | 2. GUI for other videos 22 | 23 | ## Prerequisite 24 | - cuda 11.0 25 | - python 3.6 26 | - pytorch 1.6.0 27 | - [davisinteractive 1.0.4](https://github.com/albertomontesg/davis-interactive) 28 | - numpy, cv2, PtQt5, and other general libraries of python3 29 | 30 | ## Directory Structure 31 | 32 | * `root/apps`: QWidget apps. 33 | 34 | * `root/checkpoints`: save our checkpoints (pth extensions) here. 35 | 36 | * `root/dataset_torch`: pytorch datasets. 37 | 38 | * `root/libs`: library of utility files. 39 | 40 | * `root/model_CVPR2021` : networks and GUI models for CVPR2021 41 | - detailed explanations on [[Github:CVPR2021]](https://github.com/yuk6heo/GIS-RAmap) 42 | * `root/model_ECCV2020` : networks and GUI models for ECCV2020 43 | - detailed explanations (building correlation package) on [[Github:ECCV2020]](https://github.com/yuk6heo/IVOS-ATNet) 44 | 45 | * `root/eval_GIS_RS1.py` : DAVIS2017 evaluation based on the [DAVIS framework](https://interactive.davischallenge.org/). 46 | * `root/eval_GIS_RS4.py` : DAVIS2017 evaluation based on the [DAVIS framework](https://interactive.davischallenge.org/). 47 | * `root/eval_IVOS.py` : DAVIS2017 evaluation based on the [DAVIS framework](https://interactive.davischallenge.org/). 48 | * `root/IVOS_demo_customvideo.py` : GUI for custom videos 49 | 50 | ## Instruction 51 | 52 | ### To run 53 | 1. Edit `eval_GIS_RS1.py``eval_GIS_RS4.py``eval_IVOS.py``IVOS_demo_customvideo.py` to set the directory of your DAVIS2017 dataset and other configurations. 54 | 2. Download our parameters and place the file as `root/checkpoints/GIS-ckpt_standard.pth`. 55 | - For CVPR2021 evaluation [[Google-Drive]](https://drive.google.com/file/d/1dkgXJJ2gPYDtPE9yTtlP4Th0iNX5ZG6a/view?usp=sharing) 56 | - For ECCV2020 evaluation [[Google-Drive]](https://drive.google.com/file/d/1t1VO2zy3pLBXCWqme9h63Def86Y4ECIH/view?usp=sharing) 57 | 3. Run `eval_GIS_RS1.py``eval_GIS_RS4.py``eval_IVOS.py` for real-world GUI evaluation on DAVIS2017 or 58 | 4. Run `IVOS_demo_customvideo.py` to apply our method on the other videos 59 | 60 | ### To use 61 | ![explain_qwerty](etc/explain_qwerty.png) 62 | 63 | Left click for the target object and right click for the background. 64 | 1. Select any frame to interact by dragging the slidder under the main image 65 | 2. Give interaction 66 | 3. Run VOS 67 | 4. Find worst frame and reinteract. - For GIS, a candidate frame(RS1) or candidate frames(RS4) are given 68 | 5. Iterate until you get satisfied with VOS results. 69 | 6. By selecting satisfied button, your evaluation result (consumed time and frames) will be recorded on `root/results`. 70 | 71 | ## Reference 72 | 73 | Please cite our paper if the implementations are useful in your work: 74 | ``` 75 | @Inproceedings{ 76 | Yuk2021GIS, 77 | title={Guided Interactive Video Object Segmentation Using Reliability-Based Attention Maps}, 78 | author={Yuk Heo and Yeong Jun Koh and Chang-Su Kim}, 79 | booktitle={CVPR}, 80 | year={2021}, 81 | url={https://openaccess.thecvf.com/content/CVPR2021/papers/Heo_Guided_Interactive_Video_Object_Segmentation_Using_Reliability-Based_Attention_Maps_CVPR_2021_paper.pdf} 82 | } 83 | ``` 84 | 85 | ``` 86 | @Inproceedings{ 87 | Yuk2020IVOS, 88 | title={Interactive Video Object Segmentation Using Global and Local Transfer Modules}, 89 | author={Yuk Heo and Yeong Jun Koh and Chang-Su Kim}, 90 | booktitle={ECCV}, 91 | year={2020}, 92 | url={https://openreview.net/forum?id=bo_lWt_aA} 93 | } 94 | ``` 95 | 96 | 97 | Our real-world evaluation demo is based on the GUI of [IPNet](https://github.com/seoungwugoh/ivs-demo): 98 | ``` 99 | @Inproceedings{ 100 | Oh2019IVOS, 101 | title={Fast User-Guided Video Object Segmentation by Interaction-and-Propagation Networks}, 102 | author={Seoung Wug Oh and Joon-Young Lee and Seon Joo Kim}, 103 | booktitle={CVPR}, 104 | year={2019}, 105 | url={https://openaccess.thecvf.com/content_ICCV_2019/papers/Oh_Video_Object_Segmentation_Using_Space-Time_Memory_Networks_ICCV_2019_paper.pdf} 106 | } 107 | ``` 108 | -------------------------------------------------------------------------------- /model_CVPR2021/networks/deeplab/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /model_ECCV2020/networks/deeplab/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /model_CVPR2021/networks/deeplab/backbone/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import math 5 | from ..sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | import torch.utils.model_zoo as model_zoo 7 | 8 | def conv_bn(inp, oup, stride, BatchNorm): 9 | return nn.Sequential( 10 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 11 | BatchNorm(oup), 12 | nn.ReLU6(inplace=True) 13 | ) 14 | 15 | 16 | def fixed_padding(inputs, kernel_size, dilation): 17 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 18 | pad_total = kernel_size_effective - 1 19 | pad_beg = pad_total // 2 20 | pad_end = pad_total - pad_beg 21 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 22 | return padded_inputs 23 | 24 | 25 | class InvertedResidual(nn.Module): 26 | def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm): 27 | super(InvertedResidual, self).__init__() 28 | self.stride = stride 29 | assert stride in [1, 2] 30 | 31 | hidden_dim = round(inp * expand_ratio) 32 | self.use_res_connect = self.stride == 1 and inp == oup 33 | self.kernel_size = 3 34 | self.dilation = dilation 35 | 36 | if expand_ratio == 1: 37 | self.conv = nn.Sequential( 38 | # dw 39 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 40 | BatchNorm(hidden_dim), 41 | nn.ReLU6(inplace=True), 42 | # pw-linear 43 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False), 44 | BatchNorm(oup), 45 | ) 46 | else: 47 | self.conv = nn.Sequential( 48 | # pw 49 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False), 50 | BatchNorm(hidden_dim), 51 | nn.ReLU6(inplace=True), 52 | # dw 53 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 54 | BatchNorm(hidden_dim), 55 | nn.ReLU6(inplace=True), 56 | # pw-linear 57 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False), 58 | BatchNorm(oup), 59 | ) 60 | 61 | def forward(self, x): 62 | x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation) 63 | if self.use_res_connect: 64 | x = x + self.conv(x_pad) 65 | else: 66 | x = self.conv(x_pad) 67 | return x 68 | 69 | 70 | class MobileNetV2(nn.Module): 71 | def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True): 72 | super(MobileNetV2, self).__init__() 73 | block = InvertedResidual 74 | input_channel = 32 75 | current_stride = 1 76 | rate = 1 77 | interverted_residual_setting = [ 78 | # t, c, n, s 79 | [1, 16, 1, 1], 80 | [6, 24, 2, 2], 81 | [6, 32, 3, 2], 82 | [6, 64, 4, 2], 83 | [6, 96, 3, 1], 84 | [6, 160, 3, 2], 85 | [6, 320, 1, 1], 86 | ] 87 | 88 | # building first layer 89 | input_channel = int(input_channel * width_mult) 90 | self.features = [conv_bn(3, input_channel, 2, BatchNorm)] 91 | current_stride *= 2 92 | # building inverted residual blocks 93 | for t, c, n, s in interverted_residual_setting: 94 | if current_stride == output_stride: 95 | stride = 1 96 | dilation = rate 97 | rate *= s 98 | else: 99 | stride = s 100 | dilation = 1 101 | current_stride *= s 102 | output_channel = int(c * width_mult) 103 | for i in range(n): 104 | if i == 0: 105 | self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm)) 106 | else: 107 | self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm)) 108 | input_channel = output_channel 109 | self.features = nn.Sequential(*self.features) 110 | self._initialize_weights() 111 | 112 | if pretrained: 113 | self._load_pretrained_model() 114 | 115 | self.low_level_features = self.features[0:4] 116 | self.high_level_features = self.features[4:] 117 | 118 | def forward(self, x): 119 | low_level_feat = self.low_level_features(x) 120 | x = self.high_level_features(low_level_feat) 121 | return x, low_level_feat 122 | 123 | def _load_pretrained_model(self): 124 | pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth') 125 | model_dict = {} 126 | state_dict = self.state_dict() 127 | for k, v in pretrain_dict.items(): 128 | if k in state_dict: 129 | model_dict[k] = v 130 | state_dict.update(model_dict) 131 | self.load_state_dict(state_dict) 132 | 133 | def _initialize_weights(self): 134 | for m in self.modules(): 135 | if isinstance(m, nn.Conv2d): 136 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 137 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 138 | torch.nn.init.kaiming_normal_(m.weight) 139 | elif isinstance(m, SynchronizedBatchNorm2d): 140 | m.weight.data.fill_(1) 141 | m.bias.data.zero_() 142 | elif isinstance(m, nn.BatchNorm2d): 143 | m.weight.data.fill_(1) 144 | m.bias.data.zero_() 145 | 146 | if __name__ == "__main__": 147 | input = torch.rand(1, 3, 512, 512) 148 | model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d) 149 | output, low_level_feat = model(input) 150 | print(output.size()) 151 | print(low_level_feat.size()) 152 | -------------------------------------------------------------------------------- /model_ECCV2020/networks/deeplab/backbone/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import math 5 | from ..sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | import torch.utils.model_zoo as model_zoo 7 | 8 | def conv_bn(inp, oup, stride, BatchNorm): 9 | return nn.Sequential( 10 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 11 | BatchNorm(oup), 12 | nn.ReLU6(inplace=True) 13 | ) 14 | 15 | 16 | def fixed_padding(inputs, kernel_size, dilation): 17 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 18 | pad_total = kernel_size_effective - 1 19 | pad_beg = pad_total // 2 20 | pad_end = pad_total - pad_beg 21 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 22 | return padded_inputs 23 | 24 | 25 | class InvertedResidual(nn.Module): 26 | def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm): 27 | super(InvertedResidual, self).__init__() 28 | self.stride = stride 29 | assert stride in [1, 2] 30 | 31 | hidden_dim = round(inp * expand_ratio) 32 | self.use_res_connect = self.stride == 1 and inp == oup 33 | self.kernel_size = 3 34 | self.dilation = dilation 35 | 36 | if expand_ratio == 1: 37 | self.conv = nn.Sequential( 38 | # dw 39 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 40 | BatchNorm(hidden_dim), 41 | nn.ReLU6(inplace=True), 42 | # pw-linear 43 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False), 44 | BatchNorm(oup), 45 | ) 46 | else: 47 | self.conv = nn.Sequential( 48 | # pw 49 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False), 50 | BatchNorm(hidden_dim), 51 | nn.ReLU6(inplace=True), 52 | # dw 53 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 54 | BatchNorm(hidden_dim), 55 | nn.ReLU6(inplace=True), 56 | # pw-linear 57 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False), 58 | BatchNorm(oup), 59 | ) 60 | 61 | def forward(self, x): 62 | x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation) 63 | if self.use_res_connect: 64 | x = x + self.conv(x_pad) 65 | else: 66 | x = self.conv(x_pad) 67 | return x 68 | 69 | 70 | class MobileNetV2(nn.Module): 71 | def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True): 72 | super(MobileNetV2, self).__init__() 73 | block = InvertedResidual 74 | input_channel = 32 75 | current_stride = 1 76 | rate = 1 77 | interverted_residual_setting = [ 78 | # t, c, n, s 79 | [1, 16, 1, 1], 80 | [6, 24, 2, 2], 81 | [6, 32, 3, 2], 82 | [6, 64, 4, 2], 83 | [6, 96, 3, 1], 84 | [6, 160, 3, 2], 85 | [6, 320, 1, 1], 86 | ] 87 | 88 | # building first layer 89 | input_channel = int(input_channel * width_mult) 90 | self.features = [conv_bn(3, input_channel, 2, BatchNorm)] 91 | current_stride *= 2 92 | # building inverted residual blocks 93 | for t, c, n, s in interverted_residual_setting: 94 | if current_stride == output_stride: 95 | stride = 1 96 | dilation = rate 97 | rate *= s 98 | else: 99 | stride = s 100 | dilation = 1 101 | current_stride *= s 102 | output_channel = int(c * width_mult) 103 | for i in range(n): 104 | if i == 0: 105 | self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm)) 106 | else: 107 | self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm)) 108 | input_channel = output_channel 109 | self.features = nn.Sequential(*self.features) 110 | self._initialize_weights() 111 | 112 | if pretrained: 113 | self._load_pretrained_model() 114 | 115 | self.low_level_features = self.features[0:4] 116 | self.high_level_features = self.features[4:] 117 | 118 | def forward(self, x): 119 | low_level_feat = self.low_level_features(x) 120 | x = self.high_level_features(low_level_feat) 121 | return x, low_level_feat 122 | 123 | def _load_pretrained_model(self): 124 | pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth') 125 | model_dict = {} 126 | state_dict = self.state_dict() 127 | for k, v in pretrain_dict.items(): 128 | if k in state_dict: 129 | model_dict[k] = v 130 | state_dict.update(model_dict) 131 | self.load_state_dict(state_dict) 132 | 133 | def _initialize_weights(self): 134 | for m in self.modules(): 135 | if isinstance(m, nn.Conv2d): 136 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 137 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 138 | torch.nn.init.kaiming_normal_(m.weight) 139 | elif isinstance(m, SynchronizedBatchNorm2d): 140 | m.weight.data.fill_(1) 141 | m.bias.data.zero_() 142 | elif isinstance(m, nn.BatchNorm2d): 143 | m.weight.data.fill_(1) 144 | m.bias.data.zero_() 145 | 146 | if __name__ == "__main__": 147 | input = torch.rand(1, 3, 512, 512) 148 | model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d) 149 | output, low_level_feat = model(input) 150 | print(output.size()) 151 | print(low_level_feat.size()) 152 | -------------------------------------------------------------------------------- /libs/utils_custom.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import shutil 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | import matplotlib 11 | matplotlib.use('Agg') 12 | from matplotlib import pyplot as plt 13 | import pickle 14 | import glob 15 | import random 16 | import cv2 17 | from scipy import ndimage 18 | import os 19 | 20 | from libs import helpers,utils_torch 21 | from davisinteractive.utils.visualization import * 22 | from davisinteractive.utils.operations import bresenham 23 | from scipy.ndimage.morphology import distance_transform_edt 24 | 25 | from PIL import Image 26 | 27 | 28 | def mkdir(paths): 29 | if not isinstance(paths, (list, tuple)): 30 | paths = [paths] 31 | for path in paths: 32 | if not os.path.isdir(path): 33 | os.makedirs(path) 34 | 35 | 36 | def apply_pad(img, padinfo=None): 37 | if padinfo: # ((hpad,hpad),(wpad,wpad)) 38 | (hpad, wpad) = padinfo 39 | if len(img.shape)==3 : pad_img = np.pad(img, (hpad, wpad, (0, 0)), mode='reflect') # H,W,3 40 | else: pad_img = np.pad(img, (hpad, wpad), mode='reflect') #H,W 41 | return pad_img 42 | else: 43 | h, w = img.shape[0:2] 44 | new_h = h + 32 - h % 32 45 | new_w = w + 32 - w % 32 46 | # print(new_h, new_w) 47 | lh, uh = (new_h - h) / 2, (new_h - h) / 2 + (new_h - h) % 2 48 | lw, uw = (new_w - w) / 2, (new_w - w) / 2 + (new_w - w) % 2 49 | lh, uh, lw, uw = int(lh), int(uh), int(lw), int(uw) 50 | if len(img.shape)==3 : pad_img = np.pad(img, ((lh, uh), (lw, uw), (0, 0)), mode='reflect') # H,W,3 51 | else: pad_img = np.pad(img, ((lh, uh), (lw, uw)), mode='reflect') # H,W 52 | info = ((lh, uh), (lw, uw)) 53 | 54 | return pad_img, info 55 | 56 | 57 | def _pascal_color_map(N=256, normalized=True): 58 | """ 59 | Python implementation of the color map function for the PASCAL VOC data set. 60 | Official Matlab version can be found in the PASCAL VOC devkit 61 | http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit 62 | """ 63 | 64 | def bitget(byteval, idx): 65 | return (byteval & (1 << idx)) != 0 66 | 67 | dtype = 'float32' if normalized else 'uint8' 68 | cmap = np.zeros((N, 3), dtype=dtype) 69 | for i in range(N): 70 | r = g = b = 0 71 | c = i 72 | for j in range(8): 73 | r = r | (bitget(c, 0) << 7 - j) 74 | g = g | (bitget(c, 1) << 7 - j) 75 | b = b | (bitget(c, 2) << 7 - j) 76 | c = c >> 3 77 | 78 | cmap[i] = np.array([r, g, b]) 79 | 80 | cmap = cmap / 255 if normalized else cmap 81 | return cmap 82 | 83 | 84 | def get_prop_list(annotated_frames, annotated_now, num_frames, proportion = 1.0, get_close_anno_frames = False): 85 | 86 | aligned_anno = sorted(annotated_frames) 87 | overlap = aligned_anno.count(annotated_now) 88 | for i in range(overlap): 89 | aligned_anno.remove(annotated_now) 90 | 91 | start_frame, end_frame = 0, num_frames -1 92 | for i in range(len(aligned_anno)): 93 | if aligned_anno[i] > annotated_now: 94 | end_frame = aligned_anno[i] - 1 95 | break 96 | aligned_anno.reverse() 97 | for i in range(len(aligned_anno)): 98 | if aligned_anno[i] < annotated_now: 99 | start_frame = aligned_anno[i]+1 100 | break 101 | 102 | if get_close_anno_frames: 103 | close_frames_round=dict() # 1st column: iaction idx, 2nd column: the close frames 104 | annotated_frames.reverse() 105 | try: close_frames_round["left"] = len(annotated_frames) - annotated_frames.index(start_frame-1) - 1 106 | except: print('No left annotated fr') 107 | try: close_frames_round["right"] = len(annotated_frames) - annotated_frames.index(end_frame) - 1 108 | except: print('No right annotated fr') 109 | 110 | if proportion != 1.0: 111 | if start_frame!=0: 112 | start_frame = annotated_now - int((annotated_now-start_frame)*proportion + 0.5) 113 | if end_frame != num_frames-1: 114 | end_frame = annotated_now + int((end_frame - annotated_now) * proportion + 0.5) 115 | prop_list = list(range(annotated_now,start_frame-1,-1)) + list(range(annotated_now,end_frame+1)) 116 | if len(prop_list)==0: 117 | prop_list = [annotated_now] 118 | 119 | if not get_close_anno_frames: 120 | return prop_list 121 | 122 | else: 123 | return prop_list, close_frames_round 124 | 125 | def scribble_to_image(scribbles, currentframe, obj_id, prev_mask, dilation=5, 126 | nocare_area=None, bresenhamtf=True, blur=True, singleimg=False, seperate_pos_neg = False): 127 | """ Make scrible to previous mask shaped numpyfile 128 | 129 | """ 130 | h,w = prev_mask.shape 131 | regions2exclude_on_maskneg = prev_mask!=obj_id 132 | mask = np.zeros([h,w]) 133 | mask_neg = np.zeros([h,w]) 134 | if singleimg: 135 | scribbles=scribbles 136 | else: scribbles = scribbles[currentframe] 137 | 138 | for scribble in scribbles: 139 | points_scribble = np.round(np.array(scribble['path']) * np.array((w, h))).astype(np.int) 140 | if bresenhamtf and len(points_scribble) > 1: 141 | all_points = bresenham(points_scribble) 142 | else: 143 | all_points = points_scribble 144 | 145 | if obj_id==0: 146 | raise NotImplementedError 147 | else: 148 | if scribble['object_id'] == obj_id: 149 | mask[all_points[:, 1] - 1, all_points[:, 0] - 1] = 1 150 | else: 151 | mask_neg[all_points[:, 1] - 1, all_points[:, 0] - 1] = 1 152 | # else: 153 | # mask_neg[all_points[:, 1] - 1, all_points[:, 0] - 1] = 1 154 | 155 | scr_gt, _ = helpers.scrimg_postprocess(mask, dilation=dilation, nocare_area=nocare_area, blur=blur, blursize=(5, 5)) 156 | scr_gt_neg, _ = helpers.scrimg_postprocess(mask_neg, dilation=dilation, nocare_area=nocare_area, blur=blur, blursize=(5, 5)) 157 | scr_gt_neg[regions2exclude_on_maskneg] = 0 158 | 159 | if seperate_pos_neg: 160 | return scr_gt.astype(np.float32), scr_gt_neg.astype(np.float32) 161 | else: 162 | scr_img = scr_gt - scr_gt_neg 163 | return scr_img.astype(np.float32) 164 | 165 | 166 | class logger: 167 | def __init__(self, log_file): 168 | self.log_file = log_file 169 | 170 | def printNlog(self,str2print): 171 | print(str2print) 172 | with open(self.log_file, 'a') as f: 173 | f.write(str2print + '\n') 174 | f.close() 175 | 176 | def printNlog(str2print, log_file): 177 | print(str2print) 178 | with open(log_file, 'a') as f: 179 | f.write(str2print+'\n') 180 | f.close() 181 | 182 | 183 | def load_frames(path, size=None, num_frames=None): 184 | fnames = glob.glob(os.path.join(path, '*.jpg')) 185 | fnames.sort() 186 | frame_list = [] 187 | for i, fname in enumerate(fnames): 188 | if size: 189 | frame_list.append(np.array(Image.open(fname).convert('RGB').resize((size[0], size[1]), Image.BICUBIC), dtype=np.uint8)) 190 | else: 191 | frame_list.append(np.array(Image.open(fname).convert('RGB'), dtype=np.uint8)) 192 | if num_frames and i > num_frames: 193 | break 194 | frames = np.stack(frame_list, axis=0) 195 | return frames 196 | 197 | def load_gts(path, size=None, num_frames=None): 198 | fnames = glob.glob(os.path.join(path, '*.png')) 199 | fnames.sort() 200 | frame_list = [] 201 | for i, fname in enumerate(fnames): 202 | if size: 203 | frame_list.append(np.array(Image.open(fname).resize((size[0], size[1]), Image.BICUBIC), dtype=np.uint8)) 204 | else: 205 | frame_list.append((np.array(Image.open(fname))).astype(np.bool).astype(np.uint8)) 206 | if num_frames and i > num_frames: 207 | break 208 | segs = np.stack(frame_list, axis=0) 209 | return segs 210 | 211 | def load_gts_multi(path, num_frames=None): 212 | fnames = glob.glob(os.path.join(path, '*.png')) 213 | fnames.sort() 214 | frame_list = [] 215 | for i, fname in enumerate(fnames): 216 | frame_list.append((np.array(Image.open(fname))).astype(np.uint8)) 217 | if num_frames and i > num_frames: 218 | break 219 | segs = np.stack(frame_list, axis=0) 220 | return segs 221 | 222 | 223 | 224 | if __name__ =='__main__': 225 | get_prop_list([50, 70, 90], 71, 100, 0.67) -------------------------------------------------------------------------------- /model_ECCV2020/networks/deeplab/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | from ..sync_batchnorm import SynchronizedBatchNorm2d 5 | 6 | class Bottleneck(nn.Module): 7 | expansion = 4 8 | 9 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 10 | super(Bottleneck, self).__init__() 11 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 12 | self.bn1 = BatchNorm(planes) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 14 | dilation=dilation, padding=dilation, bias=False) 15 | self.bn2 = BatchNorm(planes) 16 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 17 | self.bn3 = BatchNorm(planes * self.expansion) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.downsample = downsample 20 | self.stride = stride 21 | self.dilation = dilation 22 | 23 | def forward(self, x): 24 | residual = x 25 | 26 | out = self.conv1(x) 27 | out = self.bn1(out) 28 | out = self.relu(out) 29 | 30 | out = self.conv2(out) 31 | out = self.bn2(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv3(out) 35 | out = self.bn3(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | class SEBottleneck(nn.Module): 46 | expansion = 4 47 | 48 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 49 | super(SEBottleneck, self).__init__() 50 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 51 | self.bn1 = BatchNorm(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 53 | dilation=dilation, padding=dilation, bias=False) 54 | self.bn2 = BatchNorm(planes) 55 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 56 | self.bn3 = BatchNorm(planes * self.expansion) 57 | self.relu = nn.ReLU(inplace=True) 58 | # SE 59 | self.global_pool = nn.AdaptiveAvgPool2d(1) 60 | self.conv_down = nn.Conv2d( 61 | planes * 4, planes // 4, kernel_size=1, bias=False) 62 | self.conv_up = nn.Conv2d( 63 | planes // 4, planes * 4, kernel_size=1, bias=False) 64 | self.sig = nn.Sigmoid() 65 | 66 | self.downsample = downsample 67 | self.stride = stride 68 | self.dilation = dilation 69 | 70 | def forward(self, x): 71 | residual = x 72 | 73 | out = self.conv1(x) 74 | out = self.bn1(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv2(out) 78 | out = self.bn2(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv3(out) 82 | out = self.bn3(out) 83 | 84 | out1 = self.global_pool(out) 85 | out1 = self.conv_down(out1) 86 | out1 = self.relu(out1) 87 | out1 = self.conv_up(out1) 88 | out1 = self.sig(out1) 89 | 90 | if self.downsample is not None: 91 | residual = self.downsample(x) 92 | 93 | res = out1 * out + residual 94 | res = self.relu(res) 95 | 96 | return res 97 | 98 | 99 | class ResNet(nn.Module): 100 | 101 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True, modelname = 'res101'): 102 | self.inplanes = 64 103 | self.modelname = modelname 104 | super(ResNet, self).__init__() 105 | blocks = [1, 2, 4] 106 | if output_stride == 16: 107 | strides = [1, 2, 2, 1] 108 | dilations = [1, 1, 1, 2] 109 | elif output_stride == 8: 110 | strides = [1, 2, 1, 1] 111 | dilations = [1, 1, 2, 4] 112 | else: 113 | raise NotImplementedError 114 | 115 | # Modules 116 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 117 | self.bn1 = BatchNorm(64) 118 | self.relu = nn.ReLU(inplace=True) 119 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 120 | 121 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm) 122 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm) 123 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm) 124 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 125 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 126 | self._init_weight() 127 | if pretrained: 128 | self._load_pretrained_model() 129 | 130 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 131 | downsample = None 132 | if stride != 1 or self.inplanes != planes * block.expansion: 133 | downsample = nn.Sequential( 134 | nn.Conv2d(self.inplanes, planes * block.expansion, 135 | kernel_size=1, stride=stride, bias=False), 136 | BatchNorm(planes * block.expansion), 137 | ) 138 | 139 | layers = [] 140 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 141 | self.inplanes = planes * block.expansion 142 | for i in range(1, blocks): 143 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 144 | 145 | return nn.Sequential(*layers) 146 | 147 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 148 | downsample = None 149 | if stride != 1 or self.inplanes != planes * block.expansion: 150 | downsample = nn.Sequential( 151 | nn.Conv2d(self.inplanes, planes * block.expansion, 152 | kernel_size=1, stride=stride, bias=False), 153 | BatchNorm(planes * block.expansion), 154 | ) 155 | 156 | layers = [] 157 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, 158 | downsample=downsample, BatchNorm=BatchNorm)) 159 | self.inplanes = planes * block.expansion 160 | for i in range(1, len(blocks)): 161 | layers.append(block(self.inplanes, planes, stride=1, 162 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm)) 163 | 164 | return nn.Sequential(*layers) 165 | 166 | def forward(self, input): 167 | x = self.conv1(input) 168 | x = self.bn1(x) 169 | x = self.relu(x) 170 | x = self.maxpool(x) 171 | 172 | x = self.layer1(x) #256 128 128 173 | low_level_feat = x 174 | x = self.layer2(x) #512 64 64 175 | x = self.layer3(x) #1024 32 32 176 | x = self.layer4(x) #2048 32 32 177 | return x, low_level_feat 178 | 179 | def _init_weight(self): 180 | for m in self.modules(): 181 | if isinstance(m, nn.Conv2d): 182 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 183 | m.weight.data.normal_(0, math.sqrt(2. / n)) 184 | elif isinstance(m, SynchronizedBatchNorm2d): 185 | m.weight.data.fill_(1) 186 | m.bias.data.zero_() 187 | elif isinstance(m, nn.BatchNorm2d): 188 | m.weight.data.fill_(1) 189 | m.bias.data.zero_() 190 | 191 | def _load_pretrained_model(self): 192 | if self.modelname =='res101': 193 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth') 194 | elif self.modelname == 'res50': 195 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet50-19c8e357.pth') 196 | elif self.modelname == 'SEres50': 197 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet50-19c8e357.pth') 198 | else: raise NotImplementedError 199 | model_dict = {} 200 | state_dict = self.state_dict() 201 | for k, v in pretrain_dict.items(): 202 | if k in state_dict: 203 | model_dict[k] = v 204 | state_dict.update(model_dict) 205 | self.load_state_dict(state_dict) 206 | 207 | def ResNet101(output_stride, BatchNorm, pretrained=True,): 208 | """Constructs a ResNet-101 model. 209 | Args: 210 | pretrained (bool): If True, returns a model pre-trained on ImageNet 211 | """ 212 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained, modelname='res101') 213 | return model 214 | 215 | def ResNet50(output_stride, BatchNorm, pretrained=True): 216 | """Constructs a ResNet-50 model. 217 | 218 | Args: 219 | pretrained (bool): If True, returns a model pre-trained on ImageNet 220 | """ 221 | model = ResNet(Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, pretrained=pretrained, modelname='res50') 222 | return model 223 | 224 | def SEResNet50(output_stride, BatchNorm, pretrained=True): 225 | """Constructs a ResNet-50 model. 226 | 227 | Args: 228 | pretrained (bool): If True, returns a model pre-trained on ImageNet 229 | """ 230 | model = ResNet(SEBottleneck, [3, 4, 6, 3], output_stride, BatchNorm, pretrained=pretrained, modelname='SEres50') 231 | return model 232 | 233 | if __name__ == "__main__": 234 | import torch 235 | model = ResNet50(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=16) 236 | input = torch.rand(1, 3, 512, 512) 237 | output, low_level_feat = model(input) 238 | print(output.size()) 239 | print(low_level_feat.size()) -------------------------------------------------------------------------------- /model_CVPR2021/networks/deeplab/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | from ..sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | 6 | class Bottleneck(nn.Module): 7 | expansion = 4 8 | 9 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 10 | super(Bottleneck, self).__init__() 11 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 12 | self.bn1 = BatchNorm(planes) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 14 | dilation=dilation, padding=dilation, bias=False) 15 | self.bn2 = BatchNorm(planes) 16 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 17 | self.bn3 = BatchNorm(planes * self.expansion) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.downsample = downsample 20 | self.stride = stride 21 | self.dilation = dilation 22 | 23 | def forward(self, x): 24 | residual = x 25 | 26 | out = self.conv1(x) 27 | out = self.bn1(out) 28 | out = self.relu(out) 29 | 30 | out = self.conv2(out) 31 | out = self.bn2(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv3(out) 35 | out = self.bn3(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | class SEBottleneck(nn.Module): 46 | expansion = 4 47 | 48 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 49 | super(SEBottleneck, self).__init__() 50 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 51 | self.bn1 = BatchNorm(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 53 | dilation=dilation, padding=dilation, bias=False) 54 | self.bn2 = BatchNorm(planes) 55 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 56 | self.bn3 = BatchNorm(planes * self.expansion) 57 | self.relu = nn.ReLU(inplace=True) 58 | # SE 59 | self.global_pool = nn.AdaptiveAvgPool2d(1) 60 | self.conv_down = nn.Conv2d( 61 | planes * 4, planes // 4, kernel_size=1, bias=False) 62 | self.conv_up = nn.Conv2d( 63 | planes // 4, planes * 4, kernel_size=1, bias=False) 64 | self.sig = nn.Sigmoid() 65 | 66 | self.downsample = downsample 67 | self.stride = stride 68 | self.dilation = dilation 69 | 70 | def forward(self, x): 71 | residual = x 72 | 73 | out = self.conv1(x) 74 | out = self.bn1(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv2(out) 78 | out = self.bn2(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv3(out) 82 | out = self.bn3(out) 83 | 84 | out1 = self.global_pool(out) 85 | out1 = self.conv_down(out1) 86 | out1 = self.relu(out1) 87 | out1 = self.conv_up(out1) 88 | out1 = self.sig(out1) 89 | 90 | if self.downsample is not None: 91 | residual = self.downsample(x) 92 | 93 | res = out1 * out + residual 94 | res = self.relu(res) 95 | 96 | return res 97 | 98 | 99 | class ResNet(nn.Module): 100 | 101 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True, modelname = 'res101'): 102 | self.inplanes = 64 103 | self.modelname = modelname 104 | super(ResNet, self).__init__() 105 | blocks = [1, 2, 4] 106 | if output_stride == 16: 107 | strides = [1, 2, 2, 1] 108 | dilations = [1, 1, 1, 2] 109 | elif output_stride == 8: 110 | strides = [1, 2, 1, 1] 111 | dilations = [1, 1, 2, 4] 112 | else: 113 | raise NotImplementedError 114 | 115 | # Modules 116 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 117 | self.bn1 = BatchNorm(64) 118 | self.relu = nn.ReLU(inplace=True) 119 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 120 | 121 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm) 122 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm) 123 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm) 124 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 125 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 126 | self._init_weight() 127 | if pretrained: 128 | self._load_pretrained_model() 129 | 130 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 131 | downsample = None 132 | if stride != 1 or self.inplanes != planes * block.expansion: 133 | downsample = nn.Sequential( 134 | nn.Conv2d(self.inplanes, planes * block.expansion, 135 | kernel_size=1, stride=stride, bias=False), 136 | BatchNorm(planes * block.expansion), 137 | ) 138 | 139 | layers = [] 140 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 141 | self.inplanes = planes * block.expansion 142 | for i in range(1, blocks): 143 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 144 | 145 | return nn.Sequential(*layers) 146 | 147 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 148 | downsample = None 149 | if stride != 1 or self.inplanes != planes * block.expansion: 150 | downsample = nn.Sequential( 151 | nn.Conv2d(self.inplanes, planes * block.expansion, 152 | kernel_size=1, stride=stride, bias=False), 153 | BatchNorm(planes * block.expansion), 154 | ) 155 | 156 | layers = [] 157 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, 158 | downsample=downsample, BatchNorm=BatchNorm)) 159 | self.inplanes = planes * block.expansion 160 | for i in range(1, len(blocks)): 161 | layers.append(block(self.inplanes, planes, stride=1, 162 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm)) 163 | 164 | return nn.Sequential(*layers) 165 | 166 | def forward(self, input): 167 | x = self.conv1(input) 168 | x = self.bn1(x) 169 | x = self.relu(x) 170 | x = self.maxpool(x) 171 | 172 | x = self.layer1(x) #256 128 128 173 | low_level_feat = x 174 | x = self.layer2(x) #512 64 64 175 | x = self.layer3(x) #1024 32 32 176 | x = self.layer4(x) #2048 32 32 177 | return x, low_level_feat 178 | 179 | def _init_weight(self): 180 | for m in self.modules(): 181 | if isinstance(m, nn.Conv2d): 182 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 183 | m.weight.data.normal_(0, math.sqrt(2. / n)) 184 | elif isinstance(m, SynchronizedBatchNorm2d): 185 | m.weight.data.fill_(1) 186 | m.bias.data.zero_() 187 | elif isinstance(m, nn.BatchNorm2d): 188 | m.weight.data.fill_(1) 189 | m.bias.data.zero_() 190 | 191 | def _load_pretrained_model(self): 192 | if self.modelname =='res101': 193 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth') 194 | elif self.modelname == 'res50': 195 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet50-19c8e357.pth') 196 | elif self.modelname == 'SEres50': 197 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet50-19c8e357.pth') 198 | else: raise NotImplementedError 199 | model_dict = {} 200 | state_dict = self.state_dict() 201 | for k, v in pretrain_dict.items(): 202 | if k in state_dict: 203 | model_dict[k] = v 204 | state_dict.update(model_dict) 205 | self.load_state_dict(state_dict) 206 | 207 | def ResNet101(output_stride, BatchNorm, pretrained=True,): 208 | """Constructs a ResNet-101 model. 209 | Args: 210 | pretrained (bool): If True, returns a model pre-trained on ImageNet 211 | """ 212 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained, modelname='res101') 213 | return model 214 | 215 | def ResNet50(output_stride, BatchNorm, pretrained=True): 216 | """Constructs a ResNet-50 model. 217 | 218 | Args: 219 | pretrained (bool): If True, returns a model pre-trained on ImageNet 220 | """ 221 | model = ResNet(Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, pretrained=pretrained, modelname='res50') 222 | return model 223 | 224 | def SEResNet50(output_stride, BatchNorm, pretrained=True): 225 | """Constructs a ResNet-50 model. 226 | 227 | Args: 228 | pretrained (bool): If True, returns a model pre-trained on ImageNet 229 | """ 230 | model = ResNet(SEBottleneck, [3, 4, 6, 3], output_stride, BatchNorm, pretrained=pretrained, modelname='SEres50') 231 | return model 232 | 233 | if __name__ == "__main__": 234 | import torch 235 | model = ResNet50(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=16) 236 | input = torch.rand(1, 3, 512, 512) 237 | output, low_level_feat = model(input) 238 | print(output.size()) 239 | print(low_level_feat.size()) -------------------------------------------------------------------------------- /model_CVPR2021/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | 4 | # general libs 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | # my libs 9 | # from Algorithm_Heo.networks.network_CerBerusnet_v6 import CBnet 10 | from .networks.network import NET_GAmap 11 | from libs import utils_custom 12 | 13 | # davis 14 | 15 | 16 | class model(): 17 | def __init__(self): 18 | self.net = NET_GAmap() 19 | self.net.cuda() 20 | self.net.load_state_dict(torch.load('checkpoints/GIS-ckpt_standard.pth')) 21 | self.net.eval() 22 | for param in self.net.parameters(): param.requires_grad = False 23 | 24 | self.mean, self.var = torch.Tensor([0.485, 0.456, 0.406]), torch.Tensor([0.229, 0.224, 0.225]) 25 | self.mean, self.var = self.mean.view(1,1,3,1,1).cuda(), self.var.view(1,1,3,1,1).cuda() 26 | 27 | 28 | def init_with_new_video(self, frames, n_obj=1): 29 | self.n_objects = n_obj 30 | self.frames = frames.copy() # f h w 3 31 | self.num_frames, self.height, self.width = self.frames.shape[:3] 32 | self.scores_nf = np.zeros([self.num_frames]) 33 | self.annotated_frames=[] 34 | 35 | pad_info = utils_custom.apply_pad(self.frames[0])[1] 36 | self.hpad1, self.wpad1 = pad_info[0][0], pad_info[1][0] 37 | self.hpad2, self.wpad2 = pad_info[0][1], pad_info[1][1] 38 | self.padding = pad_info[1] + pad_info[0] 39 | self.prob_map_of_frames = torch.zeros((self.num_frames, self.n_objects+1, self.height + sum(pad_info[0]), self.width + sum(pad_info[1]))).requires_grad_(False).cuda() # f,1,p_h,p_w {cudatensor} 40 | 41 | self.all_F = torch.unsqueeze(torch.nn.ReflectionPad2d(self.padding)( 42 | torch.from_numpy(np.transpose(frames, (0,3,1,2))).float() / 255.), dim=1).requires_grad_(False).cuda() # fr,1,3,p_h,p_w {cudatensor} 43 | self.all_F = (self.all_F -self.mean) / self.var 44 | self.current_round_masks = np.zeros([self.num_frames, self.height, self.width]) # f,h,w {numpy} 45 | self.anno_6chEnc_r4_list = [] 46 | self.anno_3chEnc_r4_list = [] 47 | self.current_round = 1 48 | 49 | self.one_hot_outputs_anno = None 50 | self.current_round = 1 51 | 52 | def Run_propagation(self, annotated_now, ): 53 | 54 | print('[Propagation running...]') 55 | self.current_round +=1 56 | prop_list = utils_custom.get_prop_list(self.annotated_frames, annotated_now, self.num_frames, proportion=0.99) 57 | annotated_frames_np = np.array(self.annotated_frames) 58 | prop_fore = sorted(prop_list)[0] 59 | prop_rear = sorted(prop_list)[-1] 60 | 61 | flag = 0 # 1: propagating backward, 2: propagating forward 62 | for operating_frame in prop_list: 63 | if operating_frame == annotated_now: 64 | if flag == 0: 65 | flag += 1 66 | adjacent_to_anno = True 67 | continue 68 | elif flag == 1: 69 | flag += 1 70 | adjacent_to_anno = True 71 | continue 72 | else: 73 | raise NotImplementedError 74 | else: 75 | print('operating in : {:03d}'.format(operating_frame)) 76 | if adjacent_to_anno: 77 | r4_neighbor = self.r4_anno 78 | neighbor_pred_onehot = self.anno_onehot_prob 79 | adjacent_to_anno = False 80 | else: 81 | r4_neighbor = r4_que 82 | neighbor_pred_onehot = targ_onehot_prob 83 | 84 | output_logit, r4_que, score = self.net.forward_prop( 85 | self.anno_3chEnc_r4_list, self.all_F[operating_frame].repeat(self.n_objects,1,1,1), self.anno_6chEnc_r4_list, 86 | r4_neighbor, neighbor_pred_onehot, 87 | anno_fr_list= annotated_frames_np, que_fr= operating_frame) # [nobj, 1, P_H, P_W] 88 | 89 | output_prob_tmp = F.softmax(output_logit, dim=1) # [nobj, 2, P_H, P_W] 90 | output_prob_tmp = output_prob_tmp[:, 1] # [nobj, P_H, P_W] 91 | one_hot_outputs_t = F.softmax(self.soft_aggregation(output_prob_tmp), dim=0) # [nobj+1, P_H, P_W] 92 | 93 | 94 | smallest_alpha = 0.5 95 | if flag==1: 96 | sorted_frames = annotated_frames_np[annotated_frames_np < annotated_now] 97 | if len(sorted_frames) ==0: 98 | alpha = 1 99 | else: 100 | closest_addianno_frame = np.max(sorted_frames) 101 | alpha = smallest_alpha+(1-smallest_alpha)*((operating_frame-closest_addianno_frame)/(annotated_now - closest_addianno_frame)) 102 | else: 103 | sorted_frames = annotated_frames_np[annotated_frames_np > annotated_now] 104 | if len(sorted_frames) == 0: 105 | alpha = 1 106 | else: 107 | closest_addianno_frame = np.min(sorted_frames) 108 | alpha = smallest_alpha+(1-smallest_alpha)*((closest_addianno_frame - operating_frame) / (closest_addianno_frame - annotated_now)) 109 | 110 | 111 | one_hot_outputs_t = (alpha * one_hot_outputs_t) + ((1 - alpha) * self.prob_map_of_frames[operating_frame]) 112 | self.prob_map_of_frames[operating_frame] = one_hot_outputs_t 113 | targ_onehot_prob = one_hot_outputs_t.clone()[1:].unsqueeze(1) # [nobj, 1, P_H, P_W] 114 | 115 | self.scores_nf[operating_frame] = score 116 | 117 | self.current_round_masks = torch.argmax(self.prob_map_of_frames,dim=1).cpu().numpy().astype(np.uint8)[:,self.hpad1:-self.hpad2, self.wpad1:-self.wpad2] 118 | 119 | print('[Propagation process is done.]') 120 | 121 | 122 | def Run_interaction(self, scribbles): 123 | 124 | print('[Interaction running...]') 125 | annotated_now = scribbles['annotated_frame'] 126 | scribbles_list = scribbles['scribbles'] 127 | 128 | pm_ps_ns_3ch_t=[] # n_obj,3,h,w 129 | if self.current_round == 1: 130 | for obj_id in range(1, self.n_objects + 1): 131 | pos_scrimg, neg_scrimg = utils_custom.scribble_to_image(scribbles_list, annotated_now, obj_id, 132 | prev_mask=self.current_round_masks[annotated_now], blur=True, 133 | singleimg=False, seperate_pos_neg=True) 134 | pm_ps_ns_3ch_t.append(np.stack([np.ones_like(pos_scrimg)/2, pos_scrimg, neg_scrimg], axis=0)) 135 | pm_ps_ns_3ch_t = np.stack(pm_ps_ns_3ch_t, axis=0) # n_obj,3,h,w 136 | 137 | else: 138 | for obj_id in range(1, self.n_objects + 1): 139 | prev_round_input = (self.current_round_masks[annotated_now] == obj_id).astype(np.float32) # H,W 140 | pos_scrimg, neg_scrimg = utils_custom.scribble_to_image(scribbles_list, annotated_now, obj_id, 141 | prev_mask=self.current_round_masks[annotated_now], blur=True, 142 | singleimg=False, seperate_pos_neg=True) 143 | pm_ps_ns_3ch_t.append(np.stack([prev_round_input, pos_scrimg, neg_scrimg], axis=0)) 144 | pm_ps_ns_3ch_t = np.stack(pm_ps_ns_3ch_t, axis=0) # n_obj,3,h,w 145 | 146 | batched_F = self.all_F[annotated_now].repeat(self.n_objects,1,1,1) 147 | 148 | pm_ps_ns_3ch_t = torch.from_numpy(pm_ps_ns_3ch_t).cuda() 149 | pm_ps_ns_3ch_t = torch.nn.ReflectionPad2d(self.padding)(pm_ps_ns_3ch_t) 150 | inputs = torch.cat([batched_F, pm_ps_ns_3ch_t], dim=1) 151 | 152 | anno_3chEnc_r4, r2_prev_fromanno = self.net.encoder_3ch.forward(batched_F) 153 | neighbor_pred_onehot_sal, anno_6chEnc_r4 = self.net.forward_obj_feature_extractor(inputs) # [nobj, 1, P_H, P_W], # [n_obj,2048,h/16,w/16] 154 | 155 | output_logit, self.r4_anno, score = self.net.forward_prop( 156 | [anno_3chEnc_r4], batched_F, [anno_6chEnc_r4], 157 | anno_3chEnc_r4, torch.sigmoid(neighbor_pred_onehot_sal)) # [nobj, 1, P_H, P_W] 158 | 159 | output_prob_tmp = F.softmax(output_logit, dim=1) # [nobj, 2, P_H, P_W] 160 | output_prob_tmp = output_prob_tmp[:, 1] # [nobj, P_H, P_W] 161 | one_hot_outputs_t = F.softmax(self.soft_aggregation(output_prob_tmp), dim=0) # [nobj+1, P_H, P_W] 162 | 163 | self.anno_onehot_prob = one_hot_outputs_t.clone()[1:].unsqueeze(1) # [nobj, 1, P_H, P_W] 164 | self.prob_map_of_frames[annotated_now] = one_hot_outputs_t 165 | self.current_round_masks[annotated_now] = \ 166 | torch.argmax(self.prob_map_of_frames[annotated_now],dim=0).cpu().numpy().astype(np.uint8)[self.hpad1:-self.hpad2, self.wpad1:-self.wpad2] 167 | self.scores_nf[annotated_now] = score 168 | 169 | 170 | if len(self.anno_6chEnc_r4_list) < self.current_round: 171 | self.anno_6chEnc_r4_list.append(anno_6chEnc_r4) 172 | self.anno_3chEnc_r4_list.append(anno_3chEnc_r4) 173 | self.annotated_frames.append(annotated_now) 174 | elif len(self.anno_6chEnc_r4_list) == self.current_round: 175 | self.anno_6chEnc_r4_list[self.current_round-1] = anno_6chEnc_r4 176 | self.anno_3chEnc_r4_list[self.current_round-1] = anno_3chEnc_r4 177 | else: 178 | raise NotImplementedError 179 | 180 | 181 | 182 | 183 | print('[Interaction process is done.]') 184 | 185 | def Get_mask(self): 186 | return self.current_round_masks 187 | 188 | 189 | def Get_mask_index(self, index): 190 | return self.current_round_masks[index] 191 | 192 | 193 | def soft_aggregation(self, ps): 194 | num_objects, H, W = ps.shape 195 | em = torch.zeros(num_objects +1, H, W).cuda() 196 | em[0] = torch.prod(1-ps, dim=0) # bg prob 197 | em[1:num_objects+1] = ps # obj prob 198 | em = torch.clamp(em, 1e-7, 1-1e-7) 199 | logit = torch.log((em /(1-em))) 200 | return logit 201 | -------------------------------------------------------------------------------- /model_ECCV2020/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | 4 | import torch.nn as nn 5 | 6 | # general libs 7 | from PIL import Image 8 | import numpy as np 9 | from torchvision import transforms 10 | 11 | # my libs 12 | # from Algorithm_Heo.networks.network_CerBerusnet_v6 import CBnet 13 | from model_ECCV2020.networks.network import ATnet 14 | from libs import utils_custom, utils_torch, helpers 15 | 16 | 17 | class model(): 18 | def __init__(self, 19 | iact_threshold = 0.8, 20 | prop_threshold = 0.8): 21 | self.net = ATnet() 22 | self.net.cuda() 23 | # self.net.load_state_dict(torch.load('/data/0codes/VOS_Scribble_interactive/Interactive_VOS_2020/results/train_result/CBNetv8_prop_videolearn_dualloss_SA50_20191109_145805/PTHbestIoU_0013.pth')) 24 | self.net.load_state_dict(torch.load('/home/yuk/codes_ssd/Interactive_VOS_2020/results/train_results_optimized/ATNet_v00_fromCVPR2020_20200214-110610_Frominital/e_0046_onlyparam.pth')) 25 | self.net.eval() 26 | for param in self.net.parameters(): param.requires_grad = False 27 | 28 | self.iact_threshold = iact_threshold 29 | self.prop_threshold = prop_threshold 30 | 31 | self.mean, self.var = torch.Tensor([0.485, 0.456, 0.406]), torch.Tensor([0.229, 0.224, 0.225]) 32 | self.mean, self.var = self.mean.view(1,1,3,1,1).cuda(), self.var.view(1,1,3,1,1).cuda() 33 | 34 | 35 | def init_with_new_video(self, frames, n_obj=1): 36 | self.n_objects = n_obj 37 | self.frames = frames.copy() # f h w 3 38 | self.num_frames, self.height, self.width = self.frames.shape[:3] 39 | self.scores_nf = np.zeros([self.num_frames]) 40 | self.annotated_frames=[] 41 | 42 | pad_info = utils_custom.apply_pad(self.frames[0])[1] 43 | self.hpad1, self.wpad1 = pad_info[0][0], pad_info[1][0] 44 | self.hpad2, self.wpad2 = pad_info[0][1], pad_info[1][1] 45 | self.padding = pad_info[1] + pad_info[0] 46 | self.prob_map_of_frames = torch.zeros((self.num_frames, self.n_objects, self.height + sum(pad_info[0]), self.width + sum(pad_info[1]))).requires_grad_(False).cuda() # f,1,p_h,p_w {cudatensor} 47 | 48 | self.all_F = torch.unsqueeze(torch.nn.ReflectionPad2d(self.padding)( 49 | torch.from_numpy(np.transpose(frames, (0,3,1,2))).float() / 255.), dim=1).requires_grad_(False).cuda() # fr,1,3,p_h,p_w {cudatensor} 50 | self.all_F = (self.all_F -self.mean) / self.var 51 | 52 | self.current_round_masks = np.zeros([self.num_frames, self.height, self.width]) # f,h,w {numpy} 53 | 54 | self.anno_6chEnc_r5_list = [] 55 | self.anno_3chEnc_r5_list = [] 56 | 57 | self.one_hot_outputs_anno = None 58 | self.r2_fromanno = None 59 | self.current_round = 1 60 | 61 | 62 | 63 | def Run_propagation(self, annotated_now, ): 64 | 65 | print('[T-Net running...]') 66 | self.current_round +=1 67 | prop_list = utils_custom.get_prop_list(self.annotated_frames, annotated_now, self.num_frames, proportion=0.99) 68 | annotated_frames_np = np.array(self.annotated_frames) 69 | prop_fore = sorted(prop_list)[0] 70 | prop_rear = sorted(prop_list)[-1] 71 | 72 | flag = 0 # 1: propagating backward, 2: propagating forward 73 | for operating_frame in prop_list: 74 | if operating_frame == annotated_now: 75 | if flag == 0: 76 | flag += 1 77 | adjacent_to_anno = True 78 | elif flag == 1: 79 | flag += 1 80 | adjacent_to_anno = True 81 | continue 82 | else: 83 | raise NotImplementedError 84 | else: 85 | print('operating in : {:03d}'.format(operating_frame)) 86 | if adjacent_to_anno: 87 | r2_prev = self.r2_fromanno 88 | predmask_prev = self.one_hot_outputs_anno 89 | adjacent_to_anno = False 90 | 91 | output_prob, r2_prev = self.net.forward_prop( 92 | self.anno_3chEnc_r5_list, self.all_F[operating_frame].repeat(self.n_objects,1,1,1), self.anno_6chEnc_r5_list, r2_prev, predmask_prev) # [nobj, 1, P_H, P_W] 93 | 94 | predmask_prev = torch.sigmoid(output_prob) # [nobj, 1, P_H, P_W] 95 | one_hot_outputs_t = predmask_prev[:, 0].detach() 96 | 97 | 98 | smallest_alpha = 0.5 99 | if flag==1: 100 | sorted_frames = annotated_frames_np[annotated_frames_np < annotated_now] 101 | if len(sorted_frames) ==0: 102 | alpha = 1 103 | else: 104 | closest_addianno_frame = np.max(sorted_frames) 105 | alpha = smallest_alpha+(1-smallest_alpha)*((operating_frame-closest_addianno_frame)/(annotated_now - closest_addianno_frame)) 106 | else: 107 | sorted_frames = annotated_frames_np[annotated_frames_np > annotated_now] 108 | if len(sorted_frames) == 0: 109 | alpha = 1 110 | else: 111 | closest_addianno_frame = np.min(sorted_frames) 112 | alpha = smallest_alpha+(1-smallest_alpha)*((closest_addianno_frame - operating_frame) / (closest_addianno_frame - annotated_now)) 113 | 114 | 115 | one_hot_outputs_t = (alpha * one_hot_outputs_t) + ((1 - alpha) * self.prob_map_of_frames[operating_frame]) 116 | self.prob_map_of_frames[operating_frame] = one_hot_outputs_t 117 | 118 | self.current_round_masks[prop_fore:prop_rear + 1] = \ 119 | utils_torch.combine_masks_with_batch(self.prob_map_of_frames[prop_fore:prop_rear + 1], 120 | n_obj=self.n_objects, 121 | th=self.prop_threshold 122 | )[:, 0, self.hpad1:-self.hpad2, self.wpad1:-self.wpad2].cpu().numpy() # f,h,w 123 | if self.iact_threshold != self.prop_threshold: 124 | utils_torch.combine_masks_with_batch(torch.unsqueeze(self.prob_map_of_frames[annotated_now], dim=0), 125 | n_obj=self.n_objects, 126 | th=self.iact_threshold 127 | )[0, 0, self.hpad1:-self.hpad2, self.wpad1:-self.wpad2].cpu().numpy() # f,h,w 128 | 129 | print('[T-Net process is done.]') 130 | 131 | 132 | def Run_interaction(self, scribbles): 133 | 134 | print('[A-Net running...]') 135 | annotated_now = scribbles['annotated_frame'] 136 | scribbles_list = scribbles['scribbles'] 137 | 138 | pm_ps_ns_3ch_t=[] # n_obj,3,h,w 139 | if self.current_round == 1: 140 | for obj_id in range(1, self.n_objects + 1): 141 | pos_scrimg, neg_scrimg = utils_custom.scribble_to_image(scribbles_list, annotated_now, obj_id, 142 | prev_mask=self.current_round_masks[annotated_now], blur=True, 143 | singleimg=False, seperate_pos_neg=True) 144 | pm_ps_ns_3ch_t.append(np.stack([np.ones_like(pos_scrimg)/2, pos_scrimg, neg_scrimg], axis=0)) 145 | pm_ps_ns_3ch_t = np.stack(pm_ps_ns_3ch_t, axis=0) # n_obj,3,h,w 146 | 147 | else: 148 | for obj_id in range(1, self.n_objects + 1): 149 | prev_round_input = (self.current_round_masks[annotated_now] == obj_id).astype(np.float32) # H,W 150 | pos_scrimg, neg_scrimg = utils_custom.scribble_to_image(scribbles_list, annotated_now, obj_id, 151 | prev_mask=self.current_round_masks[annotated_now], blur=True, 152 | singleimg=False, seperate_pos_neg=True) 153 | pm_ps_ns_3ch_t.append(np.stack([prev_round_input, pos_scrimg, neg_scrimg], axis=0)) 154 | pm_ps_ns_3ch_t = np.stack(pm_ps_ns_3ch_t, axis=0) # n_obj,3,h,w 155 | 156 | batched_F = self.all_F[annotated_now].repeat(self.n_objects,1,1,1) 157 | 158 | pm_ps_ns_3ch_t = torch.from_numpy(pm_ps_ns_3ch_t).cuda() 159 | pm_ps_ns_3ch_t = torch.nn.ReflectionPad2d(self.padding)(pm_ps_ns_3ch_t) 160 | inputs = torch.cat([batched_F, pm_ps_ns_3ch_t], dim=1) 161 | 162 | output_prob, anno_6chEnc_r5 = self.net.forward_iact(inputs) 163 | anno_3chEnc_r5, _, _, self.r2_fromanno = self.net.encoder_3ch.forward(batched_F) 164 | 165 | if len(self.anno_6chEnc_r5_list) < self.current_round: 166 | self.anno_6chEnc_r5_list.append(anno_6chEnc_r5) 167 | self.anno_3chEnc_r5_list.append(anno_3chEnc_r5) 168 | self.annotated_frames.append(annotated_now) 169 | elif len(self.anno_6chEnc_r5_list) == self.current_round: 170 | self.anno_6chEnc_r5_list[self.current_round-1] = anno_6chEnc_r5 171 | self.anno_3chEnc_r5_list[self.current_round-1] = anno_3chEnc_r5 172 | else: 173 | raise NotImplementedError 174 | 175 | 176 | 177 | self.one_hot_outputs_anno = torch.sigmoid(output_prob) 178 | self.prob_map_of_frames[annotated_now] = self.one_hot_outputs_anno[:, 0].detach() 179 | 180 | self.current_round_masks[annotated_now] = utils_torch.combine_masks_with_batch(torch.unsqueeze(self.prob_map_of_frames[annotated_now], dim=0), 181 | n_obj=self.n_objects, 182 | th=self.iact_threshold 183 | )[0, 0, self.hpad1:-self.hpad2, 184 | self.wpad1:-self.wpad2].cpu().numpy() # f,h,w 185 | 186 | print('[A-Net process is done.]') 187 | 188 | def Get_mask(self): 189 | return self.current_round_masks 190 | 191 | 192 | def Get_mask_index(self, index): 193 | return self.current_round_masks[index] 194 | 195 | -------------------------------------------------------------------------------- /model_CVPR2021/networks/deeplab/backbone/xception.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from ..sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 7 | 8 | def fixed_padding(inputs, kernel_size, dilation): 9 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 10 | pad_total = kernel_size_effective - 1 11 | pad_beg = pad_total // 2 12 | pad_end = pad_total - pad_beg 13 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 14 | return padded_inputs 15 | 16 | 17 | class SeparableConv2d(nn.Module): 18 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, BatchNorm=None): 19 | super(SeparableConv2d, self).__init__() 20 | 21 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation, 22 | groups=inplanes, bias=bias) 23 | self.bn = BatchNorm(inplanes) 24 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) 25 | 26 | def forward(self, x): 27 | x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0]) 28 | x = self.conv1(x) 29 | x = self.bn(x) 30 | x = self.pointwise(x) 31 | return x 32 | 33 | 34 | class Block(nn.Module): 35 | def __init__(self, inplanes, planes, reps, stride=1, dilation=1, BatchNorm=None, 36 | start_with_relu=True, grow_first=True, is_last=False): 37 | super(Block, self).__init__() 38 | 39 | if planes != inplanes or stride != 1: 40 | self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False) 41 | self.skipbn = BatchNorm(planes) 42 | else: 43 | self.skip = None 44 | 45 | self.relu = nn.ReLU(inplace=True) 46 | rep = [] 47 | 48 | filters = inplanes 49 | if grow_first: 50 | rep.append(self.relu) 51 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 52 | rep.append(BatchNorm(planes)) 53 | filters = planes 54 | 55 | for i in range(reps - 1): 56 | rep.append(self.relu) 57 | rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, BatchNorm=BatchNorm)) 58 | rep.append(BatchNorm(filters)) 59 | 60 | if not grow_first: 61 | rep.append(self.relu) 62 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 63 | rep.append(BatchNorm(planes)) 64 | 65 | if stride != 1: 66 | rep.append(self.relu) 67 | rep.append(SeparableConv2d(planes, planes, 3, 2, BatchNorm=BatchNorm)) 68 | rep.append(BatchNorm(planes)) 69 | 70 | if stride == 1 and is_last: 71 | rep.append(self.relu) 72 | rep.append(SeparableConv2d(planes, planes, 3, 1, BatchNorm=BatchNorm)) 73 | rep.append(BatchNorm(planes)) 74 | 75 | if not start_with_relu: 76 | rep = rep[1:] 77 | 78 | self.rep = nn.Sequential(*rep) 79 | 80 | def forward(self, inp): 81 | x = self.rep(inp) 82 | 83 | if self.skip is not None: 84 | skip = self.skip(inp) 85 | skip = self.skipbn(skip) 86 | else: 87 | skip = inp 88 | 89 | x = x + skip 90 | 91 | return x 92 | 93 | 94 | class AlignedXception(nn.Module): 95 | """ 96 | Modified Alighed Xception 97 | """ 98 | def __init__(self, output_stride, BatchNorm, 99 | pretrained=True): 100 | super(AlignedXception, self).__init__() 101 | 102 | if output_stride == 16: 103 | entry_block3_stride = 2 104 | middle_block_dilation = 1 105 | exit_block_dilations = (1, 2) 106 | elif output_stride == 8: 107 | entry_block3_stride = 1 108 | middle_block_dilation = 2 109 | exit_block_dilations = (2, 4) 110 | else: 111 | raise NotImplementedError 112 | 113 | 114 | # Entry flow 115 | self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False) 116 | self.bn1 = BatchNorm(32) 117 | self.relu = nn.ReLU(inplace=True) 118 | 119 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False) 120 | self.bn2 = BatchNorm(64) 121 | 122 | self.block1 = Block(64, 128, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False) 123 | self.block2 = Block(128, 256, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False, 124 | grow_first=True) 125 | self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, BatchNorm=BatchNorm, 126 | start_with_relu=True, grow_first=True, is_last=True) 127 | 128 | # Middle flow 129 | self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 130 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 131 | self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 132 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 133 | self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 134 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 135 | self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 136 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 137 | self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 138 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 139 | self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 140 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 141 | self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 142 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 143 | self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 144 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 145 | self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 146 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 147 | self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 148 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 149 | self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 150 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 151 | self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 152 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 153 | self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 154 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 155 | self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 156 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 157 | self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 158 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 159 | self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 160 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 161 | 162 | # Exit flow 163 | self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_dilations[0], 164 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=False, is_last=True) 165 | 166 | self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 167 | self.bn3 = BatchNorm(1536) 168 | 169 | self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 170 | self.bn4 = BatchNorm(1536) 171 | 172 | self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 173 | self.bn5 = BatchNorm(2048) 174 | 175 | # Init weights 176 | self._init_weight() 177 | 178 | # Load pretrained model 179 | if pretrained: 180 | self._load_pretrained_model() 181 | 182 | def forward(self, x): 183 | # Entry flow 184 | x = self.conv1(x) 185 | x = self.bn1(x) 186 | x = self.relu(x) 187 | 188 | x = self.conv2(x) 189 | x = self.bn2(x) 190 | x = self.relu(x) 191 | 192 | x = self.block1(x) 193 | # add relu here 194 | x = self.relu(x) 195 | low_level_feat = x 196 | x = self.block2(x) 197 | x = self.block3(x) 198 | 199 | # Middle flow 200 | x = self.block4(x) 201 | x = self.block5(x) 202 | x = self.block6(x) 203 | x = self.block7(x) 204 | x = self.block8(x) 205 | x = self.block9(x) 206 | x = self.block10(x) 207 | x = self.block11(x) 208 | x = self.block12(x) 209 | x = self.block13(x) 210 | x = self.block14(x) 211 | x = self.block15(x) 212 | x = self.block16(x) 213 | x = self.block17(x) 214 | x = self.block18(x) 215 | x = self.block19(x) 216 | 217 | # Exit flow 218 | x = self.block20(x) 219 | x = self.relu(x) 220 | x = self.conv3(x) 221 | x = self.bn3(x) 222 | x = self.relu(x) 223 | 224 | x = self.conv4(x) 225 | x = self.bn4(x) 226 | x = self.relu(x) 227 | 228 | x = self.conv5(x) 229 | x = self.bn5(x) 230 | x = self.relu(x) 231 | 232 | return x, low_level_feat 233 | 234 | def _init_weight(self): 235 | for m in self.modules(): 236 | if isinstance(m, nn.Conv2d): 237 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 238 | m.weight.data.normal_(0, math.sqrt(2. / n)) 239 | elif isinstance(m, SynchronizedBatchNorm2d): 240 | m.weight.data.fill_(1) 241 | m.bias.data.zero_() 242 | elif isinstance(m, nn.BatchNorm2d): 243 | m.weight.data.fill_(1) 244 | m.bias.data.zero_() 245 | 246 | 247 | def _load_pretrained_model(self): 248 | pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth') 249 | model_dict = {} 250 | state_dict = self.state_dict() 251 | 252 | for k, v in pretrain_dict.items(): 253 | if k in model_dict: 254 | if 'pointwise' in k: 255 | v = v.unsqueeze(-1).unsqueeze(-1) 256 | if k.startswith('block11'): 257 | model_dict[k] = v 258 | model_dict[k.replace('block11', 'block12')] = v 259 | model_dict[k.replace('block11', 'block13')] = v 260 | model_dict[k.replace('block11', 'block14')] = v 261 | model_dict[k.replace('block11', 'block15')] = v 262 | model_dict[k.replace('block11', 'block16')] = v 263 | model_dict[k.replace('block11', 'block17')] = v 264 | model_dict[k.replace('block11', 'block18')] = v 265 | model_dict[k.replace('block11', 'block19')] = v 266 | elif k.startswith('block12'): 267 | model_dict[k.replace('block12', 'block20')] = v 268 | elif k.startswith('bn3'): 269 | model_dict[k] = v 270 | model_dict[k.replace('bn3', 'bn4')] = v 271 | elif k.startswith('conv4'): 272 | model_dict[k.replace('conv4', 'conv5')] = v 273 | elif k.startswith('bn4'): 274 | model_dict[k.replace('bn4', 'bn5')] = v 275 | else: 276 | model_dict[k] = v 277 | state_dict.update(model_dict) 278 | self.load_state_dict(state_dict) 279 | 280 | 281 | 282 | if __name__ == "__main__": 283 | import torch 284 | model = AlignedXception(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=16) 285 | input = torch.rand(1, 3, 512, 512) 286 | output, low_level_feat = model(input) 287 | print(output.size()) 288 | print(low_level_feat.size()) -------------------------------------------------------------------------------- /model_CVPR2021/networks/deeplab/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | .. math:: 132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 134 | standard-deviation are reduced across all devices during training. 135 | For example, when one uses `nn.DataParallel` to wrap the network during 136 | training, PyTorch's implementation normalize the tensor on each device using 137 | the statistics only on that device, which accelerated the computation and 138 | is also easy to implement, but the statistics might be inaccurate. 139 | Instead, in this synchronized version, the statistics will be computed 140 | over all training samples distributed on multiple devices. 141 | 142 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 143 | as the built-in PyTorch implementation. 144 | The mean and standard-deviation are calculated per-dimension over 145 | the mini-batches and gamma and beta are learnable parameter vectors 146 | of size C (where C is the input size). 147 | During training, this layer keeps a running estimate of its computed mean 148 | and variance. The running sum is kept with a default momentum of 0.1. 149 | During evaluation, this running mean/variance is used for normalization. 150 | Because the BatchNorm is done over the `C` dimension, computing statistics 151 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 152 | Args: 153 | num_features: num_features from an expected input of size 154 | `batch_size x num_features [x width]` 155 | eps: a value added to the denominator for numerical stability. 156 | Default: 1e-5 157 | momentum: the value used for the running_mean and running_var 158 | computation. Default: 0.1 159 | affine: a boolean value that when set to ``True``, gives the layer learnable 160 | affine parameters. Default: ``True`` 161 | Shape: 162 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 163 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 164 | Examples: 165 | >>> # With Learnable Parameters 166 | >>> m = SynchronizedBatchNorm1d(100) 167 | >>> # Without Learnable Parameters 168 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 169 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 170 | >>> output = m(input) 171 | """ 172 | 173 | def _check_input_dim(self, input): 174 | if input.dim() != 2 and input.dim() != 3: 175 | raise ValueError('expected 2D or 3D input (got {}D input)' 176 | .format(input.dim())) 177 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 178 | 179 | 180 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 181 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 182 | of 3d inputs 183 | .. math:: 184 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 185 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 186 | standard-deviation are reduced across all devices during training. 187 | For example, when one uses `nn.DataParallel` to wrap the network during 188 | training, PyTorch's implementation normalize the tensor on each device using 189 | the statistics only on that device, which accelerated the computation and 190 | is also easy to implement, but the statistics might be inaccurate. 191 | Instead, in this synchronized version, the statistics will be computed 192 | over all training samples distributed on multiple devices. 193 | 194 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 195 | as the built-in PyTorch implementation. 196 | The mean and standard-deviation are calculated per-dimension over 197 | the mini-batches and gamma and beta are learnable parameter vectors 198 | of size C (where C is the input size). 199 | During training, this layer keeps a running estimate of its computed mean 200 | and variance. The running sum is kept with a default momentum of 0.1. 201 | During evaluation, this running mean/variance is used for normalization. 202 | Because the BatchNorm is done over the `C` dimension, computing statistics 203 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 204 | Args: 205 | num_features: num_features from an expected input of 206 | size batch_size x num_features x height x width 207 | eps: a value added to the denominator for numerical stability. 208 | Default: 1e-5 209 | momentum: the value used for the running_mean and running_var 210 | computation. Default: 0.1 211 | affine: a boolean value that when set to ``True``, gives the layer learnable 212 | affine parameters. Default: ``True`` 213 | Shape: 214 | - Input: :math:`(N, C, H, W)` 215 | - Output: :math:`(N, C, H, W)` (same shape as input) 216 | Examples: 217 | >>> # With Learnable Parameters 218 | >>> m = SynchronizedBatchNorm2d(100) 219 | >>> # Without Learnable Parameters 220 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 221 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 222 | >>> output = m(input) 223 | """ 224 | 225 | def _check_input_dim(self, input): 226 | if input.dim() != 4: 227 | raise ValueError('expected 4D input (got {}D input)' 228 | .format(input.dim())) 229 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 230 | 231 | 232 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 233 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 234 | of 4d inputs 235 | .. math:: 236 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 237 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 238 | standard-deviation are reduced across all devices during training. 239 | For example, when one uses `nn.DataParallel` to wrap the network during 240 | training, PyTorch's implementation normalize the tensor on each device using 241 | the statistics only on that device, which accelerated the computation and 242 | is also easy to implement, but the statistics might be inaccurate. 243 | Instead, in this synchronized version, the statistics will be computed 244 | over all training samples distributed on multiple devices. 245 | 246 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 247 | as the built-in PyTorch implementation. 248 | The mean and standard-deviation are calculated per-dimension over 249 | the mini-batches and gamma and beta are learnable parameter vectors 250 | of size C (where C is the input size). 251 | During training, this layer keeps a running estimate of its computed mean 252 | and variance. The running sum is kept with a default momentum of 0.1. 253 | During evaluation, this running mean/variance is used for normalization. 254 | Because the BatchNorm is done over the `C` dimension, computing statistics 255 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 256 | or Spatio-temporal BatchNorm 257 | Args: 258 | num_features: num_features from an expected input of 259 | size batch_size x num_features x depth x height x width 260 | eps: a value added to the denominator for numerical stability. 261 | Default: 1e-5 262 | momentum: the value used for the running_mean and running_var 263 | computation. Default: 0.1 264 | affine: a boolean value that when set to ``True``, gives the layer learnable 265 | affine parameters. Default: ``True`` 266 | Shape: 267 | - Input: :math:`(N, C, D, H, W)` 268 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 269 | Examples: 270 | >>> # With Learnable Parameters 271 | >>> m = SynchronizedBatchNorm3d(100) 272 | >>> # Without Learnable Parameters 273 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 274 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 275 | >>> output = m(input) 276 | """ 277 | 278 | def _check_input_dim(self, input): 279 | if input.dim() != 5: 280 | raise ValueError('expected 5D input (got {}D input)' 281 | .format(input.dim())) 282 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /model_ECCV2020/networks/deeplab/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | .. math:: 132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 134 | standard-deviation are reduced across all devices during training. 135 | For example, when one uses `nn.DataParallel` to wrap the network during 136 | training, PyTorch's implementation normalize the tensor on each device using 137 | the statistics only on that device, which accelerated the computation and 138 | is also easy to implement, but the statistics might be inaccurate. 139 | Instead, in this synchronized version, the statistics will be computed 140 | over all training samples distributed on multiple devices. 141 | 142 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 143 | as the built-in PyTorch implementation. 144 | The mean and standard-deviation are calculated per-dimension over 145 | the mini-batches and gamma and beta are learnable parameter vectors 146 | of size C (where C is the input size). 147 | During training, this layer keeps a running estimate of its computed mean 148 | and variance. The running sum is kept with a default momentum of 0.1. 149 | During evaluation, this running mean/variance is used for normalization. 150 | Because the BatchNorm is done over the `C` dimension, computing statistics 151 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 152 | Args: 153 | num_features: num_features from an expected input of size 154 | `batch_size x num_features [x width]` 155 | eps: a value added to the denominator for numerical stability. 156 | Default: 1e-5 157 | momentum: the value used for the running_mean and running_var 158 | computation. Default: 0.1 159 | affine: a boolean value that when set to ``True``, gives the layer learnable 160 | affine parameters. Default: ``True`` 161 | Shape: 162 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 163 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 164 | Examples: 165 | >>> # With Learnable Parameters 166 | >>> m = SynchronizedBatchNorm1d(100) 167 | >>> # Without Learnable Parameters 168 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 169 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 170 | >>> output = m(input) 171 | """ 172 | 173 | def _check_input_dim(self, input): 174 | if input.dim() != 2 and input.dim() != 3: 175 | raise ValueError('expected 2D or 3D input (got {}D input)' 176 | .format(input.dim())) 177 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 178 | 179 | 180 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 181 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 182 | of 3d inputs 183 | .. math:: 184 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 185 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 186 | standard-deviation are reduced across all devices during training. 187 | For example, when one uses `nn.DataParallel` to wrap the network during 188 | training, PyTorch's implementation normalize the tensor on each device using 189 | the statistics only on that device, which accelerated the computation and 190 | is also easy to implement, but the statistics might be inaccurate. 191 | Instead, in this synchronized version, the statistics will be computed 192 | over all training samples distributed on multiple devices. 193 | 194 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 195 | as the built-in PyTorch implementation. 196 | The mean and standard-deviation are calculated per-dimension over 197 | the mini-batches and gamma and beta are learnable parameter vectors 198 | of size C (where C is the input size). 199 | During training, this layer keeps a running estimate of its computed mean 200 | and variance. The running sum is kept with a default momentum of 0.1. 201 | During evaluation, this running mean/variance is used for normalization. 202 | Because the BatchNorm is done over the `C` dimension, computing statistics 203 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 204 | Args: 205 | num_features: num_features from an expected input of 206 | size batch_size x num_features x height x width 207 | eps: a value added to the denominator for numerical stability. 208 | Default: 1e-5 209 | momentum: the value used for the running_mean and running_var 210 | computation. Default: 0.1 211 | affine: a boolean value that when set to ``True``, gives the layer learnable 212 | affine parameters. Default: ``True`` 213 | Shape: 214 | - Input: :math:`(N, C, H, W)` 215 | - Output: :math:`(N, C, H, W)` (same shape as input) 216 | Examples: 217 | >>> # With Learnable Parameters 218 | >>> m = SynchronizedBatchNorm2d(100) 219 | >>> # Without Learnable Parameters 220 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 221 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 222 | >>> output = m(input) 223 | """ 224 | 225 | def _check_input_dim(self, input): 226 | if input.dim() != 4: 227 | raise ValueError('expected 4D input (got {}D input)' 228 | .format(input.dim())) 229 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 230 | 231 | 232 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 233 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 234 | of 4d inputs 235 | .. math:: 236 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 237 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 238 | standard-deviation are reduced across all devices during training. 239 | For example, when one uses `nn.DataParallel` to wrap the network during 240 | training, PyTorch's implementation normalize the tensor on each device using 241 | the statistics only on that device, which accelerated the computation and 242 | is also easy to implement, but the statistics might be inaccurate. 243 | Instead, in this synchronized version, the statistics will be computed 244 | over all training samples distributed on multiple devices. 245 | 246 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 247 | as the built-in PyTorch implementation. 248 | The mean and standard-deviation are calculated per-dimension over 249 | the mini-batches and gamma and beta are learnable parameter vectors 250 | of size C (where C is the input size). 251 | During training, this layer keeps a running estimate of its computed mean 252 | and variance. The running sum is kept with a default momentum of 0.1. 253 | During evaluation, this running mean/variance is used for normalization. 254 | Because the BatchNorm is done over the `C` dimension, computing statistics 255 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 256 | or Spatio-temporal BatchNorm 257 | Args: 258 | num_features: num_features from an expected input of 259 | size batch_size x num_features x depth x height x width 260 | eps: a value added to the denominator for numerical stability. 261 | Default: 1e-5 262 | momentum: the value used for the running_mean and running_var 263 | computation. Default: 0.1 264 | affine: a boolean value that when set to ``True``, gives the layer learnable 265 | affine parameters. Default: ``True`` 266 | Shape: 267 | - Input: :math:`(N, C, D, H, W)` 268 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 269 | Examples: 270 | >>> # With Learnable Parameters 271 | >>> m = SynchronizedBatchNorm3d(100) 272 | >>> # Without Learnable Parameters 273 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 274 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 275 | >>> output = m(input) 276 | """ 277 | 278 | def _check_input_dim(self, input): 279 | if input.dim() != 5: 280 | raise ValueError('expected 5D input (got {}D input)' 281 | .format(input.dim())) 282 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /model_CVPR2021/networks/deeplab/backbone/drn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | from ..sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | 6 | webroot = 'https://tigress-web.princeton.edu/~fy/drn/models/' 7 | 8 | model_urls = { 9 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 10 | 'drn-c-26': webroot + 'drn_c_26-ddedf421.pth', 11 | 'drn-c-42': webroot + 'drn_c_42-9d336e8c.pth', 12 | 'drn-c-58': webroot + 'drn_c_58-0a53a92c.pth', 13 | 'drn-d-22': webroot + 'drn_d_22-4bd2f8ea.pth', 14 | 'drn-d-38': webroot + 'drn_d_38-eebb45f0.pth', 15 | 'drn-d-54': webroot + 'drn_d_54-0e0534ff.pth', 16 | 'drn-d-105': webroot + 'drn_d_105-12b40979.pth' 17 | } 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1): 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=padding, bias=False, dilation=dilation) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None, 29 | dilation=(1, 1), residual=True, BatchNorm=None): 30 | super(BasicBlock, self).__init__() 31 | self.conv1 = conv3x3(inplanes, planes, stride, 32 | padding=dilation[0], dilation=dilation[0]) 33 | self.bn1 = BatchNorm(planes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes, 36 | padding=dilation[1], dilation=dilation[1]) 37 | self.bn2 = BatchNorm(planes) 38 | self.downsample = downsample 39 | self.stride = stride 40 | self.residual = residual 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | if self.residual: 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None, 65 | dilation=(1, 1), residual=True, BatchNorm=None): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = BatchNorm(planes) 69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 70 | padding=dilation[1], bias=False, 71 | dilation=dilation[1]) 72 | self.bn2 = BatchNorm(planes) 73 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 74 | self.bn3 = BatchNorm(planes * 4) 75 | self.relu = nn.ReLU(inplace=True) 76 | self.downsample = downsample 77 | self.stride = stride 78 | 79 | def forward(self, x): 80 | residual = x 81 | 82 | out = self.conv1(x) 83 | out = self.bn1(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv2(out) 87 | out = self.bn2(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv3(out) 91 | out = self.bn3(out) 92 | 93 | if self.downsample is not None: 94 | residual = self.downsample(x) 95 | 96 | out += residual 97 | out = self.relu(out) 98 | 99 | return out 100 | 101 | 102 | class DRN(nn.Module): 103 | 104 | def __init__(self, block, layers, arch='D', 105 | channels=(16, 32, 64, 128, 256, 512, 512, 512), 106 | BatchNorm=None): 107 | super(DRN, self).__init__() 108 | self.inplanes = channels[0] 109 | self.out_dim = channels[-1] 110 | self.arch = arch 111 | 112 | if arch == 'C': 113 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=7, stride=1, 114 | padding=3, bias=False) 115 | self.bn1 = BatchNorm(channels[0]) 116 | self.relu = nn.ReLU(inplace=True) 117 | 118 | self.layer1 = self._make_layer( 119 | BasicBlock, channels[0], layers[0], stride=1, BatchNorm=BatchNorm) 120 | self.layer2 = self._make_layer( 121 | BasicBlock, channels[1], layers[1], stride=2, BatchNorm=BatchNorm) 122 | 123 | elif arch == 'D': 124 | self.layer0 = nn.Sequential( 125 | nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3, 126 | bias=False), 127 | BatchNorm(channels[0]), 128 | nn.ReLU(inplace=True) 129 | ) 130 | 131 | self.layer1 = self._make_conv_layers( 132 | channels[0], layers[0], stride=1, BatchNorm=BatchNorm) 133 | self.layer2 = self._make_conv_layers( 134 | channels[1], layers[1], stride=2, BatchNorm=BatchNorm) 135 | 136 | self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2, BatchNorm=BatchNorm) 137 | self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2, BatchNorm=BatchNorm) 138 | self.layer5 = self._make_layer(block, channels[4], layers[4], 139 | dilation=2, new_level=False, BatchNorm=BatchNorm) 140 | self.layer6 = None if layers[5] == 0 else \ 141 | self._make_layer(block, channels[5], layers[5], dilation=4, 142 | new_level=False, BatchNorm=BatchNorm) 143 | 144 | if arch == 'C': 145 | self.layer7 = None if layers[6] == 0 else \ 146 | self._make_layer(BasicBlock, channels[6], layers[6], dilation=2, 147 | new_level=False, residual=False, BatchNorm=BatchNorm) 148 | self.layer8 = None if layers[7] == 0 else \ 149 | self._make_layer(BasicBlock, channels[7], layers[7], dilation=1, 150 | new_level=False, residual=False, BatchNorm=BatchNorm) 151 | elif arch == 'D': 152 | self.layer7 = None if layers[6] == 0 else \ 153 | self._make_conv_layers(channels[6], layers[6], dilation=2, BatchNorm=BatchNorm) 154 | self.layer8 = None if layers[7] == 0 else \ 155 | self._make_conv_layers(channels[7], layers[7], dilation=1, BatchNorm=BatchNorm) 156 | 157 | self._init_weight() 158 | 159 | def _init_weight(self): 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 163 | m.weight.data.normal_(0, math.sqrt(2. / n)) 164 | elif isinstance(m, SynchronizedBatchNorm2d): 165 | m.weight.data.fill_(1) 166 | m.bias.data.zero_() 167 | elif isinstance(m, nn.BatchNorm2d): 168 | m.weight.data.fill_(1) 169 | m.bias.data.zero_() 170 | 171 | 172 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, 173 | new_level=True, residual=True, BatchNorm=None): 174 | assert dilation == 1 or dilation % 2 == 0 175 | downsample = None 176 | if stride != 1 or self.inplanes != planes * block.expansion: 177 | downsample = nn.Sequential( 178 | nn.Conv2d(self.inplanes, planes * block.expansion, 179 | kernel_size=1, stride=stride, bias=False), 180 | BatchNorm(planes * block.expansion), 181 | ) 182 | 183 | layers = list() 184 | layers.append(block( 185 | self.inplanes, planes, stride, downsample, 186 | dilation=(1, 1) if dilation == 1 else ( 187 | dilation // 2 if new_level else dilation, dilation), 188 | residual=residual, BatchNorm=BatchNorm)) 189 | self.inplanes = planes * block.expansion 190 | for i in range(1, blocks): 191 | layers.append(block(self.inplanes, planes, residual=residual, 192 | dilation=(dilation, dilation), BatchNorm=BatchNorm)) 193 | 194 | return nn.Sequential(*layers) 195 | 196 | def _make_conv_layers(self, channels, convs, stride=1, dilation=1, BatchNorm=None): 197 | modules = [] 198 | for i in range(convs): 199 | modules.extend([ 200 | nn.Conv2d(self.inplanes, channels, kernel_size=3, 201 | stride=stride if i == 0 else 1, 202 | padding=dilation, bias=False, dilation=dilation), 203 | BatchNorm(channels), 204 | nn.ReLU(inplace=True)]) 205 | self.inplanes = channels 206 | return nn.Sequential(*modules) 207 | 208 | def forward(self, x): 209 | if self.arch == 'C': 210 | x = self.conv1(x) 211 | x = self.bn1(x) 212 | x = self.relu(x) 213 | elif self.arch == 'D': 214 | x = self.layer0(x) 215 | 216 | x = self.layer1(x) 217 | x = self.layer2(x) 218 | 219 | x = self.layer3(x) 220 | low_level_feat = x 221 | 222 | x = self.layer4(x) 223 | x = self.layer5(x) 224 | 225 | if self.layer6 is not None: 226 | x = self.layer6(x) 227 | 228 | if self.layer7 is not None: 229 | x = self.layer7(x) 230 | 231 | if self.layer8 is not None: 232 | x = self.layer8(x) 233 | 234 | return x, low_level_feat 235 | 236 | 237 | class DRN_A(nn.Module): 238 | 239 | def __init__(self, block, layers, BatchNorm=None): 240 | self.inplanes = 64 241 | super(DRN_A, self).__init__() 242 | self.out_dim = 512 * block.expansion 243 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 244 | bias=False) 245 | self.bn1 = BatchNorm(64) 246 | self.relu = nn.ReLU(inplace=True) 247 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 248 | self.layer1 = self._make_layer(block, 64, layers[0], BatchNorm=BatchNorm) 249 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, BatchNorm=BatchNorm) 250 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, 251 | dilation=2, BatchNorm=BatchNorm) 252 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 253 | dilation=4, BatchNorm=BatchNorm) 254 | 255 | self._init_weight() 256 | 257 | def _init_weight(self): 258 | for m in self.modules(): 259 | if isinstance(m, nn.Conv2d): 260 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 261 | m.weight.data.normal_(0, math.sqrt(2. / n)) 262 | elif isinstance(m, SynchronizedBatchNorm2d): 263 | m.weight.data.fill_(1) 264 | m.bias.data.zero_() 265 | elif isinstance(m, nn.BatchNorm2d): 266 | m.weight.data.fill_(1) 267 | m.bias.data.zero_() 268 | 269 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 270 | downsample = None 271 | if stride != 1 or self.inplanes != planes * block.expansion: 272 | downsample = nn.Sequential( 273 | nn.Conv2d(self.inplanes, planes * block.expansion, 274 | kernel_size=1, stride=stride, bias=False), 275 | BatchNorm(planes * block.expansion), 276 | ) 277 | 278 | layers = [] 279 | layers.append(block(self.inplanes, planes, stride, downsample, BatchNorm=BatchNorm)) 280 | self.inplanes = planes * block.expansion 281 | for i in range(1, blocks): 282 | layers.append(block(self.inplanes, planes, 283 | dilation=(dilation, dilation, ), BatchNorm=BatchNorm)) 284 | 285 | return nn.Sequential(*layers) 286 | 287 | def forward(self, x): 288 | x = self.conv1(x) 289 | x = self.bn1(x) 290 | x = self.relu(x) 291 | x = self.maxpool(x) 292 | 293 | x = self.layer1(x) 294 | x = self.layer2(x) 295 | x = self.layer3(x) 296 | x = self.layer4(x) 297 | 298 | return x 299 | 300 | def drn_a_50(BatchNorm, pretrained=True): 301 | model = DRN_A(Bottleneck, [3, 4, 6, 3], BatchNorm=BatchNorm) 302 | if pretrained: 303 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 304 | return model 305 | 306 | 307 | def drn_c_26(BatchNorm, pretrained=True): 308 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='C', BatchNorm=BatchNorm) 309 | if pretrained: 310 | pretrained = model_zoo.load_url(model_urls['drn-c-26']) 311 | del pretrained['fc.weight'] 312 | del pretrained['fc.bias'] 313 | model.load_state_dict(pretrained) 314 | return model 315 | 316 | 317 | def drn_c_42(BatchNorm, pretrained=True): 318 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm) 319 | if pretrained: 320 | pretrained = model_zoo.load_url(model_urls['drn-c-42']) 321 | del pretrained['fc.weight'] 322 | del pretrained['fc.bias'] 323 | model.load_state_dict(pretrained) 324 | return model 325 | 326 | 327 | def drn_c_58(BatchNorm, pretrained=True): 328 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm) 329 | if pretrained: 330 | pretrained = model_zoo.load_url(model_urls['drn-c-58']) 331 | del pretrained['fc.weight'] 332 | del pretrained['fc.bias'] 333 | model.load_state_dict(pretrained) 334 | return model 335 | 336 | 337 | def drn_d_22(BatchNorm, pretrained=True): 338 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='D', BatchNorm=BatchNorm) 339 | if pretrained: 340 | pretrained = model_zoo.load_url(model_urls['drn-d-22']) 341 | del pretrained['fc.weight'] 342 | del pretrained['fc.bias'] 343 | model.load_state_dict(pretrained) 344 | return model 345 | 346 | 347 | def drn_d_24(BatchNorm, pretrained=True): 348 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 2, 2], arch='D', BatchNorm=BatchNorm) 349 | if pretrained: 350 | pretrained = model_zoo.load_url(model_urls['drn-d-24']) 351 | del pretrained['fc.weight'] 352 | del pretrained['fc.bias'] 353 | model.load_state_dict(pretrained) 354 | return model 355 | 356 | 357 | def drn_d_38(BatchNorm, pretrained=True): 358 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm) 359 | if pretrained: 360 | pretrained = model_zoo.load_url(model_urls['drn-d-38']) 361 | del pretrained['fc.weight'] 362 | del pretrained['fc.bias'] 363 | model.load_state_dict(pretrained) 364 | return model 365 | 366 | 367 | def drn_d_40(BatchNorm, pretrained=True): 368 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 2, 2], arch='D', BatchNorm=BatchNorm) 369 | if pretrained: 370 | pretrained = model_zoo.load_url(model_urls['drn-d-40']) 371 | del pretrained['fc.weight'] 372 | del pretrained['fc.bias'] 373 | model.load_state_dict(pretrained) 374 | return model 375 | 376 | 377 | def drn_d_54(BatchNorm, pretrained=True): 378 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm) 379 | if pretrained: 380 | pretrained = model_zoo.load_url(model_urls['drn-d-54']) 381 | del pretrained['fc.weight'] 382 | del pretrained['fc.bias'] 383 | model.load_state_dict(pretrained) 384 | return model 385 | 386 | 387 | def drn_d_105(BatchNorm, pretrained=True): 388 | model = DRN(Bottleneck, [1, 1, 3, 4, 23, 3, 1, 1], arch='D', BatchNorm=BatchNorm) 389 | if pretrained: 390 | pretrained = model_zoo.load_url(model_urls['drn-d-105']) 391 | del pretrained['fc.weight'] 392 | del pretrained['fc.bias'] 393 | model.load_state_dict(pretrained) 394 | return model 395 | 396 | if __name__ == "__main__": 397 | import torch 398 | model = drn_a_50(BatchNorm=nn.BatchNorm2d, pretrained=True) 399 | input = torch.rand(1, 3, 512, 512) 400 | output, low_level_feat = model(input) 401 | print(output.size()) 402 | print(low_level_feat.size()) 403 | -------------------------------------------------------------------------------- /model_ECCV2020/networks/deeplab/backbone/drn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | from ..sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | 6 | webroot = 'https://tigress-web.princeton.edu/~fy/drn/models/' 7 | 8 | model_urls = { 9 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 10 | 'drn-c-26': webroot + 'drn_c_26-ddedf421.pth', 11 | 'drn-c-42': webroot + 'drn_c_42-9d336e8c.pth', 12 | 'drn-c-58': webroot + 'drn_c_58-0a53a92c.pth', 13 | 'drn-d-22': webroot + 'drn_d_22-4bd2f8ea.pth', 14 | 'drn-d-38': webroot + 'drn_d_38-eebb45f0.pth', 15 | 'drn-d-54': webroot + 'drn_d_54-0e0534ff.pth', 16 | 'drn-d-105': webroot + 'drn_d_105-12b40979.pth' 17 | } 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1): 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=padding, bias=False, dilation=dilation) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None, 29 | dilation=(1, 1), residual=True, BatchNorm=None): 30 | super(BasicBlock, self).__init__() 31 | self.conv1 = conv3x3(inplanes, planes, stride, 32 | padding=dilation[0], dilation=dilation[0]) 33 | self.bn1 = BatchNorm(planes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes, 36 | padding=dilation[1], dilation=dilation[1]) 37 | self.bn2 = BatchNorm(planes) 38 | self.downsample = downsample 39 | self.stride = stride 40 | self.residual = residual 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | if self.residual: 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None, 65 | dilation=(1, 1), residual=True, BatchNorm=None): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = BatchNorm(planes) 69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 70 | padding=dilation[1], bias=False, 71 | dilation=dilation[1]) 72 | self.bn2 = BatchNorm(planes) 73 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 74 | self.bn3 = BatchNorm(planes * 4) 75 | self.relu = nn.ReLU(inplace=True) 76 | self.downsample = downsample 77 | self.stride = stride 78 | 79 | def forward(self, x): 80 | residual = x 81 | 82 | out = self.conv1(x) 83 | out = self.bn1(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv2(out) 87 | out = self.bn2(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv3(out) 91 | out = self.bn3(out) 92 | 93 | if self.downsample is not None: 94 | residual = self.downsample(x) 95 | 96 | out += residual 97 | out = self.relu(out) 98 | 99 | return out 100 | 101 | 102 | class DRN(nn.Module): 103 | 104 | def __init__(self, block, layers, arch='D', 105 | channels=(16, 32, 64, 128, 256, 512, 512, 512), 106 | BatchNorm=None): 107 | super(DRN, self).__init__() 108 | self.inplanes = channels[0] 109 | self.out_dim = channels[-1] 110 | self.arch = arch 111 | 112 | if arch == 'C': 113 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=7, stride=1, 114 | padding=3, bias=False) 115 | self.bn1 = BatchNorm(channels[0]) 116 | self.relu = nn.ReLU(inplace=True) 117 | 118 | self.layer1 = self._make_layer( 119 | BasicBlock, channels[0], layers[0], stride=1, BatchNorm=BatchNorm) 120 | self.layer2 = self._make_layer( 121 | BasicBlock, channels[1], layers[1], stride=2, BatchNorm=BatchNorm) 122 | 123 | elif arch == 'D': 124 | self.layer0 = nn.Sequential( 125 | nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3, 126 | bias=False), 127 | BatchNorm(channels[0]), 128 | nn.ReLU(inplace=True) 129 | ) 130 | 131 | self.layer1 = self._make_conv_layers( 132 | channels[0], layers[0], stride=1, BatchNorm=BatchNorm) 133 | self.layer2 = self._make_conv_layers( 134 | channels[1], layers[1], stride=2, BatchNorm=BatchNorm) 135 | 136 | self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2, BatchNorm=BatchNorm) 137 | self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2, BatchNorm=BatchNorm) 138 | self.layer5 = self._make_layer(block, channels[4], layers[4], 139 | dilation=2, new_level=False, BatchNorm=BatchNorm) 140 | self.layer6 = None if layers[5] == 0 else \ 141 | self._make_layer(block, channels[5], layers[5], dilation=4, 142 | new_level=False, BatchNorm=BatchNorm) 143 | 144 | if arch == 'C': 145 | self.layer7 = None if layers[6] == 0 else \ 146 | self._make_layer(BasicBlock, channels[6], layers[6], dilation=2, 147 | new_level=False, residual=False, BatchNorm=BatchNorm) 148 | self.layer8 = None if layers[7] == 0 else \ 149 | self._make_layer(BasicBlock, channels[7], layers[7], dilation=1, 150 | new_level=False, residual=False, BatchNorm=BatchNorm) 151 | elif arch == 'D': 152 | self.layer7 = None if layers[6] == 0 else \ 153 | self._make_conv_layers(channels[6], layers[6], dilation=2, BatchNorm=BatchNorm) 154 | self.layer8 = None if layers[7] == 0 else \ 155 | self._make_conv_layers(channels[7], layers[7], dilation=1, BatchNorm=BatchNorm) 156 | 157 | self._init_weight() 158 | 159 | def _init_weight(self): 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 163 | m.weight.data.normal_(0, math.sqrt(2. / n)) 164 | elif isinstance(m, SynchronizedBatchNorm2d): 165 | m.weight.data.fill_(1) 166 | m.bias.data.zero_() 167 | elif isinstance(m, nn.BatchNorm2d): 168 | m.weight.data.fill_(1) 169 | m.bias.data.zero_() 170 | 171 | 172 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, 173 | new_level=True, residual=True, BatchNorm=None): 174 | assert dilation == 1 or dilation % 2 == 0 175 | downsample = None 176 | if stride != 1 or self.inplanes != planes * block.expansion: 177 | downsample = nn.Sequential( 178 | nn.Conv2d(self.inplanes, planes * block.expansion, 179 | kernel_size=1, stride=stride, bias=False), 180 | BatchNorm(planes * block.expansion), 181 | ) 182 | 183 | layers = list() 184 | layers.append(block( 185 | self.inplanes, planes, stride, downsample, 186 | dilation=(1, 1) if dilation == 1 else ( 187 | dilation // 2 if new_level else dilation, dilation), 188 | residual=residual, BatchNorm=BatchNorm)) 189 | self.inplanes = planes * block.expansion 190 | for i in range(1, blocks): 191 | layers.append(block(self.inplanes, planes, residual=residual, 192 | dilation=(dilation, dilation), BatchNorm=BatchNorm)) 193 | 194 | return nn.Sequential(*layers) 195 | 196 | def _make_conv_layers(self, channels, convs, stride=1, dilation=1, BatchNorm=None): 197 | modules = [] 198 | for i in range(convs): 199 | modules.extend([ 200 | nn.Conv2d(self.inplanes, channels, kernel_size=3, 201 | stride=stride if i == 0 else 1, 202 | padding=dilation, bias=False, dilation=dilation), 203 | BatchNorm(channels), 204 | nn.ReLU(inplace=True)]) 205 | self.inplanes = channels 206 | return nn.Sequential(*modules) 207 | 208 | def forward(self, x): 209 | if self.arch == 'C': 210 | x = self.conv1(x) 211 | x = self.bn1(x) 212 | x = self.relu(x) 213 | elif self.arch == 'D': 214 | x = self.layer0(x) 215 | 216 | x = self.layer1(x) 217 | x = self.layer2(x) 218 | 219 | x = self.layer3(x) 220 | low_level_feat = x 221 | 222 | x = self.layer4(x) 223 | x = self.layer5(x) 224 | 225 | if self.layer6 is not None: 226 | x = self.layer6(x) 227 | 228 | if self.layer7 is not None: 229 | x = self.layer7(x) 230 | 231 | if self.layer8 is not None: 232 | x = self.layer8(x) 233 | 234 | return x, low_level_feat 235 | 236 | 237 | class DRN_A(nn.Module): 238 | 239 | def __init__(self, block, layers, BatchNorm=None): 240 | self.inplanes = 64 241 | super(DRN_A, self).__init__() 242 | self.out_dim = 512 * block.expansion 243 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 244 | bias=False) 245 | self.bn1 = BatchNorm(64) 246 | self.relu = nn.ReLU(inplace=True) 247 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 248 | self.layer1 = self._make_layer(block, 64, layers[0], BatchNorm=BatchNorm) 249 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, BatchNorm=BatchNorm) 250 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, 251 | dilation=2, BatchNorm=BatchNorm) 252 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 253 | dilation=4, BatchNorm=BatchNorm) 254 | 255 | self._init_weight() 256 | 257 | def _init_weight(self): 258 | for m in self.modules(): 259 | if isinstance(m, nn.Conv2d): 260 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 261 | m.weight.data.normal_(0, math.sqrt(2. / n)) 262 | elif isinstance(m, SynchronizedBatchNorm2d): 263 | m.weight.data.fill_(1) 264 | m.bias.data.zero_() 265 | elif isinstance(m, nn.BatchNorm2d): 266 | m.weight.data.fill_(1) 267 | m.bias.data.zero_() 268 | 269 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 270 | downsample = None 271 | if stride != 1 or self.inplanes != planes * block.expansion: 272 | downsample = nn.Sequential( 273 | nn.Conv2d(self.inplanes, planes * block.expansion, 274 | kernel_size=1, stride=stride, bias=False), 275 | BatchNorm(planes * block.expansion), 276 | ) 277 | 278 | layers = [] 279 | layers.append(block(self.inplanes, planes, stride, downsample, BatchNorm=BatchNorm)) 280 | self.inplanes = planes * block.expansion 281 | for i in range(1, blocks): 282 | layers.append(block(self.inplanes, planes, 283 | dilation=(dilation, dilation, ), BatchNorm=BatchNorm)) 284 | 285 | return nn.Sequential(*layers) 286 | 287 | def forward(self, x): 288 | x = self.conv1(x) 289 | x = self.bn1(x) 290 | x = self.relu(x) 291 | x = self.maxpool(x) 292 | 293 | x = self.layer1(x) 294 | x = self.layer2(x) 295 | x = self.layer3(x) 296 | x = self.layer4(x) 297 | 298 | return x 299 | 300 | def drn_a_50(BatchNorm, pretrained=True): 301 | model = DRN_A(Bottleneck, [3, 4, 6, 3], BatchNorm=BatchNorm) 302 | if pretrained: 303 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 304 | return model 305 | 306 | 307 | def drn_c_26(BatchNorm, pretrained=True): 308 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='C', BatchNorm=BatchNorm) 309 | if pretrained: 310 | pretrained = model_zoo.load_url(model_urls['drn-c-26']) 311 | del pretrained['fc.weight'] 312 | del pretrained['fc.bias'] 313 | model.load_state_dict(pretrained) 314 | return model 315 | 316 | 317 | def drn_c_42(BatchNorm, pretrained=True): 318 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm) 319 | if pretrained: 320 | pretrained = model_zoo.load_url(model_urls['drn-c-42']) 321 | del pretrained['fc.weight'] 322 | del pretrained['fc.bias'] 323 | model.load_state_dict(pretrained) 324 | return model 325 | 326 | 327 | def drn_c_58(BatchNorm, pretrained=True): 328 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm) 329 | if pretrained: 330 | pretrained = model_zoo.load_url(model_urls['drn-c-58']) 331 | del pretrained['fc.weight'] 332 | del pretrained['fc.bias'] 333 | model.load_state_dict(pretrained) 334 | return model 335 | 336 | 337 | def drn_d_22(BatchNorm, pretrained=True): 338 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='D', BatchNorm=BatchNorm) 339 | if pretrained: 340 | pretrained = model_zoo.load_url(model_urls['drn-d-22']) 341 | del pretrained['fc.weight'] 342 | del pretrained['fc.bias'] 343 | model.load_state_dict(pretrained) 344 | return model 345 | 346 | 347 | def drn_d_24(BatchNorm, pretrained=True): 348 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 2, 2], arch='D', BatchNorm=BatchNorm) 349 | if pretrained: 350 | pretrained = model_zoo.load_url(model_urls['drn-d-24']) 351 | del pretrained['fc.weight'] 352 | del pretrained['fc.bias'] 353 | model.load_state_dict(pretrained) 354 | return model 355 | 356 | 357 | def drn_d_38(BatchNorm, pretrained=True): 358 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm) 359 | if pretrained: 360 | pretrained = model_zoo.load_url(model_urls['drn-d-38']) 361 | del pretrained['fc.weight'] 362 | del pretrained['fc.bias'] 363 | model.load_state_dict(pretrained) 364 | return model 365 | 366 | 367 | def drn_d_40(BatchNorm, pretrained=True): 368 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 2, 2], arch='D', BatchNorm=BatchNorm) 369 | if pretrained: 370 | pretrained = model_zoo.load_url(model_urls['drn-d-40']) 371 | del pretrained['fc.weight'] 372 | del pretrained['fc.bias'] 373 | model.load_state_dict(pretrained) 374 | return model 375 | 376 | 377 | def drn_d_54(BatchNorm, pretrained=True): 378 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm) 379 | if pretrained: 380 | pretrained = model_zoo.load_url(model_urls['drn-d-54']) 381 | del pretrained['fc.weight'] 382 | del pretrained['fc.bias'] 383 | model.load_state_dict(pretrained) 384 | return model 385 | 386 | 387 | def drn_d_105(BatchNorm, pretrained=True): 388 | model = DRN(Bottleneck, [1, 1, 3, 4, 23, 3, 1, 1], arch='D', BatchNorm=BatchNorm) 389 | if pretrained: 390 | pretrained = model_zoo.load_url(model_urls['drn-d-105']) 391 | del pretrained['fc.weight'] 392 | del pretrained['fc.bias'] 393 | model.load_state_dict(pretrained) 394 | return model 395 | 396 | if __name__ == "__main__": 397 | import torch 398 | model = drn_a_50(BatchNorm=nn.BatchNorm2d, pretrained=True) 399 | input = torch.rand(1, 3, 512, 512) 400 | output, low_level_feat = model(input) 401 | print(output.size()) 402 | print(low_level_feat.size()) 403 | -------------------------------------------------------------------------------- /model_ECCV2020/networks/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..networks.deeplab.aspp import ASPP 6 | from ..networks.deeplab.backbone.resnet import SEResNet50 7 | from ..networks.correlation_package.correlation import Correlation 8 | from ..networks.ltm_transfer import LTM_transfer 9 | 10 | 11 | class ATnet(nn.Module): 12 | def __init__(self, pretrained=1, resfix=False, corr_displacement=4, corr_stride=2): 13 | super(ATnet, self).__init__() 14 | print("Constructing ATnet architecture..") 15 | 16 | self.encoder_6ch = Encoder_6ch(resfix) 17 | self.encoder_3ch = Encoder_3ch(resfix) 18 | self.indicator_encoder = ConverterEncoder() # 19 | self.decoder_iact = Decoder() 20 | self.decoder_prop = Decoder_prop() 21 | 22 | self.ltm_local_affinity = Correlation(pad_size=corr_displacement * corr_stride, kernel_size=1, 23 | max_displacement=corr_displacement * corr_stride, 24 | stride1=1, stride2=corr_stride, corr_multiply=1) 25 | self.ltm_transfer = LTM_transfer(md=corr_displacement, stride=corr_stride) 26 | 27 | self.prev_conv1x1 = nn.Conv2d(256, 256, kernel_size=1, padding=0) # 1/4, 256 28 | self.conv1x1 = nn.Conv2d(2048*2, 2048, kernel_size=1, padding=0) # 1/16, 2048 29 | 30 | self.refer_weight = None 31 | self._initialize_weights(pretrained) 32 | 33 | def forward_ANet(self, x): # Bx4xHxW to Bx1xHxW 34 | r5, r4, r3, r2 = self.encoder_6ch(x) 35 | estimated_mask, m2 = self.decoder_iact(r5, r3, r2, only_return_feature=False) 36 | r5_indicator = self.indicator_encoder(r5, m2) 37 | return estimated_mask, r5_indicator 38 | 39 | def forward_TNet(self, anno_propEnc_r5_list, targframe_3ch, anno_iactEnc_r5_list, r2_prev, predmask_prev, debug_f_mask = False): #1/16, 2048 40 | f_targ, _, r3_targ, r2_targ = self.encoder_3ch(targframe_3ch) 41 | f_mask_r5 = self.correlation_global_transfer(anno_propEnc_r5_list, f_targ, anno_iactEnc_r5_list) # 1/16, 2048 42 | 43 | r2_targ_c = self.prev_conv1x1(r2_targ) 44 | r2_prev = self.prev_conv1x1(r2_prev) 45 | f_mask_r2 = self.correlation_local_transfer(r2_prev, r2_targ_c, predmask_prev) # 1/4, 1 [B,1,H/4,W/4] 46 | 47 | r5_concat = torch.cat([f_targ, f_mask_r5], dim=1) # 1/16, 2048*2 48 | r5_concat = self.conv1x1(r5_concat) 49 | estimated_mask, m2 = self.decoder_prop(r5_concat, r3_targ, r2_targ, f_mask_r2) 50 | 51 | if not debug_f_mask: 52 | return estimated_mask, r2_targ 53 | else: 54 | return estimated_mask, r2_targ, f_mask_r2 55 | 56 | def correlation_global_transfer(self, anno_feature_list, targ_feature, anno_indicator_feature_list ): 57 | ''' 58 | :param anno_feature_list: [B,C,H,W] x list (N values in list) 59 | :param targ_feature: [B,C,H,W] 60 | :param anno_indicator_feature_list: [B,C,H,W] x list (N values in list) 61 | :return targ_mask_feature: [B,C,H,W] 62 | ''' 63 | 64 | b, c, h, w = anno_indicator_feature_list[0].size() # b means n_objs 65 | targ_feature = targ_feature.view(b, c, h * w) # [B, C, HxW] 66 | n_features = len(anno_feature_list) 67 | anno_feature = [] 68 | for f_idx in range(n_features): 69 | anno_feature.append(anno_feature_list[f_idx].view(b, c, h * w).transpose(1, 2)) # [B, HxW', C] 70 | anno_feature = torch.cat(anno_feature, dim=1) # [B, NxHxW', C] 71 | sim_feature = torch.bmm(anno_feature, targ_feature) # [B, NxHxW', HxW] 72 | sim_feature = F.softmax(sim_feature, dim=2) / n_features # [B, NxHxW', HxW] 73 | anno_indicator_feature = [] 74 | for f_idx in range(n_features): 75 | anno_indicator_feature.append(anno_indicator_feature_list[f_idx].view(b, c, h * w)) # [B, C, HxW'] 76 | anno_indicator_feature = torch.cat(anno_indicator_feature, dim=-1) # [B, C, NxHxW'] 77 | targ_mask_feature = torch.bmm(anno_indicator_feature, sim_feature) # [B, C, HxW] 78 | targ_mask_feature = targ_mask_feature.view(b, c, h, w) 79 | 80 | return targ_mask_feature 81 | 82 | def correlation_local_transfer(self, r2_prev, r2_targ, predmask_prev): 83 | ''' 84 | 85 | :param r2_prev: [B,C,H,W] 86 | :param r2_targ: [B,C,H,W] 87 | :param predmask_prev: [B,1,4*H,4*W] 88 | :return targ_mask_feature_r2: [B,1,H,W] 89 | ''' 90 | 91 | predmask_prev = F.interpolate(predmask_prev, scale_factor=0.25, mode='bilinear',align_corners=True) # B,1,H,W 92 | sim_feature = self.ltm_local_affinity.forward(r2_targ,r2_prev,) # B,D^2,H,W 93 | sim_feature = F.softmax(sim_feature, dim=2) # B,D^2,H,W 94 | predmask_targ = self.ltm_transfer.forward(sim_feature, predmask_prev, apply_softmax_on_simfeature = False) # B,1,H,W 95 | 96 | return predmask_targ 97 | 98 | def _initialize_weights(self, pretrained): 99 | for m in self.modules(): 100 | if pretrained: 101 | break 102 | else: 103 | if isinstance(m, nn.Conv2d): 104 | m.weight.data.normal_(0, 0.001) 105 | if m.bias is not None: 106 | m.bias.data.zero_() 107 | elif isinstance(m, nn.BatchNorm2d): 108 | m.weight.data.fill_(1) 109 | m.bias.data.zero_() 110 | elif isinstance(m, nn.Linear): 111 | m.weight.data.normal_(0, 0.01) 112 | m.bias.data.zero_() 113 | 114 | 115 | class Encoder_3ch(nn.Module): 116 | # T-Net Encoder 117 | def __init__(self, resfix): 118 | super(Encoder_3ch, self).__init__() 119 | 120 | self.conv0_3ch = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=True) 121 | 122 | resnet = SEResNet50(output_stride=16, BatchNorm=nn.BatchNorm2d, pretrained=True) 123 | self.bn1 = resnet.bn1 124 | self.relu = resnet.relu # 1/2, 64 125 | self.maxpool = resnet.maxpool 126 | 127 | self.res2 = resnet.layer1 # 1/4, 256 128 | self.res3 = resnet.layer2 # 1/8, 512 129 | self.res4 = resnet.layer3 # 1/16, 1024 130 | self.res5 = resnet.layer4 # 1/16, 2048 131 | 132 | # freeze BNs 133 | if resfix: 134 | for m in self.modules(): 135 | if isinstance(m, nn.BatchNorm2d): 136 | for p in m.parameters(): 137 | p.requires_grad = False 138 | 139 | def forward(self, x): 140 | x = self.conv0_3ch(x) # 1/2, 64 141 | x = self.bn1(x) 142 | c1 = self.relu(x) # 1/2, 64 143 | x = self.maxpool(c1) # 1/4, 64 144 | r2 = self.res2(x) # 1/4, 256 145 | r3 = self.res3(r2) # 1/8, 512 146 | r4 = self.res4(r3) # 1/16, 1024 147 | r5 = self.res5(r4) # 1/16, 2048 148 | 149 | return r5, r4, r3, r2 150 | 151 | def forward_r2(self,x): 152 | x = self.conv0_3ch(x) # 1/2, 64 153 | x = self.bn1(x) 154 | c1 = self.relu(x) # 1/2, 64 155 | x = self.maxpool(c1) # 1/4, 64 156 | r2 = self.res2(x) # 1/4, 256 157 | return r2 158 | 159 | 160 | class Encoder_6ch(nn.Module): 161 | # A-Net Encoder 162 | def __init__(self, resfix): 163 | super(Encoder_6ch, self).__init__() 164 | 165 | self.conv0_6ch = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=True) 166 | 167 | resnet = SEResNet50(output_stride=16, BatchNorm=nn.BatchNorm2d, pretrained=True) 168 | self.bn1 = resnet.bn1 169 | self.relu = resnet.relu # 1/2, 64 170 | self.maxpool = resnet.maxpool 171 | 172 | self.res2 = resnet.layer1 # 1/4, 256 173 | self.res3 = resnet.layer2 # 1/8, 512 174 | self.res4 = resnet.layer3 # 1/16, 1024 175 | self.res5 = resnet.layer4 # 1/16, 2048 176 | 177 | # freeze BNs 178 | if resfix: 179 | for m in self.modules(): 180 | if isinstance(m, nn.BatchNorm2d): 181 | for p in m.parameters(): 182 | p.requires_grad = False 183 | 184 | def forward(self, x): 185 | 186 | x = self.conv0_6ch(x) # 1/2, 64 187 | x = self.bn1(x) 188 | c1 = self.relu(x) # 1/2, 64 189 | x = self.maxpool(c1) # 1/4, 64 190 | r2 = self.res2(x) # 1/4, 256 191 | r3 = self.res3(r2) # 1/8, 512 192 | r4 = self.res4(r3) # 1/16, 1024 193 | r5 = self.res5(r4) # 1/16, 2048 194 | 195 | return r5, r4, r3, r2 196 | 197 | 198 | class Decoder(nn.Module): 199 | # A-Net Decoder 200 | def __init__(self): 201 | super(Decoder, self).__init__() 202 | mdim = 256 203 | 204 | self.aspp_decoder = ASPP(backbone='res', output_stride=16, BatchNorm=nn.BatchNorm2d, pretrained=1) 205 | self.convG0 = nn.Conv2d(2048, mdim, kernel_size=3, padding=1) 206 | self.convG1 = nn.Conv2d(mdim, mdim, kernel_size=3, padding=1) 207 | self.convG2 = nn.Conv2d(mdim, mdim, kernel_size=3, padding=1) 208 | 209 | self.RF3 = Refine(512, mdim) # 1/16 -> 1/8 210 | self.RF2 = Refine(256, mdim) # 1/8 -> 1/4 211 | 212 | self.lastconv = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1, bias=False), 213 | nn.BatchNorm2d(256), 214 | nn.ReLU(), 215 | nn.Dropout(0.5), 216 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 217 | nn.BatchNorm2d(256), 218 | nn.ReLU(), 219 | nn.Dropout(0.1), 220 | nn.Conv2d(256, 1, kernel_size=1, stride=1)) 221 | 222 | def forward(self, r5, r3_targ, r2_targ, only_return_feature = False): 223 | 224 | aspp_out = self.aspp_decoder(r5) #1/16 mdim 225 | aspp_out = F.interpolate(aspp_out, scale_factor=4, mode='bilinear',align_corners=True) #1/4 mdim 226 | m4 = self.convG0(F.relu(r5)) # out: # 1/16, mdim 227 | m4 = self.convG1(F.relu(m4)) # out: # 1/16, mdim 228 | m4 = self.convG2(F.relu(m4)) # out: # 1/16, mdim 229 | 230 | 231 | m3 = self.RF3(r3_targ, m4) # out: 1/8, mdim 232 | m2 = self.RF2(r2_targ, m3) # out: 1/4, mdim 233 | m2 = torch.cat((m2, aspp_out), dim=1) # out: 1/4, mdim*2 234 | 235 | if only_return_feature: 236 | return m2 237 | 238 | x = self.lastconv(m2) 239 | x = F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=True) 240 | 241 | return x, m2 242 | 243 | 244 | class Decoder_prop(nn.Module): 245 | # T-Net Decoder 246 | def __init__(self): 247 | super(Decoder_prop, self).__init__() 248 | mdim = 256 249 | 250 | self.aspp_decoder = ASPP(backbone='res', output_stride=16, BatchNorm=nn.BatchNorm2d, pretrained=1) 251 | self.convG0 = nn.Conv2d(2048, mdim, kernel_size=3, padding=1) 252 | self.convG1 = nn.Conv2d(mdim, mdim, kernel_size=3, padding=1) 253 | self.convG2 = nn.Conv2d(mdim, mdim, kernel_size=3, padding=1) 254 | 255 | self.RF3 = Refine(512, mdim) # 1/16 -> 1/8 256 | self.RF2 = Refine(256, mdim) # 1/8 -> 1/4 257 | 258 | self.lastconv = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1, bias=False), 259 | nn.BatchNorm2d(256), 260 | nn.ReLU(), 261 | nn.Dropout(0.5), 262 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 263 | nn.BatchNorm2d(256), 264 | nn.ReLU(), 265 | nn.Dropout(0.1), 266 | nn.Conv2d(256, 1, kernel_size=1, stride=1)) 267 | 268 | def forward(self, r5, r3_targ, r2_targ, f_mask_r2): 269 | 270 | aspp_out = self.aspp_decoder(r5) #1/16 mdim 271 | aspp_out = F.interpolate(aspp_out, scale_factor=4, mode='bilinear',align_corners=True) #1/4 mdim 272 | m4 = self.convG0(F.relu(r5)) # out: # 1/16, mdim 273 | m4 = self.convG1(F.relu(m4)) # out: # 1/16, mdim 274 | m4 = self.convG2(F.relu(m4)) # out: # 1/16, mdim 275 | 276 | m3 = self.RF3(r3_targ, m4) # out: 1/8, mdim 277 | m3 = m3 + 0.5 * F.interpolate(f_mask_r2, scale_factor=0.5, mode='bilinear',align_corners=True) #1/4 mdim 278 | m2 = self.RF2(r2_targ, m3) # out: 1/4, mdim 279 | m2 = m2 + 0.5 * f_mask_r2 280 | m2 = torch.cat((m2, aspp_out), dim=1) # out: 1/4, mdim*2 281 | 282 | x = self.lastconv(m2) 283 | x = F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=True) 284 | 285 | return x, m2 286 | 287 | 288 | class ConverterEncoder(nn.Module): 289 | def __init__(self): 290 | super(ConverterEncoder, self).__init__() 291 | # [1/4, 512] to [1/8, 1024] 292 | downsample1 = nn.Sequential(nn.Conv2d(512, 1024, kernel_size=1, stride=2, bias=False), 293 | nn.BatchNorm2d(1024), 294 | ) 295 | self.block1 = SEBottleneck(512, 256, stride = 2, downsample = downsample1) 296 | # [1/8, 1024] to [1/16, 2048] 297 | downsample2 = nn.Sequential(nn.Conv2d(1024, 2048, kernel_size=1, stride=2, bias=False), 298 | nn.BatchNorm2d(2048), 299 | ) 300 | self.block2 = SEBottleneck(1024, 512, stride = 2, downsample=downsample2) 301 | self.conv1x1 = nn.Conv2d(2048 * 2, 2048, kernel_size=1, padding=0) # 1/16, 2048 302 | 303 | def forward(self, r5, m2): 304 | ''' 305 | 306 | :param r5: 1/16, 2048 307 | :param m2: 1/4, 512 308 | :return: 309 | ''' 310 | x = self.block1(m2) 311 | x = self.block2(x) 312 | x = torch.cat((x,r5),dim=1) 313 | x = self.conv1x1(x) 314 | 315 | return x 316 | 317 | 318 | class SEBottleneck(nn.Module): 319 | expansion = 4 320 | 321 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=nn.BatchNorm2d): 322 | super(SEBottleneck, self).__init__() 323 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 324 | self.bn1 = BatchNorm(planes) 325 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 326 | dilation=dilation, padding=dilation, bias=False) 327 | self.bn2 = BatchNorm(planes) 328 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 329 | self.bn3 = BatchNorm(planes * self.expansion) 330 | self.relu = nn.ReLU(inplace=True) 331 | # SE 332 | self.global_pool = nn.AdaptiveAvgPool2d(1) 333 | self.conv_down = nn.Conv2d( 334 | planes * 4, planes // 4, kernel_size=1, bias=False) 335 | self.conv_up = nn.Conv2d( 336 | planes // 4, planes * 4, kernel_size=1, bias=False) 337 | self.sig = nn.Sigmoid() 338 | 339 | self.downsample = downsample 340 | self.stride = stride 341 | self.dilation = dilation 342 | 343 | def forward(self, x): 344 | residual = x 345 | 346 | out = self.conv1(x) 347 | out = self.bn1(out) 348 | out = self.relu(out) 349 | 350 | out = self.conv2(out) 351 | out = self.bn2(out) 352 | out = self.relu(out) 353 | 354 | out = self.conv3(out) 355 | out = self.bn3(out) 356 | 357 | out1 = self.global_pool(out) 358 | out1 = self.conv_down(out1) 359 | out1 = self.relu(out1) 360 | out1 = self.conv_up(out1) 361 | out1 = self.sig(out1) 362 | 363 | if self.downsample is not None: 364 | residual = self.downsample(x) 365 | 366 | res = out1 * out + residual 367 | res = self.relu(res) 368 | 369 | return res 370 | 371 | 372 | class SELayer(nn.Module): 373 | def __init__(self, channel, reduction=16): 374 | super(SELayer, self).__init__() 375 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 376 | self.fc = nn.Sequential( 377 | nn.Linear(channel, channel // reduction, bias=False), 378 | nn.ReLU(inplace=True), 379 | nn.Linear(channel // reduction, channel, bias=False), 380 | nn.Sigmoid() 381 | ) 382 | 383 | def forward(self, x): 384 | b, c, _, _ = x.size() 385 | y = self.avg_pool(x).view(b, c) 386 | y = self.fc(y).view(b, c, 1, 1) 387 | return x * y.expand_as(x) 388 | 389 | 390 | class Refine(nn.Module): 391 | def __init__(self, inplanes, planes, scale_factor=2): 392 | super(Refine, self).__init__() 393 | self.convFS1 = nn.Conv2d(inplanes, planes, kernel_size=3, padding=1) 394 | self.convFS2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 395 | self.convFS3 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 396 | self.convMM1 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 397 | self.convMM2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 398 | self.scale_factor = scale_factor 399 | 400 | def forward(self, f, pm): 401 | s = self.convFS1(f) 402 | sr = self.convFS2(F.relu(s)) 403 | sr = self.convFS3(F.relu(sr)) 404 | s = s + sr 405 | 406 | m = s + F.interpolate(pm, scale_factor=self.scale_factor, mode='bilinear',align_corners=True) 407 | mr = self.convMM1(F.relu(m)) 408 | mr = self.convMM2(F.relu(mr)) 409 | m = m + mr 410 | return m 411 | --------------------------------------------------------------------------------