├── datasets └── dataset.txt ├── weights └── checkpoint.txt ├── assets ├── Banner.png └── framework.png ├── phasecam-psfs-27.npy ├── scripts ├── __pycache__ │ ├── data.cpython-311.pyc │ ├── unet.cpython-311.pyc │ └── config.cpython-311.pyc ├── data.py ├── unet.py ├── evaluate.py ├── trainer.py ├── config.py └── coded_generator.py ├── LICENSE ├── README.md └── environment.yml /datasets/dataset.txt: -------------------------------------------------------------------------------- 1 | place datasets in this folder 2 | -------------------------------------------------------------------------------- /weights/checkpoint.txt: -------------------------------------------------------------------------------- 1 | place weight file in this folder 2 | -------------------------------------------------------------------------------- /assets/Banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naitri/CodedVO/HEAD/assets/Banner.png -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naitri/CodedVO/HEAD/assets/framework.png -------------------------------------------------------------------------------- /phasecam-psfs-27.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naitri/CodedVO/HEAD/phasecam-psfs-27.npy -------------------------------------------------------------------------------- /scripts/__pycache__/data.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naitri/CodedVO/HEAD/scripts/__pycache__/data.cpython-311.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/unet.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naitri/CodedVO/HEAD/scripts/__pycache__/unet.cpython-311.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/config.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naitri/CodedVO/HEAD/scripts/__pycache__/config.cpython-311.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Naitri_Rajyaguru 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | from torch.utils.data import Dataset 5 | import torchvision.transforms as transforms 6 | import natsort 7 | 8 | # Set the environment variable for OpenEXR support in OpenCV 9 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" 10 | 11 | # Check if CUDA is available and set the device accordingly 12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | 14 | class ImageDepthDataset(Dataset): 15 | def __init__(self, base: str, path: str, codedDir: str = "Coded", cache: bool = True, is_blender: bool = False, image_size=(480, 640), scale_factor: float = 5000): 16 | """ 17 | Initialize the dataset. 18 | 19 | Args: 20 | base (str): Path to the base directory. 21 | path (str): Path to the dataset. 22 | codedDir (str): Directory containing coded images. 23 | cache (bool): Whether to cache the dataset in memory. 24 | is_blender (bool): Whether the dataset uses Blender EXR files. 25 | image_size (tuple): Size of the images. 26 | scale_factor (float): Scale factor for depth normalization. 27 | """ 28 | self.path = path 29 | self.is_blender = is_blender 30 | self.transform = transforms.Compose([transforms.CenterCrop(image_size)]) 31 | self.data = [] 32 | self.scale_factor = scale_factor 33 | 34 | # Directory path for the coded images 35 | dir_path = os.path.join(base, path, codedDir) 36 | 37 | # Get list of PNG files in the coded directory 38 | ffiles = natsort.natsorted([p for p in os.listdir(dir_path) if p.endswith(".png")]) 39 | 40 | # Process each file 41 | for file in files: 42 | coded_file = os.path.join(base, path, codedDir, file) 43 | depth_file = os.path.join(base, path, "depth", file.replace(".png", ".exr") if is_blender else file) 44 | # Cache the processed data or store file paths 45 | if cache: 46 | self.data.append(self.process(coded_file, depth_file)) 47 | else: 48 | self.data.append((coded_file, depth_file)) 49 | 50 | self.cache = cache 51 | self.len = len(self.data) 52 | 53 | def process(self, coded_file: str, depth_file: str): 54 | """ 55 | Process a single pair of coded and depth images. 56 | 57 | Args: 58 | coded_file (str): Path to the coded image file. 59 | depth_file (str): Path to the depth image file. 60 | 61 | Returns: 62 | dict: Processed images in a dictionary. 63 | """ 64 | # Read the coded image and convert to a tensor 65 | coded = torch.from_numpy(cv2.imread(coded_file)).permute(2, 0, 1) 66 | 67 | # Read the depth image and convert to a tensor 68 | if self.is_blender: 69 | raw_depth = cv2.imread(depth_file, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) 70 | metric_depth = torch.from_numpy(raw_depth[:, :, 0]) 71 | else: 72 | metric_depth = torch.from_numpy(cv2.imread(depth_file, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) / self.scale_factor) 73 | 74 | return { 75 | "Coded": self.transform(coded.to(torch.float32)) / 255.0, 76 | "Depth": self.transform(metric_depth.to(torch.float32)), 77 | } 78 | 79 | def __len__(self): 80 | return self.len 81 | 82 | def __getitem__(self, idx): 83 | if self.cache: 84 | return self.data[idx] 85 | else: 86 | return self.process(*self.data[idx]) 87 | 88 | def __repr__(self): 89 | dataset_type = "Blender" if self.is_blender else "ICL" 90 | return f"{dataset_type}Dataset(path={self.path}, n={self.len}, scale_factor={self.scale_factor})" 91 | -------------------------------------------------------------------------------- /scripts/unet.py: -------------------------------------------------------------------------------- 1 | # mostly from https://github.com/LeeJunHyun/Image_Segmentation 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import init 5 | 6 | 7 | def init_weights(net, init_type="normal", gain=0.02): 8 | def init_func(m): 9 | classname = m.__class__.__name__ 10 | if hasattr(m, "weight") and ( 11 | classname.find("Conv") != -1 or classname.find("Linear") != -1 12 | ): 13 | if init_type == "normal": 14 | init.normal_(m.weight.data, 0.0, gain) 15 | elif init_type == "xavier": 16 | init.xavier_normal_(m.weight.data, gain=gain) 17 | elif init_type == "kaiming": 18 | init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") 19 | elif init_type == "orthogonal": 20 | init.orthogonal_(m.weight.data, gain=gain) 21 | else: 22 | raise NotImplementedError( 23 | "initialization method [%s] is not implemented" % init_type 24 | ) 25 | if hasattr(m, "bias") and m.bias is not None: 26 | init.constant_(m.bias.data, 0.0) 27 | elif classname.find("BatchNorm2d") != -1: 28 | init.normal_(m.weight.data, 1.0, gain) 29 | init.constant_(m.bias.data, 0.0) 30 | 31 | print("initialize network with %s" % init_type) 32 | net.apply(init_func) 33 | 34 | 35 | def count_parameters(model): 36 | pp = 0 37 | for p in list(model.parameters()): 38 | nn = 1 39 | for s in list(p.size()): 40 | nn = nn * s 41 | pp += nn 42 | return pp 43 | 44 | 45 | class conv_block(nn.Module): 46 | def __init__(self, ch_in, ch_out): 47 | super(conv_block, self).__init__() 48 | self.conv = nn.Sequential( 49 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 50 | nn.BatchNorm2d(ch_out), 51 | nn.ReLU(inplace=True), 52 | nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 53 | nn.BatchNorm2d(ch_out), 54 | nn.ReLU(inplace=True), 55 | ) 56 | 57 | def forward(self, x): 58 | x = self.conv(x) 59 | return x 60 | 61 | 62 | class up_conv(nn.Module): 63 | def __init__(self, ch_in, ch_out): 64 | super(up_conv, self).__init__() 65 | self.up = nn.Sequential( 66 | nn.Upsample(scale_factor=2), 67 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 68 | nn.BatchNorm2d(ch_out), 69 | nn.ReLU(inplace=True), 70 | ) 71 | 72 | def forward(self, x): 73 | x = self.up(x) 74 | return x 75 | 76 | 77 | class U_Net(nn.Module): 78 | def __init__(self, img_ch=3, output_ch=1): 79 | super(U_Net, self).__init__() 80 | 81 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 82 | 83 | self.Conv1 = conv_block(ch_in=img_ch, ch_out=64) 84 | self.Conv2 = conv_block(ch_in=64, ch_out=128) 85 | self.Conv3 = conv_block(ch_in=128, ch_out=256) 86 | self.Conv4 = conv_block(ch_in=256, ch_out=512) 87 | self.Conv5 = conv_block(ch_in=512, ch_out=1024) 88 | 89 | self.Up5 = up_conv(ch_in=1024, ch_out=512) 90 | self.Up_conv5 = conv_block(ch_in=1024, ch_out=512) 91 | 92 | self.Up4 = up_conv(ch_in=512, ch_out=256) 93 | self.Up_conv4 = conv_block(ch_in=512, ch_out=256) 94 | 95 | self.Up3 = up_conv(ch_in=256, ch_out=128) 96 | self.Up_conv3 = conv_block(ch_in=256, ch_out=128) 97 | 98 | self.Up2 = up_conv(ch_in=128, ch_out=64) 99 | self.Up_conv2 = conv_block(ch_in=128, ch_out=64) 100 | 101 | self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0) 102 | 103 | def forward(self, x): 104 | # encoding path 105 | x1 = self.Conv1(x) 106 | 107 | x2 = self.Maxpool(x1) 108 | x2 = self.Conv2(x2) 109 | 110 | x3 = self.Maxpool(x2) 111 | x3 = self.Conv3(x3) 112 | 113 | x4 = self.Maxpool(x3) 114 | x4 = self.Conv4(x4) 115 | 116 | x5 = self.Maxpool(x4) 117 | x5 = self.Conv5(x5) 118 | 119 | # decoding + concat path 120 | d5 = self.Up5(x5) 121 | d5 = torch.cat((x4, d5), dim=1) 122 | 123 | d5 = self.Up_conv5(d5) 124 | 125 | d4 = self.Up4(d5) 126 | d4 = torch.cat((x3, d4), dim=1) 127 | d4 = self.Up_conv4(d4) 128 | 129 | d3 = self.Up3(d4) 130 | d3 = torch.cat((x2, d3), dim=1) 131 | d3 = self.Up_conv3(d3) 132 | 133 | d2 = self.Up2(d3) 134 | d2 = torch.cat((x1, d2), dim=1) 135 | d2 = self.Up_conv2(d2) 136 | 137 | d1 = self.Conv_1x1(d2) 138 | 139 | return d1 140 | 141 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CodedVO: Coded Visual Odometry 2 | 3 | Official implementation of **"CodedVO: Coded Visual Odometry"** accepted in IEEE Robotics and Automation Letters, 2024. 4 | 5 | [**Project page**](http://prg.cs.umd.edu/CodedVO) | [**IEEE Xplore**](https://ieeexplore.ieee.org/abstract/document/10564186) | [**arXiv**](https://arxiv.org/pdf/2407.18240) 6 | 7 | ![Example of coded aperture setup](assets/Banner.png) 8 | 9 | ## Video 10 | 11 | [![CodedVO Video](https://img.youtube.com/vi/5MOlGOLvtS4/0.jpg)](https://www.youtube.com/watch?v=5MOlGOLvtS4) 12 | 13 | ## Citation 14 | 15 | If you use this code in your research, please cite: 16 | 17 | ```bibtex 18 | @ARTICLE{codedvo2024, 19 | author={Shah, Sachin and Rajyaguru, Naitri and Singh, Chahat Deep and Metzler, Christopher and Aloimonos, Yiannis}, 20 | journal={IEEE Robotics and Automation Letters}, 21 | title={CodedVO: Coded Visual Odometry}, 22 | year={2024}, 23 | doi={10.1109/LRA.2024.3416788}} 24 | ``` 25 | ## Table of Contents 26 | 27 | 1. [Introduction](#introduction) 28 | 2. [Installation](#installation) 29 | - [Clone Repository](#clone-repository) 30 | - [Environment Setup](#environment-setup) 31 | 3. [Models](#models) 32 | - [Download Pre-trained Models](#download-pre-trained-models) 33 | 4. [Dataset](#dataset) 34 | - [Download and Setup](#download-and-setup) 35 | - [Dataset Structure](dataset-structure) 36 | - [Generate Coded Images](#generate-coded-images) 37 | 5. [Training](#training) 38 | - [Train from Scratch](#train-from-scratch) 39 | 6. [Evaluation](#evaluation) 40 | 7. [Usage](#usage) 41 | - [Run Visual Odometry](#run-visual-odometry) 42 | 8. [Contributions](#contributions) 43 | 44 | ## Introduction 45 | - A novel method for estimating monocular visual odometry that leverages RGB and metric depth estimates obtained through a phase mask on a standard 1-inch camera sensor. 46 | - A depth-weighted loss function designed to prioritize learning depth maps at closer distances. 47 | - Evaluation in zero-shot indoor scenes without requiring a scale for evaluation. 48 | 49 | ## Installation 50 | 51 | ### Clone Repository 52 | 53 | ```bash 54 | git clone https://github.com/naitri/CodedVO 55 | cd CodedVO 56 | ``` 57 | 58 | ### Environment Setup 59 | 60 | ```bash 61 | conda env create -f environment.yml 62 | ``` 63 | 64 | ## Models 65 | 66 | ### Download Pre-trained Models 67 | We provide our metric depth-weighted loss pre-trained model, which has been benchmarked on various indoor datasets. 68 | [Download Pre-trained Models](https://drive.google.com/drive/folders/1N8GyIXZe1DBrKiHNpmL3U363nQy-Rwi8?usp=sharing) 69 | 70 | ## Dataset 71 | 72 | ### Download and Setup 73 | We provide the training dataset, which includes the UMD-CodedVO dataset LivingRoom and NYU data, each containing 1000 images. The dataset also includes coded blur RGB images. 74 | - [Training data](https://drive.google.com/drive/folders/12GrDxTBMaSlGeMRWycxmCQl01BHnC5-O?usp=sharing) 75 | 76 | Additionally, we provide UMD-CodedVO dataset which includes ground truth depth, RGB images, coded blur RGB images, and trajectory information. 77 | - [UMD-CodedVO Dataset](https://drive.google.com/drive/folders/12U8BH-AWUA4DgbOValO-_hNI_Z9RgMXr?usp=sharing) 78 | 79 | ### Dataset Structure 80 | ``` 81 | ├── README.md 82 | ├── datasets 83 | │ └── nyu_data 84 | │ ├── rgb 85 | │ ├── depth 86 | │ └── Codedphasecam-27Linear 87 | │ └── ... 88 | ├── scripts 89 | │ └── ... 90 | ├── weights 91 | │ └── ... 92 | ``` 93 | 94 | ### Generate Coded Images 95 | To generate coded blur RGB images from your own data, you can use the script `coded-generator.py`. 96 | 97 | ```bash 98 | cd scripts 99 | python coded_generator.py --root /path/to/your/data --scale_factor YOUR_SCALE_FACTOR 100 | ``` 101 | - Scale factor for NYUv2 dataset is 1000, UMD-CodedVO-dataset is 1 and ICL-NUIM dataset is 5000. 102 | - root path should be fodler contianing rgb, depth, Codedphasecam-27Linear. for e.g. ./datasets/nyu_data 103 | 104 | Note: Our Point Spread Functions (PSFs) correspond to discretized depth layers using a 23×23 Zernike parameterized phase mask,with the depth range discretized into 27 bins within the interval of [0.5, 6] meters, with a focal distance of 85 cm. 105 | 106 | ## Training 107 | 108 | ### Train from Scratch 109 | To train your data or our given dataset : 110 | ```bash 111 | python trainer.py --config MetricWeightedLossBlenderNYU --datasets /path/to/dataset/folder 112 | ``` 113 | - You can add different configurations for loss and depth space in config.py and use those configurations for training. In this example, we use MetricWeightedLossBlenderNYU for our pre-trained weight file. 114 | - You can also change the training or test dataset in config.py by modifying lines 19-31. 115 | 116 | ## Evaluation 117 | The evaluation script can be executed as follows: 118 | ```bash 119 | python evaluate.py --CONFIG MetricWeightedLossBlenderNYU --DATASET /path/to/dataset/folder --OUTPUT /path/to/output/folder --CHECKPOINT /path/to/checkpoint/file 120 | ``` 121 | 122 | ## Usage 123 | 124 | ### Run Visual Odometry 125 | We use ORB-SLAM after disabling the loop closure. Predicted depth maps from the above models are used to compute the odometry. Follow the [ORB-SLAM2](https://github.com/raulmur/ORB_SLAM2) RGBD execution instructions. Note that we do not use coded blur RGB images directly. As mentioned in the paper, we apply [unsharp masking](https://www.mathworks.com/help/images/ref/imsharpen.html) on them for computing odometry. 126 | 127 | ## Acknowledgements 128 | We would like to thank authors of [Phasecam3D](https://github.com/YichengWu/PhaseCam3D) and [ORB-SLAM2](https://github.com/raulmur/ORB_SLAM2) for opensourcing codebase. 129 | 130 | ## Contributions 131 | If you have any questions/comments/bug reports, feel free to open a github issue or pull a request or e-mail to the authors [Naitri Rajyaguru](mailto:nrajyagu@umd.edu) or [Sachin Shah](mailto:shah2022@umd.edu) 132 | 133 | -------------------------------------------------------------------------------- /scripts/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import time 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader 10 | 11 | from config import get_experiment, DatasetName_blender 12 | from data import ImageDepthDataset 13 | 14 | # Ensure OpenEXR support in OpenCV 15 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" 16 | 17 | # Select device 18 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 19 | 20 | # Argument parser setup 21 | parser = argparse.ArgumentParser(description="Evaluate depth prediction models.") 22 | parser.add_argument("--CONFIG", "-c", type=str, required=True, help="Path to the configuration file.") 23 | parser.add_argument("--DATASET", "-d", type=str, default="datasets", required=False, help="Path to the dataset directory.") 24 | parser.add_argument("--CHECKPOINT", "-s", type=str, required=True, help="Path to the model checkpoint.") 25 | parser.add_argument("--OUTPUT", "-o", type=str, required=True, help="Path to the output directory.") 26 | parser.add_argument("--is_blender", action="store_true", help="Use Blender's EXR depth format.") 27 | args = parser.parse_args(sys.argv[1:]) 28 | 29 | # Load experiment configuration 30 | DATASET_DIR = args.DATASET 31 | EXPERIMENT = get_experiment(args.CONFIG) 32 | 33 | # Load datasets 34 | ## NOTE: If evaluating on other datasets, do not use is_blender flag, and use scale factor accordingly. 35 | blender_datasets = { 36 | name.name: ImageDepthDataset( 37 | base=DATASET_DIR, 38 | path=name.name, 39 | codedDir=EXPERIMENT.coded, 40 | cache=True, 41 | is_blender=args.is_blender, 42 | scale_factor=5000 43 | ) 44 | for name in DatasetName_blender 45 | } 46 | 47 | # Create DataLoader for the test datasets 48 | test_loaders = {name: DataLoader(dataset, batch_size=1, shuffle=False) for name, dataset in blender_datasets.items()} 49 | 50 | # Initialize the model 51 | model = EXPERIMENT.model().to(device).eval() 52 | model.load_state_dict(torch.load(args.CHECKPOINT, map_location=device)) 53 | 54 | # Loss function 55 | L1 = nn.L1Loss() 56 | 57 | def to_numpy(img: torch.Tensor): 58 | """ 59 | Convert PyTorch tensor to NumPy array. 60 | """ 61 | return np.clip(img.detach().cpu().numpy(), 0, None) 62 | 63 | def sigma_metric(estimated_depth, ground_truth_depth, threshold): 64 | """ 65 | Compute the sigma accuracy metrics (as per standard depth estimation). 66 | """ 67 | ratio = torch.max(estimated_depth / ground_truth_depth, ground_truth_depth / estimated_depth) 68 | return torch.mean((ratio < threshold).float()) 69 | 70 | def evaluate(dataloader: DataLoader, output_dir: str): 71 | """ 72 | Evaluate the model and compute various depth metrics. 73 | 74 | Args: 75 | dataloader (DataLoader): Dataloader for the test dataset. 76 | output_dir (str): Directory to save the output depth images. 77 | 78 | Returns: 79 | dict: Aggregated evaluation metrics. 80 | """ 81 | model.eval() 82 | metrics = { 83 | "metric_depth_error": 0, 84 | "metric_depth_error_under3": 0, 85 | "abs_rel": 0, 86 | "sq_rel": 0, 87 | "rmse": 0, 88 | "rmse_log": 0, 89 | "sigma_1_25": 0, 90 | "sigma_1_25_2": 0, 91 | "sigma_1_25_3": 0, 92 | "sample_count": 0, 93 | "total_inference_time": 0 94 | } 95 | 96 | with torch.no_grad(): 97 | for idx, batch in enumerate(dataloader): 98 | start_time = time.time() 99 | recon = EXPERIMENT.post_forward(model(batch["Coded"].to(device))) 100 | end_time = time.time() 101 | batch_inference_time = end_time - start_time 102 | metrics["total_inference_time"] += batch_inference_time 103 | 104 | # Convert depth values to metric depth 105 | metric_gt = EXPERIMENT.depth.output_to_metric(batch["Depth"]).squeeze(1) 106 | metric_re = EXPERIMENT.depth.output_to_metric(recon).squeeze(1) 107 | 108 | # Apply valid mask for depth values greater than 0 109 | valid_mask = metric_gt > 0 110 | gt = torch.clamp(metric_gt[valid_mask], 0, 6).to(device) 111 | pred = torch.clamp(metric_re[valid_mask], 0, 6) 112 | 113 | log_diff = torch.log(pred) - torch.log(gt) 114 | metrics["rmse_log"] += torch.sqrt(torch.mean(log_diff ** 2)) 115 | 116 | # Calculate metrics only for valid depth predictions 117 | mask = gt < 3 118 | if torch.any(mask).item(): 119 | metrics["metric_depth_error_under3"] += L1(pred[mask], gt[mask]).item() * len(batch) 120 | 121 | metrics["abs_rel"] += torch.mean(torch.abs(pred - gt) / gt).item() * len(batch) 122 | metrics["sq_rel"] += torch.mean(((pred - gt) ** 2) / gt).item() * len(batch) 123 | metrics["rmse"] += torch.sqrt(torch.mean(((pred - gt) ** 2))).item() * len(batch) 124 | metrics["sigma_1_25"] += sigma_metric(pred, gt, 1.25) * len(batch) 125 | metrics["sigma_1_25_2"] += sigma_metric(pred, gt, 1.25 ** 2) * len(batch) 126 | metrics["sigma_1_25_3"] += sigma_metric(pred, gt, 1.25 ** 3) * len(batch) 127 | metrics["metric_depth_error"] += L1(pred, gt).item() * len(batch) 128 | metrics["sample_count"] += len(batch) 129 | 130 | # Save the predicted depth image 131 | prediction = (metric_re.squeeze(0).cpu().numpy() * 5000).astype(np.uint16) 132 | output_path_pred = os.path.join(output_dir, "pred_depth", f"{idx}.png") 133 | cv2.imwrite(output_path_pred, prediction) 134 | 135 | # Compute average inference time per batch 136 | avg_inference_speed = metrics["total_inference_time"] / len(dataloader) 137 | avg_fps = 1 / avg_inference_speed 138 | 139 | print(f"Average Inference Speed: {avg_inference_speed:.4f} seconds per batch") 140 | print(f"Average FPS: {avg_fps:.2f} frames per second") 141 | 142 | # Average the metrics 143 | avg_metrics = {k: v / metrics["sample_count"] for k, v in metrics.items() if k != "total_inference_time"} 144 | 145 | return avg_metrics 146 | 147 | # Ensure output directories exist 148 | output_dir = args.OUTPUT 149 | os.makedirs(os.path.join(output_dir, "pred_depth"), exist_ok=True) 150 | 151 | # Evaluate the model on each test dataset 152 | for name, dataloader in test_loaders.items(): 153 | metrics = evaluate(dataloader, output_dir) 154 | 155 | if name in [item.name for item in EXPERIMENT.train]: 156 | print(f"[train] {name}") 157 | else: 158 | print(f"{name}") 159 | 160 | print(f"| L1 : {metrics['metric_depth_error']:.3f}") 161 | print(f"| L1 <3m : {metrics['metric_depth_error_under3']:.3f}") 162 | print(f"| Mean Absolute Relative Error (abs_rel): {metrics['abs_rel']:.3f}") 163 | print(f"| Mean Squared Relative Error (sq_rel): {metrics['sq_rel']:.3f}") 164 | print(f"| Root Mean Squared Error (RMSE): {metrics['rmse']:.3f}") 165 | print(f"| RMSE Log: {metrics['rmse_log']:.3f}") 166 | print(f"| Sigma 1.25: {metrics['sigma_1_25']:.3f}") 167 | print(f"| Sigma 1.25^2: {metrics['sigma_1_25_2']:.3f}") 168 | print(f"| Sigma 1.25^3: {metrics['sigma_1_25_3']:.3f}") 169 | print() 170 | -------------------------------------------------------------------------------- /scripts/trainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import time 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from torch.utils.data import DataLoader, ConcatDataset 9 | import wandb 10 | 11 | from config import get_experiment, DatasetName, DatasetName_train_icl, DatasetName_train_blender 12 | from data import ImageDepthDataset 13 | from unet import init_weights, count_parameters 14 | 15 | # Set the environment variable for OpenEXR support in OpenCV 16 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" 17 | 18 | # Select the device to use for computation 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | 21 | # Argument parser setup 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--CONFIG", "-c", type=str, required=True, help="Path to the configuration file.") 24 | parser.add_argument("--SAVE_EVERY", "-e", type=int, default=10, required=False, help="Save checkpoint every N epochs.") 25 | parser.add_argument("--DATASET", "-d", type=str, default="datasets", required=False, help="Path to the dataset directory.") 26 | parser.add_argument("--SAVEDIR", "-s", type=str, default="checkpoints", required=False, help="Directory to save checkpoints.") 27 | args = parser.parse_args(sys.argv[1:]) 28 | 29 | # Load the experiment configuration 30 | DATASET_DIR = args.DATASET 31 | EXPERIMENT = get_experiment(args.CONFIG) 32 | experiment_name = EXPERIMENT.__class__.__name__.split("/")[-1] 33 | CHECKPOINT_DIR = os.path.join(args.SAVEDIR, experiment_name) 34 | os.makedirs(CHECKPOINT_DIR, exist_ok=True) 35 | 36 | # Load datasets 37 | ##NOTE: Scale factor must be changed depeding on the training and test data. Blender EXR scale factor remains 1. 38 | test_datasets = {name: ImageDepthDataset(DATASET_DIR, name.name, codedDir="Codedphasecam-27Linear", cache=True, scale_factor=1000) for name in DatasetName} 39 | nyu_datasets_train = {name: ImageDepthDataset(DATASET_DIR, name.name, codedDir="Codedphasecam-27Linear", cache=True, scale_factor=1000) for name in DatasetName_train_icl} 40 | blender_datasets_train = {name: ImageDepthDataset(DATASET_DIR, name.name, codedDir="Codedphasecam-27Linear", cache=True, is_blender=True, scale_factor=1) for name in DatasetName_train_blender} 41 | 42 | # Combine training datasets 43 | train_datasets = [nyu_datasets_train[name] for name in EXPERIMENT.train if name in nyu_datasets_train] + \ 44 | [blender_datasets_train[name] for name in EXPERIMENT.train if name in blender_datasets_train] 45 | 46 | # Create DataLoader for training 47 | train_dataset = ConcatDataset(train_datasets) 48 | train_loader = DataLoader(train_dataset, batch_size=EXPERIMENT.batch_size, shuffle=True, num_workers=4) 49 | 50 | # Create DataLoader for testing 51 | test_loaders = {name: DataLoader(dataset, batch_size=EXPERIMENT.batch_size, shuffle=True) for name, dataset in test_datasets.items()} 52 | 53 | # Print the loaded testing datasets 54 | for name in test_loaders.keys(): 55 | print(f"Testing dataset: {name.name}") 56 | 57 | # Initialize the model and optimizer 58 | model = EXPERIMENT.model().to(device) 59 | init_weights(model) 60 | model_optimizer = torch.optim.Adam(model.parameters(), lr=EXPERIMENT.learning_rate) 61 | L1 = nn.L1Loss() 62 | 63 | # Initialize Weights and Biases for logging 64 | wandb.init(project="coded_losses", name=experiment_name, config={**vars(args)}) 65 | print(repr(EXPERIMENT)) 66 | print(f"Training model with {count_parameters(model)} parameters") 67 | 68 | def wandbimg(img: torch.Tensor, vmax=6.5): 69 | """ 70 | Convert tensor image to WandB image format. 71 | 72 | Args: 73 | img (torch.Tensor): Input image tensor. 74 | vmax (float): Maximum value for clipping. 75 | 76 | Returns: 77 | wandb.Image: Image in WandB format. 78 | """ 79 | out = np.clip(img.detach().cpu().numpy(), 0, vmax) / vmax * 255 80 | return wandb.Image(out.astype(np.uint8)) 81 | 82 | def evaluate(dataloader: DataLoader): 83 | """ 84 | Evaluate the model on the given dataloader. 85 | 86 | Args: 87 | dataloader (DataLoader): DataLoader for evaluation. 88 | 89 | Returns: 90 | tuple: Average L1 error, L1 error for depth < 3m, ground truth depth map, reconstructed depth map. 91 | """ 92 | model.eval() 93 | with torch.no_grad(): 94 | metric_depth_error = 0 95 | metric_depth_error_under3 = 0 96 | sample_count = 0 97 | 98 | for batch in dataloader: 99 | recon = EXPERIMENT.post_forward(model(batch["Coded"].to(device))) 100 | metric_gt = batch["Depth"].to(device) 101 | metric_re = EXPERIMENT.depth.output_to_metric(recon) 102 | 103 | mask = metric_gt < 3 104 | if torch.any(mask).item(): 105 | metric_depth_error_under3 += L1(metric_re[mask, 0], metric_gt[mask]).item() * len(batch) 106 | 107 | metric_depth_error += L1(metric_re[:, 0], metric_gt).item() * len(batch) 108 | sample_count += len(batch) 109 | 110 | model.train() 111 | ground_truth_depth_map = wandbimg(metric_gt[0]) 112 | reconstructed_depth_map = wandbimg(metric_re[0, 0]) 113 | 114 | return ( 115 | metric_depth_error / sample_count, 116 | metric_depth_error_under3 / sample_count, 117 | ground_truth_depth_map, 118 | reconstructed_depth_map, 119 | ) 120 | 121 | # Training loop 122 | NUM_TEST_SETS = len(DatasetName) 123 | previous_test_error = float('inf') 124 | 125 | for epoch in range(EXPERIMENT.epochs): 126 | start_time = time.monotonic() 127 | 128 | total_error = 0 129 | for batch in train_loader: 130 | model_optimizer.zero_grad() 131 | reconstruction = EXPERIMENT.post_forward(model(batch["Coded"].to(device))) 132 | metric_gt = batch["Depth"].to(device) 133 | ground_truth = EXPERIMENT.depth.metric_to_output(metric_gt) 134 | error = EXPERIMENT.compute_loss(ground_truth, reconstruction, epoch) 135 | error.backward() 136 | model_optimizer.step() 137 | total_error += error.item() 138 | 139 | # Evaluate the model on the test datasets 140 | test_artifacts = {} 141 | total_avg_l1 = 0 142 | total_u3_l1 = 0 143 | 144 | for name, dataloader in test_loaders.items(): 145 | avg_l1, u3_l1, gt_depth_map, re_depth_map = evaluate(dataloader) 146 | test_artifacts[f"{name.name}: L1"] = avg_l1 147 | test_artifacts[f"{name.name}: L1 <3m"] = u3_l1 148 | 149 | if epoch % 5 == 0: 150 | test_artifacts[f"{name.name}: ground truth"] = gt_depth_map 151 | test_artifacts[f"{name.name}: reconstructed"] = re_depth_map 152 | 153 | if name not in EXPERIMENT.train: 154 | total_avg_l1 += avg_l1 155 | total_u3_l1 += u3_l1 156 | 157 | # Log metrics and save the best model 158 | iterate_values = { 159 | "train_error": total_error / len(train_loader), 160 | "test_L1": total_avg_l1 / NUM_TEST_SETS, 161 | "test_L1_under3": total_u3_l1 / NUM_TEST_SETS, 162 | **test_artifacts, 163 | } 164 | 165 | if total_u3_l1 / NUM_TEST_SETS < previous_test_error: 166 | previous_test_error = total_u3_l1 / NUM_TEST_SETS 167 | torch.save(model.state_dict(), f"{CHECKPOINT_DIR}/best.pt") 168 | 169 | wandb.log(iterate_values) 170 | 171 | if epoch % args.SAVE_EVERY == 0: 172 | torch.save(model.state_dict(), f"{CHECKPOINT_DIR}/recon_{epoch}.pt") 173 | 174 | end_time = time.monotonic() 175 | print(f"Epoch={epoch}: loss={total_error / len(train_loader)} :: {end_time - start_time:.3f}s") 176 | 177 | # Save the final model 178 | torch.save(model.state_dict(), f"{CHECKPOINT_DIR}/recon_end.pt") 179 | wandb.finish() 180 | -------------------------------------------------------------------------------- /scripts/config.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from enum import Enum 3 | from typing import List, Type 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from unet import U_Net 9 | 10 | L1 = nn.L1Loss() 11 | L2 = nn.MSELoss() 12 | 13 | # Custom Loss Functions 14 | def weighted_mse_loss(input: torch.Tensor, target: torch.Tensor, weight: torch.Tensor): 15 | return torch.sum(weight * (input - target) ** 2) / torch.sum(weight) 16 | 17 | def weighted_depth_loss(input: torch.Tensor, target: torch.Tensor, weight: float): 18 | diff = torch.abs(input - target) 19 | weights = torch.where(target <= 3, torch.exp(weight * (3 - target)), torch.exp(-weight * (target - 3))) 20 | loss = weights * diff 21 | return loss.mean() 22 | 23 | def weighted_l1_loss(input: torch.Tensor, target: torch.Tensor): 24 | mask = target <= 3.0 25 | weights = torch.where(mask, torch.ones_like(target), (3.0 / target)) 26 | diff = L1(input, target) 27 | loss = weights * diff 28 | return loss.mean() 29 | 30 | def silog_loss(input: torch.Tensor, target: torch.Tensor, variance_focus: float = 0.85): 31 | # Only compute the loss on non-null pixels from the ground-truth depth-map 32 | non_zero_mask = (target > 0) & (input > 0) 33 | d = torch.log(input[non_zero_mask]) - torch.log(target[non_zero_mask]) 34 | return torch.sqrt((d ** 2).mean() - variance_focus * (d.mean() ** 2)) * 10.0 35 | 36 | 37 | # Enum definitions for datasets 38 | ## NOTE: Edit this as per dataset usage. For e.g for testing on ICL-NUIM office trajectory 1 edit DatasetName = Enum("DatasetName", ["office_traj1"]) 39 | ## and change scale_factor accordingly 40 | 41 | DepthStyle = Enum("DepthStyle", ["metric", "phi", "normalized"]) 42 | DatasetName = Enum("DatasetName", ["nyu_data"]) 43 | DatasetName_blender = Enum("DatasetName", ["DiningRoom"]) 44 | DatasetName_train_icl = Enum("DatasetName", ["nyu_data"]) 45 | DatasetName_train_blender = Enum("DatasetName", ["LivingRoom1"]) 46 | 47 | # Metric units for convenience 48 | cm = 1e-2 49 | mm = 1e-3 50 | 51 | # Abstract base class for depth space representation 52 | class DepthSpace(ABC): 53 | """Convert metric depth maps to and from model output maps""" 54 | 55 | @abstractmethod 56 | def output_to_metric(self, out: torch.Tensor) -> torch.Tensor: 57 | ... 58 | 59 | @abstractmethod 60 | def metric_to_output(self, met: torch.Tensor) -> torch.Tensor: 61 | ... 62 | 63 | def __repr__(self): 64 | return self.__class__.__name__ 65 | 66 | # Depth representation classes 67 | class MetricDepth(DepthSpace): 68 | """Basic metric depth representation""" 69 | def output_to_metric(self, out: torch.Tensor): 70 | return out 71 | 72 | def metric_to_output(self, met: torch.Tensor): 73 | return met 74 | 75 | class NormalizedDepth(DepthSpace): 76 | """Represent depth as 0-1 based on a maximum distance.""" 77 | def __init__(self, max_value: float = 7): 78 | self.max_value = max_value 79 | 80 | def output_to_metric(self, out: torch.Tensor): 81 | return out * self.max_value 82 | 83 | def metric_to_output(self, met: torch.Tensor): 84 | return met / self.max_value 85 | 86 | def __repr__(self): 87 | return f"{self.__class__.__name__}(max_value={self.max_value})" 88 | 89 | class DiopterDepth(DepthSpace): 90 | """Represent depth as defocus terms away from the focal plane.""" 91 | def __init__(self, f_number: float = 17, focal_length: float = 50 * mm, focus_distance: float = 85 * cm, max_diopter: int = 15): 92 | self.f_number = f_number 93 | self.focal_length = focal_length 94 | self.R = focal_length / (2 * f_number) 95 | self.focus_distance = focus_distance 96 | self.k = 2 * torch.pi / 530e-9 97 | self.max_diopter = max_diopter 98 | 99 | def output_to_metric(self, out: torch.Tensor): 100 | out2 = out * self.max_diopter * 2 - self.max_diopter 101 | Wm = out2 / self.k 102 | depth = 1 / (1 / self.focus_distance + 2 * Wm / self.R**2) 103 | return depth 104 | 105 | def metric_to_output(self, met: torch.Tensor): 106 | depth = torch.clamp(met, 1 * mm, None) # prevent 0 depth 107 | inv = 1 / depth 108 | sub = inv - 1 / self.focus_distance 109 | div = sub / 2 110 | Wm = div * self.R**2 111 | Phi2 = Wm * self.k 112 | return (Phi2 + self.max_diopter) / (2 * self.max_diopter) 113 | 114 | def __repr__(self): 115 | return f"{self.__class__.__name__}(f_number={self.f_number}, focal_length={self.focal_length*100:.1f}cm, focus_distance={self.focus_distance*100:.1f}cm, max_diopter={self.max_diopter})" 116 | 117 | # Experiment class to handle different models and datasets 118 | def get_experiment(class_name: str): 119 | return [cls for cls in Experiment.__subclasses__() if cls.__name__ == class_name and cls != Experiment][0]() 120 | 121 | class Experiment(ABC): 122 | name: str 123 | model: Type[nn.Module] 124 | depth: DepthSpace 125 | epochs: int 126 | batch_size: int 127 | learning_rate: float 128 | train: List[DatasetName] 129 | coded: str 130 | 131 | @abstractmethod 132 | def compute_loss(self, ground_truth: torch.Tensor, reconstruction: torch.Tensor, idx: int = 0): 133 | ... 134 | 135 | @abstractmethod 136 | def post_forward(self, reconstruction: torch.Tensor): 137 | ... 138 | 139 | def __repr__(self): 140 | out = f"{self.__class__.__name__}(\n" 141 | out += f"\tmodel={self.model.__name__}\n" 142 | out += f"\tdepth={self.depth!r}\n" 143 | out += f"\tLR={self.learning_rate}\n" 144 | out += f"\tepochs={self.epochs}\n" 145 | out += f"\tbatch-size={self.batch_size}\n" 146 | out += f"\ttrain-set={[item.name for item in self.train]}\n" 147 | out += f"\tcoded={self.coded!r}\n" 148 | out += ")" 149 | return out 150 | 151 | 152 | class SimpleDiopter(Experiment): 153 | model = U_Net 154 | depth = DiopterDepth(max_diopter=13) 155 | epochs = 200 156 | batch_size = 8 157 | learning_rate = 1e-4 158 | coded = "Codedphasecam-27Linear" 159 | 160 | def post_forward(self, reconstruction: torch.Tensor) -> torch.Tensor: 161 | return torch.sigmoid(reconstruction) 162 | 163 | def compute_loss(self, ground_truth: torch.Tensor, reconstruction: torch.Tensor, idx: int = 0) -> torch.Tensor: 164 | return L2(reconstruction[:, 0], ground_truth) 165 | 166 | class L1LossBlenderNYU(Experiment): 167 | model = U_Net 168 | depth = MetricDepth() 169 | epochs = 80 170 | batch_size = 3 171 | learning_rate = 1e-4 172 | train = [DatasetName_train_icl.nyu_data, DatasetName_train_blender.LivingRoom1] 173 | coded = "Codedphasecam-27Linear" 174 | 175 | def post_forward(self, reconstruction: torch.Tensor): 176 | return reconstruction 177 | 178 | def compute_loss(self, ground_truth: torch.Tensor, reconstruction: torch.Tensor, idx: int = 0): 179 | reconstruction_metric = self.depth.output_to_metric(reconstruction) 180 | ground_truth_metric = self.depth.output_to_metric(ground_truth) 181 | return L1(reconstruction_metric[:, 0], ground_truth_metric) 182 | 183 | class MetricWeightedLossBlenderNYU(Experiment): 184 | model = U_Net 185 | depth = MetricDepth() 186 | epochs = 80 187 | batch_size = 3 188 | learning_rate = 1e-4 189 | train = [DatasetName_train_blender.LivingRoom1, DatasetName_train_icl.nyu_data] 190 | coded = "Codedphasecam-27Linear" 191 | 192 | def post_forward(self, reconstruction: torch.Tensor): 193 | return reconstruction 194 | 195 | def compute_loss(self, ground_truth: torch.Tensor, reconstruction: torch.Tensor, idx: int = 0): 196 | reconstruction_metric = self.depth.output_to_metric(reconstruction) 197 | ground_truth_metric = self.depth.output_to_metric(ground_truth) 198 | return weighted_mse_loss(reconstruction_metric[:, 0], ground_truth_metric, 2 ** (-0.3 * ground_truth_metric)) 199 | 200 | # Example class using SILog loss 201 | class SILossLiving_Office(Experiment): 202 | model = U_Net 203 | depth = MetricDepth() 204 | epochs = 80 205 | batch_size = 1 206 | learning_rate = 1e-4 207 | coded = "Codedphasecam-27Linear" 208 | 209 | def post_forward(self, reconstruction: torch.Tensor): 210 | return reconstruction 211 | 212 | def compute_loss(self, ground_truth: torch.Tensor, reconstruction: torch.Tensor, idx: int = 0): 213 | reconstruction_metric = self.depth.output_to_metric(reconstruction) 214 | ground_truth_metric = self.depth.output_to_metric(ground_truth) 215 | return silog_loss(reconstruction_metric[:, 0], ground_truth_metric) 216 | 217 | -------------------------------------------------------------------------------- /scripts/coded_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import torch 5 | import tqdm 6 | import argparse 7 | 8 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | def tensor_to_numpy_image(tensor: torch.Tensor): 12 | """ 13 | Convert a PyTorch tensor to a NumPy image. 14 | 15 | Args: 16 | tensor (torch.Tensor): Input tensor with shape (C, H, W). 17 | 18 | Returns: 19 | np.ndarray: Output image in NumPy format. 20 | """ 21 | out = tensor.moveaxis(0, -1) 22 | return torch.clamp(out, 0, 255).to(torch.uint8).cpu().numpy() 23 | 24 | class Camera: 25 | def __init__(self, name: str, psfs: torch.Tensor, depth_layers: torch.Tensor, use_nonlinear: bool = False): 26 | """ 27 | Initialize the Camera object. 28 | 29 | Args: 30 | name (str): Name of the camera. 31 | psfs (torch.Tensor): Point spread functions (PSFs) tensor. 32 | depth_layers (torch.Tensor): Depth layers tensor. 33 | use_nonlinear (bool): Whether to use non-linear processing. 34 | """ 35 | self.name = f"Coded{name}NonLinear" if use_nonlinear else f"Coded{name}Linear" 36 | n_depths, n_channels, height, width = psfs.shape 37 | 38 | if n_channels != 3 or n_depths != depth_layers.numel() or width != height or width % 2 == 0: 39 | raise ValueError(f"PSF has wrong shape: {psfs.shape}") 40 | 41 | self.psfs = psfs.to(device) 42 | self.depth_layers = depth_layers.to(device) 43 | self.use_nonlinear = use_nonlinear 44 | self.padding = width // 2 45 | 46 | def capture(self, img: np.ndarray, metric_depth: np.ndarray) -> np.ndarray: 47 | """ 48 | Capture an image with depth-dependent processing. 49 | 50 | Args: 51 | img (np.ndarray): Input RGB image. 52 | metric_depth (np.ndarray): Metric depth map. 53 | 54 | Returns: 55 | np.ndarray: Processed image. 56 | """ 57 | image = torch.from_numpy(img).moveaxis(-1, 0).to(torch.float32).to(device) 58 | depth = torch.from_numpy(metric_depth).to(device) 59 | 60 | if self.use_nonlinear: 61 | coded = self.nonlinear(image, depth) 62 | else: 63 | coded = self.linear(image, depth) 64 | return tensor_to_numpy_image(coded) 65 | 66 | def get_depth_layers(self, depth_map: torch.Tensor) -> torch.Tensor: 67 | """ 68 | Get depth layers from a depth map. 69 | 70 | Args: 71 | depth_map (torch.Tensor): Input depth map. 72 | 73 | Returns: 74 | torch.Tensor: Quantized depth layers. 75 | """ 76 | quantized_depth = torch.bucketize(depth_map, self.depth_layers) 77 | return torch.stack([quantized_depth == j for j in range(len(self.depth_layers))]) 78 | 79 | def linear(self, image: torch.Tensor, depth_map: torch.Tensor) -> torch.Tensor: 80 | """ 81 | Perform linear depth-dependent convolution. 82 | 83 | Args: 84 | image (torch.Tensor): Input image. 85 | depth_map (torch.Tensor): Depth map. 86 | 87 | Returns: 88 | torch.Tensor: Convolved image. 89 | """ 90 | depth_mask = self.get_depth_layers(depth_map) 91 | 92 | return torch.stack( 93 | [ 94 | torch.sum( 95 | torch.nn.functional.conv2d( 96 | image[None, channel:channel+1], 97 | self.psfs[:, channel:channel+1], 98 | stride=1, 99 | padding=self.padding, 100 | ) * depth_mask, dim=1 101 | ) 102 | for channel in range(3) 103 | ], dim=1 104 | )[0] 105 | 106 | def single_psf_convolution(self, image: torch.Tensor, depth_idx: int, channel_idx: int) -> torch.Tensor: 107 | """ 108 | Convolve image with a single PSF. 109 | 110 | Args: 111 | image (torch.Tensor): Input image. 112 | depth_idx (int): Depth index. 113 | channel_idx (int): Channel index. 114 | 115 | Returns: 116 | torch.Tensor: Convolved image. 117 | """ 118 | return torch.nn.functional.conv2d( 119 | image, 120 | self.psfs[depth_idx:depth_idx+1, channel_idx:channel_idx+1], 121 | stride=1, 122 | padding=self.padding, 123 | ) 124 | 125 | def nonlinear(self, img: torch.Tensor, depth_map: torch.Tensor, eps=1e-6) -> torch.Tensor: 126 | """ 127 | Perform non-linear blurring based on Ikoma et al. 2021 equation 5. 128 | 129 | Args: 130 | img (torch.Tensor): Input image. 131 | depth_map (torch.Tensor): Depth map. 132 | eps (float): Small epsilon value to prevent division by zero. 133 | 134 | Returns: 135 | torch.Tensor: Blurred image. 136 | """ 137 | depth_mask = self.get_depth_layers(depth_map) 138 | K, _, _ = depth_mask.shape 139 | depth_mask = depth_mask.to(torch.float) 140 | depth_mask = torch.flip(depth_mask, dims=(0,)) 141 | 142 | out = torch.zeros_like(img) 143 | img = img.to(torch.float) / 255.0 144 | depth_sum = torch.cumsum(depth_mask, dim=0) 145 | 146 | for channel in range(3): 147 | layered = img[channel:channel+1] * depth_mask 148 | for k in range(K): 149 | E_k = self.single_psf_convolution(depth_sum[k][None, None], k, channel) 150 | l_k = self.single_psf_convolution(layered[k][None, None], k, channel) / (E_k + eps) 151 | for kp in range(k + 1, K): 152 | E_kp = self.single_psf_convolution(depth_sum[kp][None, None], kp, channel) 153 | a_kp = 1 - self.single_psf_convolution(depth_mask[kp][None, None], kp, channel) / (E_kp + eps) 154 | l_k = l_k * a_kp 155 | out[channel] = out[channel] + l_k 156 | 157 | return torch.clamp(out * 255, 0, 255) 158 | 159 | def process_folder(self, root: str, is_blender: bool = False, scale_factor: float = 5000): 160 | """ 161 | Process a folder of images and depths. 162 | 163 | Args: 164 | root (str): Root directory. 165 | is_blender (bool): Whether the depth images are in Blender's EXR format. 166 | scale_factor (float): Scale factor for depth normalization. 167 | """ 168 | depth_folder = os.path.join(root, "depth") 169 | image_folder = os.path.join(root, "rgb") 170 | output_folder = os.path.join(root, self.name) 171 | os.makedirs(output_folder, exist_ok=True) 172 | 173 | files = os.listdir(image_folder) 174 | max_depth_value = 0 175 | 176 | for idx, file in tqdm.tqdm(enumerate(files), total=len(files), desc=root): 177 | 178 | image_file = os.path.join(image_folder, file) 179 | depth_file = os.path.join(depth_folder, file).replace(".png", ".exr") if is_blender else os.path.join(depth_folder, file) 180 | 181 | image_bgr = cv2.imread(image_file) 182 | image = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) 183 | raw_depth = cv2.imread(depth_file, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) 184 | 185 | if raw_depth is None: 186 | print(f"{file} is missing a depth file") 187 | continue 188 | 189 | metric_depth = raw_depth[:, :, 0] if is_blender else raw_depth / scale_factor 190 | 191 | coded_image_rgb = self.capture(image.astype(np.float32), metric_depth) 192 | coded_image_bgr = cv2.cvtColor(coded_image_rgb, cv2.COLOR_RGB2BGR) 193 | cv2.imwrite(os.path.join(output_folder, file), coded_image_bgr) 194 | max_depth_value = max(max_depth_value, np.max(metric_depth)) 195 | 196 | print(f"Max Depth Value in the folder: {max_depth_value}") 197 | 198 | def main(): 199 | parser = argparse.ArgumentParser(description="Process images with depth-dependent processing.") 200 | parser.add_argument("--root", type=str, required=True, help="Root directory containing the datasets.") 201 | parser.add_argument("--is_blender", action="store_true", help="Use Blender's EXR depth format.") 202 | parser.add_argument("--scale_factor", type=float, default=5000, help="Scale factor for depth normalization.") 203 | args = parser.parse_args() 204 | 205 | # Fixed paths and parameters 206 | psf_path = os.path.join("..", "phasecam-psfs-27.npy") 207 | depth_layers = torch.linspace(0.5, 6, 27) 208 | 209 | camera = Camera( 210 | "phasecam-27", 211 | torch.from_numpy(np.moveaxis(np.load(psf_path), -1, 1)), 212 | depth_layers, 213 | not args.is_blender, 214 | ) 215 | 216 | camera.process_folder( 217 | args.root, 218 | is_blender=args.is_blender, 219 | scale_factor=args.scale_factor, 220 | ) 221 | 222 | if __name__ == "__main__": 223 | main() 224 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: codedvo 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=conda_forge 7 | - _openmp_mutex=4.5=2_gnu 8 | - alsa-lib=1.2.12=h4ab18f5_0 9 | - aom=3.9.1=hac33072_0 10 | - attr=2.5.1=h166bdaf_1 11 | - blosc=1.21.6=hef167b5_0 12 | - brotli=1.1.0=hd590300_1 13 | - brotli-bin=1.1.0=hd590300_1 14 | - brunsli=0.1=h9c3ff4c_0 15 | - bzip2=1.0.8=h4bc722e_7 16 | - c-ares=1.32.2=h4bc722e_0 17 | - c-blosc2=2.15.0=h6d6b9e4_1 18 | - ca-certificates=2024.7.4=hbcca054_0 19 | - cairo=1.18.0=hbb29018_2 20 | - certifi=2024.7.4=pyhd8ed1ab_0 21 | - charls=2.4.2=h59595ed_0 22 | - colorama=0.4.6=pyhd8ed1ab_0 23 | - contourpy=1.2.1=py311h9547e67_0 24 | - cycler=0.12.1=pyhd8ed1ab_0 25 | - dav1d=1.2.1=hd590300_0 26 | - dbus=1.13.6=h5008d03_3 27 | - double-conversion=3.3.0=h59595ed_0 28 | - expat=2.6.2=h59595ed_0 29 | - ffmpeg=6.1.1=gpl_he2f97eb_114 30 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 31 | - font-ttf-inconsolata=3.000=h77eed37_0 32 | - font-ttf-source-code-pro=2.038=h77eed37_0 33 | - font-ttf-ubuntu=0.83=h77eed37_2 34 | - fontconfig=2.14.2=h14ed4e7_0 35 | - fonts-conda-ecosystem=1=0 36 | - fonts-conda-forge=1=0 37 | - fonttools=4.53.1=py311h61187de_0 38 | - freeglut=3.2.2=ha6d2627_3 39 | - freetype=2.12.1=h267a509_2 40 | - fribidi=1.0.10=h36c2ea0_0 41 | - gettext=0.22.5=h59595ed_2 42 | - gettext-tools=0.22.5=h59595ed_2 43 | - giflib=5.2.2=hd590300_0 44 | - glib=2.80.2=hf974151_0 45 | - glib-tools=2.80.2=hb6ce0ca_0 46 | - gmp=6.3.0=hac33072_2 47 | - gnutls=3.7.9=hb077bed_0 48 | - graphite2=1.3.13=h59595ed_1003 49 | - gst-plugins-base=1.24.5=hbaaba92_0 50 | - gstreamer=1.24.5=haf2f30d_0 51 | - harfbuzz=8.5.0=hfac3d4d_0 52 | - hdf5=1.14.3=nompi_hdf9ad27_105 53 | - icu=73.2=h59595ed_0 54 | - imagecodecs=2024.6.1=py311hb8791aa_2 55 | - imageio=2.34.2=pyh12aca89_0 56 | - imath=3.1.11=hfc55251_0 57 | - importlib-metadata=8.0.0=pyha770c72_0 58 | - jasper=4.2.4=h536e39c_0 59 | - jxrlib=1.1=hd590300_3 60 | - keyutils=1.6.1=h166bdaf_0 61 | - kiwisolver=1.4.5=py311h9547e67_1 62 | - krb5=1.21.3=h659f571_0 63 | - lame=3.100=h166bdaf_1003 64 | - lazy_loader=0.4=pyhd8ed1ab_0 65 | - lcms2=2.16=hb7c19ff_0 66 | - ld_impl_linux-64=2.40=hf3520f5_7 67 | - lerc=4.0.0=h27087fc_0 68 | - libabseil=20240116.2=cxx17_he02047a_1 69 | - libaec=1.1.3=h59595ed_0 70 | - libasprintf=0.22.5=h661eb56_2 71 | - libasprintf-devel=0.22.5=h661eb56_2 72 | - libass=0.17.1=h8fe9dca_1 73 | - libavif16=1.0.4=hd2f8ffe_4 74 | - libblas=3.9.0=22_linux64_openblas 75 | - libbrotlicommon=1.1.0=hd590300_1 76 | - libbrotlidec=1.1.0=hd590300_1 77 | - libbrotlienc=1.1.0=hd590300_1 78 | - libcap=2.69=h0f662aa_0 79 | - libcblas=3.9.0=22_linux64_openblas 80 | - libclang-cpp15=15.0.7=default_h127d8a8_5 81 | - libclang-cpp18.1=18.1.8=default_h36b48a3_0 82 | - libclang13=18.1.8=default_h6ae225f_0 83 | - libcups=2.3.3=h4637d8d_4 84 | - libcurl=8.8.0=hca28451_1 85 | - libdeflate=1.20=hd590300_0 86 | - libdrm=2.4.122=h4ab18f5_0 87 | - libedit=3.1.20191231=he28a2e2_2 88 | - libev=4.33=hd590300_2 89 | - libevent=2.1.12=hf998b51_1 90 | - libexpat=2.6.2=h59595ed_0 91 | - libffi=3.4.2=h7f98852_5 92 | - libflac=1.4.3=h59595ed_0 93 | - libgcc-ng=14.1.0=h77fa898_0 94 | - libgcrypt=1.11.0=h4ab18f5_0 95 | - libgettextpo=0.22.5=h59595ed_2 96 | - libgettextpo-devel=0.22.5=h59595ed_2 97 | - libgfortran-ng=14.1.0=h69a702a_0 98 | - libgfortran5=14.1.0=hc5f4f2c_0 99 | - libglib=2.80.2=hf974151_0 100 | - libglu=9.0.0=ha6d2627_1004 101 | - libgomp=14.1.0=h77fa898_0 102 | - libgpg-error=1.50=h4f305b6_0 103 | - libhwloc=2.11.1=default_hecaa2ac_1000 104 | - libhwy=1.1.0=h00ab1b0_0 105 | - libiconv=1.17=hd590300_2 106 | - libidn2=2.3.7=hd590300_0 107 | - libjpeg-turbo=3.0.0=hd590300_1 108 | - libjxl=0.10.3=h66b40c8_0 109 | - liblapack=3.9.0=22_linux64_openblas 110 | - liblapacke=3.9.0=22_linux64_openblas 111 | - libllvm15=15.0.7=hb3ce162_4 112 | - libllvm18=18.1.8=hc9dba70_0 113 | - libnghttp2=1.58.0=h47da74e_1 114 | - libnsl=2.0.1=hd590300_0 115 | - libogg=1.3.5=h4ab18f5_0 116 | - libopenblas=0.3.27=pthreads_hac2b453_1 117 | - libopencv=4.10.0=qt6_py311h5a6cdeb_601 118 | - libopenvino=2024.2.0=h2da1b83_1 119 | - libopenvino-auto-batch-plugin=2024.2.0=hb045406_1 120 | - libopenvino-auto-plugin=2024.2.0=hb045406_1 121 | - libopenvino-hetero-plugin=2024.2.0=h5c03a75_1 122 | - libopenvino-intel-cpu-plugin=2024.2.0=h2da1b83_1 123 | - libopenvino-intel-gpu-plugin=2024.2.0=h2da1b83_1 124 | - libopenvino-intel-npu-plugin=2024.2.0=he02047a_1 125 | - libopenvino-ir-frontend=2024.2.0=h5c03a75_1 126 | - libopenvino-onnx-frontend=2024.2.0=h07e8aee_1 127 | - libopenvino-paddle-frontend=2024.2.0=h07e8aee_1 128 | - libopenvino-pytorch-frontend=2024.2.0=he02047a_1 129 | - libopenvino-tensorflow-frontend=2024.2.0=h39126c6_1 130 | - libopenvino-tensorflow-lite-frontend=2024.2.0=he02047a_1 131 | - libopus=1.3.1=h7f98852_1 132 | - libpciaccess=0.18=hd590300_0 133 | - libpng=1.6.43=h2797004_0 134 | - libpq=16.3=ha72fbe1_0 135 | - libprotobuf=4.25.3=h08a7969_0 136 | - libsndfile=1.2.2=hc60ed4a_1 137 | - libsqlite=3.46.0=hde9e2c9_0 138 | - libssh2=1.11.0=h0841786_0 139 | - libstdcxx-ng=14.1.0=hc0a3c3a_0 140 | - libsystemd0=255=h3516f8a_1 141 | - libtasn1=4.19.0=h166bdaf_0 142 | - libtiff=4.6.0=h1dd3fc0_3 143 | - libunistring=0.9.10=h7f98852_0 144 | - libuuid=2.38.1=h0b41bf4_0 145 | - libva=2.22.0=hb711507_0 146 | - libvorbis=1.3.7=h9c3ff4c_0 147 | - libvpx=1.14.1=hac33072_0 148 | - libwebp-base=1.4.0=hd590300_0 149 | - libxcb=1.16=hd590300_0 150 | - libxcrypt=4.4.36=hd590300_1 151 | - libxkbcommon=1.7.0=h2c5496b_1 152 | - libxml2=2.12.7=h4c95cb1_3 153 | - libzlib=1.3.1=h4ab18f5_1 154 | - libzopfli=1.0.3=h9c3ff4c_0 155 | - lz4-c=1.9.4=hcb278e6_0 156 | - matplotlib=3.9.1=py311h38be061_0 157 | - matplotlib-base=3.9.1=py311hffb96ce_0 158 | - mpg123=1.32.6=h59595ed_0 159 | - munkres=1.1.4=pyh9f0ad1d_0 160 | - mysql-common=8.3.0=hf1915f5_4 161 | - mysql-libs=8.3.0=hca2cd23_4 162 | - ncurses=6.5=h59595ed_0 163 | - nettle=3.9.1=h7ab15ed_0 164 | - networkx=3.3=pyhd8ed1ab_1 165 | - nspr=4.35=h27087fc_0 166 | - nss=3.102=h593d115_0 167 | - numpy=2.0.0=py311h1461c94_0 168 | - ocl-icd=2.3.2=hd590300_1 169 | - opencv=4.10.0=qt6_py311hc414901_601 170 | - openexr=3.2.2=haf962dd_1 171 | - openh264=2.4.1=h59595ed_0 172 | - openjpeg=2.5.2=h488ebb8_0 173 | - openssl=3.3.1=h4ab18f5_1 174 | - p11-kit=0.24.1=hc5aa10d_0 175 | - packaging=24.1=pyhd8ed1ab_0 176 | - pcre2=10.43=hcad00b1_0 177 | - pillow=10.4.0=py311h82a398c_0 178 | - pip=24.0=pyhd8ed1ab_0 179 | - pixman=0.43.2=h59595ed_0 180 | - ply=3.11=pyhd8ed1ab_2 181 | - pthread-stubs=0.4=h36c2ea0_1001 182 | - pugixml=1.14=h59595ed_0 183 | - pulseaudio-client=17.0=hb77b528_0 184 | - py-opencv=4.10.0=qt6_py311h074fb97_601 185 | - pyparsing=3.1.2=pyhd8ed1ab_0 186 | - pyqt=5.15.9=py311hf0fb5b6_5 187 | - pyqt5-sip=12.12.2=py311hb755f60_5 188 | - python=3.11.9=hb806964_0_cpython 189 | - python-dateutil=2.9.0=pyhd8ed1ab_0 190 | - python_abi=3.11=4_cp311 191 | - pywavelets=1.6.0=py311h18e1886_0 192 | - qhull=2020.2=h434a139_5 193 | - qt-main=5.15.8=ha2b5568_22 194 | - qt6-main=6.7.2=h0f8cd61_2 195 | - rav1e=0.6.6=he8a937b_2 196 | - readline=8.2=h8228510_1 197 | - scikit-image=0.24.0=py311h14de704_1 198 | - scipy=1.14.0=py311h517d4fd_1 199 | - setuptools=70.3.0=pyhd8ed1ab_0 200 | - sip=6.7.12=py311hb755f60_0 201 | - six=1.16.0=pyh6c4a22f_0 202 | - snappy=1.2.1=ha2e4443_0 203 | - svt-av1=2.1.0=hac33072_0 204 | - tbb=2021.12.0=h434a139_3 205 | - tifffile=2024.7.2=pyhd8ed1ab_0 206 | - tk=8.6.13=noxft_h4845f30_101 207 | - toml=0.10.2=pyhd8ed1ab_0 208 | - tomli=2.0.1=pyhd8ed1ab_0 209 | - tornado=6.4.1=py311h331c9d8_0 210 | - tqdm=4.66.4=pyhd8ed1ab_0 211 | - tzdata=2024a=h0c530f3_0 212 | - wayland=1.23.0=h5291e77_0 213 | - wayland-protocols=1.36=hd8ed1ab_0 214 | - wheel=0.43.0=pyhd8ed1ab_1 215 | - x264=1!164.3095=h166bdaf_2 216 | - x265=3.5=h924138e_3 217 | - xcb-util=0.4.1=hb711507_2 218 | - xcb-util-cursor=0.1.4=h4ab18f5_2 219 | - xcb-util-image=0.4.0=hb711507_2 220 | - xcb-util-keysyms=0.4.1=hb711507_0 221 | - xcb-util-renderutil=0.3.10=hb711507_0 222 | - xcb-util-wm=0.4.2=hb711507_0 223 | - xkeyboard-config=2.42=h4ab18f5_0 224 | - xorg-fixesproto=5.0=h7f98852_1002 225 | - xorg-inputproto=2.3.2=h7f98852_1002 226 | - xorg-kbproto=1.0.7=h7f98852_1002 227 | - xorg-libice=1.1.1=hd590300_0 228 | - xorg-libsm=1.2.4=h7391055_0 229 | - xorg-libx11=1.8.9=hb711507_1 230 | - xorg-libxau=1.0.11=hd590300_0 231 | - xorg-libxdmcp=1.1.3=h7f98852_0 232 | - xorg-libxext=1.3.4=h0b41bf4_2 233 | - xorg-libxfixes=5.0.3=h7f98852_1004 234 | - xorg-libxi=1.7.10=h7f98852_0 235 | - xorg-libxrender=0.9.11=hd590300_0 236 | - xorg-renderproto=0.11.1=h7f98852_1002 237 | - xorg-xextproto=7.3.0=h0b41bf4_1003 238 | - xorg-xf86vidmodeproto=2.3.1=h7f98852_1002 239 | - xorg-xproto=7.0.31=h7f98852_1007 240 | - xz=5.2.6=h166bdaf_0 241 | - zfp=1.0.1=hac33072_1 242 | - zipp=3.19.2=pyhd8ed1ab_0 243 | - zlib=1.3.1=h4ab18f5_1 244 | - zlib-ng=2.2.1=he02047a_0 245 | - zstd=1.5.6=ha6fb4c9_0 246 | - pip: 247 | - charset-normalizer==3.3.2 248 | - click==8.1.7 249 | - docker-pycreds==0.4.0 250 | - filelock==3.13.1 251 | - fsspec==2024.2.0 252 | - gitdb==4.0.11 253 | - gitpython==3.1.43 254 | - idna==3.7 255 | - jinja2==3.1.3 256 | - markupsafe==2.1.5 257 | - mpmath==1.3.0 258 | - nvidia-cublas-cu11==11.11.3.6 259 | - nvidia-cuda-cupti-cu11==11.8.87 260 | - nvidia-cuda-nvrtc-cu11==11.8.89 261 | - nvidia-cuda-runtime-cu11==11.8.89 262 | - nvidia-cudnn-cu11==8.7.0.84 263 | - nvidia-cufft-cu11==10.9.0.58 264 | - nvidia-curand-cu11==10.3.0.86 265 | - nvidia-cusolver-cu11==11.4.1.48 266 | - nvidia-cusparse-cu11==11.7.5.86 267 | - nvidia-nccl-cu11==2.20.5 268 | - nvidia-nvtx-cu11==11.8.86 269 | - platformdirs==4.2.2 270 | - protobuf==5.27.2 271 | - psutil==6.0.0 272 | - pyyaml==6.0.1 273 | - requests==2.32.3 274 | - sentry-sdk==2.10.0 275 | - setproctitle==1.3.3 276 | - smmap==5.0.1 277 | - sympy==1.12 278 | - torch==2.3.1+cu118 279 | - torchaudio==2.3.1+cu118 280 | - torchvision==0.18.1+cu118 281 | - triton==2.3.1 282 | - typing-extensions==4.9.0 283 | - urllib3==2.2.2 284 | - wandb==0.17.4 285 | --------------------------------------------------------------------------------