├── LICENSE ├── README.md ├── augs.py ├── checkpoints └── .gitignore ├── configs ├── GN.yaml └── GNS.yaml ├── criteria.py ├── datas └── .gitignore ├── datasets.py ├── exts ├── guideconv.cpp ├── guideconv_kernel.cu └── setup.py ├── models.py ├── optimizers.py ├── test.py ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Jie Tang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Learning Guided Convolutional Network for Depth Completion](https://arxiv.org/pdf/1908.01238). 2 | 3 | 4 | ## Introduction 5 | 6 | This is the pytorch implementation of our paper. 7 | 8 | ## Dependency 9 | ``` 10 | PyTorch 1.4 11 | PyTorch-Encoding v1.4.0 12 | ``` 13 | 14 | ## Setup 15 | Compile the C++ and CUDA code: 16 | ``` 17 | cd exts 18 | python setup.py install 19 | ``` 20 | 21 | ## Dataset 22 | Please download KITTI [depth completion](http://www.cvlibs.net/datasets/kitti/eval_depth.php?benchmark=depth_completion) 23 | dataset. 24 | The structure of data directory: 25 | ``` 26 | └── datas 27 | └── kitti 28 | ├── data_depth_annotated 29 | │   ├── train 30 | │   └── val 31 | ├── data_depth_velodyne 32 | │   ├── train 33 | │   └── val 34 | ├── raw 35 | │   ├── 2011_09_26 36 | │   ├── 2011_09_28 37 | │   ├── 2011_09_29 38 | │   ├── 2011_09_30 39 | │   └── 2011_10_03 40 | ├── test_depth_completion_anonymous 41 | │   ├── image 42 | │   ├── intrinsics 43 | │   └── velodyne_raw 44 | └── val_selection_cropped 45 | ├── groundtruth_depth 46 | ├── image 47 | ├── intrinsics 48 | └── velodyne_raw 49 | ``` 50 | 51 | ## Configs 52 | The config of different settings: 53 | - GN.yaml 54 | - GNS.yaml 55 | 56 | *Compared to **GN**, **GNS** uses fewer parameters to generate the guided kernels, 57 | but achieves slightly better results.* 58 | 59 | 60 | ## Trained Models 61 | You can directly download the trained model and put it in *checkpoints*: 62 | - [GN](https://drive.google.com/file/d/1-sa2pnMMjSv2dV2bRwuyLxPr1onmVykj/view?usp=sharing) 63 | - [GNS](https://drive.google.com/file/d/16tVrZQEDBucgjZmTjZl4iFkklkjfeDcs/view?usp=sharing) 64 | 65 | ## Train 66 | You can also train by yourself: 67 | ``` 68 | python train.py 69 | ``` 70 | *Pay attention to the settings in the config file (e.g. gpu id).* 71 | 72 | ## Test 73 | With the trained model, 74 | you can test and save depth images. 75 | ``` 76 | python test.py 77 | ``` 78 | 79 | ## Citation 80 | If you find this work useful in your research, please consider citing: 81 | ``` 82 | @article{guidenet, 83 | title={Learning guided convolutional network for depth completion}, 84 | author={Tang, Jie and Tian, Fei-Peng and Feng, Wei and Li, Jian and Tan, Ping}, 85 | journal={IEEE Transactions on Image Processing}, 86 | volume={30}, 87 | pages={1116--1129}, 88 | year={2020}, 89 | publisher={IEEE} 90 | } 91 | ``` -------------------------------------------------------------------------------- /augs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @Filename: augs.py 4 | # @Project: GuideNet 5 | # @Author: jie 6 | # @Time: 2021/3/14 8:27 PM 7 | 8 | import numpy as np 9 | 10 | __all__ = [ 11 | 'Compose', 12 | 'Norm', 13 | 'Jitter', 14 | 'Flip', 15 | ] 16 | 17 | 18 | class Compose(object): 19 | """ 20 | Sequential operations on input images, (i.e. rgb, lidar and depth). 21 | """ 22 | 23 | def __init__(self, transforms): 24 | self.transforms = transforms 25 | 26 | def __call__(self, rgb, lidar, depth): 27 | for t in self.transforms: 28 | rgb, lidar, depth = t(rgb, lidar, depth) 29 | return rgb, lidar, depth 30 | 31 | 32 | class Norm(object): 33 | """ 34 | normalize rgb image. 35 | """ 36 | 37 | def __init__(self, mean, std): 38 | self.mean = np.array(mean) 39 | self.std = np.array(std) 40 | 41 | def __call__(self, rgb, lidar, depth): 42 | rgb = (rgb - self.mean) / self.std 43 | return rgb, lidar, depth 44 | 45 | 46 | class Jitter(object): 47 | """ 48 | borrow from https://github.com/kujason/avod/blob/master/avod/datasets/kitti/kitti_aug.py 49 | """ 50 | 51 | def __call__(self, rgb, lidar, depth): 52 | pca = compute_pca(rgb) 53 | rgb = add_pca_jitter(rgb, pca) 54 | return rgb, lidar, depth 55 | 56 | 57 | class Flip(object): 58 | """ 59 | random horizontal flip of images. 60 | """ 61 | 62 | def __call__(self, rgb, lidar, depth): 63 | flip = bool(np.random.randint(2)) 64 | if flip: 65 | rgb = rgb[:, ::-1, :] 66 | lidar = lidar[:, ::-1, :] 67 | depth = depth[:, ::-1, :] 68 | return rgb, lidar, depth 69 | 70 | 71 | def compute_pca(image): 72 | """ 73 | calculate PCA of image 74 | """ 75 | 76 | reshaped_data = image.reshape(-1, 3) 77 | reshaped_data = (reshaped_data / 255.0).astype(np.float32) 78 | covariance = np.cov(reshaped_data.T) 79 | e_vals, e_vecs = np.linalg.eigh(covariance) 80 | pca = np.sqrt(e_vals) * e_vecs 81 | return pca 82 | 83 | 84 | def add_pca_jitter(img_data, pca): 85 | """ 86 | add a multiple of principle components with Gaussian noise 87 | """ 88 | new_img_data = np.copy(img_data).astype(np.float32) / 255.0 89 | magnitude = np.random.randn(3) * 0.1 90 | noise = (pca * magnitude).sum(axis=1) 91 | 92 | new_img_data = new_img_data + noise 93 | np.clip(new_img_data, 0.0, 1.0, out=new_img_data) 94 | new_img_data = (new_img_data * 255).astype(np.uint8) 95 | 96 | return new_img_data 97 | -------------------------------------------------------------------------------- /checkpoints/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /configs/GN.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 8 2 | data_config: 3 | kitti: 4 | path: datas/kitti 5 | gpu_ids: 6 | - 4 7 | - 5 8 | loss: MSE 9 | lr_config: 10 | MultiStepLR: 11 | gamma: 0.5 12 | last_epoch: -1 13 | milestones: 14 | - 5 15 | - 10 16 | - 15 17 | manual_seed: 0 18 | metric: RMSE 19 | model: GN 20 | name: GN 21 | nepoch: 20 22 | num_workers: 4 23 | optim_config: 24 | AdamW: 25 | lr: 0.001 26 | weight_decay: 0.05 27 | resume_seed: 6288 28 | start_epoch: 0 29 | test_aug_configs: 30 | - Norm: 31 | mean: 32 | - 90.995 33 | - 96.2278 34 | - 94.3213 35 | std: 36 | - 79.2382 37 | - 80.5267 38 | - 82.1483 39 | test_epoch: 15 40 | test_iters: 500 41 | train_aug_configs: 42 | - Jitter 43 | - Flip 44 | - Norm: 45 | mean: 46 | - 90.995 47 | - 96.2278 48 | - 94.3213 49 | std: 50 | - 79.2382 51 | - 80.5267 52 | - 82.1483 53 | tta: true 54 | vis: true 55 | vis_iters: 100 56 | -------------------------------------------------------------------------------- /configs/GNS.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 8 2 | data_config: 3 | kitti: 4 | path: datas/kitti 5 | gpu_ids: 6 | - 6 7 | - 7 8 | loss: MSE 9 | lr_config: 10 | MultiStepLR: 11 | gamma: 0.5 12 | last_epoch: -1 13 | milestones: 14 | - 5 15 | - 10 16 | - 15 17 | manual_seed: 0 18 | metric: RMSE 19 | model: GNS 20 | name: GNS 21 | nepoch: 20 22 | num_workers: 4 23 | optim_config: 24 | AdamW: 25 | lr: 0.001 26 | weight_decay: 0.05 27 | resume_seed: 1600 28 | start_epoch: 0 29 | test_aug_configs: 30 | - Norm: 31 | mean: 32 | - 90.995 33 | - 96.2278 34 | - 94.3213 35 | std: 36 | - 79.2382 37 | - 80.5267 38 | - 82.1483 39 | test_epoch: 15 40 | test_iters: 500 41 | train_aug_configs: 42 | - Jitter 43 | - Flip 44 | - Norm: 45 | mean: 46 | - 90.995 47 | - 96.2278 48 | - 94.3213 49 | std: 50 | - 79.2382 51 | - 80.5267 52 | - 82.1483 53 | tta: true 54 | vis: true 55 | vis_iters: 100 56 | -------------------------------------------------------------------------------- /criteria.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @Filename: criteria.py 4 | # @Project: GuideNet 5 | # @Author: jie 6 | # @Time: 2021/3/14 7:51 PM 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | __all__ = [ 12 | 'RMSE', 13 | 'MSE', 14 | ] 15 | 16 | 17 | class RMSE(nn.Module): 18 | 19 | def __init__(self): 20 | super().__init__() 21 | 22 | def forward(self, outputs, target, *args): 23 | val_pixels = (target > 1e-3).float().cuda() 24 | err = (target * val_pixels - outputs * val_pixels) ** 2 25 | loss = torch.sum(err.view(err.size(0), 1, -1), -1, keepdim=True) 26 | cnt = torch.sum(val_pixels.view(val_pixels.size(0), 1, -1), -1, keepdim=True) 27 | return torch.sqrt(loss / cnt) 28 | 29 | 30 | class MSE(nn.Module): 31 | 32 | def __init__(self): 33 | super().__init__() 34 | 35 | def forward(self, outputs, target, *args): 36 | val_pixels = (target > 1e-3).float().cuda() 37 | loss = target * val_pixels - outputs * val_pixels 38 | return loss ** 2 39 | -------------------------------------------------------------------------------- /datas/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @Filename: datasets.py 4 | # @Project: GuideNet 5 | # @Author: jie 6 | # @Time: 2021/3/14 8:08 PM 7 | 8 | import os 9 | import numpy as np 10 | import glob 11 | from PIL import Image 12 | import torch.utils.data as data 13 | 14 | __all__ = [ 15 | 'kitti', 16 | ] 17 | 18 | 19 | class kitti(data.Dataset): 20 | """ 21 | kitti depth completion dataset: http://www.cvlibs.net/datasets/kitti/eval_depth.php?benchmark=depth_completion 22 | """ 23 | 24 | def __init__(self, path='../datas/kitti', mode='train', height=256, width=1216, return_idx=False, return_size=False, 25 | transform=None): 26 | self.base_dir = path 27 | self.height = height 28 | self.width = width 29 | self.mode = mode 30 | self.return_idx = return_idx 31 | self.return_size = return_size 32 | self.transform = transform 33 | if mode in ['train', 'val']: 34 | self.depth_path = os.path.join(self.base_dir, 'data_depth_annotated', mode) 35 | self.lidar_path = os.path.join(self.base_dir, 'data_depth_velodyne', mode) 36 | self.depths = list(sorted(glob.iglob(self.depth_path + "/**/*.png", recursive=True))) 37 | self.lidars = list(sorted(glob.iglob(self.lidar_path + "/**/*.png", recursive=True))) 38 | elif mode == 'selval': 39 | self.depth_path = os.path.join(self.base_dir, 'val_selection_cropped', 'groundtruth_depth') 40 | self.lidar_path = os.path.join(self.base_dir, 'val_selection_cropped', 'velodyne_raw') 41 | self.image_path = os.path.join(self.base_dir, 'val_selection_cropped', 'image') 42 | self.depths = list(sorted(glob.iglob(self.depth_path + "/*.png", recursive=True))) 43 | self.lidars = list(sorted(glob.iglob(self.lidar_path + "/*.png", recursive=True))) 44 | self.images = list(sorted(glob.iglob(self.image_path + "/*.png", recursive=True))) 45 | elif mode == 'test': 46 | self.lidar_path = os.path.join(self.base_dir, 'test_depth_completion_anonymous', 'velodyne_raw') 47 | self.image_path = os.path.join(self.base_dir, 'test_depth_completion_anonymous', 'image') 48 | self.lidars = list(sorted(glob.iglob(self.lidar_path + "/*.png", recursive=True))) 49 | self.images = list(sorted(glob.iglob(self.image_path + "/*.png", recursive=True))) 50 | self.depths = self.lidars 51 | else: 52 | raise ValueError("Unknown mode: {}".format(mode)) 53 | assert (len(self.depths) == len(self.lidars)) 54 | self.names = [os.path.split(path)[-1] for path in self.depths] 55 | 56 | def __len__(self): 57 | return len(self.depths) 58 | 59 | def __getitem__(self, index): 60 | 61 | depth = self.pull_DEPTH(self.depths[index]) 62 | depth = np.expand_dims(depth, axis=2) 63 | lidar = self.pull_DEPTH(self.lidars[index]) 64 | lidar = np.expand_dims(lidar, axis=2) 65 | file_names = self.depths[index].split('/') 66 | if self.mode in ['train', 'val']: 67 | rgb_path = os.path.join(*file_names[:-7], 'raw', file_names[-5].split('_drive')[0], file_names[-5], 68 | file_names[-2], 'data', file_names[-1]) 69 | elif self.mode in ['selval', 'test']: 70 | rgb_path = self.images[index] 71 | else: 72 | ValueError("Unknown mode: {}".format(self.mode)) 73 | rgb = self.pull_RGB(rgb_path) 74 | rgb = rgb.astype(np.float32) 75 | lidar = lidar.astype(np.float32) 76 | depth = depth.astype(np.float32) 77 | shape = lidar.shape 78 | if self.transform: 79 | rgb, lidar, depth = self.transform(rgb, lidar, depth) 80 | rgb = rgb.transpose(2, 0, 1).astype(np.float32) 81 | lidar = lidar.transpose(2, 0, 1).astype(np.float32) 82 | depth = depth.transpose(2, 0, 1).astype(np.float32) 83 | lp = (rgb.shape[2] - self.width) // 2 84 | rgb = rgb[:, -self.height:, lp:lp + self.width] 85 | lidar = lidar[:, -self.height:, lp:lp + self.width] 86 | depth = depth[:, -self.height:, lp:lp + self.width] 87 | output = [rgb, lidar, depth] 88 | if self.return_idx: 89 | output.append(np.array([index], dtype=int)) 90 | if self.return_size: 91 | output.append(np.array(shape[:2], dtype=int)) 92 | return output 93 | 94 | def pull_RGB(self, path): 95 | img = np.array(Image.open(path).convert('RGB'), dtype=np.uint8) 96 | return img 97 | 98 | def pull_DEPTH(self, path): 99 | depth_png = np.array(Image.open(path), dtype=int) 100 | assert (np.max(depth_png) > 255) 101 | depth_image = (depth_png / 256.).astype(np.float32) 102 | return depth_image 103 | -------------------------------------------------------------------------------- /exts/guideconv.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by jie on 09/02/19. 3 | // 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | void Conv2d_LF_Cuda(at::Tensor x, at::Tensor y, at::Tensor z, size_t N1, size_t N2, size_t Ci, size_t Co, size_t B, 11 | size_t K); 12 | 13 | void 14 | Conv2d_LB_Cuda(at::Tensor x, at::Tensor y, at::Tensor gx, at::Tensor gy, at::Tensor gz, size_t N1, size_t N2, size_t Ci, 15 | size_t Co, size_t B, size_t K); 16 | 17 | 18 | at::Tensor Conv2dLocal_F( 19 | at::Tensor a, // BCHW 20 | at::Tensor b // BCKKHW 21 | ) { 22 | int N1, N2, Ci, Co, K, B; 23 | B = a.size(0); 24 | Ci = a.size(1); 25 | N1 = a.size(2); 26 | N2 = a.size(3); 27 | Co = Ci; 28 | K = sqrt(b.size(1) / Co); 29 | auto c = at::zeros_like(a); 30 | Conv2d_LF_Cuda(a, b, c, N1, N2, Ci, Co, B, K); 31 | return c; 32 | } 33 | 34 | 35 | std::tuple Conv2dLocal_B( 36 | at::Tensor a, 37 | at::Tensor b, 38 | at::Tensor gc 39 | ) { 40 | int N1, N2, Ci, Co, K, B; 41 | B = a.size(0); 42 | Ci = a.size(1); 43 | N1 = a.size(2); 44 | N2 = a.size(3); 45 | Co = Ci; 46 | K = sqrt(b.size(1) / Co); 47 | auto ga = at::zeros_like(a); 48 | auto gb = at::zeros_like(b); 49 | Conv2d_LB_Cuda(a, b, ga, gb, gc, N1, N2, Ci, Co, B, K); 50 | return std::make_tuple(ga, gb); 51 | } 52 | 53 | 54 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m 55 | ) { 56 | m.def("Conv2dLocal_F", &Conv2dLocal_F, "Conv2dLocal Forward (CUDA)"); 57 | m.def("Conv2dLocal_B", &Conv2dLocal_B, "Conv2dLocal Backward (CUDA)"); 58 | } -------------------------------------------------------------------------------- /exts/guideconv_kernel.cu: -------------------------------------------------------------------------------- 1 | // 2 | // Created by jie on 09/02/19. 3 | // 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | namespace { 10 | 11 | template 12 | __global__ void 13 | conv2d_kernel_lf(scalar_t *__restrict__ x, scalar_t *__restrict__ y, scalar_t *__restrict__ z, size_t N1, 14 | size_t N2, size_t Ci, size_t Co, size_t B, 15 | size_t K) { 16 | int col_index = threadIdx.x + blockIdx.x * blockDim.x; 17 | int row_index = threadIdx.y + blockIdx.y * blockDim.y; 18 | int cha_index = threadIdx.z + blockIdx.z * blockDim.z; 19 | if ((row_index < N1) && (col_index < N2) && (cha_index < Co)) { 20 | for (int b = 0; b < B; b++) { 21 | scalar_t result = 0; 22 | for (int i = -int((K - 1) / 2.); i < (K + 1) / 2.; i++) { 23 | for (int j = -int((K - 1) / 2.); j < (K + 1) / 2.; j++) { 24 | 25 | if ((row_index + i < 0) || (row_index + i >= N1) || (col_index + j < 0) || 26 | (col_index + j >= N2)) { 27 | continue; 28 | } 29 | 30 | result += x[b * N1 * N2 * Ci + cha_index * N1 * N2 + (row_index + i) * N2 + col_index + j] * 31 | y[b * N1 * N2 * Ci * K * K + cha_index * N1 * N2 * K * K + 32 | (i + (K - 1) / 2) * K * N1 * N2 + 33 | (j + (K - 1) / 2) * N1 * N2 + row_index * N2 + col_index]; 34 | } 35 | } 36 | z[b * N1 * N2 * Co + cha_index * N1 * N2 + row_index * N2 + col_index] = result; 37 | } 38 | } 39 | } 40 | 41 | 42 | template 43 | __global__ void conv2d_kernel_lb(scalar_t *__restrict__ x, scalar_t *__restrict__ y, scalar_t *__restrict__ gx, 44 | scalar_t *__restrict__ gy, scalar_t *__restrict__ gz, size_t N1, size_t N2, 45 | size_t Ci, size_t Co, size_t B, 46 | size_t K) { 47 | int col_index = threadIdx.x + blockIdx.x * blockDim.x; 48 | int row_index = threadIdx.y + blockIdx.y * blockDim.y; 49 | int cha_index = threadIdx.z + blockIdx.z * blockDim.z; 50 | if ((row_index < N1) && (col_index < N2) && (cha_index < Co)) { 51 | for (int b = 0; b < B; b++) { 52 | scalar_t result = 0; 53 | for (int i = -int((K - 1) / 2.); i < (K + 1) / 2.; i++) { 54 | for (int j = -int((K - 1) / 2.); j < (K + 1) / 2.; j++) { 55 | 56 | if ((row_index - i < 0) || (row_index - i >= N1) || (col_index - j < 0) || 57 | (col_index - j >= N2)) { 58 | continue; 59 | } 60 | result += gz[b * N1 * N2 * Ci + cha_index * N1 * N2 + (row_index - i) * N2 + col_index - j 61 | ] * 62 | y[b * N1 * N2 * Ci * K * K + cha_index * N1 * N2 * K * K + 63 | (i + (K - 1) / 2) * K * N1 * N2 + 64 | (j + (K - 1) / 2) * N1 * N2 + (row_index - i) * N2 + col_index - j]; 65 | gy[b * N1 * N2 * Ci * K * K + cha_index * N1 * N2 * K * K + (i + (K - 1) / 2) * K * N1 * N2 + 66 | (j + (K - 1) / 2) * N1 * N2 + (row_index - i) * N2 + col_index - j] = 67 | gz[b * N1 * N2 * Ci + cha_index * N1 * N2 + (row_index - i) * N2 + col_index - j 68 | ] * x[b * N1 * N2 * Ci + cha_index * N1 * N2 + row_index * N2 + col_index]; 69 | 70 | } 71 | } 72 | gx[b * N1 * N2 * Co + cha_index * N1 * N2 + row_index * N2 + col_index] = result; 73 | } 74 | } 75 | } 76 | } 77 | 78 | 79 | void Conv2d_LF_Cuda(at::Tensor x, at::Tensor y, at::Tensor z, size_t N1, size_t N2, size_t Ci, size_t Co, size_t B, 80 | size_t K) { 81 | dim3 blockSize(32, 32, 1); 82 | dim3 gridSize((N2 + blockSize.x - 1) / blockSize.x, (N1 + blockSize.y - 1) / blockSize.y, 83 | (Co + blockSize.z - 1) / blockSize.z); 84 | AT_DISPATCH_FLOATING_TYPES(x.type(), "Conv2d_LF", ([&] { 85 | conv2d_kernel_lf << < gridSize, blockSize >> > ( 86 | x.data(), y.data(), z.data(), 87 | N1, N2, Ci, Co, B, K); 88 | })); 89 | } 90 | 91 | 92 | void 93 | Conv2d_LB_Cuda(at::Tensor x, at::Tensor y, at::Tensor gx, at::Tensor gy, at::Tensor gz, size_t N1, size_t N2, size_t Ci, 94 | size_t Co, size_t B, size_t K) { 95 | dim3 blockSize(32, 32, 1); 96 | dim3 gridSize((N2 + blockSize.x - 1) / blockSize.x, (N1 + blockSize.y - 1) / blockSize.y, 97 | (Co + blockSize.z - 1) / blockSize.z); 98 | AT_DISPATCH_FLOATING_TYPES(x.type(), "Conv2d_LB", ([&] { 99 | conv2d_kernel_lb << < gridSize, blockSize >> > ( 100 | x.data(), y.data(), 101 | gx.data(), gy.data(), gz.data(), 102 | N1, N2, Ci, Co, B, K); 103 | })); 104 | } 105 | -------------------------------------------------------------------------------- /exts/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='GuideConv', 6 | ext_modules=[ 7 | CUDAExtension('GuideConv', [ 8 | 'guideconv.cpp', 9 | 'guideconv_kernel.cu', 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @Filename: model.py 4 | # @Project: GuideNet 5 | # @Author: jie 6 | # @Time: 2021/3/14 7:50 PM 7 | 8 | import torch 9 | import torch.nn as nn 10 | from scipy.stats import truncnorm 11 | import math 12 | from torch.autograd import Function 13 | import encoding 14 | import GuideConv 15 | 16 | __all__ = [ 17 | 'GN', 18 | 'GNS', 19 | ] 20 | 21 | 22 | def Conv1x1(in_planes, out_planes, stride=1): 23 | """1x1 convolution""" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 25 | 26 | 27 | def Conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 28 | """3x3 convolution with padding""" 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 30 | padding=dilation, groups=groups, bias=False, dilation=dilation) 31 | 32 | 33 | class Conv2dLocal_F(Function): 34 | @staticmethod 35 | def forward(ctx, input, weight): 36 | ctx.save_for_backward(input, weight) 37 | output = GuideConv.Conv2dLocal_F(input, weight) 38 | return output 39 | 40 | @staticmethod 41 | def backward(ctx, grad_output): 42 | input, weight = ctx.saved_tensors 43 | grad_output = grad_output.contiguous() 44 | grad_input, grad_weight = GuideConv.Conv2dLocal_B(input, weight, grad_output) 45 | return grad_input, grad_weight 46 | 47 | 48 | class Conv2dLocal(nn.Module): 49 | def __init__(self, ): 50 | super().__init__() 51 | 52 | def forward(self, input, weight): 53 | output = Conv2dLocal_F.apply(input, weight) 54 | return output 55 | 56 | 57 | class Basic2d(nn.Module): 58 | def __init__(self, in_channels, out_channels, norm_layer=None, kernel_size=3, padding=1): 59 | super().__init__() 60 | if norm_layer: 61 | conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 62 | stride=1, padding=padding, bias=False) 63 | else: 64 | conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 65 | stride=1, padding=padding, bias=True) 66 | self.conv = nn.Sequential(conv, ) 67 | if norm_layer: 68 | self.conv.add_module('bn', norm_layer(out_channels)) 69 | self.conv.add_module('relu', nn.ReLU(inplace=True)) 70 | 71 | def forward(self, x): 72 | out = self.conv(x) 73 | return out 74 | 75 | 76 | class Basic2dTrans(nn.Module): 77 | def __init__(self, in_channels, out_channels, norm_layer=None): 78 | super().__init__() 79 | if norm_layer is None: 80 | norm_layer = nn.BatchNorm2d 81 | self.conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, 82 | stride=2, padding=1, output_padding=1, bias=False) 83 | self.bn = norm_layer(out_channels) 84 | self.relu = nn.ReLU(inplace=True) 85 | 86 | def forward(self, x): 87 | out = self.conv(x) 88 | out = self.bn(out) 89 | out = self.relu(out) 90 | return out 91 | 92 | 93 | class Basic2dLocal(nn.Module): 94 | def __init__(self, out_channels, norm_layer=None): 95 | super().__init__() 96 | if norm_layer is None: 97 | norm_layer = nn.BatchNorm2d 98 | 99 | self.conv = Conv2dLocal() 100 | self.bn = norm_layer(out_channels) 101 | self.relu = nn.ReLU(inplace=True) 102 | 103 | def forward(self, input, weight): 104 | out = self.conv(input, weight) 105 | out = self.bn(out) 106 | out = self.relu(out) 107 | return out 108 | 109 | 110 | class Guide(nn.Module): 111 | 112 | def __init__(self, input_planes, weight_planes, norm_layer=None, weight_ks=3): 113 | super().__init__() 114 | if norm_layer is None: 115 | norm_layer = nn.BatchNorm2d 116 | self.local = Basic2dLocal(input_planes, norm_layer) 117 | self.pool = nn.AdaptiveAvgPool2d((1, 1)) 118 | self.conv11 = Basic2d(input_planes + weight_planes, input_planes, None) 119 | self.conv12 = nn.Conv2d(input_planes, input_planes * 9, kernel_size=weight_ks, padding=weight_ks // 2) 120 | self.conv21 = Basic2d(input_planes + weight_planes, input_planes, None) 121 | self.conv22 = nn.Conv2d(input_planes, input_planes * input_planes, kernel_size=1, padding=0) 122 | self.br = nn.Sequential( 123 | norm_layer(num_features=input_planes), 124 | nn.ReLU(inplace=True), 125 | ) 126 | self.conv3 = Basic2d(input_planes, input_planes, norm_layer) 127 | 128 | def forward(self, input, weight): 129 | B, Ci, H, W = input.shape 130 | weight = torch.cat([input, weight], 1) 131 | weight11 = self.conv11(weight) 132 | weight12 = self.conv12(weight11) 133 | weight21 = self.conv21(weight) 134 | weight21 = self.pool(weight21) 135 | weight22 = self.conv22(weight21).view(B, -1, Ci) 136 | out = self.local(input, weight12).view(B, Ci, -1) 137 | out = torch.bmm(weight22, out).view(B, Ci, H, W) 138 | out = self.br(out) 139 | out = self.conv3(out) 140 | return out 141 | 142 | 143 | class BasicBlock(nn.Module): 144 | expansion = 1 145 | __constants__ = ['downsample'] 146 | 147 | def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None, act=True): 148 | super().__init__() 149 | if norm_layer is None: 150 | norm_layer = nn.BatchNorm2d 151 | self.conv1 = Conv3x3(inplanes, planes, stride) 152 | self.bn1 = norm_layer(planes) 153 | self.relu = nn.ReLU(inplace=True) 154 | self.conv2 = Conv3x3(planes, planes) 155 | self.bn2 = norm_layer(planes) 156 | self.downsample = downsample 157 | self.stride = stride 158 | self.act = act 159 | 160 | def forward(self, x): 161 | identity = x 162 | out = self.conv1(x) 163 | out = self.bn1(out) 164 | out = self.relu(out) 165 | out = self.conv2(out) 166 | out = self.bn2(out) 167 | if self.downsample is not None: 168 | identity = self.downsample(x) 169 | out += identity 170 | if self.act: 171 | out = self.relu(out) 172 | return out 173 | 174 | 175 | class GuideNet(nn.Module): 176 | """ 177 | Not activate at the ref 178 | Init change to trunctated norm 179 | """ 180 | 181 | def __init__(self, block=BasicBlock, bc=16, img_layers=[2, 2, 2, 2, 2], 182 | depth_layers=[2, 2, 2, 2, 2], norm_layer=nn.BatchNorm2d, guide=Guide, weight_ks=3): 183 | super().__init__() 184 | self._norm_layer = norm_layer 185 | 186 | self.conv_img = Basic2d(3, bc * 2, norm_layer=norm_layer, kernel_size=5, padding=2) 187 | in_channels = bc * 2 188 | self.inplanes = in_channels 189 | self.layer1_img = self._make_layer(block, in_channels * 2, img_layers[0], stride=2) 190 | 191 | self.guide1 = guide(in_channels * 2, in_channels * 2, norm_layer, weight_ks) 192 | self.inplanes = in_channels * 2 * block.expansion 193 | self.layer2_img = self._make_layer(block, in_channels * 4, img_layers[1], stride=2) 194 | 195 | self.guide2 = guide(in_channels * 4, in_channels * 4, norm_layer, weight_ks) 196 | self.inplanes = in_channels * 4 * block.expansion 197 | self.layer3_img = self._make_layer(block, in_channels * 8, img_layers[2], stride=2) 198 | 199 | self.guide3 = guide(in_channels * 8, in_channels * 8, norm_layer, weight_ks) 200 | self.inplanes = in_channels * 8 * block.expansion 201 | self.layer4_img = self._make_layer(block, in_channels * 8, img_layers[3], stride=2) 202 | 203 | self.guide4 = guide(in_channels * 8, in_channels * 8, norm_layer, weight_ks) 204 | self.inplanes = in_channels * 8 * block.expansion 205 | self.layer5_img = self._make_layer(block, in_channels * 8, img_layers[4], stride=2) 206 | 207 | self.layer2d_img = Basic2dTrans(in_channels * 4, in_channels * 2, norm_layer) 208 | self.layer3d_img = Basic2dTrans(in_channels * 8, in_channels * 4, norm_layer) 209 | self.layer4d_img = Basic2dTrans(in_channels * 8, in_channels * 8, norm_layer) 210 | self.layer5d_img = Basic2dTrans(in_channels * 8, in_channels * 8, norm_layer) 211 | 212 | self.conv_lidar = Basic2d(1, bc * 2, norm_layer=None, kernel_size=5, padding=2) 213 | 214 | self.inplanes = in_channels 215 | self.layer1_lidar = self._make_layer(block, in_channels * 2, depth_layers[0], stride=2) 216 | self.inplanes = in_channels * 2 * block.expansion 217 | self.layer2_lidar = self._make_layer(block, in_channels * 4, depth_layers[1], stride=2) 218 | self.inplanes = in_channels * 4 * block.expansion 219 | self.layer3_lidar = self._make_layer(block, in_channels * 8, depth_layers[2], stride=2) 220 | self.inplanes = in_channels * 8 * block.expansion 221 | self.layer4_lidar = self._make_layer(block, in_channels * 8, depth_layers[3], stride=2) 222 | self.inplanes = in_channels * 8 * block.expansion 223 | self.layer5_lidar = self._make_layer(block, in_channels * 8, depth_layers[4], stride=2) 224 | 225 | self.layer1d = Basic2dTrans(in_channels * 2, in_channels, norm_layer) 226 | self.layer2d = Basic2dTrans(in_channels * 4, in_channels * 2, norm_layer) 227 | self.layer3d = Basic2dTrans(in_channels * 8, in_channels * 4, norm_layer) 228 | self.layer4d = Basic2dTrans(in_channels * 8, in_channels * 8, norm_layer) 229 | self.layer5d = Basic2dTrans(in_channels * 8, in_channels * 8, norm_layer) 230 | 231 | self.conv = nn.Conv2d(bc * 2, 1, kernel_size=3, stride=1, padding=1) 232 | self.ref = block(bc * 2, bc * 2, norm_layer=norm_layer, act=False) 233 | 234 | self._initialize_weights() 235 | 236 | def forward(self, img, lidar): 237 | c0_img = self.conv_img(img) 238 | c1_img = self.layer1_img(c0_img) 239 | c2_img = self.layer2_img(c1_img) 240 | c3_img = self.layer3_img(c2_img) 241 | c4_img = self.layer4_img(c3_img) 242 | c5_img = self.layer5_img(c4_img) 243 | dc5_img = self.layer5d_img(c5_img) 244 | c4_mix = dc5_img + c4_img 245 | dc4_img = self.layer4d_img(c4_mix) 246 | c3_mix = dc4_img + c3_img 247 | dc3_img = self.layer3d_img(c3_mix) 248 | c2_mix = dc3_img + c2_img 249 | dc2_img = self.layer2d_img(c2_mix) 250 | c1_mix = dc2_img + c1_img 251 | 252 | c0_lidar = self.conv_lidar(lidar) 253 | c1_lidar = self.layer1_lidar(c0_lidar) 254 | c1_lidar_dyn = self.guide1(c1_lidar, c1_mix) 255 | c2_lidar = self.layer2_lidar(c1_lidar_dyn) 256 | c2_lidar_dyn = self.guide2(c2_lidar, c2_mix) 257 | c3_lidar = self.layer3_lidar(c2_lidar_dyn) 258 | c3_lidar_dyn = self.guide3(c3_lidar, c3_mix) 259 | c4_lidar = self.layer4_lidar(c3_lidar_dyn) 260 | c4_lidar_dyn = self.guide4(c4_lidar, c4_mix) 261 | c5_lidar = self.layer5_lidar(c4_lidar_dyn) 262 | c5 = c5_img + c5_lidar 263 | dc5 = self.layer5d(c5) 264 | c4 = dc5 + c4_lidar_dyn 265 | dc4 = self.layer4d(c4) 266 | c3 = dc4 + c3_lidar_dyn 267 | dc3 = self.layer3d(c3) 268 | c2 = dc3 + c2_lidar_dyn 269 | dc2 = self.layer2d(c2) 270 | c1 = dc2 + c1_lidar_dyn 271 | dc1 = self.layer1d(c1) 272 | c0 = dc1 + c0_lidar 273 | output = self.ref(c0) 274 | output = self.conv(output) 275 | return (output,) 276 | 277 | def _make_layer(self, block, planes, blocks, stride=1): 278 | norm_layer = self._norm_layer 279 | downsample = None 280 | if stride != 1 or self.inplanes != planes * block.expansion: 281 | downsample = nn.Sequential( 282 | Conv1x1(self.inplanes, planes * block.expansion, stride), 283 | norm_layer(planes * block.expansion), 284 | ) 285 | 286 | layers = [] 287 | layers.append(block(self.inplanes, planes, stride, downsample, norm_layer)) 288 | self.inplanes = planes * block.expansion 289 | for _ in range(1, blocks): 290 | layers.append(block(self.inplanes, planes, norm_layer=norm_layer)) 291 | 292 | return nn.Sequential(*layers) 293 | 294 | def _initialize_weights(self): 295 | def truncated_normal_(num, mean=0., std=1.): 296 | lower = -2 * std 297 | upper = 2 * std 298 | X = truncnorm((lower - mean) / std, (upper - mean) / std, loc=mean, scale=std) 299 | samples = X.rvs(num) 300 | output = torch.from_numpy(samples) 301 | return output 302 | 303 | for m in self.modules(): 304 | if isinstance(m, nn.Conv2d): 305 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels 306 | data = truncated_normal_(m.weight.nelement(), mean=0, std=math.sqrt(1.3 * 2. / n)) 307 | data = data.type_as(m.weight.data) 308 | m.weight.data = data.view_as(m.weight.data) 309 | if m.bias is not None: 310 | nn.init.zeros_(m.bias) 311 | 312 | 313 | def GN(): 314 | return GuideNet(norm_layer=encoding.nn.SyncBatchNorm, guide=Guide) 315 | 316 | 317 | def GNS(): 318 | return GuideNet(norm_layer=encoding.nn.SyncBatchNorm, guide=Guide, weight_ks=1) 319 | -------------------------------------------------------------------------------- /optimizers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @Filename: optimizers.py 4 | # @Project: GuideNet 5 | # @Author: jie 6 | # @Time: 2021/3/15 4:59 PM 7 | """ 8 | This is a fixup as pytorch 1.4.0 can not import AdamW directly from torch.optim 9 | """ 10 | 11 | from torch.optim import * 12 | from torch.optim.adamw import AdamW -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @Filename: test.py 4 | # @Project: GuideNet 5 | # @Author: jie 6 | # @Time: 2021/3/16 4:47 PM 7 | 8 | import os 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 11 | import torch 12 | import yaml 13 | from easydict import EasyDict as edict 14 | import datasets 15 | import encoding 16 | 17 | def test(): 18 | net.eval() 19 | for batch_idx, (rgb, lidar, _, idx, ori_size) in enumerate(testloader): 20 | with torch.no_grad(): 21 | if config.tta: 22 | rgbf = torch.flip(rgb, [-1]) 23 | lidarf = torch.flip(lidar, [-1]) 24 | rgbs = torch.cat([rgb, rgbf], 0) 25 | lidars = torch.cat([lidar, lidarf], 0) 26 | rgbs, lidars = rgbs.cuda(), lidars.cuda() 27 | depth_preds, = net(rgbs, lidars) 28 | depth_pred, depth_predf = depth_preds.split(depth_preds.shape[0] // 2) 29 | depth_predf = torch.flip(depth_predf, [-1]) 30 | depth_pred = (depth_pred + depth_predf) / 2. 31 | else: 32 | rgb, lidar = rgb.cuda(), lidar.cuda() 33 | depth_pred, = net(rgb, lidar) 34 | depth_pred[depth_pred < 0] = 0 35 | depth_pred = depth_pred.cpu().squeeze(1).numpy() 36 | idx = idx.cpu().squeeze(1).numpy() 37 | ori_size = ori_size.cpu().numpy() 38 | name = [testset.names[i] for i in idx] 39 | save_result(config, depth_pred, name, ori_size) 40 | 41 | 42 | if __name__ == '__main__': 43 | # config_name = 'GN.yaml' 44 | config_name = 'GNS.yaml' 45 | with open(os.path.join('configs', config_name), 'r') as file: 46 | config_data = yaml.load(file, Loader=yaml.FullLoader) 47 | config = edict(config_data) 48 | from utils import * 49 | 50 | transform = init_aug(config.test_aug_configs) 51 | key, params = config.data_config.popitem() 52 | dataset = getattr(datasets, key) 53 | testset = dataset(**params, mode='test', transform=transform, return_idx=True, return_size=True) 54 | testloader = torch.utils.data.DataLoader(testset, batch_size=config.batch_size, num_workers=config.num_workers, 55 | shuffle=False, pin_memory=True) 56 | print('num_test = {}'.format(len(testset))) 57 | net = init_net(config) 58 | torch.cuda.empty_cache() 59 | torch.backends.cudnn.benchmark = True 60 | net.cuda() 61 | net = encoding.parallel.DataParallelModel(net) 62 | net = resume_state(config, net) 63 | test() 64 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @Filename: train.py 4 | # @Project: GuideNet 5 | # @Author: jie 6 | # @Time: 2021/3/14 7:50 PM 7 | 8 | import os 9 | import torch 10 | import yaml 11 | from easydict import EasyDict as edict 12 | 13 | 14 | def train(epoch): 15 | global iters 16 | Avg = AverageMeter() 17 | for batch_idx, (rgb, lidar, depth) in enumerate(trainloader): 18 | if epoch >= config.test_epoch and iters % config.test_iters == 0: 19 | test() 20 | net.train() 21 | rgb, lidar, depth = rgb.cuda(), lidar.cuda(), depth.cuda() 22 | optimizer.zero_grad() 23 | output = net(rgb, lidar) 24 | loss = criterion(output, depth).mean() 25 | loss.backward() 26 | optimizer.step() 27 | Avg.update(loss.item()) 28 | iters += 1 29 | if config.vis and batch_idx % config.vis_iters == 0: 30 | print('Epoch {} Idx {} Loss {:.4f}'.format(epoch, batch_idx, Avg.avg)) 31 | 32 | 33 | def test(): 34 | global best_metric 35 | Avg = AverageMeter() 36 | net.eval() 37 | for batch_idx, (rgb, lidar, depth) in enumerate(testloader): 38 | rgb, lidar, depth = rgb.cuda(), lidar.cuda(), depth.cuda() 39 | with torch.no_grad(): 40 | output = net(rgb, lidar) 41 | prec = metric(output, depth).mean() 42 | Avg.update(prec.item(), rgb.size(0)) 43 | if Avg.avg < best_metric: 44 | best_metric = Avg.avg 45 | save_state(config, net) 46 | print('Best Result: {:.4f}\n'.format(best_metric)) 47 | 48 | 49 | if __name__ == '__main__': 50 | # config_name = 'GN.yaml' 51 | config_name = 'GNS.yaml' 52 | with open(os.path.join('configs', config_name), 'r') as file: 53 | config_data = yaml.load(file, Loader=yaml.FullLoader) 54 | config = edict(config_data) 55 | print(config.name) 56 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(gpu_id) for gpu_id in config.gpu_ids]) 57 | from utils import * 58 | 59 | init_seed(config) 60 | trainloader, testloader = init_dataset(config) 61 | net = init_net(config) 62 | criterion = init_loss(config) 63 | metric = init_metric(config) 64 | net, criterion, metric = init_cuda(net, criterion, metric) 65 | optimizer = init_optim(config, net) 66 | lr_scheduler = init_lr_scheduler(config, optimizer) 67 | iters = 0 68 | best_metric = 100 69 | for epoch in range(config.start_epoch, config.nepoch): 70 | train(epoch) 71 | lr_scheduler.step() 72 | print('Best Results: {:.4f}\n'.format(best_metric)) 73 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @Filename: utils.py 4 | # @Project: GuideNet 5 | # @Author: jie 6 | # @Time: 2021/3/15 5:25 PM 7 | 8 | import os 9 | import torch 10 | import random 11 | import numpy as np 12 | import augs 13 | import models 14 | import datasets 15 | import optimizers 16 | import encoding 17 | import criteria 18 | from PIL import Image 19 | 20 | __all__ = [ 21 | 'AverageMeter', 22 | 'init_seed', 23 | 'init_aug', 24 | 'init_dataset', 25 | 'init_cuda', 26 | 'init_net', 27 | 'init_loss', 28 | 'init_metric', 29 | 'init_optim', 30 | 'init_lr_scheduler', 31 | 'save_state', 32 | 'resume_state', 33 | 'save_result', 34 | ] 35 | 36 | 37 | class AverageMeter(object): 38 | def __init__(self): 39 | self.reset() 40 | 41 | def reset(self): 42 | self.val = 0 43 | self.avg = 0 44 | self.sum = 0 45 | self.count = 0 46 | 47 | def update(self, val, n=1): 48 | self.val = val 49 | self.sum += val * n 50 | self.count += n 51 | self.avg = self.sum / self.count 52 | 53 | 54 | def config_param(model): 55 | param_groups = [] 56 | other_params = [] 57 | for name, param in model.named_parameters(): 58 | if len(param.shape) == 1: 59 | g = {'params': [param], 'weight_decay': 0.0} 60 | param_groups.append(g) 61 | else: 62 | other_params.append(param) 63 | param_groups.append({'params': other_params}) 64 | return param_groups 65 | 66 | 67 | def save_state(config, model): 68 | print('==> Saving model ...') 69 | env_name = config.name + '_' + str(config.manual_seed) 70 | save_path = os.path.join('checkpoints', env_name) 71 | os.makedirs(save_path, exist_ok=True) 72 | model_state_dict = model.state_dict() 73 | state_dict = { 74 | 'net': model_state_dict, 75 | } 76 | torch.save(state_dict, os.path.join(save_path, 'result.pth')) 77 | 78 | 79 | def resume_state(config, model): 80 | env_name = config.name + '_' + str(config.resume_seed) 81 | cp_path = os.path.join('checkpoints', env_name, 'result.pth') 82 | resume_model = torch.load(cp_path)['net'] 83 | model.load_state_dict(resume_model, strict=True) 84 | return model 85 | 86 | 87 | def pad_rep(image, ori_size): 88 | h, w = image.shape 89 | oh, ow = ori_size 90 | pl = (ow - w) // 2 91 | pr = ow - w - pl 92 | pt = oh - h 93 | image_pad = np.pad(image, pad_width=((pt, 0), (pl, pr)), mode='edge') 94 | return image_pad 95 | 96 | 97 | def save_result(config, depths, names, ori_sizes=None): 98 | env_name = config.name + '_' + str(config.resume_seed) 99 | save_path = os.path.join('results', env_name) 100 | os.makedirs(save_path, exist_ok=True) 101 | for i in range(depths.shape[0]): 102 | depth, name = depths[i], names[i] 103 | if ori_sizes is not None: 104 | depth = pad_rep(depth, ori_sizes[i]) 105 | filename = os.path.join(save_path, name) 106 | img = (depth * 256.0).astype('uint16') 107 | Img = Image.fromarray(img) 108 | Img.save(filename) 109 | 110 | 111 | def init_seed(config): 112 | if config.manual_seed == 0: 113 | config.manual_seed = random.randint(1, 10000) 114 | print("Random Seed: ", config.manual_seed) 115 | torch.initial_seed() 116 | random.seed(config.manual_seed) 117 | np.random.seed(config.manual_seed) 118 | torch.manual_seed(config.manual_seed) 119 | torch.cuda.manual_seed_all(config.manual_seed) 120 | 121 | 122 | def init_net(config): 123 | return getattr(models, config.model)() 124 | 125 | 126 | def init_loss(config): 127 | return getattr(criteria, config.loss)() 128 | 129 | 130 | def init_metric(config): 131 | return getattr(criteria, config.metric)() 132 | 133 | 134 | def init_aug(aug_config): 135 | transform = [] 136 | for x in aug_config: 137 | print(x) 138 | if type(x) == str: 139 | transform.append(getattr(augs, x)()) 140 | else: 141 | key, params = x.popitem() 142 | transform.append(getattr(augs, key)(**params)) 143 | return augs.Compose(transform) 144 | 145 | 146 | def init_dataset(config): 147 | train_transform = init_aug(config.train_aug_configs) 148 | test_transform = init_aug(config.test_aug_configs) 149 | key, params = config.data_config.popitem() 150 | dataset = getattr(datasets, key) 151 | trainset = dataset(**params, mode='train', transform=train_transform) 152 | testset = dataset(**params, mode='selval', transform=test_transform) 153 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=config.batch_size, 154 | num_workers=config.num_workers, shuffle=True, drop_last=True, 155 | pin_memory=True) 156 | testloader = torch.utils.data.DataLoader(testset, batch_size=config.batch_size, 157 | num_workers=config.num_workers, shuffle=True, drop_last=True, 158 | pin_memory=True) 159 | print('num_train = {}, num_test = {}'.format(len(trainset), len(testset))) 160 | return trainloader, testloader 161 | 162 | 163 | def init_cuda(net, criterion, metric): 164 | torch.cuda.empty_cache() 165 | net.cuda() 166 | criterion.cuda() 167 | metric.cuda() 168 | net = encoding.parallel.DataParallelModel(net) 169 | criterion = encoding.parallel.DataParallelCriterion(criterion) 170 | metric = encoding.parallel.DataParallelCriterion(metric) 171 | torch.backends.cudnn.benchmark = True 172 | return net, criterion, metric 173 | 174 | 175 | def init_optim(config, net): 176 | key, params = config.optim_config.popitem() 177 | return getattr(optimizers, key)(config_param(net), **params) 178 | 179 | 180 | def init_lr_scheduler(config, optimizer): 181 | key, params = config.lr_config.popitem() 182 | return getattr(torch.optim.lr_scheduler, key)(optimizer, **params) 183 | --------------------------------------------------------------------------------