├── .gitignore ├── LICENSE ├── README.md ├── dataloaders ├── city_dataloader.py ├── dataloader.py ├── dense_to_sparse.py ├── kitti_dataloader.py ├── nyu_dataloader.py └── transforms.py ├── evaluate.py ├── images └── 500.gif ├── models ├── DCCA_sparse_model.py ├── DCCA_sparse_networks.py ├── __init__.py └── base_model.py ├── options ├── __init__.py ├── base_options.py └── options.py ├── train_depth_complete.py ├── util ├── __init__.py ├── util.py └── visualizer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # others 107 | .vis/ 108 | .checkpoints/ 109 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Cho Ying Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep RGB-D Canonical Correlation Analysis for Sparse Depth Completion 2 | This is the official PyTorch implemenation for our NeurIPS 2019 paper by Yiqi Zhong\*, Cho-Ying Wu\*, Suya You, Ulrich Neumann (\*Equal Contribution) at USC 3 | 4 | Paper: [Arxiv]. 5 | 6 | 7 | 8 | Check out the whole video demo [Youtube]. 9 | 10 | **Also check our newest work on depth estimation/completion using sensor fusion SCADC!** 11 | 12 | # Prerequisites 13 | Linux 14 | Python 3 15 | PyTorch 1.0+ (Orginally developed upder v1.0, testing on v1.5 is also fine) 16 | NVIDIA GPU + CUDA CuDNN 17 | Other common libraries: matplotlib, cv2, PIL 18 | 19 | # Getting Started 20 | 21 | Data Preparation: 22 | Please refer to [KITTI] or [NYU Depth V2] and process them into h5 files. Here also provides preprocessed data. 23 | 24 | # Tutorial: 25 | 26 | 1. Create a folder and a subfolder 'checkpoint/kitti' 27 | 2. Download the pretrained weights: [NYU-Depth 500 points training] [KITTI 500 points training] and put the .pth under 'checkpoint/kitti/' 28 | 3. Prepare data in the previous "getting started" section 29 | 4. Run "python3 evaluate.py --name kitti --checkpoints_dir ./checkpoint/ --test_path [path ot the testing file] " 30 | 4. You'll see completed depth maps are saved under 'vis/' 31 | 32 | # Train/Evaluation: 33 | 34 | For training, please run 35 | 36 | python3 train_depth_complete.py --name kitti --checkpoints_dir [path to save_dir] --train_path [train_data_dir] --test_path [test_data_dir] 37 | 38 | If you use the preprocessed data from here. The train/test data path should be ./kitti/train or ./kitti/val/ under your data directory. 39 | 40 | If you want to use your data, please make your data into h5 dataset. (See dataloaders/dataloader.py) 41 | 42 | Other specifications: `--continue_train` would load the lastest saved ckpt. Also set --epoch_count to tell what's the next epoch_number. Otherwise, will start from epoch 0. Set hyperparameters by `--lr`, `--batch_size`, `--weight_decay`, or others. Please refer to the options/base_options.py and options/options.py 43 | 44 | Note that the default batch size is 4 during the training and use gpu:0. You can set larger batch size (--batch_size=xx) with more gpus (--gpu_ids="0,1,2,3") to attain larger batch size training. 45 | 46 | Example command: 47 | 48 | python3 train_depth_complete.py --name kitti --checkpoints_dir ./checkpoints --lr 0.001 --batch_size 4 --train_path './kitti/train/' --test_path './kitti/val/' --continue_train --epoch_count [next_epoch_number] 49 | 50 | For evalutation, please run 51 | 52 | python3 evaluate.py --name kitti --checkpoints_dir [path to save_dir to load ckpt] --test_path [test_data_dir] [--epoch [epoch number]] 53 | 54 | This will load the latest checkpoint to evaluate. Add `--epoch` to specify which epoch checkpoint you want to load. 55 | 56 | # Update: 02/10/2020 57 | 58 | 1.Fix several bugs and take off redundant options. 59 | 60 | 2.Release Orb sparsifier 61 | 62 | 3.Pretrain models release: [NYU-Depth 500 points training] [KITTI 500 points training] 63 | 64 | 65 | # Update: 04/19/2021 66 | 67 | 1. Revise README and add a tutorial 68 | 2. Several minor revisions 69 | 70 | 71 | If you find our work useful, please consider to cite our work. 72 | 73 | @inproceedings{zhong2019deep, 74 | title={Deep rgb-d canonical correlation analysis for sparse depth completion}, 75 | author={Zhong, Yiqi and Wu, Cho-Ying and You, Suya and Neumann, Ulrich}, 76 | booktitle={Advances in Neural Information Processing Systems}, 77 | pages={5332--5342}, 78 | year={2019} 79 | 80 | 81 | -------------------------------------------------------------------------------- /dataloaders/city_dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import dataloaders.transforms as transforms 3 | from dataloaders.dataloader import MyDataloader 4 | 5 | class CITY_SCAPESDataset(MyDataloader): 6 | def __init__(self, root, type, sparsifier=None, modality='rgb'): 7 | super(CITY_SCAPESDataset, self).__init__(root, type, sparsifier, modality) 8 | self.output_size = (228, 912) 9 | 10 | def train_transform(self, rgb, depth): 11 | s = np.random.uniform(1.0, 1.5) # random scaling 12 | depth_np = depth / s 13 | angle = np.random.uniform(-5.0, 5.0) # random rotation degrees 14 | do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip 15 | 16 | # perform 1st step of data augmentation 17 | transform = transforms.Compose([ 18 | transforms.Crop(0, 20, 750, 2000), 19 | transforms.Resize(500 / 750), 20 | transforms.Rotate(angle), 21 | transforms.Resize(s), 22 | transforms.CenterCrop(self.output_size), 23 | transforms.HorizontalFlip(do_flip) 24 | ]) 25 | rgb_np = transform(rgb) 26 | rgb_np = self.color_jitter(rgb_np) # random color jittering 27 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255 28 | # Scipy affine_transform produced RuntimeError when the depth map was 29 | # given as a 'numpy.ndarray' 30 | depth_np = np.asfarray(depth_np, dtype='float32') 31 | depth_np = transform(depth_np) 32 | 33 | return rgb_np, depth_np 34 | 35 | def val_transform(self, rgb, depth): 36 | depth_np = depth 37 | transform = transforms.Compose([ 38 | transforms.Crop(0, 20, 750, 2000), 39 | transforms.Resize(500 / 750), 40 | transforms.CenterCrop(self.output_size), 41 | ]) 42 | rgb_np = transform(rgb) 43 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255 44 | depth_np = np.asfarray(depth_np, dtype='float32') 45 | depth_np = transform(depth_np) 46 | 47 | return rgb_np, depth_np 48 | 49 | -------------------------------------------------------------------------------- /dataloaders/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import numpy as np 4 | import torch.utils.data as data 5 | import h5py 6 | import dataloaders.transforms as transforms 7 | import torch 8 | 9 | IMG_EXTENSIONS = ['.h5',] 10 | 11 | def is_image_file(filename): 12 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 13 | 14 | def find_classes(dir): 15 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 16 | classes.sort() 17 | class_to_idx = {classes[i]: i for i in range(len(classes))} 18 | return classes, class_to_idx 19 | 20 | def make_dataset(dir, class_to_idx): 21 | images = [] 22 | dir = os.path.expanduser(dir) 23 | for target in sorted(os.listdir(dir)): 24 | d = os.path.join(dir, target) 25 | if not os.path.isdir(d): 26 | continue 27 | for root, _, fnames in sorted(os.walk(d)): 28 | for fname in sorted(fnames): 29 | if is_image_file(fname): 30 | path = os.path.join(root, fname) 31 | item = (path, class_to_idx[target]) 32 | images.append(item) 33 | return images 34 | 35 | def h5_loader(path): 36 | h5f = h5py.File(path, "r") 37 | rgb = np.array(h5f['rgb']) 38 | rgb = np.transpose(rgb, (1, 2, 0)) 39 | depth = np.array(h5f['depth']) 40 | return rgb, depth 41 | 42 | to_tensor = transforms.ToTensor() 43 | 44 | class MyDataloader(data.Dataset): 45 | modality_names = ['rgb', 'rgbd', 'd','rgbdm'] # , 'g', 'gd' 46 | color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4) 47 | 48 | def __init__(self, root, type, sparsifier=None, modality='rgb', loader=h5_loader): 49 | classes, class_to_idx = find_classes(root) 50 | imgs = make_dataset(root, class_to_idx) 51 | assert len(imgs)>0, "Found 0 images in subfolders of: " + root + "\n" 52 | print("Found {} images in {} folder.".format(len(imgs), type)) 53 | self.root = root 54 | self.imgs = imgs 55 | self.classes = classes 56 | self.class_to_idx = class_to_idx 57 | if type == 'train': 58 | self.transform = self.train_transform 59 | elif type == 'val': 60 | self.transform = self.val_transform 61 | else: 62 | raise (RuntimeError("Invalid dataset type: " + type + "\n" 63 | "Supported dataset types are: train, val")) 64 | self.loader = loader 65 | self.sparsifier = sparsifier 66 | 67 | assert (modality in self.modality_names), "Invalid modality type: " + modality + "\n" + \ 68 | "Supported dataset types are: " + ''.join(self.modality_names) 69 | self.modality = modality 70 | 71 | def train_transform(self, rgb, depth): 72 | raise (RuntimeError("train_transform() is not implemented. ")) 73 | 74 | def val_transform(rgb, depth): 75 | raise (RuntimeError("val_transform() is not implemented.")) 76 | 77 | def create_sparse_depth(self, rgb, depth): 78 | if self.sparsifier is None: 79 | return depth 80 | else: 81 | mask_keep = self.sparsifier.dense_to_sparse(rgb, depth) 82 | sparse_depth = np.zeros(depth.shape) 83 | sparse_depth[mask_keep] = depth[mask_keep] 84 | return sparse_depth 85 | 86 | def create_sparse_depth_rgb(self, rgb, depth): 87 | if self.sparsifier is None: 88 | return depth 89 | else: 90 | mask_keep = self.sparsifier.dense_to_sparse(rgb, depth) 91 | sparse_depth = np.zeros(depth.shape) 92 | sparse_depth[mask_keep] = depth[mask_keep] 93 | sparse_rgb = np.zeros(rgb.shape) 94 | sparse_rgb[mask_keep,:] = rgb[mask_keep,:] 95 | sparse_mask = np.zeros(depth.shape) 96 | sparse_mask[mask_keep] = 1 97 | mask_keep = mask_keep.astype(np.uint8) 98 | return sparse_depth,sparse_rgb, mask_keep 99 | 100 | def create_rgbdm(self, rgb, depth): 101 | sparse_depth,sparse_rgb,mask = self.create_sparse_depth_rgb(rgb, depth) 102 | rgbd = np.append(rgb, np.expand_dims(sparse_depth, axis=2),axis=2) 103 | rgbdm = np.append(rgbd, sparse_rgb, axis=2) 104 | rgbdm = np.append(rgbdm, np.expand_dims(mask, axis=2),axis=2) 105 | 106 | return rgbdm 107 | 108 | def create_rgbd(self, rgb, depth): 109 | sparse_depth = self.create_sparse_depth(rgb, depth) 110 | rgbd = np.append(rgb, np.expand_dims(sparse_depth, axis=2), axis=2) 111 | 112 | return rgbd 113 | 114 | def __getraw__(self, index): 115 | path, target = self.imgs[index] 116 | rgb, depth = self.loader(path) 117 | return rgb, depth 118 | 119 | def __getitem__(self, index): 120 | rgb, depth = self.__getraw__(index) 121 | if self.transform is not None: 122 | rgb_np, depth_np = self.transform(rgb, depth) 123 | else: 124 | raise(RuntimeError("transform not defined")) 125 | 126 | if self.modality == 'rgb': 127 | input_np = rgb_np 128 | elif self.modality == 'rgbd': 129 | input_np = self.create_rgbd(rgb_np, depth_np) 130 | elif self.modality == 'd': 131 | input_np = self.create_sparse_depth(rgb_np, depth_np) 132 | elif self.modality == 'rgbdm': 133 | input_np = self.create_rgbdm(rgb_np, depth_np) 134 | 135 | input_tensor = to_tensor(input_np) 136 | while input_tensor.dim() < 3: 137 | input_tensor = input_tensor.unsqueeze(0) 138 | depth_tensor = to_tensor(depth_np) 139 | depth_tensor = depth_tensor.unsqueeze(0) 140 | 141 | return input_tensor, depth_tensor 142 | 143 | def __len__(self): 144 | return len(self.imgs) -------------------------------------------------------------------------------- /dataloaders/dense_to_sparse.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | 5 | def rgb2grayscale(rgb): 6 | return rgb[:, :, 0] * 0.2989 + rgb[:, :, 1] * 0.587 + rgb[:, :, 2] * 0.114 7 | 8 | 9 | class DenseToSparse: 10 | def __init__(self): 11 | pass 12 | 13 | def dense_to_sparse(self, rgb, depth): 14 | pass 15 | 16 | def __repr__(self): 17 | pass 18 | 19 | class UniformSampling(DenseToSparse): 20 | name = "uar" 21 | def __init__(self, num_samples, max_depth=np.inf): 22 | DenseToSparse.__init__(self) 23 | self.num_samples = num_samples 24 | self.max_depth = max_depth 25 | 26 | def __repr__(self): 27 | return "%s{ns=%d,md=%f}" % (self.name, self.num_samples, self.max_depth) 28 | 29 | def dense_to_sparse(self, rgb, depth): 30 | """ 31 | Samples pixels with `num_samples`/#pixels probability in `depth`. 32 | Only pixels with a maximum depth of `max_depth` are considered. 33 | If no `max_depth` is given, samples in all pixels 34 | """ 35 | mask_keep = depth > 0 36 | if self.max_depth is not np.inf: 37 | mask_keep = np.bitwise_and(mask_keep, depth <= self.max_depth) 38 | n_keep = np.count_nonzero(mask_keep) 39 | if n_keep == 0: 40 | return mask_keep 41 | else: 42 | prob = float(self.num_samples) / n_keep 43 | return np.bitwise_and(mask_keep, np.random.uniform(0, 1, depth.shape) < prob) 44 | 45 | 46 | 47 | class SimulatedStereo(DenseToSparse): 48 | name = "sim_stereo" 49 | 50 | def __init__(self, num_samples, max_depth=np.inf, dilate_kernel=3, dilate_iterations=1): 51 | DenseToSparse.__init__(self) 52 | self.num_samples = num_samples 53 | self.max_depth = max_depth 54 | self.dilate_kernel = dilate_kernel 55 | self.dilate_iterations = dilate_iterations 56 | 57 | def __repr__(self): 58 | return "%s{ns=%d,md=%f,dil=%d.%d}" % \ 59 | (self.name, self.num_samples, self.max_depth, self.dilate_kernel, self.dilate_iterations) 60 | 61 | # We do not use cv2.Canny, since that applies non max suppression 62 | # So we simply do 63 | # RGB to intensitities 64 | # Smooth with gaussian 65 | # Take simple sobel gradients 66 | # Threshold the edge gradient 67 | # Dilatate 68 | def dense_to_sparse(self, rgb, depth): 69 | gray = rgb2grayscale(rgb) 70 | 71 | 72 | blurred = cv2.GaussianBlur(gray, (5, 5), 0) 73 | gx = cv2.Sobel(blurred, cv2.CV_64F, 1, 0, ksize=5) 74 | gy = cv2.Sobel(blurred, cv2.CV_64F, 0, 1, ksize=5) 75 | 76 | depth_mask = np.bitwise_and(depth != 0.0, depth <= self.max_depth) 77 | 78 | edge_fraction = float(self.num_samples) / np.size(depth) 79 | 80 | mag = cv2.magnitude(gx, gy) 81 | min_mag = np.percentile(mag[depth_mask], 100 * (1.0 - edge_fraction)) 82 | mag_mask = mag >= min_mag 83 | 84 | if self.dilate_iterations >= 0: 85 | kernel = np.ones((self.dilate_kernel, self.dilate_kernel), dtype=np.uint8) 86 | cv2.dilate(mag_mask.astype(np.uint8), kernel, iterations=self.dilate_iterations) 87 | 88 | mask = np.bitwise_and(mag_mask, depth_mask) 89 | return mask 90 | 91 | 92 | class ORBSampling(DenseToSparse): 93 | name = "ORB" 94 | def __init__(self,max_depth=np.inf): 95 | DenseToSparse.__init__(self) 96 | self.max_depth = max_depth 97 | 98 | def __repr__(self): 99 | return "%s{ns=%d,md=%f}" % (self.name, self.max_depth) 100 | 101 | def dense_to_sparse(self, rgb, depth): 102 | """ 103 | Samples pixels with `num_samples`/#pixels probability in `depth`. 104 | Only pixels with a maximum depth of `max_depth` are considered. 105 | If no `max_depth` is given, samples in all pixels 106 | """ 107 | mask_keep = depth > 0 108 | 109 | orb = cv2.ORB_create() 110 | rgb_ori = (rgb.copy()*255).astype(np.uint8) 111 | kp = orb.detect(rgb_ori,None) 112 | 113 | mask_keep_orb = np.zeros(mask_keep.shape).astype(mask_keep.dtype) 114 | for marker in kp: 115 | position = np.asarray(marker.pt).astype(np.uint8) 116 | mask_keep_orb[position[1]][position[0]] = True 117 | if self.max_depth is not np.inf: 118 | mask_keep = np.bitwise_and(mask_keep, depth <= self.max_depth) 119 | 120 | mask_keep = np.bitwise_and(mask_keep, mask_keep_orb) 121 | n_keep = np.count_nonzero(mask_keep) 122 | return mask_keep -------------------------------------------------------------------------------- /dataloaders/kitti_dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import dataloaders.transforms as transforms 3 | from dataloaders.dataloader import MyDataloader 4 | 5 | class KITTIDataset(MyDataloader): 6 | def __init__(self, root, type, sparsifier=None, modality='rgb'): 7 | super(KITTIDataset, self).__init__(root, type, sparsifier, modality) 8 | self.output_size = (228, 912) 9 | 10 | def train_transform(self, rgb, depth): 11 | s = np.random.uniform(1.0, 1.5) # random scaling 12 | depth_np = depth / s 13 | angle = np.random.uniform(-5.0, 5.0) # random rotation degrees 14 | do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip 15 | 16 | # perform 1st step of data augmentation 17 | transform = transforms.Compose([ 18 | transforms.Crop(130, 10, 240, 1200), 19 | transforms.Rotate(angle), 20 | transforms.Resize(s), 21 | transforms.CenterCrop(self.output_size), 22 | transforms.HorizontalFlip(do_flip) 23 | ]) 24 | rgb_np = transform(rgb) 25 | rgb_np = self.color_jitter((rgb_np*255.0).astype(np.uint8)) # random color jittering 26 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255 27 | # Scipy affine_transform produced RuntimeError when the depth map was 28 | # given as a 'numpy.ndarray' 29 | depth_np = np.asfarray(depth_np, dtype='float32') 30 | depth_np = transform(depth_np) 31 | 32 | return rgb_np, depth_np 33 | 34 | def val_transform(self, rgb, depth): 35 | depth_np = depth 36 | transform = transforms.Compose([ 37 | #transforms.Crop(130, 10, 240, 1200), 38 | transforms.CenterCrop(self.output_size), 39 | ]) 40 | rgb_np = transform(rgb) 41 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255 42 | depth_np = np.asfarray(depth_np, dtype='float32') 43 | depth_np = transform(depth_np) 44 | 45 | return rgb_np, depth_np 46 | 47 | -------------------------------------------------------------------------------- /dataloaders/nyu_dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import dataloaders.transforms as transforms 3 | from dataloaders.dataloader import MyDataloader 4 | 5 | iheight, iwidth = 480, 640 # raw image size 6 | 7 | class NYUDataset(MyDataloader): 8 | def __init__(self, root, type, sparsifier=None, modality='rgbdm'): 9 | super(NYUDataset, self).__init__(root, type, sparsifier, modality) 10 | self.output_size = (224, 224) 11 | 12 | def train_transform(self, rgb, depth): 13 | s = np.random.uniform(1.0, 1.5) # random scaling 14 | depth_np = depth / s 15 | angle = np.random.uniform(-5.0, 5.0) # random rotation degrees 16 | do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip 17 | 18 | # perform 1st step of data augmentation 19 | transform = transforms.Compose([ 20 | transforms.Resize(250.0 / iheight), # this is for computational efficiency, since rotation can be slow 21 | transforms.Rotate(angle), 22 | transforms.Resize(s), 23 | transforms.CenterCrop(self.output_size), 24 | transforms.HorizontalFlip(do_flip) 25 | ]) 26 | rgb_np = transform(rgb) 27 | rgb_np = self.color_jitter(rgb_np) # random color jittering 28 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255 29 | depth_np = transform(depth_np) 30 | 31 | return rgb_np, depth_np 32 | 33 | def val_transform(self, rgb, depth): 34 | depth_np = depth 35 | transform = transforms.Compose([ 36 | transforms.Resize(240.0 / iheight), 37 | transforms.CenterCrop(self.output_size), 38 | ]) 39 | rgb_np = transform(rgb) 40 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255 41 | depth_np = transform(depth_np) 42 | 43 | return rgb_np, depth_np 44 | -------------------------------------------------------------------------------- /dataloaders/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | 6 | from PIL import Image, ImageOps, ImageEnhance 7 | try: 8 | import accimage 9 | except ImportError: 10 | accimage = None 11 | 12 | import numpy as np 13 | import numbers 14 | import types 15 | import collections 16 | import warnings 17 | 18 | import scipy.ndimage.interpolation as itpl 19 | import scipy.misc as misc 20 | import skimage.transform 21 | 22 | 23 | def _is_numpy_image(img): 24 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) 25 | 26 | def _is_pil_image(img): 27 | if accimage is not None: 28 | return isinstance(img, (Image.Image, accimage.Image)) 29 | else: 30 | return isinstance(img, Image.Image) 31 | 32 | def _is_tensor_image(img): 33 | return torch.is_tensor(img) and img.ndimension() == 3 34 | 35 | def adjust_brightness(img, brightness_factor): 36 | """Adjust brightness of an Image. 37 | 38 | Args: 39 | img (PIL Image): PIL Image to be adjusted. 40 | brightness_factor (float): How much to adjust the brightness. Can be 41 | any non negative number. 0 gives a black image, 1 gives the 42 | original image while 2 increases the brightness by a factor of 2. 43 | 44 | Returns: 45 | PIL Image: Brightness adjusted image. 46 | """ 47 | if not _is_pil_image(img): 48 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 49 | 50 | enhancer = ImageEnhance.Brightness(img) 51 | img = enhancer.enhance(brightness_factor) 52 | return img 53 | 54 | 55 | def adjust_contrast(img, contrast_factor): 56 | """Adjust contrast of an Image. 57 | 58 | Args: 59 | img (PIL Image): PIL Image to be adjusted. 60 | contrast_factor (float): How much to adjust the contrast. Can be any 61 | non negative number. 0 gives a solid gray image, 1 gives the 62 | original image while 2 increases the contrast by a factor of 2. 63 | 64 | Returns: 65 | PIL Image: Contrast adjusted image. 66 | """ 67 | if not _is_pil_image(img): 68 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 69 | 70 | enhancer = ImageEnhance.Contrast(img) 71 | img = enhancer.enhance(contrast_factor) 72 | return img 73 | 74 | 75 | def adjust_saturation(img, saturation_factor): 76 | """Adjust color saturation of an image. 77 | 78 | Args: 79 | img (PIL Image): PIL Image to be adjusted. 80 | saturation_factor (float): How much to adjust the saturation. 0 will 81 | give a black and white image, 1 will give the original image while 82 | 2 will enhance the saturation by a factor of 2. 83 | 84 | Returns: 85 | PIL Image: Saturation adjusted image. 86 | """ 87 | if not _is_pil_image(img): 88 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 89 | 90 | enhancer = ImageEnhance.Color(img) 91 | img = enhancer.enhance(saturation_factor) 92 | return img 93 | 94 | 95 | def adjust_hue(img, hue_factor): 96 | """Adjust hue of an image. 97 | 98 | The image hue is adjusted by converting the image to HSV and 99 | cyclically shifting the intensities in the hue channel (H). 100 | The image is then converted back to original image mode. 101 | 102 | `hue_factor` is the amount of shift in H channel and must be in the 103 | interval `[-0.5, 0.5]`. 104 | 105 | See https://en.wikipedia.org/wiki/Hue for more details on Hue. 106 | 107 | Args: 108 | img (PIL Image): PIL Image to be adjusted. 109 | hue_factor (float): How much to shift the hue channel. Should be in 110 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in 111 | HSV space in positive and negative direction respectively. 112 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image 113 | with complementary colors while 0 gives the original image. 114 | 115 | Returns: 116 | PIL Image: Hue adjusted image. 117 | """ 118 | if not(-0.5 <= hue_factor <= 0.5): 119 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) 120 | 121 | if not _is_pil_image(img): 122 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 123 | 124 | input_mode = img.mode 125 | if input_mode in {'L', '1', 'I', 'F'}: 126 | return img 127 | 128 | h, s, v = img.convert('HSV').split() 129 | 130 | np_h = np.array(h, dtype=np.uint8) 131 | # uint8 addition take cares of rotation across boundaries 132 | with np.errstate(over='ignore'): 133 | np_h += np.uint8(hue_factor * 255) 134 | h = Image.fromarray(np_h, 'L') 135 | 136 | img = Image.merge('HSV', (h, s, v)).convert(input_mode) 137 | return img 138 | 139 | 140 | def adjust_gamma(img, gamma, gain=1): 141 | """Perform gamma correction on an image. 142 | 143 | Also known as Power Law Transform. Intensities in RGB mode are adjusted 144 | based on the following equation: 145 | 146 | I_out = 255 * gain * ((I_in / 255) ** gamma) 147 | 148 | See https://en.wikipedia.org/wiki/Gamma_correction for more details. 149 | 150 | Args: 151 | img (PIL Image): PIL Image to be adjusted. 152 | gamma (float): Non negative real number. gamma larger than 1 make the 153 | shadows darker, while gamma smaller than 1 make dark regions 154 | lighter. 155 | gain (float): The constant multiplier. 156 | """ 157 | if not _is_pil_image(img): 158 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 159 | 160 | if gamma < 0: 161 | raise ValueError('Gamma should be a non-negative real number') 162 | 163 | input_mode = img.mode 164 | img = img.convert('RGB') 165 | 166 | np_img = np.array(img, dtype=np.float32) 167 | np_img = 255 * gain * ((np_img / 255) ** gamma) 168 | np_img = np.uint8(np.clip(np_img, 0, 255)) 169 | 170 | img = Image.fromarray(np_img, 'RGB').convert(input_mode) 171 | return img 172 | 173 | 174 | class Compose(object): 175 | """Composes several transforms together. 176 | 177 | Args: 178 | transforms (list of ``Transform`` objects): list of transforms to compose. 179 | 180 | Example: 181 | >>> transforms.Compose([ 182 | >>> transforms.CenterCrop(10), 183 | >>> transforms.ToTensor(), 184 | >>> ]) 185 | """ 186 | 187 | def __init__(self, transforms): 188 | self.transforms = transforms 189 | 190 | def __call__(self, img): 191 | for t in self.transforms: 192 | img = t(img) 193 | return img 194 | 195 | 196 | class ToTensor(object): 197 | """Convert a ``numpy.ndarray`` to tensor. 198 | 199 | Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W). 200 | """ 201 | 202 | def __call__(self, img): 203 | """Convert a ``numpy.ndarray`` to tensor. 204 | 205 | Args: 206 | img (numpy.ndarray): Image to be converted to tensor. 207 | 208 | Returns: 209 | Tensor: Converted image. 210 | """ 211 | if not(_is_numpy_image(img)): 212 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 213 | 214 | if isinstance(img, np.ndarray): 215 | # handle numpy array 216 | if img.ndim == 3: 217 | img = torch.from_numpy(img.transpose((2, 0, 1)).copy()) 218 | elif img.ndim == 2: 219 | img = torch.from_numpy(img.copy()) 220 | else: 221 | raise RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) 222 | 223 | # backward compatibility 224 | # return img.float().div(255) 225 | return img.float() 226 | 227 | 228 | class NormalizeNumpyArray(object): 229 | """Normalize a ``numpy.ndarray`` with mean and standard deviation. 230 | Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform 231 | will normalize each channel of the input ``numpy.ndarray`` i.e. 232 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 233 | 234 | Args: 235 | mean (sequence): Sequence of means for each channel. 236 | std (sequence): Sequence of standard deviations for each channel. 237 | """ 238 | 239 | def __init__(self, mean, std): 240 | self.mean = mean 241 | self.std = std 242 | 243 | def __call__(self, img): 244 | """ 245 | Args: 246 | img (numpy.ndarray): Image of size (H, W, C) to be normalized. 247 | 248 | Returns: 249 | Tensor: Normalized image. 250 | """ 251 | if not(_is_numpy_image(img)): 252 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 253 | # TODO: make efficient 254 | print(img.shape) 255 | for i in range(3): 256 | img[:,:,i] = (img[:,:,i] - self.mean[i]) / self.std[i] 257 | return img 258 | 259 | class NormalizeTensor(object): 260 | """Normalize an tensor image with mean and standard deviation. 261 | Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform 262 | will normalize each channel of the input ``torch.*Tensor`` i.e. 263 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 264 | 265 | Args: 266 | mean (sequence): Sequence of means for each channel. 267 | std (sequence): Sequence of standard deviations for each channel. 268 | """ 269 | 270 | def __init__(self, mean, std): 271 | self.mean = mean 272 | self.std = std 273 | 274 | def __call__(self, tensor): 275 | """ 276 | Args: 277 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 278 | 279 | Returns: 280 | Tensor: Normalized Tensor image. 281 | """ 282 | if not _is_tensor_image(tensor): 283 | raise TypeError('tensor is not a torch image.') 284 | # TODO: make efficient 285 | for t, m, s in zip(tensor, self.mean, self.std): 286 | t.sub_(m).div_(s) 287 | return tensor 288 | 289 | class Rotate(object): 290 | """Rotates the given ``numpy.ndarray``. 291 | 292 | Args: 293 | angle (float): The rotation angle in degrees. 294 | """ 295 | 296 | def __init__(self, angle): 297 | self.angle = angle 298 | 299 | def __call__(self, img): 300 | """ 301 | Args: 302 | img (numpy.ndarray (C x H x W)): Image to be rotated. 303 | 304 | Returns: 305 | img (numpy.ndarray (C x H x W)): Rotated image. 306 | """ 307 | 308 | # order=0 means nearest-neighbor type interpolation 309 | return itpl.rotate(img, self.angle, reshape=False, prefilter=False, order=0) 310 | 311 | 312 | class Resize(object): 313 | """Resize the the given ``numpy.ndarray`` to the given size. 314 | Args: 315 | size (sequence or int): Desired output size. If size is a sequence like 316 | (h, w), output size will be matched to this. If size is an int, 317 | smaller edge of the image will be matched to this number. 318 | i.e, if height > width, then image will be rescaled to 319 | (size * height / width, size) 320 | interpolation (int, optional): Desired interpolation. Default is 321 | ``PIL.Image.BILINEAR`` 322 | """ 323 | 324 | def __init__(self, size, interpolation='nearest'): 325 | assert isinstance(size, int) or isinstance(size, float) or \ 326 | (isinstance(size, collections.Iterable) and len(size) == 2) 327 | self.size = size 328 | self.interpolation = interpolation 329 | 330 | def __call__(self, img): 331 | """ 332 | Args: 333 | img (PIL Image): Image to be scaled. 334 | Returns: 335 | PIL Image: Rescaled image. 336 | """ 337 | if img.ndim == 3: 338 | return skimage.transform.rescale(img, self.size, order=0) 339 | elif img.ndim == 2: 340 | return skimage.transform.rescale(img, self.size, order=0) 341 | else: 342 | RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) 343 | 344 | 345 | class CenterCrop(object): 346 | """Crops the given ``numpy.ndarray`` at the center. 347 | 348 | Args: 349 | size (sequence or int): Desired output size of the crop. If size is an 350 | int instead of sequence like (h, w), a square crop (size, size) is 351 | made. 352 | """ 353 | 354 | def __init__(self, size): 355 | if isinstance(size, numbers.Number): 356 | self.size = (int(size), int(size)) 357 | else: 358 | self.size = size 359 | 360 | @staticmethod 361 | def get_params(img, output_size): 362 | """Get parameters for ``crop`` for center crop. 363 | 364 | Args: 365 | img (numpy.ndarray (C x H x W)): Image to be cropped. 366 | output_size (tuple): Expected output size of the crop. 367 | 368 | Returns: 369 | tuple: params (i, j, h, w) to be passed to ``crop`` for center crop. 370 | """ 371 | h = img.shape[0] 372 | w = img.shape[1] 373 | th, tw = output_size 374 | i = int(round((h - th) / 2.)) 375 | j = int(round((w - tw) / 2.)) 376 | 377 | # # randomized cropping 378 | # i = np.random.randint(i-3, i+4) 379 | # j = np.random.randint(j-3, j+4) 380 | 381 | return i, j, th, tw 382 | 383 | def __call__(self, img): 384 | """ 385 | Args: 386 | img (numpy.ndarray (C x H x W)): Image to be cropped. 387 | 388 | Returns: 389 | img (numpy.ndarray (C x H x W)): Cropped image. 390 | """ 391 | i, j, h, w = self.get_params(img, self.size) 392 | 393 | """ 394 | i: Upper pixel coordinate. 395 | j: Left pixel coordinate. 396 | h: Height of the cropped image. 397 | w: Width of the cropped image. 398 | """ 399 | if not(_is_numpy_image(img)): 400 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 401 | if img.ndim == 3: 402 | return img[i:i+h, j:j+w, :] 403 | elif img.ndim == 2: 404 | return img[i:i + h, j:j + w] 405 | else: 406 | raise RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) 407 | 408 | 409 | class Lambda(object): 410 | """Apply a user-defined lambda as a transform. 411 | 412 | Args: 413 | lambd (function): Lambda/function to be used for transform. 414 | """ 415 | 416 | def __init__(self, lambd): 417 | assert isinstance(lambd, types.LambdaType) 418 | self.lambd = lambd 419 | 420 | def __call__(self, img): 421 | return self.lambd(img) 422 | 423 | 424 | class HorizontalFlip(object): 425 | """Horizontally flip the given ``numpy.ndarray``. 426 | 427 | Args: 428 | do_flip (boolean): whether or not do horizontal flip. 429 | 430 | """ 431 | 432 | def __init__(self, do_flip): 433 | self.do_flip = do_flip 434 | 435 | def __call__(self, img): 436 | """ 437 | Args: 438 | img (numpy.ndarray (C x H x W)): Image to be flipped. 439 | 440 | Returns: 441 | img (numpy.ndarray (C x H x W)): flipped image. 442 | """ 443 | if not(_is_numpy_image(img)): 444 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 445 | 446 | if self.do_flip: 447 | return np.fliplr(img) 448 | else: 449 | return img 450 | 451 | 452 | class ColorJitter(object): 453 | """Randomly change the brightness, contrast and saturation of an image. 454 | 455 | Args: 456 | brightness (float): How much to jitter brightness. brightness_factor 457 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. 458 | contrast (float): How much to jitter contrast. contrast_factor 459 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. 460 | saturation (float): How much to jitter saturation. saturation_factor 461 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. 462 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from 463 | [-hue, hue]. Should be >=0 and <= 0.5. 464 | """ 465 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 466 | self.brightness = brightness 467 | self.contrast = contrast 468 | self.saturation = saturation 469 | self.hue = hue 470 | 471 | @staticmethod 472 | def get_params(brightness, contrast, saturation, hue): 473 | """Get a randomized transform to be applied on image. 474 | 475 | Arguments are same as that of __init__. 476 | 477 | Returns: 478 | Transform which randomly adjusts brightness, contrast and 479 | saturation in a random order. 480 | """ 481 | transforms = [] 482 | if brightness > 0: 483 | brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) 484 | transforms.append(Lambda(lambda img: adjust_brightness(img, brightness_factor))) 485 | 486 | if contrast > 0: 487 | contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) 488 | transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor))) 489 | 490 | if saturation > 0: 491 | saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) 492 | transforms.append(Lambda(lambda img: adjust_saturation(img, saturation_factor))) 493 | 494 | if hue > 0: 495 | hue_factor = np.random.uniform(-hue, hue) 496 | transforms.append(Lambda(lambda img: adjust_hue(img, hue_factor))) 497 | 498 | np.random.shuffle(transforms) 499 | transform = Compose(transforms) 500 | 501 | return transform 502 | 503 | def __call__(self, img): 504 | """ 505 | Args: 506 | img (numpy.ndarray (C x H x W)): Input image. 507 | 508 | Returns: 509 | img (numpy.ndarray (C x H x W)): Color jittered image. 510 | """ 511 | if not(_is_numpy_image(img)): 512 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 513 | pil = Image.fromarray(img) 514 | transform = self.get_params(self.brightness, self.contrast, 515 | self.saturation, self.hue) 516 | return np.array(transform(pil)) 517 | 518 | class Crop(object): 519 | """Crops the given PIL Image to a rectangular region based on a given 520 | 4-tuple defining the left, upper pixel coordinated, hight and width size. 521 | 522 | Args: 523 | a tuple: (upper pixel coordinate, left pixel coordinate, hight, width)-tuple 524 | """ 525 | 526 | def __init__(self, i, j, h, w): 527 | """ 528 | i: Upper pixel coordinate. 529 | j: Left pixel coordinate. 530 | h: Height of the cropped image. 531 | w: Width of the cropped image. 532 | """ 533 | self.i = i 534 | self.j = j 535 | self.h = h 536 | self.w = w 537 | 538 | def __call__(self, img): 539 | """ 540 | Args: 541 | img (numpy.ndarray (C x H x W)): Image to be cropped. 542 | Returns: 543 | img (numpy.ndarray (C x H x W)): Cropped image. 544 | """ 545 | 546 | i, j, h, w = self.i, self.j, self.h, self.w 547 | 548 | if not(_is_numpy_image(img)): 549 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 550 | if img.ndim == 3: 551 | return img[i:i + h, j:j + w, :] 552 | elif img.ndim == 2: 553 | return img[i:i + h, j:j + w] 554 | else: 555 | raise RuntimeError( 556 | 'img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) 557 | 558 | def __repr__(self): 559 | return self.__class__.__name__ + '(i={0},j={1},h={2},w={3})'.format( 560 | self.i, self.j, self.h, self.w) 561 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import time 3 | from options.options import AdvanceOptions 4 | from models import create_model 5 | from util.visualizer import Visualizer 6 | from dataloaders.nyu_dataloader import NYUDataset 7 | from dataloaders.kitti_dataloader import KITTIDataset 8 | from dataloaders.dense_to_sparse import UniformSampling, SimulatedStereo 9 | import numpy as np 10 | import random 11 | import torch 12 | import cv2 13 | import utils 14 | import os 15 | 16 | # def colored_depthmap(depth, d_min=None, d_max=None): 17 | # if d_min is None: 18 | # d_min = np.min(depth) 19 | # if d_max is None: 20 | # d_max = np.max(depth) 21 | # depth_relative = (depth - d_min) / (d_max - d_min) 22 | # return 255 * plt.cm.viridis(depth_relative)[:,:,:3] # H, W, C 23 | 24 | # def merge_into_row_with_pred_visualize(input, depth_input, rgb_sparse,depth_target, depth_est): 25 | # rgb = 255 * np.transpose(np.squeeze(input.cpu().numpy()), (1,2,0)) # H, W, C 26 | # rgb_sparse = 255 * np.transpose(np.squeeze(rgb_sparse.cpu().numpy()), (1,2,0)) 27 | # depth_input_cpu = np.squeeze(depth_input.cpu().numpy()) 28 | # depth_target_cpu = np.squeeze(depth_target.cpu().numpy()) 29 | # depth_pred_cpu = np.squeeze(depth_est.cpu().numpy()) 30 | 31 | # d_min = min(np.min(depth_input_cpu), np.min(depth_target_cpu), np.min(depth_pred_cpu)) 32 | # d_max = max(np.max(depth_input_cpu), np.max(depth_target_cpu), np.min(depth_pred_cpu)) 33 | # depth_input_col = colored_depthmap(depth_input_cpu, d_min, d_max) 34 | # depth_target_col = colored_depthmap(depth_target_cpu, d_min, d_max) 35 | # depth_pred_col = colored_depthmap(depth_target_cpu, d_min, d_max) 36 | 37 | # img_merge = np.hstack([rgb, rgb_sparse,depth_input_col, depth_target_col,depth_pred_col]) 38 | 39 | # return img_merge 40 | 41 | if __name__ == '__main__': 42 | test_opt = AdvanceOptions().parse(False) 43 | 44 | sparsifier = UniformSampling(test_opt.nP, max_depth=np.inf) 45 | #sparsifier = SimulatedStereo(100, max_depth=np.inf, dilate_kernel=3, dilate_iterations=1) 46 | test_dataset = KITTIDataset(test_opt.test_path, type='val', 47 | modality='rgbdm', sparsifier=sparsifier) 48 | 49 | ### Please use this dataloder if you want to use NYU 50 | # test_dataset = NYUDataset(test_opt.test_path, type='val', 51 | # modality='rgbdm', sparsifier=sparsifier) 52 | 53 | 54 | test_opt.phase = 'val' 55 | test_opt.batch_size = 1 56 | test_opt.num_threads = 1 57 | test_opt.serial_batches = True 58 | test_opt.no_flip = True 59 | 60 | test_data_loader = torch.utils.data.DataLoader(test_dataset, 61 | batch_size=test_opt.batch_size, shuffle=False, num_workers=test_opt.num_threads, pin_memory=True) 62 | 63 | test_dataset_size = len(test_data_loader) 64 | print('#test images = %d' % test_dataset_size) 65 | 66 | model = create_model(test_opt, test_dataset) 67 | model.eval() 68 | model.setup(test_opt) 69 | visualizer = Visualizer(test_opt) 70 | test_loss_iter = [] 71 | gts = None 72 | preds = None 73 | epoch_iter = 0 74 | model.init_test_eval() 75 | epoch = 0 76 | num = 5 # How many images to save in an image 77 | if not os.path.exists('vis'): 78 | os.makedirs('vis') 79 | with torch.no_grad(): 80 | iterator = iter(test_data_loader) 81 | i = 0 82 | while True: 83 | try: # Some images couldn't sample more than defined nP points under Stereo sampling 84 | next_batch = next(iterator) 85 | except IndexError: 86 | print("Catch and Skip!") 87 | continue 88 | except StopIteration: 89 | break 90 | 91 | data, target = next_batch[0], next_batch[1] 92 | model.set_new_input(data,target) 93 | model.forward() 94 | model.test_depth_evaluation() 95 | model.get_loss() 96 | epoch_iter += test_opt.batch_size 97 | losses = model.get_current_losses() 98 | test_loss_iter.append(model.loss_dcca.item()) 99 | 100 | rgb_input = model.rgb_image 101 | depth_input = model.sparse_depth 102 | rgb_sparse = model.sparse_rgb 103 | depth_target = model.depth_image 104 | depth_est = model.depth_est 105 | 106 | ### These part save image in vis/ folder 107 | if i%num == 0: 108 | img_merge = utils.merge_into_row_with_pred_visualize(rgb_input, depth_input, rgb_sparse,depth_target, depth_est) 109 | elif i%num < num-1: 110 | row = utils.merge_into_row_with_pred_visualize(rgb_input, depth_input, rgb_sparse,depth_target, depth_est) 111 | img_merge = utils.add_row(img_merge, row) 112 | elif i%num == num-1: 113 | filename = 'vis/'+str(i)+'.png' 114 | utils.save_image(img_merge, filename) 115 | 116 | i += 1 117 | 118 | print('test epoch {0:}, iters: {1:}/{2:} '.format(epoch, epoch_iter, len(test_dataset) * test_opt.batch_size), end='\r') 119 | print( 120 | 'RMSE={result.rmse:.4f}({average.rmse:.4f}) ' 121 | 'MSE={result.mse:.4f}({average.mse:.4f}) ' 122 | 'MAE={result.mae:.4f}({average.mae:.4f}) ' 123 | 'Delta1={result.delta1:.4f}({average.delta1:.4f}) ' 124 | 'Delta2={result.delta2:.4f}({average.delta2:.4f}) ' 125 | 'Delta3={result.delta3:.4f}({average.delta3:.4f}) ' 126 | 'REL={result.absrel:.4f}({average.absrel:.4f}) ' 127 | 'Lg10={result.lg10:.4f}({average.lg10:.4f}) '.format( 128 | result=model.test_result, average=model.test_average.average())) 129 | avg_test_loss = np.mean(np.asarray(test_loss_iter)) 130 | -------------------------------------------------------------------------------- /images/500.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/CFCNet/828e0c09c646a4669685b3d31b8aa0ae2a5cd351/images/500.gif -------------------------------------------------------------------------------- /models/DCCA_sparse_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import DCCA_sparse_networks 4 | import numpy as np 5 | import os 6 | import math 7 | 8 | class DCCASparseModel(BaseModel): 9 | def name(self): 10 | return 'DCCASparseNetModel' 11 | 12 | @staticmethod 13 | def modify_commandline_options(parser, is_train=True): 14 | 15 | # changing the default values 16 | if is_train: 17 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') 18 | return parser 19 | 20 | def initialize(self, opt, dataset): 21 | BaseModel.initialize(self, opt) 22 | 23 | self.x_dataview = None 24 | self.y_dataview = None 25 | self.depth_est = None 26 | self.loss_dcca = 0 27 | self.loss_l1 = 0 28 | self.loss_mse = None 29 | self.loss_smooth = None 30 | self.result = None 31 | self.test_result = None 32 | self.average = None 33 | self.test_average = None 34 | 35 | self.isTrain = opt.isTrain 36 | # specify the training losses you want to print out. The program will call base_model.get_current_losses 37 | self.loss_names = ['mse','dcca','total','transform','smooth'] 38 | # specify the images you want to save/display. The program will call base_model.get_current_visuals 39 | self.visual_names = ['rgb_image','depth_image','mask','output'] 40 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks 41 | self.model_names = ['DCCASparseNet'] 42 | 43 | # load/define networks 44 | self.netDCCASparseNet = DCCA_sparse_networks.define_DCCASparseNet(rgb_enc=True, depth_enc=True, depth_dec=True, norm=opt.norm, init_type=opt.init_type, init_gain= opt.init_gain, gpu_ids= self.gpu_ids) 45 | # define loss functions 46 | self.criterionDCCA = DCCA_sparse_networks.DCCA_2D_Loss(outdim_size = 60,use_all_singular_values = True, device=self.device).to(self.device) 47 | self.MSE = DCCA_sparse_networks.MaskedMSELoss() 48 | self.SMOOTH = DCCA_sparse_networks.SmoothLoss() 49 | self.TransformLoss = DCCA_sparse_networks.TransformLoss() 50 | 51 | if self.isTrain: 52 | # initialize optimizers 53 | self.optimizers = [] 54 | self.optimizer_DCCASparseNet = torch.optim.SGD(self.netDCCASparseNet.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay) 55 | self.optimizers.append(self.optimizer_DCCASparseNet) 56 | 57 | def set_input(self, input): 58 | self.rgb_image = input['rgb_image'].to(self.device) 59 | self.depth_image = input['depth_image'].to(self.device) 60 | self.mask = input['mask'].to(self.device) 61 | self.image_paths = input['path'] 62 | 63 | def set_new_input(self, input,target): 64 | self.rgb_image = input[:,:3,:,:].to(self.device) 65 | self.sparse_rgb = input[:,4:7,:,:].to(self.device) 66 | self.depth_image = target.to(self.device) 67 | self.sparse_depth = input[:,3,:,:].to(self.device).unsqueeze(1) 68 | self.mask = input[:,7,:,:].to(self.device).unsqueeze(1) 69 | 70 | def forward(self): 71 | self.x_dataview,self.y_dataview,self.x_trans,self.depth_est= self.netDCCASparseNet(self.sparse_rgb,self.sparse_depth,self.mask,self.rgb_image,self.depth_image) 72 | 73 | def get_loss(self): 74 | self.loss_dcca = self.criterionDCCA(self.x_dataview,self.y_dataview) 75 | self.loss_mse = self.MSE(self.depth_est,self.depth_image) 76 | self.loss_smooth = self.SMOOTH(self.depth_est) 77 | self.loss_transform = self.TransformLoss(self.x_trans, self.x_dataview) 78 | self.loss_total = self.loss_mse + self.loss_dcca + self.loss_transform + 0.1*self.loss_smooth 79 | 80 | def backward(self): 81 | self.loss_total.backward() 82 | 83 | def pure_backward(self): 84 | self.loss_dcca.backward() 85 | 86 | def init_test_eval(self): 87 | self.test_result = Result() 88 | self.test_average = AverageMeter() 89 | 90 | def init_eval(self): 91 | self.result = Result() 92 | self.average = AverageMeter() 93 | 94 | def depth_evaluation(self): 95 | self.result.evaluate(self.depth_est.data, self.depth_image.data) 96 | self.average.update(self.result, self.sparse_rgb.size(0)) 97 | 98 | def test_depth_evaluation(self): 99 | self.test_result.evaluate(self.depth_est.data, self.depth_image.data) 100 | self.test_average.update(self.test_result, self.sparse_rgb.size(0)) 101 | print() 102 | 103 | def print_test_depth_evaluation(self): 104 | message = 'RMSE={result.rmse:.4f}({average.rmse:.4f}) \ 105 | MAE={result.mae:.4f}({average.mae:.4f}) \ 106 | Delta1={result.delta1:.4f}({average.delta1:.4f}) \ 107 | REL={result.absrel:.4f}({average.absrel:.4f}) \ 108 | Lg10={result.lg10:.4f}({average.lg10:.4f})'.format(result=self.test_result, average=self.test_average.average()) 109 | print(message) 110 | return message 111 | 112 | def print_depth_evaluation(self): 113 | message = 'RMSE={result.rmse:.4f}({average.rmse:.4f}) \ 114 | MAE={result.mae:.4f}({average.mae:.4f}) \ 115 | Delta1={result.delta1:.4f}({average.delta1:.4f}) \ 116 | REL={result.absrel:.4f}({average.absrel:.4f}) \ 117 | Lg10={result.lg10:.4f}({average.lg10:.4f})'.format(result=self.result, average=self.average.average()) 118 | print(message) 119 | return message 120 | 121 | def optimize_parameters(self): 122 | self.forward() 123 | self.depth_evaluation() 124 | self.set_requires_grad(self.netDCCASparseNet, True) 125 | self.get_loss() 126 | self.optimizer_DCCASparseNet.zero_grad() 127 | # update DCCAnet 128 | self.backward() 129 | self.optimizer_DCCASparseNet.step() 130 | 131 | 132 | ####### Metrics ######## 133 | def log10(x): 134 | """Convert a new tensor with the base-10 logarithm of the elements of x. """ 135 | return torch.log(x) / math.log(10) 136 | 137 | class Result(object): 138 | def __init__(self): 139 | self.irmse, self.imae = 0, 0 140 | self.mse, self.rmse, self.mae = 0, 0, 0 141 | self.absrel, self.lg10 = 0, 0 142 | self.delta1, self.delta2, self.delta3 = 0, 0, 0 143 | self.data_time, self.gpu_time = 0, 0 144 | 145 | def set_to_worst(self): 146 | self.irmse, self.imae = np.inf, np.inf 147 | self.mse, self.rmse, self.mae = np.inf, np.inf, np.inf 148 | self.absrel, self.lg10 = np.inf, np.inf 149 | self.delta1, self.delta2, self.delta3 = 0, 0, 0 150 | self.data_time, self.gpu_time = 0, 0 151 | 152 | def update(self, irmse, imae, mse, rmse, mae, absrel, lg10, delta1, delta2, delta3, gpu_time, data_time): 153 | self.irmse, self.imae = irmse, imae 154 | self.mse, self.rmse, self.mae = mse, rmse, mae 155 | self.absrel, self.lg10 = absrel, lg10 156 | self.delta1, self.delta2, self.delta3 = delta1, delta2, delta3 157 | self.data_time, self.gpu_time = data_time, gpu_time 158 | 159 | def evaluate(self, output, target): 160 | valid_mask = target>0 161 | output = output[valid_mask] 162 | target = target[valid_mask] 163 | 164 | new_output = output[target<=50] 165 | new_target = target[target<=50] 166 | target = new_target 167 | output = new_output 168 | 169 | abs_diff = (output - target).abs() 170 | 171 | self.mse = float((torch.pow(abs_diff, 2)).mean()) 172 | self.rmse = math.sqrt(self.mse) 173 | self.mae = float(abs_diff.mean()) 174 | self.lg10 = float((log10(output) - log10(target)).abs().mean()) 175 | self.absrel = float((abs_diff / target).mean()) 176 | 177 | maxRatio = torch.max(output / target, target / output) 178 | self.delta1 = float((maxRatio < 1.25).float().mean()) 179 | self.delta2 = float((maxRatio < 1.25 ** 2).float().mean()) 180 | self.delta3 = float((maxRatio < 1.25 ** 3).float().mean()) 181 | self.data_time = 0 182 | self.gpu_time = 0 183 | 184 | inv_output = 1 / output 185 | inv_target = 1 / target 186 | abs_inv_diff = (inv_output - inv_target).abs() 187 | self.irmse = math.sqrt((torch.pow(abs_inv_diff, 2)).mean()) 188 | self.imae = float(abs_inv_diff.mean()) 189 | 190 | 191 | class AverageMeter(object): 192 | def __init__(self): 193 | self.reset() 194 | 195 | def reset(self): 196 | self.count = 0.0 197 | self.sum_irmse, self.sum_imae = 0, 0 198 | self.sum_mse, self.sum_rmse, self.sum_mae = 0, 0, 0 199 | self.sum_absrel, self.sum_lg10 = 0, 0 200 | self.sum_delta1, self.sum_delta2, self.sum_delta3 = 0, 0, 0 201 | self.sum_data_time, self.sum_gpu_time = 0, 0 202 | 203 | def update(self, result, n=1): 204 | self.count += n 205 | self.sum_irmse += n*result.irmse 206 | self.sum_imae += n*result.imae 207 | self.sum_mse += n*result.mse 208 | self.sum_rmse += n*result.rmse 209 | self.sum_mae += n*result.mae 210 | self.sum_absrel += n*result.absrel 211 | self.sum_lg10 += n*result.lg10 212 | self.sum_delta1 += n*result.delta1 213 | self.sum_delta2 += n*result.delta2 214 | self.sum_delta3 += n*result.delta3 215 | 216 | def average(self): 217 | avg = Result() 218 | avg.update( 219 | self.sum_irmse / self.count, self.sum_imae / self.count, 220 | self.sum_mse / self.count, self.sum_rmse / self.count, self.sum_mae / self.count, 221 | self.sum_absrel / self.count, self.sum_lg10 / self.count, 222 | self.sum_delta1 / self.count, self.sum_delta2 / self.count, self.sum_delta3 / self.count, 223 | self.sum_gpu_time / self.count, self.sum_data_time / self.count) 224 | return avg 225 | -------------------------------------------------------------------------------- /models/DCCA_sparse_networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import torchvision 5 | import functools 6 | from torch.optim import lr_scheduler 7 | import torch.nn.functional as F 8 | from copy import deepcopy 9 | import numpy as np 10 | import cv2 11 | import collections 12 | import matplotlib.pyplot as plt 13 | 14 | def get_norm_layer(norm_type='instance'): 15 | if norm_type == 'batch': 16 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 17 | elif norm_type == 'instance': 18 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True) 19 | elif norm_type == 'none': 20 | norm_layer = None 21 | else: 22 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 23 | return norm_layer 24 | 25 | 26 | def get_scheduler(optimizer, opt): 27 | if opt.lr_policy == 'lambda': 28 | lambda_rule = lambda epoch: opt.lr_gamma ** ((epoch+1) // opt.lr_decay_epochs) 29 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 30 | elif opt.lr_policy == 'step': 31 | scheduler = lr_scheduler.StepLR(optimizer,step_size=opt.lr_decay_iters, gamma=0.1) 32 | else: 33 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 34 | return scheduler 35 | 36 | 37 | def init_weights(net, init_type='normal', gain=0.02): 38 | net = net 39 | def init_func(m): 40 | classname = m.__class__.__name__ 41 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 42 | if init_type == 'normal': 43 | init.normal_(m.weight.data, 0.0, gain) 44 | elif init_type == 'xavier': 45 | init.xavier_normal_(m.weight.data, gain=gain) 46 | elif init_type == 'kaiming': 47 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 48 | elif init_type == 'orthogonal': 49 | init.orthogonal_(m.weight.data, gain=gain) 50 | elif init_type == 'pretrained': 51 | pass 52 | else: 53 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 54 | if hasattr(m, 'bias') and m.bias is not None and init_type != 'pretrained': 55 | init.constant_(m.bias.data, 0.0) 56 | elif classname.find('BatchNorm2d') != -1: 57 | init.normal_(m.weight.data, 1.0, gain) 58 | init.constant_(m.bias.data, 0.0) 59 | print('initialize network with %s' % init_type) 60 | net.apply(init_func) 61 | 62 | 63 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 64 | if len(gpu_ids) > 0: 65 | assert(torch.cuda.is_available()) 66 | net.to(gpu_ids[0]) 67 | net = torch.nn.DataParallel(net, gpu_ids) 68 | 69 | for root_child in net.children(): 70 | for children in root_child.children(): 71 | if children in root_child.need_initialization: 72 | init_weights(children, init_type, gain=init_gain) 73 | return net 74 | 75 | def define_DCCASparseNet(rgb_enc=True, depth_enc=True, depth_dec=True, norm='batch', init_type='xavier', init_gain=0.02, gpu_ids=[]): 76 | net = None 77 | norm_layer = get_norm_layer(norm_type=norm) 78 | net = DCCASparsenetGenerator(rgb_enc=rgb_enc, depth_enc=depth_enc, depth_dec=depth_dec) 79 | return init_net(net, init_type, init_gain, gpu_ids) 80 | 81 | ############################################################################## 82 | # Classes 83 | ############################################################################## 84 | class SAConv(nn.Module): 85 | # Convolution layer for sparse data 86 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, dilation=1, bias=True): 87 | super(SAConv, self).__init__() 88 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False) 89 | self.if_bias = bias 90 | if self.if_bias: 91 | self.bias = nn.Parameter(torch.zeros(out_channels).float(), requires_grad=True) 92 | self.pool = nn.MaxPool2d(kernel_size, stride=stride, padding=padding, dilation=dilation) 93 | nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu') 94 | self.pool.require_grad = False 95 | 96 | def forward(self, input): 97 | x, m = input 98 | x = x * m 99 | x = self.conv(x) 100 | weights = torch.ones(torch.Size([1, 1, 3, 3])).cuda() 101 | mc = F.conv2d(m, weights, bias=None, stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation) 102 | mc = torch.clamp(mc, min=1e-5) 103 | mc = 1. / mc * 9 104 | 105 | if self.if_bias: 106 | x = x + self.bias.view(1, self.bias.size(0), 1, 1).expand_as(x) 107 | m = self.pool(m) 108 | 109 | return x, m 110 | 111 | class SAConvBlock(nn.Module): 112 | 113 | def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=1, dilation=1, bias=True): 114 | super(SAConvBlock, self).__init__() 115 | self.sparse_conv = SAConv(in_channel, out_channel, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True) 116 | self.relu = nn.ReLU(inplace=True) 117 | 118 | def forward(self, input): 119 | x, m = input 120 | x, m = self.sparse_conv((x, m)) 121 | assert (m.size(1)==1) 122 | x = self.relu(x) 123 | 124 | return x, m 125 | 126 | class Decoder(nn.Module): 127 | # Decoder is the base class for all decoders 128 | 129 | def __init__(self): 130 | super(Decoder, self).__init__() 131 | 132 | self.layer1 = None 133 | self.layer2 = None 134 | self.layer3 = None 135 | self.layer4 = None 136 | 137 | def forward(self, x): 138 | x = self.layer1(x) 139 | x = self.layer2(x) 140 | x = self.layer3(x) 141 | x = self.layer4(x) 142 | return x 143 | 144 | class DeConv(Decoder): 145 | def __init__(self, in_channels, kernel_size): 146 | assert kernel_size>=2, "kernel_size out of range: {}".format(kernel_size) 147 | super(DeConv, self).__init__() 148 | 149 | def convt(in_channels): 150 | stride = 2 151 | padding = (kernel_size - 1) // 2 152 | output_padding = kernel_size % 2 153 | assert -2 - 2*padding + kernel_size + output_padding == 0, "deconv parameters incorrect" 154 | 155 | module_name = "deconv{}".format(kernel_size) 156 | return nn.Sequential(collections.OrderedDict([ 157 | (module_name, nn.ConvTranspose2d(in_channels,in_channels//2,kernel_size, 158 | stride,padding,output_padding,bias=False)), 159 | ('batchnorm', nn.BatchNorm2d(in_channels//2)), 160 | ('relu', nn.ReLU(inplace=True)), 161 | ])) 162 | self.layer1 = convt(in_channels) 163 | self.layer2 = convt(in_channels // 2) 164 | self.layer3 = convt(in_channels // (2 ** 2)) 165 | self.layer4 = convt(in_channels // (2 ** 3)) 166 | 167 | def make_layers_from_size(sizes): 168 | layers = [] 169 | for size in sizes: 170 | layers += [nn.Conv2d(size[0], size[1], kernel_size=3, padding=1), nn.BatchNorm2d(size[1],momentum = 0.1), nn.ReLU(inplace=True)] 171 | return nn.Sequential(*layers) 172 | 173 | def make_blocks_from_names(names,in_dim,out_dim): 174 | layers = [] 175 | if names[0] == "block1" or names[0] == "block2": 176 | layers += [SAConvBlock(in_dim, out_dim, 3,stride = 1)] 177 | layers += [SAConvBlock(out_dim, out_dim, 3,stride = 1)] 178 | else: 179 | layers += [SAConvBlock(in_dim, out_dim, 3,stride = 1)] 180 | layers += [SAConvBlock(out_dim, out_dim, 3,stride = 1)] 181 | layers += [SAConvBlock(out_dim, out_dim, 3,stride = 1)] 182 | return nn.Sequential(*layers) 183 | 184 | class DCCASparsenetGenerator(nn.Module): 185 | def __init__(self, rgb_enc=True, depth_enc=True, depth_dec=True): 186 | super(DCCASparsenetGenerator, self).__init__() 187 | #batchNorm_momentum = 0.1 188 | self.need_initialization = [] 189 | 190 | if rgb_enc : 191 | ##### RGB ENCODER #### 192 | self.CBR1_RGB_ENC = make_blocks_from_names(["block1"], 3,64) 193 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 194 | 195 | self.CBR2_RGB_ENC = make_blocks_from_names(["block2"], 64, 128) 196 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 197 | 198 | self.CBR3_RGB_ENC = make_blocks_from_names(["block3"], 128, 256) 199 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 200 | self.dropout3 = nn.Dropout(p=0.4) 201 | 202 | self.CBR4_RGB_ENC = make_blocks_from_names(["block4"], 256, 512) 203 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 204 | self.dropout4 = nn.Dropout(p=0.4) 205 | 206 | self.CBR5_RGB_ENC = make_blocks_from_names(["block5"], 512, 512) 207 | self.dropout5 = nn.Dropout(p=0.4) 208 | 209 | self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 210 | 211 | if depth_enc : 212 | 213 | self.CBR1_DEPTH_ENC = make_blocks_from_names(["block1"], 1, 64) 214 | self.pool1_d = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 215 | 216 | self.CBR2_DEPTH_ENC = make_blocks_from_names(["block2"], 64, 128) 217 | self.pool2_d = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 218 | 219 | self.CBR3_DEPTH_ENC = make_blocks_from_names(["block3"], 128, 256) 220 | self.pool3_d = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 221 | 222 | self.CBR4_DEPTH_ENC = make_blocks_from_names(["block4"], 256, 512) 223 | self.pool4_d = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 224 | 225 | self.CBR5_DEPTH_ENC = make_blocks_from_names(["block5"], 512, 512) 226 | 227 | self.pool5_d = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 228 | 229 | if depth_dec : 230 | #### DECODER #### 231 | self.Transform = make_blocks_from_names(["block1"],512, 512) 232 | self.decoder = DeConv(1024, 3) 233 | self.conv3 = nn.Conv2d(64,1,kernel_size=3,stride=1,padding=1,bias=False) 234 | ## This size is for KITTI, use (224,224) for NYU 235 | self.bilinear = nn.Upsample((228,912), mode='bilinear', align_corners=True) 236 | 237 | self.need_initialization.append(self.decoder) 238 | self.need_initialization.append(self.conv3) 239 | 240 | def forward(self, sparse_rgb,sparse_d,mask,rgb,d): 241 | 242 | ######## DEPTH ENCODER ######## 243 | x_1,m_d = self.CBR1_DEPTH_ENC((sparse_d,mask)) 244 | x, id1_d = self.pool1_d(x_1) 245 | m_d,_ = self.pool1_d(m_d ) 246 | 247 | x_2,m_d = self.CBR2_DEPTH_ENC((x,m_d )) 248 | x, id2_d = self.pool2_d(x_2) 249 | m_d,_ = self.pool2_d(m_d ) 250 | 251 | x_3,m_d = self.CBR3_DEPTH_ENC((x,m_d )) 252 | x, id3_d = self.pool4_d(x_3) 253 | m_d,_ = self.pool3_d(m_d ) 254 | 255 | x_4,m_d = self.CBR4_DEPTH_ENC((x,m_d )) 256 | x, id4_d = self.pool4_d(x_4) 257 | m_d,_ = self.pool4_d(m_d ) 258 | 259 | x_5,m_d = self.CBR5_DEPTH_ENC((x,m_d )) 260 | x_dataview, id5_d = self.pool5_d(x_5) 261 | m_d,_ = self.pool5_d(m_d ) 262 | 263 | ######## RGB ENCODER ######## 264 | y_1,m_r = self.CBR1_RGB_ENC((sparse_rgb,mask)) 265 | y, id1 = self.pool1(y_1) 266 | m_r,_ = self.pool1(m_r) 267 | 268 | y_2,m_r = self.CBR2_RGB_ENC((y,m_r)) 269 | y, id2 = self.pool2(y_2) 270 | m_r,_ = self.pool2(m_r) 271 | 272 | y_3,m_r = self.CBR3_RGB_ENC((y,m_r)) 273 | y, id3 = self.pool3(y_3) 274 | m_r,_ = self.pool3(m_r) 275 | 276 | y_4,m_r = self.CBR4_RGB_ENC((y,m_r)) 277 | y, id4 = self.pool4(y_4) 278 | m_r,_ = self.pool4(m_r) 279 | 280 | y_5,m_r = self.CBR5_RGB_ENC((y,m_r)) 281 | y_dataview, id5 = self.pool5(y_5) 282 | m_r,_ = self.pool5(m_r) 283 | 284 | ######## MISSING DATA ENCODER ######## 285 | inverse_mask = torch.ones_like(mask)-mask 286 | inverse_rgb = rgb*inverse_mask 287 | 288 | ym_1,m_m = self.CBR1_RGB_ENC((inverse_rgb,inverse_mask)) 289 | ym, id1_m = self.pool1(ym_1) 290 | m_m,_ = self.pool1(m_m) 291 | 292 | ym_2,m_m = self.CBR2_RGB_ENC((ym,m_m )) 293 | ym, id2_m = self.pool2(ym_2) 294 | m_m,_ = self.pool2(m_m) 295 | 296 | ym_3,m_m = self.CBR3_RGB_ENC((ym,m_m )) 297 | ym, id3_m = self.pool4(ym_3) 298 | m_m,_ = self.pool3(m_m) 299 | 300 | ym_4,m_m = self.CBR4_RGB_ENC((ym,m_m )) 301 | ym, id4_m = self.pool4(ym_4) 302 | m_m,_ = self.pool4(m_m) 303 | 304 | ym_5,m_m = self.CBR5_RGB_ENC((ym,m_m )) 305 | ym_dataview, id5_m = self.pool5(ym_5) 306 | m_m,_ = self.pool5(m_m) 307 | 308 | ######## Transformer ######## 309 | x_trans, m_trans = self.Transform((y_dataview,m_r)) 310 | xm_trans, mm_trans = self.Transform((ym_dataview,m_r)) 311 | 312 | ######## DECODER ######## 313 | x = self.decoder(torch.cat((x_dataview,xm_trans),1)) 314 | x = self.conv3(x) 315 | depth_est = self.bilinear(x) 316 | 317 | return x_dataview, y_dataview, x_trans, depth_est 318 | 319 | class MaskedMSELoss(nn.Module): 320 | def __init__(self): 321 | super(MaskedMSELoss, self).__init__() 322 | 323 | def forward(self, pred, target): 324 | assert pred.dim() == target.dim(), "inconsistent dimensions" 325 | valid_mask = (target>0).detach() 326 | diff = target - pred 327 | diff = diff[valid_mask] 328 | self.loss = (diff ** 2).mean() 329 | return self.loss 330 | 331 | class TransformLoss(nn.Module): 332 | def __init__(self): 333 | super(TransformLoss, self).__init__() 334 | 335 | def forward(self, f_in, f_target): 336 | assert f_in.dim() == f_target.dim(), "inconsistent dimensions" 337 | diff = f_in - f_target 338 | self.loss = (diff ** 2).mean() 339 | return self.loss 340 | 341 | class SmoothLoss(nn.Module): 342 | def __init__(self): 343 | super(SmoothLoss, self).__init__() 344 | 345 | def forward(self, pred_map): 346 | def gradient(pred): 347 | D_dy = pred[:, :, 1:] - pred[:, :, :-1] 348 | D_dx = pred[:, :, :, 1:] - pred[:, :, :, :-1] 349 | return D_dx, D_dy 350 | 351 | if type(pred_map) not in [tuple, list]: 352 | pred_map = [pred_map] 353 | 354 | loss = 0 355 | weight = 1. 356 | 357 | for scaled_map in pred_map: 358 | dx, dy = gradient(scaled_map) 359 | dx2, dxdy = gradient(dx) 360 | dydx, dy2 = gradient(dy) 361 | loss += (dx2.abs().mean() + dxdy.abs().mean() + dydx.abs().mean() + dy2.abs().mean())*weight 362 | weight /= 2.3 # don't ask me why it works better 363 | return loss 364 | 365 | class DCCA_2D_Loss(nn.Module): 366 | def __init__(self,outdim_size, use_all_singular_values, device): 367 | super(DCCA_2D_Loss, self).__init__() 368 | self.outdim_size = outdim_size 369 | self.use_all_singular_values = use_all_singular_values 370 | self.device = device 371 | 372 | def __call__(self, data_view1, data_view2): 373 | H1 = data_view1.view(data_view1.size(0)*data_view1.size(1),data_view1.size(2),data_view1.size(3)) 374 | H2 = data_view2.view(data_view2.size(0)*data_view2.size(1),data_view2.size(2),data_view2.size(3)) 375 | 376 | r1 = 1e-4 377 | r2 = 1e-4 378 | eps = 1e-12 379 | corr_sum = 0 380 | o1 = o2 = H1.size(1) 381 | 382 | m = H1.size(0) 383 | n = H1.size(1) 384 | 385 | H1bar = H1 - (1.0 / m) * H1 386 | H2bar = H2 - (1.0 / m) * H2 387 | Hat12 = torch.zeros(m,n,n).cuda() 388 | Hat11 = torch.zeros(m,n,n).cuda() 389 | Hat22 = torch.zeros(m,n,n).cuda() 390 | 391 | for i in range(m): 392 | Hat11[i] = torch.matmul(H1bar[i],H1bar.transpose(1,2)[i]) 393 | Hat12[i] = torch.matmul(H1bar[i],H2bar.transpose(1,2)[i]) 394 | Hat22[i] = torch.matmul(H2bar[i],H2bar.transpose(1,2)[i]) 395 | 396 | SigmaHat12 = (1.0 / (m - 1)) * torch.mean(Hat12,dim=0) 397 | SigmaHat11 = (1.0 / (m - 1)) * torch.mean(Hat11,dim=0)+ r1 * torch.eye(o1, device=self.device) 398 | SigmaHat22 = (1.0 / (m - 1)) * torch.mean(Hat22,dim=0) + r2 * torch.eye(o2, device=self.device) 399 | 400 | # Calculating the root inverse of covariance matrices by using eigen decomposition 401 | [D1, V1] = torch.symeig(SigmaHat11, eigenvectors=True) 402 | [D2, V2] = torch.symeig(SigmaHat22, eigenvectors=True) 403 | 404 | # Added to increase stability 405 | posInd1 = torch.gt(D1, eps).nonzero()[:, 0] 406 | D1 = D1[posInd1] 407 | V1 = V1[:, posInd1] 408 | posInd2 = torch.gt(D2, eps).nonzero()[:, 0] 409 | D2 = D2[posInd2] 410 | V2 = V2[:, posInd2] 411 | SigmaHat11RootInv = torch.matmul( 412 | torch.matmul(V1, torch.diag(D1 ** -0.5)), V1.t()) 413 | SigmaHat22RootInv = torch.matmul( 414 | torch.matmul(V2, torch.diag(D2 ** -0.5)), V2.t()) 415 | 416 | Tval = torch.matmul(torch.matmul(SigmaHat11RootInv, 417 | SigmaHat12), SigmaHat22RootInv) 418 | 419 | if self.use_all_singular_values: 420 | # all singular values are used to calculate the correlation 421 | corr = torch.sqrt(torch.trace(torch.matmul(Tval.t(), Tval))) 422 | else: 423 | # just the top self.outdim_size singular values are used 424 | U, V = torch.symeig(torch.matmul(Tval.t(), Tval), eigenvectors=True) 425 | U = U[torch.gt(U, eps).nonzero()[:, 0]] 426 | U = U.topk(self.outdim_size)[0] 427 | corr = torch.sum(torch.sqrt(U)) 428 | return -corr -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from models.base_model import BaseModel 3 | 4 | 5 | def find_model_using_name(model_name): 6 | # Given the option --model [modelname], 7 | # the file "models/modelname_model.py" 8 | # will be imported. 9 | model_filename = "models." + model_name + "_model" 10 | modellib = importlib.import_module(model_filename) 11 | 12 | # In the file, the class called ModelNameModel() will 13 | # be instantiated. It has to be a subclass of BaseModel, 14 | # and it is case-insensitive. 15 | model = None 16 | target_model_name = model_name.replace('_', '') + 'model' 17 | for name, cls in modellib.__dict__.items(): 18 | if name.lower() == target_model_name.lower() \ 19 | and issubclass(cls, BaseModel): 20 | model = cls 21 | 22 | if model is None: 23 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 24 | exit(0) 25 | 26 | return model 27 | 28 | 29 | def get_option_setter(model_name): 30 | model_class = find_model_using_name(model_name) 31 | return model_class.modify_commandline_options 32 | 33 | 34 | def create_model(opt, dataset): 35 | model = find_model_using_name(opt.model) 36 | instance = model() 37 | instance.initialize(opt, dataset) 38 | print("model [%s] was created" % (instance.name())) 39 | return instance 40 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from torch.optim import lr_scheduler 5 | 6 | class BaseModel(): 7 | @staticmethod 8 | def modify_commandline_options(parser, is_train): 9 | return parser 10 | 11 | def name(self): 12 | return 'BaseModel' 13 | 14 | def initialize(self, opt): 15 | self.opt = opt 16 | self.gpu_ids = opt.gpu_ids 17 | self.isTrain = opt.isTrain 18 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 19 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 20 | if opt.resize_or_crop != 'scale_width': 21 | torch.backends.cudnn.benchmark = True 22 | self.loss_names = [] 23 | self.model_names = [] 24 | self.visual_names = [] 25 | self.image_paths = [] 26 | 27 | def set_input(self, input): 28 | self.input = input 29 | 30 | def forward(self): 31 | pass 32 | 33 | def get_scheduler(self, optimizer, opt): 34 | if opt.lr_policy == 'lambda': 35 | lambda_rule = lambda epoch: opt.lr_gamma ** ((epoch+1) // opt.lr_decay_epochs) 36 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 37 | elif opt.lr_policy == 'step': 38 | scheduler = lr_scheduler.StepLR(optimizer,step_size=opt.lr_decay_iters, gamma=0.1) 39 | else: 40 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 41 | return scheduler 42 | 43 | # load and print networks; create schedulers 44 | def setup(self, opt, parser=None): 45 | if self.isTrain: 46 | self.schedulers = [self.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 47 | 48 | if not self.isTrain or opt.continue_train: 49 | self.load_networks(opt.epoch) 50 | self.print_networks(opt.verbose) 51 | 52 | # make models eval mode during test time 53 | def eval(self): 54 | for name in self.model_names: 55 | if isinstance(name, str): 56 | net = getattr(self, 'net' + name) 57 | net.eval() 58 | def train(self): 59 | for name in self.model_names: 60 | if isinstance(name, str): 61 | net = getattr(self, 'net' + name) 62 | net.train() 63 | 64 | def test(self): 65 | with torch.no_grad(): 66 | self.forward() 67 | 68 | # get image paths 69 | def get_image_paths(self): 70 | return self.image_paths 71 | 72 | def optimize_parameters(self): 73 | pass 74 | 75 | # update learning rate (called once every epoch) 76 | def update_learning_rate(self): 77 | for scheduler in self.schedulers: 78 | scheduler.step() 79 | lr = self.optimizers[0].param_groups[0]['lr'] 80 | print('learning rate = %.7f' % lr) 81 | 82 | # return visualization images. train.py will display these images, and save the images to a html 83 | def get_current_visuals(self): 84 | visual_ret = OrderedDict() 85 | for name in self.visual_names: 86 | if isinstance(name, str): 87 | visual_ret[name] = getattr(self, name) 88 | return visual_ret 89 | 90 | # return traning losses/errors. train.py will print out these errors as debugging information 91 | def get_current_losses(self): 92 | errors_ret = OrderedDict() 93 | for name in self.loss_names: 94 | if isinstance(name, str): 95 | # float(...) works for both scalar tensor and float number 96 | errors_ret[name] = float(getattr(self, 'loss_' + name)) 97 | return errors_ret 98 | 99 | # save models to the disk 100 | def save_networks(self, epoch): 101 | for name in self.model_names: 102 | if isinstance(name, str): 103 | save_filename = '%s_net_%s.pth' % (epoch, name) 104 | save_path = os.path.join(self.save_dir, save_filename) 105 | net = getattr(self, 'net' + name) 106 | 107 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 108 | torch.save(net.module.cpu().state_dict(), save_path) 109 | net.cuda(self.gpu_ids[0]) 110 | else: 111 | torch.save(net.cpu().state_dict(), save_path) 112 | 113 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 114 | key = keys[i] 115 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 116 | if module.__class__.__name__.startswith('InstanceNorm') and \ 117 | (key == 'running_mean' or key == 'running_var'): 118 | if getattr(module, key) is None: 119 | state_dict.pop('.'.join(keys)) 120 | if module.__class__.__name__.startswith('InstanceNorm') and \ 121 | (key == 'num_batches_tracked'): 122 | state_dict.pop('.'.join(keys)) 123 | else: 124 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 125 | 126 | # load models from the disk 127 | def load_networks(self, epoch): 128 | for name in self.model_names: 129 | if isinstance(name, str): 130 | load_filename = '%s_net_%s.pth' % (epoch, name) 131 | load_path = os.path.join(self.save_dir, load_filename) 132 | net = getattr(self, 'net' + name) 133 | if isinstance(net, torch.nn.DataParallel): 134 | net = net.module 135 | print('loading the model from %s' % load_path) 136 | # if you are using PyTorch newer than 0.4 (e.g., built from 137 | # GitHub source), you can remove str() on self.device 138 | state_dict = torch.load(load_path, map_location=str(self.device)) 139 | if hasattr(state_dict, '_metadata'): 140 | del state_dict._metadata 141 | 142 | # patch InstanceNorm checkpoints prior to 0.4 143 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 144 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 145 | net.load_state_dict(state_dict) 146 | 147 | # print network information 148 | def print_networks(self, verbose): 149 | print('---------- Networks initialized -------------') 150 | for name in self.model_names: 151 | if isinstance(name, str): 152 | net = getattr(self, 'net' + name) 153 | num_params = 0 154 | for param in net.parameters(): 155 | num_params += param.numel() 156 | if verbose: 157 | print(net) 158 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 159 | print('-----------------------------------------------') 160 | 161 | # set requies_grad=Fasle to avoid computation 162 | def set_requires_grad(self, nets, requires_grad=False): 163 | if not isinstance(nets, list): 164 | nets = [nets] 165 | for net in nets: 166 | if net is not None: 167 | for param in net.parameters(): 168 | param.requires_grad = requires_grad 169 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/CFCNet/828e0c09c646a4669685b3d31b8aa0ae2a5cd351/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | import models 6 | #import data 7 | 8 | 9 | class BaseOptions(): 10 | def __init__(self): 11 | self.initialized = False 12 | 13 | def initialize(self, parser): 14 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size') 15 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 16 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 17 | parser.add_argument('--model', type=str, default='DCCA_sparse', 18 | help='chooses which model to use. cycle_gan, pix2pix, test') 19 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 20 | parser.add_argument('--num_threads', default=8, type=int, help='# threads for loading data') 21 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 22 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') 23 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 24 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 25 | parser.add_argument('--resize_or_crop', type=str, default='none', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop|none]') 26 | parser.add_argument('--no_flip', action='store_true', default=True, help='if specified, do not flip the images for data augmentation') 27 | parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]') 28 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 29 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 30 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{loadSize}') 31 | parser.add_argument('--seed', type=int, default=0, help='seed for random generators') 32 | self.initialized = True 33 | return parser 34 | 35 | def gather_options(self, flag): 36 | # initialize parser with basic options 37 | if not self.initialized: 38 | parser = argparse.ArgumentParser( 39 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 40 | parser = self.initialize(parser, flag) 41 | 42 | # get the basic options 43 | opt, _ = parser.parse_known_args() 44 | 45 | # modify model-related parser options 46 | model_name = opt.model 47 | model_option_setter = models.get_option_setter(model_name) 48 | parser = model_option_setter(parser, self.isTrain) 49 | opt, _ = parser.parse_known_args() # parse again with the new defaults 50 | 51 | # modify dataset-related parser options 52 | # dataset_name = opt.dataset_mode 53 | # print(dataset_name) 54 | # dataset_option_setter = data.get_option_setter(dataset_name) 55 | # parser = dataset_option_setter(parser, self.isTrain) 56 | 57 | self.parser = parser 58 | 59 | return parser.parse_args() 60 | 61 | def print_options(self, opt): 62 | message = '' 63 | message += '----------------- Options ---------------\n' 64 | for k, v in sorted(vars(opt).items()): 65 | comment = '' 66 | default = self.parser.get_default(k) 67 | if v != default: 68 | comment = '\t[default: %s]' % str(default) 69 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 70 | message += '----------------- End -------------------' 71 | print(message) 72 | 73 | # save to the disk 74 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 75 | util.mkdirs(expr_dir) 76 | file_name = os.path.join(expr_dir, 'opt.txt') 77 | with open(file_name, 'wt') as opt_file: 78 | opt_file.write(message) 79 | opt_file.write('\n') 80 | 81 | def parse(self, flag): 82 | 83 | opt = self.gather_options(flag) 84 | opt.isTrain = self.isTrain # train or test 85 | 86 | # process opt.suffix 87 | if opt.suffix: 88 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 89 | opt.name = opt.name + suffix 90 | 91 | self.print_options(opt) 92 | 93 | # set gpu ids 94 | str_ids = opt.gpu_ids.split(',') 95 | opt.gpu_ids = [] 96 | for str_id in str_ids: 97 | id = int(str_id) 98 | if id >= 0: 99 | opt.gpu_ids.append(id) 100 | if len(opt.gpu_ids) > 0: 101 | torch.cuda.set_device(opt.gpu_ids[0]) 102 | 103 | self.opt = opt 104 | return self.opt 105 | -------------------------------------------------------------------------------- /options/options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class AdvanceOptions(BaseOptions): 5 | def initialize(self, parser, flag): 6 | parser = BaseOptions.initialize(self, parser) 7 | parser.add_argument('--print_freq', type=int, default=1, help='frequency of showing training results on console') 8 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') 9 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 10 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 11 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 12 | parser.add_argument('--niter', type=int, default=400, help='# of iter at starting learning rate') 13 | parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate for optimizer') 14 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum factor for SGD') 15 | parser.add_argument('--weight_decay', type=float, default=0.0005, help='momentum factor for optimizer') 16 | parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau|cosine') 17 | parser.add_argument('--lr_decay_iters', type=int, default=5000000, help='multiply by a gamma every lr_decay_iters iterations') 18 | parser.add_argument('--lr_decay_epochs', type=int, default=100, help='multiply by a gamma every lr_decay_epoch epochs') 19 | parser.add_argument('--lr_gamma', type=float, default=0.9, help='gamma factor for lr_scheduler') 20 | parser.add_argument('--nP', type=int, default=500, help='number of points') 21 | parser.add_argument('--train_path', help='path to the training dataset') 22 | parser.add_argument('--test_path', help='path to the testing dataset') 23 | self.isTrain = flag 24 | return parser 25 | -------------------------------------------------------------------------------- /train_depth_complete.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import time 3 | from options.options import AdvanceOptions 4 | from models import create_model 5 | from util.visualizer import Visualizer 6 | from dataloaders.nyu_dataloader import NYUDataset 7 | from dataloaders.kitti_dataloader import KITTIDataset 8 | from dataloaders.dense_to_sparse import UniformSampling, SimulatedStereo 9 | import numpy as np 10 | import random 11 | import torch 12 | import cv2 13 | 14 | if __name__ == '__main__': 15 | train_opt = AdvanceOptions().parse(True) 16 | 17 | # The SimulatedStereo class is also provided to subsample to stereo points 18 | sparsifier = UniformSampling(train_opt.nP, max_depth=np.inf) 19 | 20 | train_dataset = KITTIDataset(train_opt.train_path, type='train', 21 | modality='rgbdm', sparsifier=sparsifier) 22 | test_dataset = KITTIDataset(train_opt.test_path, type='val', 23 | modality='rgbdm', sparsifier=sparsifier) 24 | ## Please use this dataloder if you want to use NYU 25 | # train_dataset = NYUDataset(train_opt.train_path, type='train', 26 | # modality='rgbdm', sparsifier=sparsifier) 27 | ## Please use this dataloder if you want to use NYU 28 | # test_dataset = NYUDataset(train_opt.test_path, type='val', 29 | # modality='rgbdm', sparsifier=sparsifier) 30 | 31 | train_data_loader = torch.utils.data.DataLoader( 32 | train_dataset, batch_size=train_opt.batch_size, shuffle=True, 33 | num_workers=train_opt.num_threads, pin_memory=True, sampler=None, 34 | worker_init_fn=lambda work_id:np.random.seed(train_opt.seed + work_id)) 35 | test_opt = AdvanceOptions().parse(True) 36 | test_opt.phase = 'val' 37 | test_opt.batch_size = 1 38 | test_opt.num_threads = 1 39 | test_opt.serial_batches = True 40 | test_opt.no_flip = True 41 | 42 | test_data_loader = torch.utils.data.DataLoader(test_dataset, 43 | batch_size=test_opt.batch_size, shuffle=False, num_workers=test_opt.num_threads, pin_memory=True) 44 | 45 | train_dataset_size = len(train_data_loader) 46 | print('#training images = %d' % train_dataset_size) 47 | test_dataset_size = len(test_data_loader) 48 | print('#test images = %d' % test_dataset_size) 49 | 50 | model = create_model(train_opt, train_dataset) 51 | model.setup(train_opt) 52 | visualizer = Visualizer(train_opt) 53 | total_steps = 0 54 | for epoch in range(train_opt.epoch_count, train_opt.niter + 1): 55 | model.train() 56 | epoch_start_time = time.time() 57 | iter_data_time = time.time() 58 | epoch_iter = 0 59 | model.init_eval() 60 | iterator = iter(train_data_loader) 61 | while True: 62 | try: # Some images couldn't sample more than defined nP points under Stereo sampling 63 | next_batch = next(iterator) 64 | except IndexError: 65 | print("Catch and Skip!") 66 | continue 67 | except StopIteration: 68 | break 69 | data, target = next_batch[0], next_batch[1] 70 | 71 | iter_start_time = time.time() 72 | if total_steps % train_opt.print_freq == 0: 73 | t_data = iter_start_time - iter_data_time 74 | total_steps += train_opt.batch_size 75 | epoch_iter += train_opt.batch_size 76 | model.set_new_input(data,target) 77 | model.optimize_parameters() 78 | 79 | if total_steps % train_opt.print_freq == 0: 80 | losses = model.get_current_losses() 81 | t = (time.time() - iter_start_time) / train_opt.batch_size 82 | visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data) 83 | message = model.print_depth_evaluation() 84 | visualizer.print_current_depth_evaluation(message) 85 | print() 86 | 87 | iter_data_time = time.time() 88 | 89 | print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, train_opt.niter, time.time() - epoch_start_time)) 90 | model.update_learning_rate() 91 | if epoch and epoch % train_opt.save_epoch_freq == 0: 92 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) 93 | model.save_networks('latest') 94 | model.save_networks(epoch) 95 | 96 | model.eval() 97 | test_loss_iter = [] 98 | gts = None 99 | preds = None 100 | epoch_iter = 0 101 | model.init_test_eval() 102 | with torch.no_grad(): 103 | iterator = iter(test_data_loader) 104 | while True: 105 | try: # Some images couldn't sample more than defined nP points under Stereo sampling 106 | next_batch = next(iterator) 107 | except IndexError: 108 | print("Catch and Skip!") 109 | continue 110 | except StopIteration: 111 | break 112 | 113 | data, target = next_batch[0], next_batch[1] 114 | 115 | model.set_new_input(data,target) 116 | model.forward() 117 | model.test_depth_evaluation() 118 | model.get_loss() 119 | epoch_iter += test_opt.batch_size 120 | losses = model.get_current_losses() 121 | test_loss_iter.append(model.loss_dcca.item()) 122 | print('test epoch {0:}, iters: {1:}/{2:} '.format(epoch, epoch_iter, len(test_dataset) * test_opt.batch_size), end='\r') 123 | message = model.print_test_depth_evaluation() 124 | visualizer.print_current_depth_evaluation(message) 125 | print( 126 | 'RMSE={result.rmse:.4f}({average.rmse:.4f}) ' 127 | 'MAE={result.mae:.4f}({average.mae:.4f}) ' 128 | 'Delta1={result.delta1:.4f}({average.delta1:.4f}) ' 129 | 'REL={result.absrel:.4f}({average.absrel:.4f}) ' 130 | 'Lg10={result.lg10:.4f}({average.lg10:.4f}) '.format( 131 | result=model.test_result, average=model.test_average.average())) 132 | avg_test_loss = np.mean(np.asarray(test_loss_iter)) 133 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/CFCNet/828e0c09c646a4669685b3d31b8aa0ae2a5cd351/util/__init__.py -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import os 6 | 7 | 8 | # Converts a Tensor into an image array (numpy) 9 | # |imtype|: the desired type of the converted numpy array 10 | def tensor2im(input_image, imtype=np.uint8): 11 | if isinstance(input_image, torch.Tensor): 12 | image_tensor = input_image.data 13 | else: 14 | return input_image 15 | image_numpy = image_tensor[0].cpu().float().numpy() 16 | if image_numpy.shape[0] == 1: 17 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 18 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)))* 255.0 19 | return image_numpy.astype(imtype) 20 | 21 | 22 | def tensor2labelim(label_tensor, impalette, imtype=np.uint8): 23 | if len(label_tensor.shape) == 4: 24 | _, label_tensor = torch.max(label_tensor.data.cpu(), 1) 25 | 26 | label_numpy = label_tensor[0].cpu().float().detach().numpy() 27 | label_image = Image.fromarray(label_numpy.astype(np.uint8)) 28 | label_image = label_image.convert("P") 29 | label_image.putpalette(impalette) 30 | label_image = label_image.convert("RGB") 31 | return np.array(label_image).astype(imtype) 32 | 33 | def diagnose_network(net, name='network'): 34 | mean = 0.0 35 | count = 0 36 | for param in net.parameters(): 37 | if param.grad is not None: 38 | mean += torch.mean(torch.abs(param.grad.data)) 39 | count += 1 40 | if count > 0: 41 | mean = mean / count 42 | print(name) 43 | print(mean) 44 | 45 | 46 | def save_image(image_numpy, image_path): 47 | image_pil = Image.fromarray(image_numpy) 48 | image_pil.save(image_path) 49 | 50 | def save_image_cv2(image_numpy, image_path): 51 | #image_pil = Image.fromarray(image_numpy) 52 | cv2.imwrite(image_path,image_numpy) 53 | #image_pil.save(image_path) 54 | 55 | 56 | def print_numpy(x, val=True, shp=False): 57 | x = x.astype(np.float64) 58 | if shp: 59 | print('shape,', x.shape) 60 | if val: 61 | x = x.flatten() 62 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 63 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 64 | 65 | 66 | def mkdirs(paths): 67 | if isinstance(paths, list) and not isinstance(paths, str): 68 | for path in paths: 69 | mkdir(path) 70 | else: 71 | mkdir(paths) 72 | 73 | 74 | def mkdir(path): 75 | if not os.path.exists(path): 76 | os.makedirs(path) 77 | 78 | def confusion_matrix(x , y, n, ignore_label=None, mask=None): 79 | if mask is None: 80 | mask = np.ones_like(x) == 1 81 | k = (x >= 0) & (y < n) & (x != ignore_label) & (mask.astype(np.bool)) 82 | return np.bincount(n * x[k].astype(int) + y[k], minlength=n**2).reshape(n, n) 83 | 84 | def getScores(conf_matrix): 85 | if conf_matrix.sum() == 0: 86 | return 0, 0, 0 87 | with np.errstate(divide='ignore',invalid='ignore'): 88 | overall = np.diag(conf_matrix).sum() / np.float(conf_matrix.sum()) 89 | perclass = np.diag(conf_matrix) / conf_matrix.sum(1).astype(np.float) 90 | IU = np.diag(conf_matrix) / (conf_matrix.sum(1) + conf_matrix.sum(0) - np.diag(conf_matrix)).astype(np.float) 91 | return overall * 100., np.nanmean(perclass) * 100., np.nanmean(IU) * 100. 92 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class Visualizer(): 4 | def __init__(self, opt): 5 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 6 | 7 | # losses: same format as |losses| of plot_current_losses 8 | def print_current_losses(self, epoch, i, losses, t, t_data): 9 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data) 10 | for k, v in losses.items(): 11 | message += '%s: %.3f ' % (k, v) 12 | print(message) 13 | with open(self.log_name, "a") as log_file: 14 | log_file.write('%s\n' % message) 15 | 16 | def print_current_depth_evaluation(self, message): 17 | with open(self.log_name, "a") as log_file: 18 | log_file.write('%s\n' % message) 19 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import shutil 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from PIL import Image 7 | import cv2 8 | 9 | def colored_depthmap(depth, d_min=None, d_max=None): 10 | if d_min is None: 11 | d_min = np.min(depth) 12 | if d_max is None: 13 | d_max = np.max(depth) 14 | depth_relative = (depth - d_min) / (d_max - d_min) 15 | return 255 * plt.cm.jet(depth_relative)[:,:,:3] # H, W, C 16 | 17 | cmap = plt.cm.jet 18 | def depth_colorize_16(depth): 19 | depth = (depth - np.min(depth)) / (np.max(depth) - np.min(depth)) 20 | depth = 255* 256 * cmap(depth)[:,:,:3] # H, W, C 21 | return depth.astype('uint16') 22 | 23 | def depth_colorize_8(depth): 24 | depth = (depth - np.min(depth)) / (np.max(depth) - np.min(depth)) 25 | depth = 255* cmap(depth)[:,:,:3] # H, W, C 26 | return depth.astype('uint8') 27 | 28 | def Enlarge_pixel(sparse_depth): 29 | for i in range(2,sparse_depth.shape[0]-2): 30 | for j in range(2,sparse_depth.shape[1]-2): 31 | if np.sum(sparse_depth[i][j]) > 0: 32 | for w in range(-2,2): 33 | for h in range(-2,2): 34 | sparse_depth[i+w][j+h] = sparse_depth[i][j] 35 | 36 | return sparse_depth 37 | 38 | def merge_into_row_with_pred_visualize(input, depth_input, rgb_sparse, depth_target, depth_est): 39 | rgb = 255 * np.transpose(np.squeeze(input.cpu().numpy()), (1,2,0))[:,:,(2,1,0)] # H, W, C 40 | rgb_sparse = 255 * np.transpose(np.squeeze(rgb_sparse.cpu().numpy()), (1,2,0))[:,:,(2,1,0)] 41 | depth_input_cpu = np.squeeze(depth_input.cpu().numpy()) 42 | depth_target_cpu = np.squeeze(depth_target.cpu().numpy()) 43 | depth_pred_cpu = np.squeeze(depth_est.cpu().numpy()) 44 | 45 | d_min = min(np.min(depth_input_cpu), np.min(depth_target_cpu), np.min(depth_pred_cpu)) 46 | d_max = max(np.max(depth_input_cpu), np.max(depth_target_cpu), np.min(depth_pred_cpu)) 47 | depth_input_col = colored_depthmap(depth_input_cpu, d_min, d_max) 48 | # depth_target_col = colored_depthmap(depth_target_cpu, d_min, d_max) 49 | # depth_pred_col = colored_depthmap(depth_pred_cpu, d_min, d_max) 50 | depth_input_col = (depth_colorize_8(depth_input_cpu)) 51 | depth_target_col = (depth_colorize_8(depth_target_cpu)) 52 | depth_pred_col = depth_colorize_8(depth_pred_cpu) 53 | 54 | img_merge = np.hstack([rgb, depth_pred_col]) 55 | #img_merge = np.hstack([rgb,depth_input_col]) 56 | #depth_merge = np.hstack([depth_pred_col,depth_target_col]) 57 | #img_merge = np.vstack([img_merge,depth_merge]) 58 | 59 | return img_merge 60 | 61 | def add_row(img_merge, row): 62 | return np.vstack([img_merge, row]) 63 | 64 | def save_image(img_merge, filename): 65 | img_merge = Image.fromarray(img_merge.astype('uint8')) 66 | img_merge.save(filename) 67 | 68 | def save_image_cv2(image_numpy, image_path): 69 | #image_pil = Image.fromarray(image_numpy) 70 | cv2.imwrite(image_path,image_numpy) 71 | #image_pil.save(image_path) 72 | 73 | 74 | --------------------------------------------------------------------------------