├── LICENSE ├── Overall_Network.png ├── README.md ├── config.py ├── eval_davis-framework.py ├── eval_real-world.py ├── libs ├── __init__.py ├── analyze_report.py ├── custom_transforms.py ├── davis2017_torchdataset.py ├── utils.py └── utils_torch.py ├── networks ├── __init__.py ├── atnet.py ├── correlation_package.zip ├── deeplab │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── aspp.cpython-36.pyc │ ├── aspp.py │ ├── backbone │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── drn.cpython-36.pyc │ │ │ ├── mobilenet.cpython-36.pyc │ │ │ ├── resnet.cpython-36.pyc │ │ │ └── xception.cpython-36.pyc │ │ ├── drn.py │ │ ├── mobilenet.py │ │ ├── resnet.py │ │ └── xception.py │ ├── decoder.py │ ├── deeplab.py │ └── sync_batchnorm │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── batchnorm.cpython-36.pyc │ │ ├── comm.cpython-36.pyc │ │ └── replicate.cpython-36.pyc │ │ ├── batchnorm.py │ │ ├── comm.py │ │ ├── replicate.py │ │ └── unittest.py └── ltm_transfer.py └── results └── test_result_davisframework ├── IVOS-ATNet_JF_example ├── summary.json └── summary_graph_0.827.png └── IVOS-ATNet_J_example ├── summary.json └── summary_graph_0.790.png /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Yuk Heo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Overall_Network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuk6heo/IVOS-ATNet/1cf574953a96bd680c518c6362b510fd103ff271/Overall_Network.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Python 3.6](https://img.shields.io/badge/python-3.6-green.svg) 2 | # Interactive Video Object Segmentation Using Global and Local Transfer Modules 3 | ### Yuk Heo, Yeong Jun Koh, Chang-Su Kim 4 | 5 | 6 | 7 | [[Project page]](http://mcl.korea.ac.kr/yukheo_eccv2020/) 8 | [[arXiv]](https://arxiv.org/abs/2007.08139) 9 | 10 | Implementation of ECCV2020 paper, "Interactive Video Object Segmentation Using Global and Local Transfer Modules" 11 | 12 | Codes in this github: 13 | 14 | 1. DAVIS2017 evaluation based on the [DAVIS framework](https://interactive.davischallenge.org/) 15 | 2. DAVIS2016 real-world evaluation GUI 16 | 17 | ## Prerequisite 18 | - cuda 10.0 19 | - python 3.6 20 | - pytorch 1.2.0 21 | - [davisinteractive 1.0.4](https://github.com/albertomontesg/davis-interactive) 22 | - [corrlation package](https://github.com/NVIDIA/flownet2-pytorch/tree/master/networks/correlation_package) of [FlowNet2](https://github.com/NVIDIA/flownet2-pytorch) 23 | - numpy, cv2, PtQt5, and other general libraries of python3 24 | 25 | ## Directory Structure 26 | * `root/libs`: library of utility files. 27 | 28 | * `root/networks` : network codes. 29 | - `correlation_package.zip` : conserves GPU memory by appling the correlation package of FlowNet2. 30 | - `deeplab`: applies ASPP module in decoders. [[original code]](https://github.com/jfzhang95/pytorch-deeplab-xception/tree/master/modeling) 31 | - `atnet.py`: consists A-Net and T-Net. 32 | - `ltm_transfer.py`: transfers previous segmentation with the local affinity of the local transfer module. 33 | 34 | * `root/config.py` : configurations. 35 | 36 | * `root/eval_davis-framework.py` : DAVIS2017 evaluation based on the [DAVIS framework](https://interactive.davischallenge.org/). 37 | 38 | * `root/eval_real-world.py` : DAVIS2016 real-world evaluation GUI (to be released). 39 | 40 | ## Instruction 41 | 42 | ### DAVIS2017 evaluation based on the DAVIS framework 43 | 44 | 1. Edit `config.py` to set the directory of your DAVIS2017 dataset and the gpu ID. 45 | 2. Unzip and build [corrlation package](https://github.com/NVIDIA/flownet2-pytorch/tree/master/networks/correlation_package) by 46 | ``` 47 | cd ./networks 48 | unzip correlation_package.zip 49 | cd correlation_package 50 | rm -rf *_cuda.egg-info build dist __pycache__ 51 | python3 setup.py install --user 52 | ``` 53 |     If you have problems in this step, you can find more information in the [repository](https://github.com/NVIDIA/flownet2-pytorch). 54 | 55 | 3. Download our [network parameters](https://drive.google.com/file/d/1t1VO2zy3pLBXCWqme9h63Def86Y4ECIH/view?usp=sharing) and place the file as `root/ATNet-checkpoint.pth`. 56 | 4. Run with `python3 eval_davis-framework.py`. 57 | 58 | ### DAVIS2016 real-world evaluation GUI 59 | 60 | Multi-object GUI (for DAVIS2017) is available at our github page, [[GUI-IVOS]](https://github.com/yuk6heo/GUI-IVOS) 61 | 62 | ## Reference 63 | 64 | Please cite our paper if the implementations are useful in your work: 65 | ``` 66 | @Inproceedings{ 67 | Yuk2020IVOS, 68 | title={Interactive Video Object Segmentation Using Global and Local Transfer Modules}, 69 | author={Yuk Heo and Yeong Jun Koh and Chang-Su Kim}, 70 | booktitle={ECCV}, 71 | year={2020}, 72 | url={https://openreview.net/forum?id=bo_lWt_aA} 73 | } 74 | ``` 75 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class Config(object): 4 | def __init__(self): 5 | 6 | ################################ C ################################## 7 | # DAVIS path 8 | self.davis_dataset_dir = '/home/yuk/data_ssd/datasets/DAVIS' 9 | self.test_gpu_id = 2 10 | self.test_metric_list = ['J', 'J_AND_F'] 11 | 12 | ################################ For test parameters ################################## 13 | self.test_host = 'localhost' # 'localhost' for subsets train and val. 14 | self.test_subset = 'val' 15 | self.test_userkey = None 16 | self.test_propagation_proportion = 0.99 17 | self.test_propth = 0.8 18 | self.test_min_nb_nodes = 2 19 | self.test_save_all_segs_option = True 20 | 21 | ############################### Other parameters ################################## 22 | self.mean, self.var = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 23 | self.scribble_dilation_param = 5 24 | 25 | # Rel path 26 | # project_path = os.path.dirname(__file__) 27 | # self.font_dir = project_path + '/fonts/' 28 | self.palette_dir = self.davis_dataset_dir + '/Annotations/480p/bear/00000.png' 29 | self.test_result_df_dir = 'results/test_result_davisframework' 30 | self.test_result_rw_dir = 'results/test_result_realworld' 31 | self.test_load_state_dir = 'ATNet-checkpoint.pth' # CKpath 32 | -------------------------------------------------------------------------------- /eval_davis-framework.py: -------------------------------------------------------------------------------- 1 | from davisinteractive.session import DavisInteractiveSession 2 | from davisinteractive import utils as interactive_utils 3 | from davisinteractive.dataset import Davis 4 | from davisinteractive.metrics import batched_jaccard 5 | 6 | from libs import custom_transforms as tr, davis2017_torchdataset 7 | import os 8 | 9 | import numpy as np 10 | from PIL import Image 11 | import csv 12 | from datetime import datetime 13 | 14 | import torch 15 | from torch.autograd import Variable 16 | from torchvision import transforms 17 | from torch.utils.data import DataLoader 18 | 19 | from libs import utils, utils_torch 20 | from libs.analyze_report import analyze_summary 21 | from config import Config 22 | from networks.atnet import ATnet 23 | 24 | 25 | class Main_tester(object): 26 | def __init__(self, config): 27 | self.config = config 28 | self.Davisclass = Davis(self.config.davis_dataset_dir) 29 | self.current_time = datetime.now().strftime('%Y%m%d-%H%M%S') 30 | self._palette = Image.open(self.config.palette_dir).getpalette() 31 | self.save_res_dir = str() 32 | self.save_log_dir = str() 33 | self.save_logger = None 34 | self.save_csvsummary_dir = str() 35 | 36 | self.net = ATnet() 37 | self.net.cuda() 38 | self.net.eval() 39 | self.net.load_state_dict(torch.load(self.config.test_load_state_dir)) 40 | 41 | # To implement ordered test 42 | self.scr_indices = [1, 2, 3] 43 | self.max_nb_interactions = 8 44 | self.max_time = self.max_nb_interactions * 30 45 | self.scr_samples = [] 46 | for v in sorted(self.Davisclass.sets[self.config.test_subset]): 47 | for idx in self.scr_indices: 48 | self.scr_samples.append((v, idx)) 49 | 50 | self.img_size, self.num_frames, self.n_objects, self.final_masks, self.tmpdict_siact = None, None, None, None, None 51 | self.pad_info, self.hpad1, self.wpad1, self.hpad2, self.wpad2 = None, None, None, None, None 52 | 53 | def run_for_diverse_metrics(self, ): 54 | 55 | with torch.no_grad(): 56 | for metric in self.config.test_metric_list: 57 | if metric == 'J': 58 | dir_name = 'IVOS-ATNet_J_' + self.current_time 59 | elif metric == 'J_AND_F': 60 | dir_name = 'IVOS-ATNet_JF_' + self.current_time 61 | else: 62 | dir_name = None 63 | print("Impossible metric is contained in config.test_metric_list!") 64 | raise NotImplementedError() 65 | self.save_res_dir = os.path.join(self.config.test_result_df_dir, dir_name) 66 | utils.mkdir(self.save_res_dir) 67 | self.save_csvsummary_dir = os.path.join(self.save_res_dir, 'summary_in_csv.csv') 68 | self.save_log_dir = os.path.join(self.save_res_dir, 'test_logs.txt') 69 | self.save_logger = utils.logger(self.save_log_dir) 70 | self.save_logger.printNlog(dir_name) 71 | curr_path = os.path.dirname(os.path.abspath(__file__)) 72 | os.system('cp {}/config.py {}/config.py'.format(curr_path, self.save_res_dir)) 73 | 74 | 75 | 76 | self.run_IVOS(metric) 77 | 78 | def run_IVOS(self, metric): 79 | seen_seq = {} 80 | numseq, tmpseq = 0, '' 81 | output_dict = dict() 82 | output_dict['average_objs_iou'] = dict() 83 | output_dict['average_iact_iou'] = np.zeros(self.max_nb_interactions) 84 | output_dict['annotated_frames'] = dict() 85 | 86 | with open(self.save_csvsummary_dir, mode='a') as csv_file: 87 | writer = csv.writer(csv_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) 88 | writer.writerow(['sequence', 'obj_idx', 'scr_idx'] + ['round-' + str(i + 1) for i in range(self.max_nb_interactions)]) 89 | 90 | with DavisInteractiveSession(host=self.config.test_host, 91 | user_key=self.config.test_userkey, 92 | davis_root=self.config.davis_dataset_dir, 93 | subset=self.config.test_subset, 94 | report_save_dir=self.save_res_dir, 95 | max_nb_interactions=self.max_nb_interactions, 96 | max_time=self.max_time, 97 | metric_to_optimize=metric) as sess: 98 | 99 | sess.connector.service.robot.min_nb_nodes = self.config.test_min_nb_nodes 100 | sess.samples = self.scr_samples 101 | # sess.samples = [('dog', 3)] 102 | 103 | while sess.next(): 104 | # Get the current iteration scribbles 105 | self.sequence, scribbles, first_scribble = sess.get_scribbles(only_last=False) 106 | 107 | if first_scribble: 108 | anno_dict = {'frames': [], 'annotated_masks': [], 'masks_tobe_modified': []} 109 | n_interaction = 1 110 | info = Davis.dataset[self.sequence] 111 | self.img_size = info['image_size'][::-1] 112 | self.num_frames = info['num_frames'] 113 | self.n_objects = info['num_objects'] 114 | info = None 115 | seen_seq[self.sequence] = 1 if self.sequence not in seen_seq.keys() else seen_seq[self.sequence] + 1 116 | scr_id = seen_seq[self.sequence] 117 | self.final_masks = np.zeros([self.num_frames, self.img_size[0], self.img_size[1]]) 118 | self.pad_info = utils.apply_pad(self.final_masks[0])[1] 119 | self.hpad1, self.wpad1 = self.pad_info[0][0], self.pad_info[1][0] 120 | self.hpad2, self.wpad2 = self.pad_info[0][1], self.pad_info[1][1] 121 | self.h_ds, self.w_ds = int((self.img_size[0] + sum(self.pad_info[0])) / 4), int((self.img_size[1] + sum(self.pad_info[1])) / 4) 122 | self.anno_6chEnc_r5_list = [] 123 | self.anno_3chEnc_r5_list = [] 124 | self.prob_map_of_frames = torch.zeros((self.num_frames, self.n_objects, 4 * self.h_ds, 4 * self.w_ds)).cuda() 125 | self.gt_masks = self.Davisclass.load_annotations(self.sequence) 126 | 127 | IoU_over_eobj = [] 128 | 129 | else: 130 | n_interaction += 1 131 | 132 | self.save_logger.printNlog('\nRunning sequence {} in (scribble index: {}) (round: {})' 133 | .format(self.sequence, sess.samples[sess.sample_idx][1], n_interaction)) 134 | 135 | annotated_now = interactive_utils.scribbles.annotated_frames(sess.sample_last_scribble)[0] 136 | anno_dict['frames'].append(annotated_now) # Where we save annotated frames 137 | anno_dict['masks_tobe_modified'].append(self.final_masks[annotated_now]) # mask before modefied at the annotated frame 138 | 139 | # Get Predicted mask & Mask decision from pred_mask 140 | self.final_masks = self.run_VOS_singleiact(n_interaction, scribbles, anno_dict['frames']) # self.final_mask changes 141 | 142 | if self.config.test_save_all_segs_option: 143 | utils.mkdir( 144 | os.path.join(self.save_res_dir, 'result_video', '{}-scr{:02d}/round{:02d}'.format(self.sequence, scr_id, n_interaction))) 145 | for fr in range(self.num_frames): 146 | savefname = os.path.join(self.save_res_dir, 'result_video', 147 | '{}-scr{:02d}/round{:02d}'.format(self.sequence, scr_id, n_interaction), 148 | '{:05d}.png'.format(fr)) 149 | tmpPIL = Image.fromarray(self.final_masks[fr].astype(np.uint8), 'P') 150 | tmpPIL.putpalette(self._palette) 151 | tmpPIL.save(savefname) 152 | 153 | # Submit your prediction 154 | sess.submit_masks(self.final_masks) # F, H, W 155 | 156 | # print sequence name 157 | if tmpseq != self.sequence: 158 | tmpseq, numseq = self.sequence, numseq + 1 159 | print(str(numseq) + ':' + str(self.sequence) + '-' + str(seen_seq[self.sequence]) + '\n') 160 | 161 | ## Visualizers and Saver 162 | # IoU estimation 163 | jaccard = batched_jaccard(self.gt_masks, 164 | self.final_masks, 165 | average_over_objects=False, 166 | nb_objects=self.n_objects 167 | ) # frames, objid 168 | 169 | IoU_over_eobj.append(jaccard) 170 | 171 | anno_dict['annotated_masks'].append(self.final_masks[annotated_now]) # mask after modefied at the annotated frame 172 | 173 | if self.max_nb_interactions == len(anno_dict['frames']): # After Lastround -> total 90 iter 174 | seq_scrid_name = self.sequence + str(scr_id) 175 | 176 | # IoU manager 177 | IoU_over_eobj = np.stack(IoU_over_eobj, axis=0) # niact,frames,n_obj 178 | IoUeveryround_perobj = np.mean(IoU_over_eobj, axis=1) # niact,n_obj 179 | output_dict['average_iact_iou'] += np.sum(IoU_over_eobj[list(range(n_interaction)), anno_dict['frames']], axis=-1) 180 | output_dict['annotated_frames'][seq_scrid_name] = anno_dict['frames'] 181 | 182 | # write csv 183 | for obj_idx in range(self.n_objects): 184 | with open(self.save_csvsummary_dir, mode='a') as csv_file: 185 | writer = csv.writer(csv_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) 186 | writer.writerow([self.sequence, str(obj_idx + 1), str(scr_id)] + list(IoUeveryround_perobj[:, obj_idx])) 187 | 188 | summary = sess.get_global_summary(save_file=self.save_res_dir + '/summary_' + sess.report_name[7:] + '.json') 189 | analyze_summary(self.save_res_dir + '/summary_' + sess.report_name[7:] + '.json', metric=metric) 190 | 191 | # final_IOU = summary['curve'][metric][-1] 192 | average_IoU_per_round = summary['curve'][metric][1:-1] 193 | 194 | torch.cuda.empty_cache() 195 | model = None 196 | return average_IoU_per_round 197 | 198 | def run_VOS_singleiact(self, n_interaction, scribbles_data, annotated_frames): 199 | 200 | annotated_frames_np = np.array(annotated_frames) 201 | num_workers = 4 202 | annotated_now = annotated_frames[-1] 203 | scribbles_list = scribbles_data['scribbles'] 204 | seq_name = scribbles_data['sequence'] 205 | 206 | output_masks = self.final_masks.copy().astype(np.float64) 207 | 208 | prop_list = utils.get_prop_list(annotated_frames, annotated_now, self.num_frames, proportion=self.config.test_propagation_proportion) 209 | prop_fore = sorted(prop_list)[0] 210 | prop_rear = sorted(prop_list)[-1] 211 | 212 | # Interaction settings 213 | pm_ps_ns_3ch_t = [] # n_obj,3,h,w 214 | if n_interaction == 1: 215 | for obj_id in range(1, self.n_objects + 1): 216 | pos_scrimg = utils.scribble_to_image(scribbles_list, annotated_now, obj_id, 217 | dilation=self.config.scribble_dilation_param, 218 | prev_mask=self.final_masks[annotated_now]) 219 | pm_ps_ns_3ch_t.append(np.stack([np.ones_like(pos_scrimg) / 2, pos_scrimg, np.zeros_like(pos_scrimg)], axis=0)) 220 | pm_ps_ns_3ch_t = np.stack(pm_ps_ns_3ch_t, axis=0) # n_obj,3,h,w 221 | # Image.fromarray((scr_img[:, :, 1] * 255).astype(np.uint8)).save('/home/six/Desktop/CVPRW_figure/judo_obj1_scr.png') 222 | 223 | else: 224 | for obj_id in range(1, self.n_objects + 1): 225 | prev_round_input = (self.final_masks[annotated_now] == obj_id).astype(np.float32) # H,W 226 | pos_scrimg, neg_scrimg = utils.scribble_to_image(scribbles_list, annotated_now, obj_id, 227 | dilation=self.config.scribble_dilation_param, 228 | prev_mask=self.final_masks[annotated_now], blur=True, 229 | singleimg=False, seperate_pos_neg=True) 230 | pm_ps_ns_3ch_t.append(np.stack([prev_round_input, pos_scrimg, neg_scrimg], axis=0)) 231 | pm_ps_ns_3ch_t = np.stack(pm_ps_ns_3ch_t, axis=0) # n_obj,3,h,w 232 | pm_ps_ns_3ch_t = torch.from_numpy(pm_ps_ns_3ch_t).cuda() 233 | 234 | if (prop_list[0] != annotated_now) and (prop_list.count(annotated_now) != 2): 235 | print(str(prop_list)) 236 | raise NotImplementedError 237 | print(str(prop_list)) # we made our proplist first backward, and then forward 238 | 239 | composed_transforms = transforms.Compose([tr.Normalize_ApplymeanvarImage(self.config.mean, self.config.var), 240 | tr.ToTensor()]) 241 | db_test = davis2017_torchdataset.DAVIS2017(split='val', transform=composed_transforms, root=self.config.davis_dataset_dir, 242 | custom_frames=prop_list, seq_name=seq_name, rgb=True, 243 | obj_id=None, no_gt=True, retname=True, prev_round_masks=self.final_masks, ) 244 | testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True) 245 | 246 | flag = 0 # 1: propagating backward, 2: propagating forward 247 | print('[{:01d} round] processing...'.format(n_interaction)) 248 | 249 | for ii, batched in enumerate(testloader): 250 | # batched : image, scr_img, 0~fr, meta 251 | inpdict = dict() 252 | operating_frame = int(batched['meta']['frame_id'][0]) 253 | 254 | for inp in batched: 255 | if inp == 'meta': continue 256 | inpdict[inp] = Variable(batched[inp]).cuda() 257 | 258 | inpdict['image'] = inpdict['image'].expand(self.n_objects, -1, -1, -1) 259 | 260 | #################### Iaction ######################## 261 | if operating_frame == annotated_now: # Check the round is on interaction 262 | if flag == 0: 263 | flag += 1 264 | adjacent_to_anno = True 265 | elif flag == 1: 266 | flag += 1 267 | adjacent_to_anno = True 268 | continue 269 | else: 270 | raise NotImplementedError 271 | 272 | pm_ps_ns_3ch_t = torch.nn.ReflectionPad2d(self.pad_info[1] + self.pad_info[0])(pm_ps_ns_3ch_t) 273 | inputs = torch.cat([inpdict['image'], pm_ps_ns_3ch_t], dim=1) 274 | output_logit, anno_6chEnc_r5 = self.net.forward_ANet(inputs) # [nobj, 1, P_H, P_W], # [n_obj,2048,h/16,w/16] 275 | output_prob_anno = torch.sigmoid(output_logit) 276 | prob_onehot_t = output_prob_anno[:, 0].detach() 277 | 278 | anno_3chEnc_r5, _, _, r2_prev_fromanno = self.net.encoder_3ch.forward(inpdict['image']) 279 | self.anno_6chEnc_r5_list.append(anno_6chEnc_r5) 280 | self.anno_3chEnc_r5_list.append(anno_3chEnc_r5) 281 | 282 | if len(self.anno_6chEnc_r5_list) != len(annotated_frames): 283 | raise NotImplementedError 284 | 285 | 286 | 287 | #################### Propagation ######################## 288 | else: 289 | # Flag [1: propagating backward, 2: propagating forward] 290 | if adjacent_to_anno: 291 | r2_prev = r2_prev_fromanno 292 | predmask_prev = output_prob_anno 293 | else: 294 | predmask_prev = output_prob_prop 295 | adjacent_to_anno = False 296 | 297 | output_logit, r2_prev = self.net.forward_TNet( 298 | self.anno_3chEnc_r5_list, inpdict['image'], self.anno_6chEnc_r5_list, r2_prev, predmask_prev) # [nobj, 1, P_H, P_W] 299 | output_prob_prop = torch.sigmoid(output_logit) 300 | prob_onehot_t = output_prob_prop[:, 0].detach() 301 | 302 | smallest_alpha = 0.5 303 | if flag == 1: 304 | sorted_frames = annotated_frames_np[annotated_frames_np < annotated_now] 305 | if len(sorted_frames) == 0: 306 | alpha = 1 307 | else: 308 | closest_addianno_frame = np.max(sorted_frames) 309 | alpha = smallest_alpha + (1 - smallest_alpha) * ( 310 | (operating_frame - closest_addianno_frame) / (annotated_now - closest_addianno_frame)) 311 | else: 312 | sorted_frames = annotated_frames_np[annotated_frames_np > annotated_now] 313 | if len(sorted_frames) == 0: 314 | alpha = 1 315 | else: 316 | closest_addianno_frame = np.min(sorted_frames) 317 | alpha = smallest_alpha + (1 - smallest_alpha) * ( 318 | (closest_addianno_frame - operating_frame) / (closest_addianno_frame - annotated_now)) 319 | 320 | prob_onehot_t = (alpha * prob_onehot_t) + ((1 - alpha) * self.prob_map_of_frames[operating_frame]) 321 | 322 | # Final mask indexing 323 | self.prob_map_of_frames[operating_frame] = prob_onehot_t 324 | 325 | output_masks[prop_fore:prop_rear + 1] = \ 326 | utils_torch.combine_masks_with_batch(self.prob_map_of_frames[prop_fore:prop_rear + 1], 327 | n_obj=self.n_objects, th=self.config.test_propth 328 | )[:, 0, self.hpad1:-self.hpad2, self.wpad1:-self.wpad2].cpu().numpy().astype(np.float) # f,h,w 329 | 330 | torch.cuda.empty_cache() 331 | 332 | return output_masks 333 | 334 | 335 | if __name__ == '__main__': 336 | config = Config() 337 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 338 | os.environ["CUDA_VISIBLE_DEVICES"] = str(config.test_gpu_id) 339 | 340 | tester = Main_tester(config) 341 | tester.run_for_diverse_metrics() 342 | 343 | # try:main_val(model, 344 | # Config, 345 | # min_nb_nodes= min_nb_nodes, 346 | # simplyfied_testset= simplyfied_test,tr(config.test_gpu_id) 347 | # metric = metric) 348 | # except: continue 349 | -------------------------------------------------------------------------------- /eval_real-world.py: -------------------------------------------------------------------------------- 1 | from davisinteractive.session import DavisInteractiveSession 2 | from davisinteractive import utils as interactive_utils 3 | from davisinteractive.dataset import Davis 4 | from davisinteractive.metrics import batched_jaccard 5 | 6 | from libs import custom_transforms as tr, davis2017_torchdataset 7 | import os 8 | 9 | import numpy as np 10 | from PIL import Image 11 | import csv 12 | from datetime import datetime 13 | 14 | import torch 15 | from torch.autograd import Variable 16 | from torchvision import transforms 17 | from torch.utils.data import DataLoader 18 | 19 | from libs import utils, utils_torch 20 | from libs.analyze_report import analyze_summary 21 | from config import Config 22 | from networks.atnet import ATnet 23 | 24 | 25 | class Main_tester(object): 26 | def __init__(self, config): 27 | self.config = config 28 | self.Davisclass = Davis(self.config.davis_dataset_dir) 29 | self.current_time = datetime.now().strftime('%Y%m%d-%H%M%S') 30 | self._palette = Image.open(self.config.palette_dir).getpalette() 31 | self.save_res_dir = str() 32 | self.save_log_dir = str() 33 | self.save_logger = None 34 | self.save_csvsummary_dir = str() 35 | 36 | self.net = ATnet() 37 | self.net.cuda() 38 | self.net.eval() 39 | self.net.load_state_dict(torch.load(self.config.test_load_state_dir)) 40 | 41 | # To implement ordered test 42 | self.scr_indices = [1, 2, 3] 43 | self.max_nb_interactions = 8 44 | self.max_time = self.max_nb_interactions * 30 45 | self.scr_samples = [] 46 | for v in sorted(self.Davisclass.sets[self.config.test_subset]): 47 | for idx in self.scr_indices: 48 | self.scr_samples.append((v, idx)) 49 | 50 | self.img_size, self.num_frames, self.n_objects, self.final_masks, self.tmpdict_siact = None, None, None, None, None 51 | self.pad_info, self.hpad1, self.wpad1, self.hpad2, self.wpad2 = None, None, None, None, None 52 | 53 | def run_for_diverse_metrics(self, ): 54 | 55 | with torch.no_grad(): 56 | for metric in self.config.test_metric_list: 57 | if metric == 'J': 58 | dir_name = os.path.split(os.path.split(__file__)[0])[1] + '[J]_' + self.current_time 59 | elif metric == 'J_AND_F': 60 | dir_name = os.path.split(os.path.split(__file__)[0])[1] + '[JF]_' + self.current_time 61 | else: 62 | dir_name = None 63 | print("Impossible metric is contained in config.test_metric_list!") 64 | raise NotImplementedError() 65 | self.save_res_dir = os.path.join(self.config.test_result_dir, dir_name) 66 | utils.mkdir(self.save_res_dir) 67 | self.save_csvsummary_dir = os.path.join(self.save_res_dir, 'summary_in_csv.csv') 68 | self.save_log_dir = os.path.join(self.save_res_dir, 'test_logs.txt') 69 | self.save_logger = utils.logger(self.save_log_dir) 70 | self.save_logger.printNlog(dir_name) 71 | curr_path = os.path.dirname(os.path.abspath(__file__)) 72 | os.system('cp {}/config.py {}/config.py'.format(curr_path, self.save_res_dir)) 73 | 74 | 75 | 76 | self.run_IVOS(metric) 77 | 78 | def run_IVOS(self, metric): 79 | seen_seq = {} 80 | numseq, tmpseq = 0, '' 81 | output_dict = dict() 82 | output_dict['average_objs_iou'] = dict() 83 | output_dict['average_iact_iou'] = np.zeros(self.max_nb_interactions) 84 | output_dict['annotated_frames'] = dict() 85 | 86 | with open(self.save_csvsummary_dir, mode='a') as csv_file: 87 | writer = csv.writer(csv_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) 88 | writer.writerow(['sequence', 'obj_idx', 'scr_idx'] + ['round-' + str(i + 1) for i in range(self.max_nb_interactions)]) 89 | 90 | with DavisInteractiveSession(host=self.config.test_host, 91 | user_key=self.config.test_userkey, 92 | davis_root=self.config.davis_dataset_dir, 93 | subset=self.config.test_subset, 94 | report_save_dir=self.save_res_dir, 95 | max_nb_interactions=self.max_nb_interactions, 96 | max_time=self.max_time, 97 | metric_to_optimize=metric) as sess: 98 | 99 | sess.connector.service.robot.min_nb_nodes = self.config.test_min_nb_nodes 100 | sess.samples = self.scr_samples 101 | # sess.samples = [('dog', 3)] 102 | 103 | while sess.next(): 104 | # Get the current iteration scribbles 105 | self.sequence, scribbles, first_scribble = sess.get_scribbles(only_last=False) 106 | 107 | if first_scribble: 108 | anno_dict = {'frames': [], 'annotated_masks': [], 'masks_tobe_modified': []} 109 | n_interaction = 1 110 | info = Davis.dataset[self.sequence] 111 | self.img_size = info['image_size'][::-1] 112 | self.num_frames = info['num_frames'] 113 | self.n_objects = info['num_objects'] 114 | info = None 115 | seen_seq[self.sequence] = 1 if self.sequence not in seen_seq.keys() else seen_seq[self.sequence] + 1 116 | scr_id = seen_seq[self.sequence] 117 | self.final_masks = np.zeros([self.num_frames, self.img_size[0], self.img_size[1]]) 118 | self.pad_info = utils.apply_pad(self.final_masks[0])[1] 119 | self.hpad1, self.wpad1 = self.pad_info[0][0], self.pad_info[1][0] 120 | self.hpad2, self.wpad2 = self.pad_info[0][1], self.pad_info[1][1] 121 | self.h_ds, self.w_ds = int((self.img_size[0] + sum(self.pad_info[0])) / 4), int((self.img_size[1] + sum(self.pad_info[1])) / 4) 122 | self.anno_6chEnc_r5_list = [] 123 | self.anno_3chEnc_r5_list = [] 124 | self.prob_map_of_frames = torch.zeros((self.num_frames, self.n_objects, 4 * self.h_ds, 4 * self.w_ds)).cuda() 125 | self.gt_masks = self.Davisclass.load_annotations(self.sequence) 126 | 127 | IoU_over_eobj = [] 128 | 129 | else: 130 | n_interaction += 1 131 | 132 | self.save_logger.printNlog('\nRunning sequence {} in (scribble index: {}) (round: {})' 133 | .format(self.sequence, sess.samples[sess.sample_idx][1], n_interaction)) 134 | 135 | annotated_now = interactive_utils.scribbles.annotated_frames(sess.sample_last_scribble)[0] 136 | anno_dict['frames'].append(annotated_now) # Where we save annotated frames 137 | anno_dict['masks_tobe_modified'].append(self.final_masks[annotated_now]) # mask before modefied at the annotated frame 138 | 139 | # Get Predicted mask & Mask decision from pred_mask 140 | self.final_masks = self.run_VOS_singleiact(n_interaction, scribbles, anno_dict['frames']) # self.final_mask changes 141 | 142 | if self.config.test_save_all_segs_option: 143 | utils.mkdir( 144 | os.path.join(self.save_res_dir, 'result_video', '{}-scr{:02d}/round{:02d}'.format(self.sequence, scr_id, n_interaction))) 145 | for fr in range(self.num_frames): 146 | savefname = os.path.join(self.save_res_dir, 'result_video', 147 | '{}-scr{:02d}/round{:02d}'.format(self.sequence, scr_id, n_interaction), 148 | '{:05d}.png'.format(fr)) 149 | tmpPIL = Image.fromarray(self.final_masks[fr].astype(np.uint8), 'P') 150 | tmpPIL.putpalette(self._palette) 151 | tmpPIL.save(savefname) 152 | 153 | # Submit your prediction 154 | sess.submit_masks(self.final_masks) # F, H, W 155 | 156 | # print sequence name 157 | if tmpseq != self.sequence: 158 | tmpseq, numseq = self.sequence, numseq + 1 159 | print(str(numseq) + ':' + str(self.sequence) + '-' + str(seen_seq[self.sequence]) + '\n') 160 | 161 | ## Visualizers and Saver 162 | # IoU estimation 163 | jaccard = batched_jaccard(self.gt_masks, 164 | self.final_masks, 165 | average_over_objects=False, 166 | nb_objects=self.n_objects 167 | ) # frames, objid 168 | 169 | IoU_over_eobj.append(jaccard) 170 | 171 | anno_dict['annotated_masks'].append(self.final_masks[annotated_now]) # mask after modefied at the annotated frame 172 | 173 | if self.max_nb_interactions == len(anno_dict['frames']): # After Lastround -> total 90 iter 174 | seq_scrid_name = self.sequence + str(scr_id) 175 | 176 | # IoU manager 177 | IoU_over_eobj = np.stack(IoU_over_eobj, axis=0) # niact,frames,n_obj 178 | IoUeveryround_perobj = np.mean(IoU_over_eobj, axis=1) # niact,n_obj 179 | output_dict['average_iact_iou'] += np.sum(IoU_over_eobj[list(range(n_interaction)), anno_dict['frames']], axis=-1) 180 | output_dict['annotated_frames'][seq_scrid_name] = anno_dict['frames'] 181 | 182 | # write csv 183 | for obj_idx in range(self.n_objects): 184 | with open(self.save_csvsummary_dir, mode='a') as csv_file: 185 | writer = csv.writer(csv_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) 186 | writer.writerow([self.sequence, str(obj_idx + 1), str(scr_id)] + list(IoUeveryround_perobj[:, obj_idx])) 187 | 188 | summary = sess.get_global_summary(save_file=self.save_res_dir + '/summary_' + sess.report_name[7:] + '.json') 189 | analyze_summary(self.save_res_dir + '/summary_' + sess.report_name[7:] + '.json', metric=metric) 190 | 191 | # final_IOU = summary['curve'][metric][-1] 192 | average_IoU_per_round = summary['curve'][metric][1:-1] 193 | 194 | torch.cuda.empty_cache() 195 | model = None 196 | return average_IoU_per_round 197 | 198 | def run_VOS_singleiact(self, n_interaction, scribbles_data, annotated_frames): 199 | 200 | annotated_frames_np = np.array(annotated_frames) 201 | num_workers = 4 202 | annotated_now = annotated_frames[-1] 203 | scribbles_list = scribbles_data['scribbles'] 204 | seq_name = scribbles_data['sequence'] 205 | 206 | output_masks = self.final_masks.copy().astype(np.float64) 207 | 208 | prop_list = utils.get_prop_list(annotated_frames, annotated_now, self.num_frames, proportion=self.config.test_propagation_proportion) 209 | prop_fore = sorted(prop_list)[0] 210 | prop_rear = sorted(prop_list)[-1] 211 | 212 | # Interaction settings 213 | pm_ps_ns_3ch_t = [] # n_obj,3,h,w 214 | if n_interaction == 1: 215 | for obj_id in range(1, self.n_objects + 1): 216 | pos_scrimg = utils.scribble_to_image(scribbles_list, annotated_now, obj_id, 217 | dilation=self.config.scribble_dilation_param, 218 | prev_mask=self.final_masks[annotated_now]) 219 | pm_ps_ns_3ch_t.append(np.stack([np.ones_like(pos_scrimg) / 2, pos_scrimg, np.zeros_like(pos_scrimg)], axis=0)) 220 | pm_ps_ns_3ch_t = np.stack(pm_ps_ns_3ch_t, axis=0) # n_obj,3,h,w 221 | # Image.fromarray((scr_img[:, :, 1] * 255).astype(np.uint8)).save('/home/six/Desktop/CVPRW_figure/judo_obj1_scr.png') 222 | 223 | else: 224 | for obj_id in range(1, self.n_objects + 1): 225 | prev_round_input = (self.final_masks[annotated_now] == obj_id).astype(np.float32) # H,W 226 | pos_scrimg, neg_scrimg = utils.scribble_to_image(scribbles_list, annotated_now, obj_id, 227 | dilation=self.config.scribble_dilation_param, 228 | prev_mask=self.final_masks[annotated_now], blur=True, 229 | singleimg=False, seperate_pos_neg=True) 230 | pm_ps_ns_3ch_t.append(np.stack([prev_round_input, pos_scrimg, neg_scrimg], axis=0)) 231 | pm_ps_ns_3ch_t = np.stack(pm_ps_ns_3ch_t, axis=0) # n_obj,3,h,w 232 | pm_ps_ns_3ch_t = torch.from_numpy(pm_ps_ns_3ch_t).cuda() 233 | 234 | if (prop_list[0] != annotated_now) and (prop_list.count(annotated_now) != 2): 235 | print(str(prop_list)) 236 | raise NotImplementedError 237 | print(str(prop_list)) # we made our proplist first backward, and then forward 238 | 239 | composed_transforms = transforms.Compose([tr.Normalize_ApplymeanvarImage(self.config.mean, self.config.var), 240 | tr.ToTensor()]) 241 | db_test = davis2017_torchdataset.DAVIS2017(split='val', transform=composed_transforms, root=self.config.davis_dataset_dir, 242 | custom_frames=prop_list, seq_name=seq_name, rgb=True, 243 | obj_id=None, no_gt=True, retname=True, prev_round_masks=self.final_masks, ) 244 | testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True) 245 | 246 | flag = 0 # 1: propagating backward, 2: propagating forward 247 | print('[{:01d} round] processing...'.format(n_interaction)) 248 | 249 | for ii, batched in enumerate(testloader): 250 | # batched : image, scr_img, 0~fr, meta 251 | inpdict = dict() 252 | operating_frame = int(batched['meta']['frame_id'][0]) 253 | 254 | for inp in batched: 255 | if inp == 'meta': continue 256 | inpdict[inp] = Variable(batched[inp]).cuda() 257 | 258 | inpdict['image'] = inpdict['image'].expand(self.n_objects, -1, -1, -1) 259 | 260 | #################### Iaction ######################## 261 | if operating_frame == annotated_now: # Check the round is on interaction 262 | if flag == 0: 263 | flag += 1 264 | adjacent_to_anno = True 265 | elif flag == 1: 266 | flag += 1 267 | adjacent_to_anno = True 268 | continue 269 | else: 270 | raise NotImplementedError 271 | 272 | pm_ps_ns_3ch_t = torch.nn.ReflectionPad2d(self.pad_info[1] + self.pad_info[0])(pm_ps_ns_3ch_t) 273 | inputs = torch.cat([inpdict['image'], pm_ps_ns_3ch_t], dim=1) 274 | output_logit, anno_6chEnc_r5 = self.net.forward_ANet(inputs) # [nobj, 1, P_H, P_W], # [n_obj,2048,h/16,w/16] 275 | output_prob_anno = torch.sigmoid(output_logit) 276 | prob_onehot_t = output_prob_anno[:, 0].detach() 277 | 278 | anno_3chEnc_r5, _, _, r2_prev_fromanno = self.net.encoder_3ch.forward(inpdict['image']) 279 | self.anno_6chEnc_r5_list.append(anno_6chEnc_r5) 280 | self.anno_3chEnc_r5_list.append(anno_3chEnc_r5) 281 | 282 | if len(self.anno_6chEnc_r5_list) != len(annotated_frames): 283 | raise NotImplementedError 284 | 285 | 286 | 287 | #################### Propagation ######################## 288 | else: 289 | # Flag [1: propagating backward, 2: propagating forward] 290 | if adjacent_to_anno: 291 | r2_prev = r2_prev_fromanno 292 | predmask_prev = output_prob_anno 293 | else: 294 | predmask_prev = output_prob_prop 295 | adjacent_to_anno = False 296 | 297 | output_logit, r2_prev = self.net.forward_TNet( 298 | self.anno_3chEnc_r5_list, inpdict['image'], self.anno_6chEnc_r5_list, r2_prev, predmask_prev) # [nobj, 1, P_H, P_W] 299 | output_prob_prop = torch.sigmoid(output_logit) 300 | prob_onehot_t = output_prob_prop[:, 0].detach() 301 | 302 | smallest_alpha = 0.5 303 | if flag == 1: 304 | sorted_frames = annotated_frames_np[annotated_frames_np < annotated_now] 305 | if len(sorted_frames) == 0: 306 | alpha = 1 307 | else: 308 | closest_addianno_frame = np.max(sorted_frames) 309 | alpha = smallest_alpha + (1 - smallest_alpha) * ( 310 | (operating_frame - closest_addianno_frame) / (annotated_now - closest_addianno_frame)) 311 | else: 312 | sorted_frames = annotated_frames_np[annotated_frames_np > annotated_now] 313 | if len(sorted_frames) == 0: 314 | alpha = 1 315 | else: 316 | closest_addianno_frame = np.min(sorted_frames) 317 | alpha = smallest_alpha + (1 - smallest_alpha) * ( 318 | (closest_addianno_frame - operating_frame) / (closest_addianno_frame - annotated_now)) 319 | 320 | prob_onehot_t = (alpha * prob_onehot_t) + ((1 - alpha) * self.prob_map_of_frames[operating_frame]) 321 | 322 | # Final mask indexing 323 | self.prob_map_of_frames[operating_frame] = prob_onehot_t 324 | 325 | output_masks[prop_fore:prop_rear + 1] = \ 326 | utils_torch.combine_masks_with_batch(self.prob_map_of_frames[prop_fore:prop_rear + 1], 327 | n_obj=self.n_objects, th=self.config.test_propth 328 | )[:, 0, self.hpad1:-self.hpad2, self.wpad1:-self.wpad2].cpu().numpy().astype(np.float) # f,h,w 329 | 330 | torch.cuda.empty_cache() 331 | 332 | return output_masks 333 | 334 | 335 | if __name__ == '__main__': 336 | config = Config() 337 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 338 | os.environ["CUDA_VISIBLE_DEVICES"] = str(config.test_gpu_id) 339 | 340 | tester = Main_tester(config) 341 | tester.run_for_diverse_metrics() 342 | 343 | # try:main_val(model, 344 | # Config, 345 | # min_nb_nodes= min_nb_nodes, 346 | # simplyfied_testset= simplyfied_test,tr(config.test_gpu_id) 347 | # metric = metric) 348 | # except: continue 349 | -------------------------------------------------------------------------------- /libs/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /libs/analyze_report.py: -------------------------------------------------------------------------------- 1 | """ Analyse Global Summary 2 | """ 3 | import os 4 | import json 5 | import matplotlib.pyplot as plt 6 | 7 | def analyze_summary(fname, metric = 'J_AND_F'): 8 | METRIC_TXT = {'J': 'J', 9 | 'F': 'F', 10 | 'J_AND_F': 'J&F',} 11 | 12 | with open(fname, 'r') as fp: 13 | summary = json.load(fp) 14 | 15 | print('AUC: \t{:.3f}'.format(summary['auc'])) 16 | th = summary['metric_at_threshold']['threshold'] 17 | met = summary['metric_at_threshold'][metric] 18 | print('{}@{}: \t{:.3f}'.format(METRIC_TXT[metric], th, met)) 19 | 20 | time = summary['curve']['time'] 21 | metric_res = summary['curve'][metric] 22 | iteration = list(range(len(time))) 23 | 24 | fig = plt.figure(figsize=(6, 8)) 25 | fig.suptitle('[AUC/t: {:.3f}] [{}@{}: {:.3f}]'.format(summary['auc'],METRIC_TXT[metric],th, met), fontsize=16) 26 | ax1 = fig.add_subplot(211) 27 | ax1.plot(time, metric_res) 28 | ax1.plot(time, metric_res,'b.') 29 | # ax1.set_title('[AUC/t: {:.3f}] [J@{}: {:.3f}]'.format(summary['auc'],th, jac) ) 30 | ax1.set_ylim([0, 1]) 31 | ax1.set_xlim([0, max(time)]) 32 | ax1.set_xlabel('Accumulated Time (s)') 33 | ax1.set_ylabel(r'$\mathcal{' + METRIC_TXT[metric] + '}$') 34 | ax1.axvline(th, c='r') 35 | ax1.yaxis.grid(True) 36 | 37 | 38 | ax2 = fig.add_subplot(212) 39 | ax2.plot(iteration, metric_res) 40 | ax2.plot(iteration, metric_res,'b.') 41 | ax2.set_ylim([0, 1]) 42 | ax2.set_xlim([0, len(time)-1]) 43 | ax2.set_xlabel('Interactions (n)') 44 | ax2.set_ylabel(r'$\mathcal{' + METRIC_TXT[metric] + '}$') 45 | ax2.yaxis.grid(True) 46 | 47 | 48 | save_dir = os.path.split(fname)[0]+'/summary_graph_{:.3f}.png'.format(metric_res[-1]) 49 | plt.savefig(save_dir) 50 | 51 | if __name__ == '__main__': 52 | analyze_summary('/home/yuk/Desktop/IPNet_summary_davis17_val.json', 'J_AND_F') -------------------------------------------------------------------------------- /libs/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class Normalize_ApplymeanvarImage(object): 6 | def __init__(self, mean, var, change_channels=False): 7 | self.mean = mean 8 | self.var = var 9 | self.change_channels = change_channels 10 | 11 | def __call__(self, sample): 12 | for elem in sample.keys(): 13 | if 'image' in elem: 14 | if self.change_channels: 15 | sample[elem] = sample[elem][:, :, [2, 1, 0]] 16 | sample[elem] = sample[elem].astype(np.float32)/255.0 17 | sample[elem] = np.subtract(sample[elem], np.array(self.mean, dtype=np.float32))/np.array(self.var, dtype=np.float32) 18 | 19 | 20 | return sample 21 | 22 | def __str__(self): 23 | return 'SubtractMeanImage'+str(self.mean) 24 | 25 | 26 | class ToTensor(object): 27 | """Convert ndarrays in sample to Tensors.""" 28 | 29 | def __call__(self, sample): 30 | 31 | for elem in sample.keys(): 32 | if 'meta' in elem: 33 | continue 34 | tmp = sample[elem] 35 | 36 | if tmp.ndim == 2: 37 | tmp = tmp[:, :, np.newaxis] 38 | 39 | # swap color axis because 40 | # numpy image: H x W x C 41 | # torch image: C X H X W 42 | 43 | tmp = tmp.transpose((2, 0, 1)) 44 | sample[elem] = torch.from_numpy(tmp) 45 | 46 | return sample 47 | 48 | -------------------------------------------------------------------------------- /libs/davis2017_torchdataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | 4 | import os 5 | import numpy as np 6 | import cv2 7 | 8 | from libs import utils 9 | from torch.utils.data import Dataset 10 | import json 11 | from PIL import Image 12 | 13 | 14 | class DAVIS2017(Dataset): 15 | """DAVIS 2017 dataset constructed using the PyTorch built-in functionalities""" 16 | 17 | def __init__(self, 18 | split='val', 19 | root='', 20 | num_frames=None, 21 | custom_frames=None, 22 | transform=None, 23 | retname=False, 24 | seq_name=None, 25 | obj_id=None, 26 | gt_only_first_frame=False, 27 | no_gt=False, 28 | batch_gt=False, 29 | rgb=False, 30 | effective_batch=None, 31 | prev_round_masks = None,#f,h,w 32 | ): 33 | """Loads image to label pairs for tool pose estimation 34 | split: Split or list of splits of the dataset 35 | root: dataset directory with subfolders "JPEGImages" and "Annotations" 36 | num_frames: Select number of frames of the sequence (None for all frames) 37 | custom_frames: List or Tuple with the number of the frames to include 38 | transform: Data transformations 39 | retname: Retrieve meta data in the sample key 'meta' 40 | seq_name: Use a specific sequence 41 | obj_id: Use a specific object of a sequence (If None and sequence is specified, the batch_gt is True) 42 | gt_only_first_frame: Provide the GT only in the first frame 43 | no_gt: No GT is provided 44 | batch_gt: For every frame sequence batch all the different objects gt 45 | rgb: Use RGB channel order in the image 46 | """ 47 | if isinstance(split, str): 48 | self.split = [split] 49 | else: 50 | split.sort() 51 | self.split = split 52 | self.db_root_dir = root 53 | self.transform = transform 54 | self.seq_name = seq_name 55 | self.obj_id = obj_id 56 | self.num_frames = num_frames 57 | self.custom_frames = custom_frames 58 | self.retname = retname 59 | self.rgb = rgb 60 | if seq_name is not None and obj_id is None: 61 | batch_gt = True 62 | self.batch_gt = batch_gt 63 | self.all_seqs_list = [] 64 | 65 | self.seqs = [] 66 | for splt in self.split: 67 | with open(os.path.join(self.db_root_dir, 'ImageSets', '2017', splt + '.txt')) as f: 68 | seqs_tmp = f.readlines() 69 | seqs_tmp = list(map(lambda elem: elem.strip(), seqs_tmp)) 70 | self.seqs.extend(seqs_tmp) 71 | self.seq_list_file = os.path.join(self.db_root_dir, 'ImageSets', '2017', 72 | '_'.join(self.split) + '_instances.txt') 73 | # Precompute the dictionary with the objects per sequence 74 | if not self._check_preprocess(): 75 | self._preprocess() 76 | 77 | if self.seq_name is None: 78 | img_list = [] 79 | labels = [] 80 | prevmask_list= [] 81 | for seq in self.seqs: 82 | images = np.sort(os.listdir(os.path.join(self.db_root_dir, 'JPEGImages/480p/', seq.strip()))) 83 | images_path = list(map(lambda x: os.path.join('JPEGImages/480p/', seq.strip(), x), images)) 84 | lab = np.sort(os.listdir(os.path.join(self.db_root_dir, 'Annotations/480p/', seq.strip()))) 85 | lab_path = list(map(lambda x: os.path.join('Annotations/480p/', seq.strip(), x), lab)) 86 | if num_frames is not None: 87 | seq_len = len(images_path) 88 | num_frames = min(num_frames, seq_len) 89 | frame_vector = np.arange(num_frames) 90 | frames_ids = list(np.round(frame_vector*seq_len/float(num_frames)).astype(np.int)) 91 | frames_ids[-1] = min(frames_ids[-1], seq_len) 92 | images_path = [images_path[x] for x in frames_ids] 93 | if no_gt: 94 | lab_path = [None] * len(images_path) 95 | else: 96 | lab_path = [lab_path[x] for x in frames_ids] 97 | elif isinstance(custom_frames, tuple) or isinstance(custom_frames, list): 98 | assert min(custom_frames) >= 0 and max(custom_frames) <= len(images_path) 99 | images_path = [images_path[x] for x in custom_frames] 100 | prevmask_list = [prev_round_masks[x] for x in custom_frames] 101 | if no_gt: 102 | lab_path = [None] * len(images_path) 103 | else: 104 | lab_path = [lab_path[x] for x in custom_frames] 105 | if gt_only_first_frame: 106 | lab_path = [lab_path[0]] 107 | lab_path.extend([None] * (len(images_path) - 1)) 108 | elif no_gt: 109 | lab_path = [None] * len(images_path) 110 | if self.batch_gt: 111 | obj = self.seq_dict[seq] 112 | if -1 in obj: 113 | obj.remove(-1) 114 | for ii in range(len(img_list), len(images_path)+len(img_list)): 115 | self.all_seqs_list.append([obj, ii]) 116 | else: 117 | for obj in self.seq_dict[seq]: 118 | if obj != -1: 119 | for ii in range(len(img_list), len(images_path)+len(img_list)): 120 | self.all_seqs_list.append([obj, ii]) 121 | 122 | img_list.extend(images_path) 123 | labels.extend(lab_path) 124 | else: 125 | # Initialize the per sequence images for online training 126 | assert self.seq_name in self.seq_dict.keys(), '{} not in {} set.'.format(self.seq_name, '_'.join(self.split)) 127 | names_img = np.sort(os.listdir(os.path.join(self.db_root_dir, 'JPEGImages/480p/', str(seq_name)))) 128 | img_list = list(map(lambda x: os.path.join('JPEGImages/480p/', str(seq_name), x), names_img)) 129 | name_label = np.sort(os.listdir(os.path.join(self.db_root_dir, 'Annotations/480p/', str(seq_name)))) 130 | labels = list(map(lambda x: os.path.join('Annotations/480p/', str(seq_name), x), name_label)) 131 | prevmask_list = [] 132 | if num_frames is not None: 133 | seq_len = len(img_list) 134 | num_frames = min(num_frames, seq_len) 135 | frame_vector = np.arange(num_frames) 136 | frames_ids = list(np.round(frame_vector * seq_len / float(num_frames)).astype(np.int)) 137 | frames_ids[-1] = min(frames_ids[-1], seq_len) 138 | img_list = [img_list[x] for x in frames_ids] 139 | if no_gt: 140 | labels = [None] * len(img_list) 141 | else: 142 | labels = [labels[x] for x in frames_ids] 143 | elif isinstance(custom_frames, tuple) or isinstance(custom_frames, list): 144 | assert min(custom_frames) >= 0 and max(custom_frames) <= len(img_list) 145 | img_list = [img_list[x] for x in custom_frames] 146 | prevmask_list = [prev_round_masks[x] for x in custom_frames] 147 | if no_gt: 148 | labels = [None] * len(img_list) 149 | else: 150 | labels = [labels[x] for x in custom_frames] 151 | if gt_only_first_frame: 152 | labels = [labels[0]] 153 | labels.extend([None]*(len(img_list)-1)) 154 | elif no_gt: 155 | labels = [None] * len(img_list) 156 | if obj_id is not None: 157 | assert obj_id in self.seq_dict[self.seq_name], \ 158 | "{} doesn't have this object id {}.".format(self.seq_name, str(obj_id)) 159 | if self.batch_gt: 160 | self.obj_id = self.seq_dict[self.seq_name] 161 | if -1 in self.obj_id: 162 | self.obj_id.remove(-1) 163 | self.obj_id = [0]+self.obj_id 164 | 165 | assert (len(labels) == len(img_list)) 166 | 167 | if effective_batch: 168 | self.img_list = img_list * effective_batch 169 | self.labels = labels * effective_batch 170 | else: 171 | self.img_list = img_list 172 | self.labels = labels 173 | self.prevmasks_list = prevmask_list 174 | 175 | # print('Done initializing DAVIS2017 '+'_'.join(self.split)+' Dataset') 176 | # print('Number of images: {}'.format(len(self.img_list))) 177 | # if self.seq_name is None: 178 | # print('Number of elements {}'.format(len(self.all_seqs_list))) 179 | 180 | def _check_preprocess(self): 181 | _seq_list_file = self.seq_list_file 182 | if not os.path.isfile(_seq_list_file): 183 | return False 184 | else: 185 | self.seq_dict = json.load(open(self.seq_list_file, 'r')) 186 | return True 187 | 188 | def _preprocess(self): 189 | self.seq_dict = {} 190 | for seq in self.seqs: 191 | # Read object masks and get number of objects 192 | name_label = np.sort(os.listdir(os.path.join(self.db_root_dir, 'Annotations/480p/', seq))) 193 | label_path = os.path.join(self.db_root_dir, 'Annotations/480p/', seq, name_label[0]) 194 | _mask = np.array(Image.open(label_path)) 195 | _mask_ids = np.unique(_mask) 196 | n_obj = _mask_ids[-1] 197 | 198 | self.seq_dict[seq] = list(range(1, n_obj+1)) 199 | 200 | with open(self.seq_list_file, 'w') as outfile: 201 | outfile.write('{{\n\t"{:s}": {:s}'.format(self.seqs[0], json.dumps(self.seq_dict[self.seqs[0]]))) 202 | for ii in range(1, len(self.seqs)): 203 | outfile.write(',\n\t"{:s}": {:s}'.format(self.seqs[ii], json.dumps(self.seq_dict[self.seqs[ii]]))) 204 | outfile.write('\n}\n') 205 | 206 | print('Preprocessing finished') 207 | 208 | def __len__(self): 209 | if self.seq_name is None: 210 | return len(self.all_seqs_list) 211 | else: 212 | return len(self.img_list) 213 | 214 | def __getitem__(self, idx): 215 | # print(idx) 216 | img, gt, prev_round_mask = self.make_img_gt_mask_pair(idx) 217 | 218 | pad_img, pad_info = utils.apply_pad(img) 219 | pad_gt= utils.apply_pad(gt, padinfo = pad_info)#h,w,n 220 | sample = {'image': pad_img, 'gt': pad_gt} 221 | 222 | 223 | if self.retname: 224 | if self.seq_name is None: 225 | obj_id = self.all_seqs_list[idx][0] 226 | img_path = self.img_list[self.all_seqs_list[idx][1]] 227 | else: 228 | obj_id = self.obj_id 229 | img_path = self.img_list[idx] 230 | seq_name = img_path.split('/')[-2] 231 | frame_id = img_path.split('/')[-1].split('.')[-2] 232 | sample['meta'] = {'seq_name': seq_name, 233 | 'frame_id': frame_id, 234 | 'obj_id': obj_id, 235 | 'im_size': (img.shape[0], img.shape[1]), 236 | 'pad_size': (pad_img.shape[0], pad_img.shape[1]), 237 | 'pad_info': pad_info} 238 | 239 | if self.transform is not None: 240 | sample = self.transform(sample) 241 | 242 | return sample 243 | 244 | 245 | def make_img_gt_mask_pair(self, idx): 246 | """ 247 | Make the image-ground-truth pair 248 | """ 249 | prev_round_mask_tmp = self.prevmasks_list[idx] 250 | if self.seq_name is None: 251 | obj_id = self.all_seqs_list[idx][0] 252 | img_path = self.img_list[self.all_seqs_list[idx][1]] 253 | label_path = self.labels[self.all_seqs_list[idx][1]] 254 | else: 255 | obj_id = self.obj_id 256 | img_path = self.img_list[idx] 257 | label_path = self.labels[idx] 258 | seq_name = img_path.split('/')[-2] 259 | n_obj = 1 if isinstance(obj_id, int) else len(obj_id) 260 | img = cv2.imread(os.path.join(self.db_root_dir, img_path)) 261 | img = np.array(img, dtype=np.float32) 262 | if self.rgb: 263 | img = img[:, :, [2, 1, 0]] 264 | 265 | if label_path is not None: 266 | label = Image.open(os.path.join(self.db_root_dir, label_path)) 267 | else: 268 | if self.batch_gt: 269 | gt = np.zeros(np.append(img.shape[:-1], n_obj), dtype=np.float32) 270 | else: 271 | gt = np.zeros(img.shape[:-1], dtype=np.float32) 272 | 273 | if label_path is not None: 274 | gt_tmp = np.array(label, dtype=np.uint8) 275 | if self.batch_gt: 276 | gt = np.zeros(np.append(n_obj, gt_tmp.shape), dtype=np.float32) 277 | for ii, k in enumerate(obj_id): 278 | gt[ii, :, :] = gt_tmp == k 279 | gt = gt.transpose((1, 2, 0)) 280 | else: 281 | gt = (gt_tmp == obj_id).astype(np.float32) 282 | 283 | if self.batch_gt: 284 | prev_round_mask = np.zeros(np.append(img.shape[:-1], n_obj), dtype=np.float32) 285 | for ii, k in enumerate(obj_id): 286 | prev_round_mask[:, :, ii] = prev_round_mask_tmp == k 287 | else: 288 | prev_round_mask = (prev_round_mask_tmp == obj_id).astype(np.float32) 289 | 290 | return img, gt, prev_round_mask 291 | 292 | def get_img_size(self): 293 | img = cv2.imread(os.path.join(self.db_root_dir, self.img_list[0])) 294 | return list(img.shape[:2]) 295 | 296 | def __str__(self): 297 | return 'DAVIS2017' 298 | 299 | if __name__ =='__main__': 300 | a = DAVIS2017(split='val', custom_frames=[21,22], seq_name='gold-fish', rgb=True, no_gt=False, retname=True,prev_round_masks=np.zeros([40,480,854])) 301 | c= a.__getitem__(0) 302 | b=1 -------------------------------------------------------------------------------- /libs/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | 5 | from davisinteractive.utils.operations import bresenham 6 | 7 | def mkdir(paths): 8 | if not isinstance(paths, (list, tuple)): 9 | paths = [paths] 10 | for path in paths: 11 | if not os.path.isdir(path): 12 | os.makedirs(path) 13 | 14 | 15 | class logger: 16 | def __init__(self, log_file): 17 | self.log_file = log_file 18 | 19 | def printNlog(self,str2print): 20 | print(str2print) 21 | with open(self.log_file, 'a') as f: 22 | f.write(str2print + '\n') 23 | f.close() 24 | 25 | 26 | def apply_pad(img, padinfo=None): 27 | if padinfo: # ((hpad,hpad),(wpad,wpad)) 28 | (hpad, wpad) = padinfo 29 | if len(img.shape)==3 : pad_img = np.pad(img, (hpad, wpad, (0, 0)), mode='reflect') # H,W,3 30 | else: pad_img = np.pad(img, (hpad, wpad), mode='reflect') #H,W 31 | return pad_img 32 | else: 33 | h, w = img.shape[0:2] 34 | new_h = h + 32 - h % 32 35 | new_w = w + 32 - w % 32 36 | # print(new_h, new_w) 37 | lh, uh = (new_h - h) / 2, (new_h - h) / 2 + (new_h - h) % 2 38 | lw, uw = (new_w - w) / 2, (new_w - w) / 2 + (new_w - w) % 2 39 | lh, uh, lw, uw = int(lh), int(uh), int(lw), int(uw) 40 | if len(img.shape)==3 : pad_img = np.pad(img, ((lh, uh), (lw, uw), (0, 0)), mode='reflect') # H,W,3 41 | else: pad_img = np.pad(img, ((lh, uh), (lw, uw)), mode='reflect') # H,W 42 | info = ((lh, uh), (lw, uw)) 43 | 44 | return pad_img, info 45 | 46 | 47 | def get_prop_list(annotated_frames, annotated_now, num_frames, proportion = 1.0, get_close_anno_frames = False): 48 | 49 | aligned_anno = sorted(annotated_frames) 50 | overlap = aligned_anno.count(annotated_now) 51 | for i in range(overlap): 52 | aligned_anno.remove(annotated_now) 53 | 54 | start_frame, end_frame = 0, num_frames -1 55 | for i in range(len(aligned_anno)): 56 | if aligned_anno[i] > annotated_now: 57 | end_frame = aligned_anno[i] - 1 58 | break 59 | aligned_anno.reverse() 60 | for i in range(len(aligned_anno)): 61 | if aligned_anno[i] < annotated_now: 62 | start_frame = aligned_anno[i]+1 63 | break 64 | 65 | if get_close_anno_frames: 66 | close_frames_round=dict() # 1st column: iaction idx, 2nd column: the close frames 67 | annotated_frames.reverse() 68 | try: close_frames_round["left"] = len(annotated_frames) - annotated_frames.index(start_frame-1) - 1 69 | except: print('No left annotated fr') 70 | try: close_frames_round["right"] = len(annotated_frames) - annotated_frames.index(end_frame) - 1 71 | except: print('No right annotated fr') 72 | 73 | if proportion != 1.0: 74 | if start_frame!=0: 75 | start_frame = annotated_now - int((annotated_now-start_frame)*proportion + 0.5) 76 | if end_frame != num_frames-1: 77 | end_frame = annotated_now + int((end_frame - annotated_now) * proportion + 0.5) 78 | prop_list = list(range(annotated_now,start_frame-1,-1)) + list(range(annotated_now,end_frame+1)) 79 | if len(prop_list)==0: 80 | prop_list = [annotated_now] 81 | 82 | if not get_close_anno_frames: 83 | return prop_list 84 | 85 | else: 86 | return prop_list, close_frames_round 87 | 88 | 89 | def scribble_to_image(scribbles, currentframe, obj_id, prev_mask, dilation=5, 90 | nocare_area=None, bresenhamtf=True, blur=True, singleimg=False, seperate_pos_neg = False): 91 | """ Make scrible to previous mask shaped numpyfile 92 | 93 | 94 | """ 95 | h,w = prev_mask.shape 96 | regions2exclude_on_maskneg = prev_mask!=obj_id 97 | mask = np.zeros([h,w]) 98 | mask_neg = np.zeros([h,w]) 99 | if singleimg: 100 | scribbles=scribbles 101 | else: scribbles = scribbles[currentframe] 102 | 103 | for scribble in scribbles: 104 | points_scribble = np.round(np.array(scribble['path']) * np.array((w, h))).astype(np.int) 105 | if bresenhamtf and len(points_scribble) > 1: 106 | all_points = bresenham(points_scribble) 107 | else: 108 | all_points = points_scribble 109 | 110 | if obj_id==0: 111 | raise NotImplementedError 112 | else: 113 | if scribble['object_id'] == obj_id: 114 | mask[all_points[:, 1] - 1, all_points[:, 0] - 1] = 1 115 | else: 116 | mask_neg[all_points[:, 1] - 1, all_points[:, 0] - 1] = 1 117 | # else: 118 | # mask_neg[all_points[:, 1] - 1, all_points[:, 0] - 1] = 1 119 | 120 | scr_gt, _ = scrimg_postprocess(mask, dilation=dilation, nocare_area=nocare_area, blur=blur, blursize=(5, 5)) 121 | scr_gt_neg, _ = scrimg_postprocess(mask_neg, dilation=dilation, nocare_area=nocare_area, blur=blur, blursize=(5, 5)) 122 | scr_gt_neg[regions2exclude_on_maskneg] = 0 123 | 124 | if seperate_pos_neg: 125 | return scr_gt.astype(np.float32), scr_gt_neg.astype(np.float32) 126 | else: 127 | scr_img = scr_gt - scr_gt_neg 128 | return scr_img.astype(np.float32) 129 | 130 | 131 | def scrimg_postprocess(scr, dilation=7, nocare_area=21, blur = False, blursize=(5, 5), var = 6.0, custom_blur = None): 132 | 133 | # Compute foreground 134 | if scr.max() == 1: 135 | kernel_fg = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilation, dilation)) 136 | fg = cv2.dilate(scr.astype(np.uint8), kernel=kernel_fg).astype(scr.dtype) 137 | else: 138 | fg = scr 139 | 140 | # Compute nocare area 141 | if nocare_area is None: 142 | nocare = None 143 | else: 144 | kernel_nc = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (nocare_area, nocare_area)) 145 | nocare = cv2.dilate(fg, kernel=kernel_nc) - fg 146 | if blur: 147 | fg = cv2.GaussianBlur(fg,ksize=blursize,sigmaX=var) 148 | elif custom_blur: 149 | c_kernel = np.array([[1,2,3,2,1],[2,4,9,4,2],[3,9,64,9,3],[2,4,9,4,2],[1,2,3,2,1]]) 150 | c_kernel = c_kernel/np.sum(c_kernel) 151 | fg = cv2.filter2D(fg,ddepth=-1,kernel = c_kernel) 152 | 153 | return fg, nocare 154 | -------------------------------------------------------------------------------- /libs/utils_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def combine_masks_with_batch(masks, n_obj, th=0.5, return_as_onehot = False): 4 | """ Combine mask for different objects. 5 | 6 | Different methods are the following: 7 | 8 | * `max_per_pixel`: Computes the final mask taking the pixel with the highest 9 | probability for every object. 10 | 11 | # Arguments 12 | masks: Tensor with shape[B, nobj, H, W]. H, W on batches must be same 13 | method: String. Method that specifies how the masks are fused. 14 | 15 | # Returns 16 | [B, 1, H, W] 17 | """ 18 | 19 | # masks : B, nobj, h, w 20 | # output : h,w 21 | marker = torch.argmax(masks, dim=1, keepdim=True) # 22 | if not return_as_onehot: 23 | out_mask = torch.unsqueeze(torch.zeros_like(masks)[:,0],1) #[B, 1, H, W] 24 | for obj_id in range(n_obj): 25 | try :tmp_mask = (marker == obj_id) * (masks[:,obj_id].unsqueeze(1) > th) 26 | except: raise NotImplementedError 27 | out_mask[tmp_mask] = obj_id + 1 # [B, 1, H, W] 28 | 29 | if return_as_onehot: 30 | out_mask = torch.zeros_like(masks) # [B, nobj, H, W] 31 | for obj_id in range(n_obj): 32 | try :tmp_mask = (marker == obj_id) * (masks[:,obj_id].unsqueeze(1) > th) 33 | except: raise NotImplementedError 34 | out_mask[:, obj_id] = tmp_mask[:,0].type(torch.cuda.FloatTensor) 35 | 36 | return out_mask 37 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /networks/atnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from networks.deeplab.aspp import ASPP 6 | from networks.deeplab.backbone.resnet import SEResNet50 7 | from networks.correlation_package.correlation import Correlation 8 | from networks.ltm_transfer import LTM_transfer 9 | 10 | 11 | class ATnet(nn.Module): 12 | def __init__(self, pretrained=1, resfix=False, corr_displacement=4, corr_stride=2): 13 | super(ATnet, self).__init__() 14 | print("Constructing ATnet architecture..") 15 | 16 | self.encoder_6ch = Encoder_6ch(resfix) 17 | self.encoder_3ch = Encoder_3ch(resfix) 18 | self.indicator_encoder = ConverterEncoder() # 19 | self.decoder_iact = Decoder() 20 | self.decoder_prop = Decoder_prop() 21 | 22 | self.ltm_local_affinity = Correlation(pad_size=corr_displacement * corr_stride, kernel_size=1, 23 | max_displacement=corr_displacement * corr_stride, 24 | stride1=1, stride2=corr_stride, corr_multiply=1) 25 | self.ltm_transfer = LTM_transfer(md=corr_displacement, stride=corr_stride) 26 | 27 | self.prev_conv1x1 = nn.Conv2d(256, 256, kernel_size=1, padding=0) # 1/4, 256 28 | self.conv1x1 = nn.Conv2d(2048*2, 2048, kernel_size=1, padding=0) # 1/16, 2048 29 | 30 | self.refer_weight = None 31 | self._initialize_weights(pretrained) 32 | 33 | def forward_ANet(self, x): # Bx4xHxW to Bx1xHxW 34 | r5, r4, r3, r2 = self.encoder_6ch(x) 35 | estimated_mask, m2 = self.decoder_iact(r5, r3, r2, only_return_feature=False) 36 | r5_indicator = self.indicator_encoder(r5, m2) 37 | return estimated_mask, r5_indicator 38 | 39 | def forward_TNet(self, anno_propEnc_r5_list, targframe_3ch, anno_iactEnc_r5_list, r2_prev, predmask_prev, debug_f_mask = False): #1/16, 2048 40 | f_targ, _, r3_targ, r2_targ = self.encoder_3ch(targframe_3ch) 41 | f_mask_r5 = self.correlation_global_transfer(anno_propEnc_r5_list, f_targ, anno_iactEnc_r5_list) # 1/16, 2048 42 | 43 | r2_targ_c = self.prev_conv1x1(r2_targ) 44 | r2_prev = self.prev_conv1x1(r2_prev) 45 | f_mask_r2 = self.correlation_local_transfer(r2_prev, r2_targ_c, predmask_prev) # 1/4, 1 [B,1,H/4,W/4] 46 | 47 | r5_concat = torch.cat([f_targ, f_mask_r5], dim=1) # 1/16, 2048*2 48 | r5_concat = self.conv1x1(r5_concat) 49 | estimated_mask, m2 = self.decoder_prop(r5_concat, r3_targ, r2_targ, f_mask_r2) 50 | 51 | if not debug_f_mask: 52 | return estimated_mask, r2_targ 53 | else: 54 | return estimated_mask, r2_targ, f_mask_r2 55 | 56 | def correlation_global_transfer(self, anno_feature_list, targ_feature, anno_indicator_feature_list ): 57 | ''' 58 | :param anno_feature_list: [B,C,H,W] x list (N values in list) 59 | :param targ_feature: [B,C,H,W] 60 | :param anno_indicator_feature_list: [B,C,H,W] x list (N values in list) 61 | :return targ_mask_feature: [B,C,H,W] 62 | ''' 63 | 64 | b, c, h, w = anno_indicator_feature_list[0].size() # b means n_objs 65 | targ_feature = targ_feature.view(b, c, h * w) # [B, C, HxW] 66 | n_features = len(anno_feature_list) 67 | anno_feature = [] 68 | for f_idx in range(n_features): 69 | anno_feature.append(anno_feature_list[f_idx].view(b, c, h * w).transpose(1, 2)) # [B, HxW', C] 70 | anno_feature = torch.cat(anno_feature, dim=1) # [B, NxHxW', C] 71 | sim_feature = torch.bmm(anno_feature, targ_feature) # [B, NxHxW', HxW] 72 | sim_feature = F.softmax(sim_feature, dim=2) / n_features # [B, NxHxW', HxW] 73 | anno_indicator_feature = [] 74 | for f_idx in range(n_features): 75 | anno_indicator_feature.append(anno_indicator_feature_list[f_idx].view(b, c, h * w)) # [B, C, HxW'] 76 | anno_indicator_feature = torch.cat(anno_indicator_feature, dim=-1) # [B, C, NxHxW'] 77 | targ_mask_feature = torch.bmm(anno_indicator_feature, sim_feature) # [B, C, HxW] 78 | targ_mask_feature = targ_mask_feature.view(b, c, h, w) 79 | 80 | return targ_mask_feature 81 | 82 | def correlation_local_transfer(self, r2_prev, r2_targ, predmask_prev): 83 | ''' 84 | 85 | :param r2_prev: [B,C,H,W] 86 | :param r2_targ: [B,C,H,W] 87 | :param predmask_prev: [B,1,4*H,4*W] 88 | :return targ_mask_feature_r2: [B,1,H,W] 89 | ''' 90 | 91 | predmask_prev = F.interpolate(predmask_prev, scale_factor=0.25, mode='bilinear',align_corners=True) # B,1,H,W 92 | sim_feature = self.ltm_local_affinity.forward(r2_targ,r2_prev,) # B,D^2,H,W 93 | sim_feature = F.softmax(sim_feature, dim=2) # B,D^2,H,W 94 | predmask_targ = self.ltm_transfer.forward(sim_feature, predmask_prev, apply_softmax_on_simfeature = False) # B,1,H,W 95 | 96 | return predmask_targ 97 | 98 | def _initialize_weights(self, pretrained): 99 | for m in self.modules(): 100 | if pretrained: 101 | break 102 | else: 103 | if isinstance(m, nn.Conv2d): 104 | m.weight.data.normal_(0, 0.001) 105 | if m.bias is not None: 106 | m.bias.data.zero_() 107 | elif isinstance(m, nn.BatchNorm2d): 108 | m.weight.data.fill_(1) 109 | m.bias.data.zero_() 110 | elif isinstance(m, nn.Linear): 111 | m.weight.data.normal_(0, 0.01) 112 | m.bias.data.zero_() 113 | 114 | 115 | class Encoder_3ch(nn.Module): 116 | # T-Net Encoder 117 | def __init__(self, resfix): 118 | super(Encoder_3ch, self).__init__() 119 | 120 | self.conv0_3ch = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=True) 121 | 122 | resnet = SEResNet50(output_stride=16, BatchNorm=nn.BatchNorm2d, pretrained=True) 123 | self.bn1 = resnet.bn1 124 | self.relu = resnet.relu # 1/2, 64 125 | self.maxpool = resnet.maxpool 126 | 127 | self.res2 = resnet.layer1 # 1/4, 256 128 | self.res3 = resnet.layer2 # 1/8, 512 129 | self.res4 = resnet.layer3 # 1/16, 1024 130 | self.res5 = resnet.layer4 # 1/16, 2048 131 | 132 | # freeze BNs 133 | if resfix: 134 | for m in self.modules(): 135 | if isinstance(m, nn.BatchNorm2d): 136 | for p in m.parameters(): 137 | p.requires_grad = False 138 | 139 | def forward(self, x): 140 | x = self.conv0_3ch(x) # 1/2, 64 141 | x = self.bn1(x) 142 | c1 = self.relu(x) # 1/2, 64 143 | x = self.maxpool(c1) # 1/4, 64 144 | r2 = self.res2(x) # 1/4, 256 145 | r3 = self.res3(r2) # 1/8, 512 146 | r4 = self.res4(r3) # 1/16, 1024 147 | r5 = self.res5(r4) # 1/16, 2048 148 | 149 | return r5, r4, r3, r2 150 | 151 | def forward_r2(self,x): 152 | x = self.conv0_3ch(x) # 1/2, 64 153 | x = self.bn1(x) 154 | c1 = self.relu(x) # 1/2, 64 155 | x = self.maxpool(c1) # 1/4, 64 156 | r2 = self.res2(x) # 1/4, 256 157 | return r2 158 | 159 | 160 | class Encoder_6ch(nn.Module): 161 | # A-Net Encoder 162 | def __init__(self, resfix): 163 | super(Encoder_6ch, self).__init__() 164 | 165 | self.conv0_6ch = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=True) 166 | 167 | resnet = SEResNet50(output_stride=16, BatchNorm=nn.BatchNorm2d, pretrained=True) 168 | self.bn1 = resnet.bn1 169 | self.relu = resnet.relu # 1/2, 64 170 | self.maxpool = resnet.maxpool 171 | 172 | self.res2 = resnet.layer1 # 1/4, 256 173 | self.res3 = resnet.layer2 # 1/8, 512 174 | self.res4 = resnet.layer3 # 1/16, 1024 175 | self.res5 = resnet.layer4 # 1/16, 2048 176 | 177 | # freeze BNs 178 | if resfix: 179 | for m in self.modules(): 180 | if isinstance(m, nn.BatchNorm2d): 181 | for p in m.parameters(): 182 | p.requires_grad = False 183 | 184 | def forward(self, x): 185 | 186 | x = self.conv0_6ch(x) # 1/2, 64 187 | x = self.bn1(x) 188 | c1 = self.relu(x) # 1/2, 64 189 | x = self.maxpool(c1) # 1/4, 64 190 | r2 = self.res2(x) # 1/4, 256 191 | r3 = self.res3(r2) # 1/8, 512 192 | r4 = self.res4(r3) # 1/16, 1024 193 | r5 = self.res5(r4) # 1/16, 2048 194 | 195 | return r5, r4, r3, r2 196 | 197 | 198 | class Decoder(nn.Module): 199 | # A-Net Decoder 200 | def __init__(self): 201 | super(Decoder, self).__init__() 202 | mdim = 256 203 | 204 | self.aspp_decoder = ASPP(backbone='res', output_stride=16, BatchNorm=nn.BatchNorm2d, pretrained=1) 205 | self.convG0 = nn.Conv2d(2048, mdim, kernel_size=3, padding=1) 206 | self.convG1 = nn.Conv2d(mdim, mdim, kernel_size=3, padding=1) 207 | self.convG2 = nn.Conv2d(mdim, mdim, kernel_size=3, padding=1) 208 | 209 | self.RF3 = Refine(512, mdim) # 1/16 -> 1/8 210 | self.RF2 = Refine(256, mdim) # 1/8 -> 1/4 211 | 212 | self.lastconv = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1, bias=False), 213 | nn.BatchNorm2d(256), 214 | nn.ReLU(), 215 | nn.Dropout(0.5), 216 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 217 | nn.BatchNorm2d(256), 218 | nn.ReLU(), 219 | nn.Dropout(0.1), 220 | nn.Conv2d(256, 1, kernel_size=1, stride=1)) 221 | 222 | def forward(self, r5, r3_targ, r2_targ, only_return_feature = False): 223 | 224 | aspp_out = self.aspp_decoder(r5) #1/16 mdim 225 | aspp_out = F.interpolate(aspp_out, scale_factor=4, mode='bilinear',align_corners=True) #1/4 mdim 226 | m4 = self.convG0(F.relu(r5)) # out: # 1/16, mdim 227 | m4 = self.convG1(F.relu(m4)) # out: # 1/16, mdim 228 | m4 = self.convG2(F.relu(m4)) # out: # 1/16, mdim 229 | 230 | 231 | m3 = self.RF3(r3_targ, m4) # out: 1/8, mdim 232 | m2 = self.RF2(r2_targ, m3) # out: 1/4, mdim 233 | m2 = torch.cat((m2, aspp_out), dim=1) # out: 1/4, mdim*2 234 | 235 | if only_return_feature: 236 | return m2 237 | 238 | x = self.lastconv(m2) 239 | x = F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=True) 240 | 241 | return x, m2 242 | 243 | 244 | class Decoder_prop(nn.Module): 245 | # T-Net Decoder 246 | def __init__(self): 247 | super(Decoder_prop, self).__init__() 248 | mdim = 256 249 | 250 | self.aspp_decoder = ASPP(backbone='res', output_stride=16, BatchNorm=nn.BatchNorm2d, pretrained=1) 251 | self.convG0 = nn.Conv2d(2048, mdim, kernel_size=3, padding=1) 252 | self.convG1 = nn.Conv2d(mdim, mdim, kernel_size=3, padding=1) 253 | self.convG2 = nn.Conv2d(mdim, mdim, kernel_size=3, padding=1) 254 | 255 | self.RF3 = Refine(512, mdim) # 1/16 -> 1/8 256 | self.RF2 = Refine(256, mdim) # 1/8 -> 1/4 257 | 258 | self.lastconv = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1, bias=False), 259 | nn.BatchNorm2d(256), 260 | nn.ReLU(), 261 | nn.Dropout(0.5), 262 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 263 | nn.BatchNorm2d(256), 264 | nn.ReLU(), 265 | nn.Dropout(0.1), 266 | nn.Conv2d(256, 1, kernel_size=1, stride=1)) 267 | 268 | def forward(self, r5, r3_targ, r2_targ, f_mask_r2): 269 | 270 | aspp_out = self.aspp_decoder(r5) #1/16 mdim 271 | aspp_out = F.interpolate(aspp_out, scale_factor=4, mode='bilinear',align_corners=True) #1/4 mdim 272 | m4 = self.convG0(F.relu(r5)) # out: # 1/16, mdim 273 | m4 = self.convG1(F.relu(m4)) # out: # 1/16, mdim 274 | m4 = self.convG2(F.relu(m4)) # out: # 1/16, mdim 275 | 276 | m3 = self.RF3(r3_targ, m4) # out: 1/8, mdim 277 | m3 = m3 + 0.5 * F.interpolate(f_mask_r2, scale_factor=0.5, mode='bilinear',align_corners=True) #1/4 mdim 278 | m2 = self.RF2(r2_targ, m3) # out: 1/4, mdim 279 | m2 = m2 + 0.5 * f_mask_r2 280 | m2 = torch.cat((m2, aspp_out), dim=1) # out: 1/4, mdim*2 281 | 282 | x = self.lastconv(m2) 283 | x = F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=True) 284 | 285 | return x, m2 286 | 287 | 288 | class ConverterEncoder(nn.Module): 289 | def __init__(self): 290 | super(ConverterEncoder, self).__init__() 291 | # [1/4, 512] to [1/8, 1024] 292 | downsample1 = nn.Sequential(nn.Conv2d(512, 1024, kernel_size=1, stride=2, bias=False), 293 | nn.BatchNorm2d(1024), 294 | ) 295 | self.block1 = SEBottleneck(512, 256, stride = 2, downsample = downsample1) 296 | # [1/8, 1024] to [1/16, 2048] 297 | downsample2 = nn.Sequential(nn.Conv2d(1024, 2048, kernel_size=1, stride=2, bias=False), 298 | nn.BatchNorm2d(2048), 299 | ) 300 | self.block2 = SEBottleneck(1024, 512, stride = 2, downsample=downsample2) 301 | self.conv1x1 = nn.Conv2d(2048 * 2, 2048, kernel_size=1, padding=0) # 1/16, 2048 302 | 303 | def forward(self, r5, m2): 304 | ''' 305 | 306 | :param r5: 1/16, 2048 307 | :param m2: 1/4, 512 308 | :return: 309 | ''' 310 | x = self.block1(m2) 311 | x = self.block2(x) 312 | x = torch.cat((x,r5),dim=1) 313 | x = self.conv1x1(x) 314 | 315 | return x 316 | 317 | 318 | class SEBottleneck(nn.Module): 319 | expansion = 4 320 | 321 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=nn.BatchNorm2d): 322 | super(SEBottleneck, self).__init__() 323 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 324 | self.bn1 = BatchNorm(planes) 325 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 326 | dilation=dilation, padding=dilation, bias=False) 327 | self.bn2 = BatchNorm(planes) 328 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 329 | self.bn3 = BatchNorm(planes * self.expansion) 330 | self.relu = nn.ReLU(inplace=True) 331 | # SE 332 | self.global_pool = nn.AdaptiveAvgPool2d(1) 333 | self.conv_down = nn.Conv2d( 334 | planes * 4, planes // 4, kernel_size=1, bias=False) 335 | self.conv_up = nn.Conv2d( 336 | planes // 4, planes * 4, kernel_size=1, bias=False) 337 | self.sig = nn.Sigmoid() 338 | 339 | self.downsample = downsample 340 | self.stride = stride 341 | self.dilation = dilation 342 | 343 | def forward(self, x): 344 | residual = x 345 | 346 | out = self.conv1(x) 347 | out = self.bn1(out) 348 | out = self.relu(out) 349 | 350 | out = self.conv2(out) 351 | out = self.bn2(out) 352 | out = self.relu(out) 353 | 354 | out = self.conv3(out) 355 | out = self.bn3(out) 356 | 357 | out1 = self.global_pool(out) 358 | out1 = self.conv_down(out1) 359 | out1 = self.relu(out1) 360 | out1 = self.conv_up(out1) 361 | out1 = self.sig(out1) 362 | 363 | if self.downsample is not None: 364 | residual = self.downsample(x) 365 | 366 | res = out1 * out + residual 367 | res = self.relu(res) 368 | 369 | return res 370 | 371 | 372 | class SELayer(nn.Module): 373 | def __init__(self, channel, reduction=16): 374 | super(SELayer, self).__init__() 375 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 376 | self.fc = nn.Sequential( 377 | nn.Linear(channel, channel // reduction, bias=False), 378 | nn.ReLU(inplace=True), 379 | nn.Linear(channel // reduction, channel, bias=False), 380 | nn.Sigmoid() 381 | ) 382 | 383 | def forward(self, x): 384 | b, c, _, _ = x.size() 385 | y = self.avg_pool(x).view(b, c) 386 | y = self.fc(y).view(b, c, 1, 1) 387 | return x * y.expand_as(x) 388 | 389 | 390 | class Refine(nn.Module): 391 | def __init__(self, inplanes, planes, scale_factor=2): 392 | super(Refine, self).__init__() 393 | self.convFS1 = nn.Conv2d(inplanes, planes, kernel_size=3, padding=1) 394 | self.convFS2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 395 | self.convFS3 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 396 | self.convMM1 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 397 | self.convMM2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 398 | self.scale_factor = scale_factor 399 | 400 | def forward(self, f, pm): 401 | s = self.convFS1(f) 402 | sr = self.convFS2(F.relu(s)) 403 | sr = self.convFS3(F.relu(sr)) 404 | s = s + sr 405 | 406 | m = s + F.interpolate(pm, scale_factor=self.scale_factor, mode='bilinear',align_corners=True) 407 | mr = self.convMM1(F.relu(m)) 408 | mr = self.convMM2(F.relu(mr)) 409 | m = m + mr 410 | return m 411 | -------------------------------------------------------------------------------- /networks/correlation_package.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuk6heo/IVOS-ATNet/1cf574953a96bd680c518c6362b510fd103ff271/networks/correlation_package.zip -------------------------------------------------------------------------------- /networks/deeplab/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /networks/deeplab/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuk6heo/IVOS-ATNet/1cf574953a96bd680c518c6362b510fd103ff271/networks/deeplab/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /networks/deeplab/__pycache__/aspp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuk6heo/IVOS-ATNet/1cf574953a96bd680c518c6362b510fd103ff271/networks/deeplab/__pycache__/aspp.cpython-36.pyc -------------------------------------------------------------------------------- /networks/deeplab/aspp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from networks.deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class _ASPPModule(nn.Module): 8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm, pretrained): 9 | super(_ASPPModule, self).__init__() 10 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 11 | stride=1, padding=padding, dilation=dilation, bias=False) 12 | self.bn = BatchNorm(planes) 13 | self.relu = nn.ReLU() 14 | 15 | self._init_weight(pretrained) 16 | 17 | def forward(self, x): 18 | x = self.atrous_conv(x) 19 | x = self.bn(x) 20 | 21 | return self.relu(x) 22 | 23 | def _init_weight(self,pretrained): 24 | for m in self.modules(): 25 | if pretrained: 26 | break 27 | else: 28 | 29 | if isinstance(m, nn.Conv2d): 30 | torch.nn.init.kaiming_normal_(m.weight) 31 | elif isinstance(m, SynchronizedBatchNorm2d): 32 | m.weight.data.fill_(1) 33 | m.bias.data.zero_() 34 | elif isinstance(m, nn.BatchNorm2d): 35 | m.weight.data.fill_(1) 36 | m.bias.data.zero_() 37 | 38 | class ASPP(nn.Module): 39 | def __init__(self, backbone, output_stride, BatchNorm, pretrained): 40 | super(ASPP, self).__init__() 41 | if backbone == 'drn': 42 | inplanes = 512 43 | elif backbone == 'mobilenet': 44 | inplanes = 320 45 | else: 46 | inplanes = 2048 47 | if output_stride == 16: 48 | dilations = [1, 6, 12, 18] 49 | elif output_stride == 8: 50 | dilations = [1, 12, 24, 36] 51 | else: 52 | raise NotImplementedError 53 | 54 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm, pretrained=pretrained) 55 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm, pretrained=pretrained) 56 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm, pretrained=pretrained) 57 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm, pretrained=pretrained) 58 | 59 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 60 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 61 | BatchNorm(256), 62 | nn.ReLU()) 63 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 64 | self.bn1 = BatchNorm(256) 65 | self.relu = nn.ReLU() 66 | self.dropout = nn.Dropout(0.5) 67 | self._init_weight(pretrained) 68 | 69 | def forward(self, x): 70 | x1 = self.aspp1(x) 71 | x2 = self.aspp2(x) 72 | x3 = self.aspp3(x) 73 | x4 = self.aspp4(x) 74 | x5 = self.global_avg_pool(x) 75 | # if type(x4.size()[2]) != int: 76 | # tmpsize = (x4.size()[2].item(),x4.size()[3].item()) 77 | # else: 78 | # tmpsize = (x4.size()[2],x4.size()[3]) 79 | # x5 = F.interpolate(x5, size=(14,14), mode='bilinear', align_corners=True) 80 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 81 | 82 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 83 | 84 | x = self.conv1(x) 85 | x = self.bn1(x) 86 | x = self.relu(x) 87 | 88 | return self.dropout(x) 89 | 90 | def _init_weight(self,pretrained): 91 | for m in self.modules(): 92 | if pretrained: 93 | break 94 | else: 95 | if isinstance(m, nn.Conv2d): 96 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 97 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 98 | torch.nn.init.kaiming_normal_(m.weight) 99 | elif isinstance(m, SynchronizedBatchNorm2d): 100 | m.weight.data.fill_(1) 101 | m.bias.data.zero_() 102 | elif isinstance(m, nn.BatchNorm2d): 103 | m.weight.data.fill_(1) 104 | m.bias.data.zero_() 105 | 106 | 107 | def build_aspp(backbone, output_stride, BatchNorm,pretrained): 108 | return ASPP(backbone, output_stride, BatchNorm, pretrained) -------------------------------------------------------------------------------- /networks/deeplab/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.deeplab.backbone import resnet, xception, drn, mobilenet 2 | 3 | def build_backbone(backbone, output_stride, BatchNorm): 4 | if backbone == 'resnet': 5 | return resnet.ResNet101(output_stride, BatchNorm) 6 | elif backbone == 'xception': 7 | return xception.AlignedXception(output_stride, BatchNorm) 8 | elif backbone == 'drn': 9 | return drn.drn_d_54(BatchNorm) 10 | elif backbone == 'mobilenet': 11 | return mobilenet.MobileNetV2(output_stride, BatchNorm) 12 | else: 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /networks/deeplab/backbone/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuk6heo/IVOS-ATNet/1cf574953a96bd680c518c6362b510fd103ff271/networks/deeplab/backbone/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /networks/deeplab/backbone/__pycache__/drn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuk6heo/IVOS-ATNet/1cf574953a96bd680c518c6362b510fd103ff271/networks/deeplab/backbone/__pycache__/drn.cpython-36.pyc -------------------------------------------------------------------------------- /networks/deeplab/backbone/__pycache__/mobilenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuk6heo/IVOS-ATNet/1cf574953a96bd680c518c6362b510fd103ff271/networks/deeplab/backbone/__pycache__/mobilenet.cpython-36.pyc -------------------------------------------------------------------------------- /networks/deeplab/backbone/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuk6heo/IVOS-ATNet/1cf574953a96bd680c518c6362b510fd103ff271/networks/deeplab/backbone/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /networks/deeplab/backbone/__pycache__/xception.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuk6heo/IVOS-ATNet/1cf574953a96bd680c518c6362b510fd103ff271/networks/deeplab/backbone/__pycache__/xception.cpython-36.pyc -------------------------------------------------------------------------------- /networks/deeplab/backbone/drn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | from networks.deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | 6 | webroot = 'https://tigress-web.princeton.edu/~fy/drn/models/' 7 | 8 | model_urls = { 9 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 10 | 'drn-c-26': webroot + 'drn_c_26-ddedf421.pth', 11 | 'drn-c-42': webroot + 'drn_c_42-9d336e8c.pth', 12 | 'drn-c-58': webroot + 'drn_c_58-0a53a92c.pth', 13 | 'drn-d-22': webroot + 'drn_d_22-4bd2f8ea.pth', 14 | 'drn-d-38': webroot + 'drn_d_38-eebb45f0.pth', 15 | 'drn-d-54': webroot + 'drn_d_54-0e0534ff.pth', 16 | 'drn-d-105': webroot + 'drn_d_105-12b40979.pth' 17 | } 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1): 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=padding, bias=False, dilation=dilation) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None, 29 | dilation=(1, 1), residual=True, BatchNorm=None): 30 | super(BasicBlock, self).__init__() 31 | self.conv1 = conv3x3(inplanes, planes, stride, 32 | padding=dilation[0], dilation=dilation[0]) 33 | self.bn1 = BatchNorm(planes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes, 36 | padding=dilation[1], dilation=dilation[1]) 37 | self.bn2 = BatchNorm(planes) 38 | self.downsample = downsample 39 | self.stride = stride 40 | self.residual = residual 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | if self.residual: 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None, 65 | dilation=(1, 1), residual=True, BatchNorm=None): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = BatchNorm(planes) 69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 70 | padding=dilation[1], bias=False, 71 | dilation=dilation[1]) 72 | self.bn2 = BatchNorm(planes) 73 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 74 | self.bn3 = BatchNorm(planes * 4) 75 | self.relu = nn.ReLU(inplace=True) 76 | self.downsample = downsample 77 | self.stride = stride 78 | 79 | def forward(self, x): 80 | residual = x 81 | 82 | out = self.conv1(x) 83 | out = self.bn1(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv2(out) 87 | out = self.bn2(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv3(out) 91 | out = self.bn3(out) 92 | 93 | if self.downsample is not None: 94 | residual = self.downsample(x) 95 | 96 | out += residual 97 | out = self.relu(out) 98 | 99 | return out 100 | 101 | 102 | class DRN(nn.Module): 103 | 104 | def __init__(self, block, layers, arch='D', 105 | channels=(16, 32, 64, 128, 256, 512, 512, 512), 106 | BatchNorm=None): 107 | super(DRN, self).__init__() 108 | self.inplanes = channels[0] 109 | self.out_dim = channels[-1] 110 | self.arch = arch 111 | 112 | if arch == 'C': 113 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=7, stride=1, 114 | padding=3, bias=False) 115 | self.bn1 = BatchNorm(channels[0]) 116 | self.relu = nn.ReLU(inplace=True) 117 | 118 | self.layer1 = self._make_layer( 119 | BasicBlock, channels[0], layers[0], stride=1, BatchNorm=BatchNorm) 120 | self.layer2 = self._make_layer( 121 | BasicBlock, channels[1], layers[1], stride=2, BatchNorm=BatchNorm) 122 | 123 | elif arch == 'D': 124 | self.layer0 = nn.Sequential( 125 | nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3, 126 | bias=False), 127 | BatchNorm(channels[0]), 128 | nn.ReLU(inplace=True) 129 | ) 130 | 131 | self.layer1 = self._make_conv_layers( 132 | channels[0], layers[0], stride=1, BatchNorm=BatchNorm) 133 | self.layer2 = self._make_conv_layers( 134 | channels[1], layers[1], stride=2, BatchNorm=BatchNorm) 135 | 136 | self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2, BatchNorm=BatchNorm) 137 | self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2, BatchNorm=BatchNorm) 138 | self.layer5 = self._make_layer(block, channels[4], layers[4], 139 | dilation=2, new_level=False, BatchNorm=BatchNorm) 140 | self.layer6 = None if layers[5] == 0 else \ 141 | self._make_layer(block, channels[5], layers[5], dilation=4, 142 | new_level=False, BatchNorm=BatchNorm) 143 | 144 | if arch == 'C': 145 | self.layer7 = None if layers[6] == 0 else \ 146 | self._make_layer(BasicBlock, channels[6], layers[6], dilation=2, 147 | new_level=False, residual=False, BatchNorm=BatchNorm) 148 | self.layer8 = None if layers[7] == 0 else \ 149 | self._make_layer(BasicBlock, channels[7], layers[7], dilation=1, 150 | new_level=False, residual=False, BatchNorm=BatchNorm) 151 | elif arch == 'D': 152 | self.layer7 = None if layers[6] == 0 else \ 153 | self._make_conv_layers(channels[6], layers[6], dilation=2, BatchNorm=BatchNorm) 154 | self.layer8 = None if layers[7] == 0 else \ 155 | self._make_conv_layers(channels[7], layers[7], dilation=1, BatchNorm=BatchNorm) 156 | 157 | self._init_weight() 158 | 159 | def _init_weight(self): 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 163 | m.weight.data.normal_(0, math.sqrt(2. / n)) 164 | elif isinstance(m, SynchronizedBatchNorm2d): 165 | m.weight.data.fill_(1) 166 | m.bias.data.zero_() 167 | elif isinstance(m, nn.BatchNorm2d): 168 | m.weight.data.fill_(1) 169 | m.bias.data.zero_() 170 | 171 | 172 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, 173 | new_level=True, residual=True, BatchNorm=None): 174 | assert dilation == 1 or dilation % 2 == 0 175 | downsample = None 176 | if stride != 1 or self.inplanes != planes * block.expansion: 177 | downsample = nn.Sequential( 178 | nn.Conv2d(self.inplanes, planes * block.expansion, 179 | kernel_size=1, stride=stride, bias=False), 180 | BatchNorm(planes * block.expansion), 181 | ) 182 | 183 | layers = list() 184 | layers.append(block( 185 | self.inplanes, planes, stride, downsample, 186 | dilation=(1, 1) if dilation == 1 else ( 187 | dilation // 2 if new_level else dilation, dilation), 188 | residual=residual, BatchNorm=BatchNorm)) 189 | self.inplanes = planes * block.expansion 190 | for i in range(1, blocks): 191 | layers.append(block(self.inplanes, planes, residual=residual, 192 | dilation=(dilation, dilation), BatchNorm=BatchNorm)) 193 | 194 | return nn.Sequential(*layers) 195 | 196 | def _make_conv_layers(self, channels, convs, stride=1, dilation=1, BatchNorm=None): 197 | modules = [] 198 | for i in range(convs): 199 | modules.extend([ 200 | nn.Conv2d(self.inplanes, channels, kernel_size=3, 201 | stride=stride if i == 0 else 1, 202 | padding=dilation, bias=False, dilation=dilation), 203 | BatchNorm(channels), 204 | nn.ReLU(inplace=True)]) 205 | self.inplanes = channels 206 | return nn.Sequential(*modules) 207 | 208 | def forward(self, x): 209 | if self.arch == 'C': 210 | x = self.conv1(x) 211 | x = self.bn1(x) 212 | x = self.relu(x) 213 | elif self.arch == 'D': 214 | x = self.layer0(x) 215 | 216 | x = self.layer1(x) 217 | x = self.layer2(x) 218 | 219 | x = self.layer3(x) 220 | low_level_feat = x 221 | 222 | x = self.layer4(x) 223 | x = self.layer5(x) 224 | 225 | if self.layer6 is not None: 226 | x = self.layer6(x) 227 | 228 | if self.layer7 is not None: 229 | x = self.layer7(x) 230 | 231 | if self.layer8 is not None: 232 | x = self.layer8(x) 233 | 234 | return x, low_level_feat 235 | 236 | 237 | class DRN_A(nn.Module): 238 | 239 | def __init__(self, block, layers, BatchNorm=None): 240 | self.inplanes = 64 241 | super(DRN_A, self).__init__() 242 | self.out_dim = 512 * block.expansion 243 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 244 | bias=False) 245 | self.bn1 = BatchNorm(64) 246 | self.relu = nn.ReLU(inplace=True) 247 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 248 | self.layer1 = self._make_layer(block, 64, layers[0], BatchNorm=BatchNorm) 249 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, BatchNorm=BatchNorm) 250 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, 251 | dilation=2, BatchNorm=BatchNorm) 252 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 253 | dilation=4, BatchNorm=BatchNorm) 254 | 255 | self._init_weight() 256 | 257 | def _init_weight(self): 258 | for m in self.modules(): 259 | if isinstance(m, nn.Conv2d): 260 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 261 | m.weight.data.normal_(0, math.sqrt(2. / n)) 262 | elif isinstance(m, SynchronizedBatchNorm2d): 263 | m.weight.data.fill_(1) 264 | m.bias.data.zero_() 265 | elif isinstance(m, nn.BatchNorm2d): 266 | m.weight.data.fill_(1) 267 | m.bias.data.zero_() 268 | 269 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 270 | downsample = None 271 | if stride != 1 or self.inplanes != planes * block.expansion: 272 | downsample = nn.Sequential( 273 | nn.Conv2d(self.inplanes, planes * block.expansion, 274 | kernel_size=1, stride=stride, bias=False), 275 | BatchNorm(planes * block.expansion), 276 | ) 277 | 278 | layers = [] 279 | layers.append(block(self.inplanes, planes, stride, downsample, BatchNorm=BatchNorm)) 280 | self.inplanes = planes * block.expansion 281 | for i in range(1, blocks): 282 | layers.append(block(self.inplanes, planes, 283 | dilation=(dilation, dilation, ), BatchNorm=BatchNorm)) 284 | 285 | return nn.Sequential(*layers) 286 | 287 | def forward(self, x): 288 | x = self.conv1(x) 289 | x = self.bn1(x) 290 | x = self.relu(x) 291 | x = self.maxpool(x) 292 | 293 | x = self.layer1(x) 294 | x = self.layer2(x) 295 | x = self.layer3(x) 296 | x = self.layer4(x) 297 | 298 | return x 299 | 300 | def drn_a_50(BatchNorm, pretrained=True): 301 | model = DRN_A(Bottleneck, [3, 4, 6, 3], BatchNorm=BatchNorm) 302 | if pretrained: 303 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 304 | return model 305 | 306 | 307 | def drn_c_26(BatchNorm, pretrained=True): 308 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='C', BatchNorm=BatchNorm) 309 | if pretrained: 310 | pretrained = model_zoo.load_url(model_urls['drn-c-26']) 311 | del pretrained['fc.weight'] 312 | del pretrained['fc.bias'] 313 | model.load_state_dict(pretrained) 314 | return model 315 | 316 | 317 | def drn_c_42(BatchNorm, pretrained=True): 318 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm) 319 | if pretrained: 320 | pretrained = model_zoo.load_url(model_urls['drn-c-42']) 321 | del pretrained['fc.weight'] 322 | del pretrained['fc.bias'] 323 | model.load_state_dict(pretrained) 324 | return model 325 | 326 | 327 | def drn_c_58(BatchNorm, pretrained=True): 328 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm) 329 | if pretrained: 330 | pretrained = model_zoo.load_url(model_urls['drn-c-58']) 331 | del pretrained['fc.weight'] 332 | del pretrained['fc.bias'] 333 | model.load_state_dict(pretrained) 334 | return model 335 | 336 | 337 | def drn_d_22(BatchNorm, pretrained=True): 338 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='D', BatchNorm=BatchNorm) 339 | if pretrained: 340 | pretrained = model_zoo.load_url(model_urls['drn-d-22']) 341 | del pretrained['fc.weight'] 342 | del pretrained['fc.bias'] 343 | model.load_state_dict(pretrained) 344 | return model 345 | 346 | 347 | def drn_d_24(BatchNorm, pretrained=True): 348 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 2, 2], arch='D', BatchNorm=BatchNorm) 349 | if pretrained: 350 | pretrained = model_zoo.load_url(model_urls['drn-d-24']) 351 | del pretrained['fc.weight'] 352 | del pretrained['fc.bias'] 353 | model.load_state_dict(pretrained) 354 | return model 355 | 356 | 357 | def drn_d_38(BatchNorm, pretrained=True): 358 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm) 359 | if pretrained: 360 | pretrained = model_zoo.load_url(model_urls['drn-d-38']) 361 | del pretrained['fc.weight'] 362 | del pretrained['fc.bias'] 363 | model.load_state_dict(pretrained) 364 | return model 365 | 366 | 367 | def drn_d_40(BatchNorm, pretrained=True): 368 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 2, 2], arch='D', BatchNorm=BatchNorm) 369 | if pretrained: 370 | pretrained = model_zoo.load_url(model_urls['drn-d-40']) 371 | del pretrained['fc.weight'] 372 | del pretrained['fc.bias'] 373 | model.load_state_dict(pretrained) 374 | return model 375 | 376 | 377 | def drn_d_54(BatchNorm, pretrained=True): 378 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm) 379 | if pretrained: 380 | pretrained = model_zoo.load_url(model_urls['drn-d-54']) 381 | del pretrained['fc.weight'] 382 | del pretrained['fc.bias'] 383 | model.load_state_dict(pretrained) 384 | return model 385 | 386 | 387 | def drn_d_105(BatchNorm, pretrained=True): 388 | model = DRN(Bottleneck, [1, 1, 3, 4, 23, 3, 1, 1], arch='D', BatchNorm=BatchNorm) 389 | if pretrained: 390 | pretrained = model_zoo.load_url(model_urls['drn-d-105']) 391 | del pretrained['fc.weight'] 392 | del pretrained['fc.bias'] 393 | model.load_state_dict(pretrained) 394 | return model 395 | 396 | if __name__ == "__main__": 397 | import torch 398 | model = drn_a_50(BatchNorm=nn.BatchNorm2d, pretrained=True) 399 | input = torch.rand(1, 3, 512, 512) 400 | output, low_level_feat = model(input) 401 | print(output.size()) 402 | print(low_level_feat.size()) 403 | -------------------------------------------------------------------------------- /networks/deeplab/backbone/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import math 5 | from networks.deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | import torch.utils.model_zoo as model_zoo 7 | 8 | def conv_bn(inp, oup, stride, BatchNorm): 9 | return nn.Sequential( 10 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 11 | BatchNorm(oup), 12 | nn.ReLU6(inplace=True) 13 | ) 14 | 15 | 16 | def fixed_padding(inputs, kernel_size, dilation): 17 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 18 | pad_total = kernel_size_effective - 1 19 | pad_beg = pad_total // 2 20 | pad_end = pad_total - pad_beg 21 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 22 | return padded_inputs 23 | 24 | 25 | class InvertedResidual(nn.Module): 26 | def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm): 27 | super(InvertedResidual, self).__init__() 28 | self.stride = stride 29 | assert stride in [1, 2] 30 | 31 | hidden_dim = round(inp * expand_ratio) 32 | self.use_res_connect = self.stride == 1 and inp == oup 33 | self.kernel_size = 3 34 | self.dilation = dilation 35 | 36 | if expand_ratio == 1: 37 | self.conv = nn.Sequential( 38 | # dw 39 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 40 | BatchNorm(hidden_dim), 41 | nn.ReLU6(inplace=True), 42 | # pw-linear 43 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False), 44 | BatchNorm(oup), 45 | ) 46 | else: 47 | self.conv = nn.Sequential( 48 | # pw 49 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False), 50 | BatchNorm(hidden_dim), 51 | nn.ReLU6(inplace=True), 52 | # dw 53 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 54 | BatchNorm(hidden_dim), 55 | nn.ReLU6(inplace=True), 56 | # pw-linear 57 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False), 58 | BatchNorm(oup), 59 | ) 60 | 61 | def forward(self, x): 62 | x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation) 63 | if self.use_res_connect: 64 | x = x + self.conv(x_pad) 65 | else: 66 | x = self.conv(x_pad) 67 | return x 68 | 69 | 70 | class MobileNetV2(nn.Module): 71 | def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True): 72 | super(MobileNetV2, self).__init__() 73 | block = InvertedResidual 74 | input_channel = 32 75 | current_stride = 1 76 | rate = 1 77 | interverted_residual_setting = [ 78 | # t, c, n, s 79 | [1, 16, 1, 1], 80 | [6, 24, 2, 2], 81 | [6, 32, 3, 2], 82 | [6, 64, 4, 2], 83 | [6, 96, 3, 1], 84 | [6, 160, 3, 2], 85 | [6, 320, 1, 1], 86 | ] 87 | 88 | # building first layer 89 | input_channel = int(input_channel * width_mult) 90 | self.features = [conv_bn(3, input_channel, 2, BatchNorm)] 91 | current_stride *= 2 92 | # building inverted residual blocks 93 | for t, c, n, s in interverted_residual_setting: 94 | if current_stride == output_stride: 95 | stride = 1 96 | dilation = rate 97 | rate *= s 98 | else: 99 | stride = s 100 | dilation = 1 101 | current_stride *= s 102 | output_channel = int(c * width_mult) 103 | for i in range(n): 104 | if i == 0: 105 | self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm)) 106 | else: 107 | self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm)) 108 | input_channel = output_channel 109 | self.features = nn.Sequential(*self.features) 110 | self._initialize_weights() 111 | 112 | if pretrained: 113 | self._load_pretrained_model() 114 | 115 | self.low_level_features = self.features[0:4] 116 | self.high_level_features = self.features[4:] 117 | 118 | def forward(self, x): 119 | low_level_feat = self.low_level_features(x) 120 | x = self.high_level_features(low_level_feat) 121 | return x, low_level_feat 122 | 123 | def _load_pretrained_model(self): 124 | pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth') 125 | model_dict = {} 126 | state_dict = self.state_dict() 127 | for k, v in pretrain_dict.items(): 128 | if k in state_dict: 129 | model_dict[k] = v 130 | state_dict.update(model_dict) 131 | self.load_state_dict(state_dict) 132 | 133 | def _initialize_weights(self): 134 | for m in self.modules(): 135 | if isinstance(m, nn.Conv2d): 136 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 137 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 138 | torch.nn.init.kaiming_normal_(m.weight) 139 | elif isinstance(m, SynchronizedBatchNorm2d): 140 | m.weight.data.fill_(1) 141 | m.bias.data.zero_() 142 | elif isinstance(m, nn.BatchNorm2d): 143 | m.weight.data.fill_(1) 144 | m.bias.data.zero_() 145 | 146 | if __name__ == "__main__": 147 | input = torch.rand(1, 3, 512, 512) 148 | model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d) 149 | output, low_level_feat = model(input) 150 | print(output.size()) 151 | print(low_level_feat.size()) 152 | -------------------------------------------------------------------------------- /networks/deeplab/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | from networks.deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | 6 | class Bottleneck(nn.Module): 7 | expansion = 4 8 | 9 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 10 | super(Bottleneck, self).__init__() 11 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 12 | self.bn1 = BatchNorm(planes) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 14 | dilation=dilation, padding=dilation, bias=False) 15 | self.bn2 = BatchNorm(planes) 16 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 17 | self.bn3 = BatchNorm(planes * self.expansion) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.downsample = downsample 20 | self.stride = stride 21 | self.dilation = dilation 22 | 23 | def forward(self, x): 24 | residual = x 25 | 26 | out = self.conv1(x) 27 | out = self.bn1(out) 28 | out = self.relu(out) 29 | 30 | out = self.conv2(out) 31 | out = self.bn2(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv3(out) 35 | out = self.bn3(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | class SEBottleneck(nn.Module): 46 | expansion = 4 47 | 48 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 49 | super(SEBottleneck, self).__init__() 50 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 51 | self.bn1 = BatchNorm(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 53 | dilation=dilation, padding=dilation, bias=False) 54 | self.bn2 = BatchNorm(planes) 55 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 56 | self.bn3 = BatchNorm(planes * self.expansion) 57 | self.relu = nn.ReLU(inplace=True) 58 | # SE 59 | self.global_pool = nn.AdaptiveAvgPool2d(1) 60 | self.conv_down = nn.Conv2d( 61 | planes * 4, planes // 4, kernel_size=1, bias=False) 62 | self.conv_up = nn.Conv2d( 63 | planes // 4, planes * 4, kernel_size=1, bias=False) 64 | self.sig = nn.Sigmoid() 65 | 66 | self.downsample = downsample 67 | self.stride = stride 68 | self.dilation = dilation 69 | 70 | def forward(self, x): 71 | residual = x 72 | 73 | out = self.conv1(x) 74 | out = self.bn1(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv2(out) 78 | out = self.bn2(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv3(out) 82 | out = self.bn3(out) 83 | 84 | out1 = self.global_pool(out) 85 | out1 = self.conv_down(out1) 86 | out1 = self.relu(out1) 87 | out1 = self.conv_up(out1) 88 | out1 = self.sig(out1) 89 | 90 | if self.downsample is not None: 91 | residual = self.downsample(x) 92 | 93 | res = out1 * out + residual 94 | res = self.relu(res) 95 | 96 | return res 97 | 98 | 99 | class ResNet(nn.Module): 100 | 101 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True, modelname = 'res101'): 102 | self.inplanes = 64 103 | self.modelname = modelname 104 | super(ResNet, self).__init__() 105 | blocks = [1, 2, 4] 106 | if output_stride == 16: 107 | strides = [1, 2, 2, 1] 108 | dilations = [1, 1, 1, 2] 109 | elif output_stride == 8: 110 | strides = [1, 2, 1, 1] 111 | dilations = [1, 1, 2, 4] 112 | else: 113 | raise NotImplementedError 114 | 115 | # Modules 116 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 117 | self.bn1 = BatchNorm(64) 118 | self.relu = nn.ReLU(inplace=True) 119 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 120 | 121 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm) 122 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm) 123 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm) 124 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 125 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 126 | self._init_weight() 127 | if pretrained: 128 | self._load_pretrained_model() 129 | 130 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 131 | downsample = None 132 | if stride != 1 or self.inplanes != planes * block.expansion: 133 | downsample = nn.Sequential( 134 | nn.Conv2d(self.inplanes, planes * block.expansion, 135 | kernel_size=1, stride=stride, bias=False), 136 | BatchNorm(planes * block.expansion), 137 | ) 138 | 139 | layers = [] 140 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 141 | self.inplanes = planes * block.expansion 142 | for i in range(1, blocks): 143 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 144 | 145 | return nn.Sequential(*layers) 146 | 147 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 148 | downsample = None 149 | if stride != 1 or self.inplanes != planes * block.expansion: 150 | downsample = nn.Sequential( 151 | nn.Conv2d(self.inplanes, planes * block.expansion, 152 | kernel_size=1, stride=stride, bias=False), 153 | BatchNorm(planes * block.expansion), 154 | ) 155 | 156 | layers = [] 157 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, 158 | downsample=downsample, BatchNorm=BatchNorm)) 159 | self.inplanes = planes * block.expansion 160 | for i in range(1, len(blocks)): 161 | layers.append(block(self.inplanes, planes, stride=1, 162 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm)) 163 | 164 | return nn.Sequential(*layers) 165 | 166 | def forward(self, input): 167 | x = self.conv1(input) 168 | x = self.bn1(x) 169 | x = self.relu(x) 170 | x = self.maxpool(x) 171 | 172 | x = self.layer1(x) #256 128 128 173 | low_level_feat = x 174 | x = self.layer2(x) #512 64 64 175 | x = self.layer3(x) #1024 32 32 176 | x = self.layer4(x) #2048 32 32 177 | return x, low_level_feat 178 | 179 | def _init_weight(self): 180 | for m in self.modules(): 181 | if isinstance(m, nn.Conv2d): 182 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 183 | m.weight.data.normal_(0, math.sqrt(2. / n)) 184 | elif isinstance(m, SynchronizedBatchNorm2d): 185 | m.weight.data.fill_(1) 186 | m.bias.data.zero_() 187 | elif isinstance(m, nn.BatchNorm2d): 188 | m.weight.data.fill_(1) 189 | m.bias.data.zero_() 190 | 191 | def _load_pretrained_model(self): 192 | if self.modelname =='res101': 193 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth') 194 | elif self.modelname == 'res50': 195 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet50-19c8e357.pth') 196 | elif self.modelname == 'SEres50': 197 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet50-19c8e357.pth') 198 | else: raise NotImplementedError 199 | model_dict = {} 200 | state_dict = self.state_dict() 201 | for k, v in pretrain_dict.items(): 202 | if k in state_dict: 203 | model_dict[k] = v 204 | state_dict.update(model_dict) 205 | self.load_state_dict(state_dict) 206 | 207 | def ResNet101(output_stride, BatchNorm, pretrained=True,): 208 | """Constructs a ResNet-101 model. 209 | Args: 210 | pretrained (bool): If True, returns a model pre-trained on ImageNet 211 | """ 212 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained, modelname='res101') 213 | return model 214 | 215 | def ResNet50(output_stride, BatchNorm, pretrained=True): 216 | """Constructs a ResNet-50 model. 217 | 218 | Args: 219 | pretrained (bool): If True, returns a model pre-trained on ImageNet 220 | """ 221 | model = ResNet(Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, pretrained=pretrained, modelname='res50') 222 | return model 223 | 224 | def SEResNet50(output_stride, BatchNorm, pretrained=True): 225 | """Constructs a ResNet-50 model. 226 | 227 | Args: 228 | pretrained (bool): If True, returns a model pre-trained on ImageNet 229 | """ 230 | model = ResNet(SEBottleneck, [3, 4, 6, 3], output_stride, BatchNorm, pretrained=pretrained, modelname='SEres50') 231 | return model 232 | 233 | if __name__ == "__main__": 234 | import torch 235 | model = ResNet50(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=16) 236 | input = torch.rand(1, 3, 512, 512) 237 | output, low_level_feat = model(input) 238 | print(output.size()) 239 | print(low_level_feat.size()) -------------------------------------------------------------------------------- /networks/deeplab/backbone/xception.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from networks.deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 7 | 8 | def fixed_padding(inputs, kernel_size, dilation): 9 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 10 | pad_total = kernel_size_effective - 1 11 | pad_beg = pad_total // 2 12 | pad_end = pad_total - pad_beg 13 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 14 | return padded_inputs 15 | 16 | 17 | class SeparableConv2d(nn.Module): 18 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, BatchNorm=None): 19 | super(SeparableConv2d, self).__init__() 20 | 21 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation, 22 | groups=inplanes, bias=bias) 23 | self.bn = BatchNorm(inplanes) 24 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) 25 | 26 | def forward(self, x): 27 | x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0]) 28 | x = self.conv1(x) 29 | x = self.bn(x) 30 | x = self.pointwise(x) 31 | return x 32 | 33 | 34 | class Block(nn.Module): 35 | def __init__(self, inplanes, planes, reps, stride=1, dilation=1, BatchNorm=None, 36 | start_with_relu=True, grow_first=True, is_last=False): 37 | super(Block, self).__init__() 38 | 39 | if planes != inplanes or stride != 1: 40 | self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False) 41 | self.skipbn = BatchNorm(planes) 42 | else: 43 | self.skip = None 44 | 45 | self.relu = nn.ReLU(inplace=True) 46 | rep = [] 47 | 48 | filters = inplanes 49 | if grow_first: 50 | rep.append(self.relu) 51 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 52 | rep.append(BatchNorm(planes)) 53 | filters = planes 54 | 55 | for i in range(reps - 1): 56 | rep.append(self.relu) 57 | rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, BatchNorm=BatchNorm)) 58 | rep.append(BatchNorm(filters)) 59 | 60 | if not grow_first: 61 | rep.append(self.relu) 62 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 63 | rep.append(BatchNorm(planes)) 64 | 65 | if stride != 1: 66 | rep.append(self.relu) 67 | rep.append(SeparableConv2d(planes, planes, 3, 2, BatchNorm=BatchNorm)) 68 | rep.append(BatchNorm(planes)) 69 | 70 | if stride == 1 and is_last: 71 | rep.append(self.relu) 72 | rep.append(SeparableConv2d(planes, planes, 3, 1, BatchNorm=BatchNorm)) 73 | rep.append(BatchNorm(planes)) 74 | 75 | if not start_with_relu: 76 | rep = rep[1:] 77 | 78 | self.rep = nn.Sequential(*rep) 79 | 80 | def forward(self, inp): 81 | x = self.rep(inp) 82 | 83 | if self.skip is not None: 84 | skip = self.skip(inp) 85 | skip = self.skipbn(skip) 86 | else: 87 | skip = inp 88 | 89 | x = x + skip 90 | 91 | return x 92 | 93 | 94 | class AlignedXception(nn.Module): 95 | """ 96 | Modified Alighed Xception 97 | """ 98 | def __init__(self, output_stride, BatchNorm, 99 | pretrained=True): 100 | super(AlignedXception, self).__init__() 101 | 102 | if output_stride == 16: 103 | entry_block3_stride = 2 104 | middle_block_dilation = 1 105 | exit_block_dilations = (1, 2) 106 | elif output_stride == 8: 107 | entry_block3_stride = 1 108 | middle_block_dilation = 2 109 | exit_block_dilations = (2, 4) 110 | else: 111 | raise NotImplementedError 112 | 113 | 114 | # Entry flow 115 | self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False) 116 | self.bn1 = BatchNorm(32) 117 | self.relu = nn.ReLU(inplace=True) 118 | 119 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False) 120 | self.bn2 = BatchNorm(64) 121 | 122 | self.block1 = Block(64, 128, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False) 123 | self.block2 = Block(128, 256, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False, 124 | grow_first=True) 125 | self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, BatchNorm=BatchNorm, 126 | start_with_relu=True, grow_first=True, is_last=True) 127 | 128 | # Middle flow 129 | self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 130 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 131 | self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 132 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 133 | self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 134 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 135 | self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 136 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 137 | self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 138 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 139 | self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 140 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 141 | self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 142 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 143 | self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 144 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 145 | self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 146 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 147 | self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 148 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 149 | self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 150 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 151 | self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 152 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 153 | self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 154 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 155 | self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 156 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 157 | self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 158 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 159 | self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 160 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 161 | 162 | # Exit flow 163 | self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_dilations[0], 164 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=False, is_last=True) 165 | 166 | self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 167 | self.bn3 = BatchNorm(1536) 168 | 169 | self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 170 | self.bn4 = BatchNorm(1536) 171 | 172 | self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 173 | self.bn5 = BatchNorm(2048) 174 | 175 | # Init weights 176 | self._init_weight() 177 | 178 | # Load pretrained model 179 | if pretrained: 180 | self._load_pretrained_model() 181 | 182 | def forward(self, x): 183 | # Entry flow 184 | x = self.conv1(x) 185 | x = self.bn1(x) 186 | x = self.relu(x) 187 | 188 | x = self.conv2(x) 189 | x = self.bn2(x) 190 | x = self.relu(x) 191 | 192 | x = self.block1(x) 193 | # add relu here 194 | x = self.relu(x) 195 | low_level_feat = x 196 | x = self.block2(x) 197 | x = self.block3(x) 198 | 199 | # Middle flow 200 | x = self.block4(x) 201 | x = self.block5(x) 202 | x = self.block6(x) 203 | x = self.block7(x) 204 | x = self.block8(x) 205 | x = self.block9(x) 206 | x = self.block10(x) 207 | x = self.block11(x) 208 | x = self.block12(x) 209 | x = self.block13(x) 210 | x = self.block14(x) 211 | x = self.block15(x) 212 | x = self.block16(x) 213 | x = self.block17(x) 214 | x = self.block18(x) 215 | x = self.block19(x) 216 | 217 | # Exit flow 218 | x = self.block20(x) 219 | x = self.relu(x) 220 | x = self.conv3(x) 221 | x = self.bn3(x) 222 | x = self.relu(x) 223 | 224 | x = self.conv4(x) 225 | x = self.bn4(x) 226 | x = self.relu(x) 227 | 228 | x = self.conv5(x) 229 | x = self.bn5(x) 230 | x = self.relu(x) 231 | 232 | return x, low_level_feat 233 | 234 | def _init_weight(self): 235 | for m in self.modules(): 236 | if isinstance(m, nn.Conv2d): 237 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 238 | m.weight.data.normal_(0, math.sqrt(2. / n)) 239 | elif isinstance(m, SynchronizedBatchNorm2d): 240 | m.weight.data.fill_(1) 241 | m.bias.data.zero_() 242 | elif isinstance(m, nn.BatchNorm2d): 243 | m.weight.data.fill_(1) 244 | m.bias.data.zero_() 245 | 246 | 247 | def _load_pretrained_model(self): 248 | pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth') 249 | model_dict = {} 250 | state_dict = self.state_dict() 251 | 252 | for k, v in pretrain_dict.items(): 253 | if k in model_dict: 254 | if 'pointwise' in k: 255 | v = v.unsqueeze(-1).unsqueeze(-1) 256 | if k.startswith('block11'): 257 | model_dict[k] = v 258 | model_dict[k.replace('block11', 'block12')] = v 259 | model_dict[k.replace('block11', 'block13')] = v 260 | model_dict[k.replace('block11', 'block14')] = v 261 | model_dict[k.replace('block11', 'block15')] = v 262 | model_dict[k.replace('block11', 'block16')] = v 263 | model_dict[k.replace('block11', 'block17')] = v 264 | model_dict[k.replace('block11', 'block18')] = v 265 | model_dict[k.replace('block11', 'block19')] = v 266 | elif k.startswith('block12'): 267 | model_dict[k.replace('block12', 'block20')] = v 268 | elif k.startswith('bn3'): 269 | model_dict[k] = v 270 | model_dict[k.replace('bn3', 'bn4')] = v 271 | elif k.startswith('conv4'): 272 | model_dict[k.replace('conv4', 'conv5')] = v 273 | elif k.startswith('bn4'): 274 | model_dict[k.replace('bn4', 'bn5')] = v 275 | else: 276 | model_dict[k] = v 277 | state_dict.update(model_dict) 278 | self.load_state_dict(state_dict) 279 | 280 | 281 | 282 | if __name__ == "__main__": 283 | import torch 284 | model = AlignedXception(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=16) 285 | input = torch.rand(1, 3, 512, 512) 286 | output, low_level_feat = model(input) 287 | print(output.size()) 288 | print(low_level_feat.size()) -------------------------------------------------------------------------------- /networks/deeplab/decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from networks.deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class Decoder(nn.Module): 8 | def __init__(self, num_classes, backbone, BatchNorm): 9 | super(Decoder, self).__init__() 10 | if backbone == 'resnet' or backbone == 'drn': 11 | low_level_inplanes = 256 12 | elif backbone == 'xception': 13 | low_level_inplanes = 128 14 | elif backbone == 'mobilenet': 15 | low_level_inplanes = 24 16 | else: 17 | raise NotImplementedError 18 | 19 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 20 | self.bn1 = BatchNorm(48) 21 | self.relu = nn.ReLU() 22 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 23 | BatchNorm(256), 24 | nn.ReLU(), 25 | nn.Dropout(0.5), 26 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 27 | BatchNorm(256), 28 | nn.ReLU(), 29 | nn.Dropout(0.1), 30 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1)) 31 | self._init_weight() 32 | 33 | 34 | def forward(self, x, low_level_feat): 35 | low_level_feat = self.conv1(low_level_feat) 36 | low_level_feat = self.bn1(low_level_feat) 37 | low_level_feat = self.relu(low_level_feat) 38 | 39 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) 40 | x = torch.cat((x, low_level_feat), dim=1) 41 | x = self.last_conv(x) 42 | 43 | return x 44 | 45 | def _init_weight(self): 46 | for m in self.modules(): 47 | if isinstance(m, nn.Conv2d): 48 | torch.nn.init.kaiming_normal_(m.weight) 49 | elif isinstance(m, SynchronizedBatchNorm2d): 50 | m.weight.data.fill_(1) 51 | m.bias.data.zero_() 52 | elif isinstance(m, nn.BatchNorm2d): 53 | m.weight.data.fill_(1) 54 | m.bias.data.zero_() 55 | 56 | def build_decoder(num_classes, backbone, BatchNorm): 57 | return Decoder(num_classes, backbone, BatchNorm) -------------------------------------------------------------------------------- /networks/deeplab/deeplab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from networks.deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | from networks.deeplab.aspp import build_aspp 6 | from networks.deeplab.decoder import build_decoder 7 | from networks.deeplab.backbone import build_backbone 8 | 9 | class DeepLab(nn.Module): 10 | def __init__(self, backbone='resnet', output_stride=16, num_classes=21, 11 | sync_bn=True, freeze_bn=False): 12 | super(DeepLab, self).__init__() 13 | if backbone == 'drn': 14 | output_stride = 8 15 | 16 | if sync_bn == True: 17 | BatchNorm = SynchronizedBatchNorm2d 18 | else: 19 | BatchNorm = nn.BatchNorm2d 20 | 21 | self.backbone = build_backbone(backbone, output_stride, BatchNorm) 22 | self.aspp = build_aspp(backbone, output_stride, BatchNorm) 23 | self.decoder = build_decoder(num_classes, backbone, BatchNorm) 24 | 25 | if freeze_bn: 26 | self.freeze_bn() 27 | 28 | def forward(self, input): 29 | x, low_level_feat = self.backbone(input) 30 | x = self.aspp(x) 31 | x = self.decoder(x, low_level_feat) 32 | x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True) 33 | 34 | return x 35 | 36 | def freeze_bn(self): 37 | for m in self.modules(): 38 | if isinstance(m, SynchronizedBatchNorm2d): 39 | m.eval() 40 | elif isinstance(m, nn.BatchNorm2d): 41 | m.eval() 42 | 43 | def get_1x_lr_params(self): 44 | modules = [self.backbone] 45 | for i in range(len(modules)): 46 | for m in modules[i].named_modules(): 47 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 48 | or isinstance(m[1], nn.BatchNorm2d): 49 | for p in m[1].parameters(): 50 | if p.requires_grad: 51 | yield p 52 | 53 | def get_10x_lr_params(self): 54 | modules = [self.aspp, self.decoder] 55 | for i in range(len(modules)): 56 | for m in modules[i].named_modules(): 57 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 58 | or isinstance(m[1], nn.BatchNorm2d): 59 | for p in m[1].parameters(): 60 | if p.requires_grad: 61 | yield p 62 | 63 | 64 | if __name__ == "__main__": 65 | model = DeepLab(backbone='mobilenet', output_stride=16) 66 | model.eval() 67 | input = torch.rand(1, 3, 513, 513) 68 | output = model(input) 69 | print(output.size()) 70 | 71 | 72 | -------------------------------------------------------------------------------- /networks/deeplab/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /networks/deeplab/sync_batchnorm/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuk6heo/IVOS-ATNet/1cf574953a96bd680c518c6362b510fd103ff271/networks/deeplab/sync_batchnorm/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /networks/deeplab/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuk6heo/IVOS-ATNet/1cf574953a96bd680c518c6362b510fd103ff271/networks/deeplab/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc -------------------------------------------------------------------------------- /networks/deeplab/sync_batchnorm/__pycache__/comm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuk6heo/IVOS-ATNet/1cf574953a96bd680c518c6362b510fd103ff271/networks/deeplab/sync_batchnorm/__pycache__/comm.cpython-36.pyc -------------------------------------------------------------------------------- /networks/deeplab/sync_batchnorm/__pycache__/replicate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuk6heo/IVOS-ATNet/1cf574953a96bd680c518c6362b510fd103ff271/networks/deeplab/sync_batchnorm/__pycache__/replicate.cpython-36.pyc -------------------------------------------------------------------------------- /networks/deeplab/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | .. math:: 132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 134 | standard-deviation are reduced across all devices during training. 135 | For example, when one uses `nn.DataParallel` to wrap the network during 136 | training, PyTorch's implementation normalize the tensor on each device using 137 | the statistics only on that device, which accelerated the computation and 138 | is also easy to implement, but the statistics might be inaccurate. 139 | Instead, in this synchronized version, the statistics will be computed 140 | over all training samples distributed on multiple devices. 141 | 142 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 143 | as the built-in PyTorch implementation. 144 | The mean and standard-deviation are calculated per-dimension over 145 | the mini-batches and gamma and beta are learnable parameter vectors 146 | of size C (where C is the input size). 147 | During training, this layer keeps a running estimate of its computed mean 148 | and variance. The running sum is kept with a default momentum of 0.1. 149 | During evaluation, this running mean/variance is used for normalization. 150 | Because the BatchNorm is done over the `C` dimension, computing statistics 151 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 152 | Args: 153 | num_features: num_features from an expected input of size 154 | `batch_size x num_features [x width]` 155 | eps: a value added to the denominator for numerical stability. 156 | Default: 1e-5 157 | momentum: the value used for the running_mean and running_var 158 | computation. Default: 0.1 159 | affine: a boolean value that when set to ``True``, gives the layer learnable 160 | affine parameters. Default: ``True`` 161 | Shape: 162 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 163 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 164 | Examples: 165 | >>> # With Learnable Parameters 166 | >>> m = SynchronizedBatchNorm1d(100) 167 | >>> # Without Learnable Parameters 168 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 169 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 170 | >>> output = m(input) 171 | """ 172 | 173 | def _check_input_dim(self, input): 174 | if input.dim() != 2 and input.dim() != 3: 175 | raise ValueError('expected 2D or 3D input (got {}D input)' 176 | .format(input.dim())) 177 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 178 | 179 | 180 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 181 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 182 | of 3d inputs 183 | .. math:: 184 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 185 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 186 | standard-deviation are reduced across all devices during training. 187 | For example, when one uses `nn.DataParallel` to wrap the network during 188 | training, PyTorch's implementation normalize the tensor on each device using 189 | the statistics only on that device, which accelerated the computation and 190 | is also easy to implement, but the statistics might be inaccurate. 191 | Instead, in this synchronized version, the statistics will be computed 192 | over all training samples distributed on multiple devices. 193 | 194 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 195 | as the built-in PyTorch implementation. 196 | The mean and standard-deviation are calculated per-dimension over 197 | the mini-batches and gamma and beta are learnable parameter vectors 198 | of size C (where C is the input size). 199 | During training, this layer keeps a running estimate of its computed mean 200 | and variance. The running sum is kept with a default momentum of 0.1. 201 | During evaluation, this running mean/variance is used for normalization. 202 | Because the BatchNorm is done over the `C` dimension, computing statistics 203 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 204 | Args: 205 | num_features: num_features from an expected input of 206 | size batch_size x num_features x height x width 207 | eps: a value added to the denominator for numerical stability. 208 | Default: 1e-5 209 | momentum: the value used for the running_mean and running_var 210 | computation. Default: 0.1 211 | affine: a boolean value that when set to ``True``, gives the layer learnable 212 | affine parameters. Default: ``True`` 213 | Shape: 214 | - Input: :math:`(N, C, H, W)` 215 | - Output: :math:`(N, C, H, W)` (same shape as input) 216 | Examples: 217 | >>> # With Learnable Parameters 218 | >>> m = SynchronizedBatchNorm2d(100) 219 | >>> # Without Learnable Parameters 220 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 221 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 222 | >>> output = m(input) 223 | """ 224 | 225 | def _check_input_dim(self, input): 226 | if input.dim() != 4: 227 | raise ValueError('expected 4D input (got {}D input)' 228 | .format(input.dim())) 229 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 230 | 231 | 232 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 233 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 234 | of 4d inputs 235 | .. math:: 236 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 237 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 238 | standard-deviation are reduced across all devices during training. 239 | For example, when one uses `nn.DataParallel` to wrap the network during 240 | training, PyTorch's implementation normalize the tensor on each device using 241 | the statistics only on that device, which accelerated the computation and 242 | is also easy to implement, but the statistics might be inaccurate. 243 | Instead, in this synchronized version, the statistics will be computed 244 | over all training samples distributed on multiple devices. 245 | 246 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 247 | as the built-in PyTorch implementation. 248 | The mean and standard-deviation are calculated per-dimension over 249 | the mini-batches and gamma and beta are learnable parameter vectors 250 | of size C (where C is the input size). 251 | During training, this layer keeps a running estimate of its computed mean 252 | and variance. The running sum is kept with a default momentum of 0.1. 253 | During evaluation, this running mean/variance is used for normalization. 254 | Because the BatchNorm is done over the `C` dimension, computing statistics 255 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 256 | or Spatio-temporal BatchNorm 257 | Args: 258 | num_features: num_features from an expected input of 259 | size batch_size x num_features x depth x height x width 260 | eps: a value added to the denominator for numerical stability. 261 | Default: 1e-5 262 | momentum: the value used for the running_mean and running_var 263 | computation. Default: 0.1 264 | affine: a boolean value that when set to ``True``, gives the layer learnable 265 | affine parameters. Default: ``True`` 266 | Shape: 267 | - Input: :math:`(N, C, D, H, W)` 268 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 269 | Examples: 270 | >>> # With Learnable Parameters 271 | >>> m = SynchronizedBatchNorm3d(100) 272 | >>> # Without Learnable Parameters 273 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 274 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 275 | >>> output = m(input) 276 | """ 277 | 278 | def _check_input_dim(self, input): 279 | if input.dim() != 5: 280 | raise ValueError('expected 5D input (got {}D input)' 281 | .format(input.dim())) 282 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /networks/deeplab/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /networks/deeplab/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /networks/deeplab/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /networks/ltm_transfer.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | class LTM_transfer(nn.Module): 9 | def __init__(self,md=4, stride=1): 10 | super(LTM_transfer, self).__init__() 11 | self.md = md #displacement (default = 4pixels) 12 | self.range = (md*2 + 1) ** 2 #(default = (4x2+1)**2 = 81) 13 | self.grid = None 14 | self.Channelwise_sum = None 15 | 16 | d_u = torch.linspace(-self.md * stride, self.md * stride, 2 * self.md + 1).view(1, -1).repeat((2 * self.md + 1, 1)).view(self.range, 1) # (25,1) 17 | d_v = torch.linspace(-self.md * stride, self.md * stride, 2 * self.md + 1).view(-1, 1).repeat((1, 2 * self.md + 1)).view(self.range, 1) # (25,1) 18 | self.d = torch.cat((d_u, d_v), dim=1).cuda() # (25,2) 19 | 20 | def L2normalize(self, x, d=1): 21 | eps = 1e-6 22 | norm = x ** 2 23 | norm = norm.sum(dim=d, keepdim=True) + eps 24 | norm = norm ** (0.5) 25 | return (x/norm) 26 | 27 | def UniformGrid(self, Input): 28 | ''' 29 | Make uniform grid 30 | :param Input: tensor(N,C,H,W) 31 | :return grid: (1,2,H,W) 32 | ''' 33 | # torchHorizontal = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(N, 1, H, W) 34 | # torchVertical = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(N, 1, H, W) 35 | # grid = torch.cat([torchHorizontal, torchVertical], 1).cuda() 36 | 37 | _, _, H, W = Input.size() 38 | # mesh grid 39 | xx = torch.arange(0, W).view(1, 1, 1, W).expand(1, 1, H, W) 40 | yy = torch.arange(0, H).view(1, 1, H, 1).expand(1, 1, H, W) 41 | 42 | grid = torch.cat((xx, yy), 1).float() 43 | 44 | if Input.is_cuda: 45 | grid = grid.cuda() 46 | 47 | return grid 48 | 49 | def warp(self, x, BM_d): 50 | vgrid = self.grid + BM_d # [N2HW] # [(2d+1)^2, 2, H, W] 51 | # scale grid to [-1,1] 52 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :] / max(x.size(3) - 1, 1) - 1.0 53 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :] / max(x.size(2) - 1, 1) - 1.0 54 | 55 | vgrid = vgrid.permute(0, 2, 3, 1) 56 | output = nn.functional.grid_sample(x, vgrid, mode='bilinear', padding_mode = 'border') #800MB memory occupied (d=2,C=64,H=256,W=256) 57 | mask = torch.autograd.Variable(torch.ones(x.size())).cuda() 58 | mask = nn.functional.grid_sample(mask, vgrid) #300MB memory occpied (d=2,C=64,H=256,W=256) 59 | 60 | mask = mask.masked_fill_(mask<0.999,0) 61 | mask = mask.masked_fill_(mask>0,1) 62 | 63 | return output * mask 64 | 65 | def forward(self,sim_feature, f_map, apply_softmax_on_simfeature = True): 66 | ''' 67 | Return bilateral cost volume(Set of bilateral correlation map) 68 | :param sim_feature: Correlation feature based on operating frame's HW (N,D2,H,W) 69 | :param f_map: Previous frame mask (N,1,H,W) 70 | :return Correlation Cost: (N,(2d+1)^2,H,W) 71 | ''' 72 | # feature1 = self.L2normalize(feature1) 73 | # feature2 = self.L2normalize(feature2) 74 | 75 | B_size,C_size,H_size,W_size = f_map.size() 76 | 77 | if self.grid is None: 78 | # Initialize first uniform grid 79 | self.grid = self.UniformGrid(f_map) 80 | 81 | if H_size != self.grid.size(2) or W_size != self.grid.size(3): 82 | # Update uniform grid to fit on input tensor shape 83 | self.grid = self.UniformGrid(f_map) 84 | 85 | 86 | # Displacement volume (N,(2d+1)^2,2,H,W) d = (i,j) , i in [-md,md] & j in [-md,md] 87 | D_vol = self.d.view(self.range, 2, 1, 1).expand(-1, -1, H_size, W_size) # [(2d+1)^2, 2, H, W] 88 | 89 | if apply_softmax_on_simfeature: 90 | sim_feature = F.softmax(sim_feature, dim=1) # B,D^2,H,W 91 | f_map = self.warp(f_map.transpose(0, 1).expand(self.range,-1,-1,-1), D_vol).transpose(0, 1) # B,D^2,H,W 92 | 93 | f_map = torch.sum(torch.mul(sim_feature, f_map),dim=1, keepdim=True) # B,1,H,W 94 | 95 | return f_map # B,1,H,W 96 | -------------------------------------------------------------------------------- /results/test_result_davisframework/IVOS-ATNet_JF_example/summary.json: -------------------------------------------------------------------------------- 1 | {"auc": 0.8150432614019845, "metric_at_threshold": {"threshold": 60, "J_AND_F": 0.8271279419489322}, "curve": {"time": [0.0, 9.487896331151326, 16.278953917821248, 21.859388258722092, 26.12112567424774, 29.627880615658228, 32.95456397533417, 35.73726238409678, 38.56137614780002, 488.0], "J_AND_F": [0.0, 0.690892643367614, 0.7512954328742666, 0.7875513696076725, 0.8018014877265979, 0.809771518288365, 0.8169255790804818, 0.822880847995248, 0.8271279419489322, 0.8271279419489322]}} -------------------------------------------------------------------------------- /results/test_result_davisframework/IVOS-ATNet_JF_example/summary_graph_0.827.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuk6heo/IVOS-ATNet/1cf574953a96bd680c518c6362b510fd103ff271/results/test_result_davisframework/IVOS-ATNet_JF_example/summary_graph_0.827.png -------------------------------------------------------------------------------- /results/test_result_davisframework/IVOS-ATNet_J_example/summary.json: -------------------------------------------------------------------------------- 1 | {"auc": 0.7781946615240677, "metric_at_threshold": {"threshold": 60, "J": 0.7895668677930413}, "curve": {"time": [0.0, 9.543367825614082, 16.36207738187578, 21.817193643252054, 26.396136294470892, 29.93320824040307, 33.12130610413021, 35.9824918879403, 38.66884519788954, 488.0], "J": [0.0, 0.6634230419673162, 0.7219753407865033, 0.7551480870563894, 0.7675534925001973, 0.7759233108062968, 0.78285585613275, 0.787008686963769, 0.7895668677930413, 0.7895668677930413]}} -------------------------------------------------------------------------------- /results/test_result_davisframework/IVOS-ATNet_J_example/summary_graph_0.790.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuk6heo/IVOS-ATNet/1cf574953a96bd680c518c6362b510fd103ff271/results/test_result_davisframework/IVOS-ATNet_J_example/summary_graph_0.790.png --------------------------------------------------------------------------------