├── .gitignore
├── LICENSE
├── README.md
├── dataloaders
├── city_dataloader.py
├── dataloader.py
├── dense_to_sparse.py
├── kitti_dataloader.py
├── nyu_dataloader.py
└── transforms.py
├── evaluate.py
├── images
└── 500.gif
├── models
├── DCCA_sparse_model.py
├── DCCA_sparse_networks.py
├── __init__.py
└── base_model.py
├── options
├── __init__.py
├── base_options.py
└── options.py
├── train_depth_complete.py
├── util
├── __init__.py
├── util.py
└── visualizer.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | # others
107 | .vis/
108 | .checkpoints/
109 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Cho Ying Wu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep RGB-D Canonical Correlation Analysis for Sparse Depth Completion
2 | This is the official PyTorch implemenation for our NeurIPS 2019 paper by Yiqi Zhong\*, Cho-Ying Wu\*, Suya You, Ulrich Neumann (\*Equal Contribution) at USC
3 |
4 | Paper: [Arxiv].
5 |
6 |
7 |
8 | Check out the whole video demo [Youtube].
9 |
10 | **Also check our newest work on depth estimation/completion using sensor fusion SCADC!**
11 |
12 | # Prerequisites
13 | Linux
14 | Python 3
15 | PyTorch 1.0+ (Orginally developed upder v1.0, testing on v1.5 is also fine)
16 | NVIDIA GPU + CUDA CuDNN
17 | Other common libraries: matplotlib, cv2, PIL
18 |
19 | # Getting Started
20 |
21 | Data Preparation:
22 | Please refer to [KITTI] or [NYU Depth V2] and process them into h5 files. Here also provides preprocessed data.
23 |
24 | # Tutorial:
25 |
26 | 1. Create a folder and a subfolder 'checkpoint/kitti'
27 | 2. Download the pretrained weights: [NYU-Depth 500 points training] [KITTI 500 points training] and put the .pth under 'checkpoint/kitti/'
28 | 3. Prepare data in the previous "getting started" section
29 | 4. Run "python3 evaluate.py --name kitti --checkpoints_dir ./checkpoint/ --test_path [path ot the testing file] "
30 | 4. You'll see completed depth maps are saved under 'vis/'
31 |
32 | # Train/Evaluation:
33 |
34 | For training, please run
35 |
36 | python3 train_depth_complete.py --name kitti --checkpoints_dir [path to save_dir] --train_path [train_data_dir] --test_path [test_data_dir]
37 |
38 | If you use the preprocessed data from here. The train/test data path should be ./kitti/train or ./kitti/val/ under your data directory.
39 |
40 | If you want to use your data, please make your data into h5 dataset. (See dataloaders/dataloader.py)
41 |
42 | Other specifications: `--continue_train` would load the lastest saved ckpt. Also set --epoch_count to tell what's the next epoch_number. Otherwise, will start from epoch 0. Set hyperparameters by `--lr`, `--batch_size`, `--weight_decay`, or others. Please refer to the options/base_options.py and options/options.py
43 |
44 | Note that the default batch size is 4 during the training and use gpu:0. You can set larger batch size (--batch_size=xx) with more gpus (--gpu_ids="0,1,2,3") to attain larger batch size training.
45 |
46 | Example command:
47 |
48 | python3 train_depth_complete.py --name kitti --checkpoints_dir ./checkpoints --lr 0.001 --batch_size 4 --train_path './kitti/train/' --test_path './kitti/val/' --continue_train --epoch_count [next_epoch_number]
49 |
50 | For evalutation, please run
51 |
52 | python3 evaluate.py --name kitti --checkpoints_dir [path to save_dir to load ckpt] --test_path [test_data_dir] [--epoch [epoch number]]
53 |
54 | This will load the latest checkpoint to evaluate. Add `--epoch` to specify which epoch checkpoint you want to load.
55 |
56 | # Update: 02/10/2020
57 |
58 | 1.Fix several bugs and take off redundant options.
59 |
60 | 2.Release Orb sparsifier
61 |
62 | 3.Pretrain models release: [NYU-Depth 500 points training] [KITTI 500 points training]
63 |
64 |
65 | # Update: 04/19/2021
66 |
67 | 1. Revise README and add a tutorial
68 | 2. Several minor revisions
69 |
70 |
71 | If you find our work useful, please consider to cite our work.
72 |
73 | @inproceedings{zhong2019deep,
74 | title={Deep rgb-d canonical correlation analysis for sparse depth completion},
75 | author={Zhong, Yiqi and Wu, Cho-Ying and You, Suya and Neumann, Ulrich},
76 | booktitle={Advances in Neural Information Processing Systems},
77 | pages={5332--5342},
78 | year={2019}
79 |
80 |
81 |
--------------------------------------------------------------------------------
/dataloaders/city_dataloader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import dataloaders.transforms as transforms
3 | from dataloaders.dataloader import MyDataloader
4 |
5 | class CITY_SCAPESDataset(MyDataloader):
6 | def __init__(self, root, type, sparsifier=None, modality='rgb'):
7 | super(CITY_SCAPESDataset, self).__init__(root, type, sparsifier, modality)
8 | self.output_size = (228, 912)
9 |
10 | def train_transform(self, rgb, depth):
11 | s = np.random.uniform(1.0, 1.5) # random scaling
12 | depth_np = depth / s
13 | angle = np.random.uniform(-5.0, 5.0) # random rotation degrees
14 | do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip
15 |
16 | # perform 1st step of data augmentation
17 | transform = transforms.Compose([
18 | transforms.Crop(0, 20, 750, 2000),
19 | transforms.Resize(500 / 750),
20 | transforms.Rotate(angle),
21 | transforms.Resize(s),
22 | transforms.CenterCrop(self.output_size),
23 | transforms.HorizontalFlip(do_flip)
24 | ])
25 | rgb_np = transform(rgb)
26 | rgb_np = self.color_jitter(rgb_np) # random color jittering
27 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255
28 | # Scipy affine_transform produced RuntimeError when the depth map was
29 | # given as a 'numpy.ndarray'
30 | depth_np = np.asfarray(depth_np, dtype='float32')
31 | depth_np = transform(depth_np)
32 |
33 | return rgb_np, depth_np
34 |
35 | def val_transform(self, rgb, depth):
36 | depth_np = depth
37 | transform = transforms.Compose([
38 | transforms.Crop(0, 20, 750, 2000),
39 | transforms.Resize(500 / 750),
40 | transforms.CenterCrop(self.output_size),
41 | ])
42 | rgb_np = transform(rgb)
43 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255
44 | depth_np = np.asfarray(depth_np, dtype='float32')
45 | depth_np = transform(depth_np)
46 |
47 | return rgb_np, depth_np
48 |
49 |
--------------------------------------------------------------------------------
/dataloaders/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import numpy as np
4 | import torch.utils.data as data
5 | import h5py
6 | import dataloaders.transforms as transforms
7 | import torch
8 |
9 | IMG_EXTENSIONS = ['.h5',]
10 |
11 | def is_image_file(filename):
12 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
13 |
14 | def find_classes(dir):
15 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
16 | classes.sort()
17 | class_to_idx = {classes[i]: i for i in range(len(classes))}
18 | return classes, class_to_idx
19 |
20 | def make_dataset(dir, class_to_idx):
21 | images = []
22 | dir = os.path.expanduser(dir)
23 | for target in sorted(os.listdir(dir)):
24 | d = os.path.join(dir, target)
25 | if not os.path.isdir(d):
26 | continue
27 | for root, _, fnames in sorted(os.walk(d)):
28 | for fname in sorted(fnames):
29 | if is_image_file(fname):
30 | path = os.path.join(root, fname)
31 | item = (path, class_to_idx[target])
32 | images.append(item)
33 | return images
34 |
35 | def h5_loader(path):
36 | h5f = h5py.File(path, "r")
37 | rgb = np.array(h5f['rgb'])
38 | rgb = np.transpose(rgb, (1, 2, 0))
39 | depth = np.array(h5f['depth'])
40 | return rgb, depth
41 |
42 | to_tensor = transforms.ToTensor()
43 |
44 | class MyDataloader(data.Dataset):
45 | modality_names = ['rgb', 'rgbd', 'd','rgbdm'] # , 'g', 'gd'
46 | color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4)
47 |
48 | def __init__(self, root, type, sparsifier=None, modality='rgb', loader=h5_loader):
49 | classes, class_to_idx = find_classes(root)
50 | imgs = make_dataset(root, class_to_idx)
51 | assert len(imgs)>0, "Found 0 images in subfolders of: " + root + "\n"
52 | print("Found {} images in {} folder.".format(len(imgs), type))
53 | self.root = root
54 | self.imgs = imgs
55 | self.classes = classes
56 | self.class_to_idx = class_to_idx
57 | if type == 'train':
58 | self.transform = self.train_transform
59 | elif type == 'val':
60 | self.transform = self.val_transform
61 | else:
62 | raise (RuntimeError("Invalid dataset type: " + type + "\n"
63 | "Supported dataset types are: train, val"))
64 | self.loader = loader
65 | self.sparsifier = sparsifier
66 |
67 | assert (modality in self.modality_names), "Invalid modality type: " + modality + "\n" + \
68 | "Supported dataset types are: " + ''.join(self.modality_names)
69 | self.modality = modality
70 |
71 | def train_transform(self, rgb, depth):
72 | raise (RuntimeError("train_transform() is not implemented. "))
73 |
74 | def val_transform(rgb, depth):
75 | raise (RuntimeError("val_transform() is not implemented."))
76 |
77 | def create_sparse_depth(self, rgb, depth):
78 | if self.sparsifier is None:
79 | return depth
80 | else:
81 | mask_keep = self.sparsifier.dense_to_sparse(rgb, depth)
82 | sparse_depth = np.zeros(depth.shape)
83 | sparse_depth[mask_keep] = depth[mask_keep]
84 | return sparse_depth
85 |
86 | def create_sparse_depth_rgb(self, rgb, depth):
87 | if self.sparsifier is None:
88 | return depth
89 | else:
90 | mask_keep = self.sparsifier.dense_to_sparse(rgb, depth)
91 | sparse_depth = np.zeros(depth.shape)
92 | sparse_depth[mask_keep] = depth[mask_keep]
93 | sparse_rgb = np.zeros(rgb.shape)
94 | sparse_rgb[mask_keep,:] = rgb[mask_keep,:]
95 | sparse_mask = np.zeros(depth.shape)
96 | sparse_mask[mask_keep] = 1
97 | mask_keep = mask_keep.astype(np.uint8)
98 | return sparse_depth,sparse_rgb, mask_keep
99 |
100 | def create_rgbdm(self, rgb, depth):
101 | sparse_depth,sparse_rgb,mask = self.create_sparse_depth_rgb(rgb, depth)
102 | rgbd = np.append(rgb, np.expand_dims(sparse_depth, axis=2),axis=2)
103 | rgbdm = np.append(rgbd, sparse_rgb, axis=2)
104 | rgbdm = np.append(rgbdm, np.expand_dims(mask, axis=2),axis=2)
105 |
106 | return rgbdm
107 |
108 | def create_rgbd(self, rgb, depth):
109 | sparse_depth = self.create_sparse_depth(rgb, depth)
110 | rgbd = np.append(rgb, np.expand_dims(sparse_depth, axis=2), axis=2)
111 |
112 | return rgbd
113 |
114 | def __getraw__(self, index):
115 | path, target = self.imgs[index]
116 | rgb, depth = self.loader(path)
117 | return rgb, depth
118 |
119 | def __getitem__(self, index):
120 | rgb, depth = self.__getraw__(index)
121 | if self.transform is not None:
122 | rgb_np, depth_np = self.transform(rgb, depth)
123 | else:
124 | raise(RuntimeError("transform not defined"))
125 |
126 | if self.modality == 'rgb':
127 | input_np = rgb_np
128 | elif self.modality == 'rgbd':
129 | input_np = self.create_rgbd(rgb_np, depth_np)
130 | elif self.modality == 'd':
131 | input_np = self.create_sparse_depth(rgb_np, depth_np)
132 | elif self.modality == 'rgbdm':
133 | input_np = self.create_rgbdm(rgb_np, depth_np)
134 |
135 | input_tensor = to_tensor(input_np)
136 | while input_tensor.dim() < 3:
137 | input_tensor = input_tensor.unsqueeze(0)
138 | depth_tensor = to_tensor(depth_np)
139 | depth_tensor = depth_tensor.unsqueeze(0)
140 |
141 | return input_tensor, depth_tensor
142 |
143 | def __len__(self):
144 | return len(self.imgs)
--------------------------------------------------------------------------------
/dataloaders/dense_to_sparse.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 |
4 |
5 | def rgb2grayscale(rgb):
6 | return rgb[:, :, 0] * 0.2989 + rgb[:, :, 1] * 0.587 + rgb[:, :, 2] * 0.114
7 |
8 |
9 | class DenseToSparse:
10 | def __init__(self):
11 | pass
12 |
13 | def dense_to_sparse(self, rgb, depth):
14 | pass
15 |
16 | def __repr__(self):
17 | pass
18 |
19 | class UniformSampling(DenseToSparse):
20 | name = "uar"
21 | def __init__(self, num_samples, max_depth=np.inf):
22 | DenseToSparse.__init__(self)
23 | self.num_samples = num_samples
24 | self.max_depth = max_depth
25 |
26 | def __repr__(self):
27 | return "%s{ns=%d,md=%f}" % (self.name, self.num_samples, self.max_depth)
28 |
29 | def dense_to_sparse(self, rgb, depth):
30 | """
31 | Samples pixels with `num_samples`/#pixels probability in `depth`.
32 | Only pixels with a maximum depth of `max_depth` are considered.
33 | If no `max_depth` is given, samples in all pixels
34 | """
35 | mask_keep = depth > 0
36 | if self.max_depth is not np.inf:
37 | mask_keep = np.bitwise_and(mask_keep, depth <= self.max_depth)
38 | n_keep = np.count_nonzero(mask_keep)
39 | if n_keep == 0:
40 | return mask_keep
41 | else:
42 | prob = float(self.num_samples) / n_keep
43 | return np.bitwise_and(mask_keep, np.random.uniform(0, 1, depth.shape) < prob)
44 |
45 |
46 |
47 | class SimulatedStereo(DenseToSparse):
48 | name = "sim_stereo"
49 |
50 | def __init__(self, num_samples, max_depth=np.inf, dilate_kernel=3, dilate_iterations=1):
51 | DenseToSparse.__init__(self)
52 | self.num_samples = num_samples
53 | self.max_depth = max_depth
54 | self.dilate_kernel = dilate_kernel
55 | self.dilate_iterations = dilate_iterations
56 |
57 | def __repr__(self):
58 | return "%s{ns=%d,md=%f,dil=%d.%d}" % \
59 | (self.name, self.num_samples, self.max_depth, self.dilate_kernel, self.dilate_iterations)
60 |
61 | # We do not use cv2.Canny, since that applies non max suppression
62 | # So we simply do
63 | # RGB to intensitities
64 | # Smooth with gaussian
65 | # Take simple sobel gradients
66 | # Threshold the edge gradient
67 | # Dilatate
68 | def dense_to_sparse(self, rgb, depth):
69 | gray = rgb2grayscale(rgb)
70 |
71 |
72 | blurred = cv2.GaussianBlur(gray, (5, 5), 0)
73 | gx = cv2.Sobel(blurred, cv2.CV_64F, 1, 0, ksize=5)
74 | gy = cv2.Sobel(blurred, cv2.CV_64F, 0, 1, ksize=5)
75 |
76 | depth_mask = np.bitwise_and(depth != 0.0, depth <= self.max_depth)
77 |
78 | edge_fraction = float(self.num_samples) / np.size(depth)
79 |
80 | mag = cv2.magnitude(gx, gy)
81 | min_mag = np.percentile(mag[depth_mask], 100 * (1.0 - edge_fraction))
82 | mag_mask = mag >= min_mag
83 |
84 | if self.dilate_iterations >= 0:
85 | kernel = np.ones((self.dilate_kernel, self.dilate_kernel), dtype=np.uint8)
86 | cv2.dilate(mag_mask.astype(np.uint8), kernel, iterations=self.dilate_iterations)
87 |
88 | mask = np.bitwise_and(mag_mask, depth_mask)
89 | return mask
90 |
91 |
92 | class ORBSampling(DenseToSparse):
93 | name = "ORB"
94 | def __init__(self,max_depth=np.inf):
95 | DenseToSparse.__init__(self)
96 | self.max_depth = max_depth
97 |
98 | def __repr__(self):
99 | return "%s{ns=%d,md=%f}" % (self.name, self.max_depth)
100 |
101 | def dense_to_sparse(self, rgb, depth):
102 | """
103 | Samples pixels with `num_samples`/#pixels probability in `depth`.
104 | Only pixels with a maximum depth of `max_depth` are considered.
105 | If no `max_depth` is given, samples in all pixels
106 | """
107 | mask_keep = depth > 0
108 |
109 | orb = cv2.ORB_create()
110 | rgb_ori = (rgb.copy()*255).astype(np.uint8)
111 | kp = orb.detect(rgb_ori,None)
112 |
113 | mask_keep_orb = np.zeros(mask_keep.shape).astype(mask_keep.dtype)
114 | for marker in kp:
115 | position = np.asarray(marker.pt).astype(np.uint8)
116 | mask_keep_orb[position[1]][position[0]] = True
117 | if self.max_depth is not np.inf:
118 | mask_keep = np.bitwise_and(mask_keep, depth <= self.max_depth)
119 |
120 | mask_keep = np.bitwise_and(mask_keep, mask_keep_orb)
121 | n_keep = np.count_nonzero(mask_keep)
122 | return mask_keep
--------------------------------------------------------------------------------
/dataloaders/kitti_dataloader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import dataloaders.transforms as transforms
3 | from dataloaders.dataloader import MyDataloader
4 |
5 | class KITTIDataset(MyDataloader):
6 | def __init__(self, root, type, sparsifier=None, modality='rgb'):
7 | super(KITTIDataset, self).__init__(root, type, sparsifier, modality)
8 | self.output_size = (228, 912)
9 |
10 | def train_transform(self, rgb, depth):
11 | s = np.random.uniform(1.0, 1.5) # random scaling
12 | depth_np = depth / s
13 | angle = np.random.uniform(-5.0, 5.0) # random rotation degrees
14 | do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip
15 |
16 | # perform 1st step of data augmentation
17 | transform = transforms.Compose([
18 | transforms.Crop(130, 10, 240, 1200),
19 | transforms.Rotate(angle),
20 | transforms.Resize(s),
21 | transforms.CenterCrop(self.output_size),
22 | transforms.HorizontalFlip(do_flip)
23 | ])
24 | rgb_np = transform(rgb)
25 | rgb_np = self.color_jitter((rgb_np*255.0).astype(np.uint8)) # random color jittering
26 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255
27 | # Scipy affine_transform produced RuntimeError when the depth map was
28 | # given as a 'numpy.ndarray'
29 | depth_np = np.asfarray(depth_np, dtype='float32')
30 | depth_np = transform(depth_np)
31 |
32 | return rgb_np, depth_np
33 |
34 | def val_transform(self, rgb, depth):
35 | depth_np = depth
36 | transform = transforms.Compose([
37 | #transforms.Crop(130, 10, 240, 1200),
38 | transforms.CenterCrop(self.output_size),
39 | ])
40 | rgb_np = transform(rgb)
41 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255
42 | depth_np = np.asfarray(depth_np, dtype='float32')
43 | depth_np = transform(depth_np)
44 |
45 | return rgb_np, depth_np
46 |
47 |
--------------------------------------------------------------------------------
/dataloaders/nyu_dataloader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import dataloaders.transforms as transforms
3 | from dataloaders.dataloader import MyDataloader
4 |
5 | iheight, iwidth = 480, 640 # raw image size
6 |
7 | class NYUDataset(MyDataloader):
8 | def __init__(self, root, type, sparsifier=None, modality='rgbdm'):
9 | super(NYUDataset, self).__init__(root, type, sparsifier, modality)
10 | self.output_size = (224, 224)
11 |
12 | def train_transform(self, rgb, depth):
13 | s = np.random.uniform(1.0, 1.5) # random scaling
14 | depth_np = depth / s
15 | angle = np.random.uniform(-5.0, 5.0) # random rotation degrees
16 | do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip
17 |
18 | # perform 1st step of data augmentation
19 | transform = transforms.Compose([
20 | transforms.Resize(250.0 / iheight), # this is for computational efficiency, since rotation can be slow
21 | transforms.Rotate(angle),
22 | transforms.Resize(s),
23 | transforms.CenterCrop(self.output_size),
24 | transforms.HorizontalFlip(do_flip)
25 | ])
26 | rgb_np = transform(rgb)
27 | rgb_np = self.color_jitter(rgb_np) # random color jittering
28 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255
29 | depth_np = transform(depth_np)
30 |
31 | return rgb_np, depth_np
32 |
33 | def val_transform(self, rgb, depth):
34 | depth_np = depth
35 | transform = transforms.Compose([
36 | transforms.Resize(240.0 / iheight),
37 | transforms.CenterCrop(self.output_size),
38 | ])
39 | rgb_np = transform(rgb)
40 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255
41 | depth_np = transform(depth_np)
42 |
43 | return rgb_np, depth_np
44 |
--------------------------------------------------------------------------------
/dataloaders/transforms.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 | import math
4 | import random
5 |
6 | from PIL import Image, ImageOps, ImageEnhance
7 | try:
8 | import accimage
9 | except ImportError:
10 | accimage = None
11 |
12 | import numpy as np
13 | import numbers
14 | import types
15 | import collections
16 | import warnings
17 |
18 | import scipy.ndimage.interpolation as itpl
19 | import scipy.misc as misc
20 | import skimage.transform
21 |
22 |
23 | def _is_numpy_image(img):
24 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
25 |
26 | def _is_pil_image(img):
27 | if accimage is not None:
28 | return isinstance(img, (Image.Image, accimage.Image))
29 | else:
30 | return isinstance(img, Image.Image)
31 |
32 | def _is_tensor_image(img):
33 | return torch.is_tensor(img) and img.ndimension() == 3
34 |
35 | def adjust_brightness(img, brightness_factor):
36 | """Adjust brightness of an Image.
37 |
38 | Args:
39 | img (PIL Image): PIL Image to be adjusted.
40 | brightness_factor (float): How much to adjust the brightness. Can be
41 | any non negative number. 0 gives a black image, 1 gives the
42 | original image while 2 increases the brightness by a factor of 2.
43 |
44 | Returns:
45 | PIL Image: Brightness adjusted image.
46 | """
47 | if not _is_pil_image(img):
48 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
49 |
50 | enhancer = ImageEnhance.Brightness(img)
51 | img = enhancer.enhance(brightness_factor)
52 | return img
53 |
54 |
55 | def adjust_contrast(img, contrast_factor):
56 | """Adjust contrast of an Image.
57 |
58 | Args:
59 | img (PIL Image): PIL Image to be adjusted.
60 | contrast_factor (float): How much to adjust the contrast. Can be any
61 | non negative number. 0 gives a solid gray image, 1 gives the
62 | original image while 2 increases the contrast by a factor of 2.
63 |
64 | Returns:
65 | PIL Image: Contrast adjusted image.
66 | """
67 | if not _is_pil_image(img):
68 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
69 |
70 | enhancer = ImageEnhance.Contrast(img)
71 | img = enhancer.enhance(contrast_factor)
72 | return img
73 |
74 |
75 | def adjust_saturation(img, saturation_factor):
76 | """Adjust color saturation of an image.
77 |
78 | Args:
79 | img (PIL Image): PIL Image to be adjusted.
80 | saturation_factor (float): How much to adjust the saturation. 0 will
81 | give a black and white image, 1 will give the original image while
82 | 2 will enhance the saturation by a factor of 2.
83 |
84 | Returns:
85 | PIL Image: Saturation adjusted image.
86 | """
87 | if not _is_pil_image(img):
88 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
89 |
90 | enhancer = ImageEnhance.Color(img)
91 | img = enhancer.enhance(saturation_factor)
92 | return img
93 |
94 |
95 | def adjust_hue(img, hue_factor):
96 | """Adjust hue of an image.
97 |
98 | The image hue is adjusted by converting the image to HSV and
99 | cyclically shifting the intensities in the hue channel (H).
100 | The image is then converted back to original image mode.
101 |
102 | `hue_factor` is the amount of shift in H channel and must be in the
103 | interval `[-0.5, 0.5]`.
104 |
105 | See https://en.wikipedia.org/wiki/Hue for more details on Hue.
106 |
107 | Args:
108 | img (PIL Image): PIL Image to be adjusted.
109 | hue_factor (float): How much to shift the hue channel. Should be in
110 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
111 | HSV space in positive and negative direction respectively.
112 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image
113 | with complementary colors while 0 gives the original image.
114 |
115 | Returns:
116 | PIL Image: Hue adjusted image.
117 | """
118 | if not(-0.5 <= hue_factor <= 0.5):
119 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
120 |
121 | if not _is_pil_image(img):
122 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
123 |
124 | input_mode = img.mode
125 | if input_mode in {'L', '1', 'I', 'F'}:
126 | return img
127 |
128 | h, s, v = img.convert('HSV').split()
129 |
130 | np_h = np.array(h, dtype=np.uint8)
131 | # uint8 addition take cares of rotation across boundaries
132 | with np.errstate(over='ignore'):
133 | np_h += np.uint8(hue_factor * 255)
134 | h = Image.fromarray(np_h, 'L')
135 |
136 | img = Image.merge('HSV', (h, s, v)).convert(input_mode)
137 | return img
138 |
139 |
140 | def adjust_gamma(img, gamma, gain=1):
141 | """Perform gamma correction on an image.
142 |
143 | Also known as Power Law Transform. Intensities in RGB mode are adjusted
144 | based on the following equation:
145 |
146 | I_out = 255 * gain * ((I_in / 255) ** gamma)
147 |
148 | See https://en.wikipedia.org/wiki/Gamma_correction for more details.
149 |
150 | Args:
151 | img (PIL Image): PIL Image to be adjusted.
152 | gamma (float): Non negative real number. gamma larger than 1 make the
153 | shadows darker, while gamma smaller than 1 make dark regions
154 | lighter.
155 | gain (float): The constant multiplier.
156 | """
157 | if not _is_pil_image(img):
158 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
159 |
160 | if gamma < 0:
161 | raise ValueError('Gamma should be a non-negative real number')
162 |
163 | input_mode = img.mode
164 | img = img.convert('RGB')
165 |
166 | np_img = np.array(img, dtype=np.float32)
167 | np_img = 255 * gain * ((np_img / 255) ** gamma)
168 | np_img = np.uint8(np.clip(np_img, 0, 255))
169 |
170 | img = Image.fromarray(np_img, 'RGB').convert(input_mode)
171 | return img
172 |
173 |
174 | class Compose(object):
175 | """Composes several transforms together.
176 |
177 | Args:
178 | transforms (list of ``Transform`` objects): list of transforms to compose.
179 |
180 | Example:
181 | >>> transforms.Compose([
182 | >>> transforms.CenterCrop(10),
183 | >>> transforms.ToTensor(),
184 | >>> ])
185 | """
186 |
187 | def __init__(self, transforms):
188 | self.transforms = transforms
189 |
190 | def __call__(self, img):
191 | for t in self.transforms:
192 | img = t(img)
193 | return img
194 |
195 |
196 | class ToTensor(object):
197 | """Convert a ``numpy.ndarray`` to tensor.
198 |
199 | Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W).
200 | """
201 |
202 | def __call__(self, img):
203 | """Convert a ``numpy.ndarray`` to tensor.
204 |
205 | Args:
206 | img (numpy.ndarray): Image to be converted to tensor.
207 |
208 | Returns:
209 | Tensor: Converted image.
210 | """
211 | if not(_is_numpy_image(img)):
212 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
213 |
214 | if isinstance(img, np.ndarray):
215 | # handle numpy array
216 | if img.ndim == 3:
217 | img = torch.from_numpy(img.transpose((2, 0, 1)).copy())
218 | elif img.ndim == 2:
219 | img = torch.from_numpy(img.copy())
220 | else:
221 | raise RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim))
222 |
223 | # backward compatibility
224 | # return img.float().div(255)
225 | return img.float()
226 |
227 |
228 | class NormalizeNumpyArray(object):
229 | """Normalize a ``numpy.ndarray`` with mean and standard deviation.
230 | Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform
231 | will normalize each channel of the input ``numpy.ndarray`` i.e.
232 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
233 |
234 | Args:
235 | mean (sequence): Sequence of means for each channel.
236 | std (sequence): Sequence of standard deviations for each channel.
237 | """
238 |
239 | def __init__(self, mean, std):
240 | self.mean = mean
241 | self.std = std
242 |
243 | def __call__(self, img):
244 | """
245 | Args:
246 | img (numpy.ndarray): Image of size (H, W, C) to be normalized.
247 |
248 | Returns:
249 | Tensor: Normalized image.
250 | """
251 | if not(_is_numpy_image(img)):
252 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
253 | # TODO: make efficient
254 | print(img.shape)
255 | for i in range(3):
256 | img[:,:,i] = (img[:,:,i] - self.mean[i]) / self.std[i]
257 | return img
258 |
259 | class NormalizeTensor(object):
260 | """Normalize an tensor image with mean and standard deviation.
261 | Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform
262 | will normalize each channel of the input ``torch.*Tensor`` i.e.
263 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
264 |
265 | Args:
266 | mean (sequence): Sequence of means for each channel.
267 | std (sequence): Sequence of standard deviations for each channel.
268 | """
269 |
270 | def __init__(self, mean, std):
271 | self.mean = mean
272 | self.std = std
273 |
274 | def __call__(self, tensor):
275 | """
276 | Args:
277 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
278 |
279 | Returns:
280 | Tensor: Normalized Tensor image.
281 | """
282 | if not _is_tensor_image(tensor):
283 | raise TypeError('tensor is not a torch image.')
284 | # TODO: make efficient
285 | for t, m, s in zip(tensor, self.mean, self.std):
286 | t.sub_(m).div_(s)
287 | return tensor
288 |
289 | class Rotate(object):
290 | """Rotates the given ``numpy.ndarray``.
291 |
292 | Args:
293 | angle (float): The rotation angle in degrees.
294 | """
295 |
296 | def __init__(self, angle):
297 | self.angle = angle
298 |
299 | def __call__(self, img):
300 | """
301 | Args:
302 | img (numpy.ndarray (C x H x W)): Image to be rotated.
303 |
304 | Returns:
305 | img (numpy.ndarray (C x H x W)): Rotated image.
306 | """
307 |
308 | # order=0 means nearest-neighbor type interpolation
309 | return itpl.rotate(img, self.angle, reshape=False, prefilter=False, order=0)
310 |
311 |
312 | class Resize(object):
313 | """Resize the the given ``numpy.ndarray`` to the given size.
314 | Args:
315 | size (sequence or int): Desired output size. If size is a sequence like
316 | (h, w), output size will be matched to this. If size is an int,
317 | smaller edge of the image will be matched to this number.
318 | i.e, if height > width, then image will be rescaled to
319 | (size * height / width, size)
320 | interpolation (int, optional): Desired interpolation. Default is
321 | ``PIL.Image.BILINEAR``
322 | """
323 |
324 | def __init__(self, size, interpolation='nearest'):
325 | assert isinstance(size, int) or isinstance(size, float) or \
326 | (isinstance(size, collections.Iterable) and len(size) == 2)
327 | self.size = size
328 | self.interpolation = interpolation
329 |
330 | def __call__(self, img):
331 | """
332 | Args:
333 | img (PIL Image): Image to be scaled.
334 | Returns:
335 | PIL Image: Rescaled image.
336 | """
337 | if img.ndim == 3:
338 | return skimage.transform.rescale(img, self.size, order=0)
339 | elif img.ndim == 2:
340 | return skimage.transform.rescale(img, self.size, order=0)
341 | else:
342 | RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim))
343 |
344 |
345 | class CenterCrop(object):
346 | """Crops the given ``numpy.ndarray`` at the center.
347 |
348 | Args:
349 | size (sequence or int): Desired output size of the crop. If size is an
350 | int instead of sequence like (h, w), a square crop (size, size) is
351 | made.
352 | """
353 |
354 | def __init__(self, size):
355 | if isinstance(size, numbers.Number):
356 | self.size = (int(size), int(size))
357 | else:
358 | self.size = size
359 |
360 | @staticmethod
361 | def get_params(img, output_size):
362 | """Get parameters for ``crop`` for center crop.
363 |
364 | Args:
365 | img (numpy.ndarray (C x H x W)): Image to be cropped.
366 | output_size (tuple): Expected output size of the crop.
367 |
368 | Returns:
369 | tuple: params (i, j, h, w) to be passed to ``crop`` for center crop.
370 | """
371 | h = img.shape[0]
372 | w = img.shape[1]
373 | th, tw = output_size
374 | i = int(round((h - th) / 2.))
375 | j = int(round((w - tw) / 2.))
376 |
377 | # # randomized cropping
378 | # i = np.random.randint(i-3, i+4)
379 | # j = np.random.randint(j-3, j+4)
380 |
381 | return i, j, th, tw
382 |
383 | def __call__(self, img):
384 | """
385 | Args:
386 | img (numpy.ndarray (C x H x W)): Image to be cropped.
387 |
388 | Returns:
389 | img (numpy.ndarray (C x H x W)): Cropped image.
390 | """
391 | i, j, h, w = self.get_params(img, self.size)
392 |
393 | """
394 | i: Upper pixel coordinate.
395 | j: Left pixel coordinate.
396 | h: Height of the cropped image.
397 | w: Width of the cropped image.
398 | """
399 | if not(_is_numpy_image(img)):
400 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
401 | if img.ndim == 3:
402 | return img[i:i+h, j:j+w, :]
403 | elif img.ndim == 2:
404 | return img[i:i + h, j:j + w]
405 | else:
406 | raise RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim))
407 |
408 |
409 | class Lambda(object):
410 | """Apply a user-defined lambda as a transform.
411 |
412 | Args:
413 | lambd (function): Lambda/function to be used for transform.
414 | """
415 |
416 | def __init__(self, lambd):
417 | assert isinstance(lambd, types.LambdaType)
418 | self.lambd = lambd
419 |
420 | def __call__(self, img):
421 | return self.lambd(img)
422 |
423 |
424 | class HorizontalFlip(object):
425 | """Horizontally flip the given ``numpy.ndarray``.
426 |
427 | Args:
428 | do_flip (boolean): whether or not do horizontal flip.
429 |
430 | """
431 |
432 | def __init__(self, do_flip):
433 | self.do_flip = do_flip
434 |
435 | def __call__(self, img):
436 | """
437 | Args:
438 | img (numpy.ndarray (C x H x W)): Image to be flipped.
439 |
440 | Returns:
441 | img (numpy.ndarray (C x H x W)): flipped image.
442 | """
443 | if not(_is_numpy_image(img)):
444 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
445 |
446 | if self.do_flip:
447 | return np.fliplr(img)
448 | else:
449 | return img
450 |
451 |
452 | class ColorJitter(object):
453 | """Randomly change the brightness, contrast and saturation of an image.
454 |
455 | Args:
456 | brightness (float): How much to jitter brightness. brightness_factor
457 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
458 | contrast (float): How much to jitter contrast. contrast_factor
459 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
460 | saturation (float): How much to jitter saturation. saturation_factor
461 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
462 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from
463 | [-hue, hue]. Should be >=0 and <= 0.5.
464 | """
465 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
466 | self.brightness = brightness
467 | self.contrast = contrast
468 | self.saturation = saturation
469 | self.hue = hue
470 |
471 | @staticmethod
472 | def get_params(brightness, contrast, saturation, hue):
473 | """Get a randomized transform to be applied on image.
474 |
475 | Arguments are same as that of __init__.
476 |
477 | Returns:
478 | Transform which randomly adjusts brightness, contrast and
479 | saturation in a random order.
480 | """
481 | transforms = []
482 | if brightness > 0:
483 | brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness)
484 | transforms.append(Lambda(lambda img: adjust_brightness(img, brightness_factor)))
485 |
486 | if contrast > 0:
487 | contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast)
488 | transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor)))
489 |
490 | if saturation > 0:
491 | saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation)
492 | transforms.append(Lambda(lambda img: adjust_saturation(img, saturation_factor)))
493 |
494 | if hue > 0:
495 | hue_factor = np.random.uniform(-hue, hue)
496 | transforms.append(Lambda(lambda img: adjust_hue(img, hue_factor)))
497 |
498 | np.random.shuffle(transforms)
499 | transform = Compose(transforms)
500 |
501 | return transform
502 |
503 | def __call__(self, img):
504 | """
505 | Args:
506 | img (numpy.ndarray (C x H x W)): Input image.
507 |
508 | Returns:
509 | img (numpy.ndarray (C x H x W)): Color jittered image.
510 | """
511 | if not(_is_numpy_image(img)):
512 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
513 | pil = Image.fromarray(img)
514 | transform = self.get_params(self.brightness, self.contrast,
515 | self.saturation, self.hue)
516 | return np.array(transform(pil))
517 |
518 | class Crop(object):
519 | """Crops the given PIL Image to a rectangular region based on a given
520 | 4-tuple defining the left, upper pixel coordinated, hight and width size.
521 |
522 | Args:
523 | a tuple: (upper pixel coordinate, left pixel coordinate, hight, width)-tuple
524 | """
525 |
526 | def __init__(self, i, j, h, w):
527 | """
528 | i: Upper pixel coordinate.
529 | j: Left pixel coordinate.
530 | h: Height of the cropped image.
531 | w: Width of the cropped image.
532 | """
533 | self.i = i
534 | self.j = j
535 | self.h = h
536 | self.w = w
537 |
538 | def __call__(self, img):
539 | """
540 | Args:
541 | img (numpy.ndarray (C x H x W)): Image to be cropped.
542 | Returns:
543 | img (numpy.ndarray (C x H x W)): Cropped image.
544 | """
545 |
546 | i, j, h, w = self.i, self.j, self.h, self.w
547 |
548 | if not(_is_numpy_image(img)):
549 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
550 | if img.ndim == 3:
551 | return img[i:i + h, j:j + w, :]
552 | elif img.ndim == 2:
553 | return img[i:i + h, j:j + w]
554 | else:
555 | raise RuntimeError(
556 | 'img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim))
557 |
558 | def __repr__(self):
559 | return self.__class__.__name__ + '(i={0},j={1},h={2},w={3})'.format(
560 | self.i, self.j, self.h, self.w)
561 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import time
3 | from options.options import AdvanceOptions
4 | from models import create_model
5 | from util.visualizer import Visualizer
6 | from dataloaders.nyu_dataloader import NYUDataset
7 | from dataloaders.kitti_dataloader import KITTIDataset
8 | from dataloaders.dense_to_sparse import UniformSampling, SimulatedStereo
9 | import numpy as np
10 | import random
11 | import torch
12 | import cv2
13 | import utils
14 | import os
15 |
16 | # def colored_depthmap(depth, d_min=None, d_max=None):
17 | # if d_min is None:
18 | # d_min = np.min(depth)
19 | # if d_max is None:
20 | # d_max = np.max(depth)
21 | # depth_relative = (depth - d_min) / (d_max - d_min)
22 | # return 255 * plt.cm.viridis(depth_relative)[:,:,:3] # H, W, C
23 |
24 | # def merge_into_row_with_pred_visualize(input, depth_input, rgb_sparse,depth_target, depth_est):
25 | # rgb = 255 * np.transpose(np.squeeze(input.cpu().numpy()), (1,2,0)) # H, W, C
26 | # rgb_sparse = 255 * np.transpose(np.squeeze(rgb_sparse.cpu().numpy()), (1,2,0))
27 | # depth_input_cpu = np.squeeze(depth_input.cpu().numpy())
28 | # depth_target_cpu = np.squeeze(depth_target.cpu().numpy())
29 | # depth_pred_cpu = np.squeeze(depth_est.cpu().numpy())
30 |
31 | # d_min = min(np.min(depth_input_cpu), np.min(depth_target_cpu), np.min(depth_pred_cpu))
32 | # d_max = max(np.max(depth_input_cpu), np.max(depth_target_cpu), np.min(depth_pred_cpu))
33 | # depth_input_col = colored_depthmap(depth_input_cpu, d_min, d_max)
34 | # depth_target_col = colored_depthmap(depth_target_cpu, d_min, d_max)
35 | # depth_pred_col = colored_depthmap(depth_target_cpu, d_min, d_max)
36 |
37 | # img_merge = np.hstack([rgb, rgb_sparse,depth_input_col, depth_target_col,depth_pred_col])
38 |
39 | # return img_merge
40 |
41 | if __name__ == '__main__':
42 | test_opt = AdvanceOptions().parse(False)
43 |
44 | sparsifier = UniformSampling(test_opt.nP, max_depth=np.inf)
45 | #sparsifier = SimulatedStereo(100, max_depth=np.inf, dilate_kernel=3, dilate_iterations=1)
46 | test_dataset = KITTIDataset(test_opt.test_path, type='val',
47 | modality='rgbdm', sparsifier=sparsifier)
48 |
49 | ### Please use this dataloder if you want to use NYU
50 | # test_dataset = NYUDataset(test_opt.test_path, type='val',
51 | # modality='rgbdm', sparsifier=sparsifier)
52 |
53 |
54 | test_opt.phase = 'val'
55 | test_opt.batch_size = 1
56 | test_opt.num_threads = 1
57 | test_opt.serial_batches = True
58 | test_opt.no_flip = True
59 |
60 | test_data_loader = torch.utils.data.DataLoader(test_dataset,
61 | batch_size=test_opt.batch_size, shuffle=False, num_workers=test_opt.num_threads, pin_memory=True)
62 |
63 | test_dataset_size = len(test_data_loader)
64 | print('#test images = %d' % test_dataset_size)
65 |
66 | model = create_model(test_opt, test_dataset)
67 | model.eval()
68 | model.setup(test_opt)
69 | visualizer = Visualizer(test_opt)
70 | test_loss_iter = []
71 | gts = None
72 | preds = None
73 | epoch_iter = 0
74 | model.init_test_eval()
75 | epoch = 0
76 | num = 5 # How many images to save in an image
77 | if not os.path.exists('vis'):
78 | os.makedirs('vis')
79 | with torch.no_grad():
80 | iterator = iter(test_data_loader)
81 | i = 0
82 | while True:
83 | try: # Some images couldn't sample more than defined nP points under Stereo sampling
84 | next_batch = next(iterator)
85 | except IndexError:
86 | print("Catch and Skip!")
87 | continue
88 | except StopIteration:
89 | break
90 |
91 | data, target = next_batch[0], next_batch[1]
92 | model.set_new_input(data,target)
93 | model.forward()
94 | model.test_depth_evaluation()
95 | model.get_loss()
96 | epoch_iter += test_opt.batch_size
97 | losses = model.get_current_losses()
98 | test_loss_iter.append(model.loss_dcca.item())
99 |
100 | rgb_input = model.rgb_image
101 | depth_input = model.sparse_depth
102 | rgb_sparse = model.sparse_rgb
103 | depth_target = model.depth_image
104 | depth_est = model.depth_est
105 |
106 | ### These part save image in vis/ folder
107 | if i%num == 0:
108 | img_merge = utils.merge_into_row_with_pred_visualize(rgb_input, depth_input, rgb_sparse,depth_target, depth_est)
109 | elif i%num < num-1:
110 | row = utils.merge_into_row_with_pred_visualize(rgb_input, depth_input, rgb_sparse,depth_target, depth_est)
111 | img_merge = utils.add_row(img_merge, row)
112 | elif i%num == num-1:
113 | filename = 'vis/'+str(i)+'.png'
114 | utils.save_image(img_merge, filename)
115 |
116 | i += 1
117 |
118 | print('test epoch {0:}, iters: {1:}/{2:} '.format(epoch, epoch_iter, len(test_dataset) * test_opt.batch_size), end='\r')
119 | print(
120 | 'RMSE={result.rmse:.4f}({average.rmse:.4f}) '
121 | 'MSE={result.mse:.4f}({average.mse:.4f}) '
122 | 'MAE={result.mae:.4f}({average.mae:.4f}) '
123 | 'Delta1={result.delta1:.4f}({average.delta1:.4f}) '
124 | 'Delta2={result.delta2:.4f}({average.delta2:.4f}) '
125 | 'Delta3={result.delta3:.4f}({average.delta3:.4f}) '
126 | 'REL={result.absrel:.4f}({average.absrel:.4f}) '
127 | 'Lg10={result.lg10:.4f}({average.lg10:.4f}) '.format(
128 | result=model.test_result, average=model.test_average.average()))
129 | avg_test_loss = np.mean(np.asarray(test_loss_iter))
130 |
--------------------------------------------------------------------------------
/images/500.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/choyingw/CFCNet/828e0c09c646a4669685b3d31b8aa0ae2a5cd351/images/500.gif
--------------------------------------------------------------------------------
/models/DCCA_sparse_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .base_model import BaseModel
3 | from . import DCCA_sparse_networks
4 | import numpy as np
5 | import os
6 | import math
7 |
8 | class DCCASparseModel(BaseModel):
9 | def name(self):
10 | return 'DCCASparseNetModel'
11 |
12 | @staticmethod
13 | def modify_commandline_options(parser, is_train=True):
14 |
15 | # changing the default values
16 | if is_train:
17 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
18 | return parser
19 |
20 | def initialize(self, opt, dataset):
21 | BaseModel.initialize(self, opt)
22 |
23 | self.x_dataview = None
24 | self.y_dataview = None
25 | self.depth_est = None
26 | self.loss_dcca = 0
27 | self.loss_l1 = 0
28 | self.loss_mse = None
29 | self.loss_smooth = None
30 | self.result = None
31 | self.test_result = None
32 | self.average = None
33 | self.test_average = None
34 |
35 | self.isTrain = opt.isTrain
36 | # specify the training losses you want to print out. The program will call base_model.get_current_losses
37 | self.loss_names = ['mse','dcca','total','transform','smooth']
38 | # specify the images you want to save/display. The program will call base_model.get_current_visuals
39 | self.visual_names = ['rgb_image','depth_image','mask','output']
40 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
41 | self.model_names = ['DCCASparseNet']
42 |
43 | # load/define networks
44 | self.netDCCASparseNet = DCCA_sparse_networks.define_DCCASparseNet(rgb_enc=True, depth_enc=True, depth_dec=True, norm=opt.norm, init_type=opt.init_type, init_gain= opt.init_gain, gpu_ids= self.gpu_ids)
45 | # define loss functions
46 | self.criterionDCCA = DCCA_sparse_networks.DCCA_2D_Loss(outdim_size = 60,use_all_singular_values = True, device=self.device).to(self.device)
47 | self.MSE = DCCA_sparse_networks.MaskedMSELoss()
48 | self.SMOOTH = DCCA_sparse_networks.SmoothLoss()
49 | self.TransformLoss = DCCA_sparse_networks.TransformLoss()
50 |
51 | if self.isTrain:
52 | # initialize optimizers
53 | self.optimizers = []
54 | self.optimizer_DCCASparseNet = torch.optim.SGD(self.netDCCASparseNet.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay)
55 | self.optimizers.append(self.optimizer_DCCASparseNet)
56 |
57 | def set_input(self, input):
58 | self.rgb_image = input['rgb_image'].to(self.device)
59 | self.depth_image = input['depth_image'].to(self.device)
60 | self.mask = input['mask'].to(self.device)
61 | self.image_paths = input['path']
62 |
63 | def set_new_input(self, input,target):
64 | self.rgb_image = input[:,:3,:,:].to(self.device)
65 | self.sparse_rgb = input[:,4:7,:,:].to(self.device)
66 | self.depth_image = target.to(self.device)
67 | self.sparse_depth = input[:,3,:,:].to(self.device).unsqueeze(1)
68 | self.mask = input[:,7,:,:].to(self.device).unsqueeze(1)
69 |
70 | def forward(self):
71 | self.x_dataview,self.y_dataview,self.x_trans,self.depth_est= self.netDCCASparseNet(self.sparse_rgb,self.sparse_depth,self.mask,self.rgb_image,self.depth_image)
72 |
73 | def get_loss(self):
74 | self.loss_dcca = self.criterionDCCA(self.x_dataview,self.y_dataview)
75 | self.loss_mse = self.MSE(self.depth_est,self.depth_image)
76 | self.loss_smooth = self.SMOOTH(self.depth_est)
77 | self.loss_transform = self.TransformLoss(self.x_trans, self.x_dataview)
78 | self.loss_total = self.loss_mse + self.loss_dcca + self.loss_transform + 0.1*self.loss_smooth
79 |
80 | def backward(self):
81 | self.loss_total.backward()
82 |
83 | def pure_backward(self):
84 | self.loss_dcca.backward()
85 |
86 | def init_test_eval(self):
87 | self.test_result = Result()
88 | self.test_average = AverageMeter()
89 |
90 | def init_eval(self):
91 | self.result = Result()
92 | self.average = AverageMeter()
93 |
94 | def depth_evaluation(self):
95 | self.result.evaluate(self.depth_est.data, self.depth_image.data)
96 | self.average.update(self.result, self.sparse_rgb.size(0))
97 |
98 | def test_depth_evaluation(self):
99 | self.test_result.evaluate(self.depth_est.data, self.depth_image.data)
100 | self.test_average.update(self.test_result, self.sparse_rgb.size(0))
101 | print()
102 |
103 | def print_test_depth_evaluation(self):
104 | message = 'RMSE={result.rmse:.4f}({average.rmse:.4f}) \
105 | MAE={result.mae:.4f}({average.mae:.4f}) \
106 | Delta1={result.delta1:.4f}({average.delta1:.4f}) \
107 | REL={result.absrel:.4f}({average.absrel:.4f}) \
108 | Lg10={result.lg10:.4f}({average.lg10:.4f})'.format(result=self.test_result, average=self.test_average.average())
109 | print(message)
110 | return message
111 |
112 | def print_depth_evaluation(self):
113 | message = 'RMSE={result.rmse:.4f}({average.rmse:.4f}) \
114 | MAE={result.mae:.4f}({average.mae:.4f}) \
115 | Delta1={result.delta1:.4f}({average.delta1:.4f}) \
116 | REL={result.absrel:.4f}({average.absrel:.4f}) \
117 | Lg10={result.lg10:.4f}({average.lg10:.4f})'.format(result=self.result, average=self.average.average())
118 | print(message)
119 | return message
120 |
121 | def optimize_parameters(self):
122 | self.forward()
123 | self.depth_evaluation()
124 | self.set_requires_grad(self.netDCCASparseNet, True)
125 | self.get_loss()
126 | self.optimizer_DCCASparseNet.zero_grad()
127 | # update DCCAnet
128 | self.backward()
129 | self.optimizer_DCCASparseNet.step()
130 |
131 |
132 | ####### Metrics ########
133 | def log10(x):
134 | """Convert a new tensor with the base-10 logarithm of the elements of x. """
135 | return torch.log(x) / math.log(10)
136 |
137 | class Result(object):
138 | def __init__(self):
139 | self.irmse, self.imae = 0, 0
140 | self.mse, self.rmse, self.mae = 0, 0, 0
141 | self.absrel, self.lg10 = 0, 0
142 | self.delta1, self.delta2, self.delta3 = 0, 0, 0
143 | self.data_time, self.gpu_time = 0, 0
144 |
145 | def set_to_worst(self):
146 | self.irmse, self.imae = np.inf, np.inf
147 | self.mse, self.rmse, self.mae = np.inf, np.inf, np.inf
148 | self.absrel, self.lg10 = np.inf, np.inf
149 | self.delta1, self.delta2, self.delta3 = 0, 0, 0
150 | self.data_time, self.gpu_time = 0, 0
151 |
152 | def update(self, irmse, imae, mse, rmse, mae, absrel, lg10, delta1, delta2, delta3, gpu_time, data_time):
153 | self.irmse, self.imae = irmse, imae
154 | self.mse, self.rmse, self.mae = mse, rmse, mae
155 | self.absrel, self.lg10 = absrel, lg10
156 | self.delta1, self.delta2, self.delta3 = delta1, delta2, delta3
157 | self.data_time, self.gpu_time = data_time, gpu_time
158 |
159 | def evaluate(self, output, target):
160 | valid_mask = target>0
161 | output = output[valid_mask]
162 | target = target[valid_mask]
163 |
164 | new_output = output[target<=50]
165 | new_target = target[target<=50]
166 | target = new_target
167 | output = new_output
168 |
169 | abs_diff = (output - target).abs()
170 |
171 | self.mse = float((torch.pow(abs_diff, 2)).mean())
172 | self.rmse = math.sqrt(self.mse)
173 | self.mae = float(abs_diff.mean())
174 | self.lg10 = float((log10(output) - log10(target)).abs().mean())
175 | self.absrel = float((abs_diff / target).mean())
176 |
177 | maxRatio = torch.max(output / target, target / output)
178 | self.delta1 = float((maxRatio < 1.25).float().mean())
179 | self.delta2 = float((maxRatio < 1.25 ** 2).float().mean())
180 | self.delta3 = float((maxRatio < 1.25 ** 3).float().mean())
181 | self.data_time = 0
182 | self.gpu_time = 0
183 |
184 | inv_output = 1 / output
185 | inv_target = 1 / target
186 | abs_inv_diff = (inv_output - inv_target).abs()
187 | self.irmse = math.sqrt((torch.pow(abs_inv_diff, 2)).mean())
188 | self.imae = float(abs_inv_diff.mean())
189 |
190 |
191 | class AverageMeter(object):
192 | def __init__(self):
193 | self.reset()
194 |
195 | def reset(self):
196 | self.count = 0.0
197 | self.sum_irmse, self.sum_imae = 0, 0
198 | self.sum_mse, self.sum_rmse, self.sum_mae = 0, 0, 0
199 | self.sum_absrel, self.sum_lg10 = 0, 0
200 | self.sum_delta1, self.sum_delta2, self.sum_delta3 = 0, 0, 0
201 | self.sum_data_time, self.sum_gpu_time = 0, 0
202 |
203 | def update(self, result, n=1):
204 | self.count += n
205 | self.sum_irmse += n*result.irmse
206 | self.sum_imae += n*result.imae
207 | self.sum_mse += n*result.mse
208 | self.sum_rmse += n*result.rmse
209 | self.sum_mae += n*result.mae
210 | self.sum_absrel += n*result.absrel
211 | self.sum_lg10 += n*result.lg10
212 | self.sum_delta1 += n*result.delta1
213 | self.sum_delta2 += n*result.delta2
214 | self.sum_delta3 += n*result.delta3
215 |
216 | def average(self):
217 | avg = Result()
218 | avg.update(
219 | self.sum_irmse / self.count, self.sum_imae / self.count,
220 | self.sum_mse / self.count, self.sum_rmse / self.count, self.sum_mae / self.count,
221 | self.sum_absrel / self.count, self.sum_lg10 / self.count,
222 | self.sum_delta1 / self.count, self.sum_delta2 / self.count, self.sum_delta3 / self.count,
223 | self.sum_gpu_time / self.count, self.sum_data_time / self.count)
224 | return avg
225 |
--------------------------------------------------------------------------------
/models/DCCA_sparse_networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import torchvision
5 | import functools
6 | from torch.optim import lr_scheduler
7 | import torch.nn.functional as F
8 | from copy import deepcopy
9 | import numpy as np
10 | import cv2
11 | import collections
12 | import matplotlib.pyplot as plt
13 |
14 | def get_norm_layer(norm_type='instance'):
15 | if norm_type == 'batch':
16 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
17 | elif norm_type == 'instance':
18 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
19 | elif norm_type == 'none':
20 | norm_layer = None
21 | else:
22 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
23 | return norm_layer
24 |
25 |
26 | def get_scheduler(optimizer, opt):
27 | if opt.lr_policy == 'lambda':
28 | lambda_rule = lambda epoch: opt.lr_gamma ** ((epoch+1) // opt.lr_decay_epochs)
29 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
30 | elif opt.lr_policy == 'step':
31 | scheduler = lr_scheduler.StepLR(optimizer,step_size=opt.lr_decay_iters, gamma=0.1)
32 | else:
33 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
34 | return scheduler
35 |
36 |
37 | def init_weights(net, init_type='normal', gain=0.02):
38 | net = net
39 | def init_func(m):
40 | classname = m.__class__.__name__
41 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
42 | if init_type == 'normal':
43 | init.normal_(m.weight.data, 0.0, gain)
44 | elif init_type == 'xavier':
45 | init.xavier_normal_(m.weight.data, gain=gain)
46 | elif init_type == 'kaiming':
47 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
48 | elif init_type == 'orthogonal':
49 | init.orthogonal_(m.weight.data, gain=gain)
50 | elif init_type == 'pretrained':
51 | pass
52 | else:
53 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
54 | if hasattr(m, 'bias') and m.bias is not None and init_type != 'pretrained':
55 | init.constant_(m.bias.data, 0.0)
56 | elif classname.find('BatchNorm2d') != -1:
57 | init.normal_(m.weight.data, 1.0, gain)
58 | init.constant_(m.bias.data, 0.0)
59 | print('initialize network with %s' % init_type)
60 | net.apply(init_func)
61 |
62 |
63 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
64 | if len(gpu_ids) > 0:
65 | assert(torch.cuda.is_available())
66 | net.to(gpu_ids[0])
67 | net = torch.nn.DataParallel(net, gpu_ids)
68 |
69 | for root_child in net.children():
70 | for children in root_child.children():
71 | if children in root_child.need_initialization:
72 | init_weights(children, init_type, gain=init_gain)
73 | return net
74 |
75 | def define_DCCASparseNet(rgb_enc=True, depth_enc=True, depth_dec=True, norm='batch', init_type='xavier', init_gain=0.02, gpu_ids=[]):
76 | net = None
77 | norm_layer = get_norm_layer(norm_type=norm)
78 | net = DCCASparsenetGenerator(rgb_enc=rgb_enc, depth_enc=depth_enc, depth_dec=depth_dec)
79 | return init_net(net, init_type, init_gain, gpu_ids)
80 |
81 | ##############################################################################
82 | # Classes
83 | ##############################################################################
84 | class SAConv(nn.Module):
85 | # Convolution layer for sparse data
86 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, dilation=1, bias=True):
87 | super(SAConv, self).__init__()
88 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False)
89 | self.if_bias = bias
90 | if self.if_bias:
91 | self.bias = nn.Parameter(torch.zeros(out_channels).float(), requires_grad=True)
92 | self.pool = nn.MaxPool2d(kernel_size, stride=stride, padding=padding, dilation=dilation)
93 | nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu')
94 | self.pool.require_grad = False
95 |
96 | def forward(self, input):
97 | x, m = input
98 | x = x * m
99 | x = self.conv(x)
100 | weights = torch.ones(torch.Size([1, 1, 3, 3])).cuda()
101 | mc = F.conv2d(m, weights, bias=None, stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation)
102 | mc = torch.clamp(mc, min=1e-5)
103 | mc = 1. / mc * 9
104 |
105 | if self.if_bias:
106 | x = x + self.bias.view(1, self.bias.size(0), 1, 1).expand_as(x)
107 | m = self.pool(m)
108 |
109 | return x, m
110 |
111 | class SAConvBlock(nn.Module):
112 |
113 | def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=1, dilation=1, bias=True):
114 | super(SAConvBlock, self).__init__()
115 | self.sparse_conv = SAConv(in_channel, out_channel, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True)
116 | self.relu = nn.ReLU(inplace=True)
117 |
118 | def forward(self, input):
119 | x, m = input
120 | x, m = self.sparse_conv((x, m))
121 | assert (m.size(1)==1)
122 | x = self.relu(x)
123 |
124 | return x, m
125 |
126 | class Decoder(nn.Module):
127 | # Decoder is the base class for all decoders
128 |
129 | def __init__(self):
130 | super(Decoder, self).__init__()
131 |
132 | self.layer1 = None
133 | self.layer2 = None
134 | self.layer3 = None
135 | self.layer4 = None
136 |
137 | def forward(self, x):
138 | x = self.layer1(x)
139 | x = self.layer2(x)
140 | x = self.layer3(x)
141 | x = self.layer4(x)
142 | return x
143 |
144 | class DeConv(Decoder):
145 | def __init__(self, in_channels, kernel_size):
146 | assert kernel_size>=2, "kernel_size out of range: {}".format(kernel_size)
147 | super(DeConv, self).__init__()
148 |
149 | def convt(in_channels):
150 | stride = 2
151 | padding = (kernel_size - 1) // 2
152 | output_padding = kernel_size % 2
153 | assert -2 - 2*padding + kernel_size + output_padding == 0, "deconv parameters incorrect"
154 |
155 | module_name = "deconv{}".format(kernel_size)
156 | return nn.Sequential(collections.OrderedDict([
157 | (module_name, nn.ConvTranspose2d(in_channels,in_channels//2,kernel_size,
158 | stride,padding,output_padding,bias=False)),
159 | ('batchnorm', nn.BatchNorm2d(in_channels//2)),
160 | ('relu', nn.ReLU(inplace=True)),
161 | ]))
162 | self.layer1 = convt(in_channels)
163 | self.layer2 = convt(in_channels // 2)
164 | self.layer3 = convt(in_channels // (2 ** 2))
165 | self.layer4 = convt(in_channels // (2 ** 3))
166 |
167 | def make_layers_from_size(sizes):
168 | layers = []
169 | for size in sizes:
170 | layers += [nn.Conv2d(size[0], size[1], kernel_size=3, padding=1), nn.BatchNorm2d(size[1],momentum = 0.1), nn.ReLU(inplace=True)]
171 | return nn.Sequential(*layers)
172 |
173 | def make_blocks_from_names(names,in_dim,out_dim):
174 | layers = []
175 | if names[0] == "block1" or names[0] == "block2":
176 | layers += [SAConvBlock(in_dim, out_dim, 3,stride = 1)]
177 | layers += [SAConvBlock(out_dim, out_dim, 3,stride = 1)]
178 | else:
179 | layers += [SAConvBlock(in_dim, out_dim, 3,stride = 1)]
180 | layers += [SAConvBlock(out_dim, out_dim, 3,stride = 1)]
181 | layers += [SAConvBlock(out_dim, out_dim, 3,stride = 1)]
182 | return nn.Sequential(*layers)
183 |
184 | class DCCASparsenetGenerator(nn.Module):
185 | def __init__(self, rgb_enc=True, depth_enc=True, depth_dec=True):
186 | super(DCCASparsenetGenerator, self).__init__()
187 | #batchNorm_momentum = 0.1
188 | self.need_initialization = []
189 |
190 | if rgb_enc :
191 | ##### RGB ENCODER ####
192 | self.CBR1_RGB_ENC = make_blocks_from_names(["block1"], 3,64)
193 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
194 |
195 | self.CBR2_RGB_ENC = make_blocks_from_names(["block2"], 64, 128)
196 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
197 |
198 | self.CBR3_RGB_ENC = make_blocks_from_names(["block3"], 128, 256)
199 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
200 | self.dropout3 = nn.Dropout(p=0.4)
201 |
202 | self.CBR4_RGB_ENC = make_blocks_from_names(["block4"], 256, 512)
203 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
204 | self.dropout4 = nn.Dropout(p=0.4)
205 |
206 | self.CBR5_RGB_ENC = make_blocks_from_names(["block5"], 512, 512)
207 | self.dropout5 = nn.Dropout(p=0.4)
208 |
209 | self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
210 |
211 | if depth_enc :
212 |
213 | self.CBR1_DEPTH_ENC = make_blocks_from_names(["block1"], 1, 64)
214 | self.pool1_d = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
215 |
216 | self.CBR2_DEPTH_ENC = make_blocks_from_names(["block2"], 64, 128)
217 | self.pool2_d = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
218 |
219 | self.CBR3_DEPTH_ENC = make_blocks_from_names(["block3"], 128, 256)
220 | self.pool3_d = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
221 |
222 | self.CBR4_DEPTH_ENC = make_blocks_from_names(["block4"], 256, 512)
223 | self.pool4_d = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
224 |
225 | self.CBR5_DEPTH_ENC = make_blocks_from_names(["block5"], 512, 512)
226 |
227 | self.pool5_d = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
228 |
229 | if depth_dec :
230 | #### DECODER ####
231 | self.Transform = make_blocks_from_names(["block1"],512, 512)
232 | self.decoder = DeConv(1024, 3)
233 | self.conv3 = nn.Conv2d(64,1,kernel_size=3,stride=1,padding=1,bias=False)
234 | ## This size is for KITTI, use (224,224) for NYU
235 | self.bilinear = nn.Upsample((228,912), mode='bilinear', align_corners=True)
236 |
237 | self.need_initialization.append(self.decoder)
238 | self.need_initialization.append(self.conv3)
239 |
240 | def forward(self, sparse_rgb,sparse_d,mask,rgb,d):
241 |
242 | ######## DEPTH ENCODER ########
243 | x_1,m_d = self.CBR1_DEPTH_ENC((sparse_d,mask))
244 | x, id1_d = self.pool1_d(x_1)
245 | m_d,_ = self.pool1_d(m_d )
246 |
247 | x_2,m_d = self.CBR2_DEPTH_ENC((x,m_d ))
248 | x, id2_d = self.pool2_d(x_2)
249 | m_d,_ = self.pool2_d(m_d )
250 |
251 | x_3,m_d = self.CBR3_DEPTH_ENC((x,m_d ))
252 | x, id3_d = self.pool4_d(x_3)
253 | m_d,_ = self.pool3_d(m_d )
254 |
255 | x_4,m_d = self.CBR4_DEPTH_ENC((x,m_d ))
256 | x, id4_d = self.pool4_d(x_4)
257 | m_d,_ = self.pool4_d(m_d )
258 |
259 | x_5,m_d = self.CBR5_DEPTH_ENC((x,m_d ))
260 | x_dataview, id5_d = self.pool5_d(x_5)
261 | m_d,_ = self.pool5_d(m_d )
262 |
263 | ######## RGB ENCODER ########
264 | y_1,m_r = self.CBR1_RGB_ENC((sparse_rgb,mask))
265 | y, id1 = self.pool1(y_1)
266 | m_r,_ = self.pool1(m_r)
267 |
268 | y_2,m_r = self.CBR2_RGB_ENC((y,m_r))
269 | y, id2 = self.pool2(y_2)
270 | m_r,_ = self.pool2(m_r)
271 |
272 | y_3,m_r = self.CBR3_RGB_ENC((y,m_r))
273 | y, id3 = self.pool3(y_3)
274 | m_r,_ = self.pool3(m_r)
275 |
276 | y_4,m_r = self.CBR4_RGB_ENC((y,m_r))
277 | y, id4 = self.pool4(y_4)
278 | m_r,_ = self.pool4(m_r)
279 |
280 | y_5,m_r = self.CBR5_RGB_ENC((y,m_r))
281 | y_dataview, id5 = self.pool5(y_5)
282 | m_r,_ = self.pool5(m_r)
283 |
284 | ######## MISSING DATA ENCODER ########
285 | inverse_mask = torch.ones_like(mask)-mask
286 | inverse_rgb = rgb*inverse_mask
287 |
288 | ym_1,m_m = self.CBR1_RGB_ENC((inverse_rgb,inverse_mask))
289 | ym, id1_m = self.pool1(ym_1)
290 | m_m,_ = self.pool1(m_m)
291 |
292 | ym_2,m_m = self.CBR2_RGB_ENC((ym,m_m ))
293 | ym, id2_m = self.pool2(ym_2)
294 | m_m,_ = self.pool2(m_m)
295 |
296 | ym_3,m_m = self.CBR3_RGB_ENC((ym,m_m ))
297 | ym, id3_m = self.pool4(ym_3)
298 | m_m,_ = self.pool3(m_m)
299 |
300 | ym_4,m_m = self.CBR4_RGB_ENC((ym,m_m ))
301 | ym, id4_m = self.pool4(ym_4)
302 | m_m,_ = self.pool4(m_m)
303 |
304 | ym_5,m_m = self.CBR5_RGB_ENC((ym,m_m ))
305 | ym_dataview, id5_m = self.pool5(ym_5)
306 | m_m,_ = self.pool5(m_m)
307 |
308 | ######## Transformer ########
309 | x_trans, m_trans = self.Transform((y_dataview,m_r))
310 | xm_trans, mm_trans = self.Transform((ym_dataview,m_r))
311 |
312 | ######## DECODER ########
313 | x = self.decoder(torch.cat((x_dataview,xm_trans),1))
314 | x = self.conv3(x)
315 | depth_est = self.bilinear(x)
316 |
317 | return x_dataview, y_dataview, x_trans, depth_est
318 |
319 | class MaskedMSELoss(nn.Module):
320 | def __init__(self):
321 | super(MaskedMSELoss, self).__init__()
322 |
323 | def forward(self, pred, target):
324 | assert pred.dim() == target.dim(), "inconsistent dimensions"
325 | valid_mask = (target>0).detach()
326 | diff = target - pred
327 | diff = diff[valid_mask]
328 | self.loss = (diff ** 2).mean()
329 | return self.loss
330 |
331 | class TransformLoss(nn.Module):
332 | def __init__(self):
333 | super(TransformLoss, self).__init__()
334 |
335 | def forward(self, f_in, f_target):
336 | assert f_in.dim() == f_target.dim(), "inconsistent dimensions"
337 | diff = f_in - f_target
338 | self.loss = (diff ** 2).mean()
339 | return self.loss
340 |
341 | class SmoothLoss(nn.Module):
342 | def __init__(self):
343 | super(SmoothLoss, self).__init__()
344 |
345 | def forward(self, pred_map):
346 | def gradient(pred):
347 | D_dy = pred[:, :, 1:] - pred[:, :, :-1]
348 | D_dx = pred[:, :, :, 1:] - pred[:, :, :, :-1]
349 | return D_dx, D_dy
350 |
351 | if type(pred_map) not in [tuple, list]:
352 | pred_map = [pred_map]
353 |
354 | loss = 0
355 | weight = 1.
356 |
357 | for scaled_map in pred_map:
358 | dx, dy = gradient(scaled_map)
359 | dx2, dxdy = gradient(dx)
360 | dydx, dy2 = gradient(dy)
361 | loss += (dx2.abs().mean() + dxdy.abs().mean() + dydx.abs().mean() + dy2.abs().mean())*weight
362 | weight /= 2.3 # don't ask me why it works better
363 | return loss
364 |
365 | class DCCA_2D_Loss(nn.Module):
366 | def __init__(self,outdim_size, use_all_singular_values, device):
367 | super(DCCA_2D_Loss, self).__init__()
368 | self.outdim_size = outdim_size
369 | self.use_all_singular_values = use_all_singular_values
370 | self.device = device
371 |
372 | def __call__(self, data_view1, data_view2):
373 | H1 = data_view1.view(data_view1.size(0)*data_view1.size(1),data_view1.size(2),data_view1.size(3))
374 | H2 = data_view2.view(data_view2.size(0)*data_view2.size(1),data_view2.size(2),data_view2.size(3))
375 |
376 | r1 = 1e-4
377 | r2 = 1e-4
378 | eps = 1e-12
379 | corr_sum = 0
380 | o1 = o2 = H1.size(1)
381 |
382 | m = H1.size(0)
383 | n = H1.size(1)
384 |
385 | H1bar = H1 - (1.0 / m) * H1
386 | H2bar = H2 - (1.0 / m) * H2
387 | Hat12 = torch.zeros(m,n,n).cuda()
388 | Hat11 = torch.zeros(m,n,n).cuda()
389 | Hat22 = torch.zeros(m,n,n).cuda()
390 |
391 | for i in range(m):
392 | Hat11[i] = torch.matmul(H1bar[i],H1bar.transpose(1,2)[i])
393 | Hat12[i] = torch.matmul(H1bar[i],H2bar.transpose(1,2)[i])
394 | Hat22[i] = torch.matmul(H2bar[i],H2bar.transpose(1,2)[i])
395 |
396 | SigmaHat12 = (1.0 / (m - 1)) * torch.mean(Hat12,dim=0)
397 | SigmaHat11 = (1.0 / (m - 1)) * torch.mean(Hat11,dim=0)+ r1 * torch.eye(o1, device=self.device)
398 | SigmaHat22 = (1.0 / (m - 1)) * torch.mean(Hat22,dim=0) + r2 * torch.eye(o2, device=self.device)
399 |
400 | # Calculating the root inverse of covariance matrices by using eigen decomposition
401 | [D1, V1] = torch.symeig(SigmaHat11, eigenvectors=True)
402 | [D2, V2] = torch.symeig(SigmaHat22, eigenvectors=True)
403 |
404 | # Added to increase stability
405 | posInd1 = torch.gt(D1, eps).nonzero()[:, 0]
406 | D1 = D1[posInd1]
407 | V1 = V1[:, posInd1]
408 | posInd2 = torch.gt(D2, eps).nonzero()[:, 0]
409 | D2 = D2[posInd2]
410 | V2 = V2[:, posInd2]
411 | SigmaHat11RootInv = torch.matmul(
412 | torch.matmul(V1, torch.diag(D1 ** -0.5)), V1.t())
413 | SigmaHat22RootInv = torch.matmul(
414 | torch.matmul(V2, torch.diag(D2 ** -0.5)), V2.t())
415 |
416 | Tval = torch.matmul(torch.matmul(SigmaHat11RootInv,
417 | SigmaHat12), SigmaHat22RootInv)
418 |
419 | if self.use_all_singular_values:
420 | # all singular values are used to calculate the correlation
421 | corr = torch.sqrt(torch.trace(torch.matmul(Tval.t(), Tval)))
422 | else:
423 | # just the top self.outdim_size singular values are used
424 | U, V = torch.symeig(torch.matmul(Tval.t(), Tval), eigenvectors=True)
425 | U = U[torch.gt(U, eps).nonzero()[:, 0]]
426 | U = U.topk(self.outdim_size)[0]
427 | corr = torch.sum(torch.sqrt(U))
428 | return -corr
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from models.base_model import BaseModel
3 |
4 |
5 | def find_model_using_name(model_name):
6 | # Given the option --model [modelname],
7 | # the file "models/modelname_model.py"
8 | # will be imported.
9 | model_filename = "models." + model_name + "_model"
10 | modellib = importlib.import_module(model_filename)
11 |
12 | # In the file, the class called ModelNameModel() will
13 | # be instantiated. It has to be a subclass of BaseModel,
14 | # and it is case-insensitive.
15 | model = None
16 | target_model_name = model_name.replace('_', '') + 'model'
17 | for name, cls in modellib.__dict__.items():
18 | if name.lower() == target_model_name.lower() \
19 | and issubclass(cls, BaseModel):
20 | model = cls
21 |
22 | if model is None:
23 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
24 | exit(0)
25 |
26 | return model
27 |
28 |
29 | def get_option_setter(model_name):
30 | model_class = find_model_using_name(model_name)
31 | return model_class.modify_commandline_options
32 |
33 |
34 | def create_model(opt, dataset):
35 | model = find_model_using_name(opt.model)
36 | instance = model()
37 | instance.initialize(opt, dataset)
38 | print("model [%s] was created" % (instance.name()))
39 | return instance
40 |
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from collections import OrderedDict
4 | from torch.optim import lr_scheduler
5 |
6 | class BaseModel():
7 | @staticmethod
8 | def modify_commandline_options(parser, is_train):
9 | return parser
10 |
11 | def name(self):
12 | return 'BaseModel'
13 |
14 | def initialize(self, opt):
15 | self.opt = opt
16 | self.gpu_ids = opt.gpu_ids
17 | self.isTrain = opt.isTrain
18 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
19 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
20 | if opt.resize_or_crop != 'scale_width':
21 | torch.backends.cudnn.benchmark = True
22 | self.loss_names = []
23 | self.model_names = []
24 | self.visual_names = []
25 | self.image_paths = []
26 |
27 | def set_input(self, input):
28 | self.input = input
29 |
30 | def forward(self):
31 | pass
32 |
33 | def get_scheduler(self, optimizer, opt):
34 | if opt.lr_policy == 'lambda':
35 | lambda_rule = lambda epoch: opt.lr_gamma ** ((epoch+1) // opt.lr_decay_epochs)
36 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
37 | elif opt.lr_policy == 'step':
38 | scheduler = lr_scheduler.StepLR(optimizer,step_size=opt.lr_decay_iters, gamma=0.1)
39 | else:
40 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
41 | return scheduler
42 |
43 | # load and print networks; create schedulers
44 | def setup(self, opt, parser=None):
45 | if self.isTrain:
46 | self.schedulers = [self.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
47 |
48 | if not self.isTrain or opt.continue_train:
49 | self.load_networks(opt.epoch)
50 | self.print_networks(opt.verbose)
51 |
52 | # make models eval mode during test time
53 | def eval(self):
54 | for name in self.model_names:
55 | if isinstance(name, str):
56 | net = getattr(self, 'net' + name)
57 | net.eval()
58 | def train(self):
59 | for name in self.model_names:
60 | if isinstance(name, str):
61 | net = getattr(self, 'net' + name)
62 | net.train()
63 |
64 | def test(self):
65 | with torch.no_grad():
66 | self.forward()
67 |
68 | # get image paths
69 | def get_image_paths(self):
70 | return self.image_paths
71 |
72 | def optimize_parameters(self):
73 | pass
74 |
75 | # update learning rate (called once every epoch)
76 | def update_learning_rate(self):
77 | for scheduler in self.schedulers:
78 | scheduler.step()
79 | lr = self.optimizers[0].param_groups[0]['lr']
80 | print('learning rate = %.7f' % lr)
81 |
82 | # return visualization images. train.py will display these images, and save the images to a html
83 | def get_current_visuals(self):
84 | visual_ret = OrderedDict()
85 | for name in self.visual_names:
86 | if isinstance(name, str):
87 | visual_ret[name] = getattr(self, name)
88 | return visual_ret
89 |
90 | # return traning losses/errors. train.py will print out these errors as debugging information
91 | def get_current_losses(self):
92 | errors_ret = OrderedDict()
93 | for name in self.loss_names:
94 | if isinstance(name, str):
95 | # float(...) works for both scalar tensor and float number
96 | errors_ret[name] = float(getattr(self, 'loss_' + name))
97 | return errors_ret
98 |
99 | # save models to the disk
100 | def save_networks(self, epoch):
101 | for name in self.model_names:
102 | if isinstance(name, str):
103 | save_filename = '%s_net_%s.pth' % (epoch, name)
104 | save_path = os.path.join(self.save_dir, save_filename)
105 | net = getattr(self, 'net' + name)
106 |
107 | if len(self.gpu_ids) > 0 and torch.cuda.is_available():
108 | torch.save(net.module.cpu().state_dict(), save_path)
109 | net.cuda(self.gpu_ids[0])
110 | else:
111 | torch.save(net.cpu().state_dict(), save_path)
112 |
113 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
114 | key = keys[i]
115 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
116 | if module.__class__.__name__.startswith('InstanceNorm') and \
117 | (key == 'running_mean' or key == 'running_var'):
118 | if getattr(module, key) is None:
119 | state_dict.pop('.'.join(keys))
120 | if module.__class__.__name__.startswith('InstanceNorm') and \
121 | (key == 'num_batches_tracked'):
122 | state_dict.pop('.'.join(keys))
123 | else:
124 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
125 |
126 | # load models from the disk
127 | def load_networks(self, epoch):
128 | for name in self.model_names:
129 | if isinstance(name, str):
130 | load_filename = '%s_net_%s.pth' % (epoch, name)
131 | load_path = os.path.join(self.save_dir, load_filename)
132 | net = getattr(self, 'net' + name)
133 | if isinstance(net, torch.nn.DataParallel):
134 | net = net.module
135 | print('loading the model from %s' % load_path)
136 | # if you are using PyTorch newer than 0.4 (e.g., built from
137 | # GitHub source), you can remove str() on self.device
138 | state_dict = torch.load(load_path, map_location=str(self.device))
139 | if hasattr(state_dict, '_metadata'):
140 | del state_dict._metadata
141 |
142 | # patch InstanceNorm checkpoints prior to 0.4
143 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
144 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
145 | net.load_state_dict(state_dict)
146 |
147 | # print network information
148 | def print_networks(self, verbose):
149 | print('---------- Networks initialized -------------')
150 | for name in self.model_names:
151 | if isinstance(name, str):
152 | net = getattr(self, 'net' + name)
153 | num_params = 0
154 | for param in net.parameters():
155 | num_params += param.numel()
156 | if verbose:
157 | print(net)
158 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
159 | print('-----------------------------------------------')
160 |
161 | # set requies_grad=Fasle to avoid computation
162 | def set_requires_grad(self, nets, requires_grad=False):
163 | if not isinstance(nets, list):
164 | nets = [nets]
165 | for net in nets:
166 | if net is not None:
167 | for param in net.parameters():
168 | param.requires_grad = requires_grad
169 |
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/choyingw/CFCNet/828e0c09c646a4669685b3d31b8aa0ae2a5cd351/options/__init__.py
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from util import util
4 | import torch
5 | import models
6 | #import data
7 |
8 |
9 | class BaseOptions():
10 | def __init__(self):
11 | self.initialized = False
12 |
13 | def initialize(self, parser):
14 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
15 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
16 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
17 | parser.add_argument('--model', type=str, default='DCCA_sparse',
18 | help='chooses which model to use. cycle_gan, pix2pix, test')
19 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
20 | parser.add_argument('--num_threads', default=8, type=int, help='# threads for loading data')
21 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
22 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
23 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
24 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
25 | parser.add_argument('--resize_or_crop', type=str, default='none', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop|none]')
26 | parser.add_argument('--no_flip', action='store_true', default=True, help='if specified, do not flip the images for data augmentation')
27 | parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]')
28 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
29 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
30 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{loadSize}')
31 | parser.add_argument('--seed', type=int, default=0, help='seed for random generators')
32 | self.initialized = True
33 | return parser
34 |
35 | def gather_options(self, flag):
36 | # initialize parser with basic options
37 | if not self.initialized:
38 | parser = argparse.ArgumentParser(
39 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
40 | parser = self.initialize(parser, flag)
41 |
42 | # get the basic options
43 | opt, _ = parser.parse_known_args()
44 |
45 | # modify model-related parser options
46 | model_name = opt.model
47 | model_option_setter = models.get_option_setter(model_name)
48 | parser = model_option_setter(parser, self.isTrain)
49 | opt, _ = parser.parse_known_args() # parse again with the new defaults
50 |
51 | # modify dataset-related parser options
52 | # dataset_name = opt.dataset_mode
53 | # print(dataset_name)
54 | # dataset_option_setter = data.get_option_setter(dataset_name)
55 | # parser = dataset_option_setter(parser, self.isTrain)
56 |
57 | self.parser = parser
58 |
59 | return parser.parse_args()
60 |
61 | def print_options(self, opt):
62 | message = ''
63 | message += '----------------- Options ---------------\n'
64 | for k, v in sorted(vars(opt).items()):
65 | comment = ''
66 | default = self.parser.get_default(k)
67 | if v != default:
68 | comment = '\t[default: %s]' % str(default)
69 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
70 | message += '----------------- End -------------------'
71 | print(message)
72 |
73 | # save to the disk
74 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
75 | util.mkdirs(expr_dir)
76 | file_name = os.path.join(expr_dir, 'opt.txt')
77 | with open(file_name, 'wt') as opt_file:
78 | opt_file.write(message)
79 | opt_file.write('\n')
80 |
81 | def parse(self, flag):
82 |
83 | opt = self.gather_options(flag)
84 | opt.isTrain = self.isTrain # train or test
85 |
86 | # process opt.suffix
87 | if opt.suffix:
88 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
89 | opt.name = opt.name + suffix
90 |
91 | self.print_options(opt)
92 |
93 | # set gpu ids
94 | str_ids = opt.gpu_ids.split(',')
95 | opt.gpu_ids = []
96 | for str_id in str_ids:
97 | id = int(str_id)
98 | if id >= 0:
99 | opt.gpu_ids.append(id)
100 | if len(opt.gpu_ids) > 0:
101 | torch.cuda.set_device(opt.gpu_ids[0])
102 |
103 | self.opt = opt
104 | return self.opt
105 |
--------------------------------------------------------------------------------
/options/options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class AdvanceOptions(BaseOptions):
5 | def initialize(self, parser, flag):
6 | parser = BaseOptions.initialize(self, parser)
7 | parser.add_argument('--print_freq', type=int, default=1, help='frequency of showing training results on console')
8 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs')
9 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
10 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
11 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
12 | parser.add_argument('--niter', type=int, default=400, help='# of iter at starting learning rate')
13 | parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate for optimizer')
14 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum factor for SGD')
15 | parser.add_argument('--weight_decay', type=float, default=0.0005, help='momentum factor for optimizer')
16 | parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau|cosine')
17 | parser.add_argument('--lr_decay_iters', type=int, default=5000000, help='multiply by a gamma every lr_decay_iters iterations')
18 | parser.add_argument('--lr_decay_epochs', type=int, default=100, help='multiply by a gamma every lr_decay_epoch epochs')
19 | parser.add_argument('--lr_gamma', type=float, default=0.9, help='gamma factor for lr_scheduler')
20 | parser.add_argument('--nP', type=int, default=500, help='number of points')
21 | parser.add_argument('--train_path', help='path to the training dataset')
22 | parser.add_argument('--test_path', help='path to the testing dataset')
23 | self.isTrain = flag
24 | return parser
25 |
--------------------------------------------------------------------------------
/train_depth_complete.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import time
3 | from options.options import AdvanceOptions
4 | from models import create_model
5 | from util.visualizer import Visualizer
6 | from dataloaders.nyu_dataloader import NYUDataset
7 | from dataloaders.kitti_dataloader import KITTIDataset
8 | from dataloaders.dense_to_sparse import UniformSampling, SimulatedStereo
9 | import numpy as np
10 | import random
11 | import torch
12 | import cv2
13 |
14 | if __name__ == '__main__':
15 | train_opt = AdvanceOptions().parse(True)
16 |
17 | # The SimulatedStereo class is also provided to subsample to stereo points
18 | sparsifier = UniformSampling(train_opt.nP, max_depth=np.inf)
19 |
20 | train_dataset = KITTIDataset(train_opt.train_path, type='train',
21 | modality='rgbdm', sparsifier=sparsifier)
22 | test_dataset = KITTIDataset(train_opt.test_path, type='val',
23 | modality='rgbdm', sparsifier=sparsifier)
24 | ## Please use this dataloder if you want to use NYU
25 | # train_dataset = NYUDataset(train_opt.train_path, type='train',
26 | # modality='rgbdm', sparsifier=sparsifier)
27 | ## Please use this dataloder if you want to use NYU
28 | # test_dataset = NYUDataset(train_opt.test_path, type='val',
29 | # modality='rgbdm', sparsifier=sparsifier)
30 |
31 | train_data_loader = torch.utils.data.DataLoader(
32 | train_dataset, batch_size=train_opt.batch_size, shuffle=True,
33 | num_workers=train_opt.num_threads, pin_memory=True, sampler=None,
34 | worker_init_fn=lambda work_id:np.random.seed(train_opt.seed + work_id))
35 | test_opt = AdvanceOptions().parse(True)
36 | test_opt.phase = 'val'
37 | test_opt.batch_size = 1
38 | test_opt.num_threads = 1
39 | test_opt.serial_batches = True
40 | test_opt.no_flip = True
41 |
42 | test_data_loader = torch.utils.data.DataLoader(test_dataset,
43 | batch_size=test_opt.batch_size, shuffle=False, num_workers=test_opt.num_threads, pin_memory=True)
44 |
45 | train_dataset_size = len(train_data_loader)
46 | print('#training images = %d' % train_dataset_size)
47 | test_dataset_size = len(test_data_loader)
48 | print('#test images = %d' % test_dataset_size)
49 |
50 | model = create_model(train_opt, train_dataset)
51 | model.setup(train_opt)
52 | visualizer = Visualizer(train_opt)
53 | total_steps = 0
54 | for epoch in range(train_opt.epoch_count, train_opt.niter + 1):
55 | model.train()
56 | epoch_start_time = time.time()
57 | iter_data_time = time.time()
58 | epoch_iter = 0
59 | model.init_eval()
60 | iterator = iter(train_data_loader)
61 | while True:
62 | try: # Some images couldn't sample more than defined nP points under Stereo sampling
63 | next_batch = next(iterator)
64 | except IndexError:
65 | print("Catch and Skip!")
66 | continue
67 | except StopIteration:
68 | break
69 | data, target = next_batch[0], next_batch[1]
70 |
71 | iter_start_time = time.time()
72 | if total_steps % train_opt.print_freq == 0:
73 | t_data = iter_start_time - iter_data_time
74 | total_steps += train_opt.batch_size
75 | epoch_iter += train_opt.batch_size
76 | model.set_new_input(data,target)
77 | model.optimize_parameters()
78 |
79 | if total_steps % train_opt.print_freq == 0:
80 | losses = model.get_current_losses()
81 | t = (time.time() - iter_start_time) / train_opt.batch_size
82 | visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data)
83 | message = model.print_depth_evaluation()
84 | visualizer.print_current_depth_evaluation(message)
85 | print()
86 |
87 | iter_data_time = time.time()
88 |
89 | print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, train_opt.niter, time.time() - epoch_start_time))
90 | model.update_learning_rate()
91 | if epoch and epoch % train_opt.save_epoch_freq == 0:
92 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
93 | model.save_networks('latest')
94 | model.save_networks(epoch)
95 |
96 | model.eval()
97 | test_loss_iter = []
98 | gts = None
99 | preds = None
100 | epoch_iter = 0
101 | model.init_test_eval()
102 | with torch.no_grad():
103 | iterator = iter(test_data_loader)
104 | while True:
105 | try: # Some images couldn't sample more than defined nP points under Stereo sampling
106 | next_batch = next(iterator)
107 | except IndexError:
108 | print("Catch and Skip!")
109 | continue
110 | except StopIteration:
111 | break
112 |
113 | data, target = next_batch[0], next_batch[1]
114 |
115 | model.set_new_input(data,target)
116 | model.forward()
117 | model.test_depth_evaluation()
118 | model.get_loss()
119 | epoch_iter += test_opt.batch_size
120 | losses = model.get_current_losses()
121 | test_loss_iter.append(model.loss_dcca.item())
122 | print('test epoch {0:}, iters: {1:}/{2:} '.format(epoch, epoch_iter, len(test_dataset) * test_opt.batch_size), end='\r')
123 | message = model.print_test_depth_evaluation()
124 | visualizer.print_current_depth_evaluation(message)
125 | print(
126 | 'RMSE={result.rmse:.4f}({average.rmse:.4f}) '
127 | 'MAE={result.mae:.4f}({average.mae:.4f}) '
128 | 'Delta1={result.delta1:.4f}({average.delta1:.4f}) '
129 | 'REL={result.absrel:.4f}({average.absrel:.4f}) '
130 | 'Lg10={result.lg10:.4f}({average.lg10:.4f}) '.format(
131 | result=model.test_result, average=model.test_average.average()))
132 | avg_test_loss = np.mean(np.asarray(test_loss_iter))
133 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/choyingw/CFCNet/828e0c09c646a4669685b3d31b8aa0ae2a5cd351/util/__init__.py
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | import os
6 |
7 |
8 | # Converts a Tensor into an image array (numpy)
9 | # |imtype|: the desired type of the converted numpy array
10 | def tensor2im(input_image, imtype=np.uint8):
11 | if isinstance(input_image, torch.Tensor):
12 | image_tensor = input_image.data
13 | else:
14 | return input_image
15 | image_numpy = image_tensor[0].cpu().float().numpy()
16 | if image_numpy.shape[0] == 1:
17 | image_numpy = np.tile(image_numpy, (3, 1, 1))
18 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)))* 255.0
19 | return image_numpy.astype(imtype)
20 |
21 |
22 | def tensor2labelim(label_tensor, impalette, imtype=np.uint8):
23 | if len(label_tensor.shape) == 4:
24 | _, label_tensor = torch.max(label_tensor.data.cpu(), 1)
25 |
26 | label_numpy = label_tensor[0].cpu().float().detach().numpy()
27 | label_image = Image.fromarray(label_numpy.astype(np.uint8))
28 | label_image = label_image.convert("P")
29 | label_image.putpalette(impalette)
30 | label_image = label_image.convert("RGB")
31 | return np.array(label_image).astype(imtype)
32 |
33 | def diagnose_network(net, name='network'):
34 | mean = 0.0
35 | count = 0
36 | for param in net.parameters():
37 | if param.grad is not None:
38 | mean += torch.mean(torch.abs(param.grad.data))
39 | count += 1
40 | if count > 0:
41 | mean = mean / count
42 | print(name)
43 | print(mean)
44 |
45 |
46 | def save_image(image_numpy, image_path):
47 | image_pil = Image.fromarray(image_numpy)
48 | image_pil.save(image_path)
49 |
50 | def save_image_cv2(image_numpy, image_path):
51 | #image_pil = Image.fromarray(image_numpy)
52 | cv2.imwrite(image_path,image_numpy)
53 | #image_pil.save(image_path)
54 |
55 |
56 | def print_numpy(x, val=True, shp=False):
57 | x = x.astype(np.float64)
58 | if shp:
59 | print('shape,', x.shape)
60 | if val:
61 | x = x.flatten()
62 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
63 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
64 |
65 |
66 | def mkdirs(paths):
67 | if isinstance(paths, list) and not isinstance(paths, str):
68 | for path in paths:
69 | mkdir(path)
70 | else:
71 | mkdir(paths)
72 |
73 |
74 | def mkdir(path):
75 | if not os.path.exists(path):
76 | os.makedirs(path)
77 |
78 | def confusion_matrix(x , y, n, ignore_label=None, mask=None):
79 | if mask is None:
80 | mask = np.ones_like(x) == 1
81 | k = (x >= 0) & (y < n) & (x != ignore_label) & (mask.astype(np.bool))
82 | return np.bincount(n * x[k].astype(int) + y[k], minlength=n**2).reshape(n, n)
83 |
84 | def getScores(conf_matrix):
85 | if conf_matrix.sum() == 0:
86 | return 0, 0, 0
87 | with np.errstate(divide='ignore',invalid='ignore'):
88 | overall = np.diag(conf_matrix).sum() / np.float(conf_matrix.sum())
89 | perclass = np.diag(conf_matrix) / conf_matrix.sum(1).astype(np.float)
90 | IU = np.diag(conf_matrix) / (conf_matrix.sum(1) + conf_matrix.sum(0) - np.diag(conf_matrix)).astype(np.float)
91 | return overall * 100., np.nanmean(perclass) * 100., np.nanmean(IU) * 100.
92 |
--------------------------------------------------------------------------------
/util/visualizer.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | class Visualizer():
4 | def __init__(self, opt):
5 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
6 |
7 | # losses: same format as |losses| of plot_current_losses
8 | def print_current_losses(self, epoch, i, losses, t, t_data):
9 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data)
10 | for k, v in losses.items():
11 | message += '%s: %.3f ' % (k, v)
12 | print(message)
13 | with open(self.log_name, "a") as log_file:
14 | log_file.write('%s\n' % message)
15 |
16 | def print_current_depth_evaluation(self, message):
17 | with open(self.log_name, "a") as log_file:
18 | log_file.write('%s\n' % message)
19 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import shutil
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | from PIL import Image
7 | import cv2
8 |
9 | def colored_depthmap(depth, d_min=None, d_max=None):
10 | if d_min is None:
11 | d_min = np.min(depth)
12 | if d_max is None:
13 | d_max = np.max(depth)
14 | depth_relative = (depth - d_min) / (d_max - d_min)
15 | return 255 * plt.cm.jet(depth_relative)[:,:,:3] # H, W, C
16 |
17 | cmap = plt.cm.jet
18 | def depth_colorize_16(depth):
19 | depth = (depth - np.min(depth)) / (np.max(depth) - np.min(depth))
20 | depth = 255* 256 * cmap(depth)[:,:,:3] # H, W, C
21 | return depth.astype('uint16')
22 |
23 | def depth_colorize_8(depth):
24 | depth = (depth - np.min(depth)) / (np.max(depth) - np.min(depth))
25 | depth = 255* cmap(depth)[:,:,:3] # H, W, C
26 | return depth.astype('uint8')
27 |
28 | def Enlarge_pixel(sparse_depth):
29 | for i in range(2,sparse_depth.shape[0]-2):
30 | for j in range(2,sparse_depth.shape[1]-2):
31 | if np.sum(sparse_depth[i][j]) > 0:
32 | for w in range(-2,2):
33 | for h in range(-2,2):
34 | sparse_depth[i+w][j+h] = sparse_depth[i][j]
35 |
36 | return sparse_depth
37 |
38 | def merge_into_row_with_pred_visualize(input, depth_input, rgb_sparse, depth_target, depth_est):
39 | rgb = 255 * np.transpose(np.squeeze(input.cpu().numpy()), (1,2,0))[:,:,(2,1,0)] # H, W, C
40 | rgb_sparse = 255 * np.transpose(np.squeeze(rgb_sparse.cpu().numpy()), (1,2,0))[:,:,(2,1,0)]
41 | depth_input_cpu = np.squeeze(depth_input.cpu().numpy())
42 | depth_target_cpu = np.squeeze(depth_target.cpu().numpy())
43 | depth_pred_cpu = np.squeeze(depth_est.cpu().numpy())
44 |
45 | d_min = min(np.min(depth_input_cpu), np.min(depth_target_cpu), np.min(depth_pred_cpu))
46 | d_max = max(np.max(depth_input_cpu), np.max(depth_target_cpu), np.min(depth_pred_cpu))
47 | depth_input_col = colored_depthmap(depth_input_cpu, d_min, d_max)
48 | # depth_target_col = colored_depthmap(depth_target_cpu, d_min, d_max)
49 | # depth_pred_col = colored_depthmap(depth_pred_cpu, d_min, d_max)
50 | depth_input_col = (depth_colorize_8(depth_input_cpu))
51 | depth_target_col = (depth_colorize_8(depth_target_cpu))
52 | depth_pred_col = depth_colorize_8(depth_pred_cpu)
53 |
54 | img_merge = np.hstack([rgb, depth_pred_col])
55 | #img_merge = np.hstack([rgb,depth_input_col])
56 | #depth_merge = np.hstack([depth_pred_col,depth_target_col])
57 | #img_merge = np.vstack([img_merge,depth_merge])
58 |
59 | return img_merge
60 |
61 | def add_row(img_merge, row):
62 | return np.vstack([img_merge, row])
63 |
64 | def save_image(img_merge, filename):
65 | img_merge = Image.fromarray(img_merge.astype('uint8'))
66 | img_merge.save(filename)
67 |
68 | def save_image_cv2(image_numpy, image_path):
69 | #image_pil = Image.fromarray(image_numpy)
70 | cv2.imwrite(image_path,image_numpy)
71 | #image_pil.save(image_path)
72 |
73 |
74 |
--------------------------------------------------------------------------------