├── .gitattributes
├── LICENSE
├── README.md
├── assert
└── framework.png
├── create_list_nyuv2.py
├── dataloader
├── __pycache__
│ ├── nyu_transform.cpython-35.pyc
│ ├── nyudv2_dataloader.cpython-35.pyc
│ └── nyudv2_dataloader_224.cpython-35.pyc
├── nyu_transform.py
└── nyudv2_dataloader.py
├── evaluate.py
├── models
├── R_CLSTM_modules.py
├── __init__.py
├── backbone_dict.py
├── densenet.py
├── loss.py
├── modules.py
├── net.py
├── refinenet_dict.py
├── resnet.py
└── senet.py
├── options
├── __init__.py
├── testopt.py
└── trainopt.py
├── train.py
└── utils.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 bdseal
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://opensource.org/licenses/MIT)
2 | [](https://www.python.org/)
3 | [](https://pytorch.org/)
4 |
5 | # Exploiting Temporal Consistency for Real-Time Video Depth Estimation
6 | This is the UNOFFICIAL implementation of the paper [***Exploiting Temporal Consistency for Real-Time Video Depth Estimation***](https://arxiv.org/abs/1908.03706), ***ICCV 2019, Haokui Zhang, Chunhua Shen, Ying Li, Yuanzhouhan Cao, Yu Liu, Youliang Yan.***
7 |
8 | You can find official implementation (WITHOUT TRAINING SCRIPTS) [here](https://github.com/hkzhang91/ST-CLSTM).
9 |
10 | ## Framework
11 | 
12 |
13 | ## Dependencies
14 | - [Python3.6](https://www.python.org/downloads/)
15 | - [PyTorch 1.0+](https://pytorch.org/)
16 | - [NYU Depth v2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html)
17 |
18 | ## Pre-processed Data
19 | We didn't preprocess data as in the official implementation. Instead, we use the dataset shared by [Junjie Hu](https://github.com/JunjH/Revisiting_Single_Depth_Estimation), which is also used by [SARPN](https://github.com/Xt-Chen/SARPN/blob/master/README.md).
20 | You can download the pre-processed data from [here](https://drive.google.com/file/d/1WoOZOBpOWfmwe7bknWS5PMUCLBPFKTOw/view?usp=sharing).
21 |
22 | When you have downloaded the dataset, run the following command to creat training list.
23 | ```bash
24 | python create_list_nyuv2.py
25 | ```
26 |
27 | You can also follow the procedure of [ST-CLSTM](https://github.com/hkzhang91/ST-CLSTM) to preprocess the data. It is based on the oficial Matlab [Toolbox](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html). If Matlab is unavailable for you, there is also a [Python Port Toolbox](https://github.com/GabrielMajeri/nyuv2-python-toolbox) for processing the raw dataset by [GabrielMajeri](https://github.com/GabrielMajeri), which contains code for Higher-level interface to the labeled subset, Raw dataset extraction and preprocessing and Performing data augmentation.
28 |
29 | The final folder structure is shown below.
30 | ```
31 | data_root
32 | |- raw_nyu_v2_250k
33 | | |- train
34 | | | |- basement_0001a
35 | | | | |- rgb
36 | | | | | |- rgb_00000.jpg
37 | | | | | |_ ...
38 | | | | |- depth
39 | | | | | |- depth_00000.png
40 | | | | | |_ ...
41 | | | |_ ...
42 | | |- test_fps_30_fl5_end
43 | | | |- 0000
44 | | | | |- rgb
45 | | | | | |- rgb_00000.jpg
46 | | | | | |- rgb_00001.jpg
47 | | | | | |- ...
48 | | | | | |- rgb_00004.jpg
49 | | | | |- depth
50 | | | | | |- depth_00000.png
51 | | | | | |- depth_00001.png
52 | | | | | |- ...
53 | | | | | |- depth_00004.png
54 | | | |- ...
55 | | |- test_fps_30_fl4_end
56 | | |- test_fps_30_fl3_end>
57 | ```
58 | ## Train
59 | As an example, use the following command to train on NYUDV2.
60 |
61 | ```bash
62 | CUDA_VISIBLE_DEVICES="0,1,2,3" python train.py --epochs 20 --batch_size 128 \
63 | --resume --do_summary --backbone resnet18 --refinenet R_CLSTM_5 \
64 | --trainlist_path ./data_list/raw_nyu_v2_250k/raw_nyu_v2_250k_fps30_fl5_op0_end_train.json \
65 | --root_path ./data/ --checkpoint_dir ./checkpoint/ --logdir ./log/
66 | ```
67 | ## Evaluation
68 | Use the following command to evaluate the trained model on ST-CLSTM [test data](https://github.com/hkzhang91/ST-CLSTM).
69 |
70 | ```bash
71 | CUDA_VISIBLE_DEVICES="0" python evaluate.py --batch_size 1 --backbone resnet18 --refinenet R_CLSTM_5 --loadckpt ./checkpoint/ \
72 | --testlist_path ./data_list/raw_nyu_v2_250k/raw_nyu_v2_250k_fps30_fl5_op0_end_test.json \
73 | --root_path ./data/st-clstm/
74 | ```
75 | ## Pretrained Model
76 | You can download the pretrained model: [NYUDV2](https://github.com/hkzhang91/ST-CLSTM/tree/master/CLSTM_Depth_Estimation-master/prediction/trained_models).
77 |
78 | ## Citation
79 |
80 | ```bibtex
81 | @inproceedings{zhang2019temporal,
82 | title = {Exploiting Temporal Consistency for Real-Time Video Depth Estimation},
83 | author = {Haokui Zhang and Chunhua Shen and Ying Li and Yuanzhouhan Cao and Yu Liu and Youliang Yan},
84 | conference={International Conference on Computer Vision},
85 | year = {2019}
86 | }
87 | ```
--------------------------------------------------------------------------------
/assert/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weihaox/ST-CLSTM/a7494489ca85a9b2ddb8840ef31668dfeaabc85c/assert/framework.png
--------------------------------------------------------------------------------
/create_list_nyuv2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import numpy as np
5 | from glob import glob
6 | # from natsort import natsorted
7 | import re
8 |
9 | parser = argparse.ArgumentParser(description='raw_nyu_v2')
10 | parser.add_argument('--dataset', type=str, default='raw_nyu_v2_250k')
11 | parser.add_argument('--test_loc', type=str, default='end')
12 | parser.add_argument('--fps', type=int, default=30)
13 | parser.add_argument('--fl', type=int, default=5)
14 | parser.add_argument('--overlap', type=int, default=0)
15 | parser.add_argument('--list_save_dir', type=str, default='./data_list')
16 | parser.add_argument('--source_dir', type=str, default='./data/')
17 | args = parser.parse_args()
18 |
19 | args.jpg_png_save_dir = args.source_dir #+ args.dataset
20 |
21 |
22 | def make_if_not_exist(path):
23 | if not os.path.exists(path):
24 | os.makedirs(path)
25 |
26 |
27 | def video_split(frame_len, frame_train, interval, overlap):
28 | sample_interval = frame_train - overlap
29 | indices = []
30 | for start in range(interval):
31 | index_list = list(range(start, frame_len - frame_train * interval + 1, sample_interval))
32 | [indices.append(list(range(num, num+frame_train*interval, interval))) for num in index_list]
33 | indices.append(list(range(frame_len - frame_train * interval - start, frame_len - start, interval)))
34 |
35 | return indices
36 |
37 | def atoi(text):
38 | return int(text) if text.isdigit() else text
39 |
40 | def natural_keys(text):
41 | '''
42 | alist.sort(key=natural_keys) sorts in human order
43 | http://nedbatchelder.com/blog/200712/human_sorting.html
44 | (See Toothy's implementation in the comments)
45 | '''
46 | return [ atoi(c) for c in re.split(r'(\d+)', text) ]
47 |
48 | def create_dict(dataset, list_save_dir, jpg_png_save_dir):
49 | train_dir = os.path.join(jpg_png_save_dir, 'nyu2_train')
50 | # test_dir = os.path.join(jpg_png_save_dir, 'test_fps{}_fl{}_{}'.format(args.fps, args.fl,args.test_loc))
51 | interval = 30 // args.fps
52 |
53 | train_dict=[]
54 | subset_list = os.listdir(train_dir)
55 |
56 | for subset in subset_list:
57 | subset_source_dir = os.path.join(train_dir, subset)
58 | print(subset_source_dir)
59 | rgb_list = glob(subset_source_dir + '/*.jpg')
60 | depth_list = glob(subset_source_dir + '/*.png')
61 |
62 | rgb_list.sort(key=natural_keys)
63 | depth_list.sort(key=natural_keys)
64 | print(rgb_list)
65 |
66 |
67 | rgb_list_new = ['/'.join(rgb_info.split('/')[-5:]) for rgb_info in rgb_list]
68 | depth_list_new = ['/'.join(depth_info.split('/')[-5:]) for depth_info in depth_list]
69 |
70 | indices = video_split(len(depth_list), args.fl, interval, args.overlap)
71 |
72 | for index in indices:
73 | rgb_index = []
74 | depth_index = []
75 | [rgb_index.append(rgb_list_new[id]) for id in index]
76 | [depth_index.append(depth_list_new[id]) for id in index]
77 |
78 | train_info = {
79 | 'rgb_index': rgb_index,
80 | 'depth_index': depth_index,
81 | 'scene_name': subset,
82 | "test_index": args.fl-1,
83 | }
84 | train_dict.append(train_info)
85 |
86 | # test_dict = []
87 | # subset_list = os.listdir(test_dir)
88 | # for subset in subset_list:
89 | # subset_source_dir = os.path.join(test_dir, subset)
90 | # rgb_list = glob(subset_source_dir + '/rgb/rgb_*.jpg')
91 | # depth_list = glob(subset_source_dir + '/depth/depth_*.png')
92 | # rgb_list.sort()
93 | # depth_list.sort()
94 |
95 | # rgb_list_new = ['/'.join(rgb_info.split('/')[-5:]) for rgb_info in rgb_list]
96 | # depth_list_new = ['/'.join(depth_info.split('/')[-5:]) for depth_info in depth_list]
97 |
98 | # test_index = int(open(subset_source_dir + '/depth/frame_index.txt').read())
99 |
100 | # test_info = {
101 | # 'rgb_index': rgb_list_new,
102 | # 'depth_index': depth_list_new,
103 | # 'scene_name': subset,
104 | # 'test_index': test_index
105 | # }
106 | # test_dict.append(test_info)
107 |
108 | list_save_dir = os.path.join(list_save_dir, dataset)
109 | make_if_not_exist(list_save_dir)
110 | train_info_save = list_save_dir + '/{}_fps{}_fl{}_op{}_{}_train.json'.format(dataset, args.fps, args.fl, args.overlap, args.test_loc)
111 | # test_info_save = list_save_dir + '/{}_fps{}_fl{}_op{}_{}_test.json'.format(dataset, args.fps, args.fl, args.overlap, args.test_loc)
112 |
113 | with open(train_info_save, 'w') as dst_file:
114 | json.dump(train_dict, dst_file)
115 | # with open(test_info_save, 'w') as dst_file:
116 | # json.dump(test_dict, dst_file)
117 |
118 | if __name__ == '__main__':
119 |
120 | create_dict(args.dataset, args.list_save_dir, args.jpg_png_save_dir)
--------------------------------------------------------------------------------
/dataloader/__pycache__/nyu_transform.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weihaox/ST-CLSTM/a7494489ca85a9b2ddb8840ef31668dfeaabc85c/dataloader/__pycache__/nyu_transform.cpython-35.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/nyudv2_dataloader.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weihaox/ST-CLSTM/a7494489ca85a9b2ddb8840ef31668dfeaabc85c/dataloader/__pycache__/nyudv2_dataloader.cpython-35.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/nyudv2_dataloader_224.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weihaox/ST-CLSTM/a7494489ca85a9b2ddb8840ef31668dfeaabc85c/dataloader/__pycache__/nyudv2_dataloader_224.cpython-35.pyc
--------------------------------------------------------------------------------
/dataloader/nyu_transform.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import numpy as np
4 | from PIL import Image, ImageOps
5 | import collections
6 | try:
7 | import accimage
8 | except ImportError:
9 | accimage = None
10 | import random
11 | import scipy.ndimage as ndimage
12 |
13 | import pdb
14 |
15 |
16 | def _is_pil_image(img):
17 | if accimage is not None:
18 | return isinstance(img, (Image.Image, accimage.Image))
19 | else:
20 | return isinstance(img, Image.Image)
21 |
22 |
23 | def _is_numpy_image(img):
24 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
25 |
26 |
27 | class RandomRotate(object):
28 | """Random rotation of the image from -angle to angle (in degrees)
29 | This is useful for dataAugmentation, especially for geometric problems such as FlowEstimation
30 | angle: max angle of the rotation
31 | interpolation order: Default: 2 (bilinear)
32 | reshape: Default: false. If set to true, image size will be set to keep every pixel in the image.
33 | diff_angle: Default: 0. Must stay less than 10 degrees, or linear approximation of flowmap will be off.
34 | """
35 |
36 | def __init__(self, angle, diff_angle=0, order=2, reshape=False):
37 | self.angle = angle
38 | self.reshape = reshape
39 | self.order = order
40 |
41 | def __call__(self, sample):
42 | image, depth = sample['image'], sample['depth']
43 |
44 | applied_angle = random.uniform(-self.angle, self.angle)
45 | angle1 = applied_angle
46 | angle1_rad = angle1 * np.pi / 180
47 |
48 | image = ndimage.interpolation.rotate(
49 | image, angle1, reshape=self.reshape, order=self.order)
50 | depth = ndimage.interpolation.rotate(
51 | depth, angle1, reshape=self.reshape, order=self.order)
52 |
53 | image = Image.fromarray(image)
54 | depth = Image.fromarray(depth)
55 |
56 | return {'image': image, 'depth': depth}
57 |
58 | class RandomHorizontalFlip(object):
59 |
60 | def __call__(self, sample):
61 | image, depth = sample['image'], sample['depth']
62 |
63 | if not _is_pil_image(image):
64 | raise TypeError(
65 | 'img should be PIL Image. Got {}'.format(type(img)))
66 | if not _is_pil_image(depth):
67 | raise TypeError(
68 | 'img should be PIL Image. Got {}'.format(type(depth)))
69 |
70 | if random.random() < 0.5:
71 | image = image.transpose(Image.FLIP_LEFT_RIGHT)
72 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT)
73 |
74 | return {'image': image, 'depth': depth}
75 |
76 |
77 | class Scale(object):
78 | """ Rescales the inputs and target arrays to the given 'size'.
79 | 'size' will be the size of the smaller edge.
80 | For example, if height > width, then image will be
81 | rescaled to (size * height / width, size)
82 | size: size of the smaller edge
83 | interpolation order: Default: 2 (bilinear)
84 | """
85 |
86 | def __init__(self, size):
87 | self.size = size
88 |
89 | def __call__(self, sample):
90 | image, depth = sample['image'], sample['depth']
91 |
92 | image = self.changeScale(image, self.size)
93 | depth = self.changeScale(depth, self.size,Image.NEAREST)
94 |
95 | return {'image': image, 'depth': depth}
96 |
97 | def changeScale(self, img, size, interpolation=Image.BILINEAR):
98 |
99 | if not _is_pil_image(img):
100 | raise TypeError(
101 | 'img should be PIL Image. Got {}'.format(type(img)))
102 | if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)):
103 | raise TypeError('Got inappropriate size arg: {}'.format(size))
104 |
105 | if isinstance(size, int):
106 | w, h = img.size
107 | if (w <= h and w == size) or (h <= w and h == size):
108 | return img
109 | if w < h:
110 | ow = size
111 | oh = int(size * h / w)
112 | return img.resize((ow, oh), interpolation)
113 | else:
114 | oh = size
115 | ow = int(size * w / h)
116 | return img.resize((ow, oh), interpolation)
117 | else:
118 | return img.resize(size[::-1], interpolation)
119 |
120 |
121 | class CenterCrop(object):
122 | def __init__(self, size_image, size_depth):
123 | self.size_image = size_image
124 | self.size_depth = size_depth
125 |
126 | def __call__(self, sample):
127 | image, depth = sample['image'], sample['depth']
128 | ### crop image and depth to (304, 228)
129 | image = self.centerCrop(image, self.size_image)
130 | depth = self.centerCrop(depth, self.size_image)
131 | ### resize depth to (152, 114) downsample 2
132 | ow, oh = self.size_depth
133 | depth = depth.resize((ow, oh))
134 |
135 | return {'image': image, 'depth': depth}
136 |
137 | def centerCrop(self, image, size):
138 |
139 | w1, h1 = image.size
140 |
141 | tw, th = size
142 |
143 | if w1 == tw and h1 == th:
144 | return image
145 | ## (320-304) / 2. = 8
146 | ## (240-228) / 2. = 8
147 | x1 = int(round((w1 - tw) / 2.))
148 | y1 = int(round((h1 - th) / 2.))
149 |
150 | image = image.crop((x1, y1, tw + x1, th + y1))
151 |
152 | return image
153 |
154 |
155 | class ToTensor(object):
156 | """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.
157 | Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
158 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
159 | """
160 | def __init__(self,is_test=False):
161 | self.is_test = is_test
162 |
163 | def __call__(self, sample):
164 | image, depth = sample['image'], sample['depth']
165 | """
166 | Args:
167 | pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.
168 | Returns:
169 | Tensor: Converted image.
170 | """
171 | # ground truth depth of training samples is stored in 8-bit while test samples are saved in 16 bit
172 | image = self.to_tensor(image)
173 | if self.is_test:
174 | depth = self.to_tensor(depth).float()/1000
175 | else:
176 | depth = self.to_tensor(depth).float()*10
177 | return {'image': image, 'depth': depth}
178 |
179 | def to_tensor(self, pic):
180 | if not(_is_pil_image(pic) or _is_numpy_image(pic)):
181 | raise TypeError(
182 | 'pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
183 |
184 | if isinstance(pic, np.ndarray):
185 | img = torch.from_numpy(pic.transpose((2, 0, 1)))
186 | ## convert image to (0,1)
187 | return img.float().div(255)
188 |
189 | if accimage is not None and isinstance(pic, accimage.Image):
190 | nppic = np.zeros(
191 | [pic.channels, pic.height, pic.width], dtype=np.float32)
192 | pic.copyto(nppic)
193 | return torch.from_numpy(nppic)
194 |
195 | # handle PIL Image
196 | if pic.mode == 'I':
197 | img = torch.from_numpy(np.array(pic, np.int32, copy=False))
198 | elif pic.mode == 'I;16':
199 | img = torch.from_numpy(np.array(pic, np.int16, copy=False))
200 | else:
201 | img = torch.ByteTensor(
202 | torch.ByteStorage.from_buffer(pic.tobytes()))
203 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
204 | if pic.mode == 'YCbCr':
205 | nchannel = 3
206 | elif pic.mode == 'I;16':
207 | nchannel = 1
208 | else:
209 | nchannel = len(pic.mode)
210 | img = img.view(pic.size[1], pic.size[0], nchannel)
211 | # put it from HWC to CHW format
212 | # yikes, this transpose takes 80% of the loading time/CPU
213 | img = img.transpose(0, 1).transpose(0, 2).contiguous()
214 | if isinstance(img, torch.ByteTensor):
215 | return img.float().div(255)
216 | else:
217 | return img
218 |
219 |
220 | class Lighting(object):
221 |
222 | def __init__(self, alphastd, eigval, eigvec):
223 | self.alphastd = alphastd
224 | self.eigval = eigval
225 | self.eigvec = eigvec
226 |
227 | def __call__(self, sample):
228 | image, depth = sample['image'], sample['depth']
229 | if self.alphastd == 0:
230 | return image
231 |
232 | alpha = image.new().resize_(3).normal_(0, self.alphastd)
233 | rgb = self.eigvec.type_as(image).clone()\
234 | .mul(alpha.view(1, 3).expand(3, 3))\
235 | .mul(self.eigval.view(1, 3).expand(3, 3))\
236 | .sum(1).squeeze()
237 |
238 | image = image.add(rgb.view(3, 1, 1).expand_as(image))
239 |
240 | return {'image': image, 'depth': depth}
241 |
242 |
243 | class Grayscale(object):
244 |
245 | def __call__(self, img):
246 | gs = img.clone()
247 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2])
248 | gs[1].copy_(gs[0])
249 | gs[2].copy_(gs[0])
250 | return gs
251 |
252 |
253 | class Saturation(object):
254 |
255 | def __init__(self, var):
256 | self.var = var
257 |
258 | def __call__(self, img):
259 | gs = Grayscale()(img)
260 | alpha = random.uniform(-self.var, self.var)
261 | return img.lerp(gs, alpha)
262 |
263 |
264 | class Brightness(object):
265 |
266 | def __init__(self, var):
267 | self.var = var
268 |
269 | def __call__(self, img):
270 | gs = img.new().resize_as_(img).zero_()
271 | alpha = random.uniform(-self.var, self.var)
272 |
273 | return img.lerp(gs, alpha)
274 |
275 |
276 | class Contrast(object):
277 |
278 | def __init__(self, var):
279 | self.var = var
280 |
281 | def __call__(self, img):
282 | gs = Grayscale()(img)
283 | gs.fill_(gs.mean())
284 | alpha = random.uniform(-self.var, self.var)
285 | return img.lerp(gs, alpha)
286 |
287 |
288 | class RandomOrder(object):
289 | """ Composes several transforms together in random order.
290 | """
291 |
292 | def __init__(self, transforms):
293 | self.transforms = transforms
294 |
295 | def __call__(self, sample):
296 | image, depth = sample['image'], sample['depth']
297 |
298 | if self.transforms is None:
299 | return {'image': image, 'depth': depth}
300 | order = torch.randperm(len(self.transforms))
301 | for i in order:
302 | image = self.transforms[i](image)
303 |
304 | return {'image': image, 'depth': depth}
305 |
306 |
307 | class ColorJitter(RandomOrder):
308 |
309 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4):
310 | self.transforms = []
311 | if brightness != 0:
312 | self.transforms.append(Brightness(brightness))
313 | if contrast != 0:
314 | self.transforms.append(Contrast(contrast))
315 | if saturation != 0:
316 | self.transforms.append(Saturation(saturation))
317 |
318 |
319 | class Normalize(object):
320 | def __init__(self, mean, std):
321 | self.mean = mean
322 | self.std = std
323 |
324 | def __call__(self, sample):
325 | """
326 | Args:
327 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
328 | Returns:
329 | Tensor: Normalized image.
330 | """
331 | image, depth = sample['image'], sample['depth']
332 |
333 | image = self.normalize(image, self.mean, self.std)
334 |
335 | return {'image': image, 'depth': depth}
336 |
337 | def normalize(self, tensor, mean, std):
338 | """Normalize a tensor image with mean and standard deviation.
339 | See ``Normalize`` for more details.
340 | Args:
341 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
342 | mean (sequence): Sequence of means for R, G, B channels respecitvely.
343 | std (sequence): Sequence of standard deviations for R, G, B channels
344 | respecitvely.
345 | Returns:
346 | Tensor: Normalized image.
347 | """
348 |
349 | # TODO: make efficient
350 | for t, m, s in zip(tensor, mean, std):
351 | t.sub_(m).div_(s)
352 | return tensor
--------------------------------------------------------------------------------
/dataloader/nyudv2_dataloader.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | from torch.utils.data import Dataset, DataLoader
4 | from torchvision import transforms, utils
5 | from PIL import Image
6 | import random
7 | import json
8 | import os
9 | from dataloader.nyu_transform import *
10 |
11 | try:
12 | import accimage
13 | except ImportError:
14 | accimage = None
15 |
16 |
17 | def load_annotation_data(dict_file_dir):
18 | with open(dict_file_dir, 'r') as data_file:
19 | return json.load(data_file)
20 |
21 |
22 | def pil_loader(path):
23 | return Image.open(path)
24 |
25 |
26 | def video_loader(root_dir, frame_indices):
27 | video = []
28 | for index in frame_indices:
29 | image_path = os.path.join(root_dir, index)
30 | if os.path.exists(image_path):
31 | video.append(pil_loader(image_path))
32 | else:
33 | return video
34 |
35 | return video
36 |
37 |
38 | class depthDataset(Dataset):
39 | """Face Landmarks dataset."""
40 |
41 | def __init__(self, dict_dir, root_dir, transform=None, is_test=False):
42 | self.data_dict = load_annotation_data(dict_dir)
43 | self.root_dir = root_dir
44 | self.transform = transform
45 |
46 | def __getitem__(self, idx):
47 | rgb_index = self.data_dict[idx]['rgb_index']
48 | depth_index = self.data_dict[idx]['depth_index']
49 | test_index = self.data_dict[idx]['test_index']
50 |
51 | rgb_clips = video_loader(self.root_dir, rgb_index)
52 | depth_clips = video_loader(self.root_dir, depth_index)
53 |
54 | rgb_tensor = []
55 | depth_tensor = []
56 | depth_scaled_tensor = []
57 | for rgb_clip, depth_clip in zip(rgb_clips, depth_clips):
58 | sample = {'image': rgb_clip, 'depth': depth_clip}
59 | sample_new = self.transform(sample)
60 | rgb_tensor.append(sample_new['image'])
61 | depth_tensor.append(sample_new['depth'])
62 |
63 | return torch.stack(rgb_tensor, 0).permute(1, 0, 2, 3), \
64 | torch.stack(depth_tensor, 0).permute(1, 0, 2, 3), \
65 | test_index
66 |
67 | def __len__(self):
68 | return len(self.data_dict)
69 |
70 | def getTrainingData_NYUDV2(batch_size=64, dict_dir=None, root_dir=None):
71 | __imagenet_pca = {
72 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
73 | 'eigvec': torch.Tensor([
74 | [-0.5675, 0.7192, 0.4009],
75 | [-0.5808, -0.0045, -0.8140],
76 | [-0.5836, -0.6948, 0.4203],
77 | ])
78 | }
79 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406],
80 | 'std': [0.229, 0.224, 0.225]}
81 | transformed_training = depthDataset(dict_dir=dict_dir,
82 | root_dir = root_dir,
83 | transform=transforms.Compose([
84 | Scale(240),
85 | RandomHorizontalFlip(),
86 | RandomRotate(5),
87 | CenterCrop([304, 228], [152, 114]),
88 | ToTensor(),
89 | Lighting(0.1, __imagenet_pca[
90 | 'eigval'], __imagenet_pca['eigvec']),
91 | ColorJitter(
92 | brightness=0.4,
93 | contrast=0.4,
94 | saturation=0.4,
95 | ),
96 | Normalize(__imagenet_stats['mean'],
97 | __imagenet_stats['std'])
98 | ]))
99 |
100 | dataloader_training = DataLoader(transformed_training, batch_size, shuffle=True, num_workers=1, pin_memory=False)
101 |
102 | return dataloader_training
103 |
104 | def getTestingData_NYUDV2(batch_size=64, dict_dir=None, root_dir=None, num_workers=4):
105 |
106 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406],
107 | 'std': [0.229, 0.224, 0.225]}
108 |
109 |
110 | transformed_testing = depthDataset(dict_dir=dict_dir,
111 | root_dir=root_dir,
112 | transform=transforms.Compose([
113 | Scale(240),
114 | CenterCrop([304, 228], [152, 114]),
115 | ToTensor(is_test=True),
116 | Normalize(__imagenet_stats['mean'],
117 | __imagenet_stats['std'])
118 | ]))
119 |
120 | dataloader_testing = DataLoader(transformed_testing, batch_size, shuffle=False, num_workers=num_workers, pin_memory=False)
121 |
122 | return dataloader_testing
123 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import numpy as np
4 | import os
5 | from collections import OrderedDict
6 | import torch.nn as nn
7 | import torch.nn.parallel
8 | import torch.nn.functional
9 | from utils import *
10 | from options import get_args
11 | from dataloader import nyudv2_dataloader
12 | from models.backbone_dict import backbone_dict
13 | from models import modules
14 | from models import net
15 |
16 |
17 | args = get_args('test')
18 | # lode nyud v2 test set
19 | TestImgLoader = nyudv2_dataloader.getTestingData_NYUDV2(args.batch_size, args.testlist_path, args.root_path)
20 | # model
21 | backbone = backbone_dict[args.backbone]()
22 | Encoder = modules.E_resnet(backbone)
23 |
24 | if args.backbone in ['resnet50']:
25 | model = net.model(Encoder, num_features=2048, block_channel=[256, 512, 1024, 2048], refinenet=args.refinenet)
26 | elif args.backbone in ['resnet18', 'resnet34']:
27 | model = net.model(Encoder, num_features=512, block_channel=[64, 128, 256, 512], refinenet=args.refinenet)
28 |
29 | model = nn.DataParallel(model).cuda()
30 |
31 | # load test model
32 | if args.loadckpt is not None and args.loadckpt.endswith('.pth.tar'):
33 | print("loading the specific model in checkpoint_dir: {}".format(args.loadckpt))
34 | state_dict = torch.load(args.loadckpt)
35 | model.load_state_dict(state_dict)
36 | elif os.path.isdir(args.loadckpt):
37 | all_saved_ckpts = [ckpt for ckpt in os.listdir(args.loadckpt) if ckpt.endswith(".pth.tar")]
38 | print(all_saved_ckpts)
39 | all_saved_ckpts = sorted(all_saved_ckpts, key=lambda x:int(x.split('_')[-1].split('.')[0]))
40 | loadckpt = os.path.join(args.loadckpt, all_saved_ckpts[-1])
41 | start_epoch = int(all_saved_ckpts[-1].split('_')[-1].split('.')[0])
42 | print("loading the lastest model in checkpoint_dir: {}".format(loadckpt))
43 | state_dict = torch.load(loadckpt)
44 | model.load_state_dict(state_dict)
45 | else:
46 | print("You have not loaded any models.")
47 |
48 | def test():
49 |
50 | model.eval()
51 | with torch.no_grad():
52 | for batch_idx, sample in enumerate(TestImgLoader):
53 | print("Processing the {}th image!".format(batch_idx))
54 | image, depth = sample[0], sample[1]
55 | depth = depth.cuda()
56 | image = image.cuda()
57 |
58 | image = torch.autograd.Variable(image)
59 | depth = torch.autograd.Variable(depth)
60 |
61 | start = time.time()
62 | pred = model(image)
63 | end = time.time()
64 | running_time = end - start
65 |
66 | print(pred.size())
67 | print(depth.size())
68 |
69 | pred_ = np.squeeze(pred.data.cpu().numpy())
70 | depth_ = np.squeeze(depth.cpu().numpy())
71 |
72 | print(np.shape(pred_))
73 | print(np.shape(depth_))
74 |
75 | for seq_idx in range(len(pred_)):
76 | print(seq_idx)
77 | print(np.shape(depth_[0:]))
78 |
79 | depth = depth_[seq_idx]
80 | pred = pred_[seq_idx]
81 |
82 | d_min = min(np.min(depth), np.min(pred))
83 | d_max = max(np.max(depth), np.max(pred))
84 | # depth = colored_depthmap(depth, d_min, d_max)
85 | # pred = colored_depthmap(pred, d_min, d_max)
86 | depth = colored_depthmap(depth)
87 | pred = colored_depthmap(pred)
88 |
89 | print(d_min)
90 | print(d_max)
91 |
92 | filename = os.path.join('./samples/depth_' + str(seq_idx) + '.png')
93 | save_image(depth, filename)
94 |
95 | filename = os.path.join('./samples/pred_' + str(seq_idx) + '.png')
96 | save_image(pred, filename)
97 |
98 | # if metrics_s is not None:
99 | # metrics_s(torch.stack(pred_new, 0).cpu(), torch.stack(depth_new, 0))
100 |
101 | # result_s = metrics_s.loss_get()
102 | # print(result_s)
103 |
104 |
105 | if __name__ == '__main__':
106 | test()
--------------------------------------------------------------------------------
/models/R_CLSTM_modules.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import torch.nn as nn
3 | import torch
4 | import time
5 |
6 | def maps_2_cubes(x, b, d):
7 | x_b, x_c, x_h, x_w = x.shape
8 | x = x.contiguous().view(b, d, x_c, x_h, x_w)
9 |
10 | return x.permute(0, 2, 1, 3, 4)
11 |
12 |
13 | def maps_2_maps(x, b, d):
14 | x_b, x_c, x_h, x_w = x.shape
15 | x = x.contiguous().view(b, d * x_c, x_h, x_w)
16 |
17 | return x
18 |
19 |
20 | class R(nn.Module):
21 | def __init__(self, block_channel):
22 | super(R, self).__init__()
23 |
24 | num_features = 64 + block_channel[3] // 32
25 | self.conv0 = nn.Conv2d(num_features, num_features,
26 | kernel_size=5, stride=1, padding=2, bias=False)
27 | self.bn0 = nn.BatchNorm2d(num_features)
28 |
29 | self.conv1 = nn.Conv2d(num_features, num_features,
30 | kernel_size=5, stride=1, padding=2, bias=False)
31 | self.bn1 = nn.BatchNorm2d(num_features)
32 |
33 | self.conv2 = nn.Conv2d(
34 | num_features, 1, kernel_size=5, stride=1, padding=2, bias=True)
35 |
36 | def forward(self, x):
37 | x0 = self.conv0(x)
38 | x0 = self.bn0(x0)
39 | x0 = F.relu(x0)
40 |
41 | x1 = self.conv1(x0)
42 | x1 = self.bn1(x1)
43 | h = F.relu(x1)
44 |
45 | pred_depth = self.conv2(h)
46 |
47 | return h, pred_depth
48 |
49 |
50 | class R_2(nn.Module):
51 | def __init__(self, block_channel):
52 | super(R_2, self).__init__()
53 |
54 | num_features = 64 + block_channel[3] // 32 + 4
55 | self.conv0 = nn.Conv2d(num_features, num_features,
56 | kernel_size=5, stride=1, padding=2, bias=False)
57 | self.bn0 = nn.BatchNorm2d(num_features)
58 |
59 | self.conv1 = nn.Conv2d(num_features, num_features,
60 | kernel_size=5, stride=1, padding=2, bias=False)
61 | self.bn1 = nn.BatchNorm2d(num_features)
62 |
63 | self.conv2 = nn.Conv2d(
64 | num_features, 1, kernel_size=5, stride=1, padding=2, bias=True)
65 |
66 | self.convh = nn.Conv2d(
67 | num_features, 4, kernel_size=3, stride=1, padding=1, bias=True)
68 |
69 | def forward(self, x):
70 | x0 = self.conv0(x)
71 | x0 = self.bn0(x0)
72 | x0 = F.relu(x0)
73 |
74 | x1 = self.conv1(x0)
75 | x1 = self.bn1(x1)
76 | x1 = F.relu(x1)
77 |
78 | h = self.convh(x1)
79 | pred_depth = self.conv2(x1)
80 |
81 | return h, pred_depth
82 |
83 |
84 | class R_d(nn.Module):
85 | def __init__(self, block_channel):
86 | super(R_d, self).__init__()
87 |
88 | num_features = 64 + block_channel[3] // 32 + 4
89 | self.conv0 = nn.Conv2d(num_features, num_features,
90 | kernel_size=5, stride=1, padding=2, bias=False)
91 | self.bn0 = nn.BatchNorm2d(num_features)
92 |
93 | self.conv1 = nn.Conv2d(num_features, num_features,
94 | kernel_size=5, stride=1, padding=2, bias=False)
95 | self.bn1 = nn.BatchNorm2d(num_features)
96 |
97 | self.dropout = nn.Dropout2d(p=0.5)
98 |
99 | self.conv2 = nn.Conv2d(
100 | num_features, 1, kernel_size=5, stride=1, padding=2, bias=True)
101 |
102 | self.convh = nn.Conv2d(
103 | num_features, 4, kernel_size=3, stride=1, padding=1, bias=True)
104 |
105 | def forward(self, x):
106 | x0 = self.conv0(x)
107 | x0 = self.bn0(x0)
108 | x0 = F.relu(x0)
109 |
110 | x1 = self.conv1(x0)
111 | x1 = self.bn1(x1)
112 | x1 = F.relu(x1)
113 |
114 | h = self.convh(x1)
115 | x1 = self.dropout(x1)
116 | pred_depth = self.conv2(x1)
117 |
118 | return h, pred_depth
119 |
120 |
121 | class R_3(nn.Module):
122 | def __init__(self, block_channel):
123 | super(R_3, self).__init__()
124 |
125 | num_features = 64 + block_channel[3] // 32 + 8
126 | self.conv0 = nn.Conv2d(num_features, num_features,
127 | kernel_size=5, stride=1, padding=2, bias=False)
128 | self.bn0 = nn.BatchNorm2d(num_features)
129 |
130 | self.conv1 = nn.Conv2d(num_features, num_features,
131 | kernel_size=5, stride=1, padding=2, bias=False)
132 | self.bn1 = nn.BatchNorm2d(num_features)
133 |
134 | self.conv2 = nn.Conv2d(
135 | num_features, 1, kernel_size=5, stride=1, padding=2, bias=True)
136 |
137 | self.convh = nn.Conv2d(
138 | num_features, 8, kernel_size=3, stride=1, padding=1, bias=True)
139 |
140 | def forward(self, x):
141 | x0 = self.conv0(x)
142 | x0 = self.bn0(x0)
143 | x0 = F.relu(x0)
144 |
145 | x1 = self.conv1(x0)
146 | x1 = self.bn1(x1)
147 | x1 = F.relu(x1)
148 |
149 | h = self.convh(x1)
150 | pred_depth = self.conv2(x1)
151 |
152 | return h, pred_depth
153 |
154 |
155 | class R_10(nn.Module):
156 | def __init__(self, block_channel):
157 | super(R_10, self).__init__()
158 |
159 | num_features = 64 + block_channel[3] // 32 + 8
160 | self.conv0 = nn.Conv2d(num_features, num_features,
161 | kernel_size=5, stride=1, padding=2, bias=False)
162 | self.bn0 = nn.BatchNorm2d(num_features)
163 |
164 | self.conv1 = nn.Conv2d(num_features, num_features,
165 | kernel_size=5, stride=1, padding=2, bias=False)
166 | self.bn1 = nn.BatchNorm2d(num_features)
167 |
168 | self.conv2 = nn.Conv2d(
169 | num_features, 1, kernel_size=5, stride=1, padding=2, bias=True)
170 |
171 | self.convh = nn.Conv2d(
172 | num_features, 8, kernel_size=3, stride=1, padding=1, bias=True)
173 |
174 | def forward(self, x):
175 | x0 = self.conv0(x)
176 | x0 = self.bn0(x0)
177 | x0 = F.relu(x0)
178 |
179 | x1 = self.conv1(x0)
180 | x1 = self.bn1(x1)
181 | x1 = F.relu(x1)
182 |
183 | h = self.convh(x1)
184 | pred_depth = F.tanh(self.conv2(x1))
185 |
186 | return h, pred_depth
187 |
188 |
189 | class R_CLSTM_1(nn.Module):
190 | def __init__(self, block_channel):
191 | super(R_CLSTM_1, self).__init__()
192 | num_features = 64 + block_channel[3] // 32
193 | self.Refine = R(block_channel)
194 | self.F_t = nn.Sequential(
195 | nn.Conv2d(in_channels=num_features+num_features,
196 | out_channels=num_features,
197 | kernel_size=3,
198 | padding=1,
199 | ),
200 | nn.Sigmoid()
201 | )
202 | self.I_t = nn.Sequential(
203 | nn.Conv2d(in_channels=num_features + num_features,
204 | out_channels=num_features,
205 | kernel_size=3,
206 | padding=1,
207 | ),
208 | nn.Sigmoid()
209 | )
210 | self.C_t = nn.Sequential(
211 | nn.Conv2d(in_channels=num_features + num_features,
212 | out_channels=num_features,
213 | kernel_size=3,
214 | padding=1,
215 | ),
216 | nn.Tanh()
217 | )
218 | self.Q_t = nn.Sequential(
219 | nn.Conv2d(in_channels=num_features + num_features,
220 | out_channels=num_features,
221 | kernel_size=3,
222 | padding=1,
223 | ),
224 | nn.Sigmoid()
225 | )
226 |
227 | def forward(self, input_tensor, b, d):
228 | input_tensor = maps_2_cubes(input_tensor, b, d)
229 | b, c, d, h, w = input_tensor.shape
230 | h_state_init = torch.zeros(b, c, h, w).to('cuda')
231 | c_state_init = torch.zeros(b, c, h, w).to('cuda')
232 |
233 | seq_len = d
234 |
235 | h_state, c_state = h_state_init, c_state_init
236 | output_inner = []
237 | for t in range(seq_len):
238 | input_cat = torch.cat((input_tensor[:,:,t,:,:], h_state), dim=1)
239 | c_state = self.F_t(input_cat) * c_state + self.I_t(input_cat) * self.C_t(input_cat)
240 |
241 | h_state, p_depth = self.Refine(c_state * self.Q_t(input_cat))
242 |
243 | output_inner.append(p_depth)
244 |
245 | layer_output = torch.stack(output_inner, dim=2)
246 |
247 | return layer_output
248 |
249 |
250 | class R_CLSTM_2(nn.Module):
251 | def __init__(self, block_channel):
252 | super(R_CLSTM_2, self).__init__()
253 | num_features = 64 + block_channel[3] // 32
254 | self.Refine = R(block_channel)
255 | self.F_t = nn.Sequential(
256 | nn.Conv2d(in_channels=num_features + num_features,
257 | out_channels=num_features // 2,
258 | kernel_size=3,
259 | padding=1,
260 | ),
261 | nn.Sigmoid()
262 | )
263 | self.I_t = nn.Sequential(
264 | nn.Conv2d(in_channels=num_features + num_features,
265 | out_channels=num_features // 2,
266 | kernel_size=3,
267 | padding=1,
268 | ),
269 | nn.Sigmoid()
270 | )
271 | self.C_t = nn.Sequential(
272 | nn.Conv2d(in_channels=num_features + num_features,
273 | out_channels=num_features // 2,
274 | kernel_size=3,
275 | padding=1,
276 | ),
277 | nn.Tanh()
278 | )
279 | self.Q_t = nn.Sequential(
280 | nn.Conv2d(in_channels=num_features + num_features,
281 | out_channels=num_features // 2,
282 | kernel_size=3,
283 | padding=1,
284 | ),
285 | nn.Sigmoid()
286 | )
287 |
288 | def forward(self, input_tensor, b, d):
289 | input_tensor = maps_2_cubes(input_tensor, b, d)
290 | b, c, d, h, w = input_tensor.shape
291 | h_state_init = torch.zeros(b, c, h, w).to('cuda')
292 | c_state_init = torch.zeros(b, c // 2, h, w).to('cuda')
293 |
294 | seq_len = d
295 |
296 | h_state, c_state = h_state_init, c_state_init
297 | output_inner = []
298 | for t in range(seq_len):
299 | input_cat = torch.cat((input_tensor[:, :, t, :, :], h_state), dim=1)
300 | c_state = self.F_t(input_cat) * c_state + self.I_t(input_cat) * self.C_t(input_cat)
301 |
302 | h_state, p_depth = self.Refine(torch.cat((c_state, self.Q_t(input_cat)), 1))
303 |
304 | output_inner.append(p_depth)
305 |
306 | layer_output = torch.stack(output_inner, dim=2)
307 |
308 | return layer_output
309 |
310 |
311 | class R_CLSTM_3(nn.Module):
312 | def __init__(self, block_channel):
313 | super(R_CLSTM_3, self).__init__()
314 | num_features = 64 + block_channel[3] // 32
315 | self.Refine = R_2(block_channel)
316 | self.F_t = nn.Sequential(
317 | nn.Conv2d(in_channels=num_features+4,
318 | out_channels=4,
319 | kernel_size=3,
320 | padding=1,
321 | ),
322 | nn.Sigmoid()
323 | )
324 | self.I_t = nn.Sequential(
325 | nn.Conv2d(in_channels=num_features + 4,
326 | out_channels=4,
327 | kernel_size=3,
328 | padding=1,
329 | ),
330 | nn.Sigmoid()
331 | )
332 | self.C_t = nn.Sequential(
333 | nn.Conv2d(in_channels=num_features + 4,
334 | out_channels=4,
335 | kernel_size=3,
336 | padding=1,
337 | ),
338 | nn.Tanh()
339 | )
340 | self.Q_t = nn.Sequential(
341 | nn.Conv2d(in_channels=num_features + 4,
342 | out_channels=num_features,
343 | kernel_size=3,
344 | padding=1,
345 | ),
346 | nn.Sigmoid()
347 | )
348 |
349 | def forward(self, input_tensor, b, d):
350 | input_tensor = maps_2_cubes(input_tensor, b, d)
351 | b, c, d, h, w = input_tensor.shape
352 | h_state_init = torch.zeros(b, 4, h, w).to('cuda')
353 | c_state_init = torch.zeros(b, 4, h, w).to('cuda')
354 |
355 | seq_len = d
356 |
357 | h_state, c_state = h_state_init, c_state_init
358 | output_inner = []
359 | for t in range(seq_len):
360 | input_cat = torch.cat((input_tensor[:,:,t,:,:], h_state), dim=1)
361 | c_state = self.F_t(input_cat) * c_state + self.I_t(input_cat) * self.C_t(input_cat)
362 |
363 | h_state, p_depth = self.Refine(torch.cat((c_state, self.Q_t(input_cat)), 1))
364 |
365 | output_inner.append(p_depth)
366 |
367 | layer_output = torch.stack(output_inner, dim=2)
368 |
369 | return layer_output
370 |
371 |
372 | class R_CLSTM_4(nn.Module):
373 | def __init__(self, block_channel):
374 | super(R_CLSTM_4, self).__init__()
375 | num_features = 64 + block_channel[3] // 32
376 | self.Refine = R_d(block_channel)
377 | self.F_t = nn.Sequential(
378 | nn.Conv2d(in_channels=num_features+4,
379 | out_channels=4,
380 | kernel_size=3,
381 | padding=1,
382 | ),
383 | nn.Sigmoid()
384 | )
385 | self.I_t = nn.Sequential(
386 | nn.Conv2d(in_channels=num_features + 4,
387 | out_channels=4,
388 | kernel_size=3,
389 | padding=1,
390 | ),
391 | nn.Sigmoid()
392 | )
393 | self.C_t = nn.Sequential(
394 | nn.Conv2d(in_channels=num_features + 4,
395 | out_channels=4,
396 | kernel_size=3,
397 | padding=1,
398 | ),
399 | nn.Tanh()
400 | )
401 | self.Q_t = nn.Sequential(
402 | nn.Conv2d(in_channels=num_features + 4,
403 | out_channels=num_features,
404 | kernel_size=3,
405 | padding=1,
406 | ),
407 | nn.Sigmoid()
408 | )
409 |
410 | def forward(self, input_tensor, b, d):
411 | input_tensor = maps_2_cubes(input_tensor, b, d)
412 | b, c, d, h, w = input_tensor.shape
413 | h_state_init = torch.zeros(b, 4, h, w).to('cuda')
414 | c_state_init = torch.zeros(b, 4, h, w).to('cuda')
415 |
416 | seq_len = d
417 |
418 | h_state, c_state = h_state_init, c_state_init
419 | output_inner = []
420 | for t in range(seq_len):
421 | input_cat = torch.cat((input_tensor[:,:,t,:,:], h_state), dim=1)
422 | c_state = self.F_t(input_cat) * c_state + self.I_t(input_cat) * self.C_t(input_cat)
423 |
424 | h_state, p_depth = self.Refine(torch.cat((c_state, self.Q_t(input_cat)), 1))
425 |
426 | output_inner.append(p_depth)
427 |
428 | layer_output = torch.stack(output_inner, dim=2)
429 |
430 | return layer_output
431 |
432 |
433 | class R_CLSTM_5(nn.Module):
434 | def __init__(self, block_channel):
435 | super(R_CLSTM_5, self).__init__()
436 | num_features = 64 + block_channel[3] // 32
437 | self.Refine = R_3(block_channel)
438 | self.F_t = nn.Sequential(
439 | nn.Conv2d(in_channels=num_features+8,
440 | out_channels=8,
441 | kernel_size=3,
442 | padding=1,
443 | ),
444 | nn.Sigmoid()
445 | )
446 | self.I_t = nn.Sequential(
447 | nn.Conv2d(in_channels=num_features + 8,
448 | out_channels=8,
449 | kernel_size=3,
450 | padding=1,
451 | ),
452 | nn.Sigmoid()
453 | )
454 | self.C_t = nn.Sequential(
455 | nn.Conv2d(in_channels=num_features + 8,
456 | out_channels=8,
457 | kernel_size=3,
458 | padding=1,
459 | ),
460 | nn.Tanh()
461 | )
462 | self.Q_t = nn.Sequential(
463 | nn.Conv2d(in_channels=num_features + 8,
464 | out_channels=num_features,
465 | kernel_size=3,
466 | padding=1,
467 | ),
468 | nn.Sigmoid()
469 | )
470 |
471 | def forward(self, input_tensor, b, d):
472 | input_tensor = maps_2_cubes(input_tensor, b, d)
473 | b, c, d, h, w = input_tensor.shape
474 | h_state_init = torch.zeros(b, 8, h, w).to('cuda')
475 | c_state_init = torch.zeros(b, 8, h, w).to('cuda')
476 |
477 | seq_len = d
478 |
479 | h_state, c_state = h_state_init, c_state_init
480 | output_inner = []
481 | start = time.time()
482 | for t in range(seq_len):
483 | input_cat = torch.cat((input_tensor[:,:,t,:,:], h_state), dim=1)
484 | c_state = self.F_t(input_cat) * c_state + self.I_t(input_cat) * self.C_t(input_cat)
485 |
486 | h_state, p_depth = self.Refine(torch.cat((c_state, self.Q_t(input_cat)), 1))
487 |
488 | output_inner.append(p_depth)
489 |
490 | layer_output = torch.stack(output_inner, dim=2)
491 | torch.cuda.synchronize()
492 | # print(time.time()-start)
493 |
494 | return layer_output
495 |
496 |
497 | class R_cell(nn.Module):
498 | def __init__(self, block_channel, cell_width=16):
499 | super(R_cell, self).__init__()
500 |
501 | num_features = 64 + block_channel[3] // 32
502 | self.conv0 = nn.Conv2d(num_features + cell_width, num_features,
503 | kernel_size=5, stride=1, padding=2, bias=False)
504 | self.bn0 = nn.BatchNorm2d(num_features)
505 |
506 | self.conv1 = nn.Conv2d(num_features, num_features,
507 | kernel_size=5, stride=1, padding=2, bias=False)
508 | self.bn1 = nn.BatchNorm2d(num_features)
509 |
510 | self.conv2 = nn.Conv2d(
511 | num_features, 1, kernel_size=5, stride=1, padding=2, bias=True)
512 |
513 | self.convh = nn.Conv2d(
514 | num_features, cell_width, kernel_size=3, stride=1, padding=1, bias=True)
515 |
516 | def forward(self, x):
517 | x0 = self.conv0(x)
518 | x0 = self.bn0(x0)
519 | x0 = F.relu(x0)
520 |
521 | x1 = self.conv1(x0)
522 | x1 = self.bn1(x1)
523 | x1 = F.relu(x1)
524 |
525 | h = self.convh(x1)
526 | pred_depth = self.conv2(x1)
527 |
528 | return h, pred_depth
529 |
530 |
531 | class R_CLSTM_6(nn.Module):
532 | def __init__(self, block_channel):
533 | super(R_CLSTM_6, self).__init__()
534 | num_features = 64 + block_channel[3] // 32
535 | self.cell_width = 8
536 | self.Refine = R_cell(block_channel, self.cell_width)
537 | self.F_t = nn.Sequential(
538 | nn.Conv2d(in_channels=num_features + self.cell_width,
539 | out_channels=self.cell_width,
540 | kernel_size=3,
541 | padding=1,
542 | ),
543 | nn.Sigmoid()
544 | )
545 | self.I_t = nn.Sequential(
546 | nn.Conv2d(in_channels=num_features + self.cell_width,
547 | out_channels=self.cell_width,
548 | kernel_size=3,
549 | padding=1,
550 | ),
551 | nn.Sigmoid()
552 | )
553 | self.C_t = nn.Sequential(
554 | nn.Conv2d(in_channels=num_features + self.cell_width,
555 | out_channels=self.cell_width,
556 | kernel_size=3,
557 | padding=1,
558 | ),
559 | nn.Tanh()
560 | )
561 | self.Q_t = nn.Sequential(
562 | nn.Conv2d(in_channels=num_features + self.cell_width,
563 | out_channels=num_features,
564 | kernel_size=3,
565 | padding=1,
566 | ),
567 | nn.Sigmoid()
568 | )
569 |
570 | def forward(self, input_tensor, b, d):
571 | input_tensor = maps_2_cubes(input_tensor, b, d)
572 | b, c, d, h, w = input_tensor.shape
573 | h_state_init = torch.zeros(b, self.cell_width, h, w).to('cuda')
574 | c_state_init = torch.zeros(b, self.cell_width, h, w).to('cuda')
575 |
576 | seq_len = d
577 |
578 | h_state, c_state = h_state_init, c_state_init
579 | output_inner = []
580 | for t in range(seq_len):
581 | input_cat = torch.cat((input_tensor[:, :, t, :, :], h_state), dim=1)
582 | c_state = self.F_t(input_cat) * c_state + self.I_t(input_cat) * self.C_t(input_cat)
583 |
584 | h_state, p_depth = self.Refine(torch.cat((c_state, self.Q_t(input_cat)), 1))
585 |
586 | output_inner.append(p_depth)
587 |
588 | layer_output = torch.stack(output_inner, dim=2)
589 |
590 | return layer_output
591 |
592 |
593 | class R_CLSTM_7(nn.Module):
594 | def __init__(self, block_channel):
595 | super(R_CLSTM_7, self).__init__()
596 | num_features = 64 + block_channel[3] // 32
597 | self.cell_width = 12
598 | self.Refine = R_cell(block_channel, self.cell_width)
599 | self.F_t = nn.Sequential(
600 | nn.Conv2d(in_channels=num_features + self.cell_width,
601 | out_channels=self.cell_width,
602 | kernel_size=3,
603 | padding=1,
604 | ),
605 | nn.Sigmoid()
606 | )
607 | self.I_t = nn.Sequential(
608 | nn.Conv2d(in_channels=num_features + self.cell_width,
609 | out_channels=self.cell_width,
610 | kernel_size=3,
611 | padding=1,
612 | ),
613 | nn.Sigmoid()
614 | )
615 | self.C_t = nn.Sequential(
616 | nn.Conv2d(in_channels=num_features + self.cell_width,
617 | out_channels=self.cell_width,
618 | kernel_size=3,
619 | padding=1,
620 | ),
621 | nn.Tanh()
622 | )
623 | self.Q_t = nn.Sequential(
624 | nn.Conv2d(in_channels=num_features + self.cell_width,
625 | out_channels=num_features,
626 | kernel_size=3,
627 | padding=1,
628 | ),
629 | nn.Sigmoid()
630 | )
631 |
632 | def forward(self, input_tensor, b, d):
633 | input_tensor = maps_2_cubes(input_tensor, b, d)
634 | b, c, d, h, w = input_tensor.shape
635 | h_state_init = torch.zeros(b, self.cell_width, h, w).to('cuda')
636 | c_state_init = torch.zeros(b, self.cell_width, h, w).to('cuda')
637 |
638 | seq_len = d
639 |
640 | h_state, c_state = h_state_init, c_state_init
641 | output_inner = []
642 | for t in range(seq_len):
643 | input_cat = torch.cat((input_tensor[:, :, t, :, :], h_state), dim=1)
644 | c_state = self.F_t(input_cat) * c_state + self.I_t(input_cat) * self.C_t(input_cat)
645 |
646 | h_state, p_depth = self.Refine(torch.cat((c_state, self.Q_t(input_cat)), 1))
647 |
648 | output_inner.append(p_depth)
649 |
650 | layer_output = torch.stack(output_inner, dim=2)
651 |
652 | return layer_output
653 |
654 |
655 | class R_CLSTM_8(nn.Module):
656 | def __init__(self, block_channel):
657 | super(R_CLSTM_8, self).__init__()
658 | num_features = 64 + block_channel[3] // 32
659 | self.cell_width = 16
660 | self.Refine = R_cell(block_channel, self.cell_width)
661 | self.F_t = nn.Sequential(
662 | nn.Conv2d(in_channels=num_features + self.cell_width,
663 | out_channels=self.cell_width,
664 | kernel_size=3,
665 | padding=1,
666 | ),
667 | nn.Sigmoid()
668 | )
669 | self.I_t = nn.Sequential(
670 | nn.Conv2d(in_channels=num_features + self.cell_width,
671 | out_channels=self.cell_width,
672 | kernel_size=3,
673 | padding=1,
674 | ),
675 | nn.Sigmoid()
676 | )
677 | self.C_t = nn.Sequential(
678 | nn.Conv2d(in_channels=num_features + self.cell_width,
679 | out_channels=self.cell_width,
680 | kernel_size=3,
681 | padding=1,
682 | ),
683 | nn.Tanh()
684 | )
685 | self.Q_t = nn.Sequential(
686 | nn.Conv2d(in_channels=num_features + self.cell_width,
687 | out_channels=num_features,
688 | kernel_size=3,
689 | padding=1,
690 | ),
691 | nn.Sigmoid()
692 | )
693 |
694 | def forward(self, input_tensor, b, d):
695 | input_tensor = maps_2_cubes(input_tensor, b, d)
696 | b, c, d, h, w = input_tensor.shape
697 | h_state_init = torch.zeros(b, self.cell_width, h, w).to('cuda')
698 | c_state_init = torch.zeros(b, self.cell_width, h, w).to('cuda')
699 |
700 | seq_len = d
701 |
702 | h_state, c_state = h_state_init, c_state_init
703 | output_inner = []
704 | for t in range(seq_len):
705 | input_cat = torch.cat((input_tensor[:, :, t, :, :], h_state), dim=1)
706 | c_state = self.F_t(input_cat) * c_state + self.I_t(input_cat) * self.C_t(input_cat)
707 |
708 | h_state, p_depth = self.Refine(torch.cat((c_state, self.Q_t(input_cat)), 1))
709 |
710 | output_inner.append(p_depth)
711 |
712 | layer_output = torch.stack(output_inner, dim=2)
713 |
714 | return layer_output
715 |
716 |
717 | class R_CLSTM_9(nn.Module):
718 | def __init__(self, block_channel):
719 | super(R_CLSTM_9, self).__init__()
720 | num_features = 64 + block_channel[3] // 32
721 | self.cell_width = 32
722 | self.Refine = R_cell(block_channel, self.cell_width)
723 | self.F_t = nn.Sequential(
724 | nn.Conv2d(in_channels=num_features + self.cell_width,
725 | out_channels=self.cell_width,
726 | kernel_size=3,
727 | padding=1,
728 | ),
729 | nn.Sigmoid()
730 | )
731 | self.I_t = nn.Sequential(
732 | nn.Conv2d(in_channels=num_features + self.cell_width,
733 | out_channels=self.cell_width,
734 | kernel_size=3,
735 | padding=1,
736 | ),
737 | nn.Sigmoid()
738 | )
739 | self.C_t = nn.Sequential(
740 | nn.Conv2d(in_channels=num_features + self.cell_width,
741 | out_channels=self.cell_width,
742 | kernel_size=3,
743 | padding=1,
744 | ),
745 | nn.Tanh()
746 | )
747 | self.Q_t = nn.Sequential(
748 | nn.Conv2d(in_channels=num_features + self.cell_width,
749 | out_channels=num_features,
750 | kernel_size=3,
751 | padding=1,
752 | ),
753 | nn.Sigmoid()
754 | )
755 |
756 | def forward(self, input_tensor, b, d):
757 | input_tensor = maps_2_cubes(input_tensor, b, d)
758 | b, c, d, h, w = input_tensor.shape
759 | h_state_init = torch.zeros(b, self.cell_width, h, w).to('cuda')
760 | c_state_init = torch.zeros(b, self.cell_width, h, w).to('cuda')
761 |
762 | seq_len = d
763 |
764 | h_state, c_state = h_state_init, c_state_init
765 | output_inner = []
766 | for t in range(seq_len):
767 | input_cat = torch.cat((input_tensor[:, :, t, :, :], h_state), dim=1)
768 | c_state = self.F_t(input_cat) * c_state + self.I_t(input_cat) * self.C_t(input_cat)
769 |
770 | h_state, p_depth = self.Refine(torch.cat((c_state, self.Q_t(input_cat)), 1))
771 |
772 | output_inner.append(p_depth)
773 |
774 | layer_output = torch.stack(output_inner, dim=2)
775 |
776 | return layer_output
777 |
778 |
779 | class R_CLSTM_10(nn.Module):
780 | def __init__(self, block_channel):
781 | super(R_CLSTM_10, self).__init__()
782 | num_features = 64 + block_channel[3] // 32
783 | self.Refine = R_10(block_channel)
784 | self.F_t = nn.Sequential(
785 | nn.Conv2d(in_channels=num_features+8,
786 | out_channels=8,
787 | kernel_size=3,
788 | padding=1,
789 | ),
790 | nn.Sigmoid()
791 | )
792 | self.I_t = nn.Sequential(
793 | nn.Conv2d(in_channels=num_features + 8,
794 | out_channels=8,
795 | kernel_size=3,
796 | padding=1,
797 | ),
798 | nn.Sigmoid()
799 | )
800 | self.C_t = nn.Sequential(
801 | nn.Conv2d(in_channels=num_features + 8,
802 | out_channels=8,
803 | kernel_size=3,
804 | padding=1,
805 | ),
806 | nn.Tanh()
807 | )
808 | self.Q_t = nn.Sequential(
809 | nn.Conv2d(in_channels=num_features + 8,
810 | out_channels=num_features,
811 | kernel_size=3,
812 | padding=1,
813 | ),
814 | nn.Sigmoid()
815 | )
816 |
817 | def forward(self, input_tensor, b, d):
818 | input_tensor = maps_2_cubes(input_tensor, b, d)
819 | b, c, d, h, w = input_tensor.shape
820 | h_state_init = torch.zeros(b, 8, h, w).to('cuda')
821 | c_state_init = torch.zeros(b, 8, h, w).to('cuda')
822 |
823 | seq_len = d
824 |
825 | h_state, c_state = h_state_init, c_state_init
826 | output_inner = []
827 | for t in range(seq_len):
828 | input_cat = torch.cat((input_tensor[:,:,t,:,:], h_state), dim=1)
829 | c_state = self.F_t(input_cat) * c_state + self.I_t(input_cat) * self.C_t(input_cat)
830 |
831 | h_state, p_depth = self.Refine(torch.cat((c_state, self.Q_t(input_cat)), 1))
832 |
833 | output_inner.append(p_depth)
834 |
835 | layer_output = torch.stack(output_inner, dim=2)
836 |
837 | return layer_output
838 |
839 |
840 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision
4 | from models.modules import E_resnet, E_densenet, E_senet
5 | from models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
6 | from models.densenet import densenet161, densenet121, densenet169, densenet201
7 | from models.senet import senet154, se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d, se_resnext101_32x4d
8 | import pdb
9 |
10 | __models_small__ = {
11 | 'ResNet18': lambda :E_resnet(resnet18(pretrained = True)),
12 | 'ResNet34': lambda :E_resnet(resnet34(pretrained = True)),
13 | 'ResNet50': lambda :E_resnet(resnet50(pretrained = True)),
14 | 'ResNet101': lambda :E_resnet(resnet101(pretrained = True)),
15 | 'ResNet152': lambda :E_resnet(resnet152(pretrained = True)),
16 | 'DenseNet121': lambda :E_densenet(densenet121(pretrained = True)),
17 | 'DenseNet161': lambda :E_densenet(densenet161(pretrained = True)),
18 | 'DenseNet169': lambda :E_densenet(densenet169(pretrained = True)),
19 | 'DenseNet201': lambda :E_densenet(densenet201(pretrained = True)),
20 | 'SENet154': lambda :E_senet(senet154(pretrained="imagenet")),
21 | 'SE_ResNet50': lambda :E_senet(se_resnet50(pretrained="imagenet")),
22 | 'SE_ResNet101': lambda :E_senet(se_resnet101(pretrained="imagenet")),
23 | 'SE_ResNet152': lambda :E_senet(se_resnet152(pretrained="imagenet")),
24 | 'SE_ResNext50_32x4d': lambda :E_senet(se_resnext50_32x4d(pretrained="imagenet")),
25 | 'SE_ResNext101_32x4d': lambda :E_senet(se_resnext101_32x4d(pretrained="imagenet"))
26 | }
27 |
28 |
29 | def get_models(args):
30 | backbone = args.backbone
31 |
32 | if os.getenv('TORCH_MODEL_ZOO') != args.pretrained_dir:
33 | os.environ['TORCH_MODEL_ZOO'] = args.pretrained_dir
34 | else:
35 | pass
36 |
37 | return __models_small__[backbone]()
38 |
39 |
--------------------------------------------------------------------------------
/models/backbone_dict.py:
--------------------------------------------------------------------------------
1 | from models.resnet import resnet18, resnet34, resnet50
2 |
3 |
4 | backbone_dict = {
5 | 'resnet18': resnet18,
6 | 'resnet34': resnet34,
7 | 'resnet50': resnet50,
8 | }
--------------------------------------------------------------------------------
/models/densenet.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.utils.model_zoo as model_zoo
6 | from collections import OrderedDict
7 |
8 |
9 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
10 |
11 |
12 | model_urls = {
13 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
14 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
15 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
16 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
17 | }
18 |
19 |
20 | def densenet121(pretrained=False, **kwargs):
21 | r"""Densenet-121 model from
22 | `"Densely Connected Convolutional Networks" `_
23 |
24 | Args:
25 | pretrained (bool): If True, returns a model pre-trained on ImageNet
26 | """
27 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
28 | **kwargs)
29 | if pretrained:
30 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
31 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
32 | # They are also in the checkpoints in model_urls. This pattern is used
33 | # to find such keys.
34 | pattern = re.compile(
35 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
36 | state_dict = model_zoo.load_url(model_urls['densenet121'])
37 | for key in list(state_dict.keys()):
38 | res = pattern.match(key)
39 | if res:
40 | new_key = res.group(1) + res.group(2)
41 | state_dict[new_key] = state_dict[key]
42 | del state_dict[key]
43 | model.load_state_dict(state_dict)
44 | return model
45 |
46 |
47 | def densenet169(pretrained=False, **kwargs):
48 | r"""Densenet-169 model from
49 | `"Densely Connected Convolutional Networks" `_
50 |
51 | Args:
52 | pretrained (bool): If True, returns a model pre-trained on ImageNet
53 | """
54 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
55 | **kwargs)
56 | if pretrained:
57 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
58 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
59 | # They are also in the checkpoints in model_urls. This pattern is used
60 | # to find such keys.
61 | pattern = re.compile(
62 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
63 | state_dict = model_zoo.load_url(model_urls['densenet169'])
64 | for key in list(state_dict.keys()):
65 | res = pattern.match(key)
66 | if res:
67 | new_key = res.group(1) + res.group(2)
68 | state_dict[new_key] = state_dict[key]
69 | del state_dict[key]
70 | model.load_state_dict(state_dict)
71 | return model
72 |
73 |
74 | def densenet201(pretrained=False, **kwargs):
75 | r"""Densenet-201 model from
76 | `"Densely Connected Convolutional Networks" `_
77 |
78 | Args:
79 | pretrained (bool): If True, returns a model pre-trained on ImageNet
80 | """
81 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
82 | **kwargs)
83 | if pretrained:
84 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
85 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
86 | # They are also in the checkpoints in model_urls. This pattern is used
87 | # to find such keys.
88 | pattern = re.compile(
89 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
90 | state_dict = model_zoo.load_url(model_urls['densenet201'])
91 | for key in list(state_dict.keys()):
92 | res = pattern.match(key)
93 | if res:
94 | new_key = res.group(1) + res.group(2)
95 | state_dict[new_key] = state_dict[key]
96 | del state_dict[key]
97 | model.load_state_dict(state_dict)
98 | return model
99 |
100 |
101 | def densenet161(pretrained=False, **kwargs):
102 | r"""Densenet-161 model from
103 | `"Densely Connected Convolutional Networks" `_
104 |
105 | Args:
106 | pretrained (bool): If True, returns a model pre-trained on ImageNet
107 | """
108 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
109 | **kwargs)
110 | if pretrained:
111 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
112 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
113 | # They are also in the checkpoints in model_urls. This pattern is used
114 | # to find such keys.
115 | pattern = re.compile(
116 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
117 | state_dict = model_zoo.load_url(model_urls['densenet161'])
118 | for key in list(state_dict.keys()):
119 | res = pattern.match(key)
120 | if res:
121 | new_key = res.group(1) + res.group(2)
122 | state_dict[new_key] = state_dict[key]
123 | del state_dict[key]
124 | model.load_state_dict(state_dict)
125 | return model
126 |
127 | class _DenseLayer(nn.Sequential):
128 |
129 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
130 | super(_DenseLayer, self).__init__()
131 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
132 | self.add_module('relu1', nn.ReLU(inplace=True)),
133 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
134 | growth_rate, kernel_size=1, stride=1, bias=False)),
135 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
136 | self.add_module('relu2', nn.ReLU(inplace=True)),
137 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
138 | kernel_size=3, stride=1, padding=1, bias=False)),
139 | self.drop_rate = drop_rate
140 |
141 | def forward(self, x):
142 | new_features = super(_DenseLayer, self).forward(x)
143 | if self.drop_rate > 0:
144 | new_features = F.dropout(
145 | new_features, p=self.drop_rate, training=self.training)
146 | return torch.cat([x, new_features], 1)
147 |
148 |
149 | class _DenseBlock(nn.Sequential):
150 |
151 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
152 | super(_DenseBlock, self).__init__()
153 | for i in range(num_layers):
154 | layer = _DenseLayer(num_input_features + i *
155 | growth_rate, growth_rate, bn_size, drop_rate)
156 | self.add_module('denselayer%d' % (i + 1), layer)
157 |
158 |
159 | class _Transition(nn.Sequential):
160 |
161 | def __init__(self, num_input_features, num_output_features):
162 | super(_Transition, self).__init__()
163 | self.add_module('norm', nn.BatchNorm2d(num_input_features))
164 | self.add_module('relu', nn.ReLU(inplace=True))
165 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
166 | kernel_size=1, stride=1, bias=False))
167 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
168 |
169 |
170 |
171 | class DenseNet(nn.Module):
172 |
173 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
174 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):
175 |
176 | super(DenseNet, self).__init__()
177 |
178 | # First convolution
179 | self.features = nn.Sequential(OrderedDict([
180 | ('conv0', nn.Conv2d(3, num_init_features,
181 | kernel_size=7, stride=2, padding=3, bias=False)),
182 | ('norm0', nn.BatchNorm2d(num_init_features)),
183 | ('relu0', nn.ReLU(inplace=True)),
184 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
185 | ]))
186 |
187 | # Each denseblock
188 | num_features = num_init_features
189 | for i, num_layers in enumerate(block_config):
190 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
191 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
192 | self.features.add_module('denseblock%d' % (i + 1), block)
193 | num_features = num_features + num_layers * growth_rate
194 | if i != len(block_config) - 1:
195 | trans = _Transition(
196 | num_input_features=num_features, num_output_features=num_features // 2)
197 | self.features.add_module('transition%d' % (i + 1), trans)
198 | num_features = num_features // 2
199 | # print(str(i), num_features)
200 |
201 | # Final batch norm
202 | self.features.add_module('norm5', nn.BatchNorm2d(num_features))
203 | self.num_features = num_features
204 |
205 | # Linear layer
206 | self.classifier = nn.Linear(num_features, num_classes)
207 |
208 |
209 | def forward(self, x):
210 | features = self.features(x)
211 | out = F.relu(features, inplace=True)
212 | out = F.avg_pool2d(out, kernel_size=7, stride=1).view(
213 | features.size(0), -1)
214 | out = self.classifier(out)
215 | return out
216 |
217 |
--------------------------------------------------------------------------------
/models/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | import numpy as np
5 |
6 | def adjust_gt(gt_depth, pred_depth):
7 | adjusted_gt = []
8 | for each_depth in pred_depth:
9 | adjusted_gt.append(F.interpolate(gt_depth, size=[each_depth.size(2), each_depth.size(3)],
10 | mode='bilinear', align_corners=True))
11 | return adjusted_gt
12 |
13 | class Sobel(nn.Module):
14 | def __init__(self):
15 | super(Sobel, self).__init__()
16 | self.edge_conv = nn.Conv2d(1, 2, kernel_size=3, stride=1, padding=1, bias=False)
17 | edge_kx = np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]])
18 | edge_ky = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])
19 | edge_k = np.stack((edge_kx, edge_ky))
20 |
21 | edge_k = torch.from_numpy(edge_k).float().view(2, 1, 3, 3)
22 | self.edge_conv.weight = nn.Parameter(edge_k)
23 |
24 | for param in self.parameters():
25 | param.requires_grad = False
26 |
27 | def forward(self, x):
28 | out = self.edge_conv(x)
29 | out = out.contiguous().view(-1, 2, x.size(2), x.size(3))
30 |
31 | return out
32 |
33 | def cal_spatial_loss(output, depth_gt):
34 |
35 | losses=[]
36 |
37 | for depth_index in range(len(output)):
38 |
39 | cos = nn.CosineSimilarity(dim=1, eps=0)
40 | get_gradient = Sobel().cuda()
41 | ones = torch.ones(depth_gt.size(0), 1, depth_gt.size(2),depth_gt.size(3)).float().cuda()
42 | ones = torch.autograd.Variable(ones)
43 | depth_grad = get_gradient(depth_gt)
44 | output_grad = get_gradient(output)
45 | depth_grad_dx = depth_grad[:, 0, :, :].contiguous().view_as(depth_gt)
46 | depth_grad_dy = depth_grad[:, 1, :, :].contiguous().view_as(depth_gt)
47 | output_grad_dx = output_grad[:, 0, :, :].contiguous().view_as(depth_gt)
48 | output_grad_dy = output_grad[:, 1, :, :].contiguous().view_as(depth_gt)
49 |
50 | depth_normal = torch.cat((-depth_grad_dx, -depth_grad_dy, ones), 1)
51 | output_normal = torch.cat((-output_grad_dx, -output_grad_dy, ones), 1)
52 |
53 | cof = 0.5
54 |
55 | loss_depth = torch.log(torch.abs(output - depth_gt) + cof).mean()
56 | loss_dx = torch.log(torch.abs(output_grad_dx - depth_grad_dx) + cof).mean()
57 | loss_dy = torch.log(torch.abs(output_grad_dy - depth_grad_dy) + cof).mean()
58 | loss_normal = torch.abs(1 - cos(output_normal, depth_normal)).mean()
59 |
60 | loss = loss_depth + loss_normal + (loss_dx + loss_dy)
61 |
62 | losses.append(loss)
63 |
64 | spatial_loss = sum(losses)
65 |
66 | return spatial_loss
67 |
68 | def cal_temporal_loss(pred_cls, gt_cls):
69 | return F.binary_cross_entropy_with_logits(pred_cls, gt_cls)
70 |
--------------------------------------------------------------------------------
/models/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 |
5 | class _UpProjection(nn.Sequential):
6 |
7 | def __init__(self, num_input_features, num_output_features):
8 | super(_UpProjection, self).__init__()
9 |
10 | self.conv1 = nn.Conv2d(num_input_features, num_output_features,
11 | kernel_size=5, stride=1, padding=2, bias=False)
12 | self.bn1 = nn.BatchNorm2d(num_output_features)
13 | self.relu = nn.ReLU(inplace=True)
14 | self.conv1_2 = nn.Conv2d(num_output_features, num_output_features,
15 | kernel_size=3, stride=1, padding=1, bias=False)
16 | self.bn1_2 = nn.BatchNorm2d(num_output_features)
17 |
18 | self.conv2 = nn.Conv2d(num_input_features, num_output_features,
19 | kernel_size=5, stride=1, padding=2, bias=False)
20 | self.bn2 = nn.BatchNorm2d(num_output_features)
21 |
22 | def forward(self, x, size):
23 | x = F.upsample(x, size=size, mode='bilinear', align_corners=True)
24 | x_conv1 = self.relu(self.bn1(self.conv1(x)))
25 | bran1 = self.bn1_2(self.conv1_2(x_conv1))
26 | bran2 = self.bn2(self.conv2(x))
27 |
28 | out = self.relu(bran1 + bran2)
29 |
30 | return out
31 |
32 | class E_resnet(nn.Module):
33 |
34 | def __init__(self, original_model, num_features = 2048):
35 | super(E_resnet, self).__init__()
36 | self.conv1 = original_model.conv1
37 | self.bn1 = original_model.bn1
38 | self.relu = original_model.relu
39 | self.maxpool = original_model.maxpool
40 |
41 | self.layer1 = original_model.layer1
42 | self.layer2 = original_model.layer2
43 | self.layer3 = original_model.layer3
44 | self.layer4 = original_model.layer4
45 |
46 |
47 | def forward(self, x):
48 | x = self.conv1(x)
49 | x = self.bn1(x)
50 | x = self.relu(x)
51 | x = self.maxpool(x)
52 |
53 | x_block1 = self.layer1(x)
54 | x_block2 = self.layer2(x_block1)
55 | x_block3 = self.layer3(x_block2)
56 | x_block4 = self.layer4(x_block3)
57 |
58 | return x_block1, x_block2, x_block3, x_block4
59 |
60 | class E_densenet(nn.Module):
61 |
62 | def __init__(self, original_model, num_features = 2208):
63 | super(E_densenet, self).__init__()
64 | self.features = original_model.features
65 |
66 | def forward(self, x):
67 | x01 = self.features[0](x)
68 | x02 = self.features[1](x01)
69 | x03 = self.features[2](x02)
70 | x04 = self.features[3](x03)
71 |
72 | x_block1 = self.features[4](x04)
73 | x_block1 = self.features[5][0](x_block1)
74 | x_block1 = self.features[5][1](x_block1)
75 | x_block1 = self.features[5][2](x_block1)
76 | x_tran1 = self.features[5][3](x_block1)
77 |
78 | x_block2 = self.features[6](x_tran1)
79 | x_block2 = self.features[7][0](x_block2)
80 | x_block2 = self.features[7][1](x_block2)
81 | x_block2 = self.features[7][2](x_block2)
82 | x_tran2 = self.features[7][3](x_block2)
83 |
84 | x_block3 = self.features[8](x_tran2)
85 | x_block3 = self.features[9][0](x_block3)
86 | x_block3 = self.features[9][1](x_block3)
87 | x_block3 = self.features[9][2](x_block3)
88 | x_tran3 = self.features[9][3](x_block3)
89 |
90 | x_block4 = self.features[10](x_tran3)
91 | x_block4 = F.relu(self.features[11](x_block4))
92 |
93 | return x_block1, x_block2, x_block3, x_block4
94 |
95 | class E_senet(nn.Module):
96 |
97 | def __init__(self, original_model, num_features = 2048):
98 | super(E_senet, self).__init__()
99 | self.base = nn.Sequential(*list(original_model.children())[:-3])
100 |
101 | def forward(self, x):
102 | x = self.base[0](x)
103 | x_block1 = self.base[1](x)
104 | x_block2 = self.base[2](x_block1)
105 | x_block3 = self.base[3](x_block2)
106 | x_block4 = self.base[4](x_block3)
107 |
108 | return x_block1, x_block2, x_block3, x_block4
109 |
110 | class D(nn.Module):
111 |
112 | def __init__(self, num_features = 2048):
113 | super(D, self).__init__()
114 | self.conv = nn.Conv2d(num_features, num_features //
115 | 2, kernel_size=1, stride=1, bias=False)
116 | num_features = num_features // 2
117 | self.bn = nn.BatchNorm2d(num_features)
118 |
119 | self.up1 = _UpProjection(
120 | num_input_features=num_features, num_output_features=num_features // 2)
121 | num_features = num_features // 2
122 |
123 | self.up2 = _UpProjection(
124 | num_input_features=num_features, num_output_features=num_features // 2)
125 | num_features = num_features // 2
126 |
127 | self.up3 = _UpProjection(
128 | num_input_features=num_features, num_output_features=num_features // 2)
129 | num_features = num_features // 2
130 |
131 | self.up4 = _UpProjection(
132 | num_input_features=num_features, num_output_features=num_features // 2)
133 | num_features = num_features // 2
134 |
135 |
136 | def forward(self, x_block1, x_block2, x_block3, x_block4):
137 | x_d0 = F.relu(self.bn(self.conv(x_block4)))
138 | x_d1 = self.up1(x_d0, [x_block3.size(2), x_block3.size(3)])
139 | x_d2 = self.up2(x_d1, [x_block2.size(2), x_block2.size(3)])
140 | x_d3 = self.up3(x_d2, [x_block1.size(2), x_block1.size(3)])
141 | x_d4 = self.up4(x_d3, [x_block1.size(2)*2, x_block1.size(3)*2])
142 |
143 | return x_d4
144 |
145 | class MFF(nn.Module):
146 |
147 | def __init__(self, block_channel, num_features=64):
148 |
149 | super(MFF, self).__init__()
150 |
151 | self.up1 = _UpProjection(
152 | num_input_features=block_channel[0], num_output_features=16)
153 |
154 | self.up2 = _UpProjection(
155 | num_input_features=block_channel[1], num_output_features=16)
156 |
157 | self.up3 = _UpProjection(
158 | num_input_features=block_channel[2], num_output_features=16)
159 |
160 | self.up4 = _UpProjection(
161 | num_input_features=block_channel[3], num_output_features=16)
162 |
163 | self.conv = nn.Conv2d(
164 | num_features, num_features, kernel_size=5, stride=1, padding=2, bias=False)
165 | self.bn = nn.BatchNorm2d(num_features)
166 |
167 |
168 | def forward(self, x_block1, x_block2, x_block3, x_block4, size):
169 | x_m1 = self.up1(x_block1, size)
170 | x_m2 = self.up2(x_block2, size)
171 | x_m3 = self.up3(x_block3, size)
172 | x_m4 = self.up4(x_block4, size)
173 |
174 | x = self.bn(self.conv(torch.cat((x_m1, x_m2, x_m3, x_m4), 1)))
175 | x = F.relu(x)
176 |
177 | return x
178 |
--------------------------------------------------------------------------------
/models/net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from models import modules
4 | from models.refinenet_dict import refinenet_dict
5 |
6 |
7 | def cubes_2_maps(cubes):
8 | b, c, d, h, w = cubes.shape
9 | cubes = cubes.permute(0, 2, 1, 3, 4)
10 |
11 | return cubes.contiguous().view(b*d, c, h, w), b, d
12 |
13 |
14 | class model(nn.Module):
15 | def __init__(self, Encoder, num_features, block_channel, refinenet):
16 |
17 | super(model, self).__init__()
18 |
19 | self.E = Encoder
20 | self.D = modules.D(num_features)
21 | self.MFF = modules.MFF(block_channel)
22 | self.R = refinenet_dict[refinenet](block_channel)
23 |
24 | def forward(self, x):
25 | x, b, d = cubes_2_maps(x)
26 | x_block1, x_block2, x_block3, x_block4 = self.E(x)
27 | x_decoder = self.D(x_block1, x_block2, x_block3, x_block4)
28 | x_mff = self.MFF(x_block1, x_block2, x_block3, x_block4,[x_decoder.size(2),x_decoder.size(3)])
29 | out = self.R(torch.cat((x_decoder, x_mff), 1), b, d)
30 |
31 | return out
32 |
33 | class C_C3D_1(nn.Module):
34 |
35 | def __init__(self, num_classes=2, init_width=32, input_channels=1):
36 | self.inplanes = init_width
37 | super(C_C3D_1, self).__init__()
38 | self.conv1 = nn.Sequential(
39 | nn.Conv3d(input_channels, init_width, kernel_size=5, stride=2, padding=2, bias=False),
40 | nn.BatchNorm3d(init_width),
41 | nn.ReLU(inplace=True),
42 | nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
43 | )
44 | self.conv2 = nn.Sequential(
45 | nn.Conv3d(init_width, init_width*2, kernel_size=3, stride=1, padding=1, bias=False),
46 | nn.BatchNorm3d(init_width*2),
47 | nn.ReLU(inplace=True),
48 | nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
49 | )
50 |
51 | init_width = init_width * 2
52 | self.conv3 = nn.Sequential(
53 | nn.Conv3d(init_width, init_width * 2, kernel_size=3, stride=1, padding=1, bias=False),
54 | nn.BatchNorm3d(init_width * 2),
55 | nn.ReLU(inplace=True),
56 | nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
57 | )
58 |
59 | init_width = init_width * 2
60 | self.conv4 = nn.Sequential(
61 | nn.Conv3d(init_width, init_width * 2, kernel_size=3, stride=1, padding=1, bias=False),
62 | nn.BatchNorm3d(init_width * 2),
63 | nn.ReLU(inplace=True),
64 | nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
65 | )
66 |
67 | self.avgpool = nn.AdaptiveAvgPool3d(1)
68 | self.fc = nn.Linear(init_width * 2, num_classes)
69 |
70 | def forward(self, x):
71 | x = self.conv1(x)
72 | x = self.conv2(x)
73 | x = self.conv3(x)
74 | x = self.conv4(x)
75 | x = self.avgpool(x)
76 | x = x.view(x.size(0), -1)
77 | x = self.fc(x)
78 |
79 | return x
80 |
81 | class C_C3D_2(nn.Module):
82 |
83 | def __init__(self, num_classes=2, init_width=32, input_channels=1):
84 | self.inplanes = init_width
85 | super(C_C3D_2, self).__init__()
86 | self.conv1 = nn.Sequential(
87 | nn.Conv3d(input_channels, init_width, kernel_size=5, stride=2, padding=2, bias=False),
88 | nn.BatchNorm3d(init_width),
89 | nn.ReLU(inplace=True),
90 | nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
91 | )
92 | self.conv2 = nn.Sequential(
93 | nn.Conv3d(init_width, init_width*2, kernel_size=3, stride=1, padding=1, bias=False),
94 | nn.BatchNorm3d(init_width*2),
95 | nn.ReLU(inplace=True),
96 | nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
97 | )
98 |
99 | init_width = init_width * 2
100 | self.conv3 = nn.Sequential(
101 | nn.Conv3d(init_width, init_width * 2, kernel_size=3, stride=1, padding=1, bias=False),
102 | nn.BatchNorm3d(init_width * 2),
103 | nn.ReLU(inplace=True),
104 | nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
105 | )
106 |
107 | self.avgpool = nn.AdaptiveAvgPool3d(1)
108 | self.fc = nn.Linear(init_width * 2, num_classes)
109 |
110 | def forward(self, x):
111 | x = self.conv1(x)
112 | x = self.conv2(x)
113 | x = self.conv3(x)
114 | x = self.avgpool(x)
115 | x = x.view(x.size(0), -1)
116 | x = self.fc(x)
117 |
118 | return x
--------------------------------------------------------------------------------
/models/refinenet_dict.py:
--------------------------------------------------------------------------------
1 | from models.R_CLSTM_modules import (R_CLSTM_1, R_CLSTM_2, R_CLSTM_3,
2 | R_CLSTM_4, R_CLSTM_5, R_CLSTM_6,
3 | R_CLSTM_7, R_CLSTM_8, R_CLSTM_9)
4 |
5 |
6 |
7 | refinenet_dict = {
8 | 'R_CLSTM_1': R_CLSTM_1,
9 | 'R_CLSTM_2': R_CLSTM_2,
10 | 'R_CLSTM_3': R_CLSTM_3,
11 | 'R_CLSTM_4': R_CLSTM_4,
12 | 'R_CLSTM_5': R_CLSTM_5,
13 | 'R_CLSTM_6': R_CLSTM_6,
14 | 'R_CLSTM_7': R_CLSTM_7,
15 | 'R_CLSTM_8': R_CLSTM_8,
16 | 'R_CLSTM_9': R_CLSTM_9
17 | }
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch.utils.model_zoo as model_zoo
4 | import torch.nn.functional as F
5 | import torch
6 | import numpy as np
7 |
8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
9 | 'resnet152']
10 |
11 |
12 | model_urls = {
13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
18 | }
19 |
20 |
21 | def conv3x3(in_planes, out_planes, stride=1):
22 | "3x3 convolution with padding"
23 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
24 | padding=1, bias=False)
25 |
26 |
27 | class BasicBlock(nn.Module):
28 | expansion = 1
29 |
30 | def __init__(self, inplanes, planes, stride=1, downsample=None):
31 | super(BasicBlock, self).__init__()
32 | self.conv1 = conv3x3(inplanes, planes, stride)
33 | self.bn1 = nn.BatchNorm2d(planes)
34 | self.relu = nn.ReLU(inplace=True)
35 | self.conv2 = conv3x3(planes, planes)
36 | self.bn2 = nn.BatchNorm2d(planes)
37 | self.downsample = downsample
38 | self.stride = stride
39 |
40 | def forward(self, x):
41 | residual = x
42 |
43 | out = self.conv1(x)
44 | out = self.bn1(out)
45 | out = self.relu(out)
46 |
47 | out = self.conv2(out)
48 | out = self.bn2(out)
49 |
50 | if self.downsample is not None:
51 | residual = self.downsample(x)
52 |
53 | out += residual
54 | out = self.relu(out)
55 |
56 | return out
57 |
58 |
59 | class Bottleneck(nn.Module):
60 | expansion = 4
61 |
62 | def __init__(self, inplanes, planes, stride=1, downsample=None):
63 | super(Bottleneck, self).__init__()
64 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
65 | self.bn1 = nn.BatchNorm2d(planes)
66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
67 | padding=1, bias=False)
68 | self.bn2 = nn.BatchNorm2d(planes)
69 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
70 | self.bn3 = nn.BatchNorm2d(planes * 4)
71 | self.relu = nn.ReLU(inplace=True)
72 | self.downsample = downsample
73 | self.stride = stride
74 |
75 | def forward(self, x):
76 | residual = x
77 |
78 | out = self.conv1(x)
79 | out = self.bn1(out)
80 | out = self.relu(out)
81 |
82 | out = self.conv2(out)
83 | out = self.bn2(out)
84 | out = self.relu(out)
85 |
86 | out = self.conv3(out)
87 | out = self.bn3(out)
88 |
89 | if self.downsample is not None:
90 | residual = self.downsample(x)
91 |
92 | out += residual
93 | out = self.relu(out)
94 |
95 | return out
96 |
97 |
98 | class ResNet(nn.Module):
99 |
100 | def __init__(self, block, layers, num_classes=1000):
101 | self.inplanes = 64
102 | super(ResNet, self).__init__()
103 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
104 | bias=False)
105 | self.bn1 = nn.BatchNorm2d(64)
106 | self.relu = nn.ReLU(inplace=True)
107 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
108 | self.layer1 = self._make_layer(block, 64, layers[0])
109 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
110 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
111 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
112 | self.avgpool = nn.AvgPool2d(7, stride=1)
113 | self.fc = nn.Linear(512 * block.expansion, num_classes)
114 |
115 | for m in self.modules():
116 | if isinstance(m, nn.Conv2d):
117 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
118 | m.weight.data.normal_(0, math.sqrt(2. / n))
119 | elif isinstance(m, nn.BatchNorm2d):
120 | m.weight.data.fill_(1)
121 | m.bias.data.zero_()
122 |
123 | def _make_layer(self, block, planes, blocks, stride=1):
124 | downsample = None
125 | # stride=1, 64 != 256
126 | if stride != 1 or self.inplanes != planes * block.expansion:
127 | downsample = nn.Sequential(
128 | nn.Conv2d(self.inplanes, planes * block.expansion,
129 | kernel_size=1, stride=stride, bias=False),
130 | nn.BatchNorm2d(planes * block.expansion),
131 | )
132 |
133 | layers = []
134 | layers.append(block(self.inplanes, planes, stride, downsample))
135 | self.inplanes = planes * block.expansion
136 | for i in range(1, blocks):
137 | layers.append(block(self.inplanes, planes))
138 |
139 | return nn.Sequential(*layers)
140 |
141 | def forward(self, x):
142 | x = self.conv1(x)
143 | x = self.bn1(x)
144 | x = self.relu(x)
145 | x = self.maxpool(x)
146 |
147 | x = self.layer1(x)
148 | x = self.layer2(x)
149 | x = self.layer3(x)
150 | x = self.layer4(x)
151 |
152 | x = self.avgpool(x)
153 | x = x.view(x.size(0), -1)
154 | x = self.fc(x)
155 |
156 | return x
157 |
158 | def resnet18(pretrained=True, **kwargs):
159 | """Constructs a ResNet-18 model.
160 | Args:
161 | pretrained (bool): If True, returns a model pre-trained on ImageNet
162 | """
163 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
164 | if pretrained:
165 | from collections import OrderedDict
166 | pretrained_state = model_zoo.load_url(model_urls['resnet18'])
167 | model_state = model.state_dict()
168 | selected_state = OrderedDict()
169 | for k, v in pretrained_state.items():
170 | if k in model_state and v.size() == model_state[k].size():
171 | #print('pretrain..',k)
172 | selected_state[k] = v
173 | model_state.update(selected_state)
174 | model.load_state_dict(model_state)
175 | return model
176 |
177 |
178 | def resnet34(pretrained=False, **kwargs):
179 | """Constructs a ResNet-34 model.
180 | Args:
181 | pretrained (bool): If True, returns a model pre-trained on ImageNet
182 | """
183 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
184 | if pretrained:
185 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
186 | return model
187 |
188 |
189 | def resnet50(pretrained=False, **kwargs):
190 | """Constructs a ResNet-50 model.
191 | Args:
192 | pretrained (bool): If True, returns a model pre-trained on ImageNet
193 | """
194 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
195 | if pretrained:
196 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
197 | return model
198 |
199 |
200 | def resnet101(pretrained=False, **kwargs):
201 | """Constructs a ResNet-101 model.
202 | Args:
203 | pretrained (bool): If True, returns a model pre-trained on ImageNet
204 | """
205 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
206 | if pretrained:
207 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
208 | return model
209 |
210 |
211 | def resnet152(pretrained=False, **kwargs):
212 | """Constructs a ResNet-152 model.
213 | Args:
214 | pretrained (bool): If True, returns a model pre-trained on ImageNet
215 | """
216 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
217 | if pretrained:
218 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
219 | return model
220 |
--------------------------------------------------------------------------------
/models/senet.py:
--------------------------------------------------------------------------------
1 | """
2 | ResNet code gently borrowed from
3 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
4 | """
5 | from collections import OrderedDict
6 | import math
7 | import torch
8 | import torch.nn.functional as F
9 | import torch.nn as nn
10 | from torch.utils import model_zoo
11 | import copy
12 | import numpy as np
13 |
14 | __all__ = ['SENet', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152',
15 | 'se_resnext50_32x4d', 'se_resnext101_32x4d']
16 |
17 | pretrained_settings = {
18 | 'senet154': {
19 | 'imagenet': {
20 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth',
21 | 'input_space': 'RGB',
22 | 'input_size': [3, 224, 224],
23 | 'input_range': [0, 1],
24 | 'mean': [0.485, 0.456, 0.406],
25 | 'std': [0.229, 0.224, 0.225],
26 | 'num_classes': 1000
27 | }
28 | },
29 | 'se_resnet50': {
30 | 'imagenet': {
31 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
32 | 'input_space': 'RGB',
33 | 'input_size': [3, 224, 224],
34 | 'input_range': [0, 1],
35 | 'mean': [0.485, 0.456, 0.406],
36 | 'std': [0.229, 0.224, 0.225],
37 | 'num_classes': 1000
38 | }
39 | },
40 | 'se_resnet101': {
41 | 'imagenet': {
42 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth',
43 | 'input_space': 'RGB',
44 | 'input_size': [3, 224, 224],
45 | 'input_range': [0, 1],
46 | 'mean': [0.485, 0.456, 0.406],
47 | 'std': [0.229, 0.224, 0.225],
48 | 'num_classes': 1000
49 | }
50 | },
51 | 'se_resnet152': {
52 | 'imagenet': {
53 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth',
54 | 'input_space': 'RGB',
55 | 'input_size': [3, 224, 224],
56 | 'input_range': [0, 1],
57 | 'mean': [0.485, 0.456, 0.406],
58 | 'std': [0.229, 0.224, 0.225],
59 | 'num_classes': 1000
60 | }
61 | },
62 | 'se_resnext50_32x4d': {
63 | 'imagenet': {
64 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
65 | 'input_space': 'RGB',
66 | 'input_size': [3, 224, 224],
67 | 'input_range': [0, 1],
68 | 'mean': [0.485, 0.456, 0.406],
69 | 'std': [0.229, 0.224, 0.225],
70 | 'num_classes': 1000
71 | }
72 | },
73 | 'se_resnext101_32x4d': {
74 | 'imagenet': {
75 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
76 | 'input_space': 'RGB',
77 | 'input_size': [3, 224, 224],
78 | 'input_range': [0, 1],
79 | 'mean': [0.485, 0.456, 0.406],
80 | 'std': [0.229, 0.224, 0.225],
81 | 'num_classes': 1000
82 | }
83 | },
84 | }
85 |
86 |
87 | class SEModule(nn.Module):
88 |
89 | def __init__(self, channels, reduction):
90 | super(SEModule, self).__init__()
91 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
92 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1,
93 | padding=0)
94 | self.relu = nn.ReLU(inplace=True)
95 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1,
96 | padding=0)
97 | self.sigmoid = nn.Sigmoid()
98 |
99 | def forward(self, x):
100 | module_input = x
101 | x = self.avg_pool(x)
102 | x = self.fc1(x)
103 | x = self.relu(x)
104 | x = self.fc2(x)
105 | x = self.sigmoid(x)
106 | return module_input * x
107 |
108 |
109 | class Bottleneck(nn.Module):
110 | """
111 | Base class for bottlenecks that implements `forward()` method.
112 | """
113 | def forward(self, x):
114 | residual = x
115 |
116 | out = self.conv1(x)
117 | out = self.bn1(out)
118 | out = self.relu(out)
119 |
120 | out = self.conv2(out)
121 | out = self.bn2(out)
122 | out = self.relu(out)
123 |
124 | out = self.conv3(out)
125 | out = self.bn3(out)
126 |
127 | if self.downsample is not None:
128 | residual = self.downsample(x)
129 |
130 | out = self.se_module(out) + residual
131 | out = self.relu(out)
132 |
133 | return out
134 |
135 |
136 | class SEBottleneck(Bottleneck):
137 | """
138 | Bottleneck for SENet154.
139 | """
140 | expansion = 4
141 |
142 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
143 | downsample=None):
144 |
145 | super(SEBottleneck, self).__init__()
146 | self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
147 | self.bn1 = nn.BatchNorm2d(planes * 2)
148 | self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3,
149 | stride=stride, padding=1, groups=groups,
150 | bias=False)
151 | self.bn2 = nn.BatchNorm2d(planes * 4)
152 | self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1,
153 | bias=False)
154 | self.bn3 = nn.BatchNorm2d(planes * 4)
155 | self.relu = nn.ReLU(inplace=True)
156 | self.se_module = SEModule(planes * 4, reduction=reduction)
157 | self.downsample = downsample
158 | self.stride = stride
159 |
160 |
161 | class SEResNetBottleneck(Bottleneck):
162 | """
163 | ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe
164 | implementation and uses `stride=stride` in `conv1` and not in `conv2`
165 | (the latter is used in the torchvision implementation of ResNet).
166 | """
167 | expansion = 4
168 |
169 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
170 | downsample=None):
171 | super(SEResNetBottleneck, self).__init__()
172 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False,
173 | stride=stride)
174 | self.bn1 = nn.BatchNorm2d(planes)
175 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1,
176 | groups=groups, bias=False)
177 | self.bn2 = nn.BatchNorm2d(planes)
178 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
179 | self.bn3 = nn.BatchNorm2d(planes * 4)
180 | self.relu = nn.ReLU(inplace=True)
181 | self.se_module = SEModule(planes * 4, reduction=reduction)
182 | self.downsample = downsample
183 | self.stride = stride
184 |
185 |
186 | class SEResNeXtBottleneck(Bottleneck):
187 | """
188 | ResNeXt bottleneck type C with a Squeeze-and-Excitation module.
189 | """
190 | expansion = 4
191 |
192 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
193 | downsample=None, base_width=4):
194 | super(SEResNeXtBottleneck, self).__init__()
195 |
196 | width = int(planes * base_width / 64) * groups
197 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False,
198 | stride=1)
199 | self.bn1 = nn.BatchNorm2d(width)
200 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
201 | padding=1, groups=groups, bias=False)
202 | self.bn2 = nn.BatchNorm2d(width)
203 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)
204 | self.bn3 = nn.BatchNorm2d(planes * 4)
205 | self.relu = nn.ReLU(inplace=True)
206 | self.se_module = SEModule(planes * 4, reduction=reduction)
207 | self.downsample = downsample
208 | self.stride = stride
209 |
210 |
211 | class SENet(nn.Module):
212 |
213 | def __init__(self, block, layers, groups, reduction, dropout_p=0.2,
214 | inplanes=128, input_3x3=True, downsample_kernel_size=3,
215 | downsample_padding=1, num_classes=1000):
216 | """
217 | Parameters
218 | ----------
219 | block (nn.Module): Bottleneck class.
220 | - For SENet154: SEBottleneck
221 | - For SE-ResNet models: SEResNetBottleneck
222 | - For SE-ResNeXt models: SEResNeXtBottleneck
223 | layers (list of ints): Number of residual blocks for 4 layers of the
224 | network (layer1...layer4).
225 | groups (int): Number of groups for the 3x3 convolution in each
226 | bottleneck block.
227 | - For SENet154: 64
228 | - For SE-ResNet models: 1
229 | - For SE-ResNeXt models: 32
230 | reduction (int): Reduction ratio for Squeeze-and-Excitation modules.
231 | - For all models: 16
232 | dropout_p (float or None): Drop probability for the Dropout layer.
233 | If `None` the Dropout layer is not used.
234 | - For SENet154: 0.2
235 | - For SE-ResNet models: None
236 | - For SE-ResNeXt models: None
237 | inplanes (int): Number of input channels for layer1.
238 | - For SENet154: 128
239 | - For SE-ResNet models: 64
240 | - For SE-ResNeXt models: 64
241 | input_3x3 (bool): If `True`, use three 3x3 convolutions instead of
242 | a single 7x7 convolution in layer0.
243 | - For SENet154: True
244 | - For SE-ResNet models: False
245 | - For SE-ResNeXt models: False
246 | downsample_kernel_size (int): Kernel size for downsampling convolutions
247 | in layer2, layer3 and layer4.
248 | - For SENet154: 3
249 | - For SE-ResNet models: 1
250 | - For SE-ResNeXt models: 1
251 | downsample_padding (int): Padding for downsampling convolutions in
252 | layer2, layer3 and layer4.
253 | - For SENet154: 1
254 | - For SE-ResNet models: 0
255 | - For SE-ResNeXt models: 0
256 | num_classes (int): Number of outputs in `last_linear` layer.
257 | - For all models: 1000
258 | """
259 | super(SENet, self).__init__()
260 | self.inplanes = inplanes
261 | if input_3x3:
262 | layer0_modules = [
263 | ('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1,
264 | bias=False)),
265 | ('bn1', nn.BatchNorm2d(64)),
266 | ('relu1', nn.ReLU(inplace=True)),
267 | ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1,
268 | bias=False)),
269 | ('bn2', nn.BatchNorm2d(64)),
270 | ('relu2', nn.ReLU(inplace=True)),
271 | ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1,
272 | bias=False)),
273 | ('bn3', nn.BatchNorm2d(inplanes)),
274 | ('relu3', nn.ReLU(inplace=True)),
275 | ]
276 | else:
277 | layer0_modules = [
278 | ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2,
279 | padding=3, bias=False)),
280 | ('bn1', nn.BatchNorm2d(inplanes)),
281 | ('relu1', nn.ReLU(inplace=True)),
282 | ]
283 | # To preserve compatibility with Caffe weights `ceil_mode=True`
284 | # is used instead of `padding=1`.
285 | layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2,
286 | ceil_mode=True)))
287 | self.layer0 = nn.Sequential(OrderedDict(layer0_modules))
288 | self.layer1 = self._make_layer(
289 | block,
290 | planes=64,
291 | blocks=layers[0],
292 | groups=groups,
293 | reduction=reduction,
294 | downsample_kernel_size=1,
295 | downsample_padding=0
296 | )
297 | self.layer2 = self._make_layer(
298 | block,
299 | planes=128,
300 | blocks=layers[1],
301 | stride=2,
302 | groups=groups,
303 | reduction=reduction,
304 | downsample_kernel_size=downsample_kernel_size,
305 | downsample_padding=downsample_padding
306 | )
307 | self.layer3 = self._make_layer(
308 | block,
309 | planes=256,
310 | blocks=layers[2],
311 | stride=2,
312 | groups=groups,
313 | reduction=reduction,
314 | downsample_kernel_size=downsample_kernel_size,
315 | downsample_padding=downsample_padding
316 | )
317 | self.layer4 = self._make_layer(
318 | block,
319 | planes=512,
320 | blocks=layers[3],
321 | stride=2,
322 | groups=groups,
323 | reduction=reduction,
324 | downsample_kernel_size=downsample_kernel_size,
325 | downsample_padding=downsample_padding
326 | )
327 | self.avg_pool = nn.AvgPool2d(7, stride=1)
328 | self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
329 | self.last_linear = nn.Linear(512 * block.expansion, num_classes)
330 |
331 |
332 | def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
333 | downsample_kernel_size=1, downsample_padding=0):
334 |
335 | downsample = None
336 |
337 | if stride != 1 or self.inplanes != planes * block.expansion:
338 |
339 | downsample = nn.Sequential(
340 | nn.Conv2d(self.inplanes, planes * block.expansion,
341 | kernel_size=downsample_kernel_size, stride=stride,
342 | padding=downsample_padding, bias=False),
343 | nn.BatchNorm2d(planes * block.expansion),
344 | )
345 |
346 | layers = []
347 | layers.append(block(self.inplanes, planes, groups, reduction, stride,
348 | downsample))
349 | self.inplanes = planes * block.expansion
350 | for i in range(1, blocks):
351 | layers.append(block(self.inplanes, planes, groups, reduction))
352 |
353 | return nn.Sequential(*layers)
354 |
355 |
356 | def features(self, x):
357 | x = self.layer0(x)
358 | x = self.layer1(x)
359 | x = self.layer2(x)
360 | x = self.layer3(x)
361 | x = self.layer4(x)
362 |
363 | return x
364 |
365 |
366 | def logits(self, x):
367 | x = self.avg_pool(x)
368 | if self.dropout is not None:
369 | x = self.dropout(x)
370 | x = x.view(x.size(0), -1)
371 | x = self.last_linear(x)
372 | return x
373 |
374 | def forward(self, x,x_):
375 | x = self.features(x)
376 | x = self.logits(x)
377 | return x
378 |
379 | def initialize_pretrained_model(model, num_classes, settings):
380 | assert num_classes == settings['num_classes'], \
381 | 'num_classes should be {}, but is {}'.format(
382 | settings['num_classes'], num_classes)
383 | model.load_state_dict(model_zoo.load_url(settings['url']))
384 | model.input_space = settings['input_space']
385 | model.input_size = settings['input_size']
386 | model.input_range = settings['input_range']
387 | model.mean = settings['mean']
388 | model.std = settings['std']
389 |
390 |
391 | def senet154(num_classes=1000, pretrained='imagenet'):
392 | model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16,
393 | dropout_p=0.2, num_classes=num_classes)
394 | if pretrained is not None:
395 | settings = pretrained_settings['senet154'][pretrained]
396 | initialize_pretrained_model(model, num_classes, settings)
397 | return model
398 |
399 | def se_resnet50(num_classes=1000, pretrained='imagenet'):
400 | model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16,
401 | dropout_p=0.2, inplanes=64, input_3x3=False,
402 | downsample_kernel_size=1, downsample_padding=0,
403 | num_classes=num_classes)
404 | if pretrained is not None:
405 | settings = pretrained_settings['se_resnet50'][pretrained]
406 | initialize_pretrained_model(model, num_classes, settings)
407 | return model
408 |
409 |
410 | def se_resnet101(num_classes=1000, pretrained='imagenet'):
411 | model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16,
412 | dropout_p=0.2, inplanes=64, input_3x3=False,
413 | downsample_kernel_size=1, downsample_padding=0,
414 | num_classes=num_classes)
415 | if pretrained is not None:
416 | settings = pretrained_settings['se_resnet101'][pretrained]
417 | initialize_pretrained_model(model, num_classes, settings)
418 | return model
419 |
420 |
421 | def se_resnet152(num_classes=1000, pretrained='imagenet'):
422 | model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16,
423 | dropout_p=0.2, inplanes=64, input_3x3=False,
424 | downsample_kernel_size=1, downsample_padding=0,
425 | num_classes=num_classes)
426 | if pretrained is not None:
427 | settings = pretrained_settings['se_resnet152'][pretrained]
428 | initialize_pretrained_model(model, num_classes, settings)
429 | return model
430 |
431 |
432 | def se_resnext50_32x4d(num_classes=1000, pretrained='imagenet'):
433 | model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16,
434 | dropout_p=0.2, inplanes=64, input_3x3=False,
435 | downsample_kernel_size=1, downsample_padding=0,
436 | num_classes=num_classes)
437 | if pretrained is not None:
438 | settings = pretrained_settings['se_resnext50_32x4d'][pretrained]
439 | initialize_pretrained_model(model, num_classes, settings)
440 | return model
441 |
442 |
443 | def se_resnext101_32x4d(num_classes=1000, pretrained='imagenet'):
444 | model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16,
445 | dropout_p=0.2, inplanes=64, input_3x3=False,
446 | downsample_kernel_size=1, downsample_padding=0,
447 | num_classes=num_classes)
448 | if pretrained is not None:
449 | settings = pretrained_settings['se_resnext101_32x4d'][pretrained]
450 | initialize_pretrained_model(model, num_classes, settings)
451 | return model
452 |
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
1 | from .trainopt import _get_train_opt
2 | from .testopt import _get_test_opt
3 |
4 | def get_args(mode):
5 | args = None
6 | if mode == 'train':
7 | args = _get_train_opt()
8 | elif mode == 'test':
9 | args = _get_test_opt()
10 | else:
11 | raise ValueError("Invalid mode selection!")
12 |
13 | return args
14 |
--------------------------------------------------------------------------------
/options/testopt.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def _get_test_opt():
4 | parser = argparse.ArgumentParser(description = 'Evaluate performance of SARPN on NYU-D v2 test set')
5 | parser.add_argument('--testlist_path', required=True, help='the path of testlist')
6 | parser.add_argument('--root_path', required=True, help="the root path of dataset")
7 | parser.add_argument('--backbone', type=str, default='resnet18')
8 | parser.add_argument('--refinenet', type=str, default='R_CLSTM_5')
9 | parser.add_argument('--batch_size', type=int, default=1, help='testing batch size')
10 | parser.add_argument('--loadckpt', required=True, help="the path of the loaded model")
11 | # parse arguments
12 | return parser.parse_args()
13 |
--------------------------------------------------------------------------------
/options/trainopt.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def _get_train_opt():
4 | parser = argparse.ArgumentParser(description = 'Monocular Depth Estimation')
5 | parser.add_argument('--trainlist_path', required=True, help='the path of trainlist', default='/media3/x00532679/project/ST-SARPN/data_list/raw_nyu_v2_250k/raw_nyu_v2_250k_fps30_fl5_op0_end_train.json')
6 | parser.add_argument('--root_path', required=True, help="the root path of dataset", default='/media3/x00532679/data/')
7 | parser.add_argument('--batch_size', type=int, default=16, help='training batch size')
8 | parser.add_argument('--epochs', default=20, type=int, help='number of epochs')
9 | parser.add_argument('--backbone', type=str, default='resnet18')
10 | parser.add_argument('--refinenet', type=str, default='R_CLSTM_5')
11 | parser.add_argument('--logdir', required=True, help="the directory to save logs and checkpoints", default='./checkpoint')
12 | parser.add_argument('--checkpoint_dir', required=True, help="the directory to save the checkpoints", default='./log_224')
13 | parser.add_argument('--loadckpt', type=str)
14 | parser.add_argument('--overlap', type=int, default=0)
15 | parser.add_argument('--use_cuda', type=bool, default=True)
16 | parser.add_argument('--devices', type=str, default='0')
17 | parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate')
18 | parser.add_argument('--resume', action='store_true',default=False, help='continue training the model')
19 | parser.add_argument('--momentum', default=0.9, type=float, help='Momentum parameter used in the Optimizer.')
20 | parser.add_argument('--epsilon', default=0.001, type=float, help='epsilon')
21 | parser.add_argument('--optimizer_name', default="adam", type=str, help="Optimizer selection")
22 | parser.add_argument('--weight_decay', default=1e-5, type=float, help='weight decay')
23 | parser.add_argument('--do_summary', action='store_true', default=False, help='whether do summary or not')
24 | parser.add_argument('--pretrained_dir', required=False,type=str, help="the path of pretrained models")
25 | return parser.parse_args()
26 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | import numpy as np
5 | import torch.nn as nn
6 | import torch.nn.parallel
7 | import torch.optim as optim
8 | import torch.backends.cudnn as cudnn
9 | from torch.autograd import Variable
10 | from tensorboardX import SummaryWriter
11 | from utils import *
12 | from options import get_args
13 | from dataloader import nyudv2_dataloader
14 | from models.loss import cal_spatial_loss, cal_temporal_loss
15 | from models.backbone_dict import backbone_dict
16 | from models import modules
17 | from models import net
18 |
19 | cudnn.benchmark = True
20 | args = get_args('train')
21 |
22 | os.environ['CUDA_VISIBLE_DEVICES'] = args.devices
23 |
24 | # Create folder
25 | makedir(args.checkpoint_dir)
26 | makedir(args.logdir)
27 |
28 | # creat summary logger
29 | logger = SummaryWriter(args.logdir)
30 |
31 | # dataset, dataloader
32 | TrainImgLoader = nyudv2_dataloader.getTrainingData_NYUDV2(args.batch_size, args.trainlist_path, args.root_path)
33 | # model, optimizer
34 | device = 'cuda' if torch.cuda.is_available() and args.use_cuda else 'cpu'
35 |
36 | backbone = backbone_dict[args.backbone]()
37 | Encoder = modules.E_resnet(backbone)
38 |
39 | if args.backbone in ['resnet50']:
40 | model = net.model(Encoder, num_features=2048, block_channel=[256, 512, 1024, 2048], refinenet=args.refinenet)
41 | elif args.backbone in ['resnet18', 'resnet34']:
42 | model = net.model(Encoder, num_features=512, block_channel=[64, 128, 256, 512], refinenet=args.refinenet)
43 |
44 | model = nn.DataParallel(model).cuda()
45 |
46 | disc = net.C_C3D_1().cuda()
47 |
48 | optimizer = build_optimizer(model = model,
49 | learning_rate=args.lr,
50 | optimizer_name=args.optimizer_name,
51 | weight_decay = args.weight_decay,
52 | epsilon=args.epsilon,
53 | momentum=args.momentum
54 | )
55 |
56 | start_epoch = 0
57 |
58 | if args.resume:
59 | all_saved_ckpts = [ckpt for ckpt in os.listdir(args.checkpoint_dir) if ckpt.endswith(".pth.tar")]
60 | print(all_saved_ckpts)
61 | all_saved_ckpts = sorted(all_saved_ckpts, key=lambda x:int(x.split('_')[-1].split('.')[0]))
62 | loadckpt = os.path.join(args.checkpoint_dir, all_saved_ckpts[-1])
63 | start_epoch = int(all_saved_ckpts[-1].split('_')[-1].split('.')[0])
64 | print("loading the lastest model in checkpoint_dir: {}".format(loadckpt))
65 | state_dict = torch.load(loadckpt)
66 | model.load_state_dict(state_dict)
67 | elif args.loadckpt is not None:
68 | print("loading model {}".format(args.loadckpt))
69 | start_epoch = args.loadckpt.split('_')[-1].split('.')[0]
70 | state_dict = torch.load(args.loadckpt)
71 | model.load_state_dict(state_dict)
72 | else:
73 | print("start at epoch {}".format(start_epoch))
74 |
75 | def train():
76 | for epoch in range(start_epoch, args.epochs):
77 | adjust_learning_rate(optimizer, epoch, args.lr)
78 | batch_time = AverageMeter()
79 | losses = AverageMeter()
80 | model.train()
81 | end = time.time()
82 | for batch_idx, sample in enumerate(TrainImgLoader):
83 |
84 | image, depth = sample[0], sample[1]#(b,c,d,w,h)
85 |
86 | depth = depth.cuda()
87 | image = image.cuda()
88 | image = torch.autograd.Variable(image)
89 | depth = torch.autograd.Variable(depth)
90 | optimizer.zero_grad()
91 | global_step = len(TrainImgLoader) * epoch + batch_idx
92 | gt_depth = depth
93 | pred_depth = model(image)#(b, c, d, h, w)
94 |
95 | # Calculate the total loss
96 | spatial_losses=[]
97 | for seq_idx in range(image.size(2)):
98 | spatial_loss = cal_spatial_loss(pred_depth[:,:,seq_idx,:,:], gt_depth[:,:,seq_idx,:,:])
99 | spatial_losses.append(spatial_loss)
100 | spatial_loss = sum(spatial_losses)
101 |
102 | pred_cls = disc(pred_depth)
103 | gt_cls = disc(gt_depth)
104 | temporal_loss = cal_temporal_loss(pred_cls, gt_cls)
105 |
106 | loss = spatial_loss + 0.1 * temporal_loss
107 |
108 | losses.update(loss.item(), image.size(0))
109 | loss.backward()
110 | optimizer.step()
111 |
112 | batch_time.update(time.time() - end)
113 | end = time.time()
114 |
115 | batchSize = depth.size(0)
116 |
117 | print(('Epoch: [{0}][{1}/{2}]\t'
118 | 'Time {batch_time.val:.3f} ({batch_time.sum:.3f})\t'
119 | 'Loss {loss.val:.4f} ({loss.avg:.4f})'
120 | .format(epoch, batch_idx, len(TrainImgLoader), batch_time=batch_time, loss=losses)))
121 |
122 | if (epoch+1)%1 == 0:
123 | save_checkpoint(model.state_dict(), filename=args.checkpoint_dir + "ResNet18_checkpoints_small_" + str(epoch + 1) + ".pth.tar")
124 |
125 | if __name__ == '__main__':
126 | train()
127 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import torch
4 | import numpy as np
5 | from PIL import Image
6 | import matplotlib
7 | import matplotlib.cm
8 | import torchvision.utils as vutils
9 | from models_small.loss import Sobel
10 | import matplotlib.pyplot as plt
11 | cmap = plt.cm.viridis
12 |
13 |
14 | def draw_losses(logger, loss, global_step):
15 | name = "train_loss"
16 | logger.add_scalar(name, loss, global_step)
17 |
18 | def draw_images(logger, all_draw_image, global_step):
19 | for image_name, images in all_draw_image.items():
20 | if images.shape[1] == 1:
21 | images = colormap(images)
22 | elif images.shape[1] == 3:
23 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406],
24 | 'std': [0.229, 0.224, 0.225]}
25 | for channel in np.arange(images.shape[1]):
26 | images[:, channel, :, :] = images[:, channel, :, :] * __imagenet_stats["std"][channel] + __imagenet_stats["mean"][channel]
27 |
28 | if len(images.shape) == 3:
29 | images = images[np.newaxis, :, :, :]
30 | if images.shape[0]>4:
31 | images = images[:4, :, :, :]
32 |
33 | logger.add_image(image_name, images, global_step)
34 |
35 | def save_image(img_merge, filename):
36 | img_merge = Image.fromarray(img_merge.astype('uint8'))
37 | img_merge.save(filename)
38 |
39 |
40 | def colored_depthmap(depth, d_min=None, d_max=None):
41 | if d_min is None:
42 | d_min = np.min(depth)
43 | if d_max is None:
44 | d_max = np.max(depth)
45 | depth_relative = (depth - d_min) / (d_max - d_min)
46 | return 255 * cmap(depth_relative)[:,:,:3] # H, W, C
47 |
48 |
49 | def merge_into_row(input, depth_target, depth_pred):
50 | rgb = np.transpose(input.cpu().numpy(), (1,2,0)) # H, W, C
51 | depth_target_cpu = np.squeeze(depth_target.cpu().numpy())
52 | depth_pred_cpu = np.squeeze(depth_pred.data.cpu().numpy())
53 |
54 | d_min = min(np.min(depth_target_cpu), np.min(depth_pred_cpu))
55 | d_max = max(np.max(depth_target_cpu), np.max(depth_pred_cpu))
56 | depth_target_col = colored_depthmap(depth_target_cpu, d_min, d_max)
57 | depth_pred_col = colored_depthmap(depth_pred_cpu, d_min, d_max)
58 |
59 | # img_merge = np.hstack([rgb, depth_target_col, depth_pred_col])
60 | # return img_merge
61 | return rgb, depth_target_col, depth_pred_col
62 |
63 |
64 | def makedir(directory):
65 | if not os.path.exists(directory):
66 | os.makedirs(directory)
67 |
68 | def adjust_learning_rate(optimizer, epoch, init_lr):
69 |
70 | lr = init_lr * (0.1 ** (epoch // 5))
71 |
72 | for param_group in optimizer.param_groups:
73 | param_group['lr'] = lr
74 |
75 | class AverageMeter(object):
76 | def __init__(self):
77 | self.reset()
78 |
79 | def reset(self):
80 | self.val = 0
81 | self.avg = 0
82 | self.sum = 0
83 | self.count = 0
84 |
85 | def update(self, val, n=1):
86 | self.val = val
87 | self.sum += val * n
88 | self.count += n
89 | self.avg = self.sum / self.count
90 |
91 | def save_checkpoint(state, filename):
92 | torch.save(state, filename)
93 |
94 |
95 | def edge_detection(depth):
96 | get_edge = Sobel().cuda()
97 |
98 | edge_xy = get_edge(depth)
99 | edge_sobel = torch.pow(edge_xy[:, 0, :, :], 2) + \
100 | torch.pow(edge_xy[:, 1, :, :], 2)
101 | edge_sobel = torch.sqrt(edge_sobel)
102 |
103 | return edge_sobel
104 |
105 | def build_optimizer(model,
106 | learning_rate,
107 | optimizer_name='rmsprop',
108 | weight_decay=1e-5,
109 | epsilon=0.001,
110 | momentum=0.9):
111 | """Build optimizer"""
112 | if optimizer_name == "sgd":
113 | print("Using SGD optimizer.")
114 | optimizer = torch.optim.SGD(model.parameters(),
115 | lr = learning_rate,
116 | momentum=momentum,
117 | weight_decay=weight_decay)
118 |
119 | elif optimizer_name == 'rmsprop':
120 | print("Using RMSProp optimizer.")
121 | optimizer = torch.optim.RMSprop(model.parameters(),
122 | lr = learning_rate,
123 | eps = epsilon,
124 | weight_decay = weight_decay,
125 | momentum = momentum
126 | )
127 | elif optimizer_name == 'adam':
128 | print("Using Adam optimizer.")
129 | optimizer = torch.optim.Adam(model.parameters(),
130 | lr = learning_rate, weight_decay=weight_decay)
131 | return optimizer
132 |
133 |
134 |
135 |
136 | #original script: https://github.com/fangchangma/sparse-to-dense/blob/master/utils.lua
137 |
138 |
139 | def lg10(x):
140 | return torch.div(torch.log(x), math.log(10))
141 |
142 | def maxOfTwo(x, y):
143 | z = x.clone()
144 | maskYLarger = torch.lt(x, y)
145 | z[maskYLarger.detach()] = y[maskYLarger.detach()]
146 | return z
147 |
148 | def nValid(x):
149 | return torch.sum(torch.eq(x, x).float())
150 |
151 | def nNanElement(x):
152 | return torch.sum(torch.ne(x, x).float())
153 |
154 | def getNanMask(x):
155 | return torch.ne(x, x)
156 |
157 | def setNanToZero(input, target):
158 | nanMask = getNanMask(target)
159 | nValidElement = nValid(target)
160 |
161 | _input = input.clone()
162 | _target = target.clone()
163 |
164 | _input[nanMask] = 0
165 | _target[nanMask] = 0
166 |
167 | return _input, _target, nanMask, nValidElement
168 |
169 |
170 | def evaluateError(output, target):
171 | errors = {'MSE': 0, 'RMSE': 0, 'ABS_REL': 0, 'LG10': 0,
172 | 'MAE': 0, 'DELTA1': 0, 'DELTA2': 0, 'DELTA3': 0}
173 |
174 | _output, _target, nanMask, nValidElement = setNanToZero(output, target)
175 |
176 | if (nValidElement.data.cpu().numpy() > 0):
177 | diffMatrix = torch.abs(_output - _target)
178 |
179 | errors['MSE'] = torch.sum(torch.pow(diffMatrix, 2)) / nValidElement
180 |
181 | errors['MAE'] = torch.sum(diffMatrix) / nValidElement
182 |
183 | realMatrix = torch.div(diffMatrix, _target)
184 | realMatrix[nanMask] = 0
185 | errors['ABS_REL'] = torch.sum(realMatrix) / nValidElement
186 |
187 | LG10Matrix = torch.abs(lg10(_output) - lg10(_target))
188 | LG10Matrix[nanMask] = 0
189 | errors['LG10'] = torch.sum(LG10Matrix) / nValidElement
190 | yOverZ = torch.div(_output, _target)
191 | zOverY = torch.div(_target, _output)
192 |
193 | maxRatio = maxOfTwo(yOverZ, zOverY)
194 |
195 | errors['DELTA1'] = torch.sum(
196 | torch.le(maxRatio, 1.25).float()) / nValidElement
197 | errors['DELTA2'] = torch.sum(
198 | torch.le(maxRatio, math.pow(1.25, 2)).float()) / nValidElement
199 | errors['DELTA3'] = torch.sum(
200 | torch.le(maxRatio, math.pow(1.25, 3)).float()) / nValidElement
201 |
202 | errors['MSE'] = float(errors['MSE'].data.cpu().numpy())
203 | errors['ABS_REL'] = float(errors['ABS_REL'].data.cpu().numpy())
204 | errors['LG10'] = float(errors['LG10'].data.cpu().numpy())
205 | errors['MAE'] = float(errors['MAE'].data.cpu().numpy())
206 | errors['DELTA1'] = float(errors['DELTA1'].data.cpu().numpy())
207 | errors['DELTA2'] = float(errors['DELTA2'].data.cpu().numpy())
208 | errors['DELTA3'] = float(errors['DELTA3'].data.cpu().numpy())
209 |
210 | return errors
211 |
212 |
213 | def addErrors(errorSum, errors, batchSize):
214 | errorSum['MSE']=errorSum['MSE'] + errors['MSE'] * batchSize
215 | errorSum['ABS_REL']=errorSum['ABS_REL'] + errors['ABS_REL'] * batchSize
216 | errorSum['LG10']=errorSum['LG10'] + errors['LG10'] * batchSize
217 | errorSum['MAE']=errorSum['MAE'] + errors['MAE'] * batchSize
218 |
219 | errorSum['DELTA1']=errorSum['DELTA1'] + errors['DELTA1'] * batchSize
220 | errorSum['DELTA2']=errorSum['DELTA2'] + errors['DELTA2'] * batchSize
221 | errorSum['DELTA3']=errorSum['DELTA3'] + errors['DELTA3'] * batchSize
222 |
223 | return errorSum
224 |
225 |
226 | def averageErrors(errorSum, N):
227 | averageError={'MSE': 0, 'RMSE': 0, 'ABS_REL': 0, 'LG10': 0,
228 | 'MAE': 0, 'DELTA1': 0, 'DELTA2': 0, 'DELTA3': 0}
229 |
230 | averageError['MSE'] = errorSum['MSE'] / N
231 | averageError['ABS_REL'] = errorSum['ABS_REL'] / N
232 | averageError['LG10'] = errorSum['LG10'] / N
233 | averageError['MAE'] = errorSum['MAE'] / N
234 |
235 | averageError['DELTA1'] = errorSum['DELTA1'] / N
236 | averageError['DELTA2'] = errorSum['DELTA2'] / N
237 | averageError['DELTA3'] = errorSum['DELTA3'] / N
238 |
239 | return averageError
240 |
241 |
242 | def colormap(image, cmap="jet"):
243 | image_min = torch.min(image)
244 | image_max = torch.max(image)
245 | image = (image - image_min) / (image_max - image_min)
246 | image = torch.squeeze(image)
247 |
248 | # quantize
249 | indices = torch.round(image * 255).long()
250 | # gather
251 | cm = matplotlib.cm.get_cmap(cmap if cmap is not None else 'gray')
252 |
253 | colors = cm(np.arange(256))[:, :3]
254 | colors = torch.cuda.FloatTensor(colors)
255 | color_map = colors[indices].transpose(2, 3).transpose(1, 2)
256 |
257 | return color_map
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
--------------------------------------------------------------------------------