├── .gitattributes ├── LICENSE ├── README.md ├── assert └── framework.png ├── create_list_nyuv2.py ├── dataloader ├── __pycache__ │ ├── nyu_transform.cpython-35.pyc │ ├── nyudv2_dataloader.cpython-35.pyc │ └── nyudv2_dataloader_224.cpython-35.pyc ├── nyu_transform.py └── nyudv2_dataloader.py ├── evaluate.py ├── models ├── R_CLSTM_modules.py ├── __init__.py ├── backbone_dict.py ├── densenet.py ├── loss.py ├── modules.py ├── net.py ├── refinenet_dict.py ├── resnet.py └── senet.py ├── options ├── __init__.py ├── testopt.py └── trainopt.py ├── train.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 bdseal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 2 | [![Python 3.6](https://img.shields.io/badge/python-3.6-green.svg)](https://www.python.org/) 3 | [![PyTorch 1.0](https://img.shields.io/badge/pytorch-1.0-green.svg)](https://pytorch.org/) 4 | 5 | # Exploiting Temporal Consistency for Real-Time Video Depth Estimation 6 | This is the UNOFFICIAL implementation of the paper [***Exploiting Temporal Consistency for Real-Time Video Depth Estimation***](https://arxiv.org/abs/1908.03706), ***ICCV 2019, Haokui Zhang, Chunhua Shen, Ying Li, Yuanzhouhan Cao, Yu Liu, Youliang Yan.*** 7 | 8 | You can find official implementation (WITHOUT TRAINING SCRIPTS) [here](https://github.com/hkzhang91/ST-CLSTM). 9 | 10 | ## Framework 11 | ![](./assert/framework.png) 12 | 13 | ## Dependencies 14 | - [Python3.6](https://www.python.org/downloads/) 15 | - [PyTorch 1.0+](https://pytorch.org/) 16 | - [NYU Depth v2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html) 17 | 18 | ## Pre-processed Data 19 | We didn't preprocess data as in the official implementation. Instead, we use the dataset shared by [Junjie Hu](https://github.com/JunjH/Revisiting_Single_Depth_Estimation), which is also used by [SARPN](https://github.com/Xt-Chen/SARPN/blob/master/README.md). 20 | You can download the pre-processed data from [here](https://drive.google.com/file/d/1WoOZOBpOWfmwe7bknWS5PMUCLBPFKTOw/view?usp=sharing). 21 | 22 | When you have downloaded the dataset, run the following command to creat training list. 23 | ```bash 24 | python create_list_nyuv2.py 25 | ``` 26 | 27 | You can also follow the procedure of [ST-CLSTM](https://github.com/hkzhang91/ST-CLSTM) to preprocess the data. It is based on the oficial Matlab [Toolbox](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html). If Matlab is unavailable for you, there is also a [Python Port Toolbox](https://github.com/GabrielMajeri/nyuv2-python-toolbox) for processing the raw dataset by [GabrielMajeri](https://github.com/GabrielMajeri), which contains code for Higher-level interface to the labeled subset, Raw dataset extraction and preprocessing and Performing data augmentation. 28 | 29 | The final folder structure is shown below. 30 | ``` 31 | data_root 32 | |- raw_nyu_v2_250k 33 | | |- train 34 | | | |- basement_0001a 35 | | | | |- rgb 36 | | | | | |- rgb_00000.jpg 37 | | | | | |_ ... 38 | | | | |- depth 39 | | | | | |- depth_00000.png 40 | | | | | |_ ... 41 | | | |_ ... 42 | | |- test_fps_30_fl5_end 43 | | | |- 0000 44 | | | | |- rgb 45 | | | | | |- rgb_00000.jpg 46 | | | | | |- rgb_00001.jpg 47 | | | | | |- ... 48 | | | | | |- rgb_00004.jpg 49 | | | | |- depth 50 | | | | | |- depth_00000.png 51 | | | | | |- depth_00001.png 52 | | | | | |- ... 53 | | | | | |- depth_00004.png 54 | | | |- ... 55 | | |- test_fps_30_fl4_end 56 | | |- test_fps_30_fl3_end> 57 | ``` 58 | ## Train 59 | As an example, use the following command to train on NYUDV2.
60 | 61 | ```bash 62 | CUDA_VISIBLE_DEVICES="0,1,2,3" python train.py --epochs 20 --batch_size 128 \ 63 | --resume --do_summary --backbone resnet18 --refinenet R_CLSTM_5 \ 64 | --trainlist_path ./data_list/raw_nyu_v2_250k/raw_nyu_v2_250k_fps30_fl5_op0_end_train.json \ 65 | --root_path ./data/ --checkpoint_dir ./checkpoint/ --logdir ./log/ 66 | ``` 67 | ## Evaluation 68 | Use the following command to evaluate the trained model on ST-CLSTM [test data](https://github.com/hkzhang91/ST-CLSTM).
69 | 70 | ```bash 71 | CUDA_VISIBLE_DEVICES="0" python evaluate.py --batch_size 1 --backbone resnet18 --refinenet R_CLSTM_5 --loadckpt ./checkpoint/ \ 72 | --testlist_path ./data_list/raw_nyu_v2_250k/raw_nyu_v2_250k_fps30_fl5_op0_end_test.json \ 73 | --root_path ./data/st-clstm/ 74 | ``` 75 | ## Pretrained Model 76 | You can download the pretrained model: [NYUDV2](https://github.com/hkzhang91/ST-CLSTM/tree/master/CLSTM_Depth_Estimation-master/prediction/trained_models). 77 | 78 | ## Citation 79 | 80 | ```bibtex 81 | @inproceedings{zhang2019temporal, 82 | title = {Exploiting Temporal Consistency for Real-Time Video Depth Estimation}, 83 | author = {Haokui Zhang and Chunhua Shen and Ying Li and Yuanzhouhan Cao and Yu Liu and Youliang Yan}, 84 | conference={International Conference on Computer Vision}, 85 | year = {2019} 86 | } 87 | ``` -------------------------------------------------------------------------------- /assert/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihaox/ST-CLSTM/a7494489ca85a9b2ddb8840ef31668dfeaabc85c/assert/framework.png -------------------------------------------------------------------------------- /create_list_nyuv2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | from glob import glob 6 | # from natsort import natsorted 7 | import re 8 | 9 | parser = argparse.ArgumentParser(description='raw_nyu_v2') 10 | parser.add_argument('--dataset', type=str, default='raw_nyu_v2_250k') 11 | parser.add_argument('--test_loc', type=str, default='end') 12 | parser.add_argument('--fps', type=int, default=30) 13 | parser.add_argument('--fl', type=int, default=5) 14 | parser.add_argument('--overlap', type=int, default=0) 15 | parser.add_argument('--list_save_dir', type=str, default='./data_list') 16 | parser.add_argument('--source_dir', type=str, default='./data/') 17 | args = parser.parse_args() 18 | 19 | args.jpg_png_save_dir = args.source_dir #+ args.dataset 20 | 21 | 22 | def make_if_not_exist(path): 23 | if not os.path.exists(path): 24 | os.makedirs(path) 25 | 26 | 27 | def video_split(frame_len, frame_train, interval, overlap): 28 | sample_interval = frame_train - overlap 29 | indices = [] 30 | for start in range(interval): 31 | index_list = list(range(start, frame_len - frame_train * interval + 1, sample_interval)) 32 | [indices.append(list(range(num, num+frame_train*interval, interval))) for num in index_list] 33 | indices.append(list(range(frame_len - frame_train * interval - start, frame_len - start, interval))) 34 | 35 | return indices 36 | 37 | def atoi(text): 38 | return int(text) if text.isdigit() else text 39 | 40 | def natural_keys(text): 41 | ''' 42 | alist.sort(key=natural_keys) sorts in human order 43 | http://nedbatchelder.com/blog/200712/human_sorting.html 44 | (See Toothy's implementation in the comments) 45 | ''' 46 | return [ atoi(c) for c in re.split(r'(\d+)', text) ] 47 | 48 | def create_dict(dataset, list_save_dir, jpg_png_save_dir): 49 | train_dir = os.path.join(jpg_png_save_dir, 'nyu2_train') 50 | # test_dir = os.path.join(jpg_png_save_dir, 'test_fps{}_fl{}_{}'.format(args.fps, args.fl,args.test_loc)) 51 | interval = 30 // args.fps 52 | 53 | train_dict=[] 54 | subset_list = os.listdir(train_dir) 55 | 56 | for subset in subset_list: 57 | subset_source_dir = os.path.join(train_dir, subset) 58 | print(subset_source_dir) 59 | rgb_list = glob(subset_source_dir + '/*.jpg') 60 | depth_list = glob(subset_source_dir + '/*.png') 61 | 62 | rgb_list.sort(key=natural_keys) 63 | depth_list.sort(key=natural_keys) 64 | print(rgb_list) 65 | 66 | 67 | rgb_list_new = ['/'.join(rgb_info.split('/')[-5:]) for rgb_info in rgb_list] 68 | depth_list_new = ['/'.join(depth_info.split('/')[-5:]) for depth_info in depth_list] 69 | 70 | indices = video_split(len(depth_list), args.fl, interval, args.overlap) 71 | 72 | for index in indices: 73 | rgb_index = [] 74 | depth_index = [] 75 | [rgb_index.append(rgb_list_new[id]) for id in index] 76 | [depth_index.append(depth_list_new[id]) for id in index] 77 | 78 | train_info = { 79 | 'rgb_index': rgb_index, 80 | 'depth_index': depth_index, 81 | 'scene_name': subset, 82 | "test_index": args.fl-1, 83 | } 84 | train_dict.append(train_info) 85 | 86 | # test_dict = [] 87 | # subset_list = os.listdir(test_dir) 88 | # for subset in subset_list: 89 | # subset_source_dir = os.path.join(test_dir, subset) 90 | # rgb_list = glob(subset_source_dir + '/rgb/rgb_*.jpg') 91 | # depth_list = glob(subset_source_dir + '/depth/depth_*.png') 92 | # rgb_list.sort() 93 | # depth_list.sort() 94 | 95 | # rgb_list_new = ['/'.join(rgb_info.split('/')[-5:]) for rgb_info in rgb_list] 96 | # depth_list_new = ['/'.join(depth_info.split('/')[-5:]) for depth_info in depth_list] 97 | 98 | # test_index = int(open(subset_source_dir + '/depth/frame_index.txt').read()) 99 | 100 | # test_info = { 101 | # 'rgb_index': rgb_list_new, 102 | # 'depth_index': depth_list_new, 103 | # 'scene_name': subset, 104 | # 'test_index': test_index 105 | # } 106 | # test_dict.append(test_info) 107 | 108 | list_save_dir = os.path.join(list_save_dir, dataset) 109 | make_if_not_exist(list_save_dir) 110 | train_info_save = list_save_dir + '/{}_fps{}_fl{}_op{}_{}_train.json'.format(dataset, args.fps, args.fl, args.overlap, args.test_loc) 111 | # test_info_save = list_save_dir + '/{}_fps{}_fl{}_op{}_{}_test.json'.format(dataset, args.fps, args.fl, args.overlap, args.test_loc) 112 | 113 | with open(train_info_save, 'w') as dst_file: 114 | json.dump(train_dict, dst_file) 115 | # with open(test_info_save, 'w') as dst_file: 116 | # json.dump(test_dict, dst_file) 117 | 118 | if __name__ == '__main__': 119 | 120 | create_dict(args.dataset, args.list_save_dir, args.jpg_png_save_dir) -------------------------------------------------------------------------------- /dataloader/__pycache__/nyu_transform.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihaox/ST-CLSTM/a7494489ca85a9b2ddb8840ef31668dfeaabc85c/dataloader/__pycache__/nyu_transform.cpython-35.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/nyudv2_dataloader.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihaox/ST-CLSTM/a7494489ca85a9b2ddb8840ef31668dfeaabc85c/dataloader/__pycache__/nyudv2_dataloader.cpython-35.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/nyudv2_dataloader_224.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihaox/ST-CLSTM/a7494489ca85a9b2ddb8840ef31668dfeaabc85c/dataloader/__pycache__/nyudv2_dataloader_224.cpython-35.pyc -------------------------------------------------------------------------------- /dataloader/nyu_transform.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | from PIL import Image, ImageOps 5 | import collections 6 | try: 7 | import accimage 8 | except ImportError: 9 | accimage = None 10 | import random 11 | import scipy.ndimage as ndimage 12 | 13 | import pdb 14 | 15 | 16 | def _is_pil_image(img): 17 | if accimage is not None: 18 | return isinstance(img, (Image.Image, accimage.Image)) 19 | else: 20 | return isinstance(img, Image.Image) 21 | 22 | 23 | def _is_numpy_image(img): 24 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) 25 | 26 | 27 | class RandomRotate(object): 28 | """Random rotation of the image from -angle to angle (in degrees) 29 | This is useful for dataAugmentation, especially for geometric problems such as FlowEstimation 30 | angle: max angle of the rotation 31 | interpolation order: Default: 2 (bilinear) 32 | reshape: Default: false. If set to true, image size will be set to keep every pixel in the image. 33 | diff_angle: Default: 0. Must stay less than 10 degrees, or linear approximation of flowmap will be off. 34 | """ 35 | 36 | def __init__(self, angle, diff_angle=0, order=2, reshape=False): 37 | self.angle = angle 38 | self.reshape = reshape 39 | self.order = order 40 | 41 | def __call__(self, sample): 42 | image, depth = sample['image'], sample['depth'] 43 | 44 | applied_angle = random.uniform(-self.angle, self.angle) 45 | angle1 = applied_angle 46 | angle1_rad = angle1 * np.pi / 180 47 | 48 | image = ndimage.interpolation.rotate( 49 | image, angle1, reshape=self.reshape, order=self.order) 50 | depth = ndimage.interpolation.rotate( 51 | depth, angle1, reshape=self.reshape, order=self.order) 52 | 53 | image = Image.fromarray(image) 54 | depth = Image.fromarray(depth) 55 | 56 | return {'image': image, 'depth': depth} 57 | 58 | class RandomHorizontalFlip(object): 59 | 60 | def __call__(self, sample): 61 | image, depth = sample['image'], sample['depth'] 62 | 63 | if not _is_pil_image(image): 64 | raise TypeError( 65 | 'img should be PIL Image. Got {}'.format(type(img))) 66 | if not _is_pil_image(depth): 67 | raise TypeError( 68 | 'img should be PIL Image. Got {}'.format(type(depth))) 69 | 70 | if random.random() < 0.5: 71 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 72 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT) 73 | 74 | return {'image': image, 'depth': depth} 75 | 76 | 77 | class Scale(object): 78 | """ Rescales the inputs and target arrays to the given 'size'. 79 | 'size' will be the size of the smaller edge. 80 | For example, if height > width, then image will be 81 | rescaled to (size * height / width, size) 82 | size: size of the smaller edge 83 | interpolation order: Default: 2 (bilinear) 84 | """ 85 | 86 | def __init__(self, size): 87 | self.size = size 88 | 89 | def __call__(self, sample): 90 | image, depth = sample['image'], sample['depth'] 91 | 92 | image = self.changeScale(image, self.size) 93 | depth = self.changeScale(depth, self.size,Image.NEAREST) 94 | 95 | return {'image': image, 'depth': depth} 96 | 97 | def changeScale(self, img, size, interpolation=Image.BILINEAR): 98 | 99 | if not _is_pil_image(img): 100 | raise TypeError( 101 | 'img should be PIL Image. Got {}'.format(type(img))) 102 | if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)): 103 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 104 | 105 | if isinstance(size, int): 106 | w, h = img.size 107 | if (w <= h and w == size) or (h <= w and h == size): 108 | return img 109 | if w < h: 110 | ow = size 111 | oh = int(size * h / w) 112 | return img.resize((ow, oh), interpolation) 113 | else: 114 | oh = size 115 | ow = int(size * w / h) 116 | return img.resize((ow, oh), interpolation) 117 | else: 118 | return img.resize(size[::-1], interpolation) 119 | 120 | 121 | class CenterCrop(object): 122 | def __init__(self, size_image, size_depth): 123 | self.size_image = size_image 124 | self.size_depth = size_depth 125 | 126 | def __call__(self, sample): 127 | image, depth = sample['image'], sample['depth'] 128 | ### crop image and depth to (304, 228) 129 | image = self.centerCrop(image, self.size_image) 130 | depth = self.centerCrop(depth, self.size_image) 131 | ### resize depth to (152, 114) downsample 2 132 | ow, oh = self.size_depth 133 | depth = depth.resize((ow, oh)) 134 | 135 | return {'image': image, 'depth': depth} 136 | 137 | def centerCrop(self, image, size): 138 | 139 | w1, h1 = image.size 140 | 141 | tw, th = size 142 | 143 | if w1 == tw and h1 == th: 144 | return image 145 | ## (320-304) / 2. = 8 146 | ## (240-228) / 2. = 8 147 | x1 = int(round((w1 - tw) / 2.)) 148 | y1 = int(round((h1 - th) / 2.)) 149 | 150 | image = image.crop((x1, y1, tw + x1, th + y1)) 151 | 152 | return image 153 | 154 | 155 | class ToTensor(object): 156 | """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor. 157 | Converts a PIL.Image or numpy.ndarray (H x W x C) in the range 158 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 159 | """ 160 | def __init__(self,is_test=False): 161 | self.is_test = is_test 162 | 163 | def __call__(self, sample): 164 | image, depth = sample['image'], sample['depth'] 165 | """ 166 | Args: 167 | pic (PIL.Image or numpy.ndarray): Image to be converted to tensor. 168 | Returns: 169 | Tensor: Converted image. 170 | """ 171 | # ground truth depth of training samples is stored in 8-bit while test samples are saved in 16 bit 172 | image = self.to_tensor(image) 173 | if self.is_test: 174 | depth = self.to_tensor(depth).float()/1000 175 | else: 176 | depth = self.to_tensor(depth).float()*10 177 | return {'image': image, 'depth': depth} 178 | 179 | def to_tensor(self, pic): 180 | if not(_is_pil_image(pic) or _is_numpy_image(pic)): 181 | raise TypeError( 182 | 'pic should be PIL Image or ndarray. Got {}'.format(type(pic))) 183 | 184 | if isinstance(pic, np.ndarray): 185 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 186 | ## convert image to (0,1) 187 | return img.float().div(255) 188 | 189 | if accimage is not None and isinstance(pic, accimage.Image): 190 | nppic = np.zeros( 191 | [pic.channels, pic.height, pic.width], dtype=np.float32) 192 | pic.copyto(nppic) 193 | return torch.from_numpy(nppic) 194 | 195 | # handle PIL Image 196 | if pic.mode == 'I': 197 | img = torch.from_numpy(np.array(pic, np.int32, copy=False)) 198 | elif pic.mode == 'I;16': 199 | img = torch.from_numpy(np.array(pic, np.int16, copy=False)) 200 | else: 201 | img = torch.ByteTensor( 202 | torch.ByteStorage.from_buffer(pic.tobytes())) 203 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 204 | if pic.mode == 'YCbCr': 205 | nchannel = 3 206 | elif pic.mode == 'I;16': 207 | nchannel = 1 208 | else: 209 | nchannel = len(pic.mode) 210 | img = img.view(pic.size[1], pic.size[0], nchannel) 211 | # put it from HWC to CHW format 212 | # yikes, this transpose takes 80% of the loading time/CPU 213 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 214 | if isinstance(img, torch.ByteTensor): 215 | return img.float().div(255) 216 | else: 217 | return img 218 | 219 | 220 | class Lighting(object): 221 | 222 | def __init__(self, alphastd, eigval, eigvec): 223 | self.alphastd = alphastd 224 | self.eigval = eigval 225 | self.eigvec = eigvec 226 | 227 | def __call__(self, sample): 228 | image, depth = sample['image'], sample['depth'] 229 | if self.alphastd == 0: 230 | return image 231 | 232 | alpha = image.new().resize_(3).normal_(0, self.alphastd) 233 | rgb = self.eigvec.type_as(image).clone()\ 234 | .mul(alpha.view(1, 3).expand(3, 3))\ 235 | .mul(self.eigval.view(1, 3).expand(3, 3))\ 236 | .sum(1).squeeze() 237 | 238 | image = image.add(rgb.view(3, 1, 1).expand_as(image)) 239 | 240 | return {'image': image, 'depth': depth} 241 | 242 | 243 | class Grayscale(object): 244 | 245 | def __call__(self, img): 246 | gs = img.clone() 247 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2]) 248 | gs[1].copy_(gs[0]) 249 | gs[2].copy_(gs[0]) 250 | return gs 251 | 252 | 253 | class Saturation(object): 254 | 255 | def __init__(self, var): 256 | self.var = var 257 | 258 | def __call__(self, img): 259 | gs = Grayscale()(img) 260 | alpha = random.uniform(-self.var, self.var) 261 | return img.lerp(gs, alpha) 262 | 263 | 264 | class Brightness(object): 265 | 266 | def __init__(self, var): 267 | self.var = var 268 | 269 | def __call__(self, img): 270 | gs = img.new().resize_as_(img).zero_() 271 | alpha = random.uniform(-self.var, self.var) 272 | 273 | return img.lerp(gs, alpha) 274 | 275 | 276 | class Contrast(object): 277 | 278 | def __init__(self, var): 279 | self.var = var 280 | 281 | def __call__(self, img): 282 | gs = Grayscale()(img) 283 | gs.fill_(gs.mean()) 284 | alpha = random.uniform(-self.var, self.var) 285 | return img.lerp(gs, alpha) 286 | 287 | 288 | class RandomOrder(object): 289 | """ Composes several transforms together in random order. 290 | """ 291 | 292 | def __init__(self, transforms): 293 | self.transforms = transforms 294 | 295 | def __call__(self, sample): 296 | image, depth = sample['image'], sample['depth'] 297 | 298 | if self.transforms is None: 299 | return {'image': image, 'depth': depth} 300 | order = torch.randperm(len(self.transforms)) 301 | for i in order: 302 | image = self.transforms[i](image) 303 | 304 | return {'image': image, 'depth': depth} 305 | 306 | 307 | class ColorJitter(RandomOrder): 308 | 309 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): 310 | self.transforms = [] 311 | if brightness != 0: 312 | self.transforms.append(Brightness(brightness)) 313 | if contrast != 0: 314 | self.transforms.append(Contrast(contrast)) 315 | if saturation != 0: 316 | self.transforms.append(Saturation(saturation)) 317 | 318 | 319 | class Normalize(object): 320 | def __init__(self, mean, std): 321 | self.mean = mean 322 | self.std = std 323 | 324 | def __call__(self, sample): 325 | """ 326 | Args: 327 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 328 | Returns: 329 | Tensor: Normalized image. 330 | """ 331 | image, depth = sample['image'], sample['depth'] 332 | 333 | image = self.normalize(image, self.mean, self.std) 334 | 335 | return {'image': image, 'depth': depth} 336 | 337 | def normalize(self, tensor, mean, std): 338 | """Normalize a tensor image with mean and standard deviation. 339 | See ``Normalize`` for more details. 340 | Args: 341 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 342 | mean (sequence): Sequence of means for R, G, B channels respecitvely. 343 | std (sequence): Sequence of standard deviations for R, G, B channels 344 | respecitvely. 345 | Returns: 346 | Tensor: Normalized image. 347 | """ 348 | 349 | # TODO: make efficient 350 | for t, m, s in zip(tensor, mean, std): 351 | t.sub_(m).div_(s) 352 | return tensor -------------------------------------------------------------------------------- /dataloader/nyudv2_dataloader.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from torch.utils.data import Dataset, DataLoader 4 | from torchvision import transforms, utils 5 | from PIL import Image 6 | import random 7 | import json 8 | import os 9 | from dataloader.nyu_transform import * 10 | 11 | try: 12 | import accimage 13 | except ImportError: 14 | accimage = None 15 | 16 | 17 | def load_annotation_data(dict_file_dir): 18 | with open(dict_file_dir, 'r') as data_file: 19 | return json.load(data_file) 20 | 21 | 22 | def pil_loader(path): 23 | return Image.open(path) 24 | 25 | 26 | def video_loader(root_dir, frame_indices): 27 | video = [] 28 | for index in frame_indices: 29 | image_path = os.path.join(root_dir, index) 30 | if os.path.exists(image_path): 31 | video.append(pil_loader(image_path)) 32 | else: 33 | return video 34 | 35 | return video 36 | 37 | 38 | class depthDataset(Dataset): 39 | """Face Landmarks dataset.""" 40 | 41 | def __init__(self, dict_dir, root_dir, transform=None, is_test=False): 42 | self.data_dict = load_annotation_data(dict_dir) 43 | self.root_dir = root_dir 44 | self.transform = transform 45 | 46 | def __getitem__(self, idx): 47 | rgb_index = self.data_dict[idx]['rgb_index'] 48 | depth_index = self.data_dict[idx]['depth_index'] 49 | test_index = self.data_dict[idx]['test_index'] 50 | 51 | rgb_clips = video_loader(self.root_dir, rgb_index) 52 | depth_clips = video_loader(self.root_dir, depth_index) 53 | 54 | rgb_tensor = [] 55 | depth_tensor = [] 56 | depth_scaled_tensor = [] 57 | for rgb_clip, depth_clip in zip(rgb_clips, depth_clips): 58 | sample = {'image': rgb_clip, 'depth': depth_clip} 59 | sample_new = self.transform(sample) 60 | rgb_tensor.append(sample_new['image']) 61 | depth_tensor.append(sample_new['depth']) 62 | 63 | return torch.stack(rgb_tensor, 0).permute(1, 0, 2, 3), \ 64 | torch.stack(depth_tensor, 0).permute(1, 0, 2, 3), \ 65 | test_index 66 | 67 | def __len__(self): 68 | return len(self.data_dict) 69 | 70 | def getTrainingData_NYUDV2(batch_size=64, dict_dir=None, root_dir=None): 71 | __imagenet_pca = { 72 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), 73 | 'eigvec': torch.Tensor([ 74 | [-0.5675, 0.7192, 0.4009], 75 | [-0.5808, -0.0045, -0.8140], 76 | [-0.5836, -0.6948, 0.4203], 77 | ]) 78 | } 79 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 80 | 'std': [0.229, 0.224, 0.225]} 81 | transformed_training = depthDataset(dict_dir=dict_dir, 82 | root_dir = root_dir, 83 | transform=transforms.Compose([ 84 | Scale(240), 85 | RandomHorizontalFlip(), 86 | RandomRotate(5), 87 | CenterCrop([304, 228], [152, 114]), 88 | ToTensor(), 89 | Lighting(0.1, __imagenet_pca[ 90 | 'eigval'], __imagenet_pca['eigvec']), 91 | ColorJitter( 92 | brightness=0.4, 93 | contrast=0.4, 94 | saturation=0.4, 95 | ), 96 | Normalize(__imagenet_stats['mean'], 97 | __imagenet_stats['std']) 98 | ])) 99 | 100 | dataloader_training = DataLoader(transformed_training, batch_size, shuffle=True, num_workers=1, pin_memory=False) 101 | 102 | return dataloader_training 103 | 104 | def getTestingData_NYUDV2(batch_size=64, dict_dir=None, root_dir=None, num_workers=4): 105 | 106 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 107 | 'std': [0.229, 0.224, 0.225]} 108 | 109 | 110 | transformed_testing = depthDataset(dict_dir=dict_dir, 111 | root_dir=root_dir, 112 | transform=transforms.Compose([ 113 | Scale(240), 114 | CenterCrop([304, 228], [152, 114]), 115 | ToTensor(is_test=True), 116 | Normalize(__imagenet_stats['mean'], 117 | __imagenet_stats['std']) 118 | ])) 119 | 120 | dataloader_testing = DataLoader(transformed_testing, batch_size, shuffle=False, num_workers=num_workers, pin_memory=False) 121 | 122 | return dataloader_testing 123 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import numpy as np 4 | import os 5 | from collections import OrderedDict 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.nn.functional 9 | from utils import * 10 | from options import get_args 11 | from dataloader import nyudv2_dataloader 12 | from models.backbone_dict import backbone_dict 13 | from models import modules 14 | from models import net 15 | 16 | 17 | args = get_args('test') 18 | # lode nyud v2 test set 19 | TestImgLoader = nyudv2_dataloader.getTestingData_NYUDV2(args.batch_size, args.testlist_path, args.root_path) 20 | # model 21 | backbone = backbone_dict[args.backbone]() 22 | Encoder = modules.E_resnet(backbone) 23 | 24 | if args.backbone in ['resnet50']: 25 | model = net.model(Encoder, num_features=2048, block_channel=[256, 512, 1024, 2048], refinenet=args.refinenet) 26 | elif args.backbone in ['resnet18', 'resnet34']: 27 | model = net.model(Encoder, num_features=512, block_channel=[64, 128, 256, 512], refinenet=args.refinenet) 28 | 29 | model = nn.DataParallel(model).cuda() 30 | 31 | # load test model 32 | if args.loadckpt is not None and args.loadckpt.endswith('.pth.tar'): 33 | print("loading the specific model in checkpoint_dir: {}".format(args.loadckpt)) 34 | state_dict = torch.load(args.loadckpt) 35 | model.load_state_dict(state_dict) 36 | elif os.path.isdir(args.loadckpt): 37 | all_saved_ckpts = [ckpt for ckpt in os.listdir(args.loadckpt) if ckpt.endswith(".pth.tar")] 38 | print(all_saved_ckpts) 39 | all_saved_ckpts = sorted(all_saved_ckpts, key=lambda x:int(x.split('_')[-1].split('.')[0])) 40 | loadckpt = os.path.join(args.loadckpt, all_saved_ckpts[-1]) 41 | start_epoch = int(all_saved_ckpts[-1].split('_')[-1].split('.')[0]) 42 | print("loading the lastest model in checkpoint_dir: {}".format(loadckpt)) 43 | state_dict = torch.load(loadckpt) 44 | model.load_state_dict(state_dict) 45 | else: 46 | print("You have not loaded any models.") 47 | 48 | def test(): 49 | 50 | model.eval() 51 | with torch.no_grad(): 52 | for batch_idx, sample in enumerate(TestImgLoader): 53 | print("Processing the {}th image!".format(batch_idx)) 54 | image, depth = sample[0], sample[1] 55 | depth = depth.cuda() 56 | image = image.cuda() 57 | 58 | image = torch.autograd.Variable(image) 59 | depth = torch.autograd.Variable(depth) 60 | 61 | start = time.time() 62 | pred = model(image) 63 | end = time.time() 64 | running_time = end - start 65 | 66 | print(pred.size()) 67 | print(depth.size()) 68 | 69 | pred_ = np.squeeze(pred.data.cpu().numpy()) 70 | depth_ = np.squeeze(depth.cpu().numpy()) 71 | 72 | print(np.shape(pred_)) 73 | print(np.shape(depth_)) 74 | 75 | for seq_idx in range(len(pred_)): 76 | print(seq_idx) 77 | print(np.shape(depth_[0:])) 78 | 79 | depth = depth_[seq_idx] 80 | pred = pred_[seq_idx] 81 | 82 | d_min = min(np.min(depth), np.min(pred)) 83 | d_max = max(np.max(depth), np.max(pred)) 84 | # depth = colored_depthmap(depth, d_min, d_max) 85 | # pred = colored_depthmap(pred, d_min, d_max) 86 | depth = colored_depthmap(depth) 87 | pred = colored_depthmap(pred) 88 | 89 | print(d_min) 90 | print(d_max) 91 | 92 | filename = os.path.join('./samples/depth_' + str(seq_idx) + '.png') 93 | save_image(depth, filename) 94 | 95 | filename = os.path.join('./samples/pred_' + str(seq_idx) + '.png') 96 | save_image(pred, filename) 97 | 98 | # if metrics_s is not None: 99 | # metrics_s(torch.stack(pred_new, 0).cpu(), torch.stack(depth_new, 0)) 100 | 101 | # result_s = metrics_s.loss_get() 102 | # print(result_s) 103 | 104 | 105 | if __name__ == '__main__': 106 | test() -------------------------------------------------------------------------------- /models/R_CLSTM_modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn as nn 3 | import torch 4 | import time 5 | 6 | def maps_2_cubes(x, b, d): 7 | x_b, x_c, x_h, x_w = x.shape 8 | x = x.contiguous().view(b, d, x_c, x_h, x_w) 9 | 10 | return x.permute(0, 2, 1, 3, 4) 11 | 12 | 13 | def maps_2_maps(x, b, d): 14 | x_b, x_c, x_h, x_w = x.shape 15 | x = x.contiguous().view(b, d * x_c, x_h, x_w) 16 | 17 | return x 18 | 19 | 20 | class R(nn.Module): 21 | def __init__(self, block_channel): 22 | super(R, self).__init__() 23 | 24 | num_features = 64 + block_channel[3] // 32 25 | self.conv0 = nn.Conv2d(num_features, num_features, 26 | kernel_size=5, stride=1, padding=2, bias=False) 27 | self.bn0 = nn.BatchNorm2d(num_features) 28 | 29 | self.conv1 = nn.Conv2d(num_features, num_features, 30 | kernel_size=5, stride=1, padding=2, bias=False) 31 | self.bn1 = nn.BatchNorm2d(num_features) 32 | 33 | self.conv2 = nn.Conv2d( 34 | num_features, 1, kernel_size=5, stride=1, padding=2, bias=True) 35 | 36 | def forward(self, x): 37 | x0 = self.conv0(x) 38 | x0 = self.bn0(x0) 39 | x0 = F.relu(x0) 40 | 41 | x1 = self.conv1(x0) 42 | x1 = self.bn1(x1) 43 | h = F.relu(x1) 44 | 45 | pred_depth = self.conv2(h) 46 | 47 | return h, pred_depth 48 | 49 | 50 | class R_2(nn.Module): 51 | def __init__(self, block_channel): 52 | super(R_2, self).__init__() 53 | 54 | num_features = 64 + block_channel[3] // 32 + 4 55 | self.conv0 = nn.Conv2d(num_features, num_features, 56 | kernel_size=5, stride=1, padding=2, bias=False) 57 | self.bn0 = nn.BatchNorm2d(num_features) 58 | 59 | self.conv1 = nn.Conv2d(num_features, num_features, 60 | kernel_size=5, stride=1, padding=2, bias=False) 61 | self.bn1 = nn.BatchNorm2d(num_features) 62 | 63 | self.conv2 = nn.Conv2d( 64 | num_features, 1, kernel_size=5, stride=1, padding=2, bias=True) 65 | 66 | self.convh = nn.Conv2d( 67 | num_features, 4, kernel_size=3, stride=1, padding=1, bias=True) 68 | 69 | def forward(self, x): 70 | x0 = self.conv0(x) 71 | x0 = self.bn0(x0) 72 | x0 = F.relu(x0) 73 | 74 | x1 = self.conv1(x0) 75 | x1 = self.bn1(x1) 76 | x1 = F.relu(x1) 77 | 78 | h = self.convh(x1) 79 | pred_depth = self.conv2(x1) 80 | 81 | return h, pred_depth 82 | 83 | 84 | class R_d(nn.Module): 85 | def __init__(self, block_channel): 86 | super(R_d, self).__init__() 87 | 88 | num_features = 64 + block_channel[3] // 32 + 4 89 | self.conv0 = nn.Conv2d(num_features, num_features, 90 | kernel_size=5, stride=1, padding=2, bias=False) 91 | self.bn0 = nn.BatchNorm2d(num_features) 92 | 93 | self.conv1 = nn.Conv2d(num_features, num_features, 94 | kernel_size=5, stride=1, padding=2, bias=False) 95 | self.bn1 = nn.BatchNorm2d(num_features) 96 | 97 | self.dropout = nn.Dropout2d(p=0.5) 98 | 99 | self.conv2 = nn.Conv2d( 100 | num_features, 1, kernel_size=5, stride=1, padding=2, bias=True) 101 | 102 | self.convh = nn.Conv2d( 103 | num_features, 4, kernel_size=3, stride=1, padding=1, bias=True) 104 | 105 | def forward(self, x): 106 | x0 = self.conv0(x) 107 | x0 = self.bn0(x0) 108 | x0 = F.relu(x0) 109 | 110 | x1 = self.conv1(x0) 111 | x1 = self.bn1(x1) 112 | x1 = F.relu(x1) 113 | 114 | h = self.convh(x1) 115 | x1 = self.dropout(x1) 116 | pred_depth = self.conv2(x1) 117 | 118 | return h, pred_depth 119 | 120 | 121 | class R_3(nn.Module): 122 | def __init__(self, block_channel): 123 | super(R_3, self).__init__() 124 | 125 | num_features = 64 + block_channel[3] // 32 + 8 126 | self.conv0 = nn.Conv2d(num_features, num_features, 127 | kernel_size=5, stride=1, padding=2, bias=False) 128 | self.bn0 = nn.BatchNorm2d(num_features) 129 | 130 | self.conv1 = nn.Conv2d(num_features, num_features, 131 | kernel_size=5, stride=1, padding=2, bias=False) 132 | self.bn1 = nn.BatchNorm2d(num_features) 133 | 134 | self.conv2 = nn.Conv2d( 135 | num_features, 1, kernel_size=5, stride=1, padding=2, bias=True) 136 | 137 | self.convh = nn.Conv2d( 138 | num_features, 8, kernel_size=3, stride=1, padding=1, bias=True) 139 | 140 | def forward(self, x): 141 | x0 = self.conv0(x) 142 | x0 = self.bn0(x0) 143 | x0 = F.relu(x0) 144 | 145 | x1 = self.conv1(x0) 146 | x1 = self.bn1(x1) 147 | x1 = F.relu(x1) 148 | 149 | h = self.convh(x1) 150 | pred_depth = self.conv2(x1) 151 | 152 | return h, pred_depth 153 | 154 | 155 | class R_10(nn.Module): 156 | def __init__(self, block_channel): 157 | super(R_10, self).__init__() 158 | 159 | num_features = 64 + block_channel[3] // 32 + 8 160 | self.conv0 = nn.Conv2d(num_features, num_features, 161 | kernel_size=5, stride=1, padding=2, bias=False) 162 | self.bn0 = nn.BatchNorm2d(num_features) 163 | 164 | self.conv1 = nn.Conv2d(num_features, num_features, 165 | kernel_size=5, stride=1, padding=2, bias=False) 166 | self.bn1 = nn.BatchNorm2d(num_features) 167 | 168 | self.conv2 = nn.Conv2d( 169 | num_features, 1, kernel_size=5, stride=1, padding=2, bias=True) 170 | 171 | self.convh = nn.Conv2d( 172 | num_features, 8, kernel_size=3, stride=1, padding=1, bias=True) 173 | 174 | def forward(self, x): 175 | x0 = self.conv0(x) 176 | x0 = self.bn0(x0) 177 | x0 = F.relu(x0) 178 | 179 | x1 = self.conv1(x0) 180 | x1 = self.bn1(x1) 181 | x1 = F.relu(x1) 182 | 183 | h = self.convh(x1) 184 | pred_depth = F.tanh(self.conv2(x1)) 185 | 186 | return h, pred_depth 187 | 188 | 189 | class R_CLSTM_1(nn.Module): 190 | def __init__(self, block_channel): 191 | super(R_CLSTM_1, self).__init__() 192 | num_features = 64 + block_channel[3] // 32 193 | self.Refine = R(block_channel) 194 | self.F_t = nn.Sequential( 195 | nn.Conv2d(in_channels=num_features+num_features, 196 | out_channels=num_features, 197 | kernel_size=3, 198 | padding=1, 199 | ), 200 | nn.Sigmoid() 201 | ) 202 | self.I_t = nn.Sequential( 203 | nn.Conv2d(in_channels=num_features + num_features, 204 | out_channels=num_features, 205 | kernel_size=3, 206 | padding=1, 207 | ), 208 | nn.Sigmoid() 209 | ) 210 | self.C_t = nn.Sequential( 211 | nn.Conv2d(in_channels=num_features + num_features, 212 | out_channels=num_features, 213 | kernel_size=3, 214 | padding=1, 215 | ), 216 | nn.Tanh() 217 | ) 218 | self.Q_t = nn.Sequential( 219 | nn.Conv2d(in_channels=num_features + num_features, 220 | out_channels=num_features, 221 | kernel_size=3, 222 | padding=1, 223 | ), 224 | nn.Sigmoid() 225 | ) 226 | 227 | def forward(self, input_tensor, b, d): 228 | input_tensor = maps_2_cubes(input_tensor, b, d) 229 | b, c, d, h, w = input_tensor.shape 230 | h_state_init = torch.zeros(b, c, h, w).to('cuda') 231 | c_state_init = torch.zeros(b, c, h, w).to('cuda') 232 | 233 | seq_len = d 234 | 235 | h_state, c_state = h_state_init, c_state_init 236 | output_inner = [] 237 | for t in range(seq_len): 238 | input_cat = torch.cat((input_tensor[:,:,t,:,:], h_state), dim=1) 239 | c_state = self.F_t(input_cat) * c_state + self.I_t(input_cat) * self.C_t(input_cat) 240 | 241 | h_state, p_depth = self.Refine(c_state * self.Q_t(input_cat)) 242 | 243 | output_inner.append(p_depth) 244 | 245 | layer_output = torch.stack(output_inner, dim=2) 246 | 247 | return layer_output 248 | 249 | 250 | class R_CLSTM_2(nn.Module): 251 | def __init__(self, block_channel): 252 | super(R_CLSTM_2, self).__init__() 253 | num_features = 64 + block_channel[3] // 32 254 | self.Refine = R(block_channel) 255 | self.F_t = nn.Sequential( 256 | nn.Conv2d(in_channels=num_features + num_features, 257 | out_channels=num_features // 2, 258 | kernel_size=3, 259 | padding=1, 260 | ), 261 | nn.Sigmoid() 262 | ) 263 | self.I_t = nn.Sequential( 264 | nn.Conv2d(in_channels=num_features + num_features, 265 | out_channels=num_features // 2, 266 | kernel_size=3, 267 | padding=1, 268 | ), 269 | nn.Sigmoid() 270 | ) 271 | self.C_t = nn.Sequential( 272 | nn.Conv2d(in_channels=num_features + num_features, 273 | out_channels=num_features // 2, 274 | kernel_size=3, 275 | padding=1, 276 | ), 277 | nn.Tanh() 278 | ) 279 | self.Q_t = nn.Sequential( 280 | nn.Conv2d(in_channels=num_features + num_features, 281 | out_channels=num_features // 2, 282 | kernel_size=3, 283 | padding=1, 284 | ), 285 | nn.Sigmoid() 286 | ) 287 | 288 | def forward(self, input_tensor, b, d): 289 | input_tensor = maps_2_cubes(input_tensor, b, d) 290 | b, c, d, h, w = input_tensor.shape 291 | h_state_init = torch.zeros(b, c, h, w).to('cuda') 292 | c_state_init = torch.zeros(b, c // 2, h, w).to('cuda') 293 | 294 | seq_len = d 295 | 296 | h_state, c_state = h_state_init, c_state_init 297 | output_inner = [] 298 | for t in range(seq_len): 299 | input_cat = torch.cat((input_tensor[:, :, t, :, :], h_state), dim=1) 300 | c_state = self.F_t(input_cat) * c_state + self.I_t(input_cat) * self.C_t(input_cat) 301 | 302 | h_state, p_depth = self.Refine(torch.cat((c_state, self.Q_t(input_cat)), 1)) 303 | 304 | output_inner.append(p_depth) 305 | 306 | layer_output = torch.stack(output_inner, dim=2) 307 | 308 | return layer_output 309 | 310 | 311 | class R_CLSTM_3(nn.Module): 312 | def __init__(self, block_channel): 313 | super(R_CLSTM_3, self).__init__() 314 | num_features = 64 + block_channel[3] // 32 315 | self.Refine = R_2(block_channel) 316 | self.F_t = nn.Sequential( 317 | nn.Conv2d(in_channels=num_features+4, 318 | out_channels=4, 319 | kernel_size=3, 320 | padding=1, 321 | ), 322 | nn.Sigmoid() 323 | ) 324 | self.I_t = nn.Sequential( 325 | nn.Conv2d(in_channels=num_features + 4, 326 | out_channels=4, 327 | kernel_size=3, 328 | padding=1, 329 | ), 330 | nn.Sigmoid() 331 | ) 332 | self.C_t = nn.Sequential( 333 | nn.Conv2d(in_channels=num_features + 4, 334 | out_channels=4, 335 | kernel_size=3, 336 | padding=1, 337 | ), 338 | nn.Tanh() 339 | ) 340 | self.Q_t = nn.Sequential( 341 | nn.Conv2d(in_channels=num_features + 4, 342 | out_channels=num_features, 343 | kernel_size=3, 344 | padding=1, 345 | ), 346 | nn.Sigmoid() 347 | ) 348 | 349 | def forward(self, input_tensor, b, d): 350 | input_tensor = maps_2_cubes(input_tensor, b, d) 351 | b, c, d, h, w = input_tensor.shape 352 | h_state_init = torch.zeros(b, 4, h, w).to('cuda') 353 | c_state_init = torch.zeros(b, 4, h, w).to('cuda') 354 | 355 | seq_len = d 356 | 357 | h_state, c_state = h_state_init, c_state_init 358 | output_inner = [] 359 | for t in range(seq_len): 360 | input_cat = torch.cat((input_tensor[:,:,t,:,:], h_state), dim=1) 361 | c_state = self.F_t(input_cat) * c_state + self.I_t(input_cat) * self.C_t(input_cat) 362 | 363 | h_state, p_depth = self.Refine(torch.cat((c_state, self.Q_t(input_cat)), 1)) 364 | 365 | output_inner.append(p_depth) 366 | 367 | layer_output = torch.stack(output_inner, dim=2) 368 | 369 | return layer_output 370 | 371 | 372 | class R_CLSTM_4(nn.Module): 373 | def __init__(self, block_channel): 374 | super(R_CLSTM_4, self).__init__() 375 | num_features = 64 + block_channel[3] // 32 376 | self.Refine = R_d(block_channel) 377 | self.F_t = nn.Sequential( 378 | nn.Conv2d(in_channels=num_features+4, 379 | out_channels=4, 380 | kernel_size=3, 381 | padding=1, 382 | ), 383 | nn.Sigmoid() 384 | ) 385 | self.I_t = nn.Sequential( 386 | nn.Conv2d(in_channels=num_features + 4, 387 | out_channels=4, 388 | kernel_size=3, 389 | padding=1, 390 | ), 391 | nn.Sigmoid() 392 | ) 393 | self.C_t = nn.Sequential( 394 | nn.Conv2d(in_channels=num_features + 4, 395 | out_channels=4, 396 | kernel_size=3, 397 | padding=1, 398 | ), 399 | nn.Tanh() 400 | ) 401 | self.Q_t = nn.Sequential( 402 | nn.Conv2d(in_channels=num_features + 4, 403 | out_channels=num_features, 404 | kernel_size=3, 405 | padding=1, 406 | ), 407 | nn.Sigmoid() 408 | ) 409 | 410 | def forward(self, input_tensor, b, d): 411 | input_tensor = maps_2_cubes(input_tensor, b, d) 412 | b, c, d, h, w = input_tensor.shape 413 | h_state_init = torch.zeros(b, 4, h, w).to('cuda') 414 | c_state_init = torch.zeros(b, 4, h, w).to('cuda') 415 | 416 | seq_len = d 417 | 418 | h_state, c_state = h_state_init, c_state_init 419 | output_inner = [] 420 | for t in range(seq_len): 421 | input_cat = torch.cat((input_tensor[:,:,t,:,:], h_state), dim=1) 422 | c_state = self.F_t(input_cat) * c_state + self.I_t(input_cat) * self.C_t(input_cat) 423 | 424 | h_state, p_depth = self.Refine(torch.cat((c_state, self.Q_t(input_cat)), 1)) 425 | 426 | output_inner.append(p_depth) 427 | 428 | layer_output = torch.stack(output_inner, dim=2) 429 | 430 | return layer_output 431 | 432 | 433 | class R_CLSTM_5(nn.Module): 434 | def __init__(self, block_channel): 435 | super(R_CLSTM_5, self).__init__() 436 | num_features = 64 + block_channel[3] // 32 437 | self.Refine = R_3(block_channel) 438 | self.F_t = nn.Sequential( 439 | nn.Conv2d(in_channels=num_features+8, 440 | out_channels=8, 441 | kernel_size=3, 442 | padding=1, 443 | ), 444 | nn.Sigmoid() 445 | ) 446 | self.I_t = nn.Sequential( 447 | nn.Conv2d(in_channels=num_features + 8, 448 | out_channels=8, 449 | kernel_size=3, 450 | padding=1, 451 | ), 452 | nn.Sigmoid() 453 | ) 454 | self.C_t = nn.Sequential( 455 | nn.Conv2d(in_channels=num_features + 8, 456 | out_channels=8, 457 | kernel_size=3, 458 | padding=1, 459 | ), 460 | nn.Tanh() 461 | ) 462 | self.Q_t = nn.Sequential( 463 | nn.Conv2d(in_channels=num_features + 8, 464 | out_channels=num_features, 465 | kernel_size=3, 466 | padding=1, 467 | ), 468 | nn.Sigmoid() 469 | ) 470 | 471 | def forward(self, input_tensor, b, d): 472 | input_tensor = maps_2_cubes(input_tensor, b, d) 473 | b, c, d, h, w = input_tensor.shape 474 | h_state_init = torch.zeros(b, 8, h, w).to('cuda') 475 | c_state_init = torch.zeros(b, 8, h, w).to('cuda') 476 | 477 | seq_len = d 478 | 479 | h_state, c_state = h_state_init, c_state_init 480 | output_inner = [] 481 | start = time.time() 482 | for t in range(seq_len): 483 | input_cat = torch.cat((input_tensor[:,:,t,:,:], h_state), dim=1) 484 | c_state = self.F_t(input_cat) * c_state + self.I_t(input_cat) * self.C_t(input_cat) 485 | 486 | h_state, p_depth = self.Refine(torch.cat((c_state, self.Q_t(input_cat)), 1)) 487 | 488 | output_inner.append(p_depth) 489 | 490 | layer_output = torch.stack(output_inner, dim=2) 491 | torch.cuda.synchronize() 492 | # print(time.time()-start) 493 | 494 | return layer_output 495 | 496 | 497 | class R_cell(nn.Module): 498 | def __init__(self, block_channel, cell_width=16): 499 | super(R_cell, self).__init__() 500 | 501 | num_features = 64 + block_channel[3] // 32 502 | self.conv0 = nn.Conv2d(num_features + cell_width, num_features, 503 | kernel_size=5, stride=1, padding=2, bias=False) 504 | self.bn0 = nn.BatchNorm2d(num_features) 505 | 506 | self.conv1 = nn.Conv2d(num_features, num_features, 507 | kernel_size=5, stride=1, padding=2, bias=False) 508 | self.bn1 = nn.BatchNorm2d(num_features) 509 | 510 | self.conv2 = nn.Conv2d( 511 | num_features, 1, kernel_size=5, stride=1, padding=2, bias=True) 512 | 513 | self.convh = nn.Conv2d( 514 | num_features, cell_width, kernel_size=3, stride=1, padding=1, bias=True) 515 | 516 | def forward(self, x): 517 | x0 = self.conv0(x) 518 | x0 = self.bn0(x0) 519 | x0 = F.relu(x0) 520 | 521 | x1 = self.conv1(x0) 522 | x1 = self.bn1(x1) 523 | x1 = F.relu(x1) 524 | 525 | h = self.convh(x1) 526 | pred_depth = self.conv2(x1) 527 | 528 | return h, pred_depth 529 | 530 | 531 | class R_CLSTM_6(nn.Module): 532 | def __init__(self, block_channel): 533 | super(R_CLSTM_6, self).__init__() 534 | num_features = 64 + block_channel[3] // 32 535 | self.cell_width = 8 536 | self.Refine = R_cell(block_channel, self.cell_width) 537 | self.F_t = nn.Sequential( 538 | nn.Conv2d(in_channels=num_features + self.cell_width, 539 | out_channels=self.cell_width, 540 | kernel_size=3, 541 | padding=1, 542 | ), 543 | nn.Sigmoid() 544 | ) 545 | self.I_t = nn.Sequential( 546 | nn.Conv2d(in_channels=num_features + self.cell_width, 547 | out_channels=self.cell_width, 548 | kernel_size=3, 549 | padding=1, 550 | ), 551 | nn.Sigmoid() 552 | ) 553 | self.C_t = nn.Sequential( 554 | nn.Conv2d(in_channels=num_features + self.cell_width, 555 | out_channels=self.cell_width, 556 | kernel_size=3, 557 | padding=1, 558 | ), 559 | nn.Tanh() 560 | ) 561 | self.Q_t = nn.Sequential( 562 | nn.Conv2d(in_channels=num_features + self.cell_width, 563 | out_channels=num_features, 564 | kernel_size=3, 565 | padding=1, 566 | ), 567 | nn.Sigmoid() 568 | ) 569 | 570 | def forward(self, input_tensor, b, d): 571 | input_tensor = maps_2_cubes(input_tensor, b, d) 572 | b, c, d, h, w = input_tensor.shape 573 | h_state_init = torch.zeros(b, self.cell_width, h, w).to('cuda') 574 | c_state_init = torch.zeros(b, self.cell_width, h, w).to('cuda') 575 | 576 | seq_len = d 577 | 578 | h_state, c_state = h_state_init, c_state_init 579 | output_inner = [] 580 | for t in range(seq_len): 581 | input_cat = torch.cat((input_tensor[:, :, t, :, :], h_state), dim=1) 582 | c_state = self.F_t(input_cat) * c_state + self.I_t(input_cat) * self.C_t(input_cat) 583 | 584 | h_state, p_depth = self.Refine(torch.cat((c_state, self.Q_t(input_cat)), 1)) 585 | 586 | output_inner.append(p_depth) 587 | 588 | layer_output = torch.stack(output_inner, dim=2) 589 | 590 | return layer_output 591 | 592 | 593 | class R_CLSTM_7(nn.Module): 594 | def __init__(self, block_channel): 595 | super(R_CLSTM_7, self).__init__() 596 | num_features = 64 + block_channel[3] // 32 597 | self.cell_width = 12 598 | self.Refine = R_cell(block_channel, self.cell_width) 599 | self.F_t = nn.Sequential( 600 | nn.Conv2d(in_channels=num_features + self.cell_width, 601 | out_channels=self.cell_width, 602 | kernel_size=3, 603 | padding=1, 604 | ), 605 | nn.Sigmoid() 606 | ) 607 | self.I_t = nn.Sequential( 608 | nn.Conv2d(in_channels=num_features + self.cell_width, 609 | out_channels=self.cell_width, 610 | kernel_size=3, 611 | padding=1, 612 | ), 613 | nn.Sigmoid() 614 | ) 615 | self.C_t = nn.Sequential( 616 | nn.Conv2d(in_channels=num_features + self.cell_width, 617 | out_channels=self.cell_width, 618 | kernel_size=3, 619 | padding=1, 620 | ), 621 | nn.Tanh() 622 | ) 623 | self.Q_t = nn.Sequential( 624 | nn.Conv2d(in_channels=num_features + self.cell_width, 625 | out_channels=num_features, 626 | kernel_size=3, 627 | padding=1, 628 | ), 629 | nn.Sigmoid() 630 | ) 631 | 632 | def forward(self, input_tensor, b, d): 633 | input_tensor = maps_2_cubes(input_tensor, b, d) 634 | b, c, d, h, w = input_tensor.shape 635 | h_state_init = torch.zeros(b, self.cell_width, h, w).to('cuda') 636 | c_state_init = torch.zeros(b, self.cell_width, h, w).to('cuda') 637 | 638 | seq_len = d 639 | 640 | h_state, c_state = h_state_init, c_state_init 641 | output_inner = [] 642 | for t in range(seq_len): 643 | input_cat = torch.cat((input_tensor[:, :, t, :, :], h_state), dim=1) 644 | c_state = self.F_t(input_cat) * c_state + self.I_t(input_cat) * self.C_t(input_cat) 645 | 646 | h_state, p_depth = self.Refine(torch.cat((c_state, self.Q_t(input_cat)), 1)) 647 | 648 | output_inner.append(p_depth) 649 | 650 | layer_output = torch.stack(output_inner, dim=2) 651 | 652 | return layer_output 653 | 654 | 655 | class R_CLSTM_8(nn.Module): 656 | def __init__(self, block_channel): 657 | super(R_CLSTM_8, self).__init__() 658 | num_features = 64 + block_channel[3] // 32 659 | self.cell_width = 16 660 | self.Refine = R_cell(block_channel, self.cell_width) 661 | self.F_t = nn.Sequential( 662 | nn.Conv2d(in_channels=num_features + self.cell_width, 663 | out_channels=self.cell_width, 664 | kernel_size=3, 665 | padding=1, 666 | ), 667 | nn.Sigmoid() 668 | ) 669 | self.I_t = nn.Sequential( 670 | nn.Conv2d(in_channels=num_features + self.cell_width, 671 | out_channels=self.cell_width, 672 | kernel_size=3, 673 | padding=1, 674 | ), 675 | nn.Sigmoid() 676 | ) 677 | self.C_t = nn.Sequential( 678 | nn.Conv2d(in_channels=num_features + self.cell_width, 679 | out_channels=self.cell_width, 680 | kernel_size=3, 681 | padding=1, 682 | ), 683 | nn.Tanh() 684 | ) 685 | self.Q_t = nn.Sequential( 686 | nn.Conv2d(in_channels=num_features + self.cell_width, 687 | out_channels=num_features, 688 | kernel_size=3, 689 | padding=1, 690 | ), 691 | nn.Sigmoid() 692 | ) 693 | 694 | def forward(self, input_tensor, b, d): 695 | input_tensor = maps_2_cubes(input_tensor, b, d) 696 | b, c, d, h, w = input_tensor.shape 697 | h_state_init = torch.zeros(b, self.cell_width, h, w).to('cuda') 698 | c_state_init = torch.zeros(b, self.cell_width, h, w).to('cuda') 699 | 700 | seq_len = d 701 | 702 | h_state, c_state = h_state_init, c_state_init 703 | output_inner = [] 704 | for t in range(seq_len): 705 | input_cat = torch.cat((input_tensor[:, :, t, :, :], h_state), dim=1) 706 | c_state = self.F_t(input_cat) * c_state + self.I_t(input_cat) * self.C_t(input_cat) 707 | 708 | h_state, p_depth = self.Refine(torch.cat((c_state, self.Q_t(input_cat)), 1)) 709 | 710 | output_inner.append(p_depth) 711 | 712 | layer_output = torch.stack(output_inner, dim=2) 713 | 714 | return layer_output 715 | 716 | 717 | class R_CLSTM_9(nn.Module): 718 | def __init__(self, block_channel): 719 | super(R_CLSTM_9, self).__init__() 720 | num_features = 64 + block_channel[3] // 32 721 | self.cell_width = 32 722 | self.Refine = R_cell(block_channel, self.cell_width) 723 | self.F_t = nn.Sequential( 724 | nn.Conv2d(in_channels=num_features + self.cell_width, 725 | out_channels=self.cell_width, 726 | kernel_size=3, 727 | padding=1, 728 | ), 729 | nn.Sigmoid() 730 | ) 731 | self.I_t = nn.Sequential( 732 | nn.Conv2d(in_channels=num_features + self.cell_width, 733 | out_channels=self.cell_width, 734 | kernel_size=3, 735 | padding=1, 736 | ), 737 | nn.Sigmoid() 738 | ) 739 | self.C_t = nn.Sequential( 740 | nn.Conv2d(in_channels=num_features + self.cell_width, 741 | out_channels=self.cell_width, 742 | kernel_size=3, 743 | padding=1, 744 | ), 745 | nn.Tanh() 746 | ) 747 | self.Q_t = nn.Sequential( 748 | nn.Conv2d(in_channels=num_features + self.cell_width, 749 | out_channels=num_features, 750 | kernel_size=3, 751 | padding=1, 752 | ), 753 | nn.Sigmoid() 754 | ) 755 | 756 | def forward(self, input_tensor, b, d): 757 | input_tensor = maps_2_cubes(input_tensor, b, d) 758 | b, c, d, h, w = input_tensor.shape 759 | h_state_init = torch.zeros(b, self.cell_width, h, w).to('cuda') 760 | c_state_init = torch.zeros(b, self.cell_width, h, w).to('cuda') 761 | 762 | seq_len = d 763 | 764 | h_state, c_state = h_state_init, c_state_init 765 | output_inner = [] 766 | for t in range(seq_len): 767 | input_cat = torch.cat((input_tensor[:, :, t, :, :], h_state), dim=1) 768 | c_state = self.F_t(input_cat) * c_state + self.I_t(input_cat) * self.C_t(input_cat) 769 | 770 | h_state, p_depth = self.Refine(torch.cat((c_state, self.Q_t(input_cat)), 1)) 771 | 772 | output_inner.append(p_depth) 773 | 774 | layer_output = torch.stack(output_inner, dim=2) 775 | 776 | return layer_output 777 | 778 | 779 | class R_CLSTM_10(nn.Module): 780 | def __init__(self, block_channel): 781 | super(R_CLSTM_10, self).__init__() 782 | num_features = 64 + block_channel[3] // 32 783 | self.Refine = R_10(block_channel) 784 | self.F_t = nn.Sequential( 785 | nn.Conv2d(in_channels=num_features+8, 786 | out_channels=8, 787 | kernel_size=3, 788 | padding=1, 789 | ), 790 | nn.Sigmoid() 791 | ) 792 | self.I_t = nn.Sequential( 793 | nn.Conv2d(in_channels=num_features + 8, 794 | out_channels=8, 795 | kernel_size=3, 796 | padding=1, 797 | ), 798 | nn.Sigmoid() 799 | ) 800 | self.C_t = nn.Sequential( 801 | nn.Conv2d(in_channels=num_features + 8, 802 | out_channels=8, 803 | kernel_size=3, 804 | padding=1, 805 | ), 806 | nn.Tanh() 807 | ) 808 | self.Q_t = nn.Sequential( 809 | nn.Conv2d(in_channels=num_features + 8, 810 | out_channels=num_features, 811 | kernel_size=3, 812 | padding=1, 813 | ), 814 | nn.Sigmoid() 815 | ) 816 | 817 | def forward(self, input_tensor, b, d): 818 | input_tensor = maps_2_cubes(input_tensor, b, d) 819 | b, c, d, h, w = input_tensor.shape 820 | h_state_init = torch.zeros(b, 8, h, w).to('cuda') 821 | c_state_init = torch.zeros(b, 8, h, w).to('cuda') 822 | 823 | seq_len = d 824 | 825 | h_state, c_state = h_state_init, c_state_init 826 | output_inner = [] 827 | for t in range(seq_len): 828 | input_cat = torch.cat((input_tensor[:,:,t,:,:], h_state), dim=1) 829 | c_state = self.F_t(input_cat) * c_state + self.I_t(input_cat) * self.C_t(input_cat) 830 | 831 | h_state, p_depth = self.Refine(torch.cat((c_state, self.Q_t(input_cat)), 1)) 832 | 833 | output_inner.append(p_depth) 834 | 835 | layer_output = torch.stack(output_inner, dim=2) 836 | 837 | return layer_output 838 | 839 | 840 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | from models.modules import E_resnet, E_densenet, E_senet 5 | from models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152 6 | from models.densenet import densenet161, densenet121, densenet169, densenet201 7 | from models.senet import senet154, se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d, se_resnext101_32x4d 8 | import pdb 9 | 10 | __models_small__ = { 11 | 'ResNet18': lambda :E_resnet(resnet18(pretrained = True)), 12 | 'ResNet34': lambda :E_resnet(resnet34(pretrained = True)), 13 | 'ResNet50': lambda :E_resnet(resnet50(pretrained = True)), 14 | 'ResNet101': lambda :E_resnet(resnet101(pretrained = True)), 15 | 'ResNet152': lambda :E_resnet(resnet152(pretrained = True)), 16 | 'DenseNet121': lambda :E_densenet(densenet121(pretrained = True)), 17 | 'DenseNet161': lambda :E_densenet(densenet161(pretrained = True)), 18 | 'DenseNet169': lambda :E_densenet(densenet169(pretrained = True)), 19 | 'DenseNet201': lambda :E_densenet(densenet201(pretrained = True)), 20 | 'SENet154': lambda :E_senet(senet154(pretrained="imagenet")), 21 | 'SE_ResNet50': lambda :E_senet(se_resnet50(pretrained="imagenet")), 22 | 'SE_ResNet101': lambda :E_senet(se_resnet101(pretrained="imagenet")), 23 | 'SE_ResNet152': lambda :E_senet(se_resnet152(pretrained="imagenet")), 24 | 'SE_ResNext50_32x4d': lambda :E_senet(se_resnext50_32x4d(pretrained="imagenet")), 25 | 'SE_ResNext101_32x4d': lambda :E_senet(se_resnext101_32x4d(pretrained="imagenet")) 26 | } 27 | 28 | 29 | def get_models(args): 30 | backbone = args.backbone 31 | 32 | if os.getenv('TORCH_MODEL_ZOO') != args.pretrained_dir: 33 | os.environ['TORCH_MODEL_ZOO'] = args.pretrained_dir 34 | else: 35 | pass 36 | 37 | return __models_small__[backbone]() 38 | 39 | -------------------------------------------------------------------------------- /models/backbone_dict.py: -------------------------------------------------------------------------------- 1 | from models.resnet import resnet18, resnet34, resnet50 2 | 3 | 4 | backbone_dict = { 5 | 'resnet18': resnet18, 6 | 'resnet34': resnet34, 7 | 'resnet50': resnet50, 8 | } -------------------------------------------------------------------------------- /models/densenet.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from collections import OrderedDict 7 | 8 | 9 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 10 | 11 | 12 | model_urls = { 13 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 14 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 15 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 16 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 17 | } 18 | 19 | 20 | def densenet121(pretrained=False, **kwargs): 21 | r"""Densenet-121 model from 22 | `"Densely Connected Convolutional Networks" `_ 23 | 24 | Args: 25 | pretrained (bool): If True, returns a model pre-trained on ImageNet 26 | """ 27 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), 28 | **kwargs) 29 | if pretrained: 30 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 31 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 32 | # They are also in the checkpoints in model_urls. This pattern is used 33 | # to find such keys. 34 | pattern = re.compile( 35 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 36 | state_dict = model_zoo.load_url(model_urls['densenet121']) 37 | for key in list(state_dict.keys()): 38 | res = pattern.match(key) 39 | if res: 40 | new_key = res.group(1) + res.group(2) 41 | state_dict[new_key] = state_dict[key] 42 | del state_dict[key] 43 | model.load_state_dict(state_dict) 44 | return model 45 | 46 | 47 | def densenet169(pretrained=False, **kwargs): 48 | r"""Densenet-169 model from 49 | `"Densely Connected Convolutional Networks" `_ 50 | 51 | Args: 52 | pretrained (bool): If True, returns a model pre-trained on ImageNet 53 | """ 54 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 55 | **kwargs) 56 | if pretrained: 57 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 58 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 59 | # They are also in the checkpoints in model_urls. This pattern is used 60 | # to find such keys. 61 | pattern = re.compile( 62 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 63 | state_dict = model_zoo.load_url(model_urls['densenet169']) 64 | for key in list(state_dict.keys()): 65 | res = pattern.match(key) 66 | if res: 67 | new_key = res.group(1) + res.group(2) 68 | state_dict[new_key] = state_dict[key] 69 | del state_dict[key] 70 | model.load_state_dict(state_dict) 71 | return model 72 | 73 | 74 | def densenet201(pretrained=False, **kwargs): 75 | r"""Densenet-201 model from 76 | `"Densely Connected Convolutional Networks" `_ 77 | 78 | Args: 79 | pretrained (bool): If True, returns a model pre-trained on ImageNet 80 | """ 81 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 82 | **kwargs) 83 | if pretrained: 84 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 85 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 86 | # They are also in the checkpoints in model_urls. This pattern is used 87 | # to find such keys. 88 | pattern = re.compile( 89 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 90 | state_dict = model_zoo.load_url(model_urls['densenet201']) 91 | for key in list(state_dict.keys()): 92 | res = pattern.match(key) 93 | if res: 94 | new_key = res.group(1) + res.group(2) 95 | state_dict[new_key] = state_dict[key] 96 | del state_dict[key] 97 | model.load_state_dict(state_dict) 98 | return model 99 | 100 | 101 | def densenet161(pretrained=False, **kwargs): 102 | r"""Densenet-161 model from 103 | `"Densely Connected Convolutional Networks" `_ 104 | 105 | Args: 106 | pretrained (bool): If True, returns a model pre-trained on ImageNet 107 | """ 108 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), 109 | **kwargs) 110 | if pretrained: 111 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 112 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 113 | # They are also in the checkpoints in model_urls. This pattern is used 114 | # to find such keys. 115 | pattern = re.compile( 116 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 117 | state_dict = model_zoo.load_url(model_urls['densenet161']) 118 | for key in list(state_dict.keys()): 119 | res = pattern.match(key) 120 | if res: 121 | new_key = res.group(1) + res.group(2) 122 | state_dict[new_key] = state_dict[key] 123 | del state_dict[key] 124 | model.load_state_dict(state_dict) 125 | return model 126 | 127 | class _DenseLayer(nn.Sequential): 128 | 129 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 130 | super(_DenseLayer, self).__init__() 131 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 132 | self.add_module('relu1', nn.ReLU(inplace=True)), 133 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 134 | growth_rate, kernel_size=1, stride=1, bias=False)), 135 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 136 | self.add_module('relu2', nn.ReLU(inplace=True)), 137 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 138 | kernel_size=3, stride=1, padding=1, bias=False)), 139 | self.drop_rate = drop_rate 140 | 141 | def forward(self, x): 142 | new_features = super(_DenseLayer, self).forward(x) 143 | if self.drop_rate > 0: 144 | new_features = F.dropout( 145 | new_features, p=self.drop_rate, training=self.training) 146 | return torch.cat([x, new_features], 1) 147 | 148 | 149 | class _DenseBlock(nn.Sequential): 150 | 151 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 152 | super(_DenseBlock, self).__init__() 153 | for i in range(num_layers): 154 | layer = _DenseLayer(num_input_features + i * 155 | growth_rate, growth_rate, bn_size, drop_rate) 156 | self.add_module('denselayer%d' % (i + 1), layer) 157 | 158 | 159 | class _Transition(nn.Sequential): 160 | 161 | def __init__(self, num_input_features, num_output_features): 162 | super(_Transition, self).__init__() 163 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 164 | self.add_module('relu', nn.ReLU(inplace=True)) 165 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 166 | kernel_size=1, stride=1, bias=False)) 167 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 168 | 169 | 170 | 171 | class DenseNet(nn.Module): 172 | 173 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 174 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): 175 | 176 | super(DenseNet, self).__init__() 177 | 178 | # First convolution 179 | self.features = nn.Sequential(OrderedDict([ 180 | ('conv0', nn.Conv2d(3, num_init_features, 181 | kernel_size=7, stride=2, padding=3, bias=False)), 182 | ('norm0', nn.BatchNorm2d(num_init_features)), 183 | ('relu0', nn.ReLU(inplace=True)), 184 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 185 | ])) 186 | 187 | # Each denseblock 188 | num_features = num_init_features 189 | for i, num_layers in enumerate(block_config): 190 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 191 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 192 | self.features.add_module('denseblock%d' % (i + 1), block) 193 | num_features = num_features + num_layers * growth_rate 194 | if i != len(block_config) - 1: 195 | trans = _Transition( 196 | num_input_features=num_features, num_output_features=num_features // 2) 197 | self.features.add_module('transition%d' % (i + 1), trans) 198 | num_features = num_features // 2 199 | # print(str(i), num_features) 200 | 201 | # Final batch norm 202 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 203 | self.num_features = num_features 204 | 205 | # Linear layer 206 | self.classifier = nn.Linear(num_features, num_classes) 207 | 208 | 209 | def forward(self, x): 210 | features = self.features(x) 211 | out = F.relu(features, inplace=True) 212 | out = F.avg_pool2d(out, kernel_size=7, stride=1).view( 213 | features.size(0), -1) 214 | out = self.classifier(out) 215 | return out 216 | 217 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | def adjust_gt(gt_depth, pred_depth): 7 | adjusted_gt = [] 8 | for each_depth in pred_depth: 9 | adjusted_gt.append(F.interpolate(gt_depth, size=[each_depth.size(2), each_depth.size(3)], 10 | mode='bilinear', align_corners=True)) 11 | return adjusted_gt 12 | 13 | class Sobel(nn.Module): 14 | def __init__(self): 15 | super(Sobel, self).__init__() 16 | self.edge_conv = nn.Conv2d(1, 2, kernel_size=3, stride=1, padding=1, bias=False) 17 | edge_kx = np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]]) 18 | edge_ky = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]) 19 | edge_k = np.stack((edge_kx, edge_ky)) 20 | 21 | edge_k = torch.from_numpy(edge_k).float().view(2, 1, 3, 3) 22 | self.edge_conv.weight = nn.Parameter(edge_k) 23 | 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def forward(self, x): 28 | out = self.edge_conv(x) 29 | out = out.contiguous().view(-1, 2, x.size(2), x.size(3)) 30 | 31 | return out 32 | 33 | def cal_spatial_loss(output, depth_gt): 34 | 35 | losses=[] 36 | 37 | for depth_index in range(len(output)): 38 | 39 | cos = nn.CosineSimilarity(dim=1, eps=0) 40 | get_gradient = Sobel().cuda() 41 | ones = torch.ones(depth_gt.size(0), 1, depth_gt.size(2),depth_gt.size(3)).float().cuda() 42 | ones = torch.autograd.Variable(ones) 43 | depth_grad = get_gradient(depth_gt) 44 | output_grad = get_gradient(output) 45 | depth_grad_dx = depth_grad[:, 0, :, :].contiguous().view_as(depth_gt) 46 | depth_grad_dy = depth_grad[:, 1, :, :].contiguous().view_as(depth_gt) 47 | output_grad_dx = output_grad[:, 0, :, :].contiguous().view_as(depth_gt) 48 | output_grad_dy = output_grad[:, 1, :, :].contiguous().view_as(depth_gt) 49 | 50 | depth_normal = torch.cat((-depth_grad_dx, -depth_grad_dy, ones), 1) 51 | output_normal = torch.cat((-output_grad_dx, -output_grad_dy, ones), 1) 52 | 53 | cof = 0.5 54 | 55 | loss_depth = torch.log(torch.abs(output - depth_gt) + cof).mean() 56 | loss_dx = torch.log(torch.abs(output_grad_dx - depth_grad_dx) + cof).mean() 57 | loss_dy = torch.log(torch.abs(output_grad_dy - depth_grad_dy) + cof).mean() 58 | loss_normal = torch.abs(1 - cos(output_normal, depth_normal)).mean() 59 | 60 | loss = loss_depth + loss_normal + (loss_dx + loss_dy) 61 | 62 | losses.append(loss) 63 | 64 | spatial_loss = sum(losses) 65 | 66 | return spatial_loss 67 | 68 | def cal_temporal_loss(pred_cls, gt_cls): 69 | return F.binary_cross_entropy_with_logits(pred_cls, gt_cls) 70 | -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | class _UpProjection(nn.Sequential): 6 | 7 | def __init__(self, num_input_features, num_output_features): 8 | super(_UpProjection, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(num_input_features, num_output_features, 11 | kernel_size=5, stride=1, padding=2, bias=False) 12 | self.bn1 = nn.BatchNorm2d(num_output_features) 13 | self.relu = nn.ReLU(inplace=True) 14 | self.conv1_2 = nn.Conv2d(num_output_features, num_output_features, 15 | kernel_size=3, stride=1, padding=1, bias=False) 16 | self.bn1_2 = nn.BatchNorm2d(num_output_features) 17 | 18 | self.conv2 = nn.Conv2d(num_input_features, num_output_features, 19 | kernel_size=5, stride=1, padding=2, bias=False) 20 | self.bn2 = nn.BatchNorm2d(num_output_features) 21 | 22 | def forward(self, x, size): 23 | x = F.upsample(x, size=size, mode='bilinear', align_corners=True) 24 | x_conv1 = self.relu(self.bn1(self.conv1(x))) 25 | bran1 = self.bn1_2(self.conv1_2(x_conv1)) 26 | bran2 = self.bn2(self.conv2(x)) 27 | 28 | out = self.relu(bran1 + bran2) 29 | 30 | return out 31 | 32 | class E_resnet(nn.Module): 33 | 34 | def __init__(self, original_model, num_features = 2048): 35 | super(E_resnet, self).__init__() 36 | self.conv1 = original_model.conv1 37 | self.bn1 = original_model.bn1 38 | self.relu = original_model.relu 39 | self.maxpool = original_model.maxpool 40 | 41 | self.layer1 = original_model.layer1 42 | self.layer2 = original_model.layer2 43 | self.layer3 = original_model.layer3 44 | self.layer4 = original_model.layer4 45 | 46 | 47 | def forward(self, x): 48 | x = self.conv1(x) 49 | x = self.bn1(x) 50 | x = self.relu(x) 51 | x = self.maxpool(x) 52 | 53 | x_block1 = self.layer1(x) 54 | x_block2 = self.layer2(x_block1) 55 | x_block3 = self.layer3(x_block2) 56 | x_block4 = self.layer4(x_block3) 57 | 58 | return x_block1, x_block2, x_block3, x_block4 59 | 60 | class E_densenet(nn.Module): 61 | 62 | def __init__(self, original_model, num_features = 2208): 63 | super(E_densenet, self).__init__() 64 | self.features = original_model.features 65 | 66 | def forward(self, x): 67 | x01 = self.features[0](x) 68 | x02 = self.features[1](x01) 69 | x03 = self.features[2](x02) 70 | x04 = self.features[3](x03) 71 | 72 | x_block1 = self.features[4](x04) 73 | x_block1 = self.features[5][0](x_block1) 74 | x_block1 = self.features[5][1](x_block1) 75 | x_block1 = self.features[5][2](x_block1) 76 | x_tran1 = self.features[5][3](x_block1) 77 | 78 | x_block2 = self.features[6](x_tran1) 79 | x_block2 = self.features[7][0](x_block2) 80 | x_block2 = self.features[7][1](x_block2) 81 | x_block2 = self.features[7][2](x_block2) 82 | x_tran2 = self.features[7][3](x_block2) 83 | 84 | x_block3 = self.features[8](x_tran2) 85 | x_block3 = self.features[9][0](x_block3) 86 | x_block3 = self.features[9][1](x_block3) 87 | x_block3 = self.features[9][2](x_block3) 88 | x_tran3 = self.features[9][3](x_block3) 89 | 90 | x_block4 = self.features[10](x_tran3) 91 | x_block4 = F.relu(self.features[11](x_block4)) 92 | 93 | return x_block1, x_block2, x_block3, x_block4 94 | 95 | class E_senet(nn.Module): 96 | 97 | def __init__(self, original_model, num_features = 2048): 98 | super(E_senet, self).__init__() 99 | self.base = nn.Sequential(*list(original_model.children())[:-3]) 100 | 101 | def forward(self, x): 102 | x = self.base[0](x) 103 | x_block1 = self.base[1](x) 104 | x_block2 = self.base[2](x_block1) 105 | x_block3 = self.base[3](x_block2) 106 | x_block4 = self.base[4](x_block3) 107 | 108 | return x_block1, x_block2, x_block3, x_block4 109 | 110 | class D(nn.Module): 111 | 112 | def __init__(self, num_features = 2048): 113 | super(D, self).__init__() 114 | self.conv = nn.Conv2d(num_features, num_features // 115 | 2, kernel_size=1, stride=1, bias=False) 116 | num_features = num_features // 2 117 | self.bn = nn.BatchNorm2d(num_features) 118 | 119 | self.up1 = _UpProjection( 120 | num_input_features=num_features, num_output_features=num_features // 2) 121 | num_features = num_features // 2 122 | 123 | self.up2 = _UpProjection( 124 | num_input_features=num_features, num_output_features=num_features // 2) 125 | num_features = num_features // 2 126 | 127 | self.up3 = _UpProjection( 128 | num_input_features=num_features, num_output_features=num_features // 2) 129 | num_features = num_features // 2 130 | 131 | self.up4 = _UpProjection( 132 | num_input_features=num_features, num_output_features=num_features // 2) 133 | num_features = num_features // 2 134 | 135 | 136 | def forward(self, x_block1, x_block2, x_block3, x_block4): 137 | x_d0 = F.relu(self.bn(self.conv(x_block4))) 138 | x_d1 = self.up1(x_d0, [x_block3.size(2), x_block3.size(3)]) 139 | x_d2 = self.up2(x_d1, [x_block2.size(2), x_block2.size(3)]) 140 | x_d3 = self.up3(x_d2, [x_block1.size(2), x_block1.size(3)]) 141 | x_d4 = self.up4(x_d3, [x_block1.size(2)*2, x_block1.size(3)*2]) 142 | 143 | return x_d4 144 | 145 | class MFF(nn.Module): 146 | 147 | def __init__(self, block_channel, num_features=64): 148 | 149 | super(MFF, self).__init__() 150 | 151 | self.up1 = _UpProjection( 152 | num_input_features=block_channel[0], num_output_features=16) 153 | 154 | self.up2 = _UpProjection( 155 | num_input_features=block_channel[1], num_output_features=16) 156 | 157 | self.up3 = _UpProjection( 158 | num_input_features=block_channel[2], num_output_features=16) 159 | 160 | self.up4 = _UpProjection( 161 | num_input_features=block_channel[3], num_output_features=16) 162 | 163 | self.conv = nn.Conv2d( 164 | num_features, num_features, kernel_size=5, stride=1, padding=2, bias=False) 165 | self.bn = nn.BatchNorm2d(num_features) 166 | 167 | 168 | def forward(self, x_block1, x_block2, x_block3, x_block4, size): 169 | x_m1 = self.up1(x_block1, size) 170 | x_m2 = self.up2(x_block2, size) 171 | x_m3 = self.up3(x_block3, size) 172 | x_m4 = self.up4(x_block4, size) 173 | 174 | x = self.bn(self.conv(torch.cat((x_m1, x_m2, x_m3, x_m4), 1))) 175 | x = F.relu(x) 176 | 177 | return x 178 | -------------------------------------------------------------------------------- /models/net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models import modules 4 | from models.refinenet_dict import refinenet_dict 5 | 6 | 7 | def cubes_2_maps(cubes): 8 | b, c, d, h, w = cubes.shape 9 | cubes = cubes.permute(0, 2, 1, 3, 4) 10 | 11 | return cubes.contiguous().view(b*d, c, h, w), b, d 12 | 13 | 14 | class model(nn.Module): 15 | def __init__(self, Encoder, num_features, block_channel, refinenet): 16 | 17 | super(model, self).__init__() 18 | 19 | self.E = Encoder 20 | self.D = modules.D(num_features) 21 | self.MFF = modules.MFF(block_channel) 22 | self.R = refinenet_dict[refinenet](block_channel) 23 | 24 | def forward(self, x): 25 | x, b, d = cubes_2_maps(x) 26 | x_block1, x_block2, x_block3, x_block4 = self.E(x) 27 | x_decoder = self.D(x_block1, x_block2, x_block3, x_block4) 28 | x_mff = self.MFF(x_block1, x_block2, x_block3, x_block4,[x_decoder.size(2),x_decoder.size(3)]) 29 | out = self.R(torch.cat((x_decoder, x_mff), 1), b, d) 30 | 31 | return out 32 | 33 | class C_C3D_1(nn.Module): 34 | 35 | def __init__(self, num_classes=2, init_width=32, input_channels=1): 36 | self.inplanes = init_width 37 | super(C_C3D_1, self).__init__() 38 | self.conv1 = nn.Sequential( 39 | nn.Conv3d(input_channels, init_width, kernel_size=5, stride=2, padding=2, bias=False), 40 | nn.BatchNorm3d(init_width), 41 | nn.ReLU(inplace=True), 42 | nn.MaxPool3d(kernel_size=3, stride=2, padding=1) 43 | ) 44 | self.conv2 = nn.Sequential( 45 | nn.Conv3d(init_width, init_width*2, kernel_size=3, stride=1, padding=1, bias=False), 46 | nn.BatchNorm3d(init_width*2), 47 | nn.ReLU(inplace=True), 48 | nn.MaxPool3d(kernel_size=3, stride=2, padding=1) 49 | ) 50 | 51 | init_width = init_width * 2 52 | self.conv3 = nn.Sequential( 53 | nn.Conv3d(init_width, init_width * 2, kernel_size=3, stride=1, padding=1, bias=False), 54 | nn.BatchNorm3d(init_width * 2), 55 | nn.ReLU(inplace=True), 56 | nn.MaxPool3d(kernel_size=3, stride=2, padding=1) 57 | ) 58 | 59 | init_width = init_width * 2 60 | self.conv4 = nn.Sequential( 61 | nn.Conv3d(init_width, init_width * 2, kernel_size=3, stride=1, padding=1, bias=False), 62 | nn.BatchNorm3d(init_width * 2), 63 | nn.ReLU(inplace=True), 64 | nn.MaxPool3d(kernel_size=3, stride=2, padding=1) 65 | ) 66 | 67 | self.avgpool = nn.AdaptiveAvgPool3d(1) 68 | self.fc = nn.Linear(init_width * 2, num_classes) 69 | 70 | def forward(self, x): 71 | x = self.conv1(x) 72 | x = self.conv2(x) 73 | x = self.conv3(x) 74 | x = self.conv4(x) 75 | x = self.avgpool(x) 76 | x = x.view(x.size(0), -1) 77 | x = self.fc(x) 78 | 79 | return x 80 | 81 | class C_C3D_2(nn.Module): 82 | 83 | def __init__(self, num_classes=2, init_width=32, input_channels=1): 84 | self.inplanes = init_width 85 | super(C_C3D_2, self).__init__() 86 | self.conv1 = nn.Sequential( 87 | nn.Conv3d(input_channels, init_width, kernel_size=5, stride=2, padding=2, bias=False), 88 | nn.BatchNorm3d(init_width), 89 | nn.ReLU(inplace=True), 90 | nn.MaxPool3d(kernel_size=3, stride=2, padding=1) 91 | ) 92 | self.conv2 = nn.Sequential( 93 | nn.Conv3d(init_width, init_width*2, kernel_size=3, stride=1, padding=1, bias=False), 94 | nn.BatchNorm3d(init_width*2), 95 | nn.ReLU(inplace=True), 96 | nn.MaxPool3d(kernel_size=3, stride=2, padding=1) 97 | ) 98 | 99 | init_width = init_width * 2 100 | self.conv3 = nn.Sequential( 101 | nn.Conv3d(init_width, init_width * 2, kernel_size=3, stride=1, padding=1, bias=False), 102 | nn.BatchNorm3d(init_width * 2), 103 | nn.ReLU(inplace=True), 104 | nn.MaxPool3d(kernel_size=3, stride=2, padding=1) 105 | ) 106 | 107 | self.avgpool = nn.AdaptiveAvgPool3d(1) 108 | self.fc = nn.Linear(init_width * 2, num_classes) 109 | 110 | def forward(self, x): 111 | x = self.conv1(x) 112 | x = self.conv2(x) 113 | x = self.conv3(x) 114 | x = self.avgpool(x) 115 | x = x.view(x.size(0), -1) 116 | x = self.fc(x) 117 | 118 | return x -------------------------------------------------------------------------------- /models/refinenet_dict.py: -------------------------------------------------------------------------------- 1 | from models.R_CLSTM_modules import (R_CLSTM_1, R_CLSTM_2, R_CLSTM_3, 2 | R_CLSTM_4, R_CLSTM_5, R_CLSTM_6, 3 | R_CLSTM_7, R_CLSTM_8, R_CLSTM_9) 4 | 5 | 6 | 7 | refinenet_dict = { 8 | 'R_CLSTM_1': R_CLSTM_1, 9 | 'R_CLSTM_2': R_CLSTM_2, 10 | 'R_CLSTM_3': R_CLSTM_3, 11 | 'R_CLSTM_4': R_CLSTM_4, 12 | 'R_CLSTM_5': R_CLSTM_5, 13 | 'R_CLSTM_6': R_CLSTM_6, 14 | 'R_CLSTM_7': R_CLSTM_7, 15 | 'R_CLSTM_8': R_CLSTM_8, 16 | 'R_CLSTM_9': R_CLSTM_9 17 | } -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch.nn.functional as F 5 | import torch 6 | import numpy as np 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 9 | 'resnet152'] 10 | 11 | 12 | model_urls = { 13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 18 | } 19 | 20 | 21 | def conv3x3(in_planes, out_planes, stride=1): 22 | "3x3 convolution with padding" 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 24 | padding=1, bias=False) 25 | 26 | 27 | class BasicBlock(nn.Module): 28 | expansion = 1 29 | 30 | def __init__(self, inplanes, planes, stride=1, downsample=None): 31 | super(BasicBlock, self).__init__() 32 | self.conv1 = conv3x3(inplanes, planes, stride) 33 | self.bn1 = nn.BatchNorm2d(planes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes) 36 | self.bn2 = nn.BatchNorm2d(planes) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | residual = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | 50 | if self.downsample is not None: 51 | residual = self.downsample(x) 52 | 53 | out += residual 54 | out = self.relu(out) 55 | 56 | return out 57 | 58 | 59 | class Bottleneck(nn.Module): 60 | expansion = 4 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None): 63 | super(Bottleneck, self).__init__() 64 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 65 | self.bn1 = nn.BatchNorm2d(planes) 66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 67 | padding=1, bias=False) 68 | self.bn2 = nn.BatchNorm2d(planes) 69 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 70 | self.bn3 = nn.BatchNorm2d(planes * 4) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x): 76 | residual = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv3(out) 87 | out = self.bn3(out) 88 | 89 | if self.downsample is not None: 90 | residual = self.downsample(x) 91 | 92 | out += residual 93 | out = self.relu(out) 94 | 95 | return out 96 | 97 | 98 | class ResNet(nn.Module): 99 | 100 | def __init__(self, block, layers, num_classes=1000): 101 | self.inplanes = 64 102 | super(ResNet, self).__init__() 103 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 104 | bias=False) 105 | self.bn1 = nn.BatchNorm2d(64) 106 | self.relu = nn.ReLU(inplace=True) 107 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 108 | self.layer1 = self._make_layer(block, 64, layers[0]) 109 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 110 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 111 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 112 | self.avgpool = nn.AvgPool2d(7, stride=1) 113 | self.fc = nn.Linear(512 * block.expansion, num_classes) 114 | 115 | for m in self.modules(): 116 | if isinstance(m, nn.Conv2d): 117 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 118 | m.weight.data.normal_(0, math.sqrt(2. / n)) 119 | elif isinstance(m, nn.BatchNorm2d): 120 | m.weight.data.fill_(1) 121 | m.bias.data.zero_() 122 | 123 | def _make_layer(self, block, planes, blocks, stride=1): 124 | downsample = None 125 | # stride=1, 64 != 256 126 | if stride != 1 or self.inplanes != planes * block.expansion: 127 | downsample = nn.Sequential( 128 | nn.Conv2d(self.inplanes, planes * block.expansion, 129 | kernel_size=1, stride=stride, bias=False), 130 | nn.BatchNorm2d(planes * block.expansion), 131 | ) 132 | 133 | layers = [] 134 | layers.append(block(self.inplanes, planes, stride, downsample)) 135 | self.inplanes = planes * block.expansion 136 | for i in range(1, blocks): 137 | layers.append(block(self.inplanes, planes)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def forward(self, x): 142 | x = self.conv1(x) 143 | x = self.bn1(x) 144 | x = self.relu(x) 145 | x = self.maxpool(x) 146 | 147 | x = self.layer1(x) 148 | x = self.layer2(x) 149 | x = self.layer3(x) 150 | x = self.layer4(x) 151 | 152 | x = self.avgpool(x) 153 | x = x.view(x.size(0), -1) 154 | x = self.fc(x) 155 | 156 | return x 157 | 158 | def resnet18(pretrained=True, **kwargs): 159 | """Constructs a ResNet-18 model. 160 | Args: 161 | pretrained (bool): If True, returns a model pre-trained on ImageNet 162 | """ 163 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 164 | if pretrained: 165 | from collections import OrderedDict 166 | pretrained_state = model_zoo.load_url(model_urls['resnet18']) 167 | model_state = model.state_dict() 168 | selected_state = OrderedDict() 169 | for k, v in pretrained_state.items(): 170 | if k in model_state and v.size() == model_state[k].size(): 171 | #print('pretrain..',k) 172 | selected_state[k] = v 173 | model_state.update(selected_state) 174 | model.load_state_dict(model_state) 175 | return model 176 | 177 | 178 | def resnet34(pretrained=False, **kwargs): 179 | """Constructs a ResNet-34 model. 180 | Args: 181 | pretrained (bool): If True, returns a model pre-trained on ImageNet 182 | """ 183 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 184 | if pretrained: 185 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 186 | return model 187 | 188 | 189 | def resnet50(pretrained=False, **kwargs): 190 | """Constructs a ResNet-50 model. 191 | Args: 192 | pretrained (bool): If True, returns a model pre-trained on ImageNet 193 | """ 194 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 195 | if pretrained: 196 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 197 | return model 198 | 199 | 200 | def resnet101(pretrained=False, **kwargs): 201 | """Constructs a ResNet-101 model. 202 | Args: 203 | pretrained (bool): If True, returns a model pre-trained on ImageNet 204 | """ 205 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 206 | if pretrained: 207 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 208 | return model 209 | 210 | 211 | def resnet152(pretrained=False, **kwargs): 212 | """Constructs a ResNet-152 model. 213 | Args: 214 | pretrained (bool): If True, returns a model pre-trained on ImageNet 215 | """ 216 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 217 | if pretrained: 218 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 219 | return model 220 | -------------------------------------------------------------------------------- /models/senet.py: -------------------------------------------------------------------------------- 1 | """ 2 | ResNet code gently borrowed from 3 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 4 | """ 5 | from collections import OrderedDict 6 | import math 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | from torch.utils import model_zoo 11 | import copy 12 | import numpy as np 13 | 14 | __all__ = ['SENet', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 15 | 'se_resnext50_32x4d', 'se_resnext101_32x4d'] 16 | 17 | pretrained_settings = { 18 | 'senet154': { 19 | 'imagenet': { 20 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth', 21 | 'input_space': 'RGB', 22 | 'input_size': [3, 224, 224], 23 | 'input_range': [0, 1], 24 | 'mean': [0.485, 0.456, 0.406], 25 | 'std': [0.229, 0.224, 0.225], 26 | 'num_classes': 1000 27 | } 28 | }, 29 | 'se_resnet50': { 30 | 'imagenet': { 31 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth', 32 | 'input_space': 'RGB', 33 | 'input_size': [3, 224, 224], 34 | 'input_range': [0, 1], 35 | 'mean': [0.485, 0.456, 0.406], 36 | 'std': [0.229, 0.224, 0.225], 37 | 'num_classes': 1000 38 | } 39 | }, 40 | 'se_resnet101': { 41 | 'imagenet': { 42 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth', 43 | 'input_space': 'RGB', 44 | 'input_size': [3, 224, 224], 45 | 'input_range': [0, 1], 46 | 'mean': [0.485, 0.456, 0.406], 47 | 'std': [0.229, 0.224, 0.225], 48 | 'num_classes': 1000 49 | } 50 | }, 51 | 'se_resnet152': { 52 | 'imagenet': { 53 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth', 54 | 'input_space': 'RGB', 55 | 'input_size': [3, 224, 224], 56 | 'input_range': [0, 1], 57 | 'mean': [0.485, 0.456, 0.406], 58 | 'std': [0.229, 0.224, 0.225], 59 | 'num_classes': 1000 60 | } 61 | }, 62 | 'se_resnext50_32x4d': { 63 | 'imagenet': { 64 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth', 65 | 'input_space': 'RGB', 66 | 'input_size': [3, 224, 224], 67 | 'input_range': [0, 1], 68 | 'mean': [0.485, 0.456, 0.406], 69 | 'std': [0.229, 0.224, 0.225], 70 | 'num_classes': 1000 71 | } 72 | }, 73 | 'se_resnext101_32x4d': { 74 | 'imagenet': { 75 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth', 76 | 'input_space': 'RGB', 77 | 'input_size': [3, 224, 224], 78 | 'input_range': [0, 1], 79 | 'mean': [0.485, 0.456, 0.406], 80 | 'std': [0.229, 0.224, 0.225], 81 | 'num_classes': 1000 82 | } 83 | }, 84 | } 85 | 86 | 87 | class SEModule(nn.Module): 88 | 89 | def __init__(self, channels, reduction): 90 | super(SEModule, self).__init__() 91 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 92 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, 93 | padding=0) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, 96 | padding=0) 97 | self.sigmoid = nn.Sigmoid() 98 | 99 | def forward(self, x): 100 | module_input = x 101 | x = self.avg_pool(x) 102 | x = self.fc1(x) 103 | x = self.relu(x) 104 | x = self.fc2(x) 105 | x = self.sigmoid(x) 106 | return module_input * x 107 | 108 | 109 | class Bottleneck(nn.Module): 110 | """ 111 | Base class for bottlenecks that implements `forward()` method. 112 | """ 113 | def forward(self, x): 114 | residual = x 115 | 116 | out = self.conv1(x) 117 | out = self.bn1(out) 118 | out = self.relu(out) 119 | 120 | out = self.conv2(out) 121 | out = self.bn2(out) 122 | out = self.relu(out) 123 | 124 | out = self.conv3(out) 125 | out = self.bn3(out) 126 | 127 | if self.downsample is not None: 128 | residual = self.downsample(x) 129 | 130 | out = self.se_module(out) + residual 131 | out = self.relu(out) 132 | 133 | return out 134 | 135 | 136 | class SEBottleneck(Bottleneck): 137 | """ 138 | Bottleneck for SENet154. 139 | """ 140 | expansion = 4 141 | 142 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 143 | downsample=None): 144 | 145 | super(SEBottleneck, self).__init__() 146 | self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) 147 | self.bn1 = nn.BatchNorm2d(planes * 2) 148 | self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3, 149 | stride=stride, padding=1, groups=groups, 150 | bias=False) 151 | self.bn2 = nn.BatchNorm2d(planes * 4) 152 | self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, 153 | bias=False) 154 | self.bn3 = nn.BatchNorm2d(planes * 4) 155 | self.relu = nn.ReLU(inplace=True) 156 | self.se_module = SEModule(planes * 4, reduction=reduction) 157 | self.downsample = downsample 158 | self.stride = stride 159 | 160 | 161 | class SEResNetBottleneck(Bottleneck): 162 | """ 163 | ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe 164 | implementation and uses `stride=stride` in `conv1` and not in `conv2` 165 | (the latter is used in the torchvision implementation of ResNet). 166 | """ 167 | expansion = 4 168 | 169 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 170 | downsample=None): 171 | super(SEResNetBottleneck, self).__init__() 172 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False, 173 | stride=stride) 174 | self.bn1 = nn.BatchNorm2d(planes) 175 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, 176 | groups=groups, bias=False) 177 | self.bn2 = nn.BatchNorm2d(planes) 178 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 179 | self.bn3 = nn.BatchNorm2d(planes * 4) 180 | self.relu = nn.ReLU(inplace=True) 181 | self.se_module = SEModule(planes * 4, reduction=reduction) 182 | self.downsample = downsample 183 | self.stride = stride 184 | 185 | 186 | class SEResNeXtBottleneck(Bottleneck): 187 | """ 188 | ResNeXt bottleneck type C with a Squeeze-and-Excitation module. 189 | """ 190 | expansion = 4 191 | 192 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 193 | downsample=None, base_width=4): 194 | super(SEResNeXtBottleneck, self).__init__() 195 | 196 | width = int(planes * base_width / 64) * groups 197 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, 198 | stride=1) 199 | self.bn1 = nn.BatchNorm2d(width) 200 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, 201 | padding=1, groups=groups, bias=False) 202 | self.bn2 = nn.BatchNorm2d(width) 203 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) 204 | self.bn3 = nn.BatchNorm2d(planes * 4) 205 | self.relu = nn.ReLU(inplace=True) 206 | self.se_module = SEModule(planes * 4, reduction=reduction) 207 | self.downsample = downsample 208 | self.stride = stride 209 | 210 | 211 | class SENet(nn.Module): 212 | 213 | def __init__(self, block, layers, groups, reduction, dropout_p=0.2, 214 | inplanes=128, input_3x3=True, downsample_kernel_size=3, 215 | downsample_padding=1, num_classes=1000): 216 | """ 217 | Parameters 218 | ---------- 219 | block (nn.Module): Bottleneck class. 220 | - For SENet154: SEBottleneck 221 | - For SE-ResNet models: SEResNetBottleneck 222 | - For SE-ResNeXt models: SEResNeXtBottleneck 223 | layers (list of ints): Number of residual blocks for 4 layers of the 224 | network (layer1...layer4). 225 | groups (int): Number of groups for the 3x3 convolution in each 226 | bottleneck block. 227 | - For SENet154: 64 228 | - For SE-ResNet models: 1 229 | - For SE-ResNeXt models: 32 230 | reduction (int): Reduction ratio for Squeeze-and-Excitation modules. 231 | - For all models: 16 232 | dropout_p (float or None): Drop probability for the Dropout layer. 233 | If `None` the Dropout layer is not used. 234 | - For SENet154: 0.2 235 | - For SE-ResNet models: None 236 | - For SE-ResNeXt models: None 237 | inplanes (int): Number of input channels for layer1. 238 | - For SENet154: 128 239 | - For SE-ResNet models: 64 240 | - For SE-ResNeXt models: 64 241 | input_3x3 (bool): If `True`, use three 3x3 convolutions instead of 242 | a single 7x7 convolution in layer0. 243 | - For SENet154: True 244 | - For SE-ResNet models: False 245 | - For SE-ResNeXt models: False 246 | downsample_kernel_size (int): Kernel size for downsampling convolutions 247 | in layer2, layer3 and layer4. 248 | - For SENet154: 3 249 | - For SE-ResNet models: 1 250 | - For SE-ResNeXt models: 1 251 | downsample_padding (int): Padding for downsampling convolutions in 252 | layer2, layer3 and layer4. 253 | - For SENet154: 1 254 | - For SE-ResNet models: 0 255 | - For SE-ResNeXt models: 0 256 | num_classes (int): Number of outputs in `last_linear` layer. 257 | - For all models: 1000 258 | """ 259 | super(SENet, self).__init__() 260 | self.inplanes = inplanes 261 | if input_3x3: 262 | layer0_modules = [ 263 | ('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1, 264 | bias=False)), 265 | ('bn1', nn.BatchNorm2d(64)), 266 | ('relu1', nn.ReLU(inplace=True)), 267 | ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, 268 | bias=False)), 269 | ('bn2', nn.BatchNorm2d(64)), 270 | ('relu2', nn.ReLU(inplace=True)), 271 | ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, 272 | bias=False)), 273 | ('bn3', nn.BatchNorm2d(inplanes)), 274 | ('relu3', nn.ReLU(inplace=True)), 275 | ] 276 | else: 277 | layer0_modules = [ 278 | ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2, 279 | padding=3, bias=False)), 280 | ('bn1', nn.BatchNorm2d(inplanes)), 281 | ('relu1', nn.ReLU(inplace=True)), 282 | ] 283 | # To preserve compatibility with Caffe weights `ceil_mode=True` 284 | # is used instead of `padding=1`. 285 | layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2, 286 | ceil_mode=True))) 287 | self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) 288 | self.layer1 = self._make_layer( 289 | block, 290 | planes=64, 291 | blocks=layers[0], 292 | groups=groups, 293 | reduction=reduction, 294 | downsample_kernel_size=1, 295 | downsample_padding=0 296 | ) 297 | self.layer2 = self._make_layer( 298 | block, 299 | planes=128, 300 | blocks=layers[1], 301 | stride=2, 302 | groups=groups, 303 | reduction=reduction, 304 | downsample_kernel_size=downsample_kernel_size, 305 | downsample_padding=downsample_padding 306 | ) 307 | self.layer3 = self._make_layer( 308 | block, 309 | planes=256, 310 | blocks=layers[2], 311 | stride=2, 312 | groups=groups, 313 | reduction=reduction, 314 | downsample_kernel_size=downsample_kernel_size, 315 | downsample_padding=downsample_padding 316 | ) 317 | self.layer4 = self._make_layer( 318 | block, 319 | planes=512, 320 | blocks=layers[3], 321 | stride=2, 322 | groups=groups, 323 | reduction=reduction, 324 | downsample_kernel_size=downsample_kernel_size, 325 | downsample_padding=downsample_padding 326 | ) 327 | self.avg_pool = nn.AvgPool2d(7, stride=1) 328 | self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None 329 | self.last_linear = nn.Linear(512 * block.expansion, num_classes) 330 | 331 | 332 | def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, 333 | downsample_kernel_size=1, downsample_padding=0): 334 | 335 | downsample = None 336 | 337 | if stride != 1 or self.inplanes != planes * block.expansion: 338 | 339 | downsample = nn.Sequential( 340 | nn.Conv2d(self.inplanes, planes * block.expansion, 341 | kernel_size=downsample_kernel_size, stride=stride, 342 | padding=downsample_padding, bias=False), 343 | nn.BatchNorm2d(planes * block.expansion), 344 | ) 345 | 346 | layers = [] 347 | layers.append(block(self.inplanes, planes, groups, reduction, stride, 348 | downsample)) 349 | self.inplanes = planes * block.expansion 350 | for i in range(1, blocks): 351 | layers.append(block(self.inplanes, planes, groups, reduction)) 352 | 353 | return nn.Sequential(*layers) 354 | 355 | 356 | def features(self, x): 357 | x = self.layer0(x) 358 | x = self.layer1(x) 359 | x = self.layer2(x) 360 | x = self.layer3(x) 361 | x = self.layer4(x) 362 | 363 | return x 364 | 365 | 366 | def logits(self, x): 367 | x = self.avg_pool(x) 368 | if self.dropout is not None: 369 | x = self.dropout(x) 370 | x = x.view(x.size(0), -1) 371 | x = self.last_linear(x) 372 | return x 373 | 374 | def forward(self, x,x_): 375 | x = self.features(x) 376 | x = self.logits(x) 377 | return x 378 | 379 | def initialize_pretrained_model(model, num_classes, settings): 380 | assert num_classes == settings['num_classes'], \ 381 | 'num_classes should be {}, but is {}'.format( 382 | settings['num_classes'], num_classes) 383 | model.load_state_dict(model_zoo.load_url(settings['url'])) 384 | model.input_space = settings['input_space'] 385 | model.input_size = settings['input_size'] 386 | model.input_range = settings['input_range'] 387 | model.mean = settings['mean'] 388 | model.std = settings['std'] 389 | 390 | 391 | def senet154(num_classes=1000, pretrained='imagenet'): 392 | model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16, 393 | dropout_p=0.2, num_classes=num_classes) 394 | if pretrained is not None: 395 | settings = pretrained_settings['senet154'][pretrained] 396 | initialize_pretrained_model(model, num_classes, settings) 397 | return model 398 | 399 | def se_resnet50(num_classes=1000, pretrained='imagenet'): 400 | model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16, 401 | dropout_p=0.2, inplanes=64, input_3x3=False, 402 | downsample_kernel_size=1, downsample_padding=0, 403 | num_classes=num_classes) 404 | if pretrained is not None: 405 | settings = pretrained_settings['se_resnet50'][pretrained] 406 | initialize_pretrained_model(model, num_classes, settings) 407 | return model 408 | 409 | 410 | def se_resnet101(num_classes=1000, pretrained='imagenet'): 411 | model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16, 412 | dropout_p=0.2, inplanes=64, input_3x3=False, 413 | downsample_kernel_size=1, downsample_padding=0, 414 | num_classes=num_classes) 415 | if pretrained is not None: 416 | settings = pretrained_settings['se_resnet101'][pretrained] 417 | initialize_pretrained_model(model, num_classes, settings) 418 | return model 419 | 420 | 421 | def se_resnet152(num_classes=1000, pretrained='imagenet'): 422 | model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16, 423 | dropout_p=0.2, inplanes=64, input_3x3=False, 424 | downsample_kernel_size=1, downsample_padding=0, 425 | num_classes=num_classes) 426 | if pretrained is not None: 427 | settings = pretrained_settings['se_resnet152'][pretrained] 428 | initialize_pretrained_model(model, num_classes, settings) 429 | return model 430 | 431 | 432 | def se_resnext50_32x4d(num_classes=1000, pretrained='imagenet'): 433 | model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16, 434 | dropout_p=0.2, inplanes=64, input_3x3=False, 435 | downsample_kernel_size=1, downsample_padding=0, 436 | num_classes=num_classes) 437 | if pretrained is not None: 438 | settings = pretrained_settings['se_resnext50_32x4d'][pretrained] 439 | initialize_pretrained_model(model, num_classes, settings) 440 | return model 441 | 442 | 443 | def se_resnext101_32x4d(num_classes=1000, pretrained='imagenet'): 444 | model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16, 445 | dropout_p=0.2, inplanes=64, input_3x3=False, 446 | downsample_kernel_size=1, downsample_padding=0, 447 | num_classes=num_classes) 448 | if pretrained is not None: 449 | settings = pretrained_settings['se_resnext101_32x4d'][pretrained] 450 | initialize_pretrained_model(model, num_classes, settings) 451 | return model 452 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainopt import _get_train_opt 2 | from .testopt import _get_test_opt 3 | 4 | def get_args(mode): 5 | args = None 6 | if mode == 'train': 7 | args = _get_train_opt() 8 | elif mode == 'test': 9 | args = _get_test_opt() 10 | else: 11 | raise ValueError("Invalid mode selection!") 12 | 13 | return args 14 | -------------------------------------------------------------------------------- /options/testopt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def _get_test_opt(): 4 | parser = argparse.ArgumentParser(description = 'Evaluate performance of SARPN on NYU-D v2 test set') 5 | parser.add_argument('--testlist_path', required=True, help='the path of testlist') 6 | parser.add_argument('--root_path', required=True, help="the root path of dataset") 7 | parser.add_argument('--backbone', type=str, default='resnet18') 8 | parser.add_argument('--refinenet', type=str, default='R_CLSTM_5') 9 | parser.add_argument('--batch_size', type=int, default=1, help='testing batch size') 10 | parser.add_argument('--loadckpt', required=True, help="the path of the loaded model") 11 | # parse arguments 12 | return parser.parse_args() 13 | -------------------------------------------------------------------------------- /options/trainopt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def _get_train_opt(): 4 | parser = argparse.ArgumentParser(description = 'Monocular Depth Estimation') 5 | parser.add_argument('--trainlist_path', required=True, help='the path of trainlist', default='/media3/x00532679/project/ST-SARPN/data_list/raw_nyu_v2_250k/raw_nyu_v2_250k_fps30_fl5_op0_end_train.json') 6 | parser.add_argument('--root_path', required=True, help="the root path of dataset", default='/media3/x00532679/data/') 7 | parser.add_argument('--batch_size', type=int, default=16, help='training batch size') 8 | parser.add_argument('--epochs', default=20, type=int, help='number of epochs') 9 | parser.add_argument('--backbone', type=str, default='resnet18') 10 | parser.add_argument('--refinenet', type=str, default='R_CLSTM_5') 11 | parser.add_argument('--logdir', required=True, help="the directory to save logs and checkpoints", default='./checkpoint') 12 | parser.add_argument('--checkpoint_dir', required=True, help="the directory to save the checkpoints", default='./log_224') 13 | parser.add_argument('--loadckpt', type=str) 14 | parser.add_argument('--overlap', type=int, default=0) 15 | parser.add_argument('--use_cuda', type=bool, default=True) 16 | parser.add_argument('--devices', type=str, default='0') 17 | parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate') 18 | parser.add_argument('--resume', action='store_true',default=False, help='continue training the model') 19 | parser.add_argument('--momentum', default=0.9, type=float, help='Momentum parameter used in the Optimizer.') 20 | parser.add_argument('--epsilon', default=0.001, type=float, help='epsilon') 21 | parser.add_argument('--optimizer_name', default="adam", type=str, help="Optimizer selection") 22 | parser.add_argument('--weight_decay', default=1e-5, type=float, help='weight decay') 23 | parser.add_argument('--do_summary', action='store_true', default=False, help='whether do summary or not') 24 | parser.add_argument('--pretrained_dir', required=False,type=str, help="the path of pretrained models") 25 | return parser.parse_args() 26 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import numpy as np 5 | import torch.nn as nn 6 | import torch.nn.parallel 7 | import torch.optim as optim 8 | import torch.backends.cudnn as cudnn 9 | from torch.autograd import Variable 10 | from tensorboardX import SummaryWriter 11 | from utils import * 12 | from options import get_args 13 | from dataloader import nyudv2_dataloader 14 | from models.loss import cal_spatial_loss, cal_temporal_loss 15 | from models.backbone_dict import backbone_dict 16 | from models import modules 17 | from models import net 18 | 19 | cudnn.benchmark = True 20 | args = get_args('train') 21 | 22 | os.environ['CUDA_VISIBLE_DEVICES'] = args.devices 23 | 24 | # Create folder 25 | makedir(args.checkpoint_dir) 26 | makedir(args.logdir) 27 | 28 | # creat summary logger 29 | logger = SummaryWriter(args.logdir) 30 | 31 | # dataset, dataloader 32 | TrainImgLoader = nyudv2_dataloader.getTrainingData_NYUDV2(args.batch_size, args.trainlist_path, args.root_path) 33 | # model, optimizer 34 | device = 'cuda' if torch.cuda.is_available() and args.use_cuda else 'cpu' 35 | 36 | backbone = backbone_dict[args.backbone]() 37 | Encoder = modules.E_resnet(backbone) 38 | 39 | if args.backbone in ['resnet50']: 40 | model = net.model(Encoder, num_features=2048, block_channel=[256, 512, 1024, 2048], refinenet=args.refinenet) 41 | elif args.backbone in ['resnet18', 'resnet34']: 42 | model = net.model(Encoder, num_features=512, block_channel=[64, 128, 256, 512], refinenet=args.refinenet) 43 | 44 | model = nn.DataParallel(model).cuda() 45 | 46 | disc = net.C_C3D_1().cuda() 47 | 48 | optimizer = build_optimizer(model = model, 49 | learning_rate=args.lr, 50 | optimizer_name=args.optimizer_name, 51 | weight_decay = args.weight_decay, 52 | epsilon=args.epsilon, 53 | momentum=args.momentum 54 | ) 55 | 56 | start_epoch = 0 57 | 58 | if args.resume: 59 | all_saved_ckpts = [ckpt for ckpt in os.listdir(args.checkpoint_dir) if ckpt.endswith(".pth.tar")] 60 | print(all_saved_ckpts) 61 | all_saved_ckpts = sorted(all_saved_ckpts, key=lambda x:int(x.split('_')[-1].split('.')[0])) 62 | loadckpt = os.path.join(args.checkpoint_dir, all_saved_ckpts[-1]) 63 | start_epoch = int(all_saved_ckpts[-1].split('_')[-1].split('.')[0]) 64 | print("loading the lastest model in checkpoint_dir: {}".format(loadckpt)) 65 | state_dict = torch.load(loadckpt) 66 | model.load_state_dict(state_dict) 67 | elif args.loadckpt is not None: 68 | print("loading model {}".format(args.loadckpt)) 69 | start_epoch = args.loadckpt.split('_')[-1].split('.')[0] 70 | state_dict = torch.load(args.loadckpt) 71 | model.load_state_dict(state_dict) 72 | else: 73 | print("start at epoch {}".format(start_epoch)) 74 | 75 | def train(): 76 | for epoch in range(start_epoch, args.epochs): 77 | adjust_learning_rate(optimizer, epoch, args.lr) 78 | batch_time = AverageMeter() 79 | losses = AverageMeter() 80 | model.train() 81 | end = time.time() 82 | for batch_idx, sample in enumerate(TrainImgLoader): 83 | 84 | image, depth = sample[0], sample[1]#(b,c,d,w,h) 85 | 86 | depth = depth.cuda() 87 | image = image.cuda() 88 | image = torch.autograd.Variable(image) 89 | depth = torch.autograd.Variable(depth) 90 | optimizer.zero_grad() 91 | global_step = len(TrainImgLoader) * epoch + batch_idx 92 | gt_depth = depth 93 | pred_depth = model(image)#(b, c, d, h, w) 94 | 95 | # Calculate the total loss 96 | spatial_losses=[] 97 | for seq_idx in range(image.size(2)): 98 | spatial_loss = cal_spatial_loss(pred_depth[:,:,seq_idx,:,:], gt_depth[:,:,seq_idx,:,:]) 99 | spatial_losses.append(spatial_loss) 100 | spatial_loss = sum(spatial_losses) 101 | 102 | pred_cls = disc(pred_depth) 103 | gt_cls = disc(gt_depth) 104 | temporal_loss = cal_temporal_loss(pred_cls, gt_cls) 105 | 106 | loss = spatial_loss + 0.1 * temporal_loss 107 | 108 | losses.update(loss.item(), image.size(0)) 109 | loss.backward() 110 | optimizer.step() 111 | 112 | batch_time.update(time.time() - end) 113 | end = time.time() 114 | 115 | batchSize = depth.size(0) 116 | 117 | print(('Epoch: [{0}][{1}/{2}]\t' 118 | 'Time {batch_time.val:.3f} ({batch_time.sum:.3f})\t' 119 | 'Loss {loss.val:.4f} ({loss.avg:.4f})' 120 | .format(epoch, batch_idx, len(TrainImgLoader), batch_time=batch_time, loss=losses))) 121 | 122 | if (epoch+1)%1 == 0: 123 | save_checkpoint(model.state_dict(), filename=args.checkpoint_dir + "ResNet18_checkpoints_small_" + str(epoch + 1) + ".pth.tar") 124 | 125 | if __name__ == '__main__': 126 | train() 127 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import matplotlib 7 | import matplotlib.cm 8 | import torchvision.utils as vutils 9 | from models_small.loss import Sobel 10 | import matplotlib.pyplot as plt 11 | cmap = plt.cm.viridis 12 | 13 | 14 | def draw_losses(logger, loss, global_step): 15 | name = "train_loss" 16 | logger.add_scalar(name, loss, global_step) 17 | 18 | def draw_images(logger, all_draw_image, global_step): 19 | for image_name, images in all_draw_image.items(): 20 | if images.shape[1] == 1: 21 | images = colormap(images) 22 | elif images.shape[1] == 3: 23 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 24 | 'std': [0.229, 0.224, 0.225]} 25 | for channel in np.arange(images.shape[1]): 26 | images[:, channel, :, :] = images[:, channel, :, :] * __imagenet_stats["std"][channel] + __imagenet_stats["mean"][channel] 27 | 28 | if len(images.shape) == 3: 29 | images = images[np.newaxis, :, :, :] 30 | if images.shape[0]>4: 31 | images = images[:4, :, :, :] 32 | 33 | logger.add_image(image_name, images, global_step) 34 | 35 | def save_image(img_merge, filename): 36 | img_merge = Image.fromarray(img_merge.astype('uint8')) 37 | img_merge.save(filename) 38 | 39 | 40 | def colored_depthmap(depth, d_min=None, d_max=None): 41 | if d_min is None: 42 | d_min = np.min(depth) 43 | if d_max is None: 44 | d_max = np.max(depth) 45 | depth_relative = (depth - d_min) / (d_max - d_min) 46 | return 255 * cmap(depth_relative)[:,:,:3] # H, W, C 47 | 48 | 49 | def merge_into_row(input, depth_target, depth_pred): 50 | rgb = np.transpose(input.cpu().numpy(), (1,2,0)) # H, W, C 51 | depth_target_cpu = np.squeeze(depth_target.cpu().numpy()) 52 | depth_pred_cpu = np.squeeze(depth_pred.data.cpu().numpy()) 53 | 54 | d_min = min(np.min(depth_target_cpu), np.min(depth_pred_cpu)) 55 | d_max = max(np.max(depth_target_cpu), np.max(depth_pred_cpu)) 56 | depth_target_col = colored_depthmap(depth_target_cpu, d_min, d_max) 57 | depth_pred_col = colored_depthmap(depth_pred_cpu, d_min, d_max) 58 | 59 | # img_merge = np.hstack([rgb, depth_target_col, depth_pred_col]) 60 | # return img_merge 61 | return rgb, depth_target_col, depth_pred_col 62 | 63 | 64 | def makedir(directory): 65 | if not os.path.exists(directory): 66 | os.makedirs(directory) 67 | 68 | def adjust_learning_rate(optimizer, epoch, init_lr): 69 | 70 | lr = init_lr * (0.1 ** (epoch // 5)) 71 | 72 | for param_group in optimizer.param_groups: 73 | param_group['lr'] = lr 74 | 75 | class AverageMeter(object): 76 | def __init__(self): 77 | self.reset() 78 | 79 | def reset(self): 80 | self.val = 0 81 | self.avg = 0 82 | self.sum = 0 83 | self.count = 0 84 | 85 | def update(self, val, n=1): 86 | self.val = val 87 | self.sum += val * n 88 | self.count += n 89 | self.avg = self.sum / self.count 90 | 91 | def save_checkpoint(state, filename): 92 | torch.save(state, filename) 93 | 94 | 95 | def edge_detection(depth): 96 | get_edge = Sobel().cuda() 97 | 98 | edge_xy = get_edge(depth) 99 | edge_sobel = torch.pow(edge_xy[:, 0, :, :], 2) + \ 100 | torch.pow(edge_xy[:, 1, :, :], 2) 101 | edge_sobel = torch.sqrt(edge_sobel) 102 | 103 | return edge_sobel 104 | 105 | def build_optimizer(model, 106 | learning_rate, 107 | optimizer_name='rmsprop', 108 | weight_decay=1e-5, 109 | epsilon=0.001, 110 | momentum=0.9): 111 | """Build optimizer""" 112 | if optimizer_name == "sgd": 113 | print("Using SGD optimizer.") 114 | optimizer = torch.optim.SGD(model.parameters(), 115 | lr = learning_rate, 116 | momentum=momentum, 117 | weight_decay=weight_decay) 118 | 119 | elif optimizer_name == 'rmsprop': 120 | print("Using RMSProp optimizer.") 121 | optimizer = torch.optim.RMSprop(model.parameters(), 122 | lr = learning_rate, 123 | eps = epsilon, 124 | weight_decay = weight_decay, 125 | momentum = momentum 126 | ) 127 | elif optimizer_name == 'adam': 128 | print("Using Adam optimizer.") 129 | optimizer = torch.optim.Adam(model.parameters(), 130 | lr = learning_rate, weight_decay=weight_decay) 131 | return optimizer 132 | 133 | 134 | 135 | 136 | #original script: https://github.com/fangchangma/sparse-to-dense/blob/master/utils.lua 137 | 138 | 139 | def lg10(x): 140 | return torch.div(torch.log(x), math.log(10)) 141 | 142 | def maxOfTwo(x, y): 143 | z = x.clone() 144 | maskYLarger = torch.lt(x, y) 145 | z[maskYLarger.detach()] = y[maskYLarger.detach()] 146 | return z 147 | 148 | def nValid(x): 149 | return torch.sum(torch.eq(x, x).float()) 150 | 151 | def nNanElement(x): 152 | return torch.sum(torch.ne(x, x).float()) 153 | 154 | def getNanMask(x): 155 | return torch.ne(x, x) 156 | 157 | def setNanToZero(input, target): 158 | nanMask = getNanMask(target) 159 | nValidElement = nValid(target) 160 | 161 | _input = input.clone() 162 | _target = target.clone() 163 | 164 | _input[nanMask] = 0 165 | _target[nanMask] = 0 166 | 167 | return _input, _target, nanMask, nValidElement 168 | 169 | 170 | def evaluateError(output, target): 171 | errors = {'MSE': 0, 'RMSE': 0, 'ABS_REL': 0, 'LG10': 0, 172 | 'MAE': 0, 'DELTA1': 0, 'DELTA2': 0, 'DELTA3': 0} 173 | 174 | _output, _target, nanMask, nValidElement = setNanToZero(output, target) 175 | 176 | if (nValidElement.data.cpu().numpy() > 0): 177 | diffMatrix = torch.abs(_output - _target) 178 | 179 | errors['MSE'] = torch.sum(torch.pow(diffMatrix, 2)) / nValidElement 180 | 181 | errors['MAE'] = torch.sum(diffMatrix) / nValidElement 182 | 183 | realMatrix = torch.div(diffMatrix, _target) 184 | realMatrix[nanMask] = 0 185 | errors['ABS_REL'] = torch.sum(realMatrix) / nValidElement 186 | 187 | LG10Matrix = torch.abs(lg10(_output) - lg10(_target)) 188 | LG10Matrix[nanMask] = 0 189 | errors['LG10'] = torch.sum(LG10Matrix) / nValidElement 190 | yOverZ = torch.div(_output, _target) 191 | zOverY = torch.div(_target, _output) 192 | 193 | maxRatio = maxOfTwo(yOverZ, zOverY) 194 | 195 | errors['DELTA1'] = torch.sum( 196 | torch.le(maxRatio, 1.25).float()) / nValidElement 197 | errors['DELTA2'] = torch.sum( 198 | torch.le(maxRatio, math.pow(1.25, 2)).float()) / nValidElement 199 | errors['DELTA3'] = torch.sum( 200 | torch.le(maxRatio, math.pow(1.25, 3)).float()) / nValidElement 201 | 202 | errors['MSE'] = float(errors['MSE'].data.cpu().numpy()) 203 | errors['ABS_REL'] = float(errors['ABS_REL'].data.cpu().numpy()) 204 | errors['LG10'] = float(errors['LG10'].data.cpu().numpy()) 205 | errors['MAE'] = float(errors['MAE'].data.cpu().numpy()) 206 | errors['DELTA1'] = float(errors['DELTA1'].data.cpu().numpy()) 207 | errors['DELTA2'] = float(errors['DELTA2'].data.cpu().numpy()) 208 | errors['DELTA3'] = float(errors['DELTA3'].data.cpu().numpy()) 209 | 210 | return errors 211 | 212 | 213 | def addErrors(errorSum, errors, batchSize): 214 | errorSum['MSE']=errorSum['MSE'] + errors['MSE'] * batchSize 215 | errorSum['ABS_REL']=errorSum['ABS_REL'] + errors['ABS_REL'] * batchSize 216 | errorSum['LG10']=errorSum['LG10'] + errors['LG10'] * batchSize 217 | errorSum['MAE']=errorSum['MAE'] + errors['MAE'] * batchSize 218 | 219 | errorSum['DELTA1']=errorSum['DELTA1'] + errors['DELTA1'] * batchSize 220 | errorSum['DELTA2']=errorSum['DELTA2'] + errors['DELTA2'] * batchSize 221 | errorSum['DELTA3']=errorSum['DELTA3'] + errors['DELTA3'] * batchSize 222 | 223 | return errorSum 224 | 225 | 226 | def averageErrors(errorSum, N): 227 | averageError={'MSE': 0, 'RMSE': 0, 'ABS_REL': 0, 'LG10': 0, 228 | 'MAE': 0, 'DELTA1': 0, 'DELTA2': 0, 'DELTA3': 0} 229 | 230 | averageError['MSE'] = errorSum['MSE'] / N 231 | averageError['ABS_REL'] = errorSum['ABS_REL'] / N 232 | averageError['LG10'] = errorSum['LG10'] / N 233 | averageError['MAE'] = errorSum['MAE'] / N 234 | 235 | averageError['DELTA1'] = errorSum['DELTA1'] / N 236 | averageError['DELTA2'] = errorSum['DELTA2'] / N 237 | averageError['DELTA3'] = errorSum['DELTA3'] / N 238 | 239 | return averageError 240 | 241 | 242 | def colormap(image, cmap="jet"): 243 | image_min = torch.min(image) 244 | image_max = torch.max(image) 245 | image = (image - image_min) / (image_max - image_min) 246 | image = torch.squeeze(image) 247 | 248 | # quantize 249 | indices = torch.round(image * 255).long() 250 | # gather 251 | cm = matplotlib.cm.get_cmap(cmap if cmap is not None else 'gray') 252 | 253 | colors = cm(np.arange(256))[:, :3] 254 | colors = torch.cuda.FloatTensor(colors) 255 | color_map = colors[indices].transpose(2, 3).transpose(1, 2) 256 | 257 | return color_map 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | --------------------------------------------------------------------------------