├── LICENSE ├── README.md ├── dataset ├── __init__.py └── vos.py ├── flow_inference ├── __init__.py ├── flow_inference.py ├── models.py ├── models.pyc └── networks_flow │ ├── FlowNetC.py │ ├── FlowNetC.pyc │ ├── FlowNetFusion.py │ ├── FlowNetFusion.pyc │ ├── FlowNetS.py │ ├── FlowNetS.pyc │ ├── FlowNetSD.py │ ├── FlowNetSD.pyc │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ ├── FlowNetC.cpython-36.pyc │ ├── FlowNetFusion.cpython-36.pyc │ ├── FlowNetS.cpython-36.pyc │ ├── FlowNetSD.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ └── submodules.cpython-36.pyc │ ├── channelnorm_package │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ │ └── __init__.cpython-36.pyc │ ├── _ext │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ └── __init__.cpython-36.pyc │ │ └── channelnorm │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ └── __init__.cpython-36.pyc │ │ │ └── _channelnorm.so │ ├── build.py │ ├── functions │ │ ├── __init__.py │ │ ├── __init__.pyc │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ └── channelnorm.cpython-36.pyc │ │ ├── channelnorm.py │ │ └── channelnorm.pyc │ ├── make.sh │ ├── modules │ │ ├── __init__.py │ │ ├── __init__.pyc │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ └── channelnorm.cpython-36.pyc │ │ ├── channelnorm.py │ │ └── channelnorm.pyc │ └── src │ │ ├── ChannelNorm_cuda.c │ │ ├── ChannelNorm_cuda.h │ │ ├── ChannelNorm_kernel.cu │ │ ├── ChannelNorm_kernel.h │ │ └── ChannelNorm_kernel.o │ ├── correlation_package │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ │ └── __init__.cpython-36.pyc │ ├── _ext │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ └── __init__.cpython-36.pyc │ │ └── correlation │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ └── __init__.cpython-36.pyc │ │ │ └── _correlation.so │ ├── build.py │ ├── functions │ │ ├── __init__.py │ │ ├── __init__.pyc │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ └── correlation.cpython-36.pyc │ │ ├── correlation.py │ │ └── correlation.pyc │ ├── make.sh │ ├── modules │ │ ├── __init__.py │ │ ├── __init__.pyc │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ └── correlation.cpython-36.pyc │ │ ├── correlation.py │ │ └── correlation.pyc │ └── src │ │ ├── correlation.c │ │ ├── correlation.h │ │ ├── correlation_cuda.c │ │ ├── correlation_cuda.h │ │ ├── correlation_cuda_kernel.cu │ │ ├── correlation_cuda_kernel.h │ │ └── correlation_cuda_kernel.o │ ├── resample2d_package │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ │ └── __init__.cpython-36.pyc │ ├── _ext │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ └── __init__.cpython-36.pyc │ │ └── resample2d │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ └── __init__.cpython-36.pyc │ │ │ └── _resample2d.so │ ├── build.py │ ├── functions │ │ ├── __init__.py │ │ ├── __init__.pyc │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ └── resample2d.cpython-36.pyc │ │ ├── resample2d.py │ │ └── resample2d.pyc │ ├── make.sh │ ├── modules │ │ ├── __init__.py │ │ ├── __init__.pyc │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ └── resample2d.cpython-36.pyc │ │ ├── resample2d.py │ │ └── resample2d.pyc │ └── src │ │ ├── Resample2d_cuda.c │ │ ├── Resample2d_cuda.h │ │ ├── Resample2d_kernel.cu │ │ ├── Resample2d_kernel.h │ │ └── Resample2d_kernel.o │ ├── submodules.py │ └── submodules.pyc ├── ft_davis.sh ├── infer_davis.py ├── infer_ytv.py ├── networks ├── __init__.py └── agssvos.py ├── run_davis.sh ├── run_ytv.sh ├── test_davis.sh ├── test_davis_ft.sh ├── tools ├── __init__.py ├── preprocess.py ├── utils.py └── visualize.py ├── train_davis.py ├── train_ytv.py ├── val_davis.sh ├── val_davis_ft.sh └── val_ytv.sh /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 DV Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AGSS-VOS: Attention Guided Single-Shot Video Object Segmentation ([ICCV 2019](https://openaccess.thecvf.com/content_ICCV_2019/papers/Lin_AGSS-VOS_Attention_Guided_Single-Shot_Video_Object_Segmentation_ICCV_2019_paper.pdf)) 2 | 3 | ## Prerequisites 4 | 5 | - Python 3.6 6 | - NVIDIA GPU 7 | - Ubuntu 8 | - Pytorch 0.4.0 9 | 10 | ## Model training and evaluation 11 | 12 | ### Data Preparation 13 | 1. Download [YouTubeVOS](https://drive.google.com/open?id=1bI5J1H3mxsIGo7Kp-pPZU8i6rnykOw7f). 14 | 2. Download [DAVIS-2017](https://davischallenge.org/davis2017/code.html). 15 | 3. We prepare the split annotations for all dataset and meta.json for davis2017 in [here](https://drive.google.com/drive/folders/1sSYjIbuPieL3XfM4lFR0THEjGZyGM2qq?usp=sharing) 16 | 4. Symlink the corresponding train/validation/test dataset and json files to `data` folder. 17 | ``` 18 | data 19 | ├── youtube_vos 20 | │ ├── train 21 | │ │ ├── JPEGImages 22 | │ │ ├── Split_Annotations 23 | │ │ ├── meta.json 24 | │ ├── valid 25 | │ │ ├── JPEGImages 26 | │ │ ├── Split_Annotations 27 | │ │ ├── meta.json 28 | │ ├── valid_all_frames 29 | │ │ ├── JPEGImages 30 | ├── davis2017 31 | │ ├── trainval 32 | │ │ ├── JPEGImages 33 | │ │ ├── Split_Annotations 34 | │ │ ├── train_meta.json 35 | │ │ ├── val_meta.json 36 | │ ├── test 37 | │ │ ├── JPEGImages 38 | │ │ ├── Split_Annotations 39 | │ │ ├── test_meta.json 40 | ``` 41 | 42 | ### Model Preparation 43 | 1. Download [RGMP](https://www.dropbox.com/s/gt0kivrb2hlavi2/weights.pth?dl=0) and place it (weigths.pth) in the 'checkpoints' folder. 44 | 2. Download [the pretrained model](https://drive.google.com/drive/folders/1sSYjIbuPieL3XfM4lFR0THEjGZyGM2qq?usp=sharing) and place it in the 'checkpoints' folder. 45 | ``` 46 | checkpoints 47 | ├── weights.pth 48 | ├── train_ytv 49 | │ ├── model_4.pth 50 | ├── train_davis 51 | │ ├── model_199.pth 52 | ├── ft_davis 53 | │ ├── model_99.pth 54 | ``` 55 | 3. Download [FlowNet2C](https://drive.google.com/file/d/1BFT6b7KgKJC8rA59RmOVAXRM_S7aSfKE/view?usp=sharing) and place it in the 'flow_inference/models/FlowNet2-C_checkpoint.pth.tar'. You need to run 56 | ``` 57 | sh make.sh 58 | ``` 59 | in the 'channelnorm_package', 'correlation_package' and 'resample2d_package' in 'flow_inference/networks_flow/' folder. Make sure the version of PyTorch is '0.4.0'. 60 | 61 | ### Training 62 | 1. To train on Youtube-VOS training set. 63 | ``` 64 | sh run_ytv.sh 65 | ``` 66 | 2. To train on DAVIS-2017 training set. 67 | ``` 68 | sh run_davis.sh 69 | ``` 70 | The checkpoint will be saved in the 'Outputs' folder. 71 | 72 | ### Finetuning 73 | 1. To finetune on the DAVIS-2017 training set with pretrained model on Youtube-VOS traininset. 74 | ``` 75 | sh ft_davis.sh 76 | ``` 77 | 78 | ### Validation 79 | 1. To inference on Youtube-VOS validation set. 80 | ``` 81 | sh val_ytv.sh 82 | ``` 83 | 2. To inference on DAVIS-2017 validation set. 84 | ``` 85 | sh val_davis.sh 86 | ``` 87 | 3. To inference on DAVIS-2017 test-dev set. 88 | ``` 89 | sh test_davis.sh 90 | ``` 91 | 4. To inference on DAVIS-2017 validation/test-dev set with finetuned model. 92 | ``` 93 | sh val_davis_ft.sh 94 | ``` 95 | ``` 96 | sh test_davis_ft.sh 97 | ``` 98 | 99 | The results will be saved in the 'val_dir' or 'test_dir' folder. 100 | You can change the '--restore' in scripts to validate your own training result. 101 | You can also add '--show-img' in scripts to save the visualized image result. 102 | 103 | 104 | #### This software is for Non-commercial Research Purposes only. 105 | 106 | ### Contact 107 | If you have any questions, please feel free to contact the authors. 108 | 109 | Huaijia Lin or 110 | 111 | ### Citation 112 | 113 | If you use our code, please consider citing our paper: 114 | 115 | ``` 116 | @inproceedings{lin2019agss, 117 | title={AGSS-VOS: Attention Guided Single-Shot Video Object Segmentation}, 118 | author={Lin, Huaijia and Qi, Xiaojuan and Jia, Jiaya}, 119 | booktitle={ICCV}, 120 | year={2019} 121 | } 122 | ``` 123 | 124 | ## Acknowledgments 125 | Parts of this code were derived from [RGMP](https://github.com/xanderchf/RGMP) and [FlowNet2](https://github.com/NVIDIA/flownet2-pytorch). 126 | 127 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/vos.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import json 3 | import logging 4 | import math 5 | import numpy as np 6 | import numpy.random as npr 7 | import os 8 | import os.path as osp 9 | import random 10 | from random import shuffle 11 | import sys 12 | from torch.utils.data import Dataset 13 | 14 | sys.path.append('./tools') 15 | import preprocess 16 | 17 | class Trainset(Dataset): 18 | def __init__(self, root_data, json_meta_list, 19 | mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225), 20 | sample_size=5, test_mode=False, random_flip=True, 21 | fix_size=False, spec_vid=None, spec_range=None, spec_obj_ind=None, 22 | abandon_len1=False, min_obj_num=1, step=5, half_size=True, 23 | random_ref=True, random_skip=False): 24 | self.mean = mean 25 | self.std = std 26 | self.files = [] 27 | self.files_obj_comb = [] 28 | self.sample_size = sample_size 29 | self.test_mode = test_mode 30 | self.random_flip = random_flip 31 | self.fix_size = fix_size 32 | self.spec_range = spec_range 33 | self.half_size = half_size 34 | self.random_ref = random_ref 35 | self.random_skip = random_skip 36 | 37 | if spec_vid is None: 38 | assert spec_range is None 39 | assert spec_obj_ind is None 40 | 41 | meta_list = json.loads(open(json_meta_list).read())['videos'] 42 | 43 | for name, vid in meta_list.items(): 44 | if spec_vid is not None and name != spec_vid: 45 | continue 46 | min_idx = 1000 47 | max_idx = -1 48 | for k, obj_ind in vid['objects'].items(): 49 | min_idx = min(min_idx, int(obj_ind['frames'][0])) 50 | max_idx = max(max_idx, int(obj_ind['frames'][-1])) 51 | 52 | vid_seq = [] 53 | obj_num = len(vid['objects']) 54 | idx_map = dict() 55 | for idx in range(min_idx, max_idx+1, step): 56 | img_file = osp.join(root_data, 'JPEGImages/%s/%05d.jpg' % (name, idx)) 57 | idx_map['%05d' % idx] = len(vid_seq) 58 | vid_seq.append({ 59 | 'img': img_file, 60 | 'lab': [None for _ in range(obj_num)], 61 | 'attr': [0 for _ in range(obj_num)] 62 | }) 63 | obj_cnt = 0 64 | for k, obj_ind in vid['objects'].items(): 65 | if spec_obj_ind is not None and int(k) not in spec_obj_ind: 66 | continue 67 | for frame in obj_ind['frames']: 68 | id = idx_map[frame] 69 | lab_file = osp.join(root_data, 'Split_Annotations/%s/%s/%s.png' % (name, k, frame)) 70 | vid_seq[id]['lab'][obj_cnt] = lab_file 71 | vid_seq[id]['attr'][obj_cnt] = 1 #if cv2.imread(lab_file, cv2.IMREAD_GRAYSCALE).sum()>0 else 0 72 | obj_cnt += 1 73 | self.files.append(vid_seq) 74 | obj_attr = dict() 75 | for vid_f in vid_seq: 76 | if sum(vid_f['attr']) >= min_obj_num: 77 | tmp_vid = tuple(vid_f['attr']) 78 | if tmp_vid not in obj_attr: 79 | obj_attr[tmp_vid] = 0 80 | obj_attr[tmp_vid] += 1 81 | for k,v in obj_attr.items(): 82 | if abandon_len1 and v == 1: 83 | continue 84 | self.files_obj_comb.append((len(self.files)-1,k)) 85 | self.files_obj_comb = self.files_obj_comb 86 | random.seed(20170624) 87 | npr.seed(20170624) 88 | 89 | def __len__(self): 90 | return len(self.files_obj_comb) 91 | 92 | def __getitem__(self, item): 93 | vid_obj_comb = self.files_obj_comb[item] # (vid_name, obj_ind) 94 | video_list = self.files[vid_obj_comb[0]].copy() 95 | if self.spec_range is not None: 96 | video_list = video_list[self.spec_range[0]:self.spec_range[1]] 97 | ### flip the video is also an augmentation ### 98 | if not self.test_mode and random.uniform(0,1) > 0.5: 99 | video_list = video_list[::-1] 100 | assert len(video_list) > 0 101 | if len(video_list) >= 2: 102 | ### make sure must all the obj involved ### 103 | vid_attr = vid_obj_comb[1] 104 | cand_start_list = [i for i in range(len(video_list)) if tuple(video_list[i]['attr']) == vid_attr] 105 | video_select = [] 106 | ### random skip some frames ### 107 | if self.random_skip: 108 | skip = random.randint(1,5) 109 | else: 110 | skip = 1 111 | if len(cand_start_list) > self.sample_size*skip: 112 | start = random.randint(0, len(cand_start_list)-2) 113 | for k in cand_start_list[start:start+self.sample_size*skip][::skip]: 114 | video_select.append(video_list[k]) 115 | else: 116 | for k in cand_start_list: 117 | video_select.append(video_list[k]) 118 | ### random select ref frame ### 119 | if self.random_ref: 120 | k = random.randint(0, len(cand_start_list)-1) 121 | k = cand_start_list[k] 122 | video_select = video_list[k:k+1] + video_select 123 | else: 124 | video_select = video_list 125 | 126 | 127 | img_set = [] 128 | lab_set = [] 129 | ori_img_set = [] 130 | 131 | flip_flag = random.uniform(0,1) > 0.5 and self.random_flip 132 | attr = video_select[0]['attr'] 133 | 134 | for idx,datafiles in enumerate(video_select): 135 | image = cv2.imread(datafiles['img'], cv2.IMREAD_COLOR) 136 | labels = [] 137 | for i,lab_name in enumerate(datafiles['lab']): 138 | if attr[i] == 1: 139 | if lab_name is not None: 140 | lab = cv2.imread(lab_name, cv2.IMREAD_GRAYSCALE) 141 | labels.append(lab) 142 | else: 143 | labels.append(np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)) 144 | 145 | if self.fix_size: 146 | dsize = (640,320) 147 | else: 148 | dsize = preprocess.get_dsize(image, half_size=self.half_size, scale=64) 149 | 150 | image = cv2.resize(image, dsize, interpolation=cv2.INTER_LINEAR) 151 | for i in range(len(labels)): 152 | if labels[i] is not None: 153 | labels[i] = cv2.resize(labels[i], dsize, interpolation=cv2.INTER_NEAREST) 154 | 155 | ori_img = image.copy() 156 | 157 | image = preprocess.norm(image, self.mean, self.std) 158 | 159 | ### flip each frame ### 160 | if flip_flag: 161 | image = image[:,::-1] 162 | for i in range(len(labels)): 163 | if labels[i] is not None: 164 | labels[i] = labels[i][:,::-1] 165 | ori_img = ori_img[:,::-1] 166 | 167 | image = image.transpose((2, 0, 1)) 168 | image_cat = [] 169 | label_cat = [] 170 | for i in range(len(labels)): 171 | image_cat.append(image[np.newaxis,:]) 172 | label_cat.append(labels[i][np.newaxis,:]) 173 | 174 | image_cat = np.concatenate(image_cat, axis=0) 175 | label_cat = np.concatenate(label_cat, axis=0) 176 | img_set.append(image_cat[:,np.newaxis,:]) 177 | lab_set.append(label_cat[:,np.newaxis,:]) 178 | ori_img_set.append(ori_img[np.newaxis,:]) 179 | 180 | img_set = np.concatenate((img_set), axis=1) 181 | lab_set = np.concatenate((lab_set), axis=1) 182 | ori_img_set = np.concatenate((ori_img_set), axis=0) 183 | 184 | return img_set, lab_set, ori_img_set 185 | 186 | 187 | class Valset(Dataset): 188 | def __init__(self, root_data, root_all_data, json_meta_list, 189 | mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225), 190 | sample_size=5, all_frames=True, test_mode=False, fix_size=True, 191 | spec_vid=None, spec_obj_ind=None, half_size=True): 192 | self.mean = mean 193 | self.std = std 194 | self.files = [] 195 | if all_frames: 196 | self.sample_size = (sample_size-1)*5+1 197 | else: 198 | self.sample_size = sample_size 199 | self.test_mode = test_mode 200 | self.fix_size = fix_size 201 | self.half_size = half_size 202 | 203 | meta_list = json.loads(open(json_meta_list).read())['videos'] 204 | if spec_obj_ind is not None: 205 | assert spec_vid is not None 206 | 207 | for name, vid in meta_list.items(): 208 | if spec_vid is not None and name not in spec_vid: 209 | continue 210 | min_idx = 1000 211 | max_idx = -1 212 | for k, obj_ind in vid['objects'].items(): 213 | min_idx = min(min_idx, int(obj_ind['frames'][0])) 214 | max_idx = max(max_idx, int(obj_ind['frames'][-1])) 215 | 216 | vid_info = dict() 217 | vid_seq = [] 218 | idx_map = dict() 219 | step = 1 if all_frames or test_mode else 5 220 | for idx in range(min_idx, max_idx+1, step): 221 | img_file = osp.join(root_all_data, 'JPEGImages/%s/%05d.jpg' % (name, idx)) 222 | idx_map['%05d' % idx] = len(vid_seq) 223 | vid_seq.append(img_file) 224 | 225 | vid_info['imgs'] = vid_seq 226 | vid_info['name'] = name 227 | vid_info['min_idx'] = min_idx 228 | vid_lab = [] 229 | obj_num = 0 230 | for k, obj_ind in vid['objects'].items(): 231 | if spec_obj_ind is not None: 232 | if int(k) not in spec_obj_ind: 233 | continue 234 | obj_num = max(obj_num, int(k)) 235 | frame = obj_ind['frames'][0] 236 | lab_file = osp.join(root_data, 'Split_Annotations/%s/%s/%s.png' % (name, k, frame)) 237 | obj_lab = dict() 238 | obj_lab['obj_ind'] = k 239 | obj_lab['lab_file'] = lab_file 240 | obj_lab['start_idx'] = int(frame)-min_idx 241 | vid_lab.append(obj_lab) 242 | vid_info['obj_num'] = obj_num 243 | vid_info['labs'] = vid_lab 244 | 245 | self.files.append(vid_info) 246 | random.seed(20170624) 247 | 248 | def __len__(self): 249 | return len(self.files) 250 | 251 | def __getitem__(self, item): 252 | vid_info = self.files[item].copy() 253 | 254 | img_set = [] 255 | lab_set = [] 256 | ori_img_set = [] 257 | ori_shape = None 258 | dsize = None 259 | 260 | ### gen img_set & ori_img_set ### 261 | for img_name in vid_info['imgs']: 262 | image = cv2.imread(img_name, cv2.IMREAD_COLOR) 263 | if ori_shape is None: 264 | ori_shape = image.shape[:2] 265 | 266 | if self.fix_size: 267 | # dsize = (640,320) ## for ytv 268 | dsize = (832,448) ## for davis 269 | else: 270 | dsize = preprocess.get_dsize(image, half_size=self.half_size, scale=64) 271 | 272 | image = cv2.resize(image, dsize, interpolation=cv2.INTER_LINEAR) 273 | ori_img = image.copy() 274 | 275 | image = preprocess.norm(image, self.mean, self.std) 276 | image = image.transpose((2, 0, 1)) 277 | 278 | img_set.append(image[np.newaxis,:]) 279 | ori_img_set.append(ori_img[np.newaxis,:]) 280 | 281 | img_set = np.concatenate((img_set), axis=0) 282 | ori_img_set = np.concatenate((ori_img_set), axis=0) 283 | 284 | obj_ind = [] 285 | obj_start_idx = [] 286 | ### gen lab_set ### 287 | for obj_lab in vid_info['labs']: 288 | lab = cv2.imread(obj_lab['lab_file'], cv2.IMREAD_GRAYSCALE) 289 | lab = cv2.resize(lab, dsize, interpolation=cv2.INTER_NEAREST) 290 | lab_set.append(lab[np.newaxis,:]) 291 | obj_ind.append(obj_lab['obj_ind']) 292 | obj_start_idx.append(obj_lab['start_idx']) 293 | 294 | lab_set = np.concatenate((lab_set), axis=0) 295 | 296 | 297 | return img_set, lab_set, ori_img_set, vid_info['name'], vid_info['min_idx'], vid_info['obj_num'], \ 298 | obj_ind, obj_start_idx, ori_shape 299 | 300 | -------------------------------------------------------------------------------- /flow_inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/__init__.py -------------------------------------------------------------------------------- /flow_inference/flow_inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | import argparse, os 6 | 7 | import numpy as np 8 | 9 | import models 10 | import time 11 | import timeit 12 | from networks_flow.resample2d_package.modules.resample2d import Resample2d 13 | import cv2 14 | 15 | class Inference_flow(nn.Module): 16 | def __init__(self, args_flow, model_name="FlowNet2C", 17 | restore_path="flow_inference/models/FlowNet2-C_checkpoint.pth.tar", 18 | train_flow=False, resume=None): 19 | super(Inference_flow, self).__init__() 20 | args_flow.rgb_max = 255.0 21 | args_flow.fp16 = False 22 | args_flow.grads = {} 23 | 24 | if model_name == 'FlowNet2': 25 | self.model = models.FlowNet2(args_flow).cuda() 26 | elif model_name == 'FlowNet2C': 27 | self.model = models.FlowNet2C(args_flow).cuda() 28 | elif model_name == 'FlowNet2S': 29 | self.model = models.FlowNet2S(args_flow).cuda() 30 | elif model_name == 'FlowNet2SD': 31 | self.model = models.FlowNet2SD(args_flow).cuda() 32 | elif model_name == 'FlowNet2CS': 33 | self.model = models.FlowNet2CS(args_flow).cuda() 34 | elif model_name == 'FlowNet2CSS': 35 | self.model = models.FlowNet2CSS(args_flow).cuda() 36 | else: 37 | assert False, "No such model %s" % (model_name) 38 | print("loading %s pretrained model..." % (model_name)) 39 | if train_flow: 40 | self.model.train() 41 | else: 42 | self.model.eval() 43 | if resume is not None: 44 | self.model.load_state_dict(torch.load(resume)['flow']) 45 | else: 46 | self.model.load_state_dict(torch.load(restore_path)['state_dict']) 47 | 48 | ## flow order: img2 -> img1 ## 49 | def infer(self, img1, img2, scale=1): 50 | #assert img1.shape[0] % 64 == 0 and img1.shape[1] % 64 == 0, "shape should be n*64, but got shape (%d, %d, %d)" \ 51 | # % (img1.shape[0], img1.shape[1], img1.shape[2]) 52 | # resize flow to appropriate shape 53 | 54 | ori_h = img1.shape[0] 55 | ori_w = img1.shape[1] 56 | resize_h = ori_h 57 | resize_w = ori_w 58 | if scale != 1: 59 | resize_h = resize_h // scale 60 | resize_w = resize_w // scale 61 | 62 | resize_h = resize_h // 64 * 64 63 | resize_w = resize_w // 64 * 64 64 | if resize_h == 0: 65 | resize_h = 64 66 | if resize_w == 0: 67 | resize_w = 64 68 | if ori_h != resize_h or ori_w != resize_w: 69 | ratio_h = float(ori_h) / resize_h 70 | ratio_w = float(ori_w) / resize_w 71 | img1 = cv2.resize(img1, (resize_w, resize_h), interpolation=cv2.INTER_LINEAR) 72 | img2 = cv2.resize(img2, (resize_w, resize_h), interpolation=cv2.INTER_LINEAR) 73 | 74 | assert img1.shape == img2.shape, "image1 and image2 should be same!" 75 | 76 | # concat and forward 77 | images = [img1, img2] 78 | images = np.array(images).transpose(3, 0, 1, 2) # C 2 H W 79 | images = images[np.newaxis, :, :, :, :] 80 | images = Variable(torch.from_numpy(images.astype(np.float32)).cuda(), requires_grad=True) 81 | 82 | out_flo = self.model(images)[0].data.cpu().numpy().transpose(1, 2, 0) 83 | 84 | # resize back to original size 85 | if ori_h != resize_h or ori_w != resize_w: 86 | out_flo = cv2.resize(out_flo, (ori_w, ori_h), interpolation=cv2.INTER_LINEAR) 87 | out_flo[:, :, 0] = out_flo[:, :, 0] * ratio_w 88 | out_flo[:, :, 1] = out_flo[:, :, 1] * ratio_h 89 | 90 | return torch.FloatTensor(out_flo).cuda() 91 | 92 | 93 | -------------------------------------------------------------------------------- /flow_inference/models.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/models.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/FlowNetC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | 5 | import math 6 | import numpy as np 7 | 8 | from networks_flow.correlation_package.modules.correlation import Correlation 9 | 10 | from networks_flow.submodules import * 11 | 'Parameter count , 39,175,298 ' 12 | 13 | class FlowNetC(nn.Module): 14 | def __init__(self,args, batchNorm=True, div_flow = 20): 15 | super(FlowNetC,self).__init__() 16 | 17 | self.batchNorm = batchNorm 18 | self.div_flow = div_flow 19 | 20 | self.conv1 = conv(self.batchNorm, 3, 64, kernel_size=7, stride=2) 21 | self.conv2 = conv(self.batchNorm, 64, 128, kernel_size=5, stride=2) 22 | self.conv3 = conv(self.batchNorm, 128, 256, kernel_size=5, stride=2) 23 | self.conv_redir = conv(self.batchNorm, 256, 32, kernel_size=1, stride=1) 24 | 25 | if args.fp16: 26 | self.corr = nn.Sequential( 27 | tofp32(), 28 | Correlation(pad_size=20, kernel_size=1, max_displacement=20, stride1=1, stride2=2, corr_multiply=1), 29 | tofp16()) 30 | else: 31 | self.corr = Correlation(pad_size=20, kernel_size=1, max_displacement=20, stride1=1, stride2=2, corr_multiply=1) 32 | 33 | self.corr_activation = nn.LeakyReLU(0.1,inplace=True) 34 | self.conv3_1 = conv(self.batchNorm, 473, 256) 35 | self.conv4 = conv(self.batchNorm, 256, 512, stride=2) 36 | self.conv4_1 = conv(self.batchNorm, 512, 512) 37 | self.conv5 = conv(self.batchNorm, 512, 512, stride=2) 38 | self.conv5_1 = conv(self.batchNorm, 512, 512) 39 | self.conv6 = conv(self.batchNorm, 512, 1024, stride=2) 40 | self.conv6_1 = conv(self.batchNorm,1024, 1024) 41 | 42 | self.deconv5 = deconv(1024,512) 43 | self.deconv4 = deconv(1026,256) 44 | self.deconv3 = deconv(770,128) 45 | self.deconv2 = deconv(386,64) 46 | 47 | self.predict_flow6 = predict_flow(1024) 48 | self.predict_flow5 = predict_flow(1026) 49 | self.predict_flow4 = predict_flow(770) 50 | self.predict_flow3 = predict_flow(386) 51 | self.predict_flow2 = predict_flow(194) 52 | 53 | self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) 54 | self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) 55 | self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) 56 | self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) 57 | 58 | for m in self.modules(): 59 | if isinstance(m, nn.Conv2d): 60 | if m.bias is not None: 61 | init.uniform_(m.bias) 62 | init.xavier_uniform_(m.weight) 63 | 64 | if isinstance(m, nn.ConvTranspose2d): 65 | if m.bias is not None: 66 | init.uniform_(m.bias) 67 | init.xavier_uniform_(m.weight) 68 | # init_deconv_bilinear(m.weight) 69 | self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear') 70 | 71 | def forward(self, x): 72 | x1 = x[:,0:3,:,:] 73 | x2 = x[:,3::,:,:] 74 | 75 | out_conv1a = self.conv1(x1) 76 | out_conv2a = self.conv2(out_conv1a) 77 | out_conv3a = self.conv3(out_conv2a) 78 | 79 | # FlownetC bottom input stream 80 | out_conv1b = self.conv1(x2) 81 | 82 | out_conv2b = self.conv2(out_conv1b) 83 | out_conv3b = self.conv3(out_conv2b) 84 | 85 | # Merge streams 86 | out_corr = self.corr(out_conv3a, out_conv3b) # False 87 | out_corr = self.corr_activation(out_corr) 88 | 89 | # Redirect top input stream and concatenate 90 | out_conv_redir = self.conv_redir(out_conv3a) 91 | 92 | in_conv3_1 = torch.cat((out_conv_redir, out_corr), 1) 93 | 94 | # Merged conv layers 95 | out_conv3_1 = self.conv3_1(in_conv3_1) 96 | 97 | out_conv4 = self.conv4_1(self.conv4(out_conv3_1)) 98 | 99 | out_conv5 = self.conv5_1(self.conv5(out_conv4)) 100 | out_conv6 = self.conv6_1(self.conv6(out_conv5)) 101 | 102 | flow6 = self.predict_flow6(out_conv6) 103 | flow6_up = self.upsampled_flow6_to_5(flow6) 104 | out_deconv5 = self.deconv5(out_conv6) 105 | 106 | concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1) 107 | 108 | flow5 = self.predict_flow5(concat5) 109 | flow5_up = self.upsampled_flow5_to_4(flow5) 110 | out_deconv4 = self.deconv4(concat5) 111 | concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1) 112 | 113 | flow4 = self.predict_flow4(concat4) 114 | flow4_up = self.upsampled_flow4_to_3(flow4) 115 | out_deconv3 = self.deconv3(concat4) 116 | concat3 = torch.cat((out_conv3_1,out_deconv3,flow4_up),1) 117 | 118 | flow3 = self.predict_flow3(concat3) 119 | flow3_up = self.upsampled_flow3_to_2(flow3) 120 | out_deconv2 = self.deconv2(concat3) 121 | concat2 = torch.cat((out_conv2a,out_deconv2,flow3_up),1) 122 | 123 | flow2 = self.predict_flow2(concat2) 124 | 125 | if self.training: 126 | return flow2,flow3,flow4,flow5,flow6 127 | else: 128 | return flow2, 129 | -------------------------------------------------------------------------------- /flow_inference/networks_flow/FlowNetC.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/FlowNetC.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/FlowNetFusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | 5 | import math 6 | import numpy as np 7 | 8 | from networks_flow.submodules import * 9 | 'Parameter count = 581,226' 10 | 11 | class FlowNetFusion(nn.Module): 12 | def __init__(self,args, batchNorm=True): 13 | super(FlowNetFusion,self).__init__() 14 | 15 | self.batchNorm = batchNorm 16 | self.conv0 = conv(self.batchNorm, 11, 64) 17 | self.conv1 = conv(self.batchNorm, 64, 64, stride=2) 18 | self.conv1_1 = conv(self.batchNorm, 64, 128) 19 | self.conv2 = conv(self.batchNorm, 128, 128, stride=2) 20 | self.conv2_1 = conv(self.batchNorm, 128, 128) 21 | 22 | self.deconv1 = deconv(128,32) 23 | self.deconv0 = deconv(162,16) 24 | 25 | self.inter_conv1 = i_conv(self.batchNorm, 162, 32) 26 | self.inter_conv0 = i_conv(self.batchNorm, 82, 16) 27 | 28 | self.predict_flow2 = predict_flow(128) 29 | self.predict_flow1 = predict_flow(32) 30 | self.predict_flow0 = predict_flow(16) 31 | 32 | self.upsampled_flow2_to_1 = nn.ConvTranspose2d(2, 2, 4, 2, 1) 33 | self.upsampled_flow1_to_0 = nn.ConvTranspose2d(2, 2, 4, 2, 1) 34 | 35 | for m in self.modules(): 36 | if isinstance(m, nn.Conv2d): 37 | if m.bias is not None: 38 | init.uniform_(m.bias) 39 | init.xavier_uniform_(m.weight) 40 | 41 | if isinstance(m, nn.ConvTranspose2d): 42 | if m.bias is not None: 43 | init.uniform_(m.bias) 44 | init.xavier_uniform_(m.weight) 45 | # init_deconv_bilinear(m.weight) 46 | 47 | def forward(self, x): 48 | out_conv0 = self.conv0(x) 49 | out_conv1 = self.conv1_1(self.conv1(out_conv0)) 50 | out_conv2 = self.conv2_1(self.conv2(out_conv1)) 51 | 52 | flow2 = self.predict_flow2(out_conv2) 53 | flow2_up = self.upsampled_flow2_to_1(flow2) 54 | out_deconv1 = self.deconv1(out_conv2) 55 | 56 | concat1 = torch.cat((out_conv1,out_deconv1,flow2_up),1) 57 | out_interconv1 = self.inter_conv1(concat1) 58 | flow1 = self.predict_flow1(out_interconv1) 59 | flow1_up = self.upsampled_flow1_to_0(flow1) 60 | out_deconv0 = self.deconv0(concat1) 61 | 62 | concat0 = torch.cat((out_conv0,out_deconv0,flow1_up),1) 63 | out_interconv0 = self.inter_conv0(concat0) 64 | flow0 = self.predict_flow0(out_interconv0) 65 | 66 | return flow0 67 | 68 | -------------------------------------------------------------------------------- /flow_inference/networks_flow/FlowNetFusion.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/FlowNetFusion.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/FlowNetS.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Portions of this code copyright 2017, Clement Pinard 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import init 8 | 9 | import math 10 | import numpy as np 11 | 12 | from networks_flow.submodules import * 13 | 'Parameter count : 38,676,504 ' 14 | 15 | class FlowNetS(nn.Module): 16 | def __init__(self, args, input_channels = 12, batchNorm=True): 17 | super(FlowNetS,self).__init__() 18 | 19 | self.batchNorm = batchNorm 20 | self.conv1 = conv(self.batchNorm, input_channels, 64, kernel_size=7, stride=2) 21 | self.conv2 = conv(self.batchNorm, 64, 128, kernel_size=5, stride=2) 22 | self.conv3 = conv(self.batchNorm, 128, 256, kernel_size=5, stride=2) 23 | self.conv3_1 = conv(self.batchNorm, 256, 256) 24 | self.conv4 = conv(self.batchNorm, 256, 512, stride=2) 25 | self.conv4_1 = conv(self.batchNorm, 512, 512) 26 | self.conv5 = conv(self.batchNorm, 512, 512, stride=2) 27 | self.conv5_1 = conv(self.batchNorm, 512, 512) 28 | self.conv6 = conv(self.batchNorm, 512, 1024, stride=2) 29 | self.conv6_1 = conv(self.batchNorm,1024, 1024) 30 | 31 | self.deconv5 = deconv(1024,512) 32 | self.deconv4 = deconv(1026,256) 33 | self.deconv3 = deconv(770,128) 34 | self.deconv2 = deconv(386,64) 35 | 36 | self.predict_flow6 = predict_flow(1024) 37 | self.predict_flow5 = predict_flow(1026) 38 | self.predict_flow4 = predict_flow(770) 39 | self.predict_flow3 = predict_flow(386) 40 | self.predict_flow2 = predict_flow(194) 41 | 42 | self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) 43 | self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) 44 | self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) 45 | self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) 46 | 47 | for m in self.modules(): 48 | if isinstance(m, nn.Conv2d): 49 | if m.bias is not None: 50 | init.uniform_(m.bias) 51 | init.xavier_uniform_(m.weight) 52 | 53 | if isinstance(m, nn.ConvTranspose2d): 54 | if m.bias is not None: 55 | init.uniform_(m.bias) 56 | init.xavier_uniform_(m.weight) 57 | # init_deconv_bilinear(m.weight) 58 | self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear') 59 | 60 | def forward(self, x): 61 | out_conv1 = self.conv1(x) 62 | 63 | out_conv2 = self.conv2(out_conv1) 64 | out_conv3 = self.conv3_1(self.conv3(out_conv2)) 65 | out_conv4 = self.conv4_1(self.conv4(out_conv3)) 66 | out_conv5 = self.conv5_1(self.conv5(out_conv4)) 67 | out_conv6 = self.conv6_1(self.conv6(out_conv5)) 68 | 69 | flow6 = self.predict_flow6(out_conv6) 70 | flow6_up = self.upsampled_flow6_to_5(flow6) 71 | out_deconv5 = self.deconv5(out_conv6) 72 | 73 | concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1) 74 | flow5 = self.predict_flow5(concat5) 75 | flow5_up = self.upsampled_flow5_to_4(flow5) 76 | out_deconv4 = self.deconv4(concat5) 77 | 78 | concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1) 79 | flow4 = self.predict_flow4(concat4) 80 | flow4_up = self.upsampled_flow4_to_3(flow4) 81 | out_deconv3 = self.deconv3(concat4) 82 | 83 | concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1) 84 | flow3 = self.predict_flow3(concat3) 85 | flow3_up = self.upsampled_flow3_to_2(flow3) 86 | out_deconv2 = self.deconv2(concat3) 87 | 88 | concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1) 89 | flow2 = self.predict_flow2(concat2) 90 | 91 | if self.training: 92 | return flow2,flow3,flow4,flow5,flow6 93 | else: 94 | return flow2, 95 | 96 | -------------------------------------------------------------------------------- /flow_inference/networks_flow/FlowNetS.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/FlowNetS.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/FlowNetSD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | 5 | import math 6 | import numpy as np 7 | 8 | from networks_flow.submodules import * 9 | 'Parameter count = 45,371,666' 10 | 11 | class FlowNetSD(nn.Module): 12 | def __init__(self, args, batchNorm=True): 13 | super(FlowNetSD,self).__init__() 14 | 15 | self.batchNorm = batchNorm 16 | self.conv0 = conv(self.batchNorm, 6, 64) 17 | self.conv1 = conv(self.batchNorm, 64, 64, stride=2) 18 | self.conv1_1 = conv(self.batchNorm, 64, 128) 19 | self.conv2 = conv(self.batchNorm, 128, 128, stride=2) 20 | self.conv2_1 = conv(self.batchNorm, 128, 128) 21 | self.conv3 = conv(self.batchNorm, 128, 256, stride=2) 22 | self.conv3_1 = conv(self.batchNorm, 256, 256) 23 | self.conv4 = conv(self.batchNorm, 256, 512, stride=2) 24 | self.conv4_1 = conv(self.batchNorm, 512, 512) 25 | self.conv5 = conv(self.batchNorm, 512, 512, stride=2) 26 | self.conv5_1 = conv(self.batchNorm, 512, 512) 27 | self.conv6 = conv(self.batchNorm, 512, 1024, stride=2) 28 | self.conv6_1 = conv(self.batchNorm,1024, 1024) 29 | 30 | self.deconv5 = deconv(1024,512) 31 | self.deconv4 = deconv(1026,256) 32 | self.deconv3 = deconv(770,128) 33 | self.deconv2 = deconv(386,64) 34 | 35 | self.inter_conv5 = i_conv(self.batchNorm, 1026, 512) 36 | self.inter_conv4 = i_conv(self.batchNorm, 770, 256) 37 | self.inter_conv3 = i_conv(self.batchNorm, 386, 128) 38 | self.inter_conv2 = i_conv(self.batchNorm, 194, 64) 39 | 40 | self.predict_flow6 = predict_flow(1024) 41 | self.predict_flow5 = predict_flow(512) 42 | self.predict_flow4 = predict_flow(256) 43 | self.predict_flow3 = predict_flow(128) 44 | self.predict_flow2 = predict_flow(64) 45 | 46 | self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1) 47 | self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1) 48 | self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1) 49 | self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1) 50 | 51 | for m in self.modules(): 52 | if isinstance(m, nn.Conv2d): 53 | if m.bias is not None: 54 | init.uniform_(m.bias) 55 | init.xavier_uniform_(m.weight) 56 | 57 | if isinstance(m, nn.ConvTranspose2d): 58 | if m.bias is not None: 59 | init.uniform_(m.bias) 60 | init.xavier_uniform_(m.weight) 61 | # init_deconv_bilinear(m.weight) 62 | self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear') 63 | 64 | 65 | 66 | def forward(self, x): 67 | out_conv0 = self.conv0(x) 68 | out_conv1 = self.conv1_1(self.conv1(out_conv0)) 69 | out_conv2 = self.conv2_1(self.conv2(out_conv1)) 70 | 71 | out_conv3 = self.conv3_1(self.conv3(out_conv2)) 72 | out_conv4 = self.conv4_1(self.conv4(out_conv3)) 73 | out_conv5 = self.conv5_1(self.conv5(out_conv4)) 74 | out_conv6 = self.conv6_1(self.conv6(out_conv5)) 75 | 76 | flow6 = self.predict_flow6(out_conv6) 77 | flow6_up = self.upsampled_flow6_to_5(flow6) 78 | out_deconv5 = self.deconv5(out_conv6) 79 | 80 | concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1) 81 | out_interconv5 = self.inter_conv5(concat5) 82 | flow5 = self.predict_flow5(out_interconv5) 83 | 84 | flow5_up = self.upsampled_flow5_to_4(flow5) 85 | out_deconv4 = self.deconv4(concat5) 86 | 87 | concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1) 88 | out_interconv4 = self.inter_conv4(concat4) 89 | flow4 = self.predict_flow4(out_interconv4) 90 | flow4_up = self.upsampled_flow4_to_3(flow4) 91 | out_deconv3 = self.deconv3(concat4) 92 | 93 | concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1) 94 | out_interconv3 = self.inter_conv3(concat3) 95 | flow3 = self.predict_flow3(out_interconv3) 96 | flow3_up = self.upsampled_flow3_to_2(flow3) 97 | out_deconv2 = self.deconv2(concat3) 98 | 99 | concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1) 100 | out_interconv2 = self.inter_conv2(concat2) 101 | flow2 = self.predict_flow2(out_interconv2) 102 | 103 | if self.training: 104 | return flow2,flow3,flow4,flow5,flow6 105 | else: 106 | return flow2, 107 | -------------------------------------------------------------------------------- /flow_inference/networks_flow/FlowNetSD.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/FlowNetSD.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/__init__.py -------------------------------------------------------------------------------- /flow_inference/networks_flow/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/__init__.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/__pycache__/FlowNetC.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/__pycache__/FlowNetC.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/__pycache__/FlowNetFusion.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/__pycache__/FlowNetFusion.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/__pycache__/FlowNetS.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/__pycache__/FlowNetS.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/__pycache__/FlowNetSD.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/__pycache__/FlowNetSD.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/__pycache__/submodules.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/__pycache__/submodules.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/channelnorm_package/__init__.py -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/channelnorm_package/__init__.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/channelnorm_package/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/_ext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/channelnorm_package/_ext/__init__.py -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/_ext/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/channelnorm_package/_ext/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/_ext/channelnorm/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.ffi import _wrap_function 3 | from ._channelnorm import lib as _lib, ffi as _ffi 4 | 5 | __all__ = [] 6 | def _import_symbols(locals): 7 | for symbol in dir(_lib): 8 | fn = getattr(_lib, symbol) 9 | if callable(fn): 10 | locals[symbol] = _wrap_function(fn, _ffi) 11 | else: 12 | locals[symbol] = fn 13 | __all__.append(symbol) 14 | 15 | _import_symbols(locals()) 16 | -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/_ext/channelnorm/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/channelnorm_package/_ext/channelnorm/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/_ext/channelnorm/_channelnorm.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/channelnorm_package/_ext/channelnorm/_channelnorm.so -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.ffi 4 | 5 | this_folder = os.path.dirname(os.path.abspath(__file__)) + '/' 6 | 7 | Headers = [] 8 | Sources = [] 9 | Defines = [] 10 | Objects = [] 11 | 12 | if torch.cuda.is_available() == True: 13 | Headers += ['src/ChannelNorm_cuda.h'] 14 | Sources += ['src/ChannelNorm_cuda.c'] 15 | Defines += [('WITH_CUDA', None)] 16 | Objects += ['src/ChannelNorm_kernel.o'] 17 | 18 | ffi = torch.utils.ffi.create_extension( 19 | name='_ext.channelnorm', 20 | headers=Headers, 21 | sources=Sources, 22 | verbose=False, 23 | with_cuda=True, 24 | package=False, 25 | relative_to=this_folder, 26 | define_macros=Defines, 27 | extra_objects=[os.path.join(this_folder, Object) for Object in Objects] 28 | ) 29 | 30 | if __name__ == '__main__': 31 | ffi.build() -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/channelnorm_package/functions/__init__.py -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/functions/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/channelnorm_package/functions/__init__.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/functions/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/channelnorm_package/functions/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/functions/__pycache__/channelnorm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/channelnorm_package/functions/__pycache__/channelnorm.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/functions/channelnorm.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Function, Variable 2 | from .._ext import channelnorm 3 | 4 | 5 | class ChannelNormFunction(Function): 6 | 7 | @staticmethod 8 | def forward(ctx, input1, norm_deg=2): 9 | assert input1.is_contiguous() 10 | b, _, h, w = input1.size() 11 | output = input1.new(b, 1, h, w).zero_() 12 | 13 | channelnorm.ChannelNorm_cuda_forward(input1, output, norm_deg) 14 | ctx.save_for_backward(input1, output) 15 | ctx.norm_deg = norm_deg 16 | 17 | return output 18 | 19 | @staticmethod 20 | def backward(ctx, grad_output): 21 | input1, output = ctx.saved_tensors 22 | 23 | grad_input1 = Variable(input1.new(input1.size()).zero_()) 24 | 25 | channelnorm.ChannelNorm_cuda_backward(input1, output, grad_output.data, 26 | grad_input1.data, ctx.norm_deg) 27 | 28 | return grad_input1, None 29 | -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/functions/channelnorm.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/channelnorm_package/functions/channelnorm.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | TORCH=$(python -c "import os; import torch; print(os.path.dirname(torch.__file__))") 3 | 4 | cd src 5 | echo "Compiling channelnorm kernels by nvcc..." 6 | rm ChannelNorm_kernel.o 7 | rm -r ../_ext 8 | 9 | nvcc -c -o ChannelNorm_kernel.o ChannelNorm_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52 -I ${TORCH}/lib/include/TH -I ${TORCH}/lib/include/THC 10 | 11 | cd ../ 12 | python build.py -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/channelnorm_package/modules/__init__.py -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/modules/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/channelnorm_package/modules/__init__.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/modules/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/channelnorm_package/modules/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/modules/__pycache__/channelnorm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/channelnorm_package/modules/__pycache__/channelnorm.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/modules/channelnorm.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.module import Module 2 | 3 | from ..functions.channelnorm import ChannelNormFunction 4 | 5 | 6 | class ChannelNorm(Module): 7 | 8 | def __init__(self, norm_deg=2): 9 | super(ChannelNorm, self).__init__() 10 | self.norm_deg = norm_deg 11 | 12 | def forward(self, input1): 13 | return ChannelNormFunction.apply(input1, self.norm_deg) 14 | -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/modules/channelnorm.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/channelnorm_package/modules/channelnorm.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/src/ChannelNorm_cuda.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "ChannelNorm_kernel.h" 5 | 6 | extern THCState* state; 7 | 8 | int ChannelNorm_cuda_forward(THCudaTensor* input1, THCudaTensor* output, int norm_deg) { 9 | ChannelNorm_kernel_forward(state, input1, output, norm_deg); 10 | return 1; 11 | } 12 | 13 | 14 | int ChannelNorm_cuda_backward(THCudaTensor* input1, THCudaTensor* output, THCudaTensor* gradOutput, THCudaTensor* gradInput1, int norm_deg) { 15 | ChannelNorm_kernel_backward(state, input1, output, gradOutput, gradInput1, norm_deg); 16 | return 1; 17 | } -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/src/ChannelNorm_cuda.h: -------------------------------------------------------------------------------- 1 | int ChannelNorm_cuda_forward(THCudaTensor* input1, THCudaTensor* output, int norm_deg); 2 | 3 | int ChannelNorm_cuda_backward(THCudaTensor* input1, THCudaTensor* output, THCudaTensor* gradOutput, THCudaTensor* gradInput1, int norm_deg); -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/src/ChannelNorm_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define CUDA_NUM_THREADS 512 5 | #define THREADS_PER_BLOCK 64 6 | 7 | #define DIM0(TENSOR) ((TENSOR).x) 8 | #define DIM1(TENSOR) ((TENSOR).y) 9 | #define DIM2(TENSOR) ((TENSOR).z) 10 | #define DIM3(TENSOR) ((TENSOR).w) 11 | 12 | #define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))]) 13 | 14 | 15 | #ifdef __cplusplus 16 | extern "C" { 17 | #endif 18 | 19 | __global__ void kernel_ChannelNorm_updateOutput(const int n, const float* input1, const long4 input1_size, const long4 input1_stride, float* output, const long4 output_size, const long4 output_stride, int norm_deg) { 20 | int index = blockIdx.x * blockDim.x + threadIdx.x; 21 | 22 | if (index >= n) { 23 | return; 24 | } 25 | 26 | int dim_b = DIM0(output_size); 27 | int dim_c = DIM1(output_size); 28 | int dim_h = DIM2(output_size); 29 | int dim_w = DIM3(output_size); 30 | int dim_chw = dim_c * dim_h * dim_w; 31 | 32 | int b = ( index / dim_chw ) % dim_b; 33 | int y = ( index / dim_w ) % dim_h; 34 | int x = ( index ) % dim_w; 35 | 36 | int i1dim_c = DIM1(input1_size); 37 | int i1dim_h = DIM2(input1_size); 38 | int i1dim_w = DIM3(input1_size); 39 | int i1dim_chw = i1dim_c * i1dim_h * i1dim_w; 40 | int i1dim_hw = i1dim_h * i1dim_w; 41 | 42 | float result = 0.0; 43 | 44 | for (int c = 0; c < i1dim_c; ++c) { 45 | int i1Index = b * i1dim_chw + c * i1dim_hw + y * i1dim_w + x; 46 | float val = input1[i1Index]; 47 | result += val * val; 48 | } 49 | result = sqrt(result); 50 | output[index] = result; 51 | } 52 | 53 | 54 | __global__ void kernel_ChannelNorm_backward_input1(const int n, const float* input1, const long4 input1_size, const long4 input1_stride, 55 | const float* output, const long4 output_size, const long4 output_stride, const float* gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, 56 | float* gradInput, const long4 gradInput_size, const long4 gradInput_stride, int norm_deg) { 57 | int index = blockIdx.x * blockDim.x + threadIdx.x; 58 | 59 | if (index >= n) { 60 | return; 61 | } 62 | 63 | float val = 0.0; 64 | 65 | int dim_b = DIM0(gradInput_size); 66 | int dim_c = DIM1(gradInput_size); 67 | int dim_h = DIM2(gradInput_size); 68 | int dim_w = DIM3(gradInput_size); 69 | int dim_chw = dim_c * dim_h * dim_w; 70 | int dim_hw = dim_h * dim_w; 71 | 72 | int b = ( index / dim_chw ) % dim_b; 73 | int y = ( index / dim_w ) % dim_h; 74 | int x = ( index ) % dim_w; 75 | 76 | 77 | int outIndex = b * dim_hw + y * dim_w + x; 78 | val = gradOutput[outIndex] * input1[index] / (output[outIndex]+1e-9); 79 | gradInput[index] = val; 80 | 81 | } 82 | 83 | void ChannelNorm_kernel_forward(THCState* state, THCudaTensor* input1, THCudaTensor* output, int norm_deg) { 84 | int n = 0; 85 | 86 | const long4 input1_size = make_long4(input1->size[0], input1->size[1], input1->size[2], input1->size[3]); 87 | const long4 input1_stride = make_long4(input1->stride[0], input1->stride[1], input1->stride[2], input1->stride[3]); 88 | 89 | const long4 output_size = make_long4(output->size[0], output->size[1], output->size[2], output->size[3]); 90 | const long4 output_stride = make_long4(output->stride[0], output->stride[1], output->stride[2], output->stride[3]); 91 | 92 | n = THCudaTensor_nElement(state, output); 93 | kernel_ChannelNorm_updateOutput<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>( 94 | n, THCudaTensor_data(state, input1), input1_size, input1_stride, THCudaTensor_data(state, output), output_size, output_stride, 95 | norm_deg); 96 | 97 | THCudaCheck(cudaGetLastError()); 98 | } 99 | 100 | void ChannelNorm_kernel_backward(THCState* state, THCudaTensor* input1, THCudaTensor* output, THCudaTensor* gradOutput, THCudaTensor* gradInput1, int norm_deg) { 101 | int n = 0; 102 | 103 | const long4 input1_size = make_long4(input1->size[0], input1->size[1], input1->size[2], input1->size[3]); 104 | const long4 input1_stride = make_long4(input1->stride[0], input1->stride[1], input1->stride[2], input1->stride[3]); 105 | 106 | const long4 output_size = make_long4(output->size[0], output->size[1], output->size[2], output->size[3]); 107 | const long4 output_stride = make_long4(output->stride[0], output->stride[1], output->stride[2], output->stride[3]); 108 | 109 | const long4 gradOutput_size = make_long4(gradOutput->size[0], gradOutput->size[1], gradOutput->size[2], gradOutput->size[3]); 110 | const long4 gradOutput_stride = make_long4(gradOutput->stride[0], gradOutput->stride[1], gradOutput->stride[2], gradOutput->stride[3]); 111 | 112 | const long4 gradInput1_size = make_long4(gradInput1->size[0], gradInput1->size[1], gradInput1->size[2], gradInput1->size[3]); 113 | const long4 gradInput1_stride = make_long4(gradInput1->stride[0], gradInput1->stride[1], gradInput1->stride[2], gradInput1->stride[3]); 114 | 115 | n = THCudaTensor_nElement(state, gradInput1); 116 | kernel_ChannelNorm_backward_input1<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>( 117 | n, THCudaTensor_data(state, input1), input1_size, input1_stride, THCudaTensor_data(state, output), output_size, output_stride, 118 | THCudaTensor_data(state, gradOutput), gradOutput_size, gradOutput_stride, THCudaTensor_data(state, gradInput1), gradInput1_size, gradInput1_stride, 119 | norm_deg 120 | ); 121 | 122 | THCudaCheck(cudaGetLastError()); 123 | } 124 | 125 | #ifdef __cplusplus 126 | } 127 | #endif -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/src/ChannelNorm_kernel.h: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | extern "C" { 3 | #endif 4 | 5 | void ChannelNorm_kernel_forward(THCState* state, THCudaTensor* input1, THCudaTensor* output, int norm_deg); 6 | 7 | 8 | void ChannelNorm_kernel_backward(THCState* state, THCudaTensor* input1, THCudaTensor* output, THCudaTensor* gradOutput, THCudaTensor* gradInput1, int norm_deg); 9 | 10 | #ifdef __cplusplus 11 | } 12 | #endif -------------------------------------------------------------------------------- /flow_inference/networks_flow/channelnorm_package/src/ChannelNorm_kernel.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/channelnorm_package/src/ChannelNorm_kernel.o -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/correlation_package/__init__.py -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/correlation_package/__init__.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/correlation_package/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/_ext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/correlation_package/_ext/__init__.py -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/_ext/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/correlation_package/_ext/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/_ext/correlation/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.ffi import _wrap_function 3 | from ._correlation import lib as _lib, ffi as _ffi 4 | 5 | __all__ = [] 6 | def _import_symbols(locals): 7 | for symbol in dir(_lib): 8 | fn = getattr(_lib, symbol) 9 | if callable(fn): 10 | locals[symbol] = _wrap_function(fn, _ffi) 11 | else: 12 | locals[symbol] = fn 13 | __all__.append(symbol) 14 | 15 | _import_symbols(locals()) 16 | -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/_ext/correlation/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/correlation_package/_ext/correlation/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/_ext/correlation/_correlation.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/correlation_package/_ext/correlation/_correlation.so -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.ffi 4 | 5 | this_folder = os.path.dirname(os.path.abspath(__file__)) + '/' 6 | 7 | Headers = [] 8 | Sources = [] 9 | Defines = [] 10 | Objects = [] 11 | 12 | if torch.cuda.is_available() == True: 13 | Headers += ['src/correlation_cuda.h'] 14 | Sources += ['src/correlation_cuda.c'] 15 | Defines += [('WITH_CUDA', None)] 16 | Objects += ['src/correlation_cuda_kernel.o'] 17 | 18 | ffi = torch.utils.ffi.create_extension( 19 | name='_ext.correlation', 20 | headers=Headers, 21 | sources=Sources, 22 | verbose=False, 23 | with_cuda=True, 24 | package=False, 25 | relative_to=this_folder, 26 | define_macros=Defines, 27 | extra_objects=[os.path.join(this_folder, Object) for Object in Objects] 28 | ) 29 | 30 | if __name__ == '__main__': 31 | ffi.build() -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/correlation_package/functions/__init__.py -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/functions/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/correlation_package/functions/__init__.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/functions/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/correlation_package/functions/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/functions/__pycache__/correlation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/correlation_package/functions/__pycache__/correlation.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/functions/correlation.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Function, Variable 2 | from .._ext import correlation 3 | 4 | 5 | class CorrelationFunction(Function): 6 | 7 | @staticmethod 8 | def forward(ctx, 9 | input1, 10 | input2, 11 | pad_size=3, 12 | kernel_size=3, 13 | max_displacement=20, 14 | stride1=1, 15 | stride2=2, 16 | corr_multiply=1): 17 | assert input1.is_contiguous() 18 | assert input2.is_contiguous() 19 | 20 | ctx.save_for_backward(input1, input2) 21 | ctx.pad_size = pad_size 22 | ctx.kernel_size = kernel_size 23 | ctx.max_displacement = max_displacement 24 | ctx.stride1 = stride1 25 | ctx.stride2 = stride2 26 | ctx.corr_multiply = corr_multiply 27 | 28 | rbot1 = input1.new() 29 | rbot2 = input2.new() 30 | output = input1.new() 31 | 32 | correlation.Correlation_forward_cuda( 33 | input1, input2, rbot1, rbot2, output, pad_size, kernel_size, 34 | max_displacement, stride1, stride2, corr_multiply) 35 | 36 | return output 37 | 38 | @staticmethod 39 | def backward(ctx, grad_output): 40 | assert grad_output.is_contiguous() 41 | 42 | input1, input2 = ctx.saved_tensors 43 | 44 | rbot1 = input1.new() 45 | rbot2 = input2.new() 46 | 47 | grad_input1 = Variable(input1.new()) 48 | grad_input2 = Variable(input2.new()) 49 | 50 | correlation.Correlation_backward_cuda( 51 | input1, input2, rbot1, rbot2, grad_output.data, grad_input1.data, 52 | grad_input2.data, ctx.pad_size, ctx.kernel_size, 53 | ctx.max_displacement, ctx.stride1, ctx.stride2, ctx.corr_multiply) 54 | 55 | return (grad_input1, grad_input2) + (None, ) * 6 56 | -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/functions/correlation.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/correlation_package/functions/correlation.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | TORCH=$(python -c "import os; import torch; print(os.path.dirname(torch.__file__))") 3 | 4 | cd src 5 | 6 | echo "Compiling correlation kernels by nvcc..." 7 | 8 | rm correlation_cuda_kernel.o 9 | rm -r ../_ext 10 | 11 | nvcc -c -o correlation_cuda_kernel.o correlation_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52 12 | 13 | cd ../ 14 | python build.py 15 | -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/correlation_package/modules/__init__.py -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/modules/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/correlation_package/modules/__init__.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/modules/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/correlation_package/modules/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/modules/__pycache__/correlation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/correlation_package/modules/__pycache__/correlation.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/modules/correlation.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.module import Module 2 | 3 | from ..functions.correlation import CorrelationFunction 4 | 5 | 6 | class Correlation(Module): 7 | 8 | def __init__(self, 9 | pad_size=0, 10 | kernel_size=0, 11 | max_displacement=0, 12 | stride1=1, 13 | stride2=2, 14 | corr_multiply=1): 15 | super(Correlation, self).__init__() 16 | self.pad_size = pad_size 17 | self.kernel_size = kernel_size 18 | self.max_displacement = max_displacement 19 | self.stride1 = stride1 20 | self.stride2 = stride2 21 | self.corr_multiply = corr_multiply 22 | 23 | def forward(self, input1, input2): 24 | return CorrelationFunction.apply(input1, input2, self.pad_size, 25 | self.kernel_size, 26 | self.max_displacement, self.stride1, 27 | self.stride2, self.corr_multiply) 28 | -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/modules/correlation.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/correlation_package/modules/correlation.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/src/correlation.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | int Correlation_forward_cpu(THFloatTensor *input1, 4 | THFloatTensor *input2, 5 | THFloatTensor *rInput1, 6 | THFloatTensor *rInput2, 7 | THFloatTensor *output, 8 | int pad_size, 9 | int kernel_size, 10 | int max_displacement, 11 | int stride1, 12 | int stride2, 13 | int corr_type_multiply) 14 | { 15 | return 1; 16 | } 17 | 18 | int Correlation_backward_cpu(THFloatTensor *input1, 19 | THFloatTensor *input2, 20 | THFloatTensor *rInput1, 21 | THFloatTensor *rInput2, 22 | THFloatTensor *gradOutput, 23 | THFloatTensor *gradInput1, 24 | THFloatTensor *gradInput2, 25 | int pad_size, 26 | int kernel_size, 27 | int max_displacement, 28 | int stride1, 29 | int stride2, 30 | int corr_type_multiply) 31 | { 32 | return 1; 33 | } 34 | -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/src/correlation.h: -------------------------------------------------------------------------------- 1 | int Correlation_forward_cpu(THFloatTensor *input1, 2 | THFloatTensor *input2, 3 | THFloatTensor *rInput1, 4 | THFloatTensor *rInput2, 5 | THFloatTensor *output, 6 | int pad_size, 7 | int kernel_size, 8 | int max_displacement, 9 | int stride1, 10 | int stride2, 11 | int corr_type_multiply); 12 | 13 | int Correlation_backward_cpu(THFloatTensor *input1, 14 | THFloatTensor *input2, 15 | THFloatTensor *rInput1, 16 | THFloatTensor *rInput2, 17 | THFloatTensor *gradOutput, 18 | THFloatTensor *gradInput1, 19 | THFloatTensor *gradInput2, 20 | int pad_size, 21 | int kernel_size, 22 | int max_displacement, 23 | int stride1, 24 | int stride2, 25 | int corr_type_multiply); 26 | -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/src/correlation_cuda.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "correlation_cuda_kernel.h" 5 | 6 | #define real float 7 | 8 | // symbol to be automatically resolved by PyTorch libs 9 | extern THCState *state; 10 | 11 | int Correlation_forward_cuda(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *rInput1, THCudaTensor *rInput2, THCudaTensor *output, 12 | int pad_size, 13 | int kernel_size, 14 | int max_displacement, 15 | int stride1, 16 | int stride2, 17 | int corr_type_multiply) 18 | { 19 | 20 | int batchSize = input1->size[0]; 21 | int nInputChannels = input1->size[1]; 22 | int inputHeight = input1->size[2]; 23 | int inputWidth = input1->size[3]; 24 | 25 | int kernel_radius = (kernel_size - 1) / 2; 26 | int border_radius = kernel_radius + max_displacement; 27 | 28 | int paddedInputHeight = inputHeight + 2 * pad_size; 29 | int paddedInputWidth = inputWidth + 2 * pad_size; 30 | 31 | int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1); 32 | 33 | int outputHeight = ceil((float)(paddedInputHeight - 2 * border_radius) / (float)stride1); 34 | int outputwidth = ceil((float)(paddedInputWidth - 2 * border_radius) / (float)stride1); 35 | 36 | THCudaTensor_resize4d(state, rInput1, batchSize, paddedInputHeight, paddedInputWidth, nInputChannels); 37 | THCudaTensor_resize4d(state, rInput2, batchSize, paddedInputHeight, paddedInputWidth, nInputChannels); 38 | THCudaTensor_resize4d(state, output, batchSize, nOutputChannels, outputHeight, outputwidth); 39 | 40 | THCudaTensor_fill(state, rInput1, 0); 41 | THCudaTensor_fill(state, rInput2, 0); 42 | THCudaTensor_fill(state, output, 0); 43 | 44 | int success = 0; 45 | success = Correlation_forward_cuda_kernel( THCudaTensor_data(state, output), 46 | THCudaTensor_size(state, output, 0), 47 | THCudaTensor_size(state, output, 1), 48 | THCudaTensor_size(state, output, 2), 49 | THCudaTensor_size(state, output, 3), 50 | THCudaTensor_stride(state, output, 0), 51 | THCudaTensor_stride(state, output, 1), 52 | THCudaTensor_stride(state, output, 2), 53 | THCudaTensor_stride(state, output, 3), 54 | 55 | THCudaTensor_data(state, input1), 56 | THCudaTensor_size(state, input1, 1), 57 | THCudaTensor_size(state, input1, 2), 58 | THCudaTensor_size(state, input1, 3), 59 | THCudaTensor_stride(state, input1, 0), 60 | THCudaTensor_stride(state, input1, 1), 61 | THCudaTensor_stride(state, input1, 2), 62 | THCudaTensor_stride(state, input1, 3), 63 | 64 | THCudaTensor_data(state, input2), 65 | THCudaTensor_size(state, input2, 1), 66 | THCudaTensor_stride(state, input2, 0), 67 | THCudaTensor_stride(state, input2, 1), 68 | THCudaTensor_stride(state, input2, 2), 69 | THCudaTensor_stride(state, input2, 3), 70 | 71 | THCudaTensor_data(state, rInput1), 72 | THCudaTensor_data(state, rInput2), 73 | 74 | pad_size, 75 | kernel_size, 76 | max_displacement, 77 | stride1, 78 | stride2, 79 | corr_type_multiply, 80 | 81 | THCState_getCurrentStream(state)); 82 | 83 | THCudaTensor_free(state, rInput1); 84 | THCudaTensor_free(state, rInput2); 85 | 86 | //check for errors 87 | if (!success) { 88 | THError("aborting"); 89 | } 90 | 91 | return 1; 92 | 93 | } 94 | 95 | int Correlation_backward_cuda(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *rInput1, THCudaTensor *rInput2, THCudaTensor *gradOutput, 96 | THCudaTensor *gradInput1, THCudaTensor *gradInput2, 97 | int pad_size, 98 | int kernel_size, 99 | int max_displacement, 100 | int stride1, 101 | int stride2, 102 | int corr_type_multiply) 103 | { 104 | 105 | int batchSize = input1->size[0]; 106 | int nInputChannels = input1->size[1]; 107 | int paddedInputHeight = input1->size[2]+ 2 * pad_size; 108 | int paddedInputWidth = input1->size[3]+ 2 * pad_size; 109 | 110 | int height = input1->size[2]; 111 | int width = input1->size[3]; 112 | 113 | THCudaTensor_resize4d(state, rInput1, batchSize, paddedInputHeight, paddedInputWidth, nInputChannels); 114 | THCudaTensor_resize4d(state, rInput2, batchSize, paddedInputHeight, paddedInputWidth, nInputChannels); 115 | THCudaTensor_resize4d(state, gradInput1, batchSize, nInputChannels, height, width); 116 | THCudaTensor_resize4d(state, gradInput2, batchSize, nInputChannels, height, width); 117 | 118 | THCudaTensor_fill(state, rInput1, 0); 119 | THCudaTensor_fill(state, rInput2, 0); 120 | THCudaTensor_fill(state, gradInput1, 0); 121 | THCudaTensor_fill(state, gradInput2, 0); 122 | 123 | int success = 0; 124 | success = Correlation_backward_cuda_kernel( 125 | THCudaTensor_data(state, gradOutput), 126 | THCudaTensor_size(state, gradOutput, 0), 127 | THCudaTensor_size(state, gradOutput, 1), 128 | THCudaTensor_size(state, gradOutput, 2), 129 | THCudaTensor_size(state, gradOutput, 3), 130 | THCudaTensor_stride(state, gradOutput, 0), 131 | THCudaTensor_stride(state, gradOutput, 1), 132 | THCudaTensor_stride(state, gradOutput, 2), 133 | THCudaTensor_stride(state, gradOutput, 3), 134 | 135 | THCudaTensor_data(state, input1), 136 | THCudaTensor_size(state, input1, 1), 137 | THCudaTensor_size(state, input1, 2), 138 | THCudaTensor_size(state, input1, 3), 139 | THCudaTensor_stride(state, input1, 0), 140 | THCudaTensor_stride(state, input1, 1), 141 | THCudaTensor_stride(state, input1, 2), 142 | THCudaTensor_stride(state, input1, 3), 143 | 144 | THCudaTensor_data(state, input2), 145 | THCudaTensor_stride(state, input2, 0), 146 | THCudaTensor_stride(state, input2, 1), 147 | THCudaTensor_stride(state, input2, 2), 148 | THCudaTensor_stride(state, input2, 3), 149 | 150 | THCudaTensor_data(state, gradInput1), 151 | THCudaTensor_stride(state, gradInput1, 0), 152 | THCudaTensor_stride(state, gradInput1, 1), 153 | THCudaTensor_stride(state, gradInput1, 2), 154 | THCudaTensor_stride(state, gradInput1, 3), 155 | 156 | THCudaTensor_data(state, gradInput2), 157 | THCudaTensor_size(state, gradInput2, 1), 158 | THCudaTensor_stride(state, gradInput2, 0), 159 | THCudaTensor_stride(state, gradInput2, 1), 160 | THCudaTensor_stride(state, gradInput2, 2), 161 | THCudaTensor_stride(state, gradInput2, 3), 162 | 163 | THCudaTensor_data(state, rInput1), 164 | THCudaTensor_data(state, rInput2), 165 | pad_size, 166 | kernel_size, 167 | max_displacement, 168 | stride1, 169 | stride2, 170 | corr_type_multiply, 171 | THCState_getCurrentStream(state)); 172 | 173 | THCudaTensor_free(state, rInput1); 174 | THCudaTensor_free(state, rInput2); 175 | 176 | if (!success) { 177 | THError("aborting"); 178 | } 179 | return 1; 180 | } 181 | -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/src/correlation_cuda.h: -------------------------------------------------------------------------------- 1 | int Correlation_forward_cuda(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *rInput1, THCudaTensor *rInput2, 2 | THCudaTensor *output, 3 | int pad_size, 4 | int kernel_size, 5 | int max_displacement, 6 | int stride1, 7 | int stride2, 8 | int corr_type_multiply); 9 | 10 | int Correlation_backward_cuda(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *rInput1, THCudaTensor *rInput2, 11 | THCudaTensor *gradOutput, THCudaTensor *gradInput1, THCudaTensor *gradInput2, 12 | int pad_size, 13 | int kernel_size, 14 | int max_displacement, 15 | int stride1, 16 | int stride2, 17 | int corr_type_multiply); 18 | 19 | -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/src/correlation_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "correlation_cuda_kernel.h" 4 | 5 | #define real float 6 | 7 | #define CUDA_NUM_THREADS 1024 8 | #define THREADS_PER_BLOCK 32 9 | 10 | __global__ void channels_first(float* input, float* rinput, int channels, int height, int width, int pad_size) 11 | { 12 | // n (batch size), c (num of channels), y (height), x (width) 13 | int n = blockIdx.x; 14 | int y = blockIdx.y; 15 | int x = blockIdx.z; 16 | 17 | int ch_off = threadIdx.x; 18 | float value; 19 | 20 | int dimcyx = channels * height * width; 21 | int dimyx = height * width; 22 | 23 | int p_dimx = (width + 2 * pad_size); 24 | int p_dimy = (height + 2 * pad_size); 25 | int p_dimyxc = channels * p_dimy * p_dimx; 26 | int p_dimxc = p_dimx * channels; 27 | 28 | for (int c = ch_off; c < channels; c += THREADS_PER_BLOCK) { 29 | value = input[n * dimcyx + c * dimyx + y * width + x]; 30 | rinput[n * p_dimyxc + (y + pad_size) * p_dimxc + (x + pad_size) * channels + c] = value; 31 | } 32 | } 33 | 34 | __global__ void Correlation_forward( float *output, int nOutputChannels, int outputHeight, int outputWidth, 35 | float *rInput1, int nInputChannels, int inputHeight, int inputWidth, 36 | float *rInput2, 37 | int pad_size, 38 | int kernel_size, 39 | int max_displacement, 40 | int stride1, 41 | int stride2) 42 | { 43 | // n (batch size), c (num of channels), y (height), x (width) 44 | 45 | int pInputWidth = inputWidth + 2 * pad_size; 46 | int pInputHeight = inputHeight + 2 * pad_size; 47 | 48 | int kernel_rad = (kernel_size - 1) / 2; 49 | int displacement_rad = max_displacement / stride2; 50 | int displacement_size = 2 * displacement_rad + 1; 51 | 52 | int n = blockIdx.x; 53 | int y1 = blockIdx.y * stride1 + max_displacement + kernel_rad; 54 | int x1 = blockIdx.z * stride1 + max_displacement + kernel_rad; 55 | int c = threadIdx.x; 56 | 57 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels; 58 | int pdimxc = pInputWidth * nInputChannels; 59 | int pdimc = nInputChannels; 60 | 61 | int tdimcyx = nOutputChannels * outputHeight * outputWidth; 62 | int tdimyx = outputHeight * outputWidth; 63 | int tdimx = outputWidth; 64 | 65 | float nelems = kernel_size * kernel_size * pdimc; 66 | 67 | __shared__ float prod_sum[THREADS_PER_BLOCK]; 68 | 69 | // no significant speed-up in using chip memory for input1 sub-data, 70 | // not enough chip memory size to accomodate memory per block for input2 sub-data 71 | // instead i've used device memory for both 72 | 73 | // element-wise product along channel axis 74 | for (int tj = -displacement_rad; tj <= displacement_rad; ++tj ) { 75 | for (int ti = -displacement_rad; ti <= displacement_rad; ++ti ) { 76 | prod_sum[c] = 0; 77 | int x2 = x1 + ti*stride2; 78 | int y2 = y1 + tj*stride2; 79 | 80 | for (int j = -kernel_rad; j <= kernel_rad; ++j) { 81 | for (int i = -kernel_rad; i <= kernel_rad; ++i) { 82 | for (int ch = c; ch < pdimc; ch += THREADS_PER_BLOCK) { 83 | int indx1 = n * pdimyxc + (y1+j) * pdimxc + (x1 + i) * pdimc + ch; 84 | int indx2 = n * pdimyxc + (y2+j) * pdimxc + (x2 + i) * pdimc + ch; 85 | 86 | prod_sum[c] += rInput1[indx1] * rInput2[indx2]; 87 | } 88 | } 89 | } 90 | 91 | // accumulate 92 | __syncthreads(); 93 | if (c == 0) { 94 | float reduce_sum = 0; 95 | for (int index = 0; index < THREADS_PER_BLOCK; ++index) { 96 | reduce_sum += prod_sum[index]; 97 | } 98 | int tc = (tj + displacement_rad) * displacement_size + (ti + displacement_rad); 99 | const int tindx = n * tdimcyx + tc * tdimyx + blockIdx.y * tdimx + blockIdx.z; 100 | output[tindx] = reduce_sum / nelems; 101 | } 102 | 103 | } 104 | } 105 | 106 | } 107 | 108 | __global__ void Correlation_backward_input1(int item, float *gradInput1, int nInputChannels, int inputHeight, int inputWidth, 109 | float *gradOutput, int nOutputChannels, int outputHeight, int outputWidth, 110 | float *rInput2, 111 | int pad_size, 112 | int kernel_size, 113 | int max_displacement, 114 | int stride1, 115 | int stride2) 116 | { 117 | // n (batch size), c (num of channels), y (height), x (width) 118 | 119 | int n = item; 120 | int y = blockIdx.x * stride1 + pad_size; 121 | int x = blockIdx.y * stride1 + pad_size; 122 | int c = blockIdx.z; 123 | int tch_off = threadIdx.x; 124 | 125 | int kernel_rad = (kernel_size - 1) / 2; 126 | int displacement_rad = max_displacement / stride2; 127 | int displacement_size = 2 * displacement_rad + 1; 128 | 129 | int xmin = (x - kernel_rad - max_displacement) / stride1; 130 | int ymin = (y - kernel_rad - max_displacement) / stride1; 131 | 132 | int xmax = (x + kernel_rad - max_displacement) / stride1; 133 | int ymax = (y + kernel_rad - max_displacement) / stride1; 134 | 135 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) { 136 | // assumes gradInput1 is pre-allocated and zero filled 137 | return; 138 | } 139 | 140 | if (xmin > xmax || ymin > ymax) { 141 | // assumes gradInput1 is pre-allocated and zero filled 142 | return; 143 | } 144 | 145 | xmin = max(0,xmin); 146 | xmax = min(outputWidth-1,xmax); 147 | 148 | ymin = max(0,ymin); 149 | ymax = min(outputHeight-1,ymax); 150 | 151 | int pInputWidth = inputWidth + 2 * pad_size; 152 | int pInputHeight = inputHeight + 2 * pad_size; 153 | 154 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels; 155 | int pdimxc = pInputWidth * nInputChannels; 156 | int pdimc = nInputChannels; 157 | 158 | int tdimcyx = nOutputChannels * outputHeight * outputWidth; 159 | int tdimyx = outputHeight * outputWidth; 160 | int tdimx = outputWidth; 161 | 162 | int odimcyx = nInputChannels * inputHeight* inputWidth; 163 | int odimyx = inputHeight * inputWidth; 164 | int odimx = inputWidth; 165 | 166 | float nelems = kernel_size * kernel_size * nInputChannels; 167 | 168 | __shared__ float prod_sum[THREADS_PER_BLOCK]; 169 | prod_sum[tch_off] = 0; 170 | 171 | for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) { 172 | 173 | int i2 = (tc % displacement_size - displacement_rad) * stride2; 174 | int j2 = (tc / displacement_size - displacement_rad) * stride2; 175 | 176 | int indx2 = n * pdimyxc + (y + j2)* pdimxc + (x + i2) * pdimc + c; 177 | 178 | float val2 = rInput2[indx2]; 179 | 180 | for (int j = ymin; j <= ymax; ++j) { 181 | for (int i = xmin; i <= xmax; ++i) { 182 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i; 183 | prod_sum[tch_off] += gradOutput[tindx] * val2; 184 | } 185 | } 186 | } 187 | __syncthreads(); 188 | 189 | if(tch_off == 0) { 190 | float reduce_sum = 0; 191 | for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) { 192 | reduce_sum += prod_sum[idx]; 193 | } 194 | const int indx1 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size); 195 | gradInput1[indx1] = reduce_sum / nelems; 196 | } 197 | 198 | } 199 | 200 | __global__ void Correlation_backward_input2(int item, float *gradInput2, int nInputChannels, int inputHeight, int inputWidth, 201 | float *gradOutput, int nOutputChannels, int outputHeight, int outputWidth, 202 | float *rInput1, 203 | int pad_size, 204 | int kernel_size, 205 | int max_displacement, 206 | int stride1, 207 | int stride2) 208 | { 209 | // n (batch size), c (num of channels), y (height), x (width) 210 | 211 | int n = item; 212 | int y = blockIdx.x * stride1 + pad_size; 213 | int x = blockIdx.y * stride1 + pad_size; 214 | int c = blockIdx.z; 215 | 216 | int tch_off = threadIdx.x; 217 | 218 | int kernel_rad = (kernel_size - 1) / 2; 219 | int displacement_rad = max_displacement / stride2; 220 | int displacement_size = 2 * displacement_rad + 1; 221 | 222 | int pInputWidth = inputWidth + 2 * pad_size; 223 | int pInputHeight = inputHeight + 2 * pad_size; 224 | 225 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels; 226 | int pdimxc = pInputWidth * nInputChannels; 227 | int pdimc = nInputChannels; 228 | 229 | int tdimcyx = nOutputChannels * outputHeight * outputWidth; 230 | int tdimyx = outputHeight * outputWidth; 231 | int tdimx = outputWidth; 232 | 233 | int odimcyx = nInputChannels * inputHeight* inputWidth; 234 | int odimyx = inputHeight * inputWidth; 235 | int odimx = inputWidth; 236 | 237 | float nelems = kernel_size * kernel_size * nInputChannels; 238 | 239 | __shared__ float prod_sum[THREADS_PER_BLOCK]; 240 | prod_sum[tch_off] = 0; 241 | 242 | for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) { 243 | int i2 = (tc % displacement_size - displacement_rad) * stride2; 244 | int j2 = (tc / displacement_size - displacement_rad) * stride2; 245 | 246 | int xmin = (x - kernel_rad - max_displacement - i2) / stride1; 247 | int ymin = (y - kernel_rad - max_displacement - j2) / stride1; 248 | 249 | int xmax = (x + kernel_rad - max_displacement - i2) / stride1; 250 | int ymax = (y + kernel_rad - max_displacement - j2) / stride1; 251 | 252 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) { 253 | // assumes gradInput2 is pre-allocated and zero filled 254 | continue; 255 | } 256 | 257 | if (xmin > xmax || ymin > ymax) { 258 | // assumes gradInput2 is pre-allocated and zero filled 259 | continue; 260 | } 261 | 262 | xmin = max(0,xmin); 263 | xmax = min(outputWidth-1,xmax); 264 | 265 | ymin = max(0,ymin); 266 | ymax = min(outputHeight-1,ymax); 267 | 268 | int indx1 = n * pdimyxc + (y - j2)* pdimxc + (x - i2) * pdimc + c; 269 | float val1 = rInput1[indx1]; 270 | 271 | for (int j = ymin; j <= ymax; ++j) { 272 | for (int i = xmin; i <= xmax; ++i) { 273 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i; 274 | prod_sum[tch_off] += gradOutput[tindx] * val1; 275 | } 276 | } 277 | } 278 | 279 | __syncthreads(); 280 | 281 | if(tch_off == 0) { 282 | float reduce_sum = 0; 283 | for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) { 284 | reduce_sum += prod_sum[idx]; 285 | } 286 | const int indx2 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size); 287 | gradInput2[indx2] = reduce_sum / nelems; 288 | } 289 | 290 | } 291 | 292 | #ifdef __cplusplus 293 | extern "C" { 294 | #endif 295 | 296 | int Correlation_forward_cuda_kernel(/*THCudaTensor_data(state, output)*/ float *output, 297 | /*THCudaTensor_size(state, output, 0)*/ int ob, 298 | /*THCudaTensor_size(state, output, 1)*/ int oc, 299 | /*THCudaTensor_size(state, output, 2)*/ int oh, 300 | /*THCudaTensor_size(state, output, 3)*/ int ow, 301 | /*THCudaTensor_stride(state, output, 0)*/ int osb, 302 | /*THCudaTensor_stride(state, output, 1)*/ int osc, 303 | /*THCudaTensor_stride(state, output, 2)*/ int osh, 304 | /*THCudaTensor_stride(state, output, 3)*/ int osw, 305 | 306 | /*THCudaTensor_data(state, input1)*/ float *input1, 307 | /*THCudaTensor_size(state, input1, 1)*/ int ic, 308 | /*THCudaTensor_size(state, input1, 2)*/ int ih, 309 | /*THCudaTensor_size(state, input1, 3)*/ int iw, 310 | /*THCudaTensor_stride(state, input1, 0)*/ int isb, 311 | /*THCudaTensor_stride(state, input1, 1)*/ int isc, 312 | /*THCudaTensor_stride(state, input1, 2)*/ int ish, 313 | /*THCudaTensor_stride(state, input1, 3)*/ int isw, 314 | 315 | /*THCudaTensor_data(state, input2)*/ float *input2, 316 | /*THCudaTensor_size(state, input2, 1)*/ int gc, 317 | /*THCudaTensor_stride(state, input2, 0)*/ int gsb, 318 | /*THCudaTensor_stride(state, input2, 1)*/ int gsc, 319 | /*THCudaTensor_stride(state, input2, 2)*/ int gsh, 320 | /*THCudaTensor_stride(state, input2, 3)*/ int gsw, 321 | 322 | /*THCudaTensor_data(state, rInput1)*/ float *rInput1, 323 | /*THCudaTensor_data(state, rInput2)*/ float *rInput2, 324 | int pad_size, 325 | int kernel_size, 326 | int max_displacement, 327 | int stride1, 328 | int stride2, 329 | int corr_type_multiply, 330 | /*THCState_getCurrentStream(state)*/ cudaStream_t stream) 331 | { 332 | int batchSize = ob; 333 | 334 | int nInputChannels = ic; 335 | int inputWidth = iw; 336 | int inputHeight = ih; 337 | 338 | int nOutputChannels = oc; 339 | int outputWidth = ow; 340 | int outputHeight = oh; 341 | 342 | dim3 blocks_grid(batchSize, inputHeight, inputWidth); 343 | dim3 threads_block(THREADS_PER_BLOCK); 344 | 345 | channels_first<<>> (input1,rInput1, nInputChannels, inputHeight, inputWidth,pad_size); 346 | channels_first<<>> (input2,rInput2, nInputChannels, inputHeight, inputWidth, pad_size); 347 | 348 | dim3 threadsPerBlock(THREADS_PER_BLOCK); 349 | dim3 totalBlocksCorr(batchSize, outputHeight, outputWidth); 350 | 351 | Correlation_forward <<< totalBlocksCorr, threadsPerBlock, 0, stream >>> 352 | (output, nOutputChannels, outputHeight, outputWidth, 353 | rInput1, nInputChannels, inputHeight, inputWidth, 354 | rInput2, 355 | pad_size, 356 | kernel_size, 357 | max_displacement, 358 | stride1, 359 | stride2); 360 | 361 | // check for errors 362 | cudaError_t err = cudaGetLastError(); 363 | if (err != cudaSuccess) { 364 | printf("error in Correlation_forward_cuda_kernel: %s\n", cudaGetErrorString(err)); 365 | return 0; 366 | } 367 | 368 | return 1; 369 | } 370 | 371 | int Correlation_backward_cuda_kernel( 372 | /*THCudaTensor_data(state, gradOutput)*/ float *gradOutput, 373 | /*THCudaTensor_size(state, gradOutput, 0)*/ int gob, 374 | /*THCudaTensor_size(state, gradOutput, 1)*/ int goc, 375 | /*THCudaTensor_size(state, gradOutput, 2)*/ int goh, 376 | /*THCudaTensor_size(state, gradOutput, 3)*/ int gow, 377 | /*THCudaTensor_stride(state, gradOutput, 0)*/ int gosb, 378 | /*THCudaTensor_stride(state, gradOutput, 1)*/ int gosc, 379 | /*THCudaTensor_stride(state, gradOutput, 2)*/ int gosh, 380 | /*THCudaTensor_stride(state, gradOutput, 3)*/ int gosw, 381 | 382 | /*THCudaTensor_data(state, input1)*/ float* input1, 383 | /*THCudaTensor_size(state, input1, 1)*/ int ic, 384 | /*THCudaTensor_size(state, input1, 2)*/ int ih, 385 | /*THCudaTensor_size(state, input1, 3)*/ int iw, 386 | /*THCudaTensor_stride(state, input1, 0)*/ int isb, 387 | /*THCudaTensor_stride(state, input1, 1)*/ int isc, 388 | /*THCudaTensor_stride(state, input1, 2)*/ int ish, 389 | /*THCudaTensor_stride(state, input1, 3)*/ int isw, 390 | 391 | /*THCudaTensor_data(state, input2)*/ float *input2, 392 | /*THCudaTensor_stride(state, input2, 0)*/ int gsb, 393 | /*THCudaTensor_stride(state, input2, 1)*/ int gsc, 394 | /*THCudaTensor_stride(state, input2, 2)*/ int gsh, 395 | /*THCudaTensor_stride(state, input2, 3)*/ int gsw, 396 | 397 | /*THCudaTensor_data(state, gradInput1)*/ float *gradInput1, 398 | /*THCudaTensor_stride(state, gradInput1, 0)*/ int gisb, 399 | /*THCudaTensor_stride(state, gradInput1, 1)*/ int gisc, 400 | /*THCudaTensor_stride(state, gradInput1, 2)*/ int gish, 401 | /*THCudaTensor_stride(state, gradInput1, 3)*/ int gisw, 402 | 403 | /*THCudaTensor_data(state, gradInput2)*/ float *gradInput2, 404 | /*THCudaTensor_size(state, gradInput2, 1)*/ int ggc, 405 | /*THCudaTensor_stride(state, gradInput2, 0)*/ int ggsb, 406 | /*THCudaTensor_stride(state, gradInput2, 1)*/ int ggsc, 407 | /*THCudaTensor_stride(state, gradInput2, 2)*/ int ggsh, 408 | /*THCudaTensor_stride(state, gradInput2, 3)*/ int ggsw, 409 | 410 | /*THCudaTensor_data(state, rInput1)*/ float *rInput1, 411 | /*THCudaTensor_data(state, rInput2)*/ float *rInput2, 412 | int pad_size, 413 | int kernel_size, 414 | int max_displacement, 415 | int stride1, 416 | int stride2, 417 | int corr_type_multiply, 418 | /*THCState_getCurrentStream(state)*/cudaStream_t stream) 419 | { 420 | 421 | int batchSize = gob; 422 | int num = batchSize; 423 | 424 | int nInputChannels = ic; 425 | int inputWidth = iw; 426 | int inputHeight = ih; 427 | 428 | int nOutputChannels = goc; 429 | int outputWidth = gow; 430 | int outputHeight = goh; 431 | 432 | dim3 blocks_grid(batchSize, inputHeight, inputWidth); 433 | dim3 threads_block(THREADS_PER_BLOCK); 434 | 435 | channels_first<<>> (input1, rInput1, nInputChannels,inputHeight, inputWidth, pad_size); 436 | channels_first<<>> (input2, rInput2, nInputChannels, inputHeight, inputWidth, pad_size); 437 | 438 | dim3 threadsPerBlock(THREADS_PER_BLOCK); 439 | dim3 totalBlocksCorr(inputHeight, inputWidth, nInputChannels); 440 | 441 | for (int n = 0; n < num; ++n) { 442 | Correlation_backward_input1 << > > ( 443 | n, gradInput1, nInputChannels, inputHeight, inputWidth, 444 | gradOutput, nOutputChannels, outputHeight, outputWidth, 445 | rInput2, 446 | pad_size, 447 | kernel_size, 448 | max_displacement, 449 | stride1, 450 | stride2); 451 | } 452 | 453 | for(int n = 0; n < batchSize; n++) { 454 | Correlation_backward_input2<<>>( 455 | n, gradInput2, nInputChannels, inputHeight, inputWidth, 456 | gradOutput, nOutputChannels, outputHeight, outputWidth, 457 | rInput1, 458 | pad_size, 459 | kernel_size, 460 | max_displacement, 461 | stride1, 462 | stride2); 463 | } 464 | 465 | // check for errors 466 | cudaError_t err = cudaGetLastError(); 467 | if (err != cudaSuccess) { 468 | printf("error in Correlation_backward_cuda_kernel: %s\n", cudaGetErrorString(err)); 469 | return 0; 470 | } 471 | 472 | return 1; 473 | } 474 | 475 | #ifdef __cplusplus 476 | } 477 | #endif 478 | -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/src/correlation_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | extern "C" { 3 | #endif 4 | 5 | int Correlation_forward_cuda_kernel(/*THCudaTensor_data(state, output)*/ float *output, 6 | /*THCudaTensor_size(state, output, 0)*/ int ob, 7 | /*THCudaTensor_size(state, output, 1)*/ int oc, 8 | /*THCudaTensor_size(state, output, 2)*/ int oh, 9 | /*THCudaTensor_size(state, output, 3)*/ int ow, 10 | /*THCudaTensor_stride(state, output, 0)*/ int osb, 11 | /*THCudaTensor_stride(state, output, 1)*/ int osc, 12 | /*THCudaTensor_stride(state, output, 2)*/ int osh, 13 | /*THCudaTensor_stride(state, output, 3)*/ int osw, 14 | 15 | /*THCudaTensor_data(state, input1)*/ float *input1, 16 | /*THCudaTensor_size(state, input1, 1)*/ int ic, 17 | /*THCudaTensor_size(state, input1, 2)*/ int ih, 18 | /*THCudaTensor_size(state, input1, 3)*/ int iw, 19 | /*THCudaTensor_stride(state, input1, 0)*/ int isb, 20 | /*THCudaTensor_stride(state, input1, 1)*/ int isc, 21 | /*THCudaTensor_stride(state, input1, 2)*/ int ish, 22 | /*THCudaTensor_stride(state, input1, 3)*/ int isw, 23 | 24 | /*THCudaTensor_data(state, input2)*/ float *input2, 25 | /*THCudaTensor_size(state, input2, 1)*/ int gc, 26 | /*THCudaTensor_stride(state, input2, 0)*/ int gsb, 27 | /*THCudaTensor_stride(state, input2, 1)*/ int gsc, 28 | /*THCudaTensor_stride(state, input2, 2)*/ int gsh, 29 | /*THCudaTensor_stride(state, input2, 3)*/ int gsw, 30 | 31 | /*THCudaTensor_data(state, rInput1)*/ float *rInput1, 32 | /*THCudaTensor_data(state, rInput2)*/ float *rInput2, 33 | int pad_size, 34 | int kernel_size, 35 | int max_displacement, 36 | int stride1, 37 | int stride2, 38 | int corr_type_multiply, 39 | /*THCState_getCurrentStream(state)*/ cudaStream_t stream); 40 | 41 | int Correlation_backward_cuda_kernel( 42 | /*THCudaTensor_data(state, gradOutput)*/ float *gradOutput, 43 | /*THCudaTensor_size(state, gradOutput, 0)*/ int gob, 44 | /*THCudaTensor_size(state, gradOutput, 1)*/ int goc, 45 | /*THCudaTensor_size(state, gradOutput, 2)*/ int goh, 46 | /*THCudaTensor_size(state, gradOutput, 3)*/ int gow, 47 | /*THCudaTensor_stride(state, gradOutput, 0)*/ int gosb, 48 | /*THCudaTensor_stride(state, gradOutput, 1)*/ int gosc, 49 | /*THCudaTensor_stride(state, gradOutput, 2)*/ int gosh, 50 | /*THCudaTensor_stride(state, gradOutput, 3)*/ int gosw, 51 | 52 | /*THCudaTensor_data(state, input1)*/ float* input1, 53 | /*THCudaTensor_size(state, input1, 1)*/ int ic, 54 | /*THCudaTensor_size(state, input1, 2)*/ int ih, 55 | /*THCudaTensor_size(state, input1, 3)*/ int iw, 56 | /*THCudaTensor_stride(state, input1, 0)*/ int isb, 57 | /*THCudaTensor_stride(state, input1, 1)*/ int isc, 58 | /*THCudaTensor_stride(state, input1, 2)*/ int ish, 59 | /*THCudaTensor_stride(state, input1, 3)*/ int isw, 60 | 61 | /*THCudaTensor_data(state, input2)*/ float *input2, 62 | /*THCudaTensor_stride(state, input2, 0)*/ int gsb, 63 | /*THCudaTensor_stride(state, input2, 1)*/ int gsc, 64 | /*THCudaTensor_stride(state, input2, 2)*/ int gsh, 65 | /*THCudaTensor_stride(state, input2, 3)*/ int gsw, 66 | 67 | /*THCudaTensor_data(state, gradInput1)*/ float *gradInput1, 68 | /*THCudaTensor_stride(state, gradInput1, 0)*/ int gisb, 69 | /*THCudaTensor_stride(state, gradInput1, 1)*/ int gisc, 70 | /*THCudaTensor_stride(state, gradInput1, 2)*/ int gish, 71 | /*THCudaTensor_stride(state, gradInput1, 3)*/ int gisw, 72 | 73 | /*THCudaTensor_data(state, gradInput2)*/ float *gradInput2, 74 | /*THCudaTensor_size(state, gradInput2, 1)*/ int ggc, 75 | /*THCudaTensor_stride(state, gradInput2, 0)*/ int ggsb, 76 | /*THCudaTensor_stride(state, gradInput2, 1)*/ int ggsc, 77 | /*THCudaTensor_stride(state, gradInput2, 2)*/ int ggsh, 78 | /*THCudaTensor_stride(state, gradInput2, 3)*/ int ggsw, 79 | 80 | /*THCudaTensor_data(state, rInput1)*/ float *rInput1, 81 | /*THCudaTensor_data(state, rInput2)*/ float *rInput2, 82 | int pad_size, 83 | int kernel_size, 84 | int max_displacement, 85 | int stride1, 86 | int stride2, 87 | int corr_type_multiply, 88 | /*THCState_getCurrentStream(state)*/cudaStream_t stream); 89 | 90 | #ifdef __cplusplus 91 | } 92 | #endif 93 | -------------------------------------------------------------------------------- /flow_inference/networks_flow/correlation_package/src/correlation_cuda_kernel.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/correlation_package/src/correlation_cuda_kernel.o -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/resample2d_package/__init__.py -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/resample2d_package/__init__.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/resample2d_package/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/_ext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/resample2d_package/_ext/__init__.py -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/_ext/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/resample2d_package/_ext/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/_ext/resample2d/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.ffi import _wrap_function 3 | from ._resample2d import lib as _lib, ffi as _ffi 4 | 5 | __all__ = [] 6 | def _import_symbols(locals): 7 | for symbol in dir(_lib): 8 | fn = getattr(_lib, symbol) 9 | if callable(fn): 10 | locals[symbol] = _wrap_function(fn, _ffi) 11 | else: 12 | locals[symbol] = fn 13 | __all__.append(symbol) 14 | 15 | _import_symbols(locals()) 16 | -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/_ext/resample2d/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/resample2d_package/_ext/resample2d/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/_ext/resample2d/_resample2d.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/resample2d_package/_ext/resample2d/_resample2d.so -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.ffi 4 | 5 | this_folder = os.path.dirname(os.path.abspath(__file__)) + '/' 6 | 7 | Headers = [] 8 | Sources = [] 9 | Defines = [] 10 | Objects = [] 11 | 12 | if torch.cuda.is_available() == True: 13 | Headers += ['src/Resample2d_cuda.h'] 14 | Sources += ['src/Resample2d_cuda.c'] 15 | Defines += [('WITH_CUDA', None)] 16 | Objects += ['src/Resample2d_kernel.o'] 17 | 18 | ffi = torch.utils.ffi.create_extension( 19 | name='_ext.resample2d', 20 | headers=Headers, 21 | sources=Sources, 22 | verbose=False, 23 | with_cuda=True, 24 | package=False, 25 | relative_to=this_folder, 26 | define_macros=Defines, 27 | extra_objects=[os.path.join(this_folder, Object) for Object in Objects] 28 | ) 29 | 30 | if __name__ == '__main__': 31 | ffi.build() -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/resample2d_package/functions/__init__.py -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/functions/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/resample2d_package/functions/__init__.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/functions/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/resample2d_package/functions/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/functions/__pycache__/resample2d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/resample2d_package/functions/__pycache__/resample2d.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/functions/resample2d.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Function, Variable 2 | from .._ext import resample2d 3 | 4 | 5 | class Resample2dFunction(Function): 6 | 7 | @staticmethod 8 | def forward(ctx, input1, input2, kernel_size=1): 9 | assert input1.is_contiguous() 10 | assert input2.is_contiguous() 11 | 12 | ctx.save_for_backward(input1, input2) 13 | ctx.kernel_size = kernel_size 14 | 15 | _, d, _, _ = input1.size() 16 | b, _, h, w = input2.size() 17 | output = input1.new(b, d, h, w).zero_() 18 | 19 | resample2d.Resample2d_cuda_forward(input1, input2, output, kernel_size) 20 | 21 | return output 22 | 23 | @staticmethod 24 | def backward(ctx, grad_output): 25 | assert grad_output.is_contiguous() 26 | 27 | input1, input2 = ctx.saved_tensors 28 | 29 | grad_input1 = Variable(input1.new(input1.size()).zero_()) 30 | grad_input2 = Variable(input1.new(input2.size()).zero_()) 31 | 32 | resample2d.Resample2d_cuda_backward(input1, input2, grad_output.data, 33 | grad_input1.data, grad_input2.data, 34 | ctx.kernel_size) 35 | 36 | return grad_input1, grad_input2, None 37 | -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/functions/resample2d.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/resample2d_package/functions/resample2d.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | TORCH=$(python -c "import os; import torch; print(os.path.dirname(torch.__file__))") 3 | 4 | cd src 5 | echo "Compiling resample2d kernels by nvcc..." 6 | rm Resample2d_kernel.o 7 | rm -r ../_ext 8 | 9 | nvcc -c -o Resample2d_kernel.o Resample2d_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52 -I ${TORCH}/lib/include/TH -I ${TORCH}/lib/include/THC 10 | 11 | cd ../ 12 | python build.py -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/resample2d_package/modules/__init__.py -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/modules/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/resample2d_package/modules/__init__.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/modules/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/resample2d_package/modules/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/modules/__pycache__/resample2d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/resample2d_package/modules/__pycache__/resample2d.cpython-36.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/modules/resample2d.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.module import Module 2 | 3 | from ..functions.resample2d import Resample2dFunction 4 | 5 | 6 | class Resample2d(Module): 7 | 8 | def __init__(self, kernel_size=1): 9 | super(Resample2d, self).__init__() 10 | self.kernel_size = kernel_size 11 | 12 | def forward(self, input1, input2): 13 | input1_c = input1.contiguous() 14 | return Resample2dFunction.apply(input1_c, input2, self.kernel_size) 15 | -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/modules/resample2d.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/resample2d_package/modules/resample2d.pyc -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/src/Resample2d_cuda.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "Resample2d_kernel.h" 5 | 6 | extern THCState* state; 7 | 8 | int Resample2d_cuda_forward(THCudaTensor* input1, THCudaTensor* input2, THCudaTensor* output, int kernel_size) { 9 | Resample2d_kernel_forward(state, input1, input2, output, kernel_size); 10 | return 1; 11 | } 12 | 13 | int Resample2d_cuda_backward(THCudaTensor* input1, THCudaTensor* input2, THCudaTensor* gradOutput, THCudaTensor* gradInput1, THCudaTensor* gradInput2, int kernel_size) { 14 | Resample2d_kernel_backward(state, input1, input2, gradOutput, gradInput1, gradInput2, kernel_size); 15 | 16 | return 1; 17 | } -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/src/Resample2d_cuda.h: -------------------------------------------------------------------------------- 1 | int Resample2d_cuda_forward(THCudaTensor* input1, THCudaTensor* input2, THCudaTensor* output, int kernel_size); 2 | 3 | int Resample2d_cuda_backward(THCudaTensor* input1, THCudaTensor* input2, THCudaTensor* gradOutput, THCudaTensor* gradInput1, THCudaTensor* gradInput2, int kernel_size); -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/src/Resample2d_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #define CUDA_NUM_THREADS 512 7 | #define THREADS_PER_BLOCK 64 8 | 9 | #define DIM0(TENSOR) ((TENSOR).x) 10 | #define DIM1(TENSOR) ((TENSOR).y) 11 | #define DIM2(TENSOR) ((TENSOR).z) 12 | #define DIM3(TENSOR) ((TENSOR).w) 13 | 14 | #define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))]) 15 | 16 | #ifdef __cplusplus 17 | extern "C" { 18 | #endif 19 | 20 | __global__ void kernel_Resample2d_updateOutput(const int n, const float* input1, const long4 input1_size, const long4 input1_stride, 21 | const float* input2, const long4 input2_size, const long4 input2_stride, float* output, const long4 output_size, const long4 output_stride, int kernel_size) { 22 | int index = blockIdx.x * blockDim.x + threadIdx.x; 23 | 24 | if (index >= n) { 25 | return; 26 | } 27 | 28 | float val = 0.0; 29 | 30 | int dim_b = DIM0(output_size); 31 | int dim_c = DIM1(output_size); 32 | int dim_h = DIM2(output_size); 33 | int dim_w = DIM3(output_size); 34 | int dim_chw = dim_c * dim_h * dim_w; 35 | int dim_hw = dim_h * dim_w; 36 | 37 | int b = ( index / dim_chw ) % dim_b; 38 | int c = ( index / dim_hw ) % dim_c; 39 | int y = ( index / dim_w ) % dim_h; 40 | int x = ( index ) % dim_w; 41 | 42 | float dx = DIM3_INDEX(input2, b, 0, y, x); 43 | float dy = DIM3_INDEX(input2, b, 1, y, x); 44 | 45 | float xf = float(x) + dx; 46 | float yf = float(y) + dy; 47 | float alpha = xf - floor(xf); // alpha 48 | float beta = yf - floor(yf); // beta 49 | 50 | int xL = max(min( int (floor(xf)), dim_w-1), 0); 51 | int xR = max(min( int (floor(xf)+1), dim_w -1), 0); 52 | int yT = max(min( int (floor(yf)), dim_h-1), 0); 53 | int yB = max(min( int (floor(yf)+1), dim_h-1), 0); 54 | 55 | for (int fy = 0; fy < kernel_size; fy += 1) { 56 | for (int fx = 0; fx < kernel_size; fx += 1) { 57 | val += (1. - alpha)*(1. - beta) * DIM3_INDEX(input1, b, c, yT + fy, xL + fx); 58 | val += (alpha)*(1. - beta) * DIM3_INDEX(input1, b, c, yT + fy, xR + fx); 59 | val += (1. - alpha)*(beta) * DIM3_INDEX(input1, b, c, yB + fy, xL + fx); 60 | val += (alpha)*(beta) * DIM3_INDEX(input1, b, c, yB + fy, xR + fx); 61 | } 62 | } 63 | 64 | output[index] = val; 65 | 66 | } 67 | 68 | 69 | __global__ void kernel_Resample2d_backward_input1( 70 | const int n, const float* input1, const long4 input1_size, const long4 input1_stride, const float* input2, const long4 input2_size, const long4 input2_stride, 71 | const float* gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, float* gradInput, const long4 gradInput_size, const long4 gradInput_stride, int kernel_size) { 72 | 73 | int index = blockIdx.x * blockDim.x + threadIdx.x; 74 | 75 | if (index >= n) { 76 | return; 77 | } 78 | 79 | int dim_b = DIM0(gradOutput_size); 80 | int dim_c = DIM1(gradOutput_size); 81 | int dim_h = DIM2(gradOutput_size); 82 | int dim_w = DIM3(gradOutput_size); 83 | int dim_chw = dim_c * dim_h * dim_w; 84 | int dim_hw = dim_h * dim_w; 85 | 86 | int b = ( index / dim_chw ) % dim_b; 87 | int c = ( index / dim_hw ) % dim_c; 88 | int y = ( index / dim_w ) % dim_h; 89 | int x = ( index ) % dim_w; 90 | 91 | float dx = DIM3_INDEX(input2, b, 0, y, x); 92 | float dy = DIM3_INDEX(input2, b, 1, y, x); 93 | 94 | float xf = float(x) + dx; 95 | float yf = float(y) + dy; 96 | float alpha = xf - int(xf); // alpha 97 | float beta = yf - int(yf); // beta 98 | 99 | int idim_h = DIM2(input1_size); 100 | int idim_w = DIM3(input1_size); 101 | 102 | int xL = max(min( int (floor(xf)), idim_w-1), 0); 103 | int xR = max(min( int (floor(xf)+1), idim_w -1), 0); 104 | int yT = max(min( int (floor(yf)), idim_h-1), 0); 105 | int yB = max(min( int (floor(yf)+1), idim_h-1), 0); 106 | 107 | for (int fy = 0; fy < kernel_size; fy += 1) { 108 | for (int fx = 0; fx < kernel_size; fx += 1) { 109 | atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT + fy), (xL + fx)), (1-alpha)*(1-beta) * DIM3_INDEX(gradOutput, b, c, y, x)); 110 | atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT + fy), (xR + fx)), (alpha)*(1-beta) * DIM3_INDEX(gradOutput, b, c, y, x)); 111 | atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB + fy), (xL + fx)), (1-alpha)*(beta) * DIM3_INDEX(gradOutput, b, c, y, x)); 112 | atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB + fy), (xR + fx)), (alpha)*(beta) * DIM3_INDEX(gradOutput, b, c, y, x)); 113 | } 114 | } 115 | 116 | } 117 | 118 | __global__ void kernel_Resample2d_backward_input2( 119 | const int n, const float* input1, const long4 input1_size, const long4 input1_stride, const float* input2, const long4 input2_size, const long4 input2_stride, 120 | const float* gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, float* gradInput, const long4 gradInput_size, const long4 gradInput_stride, int kernel_size) { 121 | 122 | int index = blockIdx.x * blockDim.x + threadIdx.x; 123 | 124 | if (index >= n) { 125 | return; 126 | } 127 | 128 | float output = 0.0; 129 | int kernel_rad = (kernel_size - 1)/2; 130 | 131 | int dim_b = DIM0(gradInput_size); 132 | int dim_c = DIM1(gradInput_size); 133 | int dim_h = DIM2(gradInput_size); 134 | int dim_w = DIM3(gradInput_size); 135 | int dim_chw = dim_c * dim_h * dim_w; 136 | int dim_hw = dim_h * dim_w; 137 | 138 | int b = ( index / dim_chw ) % dim_b; 139 | int c = ( index / dim_hw ) % dim_c; 140 | int y = ( index / dim_w ) % dim_h; 141 | int x = ( index ) % dim_w; 142 | 143 | int odim_c = DIM1(gradOutput_size); 144 | 145 | float dx = DIM3_INDEX(input2, b, 0, y, x); 146 | float dy = DIM3_INDEX(input2, b, 1, y, x); 147 | 148 | float xf = float(x) + dx; 149 | float yf = float(y) + dy; 150 | 151 | int xL = max(min( int (floor(xf)), dim_w-1), 0); 152 | int xR = max(min( int (floor(xf)+1), dim_w -1), 0); 153 | int yT = max(min( int (floor(yf)), dim_h-1), 0); 154 | int yB = max(min( int (floor(yf)+1), dim_h-1), 0); 155 | 156 | if (c % 2) { 157 | float gamma = 1 - (xf - floor(xf)); // alpha 158 | for (int i = 0; i <= 2*kernel_rad; ++i) { 159 | for (int j = 0; j <= 2*kernel_rad; ++j) { 160 | for (int ch = 0; ch < odim_c; ++ch) { 161 | output += (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xL + i)); 162 | output -= (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xL + i)); 163 | output += (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xR + i)); 164 | output -= (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xR + i)); 165 | } 166 | } 167 | } 168 | } 169 | else { 170 | float gamma = 1 - (yf - floor(yf)); // alpha 171 | for (int i = 0; i <= 2*kernel_rad; ++i) { 172 | for (int j = 0; j <= 2*kernel_rad; ++j) { 173 | for (int ch = 0; ch < odim_c; ++ch) { 174 | output += (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xR + i)); 175 | output -= (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xL + i)); 176 | output += (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xR + i)); 177 | output -= (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xL + i)); 178 | } 179 | } 180 | } 181 | 182 | } 183 | 184 | gradInput[index] = output; 185 | 186 | } 187 | 188 | void Resample2d_kernel_forward(THCState* state, THCudaTensor* input1, THCudaTensor* input2, THCudaTensor* output, int kernel_size) { 189 | int n = 0; 190 | 191 | const long4 input1_size = make_long4(input1->size[0], input1->size[1], input1->size[2], input1->size[3]); 192 | const long4 input1_stride = make_long4(input1->stride[0], input1->stride[1], input1->stride[2], input1->stride[3]); 193 | 194 | const long4 input2_size = make_long4(input2->size[0], input2->size[1], input2->size[2], input2->size[3]); 195 | const long4 input2_stride = make_long4(input2->stride[0], input2->stride[1], input2->stride[2], input2->stride[3]); 196 | 197 | const long4 output_size = make_long4(output->size[0], output->size[1], output->size[2], output->size[3]); 198 | const long4 output_stride = make_long4(output->stride[0], output->stride[1], output->stride[2], output->stride[3]); 199 | 200 | n = THCudaTensor_nElement(state, output); 201 | kernel_Resample2d_updateOutput<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>( 202 | n, THCudaTensor_data(state, input1), input1_size, input1_stride, THCudaTensor_data(state, input2), input2_size, input2_stride, 203 | THCudaTensor_data(state, output), output_size, output_stride, kernel_size); 204 | 205 | THCudaCheck(cudaGetLastError()); 206 | } 207 | 208 | void Resample2d_kernel_backward(THCState* state, THCudaTensor* input1, THCudaTensor* input2, THCudaTensor* gradOutput, THCudaTensor* gradInput1, THCudaTensor* gradInput2, int kernel_size) { 209 | int n = 0; 210 | 211 | const long4 input1_size = make_long4(input1->size[0], input1->size[1], input1->size[2], input1->size[3]); 212 | const long4 input1_stride = make_long4(input1->stride[0], input1->stride[1], input1->stride[2], input1->stride[3]); 213 | 214 | const long4 input2_size = make_long4(input2->size[0], input2->size[1], input2->size[2], input2->size[3]); 215 | const long4 input2_stride = make_long4(input2->stride[0], input2->stride[1], input2->stride[2], input2->stride[3]); 216 | 217 | const long4 gradOutput_size = make_long4(gradOutput->size[0], gradOutput->size[1], gradOutput->size[2], gradOutput->size[3]); 218 | const long4 gradOutput_stride = make_long4(gradOutput->stride[0], gradOutput->stride[1], gradOutput->stride[2], gradOutput->stride[3]); 219 | 220 | const long4 gradInput1_size = make_long4(gradInput1->size[0], gradInput1->size[1], gradInput1->size[2], gradInput1->size[3]); 221 | const long4 gradInput1_stride = make_long4(gradInput1->stride[0], gradInput1->stride[1], gradInput1->stride[2], gradInput1->stride[3]); 222 | 223 | n = THCudaTensor_nElement(state, gradOutput); 224 | kernel_Resample2d_backward_input1<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>( 225 | n, THCudaTensor_data(state, input1), input1_size, input1_stride, THCudaTensor_data(state, input2), input2_size, input2_stride, 226 | THCudaTensor_data(state, gradOutput), gradOutput_size, gradOutput_stride, THCudaTensor_data(state, gradInput1), gradInput1_size, gradInput1_stride, kernel_size 227 | ); 228 | 229 | const long4 gradInput2_size = make_long4(gradInput2->size[0], gradInput2->size[1], gradInput2->size[2], gradInput2->size[3]); 230 | const long4 gradInput2_stride = make_long4(gradInput2->stride[0], gradInput2->stride[1], gradInput2->stride[2], gradInput2->stride[3]); 231 | 232 | n = THCudaTensor_nElement(state, gradInput2); 233 | kernel_Resample2d_backward_input2<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>( 234 | n, THCudaTensor_data(state, input1), input1_size, input1_stride, THCudaTensor_data(state, input2), input2_size, input2_stride, 235 | THCudaTensor_data(state, gradOutput), gradOutput_size, gradOutput_stride, THCudaTensor_data(state, gradInput2), gradInput2_size, gradInput2_stride, kernel_size 236 | ); 237 | THCudaCheck(cudaGetLastError()); 238 | } 239 | 240 | #ifdef __cplusplus 241 | } 242 | #endif -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/src/Resample2d_kernel.h: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | extern "C" { 3 | #endif 4 | 5 | void Resample2d_kernel_forward(THCState* state, THCudaTensor* input1, THCudaTensor* input2, THCudaTensor* output, int kernel_size); 6 | 7 | void Resample2d_kernel_backward(THCState* state, THCudaTensor* input1, THCudaTensor* input2, THCudaTensor* gradOutput, THCudaTensor* gradInput1, THCudaTensor* gradInput2, int kernel_size); 8 | 9 | #ifdef __cplusplus 10 | } 11 | #endif -------------------------------------------------------------------------------- /flow_inference/networks_flow/resample2d_package/src/Resample2d_kernel.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/resample2d_package/src/Resample2d_kernel.o -------------------------------------------------------------------------------- /flow_inference/networks_flow/submodules.py: -------------------------------------------------------------------------------- 1 | # freda (todo) : 2 | 3 | import torch.nn as nn 4 | import torch 5 | import numpy as np 6 | 7 | def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1): 8 | if batchNorm: 9 | return nn.Sequential( 10 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=False), 11 | nn.BatchNorm2d(out_planes), 12 | nn.LeakyReLU(0.1,inplace=True) 13 | ) 14 | else: 15 | return nn.Sequential( 16 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True), 17 | nn.LeakyReLU(0.1,inplace=True) 18 | ) 19 | 20 | def i_conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1, bias = True): 21 | if batchNorm: 22 | return nn.Sequential( 23 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=bias), 24 | nn.BatchNorm2d(out_planes), 25 | ) 26 | else: 27 | return nn.Sequential( 28 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=bias), 29 | ) 30 | 31 | def predict_flow(in_planes): 32 | return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=True) 33 | 34 | def deconv(in_planes, out_planes): 35 | return nn.Sequential( 36 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True), 37 | nn.LeakyReLU(0.1,inplace=True) 38 | ) 39 | 40 | class tofp16(nn.Module): 41 | def __init__(self): 42 | super(tofp16, self).__init__() 43 | 44 | def forward(self, input): 45 | return input.half() 46 | 47 | 48 | class tofp32(nn.Module): 49 | def __init__(self): 50 | super(tofp32, self).__init__() 51 | 52 | def forward(self, input): 53 | return input.float() 54 | 55 | 56 | def init_deconv_bilinear(weight): 57 | f_shape = weight.size() 58 | heigh, width = f_shape[-2], f_shape[-1] 59 | f = np.ceil(width/2.0) 60 | c = (2 * f - 1 - f % 2) / (2.0 * f) 61 | bilinear = np.zeros([heigh, width]) 62 | for x in range(width): 63 | for y in range(heigh): 64 | value = (1 - abs(x / f - c)) * (1 - abs(y / f - c)) 65 | bilinear[x, y] = value 66 | weight.data.fill_(0.) 67 | for i in range(f_shape[0]): 68 | for j in range(f_shape[1]): 69 | weight.data[i,j,:,:] = torch.from_numpy(bilinear) 70 | 71 | 72 | def save_grad(grads, name): 73 | def hook(grad): 74 | grads[name] = grad 75 | return hook 76 | 77 | ''' 78 | def save_grad(grads, name): 79 | def hook(grad): 80 | grads[name] = grad 81 | return hook 82 | import torch 83 | from channelnorm_package.modules.channelnorm import ChannelNorm 84 | model = ChannelNorm().cuda() 85 | grads = {} 86 | a = 100*torch.autograd.Variable(torch.randn((1,3,5,5)).cuda(), requires_grad=True) 87 | a.register_hook(save_grad(grads, 'a')) 88 | b = model(a) 89 | y = torch.mean(b) 90 | y.backward() 91 | 92 | ''' 93 | -------------------------------------------------------------------------------- /flow_inference/networks_flow/submodules.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/flow_inference/networks_flow/submodules.pyc -------------------------------------------------------------------------------- /ft_davis.sh: -------------------------------------------------------------------------------- 1 | name=ft_davis 2 | echo $name 3 | tgt_dir=Outputs/$name 4 | if [ ! -d $tgt_dir ]; then 5 | mkdir -p $tgt_dir 6 | fi 7 | python3 train_davis.py \ 8 | --root-data='data/davis2017/trainval' \ 9 | --root-all-data='data/davis2017/trainval' \ 10 | --meta-list='data/davis2017/trainval/train_meta.json' \ 11 | --restore='checkpoints/train_ytv/model_4.pth' \ 12 | --finetune \ 13 | --epoch=100 \ 14 | --random-ref \ 15 | --random-crop \ 16 | --lr-atn \ 17 | --loss-iou-maxmin \ 18 | --batch-size=1 \ 19 | --start-epoch=0 \ 20 | --sample-size=8 \ 21 | --lr=1e-6 \ 22 | --gpu='3' \ 23 | --sample-dir=$tgt_dir'/sample' \ 24 | --snapshot-dir=$tgt_dir'/snapshot' \ 25 | --fix-lr=0 \ 26 | 2>&1 | tee $tgt_dir/train.log 27 | -------------------------------------------------------------------------------- /infer_davis.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import os 4 | import os.path as osp 5 | import time 6 | import logging 7 | import numpy as np 8 | import argparse 9 | from copy import deepcopy 10 | from tqdm import tqdm 11 | 12 | import torch 13 | assert torch.__version__ == '0.4.0' 14 | import torch.nn as nn 15 | import torch.nn.parallel 16 | import torch.backends.cudnn as cudnn 17 | import torch.optim as optim 18 | from torch.utils import data 19 | import torch.nn.functional as F 20 | from torch.autograd import Variable 21 | 22 | from dataset.vos import Valset 23 | from networks.agssvos import AGSSVOS 24 | import sys 25 | sys.path.append('./flow_inference') 26 | from flow_inference.flow_inference import Inference_flow 27 | from tools import preprocess, visualize, utils 28 | import timeit 29 | import cv2 30 | 31 | 32 | def get_parser(): 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--batch-size', type=int) 35 | parser.add_argument('--root-data', type=str) 36 | parser.add_argument('--root-all-data', type=str) 37 | parser.add_argument('--list-path', type=str) 38 | parser.add_argument('--epoch', type=int) 39 | parser.add_argument('--start-epoch', type=int, default=0) 40 | parser.add_argument('--sample-size', type=int, default=10) 41 | parser.add_argument('--gpu', type=str, default='0') 42 | parser.add_argument('--lr', type=float) 43 | parser.add_argument('--fix-size', action='store_true') 44 | parser.add_argument('--stop-iou', type=float, default=0.8) 45 | parser.add_argument('--restore', type=str, default=None) 46 | parser.add_argument('--sample-dir', type=str) 47 | parser.add_argument('--snapshot-dir', type=str) 48 | parser.add_argument('--crop_size', type=int, default=512) 49 | parser.add_argument('--resize_h', type=int, default=360) 50 | parser.add_argument('--resize_w', type=int, default=640) 51 | parser.add_argument('--rgb_max', type=float, default=255.) 52 | parser.add_argument('--div_flow', type=int, default=20) 53 | parser.add_argument('--ignore_label', type=int, default=255) 54 | parser.add_argument('--scale_min', type=float, default=0.5, help='minimum random scale') 55 | parser.add_argument('--scale_max', type=float, default=2.0, help='maximum random scale') 56 | parser.add_argument('--rotate_min', type=float, default=-10, help='minimum random rotate') 57 | parser.add_argument('--rotate_max', type=float, default=10, help='maximum random rotate') 58 | parser.add_argument('--flow_checkpoint_path', type=str, default='models/FlowNet2-C_checkpoint.pth.tar', 59 | help='pretrained model for flownetC') 60 | parser.add_argument('--save-dir', type=str) 61 | parser.add_argument('--test-mode', type=str) 62 | parser.add_argument('--spec-vid', type=str, default=None) 63 | parser.add_argument('--spec-obj-ind', type=str, default=None) 64 | parser.add_argument('--show-img', action='store_true', help='whether save the visualized image result') 65 | return parser 66 | 67 | # get logger 68 | def get_logger(): 69 | logger = logging.getLogger() 70 | logger.setLevel(logging.INFO) 71 | handler = logging.StreamHandler() 72 | fmt = "[%(asctime)s line %(lineno)d] %(message)s" 73 | handler.setFormatter(logging.Formatter(fmt)) 74 | logger.addHandler(handler) 75 | return logger 76 | 77 | def show(images, labels, preds): 78 | os.system('rm %s/*' % args.sample_dir) 79 | for i_bs in range(images.shape[0]): 80 | for j_bs in range(labels.shape[1]): 81 | path = args.sample_dir + '/' + str(i_bs)+'_'+str(j_bs) + '#' 82 | image = visualize.denorm_image(images[i_bs, j_bs, :]) 83 | label = visualize.vis_label(labels[i_bs, j_bs], 1, 128) 84 | cv2.imwrite(path + 'img.jpg', image) 85 | cv2.imwrite(path + 'lab.jpg', label) 86 | 87 | for j_bs in range(preds.shape[1]): 88 | path = args.sample_dir + '/' + str(i_bs)+'_'+str(j_bs) + '#' 89 | pred = preds[i_bs,j_bs]*255 90 | cv2.imwrite(path + 'pred.jpg', pred) 91 | 92 | def main(): 93 | global args, logger, writer 94 | args = get_parser().parse_args() 95 | logger = get_logger() 96 | print(args) 97 | 98 | test_mode = int(args.test_mode) > 0 99 | 100 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 101 | 102 | # setting up model 103 | model = AGSSVOS().cuda() 104 | model = torch.nn.DataParallel(model).cuda() 105 | model.eval() 106 | flow_infer = Inference_flow(args, train_flow=False, resume=args.restore) 107 | 108 | for m in model.module.Encoder.modules(): 109 | if isinstance(m, nn.BatchNorm2d): 110 | m.eval() 111 | for p in m.parameters(): 112 | p.requires_grad = False 113 | 114 | 115 | if args.restore != None: 116 | assert os.path.isfile(args.restore), "no restore file found at %s" % (args.restore) 117 | logger.info("loading from %s" % (args.restore)) 118 | 119 | checkpoint = torch.load(args.restore)['seg'] 120 | model.load_state_dict(checkpoint) 121 | 122 | del checkpoint 123 | torch.cuda.empty_cache() 124 | 125 | spec_vid = args.spec_vid if not test_mode else None 126 | spec_obj_ind = [1] if not test_mode else None 127 | # print('spec_vid, spec_obj_ind', spec_vid, spec_obj_ind) 128 | 129 | testloader = data.DataLoader( 130 | Valset(root_data=args.root_data, root_all_data=args.root_all_data, json_meta_list=args.list_path, 131 | sample_size=args.sample_size, fix_size=args.fix_size, half_size=False, 132 | test_mode=True, spec_vid=spec_vid, spec_obj_ind=spec_obj_ind), 133 | batch_size=args.batch_size, shuffle=False, num_workers=1, pin_memory=True) 134 | 135 | # training 136 | tot_iter = len(testloader) 137 | logger.info("Total iteration per epoch is %d" % (tot_iter)) 138 | vid_time = [] 139 | vid_time_cv2 = [] 140 | vid_frm_num = [] 141 | 142 | for i_iter, batch in enumerate(testloader): 143 | img, lab, ori_img, vid_name, min_idx, obj_num, obj_ind, obj_start_idx, ori_shape = batch 144 | 145 | vid_name = vid_name[0] 146 | 147 | logger.info('vid_name = %s' % vid_name) 148 | sample_dir = args.sample_dir+'/'+vid_name+'/' 149 | save_dir = args.save_dir+'/'+vid_name+'/' 150 | 151 | if not osp.exists(sample_dir): 152 | os.makedirs(sample_dir) 153 | if not osp.exists(save_dir): 154 | os.makedirs(save_dir) 155 | 156 | start_time = timeit.default_timer() 157 | start_time_cv2 = cv2.getTickCount() 158 | 159 | img = img.cuda().float() 160 | lab = lab[0].cuda().float() # KHW 161 | img = img.expand(lab.shape[0],-1,-1,-1,-1) # KT3HW 162 | ori_img = ori_img[0].numpy() # THW3 163 | min_idx = min_idx.item() 164 | obj_num = obj_num.item() 165 | obj_ind = [int(a[0]) for a in obj_ind] 166 | obj_start_idx = [a.item() for a in obj_start_idx] 167 | ori_shape = (ori_shape[0].item(), ori_shape[1].item()) 168 | 169 | preds = [] 170 | preds.append(lab[:,0:1].contiguous()) 171 | 172 | save_lab = torch.zeros(img.shape[1], obj_num+1, img.shape[-2], img.shape[-1]).cuda().float() 173 | save_lab[:,0,:] += 0.5 174 | 175 | obj_start_idx_set = set(obj_start_idx) 176 | for start in range(img.shape[1]): 177 | if start not in obj_start_idx_set: 178 | continue 179 | 180 | ref_lab = [] 181 | obj_ind_tmp = [] 182 | for k,s in enumerate(obj_start_idx): 183 | if s == start: 184 | obj_ind_tmp.append(obj_ind[k]) 185 | ref_lab.append(lab[k:k+1]) 186 | ref_lab = torch.cat(ref_lab, dim=0) 187 | ref_img = img[:ref_lab.shape[0],start,:] 188 | 189 | preds = [torch.zeros(ref_lab.shape).cuda().float().unsqueeze(1)] 190 | cnt = 0 191 | for k,s in enumerate(obj_start_idx): 192 | if s == start: 193 | preds[0][cnt,0] = lab[k] 194 | save_lab[s,obj_ind[k]] = lab[k] 195 | cnt += 1 196 | 197 | ref_lab = ref_lab.max(0,keepdim=True)[0] 198 | with torch.no_grad(): 199 | ms = model.forward(ref_img, ref_lab) 200 | 201 | for i in tqdm(range(start+1, img.shape[1])): 202 | with torch.no_grad(): 203 | flow = flow_infer.infer(ori_img[i], ori_img[i-1]) 204 | 205 | prev_lab = utils.flow_warp_tensor(preds[-1], flow) 206 | 207 | merge_preds = prev_lab.max(0)[0] 208 | with torch.no_grad(): 209 | output, _ = model.forward(img[:ref_lab.shape[0],i], merge_preds, prev_lab.squeeze(1), ref=ms) 210 | 211 | output = output.detach() 212 | 213 | preds.append(output.contiguous()) 214 | for idx,ind in enumerate(obj_ind_tmp): 215 | save_lab[i,ind,:] = output[idx,0] 216 | end_time = timeit.default_timer() 217 | end_time_cv2 = cv2.getTickCount() 218 | 219 | vid_time.append(end_time-start_time) 220 | vid_time_cv2.append((end_time_cv2-start_time_cv2)/cv2.getTickFrequency()) 221 | vid_frm_num.append(img.shape[1]) 222 | 223 | step = 1 224 | for i,lab in enumerate(save_lab): 225 | if i % step == 0: 226 | img = cv2.resize(ori_img[i], (ori_shape[1],ori_shape[0]), interpolation=cv2.INTER_LINEAR) 227 | visualize.show_save_lab_savesmall(i+min_idx, img, lab.cpu().numpy(), sample_dir, save_dir, 228 | show=args.show_img, save=True) 229 | 230 | logger.info(('sum', sum(vid_time), sum(vid_time_cv2), sum(vid_frm_num))) 231 | logger.info(('time', sum(vid_time)/sum(vid_frm_num), sum(vid_time_cv2)/sum(vid_frm_num))) 232 | 233 | 234 | if __name__ == '__main__': 235 | main() 236 | 237 | 238 | 239 | -------------------------------------------------------------------------------- /infer_ytv.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import os 4 | import os.path as osp 5 | import time 6 | import logging 7 | import numpy as np 8 | import argparse 9 | from copy import deepcopy 10 | from tqdm import tqdm 11 | 12 | import torch 13 | assert torch.__version__ == '0.4.0' 14 | import torch.nn as nn 15 | import torch.nn.parallel 16 | import torch.backends.cudnn as cudnn 17 | import torch.optim as optim 18 | from torch.utils import data 19 | import torch.nn.functional as F 20 | from torch.autograd import Variable 21 | 22 | from dataset.vos import Valset 23 | from networks.agssvos import AGSSVOS 24 | sys.path.append('flow_inference') 25 | from flow_inference.flow_inference import Inference_flow 26 | from tools import preprocess, visualize, utils 27 | import timeit 28 | import cv2 29 | 30 | 31 | def get_parser(): 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--batch-size', type=int) 34 | parser.add_argument('--root-data', type=str, help='sampled video') 35 | parser.add_argument('--root-all-data', type=str, help='full video') 36 | parser.add_argument('--list-path', type=str) 37 | parser.add_argument('--epoch', type=int) 38 | parser.add_argument('--start-epoch', type=int, default=0) 39 | parser.add_argument('--sample-size', type=int, default=10) 40 | parser.add_argument('--gpu', type=str, default='0', help='specify gpu(s)') 41 | parser.add_argument('--lr', type=float) 42 | parser.add_argument('--restore', type=str, default=None) 43 | parser.add_argument('--sample-dir', type=str) 44 | parser.add_argument('--snapshot-dir', type=str) 45 | parser.add_argument('--crop_size', type=int, default=512) 46 | parser.add_argument('--resize_h', type=int, default=360) 47 | parser.add_argument('--resize_w', type=int, default=640) 48 | parser.add_argument('--rgb_max', type=float, default=255.) 49 | parser.add_argument('--div_flow', type=int, default=20) 50 | parser.add_argument('--ignore_label', type=int, default=255) 51 | parser.add_argument('--scale_min', type=float, default=0.5, help='minimum random scale') 52 | parser.add_argument('--scale_max', type=float, default=2.0, help='maximum random scale') 53 | parser.add_argument('--rotate_min', type=float, default=-10, help='minimum random rotate') 54 | parser.add_argument('--rotate_max', type=float, default=10, help='maximum random rotate') 55 | parser.add_argument('--flow_checkpoint_path', type=str, default='models/FlowNet2-C_checkpoint.pth.tar', 56 | help='pretrained model for flownetC') 57 | parser.add_argument('--save-dir', type=str) 58 | parser.add_argument('--test-mode', type=str) 59 | parser.add_argument('--spec-vid', type=str, default=None) 60 | parser.add_argument('--spec-obj-ind', type=str, default=None) 61 | parser.add_argument('--show-step', type=int, default=5) 62 | parser.add_argument('--flow-scale', type=int, default=1) 63 | parser.add_argument('--flow-model', type=str, default='FlowNet2C') 64 | parser.add_argument('--show-img', action='store_true', help='whether save the visualized image result') 65 | return parser 66 | 67 | ### get logger ### 68 | def get_logger(): 69 | logger = logging.getLogger() 70 | logger.setLevel(logging.INFO) 71 | handler = logging.StreamHandler() 72 | fmt = "[%(asctime)s line %(lineno)d] %(message)s" 73 | handler.setFormatter(logging.Formatter(fmt)) 74 | logger.addHandler(handler) 75 | return logger 76 | 77 | def show(images, labels, preds): 78 | os.system('rm %s/*' % args.sample_dir) 79 | for i_bs in range(images.shape[0]): 80 | for j_bs in range(labels.shape[1]): 81 | path = args.sample_dir + '/' + str(i_bs)+'_'+str(j_bs) + '#' 82 | image = visualize.denorm_image(images[i_bs, j_bs, :]) 83 | label = visualize.vis_label(labels[i_bs, j_bs], 1, 128) 84 | cv2.imwrite(path + 'img.jpg', image) 85 | cv2.imwrite(path + 'lab.jpg', label) 86 | 87 | for j_bs in range(preds.shape[1]): 88 | path = args.sample_dir + '/' + str(i_bs)+'_'+str(j_bs) + '#' 89 | pred = preds[i_bs,j_bs]*255 90 | cv2.imwrite(path + 'pred.jpg', pred) 91 | 92 | def main(): 93 | global args, logger, writer 94 | args = get_parser().parse_args() 95 | logger = get_logger() 96 | 97 | test_mode = int(args.test_mode) > 0 98 | assert test_mode 99 | 100 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 101 | 102 | ## setting up model ## 103 | model = AGSSVOS().cuda() 104 | model = torch.nn.DataParallel(model).cuda() 105 | model.eval() 106 | flow_infer = Inference_flow(args, resume=args.restore) 107 | 108 | for m in model.module.Encoder.modules(): 109 | if isinstance(m, nn.BatchNorm2d): 110 | m.eval() 111 | for p in m.parameters(): 112 | p.requires_grad = False 113 | 114 | if args.restore != None: 115 | assert os.path.isfile(args.restore), "no restore file found at %s" % (args.restore) 116 | logger.info("loading from %s" % (args.restore)) 117 | 118 | checkpoint = torch.load(args.restore) 119 | 120 | state = model.state_dict() 121 | checkpoint = {k: v for k, v in checkpoint['seg'].items() if k in state} 122 | state.update(checkpoint) 123 | model.load_state_dict(state) 124 | 125 | del checkpoint 126 | torch.cuda.empty_cache() 127 | 128 | ### for debug ### 129 | spec_vid = args.spec_vid if not test_mode else None 130 | spec_obj_ind = [1,2,3] if not test_mode else None 131 | # print('spec_vid, spec_obj_ind', spec_vid, spec_obj_ind) 132 | 133 | testloader = data.DataLoader( 134 | Valset(root_data=args.root_data, root_all_data=args.root_all_data, json_meta_list=args.list_path, 135 | sample_size=args.sample_size, fix_size=False, half_size=True, 136 | test_mode=True, spec_vid=spec_vid, spec_obj_ind=spec_obj_ind), 137 | batch_size=args.batch_size, shuffle=False, num_workers=1, pin_memory=True) 138 | 139 | 140 | tot_iter = len(testloader) 141 | logger.info("Total iteration per epoch is %d" % (tot_iter)) 142 | vid_time = [] 143 | vid_time_cv2 = [] 144 | vid_frm_num = [] 145 | 146 | for i_iter, batch in enumerate(testloader): 147 | img, lab, ori_img, vid_name, min_idx, obj_num, obj_ind, obj_start_idx, ori_shape = batch 148 | 149 | vid_name = vid_name[0] 150 | 151 | logger.info('vid_name = %s' % vid_name) 152 | sample_dir = args.sample_dir+'/'+vid_name+'/' 153 | save_dir = args.save_dir+'/'+vid_name+'/' 154 | 155 | if not osp.exists(sample_dir): 156 | os.makedirs(sample_dir) 157 | if not osp.exists(save_dir): 158 | os.makedirs(save_dir) 159 | 160 | start_time = timeit.default_timer() 161 | start_time_cv2 = cv2.getTickCount() 162 | 163 | img = img.cuda().float() 164 | lab = lab[0].cuda().float() # KHW 165 | img = img.expand(lab.shape[0],-1,-1,-1,-1) # KT3HW 166 | ori_img = ori_img[0].numpy() # THW3 167 | min_idx = min_idx.item() 168 | obj_num = obj_num.item() 169 | obj_ind = [int(a[0]) for a in obj_ind] 170 | obj_start_idx = [a.item() for a in obj_start_idx] 171 | ori_shape = (ori_shape[0].item(), ori_shape[1].item()) 172 | 173 | preds = [] 174 | preds.append(lab[:,0:1].contiguous()) 175 | 176 | save_lab = torch.zeros(img.shape[1], obj_num+1, img.shape[-2], img.shape[-1]).cuda().float() 177 | save_lab[:,0,:] += 0.5 178 | 179 | obj_start_idx_set = set(obj_start_idx) 180 | assert 0 in obj_start_idx_set 181 | for start in range(img.shape[1]): 182 | if start not in obj_start_idx_set: 183 | continue 184 | ref_lab = [] 185 | obj_ind_tmp = [] 186 | for k,s in enumerate(obj_start_idx): 187 | if s == start: 188 | obj_ind_tmp.append(obj_ind[k]) 189 | ref_lab.append(lab[k:k+1]) 190 | ref_lab = torch.cat(ref_lab, dim=0) 191 | ref_img = img[:ref_lab.shape[0],start,:] 192 | 193 | preds = [torch.zeros(ref_lab.shape).cuda().float().unsqueeze(1)] 194 | cnt = 0 195 | for k,s in enumerate(obj_start_idx): 196 | if s == start: 197 | preds[0][cnt,0] = lab[k] 198 | save_lab[s,obj_ind[k]] = lab[k] 199 | cnt += 1 200 | 201 | ref_lab = ref_lab.max(0,keepdim=True)[0] 202 | with torch.no_grad(): 203 | ms = model.forward(ref_img, ref_lab) 204 | 205 | for i in tqdm(range(start+1, img.shape[1])): 206 | flow = flow_infer.infer(ori_img[i], ori_img[i-1], scale=args.flow_scale) 207 | prev_lab = utils.flow_warp_tensor(preds[-1], flow) 208 | 209 | merge_preds = prev_lab.max(0)[0] 210 | with torch.no_grad(): 211 | output, _ = model.forward(img[:ref_lab.shape[0],i], merge_preds, prev_lab.squeeze(1), ref=ms) 212 | 213 | output = output.detach() 214 | 215 | for idx in range(prev_lab.shape[0]): 216 | if (prev_lab[idx]>0.5).sum()==0: 217 | output[idx,0] *= 0 218 | 219 | preds.append(output.contiguous()) 220 | for idx,ind in enumerate(obj_ind_tmp): 221 | save_lab[i,ind,:] = output[idx,0] 222 | end_time = timeit.default_timer() 223 | end_time_cv2 = cv2.getTickCount() 224 | 225 | vid_time.append(end_time-start_time) 226 | vid_time_cv2.append((end_time_cv2-start_time_cv2)/cv2.getTickFrequency()) 227 | vid_frm_num.append(img.shape[1]) 228 | 229 | ### save & show result ### 230 | step = 5 if test_mode else args.show_step 231 | for i,lab in enumerate(save_lab): 232 | ### the only vid that has special interval ### 233 | if vid_name=='8aa47fac99': 234 | step = 3 if i < 80 else 5 235 | if i % step == 0: 236 | img = cv2.resize(ori_img[i], (ori_shape[1],ori_shape[0]), interpolation=cv2.INTER_LINEAR) 237 | visualize.show_save_lab_savesmall(i+min_idx, img, lab.cpu().numpy(), sample_dir, save_dir, 238 | show=args.show_img, save=True) 239 | 240 | logger.info(('sum', sum(vid_time), sum(vid_time_cv2), sum(vid_frm_num))) 241 | logger.info(('time', sum(vid_time)/sum(vid_frm_num), sum(vid_time_cv2)/sum(vid_frm_num))) 242 | 243 | if __name__ == '__main__': 244 | main() 245 | 246 | 247 | 248 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/networks/__init__.py -------------------------------------------------------------------------------- /networks/agssvos.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | from torch.utils import data 7 | import torch.utils.model_zoo as model_zoo 8 | from torchvision import models 9 | import logging 10 | import cv2 11 | import timeit 12 | 13 | class Encoder(nn.Module): 14 | def __init__(self, init_atn=False, freeze=False): 15 | super(Encoder, self).__init__() 16 | 17 | self.conv1_p = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=True) 18 | 19 | resnet = models.resnet50(pretrained=True) 20 | self.conv1 = resnet.conv1 21 | self.bn1 = resnet.bn1 22 | self.relu = resnet.relu # 1/2, 64 23 | self.maxpool = resnet.maxpool 24 | 25 | self.res2 = resnet.layer1 # 1/4, 256 26 | self.res3 = resnet.layer2 # 1/8, 512 27 | self.res4 = resnet.layer3 # 1/16, 1024 28 | self.res5 = resnet.layer4 # 1/32, 2048 29 | 30 | self.atn = nn.Sequential( 31 | nn.Conv2d(257, 256, kernel_size=1, padding=0, bias=True), 32 | nn.ReLU(True), 33 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True), 34 | nn.ReLU(True), 35 | nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2, bias=True) 36 | ) 37 | 38 | if init_atn: 39 | self._initialize_weights([self.atn]) 40 | 41 | if freeze: 42 | # freeze BNs 43 | for m in self.modules(): 44 | if isinstance(m, nn.BatchNorm2d): 45 | for p in m.parameters(): 46 | p.requires_grad = False 47 | p.track_running_stats = False 48 | 49 | def _initialize_weights(self, mods, zero=False): 50 | for s in mods: 51 | for m in s: 52 | if isinstance(m, nn.Conv2d): 53 | if not zero: 54 | m.weight.data.normal_(0, 0.01) 55 | else: 56 | m.weight.data.zero_() 57 | if m.bias is not None: 58 | m.bias.data.zero_() 59 | elif isinstance(m, nn.BatchNorm2d): 60 | m.weight.data.fill_(1) 61 | m.bias.data.zero_() 62 | elif isinstance(m, nn.Linear): 63 | m.weight.data.normal_(0, 0.01) 64 | m.bias.data.zero_() 65 | elif isinstance(m, nn.ConvTranspose2d): 66 | m.weight.data.zero_() 67 | m.weight.data = interp_surgery(m) 68 | 69 | 70 | def forward(self, in_f, in_p, objr2=False): 71 | f = in_f 72 | p = torch.unsqueeze(in_p, dim=1).float() # add channel dim 73 | x = self.conv1(f) + self.conv1_p(p) 74 | x = self.bn1(x) 75 | c1 = self.relu(x) # 1/2, 64 76 | x = self.maxpool(c1) # 1/4, 64 77 | r2 = self.res2(x) # 1/4, 64 78 | 79 | if objr2: 80 | p_s4 = F.upsample(p, r2.shape[-2:], mode='bilinear', align_corners=True) 81 | r2_atn = self.atn(torch.cat((r2, p_s4), dim=1)) 82 | return r2, r2_atn 83 | 84 | r3 = self.res3(r2) # 1/8, 128 85 | r4 = self.res4(r3) # 1/16, 256 86 | r5 = self.res5(r4) # 1/32, 512 87 | 88 | return r5, r4, r3, r2 89 | 90 | class GC(nn.Module): 91 | def __init__(self, inplanes, planes, kh=7, kw=7): 92 | super(GC, self).__init__() 93 | self.conv_l1 = nn.Conv2d(inplanes, 256, kernel_size=(kh, 1), 94 | padding=(int(kh/2), 0)) 95 | self.conv_l2 = nn.Conv2d(256, planes, kernel_size=(1, kw), 96 | padding=(0, int(kw/2))) 97 | self.conv_r1 = nn.Conv2d(inplanes, 256, kernel_size=(1, kw), 98 | padding=(0, int(kw/2))) 99 | self.conv_r2 = nn.Conv2d(256, planes, kernel_size=(kh, 1), 100 | padding=(int(kh/2), 0)) 101 | 102 | def forward(self, x): 103 | x_l = self.conv_l2(self.conv_l1(x)) 104 | x_r = self.conv_r2(self.conv_r1(x)) 105 | x = x_l + x_r 106 | return x 107 | 108 | 109 | class Refine(nn.Module): 110 | def __init__(self, inplanes, planes, scale_factor=2): 111 | super(Refine, self).__init__() 112 | self.convFS1 = nn.Conv2d(inplanes, planes, kernel_size=3, padding=1) 113 | self.convFS2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 114 | self.convFS3 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 115 | self.convMM1 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 116 | self.convMM2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 117 | self.scale_factor = scale_factor 118 | 119 | def forward(self, f, pm): 120 | s = self.convFS1(f) 121 | sr = self.convFS2(F.relu(s)) 122 | sr = self.convFS3(F.relu(sr)) 123 | s = s + sr 124 | if s.shape[-1] == pm.shape[-1]: 125 | m = s + pm 126 | else: 127 | m = s + F.upsample(pm, scale_factor=self.scale_factor, mode='bilinear') 128 | 129 | mr = self.convMM1(F.relu(m)) 130 | mr = self.convMM2(F.relu(mr)) 131 | m = m + mr 132 | return m 133 | 134 | 135 | class Decoder(nn.Module): 136 | def __init__(self, output_dim=1): 137 | super(Decoder, self).__init__() 138 | mdim = 256 139 | self.GC = GC(4096, mdim) # 1/32 -> 1/32 140 | self.convG1 = nn.Conv2d(mdim, mdim, kernel_size=3, padding=1) 141 | self.convG2 = nn.Conv2d(mdim, mdim, kernel_size=3, padding=1) 142 | self.RF4 = Refine(1024, mdim) # 1/16 -> 1/8 143 | self.RF3 = Refine(512, mdim) # 1/8 -> 1/4 144 | self.RF2 = Refine(256, mdim) # 1/4 -> 1 145 | 146 | self.pred2 = nn.Conv2d(mdim, output_dim, kernel_size=(3,3), padding=(1,1), stride=1) 147 | 148 | def _p_norm(self, p): 149 | p = F.softmax(p,1) 150 | p = p[:,1:] 151 | bg = 1-p.max(0,keepdim=True)[0] 152 | p = torch.cat((bg,p),dim=0) 153 | p = torch.clamp(p, 1e-7, 1-(1e-7)) 154 | p = p/(1-p) 155 | p = p/(p.sum(0,keepdim=True)) 156 | p = p[1:] 157 | return p 158 | 159 | def forward(self, r5, x5, r4, r3, r2, r2_obj, r2_atn): 160 | x = torch.cat((r5, x5), dim=1) 161 | 162 | x = self.GC(x) 163 | r = self.convG1(F.relu(x)) 164 | r = self.convG2(F.relu(r)) 165 | m5 = x + r # out: 1/32, 64 166 | m4 = self.RF4(r4, m5) # out: 1/16, 64 167 | m3 = self.RF3(r3, m4) # out: 1/8, 64 168 | 169 | m3 = m3.expand(r2_atn.shape[0],-1,-1,-1) 170 | 171 | m3_cat = m3 * r2_atn 172 | 173 | m2_cat = self.RF2(r2_obj, m3_cat) # out: 1/4, 64 174 | 175 | p2 = self.pred2(F.relu(m2_cat)) 176 | p = F.upsample(p2, scale_factor=4, mode='bilinear') 177 | p = self._p_norm(p) 178 | 179 | return p 180 | 181 | 182 | class AGSSVOS(nn.Module): 183 | def __init__(self, output_dim=2, init_atn=True, freeze=True): 184 | super(AGSSVOS, self).__init__() 185 | self.Encoder = Encoder(init_atn=init_atn, freeze=freeze) 186 | self.Decoder = Decoder(output_dim=output_dim) 187 | 188 | def forward(self, x, l_merge, l_obj=None, ref=None): 189 | if ref is None: 190 | r5, r4, r3, r2 = self.Encoder.forward(x[0:1], l_merge) 191 | return r5 192 | else: 193 | r5, r4, r3, r2 = self.Encoder.forward(x[0:1], l_merge) 194 | r2_obj, r2_atn = self.Encoder.forward(x, l_obj, objr2=True) 195 | p = self.Decoder.forward(r5, ref, r4, r3, r2, r2_obj, r2_atn) 196 | 197 | return p, r5 198 | 199 | -------------------------------------------------------------------------------- /run_davis.sh: -------------------------------------------------------------------------------- 1 | name=train_davis 2 | echo $name 3 | tgt_dir=Outputs/$name 4 | if [ ! -d $tgt_dir ]; then 5 | mkdir -p $tgt_dir 6 | fi 7 | python3 train_davis.py \ 8 | --root-data='data/davis2017/trainval' \ 9 | --root-all-data='data/davis2017/trainval' \ 10 | --meta-list='data/davis2017/trainval/train_meta.json' \ 11 | --restore='checkpoints/weights.pth' \ 12 | --epoch=200 \ 13 | --random-ref \ 14 | --random-crop \ 15 | --lr-atn \ 16 | --loss-iou-maxmin \ 17 | --batch-size=1 \ 18 | --start-epoch=0 \ 19 | --sample-size=8 \ 20 | --lr=1e-5 \ 21 | --gpu='5' \ 22 | --sample-dir=$tgt_dir'/sample' \ 23 | --snapshot-dir=$tgt_dir'/snapshot' \ 24 | --fix-lr=0 \ 25 | 2>&1 | tee $tgt_dir/train.log 26 | -------------------------------------------------------------------------------- /run_ytv.sh: -------------------------------------------------------------------------------- 1 | name=train_ytv 2 | echo $name 3 | tgt_dir=Outputs/$name 4 | if [ ! -d $tgt_dir ]; then 5 | mkdir -p $tgt_dir 6 | fi 7 | python3 train_ytv.py \ 8 | --root-data='data/youtube_vos/train/' \ 9 | --meta-list='data/youtube_vos/train/meta.json' \ 10 | --restore='checkpoints/weights.pth' \ 11 | --batch-size=1 \ 12 | --start-epoch=0 \ 13 | --epoch=5 \ 14 | --random-ref \ 15 | --lr-atn \ 16 | --loss-iou-maxmin \ 17 | --sample-size=8 \ 18 | --lr=1e-5 \ 19 | --gpu='6' \ 20 | --sample-dir=$tgt_dir'/sample' \ 21 | --snapshot-dir=$tgt_dir'/snapshot' \ 22 | 2>&1 | tee $tgt_dir/train.log 23 | -------------------------------------------------------------------------------- /test_davis.sh: -------------------------------------------------------------------------------- 1 | name=train_davis 2 | echo $name 3 | tgt_dir='test_dir/'$name 4 | if [ ! -d $tgt_dir ]; then 5 | mkdir -p $tgt_dir 6 | fi 7 | python3 infer_davis.py \ 8 | --root-data='data/davis2017/test' \ 9 | --root-all-data='data/davis2017/test' \ 10 | --list-path='data/davis2017/test/test_meta.json' \ 11 | --restore='checkpoints/train_davis/model_199.pth' \ 12 | --batch-size=1 \ 13 | --start-epoch=0 \ 14 | --epoch=1 \ 15 | --sample-size=20 \ 16 | --lr=1e-5 \ 17 | --gpu='2' \ 18 | --sample-dir=$tgt_dir'/sample' \ 19 | --save-dir=$tgt_dir'/save' \ 20 | --test-mode=1 \ 21 | 2>&1 | tee $tgt_dir/val.log 22 | -------------------------------------------------------------------------------- /test_davis_ft.sh: -------------------------------------------------------------------------------- 1 | name=ft_davis 2 | echo $name 3 | tgt_dir='test_dir/'$name 4 | if [ ! -d $tgt_dir ]; then 5 | mkdir -p $tgt_dir 6 | fi 7 | python3 infer_davis.py \ 8 | --root-data='data/davis2017/test' \ 9 | --root-all-data='data/davis2017/test' \ 10 | --list-path='data/davis2017/test/test_meta.json' \ 11 | --restore='checkpoints/ft_davis/model_99.pth' \ 12 | --batch-size=1 \ 13 | --start-epoch=0 \ 14 | --epoch=1 \ 15 | --sample-size=20 \ 16 | --lr=1e-6 \ 17 | --gpu='3' \ 18 | --sample-dir=$tgt_dir'/sample' \ 19 | --save-dir=$tgt_dir'/save' \ 20 | --test-mode=1 \ 21 | 2>&1 | tee $tgt_dir/val.log 22 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/AGSS-VOS/e9272365aa45bf098316d7111238fe0ab8df8a17/tools/__init__.py -------------------------------------------------------------------------------- /tools/preprocess.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import random 4 | from scipy.ndimage import distance_transform_edt as Dte 5 | import torch 6 | 7 | 8 | def resize(image, new_size, label=False): 9 | r""" 10 | resize a image to make the longer size match the new size 11 | :param image: both HW3 or HW1 ok 12 | :param new_size: an int, s, denote the value of the longer size 13 | :return: resized image 14 | """ 15 | if max(image.shape[0], image.shape[1]) == new_size: 16 | return image.copy() 17 | else: 18 | if image.shape[0] >image.shape[1]: 19 | dh = new_size 20 | else: 21 | dw = new_size 22 | if label: 23 | img = cv2.resize(image, dsize=(dw, dh), interpolation=cv2.INTER_NEAREST) 24 | else: 25 | img = cv2.resize(image, dsize=(dw, dh), interpolation=cv2.INTER_LINEAR) 26 | return img 27 | 28 | def resize_tuple(image, new_size, label=False): 29 | r""" 30 | resize a image to make the longer size match the new size 31 | :param image: both HW3 or HW1 ok 32 | :param new_size: tuple, (h,w) 33 | :return: resized image 34 | """ 35 | if (image.shape[0], image.shape[1]) == new_size: 36 | return image.copy() 37 | else: 38 | if image.shape[0]*1./new_size[0] > image.shape[1]*1./new_size[1]: 39 | dh = new_size[0] 40 | dw = dh * image.shape[1] // image.shape[0] 41 | else: 42 | dw = new_size[1] 43 | dh = dw * image.shape[0] // image.shape[1] 44 | if label: 45 | img = cv2.resize(image, dsize=(dw, dh), interpolation=cv2.INTER_NEAREST) 46 | else: 47 | img = cv2.resize(image, dsize=(dw, dh), interpolation=cv2.INTER_LINEAR) 48 | return img 49 | 50 | 51 | def resize_square(image, new_size, label=False): 52 | r""" 53 | resize a image to make the both size match the new size 54 | :param image: both HW3 or HW1 ok 55 | :param new_size: an int, s, denote the value of the new size 56 | :return: resized image 57 | """ 58 | if label: 59 | img = cv2.resize(image, dsize=(new_size, new_size), interpolation=cv2.INTER_NEAREST) 60 | else: 61 | img = cv2.resize(image, dsize=(new_size, new_size), interpolation=cv2.INTER_LINEAR) 62 | return img 63 | 64 | def resize_scale(image, scale, label=False): 65 | r""" 66 | resize a image to make the longer size match the new size 67 | :param image: both HW3 or HW1 ok 68 | :param scale: an int, s, denote the resize scale of image 69 | :return: resized image 70 | """ 71 | if image.shape[0] % scale != 0 or image.shape[1] % scale != 0: 72 | assert image.shape[0] % scale == 0 73 | assert image.shape[1] % scale == 0 74 | dh = image.shape[0] / scale 75 | dw = image.shape[1] / scale 76 | if label: 77 | img = cv2.resize(image, dsize=(dw, dh), interpolation=cv2.INTER_NEAREST) 78 | else: 79 | img = cv2.resize(image, dsize=(dw, dh), interpolation=cv2.INTER_LINEAR) 80 | return img 81 | 82 | 83 | def crop_tensor(label, lab_idx=1, shift=10): 84 | """ 85 | 86 | :param label: tensor, HW 87 | :param lab_idx: 88 | :param shift: 89 | :return: 90 | """ 91 | coord = torch.nonzero((label>0.5) == lab_idx) 92 | if coord.shape[0] == 0: 93 | return None, None, None, None 94 | h1,h2 = coord[:,0].min(), coord[:,0].max() 95 | w1,w2 = coord[:,1].min(), coord[:,1].max() 96 | det_h1 = det_h2 = det_w1 = det_w2 = shift 97 | h1 = max(h1.item()-det_h1, 0) 98 | h2 = min(h2.item()+det_h2, label.shape[0]-1) 99 | w1 = max(w1.item()-det_w1, 0) 100 | w2 = min(w2.item()+det_w2, label.shape[1]-1) 101 | return h1, h2, w1, w2 102 | 103 | 104 | def crop(label, lab_idx=255, scale=0.2, shift=10, mode='scale', in_scale=0.1, out_scale=1.0, out_shift=300): 105 | r""" 106 | return a croped roi with a given scale/shift 107 | :param label: HW 2D, numpy 108 | :param lab_idx: 109 | :param scale: 110 | :param mode: 'scale' or 'shift' 111 | :return: the leftest coord and rightest coord, closed interval 112 | """ 113 | coord = np.nonzero((label>0.5) == lab_idx) 114 | h1,h2 = coord[0].min(), coord[0].max() 115 | w1,w2 = coord[1].min(), coord[1].max() 116 | if mode == 'scale': 117 | det_h = int((h2-h1)*scale) 118 | det_w = int((w2-w1)*scale) 119 | det_h1 = det_h2 = det_h 120 | det_w1 = det_w2 = det_w 121 | elif mode == 'shift': 122 | det_h = shift 123 | det_w = shift 124 | det_h1 = det_h2 = det_h 125 | det_w1 = det_w2 = det_w 126 | elif mode == 'jitter': 127 | s = random.uniform(-scale, scale) 128 | det_h1 = int((h2-h1)*s) + random.uniform(0, out_shift) 129 | s = random.uniform(-scale, scale) 130 | det_h2 = int((h2-h1)*s) + random.uniform(0, out_shift) 131 | s = random.uniform(-scale, scale) 132 | det_w1 = int((w2-w1)*s) + random.uniform(0, out_shift) 133 | s = random.uniform(-scale, scale) 134 | det_w2 = int((w2-w1)*s) + random.uniform(0, out_shift) 135 | elif mode == 'none': 136 | det_h1 = det_w1 = det_h2 = det_w2 = 0 137 | elif mode == 'jitter_in_out': 138 | s = random.uniform(-in_scale, out_scale) 139 | det_h1 = int((h2-h1)*s) 140 | s = random.uniform(-in_scale, out_scale) 141 | det_h2 = int((h2-h1)*s) 142 | s = random.uniform(-in_scale, out_scale) 143 | det_w1 = int((w2-w1)*s) 144 | s = random.uniform(-in_scale, out_scale) 145 | det_w2 = int((w2-w1)*s) 146 | else: 147 | assert False, mode 148 | h1 = max(h1-det_h1, 0) 149 | h2 = min(h2+det_h2, label.shape[0]-1) 150 | w1 = max(w1-det_w1, 0) 151 | w2 = min(w2+det_w2, label.shape[1]-1) 152 | return h1, h2, w1, w2 153 | 154 | 155 | def crop_wh_ratio(label, lab_idx=255, shift=10, ratio=2): 156 | """ 157 | give a crop box wrt given ratio of wh 158 | :param label: 159 | :param lab_idx: 160 | :param shift: 161 | :param ratio: w:h = ratio:1 162 | :return: 163 | """ 164 | coord = np.nonzero((label>0.5) == lab_idx) 165 | h1,h2 = coord[0].min(), coord[0].max() 166 | w1,w2 = coord[1].min(), coord[1].max() 167 | 168 | det_h = shift 169 | det_w = shift 170 | det_h1 = det_h2 = det_h 171 | det_w1 = det_w2 = det_w 172 | 173 | h1 = max(h1-det_h1, 0) 174 | h2 = min(h2+det_h2, label.shape[0]-1) 175 | w1 = max(w1-det_w1, 0) 176 | w2 = min(w2+det_w2, label.shape[1]-1) 177 | 178 | 179 | return h1, h2, w1, w2 180 | 181 | 182 | def norm(image, 183 | mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225), 184 | scale=255.): 185 | r""" 186 | 187 | :param image: numpy 188 | :param mean: 189 | :param std: 190 | :param scale: 191 | :return: numpy, float 192 | """ 193 | img = image[:, :, ::-1].astype(np.float) 194 | img /= scale 195 | img -= mean 196 | img /= std 197 | return img 198 | 199 | def prepare(img, label=False): 200 | """ 201 | 202 | :param img: img(3d) or lab(2d), np 203 | :param label: whether label 204 | :return: 205 | """ 206 | if not label: 207 | img = norm(img.copy()) 208 | img = torch.FloatTensor(img).cuda().float() 209 | img = img.transpose(1,2).transpose(0,1) 210 | return img.unsqueeze(0) 211 | else: 212 | lab = img.copy() 213 | lab = torch.FloatTensor(lab).cuda().float() 214 | return lab.unsqueeze(0) 215 | 216 | 217 | def norm_4d(image, mean, std, scale=255.): 218 | r""" 219 | 220 | :param image: numpy, nchw 221 | :param mean: 222 | :param std: 223 | :param scale: 224 | :return: numpy, float 225 | """ 226 | img = image[:, :, :, ::-1].astype(np.float) 227 | img /= scale 228 | img -= mean 229 | img /= std 230 | return img 231 | 232 | 233 | def add_edge(label, kernel_erode, kernel_dilate, edge_value=255): 234 | r""" 235 | if kernel=0, means it do not need to dilate/erode 236 | :param label: HW, numpy, uint8 237 | :param kernel_erode: >=0 238 | :param kernel_dilate: >=0 239 | :return: uint8 240 | """ 241 | if kernel_dilate > 0: 242 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_dilate, kernel_dilate)) 243 | lab_dilate = cv2.dilate(label, kernel) 244 | else: 245 | lab_dilate = label.copy() 246 | if kernel_erode > 0: 247 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_erode, kernel_erode)) 248 | lab_erode = cv2.erode(label, kernel) 249 | else: 250 | lab_erode = label.copy() 251 | edge = lab_dilate - lab_erode 252 | lab = label * (edge == 0) + edge * edge_value 253 | return lab 254 | 255 | def get_gaussian_mask(shape, point_set, rho=10., reverse=False, bg_zero=False): 256 | """ 257 | :param shape: 258 | :param point_set: 259 | :param rho: 260 | :param reverse: if True, means point in edge is (x,y), else (h,w) 261 | :return: 262 | """ 263 | mask = np.ones(shape).astype(np.float) 264 | for p in point_set: 265 | if reverse: 266 | mask[p[1], p[0]] = 0 267 | else: 268 | mask[p[0], p[1]] = 0 269 | mask = Dte(mask) 270 | mask /= rho 271 | mask = np.exp(-mask * mask / 2.) 272 | if bg_zero: 273 | mask = np.minimum(mask, 1.) 274 | else: 275 | mask = 1 - np.minimum(mask, 1.) 276 | return mask 277 | 278 | def gen_gaussian_map(label, num, edge_mask=255, bg_zero=False): 279 | """ 280 | 281 | :param label: 282 | :param num: 283 | :param edge_mask: 284 | :param bg_zero: if True, the bg is filled with zero, else one 285 | :return: 286 | """ 287 | if num > 0: 288 | edge = (label == edge_mask).astype(np.uint8) 289 | edge = np.nonzero(edge) 290 | edge = np.array(edge).transpose() 291 | if edge.shape[0] == 0: 292 | assert 'no edge' 293 | step = edge.shape[0] / num 294 | if step == 0: 295 | step +=1 296 | edge = edge[::step] 297 | mask = get_gaussian_mask(label.shape, edge, bg_zero=bg_zero) 298 | else: 299 | if bg_zero: 300 | mask = np.zeros(label.shape).astype(np.float) 301 | else: 302 | mask = np.ones(label.shape).astype(np.float) 303 | return mask 304 | 305 | def get_center_point(label_ori, lab_idx=1): 306 | """ 307 | get a center point of a 2D map 308 | :param label_ori: HW, 2D, Tensor 309 | :param lab_idx: value of the object 310 | :return: list with length=1, [(h,w)] 311 | """ 312 | label = (label_ori == lab_idx).cpu().numpy().astype(np.float) 313 | shape = label.shape 314 | lab_pad = cv2.copyMakeBorder(label, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 315 | value=(0.)) 316 | lab_pad = Dte(lab_pad) 317 | lab_pad = lab_pad[1:-1, 1:-1].view() 318 | idx = np.argmax(lab_pad) 319 | coord = (idx / shape[1], idx % shape[1]) 320 | return [coord] 321 | 322 | def mask2box(mask, mode='none'): 323 | h1,h2,w1,w2 = crop((mask>0.5).astype(np.uint8), lab_idx=1, mode=mode) 324 | return np.array([w1,h1,w2,h2]) 325 | 326 | def get_dsize(image, half_size=False, scale=8, max_size=700): 327 | if not half_size or max(image.shape[0], image.shape[1])<=max_size: 328 | dh = image.shape[0]//scale*scale 329 | dw = image.shape[1]//scale*scale 330 | else: 331 | dh = image.shape[0]//2//scale*scale 332 | dw = image.shape[1]//2//scale*scale 333 | if dh*dw > 1280*720: 334 | dh = image.shape[0]//4*3 335 | dw = image.shape[1]//4*3 336 | dh = dh//8*8 337 | dw = dw//8*8 338 | if dh*dw > 1280*720: 339 | if image.shape[0] > image.shape[1]: 340 | dh = 640 341 | dw = 320 342 | else: 343 | dw = 640 344 | dh = 320 345 | dw = max(dw, scale) 346 | dh = max(dh, scale) 347 | dsize = (dw,dh) 348 | return dsize 349 | 350 | def get_dsize_align(image, longsize=640, scale=32): 351 | if max(image.shape[0], image.shape[1]) == longsize: 352 | dh = image.shape[0] 353 | dw = image.shape[1] 354 | else: 355 | if image.shape[0] > image.shape[1]: 356 | dh = longsize 357 | dw = image.shape[1] * longsize // image.shape[0] 358 | else: 359 | dw = longsize 360 | dh = image.shape[0] * longsize // image.shape[1] 361 | dsize = (dw//scale*scale,dh//scale*scale) 362 | return dsize 363 | 364 | 365 | def get_dsize_ratio(image, ratio=0.5, scale=8): 366 | h = image.shape[0] 367 | w = image.shape[1] 368 | dh = int(h*ratio) // scale*scale 369 | dw = int(w*ratio) // scale*scale 370 | dsize = (dw,dh) 371 | return dsize 372 | 373 | 374 | def mask2rect(mask, context=50): 375 | box = mask2box(mask) 376 | box[0] = max(box[0]-context, 0) 377 | box[1] = max(box[1]-context, 0) 378 | box[2] = min(box[2]+context, mask.shape[1]-1) 379 | box[3] = min(box[3]+context, mask.shape[0]-1) 380 | ret = np.zeros(mask.shape, dtype=np.uint8) 381 | ret[box[1]:box[3]+1, box[0]:box[2]+1] = 1 382 | return ret 383 | -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import logging 4 | import math 5 | import numpy as np 6 | import numpy.random as npr 7 | import torch 8 | from torch.autograd import Variable 9 | import torch.nn as nn 10 | import torch.nn.functional as f 11 | import timeit 12 | 13 | def mkdir(*lis): 14 | r""" mkdir -p a new path 15 | :param lis: a dir set, like mkdir('aaa','bbb') 16 | :return: void 17 | """ 18 | for s in lis: 19 | if not type(s) == type(None) and not os.path.exists(s): 20 | os.makedirs(s) 21 | 22 | 23 | def lr_poly(base_lr, i_iter, max_iter, epoch, tot_epoch, power=0.9): 24 | r""" 25 | calc the lr of poly 26 | :param base_lr: 27 | :param i_iter: 28 | :param max_iter: 29 | :param epoch: 30 | :param tot_epoch: 31 | :param power: 32 | :return: 33 | """ 34 | cur = epoch*max_iter + i_iter 35 | tot = max_iter * tot_epoch 36 | lr = base_lr * ((1 - float(cur) / tot) ** (power)) 37 | return lr 38 | 39 | 40 | def adjust_optim(optimizer, lr, start, end, scale=1): 41 | r""" 42 | optimizer[start~end-1] = lr*scale 43 | :param optimizer: 44 | :param lr: 45 | :param start: 46 | :param end: 47 | :param scale: 48 | :return: 49 | """ 50 | for i in range(start, end): 51 | optimizer.param_groups[i]['lr'] = lr * scale 52 | 53 | def adjust_optim_all(optimizer, lr, scale_lr=None, scale=10.): 54 | ''' 55 | lr*scale for the specified param 56 | :param optimizer: 57 | :param lr: 58 | :param scale_lr: 59 | :param scale: 60 | :return: 61 | ''' 62 | for idx, param_group in enumerate(optimizer.param_groups): 63 | if scale_lr is not None and scale_lr[idx]: 64 | param_group['lr'] = lr * scale 65 | else: 66 | param_group['lr'] = lr 67 | 68 | def calc_remain_time(run_time, i_iter, max_iter, epoch, tot_epoch): 69 | r""" 70 | run_time is the running time of a unit of operation 71 | :param run_time: 72 | :param i_iter: 73 | :param max_iter: 74 | :param epoch: 75 | :param tot_epoch: 76 | :return: 77 | """ 78 | remain = run_time*(max_iter-i_iter + max_iter*(tot_epoch-epoch-1)) 79 | remain = int(remain) / 60 80 | min = remain % 60 81 | remain /= 60 82 | return '%.2f hour %.2f min' % (remain, min) 83 | 84 | 85 | def flow_warp_tensor(mask, flow, coord=None): 86 | """ 87 | 88 | :param mask: 1xCxHxW 89 | :param flow: HxWx2 90 | :return: 91 | """ 92 | if coord is None: 93 | shape = mask.shape[-2:] 94 | coord = torch.ones(shape).cuda().long() 95 | coord = torch.nonzero(coord).float() 96 | flow = flow2coord(flow, coord) 97 | flow = flow.unsqueeze(0) 98 | if mask.shape[0] > 1: 99 | flow = flow.expand(mask.shape[0], -1, -1, -1) 100 | mask = f.grid_sample(mask, flow) 101 | return mask 102 | 103 | 104 | def flow2coord(flow, coord, norm=True): 105 | if not isinstance(flow, torch.Tensor): 106 | flow = torch.Tensor(flow).cuda().float() 107 | shape = flow.shape[:2] 108 | 109 | coord2 = coord.clone() 110 | coord[:, 0] = coord2[:, 1] 111 | coord[:, 1] = coord2[:, 0] 112 | 113 | coord = coord.view(shape[0],shape[1],2) 114 | 115 | coord += flow 116 | if norm: 117 | coord[:,:,0] = (coord[:,:,0]-shape[1]/2.) / (shape[1]/2.) 118 | coord[:,:,1] = (coord[:,:,1]-shape[0]/2.) / (shape[0]/2.) 119 | 120 | return coord 121 | 122 | 123 | def loss_calc_iou(pred, label, unify=False, optim_hard=False, square=False, eps=1e-7, ignore_index=255): 124 | """ 125 | IoU = |min(P,Y)|/|max(P,Y)| 126 | :param pred: N1HW, variable, must bel [0,1] 127 | :param label: N1HW, tensor 128 | :return: variable with one value 129 | """ 130 | if not unify: 131 | pred = pred.view(pred.shape[0], -1) 132 | label = label.view(label.shape[0], -1) 133 | gt = Variable((label==1), requires_grad=False).cuda().float() 134 | mask = Variable((label!=ignore_index), requires_grad=False).cuda().float() 135 | if unify: 136 | loss_seg = 1. - ((torch.min(gt, pred)*mask).sum()) / ((torch.max(gt, pred)*mask).sum()+eps) 137 | else: 138 | ones = Variable(torch.ones(pred.shape[0]), requires_grad=False).cuda().float() 139 | loss_seg = ones - ((torch.min(gt, pred)*mask).sum(1)) / ((torch.max(gt, pred)*mask).sum(1)+eps) 140 | if square: 141 | loss_seg = (loss_seg**2)/2 142 | if not unify and optim_hard: 143 | return torch.max(loss_seg) 144 | else: 145 | return torch.mean(loss_seg) 146 | 147 | 148 | def loss_calc_iou_v2(pred, label, unify=False, optim_hard=False, square=False, eps=1e-7, ignore_index=255): 149 | """ 150 | IoU = |P*Y|/|P+Y-P*Y| 151 | :param pred: N1HW, variable, must bel [0,1] 152 | :param label: N1HW, tensor 153 | :return: variable with one value 154 | """ 155 | if not unify: 156 | pred = pred.view(pred.shape[0], -1) 157 | label = label.view(label.shape[0], -1) 158 | gt = Variable((label==1), requires_grad=False).cuda().float() 159 | mask = Variable((label!=ignore_index), requires_grad=False).cuda().float() 160 | if unify: 161 | loss_seg = 1. - (((gt*pred)*mask).sum()) / (((gt+pred-gt*pred)*mask).sum()+eps) 162 | else: 163 | ones = Variable(torch.ones(pred.shape[0]), requires_grad=False).cuda().float() 164 | intsec = gt*pred 165 | union = gt+pred-intsec 166 | loss_seg = ones - ((intsec*mask).sum(1)) / ((union*mask).sum(1)+eps) 167 | if square: 168 | loss_seg = (loss_seg**2)/2 169 | if not unify and optim_hard: 170 | return torch.max(loss_seg) 171 | else: 172 | return torch.mean(loss_seg) 173 | 174 | 175 | def calc_iou(pred, label, threshold=0.5, ignore_index=255, merge=False): 176 | """ 177 | calc the intersection over union 178 | :param pred: N1HW or NHW, Tensor 179 | :param label: NHW or N1HW, Tensor 180 | :return: intsec & union sum 181 | """ 182 | pred = pred.view(pred.shape[0], -1) 183 | label = label.view(label.shape[0], -1) 184 | pred = (pred > threshold).long() 185 | mask = (label != ignore_index).long() 186 | if mask.sum() == 0: 187 | assert 1<0, (mask.sum()) 188 | intsec = (label * pred) * mask 189 | union = (label + pred - intsec) * mask 190 | if merge: 191 | if union.sum().item() == 0: 192 | if intsec.sum().item() == 0: 193 | return 1. 194 | else: 195 | return 0. 196 | else: 197 | return intsec.sum().item()*1./union.sum().item() 198 | else: 199 | iou = [] 200 | for i in range(intsec.shape[0]): 201 | if union[i].sum() == 0: 202 | iou.append(1. if intsec[i].sum()==0. else 0.) 203 | else: 204 | iou.append((intsec[i].sum().item()*1./union[i].sum().item())) 205 | return iou 206 | -------------------------------------------------------------------------------- /tools/visualize.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def denorm_image(image, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), scale=255.): 6 | r""" 7 | img_nrom = (img/scale-mean)/std 8 | show a image that has been normalized 9 | :param image: 3xHxW, numpy, float 10 | :param mean: 11 | :param std: 12 | :return: original image, np.uint8 13 | """ 14 | img = image.transpose((1,2,0)) 15 | img = (img*std+mean)*scale 16 | img = img[:,:,::-1].astype(np.uint8) 17 | return img 18 | 19 | 20 | def vis_label(label, vis_index, vis_value): 21 | r""" 22 | label[label==vis_index] = vis_value 23 | :param label: HxW, numpy 24 | :param vis_index: [0,255] 25 | :param vis_value: [0,255] 26 | :return: 27 | """ 28 | lab = label.copy() 29 | lab[label == vis_index] = vis_value 30 | lab = lab.astype(np.uint8) 31 | return lab 32 | 33 | def vis_mask(mask, scale): 34 | """ 35 | mask *= scale 36 | :param mask: HW 2D, numpy 37 | :param scale: [0,255], int 38 | :return: 39 | """ 40 | ret = mask*scale 41 | ret = ret.astype(np.uint8) 42 | return ret 43 | 44 | def vis_flow(flow, scale=50.): 45 | flow = flow.copy()*scale 46 | hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8) 47 | hsv[...,1] = 255 48 | 49 | mag, ang = cv2.cartToPolar(flow[...,0], flow[...,1]) 50 | hsv[...,0] = ang*180/np.pi/2 51 | hsv[...,2] = cv2.normalize(mag,None,0,255,cv2.NORM_MINMAX) 52 | bgr = cv2.cvtColor(hsv,cv2.COLOR_HSV2BGR) 53 | 54 | return bgr 55 | 56 | 57 | def show_save_lab_savesmall(idx, img_ori, cur_lab, sample_dir=None, save_dir=None, show=True, save=False): 58 | lab_c = np.argmax(cur_lab, axis=0) 59 | lab_c = cv2.resize(lab_c, (img_ori.shape[1], img_ori.shape[0]), interpolation=cv2.INTER_NEAREST) 60 | if save: 61 | cv2.imwrite('%s/%05d.png' % (save_dir,idx), lab_c) 62 | if not show: 63 | return 64 | img = img_ori.copy().astype(np.float) 65 | max_c = lab_c.max() 66 | rat = 0.3 67 | base = 255 68 | ### k==0 denote background ### 69 | for k in range(1, max_c+1): 70 | if k == 1: 71 | # img[lab_c==k, 0] = 255 72 | img[lab_c==k, 0] *= rat 73 | img[lab_c==k, 0] += (1-rat)*base 74 | elif k == 2: 75 | img[lab_c==k, 1] *= rat 76 | img[lab_c==k, 1] += (1-rat)*base 77 | # img[lab_c==k, 1] = 255 78 | elif k == 3: 79 | # img[lab_c==k, 0] = 255 80 | img[lab_c==k, 2] *= rat 81 | img[lab_c==k, 2] += (1-rat)*base 82 | # img[lab_c==k, 2] *= rat 83 | # img[lab_c==k, 2] += (1-rat)*base 84 | elif k == 4: 85 | # img[lab_c==k, 0] = 0 86 | img[lab_c==k, 1] *= rat 87 | img[lab_c==k, 1] += (1-rat)*128 88 | elif k == 5: 89 | img[lab_c==k, 1] = 0 90 | elif k == 6: 91 | img[lab_c==k, 2] = 0 92 | elif k == 7: 93 | img[lab_c==k, 0] = 128 94 | elif k == 8: 95 | img[lab_c==k, 1] = 128 96 | elif k == 9: 97 | img[lab_c==k, 2] = 128 98 | elif k == 10: 99 | img[lab_c==k, 0] = 90 100 | else: 101 | assert 1<0 102 | cv2.imwrite('%s/%05d.jpg' % (sample_dir, idx), img) 103 | 104 | -------------------------------------------------------------------------------- /train_davis.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import os 4 | import time 5 | import logging 6 | import numpy as np 7 | import argparse 8 | import random 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.backends.cudnn as cudnn 14 | import torch.optim as optim 15 | from torch.utils import data 16 | import torch.nn.functional as F 17 | from torch.autograd import Variable 18 | 19 | from dataset.vos import Trainset 20 | from networks.agssvos import AGSSVOS 21 | sys.path.append('flow_inference') 22 | from flow_inference.flow_inference import Inference_flow 23 | from tools import preprocess, visualize, utils 24 | import timeit 25 | import cv2 26 | 27 | 28 | def get_parser(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--batch-size', type=int) 31 | parser.add_argument('--root-data', type=str) 32 | parser.add_argument('--root-all-data', type=str) 33 | parser.add_argument('--meta-list', type=str) 34 | parser.add_argument('--epoch', type=int) 35 | parser.add_argument('--start-epoch', type=int, default=0) 36 | parser.add_argument('--sample-size', type=int, default=10) 37 | parser.add_argument('--gpu', type=str, default='0') 38 | parser.add_argument('--lr', type=float) 39 | parser.add_argument('--finetune', action='store_true') 40 | parser.add_argument('--init-atn', action='store_true') 41 | parser.add_argument('--freeze', action='store_true') 42 | parser.add_argument('--set-bn-no-update', action='store_true') 43 | parser.add_argument('--random-crop', action='store_true') 44 | parser.add_argument('--iou-thr-per-obj', action='store_true') 45 | parser.add_argument('--lr-atn', action='store_true') 46 | parser.add_argument('--lr-after-atn', action='store_true') 47 | parser.add_argument('--three-frames-data', action='store_true') 48 | parser.add_argument('--loss-iou-maxmin', action='store_true') 49 | parser.add_argument('--random-ref', action='store_true') 50 | parser.add_argument('--random-skip', action='store_true') 51 | parser.add_argument('--restore', type=str, default=None) 52 | parser.add_argument('--sample-dir', type=str) 53 | parser.add_argument('--snapshot-dir', type=str) 54 | parser.add_argument('--crop_size', type=int, default=512) 55 | parser.add_argument('--resize_h', type=int, default=360) 56 | parser.add_argument('--resize_w', type=int, default=640) 57 | parser.add_argument('--rgb_max', type=float, default=255.) 58 | parser.add_argument('--div_flow', type=int, default=20) 59 | parser.add_argument('--ignore_label', type=int, default=255) 60 | parser.add_argument('--scale_min', type=float, default=0.5, help='minimum random scale') 61 | parser.add_argument('--scale_max', type=float, default=2.0, help='maximum random scale') 62 | parser.add_argument('--rotate_min', type=float, default=-10, help='minimum random rotate') 63 | parser.add_argument('--rotate_max', type=float, default=10, help='maximum random rotate') 64 | parser.add_argument('--flow_checkpoint_path', type=str, default='models/FlowNet2-C_checkpoint.pth.tar', 65 | help='pretrained model for flownetC') 66 | parser.add_argument('--fix-lr', type=int, default=0) 67 | parser.add_argument('--show-img', action='store_true', help='show intermediate result') 68 | return parser 69 | 70 | # get logger 71 | def get_logger(): 72 | logger = logging.getLogger('train') 73 | logger.setLevel(logging.INFO) 74 | handler = logging.StreamHandler() 75 | fmt = "[%(asctime)s line %(lineno)d] %(message)s" 76 | handler.setFormatter(logging.Formatter(fmt)) 77 | logger.addHandler(handler) 78 | return logger 79 | 80 | def show(images, labels, preds, prev_labs): 81 | os.system('rm %s/*' % args.sample_dir) 82 | for i_bs in range(images.shape[0]): 83 | for j_bs in range(labels.shape[1]): 84 | path = args.sample_dir + '/' + str(i_bs)+'_'+str(j_bs) + '#' 85 | image = visualize.denorm_image(images[i_bs, j_bs, :]) 86 | label = visualize.vis_label(labels[i_bs, j_bs], 1, 128) 87 | cv2.imwrite(path + 'img.jpg', image) 88 | cv2.imwrite(path + 'lab.jpg', label) 89 | 90 | for j_bs in range(preds.shape[1]): 91 | path = args.sample_dir + '/' + str(i_bs)+'_'+str(j_bs) + '#' 92 | pred = preds[i_bs,j_bs]*255 93 | cv2.imwrite(path + 'pred.jpg', pred) 94 | 95 | for j_bs in range(prev_labs.shape[1]): 96 | path = args.sample_dir + '/' + str(i_bs)+'_'+str(j_bs+2) + '#' 97 | prev_lab = prev_labs[i_bs,j_bs]*255 98 | cv2.imwrite(path + 'parev_lab.jpg', prev_lab) 99 | 100 | def main(): 101 | global args, logger, writer 102 | args = get_parser().parse_args() 103 | logger_train = get_logger() 104 | random.seed(20170624) 105 | logger_train.info((args)) 106 | 107 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 108 | utils.mkdir(args.snapshot_dir, args.sample_dir) 109 | 110 | # setting up model 111 | model = AGSSVOS(init_atn=args.init_atn, freeze=args.freeze).cuda() 112 | model = torch.nn.DataParallel(model).cuda() 113 | model.train() 114 | 115 | for m in model.module.Encoder.modules(): 116 | if isinstance(m, nn.BatchNorm2d): 117 | m.eval() 118 | if args.set_bn_no_update: 119 | for p in m.parameters(): 120 | p.requires_grad = False 121 | 122 | if args.restore != None: 123 | assert os.path.isfile(args.restore), "no restore file found at %s" % (args.restore) 124 | logger_train.info("loading from %s" % (args.restore)) 125 | 126 | state = model.state_dict() 127 | checkpoint = torch.load(args.restore) 128 | if args.finetune: 129 | checkpoint = checkpoint['seg'] 130 | checkpoint = {k: v for k, v in checkpoint.items() if k in state} 131 | state.update(checkpoint) 132 | model.load_state_dict(state) 133 | 134 | del checkpoint 135 | torch.cuda.empty_cache() 136 | 137 | if args.finetune: 138 | flow_infer = Inference_flow(args, train_flow=True, resume=args.restore) 139 | else: 140 | flow_infer = Inference_flow(args, train_flow=True) 141 | 142 | params = [] 143 | scale_lr = [] 144 | assert args.lr_atn != args.lr_after_atn 145 | for key, value in dict(model.module.named_parameters()).items(): 146 | if args.lr_atn and ('atn' in key or 'pred2' in key or 'RF2' in key) and not args.finetune: 147 | flag = True 148 | elif args.lr_after_atn and ('atn' in key or 'pred2' in key or 'RF2' in key) and not args.finetune: 149 | flag = True 150 | else: 151 | flag = False 152 | if value.requires_grad: 153 | if flag: 154 | scale_lr.append(True) 155 | print('lrx10', key) 156 | else: 157 | scale_lr.append(False) 158 | params += [{'params':[value],'lr':args.lr*10 if flag else args.lr , 'weight_decay': 4e-5}] 159 | optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=4e-5) 160 | spec_vid = None 161 | spec_obj_ind = None 162 | 163 | trainloader = data.DataLoader( 164 | Trainset(root_data=args.root_data, json_meta_list=args.meta_list, 165 | sample_size=args.sample_size, test_mode=False, spec_vid=spec_vid, spec_obj_ind=spec_obj_ind, 166 | step=1, fix_size=False, half_size=False, random_ref=args.random_ref, random_skip=args.random_skip), 167 | batch_size=args.batch_size, shuffle=True, num_workers=1, pin_memory=True) 168 | 169 | # training 170 | tot_iter = len(trainloader) 171 | logger_train.info("Total iteration per epoch is %d" % (tot_iter)) 172 | tot_time = [] 173 | loss_set = [] 174 | iou_set = [] 175 | optimizer.zero_grad() 176 | 177 | for epoch in range(args.start_epoch, args.epoch): 178 | for i_iter, batch in enumerate(trainloader): 179 | start_time = timeit.default_timer() 180 | 181 | img, lab, ori_img = batch 182 | 183 | img = img[0].cuda().float() 184 | lab = lab[0].cuda().float() 185 | ori_img = ori_img.numpy() 186 | # img KT3HW, lab KTHW, ori_img, KTHW3# 187 | 188 | ### It may be better to move this augmentation into the dataset preprocessing ## 189 | if random.uniform(0,1)>0.5 and args.random_crop: 190 | ### random resize ### 191 | coord = [1e4,1e4,0,0] 192 | lab_agno = lab.sum(0) 193 | val_cnt = 0 194 | for i in range(lab_agno.shape[0]): 195 | idx = torch.nonzero(lab_agno[i]>0) 196 | if idx.shape[0] == 0: 197 | continue 198 | val_cnt += 1 199 | h0 = idx[:,0].min().item() 200 | w0 = idx[:,1].min().item() 201 | h1 = idx[:,0].max().item() 202 | w1 = idx[:,1].max().item() 203 | coord[0] = min(coord[0], h0) 204 | coord[1] = min(coord[1], w0) 205 | coord[2] = max(coord[2], h1) 206 | coord[3] = max(coord[3], w1) 207 | if val_cnt < 2: 208 | logger_train.info(('The number of frames that have label is less than 2, continue..')) 209 | continue 210 | ori_shape = lab.shape[-2:] 211 | rand_coord = [0]*4 212 | 213 | if random.uniform(0,1) > 0.3: 214 | scale = random.uniform(0,1) 215 | else: 216 | scale = 1 217 | rand_coord[0] = coord[0] * scale 218 | rand_coord[1] = coord[1] * scale 219 | rand_coord[2] = (ori_shape[0]-coord[2]-1)*(1-scale)+coord[2]+1 220 | rand_coord[3] = (ori_shape[1]-coord[3]-1)*(1-scale)+coord[3]+1 221 | for j in range(4): 222 | rand_coord[j] = int(rand_coord[j]) 223 | 224 | old_img = img.clone() 225 | old_lab = lab.clone() 226 | ori_img = torch.FloatTensor(ori_img).cuda().transpose(-1,-2).transpose(-2,-3) 227 | old_ori_img = ori_img.clone() 228 | 229 | old_lab = old_lab[:,:,rand_coord[0]:rand_coord[2]+1,rand_coord[1]:rand_coord[3]+1] 230 | lab = F.upsample(old_lab, ori_shape, mode='bilinear', align_corners=True) 231 | lab = (lab>0.5).float() 232 | for i in range(img.shape[0]): 233 | img_obj = old_img[i,:,:,rand_coord[0]:rand_coord[2]+1,rand_coord[1]:rand_coord[3]+1] 234 | img[i] = F.upsample(img_obj, ori_shape, mode='bilinear', align_corners=True) 235 | img_obj = old_ori_img[0,:,:,rand_coord[0]:rand_coord[2]+1,rand_coord[1]:rand_coord[3]+1] 236 | ori_img[0] = F.upsample(img_obj, ori_shape, mode='bilinear', align_corners=True) 237 | ori_img = ori_img.transpose(-2,-3).transpose(-1,-2).cpu().numpy().astype(np.uint8) 238 | 239 | ### end of random resize ### 240 | 241 | if lab.shape[1] == 1: 242 | logger_train.info('lab.shape[1](vid_len) == 1, continue..') 243 | continue 244 | 245 | lr = utils.lr_poly(args.lr, i_iter, tot_iter, epoch, args.epoch) 246 | utils.adjust_optim_all(optimizer, lr, scale_lr) 247 | preds = [] 248 | prev_labs = [] 249 | preds.append(lab[:,0:1].contiguous()) 250 | preds.append(lab[:,1:2].contiguous()) 251 | merge_preds_ref = lab[:,0:1].contiguous().sum(0) 252 | for i in range(2, img.shape[1], 1): 253 | ms = model.forward(img[:,0], merge_preds_ref) 254 | flow = flow_infer.infer(ori_img[0,i], ori_img[0,i-1]) 255 | prev_lab = utils.flow_warp_tensor(preds[i-1], flow) 256 | 257 | prev_labs.append(prev_lab.detach()) 258 | merge_preds = prev_lab.max(0)[0] 259 | 260 | output, _ = model.forward(img[:,i], merge_preds, prev_lab.squeeze(1), ref=ms) 261 | 262 | cur_lab = lab[:,i].contiguous() 263 | 264 | if args.loss_iou_maxmin: 265 | cur_loss = utils.loss_calc_iou(output, cur_lab.unsqueeze(1), unify=False, optim_hard=False, 266 | square=False) # try this 267 | else: 268 | cur_loss = utils.loss_calc_iou_v2(output, cur_lab.unsqueeze(1), unify=False, optim_hard=False, 269 | square=False) # try this 270 | 271 | loss_set.append(cur_loss.item()) 272 | 273 | iou = utils.calc_iou(output.data, cur_lab.long(), merge=False) 274 | iou_set.append(np.mean(iou)) 275 | 276 | optimizer.zero_grad() 277 | cur_loss.backward() 278 | optimizer.step() 279 | 280 | if args.iou_thr_per_obj: 281 | output = output.detach() 282 | new_output = torch.zeros_like(output).cuda().float() 283 | for j in range(new_output.shape[0]): 284 | if iou[j] > 0.5: 285 | new_output[j] = output[j] 286 | else: 287 | new_output[j] = lab[j:j+1,i] 288 | new_output = new_output.contiguous() 289 | preds.append(new_output.detach()) 290 | else: 291 | if np.mean(iou) > 0.5: 292 | preds.append(output.detach()) 293 | else: 294 | preds.append(cur_lab.unsqueeze(1).detach()) 295 | 296 | end_time = timeit.default_timer() 297 | tot_time.append(end_time - start_time) 298 | 299 | if i_iter % 200 == 0: 300 | logger_train.info('show at %s' % args.sample_dir) 301 | try: 302 | preds = torch.cat(preds, dim=1) 303 | prev_labs = torch.cat(prev_labs, dim=1) 304 | except Exception as e: 305 | print(e) 306 | print('Ignore.. Continue..') 307 | continue 308 | if args.show_img: 309 | show(img.data.cpu().numpy(), lab.data.cpu().numpy(), preds.data.cpu().numpy().astype(np.float), 310 | prev_labs.data.cpu().numpy().astype(np.float32)) 311 | 312 | if i_iter % 20 == 0: 313 | run_time = np.mean(tot_time) 314 | rem_time = utils.calc_remain_time(run_time, i_iter, tot_iter, epoch, args.epoch) 315 | logger_train.info('iter = %d of %d in epoch = %d of %d, remain_time = %s' % 316 | (i_iter, tot_iter, epoch, args.epoch, rem_time)) 317 | tot_time = [] 318 | logger_train.info('lr = %f, loss = %f, iou = %f' % (lr, np.mean(loss_set), np.mean(iou_set))) 319 | loss_set = [] 320 | iou_set = [] 321 | 322 | if epoch % (args.epoch//5) == 0 or epoch == args.epoch - 1: 323 | path = os.path.join(args.snapshot_dir, 'model_' + str(epoch) + '.pth') 324 | logger_train.info('save model at %s' % path) 325 | torch.save({'seg':model.state_dict(), 'flow':flow_infer.model.state_dict()}, path) 326 | 327 | 328 | if __name__ == '__main__': 329 | main() 330 | 331 | 332 | 333 | -------------------------------------------------------------------------------- /train_ytv.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import os 4 | import time 5 | import logging 6 | import numpy as np 7 | import argparse 8 | import random 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.backends.cudnn as cudnn 14 | import torch.optim as optim 15 | from torch.utils import data 16 | import torch.nn.functional as F 17 | from torch.autograd import Variable 18 | 19 | from dataset.vos import Trainset 20 | from networks.agssvos import AGSSVOS 21 | sys.path.append('flow_inference') 22 | from flow_inference.flow_inference import Inference_flow 23 | from tools import preprocess, visualize, utils 24 | import timeit 25 | import cv2 26 | 27 | 28 | def get_parser(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--batch-size', type=int) 31 | parser.add_argument('--root-data', type=str) 32 | parser.add_argument('--root-all-data', type=str) 33 | parser.add_argument('--meta-list', type=str) 34 | parser.add_argument('--epoch', type=int) 35 | parser.add_argument('--start-epoch', type=int, default=0) 36 | parser.add_argument('--sample-size', type=int, default=10) 37 | parser.add_argument('--gpu', type=str, default='0') 38 | parser.add_argument('--lr', type=float) 39 | parser.add_argument('--finetune', action='store_true') 40 | parser.add_argument('--init-atn', action='store_true') 41 | parser.add_argument('--freeze', action='store_true') 42 | parser.add_argument('--set-bn-no-update', action='store_true') 43 | parser.add_argument('--random-crop', action='store_true') 44 | parser.add_argument('--iou-thr-per-obj', action='store_true') 45 | parser.add_argument('--lr-atn', action='store_true') 46 | parser.add_argument('--lr-after-atn', action='store_true') 47 | parser.add_argument('--three-frames-data', action='store_true') 48 | parser.add_argument('--loss-iou-maxmin', action='store_true') 49 | parser.add_argument('--random-ref', action='store_true') 50 | parser.add_argument('--random-skip', action='store_true') 51 | parser.add_argument('--step-size', type=float, default=4) 52 | parser.add_argument('--restore', type=str, default=None) 53 | parser.add_argument('--sample-dir', type=str) 54 | parser.add_argument('--snapshot-dir', type=str) 55 | parser.add_argument('--crop_size', type=int, default=512) 56 | parser.add_argument('--resize_h', type=int, default=360) 57 | parser.add_argument('--resize_w', type=int, default=640) 58 | parser.add_argument('--rgb_max', type=float, default=255.) 59 | parser.add_argument('--div_flow', type=int, default=20) 60 | parser.add_argument('--ignore_label', type=int, default=255) 61 | parser.add_argument('--scale_min', type=float, default=0.5, help='minimum random scale') 62 | parser.add_argument('--scale_max', type=float, default=2.0, help='maximum random scale') 63 | parser.add_argument('--rotate_min', type=float, default=-10, help='minimum random rotate') 64 | parser.add_argument('--rotate_max', type=float, default=10, help='maximum random rotate') 65 | parser.add_argument('--flow_checkpoint_path', type=str, default='models/FlowNet2-C_checkpoint.pth.tar', 66 | help='pretrained model for flownetC') 67 | parser.add_argument('--fix-lr', type=int, default=0) 68 | parser.add_argument('--show-img', action='store_true', help='show intermediate result') 69 | return parser 70 | 71 | # get logger 72 | def get_logger(): 73 | logger = logging.getLogger('train') 74 | logger.setLevel(logging.INFO) 75 | handler = logging.StreamHandler() 76 | fmt = "[%(asctime)s line %(lineno)d] %(message)s" 77 | handler.setFormatter(logging.Formatter(fmt)) 78 | logger.addHandler(handler) 79 | return logger 80 | 81 | def show(images, labels, preds, prev_labs): 82 | os.system('rm %s/*' % args.sample_dir) 83 | for i_bs in range(images.shape[0]): 84 | for j_bs in range(labels.shape[1]): 85 | path = args.sample_dir + '/' + str(i_bs)+'_'+str(j_bs) + '#' 86 | image = visualize.denorm_image(images[i_bs, j_bs, :]) 87 | label = visualize.vis_label(labels[i_bs, j_bs], 1, 128) 88 | cv2.imwrite(path + 'img.jpg', image) 89 | cv2.imwrite(path + 'lab.jpg', label) 90 | 91 | for j_bs in range(preds.shape[1]): 92 | path = args.sample_dir + '/' + str(i_bs)+'_'+str(j_bs) + '#' 93 | pred = preds[i_bs,j_bs]*255 94 | cv2.imwrite(path + 'pred.jpg', pred) 95 | 96 | for j_bs in range(prev_labs.shape[1]): 97 | path = args.sample_dir + '/' + str(i_bs)+'_'+str(j_bs+2) + '#' 98 | prev_lab = prev_labs[i_bs,j_bs]*255 99 | cv2.imwrite(path + 'parev_lab.jpg', prev_lab) 100 | 101 | def main(): 102 | global args, logger, writer 103 | args = get_parser().parse_args() 104 | logger_train = get_logger() 105 | random.seed(20170624) 106 | logger_train.info((args)) 107 | 108 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 109 | utils.mkdir(args.snapshot_dir, args.sample_dir) 110 | 111 | # setting up model 112 | model = AGSSVOS(init_atn=args.init_atn, freeze=args.freeze).cuda() 113 | model = torch.nn.DataParallel(model).cuda() 114 | model.train() 115 | 116 | for m in model.module.Encoder.modules(): 117 | if isinstance(m, nn.BatchNorm2d): 118 | m.eval() 119 | if args.set_bn_no_update: 120 | for p in m.parameters(): 121 | p.requires_grad = False 122 | 123 | if args.restore != None: 124 | assert os.path.isfile(args.restore), "no restore file found at %s" % (args.restore) 125 | logger_train.info("loading from %s" % (args.restore)) 126 | 127 | state = model.state_dict() 128 | checkpoint = torch.load(args.restore) 129 | if args.finetune: 130 | checkpoint = checkpoint['seg'] 131 | checkpoint = {k: v for k, v in checkpoint.items() if k in state} 132 | state.update(checkpoint) 133 | model.load_state_dict(state) 134 | 135 | del checkpoint 136 | torch.cuda.empty_cache() 137 | 138 | if args.finetune: 139 | flow_infer = Inference_flow(args, train_flow=True, resume=args.restore) 140 | else: 141 | flow_infer = Inference_flow(args, train_flow=True) 142 | 143 | params = [] 144 | scale_lr = [] 145 | assert args.lr_atn != args.lr_after_atn 146 | for key, value in dict(model.module.named_parameters()).items(): 147 | if args.lr_atn and ('atn' in key or 'pred2' in key or 'RF2' in key): 148 | flag = True 149 | elif args.lr_after_atn and ('atn' in key or 'pred2' in key or 'RF2' in key): 150 | flag = True 151 | else: 152 | flag = False 153 | if value.requires_grad: 154 | if flag: 155 | scale_lr.append(True) 156 | print('lrx10', key) 157 | else: 158 | scale_lr.append(False) 159 | params += [{'params':[value],'lr':args.lr*10 if flag else args.lr , 'weight_decay': 4e-5}] 160 | optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=4e-5) 161 | spec_vid = None 162 | spec_obj_ind = None 163 | 164 | trainloader = data.DataLoader( 165 | Trainset(root_data=args.root_data, json_meta_list=args.meta_list, 166 | sample_size=args.sample_size, test_mode=False, spec_vid=spec_vid, spec_obj_ind=spec_obj_ind, 167 | step=5, fix_size=False, half_size=True, random_ref=args.random_ref, random_skip=args.random_skip), 168 | batch_size=args.batch_size, shuffle=True, num_workers=1, pin_memory=True) 169 | 170 | 171 | # training 172 | tot_iter = len(trainloader) 173 | logger_train.info("Total iteration per epoch is %d" % (tot_iter)) 174 | tot_time = [] 175 | loss_set = [] 176 | iou_set = [] 177 | optimizer.zero_grad() 178 | 179 | for epoch in range(args.start_epoch, args.epoch): 180 | for i_iter, batch in enumerate(trainloader): 181 | start_time = timeit.default_timer() 182 | 183 | img, lab, ori_img = batch 184 | 185 | img = img[0].cuda().float() 186 | lab = lab[0].cuda().float() 187 | ori_img = ori_img.numpy() 188 | # img KT3HW, lab KTHW, ori_img, KTHW3# 189 | 190 | ### It may be better to move this augmentation into the dataset preprocessing ## 191 | if random.uniform(0,1)>0.5 and args.random_crop: 192 | ### random resize ### 193 | coord = [1e4,1e4,0,0] 194 | lab_agno = lab.sum(0) 195 | val_cnt = 0 196 | for i in range(lab_agno.shape[0]): 197 | idx = torch.nonzero(lab_agno[i]>0) 198 | if idx.shape[0] == 0: 199 | continue 200 | val_cnt += 1 201 | h0 = idx[:,0].min().item() 202 | w0 = idx[:,1].min().item() 203 | h1 = idx[:,0].max().item() 204 | w1 = idx[:,1].max().item() 205 | coord[0] = min(coord[0], h0) 206 | coord[1] = min(coord[1], w0) 207 | coord[2] = max(coord[2], h1) 208 | coord[3] = max(coord[3], w1) 209 | if val_cnt < 2: 210 | logger_train.info(('The number of frames that have label is less than 2, continue..')) 211 | continue 212 | ori_shape = lab.shape[-2:] 213 | rand_coord = [0]*4 214 | 215 | if random.uniform(0,1) > 0.3: 216 | scale = random.uniform(0,1) 217 | else: 218 | scale = 1 219 | rand_coord[0] = coord[0] * scale 220 | rand_coord[1] = coord[1] * scale 221 | rand_coord[2] = (ori_shape[0]-coord[2]-1)*(1-scale)+coord[2]+1 222 | rand_coord[3] = (ori_shape[1]-coord[3]-1)*(1-scale)+coord[3]+1 223 | for j in range(4): 224 | rand_coord[j] = int(rand_coord[j]) 225 | 226 | old_img = img.clone() 227 | old_lab = lab.clone() 228 | ori_img = torch.FloatTensor(ori_img).cuda().transpose(-1,-2).transpose(-2,-3) 229 | old_ori_img = ori_img.clone() 230 | 231 | old_lab = old_lab[:,:,rand_coord[0]:rand_coord[2]+1,rand_coord[1]:rand_coord[3]+1] 232 | lab = F.upsample(old_lab, ori_shape, mode='bilinear', align_corners=True) 233 | lab = (lab>0.5).float() 234 | for i in range(img.shape[0]): 235 | img_obj = old_img[i,:,:,rand_coord[0]:rand_coord[2]+1,rand_coord[1]:rand_coord[3]+1] 236 | img[i] = F.upsample(img_obj, ori_shape, mode='bilinear', align_corners=True) 237 | img_obj = old_ori_img[0,:,:,rand_coord[0]:rand_coord[2]+1,rand_coord[1]:rand_coord[3]+1] 238 | ori_img[0] = F.upsample(img_obj, ori_shape, mode='bilinear', align_corners=True) 239 | ori_img = ori_img.transpose(-2,-3).transpose(-1,-2).cpu().numpy().astype(np.uint8) 240 | 241 | ### end of random resize ### 242 | 243 | if lab.shape[1] == 1: 244 | logger_train.info('lab.shape[1](vid_len) == 1, continue..') 245 | continue 246 | 247 | lr = utils.lr_poly(args.lr, i_iter, tot_iter, epoch, args.epoch) 248 | utils.adjust_optim_all(optimizer, lr, scale_lr) 249 | preds = [] 250 | prev_labs = [] 251 | preds.append(lab[:,0:1].contiguous()) 252 | preds.append(lab[:,1:2].contiguous()) 253 | merge_preds_ref = lab[:,0:1].contiguous().sum(0) 254 | for i in range(2, img.shape[1], 1): 255 | ms = model.forward(img[:,0], merge_preds_ref) 256 | flow = flow_infer.infer(ori_img[0,i], ori_img[0,i-1]) 257 | prev_lab = utils.flow_warp_tensor(preds[i-1], flow) 258 | 259 | prev_labs.append(prev_lab.detach()) 260 | merge_preds = prev_lab.max(0)[0] 261 | 262 | output, _ = model.forward(img[:,i], merge_preds, prev_lab.squeeze(1), ref=ms) 263 | 264 | cur_lab = lab[:,i].contiguous() 265 | 266 | if args.loss_iou_maxmin: 267 | cur_loss = utils.loss_calc_iou(output, cur_lab.unsqueeze(1), unify=False, optim_hard=False, 268 | square=False) # try this 269 | else: 270 | cur_loss = utils.loss_calc_iou_v2(output, cur_lab.unsqueeze(1), unify=False, optim_hard=False, 271 | square=False) # try this 272 | 273 | loss_set.append(cur_loss.item()) 274 | 275 | iou = utils.calc_iou(output.data, cur_lab.long(), merge=False) 276 | iou_set.append(np.mean(iou)) 277 | 278 | optimizer.zero_grad() 279 | cur_loss.backward() 280 | optimizer.step() 281 | 282 | if args.iou_thr_per_obj: 283 | output = output.detach() 284 | new_output = torch.zeros_like(output).cuda().float() 285 | for j in range(new_output.shape[0]): 286 | if iou[j] > 0.5: 287 | new_output[j] = output[j] 288 | else: 289 | new_output[j] = lab[j:j+1,i] 290 | new_output = new_output.contiguous() 291 | preds.append(new_output.detach()) 292 | else: 293 | if np.mean(iou) > 0.5: 294 | preds.append(output.detach()) 295 | else: 296 | preds.append(cur_lab.unsqueeze(1).detach()) 297 | 298 | end_time = timeit.default_timer() 299 | tot_time.append(end_time - start_time) 300 | 301 | if i_iter % 200 == 0: 302 | logger_train.info('show at %s' % args.sample_dir) 303 | try: 304 | preds = torch.cat(preds, dim=1) 305 | prev_labs = torch.cat(prev_labs, dim=1) 306 | except Exception as e: 307 | print(e) 308 | print('Ignore.. Continue..') 309 | continue 310 | if args.show_img: 311 | show(img.data.cpu().numpy(), lab.data.cpu().numpy(), preds.data.cpu().numpy().astype(np.float), 312 | prev_labs.data.cpu().numpy().astype(np.float32)) 313 | 314 | if i_iter % 20 == 0: 315 | run_time = np.mean(tot_time) 316 | rem_time = utils.calc_remain_time(run_time, i_iter, tot_iter, epoch, args.epoch) 317 | logger_train.info('iter = %d of %d in epoch = %d of %d, remain_time = %s' % 318 | (i_iter, tot_iter, epoch, args.epoch, rem_time)) 319 | tot_time = [] 320 | logger_train.info('lr = %f, loss = %f, iou = %f' % (lr, np.mean(loss_set), np.mean(iou_set))) 321 | loss_set = [] 322 | iou_set = [] 323 | 324 | 325 | if epoch % (args.epoch//5) == 0 or epoch == args.epoch - 1: 326 | path = os.path.join(args.snapshot_dir, 'model_' + str(epoch) + '.pth') 327 | logger_train.info('save model at %s' % path) 328 | torch.save({'seg':model.state_dict(), 'flow':flow_infer.model.state_dict()}, path) 329 | 330 | 331 | if __name__ == '__main__': 332 | main() 333 | 334 | 335 | 336 | -------------------------------------------------------------------------------- /val_davis.sh: -------------------------------------------------------------------------------- 1 | name=train_davis 2 | echo $name 3 | tgt_dir='val_dir/'$name 4 | if [ ! -d $tgt_dir ]; then 5 | mkdir -p $tgt_dir 6 | fi 7 | python3 infer_davis.py \ 8 | --root-data='data/davis2017/trainval' \ 9 | --root-all-data='data/davis2017/trainval' \ 10 | --list-path='data/davis2017/trainval/val_meta.json' \ 11 | --restore='checkpoints/train_davis/model_199.pth' \ 12 | --batch-size=1 \ 13 | --start-epoch=0 \ 14 | --epoch=1 \ 15 | --sample-size=20 \ 16 | --lr=1e-5 \ 17 | --gpu='7' \ 18 | --sample-dir=$tgt_dir'/sample' \ 19 | --save-dir=$tgt_dir'/save' \ 20 | --test-mode=1 \ 21 | 2>&1 | tee $tgt_dir/val.log 22 | -------------------------------------------------------------------------------- /val_davis_ft.sh: -------------------------------------------------------------------------------- 1 | name=ft_davis 2 | echo $name 3 | tgt_dir='val_dir/'$name 4 | if [ ! -d $tgt_dir ]; then 5 | mkdir -p $tgt_dir 6 | fi 7 | python3 infer_davis.py \ 8 | --root-data='data/davis2017/trainval' \ 9 | --root-all-data='data/davis2017/trainval' \ 10 | --list-path='data/davis2017/trainval/val_meta.json' \ 11 | --restore='checkpoints/ft_davis/model_99.pth' \ 12 | --batch-size=1 \ 13 | --start-epoch=0 \ 14 | --epoch=1 \ 15 | --sample-size=20 \ 16 | --lr=1e-5 \ 17 | --gpu='4' \ 18 | --sample-dir=$tgt_dir'/sample' \ 19 | --save-dir=$tgt_dir'/save' \ 20 | --test-mode=1 \ 21 | 2>&1 | tee $tgt_dir/val.log 22 | -------------------------------------------------------------------------------- /val_ytv.sh: -------------------------------------------------------------------------------- 1 | name=train_ytv 2 | echo $name 3 | tgt_dir='val_dir/'$name 4 | if [ ! -d $tgt_dir ]; then 5 | mkdir -p $tgt_dir 6 | fi 7 | python3 infer_ytv.py \ 8 | --root-data='data/youtube_vos/valid/' \ 9 | --root-all-data='data/youtube_vos/valid_all_frames/' \ 10 | --list-path='data/youtube_vos/valid/meta.json' \ 11 | --restore='checkpoints/train_ytv/model_4.pth' \ 12 | --batch-size=1 \ 13 | --start-epoch=0 \ 14 | --epoch=1 \ 15 | --sample-size=20 \ 16 | --lr=1e-5 \ 17 | --gpu='4' \ 18 | --sample-dir=$tgt_dir'/sample' \ 19 | --save-dir=$tgt_dir'/save' \ 20 | --test-mode=1 \ 21 | --show-step=1 \ 22 | 2>&1 | tee $tgt_dir/val.log 23 | --------------------------------------------------------------------------------