├── .DS_Store ├── .gitignore ├── README.md ├── dataloaders ├── __init__.py ├── kitti_dataloader │ └── __init__.py ├── nyu_dataloader │ ├── __init__.py │ ├── dataloader.py │ ├── dense_to_sparse.py │ ├── nyu_dataloader.py │ └── transforms.py └── path.py ├── libs ├── __init__.py ├── criterion │ ├── __init__.py │ └── criteria.py ├── metrics.py ├── scheduler │ ├── __init__.py │ └── scheduler.py ├── trainers │ ├── __init__.py │ ├── multi_gpu_trainer.py │ └── single_gpu_trainer.py └── utils.py ├── main.py ├── network ├── __init__.py ├── libs │ ├── __init__.py │ ├── base │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── encoding.py │ │ ├── operation.py │ │ └── pac.py │ ├── inplace_abn │ │ ├── __init__.py │ │ ├── _ext │ │ │ └── __init__.py │ │ ├── bn.py │ │ ├── build.py │ │ ├── build.sh │ │ ├── dense.py │ │ ├── functions.py │ │ ├── misc.py │ │ ├── residual.py │ │ └── src │ │ │ ├── bn.cu │ │ │ ├── bn.h │ │ │ ├── bn.o │ │ │ ├── common.h │ │ │ ├── lib_cffi.cpp │ │ │ └── lib_cffi.h │ └── post_process │ │ ├── CSPN.py │ │ ├── CSPN_new.py │ │ ├── CSPN_ours.py │ │ └── __init__.py ├── unet_cspn_nyu.py ├── unet_ours.py └── utils.py ├── options.py └── result └── nyu.PNG /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dontLoveBugs/CSPN_monodepth/7363bf749b8df4ea29f1a4fa9eebddbf97cf3f4b/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | .idea/ 29 | /.idea 30 | *.iml 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # Environments 94 | .env 95 | .venv 96 | env/ 97 | venv/ 98 | ENV/ 99 | env.bak/ 100 | venv.bak/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ 114 | .dmypy.json 115 | dmypy.json 116 | 117 | # Pyre type checker 118 | .pyre/ 119 | 120 | # JPG PNG 121 | *.jpg 122 | *.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CSPN implemented in Pytorch 0.4.1 2 | 3 | 4 | ### Introduction 5 | This is a PyTorch(0.4.1) implementation of [Depth Estimation via Affinity Learned with Convolutional Spatial Propagation Network](http://arxiv.org/abs/1808.00150). At present, we can provide train script in NYU Depth V2 dataset for depth completion and monocular depth estimation. KITTI will be available soon! 6 | 7 | ### Faster Implementation 8 | We re-implement CSPN using [Pixel-Adaptive Convolution](http://arxiv.org/abs/1904.05373). 9 | 10 | ### Multi_GPU 11 | The implementation of multi-gpus is based on [inplace abn](http://arxiv.org/abs/1712.02616). 12 | 13 | ### Results 14 | Method | Implementation details | rml | rmse | log10 | Delta1 | Delta2 | Delta3 15 | :-------| :------: | :------: | :------: | :------: | :------: | :------: | :------: 16 | Paper | batch size=24 epoch=40 | 0.016 | 0.117 | - | 0.992 | 0.999 | 1.000 17 | Our_impl | batch size=8 iteration=100k | 0.018 | 0.127 | 0.008 | 0.991 | 0.998 | 1.000 18 | Our_CSPN | batch size=8 iteration=100k | 0.018 | 0.127 | 0.008 | 0.991 | 0.998 | 1.000 19 | 20 | ![Image text](https://github.com/dontLoveBugs/CSPN_monodepth/blob/master/result/nyu.PNG) 21 | 22 | ### Third Libs 23 | [inplace abn](https://github.com/mapillary/inplace_abn) 24 | 25 | [Pixel-Adaptive Convolution](https://github.com/NVlabs/pacnet) 26 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Time : 2019-05-19 15:53 5 | @Author : Wang Xin 6 | @Email : wangxin_buaa@163.com 7 | @File : __init__.py.py 8 | """ 9 | 10 | 11 | def create_loader(args, mode='train'): 12 | if args.dataset.lower() == 'nyu': 13 | from dataloaders.nyu_dataloader import create_loader 14 | return create_loader(args, mode=mode) 15 | elif args.dataset.lower() == 'kitti': 16 | return NotImplementedError 17 | else: 18 | return NotImplementedError 19 | -------------------------------------------------------------------------------- /dataloaders/kitti_dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Time : 2019-05-19 15:54 5 | @Author : Wang Xin 6 | @Email : wangxin_buaa@163.com 7 | @File : __init__.py.py 8 | """ 9 | from functools import cmp_to_key 10 | 11 | x = [[1,2], [2, 1], [3, 4]] 12 | 13 | def mycmp(x, y): 14 | if x[1] == y[1]: 15 | return x[0] - y[0] 16 | return x[1] - y[1] 17 | 18 | print(x) 19 | x= sorted(x, key=cmp_to_key(mycmp)) 20 | print(x) 21 | 22 | y = set(0) 23 | 24 | 25 | from queue import PriorityQueue -------------------------------------------------------------------------------- /dataloaders/nyu_dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Time : 2019-05-19 15:53 5 | @Author : Wang Xin 6 | @Email : wangxin_buaa@163.com 7 | @File : __init__.py.py 8 | """ 9 | 10 | 11 | def create_loader(args, mode='train'): 12 | # Data loading code 13 | print('=> creating ', mode, ' loader ...') 14 | import os 15 | from dataloaders.path import Path 16 | root_dir = Path.db_root_dir(args.dataset) 17 | 18 | # sparsifier is a class for generating random sparse depth input from the ground truth 19 | import numpy as np 20 | sparsifier = None 21 | max_depth = args.max_depth if args.max_depth >= 0.0 else np.inf 22 | from dataloaders.nyu_dataloader.dense_to_sparse import UniformSampling 23 | from dataloaders.nyu_dataloader.dense_to_sparse import SimulatedStereo 24 | if args.sparsifier == UniformSampling.name: 25 | sparsifier = UniformSampling(num_samples=args.num_samples, max_depth=max_depth) 26 | elif args.sparsifier == SimulatedStereo.name: 27 | sparsifier = SimulatedStereo(num_samples=args.num_samples, max_depth=max_depth) 28 | 29 | from dataloaders.nyu_dataloader.nyu_dataloader import NYUDataset 30 | 31 | import torch 32 | if mode.lower() == 'train': 33 | traindir = os.path.join(root_dir, 'train') 34 | 35 | if os.path.exists(traindir): 36 | print('Train dataset "{}" is existed!'.format(traindir)) 37 | else: 38 | print('Train dataset "{}" is not existed!'.format(traindir)) 39 | exit(-1) 40 | train_dataset = NYUDataset(traindir, type='train', 41 | modality=args.modality, sparsifier=sparsifier) 42 | # worker_init_fn ensures different sampling patterns for each data loading thread 43 | train_loader = torch.utils.data.DataLoader( 44 | train_dataset, batch_size=args.batch_size, shuffle=True, 45 | num_workers=args.workers, pin_memory=True, sampler=None, 46 | worker_init_fn=lambda work_id: np.random.seed(work_id)) 47 | 48 | return train_loader 49 | 50 | elif mode.lower() == 'val': 51 | valdir = os.path.join(root_dir, 'val') 52 | if os.path.exists(valdir): 53 | print('Val dataset "{}" is existed!'.format(valdir)) 54 | else: 55 | print('Val dataset "{}" is not existed!'.format(valdir)) 56 | exit(-1) 57 | val_dataset = NYUDataset(valdir, type='val', 58 | modality=args.modality, sparsifier=sparsifier) 59 | val_loader = torch.utils.data.DataLoader(val_dataset, 60 | batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True) 61 | 62 | return val_loader 63 | 64 | else: 65 | raise NotImplementedError 66 | -------------------------------------------------------------------------------- /dataloaders/nyu_dataloader/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import numpy as np 4 | import torch.utils.data as data 5 | import h5py 6 | import dataloaders.nyu_dataloader.transforms as transforms 7 | 8 | IMG_EXTENSIONS = ['.h5', ] 9 | 10 | 11 | def is_image_file(filename): 12 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 13 | 14 | 15 | def find_classes(dir): 16 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 17 | classes.sort() 18 | class_to_idx = {classes[i]: i for i in range(len(classes))} 19 | return classes, class_to_idx 20 | 21 | 22 | def make_dataset(dir, class_to_idx): 23 | images = [] 24 | dir = os.path.expanduser(dir) 25 | for target in sorted(os.listdir(dir)): 26 | d = os.path.join(dir, target) 27 | if not os.path.isdir(d): 28 | continue 29 | for root, _, fnames in sorted(os.walk(d)): 30 | for fname in sorted(fnames): 31 | if is_image_file(fname): 32 | path = os.path.join(root, fname) 33 | item = (path, class_to_idx[target]) 34 | images.append(item) 35 | return images 36 | 37 | 38 | def h5_loader(path): 39 | h5f = h5py.File(path, "r") 40 | rgb = np.array(h5f['rgb']) 41 | rgb = np.transpose(rgb, (1, 2, 0)) 42 | depth = np.array(h5f['depth']) 43 | return rgb, depth 44 | 45 | 46 | # def rgb2grayscale(rgb): 47 | # return rgb[:,:,0] * 0.2989 + rgb[:,:,1] * 0.587 + rgb[:,:,2] * 0.114 48 | 49 | to_tensor = transforms.ToTensor() 50 | 51 | 52 | class MyDataloader(data.Dataset): 53 | modality_names = ['rgb', 'rgbd', 'd'] # , 'g', 'gd' 54 | color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4) 55 | 56 | def __init__(self, root, type, sparsifier=None, modality='rgb', loader=h5_loader): 57 | classes, class_to_idx = find_classes(root) 58 | imgs = make_dataset(root, class_to_idx) 59 | assert len(imgs) > 0, "Found 0 images in subfolders of: " + root + "\n" 60 | print("Found {} images in {} folder.".format(len(imgs), type)) 61 | self.root = root 62 | self.imgs = imgs 63 | self.classes = classes 64 | self.class_to_idx = class_to_idx 65 | if type == 'train': 66 | self.transform = self.train_transform 67 | elif type == 'val': 68 | self.transform = self.val_transform 69 | else: 70 | raise (RuntimeError("Invalid dataset type: " + type + "\n" 71 | "Supported dataset types are: train, val")) 72 | self.loader = loader 73 | self.sparsifier = sparsifier 74 | 75 | assert (modality in self.modality_names), "Invalid modality type: " + modality + "\n" + \ 76 | "Supported dataset types are: " + ''.join(self.modality_names) 77 | self.modality = modality 78 | 79 | def train_transform(self, rgb, depth): 80 | raise (RuntimeError("train_transform() is not implemented. ")) 81 | 82 | def val_transform(rgb, depth): 83 | raise (RuntimeError("val_transform() is not implemented.")) 84 | 85 | def create_sparse_depth(self, rgb, depth): 86 | if self.sparsifier is None: 87 | return depth 88 | else: 89 | mask_keep = self.sparsifier.dense_to_sparse(rgb, depth) 90 | sparse_depth = np.zeros(depth.shape) 91 | sparse_depth[mask_keep] = depth[mask_keep] 92 | return sparse_depth 93 | 94 | def create_rgbd(self, rgb, depth): 95 | sparse_depth = self.create_sparse_depth(rgb, depth) 96 | rgbd = np.append(rgb, np.expand_dims(sparse_depth, axis=2), axis=2) 97 | return rgbd 98 | 99 | def __getraw__(self, index): 100 | """ 101 | Args: 102 | index (int): Index 103 | 104 | Returns: 105 | tuple: (rgb, depth) the raw data. 106 | """ 107 | path, target = self.imgs[index] 108 | rgb, depth = self.loader(path) 109 | return rgb, depth 110 | 111 | def __getitem__(self, index): 112 | rgb, depth = self.__getraw__(index) 113 | if self.transform is not None: 114 | rgb_np, depth_np = self.transform(rgb, depth) 115 | else: 116 | raise (RuntimeError("transform not defined")) 117 | 118 | if self.modality == 'rgb': 119 | input_np = rgb_np 120 | elif self.modality == 'rgbd': 121 | input_np = self.create_rgbd(rgb_np, depth_np) 122 | elif self.modality == 'd': 123 | input_np = self.create_sparse_depth(rgb_np, depth_np) 124 | 125 | input_tensor = to_tensor(input_np) 126 | while input_tensor.dim() < 3: 127 | input_tensor = input_tensor.unsqueeze(0) 128 | depth_tensor = to_tensor(depth_np) 129 | depth_tensor = depth_tensor.unsqueeze(0) 130 | 131 | return input_tensor, depth_tensor 132 | 133 | def __len__(self): 134 | return len(self.imgs) 135 | -------------------------------------------------------------------------------- /dataloaders/nyu_dataloader/dense_to_sparse.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2018/12/6 15:10 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | import numpy as np 9 | import cv2 10 | 11 | 12 | def rgb2grayscale(rgb): 13 | return rgb[:, :, 0] * 0.2989 + rgb[:, :, 1] * 0.587 + rgb[:, :, 2] * 0.114 14 | 15 | 16 | class DenseToSparse: 17 | def __init__(self): 18 | pass 19 | 20 | def dense_to_sparse(self, rgb, depth): 21 | pass 22 | 23 | def __repr__(self): 24 | pass 25 | 26 | 27 | class UniformSampling(DenseToSparse): 28 | name = "uar" 29 | 30 | def __init__(self, num_samples, max_depth=np.inf): 31 | DenseToSparse.__init__(self) 32 | self.num_samples = num_samples 33 | self.max_depth = max_depth 34 | 35 | def __repr__(self): 36 | return "%s{ns=%d,md=%f}" % (self.name, self.num_samples, self.max_depth) 37 | 38 | def dense_to_sparse(self, rgb, depth): 39 | """ 40 | Samples pixels with `num_samples`/#pixels probability in `depth`. 41 | Only pixels with a maximum depth of `max_depth` are considered. 42 | If no `max_depth` is given, samples in all pixels 43 | """ 44 | mask_keep = depth > 0 45 | if self.max_depth is not np.inf: 46 | mask_keep = np.bitwise_and(mask_keep, depth <= self.max_depth) 47 | n_keep = np.count_nonzero(mask_keep) 48 | if n_keep == 0: 49 | return mask_keep 50 | else: 51 | prob = float(self.num_samples) / n_keep 52 | return np.bitwise_and(mask_keep, np.random.uniform(0, 1, depth.shape) < prob) 53 | 54 | 55 | class SimulatedStereo(DenseToSparse): 56 | name = "sim_stereo" 57 | 58 | def __init__(self, num_samples, max_depth=np.inf, dilate_kernel=3, dilate_iterations=1): 59 | DenseToSparse.__init__(self) 60 | self.num_samples = num_samples 61 | self.max_depth = max_depth 62 | self.dilate_kernel = dilate_kernel 63 | self.dilate_iterations = dilate_iterations 64 | 65 | def __repr__(self): 66 | return "%s{ns=%d,md=%f,dil=%d.%d}" % \ 67 | (self.name, self.num_samples, self.max_depth, self.dilate_kernel, self.dilate_iterations) 68 | 69 | # We do not use cv2.Canny, since that applies non max suppression 70 | # So we simply do 71 | # RGB to intensitities 72 | # Smooth with gaussian 73 | # Take simple sobel gradients 74 | # Threshold the edge gradient 75 | # Dilatate 76 | def dense_to_sparse(self, rgb, depth): 77 | gray = rgb2grayscale(rgb) 78 | blurred = cv2.GaussianBlur(gray, (5, 5), 0) 79 | gx = cv2.Sobel(blurred, cv2.CV_64F, 1, 0, ksize=5) 80 | gy = cv2.Sobel(blurred, cv2.CV_64F, 0, 1, ksize=5) 81 | 82 | depth_mask = np.bitwise_and(depth != 0.0, depth <= self.max_depth) 83 | 84 | edge_fraction = float(self.num_samples) / np.size(depth) 85 | 86 | mag = cv2.magnitude(gx, gy) 87 | min_mag = np.percentile(mag[depth_mask], 100 * (1.0 - edge_fraction)) 88 | mag_mask = mag >= min_mag 89 | 90 | if self.dilate_iterations >= 0: 91 | kernel = np.ones((self.dilate_kernel, self.dilate_kernel), dtype=np.uint8) 92 | cv2.dilate(mag_mask.astype(np.uint8), kernel, iterations=self.dilate_iterations) 93 | 94 | mask = np.bitwise_and(mag_mask, depth_mask) 95 | return mask 96 | -------------------------------------------------------------------------------- /dataloaders/nyu_dataloader/nyu_dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import dataloaders.nyu_dataloader.transforms as transforms 3 | from dataloaders.nyu_dataloader.dataloader import MyDataloader 4 | 5 | iheight, iwidth = 480, 640 # raw image size 6 | 7 | 8 | class NYUDataset(MyDataloader): 9 | def __init__(self, root, type, sparsifier=None, modality='rgb'): 10 | super(NYUDataset, self).__init__(root, type, sparsifier, modality) 11 | self.output_size = (228, 304) 12 | 13 | def train_transform(self, rgb, depth): 14 | s = np.random.uniform(1.0, 1.5) # random scaling 15 | depth_np = depth / s 16 | angle = np.random.uniform(-5.0, 5.0) # random rotation degrees 17 | do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip 18 | 19 | # perform 1st step of data augmentation 20 | transform = transforms.Compose([ 21 | transforms.Resize(250.0 / iheight), # this is for computational efficiency, since rotation can be slow 22 | transforms.Rotate(angle), 23 | transforms.Resize(s), 24 | transforms.CenterCrop(self.output_size), 25 | transforms.HorizontalFlip(do_flip) 26 | ]) 27 | rgb_np = transform(rgb) 28 | rgb_np = self.color_jitter(rgb_np) # random color jittering 29 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255 30 | depth_np = transform(depth_np) 31 | 32 | return rgb_np, depth_np 33 | 34 | def val_transform(self, rgb, depth): 35 | depth_np = depth 36 | transform = transforms.Compose([ 37 | transforms.Resize(240.0 / iheight), 38 | transforms.CenterCrop(self.output_size), 39 | ]) 40 | rgb_np = transform(rgb) 41 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255 42 | depth_np = transform(depth_np) 43 | 44 | return rgb_np, depth_np 45 | -------------------------------------------------------------------------------- /dataloaders/path.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/21 21:27 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | 9 | class Path(object): 10 | @staticmethod 11 | def db_root_dir(database): 12 | if database == 'kitti': 13 | return '/data/wangxin/KITTI' 14 | elif database == 'nyu': 15 | return 'D:\\DATASETS\\nyudepthv2\\nyudepthv2' 16 | else: 17 | print('Database {} not available.'.format(database)) 18 | raise NotImplementedError 19 | -------------------------------------------------------------------------------- /libs/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/24 16:48 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ -------------------------------------------------------------------------------- /libs/criterion/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/3/2 18:18 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | from libs.criterion.criteria import MaskedL1Loss, MaskedMSELoss, L1_log 9 | from libs.criterion.criteria import CriterionDSN, Criterion_No_DSN 10 | 11 | key_to_criteria = { 12 | 'l1': MaskedL1Loss, 13 | 'l2': MaskedMSELoss, 14 | 'l1_log': L1_log 15 | } 16 | 17 | 18 | def get_criteria(args): 19 | if args.criterion in key_to_criteria: 20 | criterion = key_to_criteria[args.criterion]() 21 | else: 22 | print('no available criterion methods named as ', args.arch) 23 | raise NotImplementedError 24 | 25 | if args.loss_wrapper.lower() == 'dsn': 26 | criterion = CriterionDSN(criterion=criterion) 27 | else: 28 | criterion = Criterion_No_DSN(criterion=criterion) 29 | 30 | return criterion -------------------------------------------------------------------------------- /libs/criterion/criteria.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/10/23 20:04 3 | # @Author : Wang Xin 4 | # @Email : wangxin_buaa@163.com 5 | 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from libs.image_processor import sobel_filter 12 | 13 | 14 | class MaskedMSELoss(nn.Module): 15 | def __init__(self): 16 | super(MaskedMSELoss, self).__init__() 17 | 18 | def forward(self, pred, target): 19 | assert pred.dim() == target.dim(), "inconsistent dimensions" 20 | valid_mask = (target > 0).detach() 21 | diff = target - pred 22 | diff = diff[valid_mask] 23 | self.loss = (diff ** 2).mean() 24 | return self.loss 25 | 26 | 27 | class MaskedL1Loss(nn.Module): 28 | def __init__(self): 29 | super(MaskedL1Loss, self).__init__() 30 | 31 | def forward(self, pred, target): 32 | assert pred.dim() == target.dim(), "inconsistent dimensions" 33 | valid_mask = (target > 0).detach() 34 | # print('target size:', target.size()) 35 | # print('pred size:', pred.size()) 36 | diff = target - pred 37 | diff = diff[valid_mask] 38 | self.loss = diff.abs().mean() 39 | return self.loss 40 | 41 | 42 | class berHuLoss(nn.Module): 43 | def __init__(self): 44 | super(berHuLoss, self).__init__() 45 | 46 | def forward(self, pred, target): 47 | assert pred.dim() == target.dim(), "inconsistent dimensions" 48 | 49 | huber_c = torch.max(pred - target) 50 | huber_c = 0.2 * huber_c 51 | 52 | valid_mask = (target > 0).detach() 53 | diff = target - pred 54 | diff = diff[valid_mask] 55 | diff = diff.abs() 56 | 57 | huber_mask = (diff > huber_c).detach() 58 | 59 | diff2 = diff[huber_mask] 60 | diff2 = diff2 ** 2 61 | 62 | self.loss = torch.cat((diff, diff2)).mean() 63 | 64 | return self.loss 65 | 66 | 67 | class RMSE_log(nn.Module): 68 | def __init__(self): 69 | super(RMSE_log, self).__init__() 70 | 71 | def forward(self, fake, real): 72 | if not fake.shape == real.shape: 73 | _, _, H, W = real.shape 74 | fake = F.upsample(fake, size=(H, W), mode='bilinear') 75 | loss = torch.sqrt(torch.mean(torch.abs(torch.log(real) - torch.log(fake)) ** 2)) 76 | return loss 77 | 78 | 79 | class L1(nn.Module): 80 | def __init__(self): 81 | super(L1, self).__init__() 82 | 83 | def forward(self, fake, real): 84 | if not fake.shape == real.shape: 85 | _, _, H, W = real.shape 86 | fake = F.upsample(fake, size=(H, W), mode='bilinear') 87 | loss = torch.mean(torch.abs(10. * real - 10. * fake)) 88 | return loss 89 | 90 | 91 | class L1_log(nn.Module): 92 | def __init__(self): 93 | super(L1_log, self).__init__() 94 | 95 | def forward(self, fake, real): 96 | assert fake.dim() == real.dim(), "inconsistent dimensions" 97 | 98 | if not fake.shape == real.shape: 99 | _, _, H, W = real.shape 100 | fake = F.interpolate(fake, size=(H, W), mode='bilinear', align_corners=True) 101 | 102 | valid_mask = (real > 0).detach() 103 | real = real[valid_mask] 104 | fake = fake[valid_mask] 105 | 106 | loss = torch.mean(torch.abs(torch.log(real) - torch.log(fake))) 107 | return loss 108 | 109 | 110 | class BerHu(nn.Module): 111 | def __init__(self, threshold=0.2): 112 | super(BerHu, self).__init__() 113 | self.threshold = threshold 114 | 115 | def forward(self, real, fake): 116 | mask = real > 0 117 | if not fake.shape == real.shape: 118 | _, _, H, W = real.shape 119 | fake = F.upsample(fake, size=(H, W), mode='bilinear') 120 | fake = fake * mask 121 | diff = torch.abs(real - fake) 122 | delta = self.threshold * torch.max(diff).data.cpu().numpy()[0] 123 | 124 | part1 = -F.threshold(-diff, -delta, 0.) 125 | part2 = F.threshold(diff ** 2 - delta ** 2, 0., -delta ** 2.) + delta ** 2 126 | part2 = part2 / (2. * delta) 127 | 128 | loss = part1 + part2 129 | loss = torch.sum(loss) 130 | return loss 131 | 132 | 133 | class RMSE(nn.Module): 134 | def __init__(self): 135 | super(RMSE, self).__init__() 136 | 137 | def forward(self, fake, real): 138 | if not fake.shape == real.shape: 139 | _, _, H, W = real.shape 140 | fake = F.upsample(fake, size=(H, W), mode='bilinear') 141 | loss = torch.sqrt(torch.mean(torch.abs(10. * real - 10. * fake) ** 2)) 142 | return loss 143 | 144 | 145 | class GradLoss(nn.Module): 146 | def __init__(self): 147 | super(GradLoss, self).__init__() 148 | 149 | # L1 norm 150 | def forward(self, pred, target): 151 | assert pred.dim() == target.dim(), "inconsistent dimensions" 152 | return torch.mean(torch.abs(target - pred)) 153 | 154 | 155 | class NormalLoss(nn.Module): 156 | def __init__(self): 157 | super(NormalLoss, self).__init__() 158 | 159 | def forward(self, grad_fake, grad_real): 160 | assert grad_fake.dim() == grad_real.dim(), "inconsistent dimensions" 161 | 162 | # prod = (grad_fake[:, :, None, :] @ grad_real[:, :, :, None]).squeeze(-1).squeeze(-1) 163 | prod = grad_fake * grad_real 164 | fake_norm = torch.sqrt(torch.sum(grad_fake ** 2, dim=-1)) 165 | real_norm = torch.sqrt(torch.sum(grad_real ** 2, dim=-1)) 166 | 167 | return 1 - torch.mean(prod / (fake_norm * real_norm)) 168 | 169 | 170 | class Criterion_No_DSN(nn.Module): 171 | ''' 172 | No DSN : We don't need to consider other supervision for the model. 173 | ''' 174 | 175 | def __init__(self, criterion=None): 176 | super(Criterion_No_DSN, self).__init__() 177 | self.criterion = criterion 178 | 179 | def forward(self, preds, target): 180 | h, w = target.size(2), target.size(3) 181 | 182 | if h != preds[0].size(2) or w != preds[0].size(3): 183 | scale_pred = F.interpolate(input=preds[0], size=(h, w), mode='bilinear', align_corners=True) 184 | else: 185 | scale_pred = preds[0] 186 | loss = self.criterion(scale_pred, target) 187 | 188 | return loss 189 | 190 | 191 | class CriterionDSN(nn.Module): 192 | ''' 193 | DSN : We need to consider the other supervision for the model. 194 | ''' 195 | 196 | def __init__(self, criterion=None): 197 | super(CriterionDSN, self).__init__() 198 | self.criterion = criterion 199 | 200 | def forward(self, preds, target): 201 | h, w = target.size(2), target.size(3) 202 | 203 | # print('dsn target size:', target.size()) 204 | # print('dsn h = ', h) 205 | # print('dsn w = ', w) 206 | 207 | if h != preds[0].size(2) or w != preds[0].size(3): 208 | scale_pred = F.upsample(input=preds[0], size=(h, w), mode='bilinear', align_corners=True) 209 | else: 210 | scale_pred = preds[0] 211 | loss1 = self.criterion(scale_pred, target) 212 | 213 | if h != preds[1].size(2) or w != preds[1].size(3): 214 | scale_pred = F.upsample(input=preds[1], size=(h, w), mode='bilinear', align_corners=True) 215 | else: 216 | scale_pred = preds[1] 217 | loss2 = self.criterion(scale_pred, target) 218 | 219 | return loss1 + loss2 * 0.4 220 | -------------------------------------------------------------------------------- /libs/metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/10/23 19:53 3 | # @Author : Wang Xin 4 | # @Email : wangxin_buaa@163.com 5 | 6 | 7 | import torch 8 | import math 9 | import numpy as np 10 | 11 | import torch.nn.functional as F 12 | 13 | 14 | def log10(x): 15 | """Convert a new tensor with the base-10 logarithm of the elements of x. """ 16 | return torch.log(x) / math.log(10) 17 | 18 | 19 | class Result(object): 20 | def __init__(self): 21 | self.irmse, self.imae = 0, 0 22 | self.mse, self.rmse, self.mae = 0, 0, 0 23 | self.absrel, self.lg10 = 0, 0 24 | self.delta1, self.delta2, self.delta3 = 0, 0, 0 25 | self.data_time, self.gpu_time = 0, 0 26 | 27 | self.loss = 0 28 | 29 | def set_to_worst(self): 30 | self.irmse, self.imae = np.inf, np.inf 31 | self.mse, self.rmse, self.mae = np.inf, np.inf, np.inf 32 | self.absrel, self.lg10 = np.inf, np.inf 33 | self.delta1, self.delta2, self.delta3 = 0, 0, 0 34 | self.data_time, self.gpu_time = 0, 0 35 | 36 | self.loss = np.inf 37 | 38 | def update(self, irmse, imae, mse, rmse, mae, absrel, lg10, delta1, delta2, delta3, 39 | gpu_time, data_time, loss=None): 40 | self.irmse, self.imae = irmse, imae 41 | self.mse, self.rmse, self.mae = mse, rmse, mae 42 | self.absrel, self.lg10 = absrel, lg10 43 | self.delta1, self.delta2, self.delta3 = delta1, delta2, delta3 44 | self.data_time, self.gpu_time = data_time, gpu_time 45 | 46 | if loss: 47 | self.loss = loss 48 | 49 | def evaluate(self, output, target, loss=None): 50 | 51 | # print('target dim', target.dim()) 52 | h, w = target.size(2), target.size(3) 53 | 54 | if h != output.size(2) or w != output.size(3): 55 | output = F.upsample(input=output, size=(h, w), mode='bilinear', align_corners=True) 56 | 57 | valid_mask = target>0 58 | output = output[valid_mask] 59 | target = target[valid_mask] 60 | 61 | abs_diff = (output - target).abs() 62 | 63 | self.mse = float((torch.pow(abs_diff, 2)).mean()) 64 | self.rmse = math.sqrt(self.mse) 65 | self.mae = float(abs_diff.mean()) 66 | self.lg10 = float((log10(output) - log10(target)).abs().mean()) 67 | self.absrel = float((abs_diff / target).mean()) 68 | 69 | maxRatio = torch.max(output / target, target / output) 70 | self.delta1 = float((maxRatio < 1.25).float().mean()) 71 | self.delta2 = float((maxRatio < 1.25 ** 2).float().mean()) 72 | self.delta3 = float((maxRatio < 1.25 ** 3).float().mean()) 73 | self.data_time = 0 74 | self.gpu_time = 0 75 | 76 | inv_output = 1 / output 77 | inv_target = 1 / target 78 | abs_inv_diff = (inv_output - inv_target).abs() 79 | self.irmse = math.sqrt((torch.pow(abs_inv_diff, 2)).mean()) 80 | self.imae = float(abs_inv_diff.mean()) 81 | 82 | if loss: 83 | self.loss = loss 84 | 85 | 86 | class AverageMeter(object): 87 | def __init__(self): 88 | self.reset() 89 | 90 | def reset(self): 91 | self.count = 0.0 92 | 93 | self.sum_irmse, self.sum_imae = 0, 0 94 | self.sum_mse, self.sum_rmse, self.sum_mae = 0, 0, 0 95 | self.sum_absrel, self.sum_lg10 = 0, 0 96 | self.sum_delta1, self.sum_delta2, self.sum_delta3 = 0, 0, 0 97 | self.sum_data_time, self.sum_gpu_time = 0, 0 98 | 99 | self.sum_loss = 0 100 | 101 | def update(self, result, gpu_time, data_time, n=1): 102 | self.count += n 103 | 104 | self.sum_irmse += n * result.irmse 105 | self.sum_imae += n * result.imae 106 | self.sum_mse += n * result.mse 107 | self.sum_rmse += n * result.rmse 108 | self.sum_mae += n * result.mae 109 | self.sum_absrel += n * result.absrel 110 | self.sum_lg10 += n * result.lg10 111 | self.sum_delta1 += n * result.delta1 112 | self.sum_delta2 += n * result.delta2 113 | self.sum_delta3 += n * result.delta3 114 | self.sum_data_time += n * data_time 115 | self.sum_gpu_time += n * gpu_time 116 | 117 | self.sum_loss += n * result.loss 118 | 119 | def average(self): 120 | avg = Result() 121 | avg.update( 122 | self.sum_irmse / self.count, self.sum_imae / self.count, 123 | self.sum_mse / self.count, self.sum_rmse / self.count, self.sum_mae / self.count, 124 | self.sum_absrel / self.count, self.sum_lg10 / self.count, 125 | self.sum_delta1 / self.count, self.sum_delta2 / self.count, self.sum_delta3 / self.count, 126 | self.sum_gpu_time / self.count, self.sum_data_time / self.count, self.sum_loss / self.count) 127 | return avg 128 | -------------------------------------------------------------------------------- /libs/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/3/2 18:09 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | from torch.optim.lr_scheduler import ReduceLROnPlateau 9 | from libs.scheduler.scheduler import PolynomialLR, WarmUpLR 10 | 11 | 12 | def get_schedular(optimizer, args): 13 | if args.scheduler.lower() == 'poly_lr': 14 | scheduler = PolynomialLR(optimizer, max_iter=args.max_iter, decay_iter=args.decay_iter, gamma=args.gamma) 15 | elif args.scheduler.lower() == 'reduce_lr': 16 | scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=args.factor, patience=args.lr_patience) 17 | else: 18 | print('ERROR in get_schedular: not implement the scheduler named as ', args.schedular.lower()) 19 | raise NotImplementedError 20 | return scheduler 21 | 22 | 23 | def do_schedule(args, scheduler, it=None, len=None, metrics=None): 24 | if args.scheduler.lower() == 'poly_lr': 25 | scheduler.step() 26 | # print('test') 27 | elif args.scheduler.lower() == 'reduce_lr': 28 | if it is None or len is None or metrics is None: 29 | print('ERROR in do_schedule: it is None or len is None, metrics is None.') 30 | raise RuntimeError 31 | if it % len == 0: 32 | epoch = it // len 33 | scheduler.step(epoch=epoch, metrics=metrics) 34 | else: 35 | print('ERROR in do_schedule: not implement the scheduler named as ', args.schedular.lower()) 36 | raise NotImplementedError 37 | -------------------------------------------------------------------------------- /libs/scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/2/15 17:54 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | from torch.optim.lr_scheduler import _LRScheduler 9 | 10 | 11 | class PolynomialLR(_LRScheduler): 12 | def __init__(self, optimizer, max_iter, decay_iter=1, gamma=0.9, last_epoch=-1): 13 | self.decay_iter = decay_iter 14 | self.max_iter = max_iter 15 | self.gamma = gamma 16 | super(PolynomialLR, self).__init__(optimizer, last_epoch) 17 | 18 | # print(self.decay_iter, self.max_iter, self.gamma, self.last_epoch) 19 | 20 | def get_lr(self): 21 | # print(self.last_epoch, self.decay_iter, self.max_iter) 22 | # print(self.last_epoch % self.decay_iter, self.last_epoch % self.max_iter) 23 | if self.last_epoch % self.decay_iter or self.last_epoch > self.max_iter: 24 | # print('keep lr') 25 | # return [base_lr for base_lr in self.optimizer.param_groups] 26 | return [group["lr"] for group in self.optimizer.param_groups] 27 | else: 28 | factor = (1 - self.last_epoch / float(self.max_iter)) ** self.gamma 29 | # print('Poly Testing: ') 30 | return [base_lr * factor for base_lr in self.base_lrs] 31 | 32 | 33 | class WarmUpLR(_LRScheduler): 34 | def __init__( 35 | self, optimizer, scheduler, mode="linear", warmup_iters=100, gamma=0.2, last_epoch=-1 36 | ): 37 | self.mode = mode 38 | self.scheduler = scheduler 39 | self.warmup_iters = warmup_iters 40 | self.gamma = gamma 41 | super(WarmUpLR, self).__init__(optimizer, last_epoch) 42 | 43 | def get_lr(self): 44 | cold_lrs = self.scheduler.get_lr() 45 | 46 | if self.last_epoch < self.warmup_iters: 47 | if self.mode == "linear": 48 | alpha = self.last_epoch / float(self.warmup_iters) 49 | factor = self.gamma * (1 - alpha) + alpha 50 | 51 | elif self.mode == "constant": 52 | factor = self.gamma 53 | else: 54 | raise KeyError("WarmUp type {} not implemented".format(self.mode)) 55 | 56 | return [factor * base_lr for base_lr in cold_lrs] 57 | 58 | return cold_lrs -------------------------------------------------------------------------------- /libs/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Time : 2019-05-20 18:33 5 | @Author : Wang Xin 6 | @Email : wangxin_buaa@163.com 7 | @File : __init__.py.py 8 | """ -------------------------------------------------------------------------------- /libs/trainers/multi_gpu_trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Time : 2019-05-21 15:07 5 | @Author : Wang Xin 6 | @Email : wangxin_buaa@163.com 7 | @File : multi_gpu_trainer.py 8 | """ 9 | 10 | import os 11 | import time 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from tqdm import tqdm 16 | 17 | from dataloaders import create_loader 18 | from libs import utils 19 | from libs.criterion import get_criteria 20 | from libs.metrics import AverageMeter, Result 21 | from libs.scheduler import get_schedular, do_schedule 22 | 23 | from network.libs.base.encoding import DataParallelModel, DataParallelCriterion, DataParallelEvaluation 24 | from network.libs.base.base_model import EvaluationModule 25 | 26 | 27 | class trainer(object): 28 | 29 | def __init__(self, opt, model, optimizer, start_iter, best_result=None): 30 | self.opt = opt 31 | 32 | self.model = DataParallelModel(model).float().cuda() 33 | self.optimizer = optimizer 34 | self.scheduler = get_schedular(optimizer, self.opt) 35 | 36 | self.criterion = DataParallelCriterion(get_criteria(self.opt)).cuda() 37 | self.evaluation = DataParallelEvaluation(EvaluationModule()).cuda() 38 | 39 | self.output_directory = utils.get_save_path(self.opt) 40 | self.best_txt = os.path.join(self.output_directory, 'best.txt') 41 | self.logger = utils.get_logger(self.output_directory) 42 | opt.write_config(self.output_directory) 43 | 44 | self.st_iter, self.ed_iter = start_iter, self.opt.max_iter 45 | 46 | self.train_loader = create_loader(self.opt, mode='train') 47 | self.eval_loader = create_loader(self.opt, mode='val') 48 | 49 | if best_result: 50 | self.best_result = best_result 51 | else: 52 | self.best_result = Result() 53 | self.best_result.set_to_worst() 54 | 55 | # train 56 | # self.iter_save = len(self.train_loader) 57 | self.iter_save = 50 58 | self.train_meter = AverageMeter() 59 | self.eval_meter = AverageMeter() 60 | self.metric = self.best_result.absrel 61 | self.result = Result() 62 | 63 | # batch size in each GPU 64 | self.ebt = self.opt.batch_size // torch.cuda.device_count() 65 | 66 | def train_iter(self, it): 67 | # Clear gradients (ready to accumulate) 68 | self.optimizer.zero_grad() 69 | 70 | end = time.time() 71 | 72 | try: 73 | input, target = next(loader_iter) 74 | except: 75 | loader_iter = iter(self.train_loader) 76 | input, target = next(loader_iter) 77 | 78 | # _target = self.dc.discretize(target) 79 | 80 | data_time = time.time() - end 81 | # print('data time = ', data_time) 82 | 83 | # compute pred 84 | end = time.time() 85 | pred = self.model(input) # @wx 注意输出 86 | 87 | loss = self.criterion(pred, target) 88 | 89 | # print('## backward 0') 90 | # print('## loss = ', loss.item()) 91 | loss.backward() # compute gradient and do SGD step 92 | # print('## backward 1') 93 | self.optimizer.step() 94 | # print('## backward 2') 95 | 96 | gpu_time = time.time() - end 97 | 98 | # measure accuracy and record loss in each GPU 99 | # irmse, imae, mse, rmse, mae, absrel, squared_rel, lg10, silog, delta1, delta2, delta3\ 100 | out = self.evaluation(pred, target) 101 | self.result.update(out[0].item(), out[1].item(), out[2].item(), out[3].item(), out[4].item(), 102 | out[5].item(), out[6].item(), out[7].item(), out[8].item(), 103 | out[9].item(), out[10].item(), out[11].item(), 0, 0, loss=loss.item()) 104 | 105 | self.train_meter.update(self.result, gpu_time, data_time, input.size(0)) 106 | 107 | avg = self.train_meter.average() 108 | if it % self.opt.print_freq == 0: 109 | print('=> output: {}'.format(self.output_directory)) 110 | print('Train Iter: [{0}/{1}]\t' 111 | 't_Data={data_time:.3f}({average.data_time:.3f}) ' 112 | 't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t' 113 | 'Loss={Loss:.5f}({average.loss:.5f}) ' 114 | 'RMSE={result.rmse:.2f}({average.rmse:.2f}) ' 115 | 'REL={result.absrel:.2f}({average.absrel:.2f}) ' 116 | 'Log10={result.lg10:.3f}({average.lg10:.3f}) ' 117 | 'Delta1={result.delta1:.3f}({average.delta1:.3f}) ' 118 | 'Delta2={result.delta2:.3f}({average.delta2:.3f}) ' 119 | 'Delta3={result.delta3:.3f}({average.delta3:.3f})'.format( 120 | it, self.opt.max_iter, data_time=data_time, 121 | gpu_time=gpu_time, Loss=loss.item(), result=self.result, average=avg)) 122 | 123 | self.logger.add_scalar('Train/Loss', avg.loss, it) 124 | self.logger.add_scalar('Train/RMSE', avg.rmse, it) 125 | self.logger.add_scalar('Train/rel', avg.absrel, it) 126 | self.logger.add_scalar('Train/Log10', avg.lg10, it) 127 | self.logger.add_scalar('Train/Delta1', avg.delta1, it) 128 | self.logger.add_scalar('Train/Delta2', avg.delta2, it) 129 | self.logger.add_scalar('Train/Delta3', avg.delta3, it) 130 | 131 | def eval(self, it): 132 | 133 | skip = len(self.eval_loader) // 9 # save images every skip iters 134 | self.eval_meter.reset() 135 | 136 | end = time.time() 137 | 138 | for i, (input, target) in enumerate(self.eval_loader): 139 | 140 | data_time = time.time() - end 141 | 142 | # compute output 143 | end = time.time() 144 | with torch.no_grad(): 145 | pred = self.model(input) 146 | 147 | gpu_time = time.time() - end 148 | 149 | end = time.time() 150 | 151 | # measure accuracy and record loss 152 | # print(input.size(0)) 153 | out = self.evaluation(pred, target) 154 | self.result.update(out[0].item(), out[1].item(), out[2].item(), out[3].item(), out[4].item(), 155 | out[5].item(), out[6].item(), out[7].item(), out[8].item(), 156 | out[9].item(), out[10].item(), out[11].item(), 0, 0) 157 | self.eval_meter.update(self.result, gpu_time, data_time, input.size(0)) 158 | 159 | if i % skip == 0: 160 | pred = pred[0][0] # 第一张卡的第一个输出 161 | 162 | # save 8 images for visualization 163 | h, w = target.size(2), target.size(3) 164 | if h != pred.size(2) or w != pred.size(3): 165 | pred = F.interpolate(input=pred, size=(h, w), mode='bilinear', align_corners=True) 166 | 167 | data = input[0] 168 | target = target[0] 169 | pred = pred[0] 170 | 171 | if self.opt.modality == 'd': 172 | img_merge = None 173 | else: 174 | if self.opt.modality == 'rgb': 175 | rgb = data 176 | elif self.opt.modality == 'rgbd': 177 | rgb = data[:3, :, :] 178 | depth = data[3:, :, :] 179 | 180 | if i == 0: 181 | if self.opt.modality == 'rgbd': 182 | img_merge = utils.merge_into_row_with_gt(rgb, depth, target, pred) 183 | else: 184 | img_merge = utils.merge_into_row(rgb, target, pred) 185 | 186 | elif (i < 8 * skip) and (i % skip == 0): 187 | if self.opt.modality == 'rgbd': 188 | row = utils.merge_into_row_with_gt(rgb, depth, target, pred) 189 | else: 190 | row = utils.merge_into_row(rgb, target, pred) 191 | img_merge = utils.add_row(img_merge, row) 192 | elif i == 8 * skip: 193 | filename = self.output_directory + '/comparison_' + str(it) + '.png' 194 | utils.save_image(img_merge, filename) 195 | 196 | if (i + 1) % self.opt.print_freq == 0: 197 | print('Test: [{0}/{1}]\t' 198 | 't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t' 199 | 'RMSE={result.rmse:.2f}({average.rmse:.2f}) ' 200 | 'REL={result.absrel:.2f}({average.absrel:.2f}) ' 201 | 'Log10={result.lg10:.3f}({average.lg10:.3f}) ' 202 | 'Delta1={result.delta1:.3f}({average.delta1:.3f}) ' 203 | 'Delta2={result.delta2:.3f}({average.delta2:.3f}) ' 204 | 'Delta3={result.delta3:.3f}({average.delta3:.3f}) '.format( 205 | i + 1, len(self.eval_loader), gpu_time=gpu_time, result=self.result, 206 | average=self.eval_meter.average())) 207 | 208 | avg = self.eval_meter.average() 209 | 210 | self.logger.add_scalar('Test/RMSE', avg.rmse, it) 211 | self.logger.add_scalar('Test/rel', avg.absrel, it) 212 | self.logger.add_scalar('Test/Log10', avg.lg10, it) 213 | self.logger.add_scalar('Test/Delta1', avg.delta1, it) 214 | self.logger.add_scalar('Test/Delta2', avg.delta2, it) 215 | self.logger.add_scalar('Test/Delta3', avg.delta3, it) 216 | 217 | print('\n*\n' 218 | 'RMSE={average.rmse:.3f}\n' 219 | 'Rel={average.absrel:.3f}\n' 220 | 'Log10={average.lg10:.3f}\n' 221 | 'Delta1={average.delta1:.3f}\n' 222 | 'Delta2={average.delta2:.3f}\n' 223 | 'Delta3={average.delta3:.3f}\n' 224 | 't_GPU={time:.3f}\n'.format( 225 | average=avg, time=avg.gpu_time)) 226 | 227 | def train_eval(self): 228 | 229 | for it in tqdm(range(self.st_iter, self.ed_iter + 1), total=self.ed_iter - self.st_iter + 1, 230 | leave=False, dynamic_ncols=True): 231 | self.model.train() 232 | self.train_iter(it) 233 | 234 | if it % self.iter_save == 0: 235 | self.model.eval() 236 | self.eval(it) 237 | 238 | self.metric = self.eval_meter.average().silog 239 | train_avg = self.train_meter.average() 240 | eval_avg = self.eval_meter.average() 241 | 242 | self.logger.add_scalars('TrainVal/rmse', 243 | {'train_rmse': train_avg.rmse, 'test_rmse': eval_avg.rmse}, it) 244 | self.logger.add_scalars('TrainVal/rel', 245 | {'train_rel': train_avg.absrel, 'test_rmse': eval_avg.absrel}, it) 246 | self.logger.add_scalars('TrainVal/lg10', 247 | {'train_lg10': train_avg.lg10, 'test_rmse': eval_avg.lg10}, it) 248 | self.logger.add_scalars('TrainVal/Delta1', 249 | {'train_d1': train_avg.delta1, 'test_d1': eval_avg.delta1}, it) 250 | self.logger.add_scalars('TrainVal/Delta2', 251 | {'train_d2': train_avg.delta2, 'test_d2': eval_avg.delta2}, it) 252 | self.logger.add_scalars('TrainVal/Delta3', 253 | {'train_d3': train_avg.delta3, 'test_d3': eval_avg.delta3}, it) 254 | self.train_meter.reset() 255 | 256 | # remember best rmse and save checkpoint 257 | is_best = eval_avg.absrel < self.best_result.absrel 258 | if is_best: 259 | self.best_result = eval_avg 260 | with open(self.best_txt, 'w') as txtfile: 261 | txtfile.write( 262 | "Iter={}, rmse={:.3f}, rel={:.3f}, log10={:.3f}, d1={:.3f}, d2={:.3f}, dd31={:.3f}, " 263 | "t_gpu={:.4f}".format(it, eval_avg.rmse, eval_avg.absrel, eval_avg.lg10, 264 | eval_avg.delta1, eval_avg.delta2, eval_avg.delta3, eval_avg.gpu_time)) 265 | 266 | # save checkpoint for each epoch 267 | utils.save_checkpoint({ 268 | 'args': self.opt, 269 | 'epoch': it, 270 | 'state_dict': self.model.state_dict(), 271 | 'best_result': self.best_result, 272 | 'optimizer': self.optimizer, 273 | }, is_best, it, self.output_directory) 274 | 275 | # Update learning rate 276 | do_schedule(self.opt, self.scheduler, it=it, len=self.iter_save, metrics=self.metric) 277 | 278 | # record the change of learning_rate 279 | for i, param_group in enumerate(self.optimizer.param_groups): 280 | old_lr = float(param_group['lr']) 281 | self.logger.add_scalar('Lr/lr_' + str(i), old_lr, it) 282 | 283 | self.logger.close() 284 | -------------------------------------------------------------------------------- /libs/trainers/single_gpu_trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Time : 2019-05-20 18:34 5 | @Author : Wang Xin 6 | @Email : wangxin_buaa@163.com 7 | @File : single_gpu_trainer.py 8 | """ 9 | 10 | import os 11 | import time 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from tqdm import tqdm 16 | 17 | from libs import utils 18 | from libs.criterion import get_criteria 19 | from libs.metrics import AverageMeter, Result 20 | from libs.scheduler import get_schedular, do_schedule 21 | 22 | 23 | class trainer(object): 24 | 25 | def __init__(self, opt, model, optimizer, start_iter, best_result=None): 26 | self.opt = opt 27 | self.model = model.cuda() 28 | self.optimizer = optimizer 29 | self.scheduler = get_schedular(optimizer, self.opt) 30 | self.criterion = get_criteria(self.opt) 31 | 32 | self.criterion = get_criteria(self.opt) 33 | 34 | self.output_directory = utils.get_save_path(self.opt) 35 | self.best_txt = os.path.join(self.output_directory, 'best.txt') 36 | self.logger = utils.get_logger(self.output_directory) 37 | opt.write_config(self.output_directory) 38 | 39 | self.st_iter, self.ed_iter = start_iter, self.opt.max_iter 40 | 41 | # data loader 42 | from dataloaders import create_loader 43 | self.train_loader = create_loader(self.opt, mode='train') 44 | self.eval_loader = create_loader(self.opt, mode='val') 45 | 46 | if best_result: 47 | self.best_result = best_result 48 | else: 49 | self.best_result = Result() 50 | self.best_result.set_to_worst() 51 | 52 | # train parameters 53 | self.iter_save = len(self.train_loader) 54 | # self.iter_save = len(self.train_loader) 55 | self.train_meter = AverageMeter() 56 | self.eval_meter = AverageMeter() 57 | self.metric = self.best_result.absrel 58 | self.result = Result() 59 | 60 | def train_iter(self, it): 61 | # Clear gradients (ready to accumulate) 62 | self.optimizer.zero_grad() 63 | 64 | end = time.time() 65 | 66 | try: 67 | input, target = next(loader_iter) 68 | except: 69 | loader_iter = iter(self.train_loader) 70 | input, target = next(loader_iter) 71 | 72 | input, target = input.cuda(), target.cuda() 73 | data_time = time.time() - end 74 | 75 | # compute pred 76 | end = time.time() 77 | pred = self.model(input) # @wx 注意输出 78 | 79 | loss = self.criterion(pred, target) 80 | loss.backward() # compute gradient and do SGD step 81 | self.optimizer.step() 82 | 83 | gpu_time = time.time() - end 84 | 85 | # measure accuracy and record loss in each GPU 86 | self.result.set_to_worst() 87 | self.result.evaluate(pred[0], target, loss.item()) 88 | self.train_meter.update(self.result, gpu_time, data_time, input.size(0)) 89 | 90 | avg = self.train_meter.average() 91 | if it % self.opt.print_freq == 0: 92 | print('=> output: {}'.format(self.output_directory)) 93 | print('Train Iter: [{0}/{1}]\t' 94 | 't_Data={data_time:.3f}({average.data_time:.3f}) ' 95 | 't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t' 96 | 'Loss={Loss:.5f}({average.loss:.5f}) ' 97 | 'RMSE={result.rmse:.2f}({average.rmse:.2f}) ' 98 | 'REL={result.absrel:.2f}({average.absrel:.2f}) ' 99 | 'Log10={result.lg10:.3f}({average.lg10:.3f}) ' 100 | 'Delta1={result.delta1:.3f}({average.delta1:.3f}) ' 101 | 'Delta2={result.delta2:.3f}({average.delta2:.3f}) ' 102 | 'Delta3={result.delta3:.3f}({average.delta3:.3f})'.format( 103 | it, self.opt.max_iter, data_time=data_time, 104 | gpu_time=gpu_time, Loss=loss.item(), result=self.result, average=avg)) 105 | 106 | self.logger.add_scalar('Train/Loss', avg.loss, it) 107 | self.logger.add_scalar('Train/RMSE', avg.rmse, it) 108 | self.logger.add_scalar('Train/rel', avg.absrel, it) 109 | self.logger.add_scalar('Train/Log10', avg.lg10, it) 110 | self.logger.add_scalar('Train/Delta1', avg.delta1, it) 111 | self.logger.add_scalar('Train/Delta2', avg.delta2, it) 112 | self.logger.add_scalar('Train/Delta3', avg.delta3, it) 113 | 114 | def eval(self, it): 115 | 116 | skip = len(self.eval_loader) // 9 # save images every skip iters 117 | self.eval_meter.reset() 118 | 119 | for i, (input, target) in enumerate(self.eval_loader): 120 | 121 | end = time.time() 122 | input, target = input.cuda(), target.cuda() 123 | 124 | data_time = time.time() - end 125 | 126 | # compute output 127 | end = time.time() 128 | with torch.no_grad(): 129 | pred = self.model(input) 130 | 131 | gpu_time = time.time() - end 132 | 133 | # measure accuracy and record loss 134 | # print(input.size(0)) 135 | 136 | self.result.set_to_worst() 137 | self.result.evaluate(pred[0], target) 138 | self.eval_meter.update(self.result, gpu_time, data_time, input.size(0)) 139 | 140 | if i % skip == 0: 141 | pred = pred[0] 142 | 143 | # save 8 images for visualization 144 | h, w = target.size(2), target.size(3) 145 | if h != pred.size(2) or w != pred.size(3): 146 | pred = F.interpolate(input=pred, size=(h, w), mode='bilinear', align_corners=True) 147 | 148 | data = input[0] 149 | target = target[0] 150 | pred = pred[0] 151 | 152 | if self.opt.modality == 'd': 153 | img_merge = None 154 | else: 155 | if self.opt.modality == 'rgb': 156 | rgb = data 157 | elif self.opt.modality == 'rgbd': 158 | rgb = data[:3, :, :] 159 | depth = data[3:, :, :] 160 | 161 | if i == 0: 162 | if self.opt.modality == 'rgbd': 163 | img_merge = utils.merge_into_row_with_gt(rgb, depth, target, pred) 164 | else: 165 | img_merge = utils.merge_into_row(rgb, target, pred) 166 | 167 | elif (i < 8 * skip) and (i % skip == 0): 168 | if self.opt.modality == 'rgbd': 169 | row = utils.merge_into_row_with_gt(rgb, depth, target, pred) 170 | else: 171 | row = utils.merge_into_row(rgb, target, pred) 172 | img_merge = utils.add_row(img_merge, row) 173 | elif i == 8 * skip: 174 | filename = self.output_directory + '/comparison_' + str(it) + '.png' 175 | utils.save_image(img_merge, filename) 176 | 177 | if (i + 1) % self.opt.print_freq == 0: 178 | print('Test: [{0}/{1}]\t' 179 | 't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t' 180 | 'RMSE={result.rmse:.2f}({average.rmse:.2f}) ' 181 | 'REL={result.absrel:.2f}({average.absrel:.2f}) ' 182 | 'Log10={result.lg10:.3f}({average.lg10:.3f}) ' 183 | 'Delta1={result.delta1:.3f}({average.delta1:.3f}) ' 184 | 'Delta2={result.delta2:.3f}({average.delta2:.3f}) ' 185 | 'Delta3={result.delta3:.3f}({average.delta3:.3f}) '.format( 186 | i + 1, len(self.eval_loader), gpu_time=gpu_time, result=self.result, 187 | average=self.eval_meter.average())) 188 | 189 | avg = self.eval_meter.average() 190 | 191 | self.logger.add_scalar('Test/RMSE', avg.rmse, it) 192 | self.logger.add_scalar('Test/rel', avg.absrel, it) 193 | self.logger.add_scalar('Test/Log10', avg.lg10, it) 194 | self.logger.add_scalar('Test/Delta1', avg.delta1, it) 195 | self.logger.add_scalar('Test/Delta2', avg.delta2, it) 196 | self.logger.add_scalar('Test/Delta3', avg.delta3, it) 197 | 198 | print('\n*\n' 199 | 'RMSE={average.rmse:.3f}\n' 200 | 'Rel={average.absrel:.3f}\n' 201 | 'Log10={average.lg10:.3f}\n' 202 | 'Delta1={average.delta1:.3f}\n' 203 | 'Delta2={average.delta2:.3f}\n' 204 | 'Delta3={average.delta3:.3f}\n' 205 | 't_GPU={time:.3f}\n'.format( 206 | average=avg, time=avg.gpu_time)) 207 | 208 | def train_eval(self): 209 | 210 | for it in tqdm(range(self.st_iter, self.ed_iter + 1), total=self.ed_iter - self.st_iter + 1, 211 | leave=False, dynamic_ncols=True): 212 | self.model.train() 213 | self.train_iter(it) 214 | 215 | # save the change of learning_rate 216 | for i, param_group in enumerate(self.optimizer.param_groups): 217 | old_lr = float(param_group['lr']) 218 | self.logger.add_scalar('Lr/lr_' + str(i), old_lr, it) 219 | 220 | if it % self.iter_save == 0: 221 | self.model.eval() 222 | self.eval(it) 223 | 224 | self.metric = self.eval_meter.average().absrel 225 | train_avg = self.train_meter.average() 226 | eval_avg = self.eval_meter.average() 227 | 228 | self.logger.add_scalars('TrainVal/rmse', 229 | {'train_rmse': train_avg.rmse, 'test_rmse': eval_avg.rmse}, it) 230 | self.logger.add_scalars('TrainVal/rel', 231 | {'train_rel': train_avg.absrel, 'test_rmse': eval_avg.absrel}, it) 232 | self.logger.add_scalars('TrainVal/lg10', 233 | {'train_lg10': train_avg.lg10, 'test_rmse': eval_avg.lg10}, it) 234 | self.logger.add_scalars('TrainVal/Delta1', 235 | {'train_d1': train_avg.delta1, 'test_d1': eval_avg.delta1}, it) 236 | self.logger.add_scalars('TrainVal/Delta2', 237 | {'train_d2': train_avg.delta2, 'test_d2': eval_avg.delta2}, it) 238 | self.logger.add_scalars('TrainVal/Delta3', 239 | {'train_d3': train_avg.delta3, 'test_d3': eval_avg.delta3}, it) 240 | 241 | self.train_meter.reset() 242 | 243 | # remember best rmse and save checkpoint 244 | is_best = eval_avg.absrel < self.best_result.absrel 245 | if is_best: 246 | self.best_result = eval_avg 247 | with open(self.best_txt, 'w') as txtfile: 248 | txtfile.write( 249 | "Iter={}, rmse={:.3f}, rel={:.3f}, log10={:.3f}, d1={:.3f}, d2={:.3f}, dd31={:.3f}, " 250 | "t_gpu={:.4f}".format(it, eval_avg.rmse, eval_avg.absrel, eval_avg.lg10, 251 | eval_avg.delta1, eval_avg.delta2, eval_avg.delta3, eval_avg.gpu_time)) 252 | 253 | # save checkpoint for each epoch 254 | utils.save_checkpoint({ 255 | 'args': self.opt, 256 | 'epoch': it, 257 | 'state_dict': self.model.state_dict(), 258 | 'best_result': self.best_result, 259 | 'optimizer': self.optimizer, 260 | }, is_best, it, self.output_directory) 261 | 262 | # Update learning rate 263 | do_schedule(self.opt, self.scheduler, it=it, len=self.iter_save, metrics=self.metric) 264 | 265 | self.logger.close() 266 | -------------------------------------------------------------------------------- /libs/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/10/21 20:57 3 | # @Author : Wang Xin 4 | # @Email : wangxin_buaa@163.com 5 | 6 | import glob 7 | import os 8 | import shutil 9 | import socket 10 | import torch 11 | from datetime import datetime 12 | from tensorboardX import SummaryWriter 13 | 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | from PIL import Image 17 | 18 | cmap = plt.cm.jet 19 | 20 | 21 | def get_save_path(args): 22 | save_dir_root = os.getcwd() 23 | save_dir_root = os.path.join(save_dir_root, 'result', args.dataset, args.arch) 24 | if args.restore: 25 | return args.restore[:-len(args.restore.split('/')[-1])] 26 | else: 27 | runs = sorted(glob.glob(os.path.join(save_dir_root, 'run_*'))) 28 | run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0 29 | 30 | save_dir = os.path.join(save_dir_root, 'run_' + str(run_id)) 31 | 32 | if not os.path.exists(save_dir): 33 | os.makedirs(save_dir) 34 | return save_dir 35 | 36 | 37 | def get_logger(output_directory): 38 | log_path = os.path.join(output_directory, 'logs', 39 | datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname()) 40 | if os.path.isdir(log_path): 41 | shutil.rmtree(log_path) 42 | os.makedirs(log_path) 43 | logger = SummaryWriter(log_path) 44 | return logger 45 | 46 | 47 | def write_config_file(args, output_directory): 48 | config_txt = os.path.join(output_directory, 'config.txt') 49 | 50 | # write training parameters to config file 51 | if not os.path.exists(config_txt): 52 | with open(config_txt, 'w') as txtfile: 53 | args_ = vars(args) 54 | args_str = '' 55 | for k, v in args_.items(): 56 | args_str = args_str + str(k) + ':' + str(v) + ',\t\n' 57 | txtfile.write(args_str) 58 | 59 | 60 | # save checkpoint 61 | def save_checkpoint(state, is_best, epoch, output_directory): 62 | checkpoint_filename = os.path.join(output_directory, 'checkpoint-' + str(epoch) + '.pth.tar') 63 | torch.save(state, checkpoint_filename) 64 | if is_best: 65 | best_filename = os.path.join(output_directory, 'model_best.pth.tar') 66 | shutil.copyfile(checkpoint_filename, best_filename) 67 | 68 | 69 | def colored_depthmap(depth, d_min=None, d_max=None): 70 | if d_min is None: 71 | d_min = np.min(depth) 72 | if d_max is None: 73 | d_max = np.max(depth) 74 | depth_relative = (depth - d_min) / (d_max - d_min) 75 | return 255 * cmap(depth_relative)[:, :, :3] # H, W, C 76 | 77 | 78 | def merge_into_row(input, depth_target, depth_pred): 79 | # rgb = 255 * np.transpose(np.squeeze(input.cpu().numpy()), (1, 2, 0)) # H, W, C 80 | rgb = np.transpose(np.squeeze(input.cpu().numpy()), (1, 2, 0)) 81 | depth_target_cpu = np.squeeze(depth_target.cpu().numpy()) 82 | depth_pred_cpu = np.squeeze(depth_pred.data.cpu().numpy()) 83 | 84 | d_min = min(np.min(depth_target_cpu), np.min(depth_pred_cpu)) 85 | d_max = max(np.max(depth_target_cpu), np.max(depth_pred_cpu)) 86 | depth_target_col = colored_depthmap(depth_target_cpu, d_min, d_max) 87 | depth_pred_col = colored_depthmap(depth_pred_cpu, d_min, d_max) 88 | 89 | # print(rgb.shape, depth_target_col.shape, depth_pred_col.shape) 90 | img_merge = np.hstack([rgb, depth_target_col, depth_pred_col]) 91 | 92 | return img_merge 93 | 94 | 95 | def merge_into_row_with_gt(input, depth_input, depth_target, depth_pred): 96 | rgb = 255 * np.transpose(np.squeeze(input.cpu().numpy()), (1, 2, 0)) # H, W, C 97 | depth_input_cpu = np.squeeze(depth_input.cpu().numpy()) 98 | depth_target_cpu = np.squeeze(depth_target.cpu().numpy()) 99 | depth_pred_cpu = np.squeeze(depth_pred.data.cpu().numpy()) 100 | 101 | d_min = min(np.min(depth_input_cpu), np.min(depth_target_cpu), np.min(depth_pred_cpu)) 102 | d_max = max(np.max(depth_input_cpu), np.max(depth_target_cpu), np.max(depth_pred_cpu)) 103 | depth_input_col = colored_depthmap(depth_input_cpu, d_min, d_max) 104 | depth_target_col = colored_depthmap(depth_target_cpu, d_min, d_max) 105 | depth_pred_col = colored_depthmap(depth_pred_cpu, d_min, d_max) 106 | 107 | img_merge = np.hstack([rgb, depth_input_col, depth_target_col, depth_pred_col]) 108 | 109 | return img_merge 110 | 111 | 112 | def merge_rgb_depth_into_row(input, depth): 113 | rgb = np.transpose(np.squeeze(input.cpu().numpy()), (1, 2, 0)) 114 | depth_pred_cpu = np.squeeze(depth.data.cpu().numpy()) 115 | 116 | d_min = np.min(depth_pred_cpu) 117 | d_max = np.max(depth_pred_cpu) 118 | depth_pred_col = colored_depthmap(depth_pred_cpu, d_min, d_max) 119 | 120 | # print(rgb.shape, depth_target_col.shape, depth_pred_col.shape) 121 | img_merge = np.hstack([rgb, depth_pred_col]) 122 | 123 | return img_merge 124 | 125 | 126 | def add_row(img_merge, row): 127 | return np.vstack([img_merge, row]) 128 | 129 | 130 | def save_image(img_merge, filename): 131 | img_merge = Image.fromarray(img_merge.astype('uint8')) 132 | img_merge.save(filename) 133 | 134 | 135 | # write feature maps 136 | fmap = plt.cm.jet 137 | 138 | 139 | def feature_map(feature, min=None, max=None): 140 | if min is None: 141 | min = np.min(feature) 142 | if max is None: 143 | max = np.max(feature) 144 | 145 | relative = (feature - min) / (max - min) 146 | 147 | return 255 * fmap(relative)[:, :, :3] 148 | 149 | 150 | def merge_features_into_row(features, featuers_num=9): 151 | features_cpu = np.squeeze(features.cpu().numpy()) 152 | # print(features_cpu.shape) 153 | f_min = np.min(features_cpu) 154 | f_max = np.max(features_cpu) 155 | 156 | f = [] 157 | 158 | for i in range(featuers_num): 159 | f.append(feature_map(features_cpu[i], min=f_min, max=f_max)) 160 | img_merge = np.hstack(f) 161 | 162 | return img_merge 163 | 164 | 165 | def add_features_row(img_merge, row): 166 | return np.vstack([img_merge, row]) 167 | 168 | 169 | def save_featues_map(img_merge, name): 170 | img_merge = Image.fromarray(img_merge.astype('uint8')) 171 | img_merge.save(name) 172 | 173 | 174 | def save_features(features, filename, features_num=9): 175 | features_cpu = np.squeeze(features.cpu().numpy()) 176 | # print(features_cpu.shape) 177 | f_min = np.min(features_cpu) 178 | f_max = np.max(features_cpu) 179 | 180 | f = [] 181 | 182 | for i in range(features_num): 183 | f.append(feature_map(features_cpu[i], min=f_min, max=f_max)) 184 | 185 | # print('f shape:', f[0].shape) 186 | 187 | img_merge = np.hstack(f) 188 | 189 | # print('img_merge shape:', img_merge.shape) 190 | 191 | img_merge = Image.fromarray(img_merge.astype('uint8')) 192 | img_merge.save(filename) 193 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Time : 2019-05-20 16:52 5 | @Author : Wang Xin 6 | @Email : wangxin_buaa@163.com 7 | @File : main.py 8 | """ 9 | 10 | import os 11 | import random 12 | import numpy as np 13 | import torch 14 | from torch.backends import cudnn 15 | 16 | 17 | from network import get_model, get_train_params 18 | from options import Options 19 | 20 | 21 | def main(): 22 | opt = Options() 23 | opt.parse_command() 24 | opt.print_items() 25 | 26 | # if setting gpu id, the using single GPU 27 | if opt.gpu: 28 | print('Single GPU Mode.') 29 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 30 | 31 | # set random seed 32 | torch.manual_seed(opt.manual_seed) 33 | torch.cuda.manual_seed(opt.manual_seed) 34 | np.random.seed(opt.manual_seed) 35 | random.seed(opt.manual_seed) 36 | 37 | cudnn.benchmark = True 38 | 39 | if torch.cuda.device_count() > 1: 40 | print('Multi-GPUs Mode.') 41 | print("Let's use ", torch.cuda.device_count(), " GPUs!") 42 | else: 43 | print('Single GPU Mode.') 44 | print("Let's use GPU:", opt.gpu) 45 | 46 | if opt.restore: 47 | assert os.path.isfile(opt.restore), \ 48 | "=> no checkpoint found at '{}'".format(opt.restore) 49 | print("=> loading checkpoint '{}'".format(opt.restore)) 50 | checkpoint = torch.load(opt.restore) 51 | 52 | start_iter = checkpoint['epoch'] + 1 53 | best_result = checkpoint['best_result'] 54 | optimizer = checkpoint['optimizer'] 55 | 56 | model = get_model(opt) 57 | model.load_state_dict(checkpoint['state_dict']) 58 | 59 | print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch'])) 60 | del checkpoint # clear memory 61 | # del model_dict 62 | torch.cuda.empty_cache() 63 | else: 64 | print("=> creating Model") 65 | model = get_model(opt) 66 | 67 | print("=> model created.") 68 | start_iter = 1 69 | best_result = None 70 | 71 | # different modules have different learning rate 72 | train_params = get_train_params(opt, model) 73 | optimizer = torch.optim.SGD(train_params, lr=opt.lr, momentum=opt.momentum, 74 | weight_decay=opt.weight_decay) 75 | 76 | if torch.cuda.device_count() == 1: 77 | from libs.trainers import single_gpu_trainer 78 | trainer = single_gpu_trainer.trainer(opt, model, optimizer, start_iter, best_result) 79 | trainer.train_eval() 80 | else: 81 | from libs.trainers import multi_gpu_trainer 82 | trainer = multi_gpu_trainer.trainer(opt, model, optimizer, start_iter, best_result) 83 | trainer.train_eval() 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Time : 2019-05-19 15:29 5 | @Author : Wang Xin 6 | @Email : wangxin_buaa@163.com 7 | @File : __init__.py.py 8 | """ 9 | 10 | 11 | def get_model(opt): 12 | if opt.dataset.lower() == 'kitti': 13 | raise NotImplementedError 14 | elif opt.dataset.lower() == 'nyu': 15 | if opt.modality.lower() == 'rgb': 16 | raise NotImplementedError 17 | elif opt.modality.lower() == 'rgbd': 18 | if opt.arch.lower() == 'unet': 19 | from network.unet_ours import resnet50 20 | return resnet50(pretrained=True) 21 | else: 22 | raise NotImplementedError 23 | else: 24 | raise NotImplementedError 25 | 26 | 27 | def get_train_params(opt, model): 28 | return model.parameters() -------------------------------------------------------------------------------- /network/libs/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/6/30 22:12 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ -------------------------------------------------------------------------------- /network/libs/base/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Time : 2019-06-17 20:45 5 | @Author : Wang Xin 6 | @Email : wangxin_buaa@163.com 7 | @File : __init__.py.py 8 | """ -------------------------------------------------------------------------------- /network/libs/base/base_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Time : 2019-05-19 13:29 5 | @Author : Wang Xin 6 | @Email : wangxin_buaa@163.com 7 | @File : base_model.py 8 | """ 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn import functional as F 13 | 14 | import sys 15 | 16 | 17 | """ 18 | The methods for evaluation packaged by nn.Module 19 | """ 20 | 21 | 22 | class EvaluationModule(nn.Module): 23 | 24 | def __init__(self, depth_coefficients=None): 25 | super(EvaluationModule, self).__init__() 26 | self.dc = depth_coefficients 27 | 28 | def forward(self, pred, target): 29 | pred = pred[0] # pred is a list [pred, others(usually, multi scale pred)] 30 | h, w = target.size(2), target.size(3) 31 | 32 | # print('# eval 0') 33 | 34 | if self.dc: 35 | pred = F.softmax(pred, dim=1) 36 | pred = self.dc.serialize(pred) 37 | 38 | # print('# eval 1') 39 | # print('# eval:', pred.shape, target.shape) 40 | 41 | if h != pred.size(2) or w != pred.size(3): 42 | output = F.interpolate(input=pred, size=(h, w), mode='bilinear', align_corners=True) 43 | else: 44 | output = pred 45 | 46 | valid_mask = target > 0 47 | output = output[valid_mask] 48 | target = target[valid_mask] 49 | 50 | abs_diff = (output - target).abs() 51 | 52 | mse = (torch.pow(abs_diff, 2)).mean() 53 | rmse = torch.sqrt(mse) 54 | mae = abs_diff.mean() 55 | lg10 = (torch.log10(output) - torch.log10(target)).abs().mean() 56 | absrel = (abs_diff / target).mean() 57 | 58 | maxRatio = torch.max(output / target, target / output) 59 | delta1 = (maxRatio < 1.25).float().mean() 60 | delta2 = (maxRatio < 1.25 ** 2).float().mean() 61 | delta3 = (maxRatio < 1.25 ** 3).float().mean() 62 | # data_time = 0 63 | # gpu_time = 0 64 | 65 | inv_output = 1 / output 66 | inv_target = 1 / target 67 | abs_inv_diff = (inv_output - inv_target).abs() 68 | irmse = torch.sqrt((torch.pow(abs_inv_diff, 2)).mean()) 69 | imae = abs_inv_diff.mean() 70 | 71 | # return irmse, imae, mse, rmse, mae, absrel, squared_rel, lg10, silog, delta1, delta2, delta3 72 | out = torch.tensor([irmse, imae, mse, rmse, mae, absrel, lg10, delta1, delta2, delta3], 73 | dtype=torch.float, device=torch.cuda.current_device()) 74 | 75 | # print('# eval 2') 76 | return out 77 | 78 | -------------------------------------------------------------------------------- /network/libs/base/encoding.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | """Encoding Data Parallel""" 12 | import threading 13 | import functools 14 | import torch 15 | from torch.autograd import Variable, Function 16 | import torch.cuda.comm as comm 17 | from torch.nn.parallel.data_parallel import DataParallel 18 | from torch.nn.parallel.parallel_apply import get_a_var 19 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 20 | 21 | torch_ver = torch.__version__[:3] 22 | 23 | __all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion', 'DataParallelEvaluation', 24 | 'patch_replication_callback'] 25 | 26 | 27 | def allreduce(*inputs): 28 | """Cross GPU all reduce autograd operation for calculate mean and 29 | variance in SyncBN. 30 | """ 31 | return AllReduce.apply(*inputs) 32 | 33 | 34 | class AllReduce(Function): 35 | @staticmethod 36 | def forward(ctx, num_inputs, *inputs): 37 | ctx.num_inputs = num_inputs 38 | ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)] 39 | inputs = [inputs[i:i + num_inputs] 40 | for i in range(0, len(inputs), num_inputs)] 41 | # sort before reduce sum 42 | inputs = sorted(inputs, key=lambda i: i[0].get_device()) 43 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) 44 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus) 45 | return tuple([t for tensors in outputs for t in tensors]) 46 | 47 | @staticmethod 48 | def backward(ctx, *inputs): 49 | inputs = [i.data for i in inputs] 50 | inputs = [inputs[i:i + ctx.num_inputs] 51 | for i in range(0, len(inputs), ctx.num_inputs)] 52 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) 53 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus) 54 | return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors]) 55 | 56 | 57 | class Reduce(Function): 58 | @staticmethod 59 | def forward(ctx, *inputs): 60 | ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))] 61 | inputs = sorted(inputs, key=lambda i: i.get_device()) 62 | return comm.reduce_add(inputs) 63 | 64 | @staticmethod 65 | def backward(ctx, gradOutput): 66 | return Broadcast.apply(ctx.target_gpus, gradOutput) 67 | 68 | 69 | class DataParallelModel(DataParallel): 70 | """Implements data parallelism at the module level. 71 | 72 | This container parallelizes the application of the given module by 73 | splitting the input across the specified devices by chunking in the 74 | batch dimension. 75 | In the forward pass, the module is replicated on each device, 76 | and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module. 77 | Note that the outputs are not gathered, please use compatible 78 | :class:`encoding.parallel.DataParallelCriterion`. 79 | 80 | The batch size should be larger than the number of GPUs used. It should 81 | also be an integer multiple of the number of GPUs so that each chunk is 82 | the same size (so that each GPU processes the same number of samples). 83 | 84 | Args: 85 | module: module to be parallelized 86 | device_ids: CUDA devices (default: all devices) 87 | 88 | Reference: 89 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 90 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 91 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 92 | 93 | Example:: 94 | 95 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) 96 | >>> y = net(x) 97 | """ 98 | 99 | def gather(self, outputs, output_device): 100 | return outputs 101 | 102 | def replicate(self, module, device_ids): 103 | modules = super(DataParallelModel, self).replicate(module, device_ids) 104 | execute_replication_callbacks(modules) 105 | return modules 106 | 107 | 108 | class DataParallelCriterion(DataParallel): 109 | """ 110 | Calculate loss in multiple-GPUs, which balance the memory usage for 111 | Semantic Segmentation. 112 | 113 | The targets are splitted across the specified devices by chunking in 114 | the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`. 115 | 116 | Reference: 117 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 118 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 119 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 120 | 121 | Example:: 122 | 123 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) 124 | >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2]) 125 | >>> y = net(x) 126 | >>> loss = criterion(y, target) 127 | """ 128 | 129 | def forward(self, inputs, *targets, **kwargs): 130 | # input should be already scatterd 131 | # scattering the targets instead 132 | if not self.device_ids: 133 | return self.module(inputs, *targets, **kwargs) 134 | targets, kwargs = self.scatter(targets, kwargs, self.device_ids) 135 | if len(self.device_ids) == 1: 136 | return self.module(inputs, *targets[0], **kwargs[0]) 137 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 138 | outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs) 139 | return Reduce.apply(*outputs) / len(outputs) 140 | # return self.gather(outputs, self.output_device).mean() 141 | 142 | 143 | def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): 144 | assert len(modules) == len(inputs) 145 | assert len(targets) == len(inputs) 146 | if kwargs_tup: 147 | assert len(modules) == len(kwargs_tup) 148 | else: 149 | kwargs_tup = ({},) * len(modules) 150 | if devices is not None: 151 | assert len(modules) == len(devices) 152 | else: 153 | devices = [None] * len(modules) 154 | 155 | lock = threading.Lock() 156 | results = {} 157 | if torch_ver != "0.3": 158 | grad_enabled = torch.is_grad_enabled() 159 | 160 | def _worker(i, module, input, target, kwargs, device=None): 161 | if torch_ver != "0.3": 162 | torch.set_grad_enabled(grad_enabled) 163 | if device is None: 164 | device = get_a_var(input).get_device() 165 | try: 166 | if not isinstance(input, tuple): 167 | input = (input,) 168 | with torch.cuda.device(device): 169 | output = module(*(input + target), **kwargs) 170 | with lock: 171 | results[i] = output 172 | except Exception as e: 173 | with lock: 174 | results[i] = e 175 | 176 | if len(modules) > 1: 177 | threads = [threading.Thread(target=_worker, 178 | args=(i, module, input, target, 179 | kwargs, device), ) 180 | for i, (module, input, target, kwargs, device) in 181 | enumerate(zip(modules, inputs, targets, kwargs_tup, devices))] 182 | 183 | for thread in threads: 184 | thread.start() 185 | for thread in threads: 186 | thread.join() 187 | else: 188 | _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) 189 | 190 | outputs = [] 191 | for i in range(len(inputs)): 192 | output = results[i] 193 | if isinstance(output, Exception): 194 | raise output 195 | outputs.append(output) 196 | return outputs 197 | 198 | 199 | ########################################################################### 200 | # Adapted from Synchronized-BatchNorm-PyTorch. 201 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 202 | # 203 | class CallbackContext(object): 204 | pass 205 | 206 | 207 | def execute_replication_callbacks(modules): 208 | """ 209 | Execute an replication callback `__data_parallel_replicate__` on each module created 210 | by original replication. 211 | 212 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 213 | 214 | Note that, as all modules are isomorphism, we assign each sub-module with a context 215 | (shared among multiple copies of this module on different devices). 216 | Through this context, different copies can share some information. 217 | 218 | We guarantee that the callback on the master copy (the first copy) will be called ahead 219 | of calling the callback of any slave copies. 220 | """ 221 | master_copy = modules[0] 222 | nr_modules = len(list(master_copy.modules())) 223 | ctxs = [CallbackContext() for _ in range(nr_modules)] 224 | 225 | for i, module in enumerate(modules): 226 | for j, m in enumerate(module.modules()): 227 | if hasattr(m, '__data_parallel_replicate__'): 228 | m.__data_parallel_replicate__(ctxs[j], i) 229 | 230 | 231 | def patch_replication_callback(data_parallel): 232 | """ 233 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 234 | Useful when you have customized `DataParallel` implementation. 235 | 236 | Examples: 237 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 238 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 239 | > patch_replication_callback(sync_bn) 240 | # this is equivalent to 241 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 242 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 243 | """ 244 | 245 | assert isinstance(data_parallel, DataParallel) 246 | 247 | old_replicate = data_parallel.replicate 248 | 249 | @functools.wraps(old_replicate) 250 | def new_replicate(module, device_ids): 251 | modules = old_replicate(module, device_ids) 252 | execute_replication_callbacks(modules) 253 | return modules 254 | 255 | data_parallel.replicate = new_replicate 256 | 257 | 258 | """ 259 | Added by WangXin 260 | Evaluation Using Multi-GPUs mode. 261 | """ 262 | 263 | 264 | class DataParallelEvaluation(DataParallel): 265 | 266 | def forward(self, inputs, *targets, **kwargs): 267 | # input should be already scatterd 268 | # scattering the targets instead 269 | if not self.device_ids: 270 | return self.module(inputs, *targets, **kwargs) 271 | targets, kwargs = self.scatter(targets, kwargs, self.device_ids) 272 | if len(self.device_ids) == 1: 273 | return self.module(inputs, *targets[0], **kwargs[0]) 274 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 275 | outputs = _evaluation_parallel_apply(replicas, inputs, targets, kwargs) 276 | return Reduce.apply(*outputs) / len(outputs) 277 | 278 | 279 | def _evaluation_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): 280 | assert len(modules) == len(inputs) 281 | assert len(targets) == len(inputs) 282 | if kwargs_tup: 283 | assert len(modules) == len(kwargs_tup) 284 | else: 285 | kwargs_tup = ({},) * len(modules) 286 | if devices is not None: 287 | assert len(modules) == len(devices) 288 | else: 289 | devices = [None] * len(modules) 290 | 291 | lock = threading.Lock() 292 | results = {} 293 | if torch_ver != "0.3": 294 | grad_enabled = torch.is_grad_enabled() 295 | 296 | def _worker(i, module, input, target, kwargs, device=None): 297 | if torch_ver != "0.3": 298 | torch.set_grad_enabled(grad_enabled) 299 | if device is None: 300 | device = get_a_var(input).get_device() 301 | try: 302 | if not isinstance(input, tuple): 303 | input = (input,) 304 | with torch.cuda.device(device): 305 | output = module(*(input + target), **kwargs) 306 | with lock: 307 | results[i] = output 308 | except Exception as e: 309 | with lock: 310 | results[i] = e 311 | 312 | if len(modules) > 1: 313 | threads = [threading.Thread(target=_worker, 314 | args=(i, module, input, target, 315 | kwargs, device), ) 316 | for i, (module, input, target, kwargs, device) in 317 | enumerate(zip(modules, inputs, targets, kwargs_tup, devices))] 318 | 319 | for thread in threads: 320 | thread.start() 321 | for thread in threads: 322 | thread.join() 323 | else: 324 | _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) 325 | 326 | outputs = [] 327 | for i in range(len(inputs)): 328 | output = results[i] 329 | if isinstance(output, Exception): 330 | raise output 331 | outputs.append(output) 332 | return outputs -------------------------------------------------------------------------------- /network/libs/base/operation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Time : 2019-06-24 22:08 5 | @Author : Wang Xin 6 | @Email : wangxin_buaa@163.com 7 | @File : operation.py 8 | """ 9 | 10 | import torch as th 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | class BatchNorm2d_Relu(nn.Module): 16 | 17 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 18 | track_running_stats=True, activation_type='leaky_relu'): 19 | super(BatchNorm2d_Relu, self).__init__() 20 | self.batchnorm = nn.BatchNorm2d(num_features, eps=eps, momentum=momentum, affine=affine, 21 | track_running_stats=track_running_stats) 22 | 23 | if activation_type.lower() == 'relu': 24 | self.activation = nn.ReLU() 25 | elif activation_type.lower() == 'leaky_relu': 26 | self.activation = nn.LeakyReLU() 27 | else: 28 | raise NotImplementedError 29 | 30 | def forward(self, input): 31 | return self.activation(self.batchnorm(input)) -------------------------------------------------------------------------------- /network/libs/base/pac.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Time : 2019-07-02 10:37 5 | @Author : Wang Xin 6 | @Email : wangxin_buaa@163.com 7 | @File : adaptive_conv.py 8 | """ 9 | 10 | 11 | """ 12 | Reference: https://github.com/NVlabs/pacnet 13 | """ 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | from torch.autograd.function import Function, once_differentiable 19 | from torch.nn.modules.utils import _pair 20 | from torch._thnn import type2backend 21 | 22 | import math 23 | from numbers import Number 24 | 25 | try: 26 | import pyinn as P 27 | 28 | has_pyinn = True 29 | except ImportError: 30 | P = None 31 | has_pyinn = False 32 | pass 33 | 34 | 35 | def nd2col(input_nd, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, transposed=False, 36 | use_pyinn_if_possible=False): 37 | """ 38 | Shape: 39 | - Input: :math:`(N, C, L_{in})` 40 | - Output: :math:`(N, C, *kernel_size, *L_{out})` where 41 | :math:`L_{out} = floor((L_{in} + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1)` for non-transposed 42 | :math:`L_{out} = (L_{in} - 1) * stride - 2 * padding + dilation * (kernel_size - 1) + 1 + output_padding` for transposed 43 | """ 44 | n_dims = len(input_nd.shape[2:]) 45 | kernel_size = (kernel_size,) * n_dims if isinstance(kernel_size, Number) else kernel_size 46 | stride = (stride,) * n_dims if isinstance(stride, Number) else stride 47 | padding = (padding,) * n_dims if isinstance(padding, Number) else padding 48 | output_padding = (output_padding,) * n_dims if isinstance(output_padding, Number) else output_padding 49 | dilation = (dilation,) * n_dims if isinstance(dilation, Number) else dilation 50 | 51 | if transposed: 52 | assert n_dims == 2, 'Only 2D is supported for fractional strides.' 53 | w_one = input_nd.new_ones(1, 1, 1, 1) 54 | pad = [(k - 1) * d - p for (k, d, p) in zip(kernel_size, dilation, padding)] 55 | input_nd = F.conv_transpose2d(input_nd, w_one, stride=stride) 56 | input_nd = F.pad(input_nd, (pad[1], pad[1] + output_padding[1], pad[0], pad[0] + output_padding[0])) 57 | stride = _pair(1) 58 | padding = _pair(0) 59 | 60 | (bs, nch), in_sz = input_nd.shape[:2], input_nd.shape[2:] 61 | out_sz = tuple([((i + 2 * p - d * (k - 1) - 1) // s + 1) 62 | for (i, k, d, p, s) in zip(in_sz, kernel_size, dilation, padding, stride)]) 63 | # Use PyINN if possible (about 15% faster) TODO confirm the speed-up 64 | if n_dims == 2 and dilation == 1 and has_pyinn and torch.cuda.is_available() and use_pyinn_if_possible: 65 | output = P.im2col(input_nd, kernel_size, stride, padding) 66 | else: 67 | output = F.unfold(input_nd, kernel_size, dilation, padding, stride) 68 | out_shape = (bs, nch) + tuple(kernel_size) + out_sz 69 | output = output.view(*out_shape).contiguous() 70 | return output 71 | 72 | 73 | class Conv2dFn(Function): 74 | @staticmethod 75 | def forward(ctx, input, kernel, kernel_size, stride=1, padding=0, dilation=1): 76 | (bs, ch), in_sz = input.shape[:2], input.shape[2:] 77 | if kernel.size(1) > 1 and kernel.size(1) != ch: 78 | raise ValueError('Incompatible input and kernel sizes.') 79 | ctx.input_size = in_sz 80 | ctx.kernel_size = _pair(kernel_size) 81 | ctx.kernel_ch = kernel.size(1) 82 | ctx.dilation = _pair(dilation) 83 | ctx.padding = _pair(padding) 84 | ctx.stride = _pair(stride) 85 | ctx.save_for_backward(input if ctx.needs_input_grad[1] else None, 86 | kernel if ctx.needs_input_grad[0] else None) 87 | ctx._backend = type2backend[input.type()] 88 | 89 | cols = F.unfold(input, ctx.kernel_size, ctx.dilation, ctx.padding, ctx.stride) 90 | 91 | output = cols.view(bs, ch, *kernel.shape[2:]) * kernel 92 | output = torch.einsum('ijklmn->ijmn', (output,)) 93 | 94 | return output.clone() # TODO check whether a .clone() is needed here 95 | 96 | @staticmethod 97 | @once_differentiable 98 | def backward(ctx, grad_output): 99 | input, kernel = ctx.saved_tensors 100 | grad_input = grad_kernel = None 101 | (bs, ch), out_sz = grad_output.shape[:2], grad_output.shape[2:] 102 | if ctx.needs_input_grad[0]: 103 | grad_input = grad_output.new() 104 | grad_im2col_output = torch.einsum('ijmn,izklmn->ijklmn', (grad_output, kernel)) 105 | grad_im2col_output = grad_im2col_output.view(bs, -1, out_sz[0] * out_sz[1]) 106 | ctx._backend.Im2Col_updateGradInput(ctx._backend.library_state, 107 | grad_im2col_output, 108 | grad_input, 109 | ctx.input_size[0], ctx.input_size[1], 110 | ctx.kernel_size[0], ctx.kernel_size[1], 111 | ctx.dilation[0], ctx.dilation[1], 112 | ctx.padding[0], ctx.padding[1], 113 | ctx.stride[0], ctx.stride[1]) 114 | if ctx.needs_input_grad[1]: 115 | cols = F.unfold(input, ctx.kernel_size, ctx.dilation, ctx.padding, ctx.stride) 116 | cols = cols.view(bs, ch, ctx.kernel_size[0], ctx.kernel_size[1], out_sz[0], out_sz[1]) 117 | grad_kernel = torch.einsum('ijmn,ijklmn->ijklmn', (grad_output, cols)) 118 | if ctx.kernel_ch == 1: 119 | grad_kernel = grad_kernel.sum(dim=1, keepdim=True) 120 | 121 | return grad_input, grad_kernel, None, None, None, None 122 | 123 | 124 | def conv2d(input, kernel, kernel_size, stride=1, padding=0, dilation=1, native_impl=False): 125 | kernel_size = _pair(kernel_size) 126 | stride = _pair(stride) 127 | padding = _pair(padding) 128 | dilation = _pair(dilation) 129 | 130 | if native_impl: 131 | bs, in_ch, in_h, in_w = input.shape 132 | out_h = (in_h + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[0] + 1 133 | out_w = (in_w + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1] + 1 134 | 135 | # im2col on input 136 | im_cols = nd2col(input, kernel_size, stride=stride, padding=padding, dilation=dilation) 137 | 138 | # main computation 139 | im_cols *= kernel 140 | output = im_cols.view(bs, in_ch, -1, out_h, out_w).sum(dim=2, keepdim=False) 141 | else: 142 | output = Conv2dFn.apply(input, kernel, kernel_size, stride, padding, dilation) 143 | 144 | return output -------------------------------------------------------------------------------- /network/libs/inplace_abn/__init__.py: -------------------------------------------------------------------------------- 1 | from .bn import ABN, InPlaceABN, InPlaceABNWrapper, InPlaceABNSync, InPlaceABNSyncWrapper 2 | from .misc import GlobalAvgPool2d 3 | from .residual import IdentityResidualBlock 4 | from .dense import DenseModule 5 | -------------------------------------------------------------------------------- /network/libs/inplace_abn/_ext/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.ffi import _wrap_function 3 | from .__ext import lib as _lib, ffi as _ffi 4 | 5 | __all__ = [] 6 | def _import_symbols(locals): 7 | for symbol in dir(_lib): 8 | fn = getattr(_lib, symbol) 9 | if callable(fn): 10 | locals[symbol] = _wrap_function(fn, _ffi) 11 | else: 12 | locals[symbol] = fn 13 | __all__.append(symbol) 14 | 15 | _import_symbols(locals()) 16 | -------------------------------------------------------------------------------- /network/libs/inplace_abn/bn.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict, Iterable 2 | from itertools import repeat 3 | 4 | try: 5 | # python 3 6 | from queue import Queue 7 | except ImportError: 8 | # python 2 9 | from Queue import Queue 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.autograd as autograd 14 | 15 | from .functions import inplace_abn, inplace_abn_sync 16 | 17 | 18 | def _pair(x): 19 | if isinstance(x, Iterable): 20 | return x 21 | return tuple(repeat(x, 2)) 22 | 23 | 24 | class ABN(nn.Sequential): 25 | """Activated Batch Normalization 26 | 27 | This gathers a `BatchNorm2d` and an activation function in a single module 28 | """ 29 | 30 | def __init__(self, num_features, activation=nn.ReLU(inplace=True), **kwargs): 31 | """Creates an Activated Batch Normalization module 32 | 33 | Parameters 34 | ---------- 35 | num_features : int 36 | Number of feature channels in the input and output. 37 | activation : nn.Module 38 | Module used as an activation function. 39 | kwargs 40 | All other arguments are forwarded to the `BatchNorm2d` constructor. 41 | """ 42 | super(ABN, self).__init__(OrderedDict([ 43 | ("bn", nn.BatchNorm2d(num_features, **kwargs)), 44 | ("act", activation) 45 | ])) 46 | 47 | 48 | class InPlaceABN(nn.Module): 49 | """InPlace Activated Batch Normalization""" 50 | 51 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01): 52 | """Creates an InPlace Activated Batch Normalization module 53 | 54 | Parameters 55 | ---------- 56 | num_features : int 57 | Number of feature channels in the input and output. 58 | eps : float 59 | Small constant to prevent numerical issues. 60 | momentum : float 61 | Momentum factor applied to compute running statistics as. 62 | affine : bool 63 | If `True` apply learned scale and shift transformation after normalization. 64 | activation : str 65 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. 66 | slope : float 67 | Negative slope for the `leaky_relu` activation. 68 | """ 69 | super(InPlaceABN, self).__init__() 70 | self.num_features = num_features 71 | self.affine = affine 72 | self.eps = eps 73 | self.momentum = momentum 74 | self.activation = activation 75 | self.slope = slope 76 | if self.affine: 77 | self.weight = nn.Parameter(torch.Tensor(num_features)) 78 | self.bias = nn.Parameter(torch.Tensor(num_features)) 79 | else: 80 | self.register_parameter('weight', None) 81 | self.register_parameter('bias', None) 82 | self.register_buffer('running_mean', torch.zeros(num_features)) 83 | self.register_buffer('running_var', torch.ones(num_features)) 84 | self.reset_parameters() 85 | 86 | def reset_parameters(self): 87 | self.running_mean.zero_() 88 | self.running_var.fill_(1) 89 | if self.affine: 90 | self.weight.data.fill_(1) 91 | self.bias.data.zero_() 92 | 93 | def forward(self, x): 94 | return inplace_abn(x, self.weight, self.bias, autograd.Variable(self.running_mean), 95 | autograd.Variable(self.running_var), self.training, self.momentum, self.eps, 96 | self.activation, self.slope) 97 | 98 | def __repr__(self): 99 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ 100 | ' affine={affine}, activation={activation}' 101 | if self.activation == "leaky_relu": 102 | rep += ' slope={slope})' 103 | else: 104 | rep += ')' 105 | return rep.format(name=self.__class__.__name__, **self.__dict__) 106 | 107 | 108 | class InPlaceABNSync(nn.Module): 109 | """InPlace Activated Batch Normalization with cross-GPU synchronization 110 | 111 | This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DataParallel`. 112 | """ 113 | 114 | def __init__(self, num_features, devices=None, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", 115 | slope=0.01): 116 | """Creates a synchronized, InPlace Activated Batch Normalization module 117 | 118 | Parameters 119 | ---------- 120 | num_features : int 121 | Number of feature channels in the input and output. 122 | devices : list of int or None 123 | IDs of the GPUs that will run the replicas of this module. 124 | eps : float 125 | Small constant to prevent numerical issues. 126 | momentum : float 127 | Momentum factor applied to compute running statistics as. 128 | affine : bool 129 | If `True` apply learned scale and shift transformation after normalization. 130 | activation : str 131 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. 132 | slope : float 133 | Negative slope for the `leaky_relu` activation. 134 | """ 135 | super(InPlaceABNSync, self).__init__() 136 | self.num_features = num_features 137 | self.devices = devices if devices else list(range(torch.cuda.device_count())) 138 | self.affine = affine 139 | self.eps = eps 140 | self.momentum = momentum 141 | self.activation = activation 142 | self.slope = slope 143 | if self.affine: 144 | self.weight = nn.Parameter(torch.Tensor(num_features)) 145 | self.bias = nn.Parameter(torch.Tensor(num_features)) 146 | else: 147 | self.register_parameter('weight', None) 148 | self.register_parameter('bias', None) 149 | self.register_buffer('running_mean', torch.zeros(num_features)) 150 | self.register_buffer('running_var', torch.ones(num_features)) 151 | self.reset_parameters() 152 | 153 | # Initialize queues 154 | self.worker_ids = self.devices[1:] 155 | self.master_queue = Queue(len(self.worker_ids)) 156 | self.worker_queues = [Queue(1) for _ in self.worker_ids] 157 | 158 | def reset_parameters(self): 159 | self.running_mean.zero_() 160 | self.running_var.fill_(1) 161 | if self.affine: 162 | self.weight.data.fill_(1) 163 | self.bias.data.zero_() 164 | 165 | def forward(self, x): 166 | if x.get_device() == self.devices[0]: 167 | # Master mode 168 | extra = { 169 | "is_master": True, 170 | "master_queue": self.master_queue, 171 | "worker_queues": self.worker_queues, 172 | "worker_ids": self.worker_ids 173 | } 174 | else: 175 | # Worker mode 176 | extra = { 177 | "is_master": False, 178 | "master_queue": self.master_queue, 179 | "worker_queue": self.worker_queues[self.worker_ids.index(x.get_device())] 180 | } 181 | 182 | return inplace_abn_sync(x, self.weight, self.bias, autograd.Variable(self.running_mean), 183 | autograd.Variable(self.running_var), extra, self.training, self.momentum, self.eps, 184 | self.activation, self.slope) 185 | 186 | def __repr__(self): 187 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ 188 | ' affine={affine}, devices={devices}, activation={activation}' 189 | if self.activation == "leaky_relu": 190 | rep += ' slope={slope})' 191 | else: 192 | rep += ')' 193 | return rep.format(name=self.__class__.__name__, **self.__dict__) 194 | 195 | 196 | class InPlaceABNWrapper(nn.Module): 197 | """Wrapper module to make `InPlaceABN` compatible with `ABN`""" 198 | 199 | def __init__(self, *args, **kwargs): 200 | super(InPlaceABNWrapper, self).__init__() 201 | self.bn = InPlaceABN(*args, **kwargs) 202 | 203 | def forward(self, input): 204 | return self.bn(input) 205 | 206 | 207 | class InPlaceABNSyncWrapper(nn.Module): 208 | """Wrapper module to make `InPlaceABNSync` compatible with `ABN`""" 209 | 210 | def __init__(self, *args, **kwargs): 211 | super(InPlaceABNSyncWrapper, self).__init__() 212 | self.bn = InPlaceABNSync(*args, **kwargs) 213 | 214 | def forward(self, input): 215 | return self.bn(input) 216 | -------------------------------------------------------------------------------- /network/libs/inplace_abn/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torch.utils.ffi import create_extension 4 | 5 | sources = ['src/lib_cffi.cpp'] 6 | headers = ['src/lib_cffi.h'] 7 | extra_objects = ['src/bn.o'] 8 | with_cuda = True 9 | 10 | this_file = os.path.dirname(os.path.realpath(__file__)) 11 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 12 | 13 | ffi = create_extension( 14 | '_ext', 15 | headers=headers, 16 | sources=sources, 17 | relative_to=__file__, 18 | with_cuda=with_cuda, 19 | extra_objects=extra_objects, 20 | extra_compile_args=["-std=c++11"] 21 | ) 22 | 23 | if __name__ == '__main__': 24 | ffi.build() 25 | -------------------------------------------------------------------------------- /network/libs/inplace_abn/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Configuration 4 | CUDA_GENCODE="\ 5 | -gencode=arch=compute_70,code=sm_70 \ 6 | -gencode=arch=compute_61,code=sm_61 \ 7 | -gencode=arch=compute_52,code=sm_52 \ 8 | -gencode=arch=compute_50,code=sm_50" 9 | 10 | 11 | cd src 12 | nvcc -I/usr/local/cuda/include --expt-extended-lambda -O3 -c -o bn.o bn.cu -x cu -Xcompiler -fPIC -std=c++11 ${CUDA_GENCODE} 13 | cd .. 14 | -------------------------------------------------------------------------------- /network/libs/inplace_abn/dense.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .bn import ABN 7 | 8 | 9 | class DenseModule(nn.Module): 10 | def __init__(self, in_channels, growth, layers, bottleneck_factor=4, norm_act=ABN, dilation=1): 11 | super(DenseModule, self).__init__() 12 | self.in_channels = in_channels 13 | self.growth = growth 14 | self.layers = layers 15 | 16 | self.convs1 = nn.ModuleList() 17 | self.convs3 = nn.ModuleList() 18 | for i in range(self.layers): 19 | self.convs1.append(nn.Sequential(OrderedDict([ 20 | ("bn", norm_act(in_channels)), 21 | ("conv", nn.Conv2d(in_channels, self.growth * bottleneck_factor, 1, bias=False)) 22 | ]))) 23 | self.convs3.append(nn.Sequential(OrderedDict([ 24 | ("bn", norm_act(self.growth * bottleneck_factor)), 25 | ("conv", nn.Conv2d(self.growth * bottleneck_factor, self.growth, 3, padding=dilation, bias=False, 26 | dilation=dilation)) 27 | ]))) 28 | in_channels += self.growth 29 | 30 | @property 31 | def out_channels(self): 32 | return self.in_channels + self.growth * self.layers 33 | 34 | def forward(self, x): 35 | inputs = [x] 36 | for i in range(self.layers): 37 | x = torch.cat(inputs, dim=1) 38 | x = self.convs1[i](x) 39 | x = self.convs3[i](x) 40 | inputs += [x] 41 | 42 | return torch.cat(inputs, dim=1) -------------------------------------------------------------------------------- /network/libs/inplace_abn/functions.py: -------------------------------------------------------------------------------- 1 | import torch.autograd as autograd 2 | import torch.cuda.comm as comm 3 | from torch.autograd.function import once_differentiable 4 | 5 | from . import _ext 6 | 7 | # Activation names 8 | ACT_LEAKY_RELU = "leaky_relu" 9 | ACT_ELU = "elu" 10 | ACT_NONE = "none" 11 | 12 | 13 | def _check(fn, *args, **kwargs): 14 | success = fn(*args, **kwargs) 15 | if not success: 16 | raise RuntimeError("CUDA Error encountered in {}".format(fn)) 17 | 18 | 19 | def _broadcast_shape(x): 20 | out_size = [] 21 | for i, s in enumerate(x.size()): 22 | if i != 1: 23 | out_size.append(1) 24 | else: 25 | out_size.append(s) 26 | return out_size 27 | 28 | 29 | def _reduce(x): 30 | if len(x.size()) == 2: 31 | return x.sum(dim=0) 32 | else: 33 | n, c = x.size()[0:2] 34 | return x.contiguous().view((n, c, -1)).sum(2).sum(0) 35 | 36 | 37 | def _count_samples(x): 38 | count = 1 39 | for i, s in enumerate(x.size()): 40 | if i != 1: 41 | count *= s 42 | return count 43 | 44 | 45 | def _act_forward(ctx, x): 46 | if ctx.activation == ACT_LEAKY_RELU: 47 | _check(_ext.leaky_relu_cuda, x, ctx.slope) 48 | elif ctx.activation == ACT_ELU: 49 | _check(_ext.elu_cuda, x) 50 | elif ctx.activation == ACT_NONE: 51 | pass 52 | 53 | 54 | def _act_backward(ctx, x, dx): 55 | if ctx.activation == ACT_LEAKY_RELU: 56 | _check(_ext.leaky_relu_backward_cuda, x, dx, ctx.slope) 57 | _check(_ext.leaky_relu_cuda, x, 1. / ctx.slope) 58 | elif ctx.activation == ACT_ELU: 59 | _check(_ext.elu_backward_cuda, x, dx) 60 | _check(_ext.elu_inv_cuda, x) 61 | elif ctx.activation == ACT_NONE: 62 | pass 63 | 64 | 65 | def _check_contiguous(*args): 66 | if not all([mod is None or mod.is_contiguous() for mod in args]): 67 | raise ValueError("Non-contiguous input") 68 | 69 | 70 | class InPlaceABN(autograd.Function): 71 | @staticmethod 72 | def forward(ctx, x, weight, bias, running_mean, running_var, 73 | training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01): 74 | # Save context 75 | ctx.training = training 76 | ctx.momentum = momentum 77 | ctx.eps = eps 78 | ctx.activation = activation 79 | ctx.slope = slope 80 | 81 | n = _count_samples(x) 82 | 83 | if ctx.training: 84 | mean = x.new().resize_as_(running_mean) 85 | var = x.new().resize_as_(running_var) 86 | _check_contiguous(x, mean, var) 87 | _check(_ext.bn_mean_var_cuda, x, mean, var) 88 | 89 | # Update running stats 90 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) 91 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * n / (n - 1)) 92 | else: 93 | mean, var = running_mean, running_var 94 | 95 | _check_contiguous(x, mean, var, weight, bias) 96 | _check(_ext.bn_forward_cuda, 97 | x, mean, var, 98 | weight if weight is not None else x.new(), 99 | bias if bias is not None else x.new(), 100 | x, x, ctx.eps) 101 | 102 | # Activation 103 | _act_forward(ctx, x) 104 | 105 | # Output 106 | ctx.var = var 107 | ctx.save_for_backward(x, weight, bias, running_mean, running_var) 108 | ctx.mark_dirty(x) 109 | return x 110 | 111 | @staticmethod 112 | @once_differentiable 113 | def backward(ctx, dz): 114 | z, weight, bias, running_mean, running_var = ctx.saved_tensors 115 | dz = dz.contiguous() 116 | 117 | # Undo activation 118 | _act_backward(ctx, z, dz) 119 | 120 | if ctx.needs_input_grad[0]: 121 | dx = dz.new().resize_as_(dz) 122 | else: 123 | dx = None 124 | 125 | if ctx.needs_input_grad[1]: 126 | dweight = dz.new().resize_as_(running_mean).zero_() 127 | else: 128 | dweight = None 129 | 130 | if ctx.needs_input_grad[2]: 131 | dbias = dz.new().resize_as_(running_mean).zero_() 132 | else: 133 | dbias = None 134 | 135 | if ctx.training: 136 | edz = dz.new().resize_as_(running_mean) 137 | eydz = dz.new().resize_as_(running_mean) 138 | _check_contiguous(z, dz, weight, bias, edz, eydz) 139 | _check(_ext.bn_edz_eydz_cuda, 140 | z, dz, 141 | weight if weight is not None else dz.new(), 142 | bias if bias is not None else dz.new(), 143 | edz, eydz, ctx.eps) 144 | else: 145 | # TODO: implement CUDA backward for inference mode 146 | edz = dz.new().resize_as_(running_mean).zero_() 147 | eydz = dz.new().resize_as_(running_mean).zero_() 148 | 149 | _check_contiguous(dz, z, ctx.var, weight, bias, edz, eydz, dx, dweight, dbias) 150 | _check(_ext.bn_backard_cuda, 151 | dz, z, ctx.var, 152 | weight if weight is not None else dz.new(), 153 | bias if bias is not None else dz.new(), 154 | edz, eydz, 155 | dx if dx is not None else dz.new(), 156 | dweight if dweight is not None else dz.new(), 157 | dbias if dbias is not None else dz.new(), 158 | ctx.eps) 159 | 160 | del ctx.var 161 | 162 | return dx, dweight, dbias, None, None, None, None, None, None, None 163 | 164 | 165 | class InPlaceABNSync(autograd.Function): 166 | @classmethod 167 | def forward(cls, ctx, x, weight, bias, running_mean, running_var, 168 | extra, training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01): 169 | # Save context 170 | cls._parse_extra(ctx, extra) 171 | ctx.training = training 172 | ctx.momentum = momentum 173 | ctx.eps = eps 174 | ctx.activation = activation 175 | ctx.slope = slope 176 | 177 | n = _count_samples(x) * (ctx.master_queue.maxsize + 1) 178 | 179 | if ctx.training: 180 | mean = x.new().resize_(1, running_mean.size(0)) 181 | var = x.new().resize_(1, running_var.size(0)) 182 | _check_contiguous(x, mean, var) 183 | _check(_ext.bn_mean_var_cuda, x, mean, var) 184 | 185 | if ctx.is_master: 186 | means, vars = [mean], [var] 187 | for _ in range(ctx.master_queue.maxsize): 188 | mean_w, var_w = ctx.master_queue.get() 189 | ctx.master_queue.task_done() 190 | means.append(mean_w) 191 | vars.append(var_w) 192 | 193 | means = comm.gather(means) 194 | vars = comm.gather(vars) 195 | 196 | mean = means.mean(0) 197 | var = (vars + (mean - means) ** 2).mean(0) 198 | 199 | tensors = comm.broadcast_coalesced((mean, var), [mean.get_device()] + ctx.worker_ids) 200 | for ts, queue in zip(tensors[1:], ctx.worker_queues): 201 | queue.put(ts) 202 | else: 203 | ctx.master_queue.put((mean, var)) 204 | mean, var = ctx.worker_queue.get() 205 | ctx.worker_queue.task_done() 206 | 207 | # Update running stats 208 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) 209 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * n / (n - 1)) 210 | else: 211 | mean, var = running_mean, running_var 212 | 213 | _check_contiguous(x, mean, var, weight, bias) 214 | _check(_ext.bn_forward_cuda, 215 | x, mean, var, 216 | weight if weight is not None else x.new(), 217 | bias if bias is not None else x.new(), 218 | x, x, ctx.eps) 219 | 220 | # Activation 221 | _act_forward(ctx, x) 222 | 223 | # Output 224 | ctx.var = var 225 | ctx.save_for_backward(x, weight, bias, running_mean, running_var) 226 | ctx.mark_dirty(x) 227 | return x 228 | 229 | @staticmethod 230 | @once_differentiable 231 | def backward(ctx, dz): 232 | z, weight, bias, running_mean, running_var = ctx.saved_tensors 233 | dz = dz.contiguous() 234 | 235 | # Undo activation 236 | _act_backward(ctx, z, dz) 237 | 238 | if ctx.needs_input_grad[0]: 239 | dx = dz.new().resize_as_(dz) 240 | else: 241 | dx = None 242 | 243 | if ctx.needs_input_grad[1]: 244 | dweight = dz.new().resize_as_(running_mean).zero_() 245 | else: 246 | dweight = None 247 | 248 | if ctx.needs_input_grad[2]: 249 | dbias = dz.new().resize_as_(running_mean).zero_() 250 | else: 251 | dbias = None 252 | 253 | if ctx.training: 254 | edz = dz.new().resize_as_(running_mean) 255 | eydz = dz.new().resize_as_(running_mean) 256 | _check_contiguous(z, dz, weight, bias, edz, eydz) 257 | _check(_ext.bn_edz_eydz_cuda, 258 | z, dz, 259 | weight if weight is not None else dz.new(), 260 | bias if bias is not None else dz.new(), 261 | edz, eydz, ctx.eps) 262 | 263 | if ctx.is_master: 264 | edzs, eydzs = [edz], [eydz] 265 | for _ in range(len(ctx.worker_queues)): 266 | edz_w, eydz_w = ctx.master_queue.get() 267 | ctx.master_queue.task_done() 268 | edzs.append(edz_w) 269 | eydzs.append(eydz_w) 270 | 271 | edz = comm.reduce_add(edzs) / (ctx.master_queue.maxsize + 1) 272 | eydz = comm.reduce_add(eydzs) / (ctx.master_queue.maxsize + 1) 273 | 274 | tensors = comm.broadcast_coalesced((edz, eydz), [edz.get_device()] + ctx.worker_ids) 275 | for ts, queue in zip(tensors[1:], ctx.worker_queues): 276 | queue.put(ts) 277 | else: 278 | ctx.master_queue.put((edz, eydz)) 279 | edz, eydz = ctx.worker_queue.get() 280 | ctx.worker_queue.task_done() 281 | else: 282 | edz = dz.new().resize_as_(running_mean).zero_() 283 | eydz = dz.new().resize_as_(running_mean).zero_() 284 | 285 | _check_contiguous(dz, z, ctx.var, weight, bias, edz, eydz, dx, dweight, dbias) 286 | _check(_ext.bn_backard_cuda, 287 | dz, z, ctx.var, 288 | weight if weight is not None else dz.new(), 289 | bias if bias is not None else dz.new(), 290 | edz, eydz, 291 | dx if dx is not None else dz.new(), 292 | dweight if dweight is not None else dz.new(), 293 | dbias if dbias is not None else dz.new(), 294 | ctx.eps) 295 | 296 | del ctx.var 297 | 298 | return dx, dweight, dbias, None, None, None, None, None, None, None, None 299 | 300 | @staticmethod 301 | def _parse_extra(ctx, extra): 302 | ctx.is_master = extra["is_master"] 303 | if ctx.is_master: 304 | ctx.master_queue = extra["master_queue"] 305 | ctx.worker_queues = extra["worker_queues"] 306 | ctx.worker_ids = extra["worker_ids"] 307 | else: 308 | ctx.master_queue = extra["master_queue"] 309 | ctx.worker_queue = extra["worker_queue"] 310 | 311 | 312 | inplace_abn = InPlaceABN.apply 313 | inplace_abn_sync = InPlaceABNSync.apply 314 | 315 | __all__ = ["inplace_abn", "inplace_abn_sync"] 316 | -------------------------------------------------------------------------------- /network/libs/inplace_abn/misc.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class GlobalAvgPool2d(nn.Module): 5 | def __init__(self): 6 | """Global average pooling over the input's spatial dimensions""" 7 | super(GlobalAvgPool2d, self).__init__() 8 | 9 | def forward(self, inputs): 10 | in_size = inputs.size() 11 | return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2) 12 | -------------------------------------------------------------------------------- /network/libs/inplace_abn/residual.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch.nn as nn 4 | 5 | from .bn import ABN 6 | 7 | 8 | class IdentityResidualBlock(nn.Module): 9 | def __init__(self, 10 | in_channels, 11 | channels, 12 | stride=1, 13 | dilation=1, 14 | groups=1, 15 | norm_act=ABN, 16 | dropout=None): 17 | """Configurable identity-mapping residual block 18 | 19 | Parameters 20 | ---------- 21 | in_channels : int 22 | Number of input channels. 23 | channels : list of int 24 | Number of channels in the internal feature maps. Can either have two or three elements: if three construct 25 | a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then 26 | `3 x 3` then `1 x 1` convolutions. 27 | stride : int 28 | Stride of the first `3 x 3` convolution 29 | dilation : int 30 | Dilation to apply to the `3 x 3` convolutions. 31 | groups : int 32 | Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with 33 | bottleneck blocks. 34 | norm_act : callable 35 | Function to create normalization / activation Module. 36 | dropout: callable 37 | Function to create Dropout Module. 38 | """ 39 | super(IdentityResidualBlock, self).__init__() 40 | 41 | # Check parameters for inconsistencies 42 | if len(channels) != 2 and len(channels) != 3: 43 | raise ValueError("channels must contain either two or three values") 44 | if len(channels) == 2 and groups != 1: 45 | raise ValueError("groups > 1 are only valid if len(channels) == 3") 46 | 47 | is_bottleneck = len(channels) == 3 48 | need_proj_conv = stride != 1 or in_channels != channels[-1] 49 | 50 | self.bn1 = norm_act(in_channels) 51 | if not is_bottleneck: 52 | layers = [ 53 | ("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False, 54 | dilation=dilation)), 55 | ("bn2", norm_act(channels[0])), 56 | ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False, 57 | dilation=dilation)) 58 | ] 59 | if dropout is not None: 60 | layers = layers[0:2] + [("dropout", dropout())] + layers[2:] 61 | else: 62 | layers = [ 63 | ("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=stride, padding=0, bias=False)), 64 | ("bn2", norm_act(channels[0])), 65 | ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False, 66 | groups=groups, dilation=dilation)), 67 | ("bn3", norm_act(channels[1])), 68 | ("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False)) 69 | ] 70 | if dropout is not None: 71 | layers = layers[0:4] + [("dropout", dropout())] + layers[4:] 72 | self.convs = nn.Sequential(OrderedDict(layers)) 73 | 74 | if need_proj_conv: 75 | self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False) 76 | 77 | def forward(self, x): 78 | if hasattr(self, "proj_conv"): 79 | bn1 = self.bn1(x) 80 | shortcut = self.proj_conv(bn1) 81 | else: 82 | shortcut = x.clone() 83 | bn1 = self.bn1(x) 84 | 85 | out = self.convs(bn1) 86 | out.add_(shortcut) 87 | 88 | return out 89 | -------------------------------------------------------------------------------- /network/libs/inplace_abn/src/bn.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "common.h" 6 | #include "bn.h" 7 | 8 | /* 9 | * Device functions and data structures 10 | */ 11 | struct Float2 { 12 | float v1, v2; 13 | __device__ Float2() {} 14 | __device__ Float2(float _v1, float _v2) : v1(_v1), v2(_v2) {} 15 | __device__ Float2(float v) : v1(v), v2(v) {} 16 | __device__ Float2(int v) : v1(v), v2(v) {} 17 | __device__ Float2 &operator+=(const Float2 &a) { 18 | v1 += a.v1; 19 | v2 += a.v2; 20 | return *this; 21 | } 22 | }; 23 | 24 | struct SumOp { 25 | __device__ SumOp(const float *t, int c, int s) 26 | : tensor(t), C(c), S(s) {} 27 | __device__ __forceinline__ float operator()(int batch, int plane, int n) { 28 | return tensor[(batch * C + plane) * S + n]; 29 | } 30 | const float *tensor; 31 | const int C; 32 | const int S; 33 | }; 34 | 35 | struct VarOp { 36 | __device__ VarOp(float m, const float *t, int c, int s) 37 | : mean(m), tensor(t), C(c), S(s) {} 38 | __device__ __forceinline__ float operator()(int batch, int plane, int n) { 39 | float val = tensor[(batch * C + plane) * S + n]; 40 | return (val - mean) * (val - mean); 41 | } 42 | const float mean; 43 | const float *tensor; 44 | const int C; 45 | const int S; 46 | }; 47 | 48 | struct GradOp { 49 | __device__ GradOp(float _gamma, float _beta, const float *_z, const float *_dz, int c, int s) 50 | : gamma(_gamma), beta(_beta), z(_z), dz(_dz), C(c), S(s) {} 51 | __device__ __forceinline__ Float2 operator()(int batch, int plane, int n) { 52 | float _y = (z[(batch * C + plane) * S + n] - beta) / gamma; 53 | float _dz = dz[(batch * C + plane) * S + n]; 54 | return Float2(_dz, _y * _dz); 55 | } 56 | const float gamma; 57 | const float beta; 58 | const float *z; 59 | const float *dz; 60 | const int C; 61 | const int S; 62 | }; 63 | 64 | static __device__ __forceinline__ float warpSum(float val) { 65 | #if __CUDA_ARCH__ >= 300 66 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) { 67 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); 68 | } 69 | #else 70 | __shared__ float values[MAX_BLOCK_SIZE]; 71 | values[threadIdx.x] = val; 72 | __threadfence_block(); 73 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE; 74 | for (int i = 1; i < WARP_SIZE; i++) { 75 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)]; 76 | } 77 | #endif 78 | return val; 79 | } 80 | 81 | static __device__ __forceinline__ Float2 warpSum(Float2 value) { 82 | value.v1 = warpSum(value.v1); 83 | value.v2 = warpSum(value.v2); 84 | return value; 85 | } 86 | 87 | template 88 | __device__ T reduce(Op op, int plane, int N, int C, int S) { 89 | T sum = (T)0; 90 | for (int batch = 0; batch < N; ++batch) { 91 | for (int x = threadIdx.x; x < S; x += blockDim.x) { 92 | sum += op(batch, plane, x); 93 | } 94 | } 95 | 96 | // sum over NumThreads within a warp 97 | sum = warpSum(sum); 98 | 99 | // 'transpose', and reduce within warp again 100 | __shared__ T shared[32]; 101 | __syncthreads(); 102 | if (threadIdx.x % WARP_SIZE == 0) { 103 | shared[threadIdx.x / WARP_SIZE] = sum; 104 | } 105 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { 106 | // zero out the other entries in shared 107 | shared[threadIdx.x] = (T)0; 108 | } 109 | __syncthreads(); 110 | if (threadIdx.x / WARP_SIZE == 0) { 111 | sum = warpSum(shared[threadIdx.x]); 112 | if (threadIdx.x == 0) { 113 | shared[0] = sum; 114 | } 115 | } 116 | __syncthreads(); 117 | 118 | // Everyone picks it up, should be broadcast into the whole gradInput 119 | return shared[0]; 120 | } 121 | 122 | /* 123 | * Kernels 124 | */ 125 | __global__ void mean_var_kernel(const float *x, float *mean, float *var, int N, 126 | int C, int S) { 127 | int plane = blockIdx.x; 128 | float norm = 1.f / (N * S); 129 | 130 | float _mean = reduce(SumOp(x, C, S), plane, N, C, S) * norm; 131 | __syncthreads(); 132 | float _var = reduce(VarOp(_mean, x, C, S), plane, N, C, S) * norm; 133 | 134 | if (threadIdx.x == 0) { 135 | mean[plane] = _mean; 136 | var[plane] = _var; 137 | } 138 | } 139 | 140 | __global__ void forward_kernel(const float *x, const float *mean, 141 | const float *var, const float *weight, 142 | const float *bias, float *y, float *z, float eps, 143 | int N, int C, int S) { 144 | int plane = blockIdx.x; 145 | 146 | float _mean = mean[plane]; 147 | float _var = var[plane]; 148 | float invStd = 0; 149 | if (_var != 0.f || eps != 0.f) { 150 | invStd = 1 / sqrt(_var + eps); 151 | } 152 | 153 | float gamma = weight != 0 ? abs(weight[plane]) + eps : 1.f; 154 | float beta = bias != 0 ? bias[plane] : 0.f; 155 | for (int batch = 0; batch < N; ++batch) { 156 | for (int n = threadIdx.x; n < S; n += blockDim.x) { 157 | float _x = x[(batch * C + plane) * S + n]; 158 | float _y = (_x - _mean) * invStd; 159 | float _z = _y * gamma + beta; 160 | 161 | y[(batch * C + plane) * S + n] = _y; 162 | z[(batch * C + plane) * S + n] = _z; 163 | } 164 | } 165 | } 166 | 167 | __global__ void edz_eydz_kernel(const float *z, const float *dz, const float *weight, const float *bias, 168 | float *edz, float *eydz, float eps, int N, int C, int S) { 169 | int plane = blockIdx.x; 170 | float norm = 1.f / (N * S); 171 | 172 | float gamma = weight != 0 ? abs(weight[plane]) + eps : 1.f; 173 | float beta = bias != 0 ? bias[plane] : 0.f; 174 | 175 | Float2 res = reduce(GradOp(gamma, beta, z, dz, C, S), plane, N, C, S); 176 | float _edz = res.v1 * norm; 177 | float _eydz = res.v2 * norm; 178 | __syncthreads(); 179 | 180 | if (threadIdx.x == 0) { 181 | edz[plane] = _edz; 182 | eydz[plane] = _eydz; 183 | } 184 | } 185 | 186 | __global__ void backward_kernel(const float *dz, const float *z, const float *var, const float *weight, 187 | const float *bias, const float *edz, const float *eydz, float *dx, float *dweight, 188 | float *dbias, float eps, int N, int C, int S) { 189 | int plane = blockIdx.x; 190 | float _edz = edz[plane]; 191 | float _eydz = eydz[plane]; 192 | 193 | float gamma = weight != 0 ? abs(weight[plane]) + eps : 1.f; 194 | float beta = bias != 0 ? bias[plane] : 0.f; 195 | 196 | if (dx != 0) { 197 | float _var = var[plane]; 198 | float invStd = 0; 199 | if (_var != 0.f || eps != 0.f) { 200 | invStd = 1 / sqrt(_var + eps); 201 | } 202 | 203 | float mul = gamma * invStd; 204 | 205 | for (int batch = 0; batch < N; ++batch) { 206 | for (int n = threadIdx.x; n < S; n += blockDim.x) { 207 | float _dz = dz[(batch * C + plane) * S + n]; 208 | float _y = (z[(batch * C + plane) * S + n] - beta) / gamma; 209 | dx[(batch * C + plane) * S + n] = (_dz - _edz - _y * _eydz) * mul; 210 | } 211 | } 212 | } 213 | 214 | if (dweight != 0 || dbias != 0) { 215 | float norm = N * S; 216 | 217 | if (dweight != 0) { 218 | if (threadIdx.x == 0) { 219 | if (weight[plane] > 0) 220 | dweight[plane] += _eydz * norm; 221 | else if (weight[plane] < 0) 222 | dweight[plane] -= _eydz * norm; 223 | } 224 | } 225 | 226 | if (dbias != 0) { 227 | if (threadIdx.x == 0) { 228 | dbias[plane] += _edz * norm; 229 | } 230 | } 231 | } 232 | } 233 | 234 | /* 235 | * Implementations 236 | */ 237 | extern "C" int _bn_mean_var_cuda(int N, int C, int S, const float *x, float *mean, 238 | float *var, cudaStream_t stream) { 239 | // Run kernel 240 | dim3 blocks(C); 241 | dim3 threads(getNumThreads(S)); 242 | mean_var_kernel<<>>(x, mean, var, N, C, S); 243 | 244 | // Check for errors 245 | cudaError_t err = cudaGetLastError(); 246 | if (err != cudaSuccess) 247 | return 0; 248 | else 249 | return 1; 250 | } 251 | 252 | extern "C" int _bn_forward_cuda(int N, int C, int S, const float *x, 253 | const float *mean, const float *var, 254 | const float *weight, const float *bias, float *y, 255 | float *z, float eps, cudaStream_t stream) { 256 | // Run kernel 257 | dim3 blocks(C); 258 | dim3 threads(getNumThreads(S)); 259 | forward_kernel<<>>(x, mean, var, weight, bias, y, 260 | z, eps, N, C, S); 261 | 262 | // Check for errors 263 | cudaError_t err = cudaGetLastError(); 264 | if (err != cudaSuccess) 265 | return 0; 266 | else 267 | return 1; 268 | } 269 | 270 | extern "C" int _bn_edz_eydz_cuda(int N, int C, int S, const float *z, const float *dz, const float *weight, 271 | const float *bias, float *edz, float *eydz, float eps, cudaStream_t stream) { 272 | // Run kernel 273 | dim3 blocks(C); 274 | dim3 threads(getNumThreads(S)); 275 | edz_eydz_kernel<<>>(z, dz, weight, bias, edz, eydz, eps, N, C, S); 276 | 277 | // Check for errors 278 | cudaError_t err = cudaGetLastError(); 279 | if (err != cudaSuccess) 280 | return 0; 281 | else 282 | return 1; 283 | } 284 | 285 | extern "C" int _bn_backward_cuda(int N, int C, int S, const float *dz, const float *z, const float *var, 286 | const float *weight, const float *bias, const float *edz, const float *eydz, 287 | float *dx, float *dweight, float *dbias, float eps, cudaStream_t stream) { 288 | // Run kernel 289 | dim3 blocks(C); 290 | dim3 threads(getNumThreads(S)); 291 | backward_kernel<<>>(dz, z, var, weight, bias, edz, eydz, dx, dweight, dbias, 292 | eps, N, C, S); 293 | 294 | // Check for errors 295 | cudaError_t err = cudaGetLastError(); 296 | if (err != cudaSuccess) 297 | return 0; 298 | else 299 | return 1; 300 | } 301 | 302 | extern "C" int _leaky_relu_cuda(int N, float *x, float slope, cudaStream_t stream) { 303 | // Run using thrust 304 | thrust::device_ptr th_x = thrust::device_pointer_cast(x); 305 | thrust::transform_if(thrust::cuda::par.on(stream), th_x, th_x + N, th_x, 306 | [slope] __device__ (const float& x) { return x * slope; }, 307 | [] __device__ (const float& x) { return x < 0; }); 308 | 309 | // Check for errors 310 | cudaError_t err = cudaGetLastError(); 311 | if (err != cudaSuccess) 312 | return 0; 313 | else 314 | return 1; 315 | } 316 | 317 | extern "C" int _leaky_relu_backward_cuda(int N, const float *x, float *dx, float slope, cudaStream_t stream) { 318 | // Run using thrust 319 | thrust::device_ptr th_x = thrust::device_pointer_cast(x); 320 | thrust::device_ptr th_dx = thrust::device_pointer_cast(dx); 321 | thrust::transform_if(thrust::cuda::par.on(stream), th_dx, th_dx + N, th_x, th_dx, 322 | [slope] __device__ (const float& dx) { return dx * slope; }, 323 | [] __device__ (const float& x) { return x < 0; }); 324 | 325 | // Check for errors 326 | cudaError_t err = cudaGetLastError(); 327 | if (err != cudaSuccess) 328 | return 0; 329 | else 330 | return 1; 331 | } 332 | 333 | extern "C" int _elu_cuda(int N, float *x, cudaStream_t stream) { 334 | // Run using thrust 335 | thrust::device_ptr th_x = thrust::device_pointer_cast(x); 336 | thrust::transform_if(thrust::cuda::par.on(stream), th_x, th_x + N, th_x, 337 | [] __device__ (const float& x) { return exp(x) - 1.f; }, 338 | [] __device__ (const float& x) { return x < 0; }); 339 | 340 | // Check for errors 341 | cudaError_t err = cudaGetLastError(); 342 | if (err != cudaSuccess) 343 | return 0; 344 | else 345 | return 1; 346 | } 347 | 348 | extern "C" int _elu_backward_cuda(int N, const float *x, float *dx, cudaStream_t stream) { 349 | // Run using thrust 350 | thrust::device_ptr th_x = thrust::device_pointer_cast(x); 351 | thrust::device_ptr th_dx = thrust::device_pointer_cast(dx); 352 | thrust::transform_if(thrust::cuda::par.on(stream), th_dx, th_dx + N, th_x, th_x, th_dx, 353 | [] __device__ (const float& dx, const float& x) { return dx * (x + 1.f); }, 354 | [] __device__ (const float& x) { return x < 0; }); 355 | 356 | // Check for errors 357 | cudaError_t err = cudaGetLastError(); 358 | if (err != cudaSuccess) 359 | return 0; 360 | else 361 | return 1; 362 | } 363 | 364 | extern "C" int _elu_inv_cuda(int N, float *x, cudaStream_t stream) { 365 | // Run using thrust 366 | thrust::device_ptr th_x = thrust::device_pointer_cast(x); 367 | thrust::transform_if(thrust::cuda::par.on(stream), th_x, th_x + N, th_x, 368 | [] __device__ (const float& x) { return log1p(x); }, 369 | [] __device__ (const float& x) { return x < 0; }); 370 | 371 | // Check for errors 372 | cudaError_t err = cudaGetLastError(); 373 | if (err != cudaSuccess) 374 | return 0; 375 | else 376 | return 1; 377 | } 378 | -------------------------------------------------------------------------------- /network/libs/inplace_abn/src/bn.h: -------------------------------------------------------------------------------- 1 | #ifndef __BN__ 2 | #define __BN__ 3 | 4 | /* 5 | * Exported functions 6 | */ 7 | extern "C" int _bn_mean_var_cuda(int N, int C, int S, const float *x, float *mean, float *var, cudaStream_t); 8 | extern "C" int _bn_forward_cuda(int N, int C, int S, const float *x, const float *mean, const float *var, 9 | const float *weight, const float *bias, float *y, float *z, float eps, cudaStream_t); 10 | extern "C" int _bn_edz_eydz_cuda(int N, int C, int S, const float *z, const float *dz, const float *weight, 11 | const float *bias, float *edz, float *eydz, float eps, cudaStream_t stream); 12 | extern "C" int _bn_backward_cuda(int N, int C, int S, const float *dz, const float *z, const float *var, 13 | const float *weight, const float *bias, const float *edz, const float *eydz, float *dx, 14 | float *dweight, float *dbias, float eps, cudaStream_t stream); 15 | extern "C" int _leaky_relu_cuda(int N, float *x, float slope, cudaStream_t stream); 16 | extern "C" int _leaky_relu_backward_cuda(int N, const float *x, float *dx, float slope, cudaStream_t stream); 17 | extern "C" int _elu_cuda(int N, float *x, cudaStream_t stream); 18 | extern "C" int _elu_backward_cuda(int N, const float *x, float *dx, cudaStream_t stream); 19 | extern "C" int _elu_inv_cuda(int N, float *x, cudaStream_t stream); 20 | 21 | #endif 22 | -------------------------------------------------------------------------------- /network/libs/inplace_abn/src/bn.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dontLoveBugs/CSPN_monodepth/7363bf749b8df4ea29f1a4fa9eebddbf97cf3f4b/network/libs/inplace_abn/src/bn.o -------------------------------------------------------------------------------- /network/libs/inplace_abn/src/common.h: -------------------------------------------------------------------------------- 1 | #ifndef __COMMON__ 2 | #define __COMMON__ 3 | #include 4 | 5 | /* 6 | * General settings 7 | */ 8 | const int WARP_SIZE = 32; 9 | const int MAX_BLOCK_SIZE = 512; 10 | 11 | /* 12 | * Utility functions 13 | */ 14 | template 15 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, 16 | unsigned int mask = 0xffffffff) { 17 | #if CUDART_VERSION >= 9000 18 | return __shfl_xor_sync(mask, value, laneMask, width); 19 | #else 20 | return __shfl_xor(value, laneMask, width); 21 | #endif 22 | } 23 | 24 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); } 25 | 26 | static int getNumThreads(int nElem) { 27 | int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE}; 28 | for (int i = 0; i != 5; ++i) { 29 | if (nElem <= threadSizes[i]) { 30 | return threadSizes[i]; 31 | } 32 | } 33 | return MAX_BLOCK_SIZE; 34 | } 35 | 36 | 37 | #endif -------------------------------------------------------------------------------- /network/libs/inplace_abn/src/lib_cffi.cpp: -------------------------------------------------------------------------------- 1 | // All functions assume that input and output tensors are already initialized 2 | // and have the correct dimensions 3 | #include 4 | 5 | // Forward definition of implementation functions 6 | extern "C" { 7 | int _bn_mean_var_cuda(int N, int C, int S, const float *x, float *mean, float *var, cudaStream_t); 8 | int _bn_forward_cuda(int N, int C, int S, const float *x, const float *mean, const float *var, const float *weight, 9 | const float *bias, float *y, float *z, float eps, cudaStream_t); 10 | int _bn_edz_eydz_cuda(int N, int C, int S, const float *z, const float *dz, const float *weight, const float *bias, 11 | float *edz, float *eydz, float eps, cudaStream_t stream); 12 | int _bn_backward_cuda(int N, int C, int S, const float *dz, const float *z, const float *var, const float *weight, 13 | const float *bias, const float *edz, const float *eydz, float *dx, float *dweight, float *dbias, 14 | float eps, cudaStream_t stream); 15 | int _leaky_relu_cuda(int N, float *x, float slope, cudaStream_t stream); 16 | int _leaky_relu_backward_cuda(int N, const float *x, float *dx, float slope, cudaStream_t stream); 17 | int _elu_cuda(int N, float *x, cudaStream_t stream); 18 | int _elu_backward_cuda(int N, const float *x, float *dx, cudaStream_t stream); 19 | int _elu_inv_cuda(int N, float *x, cudaStream_t stream); 20 | } 21 | 22 | extern THCState *state; 23 | 24 | void get_sizes(const THCudaTensor *t, int *N, int *C, int *S){ 25 | // Get sizes 26 | *S = 1; 27 | *N = THCudaTensor_size(state, t, 0); 28 | *C = THCudaTensor_size(state, t, 1); 29 | if (THCudaTensor_nDimension(state, t) > 2) { 30 | for (int i = 2; i < THCudaTensor_nDimension(state, t); ++i) { 31 | *S *= THCudaTensor_size(state, t, i); 32 | } 33 | } 34 | } 35 | 36 | extern "C" int bn_mean_var_cuda(const THCudaTensor *x, THCudaTensor *mean, THCudaTensor *var) { 37 | cudaStream_t stream = THCState_getCurrentStream(state); 38 | 39 | int S, N, C; 40 | get_sizes(x, &N, &C, &S); 41 | 42 | // Get pointers 43 | const float *x_data = THCudaTensor_data(state, x); 44 | float *mean_data = THCudaTensor_data(state, mean); 45 | float *var_data = THCudaTensor_data(state, var); 46 | 47 | return _bn_mean_var_cuda(N, C, S, x_data, mean_data, var_data, stream); 48 | } 49 | 50 | extern "C" int bn_forward_cuda(const THCudaTensor *x, const THCudaTensor *mean, const THCudaTensor *var, 51 | const THCudaTensor *weight, const THCudaTensor *bias, THCudaTensor *y, THCudaTensor *z, 52 | float eps) { 53 | cudaStream_t stream = THCState_getCurrentStream(state); 54 | 55 | int S, N, C; 56 | get_sizes(x, &N, &C, &S); 57 | 58 | // Get pointers 59 | const float *x_data = THCudaTensor_data(state, x); 60 | const float *mean_data = THCudaTensor_data(state, mean); 61 | const float *var_data = THCudaTensor_data(state, var); 62 | const float *weight_data = THCudaTensor_nDimension(state, weight) != 0 ? THCudaTensor_data(state, weight) : 0; 63 | const float *bias_data = THCudaTensor_nDimension(state, bias) != 0 ? THCudaTensor_data(state, bias) : 0; 64 | float *y_data = THCudaTensor_data(state, y); 65 | float *z_data = THCudaTensor_data(state, z); 66 | 67 | return _bn_forward_cuda(N, C, S, x_data, mean_data, var_data, weight_data, bias_data, y_data, z_data, eps, stream); 68 | } 69 | 70 | extern "C" int bn_edz_eydz_cuda(const THCudaTensor *z, const THCudaTensor *dz, const THCudaTensor *weight, 71 | const THCudaTensor *bias, THCudaTensor *edz, THCudaTensor *eydz, float eps) { 72 | cudaStream_t stream = THCState_getCurrentStream(state); 73 | 74 | int S, N, C; 75 | get_sizes(z, &N, &C, &S); 76 | 77 | // Get pointers 78 | const float *z_data = THCudaTensor_data(state, z); 79 | const float *dz_data = THCudaTensor_data(state, dz); 80 | const float *weight_data = THCudaTensor_nDimension(state, weight) != 0 ? THCudaTensor_data(state, weight) : 0; 81 | const float *bias_data = THCudaTensor_nDimension(state, bias) != 0 ? THCudaTensor_data(state, bias) : 0; 82 | float *edz_data = THCudaTensor_data(state, edz); 83 | float *eydz_data = THCudaTensor_data(state, eydz); 84 | 85 | return _bn_edz_eydz_cuda(N, C, S, z_data, dz_data, weight_data, bias_data, edz_data, eydz_data, eps, stream); 86 | } 87 | 88 | extern "C" int bn_backard_cuda(const THCudaTensor *dz, const THCudaTensor *z, const THCudaTensor *var, 89 | const THCudaTensor *weight, const THCudaTensor *bias, const THCudaTensor *edz, 90 | const THCudaTensor *eydz, THCudaTensor *dx, THCudaTensor *dweight, 91 | THCudaTensor *dbias, float eps) { 92 | cudaStream_t stream = THCState_getCurrentStream(state); 93 | 94 | int S, N, C; 95 | get_sizes(dz, &N, &C, &S); 96 | 97 | // Get pointers 98 | const float *dz_data = THCudaTensor_data(state, dz); 99 | const float *z_data = THCudaTensor_data(state, z); 100 | const float *var_data = THCudaTensor_data(state, var); 101 | const float *weight_data = THCudaTensor_nDimension(state, weight) != 0 ? THCudaTensor_data(state, weight) : 0; 102 | const float *bias_data = THCudaTensor_nDimension(state, bias) != 0 ? THCudaTensor_data(state, bias) : 0; 103 | const float *edz_data = THCudaTensor_data(state, edz); 104 | const float *eydz_data = THCudaTensor_data(state, eydz); 105 | float *dx_data = THCudaTensor_nDimension(state, dx) != 0 ? THCudaTensor_data(state, dx) : 0; 106 | float *dweight_data = THCudaTensor_nDimension(state, dweight) != 0 ? THCudaTensor_data(state, dweight) : 0; 107 | float *dbias_data = THCudaTensor_nDimension(state, dbias) != 0 ? THCudaTensor_data(state, dbias) : 0; 108 | 109 | return _bn_backward_cuda(N, C, S, dz_data, z_data, var_data, weight_data, bias_data, edz_data, eydz_data, dx_data, 110 | dweight_data, dbias_data, eps, stream); 111 | } 112 | 113 | extern "C" int leaky_relu_cuda(THCudaTensor *x, float slope) { 114 | cudaStream_t stream = THCState_getCurrentStream(state); 115 | 116 | int N = THCudaTensor_nElement(state, x); 117 | 118 | // Get pointers 119 | float *x_data = THCudaTensor_data(state, x); 120 | 121 | return _leaky_relu_cuda(N, x_data, slope, stream); 122 | } 123 | 124 | extern "C" int leaky_relu_backward_cuda(const THCudaTensor *x, THCudaTensor *dx, float slope) { 125 | cudaStream_t stream = THCState_getCurrentStream(state); 126 | 127 | int N = THCudaTensor_nElement(state, x); 128 | 129 | // Get pointers 130 | const float *x_data = THCudaTensor_data(state, x); 131 | float *dx_data = THCudaTensor_data(state, dx); 132 | 133 | return _leaky_relu_backward_cuda(N, x_data, dx_data, slope, stream); 134 | } 135 | 136 | extern "C" int elu_cuda(THCudaTensor *x) { 137 | cudaStream_t stream = THCState_getCurrentStream(state); 138 | 139 | int N = THCudaTensor_nElement(state, x); 140 | 141 | // Get pointers 142 | float *x_data = THCudaTensor_data(state, x); 143 | 144 | return _elu_cuda(N, x_data, stream); 145 | } 146 | 147 | extern "C" int elu_backward_cuda(const THCudaTensor *x, THCudaTensor *dx) { 148 | cudaStream_t stream = THCState_getCurrentStream(state); 149 | 150 | int N = THCudaTensor_nElement(state, x); 151 | 152 | // Get pointers 153 | const float *x_data = THCudaTensor_data(state, x); 154 | float *dx_data = THCudaTensor_data(state, dx); 155 | 156 | return _elu_backward_cuda(N, x_data, dx_data, stream); 157 | } 158 | 159 | extern "C" int elu_inv_cuda(THCudaTensor *x) { 160 | cudaStream_t stream = THCState_getCurrentStream(state); 161 | 162 | int N = THCudaTensor_nElement(state, x); 163 | 164 | // Get pointers 165 | float *x_data = THCudaTensor_data(state, x); 166 | 167 | return _elu_inv_cuda(N, x_data, stream); 168 | } 169 | -------------------------------------------------------------------------------- /network/libs/inplace_abn/src/lib_cffi.h: -------------------------------------------------------------------------------- 1 | int bn_mean_var_cuda(const THCudaTensor *x, THCudaTensor *mean, THCudaTensor *var); 2 | int bn_forward_cuda(const THCudaTensor *x, const THCudaTensor *mean, const THCudaTensor *var, 3 | const THCudaTensor *weight, const THCudaTensor *bias, THCudaTensor *y, THCudaTensor *z, 4 | float eps); 5 | int bn_edz_eydz_cuda(const THCudaTensor *z, const THCudaTensor *dz, const THCudaTensor *weight, 6 | const THCudaTensor *bias, THCudaTensor *edz, THCudaTensor *eydz, float eps); 7 | int bn_backard_cuda(const THCudaTensor *dz, const THCudaTensor *z, const THCudaTensor *var, 8 | const THCudaTensor *weight, const THCudaTensor *bias, const THCudaTensor *edz, 9 | const THCudaTensor *eydz, THCudaTensor *dx, THCudaTensor *dweight, THCudaTensor *dbias, 10 | float eps); 11 | int leaky_relu_cuda(THCudaTensor *x, float slope); 12 | int leaky_relu_backward_cuda(const THCudaTensor *x, THCudaTensor *dx, float slope); 13 | int elu_cuda(THCudaTensor *x); 14 | int elu_backward_cuda(const THCudaTensor *x, THCudaTensor *dx); 15 | int elu_inv_cuda(THCudaTensor *x); -------------------------------------------------------------------------------- /network/libs/post_process/CSPN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Feb 4 15:37:41 2018 4 | @author: Xinjing Cheng 5 | @email : chengxinjing@baidu.com 6 | """ 7 | import torch.nn as nn 8 | import math 9 | import torch.utils.model_zoo as model_zoo 10 | import torch 11 | from torch.autograd import Variable 12 | 13 | 14 | class AffinityPropagate(nn.Module): 15 | 16 | def __init__(self, spn=False): 17 | super(AffinityPropagate, self).__init__() 18 | self.spn = spn 19 | 20 | def forward(self, guidance, blur_depth, sparse_depth): 21 | 22 | # normalize features 23 | gate1_w1_cmb = torch.abs(guidance.narrow(1, 0, 1)) 24 | gate2_w1_cmb = torch.abs(guidance.narrow(1, 1, 1)) 25 | gate3_w1_cmb = torch.abs(guidance.narrow(1, 2, 1)) 26 | gate4_w1_cmb = torch.abs(guidance.narrow(1, 3, 1)) 27 | gate5_w1_cmb = torch.abs(guidance.narrow(1, 4, 1)) 28 | gate6_w1_cmb = torch.abs(guidance.narrow(1, 5, 1)) 29 | gate7_w1_cmb = torch.abs(guidance.narrow(1, 6, 1)) 30 | gate8_w1_cmb = torch.abs(guidance.narrow(1, 7, 1)) 31 | 32 | sparse_mask = sparse_depth.sign() 33 | 34 | result_depth = (1 - sparse_mask) * blur_depth.clone() + sparse_mask * sparse_depth 35 | 36 | for i in range(16): 37 | # one propagation 38 | spn_kernel = 3 39 | elewise_max_gate1 = self.eight_way_propagation(gate1_w1_cmb, result_depth, spn_kernel) 40 | elewise_max_gate2 = self.eight_way_propagation(gate2_w1_cmb, result_depth, spn_kernel) 41 | elewise_max_gate3 = self.eight_way_propagation(gate3_w1_cmb, result_depth, spn_kernel) 42 | elewise_max_gate4 = self.eight_way_propagation(gate4_w1_cmb, result_depth, spn_kernel) 43 | elewise_max_gate5 = self.eight_way_propagation(gate5_w1_cmb, result_depth, spn_kernel) 44 | elewise_max_gate6 = self.eight_way_propagation(gate6_w1_cmb, result_depth, spn_kernel) 45 | elewise_max_gate7 = self.eight_way_propagation(gate7_w1_cmb, result_depth, spn_kernel) 46 | elewise_max_gate8 = self.eight_way_propagation(gate8_w1_cmb, result_depth, spn_kernel) 47 | 48 | result_depth = self.max_of_8_tensor(elewise_max_gate1, elewise_max_gate2, elewise_max_gate3, 49 | elewise_max_gate4, \ 50 | elewise_max_gate5, elewise_max_gate6, elewise_max_gate7, 51 | elewise_max_gate8) 52 | 53 | result_depth = (1 - sparse_mask) * result_depth.clone() + sparse_mask * sparse_depth 54 | 55 | return result_depth 56 | 57 | def eight_way_propagation_old(self, weight_matrix, blur_matrix, kernel): 58 | [batch_size, channels, height, width] = weight_matrix.size() 59 | self.avg_conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=kernel, stride=1, 60 | padding=(kernel - 1) // 2, 61 | bias=False) 62 | weight = torch.ones(1, 1, kernel, kernel).cuda() 63 | weight[0, 0, (kernel - 1) // 2, (kernel - 1) // 2] = 0 # kernel中心元素等于0 64 | self.avg_conv.weight = nn.Parameter(weight) 65 | for param in self.avg_conv.parameters(): 66 | param.requires_grad = False 67 | 68 | self.sum_conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=kernel, stride=1, 69 | padding=(kernel - 1) // 2, 70 | bias=False) 71 | sum_weight = torch.ones(1, 1, kernel, kernel).cuda() 72 | self.sum_conv.weight = nn.Parameter(sum_weight) 73 | for param in self.sum_conv.parameters(): 74 | param.requires_grad = False 75 | weight_sum = self.sum_conv(weight_matrix) 76 | avg_sum = self.avg_conv((weight_matrix * blur_matrix)) 77 | # 计算kernel 78 | out = (torch.div(weight_matrix, weight_sum)) * blur_matrix + torch.div(avg_sum, weight_sum) 79 | return out 80 | 81 | def eight_way_propagation(self, weight_matrix, blur_matrix, kernel): 82 | [batch_size, channels, height, width] = weight_matrix.size() 83 | self.avg_conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=kernel, stride=1, 84 | padding=(kernel - 1) // 2, 85 | bias=False) 86 | weight = torch.ones(1, 1, kernel, kernel).cuda() 87 | 88 | self.avg_conv.weight = nn.Parameter(weight) 89 | for param in self.avg_conv.parameters(): 90 | param.requires_grad = False 91 | 92 | self.sum_conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=kernel, stride=1, 93 | padding=(kernel - 1) // 2, 94 | bias=False) 95 | sum_weight = torch.ones(1, 1, kernel, kernel).cuda() 96 | self.sum_conv.weight = nn.Parameter(sum_weight) 97 | for param in self.sum_conv.parameters(): 98 | param.requires_grad = False 99 | weight_sum = self.sum_conv(weight_matrix) 100 | avg_sum = self.avg_conv((weight_matrix * blur_matrix)) 101 | 102 | out = torch.div(avg_sum, weight_sum) 103 | return out 104 | 105 | def normalize_gate(self, guidance): 106 | gate1_x1_g1 = guidance.narrow(1, 0, 1) 107 | gate1_x1_g2 = guidance.narrow(1, 1, 1) 108 | gate1_x1_g1_abs = torch.abs(gate1_x1_g1) 109 | gate1_x1_g2_abs = torch.abs(gate1_x1_g2) 110 | elesum_gate1_x1 = torch.add(gate1_x1_g1_abs, gate1_x1_g2_abs) 111 | gate1_x1_g1_cmb = torch.div(gate1_x1_g1, elesum_gate1_x1) 112 | gate1_x1_g2_cmb = torch.div(gate1_x1_g2, elesum_gate1_x1) 113 | return gate1_x1_g1_cmb, gate1_x1_g2_cmb 114 | 115 | def max_of_4_tensor(self, element1, element2, element3, element4): 116 | max_element1_2 = torch.max(element1, element2) 117 | max_element3_4 = torch.max(element3, element4) 118 | return torch.max(max_element1_2, max_element3_4) 119 | 120 | def max_of_8_tensor(self, element1, element2, element3, element4, element5, element6, element7, element8): 121 | max_element1_2 = self.max_of_4_tensor(element1, element2, element3, element4) 122 | max_element3_4 = self.max_of_4_tensor(element5, element6, element7, element8) 123 | return torch.max(max_element1_2, max_element3_4) 124 | 125 | 126 | class AffinityPropagate_prediction(nn.Module): 127 | 128 | def __init__(self, spn=False): 129 | super(AffinityPropagate_prediction, self).__init__() 130 | self.spn = spn 131 | 132 | def forward(self, guidance, blur_depth): 133 | 134 | # normalize features 135 | gate1_w1_cmb = torch.abs(guidance.narrow(1, 0, 1)) 136 | gate2_w1_cmb = torch.abs(guidance.narrow(1, 1, 1)) 137 | gate3_w1_cmb = torch.abs(guidance.narrow(1, 2, 1)) 138 | gate4_w1_cmb = torch.abs(guidance.narrow(1, 3, 1)) 139 | gate5_w1_cmb = torch.abs(guidance.narrow(1, 4, 1)) 140 | gate6_w1_cmb = torch.abs(guidance.narrow(1, 5, 1)) 141 | gate7_w1_cmb = torch.abs(guidance.narrow(1, 6, 1)) 142 | gate8_w1_cmb = torch.abs(guidance.narrow(1, 7, 1)) 143 | 144 | result_depth = blur_depth 145 | 146 | for i in range(16): 147 | # one propagation 148 | spn_kernel = 3 149 | elewise_max_gate1 = self.eight_way_propagation(gate1_w1_cmb, result_depth, spn_kernel) 150 | elewise_max_gate2 = self.eight_way_propagation(gate2_w1_cmb, result_depth, spn_kernel) 151 | elewise_max_gate3 = self.eight_way_propagation(gate3_w1_cmb, result_depth, spn_kernel) 152 | elewise_max_gate4 = self.eight_way_propagation(gate4_w1_cmb, result_depth, spn_kernel) 153 | elewise_max_gate5 = self.eight_way_propagation(gate5_w1_cmb, result_depth, spn_kernel) 154 | elewise_max_gate6 = self.eight_way_propagation(gate6_w1_cmb, result_depth, spn_kernel) 155 | elewise_max_gate7 = self.eight_way_propagation(gate7_w1_cmb, result_depth, spn_kernel) 156 | elewise_max_gate8 = self.eight_way_propagation(gate8_w1_cmb, result_depth, spn_kernel) 157 | 158 | result_depth = self.max_of_8_tensor(elewise_max_gate1, elewise_max_gate2, elewise_max_gate3, 159 | elewise_max_gate4, \ 160 | elewise_max_gate5, elewise_max_gate6, elewise_max_gate7, 161 | elewise_max_gate8) 162 | 163 | return result_depth 164 | 165 | def eight_way_propagation_old(self, weight_matrix, blur_matrix, kernel): 166 | [batch_size, channels, height, width] = weight_matrix.size() 167 | self.avg_conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=kernel, stride=1, 168 | padding=(kernel - 1) // 2, 169 | bias=False) 170 | weight = torch.ones(1, 1, kernel, kernel).cuda() 171 | weight[0, 0, (kernel - 1) // 2, (kernel - 1) // 2] = 0 # kernel中心元素等于0 172 | self.avg_conv.weight = nn.Parameter(weight) 173 | for param in self.avg_conv.parameters(): 174 | param.requires_grad = False 175 | 176 | self.sum_conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=kernel, stride=1, 177 | padding=(kernel - 1) // 2, 178 | bias=False) 179 | sum_weight = torch.ones(1, 1, kernel, kernel).cuda() 180 | self.sum_conv.weight = nn.Parameter(sum_weight) 181 | for param in self.sum_conv.parameters(): 182 | param.requires_grad = False 183 | weight_sum = self.sum_conv(weight_matrix) 184 | avg_sum = self.avg_conv((weight_matrix * blur_matrix)) 185 | # 计算kernel 186 | out = (torch.div(weight_matrix, weight_sum)) * blur_matrix + torch.div(avg_sum, weight_sum) 187 | return out 188 | 189 | def eight_way_propagation(self, weight_matrix, blur_matrix, kernel): 190 | [batch_size, channels, height, width] = weight_matrix.size() 191 | self.avg_conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=kernel, stride=1, 192 | padding=(kernel - 1) // 2, 193 | bias=False) 194 | weight = torch.ones(1, 1, kernel, kernel).cuda() 195 | 196 | self.avg_conv.weight = nn.Parameter(weight) 197 | for param in self.avg_conv.parameters(): 198 | param.requires_grad = False 199 | 200 | self.sum_conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=kernel, stride=1, 201 | padding=(kernel - 1) // 2, 202 | bias=False) 203 | sum_weight = torch.ones(1, 1, kernel, kernel).cuda() 204 | self.sum_conv.weight = nn.Parameter(sum_weight) 205 | for param in self.sum_conv.parameters(): 206 | param.requires_grad = False 207 | weight_sum = self.sum_conv(weight_matrix) 208 | avg_sum = self.avg_conv((weight_matrix * blur_matrix)) 209 | 210 | out = torch.div(avg_sum, weight_sum) 211 | return out 212 | 213 | def normalize_gate(self, guidance): 214 | gate1_x1_g1 = guidance.narrow(1, 0, 1) 215 | gate1_x1_g2 = guidance.narrow(1, 1, 1) 216 | gate1_x1_g1_abs = torch.abs(gate1_x1_g1) 217 | gate1_x1_g2_abs = torch.abs(gate1_x1_g2) 218 | elesum_gate1_x1 = torch.add(gate1_x1_g1_abs, gate1_x1_g2_abs) 219 | gate1_x1_g1_cmb = torch.div(gate1_x1_g1, elesum_gate1_x1) 220 | gate1_x1_g2_cmb = torch.div(gate1_x1_g2, elesum_gate1_x1) 221 | return gate1_x1_g1_cmb, gate1_x1_g2_cmb 222 | 223 | def max_of_4_tensor(self, element1, element2, element3, element4): 224 | max_element1_2 = torch.max(element1, element2) 225 | max_element3_4 = torch.max(element3, element4) 226 | return torch.max(max_element1_2, max_element3_4) 227 | 228 | def max_of_8_tensor(self, element1, element2, element3, element4, element5, element6, element7, element8): 229 | max_element1_2 = self.max_of_4_tensor(element1, element2, element3, element4) 230 | max_element3_4 = self.max_of_4_tensor(element5, element6, element7, element8) 231 | return torch.max(max_element1_2, max_element3_4) 232 | 233 | 234 | def eight_way_propagation(weight_matrix, blur_matrix, kernel): 235 | [batch_size, channels, height, width] = weight_matrix.size() 236 | avg_conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=kernel, stride=1, 237 | padding=(kernel - 1) // 2, 238 | bias=False) 239 | weight = torch.ones(1, 1, kernel, kernel) 240 | weight[0, 0, (kernel - 1) // 2, (kernel - 1) // 2] = 0 # kernel中心元素等于0 241 | avg_conv.weight = nn.Parameter(weight) 242 | for param in avg_conv.parameters(): 243 | param.requires_grad = False 244 | 245 | sum_conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=kernel, stride=1, 246 | padding=(kernel - 1) // 2, 247 | bias=False) 248 | sum_weight = torch.ones(1, 1, kernel, kernel) 249 | sum_conv.weight = nn.Parameter(sum_weight) 250 | for param in sum_conv.parameters(): 251 | param.requires_grad = False 252 | weight_sum = sum_conv(weight_matrix) 253 | avg_sum = avg_conv((weight_matrix * blur_matrix)) 254 | # 计算kernel 255 | out = (torch.div(weight_matrix, weight_sum)) * blur_matrix + torch.div(avg_sum, weight_sum) 256 | return out 257 | 258 | 259 | def eight_way_propagation_v2(weight_matrix, blur_matrix, kernel): 260 | [batch_size, channels, height, width] = weight_matrix.size() 261 | avg_conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=kernel, stride=1, 262 | padding=(kernel - 1) // 2, 263 | bias=False) 264 | weight = torch.ones(1, 1, kernel, kernel) 265 | 266 | avg_conv.weight = nn.Parameter(weight) 267 | for param in avg_conv.parameters(): 268 | param.requires_grad = False 269 | 270 | sum_conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=kernel, stride=1, 271 | padding=(kernel - 1) // 2, 272 | bias=False) 273 | sum_weight = torch.ones(1, 1, kernel, kernel) 274 | sum_conv.weight = nn.Parameter(sum_weight) 275 | for param in sum_conv.parameters(): 276 | param.requires_grad = False 277 | weight_sum = sum_conv(weight_matrix) 278 | avg_sum = avg_conv((weight_matrix * blur_matrix)) 279 | 280 | out = torch.div(avg_sum, weight_sum) 281 | return out 282 | 283 | 284 | if __name__ == '__main__': 285 | weight_matrix = torch.randn(1, 1, 228, 304) 286 | blur_matrix = torch.randn(1, 1, 228, 304) 287 | spn_kernel = 3 288 | 289 | d0 = eight_way_propagation(weight_matrix, blur_matrix, spn_kernel) 290 | d1 = eight_way_propagation_v2(weight_matrix, blur_matrix, spn_kernel) 291 | 292 | print(d0) 293 | print(d1) 294 | 295 | if d0 == d1: 296 | print('Yes') 297 | else: 298 | print('no') 299 | -------------------------------------------------------------------------------- /network/libs/post_process/CSPN_new.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/6/20 16:05 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | """ 9 | @author: Xinjing Cheng, https://github.com/XinJCheng/CSPN/blob/master/models/cspn.py 10 | """ 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | 17 | class AffinityPropagate(nn.Module): 18 | 19 | def __init__(self, prop_time, prop_kernel): 20 | super(AffinityPropagate, self).__init__() 21 | self.prop_time = prop_time 22 | self.prop_kernel = prop_kernel 23 | self.in_feature = 1 24 | self.out_feature = 1 25 | 26 | def forward(self, guidance, blur_depth, sparse_depth=None): 27 | 28 | # normalize features 29 | gate1_wb_cmb = torch.abs(guidance.narrow(1, 0, self.out_feature)) 30 | gate2_wb_cmb = torch.abs(guidance.narrow(1, 1 * self.out_feature, self.out_feature)) 31 | gate3_wb_cmb = torch.abs(guidance.narrow(1, 2 * self.out_feature, self.out_feature)) 32 | gate4_wb_cmb = torch.abs(guidance.narrow(1, 3 * self.out_feature, self.out_feature)) 33 | gate5_wb_cmb = torch.abs(guidance.narrow(1, 4 * self.out_feature, self.out_feature)) 34 | gate6_wb_cmb = torch.abs(guidance.narrow(1, 5 * self.out_feature, self.out_feature)) 35 | gate7_wb_cmb = torch.abs(guidance.narrow(1, 6 * self.out_feature, self.out_feature)) 36 | gate8_wb_cmb = torch.abs(guidance.narrow(1, 7 * self.out_feature, self.out_feature)) 37 | 38 | # gate1:left_top, gate2:center_top, gate3:right_top 39 | # gate4:left_center, , gate5: right_center 40 | # gate6:left_bottom, gate7: center_bottom, gate8: right_bottm 41 | 42 | # top pad 43 | left_top_pad = nn.ZeroPad2d((0, 2, 0, 2)) 44 | gate1_wb_cmb = left_top_pad(gate1_wb_cmb).unsqueeze(1) 45 | 46 | center_top_pad = nn.ZeroPad2d((1, 1, 0, 2)) 47 | gate2_wb_cmb = center_top_pad(gate2_wb_cmb).unsqueeze(1) 48 | 49 | right_top_pad = nn.ZeroPad2d((2, 0, 0, 2)) 50 | gate3_wb_cmb = right_top_pad(gate3_wb_cmb).unsqueeze(1) 51 | 52 | # center pad 53 | left_center_pad = nn.ZeroPad2d((0, 2, 1, 1)) 54 | gate4_wb_cmb = left_center_pad(gate4_wb_cmb).unsqueeze(1) 55 | 56 | right_center_pad = nn.ZeroPad2d((2, 0, 1, 1)) 57 | gate5_wb_cmb = right_center_pad(gate5_wb_cmb).unsqueeze(1) 58 | 59 | # bottom pad 60 | left_bottom_pad = nn.ZeroPad2d((0, 2, 2, 0)) 61 | gate6_wb_cmb = left_bottom_pad(gate6_wb_cmb).unsqueeze(1) 62 | 63 | center_bottom_pad = nn.ZeroPad2d((1, 1, 2, 0)) 64 | gate7_wb_cmb = center_bottom_pad(gate7_wb_cmb).unsqueeze(1) 65 | 66 | right_bottm_pad = nn.ZeroPad2d((2, 0, 2, 0)) 67 | gate8_wb_cmb = right_bottm_pad(gate8_wb_cmb).unsqueeze(1) 68 | 69 | gate_wb = torch.cat((gate1_wb_cmb, gate2_wb_cmb, gate3_wb_cmb, gate4_wb_cmb, 70 | gate5_wb_cmb, gate6_wb_cmb, gate7_wb_cmb, gate8_wb_cmb), 1) 71 | 72 | # pad input and convert to 8 channel 3D features 73 | raw_depht_input = blur_depth 74 | # blur_depht_pad = nn.ZeroPad2d((1,1,1,1)) 75 | result_depth = blur_depth 76 | 77 | if sparse_depth is not None: 78 | sparse_mask = sparse_depth.sign() 79 | 80 | for i in range(self.prop_time): 81 | 82 | # one propagation 83 | spn_kernel = self.prop_kernel 84 | result_depth = self.pad_blur_depth(result_depth) 85 | neigbor_weighted_sum = self.eight_way_propagation(gate_wb, result_depth, spn_kernel) 86 | neigbor_weighted_sum = neigbor_weighted_sum.squeeze(1) 87 | neigbor_weighted_sum = neigbor_weighted_sum[:, :, 1:-1, 1:-1] 88 | result_depth = neigbor_weighted_sum 89 | if sparse_depth is not None: 90 | result_depth = (1 - sparse_mask) * result_depth + sparse_mask * raw_depht_input 91 | 92 | return result_depth 93 | 94 | def pad_blur_depth(self, blur_depth): 95 | # top pad 96 | left_top_pad = nn.ZeroPad2d((0, 2, 0, 2)) 97 | blur_depth_1 = left_top_pad(blur_depth).unsqueeze(1) 98 | center_top_pad = nn.ZeroPad2d((1, 1, 0, 2)) 99 | blur_depth_2 = center_top_pad(blur_depth).unsqueeze(1) 100 | right_top_pad = nn.ZeroPad2d((2, 0, 0, 2)) 101 | blur_depth_3 = right_top_pad(blur_depth).unsqueeze(1) 102 | 103 | # center pad 104 | left_center_pad = nn.ZeroPad2d((0, 2, 1, 1)) 105 | blur_depth_4 = left_center_pad(blur_depth).unsqueeze(1) 106 | right_center_pad = nn.ZeroPad2d((2, 0, 1, 1)) 107 | blur_depth_5 = right_center_pad(blur_depth).unsqueeze(1) 108 | 109 | # bottom pad 110 | left_bottom_pad = nn.ZeroPad2d((0, 2, 2, 0)) 111 | blur_depth_6 = left_bottom_pad(blur_depth).unsqueeze(1) 112 | center_bottom_pad = nn.ZeroPad2d((1, 1, 2, 0)) 113 | blur_depth_7 = center_bottom_pad(blur_depth).unsqueeze(1) 114 | right_bottm_pad = nn.ZeroPad2d((2, 0, 2, 0)) 115 | blur_depth_8 = right_bottm_pad(blur_depth).unsqueeze(1) 116 | 117 | result_depth = torch.cat((blur_depth_1, blur_depth_2, blur_depth_3, blur_depth_4, 118 | blur_depth_5, blur_depth_6, blur_depth_7, blur_depth_8), 1) 119 | return result_depth 120 | 121 | def eight_way_propagation(self, weight_matrix, blur_matrix, kernel): 122 | sum_conv_weight = torch.ones((1, 8, 1, kernel//2, kernel//2), device=weight_matrix.device) 123 | 124 | _weight_sum = F.conv3d(weight_matrix, sum_conv_weight) 125 | _total_sum = F.conv3d(weight_matrix * blur_matrix, sum_conv_weight) 126 | 127 | out = torch.div(_total_sum, _weight_sum) 128 | return out -------------------------------------------------------------------------------- /network/libs/post_process/CSPN_ours.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Time : 2019-05-19 19:41 5 | @Author : Wang Xin 6 | @Email : wangxin_buaa@163.com 7 | @File : CSPN_ours.py 8 | """ 9 | 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | import math 16 | 17 | 18 | class AffinityPropagate(nn.Module): 19 | 20 | def __init__(self, prop_time): 21 | super(AffinityPropagate, self).__init__() 22 | self.times = prop_time 23 | 24 | def forward(self, x, guided, sparse_depth=None): 25 | """ 26 | :param x: Feature maps, N,C,H,W 27 | :param guided: guided Filter, N, K^2-1, H, W, K is kernel size 28 | :return: returned feature map, N, C, H, W 29 | """ 30 | 31 | B, C, H, W = guided.size() 32 | K = int(math.sqrt(C + 1)) 33 | 34 | # 归一化 35 | guided = F.softmax(guided, dim=1) 36 | 37 | kernel = torch.zeros(B, C + 1, H, W, device=guided.device) 38 | kernel[:, 0:C // 2, :, :] = guided[:, 0:C // 2, :, :] 39 | kernel[:, C // 2 + 1:C + 1, :, :] = guided[:, C // 2:C, :, :] 40 | 41 | kernel = kernel.unsqueeze(dim=1).reshape(B, 1, K, K, H, W) 42 | 43 | if sparse_depth is not None: 44 | sparse_mask = sparse_depth.sign() 45 | _x = x 46 | 47 | for _ in range(self.times): 48 | from network.libs.base.pac import conv2d 49 | x = conv2d(x, kernel, kernel_size=K, stride=1, padding=K // 2, dilation=1) 50 | 51 | if sparse_depth is not None: 52 | no_sparse_mask = 1 - sparse_mask 53 | x = sparse_mask * _x + no_sparse_mask * x 54 | return x 55 | -------------------------------------------------------------------------------- /network/libs/post_process/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Time : 2019-05-19 16:39 5 | @Author : Wang Xin 6 | @Email : wangxin_buaa@163.com 7 | @File : __init__.py.py 8 | """ 9 | -------------------------------------------------------------------------------- /network/unet_cspn_nyu.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Sat Feb 3 15:32:49 2018 3 | @author: norbot 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from network.libs.post_process import CSPN_new as post_process 10 | 11 | # memory analyze 12 | import gc 13 | 14 | 15 | """ 16 | if using multi-gpus, replace batchnorm with inplace abn. 17 | Note: When using multi-gpus, if you add new operations, 18 | you should not use inplace operation, such as "+=" or setting flag "inplace=True". 19 | """ 20 | if torch.cuda.device_count() > 1: 21 | from network.libs.inplace_abn import InPlaceABNSync 22 | affine_par = True 23 | import functools 24 | BatchNorm2d = functools.partial(InPlaceABNSync, activation='none') 25 | BatchNorm2d_relu = InPlaceABNSync 26 | else: 27 | affine_par = True 28 | BatchNorm2d = nn.BatchNorm2d 29 | from network.libs.base.operation import BatchNorm2d_Relu 30 | BatchNorm2d_relu = BatchNorm2d_Relu 31 | 32 | __all__ = ['ResNet', 'resnet18', 'resnet50'] 33 | 34 | model_urls = { 35 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 36 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 37 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 38 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 39 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 40 | } 41 | 42 | model_path = { 43 | 'resnet18': 'pretrained/resnet18.pth', 44 | 'resnet50': 'pretrained/resnet50.pth' 45 | } 46 | 47 | 48 | def conv3x3(in_planes, out_planes, stride=1): 49 | """3x3 convolution with padding""" 50 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 51 | padding=1, bias=False) 52 | 53 | 54 | class BasicBlock(nn.Module): 55 | expansion = 1 56 | 57 | def __init__(self, inplanes, planes, stride=1, downsample=None): 58 | super(BasicBlock, self).__init__() 59 | self.conv1 = conv3x3(inplanes, planes, stride) 60 | self.bn1 = BatchNorm2d(planes) 61 | self.relu = nn.ReLU(inplace=True) 62 | self.conv2 = conv3x3(planes, planes) 63 | self.bn2 = BatchNorm2d(planes) 64 | self.downsample = downsample 65 | self.stride = stride 66 | 67 | def forward(self, x): 68 | residual = x 69 | 70 | out = self.conv1(x) 71 | out = self.bn1(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv2(out) 75 | out = self.bn2(out) 76 | 77 | if self.downsample is not None: 78 | residual = self.downsample(x) 79 | 80 | out += residual 81 | out = self.relu(out) 82 | 83 | return out 84 | 85 | 86 | class Bottleneck(nn.Module): 87 | expansion = 4 88 | 89 | def __init__(self, inplanes, planes, stride=1, downsample=None): 90 | super(Bottleneck, self).__init__() 91 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 92 | self.bn1 = nn.BatchNorm2d(planes) 93 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 94 | padding=1, bias=False) 95 | self.bn2 = nn.BatchNorm2d(planes) 96 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 97 | self.bn3 = nn.BatchNorm2d(planes * 4) 98 | self.relu = nn.ReLU(inplace=True) 99 | self.downsample = downsample 100 | self.stride = stride 101 | 102 | def forward(self, x): 103 | residual = x 104 | 105 | out = self.conv1(x) 106 | out = self.bn1(out) 107 | out = self.relu(out) 108 | 109 | out = self.conv2(out) 110 | out = self.bn2(out) 111 | out = self.relu(out) 112 | 113 | out = self.conv3(out) 114 | out = self.bn3(out) 115 | 116 | if self.downsample is not None: 117 | residual = self.downsample(x) 118 | 119 | out += residual 120 | out = self.relu(out) 121 | 122 | return out 123 | 124 | 125 | class UpProj_Block(nn.Module): 126 | def __init__(self, in_channels, out_channels, oheight=0, owidth=0): 127 | super(UpProj_Block, self).__init__() 128 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=False) 129 | self.bn1 = nn.BatchNorm2d(out_channels) 130 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) 131 | self.bn2 = nn.BatchNorm2d(out_channels) 132 | self.sc_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=False) 133 | self.sc_bn1 = nn.BatchNorm2d(out_channels) 134 | self.relu = nn.ReLU(inplace=True) 135 | self.oheight = oheight 136 | self.owidth = owidth 137 | 138 | def _up_pooling(self, x, scale): 139 | oheight = 0 140 | owidth = 0 141 | if self.oheight == 0 and self.owidth == 0: 142 | oheight = scale * x.size(2) 143 | owidth = scale * x.size(3) 144 | x = nn.Upsample(scale_factor=scale, mode='nearest')(x) 145 | else: 146 | oheight = self.oheight 147 | owidth = self.owidth 148 | x = nn.Upsample(size=(oheight, owidth), mode='nearest')(x) 149 | mask = torch.zeros_like(x) 150 | for h in range(0, oheight, 2): 151 | for w in range(0, owidth, 2): 152 | mask[:, :, h, w] = 1 153 | x = torch.mul(mask, x) 154 | return x 155 | 156 | def forward(self, x): 157 | x = self._up_pooling(x, 2) 158 | out = self.relu(self.bn1(self.conv1(x))) 159 | out = self.bn2(self.conv2(out)) 160 | short_cut = self.sc_bn1(self.sc_conv1(x)) 161 | out += short_cut 162 | out = self.relu(out) 163 | return out 164 | 165 | 166 | class Simple_Gudi_UpConv_Block(nn.Module): 167 | def __init__(self, in_channels, out_channels, oheight=0, owidth=0): 168 | super(Simple_Gudi_UpConv_Block, self).__init__() 169 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=False) 170 | self.bn1 = nn.BatchNorm2d(out_channels) 171 | self.relu = nn.ReLU(inplace=True) 172 | self.oheight = oheight 173 | self.owidth = owidth 174 | 175 | def _up_pooling(self, x, scale): 176 | 177 | x = nn.Upsample(scale_factor=scale, mode='nearest')(x) 178 | if self.oheight != 0 and self.owidth != 0: 179 | x = x.narrow(2, 0, self.oheight) 180 | x = x.narrow(3, 0, self.owidth) 181 | # x = x[:,:,0:self.oheight, 0:self.owidth].clone() 182 | mask = torch.zeros_like(x) 183 | for h in range(0, self.oheight, 2): 184 | for w in range(0, self.owidth, 2): 185 | mask[:, :, h, w] = 1 186 | x = torch.mul(mask, x) 187 | return x 188 | 189 | def forward(self, x): 190 | x = self._up_pooling(x, 2) 191 | out = self.relu(self.bn1(self.conv1(x))) 192 | return out 193 | 194 | 195 | class Simple_Gudi_UpConv_Block_Last_Layer(nn.Module): 196 | def __init__(self, in_channels, out_channels, oheight=0, owidth=0): 197 | super(Simple_Gudi_UpConv_Block_Last_Layer, self).__init__() 198 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) 199 | self.oheight = oheight 200 | self.owidth = owidth 201 | 202 | def _up_pooling(self, x, scale): 203 | 204 | x = nn.Upsample(scale_factor=scale, mode='nearest')(x) 205 | if self.oheight != 0 and self.owidth != 0: 206 | x = x.narrow(2, 0, self.oheight) 207 | x = x.narrow(3, 0, self.owidth) 208 | mask = torch.zeros_like(x) 209 | for h in range(0, self.oheight, 2): 210 | for w in range(0, self.owidth, 2): 211 | mask[:, :, h, w] = 1 212 | x = torch.mul(mask, x) 213 | return x 214 | 215 | def forward(self, x): 216 | x = self._up_pooling(x, 2) 217 | out = self.conv1(x) 218 | return out 219 | 220 | 221 | class Gudi_UpProj_Block(nn.Module): 222 | def __init__(self, in_channels, out_channels, oheight=0, owidth=0): 223 | super(Gudi_UpProj_Block, self).__init__() 224 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=False) 225 | self.bn1 = nn.BatchNorm2d(out_channels) 226 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) 227 | self.bn2 = nn.BatchNorm2d(out_channels) 228 | self.sc_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=False) 229 | self.sc_bn1 = nn.BatchNorm2d(out_channels) 230 | self.relu = nn.ReLU(inplace=True) 231 | self.oheight = oheight 232 | self.owidth = owidth 233 | 234 | def _up_pooling(self, x, scale): 235 | 236 | x = nn.Upsample(scale_factor=scale, mode='nearest')(x) 237 | if self.oheight != 0 and self.owidth != 0: 238 | x = x[:, :, 0:self.oheight, 0:self.owidth] 239 | mask = torch.zeros_like(x) 240 | for h in range(0, self.oheight, 2): 241 | for w in range(0, self.owidth, 2): 242 | mask[:, :, h, w] = 1 243 | x = torch.mul(mask, x) 244 | return x 245 | 246 | def forward(self, x): 247 | x = self._up_pooling(x, 2) 248 | out = self.relu(self.bn1(self.conv1(x))) 249 | out = self.bn2(self.conv2(out)) 250 | short_cut = self.sc_bn1(self.sc_conv1(x)) 251 | out += short_cut 252 | out = self.relu(out) 253 | return out 254 | 255 | 256 | class Gudi_UpProj_Block_Cat(nn.Module): 257 | def __init__(self, in_channels, out_channels, oheight=0, owidth=0): 258 | super(Gudi_UpProj_Block_Cat, self).__init__() 259 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=False) 260 | self.bn1 = nn.BatchNorm2d(out_channels) 261 | self.conv1_1 = nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, stride=1, padding=1, bias=False) 262 | self.bn1_1 = nn.BatchNorm2d(out_channels) 263 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) 264 | self.bn2 = nn.BatchNorm2d(out_channels) 265 | self.sc_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=False) 266 | self.sc_bn1 = nn.BatchNorm2d(out_channels) 267 | self.relu = nn.ReLU(inplace=True) 268 | self.oheight = oheight 269 | self.owidth = owidth 270 | 271 | def _up_pooling(self, x, scale): 272 | 273 | x = nn.Upsample(scale_factor=scale, mode='nearest')(x) 274 | if self.oheight != 0 and self.owidth != 0: 275 | x = x[:, :, 0:self.oheight, 0:self.owidth] 276 | mask = torch.zeros_like(x) 277 | for h in range(0, self.oheight, 2): 278 | for w in range(0, self.owidth, 2): 279 | mask[:, :, h, w] = 1 280 | x = torch.mul(mask, x) 281 | return x 282 | 283 | def forward(self, x, side_input): 284 | x = self._up_pooling(x, 2) 285 | out = self.relu(self.bn1(self.conv1(x))) 286 | out = torch.cat((out, side_input), 1) 287 | out = self.relu(self.bn1_1(self.conv1_1(out))) 288 | out = self.bn2(self.conv2(out)) 289 | short_cut = self.sc_bn1(self.sc_conv1(x)) 290 | out += short_cut 291 | out = self.relu(out) 292 | return out 293 | 294 | 295 | class ResNet(nn.Module): 296 | 297 | def __init__(self, block, layers, up_proj_block): 298 | self.inplanes = 64 299 | super(ResNet, self).__init__() 300 | self.conv1_1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, 301 | bias=False) 302 | self.bn1 = nn.BatchNorm2d(64) 303 | self.relu = nn.ReLU(inplace=True) 304 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 305 | self.layer1 = self._make_layer(block, 64, layers[0]) 306 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 307 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 308 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 309 | self.mid_channel = 256 * block.expansion 310 | self.conv2 = nn.Conv2d(512 * block.expansion, 512 * block.expansion, kernel_size=3, 311 | stride=1, padding=1, bias=False) 312 | self.bn2 = nn.BatchNorm2d(512 * block.expansion) 313 | self.up_proj_layer1 = self._make_up_conv_layer(up_proj_block, 314 | self.mid_channel, 315 | int(self.mid_channel / 2)) 316 | self.up_proj_layer2 = self._make_up_conv_layer(up_proj_block, 317 | int(self.mid_channel / 2), 318 | int(self.mid_channel / 4)) 319 | self.up_proj_layer3 = self._make_up_conv_layer(up_proj_block, 320 | int(self.mid_channel / 4), 321 | int(self.mid_channel / 8)) 322 | self.up_proj_layer4 = self._make_up_conv_layer(up_proj_block, 323 | int(self.mid_channel / 8), 324 | int(self.mid_channel / 16)) 325 | self.conv3 = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1, bias=False) 326 | self.post_process_layer = self._make_post_process_layer() 327 | self.gud_up_proj_layer1 = self._make_gud_up_conv_layer(Gudi_UpProj_Block, 2048, 1024, 15, 19) 328 | self.gud_up_proj_layer2 = self._make_gud_up_conv_layer(Gudi_UpProj_Block_Cat, 1024, 512, 29, 38) 329 | self.gud_up_proj_layer3 = self._make_gud_up_conv_layer(Gudi_UpProj_Block_Cat, 512, 256, 57, 76) 330 | self.gud_up_proj_layer4 = self._make_gud_up_conv_layer(Gudi_UpProj_Block_Cat, 256, 64, 114, 152) 331 | self.gud_up_proj_layer5 = self._make_gud_up_conv_layer(Simple_Gudi_UpConv_Block_Last_Layer, 64, 1, 228, 304) 332 | self.gud_up_proj_layer6 = self._make_gud_up_conv_layer(Simple_Gudi_UpConv_Block_Last_Layer, 64, 12, 228, 304) 333 | 334 | def _make_layer(self, block, planes, blocks, stride=1): 335 | downsample = None 336 | if stride != 1 or self.inplanes != planes * block.expansion: 337 | downsample = nn.Sequential( 338 | nn.Conv2d(self.inplanes, planes * block.expansion, 339 | kernel_size=1, stride=stride, bias=False), 340 | nn.BatchNorm2d(planes * block.expansion), 341 | ) 342 | 343 | layers = [] 344 | layers.append(block(self.inplanes, planes, stride, downsample)) 345 | self.inplanes = planes * block.expansion 346 | for i in range(1, blocks): 347 | layers.append(block(self.inplanes, planes)) 348 | 349 | return nn.Sequential(*layers) 350 | 351 | def _make_up_conv_layer(self, up_proj_block, in_channels, out_channels): 352 | return up_proj_block(in_channels, out_channels) 353 | 354 | def _make_gud_up_conv_layer(self, up_proj_block, in_channels, out_channels, oheight, owidth): 355 | return up_proj_block(in_channels, out_channels, oheight, owidth) 356 | 357 | def _make_post_process_layer(self): 358 | return post_process.AffinityPropagate(24, 3) 359 | 360 | def forward(self, x): 361 | [batch_size, channel, height, width] = x.size() 362 | sparse_depth = x.narrow(1, 3, 1).clone() 363 | x = self.conv1_1(x) 364 | skip4 = x 365 | 366 | x = self.bn1(x) 367 | x = self.relu(x) 368 | x = self.maxpool(x) 369 | x = self.layer1(x) 370 | skip3 = x 371 | 372 | x = self.layer2(x) 373 | skip2 = x 374 | 375 | x = self.layer3(x) 376 | x = self.layer4(x) 377 | x = self.bn2(self.conv2(x)) 378 | x = self.gud_up_proj_layer1(x) 379 | x = self.gud_up_proj_layer2(x, skip2) 380 | x = self.gud_up_proj_layer3(x, skip3) 381 | x = self.gud_up_proj_layer4(x, skip4) 382 | 383 | guidance = self.gud_up_proj_layer6(x) 384 | x = self.gud_up_proj_layer5(x) 385 | 386 | x = self.post_process_layer(guidance, x, sparse_depth) 387 | return x 388 | 389 | 390 | def resnet18(pretrained=False, **kwargs): 391 | """Constructs a ResNet-18 model. 392 | Args: 393 | pretrained (bool): If True, returns a model pre-trained on ImageNet 394 | """ 395 | model = ResNet(BasicBlock, [2, 2, 2, 2], UpProj_Block, **kwargs) 396 | if pretrained: 397 | print('==> Load pretrained model..') 398 | pretrained_dict = torch.load(model_path['resnet18']) 399 | import network.utils as utils 400 | model.load_state_dict(utils.load_model_dict(model, pretrained_dict)) 401 | return model 402 | 403 | 404 | def resnet50(pretrained=False, **kwargs): 405 | """Constructs a ResNet-50 model. 406 | Args: 407 | pretrained (bool): If True, returns a model pre-trained on ImageNet 408 | """ 409 | model = ResNet(Bottleneck, [3, 4, 6, 3], UpProj_Block, **kwargs) 410 | if pretrained: 411 | print('==> Load pretrained model..') 412 | pretrained_dict = torch.load(model_path['resnet50']) 413 | import network.utils as utils 414 | model.load_state_dict(utils.load_model_dict(model, pretrained_dict)) 415 | return model 416 | 417 | 418 | if __name__ == '__main__': 419 | import torchsummary 420 | 421 | model = resnet50(pretrained=False) 422 | torchsummary.summary(model, input_size=(4, 228, 304)) 423 | -------------------------------------------------------------------------------- /network/unet_ours.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/6/30 22:38 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | """ 9 | Modified by WangXin 10 | Updated on 16:58:37 19/05/19 11 | Replace upsample with conv_transpose2d to implement up_pooling 12 | """ 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from network.libs.post_process import CSPN_ours as post_process 17 | 18 | """ 19 | if using multi-gpus, replace batchnorm with inplace abn. 20 | Note: When using multi-gpus, if you add new operations, 21 | you should not use inplace operation, such as "+=" or setting flag "inplace=True". 22 | """ 23 | if torch.cuda.device_count() > 1: 24 | from network.libs.inplace_abn import InPlaceABNSync 25 | affine_par = True 26 | import functools 27 | BatchNorm2d = functools.partial(InPlaceABNSync, activation='none') 28 | BatchNorm2d_relu = InPlaceABNSync 29 | else: 30 | affine_par = True 31 | BatchNorm2d = nn.BatchNorm2d 32 | from network.libs.base.operation import BatchNorm2d_Relu 33 | BatchNorm2d_relu = BatchNorm2d_Relu 34 | 35 | # memory analyze 36 | 37 | __all__ = ['ResNet', 'resnet18', 'resnet50'] 38 | 39 | model_urls = { 40 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 41 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 42 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 43 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 44 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 45 | } 46 | 47 | model_path = { 48 | 'resnet18': 'pretrained/resnet18.pth', 49 | 'resnet50': 'pretrained/resnet50.pth' 50 | } 51 | 52 | 53 | def conv3x3(in_planes, out_planes, stride=1): 54 | """3x3 convolution with padding""" 55 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 56 | padding=1, bias=False) 57 | 58 | 59 | class BasicBlock(nn.Module): 60 | expansion = 1 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None): 63 | super(BasicBlock, self).__init__() 64 | self.conv1 = conv3x3(inplanes, planes, stride) 65 | self.bn1 = BatchNorm2d(planes) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.conv2 = conv3x3(planes, planes) 68 | self.bn2 = BatchNorm2d(planes) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | 82 | if self.downsample is not None: 83 | residual = self.downsample(x) 84 | 85 | out += residual 86 | out = self.relu(out) 87 | 88 | return out 89 | 90 | 91 | class Bottleneck(nn.Module): 92 | expansion = 4 93 | 94 | def __init__(self, inplanes, planes, stride=1, downsample=None): 95 | super(Bottleneck, self).__init__() 96 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 97 | self.bn1 = BatchNorm2d(planes) 98 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 99 | padding=1, bias=False) 100 | self.bn2 = BatchNorm2d(planes) 101 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 102 | self.bn3 = BatchNorm2d(planes * 4) 103 | self.relu = nn.ReLU(inplace=False) 104 | self.relu_inplace = nn.ReLU(inplace=True) 105 | self.downsample = downsample 106 | self.stride = stride 107 | 108 | def forward(self, x): 109 | residual = x 110 | 111 | out = self.conv1(x) 112 | out = self.bn1(out) 113 | out = self.relu(out) 114 | 115 | out = self.conv2(out) 116 | out = self.bn2(out) 117 | out = self.relu(out) 118 | 119 | out = self.conv3(out) 120 | out = self.bn3(out) 121 | 122 | if self.downsample is not None: 123 | residual = self.downsample(x) 124 | 125 | out = out + residual 126 | out = self.relu_inplace(out) 127 | 128 | return out 129 | 130 | 131 | class MyBlock(nn.Module): 132 | def __init__(self, oheight=0, owidth=0): 133 | super(MyBlock, self).__init__() 134 | 135 | self.oheight = oheight 136 | self.owidth = owidth 137 | 138 | def _up_pooling(self, x, scale): 139 | N, C, H, W = x.size() 140 | 141 | num_channels = C 142 | weights = torch.zeros(num_channels, 1, scale, scale, device=x.device) 143 | weights[:, :, 0, 0] = 1 144 | y = F.conv_transpose2d(x, weights, stride=scale, groups=num_channels) 145 | del weights 146 | 147 | if self.oheight != scale * H or self.owidth != scale * W: 148 | y = y[:, :, 0:self.oheight, 0:self.owidth] 149 | 150 | return y 151 | 152 | def init_weights(self): 153 | for m in self.modules(): 154 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 155 | nn.init.kaiming_normal_(m.weight.data) 156 | if m.bias is not None: 157 | nn.init.constant_(m.bias.data, 0) 158 | 159 | 160 | class UpProj_Block(MyBlock): 161 | def __init__(self, in_channels, out_channels, oheight=0, owidth=0): 162 | super(UpProj_Block, self).__init__(oheight, owidth) 163 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=False) 164 | self.bn1 = BatchNorm2d(out_channels) 165 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) 166 | self.bn2 = BatchNorm2d(out_channels) 167 | self.sc_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=False) 168 | self.sc_bn1 = BatchNorm2d(out_channels) 169 | self.relu = nn.ReLU(inplace=False) 170 | 171 | def forward(self, x): 172 | x = self._up_pooling(x, 2) 173 | out = self.relu(self.bn1(self.conv1(x))) 174 | out = self.bn2(self.conv2(out)) 175 | short_cut = self.sc_bn1(self.sc_conv1(x)) 176 | out = out + short_cut 177 | out = self.relu(out) 178 | return out 179 | 180 | 181 | class Simple_Gudi_UpConv_Block(MyBlock): 182 | def __init__(self, in_channels, out_channels, oheight=0, owidth=0): 183 | super(Simple_Gudi_UpConv_Block, self).__init__(oheight, owidth) 184 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=False) 185 | self.bn1 = BatchNorm2d(out_channels) 186 | self.relu = nn.ReLU(inplace=False) 187 | 188 | def forward(self, x): 189 | x = self._up_pooling(x, 2) 190 | out = self.relu(self.bn1(self.conv1(x))) 191 | return out 192 | 193 | 194 | class Simple_Gudi_UpConv_Block_Last_Layer(MyBlock): 195 | def __init__(self, in_channels, out_channels, oheight=0, owidth=0): 196 | super(Simple_Gudi_UpConv_Block_Last_Layer, self).__init__(oheight, owidth) 197 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) 198 | 199 | def forward(self, x): 200 | x = self._up_pooling(x, 2) 201 | out = self.conv1(x) 202 | return out 203 | 204 | 205 | class Gudi_UpProj_Block(MyBlock): 206 | def __init__(self, in_channels, out_channels, oheight=0, owidth=0): 207 | super(Gudi_UpProj_Block, self).__init__(oheight, owidth) 208 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=False) 209 | self.bn1 = BatchNorm2d(out_channels) 210 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) 211 | self.bn2 = BatchNorm2d(out_channels) 212 | self.sc_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=False) 213 | self.sc_bn1 = BatchNorm2d(out_channels) 214 | self.relu = nn.ReLU(inplace=False) 215 | 216 | def forward(self, x): 217 | x = self._up_pooling(x, 2) 218 | out = self.relu(self.bn1(self.conv1(x))) 219 | out = self.bn2(self.conv2(out)) 220 | short_cut = self.sc_bn1(self.sc_conv1(x)) 221 | out = out + short_cut 222 | out = self.relu(out) 223 | return out 224 | 225 | 226 | class Gudi_UpProj_Block_Cat(MyBlock): 227 | def __init__(self, in_channels, out_channels, oheight=0, owidth=0): 228 | super(Gudi_UpProj_Block_Cat, self).__init__(oheight, owidth) 229 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=False) 230 | self.bn1 = BatchNorm2d(out_channels) 231 | self.conv1_1 = nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, stride=1, padding=1, bias=False) 232 | self.bn1_1 = BatchNorm2d(out_channels) 233 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) 234 | self.bn2 = BatchNorm2d(out_channels) 235 | self.sc_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=False) 236 | self.sc_bn1 = BatchNorm2d(out_channels) 237 | self.relu = nn.ReLU(inplace=False) 238 | 239 | def forward(self, x, side_input): 240 | x = self._up_pooling(x, 2) 241 | out = self.relu(self.bn1(self.conv1(x))) 242 | out = torch.cat((out, side_input), 1) 243 | out = self.relu(self.bn1_1(self.conv1_1(out))) 244 | out = self.bn2(self.conv2(out)) 245 | short_cut = self.sc_bn1(self.sc_conv1(x)) 246 | out = out + short_cut 247 | out = self.relu(out) 248 | return out 249 | 250 | 251 | class ResNet(nn.Module): 252 | 253 | def __init__(self, block, layers, up_proj_block): 254 | self.inplanes = 64 255 | super(ResNet, self).__init__() 256 | self.conv1_1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, 257 | bias=False) 258 | self.bn1 = BatchNorm2d(64) 259 | self.relu = nn.ReLU(inplace=False) 260 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 261 | self.layer1 = self._make_layer(block, 64, layers[0]) 262 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 263 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 264 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 265 | 266 | self.mid_channel = 256 * block.expansion 267 | self.conv2 = nn.Conv2d(512 * block.expansion, 512 * block.expansion, kernel_size=3, stride=1, padding=1, 268 | bias=False) 269 | self.bn2 = BatchNorm2d(512 * block.expansion) 270 | self.conv3 = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1, bias=False) 271 | 272 | self.post_process_layer = self._make_post_process_layer() 273 | 274 | self.gud_up_proj_layer1 = self._make_gud_up_conv_layer(Gudi_UpProj_Block, 2048, 1024, 15, 19) 275 | self.gud_up_proj_layer2 = self._make_gud_up_conv_layer(Gudi_UpProj_Block_Cat, 1024, 512, 29, 38) 276 | self.gud_up_proj_layer3 = self._make_gud_up_conv_layer(Gudi_UpProj_Block_Cat, 512, 256, 57, 76) 277 | self.gud_up_proj_layer4 = self._make_gud_up_conv_layer(Gudi_UpProj_Block_Cat, 256, 64, 114, 152) 278 | self.gud_up_proj_layer5 = self._make_gud_up_conv_layer(Simple_Gudi_UpConv_Block_Last_Layer, 64, 1, 228, 304) 279 | self.gud_up_proj_layer6 = self._make_gud_up_conv_layer(Simple_Gudi_UpConv_Block_Last_Layer, 64, 8, 228, 304) 280 | 281 | def _make_layer(self, block, planes, blocks, stride=1): 282 | downsample = None 283 | if stride != 1 or self.inplanes != planes * block.expansion: 284 | downsample = nn.Sequential( 285 | nn.Conv2d(self.inplanes, planes * block.expansion, 286 | kernel_size=1, stride=stride, bias=False), 287 | BatchNorm2d(planes * block.expansion, affine=affine_par), 288 | ) 289 | 290 | layers = [] 291 | layers.append(block(self.inplanes, planes, stride, downsample)) 292 | self.inplanes = planes * block.expansion 293 | for i in range(1, blocks): 294 | layers.append(block(self.inplanes, planes)) 295 | 296 | return nn.Sequential(*layers) 297 | 298 | def _make_up_conv_layer(self, up_proj_block, in_channels, out_channels): 299 | return up_proj_block(in_channels, out_channels) 300 | 301 | def _make_gud_up_conv_layer(self, up_proj_block, in_channels, out_channels, oheight, owidth): 302 | return up_proj_block(in_channels, out_channels, oheight, owidth) 303 | 304 | def _make_post_process_layer(self): 305 | return post_process.AffinityPropagate(prop_time=24) 306 | 307 | def forward(self, x): 308 | sparse_depth = x.narrow(1, 3, 1).clone() # get sparse depth 309 | x = self.conv1_1(x) 310 | skip4 = x 311 | 312 | x = self.bn1(x) 313 | x = self.relu(x) 314 | x = self.maxpool(x) 315 | x = self.layer1(x) 316 | skip3 = x 317 | 318 | x = self.layer2(x) 319 | skip2 = x 320 | 321 | x = self.layer3(x) 322 | x = self.layer4(x) 323 | 324 | x = self.bn2(self.conv2(x)) 325 | x = self.gud_up_proj_layer1(x) 326 | 327 | x = self.gud_up_proj_layer2(x, skip2) 328 | x = self.gud_up_proj_layer3(x, skip3) 329 | x = self.gud_up_proj_layer4(x, skip4) 330 | 331 | blur_depth = self.gud_up_proj_layer5(x) 332 | guidance = self.gud_up_proj_layer6(x) 333 | x = self.post_process_layer(blur_depth, guidance, sparse_depth=sparse_depth) 334 | 335 | return [x, guidance] 336 | 337 | 338 | def resnet18(pretrained=False, **kwargs): 339 | """Constructs a ResNet-18 model. 340 | Args: 341 | pretrained (bool): If True, returns a model pre-trained on ImageNet 342 | """ 343 | model = ResNet(BasicBlock, [2, 2, 2, 2], UpProj_Block, **kwargs) 344 | if pretrained: 345 | print('==> Load pretrained model..') 346 | pretrained_dict = torch.load(model_path['resnet18']) 347 | import network.utils as utils 348 | model.load_state_dict(utils.load_model_dict(model, pretrained_dict)) 349 | return model 350 | 351 | 352 | def resnet50(pretrained=False, **kwargs): 353 | """Constructs a ResNet-50 model. 354 | Args: 355 | pretrained (bool): If True, returns a model pre-trained on ImageNet 356 | """ 357 | model = ResNet(Bottleneck, [3, 4, 6, 3], UpProj_Block, **kwargs) 358 | if pretrained: 359 | print('==> Load pretrained model..') 360 | pretrained_dict = torch.load(model_path['resnet50']) 361 | import network.utils as utils 362 | model.load_state_dict(utils.load_model_dict(model, pretrained_dict)) 363 | return model 364 | 365 | 366 | if __name__ == '__main__': 367 | img = torch.randn(1, 4, 228, 304) 368 | model = resnet50(pretrained=False) 369 | 370 | y = model(img) 371 | # import torchsummary 372 | # torchsummary.summary(model, input_size=(4, 228, 304)) 373 | 374 | # import time 375 | # end = time.time() 376 | # pred = model(img) 377 | # end = time.time() - end 378 | # print('pac implementation cost time = ', end) 379 | 380 | 381 | -------------------------------------------------------------------------------- /network/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Time : 2019-05-19 16:44 5 | @Author : Wang Xin 6 | @Email : wangxin_buaa@163.com 7 | @File : utils.py 8 | """ 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.backends.cudnn as cudnn 13 | import string 14 | 15 | 16 | # update pretrained model params according to my model params 17 | def load_model_dict(my_model, pretrained_dict): 18 | my_model_dict = my_model.state_dict() 19 | # 1. filter out unnecessary keys 20 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in my_model_dict} 21 | # 2. overwrite entries in the existing state dict 22 | my_model_dict.update(pretrained_dict) 23 | 24 | return my_model_dict 25 | 26 | 27 | def update_conv_spn_model(out_dict, in_dict): 28 | in_dict = {k: v for k, v in in_dict.items() if k in out_dict} 29 | return in_dict -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/3/2 11:16 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | 9 | def parse_command(): 10 | modality_names = ['rgb', 'rgbd', 'd'] 11 | from dataloaders.nyu_dataloader.dense_to_sparse import UniformSampling, SimulatedStereo 12 | sparsifier_names = [x.name for x in [UniformSampling, SimulatedStereo]] 13 | schedular_names = ['poly_lr', 'reduce_lr'] 14 | loss_type_names = ['none', 'ms', 'dsn', 'all'] 15 | upsample_types = ['dgf', 'pac', 'djif', 'none'] 16 | pretrained_choices = ['imagenet', 'vkitti'] 17 | distance_types = ['si', 'sq', 'sl'] 18 | 19 | import argparse 20 | parser = argparse.ArgumentParser(description='MonoDepth') 21 | 22 | # model parameters 23 | parser.add_argument('--arch', default='up', type=str) 24 | parser.add_argument('--restore', default='', 25 | type=str, metavar='PATH', 26 | help='path to latest checkpoint (default: ./run/run_1/checkpoint-5.pth.tar)') 27 | parser.add_argument('--pretrained', default='imagenet', type=str, choices=pretrained_choices, 28 | help='pretrained model: vkitti, imagenet') 29 | parser.add_argument('--freeze', default=True, type=bool) 30 | parser.add_argument('--upt', default='none', choices=upsample_types, 31 | help='upsample types, if none, do not upsample.') 32 | 33 | # training parameters 34 | parser.add_argument('-b', '--batch-size', default=8, type=int, help='mini-batch size (default: 4)') 35 | 36 | # criterion parameters 37 | parser.add_argument('--criterion', default='l1', type=str) 38 | parser.add_argument('--loss_wrapper', default='none', type=str, choices=loss_type_names, 39 | help='if true, using DSN criteria') 40 | parser.add_argument('--distance', default='si', choices=distance_types) 41 | 42 | # lr scheduler parameters 43 | parser.add_argument('--scheduler', default='reduce_lr', type=str, choices=schedular_names) 44 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 45 | metavar='LR', help='initial learning rate (default 0.0001)') 46 | 47 | parser.add_argument('--factor', default=0.2, type=float, help='factor in ReduceLROnPlateau.') 48 | parser.add_argument('--lr_patience', default=2, type=int, 49 | help='Patience of LR scheduler. See documentation of ReduceLROnPlateau.') 50 | parser.add_argument('--max_iter', default=200000, type=int, metavar='N', 51 | help='number of total epochs to run (default: 15)') 52 | parser.add_argument('--decay_iter', default=10, type=int, 53 | help='decat iter in PolynomialLR.') 54 | parser.add_argument('--gamma', default=0.9, type=float, help='gamma in PolynomialLR, MultiStepLR, ExponentialLR.') 55 | 56 | # optimizer parameters 57 | parser.add_argument('--opt', default='sgd', type=str, choices=['adam', 'sgd']) 58 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') 59 | parser.add_argument('--weight_decay', '--wd', default=0.0001, type=float, 60 | metavar='W', help='weight decay (default: 1e-4)') 61 | 62 | # dataset 63 | parser.add_argument('--dataset', default='nyu', type=str, 64 | help='dataset used for training, kitti and nyu is available') 65 | parser.add_argument('-j', '--workers', default=0, type=int, metavar='N', 66 | help='number of data loading workers (default: 10)') 67 | parser.add_argument('--jitter', type=float, default=0.1, help='color jitter for images') 68 | parser.add_argument('--val_selection', type=bool, default=True) 69 | parser.add_argument('--discretization', type=int, default=1, help='discretize depth using the given value.') 70 | 71 | # data sample strategy 72 | parser.add_argument('--modality', '-m', metavar='MODALITY', default='rgbd', choices=modality_names) 73 | parser.add_argument('-s', '--num-samples', default=500, type=int, metavar='N', 74 | help='number of sparse depth samples (default: 0)') 75 | parser.add_argument('--max-depth', default=-1.0, type=float, metavar='D', 76 | help='cut-off depth of sparsifier, negative values means infinity (default: inf [m])') 77 | parser.add_argument('--sparsifier', metavar='SPARSIFIER', default=UniformSampling.name, choices=sparsifier_names, 78 | help='sparsifier: ' + ' | '.join(sparsifier_names) + ' (default: ' + UniformSampling.name + ')') 79 | 80 | # others 81 | parser.add_argument('--manual_seed', default=1, type=int, help='Manually set random seed') 82 | parser.add_argument('--gpu', default=None, type=str, help='if not none, use Single GPU') 83 | parser.add_argument('--print_freq', '-p', default=10, type=int, metavar='N', help='print frequency (default: 10)') 84 | args = parser.parse_args() 85 | 86 | if args.modality == 'rgb' and args.num_samples != 0: 87 | print("number of samples is forced to be 0 when input modality is rgb") 88 | args.num_samples = 0 89 | if args.modality == 'rgb' and args.max_depth != 0.0: 90 | print("max depth is forced to be 0.0 when input modality is rgb/rgbd") 91 | args.max_depth = 0.0 92 | 93 | return args 94 | 95 | 96 | # TODO: 97 | class Options(object): 98 | 99 | def __init__(self): 100 | # model parameters 101 | self.arch = None 102 | self.restore = None 103 | self.pretrained = None 104 | 105 | self.upt = None 106 | 107 | # training parameters 108 | self.batch_size = None 109 | 110 | # criterion parameters 111 | self.criterion = None 112 | self.loss_wrapper = None 113 | self.distance = None 114 | 115 | # lr scheduler parameters 116 | self.scheduler = None 117 | self.lr = None 118 | 119 | self.factor = None 120 | self.lr_patience = None 121 | self.max_iter = None 122 | self.decay_iter = None 123 | self.gamma = None 124 | 125 | # optimizer paramters 126 | self.opt = None 127 | self.momentum = None 128 | self.weight_decay = None 129 | 130 | # dataset paramters 131 | self.dataset = None 132 | self.workers = None 133 | self.jitter = None 134 | self.val_selection = None 135 | self.discretization = None 136 | 137 | # data sample strategy 138 | self.modality = None 139 | self.num_samples = None 140 | self.max_depth = None 141 | self.sparsifier = None 142 | 143 | # others 144 | self.manual_seed = None 145 | self.gpu = None 146 | self.print_freq = None 147 | 148 | def parse_command(self): 149 | args = parse_command() 150 | 151 | self.arch = args.arch 152 | self.restore = args.restore 153 | self.pretrained = args.pretrained 154 | 155 | self.upt = args.upt 156 | 157 | # training parameters 158 | self.batch_size = args.batch_size 159 | 160 | # criterion parameters 161 | self.criterion = args.criterion 162 | self.loss_wrapper = args.loss_wrapper 163 | self.distance = args.distance 164 | 165 | # lr scheduler parameters 166 | self.scheduler = args.scheduler 167 | self.lr = args.lr 168 | 169 | self.factor = args.factor 170 | self.lr_patience = args.lr_patience 171 | self.max_iter = args.max_iter 172 | self.decay_iter = args.decay_iter 173 | self.gamma = args.gamma 174 | 175 | # optimizer paramters 176 | self.opt = args.opt 177 | self.momentum = args.momentum 178 | self.weight_decay = args.weight_decay 179 | 180 | # dataset paramters 181 | self.dataset = args.dataset 182 | self.modality = args.modality 183 | self.num_samples = args.num_samples 184 | self.max_depth = args.max_depth 185 | self.sparsifier = args.sparsifier 186 | self.workers = args.workers 187 | self.jitter = args.jitter 188 | self.val_selection = args.val_selection 189 | self.discretization = args.discretization 190 | 191 | # others 192 | self.manual_seed = args.manual_seed 193 | self.gpu = args.gpu 194 | self.print_freq = args.print_freq 195 | 196 | def write_config(self, output_directory): 197 | import os 198 | config_txt = os.path.join(output_directory, 'options.txt') 199 | 200 | # write training parameters to config file 201 | if not os.path.exists(config_txt): 202 | with open(config_txt, 'w') as txtfile: 203 | out_str = self.__str__() 204 | txtfile.write(out_str) 205 | 206 | def print_items(self): 207 | print(self.__str__()) 208 | 209 | def __str__(self): 210 | out_str = 'model parameters:\n' 211 | out_str += ' arch:' + str(self.arch) + '\n' 212 | out_str += ' restore model path:' + str(self.restore) + '\n' 213 | out_str += ' pretrained model type:' + str(self.pretrained) + '\n' 214 | out_str += ' upsample type:' + str(self.upt) + '\n' 215 | 216 | out_str += '\ntraining parameters:\n' 217 | out_str += ' batch size:' + str(self.batch_size) + '\n' 218 | 219 | out_str += '\ncriterion parameters:\n' 220 | out_str += ' criterion:' + str(self.criterion) + '\n' 221 | out_str += ' loss wrapper:' + str(self.loss_wrapper) + '\n' 222 | out_str += ' metric distance type:' + str(self.distance) + '\n' 223 | 224 | out_str += '\nlr scheduler parameters:\n' 225 | out_str += ' lr:' + str(self.lr) + '\n' 226 | out_str += ' scheduler:' + str(self.scheduler) + '\n' 227 | out_str += ' factor:' + str(self.factor) + '\n' 228 | out_str += ' lr patience:' + str(self.lr_patience) + '\n' 229 | out_str += ' max iter:' + str(self.max_iter) + '\n' 230 | out_str += ' decay iter:' + str(self.decay_iter) + '\n' 231 | out_str += ' gamma:' + str(self.gamma) + '\n' 232 | 233 | out_str += '\noptimizer parameters:\n' 234 | out_str += ' opt:' + str(self.opt) + '\n' 235 | out_str += ' momentum:' + str(self.momentum) + '\n' 236 | out_str += ' weight decay:' + str(self.weight_decay) + '\n' 237 | 238 | out_str += '\ndataset parameters:\n' 239 | out_str += ' dataset:' + str(self.dataset) + '\n' 240 | out_str += ' workers:' + str(self.workers) + '\n' 241 | out_str += ' jitter:' + str(self.jitter) + '\n' 242 | out_str += ' val selection:' + str(self.val_selection) + '\n' 243 | out_str += ' discretization:' + str(self.discretization) + '\n' 244 | 245 | out_str += '\ndata sample strategy:\n' 246 | out_str += ' modality:' + str(self.modality) + '\n' 247 | out_str += ' sparsifier:' + str(self.sparsifier) + '\n' 248 | out_str += ' num samples:' + str(self.num_samples) + '\n' 249 | out_str += ' max depth:' + str(self.max_depth) + '\n' 250 | 251 | out_str += '\nothers\n' 252 | out_str += ' manual seed:' + str(self.manual_seed) + '\n' 253 | out_str += ' gpu ids:' + str(self.gpu) + '\n' 254 | out_str += ' print freq:' + str(self.print_freq) + '\n' 255 | 256 | return out_str 257 | 258 | 259 | if __name__ == '__main__': 260 | import torch 261 | 262 | # print(torch.cuda.current_device()) 263 | 264 | opt = Options() 265 | opt.parse_command() 266 | import os 267 | 268 | opt.write_config(os.getcwd()) 269 | 270 | # args = vars(args) 271 | # args = sorted(args.items(), key=lambda x:x[0]) 272 | # print(args) 273 | 274 | # from dataloaders import create_loader 275 | # 276 | # # 277 | # train_loader = create_loader(args, mode='train') 278 | # val_loader = create_loader(args, mode='val') 279 | # test_loader = create_loader(args, mode='test') 280 | # # 281 | # # print('batch size:', args.batch_size) 282 | # # print('train nums:', len(train_loader)) 283 | # # print('val nums:', len(val_loader)) 284 | # # print('test nums:', len(test_loader)) 285 | # # 286 | # print('...train loader ...') 287 | # import torch 288 | # 289 | # for i, data in enumerate(train_loader): 290 | # img, depth = data 291 | # depth = depth.to() 292 | # print(img.shape, depth.shape) 293 | # print(img) 294 | # print(depth) 295 | # print('max depth:', torch.max(depth), ' min depth:', torch.min(depth)) 296 | # break 297 | # 298 | # print('... val loader ...') 299 | # for i, data in enumerate(val_loader): 300 | # img, depth = data 301 | # print(img.shape, depth.shape) 302 | # print('max depth:', torch.max(depth), ' min depth:', torch.min(depth)) 303 | # break 304 | # 305 | # print('... test loader ...') 306 | # for i, data in enumerate(test_loader): 307 | # img = data 308 | # print(img.shape) 309 | # break 310 | -------------------------------------------------------------------------------- /result/nyu.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dontLoveBugs/CSPN_monodepth/7363bf749b8df4ea29f1a4fa9eebddbf97cf3f4b/result/nyu.PNG --------------------------------------------------------------------------------