├── LICENSE.md ├── README.md ├── checkpoints └── __init__.py ├── conda_env.txt ├── data └── __init__.py ├── dataloader ├── dataloader.py └── transforms.py ├── debug.sh ├── filenames ├── 4DLFB_test.txt ├── 4DLFB_train.txt └── 4DLFB_val.txt ├── model.py ├── nets ├── MaskToFNet.py ├── pretrained │ └── __init__.py └── refinement │ ├── Refinement.py │ ├── RefinementBig.py │ ├── RefinementGlobal.py │ └── RefinementHalf.py ├── reconstruction.ipynb ├── requirements.txt ├── train.py └── utils ├── barcode_masks └── __init__.py ├── chamfer_distance ├── __init__.py ├── chamfer_distance.cpp ├── chamfer_distance.cu └── chamfer_distance.py ├── file_io.py ├── tof.py └── utils.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Ilya Chugunov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mask-ToF 2 | 3 | This is the code for the CVPR 2021 work: [Mask-ToF: Learning Microlens Masks for Flying Pixel Correction in Time-of-Flight Imaging](https://light.princeton.edu/publication/mask-tof/) 4 | 5 | If you use bits and bobs of this code, or find inspiration from it, consider citing the paper: 6 | 7 | ``` 8 | @article{chugunov2021masktof, 9 | title={Mask-ToF: Learning Microlens Masks for Flying Pixel Correction in Time-of-Flight Imaging}, 10 | author={Chugunov, Ilya and Baek, Seung-Hwan and Fu, Qiang and Heidrich, Wolfgang and Heide, Felix}, 11 | journal={The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 12 | year={2021} 13 | } 14 | ``` 15 | ### Requirements: 16 | * Developed using PyTorch 1.7.0 on Ubuntu x64 machine 17 | * Condensed requirements in `/requirements.txt` 18 | * Full frozen environment can be found in `/conda_env.txt`, but many of these libraries are not necessary to run this code 19 | 20 | ### Data: 21 | * Download `test` and `additional` from https://lightfield-analysis.uni-konstanz.de/ 22 | * Place samples directly in `/data` 23 | * Confirm that the .txts in `/filenames` match the data structure (update these if you add/reorganize data) 24 | * *Optional:* Download pretrained network checkpoints from [this drive link](https://drive.google.com/file/d/1y6jvOHeZ0483NNbW1-Kks5GZqluS2mZA/view?usp=sharing) 25 | * Extract `pretrained.zip` into `\nets\pretrained` 26 | * *Optional:* Download 'barcode' masks from [this drive link](https://drive.google.com/file/d/1-aYWfIilACarQkAqw6GFSHDw4KpYs-Ry/view?usp=sharing) 27 | * Extract `barcode_masks.zip` into `\utils\barcode_masks` 28 | 29 | ### Project Structure: 30 | ```cpp 31 | MaskToF 32 | ├── checkpoints 33 | │   └── // folder for network checkpoints 34 | ├── data 35 | │   └── // folder for training/test data 36 | ├── dataloader 37 | │   ├── dataloader.py // pytorch dataloader for lightfields + depth 38 | │   └── transforms.py // data augmentations and code to generate image/depth patches 39 | ├── filenames 40 | │   └── // .txt files pointing to data locations 41 | ├── model.py // wrapper class for training the network: 42 | │ // -> load data, calculate loss, print to tensorboard, save network state 43 | ├── nets 44 | │   ├── MaskToFNet.py // the meat of MaskToF, class for data simulation and learning a mask: 45 | │ │ // -> simulate measurements, pass to refinement network, return depth 46 | │ ├── pretrained 47 | │ │   └── // folder for pretrained networks 48 | │   └── refinement 49 | │   └── // refinement network architectures 50 | ├── train.py // wrapper class for arg parsing and setting up training loop 51 | └── utils 52 | ├── chamfer_distance 53 | │   └── // pytorch implementation of chamfer distance metric 54 | ├── file_io.py // utils for loading light field data 55 | ├── tof.py // utils for simulating time-of-flight measurements 56 | └── utils.py // miscellaneous helper functions (e.g. saving network state) 57 | ``` 58 | ### Training: 59 | * Should be as simple as running the debug example with `bash debug.sh`, if all the prerequisite libraries play nice 60 | * Outputs will be saved to `checkpoint_dir=checkpoints/debug/` 61 | * `--init mask_pattern_name` sets the initial mask iterate, in this case `gaussian_circles1.5,0.75`, which is the Gaussian Circle pattern with mean `1.5` and standard distribution `0.75` 62 | * `--use_net` is a flag to jointly train a refinement/reconstruction network, as outlined in the paper 63 | * Additional arguments and descriptions can be found at the top of `/train.py` 64 | 65 | ### Reconstruction: 66 | The notebook `reconstruction.ipynb` includes an interactive demo for loading a network from a checkpoint folder, visualizing mask structure, simulating amplitude measurements, and performing depth reconstruction. 67 | 68 | ### Experimental Setup: 69 | If you're building your own experimental prototype, you can reach out to me at `chugunov[at]princeton[dot]edu` for information and advice. 70 | 71 | --- 72 | 73 | Best of luck, 74 | Ilya 75 | -------------------------------------------------------------------------------- /checkpoints/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/MaskToF/d34c453631e97424a5dff20f155c509091a75026/checkpoints/__init__.py -------------------------------------------------------------------------------- /conda_env.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | absl-py=0.11.0=pypi_0 6 | addict=2.4.0=pypi_0 7 | anyio=2.1.0=pypi_0 8 | argon2-cffi=20.1.0=py38h27cfd23_1 9 | array2gif=1.0.4=pypi_0 10 | async_generator=1.10=pyhd3eb1b0_0 11 | attrs=20.3.0=pyhd3eb1b0_0 12 | babel=2.9.0=pypi_0 13 | backcall=0.2.0=pyhd3eb1b0_0 14 | blas=1.0=mkl 15 | bleach=3.3.0=pyhd3eb1b0_0 16 | ca-certificates=2021.1.19=h06a4308_0 17 | cachetools=4.2.1=pypi_0 18 | certifi=2020.12.5=py38h06a4308_0 19 | cffi=1.14.4=pypi_0 20 | chardet=4.0.0=pypi_0 21 | cudatoolkit=10.2.89=hfd86e86_1 22 | cycler=0.10.0=pypi_0 23 | dataclasses=0.6=pypi_0 24 | decorator=4.4.2=pyhd3eb1b0_0 25 | defusedxml=0.6.0=pyhd3eb1b0_0 26 | entrypoints=0.3=pypi_0 27 | freetype=2.10.4=h5ab3b9f_0 28 | future=0.18.2=pypi_0 29 | fvcore=0.1.3.post20210204=pypi_0 30 | google-auth=1.24.0=pypi_0 31 | google-auth-oauthlib=0.4.2=pypi_0 32 | grpcio=1.35.0=pypi_0 33 | idna=2.10=pypi_0 34 | imageio=2.9.0=pypi_0 35 | importlib-metadata=2.0.0=py_1 36 | importlib_metadata=2.0.0=1 37 | intel-openmp=2020.2=254 38 | iopath=0.1.3=pypi_0 39 | ipykernel=5.4.3=pypi_0 40 | ipython=7.20.0=pypi_0 41 | ipython_genutils=0.2.0=pyhd3eb1b0_1 42 | ipywidgets=7.6.3=pyhd3eb1b0_1 43 | jedi=0.18.0=pypi_0 44 | jinja2=2.11.3=pyhd3eb1b0_0 45 | joblib=1.0.1=pypi_0 46 | jpeg=9b=h024ee3a_2 47 | json5=0.9.5=pypi_0 48 | jsonschema=3.2.0=py_2 49 | jupyter-client=6.1.11=pypi_0 50 | jupyter-server=1.3.0=pypi_0 51 | jupyter_client=6.1.7=py_0 52 | jupyter_core=4.7.1=py38h06a4308_0 53 | jupyterlab=3.0.7=pypi_0 54 | jupyterlab-server=2.2.0=pypi_0 55 | jupyterlab_pygments=0.1.2=py_0 56 | jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 57 | kiwisolver=1.3.1=pypi_0 58 | lcms2=2.11=h396b838_0 59 | ld_impl_linux-64=2.33.1=h53a641e_7 60 | libedit=3.1.20191231=h14c3975_1 61 | libffi=3.3=he6710b0_2 62 | libgcc-ng=9.1.0=hdf63c60_0 63 | libpng=1.6.37=hbc83047_0 64 | libsodium=1.0.18=h7b6447c_0 65 | libstdcxx-ng=9.1.0=hdf63c60_0 66 | libtiff=4.1.0=h2733197_1 67 | libuv=1.40.0=h7b6447c_0 68 | lz4-c=1.9.3=h2531618_0 69 | markdown=3.3.3=pypi_0 70 | markupsafe=1.1.1=py38h7b6447c_0 71 | matplotlib=3.3.4=pypi_0 72 | mistune=0.8.4=py38h7b6447c_1000 73 | mkl=2020.2=256 74 | mkl-service=2.3.0=py38he904b0f_0 75 | mkl_fft=1.2.0=py38h23d657b_0 76 | mkl_random=1.1.1=py38h0573a6f_0 77 | nbclassic=0.2.6=pypi_0 78 | nbclient=0.5.1=pypi_0 79 | nbconvert=6.0.7=py38_0 80 | nbformat=5.1.2=pyhd3eb1b0_1 81 | ncurses=6.2=he6710b0_1 82 | nest-asyncio=1.5.1=pyhd3eb1b0_0 83 | ninja=1.10.0.post2=pypi_0 84 | notebook=6.2.0=py38h06a4308_0 85 | numpy=1.19.2=py38h54aff64_0 86 | numpy-base=1.19.2=py38hfa32c7d_0 87 | nvidiacub=1.10.0=0 88 | oauthlib=3.1.0=pypi_0 89 | olefile=0.46=py_0 90 | opencv-python=4.5.1.48=pypi_0 91 | openssl=1.1.1j=h27cfd23_0 92 | packaging=20.9=pyhd3eb1b0_0 93 | pandas=1.2.3=pypi_0 94 | pandoc=2.11=hb0f4dca_0 95 | pandocfilters=1.4.3=py38h06a4308_1 96 | parso=0.8.1=pyhd3eb1b0_0 97 | pexpect=4.8.0=pyhd3eb1b0_3 98 | pickleshare=0.7.5=pyhd3eb1b0_1003 99 | pillow=8.1.0=py38he98fc37_0 100 | pip=20.3.3=py38h06a4308_0 101 | plyfile=0.7.3=pypi_0 102 | portalocker=2.2.1=pypi_0 103 | prometheus_client=0.9.0=pyhd3eb1b0_0 104 | prompt-toolkit=3.0.14=pypi_0 105 | protobuf=3.14.0=pypi_0 106 | ptyprocess=0.7.0=pyhd3eb1b0_2 107 | pyasn1=0.4.8=pypi_0 108 | pyasn1-modules=0.2.8=pypi_0 109 | pycparser=2.20=py_2 110 | pygments=2.7.4=pypi_0 111 | pyntcloud=0.1.4=pypi_0 112 | pyparsing=2.4.7=pyhd3eb1b0_0 113 | pyrsistent=0.17.3=py38h7b6447c_0 114 | python=3.8.3=hcff3b4d_2 115 | python-dateutil=2.8.1=pyhd3eb1b0_0 116 | pytorch=1.7.0=py3.8_cuda10.2.89_cudnn7.6.5_0 117 | pytorch3d=0.3.0=pypi_0 118 | pytz=2021.1=pypi_0 119 | pyyaml=5.4.1=pypi_0 120 | pyzmq=22.0.2=pypi_0 121 | readline=8.1=h27cfd23_0 122 | requests=2.25.1=pypi_0 123 | requests-oauthlib=1.3.0=pypi_0 124 | rsa=4.7=pypi_0 125 | scikit-learn=0.24.1=pypi_0 126 | scipy=1.6.0=pypi_0 127 | send2trash=1.5.0=pyhd3eb1b0_1 128 | setuptools=52.0.0=py38h06a4308_0 129 | six=1.15.0=py38h06a4308_0 130 | sklearn=0.0=pypi_0 131 | sniffio=1.2.0=pypi_0 132 | sqlite=3.33.0=h62c20be_0 133 | tabulate=0.8.7=pypi_0 134 | tensorboard=2.4.1=pypi_0 135 | tensorboard-plugin-wit=1.8.0=pypi_0 136 | termcolor=1.1.0=pypi_0 137 | terminado=0.9.2=pypi_0 138 | testpath=0.4.4=pyhd3eb1b0_0 139 | threadpoolctl=2.1.0=pypi_0 140 | tk=8.6.10=hbc83047_0 141 | torchvision=0.8.1=py38_cu102 142 | tornado=6.1=py38h27cfd23_0 143 | tqdm=4.56.0=pypi_0 144 | traitlets=5.0.5=pyhd3eb1b0_0 145 | typing_extensions=3.7.4.3=pyh06a4308_0 146 | urllib3=1.26.3=pypi_0 147 | wcwidth=0.2.5=py_0 148 | webencodings=0.5.1=pypi_0 149 | werkzeug=1.0.1=pypi_0 150 | wheel=0.36.2=pyhd3eb1b0_0 151 | widgetsnbextension=3.5.1=py38_0 152 | xz=5.2.5=h7b6447c_0 153 | yacs=0.1.8=pypi_0 154 | zeromq=4.3.3=he6710b0_3 155 | zipp=3.4.0=pyhd3eb1b0_0 156 | zlib=1.2.11=h7b6447c_3 157 | zstd=1.4.5=h9ceee32_0 158 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/MaskToF/d34c453631e97424a5dff20f155c509091a75026/data/__init__.py -------------------------------------------------------------------------------- /dataloader/dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import os 3 | import numpy as np 4 | from utils.file_io import read_lightfield, read_parameters, read_depth, read_disparity, read_depth_all_view 5 | from utils.utils import read_text_lines 6 | 7 | class LightFieldDataset(Dataset): 8 | def __init__(self, 9 | dataset_name='4DLFB', 10 | mode='train', 11 | transform=None, 12 | augmentation=None): 13 | super(LightFieldDataset, self).__init__() 14 | 15 | self.dataset_name = dataset_name 16 | self.mode = mode 17 | self.transform = transform 18 | self.augmentation = augmentation 19 | 20 | synthetic_dict = { 21 | 'train': 'filenames/4DLFB_train.txt', 22 | 'val': 'filenames/4DLFB_val.txt', 23 | 'test': 'filenames/4DLFB_test.txt' 24 | } 25 | 26 | dataset_name_dict = { 27 | '4DLFB': synthetic_dict 28 | } 29 | 30 | assert dataset_name in dataset_name_dict.keys() 31 | self.dataset_name = dataset_name 32 | self.samples = [] 33 | data_filename = dataset_name_dict[dataset_name][mode] 34 | lines = read_text_lines(data_filename) 35 | 36 | for line in lines: 37 | splits = line.split() 38 | data_folder = splits[0] 39 | sample = dict() 40 | sample["data_folder"] = data_folder 41 | sample["lightfield"] = read_lightfield(data_folder)/255. # convert to 0-1 range 42 | sample["parameters"] = read_parameters(data_folder) 43 | sample["depth"] = read_depth_all_view(data_folder, N=81)*1000. # convert to mm 44 | sample["depth_gt"] = read_depth(data_folder)*1000. # convert to mm 45 | 46 | scale = min(1000/np.percentile(sample["depth_gt"], 99.9), 1) # 1000mm max depth 47 | sample["depth"] = scale * sample["depth"] 48 | sample["depth_gt"] = scale * sample["depth_gt"] 49 | 50 | if self.transform is not None: 51 | sample = self.transform(sample) 52 | 53 | self.samples.append(sample) 54 | 55 | def __getitem__(self, index): 56 | sample = self.samples[index] 57 | 58 | if self.augmentation is not None: 59 | sample = self.augmentation(sample) 60 | 61 | return sample 62 | 63 | def __len__(self): 64 | return len(self.samples) -------------------------------------------------------------------------------- /dataloader/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | # from PIL import Image 4 | import random 5 | 6 | 7 | class Compose(object): 8 | def __init__(self, transforms): 9 | self.transforms = transforms 10 | 11 | def __call__(self, sample): 12 | for t in self.transforms: 13 | sample = t(sample) 14 | return sample 15 | 16 | 17 | class ToTensor(object): 18 | """Convert numpy array to torch tensor""" 19 | 20 | def __call__(self, sample): 21 | sample["lightfield"] = torch.from_numpy(sample["lightfield"]) # [9, 9, H, W] 22 | J,K,H,W = sample["lightfield"].shape 23 | sample["lightfield"] = sample["lightfield"].reshape(J*K,H,W) 24 | 25 | if "depth_gt" in sample.keys(): 26 | sample["depth_gt"] = torch.from_numpy(sample["depth_gt"]) 27 | sample["depth_gt"] = sample["depth_gt"].unsqueeze(0) 28 | 29 | if "depth" in sample.keys(): 30 | sample["depth"] = torch.from_numpy(sample["depth"]) 31 | 32 | 33 | return sample 34 | 35 | class RGBtoGray(object): 36 | """Convert lightfield array to grayscale""" 37 | 38 | def __call__(self, sample, R_weight=0.2125, G_weight=0.7154, B_weight=0.0721): 39 | sample["lightfield"] = R_weight*sample["lightfield"][...,0] + \ 40 | G_weight*sample["lightfield"][...,1] + \ 41 | B_weight*sample["lightfield"][...,2] 42 | return sample 43 | 44 | class RGBtoNIR(object): 45 | """Convert lightfield array to near infra-red""" 46 | 47 | def __call__(self, sample): 48 | interm = np.maximum(sample["lightfield"], 1-sample["lightfield"])[...,::-1] 49 | nir = (interm[..., 0]*0.229 + interm[..., 1]*0.587 + interm[..., 2]*0.114)**(1/0.25) 50 | sample["lightfield"] = nir 51 | return sample 52 | 53 | class ToRandomPatches(object): 54 | """Convert full image tensors to random patches""" 55 | 56 | def __init__(self, num_patches, patch_width, patch_height, random_rotation=True, random_flip=True): 57 | self.num_patches = num_patches 58 | self.patch_width = patch_width 59 | self.patch_height = patch_height 60 | self.random_rotation = random_rotation # if true apply 0-3 90 degree rotations 61 | self.random_flip = random_flip # if true apply random vertical flip 62 | 63 | def __call__(self, sample): 64 | lightfield = sample["lightfield"] 65 | depth = sample["depth"] 66 | depth_gt = sample["depth_gt"] 67 | 68 | C, H, W = lightfield.shape 69 | patch_x_coords = torch.randint(0, W - self.patch_width, (1,self.num_patches)).squeeze() 70 | patch_y_coords = torch.randint(0, H - self.patch_height, (1,self.num_patches)).squeeze() 71 | 72 | lightfield_patches = [] 73 | depth_patches = [] 74 | depth_gt_patches = [] 75 | 76 | for i in range(self.num_patches): 77 | x1, x2 = patch_x_coords[i], patch_x_coords[i] + self.patch_width 78 | y1, y2 = patch_y_coords[i], patch_y_coords[i] + self.patch_height 79 | lightfield_patch = lightfield[:,y1:y2,x1:x2] 80 | depth_patch = depth[:,y1:y2,x1:x2] 81 | depth_gt_patch = depth_gt[:,y1:y2,x1:x2] 82 | if self.random_rotation: 83 | rot = np.random.randint(0,4) 84 | lightfield_patch = torch.rot90(lightfield_patch, rot, dims=(1,2)) 85 | depth_patch = torch.rot90(depth_patch, rot, dims=(1,2)) 86 | depth_gt_patch = torch.rot90(depth_gt_patch, rot, dims=(1,2)) 87 | if self.random_flip: 88 | if np.random.randint(0,2) == 0: 89 | lightfield_patch = lightfield_patch.flip(1) 90 | depth_patch = depth_patch.flip(1) 91 | depth_gt_patch = depth_gt_patch.flip(1) 92 | 93 | lightfield_patches.append(lightfield_patch) 94 | depth_patches.append(depth_patch) 95 | depth_gt_patches.append(depth_gt_patch) 96 | 97 | sample["lightfield_patches"] = torch.stack(lightfield_patches) 98 | sample["depth_patches"] = torch.stack(depth_patches) 99 | sample["depth_gt_patches"] = torch.stack(depth_gt_patches) 100 | 101 | return sample -------------------------------------------------------------------------------- /debug.sh: -------------------------------------------------------------------------------- 1 | # example training script, put argparse arguments here 2 | 3 | checkpoint_dir=checkpoints/debug/ 4 | 5 | rm -r ${checkpoint_dir} # remove if you don't want to wipe the folder 6 | python3 train.py \ 7 | --checkpoint_dir ${checkpoint_dir} \ 8 | --mode train \ 9 | --init gaussian_circles1.5,0.75 \ 10 | --use_net \ 11 | --batch_size 2 \ 12 | -------------------------------------------------------------------------------- /filenames/4DLFB_test.txt: -------------------------------------------------------------------------------- 1 | data/bicycle 2 | data/herbs 3 | data/origami 4 | data/bedroom -------------------------------------------------------------------------------- /filenames/4DLFB_train.txt: -------------------------------------------------------------------------------- 1 | data/boardgames 2 | data/kitchen 3 | data/vinyl 4 | data/dishes 5 | data/table 6 | data/antinous 7 | data/town 8 | data/platonic 9 | data/museum 10 | data/tomb 11 | data/medieval2 12 | data/rosemary -------------------------------------------------------------------------------- /filenames/4DLFB_val.txt: -------------------------------------------------------------------------------- 1 | data/greek 2 | data/pillows 3 | data/tower 4 | data/pens -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | from torch.utils.tensorboard import SummaryWriter 4 | from utils import utils 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import os 9 | from utils.chamfer_distance import ChamferDistance 10 | 11 | class Model(object): 12 | def __init__(self, args, optimizer, net, device, start_iter=0, start_epoch=0, views_x=9, views_y=9, img_height=512, img_width=512): 13 | self.args = args 14 | self.optimizer = optimizer 15 | self.net = net 16 | self.device = device 17 | self.num_iter = start_iter 18 | self.epoch = start_epoch 19 | self.train_writer = SummaryWriter(self.args.checkpoint_dir) 20 | self.views_x, self.views_y = views_x, views_y 21 | self.img_height, self.img_width = img_height, img_width 22 | self.l1_loss = torch.nn.SmoothL1Loss() 23 | self.l2_loss = torch.nn.MSELoss() 24 | self.chamfer_loss = ChamferDistance() 25 | 26 | # save initial mask 27 | utils.check_path(os.path.join(args.checkpoint_dir, "masks/")) 28 | if self.epoch == 0 and self.epoch in args.mask_checkpoints: 29 | torch.save(self.net.amplitude_mask.mask, os.path.join(args.checkpoint_dir, "masks/mask_epoch_{0}.pt".format(self.epoch))) 30 | 31 | def train(self, train_loader): 32 | 33 | args = self.args 34 | steps_per_epoch = len(train_loader) 35 | device = self.device 36 | self.net.train() # init Module 37 | 38 | # Learning rate summary 39 | lr = self.optimizer.param_groups[0]["lr"] 40 | self.train_writer.add_scalar("base_lr", lr, self.epoch + 1) 41 | 42 | if self.epoch == self.args.mask_start_epoch: 43 | # Start gradient 44 | self.net.amplitude_mask.mask.requires_grad = True 45 | 46 | last_print_time = time.time() 47 | 48 | 49 | for i, sample in enumerate(train_loader): 50 | # increase softmax gamma 51 | softmax_weight = (1 + (args.softmax_gamma*self.num_iter)**2) 52 | self.net.amplitude_mask.softmax_weight = softmax_weight 53 | 54 | lightfield = sample["lightfield_patches"].to(device) # [B, num_patches, 81, H, W] 55 | parameters = sample["parameters"] 56 | depth = sample["depth_patches"].to(device) # [B, num_patches, 81, H, W] 57 | depth_gt = sample["depth_gt_patches"].to(device) # [B, num_patches 1, H, W] 58 | 59 | # reshape patches -> batches * patches 60 | lightfield = lightfield.view(-1, *lightfield.shape[2:]) 61 | depth = depth.view(-1, *depth.shape[2:]) # [B * num_patches, 81, H, W] 62 | depth_gt = depth_gt.view(-1, *depth_gt.shape[2:]) 63 | 64 | depth_pred = self.net(lightfield, depth, args, parameters, patch=True) 65 | pcd_pred, pcd_gt = utils.tensor_to_pcd(depth_pred), utils.tensor_to_pcd(depth_gt) 66 | 67 | l1_loss = args.l1_weight*self.l1_loss(depth_pred, depth_gt) 68 | dist1, dist2 = self.chamfer_loss(pcd_pred, pcd_gt) 69 | chamfer_loss = args.chamfer_weight*(torch.mean(dist1) + torch.mean(dist2)) 70 | 71 | if args.refinement == "RefinementUnetResnet": # tofnet 72 | B,C,H,W = depth_pred.shape 73 | tv_h = torch.pow(depth_pred[:,:,1:,:]-depth_pred[:,:,:-1,:], 2).sum() 74 | tv_w = torch.pow(depth_pred[:,:,:,1:]-depth_pred[:,:,:,:-1], 2).sum() 75 | tv_loss = (tv_h + tv_w)/(B*C*H*W) 76 | loss = F.l1_loss(depth_pred, depth_gt) + 0.0001*tv_loss 77 | 78 | else: 79 | if chamfer_loss != chamfer_loss or self.epoch < self.args.mask_start_epoch: # if is nan 80 | loss = l1_loss 81 | else: 82 | loss = l1_loss + chamfer_loss 83 | # print(depth.mean(), depth_pred.mean()) 84 | 85 | self.optimizer.zero_grad() 86 | loss.backward() 87 | self.optimizer.step() 88 | 89 | # print log 90 | if self.num_iter % args.print_freq == 0: 91 | this_cycle = time.time() - last_print_time 92 | last_print_time += this_cycle 93 | grad = self.net.amplitude_mask.mask.grad.norm().item() if self.net.amplitude_mask.mask.grad is not None else 0 94 | print("Epoch: [%3d/%3d] [%5d/%5d] time: %4.2fs l1 loss: %.3f chamfer loss: %.3f grad norm: %.8f" % 95 | (self.epoch + 1, args.max_epoch, i + 1, 96 | steps_per_epoch, this_cycle, l1_loss.item(), chamfer_loss.item(), 97 | grad)) 98 | 99 | if self.num_iter % args.summary_freq == 0: 100 | img_summary = dict() 101 | img_summary["lightfield_0"] = lightfield[:,0,:,:] 102 | img_summary["depth_gt"] = depth_gt 103 | img_summary["depth_pred"] = depth_pred 104 | img_summary["absolute_error"] = torch.abs(depth_pred - depth_gt) 105 | mask = utils.combine_masks(self.net.amplitude_mask.get_mask()) 106 | img_summary["mask_center_256"] = mask[None,None,2176:2432,2176:2432] # [1, 1, H, W] 107 | img_summary["mask_center_512"] = mask[None,None,2048:2562,2048:2562] # [1, 1, H, W] 108 | 109 | utils.save_images(self.train_writer, "train", img_summary, self.num_iter) 110 | 111 | img_summary = dict() 112 | 113 | self.train_writer.add_scalar("train/loss", loss.item(), self.num_iter) 114 | self.train_writer.add_scalar("train/l1_loss", l1_loss.item(), self.num_iter) 115 | self.train_writer.add_scalar("train/chamfer_loss", chamfer_loss.item(), self.num_iter) 116 | # self.train_writer.add_scalar("train/softmax_weight", softmax_weight, self.num_iter) 117 | 118 | self.num_iter += 1 119 | 120 | self.epoch += 1 121 | 122 | # save mask if checkpoint 123 | if self.epoch in args.mask_checkpoints: 124 | torch.save(self.net.amplitude_mask.mask, os.path.join(args.checkpoint_dir, "masks/mask_epoch_{0}.pt".format(self.epoch))) 125 | # Always save the latest model for resuming training 126 | if args.no_validate: 127 | torch.save(self.net, os.path.join(args.checkpoint_dir, "full_net_latest.pt")) 128 | utils.save_checkpoint(args.checkpoint_dir, self.optimizer, self.net, 129 | epoch=self.epoch, num_iter=self.num_iter, 130 | loss=-1, mask=self.net.amplitude_mask.mask, filename="net_latest.pt") 131 | 132 | # Save checkpoint of specific epoch 133 | if self.epoch % args.save_ckpt_freq == 0: 134 | model_dir = os.path.join(args.checkpoint_dir, "models") 135 | utils.check_path(model_dir) 136 | utils.save_checkpoint(model_dir, self.optimizer, self.net, 137 | epoch=self.epoch, num_iter=self.num_iter, 138 | loss=-1, mask=self.net.amplitude_mask.mask, save_optimizer=False) 139 | 140 | def validate(self, val_loader): 141 | args = self.args 142 | device = self.device 143 | print("=> Start validation...") 144 | 145 | self.net.eval() 146 | 147 | num_samples = len(val_loader) 148 | print("=> %d samples found in the validation set" % num_samples) 149 | 150 | val_file = os.path.join(args.checkpoint_dir, "val_results.txt") 151 | val_loss_chamfer = 0 152 | val_loss_l1 = 0 153 | valid_samples = 0 154 | 155 | for i, sample in enumerate(val_loader): 156 | lightfield = sample["lightfield"].to(device) # [B, 81, H, W] 157 | parameters = sample["parameters"] 158 | depth = sample["depth"].to(device) # [B, 81, H, W] 159 | depth_gt = sample["depth_gt"].to(device) # [B, 1, H, W] 160 | 161 | valid_samples += 1 162 | 163 | with torch.no_grad(): 164 | depth_pred = self.net(lightfield, depth, args, parameters, patch=False) 165 | pcd_pred, pcd_gt = utils.tensor_to_pcd(depth_pred), utils.tensor_to_pcd(depth_gt) 166 | 167 | val_loss_l1 += args.l1_weight*self.l1_loss(depth_pred, depth_gt) 168 | dist1, dist2 = self.chamfer_loss(pcd_pred, pcd_gt) 169 | val_loss_chamfer += args.chamfer_weight*(torch.mean(dist1) + torch.mean(dist2)) 170 | 171 | # Save 3 images for visualization 172 | if i in [num_samples // 4, num_samples // 2, num_samples // 4 * 3]: 173 | img_summary = dict() 174 | img_summary["lightfield"] = lightfield[:,0,:,:] 175 | img_summary["depth_gt"] = depth_gt 176 | img_summary["depth_pred"] = depth_pred 177 | img_summary["absolute_error"] = torch.abs(depth_pred - depth_gt) 178 | utils.save_images(self.train_writer, "val" + str(i), img_summary, self.epoch) 179 | 180 | print("=> Validation done!") 181 | 182 | val_loss_chamfer = val_loss_chamfer / valid_samples 183 | val_loss_l1 = val_loss_l1 / valid_samples 184 | loss = val_loss_chamfer + val_loss_l1 185 | # Save validation results 186 | with open(val_file, "a") as f: 187 | f.write("epoch: %03d\t" % self.epoch) 188 | f.write("val_loss_l1: %.3f\t" % val_loss_l1) 189 | f.write("val_loss_chamfer: %.3f\t" % val_loss_chamfer) 190 | 191 | print("=> Mean validation loss of epoch %d: l1: %.6f chamfer: %.6f" % (self.epoch, val_loss_l1, val_loss_chamfer)) 192 | self.train_writer.add_scalar("val/loss_l1", val_loss_l1, self.num_iter) 193 | self.train_writer.add_scalar("val/loss_chamfer", val_loss_chamfer, self.num_iter) 194 | self.train_writer.add_scalar("val/unweighted_loss_l1", val_loss_l1/args.l1_weight, self.num_iter) 195 | self.train_writer.add_scalar("val/unweighted_loss_chamfer", val_loss_chamfer/args.chamfer_weight, self.num_iter) 196 | 197 | # Always save the latest model for resuming training 198 | torch.save(self.net, os.path.join(args.checkpoint_dir, "full_net_latest.pt")) 199 | utils.save_checkpoint(args.checkpoint_dir, self.optimizer, self.net, 200 | epoch=self.epoch, num_iter=self.num_iter, 201 | loss=loss, mask=self.net.amplitude_mask.mask, filename="net_latest.pt") 202 | 203 | # Save checkpoint of specific epoch 204 | if self.epoch % args.save_ckpt_freq == 0: 205 | model_dir = os.path.join(args.checkpoint_dir, "models") 206 | utils.check_path(model_dir) 207 | utils.save_checkpoint(model_dir, self.optimizer, self.net, 208 | epoch=self.epoch, num_iter=self.num_iter, 209 | loss=loss, mask=self.net.amplitude_mask.mask, save_optimizer=False) -------------------------------------------------------------------------------- /nets/MaskToFNet.py: -------------------------------------------------------------------------------- 1 | import scipy.io 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from torch.autograd import Variable 6 | import utils.utils 7 | from utils.tof import * 8 | import importlib 9 | 10 | 11 | class AmplitudeMask(nn.Module): 12 | def __init__(self, args, device): 13 | super(AmplitudeMask, self).__init__() 14 | 15 | if args.init.lower() == "zeros": 16 | # All Zeroes 17 | mask = torch.cat([torch.ones((args.views_x*args.views_y, 1, args.patch_height, args.patch_width), device=device), 18 | torch.zeros((args.views_x*args.views_y, 1, args.patch_height, args.patch_width), device=device)], dim=1) 19 | elif args.init.lower() == "ones": 20 | # All Ones 21 | mask = torch.cat([torch.zeros((args.views_x*args.views_y, 1, args.patch_height, args.patch_width), device=device), 22 | torch.zeros((args.views_x*args.views_y, 1, args.patch_height, args.patch_width), device=device)], dim=1) 23 | elif args.init.lower() == "uniform": 24 | # Gaussian Random Mask 25 | mask = torch.empty(args.views_x*args.views_y, 2, args.patch_height, args.patch_width, device=device).uniform_(0,1) 26 | elif args.init.lower() == "bernoulli": 27 | # Bernoulli Random Mask 28 | mask = torch.empty(args.views_x*args.views_y, 1, args.patch_height, args.patch_width, device=device).uniform_(0,1) 29 | mask = torch.bernoulli(mask) 30 | mask = torch.cat([1 - mask, mask], dim=1) 31 | 32 | elif args.init.lower() == "custom": 33 | # Design your own 34 | load = torch.tensor([[0.,0,0,0,0,0,0,0,0], 35 | [0,0,0,0,0,0,0,0,0], 36 | [0,0,0,0,1,0,0,0,0], 37 | [0,0,1,0,1,0,1,0,0], 38 | [0,0,1,0,1,0,1,0,0], 39 | [0,0,1,0,1,0,1,0,0], 40 | [0,0,0,0,1,0,0,0,0], 41 | [0,0,0,0,0,0,0,0,0], 42 | [0,0,0,0,0,0,0,0,0]]) 43 | load = load[:,:,None,None] 44 | m = torch.ones(9,9,args.patch_height,args.patch_width)*load 45 | m = m.reshape(81,args.patch_height,args.patch_width) 46 | mask = torch.zeros(81,2,args.patch_height,args.patch_width, device=device) 47 | mask[:,0,:,:] = 10 - m*10 48 | mask[:,1,:,:] = m*10 49 | 50 | elif "barcode" in args.init.lower() and "/" not in args.init: 51 | mask = torch.zeros(81,2,512,512, device=device) 52 | load = torch.from_numpy(np.load("utils/barcode_masks/{0}.npy".format(args.init.lower()))).to(device).reshape(81,512,512) 53 | mask[:,0,:,:][torch.where(load <= 0)] = 10 54 | mask[:,1,:,:][torch.where(load > 0)] = 10 55 | mask = mask[:,:,:args.patch_height,:args.patch_width] 56 | 57 | elif "gaussian_circles" in args.init.lower() and "/" not in args.init: 58 | init = args.init.lower().replace("gaussian_circles","") 59 | if "," in init: 60 | mean, sigma = [float(el) for el in init.split(",")] 61 | else: 62 | mean, sigma = 1.5, 1 63 | shape = (args.views_y, args.views_x, args.patch_height, args.patch_width) 64 | mask = utils.utils.gkern_mask(mean, sigma, shape) 65 | mask = utils.utils.un_combine_masks(mask, shape)[:,None,:,:]*10 # scale for softmax 66 | mask = torch.cat([10 - mask, mask], dim=1).float().to(device) 67 | 68 | elif "/" in args.init: # load path 69 | mask = torch.load(args.init, map_location=device) 70 | 71 | else: 72 | raise Exception("Not implemented.") 73 | 74 | self.softmax_weight = 1 # softmax temperature 75 | self.softmax = nn.Softmax(dim=2) 76 | self.mask = nn.Parameter(data=mask, requires_grad=(args.mask_start_epoch == 0)) 77 | self.mask = self.mask.to(device) 78 | assert args.img_height % (args.patch_height - args.pad_y*2) == 0 79 | assert args.img_width % (args.patch_width - args.pad_x*2) == 0 80 | self.y_repeat = args.img_height//(args.patch_height - args.pad_y*2) 81 | self.x_repeat = args.img_width//(args.patch_width - args.pad_x*2) 82 | self.pad_y = args.pad_x 83 | self.pad_x = args.pad_y 84 | 85 | def get_mask_internal(self, patch=True): 86 | if patch: 87 | mask = self.mask 88 | else: 89 | if self.pad_x > 0 or self.pad_y > 0: 90 | mask = self.mask[:,:,self.pad_y:-self.pad_y,self.pad_x:-self.pad_x] 91 | else: 92 | mask = self.mask 93 | mask = mask.repeat(1,1,self.y_repeat, self.x_repeat) 94 | mask = utils.utils.combine_masks(mask[:,1,:,:])[None,None,:,:] # [1,1,9H,9W] 95 | return mask 96 | 97 | def get_mask(self): 98 | if self.pad_x > 0 or self.pad_y > 0: 99 | mask = self.mask[:,:,self.pad_y:-self.pad_y,self.pad_x:-self.pad_x] 100 | else: 101 | mask = self.mask 102 | mask = mask.repeat(1,1,self.y_repeat, self.x_repeat) 103 | return self.softmax(self.softmax_weight * mask)[:,1,:,:] 104 | 105 | def forward(self, amplitudes, patch=False): 106 | if patch: 107 | mask = self.mask.unsqueeze(0) 108 | else: 109 | if self.pad_x > 0 or self.pad_y > 0: 110 | mask = self.mask[:,:,self.pad_y:-self.pad_y,self.pad_x:-self.pad_x] 111 | else: 112 | mask = self.mask 113 | mask = mask.repeat(1,1,self.y_repeat, self.x_repeat).unsqueeze(0) # [1, C, 2, H, W] 114 | mask = self.softmax(self.softmax_weight * mask) # threshold 0-1 115 | mask = mask[:,:,1,:,:] # select 'ON' mask, [B*num_patches, C, H, W] 116 | mask = mask.unsqueeze(1) # [B*num_patches, 1, C, H, W] 117 | return mask * amplitudes 118 | 119 | class MaskToFNet(nn.Module): 120 | def __init__(self, args, device): 121 | super(MaskToFNet, self).__init__() 122 | self.views_x, self.views_y = args.views_x, args.views_y 123 | self.img_height, self.img_width = args.img_height, args.img_width 124 | self.amplitude_mask = AmplitudeMask(args, device) 125 | 126 | if args.use_net: 127 | HourglassRefinement = importlib.import_module('nets.refinement.{0}'.format(args.refinement)).HourglassRefinement 128 | self.refinement = HourglassRefinement() 129 | 130 | def forward(self, lightfield, depth, args, parameters, patch=False): 131 | B, C, H, W = lightfield.shape 132 | phi_list = [] 133 | 134 | # main loop 135 | for f in args.f_list: 136 | amplitudes = sim_quad(depth, f, args.T, args.g, lightfield) 137 | amplitudes = self.amplitude_mask(amplitudes, patch) 138 | amplitudes = amplitudes.mean(dim=2, dtype=torch.float32) # [B*num_patch, 4, patch_height, patch_width] 139 | # white gaussian noise 140 | noise_scale = torch.zeros(amplitudes.shape[0], device=amplitudes.device).uniform_(0.75,1.25)[:,None,None,None] # [B*num_patch, 1,1,1] 141 | noise = torch.normal(std=args.AWGN_sigma, mean=0, size=amplitudes.shape, 142 | device=amplitudes.device, dtype=torch.float32) 143 | if patch: 144 | noise = noise * torch.sqrt(noise_scale) # random scale for training 145 | amplitudes += noise 146 | phi_est, _, _ = decode_quad(amplitudes, args.T, args.mT) 147 | phi_list.append(phi_est.squeeze(1)) 148 | 149 | if len(args.f_list) == 1: 150 | depth_recon = phase2depth(phi_list[0], args.f_list[0]) # [B, H, W] 151 | else: # phase unwrapping 152 | depth_recon = unwrap_ranking(phi_list, args.f_list, min_depth=0, max_depth=6000) 153 | 154 | depth_recon = depth_recon.unsqueeze(1) # [B, 1, H, W] 155 | 156 | if args.use_net: 157 | mask = self.amplitude_mask.get_mask_internal(patch=patch) 158 | depth_recon = self.refinement(depth_recon, mask) 159 | 160 | return depth_recon # [B, 1, H, W] 161 | 162 | def process_amplitudes(self, amplitudes, args, patch=False): #phi_est [B, 4, H, W] 163 | phi_est, _, _ = decode_quad(amplitudes, args.T, args.mT) 164 | phi_est = phi_est.squeeze(1) 165 | depth_recon = phase2depth(phi_est, args.f_list[0]) # [B, H, W] 166 | depth_recon = depth_recon.unsqueeze(1) # [B, 1, H, W] 167 | 168 | if args.use_net: 169 | mask = self.amplitude_mask.get_mask_internal(patch=patch) 170 | depth_recon = self.refinement(depth_recon, mask) 171 | 172 | return depth_recon # [B, 1, H, W] 173 | 174 | 175 | def process_depth(self, depth, patch=False): 176 | mask = self.amplitude_mask.get_mask_internal(patch=patch) 177 | depth_recon = self.refinement(depth, mask) 178 | 179 | return depth_recon # [B, 1, H, W] 180 | -------------------------------------------------------------------------------- /nets/pretrained/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/MaskToF/d34c453631e97424a5dff20f155c509091a75026/nets/pretrained/__init__.py -------------------------------------------------------------------------------- /nets/refinement/Refinement.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def conv2d(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, groups=1): 6 | return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, 7 | stride=stride, padding=dilation, dilation=dilation, 8 | bias=False, groups=groups), 9 | nn.BatchNorm2d(out_channels), 10 | nn.ReLU(inplace=True)) 11 | 12 | class HourglassRefinement(nn.Module): 13 | """Height and width must be divisible by 16""" 14 | 15 | def __init__(self): 16 | super(HourglassRefinement, self).__init__() 17 | 18 | # mask conv 19 | self.mask_conv1 = BasicConv(1, 16, kernel_size=3, stride=3, dilation=3, padding=2) 20 | self.mask_conv2 = BasicConv(16, 1, kernel_size=3, stride=3, dilation=3, padding=2) 21 | 22 | # depth conv 23 | self.conv1 = conv2d(2, 16) 24 | 25 | self.conv_start = BasicConv(16, 32, kernel_size=1) 26 | 27 | self.conv1a = BasicConv(32, 64, kernel_size=3, stride=2, padding=1) 28 | self.conv2a = BasicConv(64, 128, kernel_size=3, stride=2, padding=1) 29 | self.conv3a = BasicConv(128, 256, kernel_size=3, stride=2, dilation=2, padding=2) 30 | self.conv4a = BasicConv(256, 512, kernel_size=3, stride=2, dilation=2, padding=2) 31 | 32 | self.deconv4a = Conv2x(512, 256, deconv=True) 33 | self.deconv3a = Conv2x(256, 128, deconv=True) 34 | self.deconv2a = Conv2x(128, 64, deconv=True) 35 | self.deconv1a = Conv2x(64, 32, deconv=True) 36 | 37 | self.conv1b = Conv2x(32, 64) 38 | self.conv2b = Conv2x(64, 128) 39 | self.conv3b = Conv2x(128, 256) 40 | self.conv4b = Conv2x(256, 512) 41 | 42 | self.deconv4b = Conv2x(512, 256, deconv=True) 43 | self.deconv3b = Conv2x(256, 128, deconv=True) 44 | self.deconv2b = Conv2x(128, 64, deconv=True) 45 | self.deconv1b = Conv2x(64, 32, deconv=True) 46 | 47 | self.final_conv = nn.Conv2d(32, 1, 3, 1, 1) 48 | 49 | def forward(self, depth, mask): 50 | B = depth.shape[0] 51 | mask = self.mask_conv1(mask) 52 | mask = self.mask_conv2(mask) 53 | mask = mask.repeat(B,1,1,1) # [B, 1, H, W] 54 | 55 | x = torch.cat((depth, mask), dim=1) 56 | 57 | conv1 = self.conv1(x) # [B, 16, H, W] 58 | x = self.conv_start(conv1) 59 | rem0 = x 60 | x = self.conv1a(x) 61 | rem1 = x 62 | x = self.conv2a(x) 63 | rem2 = x 64 | x = self.conv3a(x) 65 | rem3 = x 66 | x = self.conv4a(x) 67 | rem4 = x 68 | 69 | x = self.deconv4a(x, rem3) 70 | rem3 = x 71 | x = self.deconv3a(x, rem2) 72 | rem2 = x 73 | x = self.deconv2a(x, rem1) 74 | rem1 = x 75 | x = self.deconv1a(x, rem0) 76 | rem0 = x 77 | 78 | x = self.conv1b(x, rem1) 79 | rem1 = x 80 | x = self.conv2b(x, rem2) 81 | rem2 = x 82 | x = self.conv3b(x, rem3) 83 | rem3 = x 84 | x = self.conv4b(x, rem4) 85 | 86 | x = self.deconv4b(x, rem3) 87 | x = self.deconv3b(x, rem2) 88 | x = self.deconv2b(x, rem1) 89 | x = self.deconv1b(x, rem0) # [B, 32, H, W] 90 | 91 | residual_depth = self.final_conv(x) # [B, 1, H, W] 92 | 93 | depth = F.relu(depth + residual_depth, inplace=True) # [B, 1, H, W] 94 | return depth 95 | 96 | class Conv2x(nn.Module): 97 | 98 | def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, bn=True, relu=True, 99 | mdconv=False): 100 | super(Conv2x, self).__init__() 101 | self.concat = concat 102 | 103 | if deconv and is_3d: 104 | kernel = (3, 4, 4) 105 | elif deconv: 106 | kernel = 4 107 | else: 108 | kernel = 3 109 | self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=True, relu=True, kernel_size=kernel, 110 | stride=2, padding=1) 111 | 112 | if self.concat: 113 | self.conv2 = BasicConv(out_channels * 2, out_channels, False, is_3d, bn, relu, kernel_size=3, 114 | stride=1, padding=1) 115 | else: 116 | self.conv2 = BasicConv(out_channels, out_channels, False, is_3d, bn, relu, kernel_size=3, stride=1, 117 | padding=1) 118 | 119 | def forward(self, x, rem): 120 | x = self.conv1(x) 121 | assert (x.size() == rem.size()) 122 | if self.concat: 123 | x = torch.cat((x, rem), 1) 124 | else: 125 | x = x + rem 126 | x = self.conv2(x) 127 | return x 128 | 129 | class BasicConv(nn.Module): 130 | 131 | def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, bn=True, relu=True, **kwargs): 132 | super(BasicConv, self).__init__() 133 | self.relu = relu 134 | self.use_bn = bn 135 | if is_3d: 136 | if deconv: 137 | self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs) 138 | else: 139 | self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs) 140 | self.bn = nn.BatchNorm3d(out_channels) 141 | else: 142 | if deconv: 143 | self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs) 144 | else: 145 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 146 | self.bn = nn.BatchNorm2d(out_channels) 147 | 148 | def forward(self, x): 149 | x = self.conv(x) 150 | if self.use_bn: 151 | x = self.bn(x) 152 | if self.relu: 153 | x = F.relu(x, inplace=True) 154 | return x -------------------------------------------------------------------------------- /nets/refinement/RefinementBig.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def conv2d(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, groups=1): 6 | return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, 7 | stride=stride, padding=dilation, dilation=dilation, 8 | bias=False, groups=groups), 9 | nn.BatchNorm2d(out_channels), 10 | nn.ReLU(inplace=True)) 11 | 12 | class HourglassRefinement(nn.Module): 13 | """Height and width must be divisible by 16""" 14 | 15 | def __init__(self): 16 | super(HourglassRefinement, self).__init__() 17 | 18 | # mask conv 19 | self.mask_conv1 = BasicConv(1, 16, kernel_size=3, stride=3, dilation=3, padding=2) 20 | self.mask_conv2 = BasicConv(16, 1, kernel_size=3, stride=3, dilation=3, padding=2) 21 | 22 | # depth conv 23 | self.conv1 = conv2d(2, 16) 24 | 25 | self.conv_start = BasicConv(16, 64, kernel_size=1) 26 | 27 | self.conv1a = BasicConv(64, 128, kernel_size=3, stride=2, padding=1) 28 | self.conv2a = BasicConv(128, 256, kernel_size=3, stride=2, padding=1) 29 | self.conv3a = BasicConv(256, 512, kernel_size=3, stride=2, dilation=2, padding=2) 30 | self.conv4a = BasicConv(512, 1024, kernel_size=3, stride=2, dilation=2, padding=2) 31 | 32 | self.deconv4a = Conv2x(1024, 512, deconv=True) 33 | self.deconv3a = Conv2x(512, 256, deconv=True) 34 | self.deconv2a = Conv2x(256, 128, deconv=True) 35 | self.deconv1a = Conv2x(128, 64, deconv=True) 36 | 37 | self.conv1b = Conv2x(64, 128) 38 | self.conv2b = Conv2x(128, 256) 39 | self.conv3b = Conv2x(256, 512) 40 | self.conv4b = Conv2x(512, 1024) 41 | 42 | self.deconv4b = Conv2x(1024, 512, deconv=True) 43 | self.deconv3b = Conv2x(512, 256, deconv=True) 44 | self.deconv2b = Conv2x(256, 128, deconv=True) 45 | self.deconv1b = Conv2x(128, 64, deconv=True) 46 | 47 | self.final_conv = nn.Conv2d(64, 1, 3, 1, 1) 48 | 49 | def forward(self, depth, mask): 50 | B = depth.shape[0] 51 | mask = self.mask_conv1(mask) 52 | mask = self.mask_conv2(mask) 53 | mask = mask.repeat(B,1,1,1) # [B, 1, H, W] 54 | 55 | x = torch.cat((depth, mask), dim=1) 56 | 57 | conv1 = self.conv1(x) # [B, 16, H, W] 58 | x = self.conv_start(conv1) 59 | rem0 = x 60 | x = self.conv1a(x) 61 | rem1 = x 62 | x = self.conv2a(x) 63 | rem2 = x 64 | x = self.conv3a(x) 65 | rem3 = x 66 | x = self.conv4a(x) 67 | rem4 = x 68 | 69 | x = self.deconv4a(x, rem3) 70 | rem3 = x 71 | x = self.deconv3a(x, rem2) 72 | rem2 = x 73 | x = self.deconv2a(x, rem1) 74 | rem1 = x 75 | x = self.deconv1a(x, rem0) 76 | rem0 = x 77 | 78 | x = self.conv1b(x, rem1) 79 | rem1 = x 80 | x = self.conv2b(x, rem2) 81 | rem2 = x 82 | x = self.conv3b(x, rem3) 83 | rem3 = x 84 | x = self.conv4b(x, rem4) 85 | 86 | x = self.deconv4b(x, rem3) 87 | x = self.deconv3b(x, rem2) 88 | x = self.deconv2b(x, rem1) 89 | x = self.deconv1b(x, rem0) # [B, 32, H, W] 90 | 91 | residual_depth = self.final_conv(x) # [B, 1, H, W] 92 | 93 | depth = F.relu(depth + residual_depth, inplace=True) # [B, 1, H, W] 94 | return depth 95 | 96 | class Conv2x(nn.Module): 97 | 98 | def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, bn=True, relu=True, 99 | mdconv=False): 100 | super(Conv2x, self).__init__() 101 | self.concat = concat 102 | 103 | if deconv and is_3d: 104 | kernel = (3, 4, 4) 105 | elif deconv: 106 | kernel = 4 107 | else: 108 | kernel = 3 109 | self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=True, relu=True, kernel_size=kernel, 110 | stride=2, padding=1) 111 | 112 | if self.concat: 113 | self.conv2 = BasicConv(out_channels * 2, out_channels, False, is_3d, bn, relu, kernel_size=3, 114 | stride=1, padding=1) 115 | else: 116 | self.conv2 = BasicConv(out_channels, out_channels, False, is_3d, bn, relu, kernel_size=3, stride=1, 117 | padding=1) 118 | 119 | def forward(self, x, rem): 120 | x = self.conv1(x) 121 | assert (x.size() == rem.size()) 122 | if self.concat: 123 | x = torch.cat((x, rem), 1) 124 | else: 125 | x = x + rem 126 | x = self.conv2(x) 127 | return x 128 | 129 | class BasicConv(nn.Module): 130 | 131 | def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, bn=True, relu=True, **kwargs): 132 | super(BasicConv, self).__init__() 133 | self.relu = relu 134 | self.use_bn = bn 135 | if is_3d: 136 | if deconv: 137 | self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs) 138 | else: 139 | self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs) 140 | self.bn = nn.BatchNorm3d(out_channels) 141 | else: 142 | if deconv: 143 | self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs) 144 | else: 145 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 146 | self.bn = nn.BatchNorm2d(out_channels) 147 | 148 | def forward(self, x): 149 | x = self.conv(x) 150 | if self.use_bn: 151 | x = self.bn(x) 152 | if self.relu: 153 | x = F.relu(x, inplace=True) 154 | return x -------------------------------------------------------------------------------- /nets/refinement/RefinementGlobal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def conv2d(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, groups=1): 6 | return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, 7 | stride=stride, padding=dilation, dilation=dilation, 8 | bias=False, groups=groups), 9 | nn.BatchNorm2d(out_channels), 10 | nn.ReLU(inplace=True)) 11 | 12 | class HourglassRefinement(nn.Module): 13 | """Height and width must be divisible by 16""" 14 | 15 | def __init__(self): 16 | super(HourglassRefinement, self).__init__() 17 | 18 | # mask conv 19 | self.mask_conv1 = BasicConv(1, 16, kernel_size=3, stride=3, dilation=3, padding=2) 20 | self.mask_conv2 = BasicConv(16, 1, kernel_size=3, stride=3, dilation=3, padding=2) 21 | 22 | # depth conv 23 | self.conv1 = conv2d(2, 16) 24 | 25 | self.conv_start = BasicConv(16, 32, kernel_size=1) 26 | 27 | self.conv1a = BasicConv(32, 64, kernel_size=3, stride=2, padding=1) 28 | self.conv2a = BasicConv(64, 128, kernel_size=3, stride=2, padding=1) 29 | self.conv3a = BasicConv(128, 256, kernel_size=3, stride=2, dilation=2, padding=2) 30 | self.conv4a = BasicConv(256, 512, kernel_size=3, stride=2, dilation=2, padding=2) 31 | 32 | self.conv1g = BasicConv(32, 64, kernel_size=3, stride=2, dilation=2, padding=0) 33 | self.conv2g = BasicConv(64, 128, kernel_size=3, stride=2, dilation=2, padding=0) 34 | self.conv3g = BasicConv(128, 256, kernel_size=3, stride=2, dilation=2, padding=0) 35 | self.conv4g = BasicConv(256, 512, kernel_size=3, stride=3, dilation=2, padding=0) 36 | 37 | self.deconv4a = Conv2x(1024, 256, deconv=True) 38 | self.deconv3a = Conv2x(256, 128, deconv=True) 39 | self.deconv2a = Conv2x(128, 64, deconv=True) 40 | self.deconv1a = Conv2x(64, 32, deconv=True) 41 | 42 | self.conv1b = Conv2x(32, 64) 43 | self.conv2b = Conv2x(64, 128) 44 | self.conv3b = Conv2x(128, 256) 45 | self.conv4b = Conv2x(256, 512) 46 | 47 | self.deconv4b = Conv2x(512, 256, deconv=True) 48 | self.deconv3b = Conv2x(256, 128, deconv=True) 49 | self.deconv2b = Conv2x(128, 64, deconv=True) 50 | self.deconv1b = Conv2x(64, 32, deconv=True) 51 | 52 | self.final_conv = nn.Conv2d(32, 1, 3, 1, 1) 53 | 54 | 55 | def forward(self, depth, mask): 56 | B = depth.shape[0] 57 | mask = self.mask_conv1(mask) 58 | mask = self.mask_conv2(mask) 59 | mask = mask.repeat(B,1,1,1) # [B, 1, H, W] 60 | 61 | x = torch.cat((depth, mask), dim=1) 62 | 63 | conv1 = self.conv1(x) # [B, 16, H, W] 64 | x = self.conv_start(conv1) 65 | rem0 = x 66 | x = self.conv1a(x) 67 | rem1 = x 68 | x = self.conv2a(x) 69 | rem2 = x 70 | x = self.conv3a(x) 71 | rem3 = x 72 | x = self.conv4a(x) 73 | rem4 = x 74 | 75 | # global branch 76 | g = self.conv1g(rem0) 77 | g = self.conv2g(g) 78 | g = self.conv3g(g) 79 | g = self.conv4g(g) 80 | g = F.interpolate(g, size=x.shape[-1]) # reshape for cat size 81 | 82 | x = torch.cat((x,g), dim=1) 83 | 84 | x = self.deconv4a(x, rem3) 85 | rem3 = x 86 | x = self.deconv3a(x, rem2) 87 | rem2 = x 88 | x = self.deconv2a(x, rem1) 89 | rem1 = x 90 | x = self.deconv1a(x, rem0) 91 | rem0 = x 92 | 93 | x = self.conv1b(x, rem1) 94 | rem1 = x 95 | x = self.conv2b(x, rem2) 96 | rem2 = x 97 | x = self.conv3b(x, rem3) 98 | rem3 = x 99 | x = self.conv4b(x, rem4) 100 | 101 | x = self.deconv4b(x, rem3) 102 | x = self.deconv3b(x, rem2) 103 | x = self.deconv2b(x, rem1) 104 | x = self.deconv1b(x, rem0) # [B, 32, H, W] 105 | 106 | residual_depth = self.final_conv(x) # [B, 1, H, W] 107 | 108 | depth = F.relu(depth + residual_depth, inplace=True) # [B, 1, H, W] 109 | return depth 110 | 111 | class Conv2x(nn.Module): 112 | 113 | def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, bn=True, relu=True, 114 | mdconv=False): 115 | super(Conv2x, self).__init__() 116 | self.concat = concat 117 | 118 | if deconv and is_3d: 119 | kernel = (3, 4, 4) 120 | elif deconv: 121 | kernel = 4 122 | else: 123 | kernel = 3 124 | self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=True, relu=True, kernel_size=kernel, 125 | stride=2, padding=1) 126 | 127 | if self.concat: 128 | self.conv2 = BasicConv(out_channels * 2, out_channels, False, is_3d, bn, relu, kernel_size=3, 129 | stride=1, padding=1) 130 | else: 131 | self.conv2 = BasicConv(out_channels, out_channels, False, is_3d, bn, relu, kernel_size=3, stride=1, 132 | padding=1) 133 | 134 | def forward(self, x, rem): 135 | x = self.conv1(x) 136 | assert (x.size() == rem.size()) 137 | if self.concat: 138 | x = torch.cat((x, rem), 1) 139 | else: 140 | x = x + rem 141 | x = self.conv2(x) 142 | return x 143 | 144 | class BasicConv(nn.Module): 145 | 146 | def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, bn=True, relu=True, **kwargs): 147 | super(BasicConv, self).__init__() 148 | self.relu = relu 149 | self.use_bn = bn 150 | if is_3d: 151 | if deconv: 152 | self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs) 153 | else: 154 | self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs) 155 | self.bn = nn.BatchNorm3d(out_channels) 156 | else: 157 | if deconv: 158 | self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs) 159 | else: 160 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 161 | self.bn = nn.BatchNorm2d(out_channels) 162 | 163 | def forward(self, x): 164 | x = self.conv(x) 165 | if self.use_bn: 166 | x = self.bn(x) 167 | if self.relu: 168 | x = F.relu(x, inplace=True) 169 | return x -------------------------------------------------------------------------------- /nets/refinement/RefinementHalf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def conv2d(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, groups=1): 6 | return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, 7 | stride=stride, padding=dilation, dilation=dilation, 8 | bias=False, groups=groups), 9 | nn.BatchNorm2d(out_channels), 10 | nn.ReLU(inplace=True)) 11 | 12 | class HourglassRefinement(nn.Module): 13 | """Height and width must be divisible by 16""" 14 | 15 | def __init__(self): 16 | super(HourglassRefinement, self).__init__() 17 | 18 | # mask conv 19 | self.mask_conv1 = BasicConv(1, 16, kernel_size=3, stride=3, dilation=3, padding=2) 20 | self.mask_conv2 = BasicConv(16, 1, kernel_size=3, stride=3, dilation=3, padding=2) 21 | 22 | # depth conv 23 | self.conv1 = conv2d(2, 16) 24 | 25 | self.conv_start = BasicConv(16, 32, kernel_size=1) 26 | 27 | self.conv1a = BasicConv(32, 64, kernel_size=3, stride=2, padding=1) 28 | self.conv2a = BasicConv(64, 128, kernel_size=3, stride=2, padding=1) 29 | self.conv3a = BasicConv(128, 256, kernel_size=3, stride=2, dilation=2, padding=2) 30 | self.conv4a = BasicConv(256, 512, kernel_size=3, stride=2, dilation=2, padding=2) 31 | 32 | self.deconv4a = Conv2x(512, 256, deconv=True) 33 | self.deconv3a = Conv2x(256, 128, deconv=True) 34 | self.deconv2a = Conv2x(128, 64, deconv=True) 35 | self.deconv1a = Conv2x(64, 32, deconv=True) 36 | 37 | self.final_conv = nn.Conv2d(32, 1, 3, 1, 1) 38 | 39 | def forward(self, depth, mask): 40 | B = depth.shape[0] 41 | mask = self.mask_conv1(mask) 42 | mask = self.mask_conv2(mask) 43 | mask = mask.repeat(B,1,1,1) # [B, 1, H, W] 44 | 45 | x = torch.cat((depth, mask), dim=1) 46 | 47 | conv1 = self.conv1(x) # [B, 16, H, W] 48 | x = self.conv_start(conv1) 49 | rem0 = x 50 | x = self.conv1a(x) 51 | rem1 = x 52 | x = self.conv2a(x) 53 | rem2 = x 54 | x = self.conv3a(x) 55 | rem3 = x 56 | x = self.conv4a(x) 57 | rem4 = x 58 | 59 | x = self.deconv4a(x, rem3) 60 | rem3 = x 61 | x = self.deconv3a(x, rem2) 62 | rem2 = x 63 | x = self.deconv2a(x, rem1) 64 | rem1 = x 65 | x = self.deconv1a(x, rem0) 66 | 67 | residual_depth = self.final_conv(x) # [B, 1, H, W] 68 | 69 | depth = F.relu(depth + residual_depth, inplace=True) # [B, 1, H, W] 70 | return depth 71 | 72 | class Conv2x(nn.Module): 73 | 74 | def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, bn=True, relu=True, 75 | mdconv=False): 76 | super(Conv2x, self).__init__() 77 | self.concat = concat 78 | 79 | if deconv and is_3d: 80 | kernel = (3, 4, 4) 81 | elif deconv: 82 | kernel = 4 83 | else: 84 | kernel = 3 85 | self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=True, relu=True, kernel_size=kernel, 86 | stride=2, padding=1) 87 | 88 | if self.concat: 89 | self.conv2 = BasicConv(out_channels * 2, out_channels, False, is_3d, bn, relu, kernel_size=3, 90 | stride=1, padding=1) 91 | else: 92 | self.conv2 = BasicConv(out_channels, out_channels, False, is_3d, bn, relu, kernel_size=3, stride=1, 93 | padding=1) 94 | 95 | def forward(self, x, rem): 96 | x = self.conv1(x) 97 | assert (x.size() == rem.size()) 98 | if self.concat: 99 | x = torch.cat((x, rem), 1) 100 | else: 101 | x = x + rem 102 | x = self.conv2(x) 103 | return x 104 | 105 | class BasicConv(nn.Module): 106 | 107 | def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, bn=True, relu=True, **kwargs): 108 | super(BasicConv, self).__init__() 109 | self.relu = relu 110 | self.use_bn = bn 111 | if is_3d: 112 | if deconv: 113 | self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs) 114 | else: 115 | self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs) 116 | self.bn = nn.BatchNorm3d(out_channels) 117 | else: 118 | if deconv: 119 | self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs) 120 | else: 121 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 122 | self.bn = nn.BatchNorm2d(out_channels) 123 | 124 | def forward(self, x): 125 | x = self.conv(x) 126 | if self.use_bn: 127 | x = self.bn(x) 128 | if self.relu: 129 | x = F.relu(x, inplace=True) 130 | return x -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | imageio>=2.9.0 2 | h5py>=2.10.0 3 | numpy>=1.20.3 4 | torch>=1.7.0 5 | torchision>=0.10.0 6 | scipy>=1.6.3 7 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | 4 | import argparse 5 | import numpy as np 6 | import os 7 | 8 | from dataloader.dataloader import LightFieldDataset 9 | from dataloader import transforms 10 | from utils import utils 11 | from utils import tof 12 | import model 13 | from nets.MaskToFNet import MaskToFNet 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | # Training args 18 | parser.add_argument("--mode", default="train", type=str, help="Network mode [train, val, test]") 19 | parser.add_argument("--checkpoint_dir", default="checkpoints/4DLFB", type=str, required=True, help="Directory to save model checkpoints and logs") 20 | parser.add_argument("--dataset_name", default="4DLFB", type=str, help="Dataset name [4DLFB]") 21 | parser.add_argument('--pretrained_net', default=None, type=str, help='Pretrained network') 22 | 23 | parser.add_argument("--batch_size", default=12, type=int, help="Batch size for training") 24 | parser.add_argument("--num_workers", default=2, type=int, help="Number of workers for data loading") 25 | parser.add_argument("--seed", default=42, type=int, help="Seed for PyTorch/NumPy.") 26 | parser.add_argument("--weight_decay", default=0, type=float, help="Weight decay for optimizer") 27 | parser.add_argument("--max_epoch", default=540, type=int, help="Maximum number of epochs for training") 28 | parser.add_argument("--milestones", default="80,160,240,320,400,480", type=str, help="Milestones for MultiStepLR") 29 | parser.add_argument("--resume", action="store_true", help="Resume training from latest checkpoint") 30 | parser.add_argument("--no_validate", action="store_true", help="No validation") 31 | 32 | parser.add_argument('--print_freq', default=1, type=int, help='Print frequency to screen (# of iterations)') 33 | parser.add_argument('--summary_freq', default=5, type=int, help='Summary frequency to tensorboard (# of iterations)') 34 | parser.add_argument('--val_freq', default=5, type=int, help='Validation frequency (# of epochs)') 35 | parser.add_argument('--save_ckpt_freq', default=10, type=int, help='Save checkpoint frequency (# of epochs)') 36 | parser.add_argument("--mask_checkpoints", default="999999", type=str, help="Epochs at which to save mask to file. (# of epochs)") 37 | 38 | # Image-specific 39 | parser.add_argument('--views_x', default=9, type=int, help='Lightfield dimension 0') 40 | parser.add_argument('--views_y', default=9, type=int, help='Lightfield dimension 1') 41 | parser.add_argument('--img_height', default=512, type=int, help='Sub-aperture view height.') 42 | parser.add_argument('--img_width', default=512, type=int, help='Sub-aperture view width.') 43 | parser.add_argument('--num_patches', default=9, type=int, help='Number of patches for patch-based training.') 44 | parser.add_argument('--patch_height', default=80, type=int, help='Image patch height for patch-based training.') 45 | parser.add_argument('--patch_width', default=80, type=int, help='Image patch width for patch-based training.') 46 | parser.add_argument('--pad_x', default=8, type=int, help='Patch padding in width.') 47 | parser.add_argument('--pad_y', default=8, type=int, help='Patch padding in height.') 48 | 49 | 50 | # Network-specific 51 | parser.add_argument("--f_list", default="100e6", type=str, help="List of modulation frequencies for phase unwrapping.") 52 | parser.add_argument("--g", default=20, type=float, help="Gain of the sensor. Metric not defined.") 53 | parser.add_argument("--T", default=1000, type=float, help="Integration time. Metric not defined.") 54 | parser.add_argument("--mT", default=2000, type=float, help="Modulation period. Default 2x integration time.") 55 | parser.add_argument("--AWGN_sigma", default=3, type=float, help="Additive white gaussian noise's standard deviation.") 56 | parser.add_argument("--init", default="uniform", type=str, help="Mask initiliazation [ones, zeros, uniform, bernoulli, barcodeX, custom, gaussian_circlesX,Y]") 57 | parser.add_argument("--use_net", action="store_true", help="Add encoder-decoder unet for reconstruction.") 58 | parser.add_argument("--net_lr", default=0.004, type=float, help="Network learning rate") 59 | parser.add_argument("--mask_lr", default=0.1, type=float, help="Mask learning rate") 60 | parser.add_argument("--mask_start_epoch", default=70, type=int, help="Epoch at which to begin updating mask.") 61 | parser.add_argument("--softmax_gamma", default=0, type=float, help="Gamma for mask sigmoid scaling factor. scale = (1 + (gamma*t)^2)") 62 | parser.add_argument("--l1_weight", default=100, type=float, help="Weight for L1 loss term.") 63 | parser.add_argument("--chamfer_weight", default=0.08, type=float, help="Weight for chamfer loss term.") 64 | parser.add_argument("--refinement", default="Refinement", type=str, help="Filename of refinement augmentation (For import).") 65 | 66 | 67 | args = parser.parse_args() 68 | 69 | utils.check_path(args.checkpoint_dir) 70 | utils.save_args(args) 71 | 72 | def main(): 73 | # Seed for reproducibility 74 | torch.manual_seed(args.seed) 75 | torch.cuda.manual_seed(args.seed) 76 | np.random.seed(args.seed) 77 | 78 | # speedup if input is same size 79 | torch.backends.cudnn.benchmark = True 80 | 81 | print("=> Training args: {0}".format(args)) 82 | 83 | if torch.cuda.is_available(): 84 | device = torch.device("cuda") 85 | print("=> Training on {0} GPU(s)".format(torch.cuda.device_count())) 86 | else: 87 | device = torch.device("cpu") 88 | print("=> Training on CPU") 89 | 90 | # Train loader 91 | train_transform_list = [transforms.RGBtoNIR(), 92 | transforms.ToTensor() 93 | ] 94 | train_augmentation_list = [transforms.ToRandomPatches(args.num_patches, args.patch_width, args.patch_height)] 95 | train_transform = transforms.Compose(train_transform_list) 96 | train_augmentation = transforms.Compose(train_augmentation_list) 97 | train_data = LightFieldDataset(dataset_name=args.dataset_name, 98 | mode=args.mode, 99 | transform=train_transform, 100 | augmentation=train_augmentation) 101 | train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True, 102 | num_workers=args.num_workers, pin_memory=True, drop_last=False) 103 | 104 | # Validation loader 105 | val_transform_list = [transforms.RGBtoNIR(), 106 | transforms.ToTensor() 107 | ] 108 | val_transform = transforms.Compose(val_transform_list) 109 | val_data = LightFieldDataset(dataset_name=args.dataset_name, 110 | mode="val", 111 | transform=val_transform) 112 | 113 | val_loader = DataLoader(dataset=val_data, batch_size=4, shuffle=False, 114 | num_workers=args.num_workers, pin_memory=True, drop_last=False) 115 | 116 | print("=> {} training samples found in the training set".format(len(train_data))) 117 | 118 | # Network 119 | net = MaskToFNet(args, device).to(device) 120 | print(net.parameters()) 121 | 122 | net_params = list(filter(lambda kv: kv[0] != "amplitude_mask.mask", net.named_parameters())) 123 | mask_params = list(filter(lambda kv: kv[0] == "amplitude_mask.mask", net.named_parameters())) 124 | net_params = [kv[1] for kv in net_params] # kv is a tuple (key, value) 125 | mask_params = [kv[1] for kv in mask_params] 126 | params_group = [{'params': net_params, 'lr': args.net_lr}, 127 | {'params': mask_params, 'lr': args.mask_lr}, ] 128 | optimizer = torch.optim.Adam(params_group, weight_decay=args.weight_decay) 129 | 130 | print("%s" % net) 131 | 132 | if args.pretrained_net is not None: 133 | logger.info("=> Loading pretrained network: %s" % args.pretrained_net) 134 | # Enable training from a partially pretrained model 135 | utils.load_checkpoint(aanet, args.pretrained_net) 136 | 137 | # Parameters 138 | num_params = utils.count_parameters(net) 139 | print("=> Number of trainable parameters: %d" % num_params) 140 | 141 | # Resume training 142 | if args.resume: 143 | # Load Network 144 | start_epoch, start_iter = utils.resume_latest_ckpt(args.checkpoint_dir, net, "net") 145 | # Load Optimizer 146 | utils.resume_latest_ckpt(args.checkpoint_dir, optimizer, "optimizer") 147 | else: 148 | start_epoch = 0 149 | start_iter = 0 150 | 151 | args.f_list = [float(f) for f in args.f_list.split(",")] 152 | args.milestones = [int(step) for step in args.milestones.split(",")] 153 | args.mask_checkpoints = [int(i) for i in args.mask_checkpoints.split(",")] 154 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=0.5) 155 | train_model = model.Model(args, optimizer, net, device, start_iter, start_epoch, 156 | views_x=args.views_x, views_y=args.views_y, img_height=args.img_height, img_width=args.img_width) 157 | 158 | print("=> Start training...") 159 | 160 | for epoch in range(start_epoch, args.max_epoch): 161 | train_model.train(train_loader) 162 | if not args.no_validate: 163 | if epoch % args.val_freq == 0 or epoch == (args.max_epoch - 1): 164 | train_model.validate(val_loader) 165 | lr_scheduler.step() 166 | 167 | print("=> End training\n\n") 168 | 169 | 170 | if __name__ == "__main__": 171 | main() -------------------------------------------------------------------------------- /utils/barcode_masks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/MaskToF/d34c453631e97424a5dff20f155c509091a75026/utils/barcode_masks/__init__.py -------------------------------------------------------------------------------- /utils/chamfer_distance/__init__.py: -------------------------------------------------------------------------------- 1 | from .chamfer_distance import ChamferDistance 2 | -------------------------------------------------------------------------------- /utils/chamfer_distance/chamfer_distance.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // CUDA forward declarations 4 | void ChamferDistanceKernelLauncher( 5 | const int b, const int n, 6 | const float* xyz, 7 | const int m, 8 | const float* xyz2, 9 | float* result, 10 | int* result_i, 11 | float* result2, 12 | int* result2_i); 13 | 14 | void ChamferDistanceGradKernelLauncher( 15 | const int b, const int n, 16 | const float* xyz1, 17 | const int m, 18 | const float* xyz2, 19 | const float* grad_dist1, 20 | const int* idx1, 21 | const float* grad_dist2, 22 | const int* idx2, 23 | float* grad_xyz1, 24 | float* grad_xyz2); 25 | 26 | 27 | void chamfer_distance_forward_cuda( 28 | const at::Tensor xyz1, 29 | const at::Tensor xyz2, 30 | const at::Tensor dist1, 31 | const at::Tensor dist2, 32 | const at::Tensor idx1, 33 | const at::Tensor idx2) 34 | { 35 | ChamferDistanceKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(), 36 | xyz2.size(1), xyz2.data(), 37 | dist1.data(), idx1.data(), 38 | dist2.data(), idx2.data()); 39 | } 40 | 41 | void chamfer_distance_backward_cuda( 42 | const at::Tensor xyz1, 43 | const at::Tensor xyz2, 44 | at::Tensor gradxyz1, 45 | at::Tensor gradxyz2, 46 | at::Tensor graddist1, 47 | at::Tensor graddist2, 48 | at::Tensor idx1, 49 | at::Tensor idx2) 50 | { 51 | ChamferDistanceGradKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(), 52 | xyz2.size(1), xyz2.data(), 53 | graddist1.data(), idx1.data(), 54 | graddist2.data(), idx2.data(), 55 | gradxyz1.data(), gradxyz2.data()); 56 | } 57 | 58 | 59 | void nnsearch( 60 | const int b, const int n, const int m, 61 | const float* xyz1, 62 | const float* xyz2, 63 | float* dist, 64 | int* idx) 65 | { 66 | for (int i = 0; i < b; i++) { 67 | for (int j = 0; j < n; j++) { 68 | const float x1 = xyz1[(i*n+j)*3+0]; 69 | const float y1 = xyz1[(i*n+j)*3+1]; 70 | const float z1 = xyz1[(i*n+j)*3+2]; 71 | double best = 0; 72 | int besti = 0; 73 | for (int k = 0; k < m; k++) { 74 | const float x2 = xyz2[(i*m+k)*3+0] - x1; 75 | const float y2 = xyz2[(i*m+k)*3+1] - y1; 76 | const float z2 = xyz2[(i*m+k)*3+2] - z1; 77 | const double d=x2*x2+y2*y2+z2*z2; 78 | if (k==0 || d < best){ 79 | best = d; 80 | besti = k; 81 | } 82 | } 83 | dist[i*n+j] = best; 84 | idx[i*n+j] = besti; 85 | } 86 | } 87 | } 88 | 89 | 90 | void chamfer_distance_forward( 91 | const at::Tensor xyz1, 92 | const at::Tensor xyz2, 93 | const at::Tensor dist1, 94 | const at::Tensor dist2, 95 | const at::Tensor idx1, 96 | const at::Tensor idx2) 97 | { 98 | const int batchsize = xyz1.size(0); 99 | const int n = xyz1.size(1); 100 | const int m = xyz2.size(1); 101 | 102 | const float* xyz1_data = xyz1.data(); 103 | const float* xyz2_data = xyz2.data(); 104 | float* dist1_data = dist1.data(); 105 | float* dist2_data = dist2.data(); 106 | int* idx1_data = idx1.data(); 107 | int* idx2_data = idx2.data(); 108 | 109 | nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data); 110 | nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data); 111 | } 112 | 113 | 114 | void chamfer_distance_backward( 115 | const at::Tensor xyz1, 116 | const at::Tensor xyz2, 117 | at::Tensor gradxyz1, 118 | at::Tensor gradxyz2, 119 | at::Tensor graddist1, 120 | at::Tensor graddist2, 121 | at::Tensor idx1, 122 | at::Tensor idx2) 123 | { 124 | const int b = xyz1.size(0); 125 | const int n = xyz1.size(1); 126 | const int m = xyz2.size(1); 127 | 128 | const float* xyz1_data = xyz1.data(); 129 | const float* xyz2_data = xyz2.data(); 130 | float* gradxyz1_data = gradxyz1.data(); 131 | float* gradxyz2_data = gradxyz2.data(); 132 | float* graddist1_data = graddist1.data(); 133 | float* graddist2_data = graddist2.data(); 134 | const int* idx1_data = idx1.data(); 135 | const int* idx2_data = idx2.data(); 136 | 137 | for (int i = 0; i < b*n*3; i++) 138 | gradxyz1_data[i] = 0; 139 | for (int i = 0; i < b*m*3; i++) 140 | gradxyz2_data[i] = 0; 141 | for (int i = 0;i < b; i++) { 142 | for (int j = 0; j < n; j++) { 143 | const float x1 = xyz1_data[(i*n+j)*3+0]; 144 | const float y1 = xyz1_data[(i*n+j)*3+1]; 145 | const float z1 = xyz1_data[(i*n+j)*3+2]; 146 | const int j2 = idx1_data[i*n+j]; 147 | 148 | const float x2 = xyz2_data[(i*m+j2)*3+0]; 149 | const float y2 = xyz2_data[(i*m+j2)*3+1]; 150 | const float z2 = xyz2_data[(i*m+j2)*3+2]; 151 | const float g = graddist1_data[i*n+j]*2; 152 | 153 | gradxyz1_data[(i*n+j)*3+0] += g*(x1-x2); 154 | gradxyz1_data[(i*n+j)*3+1] += g*(y1-y2); 155 | gradxyz1_data[(i*n+j)*3+2] += g*(z1-z2); 156 | gradxyz2_data[(i*m+j2)*3+0] -= (g*(x1-x2)); 157 | gradxyz2_data[(i*m+j2)*3+1] -= (g*(y1-y2)); 158 | gradxyz2_data[(i*m+j2)*3+2] -= (g*(z1-z2)); 159 | } 160 | for (int j = 0; j < m; j++) { 161 | const float x1 = xyz2_data[(i*m+j)*3+0]; 162 | const float y1 = xyz2_data[(i*m+j)*3+1]; 163 | const float z1 = xyz2_data[(i*m+j)*3+2]; 164 | const int j2 = idx2_data[i*m+j]; 165 | const float x2 = xyz1_data[(i*n+j2)*3+0]; 166 | const float y2 = xyz1_data[(i*n+j2)*3+1]; 167 | const float z2 = xyz1_data[(i*n+j2)*3+2]; 168 | const float g = graddist2_data[i*m+j]*2; 169 | gradxyz2_data[(i*m+j)*3+0] += g*(x1-x2); 170 | gradxyz2_data[(i*m+j)*3+1] += g*(y1-y2); 171 | gradxyz2_data[(i*m+j)*3+2] += g*(z1-z2); 172 | gradxyz1_data[(i*n+j2)*3+0] -= (g*(x1-x2)); 173 | gradxyz1_data[(i*n+j2)*3+1] -= (g*(y1-y2)); 174 | gradxyz1_data[(i*n+j2)*3+2] -= (g*(z1-z2)); 175 | } 176 | } 177 | } 178 | 179 | 180 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 181 | m.def("forward", &chamfer_distance_forward, "ChamferDistance forward"); 182 | m.def("forward_cuda", &chamfer_distance_forward_cuda, "ChamferDistance forward (CUDA)"); 183 | m.def("backward", &chamfer_distance_backward, "ChamferDistance backward"); 184 | m.def("backward_cuda", &chamfer_distance_backward_cuda, "ChamferDistance backward (CUDA)"); 185 | } 186 | -------------------------------------------------------------------------------- /utils/chamfer_distance/chamfer_distance.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | __global__ 7 | void ChamferDistanceKernel( 8 | int b, 9 | int n, 10 | const float* xyz, 11 | int m, 12 | const float* xyz2, 13 | float* result, 14 | int* result_i) 15 | { 16 | const int batch=512; 17 | __shared__ float buf[batch*3]; 18 | for (int i=blockIdx.x;ibest){ 130 | result[(i*n+j)]=best; 131 | result_i[(i*n+j)]=best_i; 132 | } 133 | } 134 | __syncthreads(); 135 | } 136 | } 137 | } 138 | 139 | void ChamferDistanceKernelLauncher( 140 | const int b, const int n, 141 | const float* xyz, 142 | const int m, 143 | const float* xyz2, 144 | float* result, 145 | int* result_i, 146 | float* result2, 147 | int* result2_i) 148 | { 149 | ChamferDistanceKernel<<>>(b, n, xyz, m, xyz2, result, result_i); 150 | ChamferDistanceKernel<<>>(b, m, xyz2, n, xyz, result2, result2_i); 151 | 152 | cudaError_t err = cudaGetLastError(); 153 | if (err != cudaSuccess) 154 | printf("error in chamfer distance updateOutput: %s\n", cudaGetErrorString(err)); 155 | } 156 | 157 | 158 | __global__ 159 | void ChamferDistanceGradKernel( 160 | int b, int n, 161 | const float* xyz1, 162 | int m, 163 | const float* xyz2, 164 | const float* grad_dist1, 165 | const int* idx1, 166 | float* grad_xyz1, 167 | float* grad_xyz2) 168 | { 169 | for (int i = blockIdx.x; i>>(b, n, xyz1, m, xyz2, grad_dist1, idx1, grad_xyz1, grad_xyz2); 204 | ChamferDistanceGradKernel<<>>(b, m, xyz2, n, xyz1, grad_dist2, idx2, grad_xyz2, grad_xyz1); 205 | 206 | cudaError_t err = cudaGetLastError(); 207 | if (err != cudaSuccess) 208 | printf("error in chamfer distance get grad: %s\n", cudaGetErrorString(err)); 209 | } 210 | -------------------------------------------------------------------------------- /utils/chamfer_distance/chamfer_distance.py: -------------------------------------------------------------------------------- 1 | # source: https://github.com/chrdiller/pyTorchChamferDistance 2 | 3 | import torch 4 | 5 | from torch.utils.cpp_extension import load 6 | cd = load(name="cd", 7 | sources=["utils/chamfer_distance/chamfer_distance.cpp", 8 | "utils/chamfer_distance/chamfer_distance.cu"]) 9 | 10 | class ChamferDistanceFunction(torch.autograd.Function): 11 | @staticmethod 12 | def forward(ctx, xyz1, xyz2): 13 | batchsize, n, _ = xyz1.size() 14 | _, m, _ = xyz2.size() 15 | xyz1 = xyz1.contiguous() 16 | xyz2 = xyz2.contiguous() 17 | dist1 = torch.zeros(batchsize, n) 18 | dist2 = torch.zeros(batchsize, m) 19 | 20 | idx1 = torch.zeros(batchsize, n, dtype=torch.int) 21 | idx2 = torch.zeros(batchsize, m, dtype=torch.int) 22 | 23 | if not xyz1.is_cuda: 24 | cd.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 25 | else: 26 | dist1 = dist1.cuda() 27 | dist2 = dist2.cuda() 28 | idx1 = idx1.cuda() 29 | idx2 = idx2.cuda() 30 | cd.forward_cuda(xyz1, xyz2, dist1, dist2, idx1, idx2) 31 | 32 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 33 | 34 | return dist1, dist2 35 | 36 | @staticmethod 37 | def backward(ctx, graddist1, graddist2): 38 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 39 | 40 | graddist1 = graddist1.contiguous() 41 | graddist2 = graddist2.contiguous() 42 | 43 | gradxyz1 = torch.zeros(xyz1.size()) 44 | gradxyz2 = torch.zeros(xyz2.size()) 45 | 46 | if not graddist1.is_cuda: 47 | cd.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) 48 | else: 49 | gradxyz1 = gradxyz1.cuda() 50 | gradxyz2 = gradxyz2.cuda() 51 | cd.backward_cuda(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) 52 | 53 | return gradxyz1, gradxyz2 54 | 55 | 56 | class ChamferDistance(torch.nn.Module): 57 | def forward(self, xyz1, xyz2): 58 | return ChamferDistanceFunction.apply(xyz1, xyz2) -------------------------------------------------------------------------------- /utils/file_io.py: -------------------------------------------------------------------------------- 1 | # source: https://github.com/lightfield-analysis/python-tools 2 | 3 | import configparser 4 | import os 5 | import sys 6 | import re 7 | 8 | import numpy as np 9 | 10 | 11 | def read_lightfield(data_folder): 12 | params = read_parameters(data_folder) 13 | light_field = np.zeros((params["num_cams_x"], params["num_cams_y"], params["height"], params["width"], 3), dtype=np.uint8) 14 | 15 | views = sorted([f for f in os.listdir(data_folder) if f.startswith("input_") and f.endswith(".png")]) 16 | 17 | for idx, view in enumerate(views): 18 | fpath = os.path.join(data_folder, view) 19 | try: 20 | img = read_img(fpath) 21 | light_field[int(idx / params["num_cams_x"]), int(idx % params["num_cams_y"]), :, :, :] = img 22 | except IOError: 23 | print("Could not read input file: %s" % fpath) 24 | sys.exit() 25 | 26 | return light_field 27 | 28 | 29 | def read_parameters(data_folder): 30 | params = dict() 31 | 32 | with open(os.path.join(data_folder, "parameters.cfg"), "r") as f: 33 | parser = configparser.ConfigParser() 34 | parser.readfp(f) 35 | 36 | section = "intrinsics" 37 | params["width"] = int(parser.get(section, 'image_resolution_x_px')) 38 | params["height"] = int(parser.get(section, 'image_resolution_y_px')) 39 | params["focal_length_mm"] = float(parser.get(section, 'focal_length_mm')) 40 | params["sensor_size_mm"] = float(parser.get(section, 'sensor_size_mm')) 41 | params["fstop"] = float(parser.get(section, 'fstop')) 42 | 43 | section = "extrinsics" 44 | params["num_cams_x"] = int(parser.get(section, 'num_cams_x')) 45 | params["num_cams_y"] = int(parser.get(section, 'num_cams_y')) 46 | params["baseline_mm"] = float(parser.get(section, 'baseline_mm')) 47 | params["focus_distance_m"] = float(parser.get(section, 'focus_distance_m')) 48 | params["center_cam_x_m"] = float(parser.get(section, 'center_cam_x_m')) 49 | params["center_cam_y_m"] = float(parser.get(section, 'center_cam_y_m')) 50 | params["center_cam_z_m"] = float(parser.get(section, 'center_cam_z_m')) 51 | params["center_cam_rx_rad"] = float(parser.get(section, 'center_cam_rx_rad')) 52 | params["center_cam_ry_rad"] = float(parser.get(section, 'center_cam_ry_rad')) 53 | params["center_cam_rz_rad"] = float(parser.get(section, 'center_cam_rz_rad')) 54 | 55 | section = "meta" 56 | params["disp_min"] = float(parser.get(section, 'disp_min')) 57 | params["disp_max"] = float(parser.get(section, 'disp_max')) 58 | params["frustum_disp_min"] = float(parser.get(section, 'frustum_disp_min')) 59 | params["frustum_disp_max"] = float(parser.get(section, 'frustum_disp_max')) 60 | params["depth_map_scale"] = float(parser.get(section, 'depth_map_scale')) 61 | 62 | params["scene"] = parser.get(section, 'scene') 63 | params["category"] = parser.get(section, 'category') 64 | params["date"] = parser.get(section, 'date') 65 | params["version"] = parser.get(section, 'version') 66 | params["authors"] = parser.get(section, 'authors').split(", ") 67 | params["contact"] = parser.get(section, 'contact') 68 | 69 | return params 70 | 71 | 72 | def read_depth(data_folder, highres=False): 73 | fpath = os.path.join(data_folder, "gt_depth_%s.pfm" % ("highres" if highres else "lowres")) 74 | try: 75 | data = read_pfm(fpath) 76 | except IOError: 77 | # print("Could not read depth file: %s" % fpath) 78 | return None 79 | return np.ascontiguousarray(data) 80 | 81 | 82 | def read_depth_all_view(data_folder, N=81): 83 | data = [] 84 | for i in range(N): 85 | fpath = os.path.join(data_folder, "gt_depth_lowres_Cam%03d.pfm" % i) 86 | try: 87 | data_i = read_pfm(fpath) 88 | except IOError: 89 | print("Could not read depth file: %s" % fpath) 90 | sys.exit() 91 | 92 | data.append( data_i ) 93 | 94 | data = np.array(data) 95 | return data 96 | 97 | def read_disparity(data_folder, highres=False): 98 | fpath = os.path.join(data_folder, "gt_disp_%s.pfm" % ("highres" if highres else "lowres")) 99 | try: 100 | data = read_pfm(fpath) 101 | except IOError: 102 | # print("Could not read disparity file: %s" % fpath) 103 | return None 104 | return data 105 | 106 | def read_disparity_all_view(data_folder, N=81): 107 | data = [] 108 | for i in range(N): 109 | fpath = os.path.join(data_folder, "gt_disp_lowres_Cam%03d.pfm" % i) 110 | try: 111 | data_i = read_pfm(fpath) 112 | except IOError: 113 | print("Could not read disparity file: %s" % fpath) 114 | sys.exit() 115 | data.append( data_i ) 116 | data = np.array(data) 117 | return data 118 | 119 | 120 | def read_img(fpath): 121 | from scipy import misc 122 | import imageio 123 | #data = misc.imread(fpath) 124 | data = imageio.imread(fpath) 125 | return data 126 | 127 | 128 | def write_hdf5(data, fpath): 129 | import h5py 130 | h = h5py.File(fpath, 'w') 131 | for key, value in data.iteritems(): 132 | h.create_dataset(key, data=value) 133 | h.close() 134 | 135 | 136 | def write_pfm(data, fpath, scale=1, file_identifier="Pf", dtype="float32"): 137 | # PFM format definition: http://netpbm.sourceforge.net/doc/pfm.html 138 | 139 | data = np.flipud(data) 140 | height, width = np.shape(data)[:2] 141 | values = np.ndarray.flatten(np.asarray(data, dtype=dtype)) 142 | endianess = data.dtype.byteorder 143 | print(endianess) 144 | 145 | if endianess == '<' or (endianess == '=' and sys.byteorder == 'little'): 146 | scale *= -1 147 | 148 | with open(fpath, 'wb') as file: 149 | file.write(file_identifier + '\n') 150 | file.write('%d %d\n' % (width, height)) 151 | file.write('%d\n' % scale) 152 | file.write(values) 153 | 154 | def read_pfm(file): 155 | file = open(file, 'rb') 156 | color = None 157 | width = None 158 | height = None 159 | scale = None 160 | endian = None 161 | 162 | header = file.readline().rstrip() 163 | 164 | if header.decode("ascii") == 'PF': 165 | color = True 166 | 167 | elif header.decode("ascii") == 'Pf': 168 | color = False 169 | 170 | else: 171 | raise Exception('Not a PFM file.') 172 | 173 | 174 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii")) 175 | 176 | if dim_match: 177 | width, height = list(map(int, dim_match.groups())) 178 | 179 | else: 180 | raise Exception('Malformed PFM header.') 181 | 182 | scale = float(file.readline().decode("ascii").rstrip()) 183 | 184 | if scale < 0: # little-endian 185 | endian = '<' 186 | scale = -scale 187 | 188 | else: 189 | endian = '>' # big-endian 190 | 191 | data = np.fromfile(file, endian + 'f') 192 | shape = (height, width, 3) if color else (height, width) 193 | 194 | data = np.reshape(data, shape) 195 | data = np.flipud(data) 196 | return data 197 | 198 | 199 | def _get_next_line(f): 200 | next_line = f.readline().rstrip() 201 | # ignore comments 202 | while next_line.startswith('#'): 203 | next_line = f.readline().rstrip() 204 | return next_line 205 | 206 | 207 | -------------------------------------------------------------------------------- /utils/tof.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.constants import speed_of_light 4 | from itertools import product, combinations 5 | 6 | 7 | def sim_quad(depth, f, T, g, e): # convert 8 | """Simulate quad amplitude for 3D time-of-flight cameras 9 | Args: 10 | depth (tensor [B,1,H,W]): [depth map] 11 | f (scalar): [frequency in Hz] 12 | T (scalar): [integration time. metric not defined] 13 | g (scalar): [gain of the sensor. metric not defined] 14 | e (tensor [B,1,H,W]): [number of electrons created by a photon incident to the sensor. metric not defined] 15 | 16 | Returns: 17 | amplitudes (tensor [B,4,H,W]): [Signal amplitudes at 4 phase offsets.] 18 | """ 19 | 20 | tru_phi = depth2phase(depth, f) 21 | 22 | A0 = g*e*(0.5+torch.cos(tru_phi)/np.pi)*T 23 | A1 = g*e*(0.5-torch.sin(tru_phi)/np.pi)*T 24 | A2 = g*e*(0.5-torch.cos(tru_phi)/np.pi)*T 25 | A3 = g*e*(0.5+torch.sin(tru_phi)/np.pi)*T 26 | 27 | return torch.stack([A0, A1, A2, A3], dim=1) 28 | 29 | 30 | def decode_quad(amplitudes, T, mT): # convert 31 | """Simulate solid-state time-of-flight range camera. 32 | Args: 33 | amplitudes (tensor [B,4,H,W]): [Signal amplitudes at 4 phase offsets. See sim_quad().] 34 | T (scalar): [Integration time. Metric not defined.] 35 | mT (scalar): [Modulation period] 36 | 37 | Returns: 38 | phi_est, amplitude_est, offset_est (tuple(tensor [B,1,H,W])): [Estimated phi, amplitude, and offset] 39 | """ 40 | assert amplitudes.shape[1] % 4 == 0 41 | 42 | A0, A1, A2, A3 = amplitudes[:,0::4,...], amplitudes[:,1::4,...], amplitudes[:,2::4,...], amplitudes[:,3::4,...] 43 | sigma = np.pi * T / mT 44 | 45 | phi_est = torch.atan2((A3-A1),(A0-A2)) 46 | phi_est[phi_est<0] = phi_est[phi_est<0] + 2*np.pi 47 | 48 | amplitude_est = (sigma/T*np.sin(sigma)) * (( (A3-A1)**2 + (A0-A2)**2 )**0.5)/2 49 | offset_est = (A0+A1+A2+A3)/(4*T) 50 | 51 | return phi_est, amplitude_est, offset_est 52 | 53 | def depth2phase(depth, freq): # convert 54 | """Convert depth map to phase map. 55 | Args: 56 | depth (tensor [B,1,H,W]): Depth map (mm) 57 | freq (scalar): Frequency (hz) 58 | 59 | Returns: 60 | phase (tensor [B,1,H,W]): Phase map (radian) 61 | """ 62 | 63 | tru_phi = (4*np.pi*depth*freq)/(1000*speed_of_light) 64 | return tru_phi 65 | 66 | def phase2depth(phase, freq): # convert 67 | """Convert phase map to depth map. 68 | Args: 69 | phase (tensor [B,1,H,W]): Phase map (radian) 70 | freq ([type]): Frequency (Hz) 71 | 72 | Returns: 73 | depth (tensor [B,1,H,W]): Depth map (mm) 74 | """ 75 | 76 | depth = (1000*speed_of_light*phase)/(4*np.pi*freq) 77 | return depth 78 | 79 | def unwrap_ranking(phiList, f_list, min_depth=5000, max_depth=10000): 80 | 81 | """Efficient Multi-Frequency Phase Unwrapping using Kernel Density Estimation 82 | Args: 83 | phiList (list(K x tensor [B,4,H,W])): K different wrapped phases measured at the given frequencies f_list 84 | f_list (list [K]): K different frequencies in Hz. 85 | min_depth (scalar) : min depth in mm 86 | max_depth (scalar) : max depth in mm 87 | Returns: 88 | depth (tensor [B,1,H,W]): Unwrapped depth map (mm) 89 | """ 90 | 91 | B,H,W = phiList[0].shape 92 | device = phiList[0].device 93 | 94 | # Compute the range of potential n (wraps) 95 | N = len(f_list) 96 | min_n = torch.zeros((N),dtype=np.int) 97 | max_n = torch.zeros((N),dtype=np.int) 98 | for i in range(N): 99 | min_phase = depth2phase(min_depth, f_list[i]) 100 | max_phase = depth2phase(max_depth, f_list[i]) 101 | min_n[i] = np.floor(min_phase/(2*np.pi)) 102 | max_n[i] = np.floor(max_phase/(2*np.pi)) 103 | 104 | n_list = [] 105 | for i in range(N): 106 | n_list_ = [] 107 | for n in range(min_n[i], max_n[i]+1, 1): 108 | n_list_.append(n) 109 | n_list.append(n_list_) 110 | 111 | from itertools import product, combinations 112 | prod = list(product(*n_list)) 113 | M = len(prod) # number of potential combinations 114 | 115 | k = np.lcm.reduce(np.array(f_list).astype(np.int))/f_list 116 | 117 | t = [] 118 | for i in range(N): 119 | t.append( (phiList[i]/(2*np.pi))*k[i] ) 120 | t = torch.stack(t) # [N, B, H, W] 121 | 122 | err_list = torch.zeros((M,B,H,W), device=device) 123 | for i in range(M): 124 | n = prod[i] 125 | 126 | pairs = list(combinations(n, 2)) 127 | kpairs = list(combinations((np.linspace(0,N-1,N)).astype(np.int), 2)) 128 | for j, pair in enumerate(pairs): 129 | kpair = kpairs[j] 130 | k1 = k[kpair[0]] 131 | k2 = k[kpair[1]] 132 | n1 = pair[0] 133 | n2 = pair[1] 134 | 135 | err = (k1*n1 - k2*n2 - (t[kpair[1]] - t[kpair[0]]))**2 136 | w = min(1/(((k1/(2*np.pi))**2 + (k2/(2*np.pi))**2)*0.16871118634340782), 10.0) 137 | err_list[i,:,:,:] += err * w 138 | 139 | # sort and select the top three 140 | ind_sorted = torch.argsort(err_list, axis=0) 141 | ind_min = ind_sorted[0] 142 | ind_second = ind_sorted[1] 143 | ind_third = ind_sorted[2] 144 | 145 | # get the best one 146 | n_list_best = torch.zeros((N,B,H,W), device=device) 147 | 148 | for i in range(M): 149 | m = (ind_min == i) 150 | for j in range(N): 151 | n_list_best_i = torch.zeros((B,H,W), device=device) 152 | n_list_best_i[m] = prod[i][j] 153 | n_list_best[j,...] += n_list_best_i 154 | 155 | # simple phase unwrapping with the ranking 156 | up = [] 157 | for i in range(N): 158 | up_i = phiList[i]+2*np.pi*n_list_best[i,...] 159 | up.append(up_i) 160 | up = np.array(up) 161 | 162 | depth = [] 163 | for i in range(N): 164 | depth_i = phase2depth(up[i], f_list[i]) 165 | depth.append(depth_i) 166 | 167 | depth = torch.stack(depth) 168 | 169 | depth = depth.mean(axis=0) 170 | return depth 171 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import time 4 | import torch 5 | import numpy as np 6 | import torchvision.utils as vutils 7 | from glob import glob 8 | import logging 9 | import optparse 10 | import os 11 | 12 | 13 | def tensor_to_pcd(arr): 14 | B, C, H, W = arr.shape 15 | arr = arr.squeeze().reshape(B,H*W,1) 16 | x = torch.arange(0,W, device=arr.device) 17 | x = x[None,:,None].repeat(B,H,1) 18 | y = torch.arange(0,H, device=arr.device) 19 | y = y[None,:,None].repeat(B,1,1).repeat_interleave(W, dim=1) 20 | pcd = torch.cat([x,y,arr], dim=2) 21 | return pcd 22 | 23 | def read_text_lines(filepath): 24 | with open(filepath, "r") as f: 25 | lines = f.readlines() 26 | lines = [l.rstrip() for l in lines] 27 | return lines 28 | 29 | 30 | def check_path(path): 31 | if not os.path.exists(path): 32 | os.makedirs(path, exist_ok=True) # explicitly set exist_ok when multi-processing 33 | 34 | def save_args(args, filename="args.json"): 35 | args_dict = vars(args) 36 | check_path(args.checkpoint_dir) 37 | save_path = os.path.join(args.checkpoint_dir, filename) 38 | 39 | with open(save_path, "w") as f: 40 | json.dump(args_dict, f, indent=4, sort_keys=False) 41 | 42 | def count_parameters(net): 43 | num = sum(p.numel() for p in net.parameters() if p.requires_grad) 44 | return num 45 | 46 | def save_checkpoint(save_path, optimizer, net, epoch, num_iter, 47 | loss, mask, filename=None, save_optimizer=True): 48 | # Network 49 | net_state = { 50 | "epoch": epoch, 51 | "num_iter": num_iter, 52 | "loss": loss, 53 | "state_dict": net.state_dict() 54 | } 55 | net_filename = "net_epoch_{:0>3d}.pt".format(epoch) if filename is None else filename 56 | net_save_path = os.path.join(save_path, net_filename) 57 | torch.save(net_state, net_save_path) 58 | 59 | mask_name = net_filename.replace("net", "mask") 60 | mask_save_path = os.path.join(save_path, mask_name) 61 | torch.save(mask, mask_save_path) 62 | 63 | # Optimizer 64 | if save_optimizer: 65 | optimizer_state = { 66 | "epoch": epoch, 67 | "num_iter": num_iter, 68 | "loss": loss, 69 | "state_dict": optimizer.state_dict() 70 | } 71 | optimizer_name = net_filename.replace("net", "optimizer") 72 | optimizer_save_path = os.path.join(save_path, optimizer_name) 73 | torch.save(optimizer_state, optimizer_save_path) 74 | 75 | 76 | def load_checkpoint(net, pretrained_path, return_epoch_iter=False, resume=False, no_strict=False): 77 | if pretrained_path is not None: 78 | if torch.cuda.is_available(): 79 | state = torch.load(pretrained_path, map_location="cuda") 80 | else: 81 | state = torch.load(pretrained_path, map_location="cpu") 82 | 83 | net.load_state_dict(state["state_dict"]) # optimizer has no argument `strict` 84 | 85 | if return_epoch_iter: 86 | epoch = state["epoch"] if "epoch" in state.keys() else None 87 | num_iter = state["num_iter"] if "num_iter" in state.keys() else None 88 | return epoch, num_iter 89 | 90 | 91 | def resume_latest_ckpt(checkpoint_dir, net, net_name): 92 | ckpts = sorted(glob(checkpoint_dir + "/" + net_name + "*.pt")) 93 | 94 | if len(ckpts) == 0: 95 | raise RuntimeError("=> No checkpoint found while resuming training") 96 | 97 | latest_ckpt = ckpts[-1] 98 | print("=> Resume latest {0} checkpoint: {1}".format(net_name, os.path.basename(latest_ckpt))) 99 | epoch, num_iter = load_checkpoint(net, latest_ckpt, True, True) 100 | 101 | return epoch, num_iter 102 | 103 | def save_images(logger, mode_tag, images_dict, global_step): 104 | images_dict = tensor2numpy(images_dict) 105 | for tag, values in images_dict.items(): 106 | if not isinstance(values, list) and not isinstance(values, tuple): 107 | values = [values] 108 | for idx, value in enumerate(values): 109 | if len(value.shape) == 3: 110 | value = value[:, np.newaxis, :, :] 111 | value = value[:1] 112 | value = torch.from_numpy(value) 113 | 114 | image_name = "{}/{}".format(mode_tag, tag) 115 | if len(values) > 1: 116 | image_name = image_name + "_" + str(idx) 117 | logger.add_image(image_name, vutils.make_grid(value, padding=0, nrow=1, normalize=True, scale_each=True), 118 | global_step) 119 | 120 | def tensor2numpy(var_dict): 121 | for key, vars in var_dict.items(): 122 | if isinstance(vars, np.ndarray): 123 | var_dict[key] = vars 124 | elif isinstance(vars, torch.Tensor): 125 | var_dict[key] = vars.data.cpu().numpy() 126 | else: 127 | raise NotImplementedError("invalid input type for tensor2numpy") 128 | 129 | return var_dict 130 | 131 | def get_all_data_folders(base_dir=None): 132 | if base_dir is None: 133 | base_dir = os.getcwd() 134 | 135 | data_folders = [] 136 | categories = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))] 137 | for category in categories: 138 | for scene in os.listdir(os.path.join(base_dir, category)): 139 | data_folder = os.path.join(*[base_dir, category, scene]) 140 | if os.path.isdir(data_folder): 141 | data_folders.append(data_folder) 142 | 143 | return data_folders 144 | 145 | 146 | def get_comma_separated_args(option, opt, value, parser): 147 | values = [v.strip() for v in value.split(",")] 148 | setattr(parser.values, option.dest, values) 149 | 150 | 151 | def parse_options(): 152 | parser = optparse.OptionParser() 153 | parser.add_option("-d", "--date_folder", type="string", action="callback", callback=get_comma_separated_args, 154 | dest="data_folders", help="e.g. stratified/dots,test/bedroom") 155 | options, remainder = parser.parse_args() 156 | 157 | if options.data_folders is None: 158 | options.data_folders = get_all_data_folders(os.getcwd()) 159 | else: 160 | options.data_folders = [os.path.abspath("%s") % d for d in options.data_folders] 161 | for f in options.data_folders: 162 | print(f) 163 | 164 | return options.data_folders 165 | 166 | 167 | ############################ MASK OPERATIONS ############################ 168 | def combine_masks(masks): 169 | if len(masks.shape) == 3: 170 | masks = masks.reshape(9,9,*masks.shape[1:]) 171 | C1, C2, H, W = masks.shape 172 | combined_mask = torch.zeros(C1*H,C2*W, device=masks.device) 173 | for i in range(C1): 174 | for j in range(C2): 175 | combined_mask[i::C1,j::C2] = masks[i,j]#/masks[i,j].norm() 176 | return combined_mask 177 | 178 | def un_combine_masks(combined_mask, shape): 179 | C1, C2, H, W = shape 180 | masks = [] 181 | for i in range(C1): 182 | for j in range(C2): 183 | masks.append(combined_mask[i::C1,j::C2]) 184 | return torch.stack(masks) 185 | 186 | def gkern(l=5, sig=1.): 187 | """\ 188 | creates gaussian kernel with side length l and a sigma of sig 189 | """ 190 | 191 | ax = np.linspace(-(l - 1) / 2., (l - 1) / 2., l) 192 | xx, yy = np.meshgrid(ax, ax) 193 | kernel = np.exp(-0.5 * (np.square(xx) + np.square(yy)) / np.square(sig)) 194 | return kernel 195 | 196 | def gkern_mask(kernel_mean, kernel_sigma, shape=(9,9,64,64)): 197 | C1, C2, H, W = shape 198 | mask = np.zeros((C1*H, C2*W)) 199 | assert C1 == C2 200 | for i in range(H): 201 | for j in range(W): 202 | sig = np.random.normal(kernel_mean, kernel_sigma) 203 | kernel = gkern(C1, sig) 204 | mask[i*C1:(i+1)*C1, j*C2:(j+1)*C2] = kernel 205 | return torch.from_numpy(mask) 206 | 207 | --------------------------------------------------------------------------------