├── .gitignore ├── ldm ├── modules │ ├── losses │ │ ├── __init__.py │ │ └── contperceptual.py │ ├── ema.py │ ├── distributions │ │ └── distributions.py │ ├── attention.py │ └── diffusionmodules │ │ └── util.py ├── lr_scheduler.py ├── util.py └── models │ └── diffusion │ └── ddim.py ├── data ├── sample │ ├── image.exr │ ├── mask.png │ └── normal.npy ├── sample.obj └── datalists │ ├── mvs_ortho_synth_refmap │ ├── sparsemaskobjects_val.txt │ └── sparsemaskobjects_test.txt │ ├── DeepRelighting_shape5000 │ ├── val_idx_mvs.txt │ ├── test_idx_mvs.txt │ └── shapes_val.txt │ └── LavalIndoor+PolyHaven_2k │ ├── envs_val.txt │ └── envs_test.txt ├── assets └── drmnet_overall.gif ├── utils ├── tonemap.py ├── img2refmap.py ├── file_io.py └── transform.py ├── models └── lr_scheduler.py ├── LICENSE ├── configs ├── obsnet │ ├── eval_obsnet.yaml │ ├── train_obsnet.yaml │ └── finetune_obsnet.yaml └── drmnet │ ├── eval_drmnet.yaml │ └── train_drmnet.yaml ├── scripts ├── preprocess_envmap.py ├── preprocess_shape.py └── estimate.py ├── environment ├── pip_freeze.txt └── drmnet_release.def ├── dataset ├── basedataset.py ├── parametricrefmap.py └── parametric_img2refmap.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *cache 3 | *.pt 4 | *.ckpt 5 | *.binary 6 | /logs -------------------------------------------------------------------------------- /ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /data/sample/image.exr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyotovision-public/DRMNet/HEAD/data/sample/image.exr -------------------------------------------------------------------------------- /data/sample/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyotovision-public/DRMNet/HEAD/data/sample/mask.png -------------------------------------------------------------------------------- /data/sample/normal.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyotovision-public/DRMNet/HEAD/data/sample/normal.npy -------------------------------------------------------------------------------- /assets/drmnet_overall.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyotovision-public/DRMNet/HEAD/assets/drmnet_overall.gif -------------------------------------------------------------------------------- /data/sample.obj: -------------------------------------------------------------------------------- 1 | # Blender 3.6.0 2 | # www.blender.org 3 | o Cube 4 | v 1.000000 1.000000 -1.000000 5 | v 1.000000 -1.000000 -1.000000 6 | v 1.000000 1.000000 1.000000 7 | v 1.000000 -1.000000 1.000000 8 | v -1.000000 1.000000 -1.000000 9 | v -1.000000 -1.000000 -1.000000 10 | v -1.000000 1.000000 1.000000 11 | v -1.000000 -1.000000 1.000000 12 | s 0 13 | f 5 3 1 14 | f 3 8 4 15 | f 7 6 8 16 | f 2 8 6 17 | f 1 4 2 18 | f 5 2 6 19 | f 5 7 3 20 | f 3 7 8 21 | f 7 5 6 22 | f 2 4 8 23 | f 1 3 4 24 | f 5 1 2 25 | -------------------------------------------------------------------------------- /utils/tonemap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def hdr2ldr(x: np.ndarray, mask: np.ndarray = None, alpha=0.18, gamma=2.2) -> np.ndarray: 5 | L = 0.212671 * x[:, :, 0] + 0.715160 * x[:, :, 1] + 0.072169 * x[:, :, 2] 6 | mask = np.logical_and(mask, L > 5e-5) if mask is not None else L > 5e-5 7 | assert mask.ndim == 2 8 | coeff = alpha / np.exp((np.log(L.clip(0) + 1e-7) * mask).sum() / mask.sum()) 9 | return (x * coeff).clip(0, 1) ** (1 / gamma) 10 | 11 | 12 | if __name__ == "__main__": 13 | import sys 14 | from pathlib import Path 15 | 16 | from file_io import load_exr, save_exr 17 | 18 | work_dir = Path(sys.argv[1]) 19 | 20 | for path in work_dir.glob("*.hdr"): 21 | img = load_exr(path) 22 | img = hdr2ldr(img) 23 | save_exr(path.with_suffix(".png"), img) 24 | -------------------------------------------------------------------------------- /models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | class LambdaWarmUpScheduler: 2 | """ 3 | note: use with a base_lr of 1.0 4 | """ 5 | 6 | def __init__(self, warm_up_steps, lr_start, lr_end, verbosity_interval=0): 7 | self.lr_warm_up_steps = warm_up_steps 8 | self.lr_start = lr_start 9 | self.lr_end = lr_end 10 | self.last_lr = 0.0 11 | self.verbosity_interval = verbosity_interval 12 | 13 | def schedule(self, n, **kwargs): 14 | if self.verbosity_interval > 0: 15 | if n % self.verbosity_interval == 0: 16 | print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 17 | if n < self.lr_warm_up_steps: 18 | lr = (self.lr_end - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 19 | self.last_lr = lr 20 | return lr 21 | else: 22 | lr = self.lr_end 23 | self.last_lr = lr 24 | return lr 25 | 26 | def __call__(self, n, **kwargs): 27 | return self.schedule(n, **kwargs) 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 kyotovision-public 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/obsnet/eval_obsnet.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: models.obsnet.ObsNetDiffusion 3 | params: 4 | linear_start: 0.0001 5 | linear_end: 0.09 6 | log_every_t: 2000 7 | timesteps: 1000 8 | loss_type: l2 9 | first_stage_key: LrK 10 | cond_stage_key: raw_refmap 11 | padding_mode: noise 12 | image_size: 128 13 | channels: 3 14 | concat_mode: true 15 | ddim_steps: 50 16 | monitor: val/loss 17 | clip_denoised: false 18 | masked_loss: false 19 | obj_img_key: img 20 | ckpt_path: ./checkpoints/obsnet.ckpt 21 | init_from_ckpt_verbose: true 22 | 23 | unet_config: 24 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 25 | params: 26 | image_size: 128 27 | in_channels: 6 28 | out_channels: 3 29 | model_channels: 128 30 | attention_resolutions: [4, 8, 16] 31 | num_res_blocks: 2 32 | channel_mult: [1, 2, 3, 4, 5] 33 | num_heads: 1 34 | resblock_updown: False 35 | conv_resample: False 36 | 37 | data: 38 | target: main.DataModuleFromConfig 39 | params: 40 | batch_size: 1 41 | num_workers: 5 42 | predict: 43 | target: dataset.basedataset.BaseDataset 44 | params: 45 | size: 128 46 | transform_func: resize_0p1tom1p1_normalizedLogarithmic_lowerbound1e-6 47 | clamp_before_exp: 20 -------------------------------------------------------------------------------- /scripts/preprocess_envmap.py: -------------------------------------------------------------------------------- 1 | """collect and reshape environment maps""" 2 | 3 | import argparse 4 | import sys 5 | from pathlib import Path 6 | from typing import List, Tuple 7 | 8 | import cv2 9 | from tqdm import tqdm 10 | 11 | sys.path.append(str(Path(__file__).parent.parent)) 12 | 13 | from utils.file_io import load_exr, save_exr 14 | 15 | if __name__ == "__main__": 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("source_dirs", nargs="?", type=Path, help="The path of directories containing environment maps") 18 | parser.add_argument("--output_dir", "-o", type=Path, help="The output directory", default=Path("./data/LavalIndoor+PolyHaven_2k/")) 19 | parser.add_argument("--resolution", type=str, help="The target resolution as HxW", default="2000x1000") 20 | args = parser.parse_args() 21 | 22 | source_dirs: List[Path] = args.source_dirs 23 | output_dir: Path = args.output_dir 24 | output_dir.mkdir(exist_ok=True) 25 | resolution: Tuple[int, int] = (int(i) for i in args.resolution.split("x")[::-1]) 26 | 27 | for source_dir in source_dirs: 28 | _suffix = [".hdr", ".exr"] 29 | for envmap_path in source_dir.glob("*.*"): 30 | if envmap_path.suffix not in _suffix: 31 | continue 32 | envmap = load_exr(envmap_path) 33 | envmap = cv2.resize(envmap) 34 | envmap = cv2.resize(envmap, resolution, interpolation=cv2.INTER_AREA) 35 | save_exr(output_dir.joinpath(envmap_path.name).with_suffix(".exr"), envmap) 36 | -------------------------------------------------------------------------------- /scripts/preprocess_shape.py: -------------------------------------------------------------------------------- 1 | """subprocess the shapes of [Xu et al.](https://cseweb.ucsd.edu/~viscomp/projects/SIG18Relighting/)""" 2 | 3 | import argparse 4 | import sys 5 | from pathlib import Path 6 | 7 | import mitsuba as mi 8 | import numpy as np 9 | import torch 10 | from tqdm import tqdm 11 | 12 | mi.set_variant("cuda_ad_rgb") 13 | 14 | sys.path.append(str(Path(__file__).parent.parent)) 15 | 16 | from utils.mitsuba3_utils import load_mesh 17 | 18 | breakpoint() 19 | 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("source_root", type=Path, help="The root path of the original shape dataset Shapes_Multi_5000") 23 | parser.add_argument("--output_dir", "-o", type=Path, help="The output directory", default=Path("./data/DeepRelighting_shape5000/")) 24 | args = parser.parse_args() 25 | 26 | source_root: Path = args.source_root 27 | output_dir: Path = args.output_dir 28 | output_dir.mkdir(exist_ok=True) 29 | 30 | mesh = mi.load_dict({"type": "obj", "filename": "./data/sample.obj"}) 31 | params = mi.traverse(mesh) 32 | 33 | num_shape = len(sorted(source_root.glob("Shape__*"))) 34 | for i in tqdm(range(num_shape)): 35 | obj_path = source_root / f"Shape__{i}/object.obj" 36 | obj_dict = load_mesh(obj_path) 37 | 38 | vertex_positions = obj_dict["vertex_positions"].torch().view(-1, 3) 39 | # normalize the size of mesh 40 | vertex_positions *= 0.9 / torch.linalg.vector_norm(vertex_positions, dim=-1).max() 41 | # ensure there is no overflow on torch.int32 42 | assert len(vertex_positions) < 2**31 43 | 44 | obj_dict["vertex_positions"] = vertex_positions.cpu() 45 | obj_dict["vertex_normals"] = obj_dict["vertex_normals"].torch().view(-1, 3).cpu() 46 | obj_dict["faces"] = torch.from_numpy(obj_dict["faces"].numpy().astype(np.int32)).view(-1, 3) 47 | 48 | torch.save(obj_dict, output_dir / f"Shape__{i}.pt") 49 | -------------------------------------------------------------------------------- /utils/img2refmap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .transform import xyz2thetaphi 4 | 5 | 6 | def refmap_mask_make( 7 | colors: torch.Tensor, # [n, 3] 8 | normals: torch.Tensor, # [n, 3] 9 | res: int, 10 | angle_threshold: float = None, 11 | min_points=0, 12 | refmap_batch_size=512, 13 | ): 14 | device = colors.device 15 | Height = Width = res 16 | theta = (torch.arange(Height, device=device) + 0.5) * (torch.pi / Height) 17 | phi = (torch.arange(Width, device=device) + 0.5) * (torch.pi / Width) 18 | thetaphi = torch.stack(torch.meshgrid(theta, phi, indexing="ij"), -1) 19 | # refmap_normal = thetaphi2xyz(thetaphi, normal=[0, 1, 0], tangent=[-1, 0, 0]) 20 | thetaphi_normals = xyz2thetaphi(normals, normal=[0, 1, 0], tangent=[-1, 0, 0]) # [n, 2(tp)] 21 | refmap = torch.zeros((res * res), colors.size(-1), device=device, dtype=colors.dtype) 22 | refmask = torch.zeros((res * res), device=device, dtype=torch.bool) 23 | for i in range((res * res - 1) // refmap_batch_size + 1): 24 | batch_slice = slice(i * refmap_batch_size, min(res * res, (i + 1) * refmap_batch_size)) 25 | refmap_thetaphi_batch = thetaphi.view(-1, 2)[batch_slice] # [bs, 2(tp)] 26 | angles = (refmap_thetaphi_batch[:, None] - thetaphi_normals[None]).abs().amax(dim=-1) # [bs, n] 27 | angle_mask = angles > angle_threshold 28 | angle_mask.masked_fill_((~angle_mask).sum(-1, keepdim=True) < min_points, True) 29 | 30 | expanded_colors = colors.sum(-1).expand(refmap_thetaphi_batch.size(0), -1).masked_fill(angle_mask, torch.nan) # [bs, n] 31 | medians, indices = torch.nanmedian(expanded_colors, dim=-1) # [bs] 32 | median_mask = torch.isnan(medians) 33 | if not median_mask.all(): 34 | refmap[batch_slice][~median_mask] = colors[indices[~median_mask]] 35 | refmask[batch_slice][~median_mask] = True 36 | 37 | return refmap.view(res, res, colors.size(-1)), refmask.view(res, res) 38 | -------------------------------------------------------------------------------- /configs/drmnet/eval_drmnet.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: models.drmnet.DRMNet 4 | params: 5 | log_every_k: 5 6 | max_timesteps: 150 7 | loss_type: l2 8 | input_key: LrK 9 | image_size: 128 10 | channels: 3 11 | parameterization: residual 12 | cond_stage_trainable: False 13 | concat_mode: True 14 | scale_factor: 1.0 15 | scale_by_std: False 16 | monitor: 'val/loss' 17 | use_ema: True 18 | sigma: 0.02 19 | delta: 0.025 20 | gamma: 0.95 21 | epsilon: 0.01 22 | l_refmap_weight: 10.0 23 | l_refcode_weight: 0.1 24 | brdf_param_names: ["metallic.value", "base_color.value.R", "base_color.value.G", "base_color.value.B", "roughness.value", "specular"] 25 | z0: [1, 1, 1, 1, 0, 1] 26 | refmap_input_scaler: 0.12 27 | 28 | renderer_config: 29 | target: utils.mitsuba3_utils.MitsubaRefMapRenderer 30 | params: 31 | refmap_res: 128 32 | spp: 256 33 | denoise: simple 34 | 35 | illnet_config: 36 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 37 | params: 38 | image_size: 128 39 | in_channels: 6 40 | out_channels: 3 41 | model_channels: 128 42 | attention_resolutions: [8, 16, 32] 43 | num_res_blocks: 2 44 | dropout: 0.0 45 | channel_mult: [1, 2, 3, 4, 5, 6] 46 | num_heads: 1 47 | resblock_updown: False 48 | conv_resample: False 49 | 50 | refnet_config: 51 | target: ldm.modules.diffusionmodules.openaimodel.EncoderUNetModel 52 | params: 53 | image_size: 128 54 | in_channels: 6 55 | model_channels: 128 56 | out_channels: 6 57 | num_res_blocks: 2 58 | attention_resolutions: [8, 16] 59 | dropout: 0.0 60 | channel_mult: [1, 1, 2, 3, 4] 61 | conv_resample: False 62 | resblock_updown: False 63 | num_heads: 1 64 | use_scale_shift_norm: False 65 | pool: "adaptive" 66 | 67 | ckpt_path: ./checkpoints/drmnet.ckpt 68 | 69 | data: 70 | target: main.DataModuleFromConfig 71 | params: 72 | batch_size: 1 73 | num_workers: 3 74 | predict: 75 | target: dataset.basedataset.BaseDataset 76 | params: 77 | size: 128 78 | transform_func: log 79 | clamp_before_exp: 20 80 | -------------------------------------------------------------------------------- /environment/pip_freeze.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | aiohttp==3.9.3 3 | aiosignal==1.3.1 4 | albumentations==1.4.1 5 | altair==5.2.0 6 | antlr4-python3-runtime==4.9.3 7 | async-timeout==4.0.3 8 | attrs==23.2.0 9 | blinker==1.7.0 10 | cachetools==5.3.3 11 | certifi==2024.2.2 12 | charset-normalizer==3.3.2 13 | click==8.1.7 14 | contourpy==1.1.1 15 | cycler==0.12.1 16 | dbus-python==1.2.16 17 | einops==0.7.0 18 | filelock==3.13.1 19 | fonttools==4.49.0 20 | frozenlist==1.4.1 21 | fsspec==2024.2.0 22 | gitdb==4.0.11 23 | GitPython==3.1.42 24 | google-auth==2.28.2 25 | google-auth-oauthlib==1.0.0 26 | grpcio==1.62.1 27 | huggingface-hub==0.21.4 28 | idna==3.6 29 | imageio==2.34.0 30 | imageio-ffmpeg==0.4.9 31 | importlib_metadata==7.0.2 32 | importlib_resources==6.1.3 33 | jedi==0.19.1 34 | Jinja2==3.1.3 35 | joblib==1.3.2 36 | jsonschema==4.21.1 37 | jsonschema-specifications==2023.12.1 38 | kiwisolver==1.4.5 39 | lazy_loader==0.3 40 | lightning-utilities==0.10.1 41 | lpips==0.1.4 42 | Markdown==3.5.2 43 | markdown-it-py==3.0.0 44 | MarkupSafe==2.1.5 45 | matplotlib==3.7.5 46 | mdurl==0.1.2 47 | multidict==6.0.5 48 | networkx==3.1 49 | numpy==1.24.4 50 | oauthlib==3.2.2 51 | omegaconf==2.3.0 52 | opencv-python-headless==4.9.0.80 53 | packaging==23.2 54 | pandas==2.0.3 55 | parso==0.8.3 56 | pillow==10.2.0 57 | pkgutil_resolve_name==1.3.10 58 | protobuf==4.25.3 59 | pudb==2024.1 60 | pyarrow==15.0.1 61 | pyasn1==0.5.1 62 | pyasn1-modules==0.3.0 63 | pydeck==0.8.1b0 64 | Pygments==2.17.2 65 | PyGObject==3.36.0 66 | pynvrtc==9.2 67 | pyparsing==3.1.2 68 | python-dateutil==2.9.0.post0 69 | pytorch-fid==0.3.0 70 | pytorch-lightning==1.9.0 71 | pytz==2024.1 72 | PyWavelets==1.4.1 73 | PyYAML==6.0.1 74 | referencing==0.33.0 75 | regex==2023.12.25 76 | requests==2.31.0 77 | requests-oauthlib==1.3.1 78 | rich==13.7.1 79 | rpds-py==0.18.0 80 | rsa==4.9 81 | safetensors==0.4.2 82 | scikit-image==0.21.0 83 | scikit-learn==1.3.2 84 | scipy==1.10.1 85 | six==1.16.0 86 | smmap==5.0.1 87 | streamlit==1.32.0 88 | # Editable Git install with no remote (taming-transformers==0.0.1) 89 | -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 90 | tenacity==8.2.3 91 | tensorboard==2.14.0 92 | tensorboard-data-server==0.7.2 93 | threadpoolctl==3.3.0 94 | tifffile==2023.7.10 95 | tokenizers==0.15.2 96 | toml==0.10.2 97 | toolz==0.12.1 98 | torch==1.12.1 99 | torch-fidelity==0.3.0 100 | torchmetrics==1.3.1 101 | torchvision==0.13.1 102 | tornado==6.4 103 | tqdm==4.66.2 104 | transformers==4.38.2 105 | typing_extensions==4.10.0 106 | tzdata==2024.1 107 | urllib3==2.2.1 108 | urwid==2.6.8 109 | urwid_readline==0.14 110 | watchdog==4.0.0 111 | wcwidth==0.2.13 112 | Werkzeug==3.0.1 113 | yarl==1.9.4 114 | zipp==3.17.0 115 | -------------------------------------------------------------------------------- /environment/drmnet_release.def: -------------------------------------------------------------------------------- 1 | Bootstrap: docker 2 | # ベースとなるdocker image 3 | From: nvidia/cuda:11.3.1-cudnn8-devel-ubuntu20.04 4 | 5 | # 環境変数設定 6 | %environment 7 | export LC_ALL=C 8 | export PYTHONPATH="/opt/mitsuba3/build/python:${PYTHONPATH}" 9 | export OPENCV_IO_ENABLE_OPENEXR=1 10 | # 環境構築 11 | %post 12 | export DEBIAN_FRONTEND=noninteractive 13 | 14 | # specify python version 15 | PYTHON_VERSION=3.8 16 | # add python repository 17 | apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv f23c5a6cf475977595c89f51ba6932366a755776 18 | echo "deb http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal main " > /etc/apt/sources.list.d/python.list 19 | echo "deb-src http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal main" >> /etc/apt/sources.list.d/python.list 20 | apt-get update 21 | 22 | # remove default python 23 | apt-get purge --auto-remove python* libpython* 24 | 25 | # install specified python 26 | apt-get install -y python${PYTHON_VERSION}-dev 27 | 28 | # install pip 29 | apt-get install -y python3-pip 30 | ln -fs /usr/bin/python$PYTHON_VERSION /usr/bin/python 31 | ln -fs /usr/bin/python$PYTHON_VERSION /usr/bin/python3 32 | 33 | apt-get install -y curl x11-apps git g++ make unzip 34 | apt-get install -y cmake clang-format gdb zlib1g-dev libopenexr-dev libopencv-dev 35 | 36 | # required for matplotlib.show() 37 | apt-get install -y python$PYTHON_VERSION-tk 38 | 39 | # install python packages 40 | alias pip="/usr/bin/python$PYTHON_VERSION -m pip" 41 | pip install --upgrade --no-cache-dir pip 42 | pip install --no-cache-dir torch==1.12.1+cu113 torchvision==0.13.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html 43 | pip install --no-cache-dir albumentations pudb imageio imageio-ffmpeg scikit-image lpips 44 | pip install --no-cache-dir pytorch-lightning==1.9.0 45 | pip install --no-cache-dir omegaconf streamlit torch-fidelity einops transformers 46 | pip install --no-cache-dir pynvrtc==9.2 47 | pip install --no-cache-dir matplotlib scikit-learn tqdm 48 | pip install --no-cache-dir Pillow tensorboard 49 | pip install --no-cache-dir pytorch-fid 50 | pip install --no-cache-dir -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 51 | 52 | # build mitsuba3 53 | cd /opt 54 | git clone --branch v3.3.0 --recursive https://github.com/mitsuba-renderer/mitsuba3 55 | cd mitsuba3 56 | apt-get install -y clang-10 libc++-10-dev libc++abi-10-dev cmake ninja-build 57 | apt-get install -y libpng-dev libjpeg-dev 58 | apt-get install -y libpython3-dev python3-distutils 59 | export CC=clang-10 export CXX=clang++-10 60 | mkdir build 61 | cd build 62 | cmake -GNinja .. 63 | sed -i -e 's/ "scalar_rgb", "scalar_spectral", "cuda_ad_rgb", "llvm_ad_rgb"/ "scalar_rgb", "cuda_ad_rgb", "cuda_rgb"/g' mitsuba.conf 64 | ninja 65 | 66 | rm -rf /var/lib/apt/lists/* 67 | 68 | # 起動時に行う処理 69 | %runscript 70 | echo "Container was created $(date)" 71 | bash 72 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /configs/obsnet/train_obsnet.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: models.obsnet.ObsNetDiffusion 4 | params: 5 | linear_start: 0.0001 6 | linear_end: 0.09 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l2 10 | first_stage_key: LrK 11 | cond_stage_key: masked_LrK 12 | padding_mode: noise 13 | image_size: 128 14 | channels: 3 15 | concat_mode: True 16 | ddim_steps: 50 17 | monitor: val/loss 18 | clip_denoised: False 19 | masked_loss: False 20 | noisy_observe: 0.04 21 | cache_data: False 22 | refmap_cache_root: 23 | objimg_cache_root: 24 | envmap_dir: 25 | 26 | renderer_config: 27 | target: utils.mitsuba3_utils.MitsubaRefMapRenderer 28 | params: 29 | refmap_res: 128 30 | spp: 256 31 | denoise: simple 32 | brdf_param_names: ["metallic.value", "base_color.value.R", "base_color.value.G", "base_color.value.B", "roughness.value", "specular"] 33 | 34 | scheduler_config: 35 | target: models.lr_scheduler.LambdaWarmUpScheduler 36 | params: 37 | verbosity_interval: 0 38 | warm_up_steps: 5000 39 | lr_start: 0.0 40 | lr_end: 1.0 41 | 42 | unet_config: 43 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 44 | params: 45 | image_size: 128 46 | in_channels: 6 47 | out_channels: 3 48 | model_channels: 128 49 | attention_resolutions: [4, 8, 16] 50 | num_res_blocks: 2 51 | channel_mult: [1, 2, 3, 4, 5] 52 | num_heads: 1 53 | resblock_updown: False 54 | conv_resample: False 55 | 56 | data: 57 | target: main.DataModuleFromConfig 58 | params: 59 | batch_size: 20 60 | num_workers: 5 61 | train: 62 | target: dataset.parametricrefmap.ParametricRefmapDataset 63 | params: 64 | size: 128 65 | split: train 66 | return_envmap: False 67 | data_root: ./data/LavalIndoor+PolyHaven_2k 68 | mask_root: ./data/mvs_ortho_synth_refmap/ 69 | transform_func: 0p1tom1p1_normalizedLogarithmic_lowerbound1e-6 70 | zdim: 6 71 | epoch_cycle: 1000 72 | refmap_cache_root: ./data/cache/refmap 73 | 74 | train_sampler: 75 | target: main.CustomRandomSampler 76 | params: 77 | same_sampling: False 78 | 79 | validation: 80 | target: dataset.parametricrefmap.ParametricRefmapDataset 81 | params: 82 | size: 128 83 | split: val 84 | return_envmap: False 85 | data_root: ./data/LavalIndoor+PolyHaven_2k 86 | mask_root: ./data/mvs_ortho_synth_refmap/ 87 | transform_func: 0p1tom1p1_normalizedLogarithmic_lowerbound1e-6 88 | zdim: 6 89 | epoch_cycle: 1000 90 | refmap_cache_root: ./data/cache/refmap 91 | 92 | val_sampler: 93 | target: main.CustomRandomSampler 94 | params: 95 | same_sampling: True 96 | 97 | lightning: 98 | callbacks: 99 | image_logger: 100 | target: main.ImageLogger 101 | params: 102 | batch_frequency: 1000 103 | max_images: 10 104 | increase_log_steps: True 105 | log_images_kwargs: 106 | ddim_steps: 50 107 | 108 | 109 | trainer: 110 | benchmark: True 111 | replace_sampler_ddp: True 112 | max_epoch: 4000 -------------------------------------------------------------------------------- /configs/obsnet/finetune_obsnet.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: models.obsnet.ObsNetDiffusion 4 | params: 5 | linear_start: 0.0001 6 | linear_end: 0.09 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l2 10 | first_stage_key: LrK 11 | cond_stage_key: raw_refmap 12 | padding_mode: noise 13 | image_size: 128 14 | channels: 3 15 | concat_mode: True 16 | ddim_steps: 50 17 | monitor: val/loss 18 | clip_denoised: False 19 | masked_loss: False 20 | obj_img_key: img 21 | cache_data: False 22 | refmap_cache_root: 23 | objimg_cache_root: 24 | envmap_dir: 25 | ckpt_path: ./logs/xxxx-xx-xxTxx-xx-xx_train_obsnet/checkpoints/last.ckpt 26 | 27 | renderer_config: 28 | target: utils.mitsuba3_utils.MitsubaRefMapRenderer 29 | params: 30 | refmap_res: 128 31 | spp: 256 32 | denoise: simple 33 | brdf_param_names: ["metallic.value", "base_color.value.R", "base_color.value.G", "base_color.value.B", "roughness.value", "specular"] 34 | img_renderer_config: 35 | target: utils.mitsuba3_utils.MitsubaOrthoRenderer 36 | params: 37 | image_size: [512, 512] 38 | spp: 64 39 | denoise: 40 | return_normal: True 41 | return_depth: True 42 | brdf_param_names: ["metallic.value", "base_color.value.R", "base_color.value.G", "base_color.value.B", "roughness.value", "specular"] 43 | 44 | scheduler_config: 45 | target: models.lr_scheduler.LambdaWarmUpScheduler 46 | params: 47 | verbosity_interval: 0 48 | warm_up_steps: 5000 49 | lr_start: 0.0 50 | lr_end: 1.0 51 | 52 | unet_config: 53 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 54 | params: 55 | image_size: 128 56 | in_channels: 6 57 | out_channels: 3 58 | model_channels: 128 59 | attention_resolutions: [4, 8, 16] 60 | num_res_blocks: 2 61 | channel_mult: [1, 2, 3, 4, 5] 62 | num_heads: 1 63 | resblock_updown: False 64 | conv_resample: False 65 | 66 | data: 67 | target: main.DataModuleFromConfig 68 | params: 69 | batch_size: 20 70 | num_workers: 5 71 | train: 72 | target: dataset.parametric_img2refmap.ParametricImg2RefmapDataset 73 | params: 74 | size: 128 75 | split: train 76 | zdim: 6 77 | data_root: ./data/LavalIndoor+PolyHaven_2k/ 78 | shape_root: ./data/DeepRelighting_shape5000/ 79 | return_envmap: True 80 | return_obj: True 81 | refmap_key: rK 82 | transform_func: 0p1tom1p1_normalizedLogarithmic_lowerbound1e-6 83 | 84 | train_sampler: 85 | target: main.CustomRandomSampler 86 | params: 87 | same_sampling: False 88 | 89 | validation: 90 | target: dataset.parametric_img2refmap.ParametricImg2RefmapDataset 91 | params: 92 | size: 128 93 | split: val 94 | zdim: 6 95 | data_root: ./data/LavalIndoor+PolyHaven_2k/ 96 | shape_root: ./data/DeepRelighting_shape5000/ 97 | return_envmap: True 98 | return_obj: True 99 | refmap_key: rK 100 | transform_func: 0p1tom1p1_normalizedLogarithmic_lowerbound1e-6 101 | 102 | val_sampler: 103 | target: main.CustomRandomSampler 104 | params: 105 | same_sampling: True 106 | 107 | lightning: 108 | callbacks: 109 | image_logger: 110 | target: main.ImageLogger 111 | params: 112 | batch_frequency: 1000 113 | max_images: 10 114 | increase_log_steps: True 115 | log_images_kwargs: 116 | ddim_steps: 50 117 | 118 | 119 | trainer: 120 | benchmark: True 121 | replace_sampler_ddp: True 122 | max_epoch: 300 -------------------------------------------------------------------------------- /utils/file_io.py: -------------------------------------------------------------------------------- 1 | import struct 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | 9 | 10 | def load_exr(path: Path, as_torch: bool = False, channel_first: bool = False) -> Union[np.ndarray, torch.Tensor]: 11 | # not support alpha channel 12 | img: np.ndarray = cv2.cvtColor(cv2.imread(str(path), -1)[..., :3], cv2.COLOR_BGR2RGB) 13 | if channel_first: 14 | img = img.transpose(2, 0, 1) 15 | if as_torch: 16 | img: torch.Tensor = torch.from_numpy(img) 17 | return img 18 | 19 | 20 | def save_exr(path: Path, img: Union[np.ndarray, torch.Tensor], channel_first: bool = False): 21 | if isinstance(img, torch.Tensor): 22 | img: np.ndarray = img.detach().cpu().numpy() 23 | if channel_first: 24 | img = img.transpose(1, 2, 0) 25 | return cv2.imwrite(str(path), cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) 26 | 27 | 28 | def save_png(path: Path, ldr: Union[np.ndarray, torch.Tensor], channel_first: bool = False, mask: Union[np.ndarray, torch.Tensor] = None): 29 | # mask: [H, W] 30 | if isinstance(ldr, torch.Tensor): 31 | ldr: np.ndarray = ldr.detach().cpu().numpy() 32 | if isinstance(mask, torch.Tensor): 33 | mask: np.ndarray = mask.detach().cpu().numpy() 34 | if channel_first: 35 | ldr = ldr.transpose(1, 2, 0) 36 | ldr = ldr[:, :, :3] 37 | ldr = cv2.cvtColor(ldr, cv2.COLOR_BGR2RGB) 38 | if mask is not None: 39 | if mask.ndim == 2: 40 | mask = mask[:, :, None] 41 | ldr = np.concatenate([ldr, mask], axis=-1) 42 | return cv2.imwrite(str(path), ldr * 255) 43 | 44 | 45 | def load_png(path: Path, as_torch: bool = False, channel_first: bool = False): 46 | # not support alpha channel 47 | img = cv2.imread(str(path), -1) 48 | if img.ndim == 3 and img.shape[-1] == 3: 49 | img: np.ndarray = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 50 | if channel_first: 51 | if img.ndim == 2: 52 | img = img[:, :, None] 53 | img = img.transpose(2, 0, 1) 54 | if as_torch: 55 | img: torch.Tensor = torch.from_numpy(img) 56 | return img / 255.0 57 | 58 | 59 | BRDF_SAMPLING_RES_THETA_H = 90 60 | BRDF_SAMPLING_RES_THETA_D = 90 61 | BRDF_SAMPLING_RES_PHI_D = 360 62 | RED_SCALE = 1.0 / 1500.0 63 | GREEN_SCALE = 1.15 / 1500.0 64 | BLUE_SCALE = 1.66 / 1500.0 65 | 66 | 67 | def save_merl(data: np.ndarray, filename: Path): 68 | """save a merl format brdf 69 | 70 | :param np.ndarray data: an array of brdf with shape of (3, ThetaH, ThetaD, PhiD) 71 | :param Path filename: the path to save (.binary) 72 | """ 73 | data = np.reshape(data, [3, -1]) 74 | data = data / np.array([RED_SCALE, GREEN_SCALE, BLUE_SCALE]).reshape(3, 1) 75 | data = data.flatten() 76 | with open(filename, "wb") as f: 77 | f.write(struct.pack("iii", BRDF_SAMPLING_RES_THETA_H, BRDF_SAMPLING_RES_THETA_D, BRDF_SAMPLING_RES_PHI_D // 2)) 78 | for i in range(data.shape[0]): 79 | f.write(struct.pack("d", data[i])) 80 | 81 | 82 | def load_merl(filename: Path) -> np.ndarray: 83 | """load a merl format brdf 84 | 85 | :param Path filename: the path to load (.binary) 86 | :return np.ndarray: an array of brdf with shape of (3, ThetaH, ThetaD, PhiD) 87 | """ 88 | N_DIM = BRDF_SAMPLING_RES_THETA_H * BRDF_SAMPLING_RES_THETA_D * BRDF_SAMPLING_RES_PHI_D // 2 89 | with open(filename, "rb") as f: 90 | dim = struct.unpack("iii", f.read(4 * 3)) 91 | n = dim[0] * dim[1] * dim[2] 92 | if n != N_DIM: 93 | raise ValueError("invalid BRDF file") 94 | 95 | data = np.empty(3 * n, dtype=np.float32) 96 | for i in range(3 * n): 97 | data[i] = struct.unpack("d", f.read(8))[0] 98 | 99 | # color x theta_h x theta_d x phi_d 100 | data = data.reshape(3, BRDF_SAMPLING_RES_THETA_H, BRDF_SAMPLING_RES_THETA_D, BRDF_SAMPLING_RES_PHI_D // 2) 101 | data *= np.reshape(np.array([RED_SCALE, GREEN_SCALE, BLUE_SCALE]), [3, 1, 1, 1]) 102 | 103 | return data 104 | -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /configs/drmnet/train_drmnet.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-5 3 | target: models.drmnet.DRMNet 4 | params: 5 | log_every_k: 5 6 | max_timesteps: 150 7 | loss_type: l2 8 | input_key: LrK 9 | image_size: 128 10 | channels: 3 11 | parameterization: residual 12 | cond_stage_trainable: False 13 | concat_mode: True 14 | scale_factor: 1.0 15 | scale_by_std: False 16 | monitor: 'val/loss' 17 | use_ema: True 18 | sigma: 0.02 19 | delta: 0.025 20 | gamma: 0.95 21 | epsilon: 0.01 22 | train_with_zk_gt: True 23 | train_with_zk_gt_switch_epoch: 2000 24 | l_refmap_weight: 10.0 25 | l_refcode_weight: 0.1 26 | brdf_param_names: ["metallic.value", "base_color.value.R", "base_color.value.G", "base_color.value.B", "roughness.value", "specular"] 27 | z0: [1, 1, 1, 1, 0, 1] 28 | refmap_input_scaler: 0.12 29 | cache_refmap: False 30 | 31 | renderer_config: 32 | target: utils.mitsuba3_utils.MitsubaRefMapRenderer 33 | params: 34 | refmap_res: 128 35 | spp: 256 36 | denoise: simple 37 | 38 | scheduler_config: 39 | target: models.lr_scheduler.LambdaWarmUpScheduler 40 | params: 41 | verbosity_interval: 0 42 | warm_up_steps: 5000 43 | lr_start: 0.0 44 | lr_end: 1.0 45 | 46 | illnet_config: 47 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 48 | params: 49 | image_size: 128 50 | in_channels: 6 51 | out_channels: 3 52 | model_channels: 128 53 | attention_resolutions: [8, 16, 32] 54 | num_res_blocks: 2 55 | dropout: 0.0 56 | channel_mult: [1, 2, 3, 4, 5, 6] 57 | num_heads: 1 58 | resblock_updown: False 59 | conv_resample: False 60 | 61 | refnet_config: 62 | target: ldm.modules.diffusionmodules.openaimodel.EncoderUNetModel 63 | params: 64 | image_size: 128 65 | in_channels: 6 66 | model_channels: 128 67 | out_channels: 6 68 | num_res_blocks: 2 69 | attention_resolutions: [8, 16] 70 | dropout: 0.0 71 | channel_mult: [1, 1, 2, 3, 4] 72 | conv_resample: False 73 | resblock_updown: False 74 | num_heads: 1 75 | use_scale_shift_norm: False 76 | pool: "adaptive" 77 | 78 | data: 79 | target: main.DataModuleFromConfig 80 | params: 81 | batch_size: 20 82 | num_workers: 3 83 | train: 84 | target: dataset.parametricrefmap.ParametricRefmapDataset 85 | params: 86 | size: 128 87 | split: train 88 | data_root: ./data/LavalIndoor+PolyHaven_2k/ 89 | transform_func: log 90 | zdim: 6 91 | epoch_cycle: 1000 92 | return_envmap: True 93 | refmap_cache_root: ./data/cache/refmap/ 94 | 95 | train_sampler: 96 | target: main.CustomRandomSampler 97 | params: 98 | same_sampling: False 99 | 100 | validation: 101 | target: dataset.parametricrefmap.ParametricRefmapDataset 102 | params: 103 | size: 128 104 | split: val 105 | data_root: ./data/LavalIndoor+PolyHaven_2k/ 106 | transform_func: log 107 | zdim: 6 108 | epoch_cycle: 1000 109 | return_envmap: True 110 | refmap_cache_root: ./data/cache/refmap/ 111 | 112 | val_sampler: 113 | target: main.CustomRandomSampler 114 | params: 115 | same_sampling: True 116 | 117 | test: 118 | target: dataset.parametricrefmap.ParametricRefmapDataset 119 | params: 120 | size: 128 121 | split: test 122 | data_root: ./data/LavalIndoor+PolyHaven_2k/ 123 | transform_func: log 124 | zdim: 6 125 | epoch_cycle: 1000 126 | return_envmap: True 127 | refmap_cache_root: ./data/cache/refmap/ 128 | 129 | test_sampler: 130 | target: main.CustomRandomSampler 131 | params: 132 | same_sampling: True 133 | 134 | lightning: 135 | callbacks: 136 | image_logger: 137 | target: main.ImageLogger 138 | params: 139 | batch_frequency: 1000 140 | max_images: 10 141 | increase_log_steps: True 142 | 143 | 144 | trainer: 145 | benchmark: True 146 | replace_sampler_ddp: True 147 | max_epoch: 4000 148 | -------------------------------------------------------------------------------- /data/datalists/mvs_ortho_synth_refmap/sparsemaskobjects_val.txt: -------------------------------------------------------------------------------- 1 | 00008 2 | 00013 3 | 00025 4 | 00031 5 | 00033 6 | 00041 7 | 00044 8 | 00051 9 | 00054 10 | 00055 11 | 00059 12 | 00067 13 | 00070 14 | 00076 15 | 00078 16 | 00084 17 | 00087 18 | 00091 19 | 00092 20 | 00093 21 | 00102 22 | 00111 23 | 00121 24 | 00134 25 | 00138 26 | 00143 27 | 00147 28 | 00149 29 | 00161 30 | 00165 31 | 00173 32 | 00187 33 | 00192 34 | 00199 35 | 00227 36 | 00236 37 | 00240 38 | 00245 39 | 00246 40 | 00250 41 | 00251 42 | 00257 43 | 00265 44 | 00274 45 | 00281 46 | 00282 47 | 00284 48 | 00321 49 | 00324 50 | 00329 51 | 00344 52 | 00345 53 | 00373 54 | 00376 55 | 00377 56 | 00379 57 | 00385 58 | 00386 59 | 00401 60 | 00414 61 | 00417 62 | 00418 63 | 00422 64 | 00429 65 | 00435 66 | 00438 67 | 00440 68 | 00445 69 | 00448 70 | 00454 71 | 00457 72 | 00458 73 | 00462 74 | 00470 75 | 00471 76 | 00495 77 | 00496 78 | 00501 79 | 00503 80 | 00514 81 | 00516 82 | 00523 83 | 00525 84 | 00528 85 | 00530 86 | 00532 87 | 00533 88 | 00534 89 | 00540 90 | 00557 91 | 00567 92 | 00568 93 | 00580 94 | 00584 95 | 00598 96 | 00610 97 | 00615 98 | 00617 99 | 00618 100 | 00622 101 | 00625 102 | 00636 103 | 00639 104 | 00648 105 | 00660 106 | 00662 107 | 00668 108 | 00674 109 | 00682 110 | 00684 111 | 00690 112 | 00692 113 | 00693 114 | 00694 115 | 00695 116 | 00704 117 | 00715 118 | 00717 119 | 00722 120 | 00730 121 | 00736 122 | 00737 123 | 00739 124 | 00748 125 | 00750 126 | 00754 127 | 00755 128 | 00760 129 | 00765 130 | 00770 131 | 00784 132 | 00785 133 | 00786 134 | 00792 135 | 00797 136 | 00798 137 | 00823 138 | 00824 139 | 00828 140 | 00832 141 | 00833 142 | 00839 143 | 00847 144 | 00851 145 | 00853 146 | 00870 147 | 00872 148 | 00884 149 | 00899 150 | 00906 151 | 00908 152 | 00912 153 | 00926 154 | 00944 155 | 00947 156 | 00953 157 | 00958 158 | 00959 159 | 00960 160 | 00961 161 | 00965 162 | 00968 163 | 00984 164 | 00988 165 | 00994 166 | 00997 167 | 01002 168 | 01003 169 | 01005 170 | 01014 171 | 01026 172 | 01070 173 | 01072 174 | 01074 175 | 01077 176 | 01087 177 | 01093 178 | 01097 179 | 01118 180 | 01122 181 | 01128 182 | 01143 183 | 01146 184 | 01147 185 | 01151 186 | 01173 187 | 01180 188 | 01185 189 | 01187 190 | 01190 191 | 01196 192 | 01197 193 | 01212 194 | 01219 195 | 01241 196 | 01243 197 | 01260 198 | 01262 199 | 01266 200 | 01267 201 | 01285 202 | 01287 203 | 01304 204 | 01306 205 | 01308 206 | 01317 207 | 01326 208 | 01341 209 | 01350 210 | 01356 211 | 01361 212 | 01364 213 | 01379 214 | 01382 215 | 01383 216 | 01388 217 | 01389 218 | 01391 219 | 01395 220 | 01408 221 | 01409 222 | 01414 223 | 01423 224 | 01424 225 | 01442 226 | 01448 227 | 01455 228 | 01456 229 | 01458 230 | 01466 231 | 01469 232 | 01470 233 | 01474 234 | 01482 235 | 01485 236 | 01489 237 | 01504 238 | 01510 239 | 01514 240 | 01521 241 | 01523 242 | 01524 243 | 01555 244 | 01558 245 | 01560 246 | 01561 247 | 01564 248 | 01585 249 | 01586 250 | 01593 251 | 01594 252 | 01597 253 | 01600 254 | 01611 255 | 01618 256 | 01633 257 | 01637 258 | 01640 259 | 01642 260 | 01648 261 | 01650 262 | 01658 263 | 01661 264 | 01665 265 | 01669 266 | 01678 267 | 01681 268 | 01686 269 | 01687 270 | 01707 271 | 01711 272 | 01712 273 | 01714 274 | 01718 275 | 01732 276 | 01738 277 | 01747 278 | 01755 279 | 01772 280 | 01773 281 | 01775 282 | 01785 283 | 01788 284 | 01802 285 | 01805 286 | 01812 287 | 01815 288 | 01822 289 | 01828 290 | 01833 291 | 01846 292 | 01848 293 | 01849 294 | 01852 295 | 01879 296 | 01882 297 | 01883 298 | 01884 299 | 01897 300 | 01899 301 | 01907 302 | 01918 303 | 01920 304 | 01923 305 | 01927 306 | 01935 307 | 01941 308 | 01968 309 | 01972 310 | 01994 311 | 01996 312 | 01997 313 | 02019 314 | 02029 315 | 02033 316 | 02034 317 | 02037 318 | 02042 319 | 02044 320 | 02045 321 | 02051 322 | 02056 323 | 02057 324 | 02067 325 | 02069 326 | 02071 327 | 02084 328 | 02093 329 | 02095 330 | 02105 331 | 02114 332 | 02116 333 | 02140 334 | 02141 335 | 02146 336 | 02155 337 | 02156 338 | 02158 339 | 02166 340 | 02169 341 | 02172 342 | 02173 343 | 02174 344 | 02182 345 | 02184 346 | 02186 347 | 02191 348 | 02205 349 | 02207 350 | 02209 351 | 02214 352 | 02217 353 | 02218 354 | 02219 355 | 02222 356 | 02224 357 | 02227 358 | 02232 359 | 02237 360 | 02247 361 | 02257 362 | 02267 363 | 02272 364 | 02279 365 | 02287 366 | 02291 367 | 02292 368 | 02305 369 | 02317 370 | 02332 371 | 02333 372 | 02334 373 | 02340 374 | 02344 375 | 02350 376 | 02352 377 | 02365 378 | 02375 379 | 02379 380 | 02388 381 | 02410 382 | 02415 383 | 02418 384 | 02426 385 | 02435 386 | 02439 387 | 02445 388 | 02453 389 | 02472 390 | 02475 391 | 02478 392 | 02491 393 | 02495 394 | 02497 395 | 02512 396 | 02521 397 | 02531 398 | 02532 399 | 02538 400 | 02544 401 | 02550 402 | 02552 403 | 02555 404 | 02559 405 | 02560 406 | 02566 407 | 02570 408 | 02583 409 | 02585 410 | 02587 411 | 02591 412 | 02593 413 | 02596 414 | 02609 415 | 02610 416 | 02613 417 | 02617 418 | 02624 419 | 02631 420 | 02634 421 | 02641 422 | 02642 423 | 02654 424 | 02669 425 | 02673 426 | 02674 427 | 02677 428 | 02678 429 | 02680 430 | 02682 -------------------------------------------------------------------------------- /dataset/basedataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Union 2 | 3 | import torch 4 | import torchvision 5 | 6 | 7 | class BaseDataset(torch.utils.data.Dataset): 8 | """base class of dataset""" 9 | 10 | def __init__( 11 | self, 12 | size: int, 13 | transform_func: str = "log", 14 | clamp_before_exp: float = 0.0, 15 | ): 16 | """ 17 | 18 | :param int size: the resolution of height/width 19 | :param str transform_func: specifies transform function, must be written in the order of mathematical notation ( f(g()) -> f_g ), defaults to "log" 20 | :param float clamp_before_exp: if >0, clamp data with min value before exponential, defaults to 0.0. 21 | """ 22 | # transform_func must be written in the order of mathematical notation ( f(g()) -> f_g ) 23 | self.size = size 24 | self.transform_func_str = transform_func 25 | self.clamp_before_exp = 10 if isinstance(clamp_before_exp, bool) and not clamp_before_exp else clamp_before_exp 26 | self.transform_funcs = [self.get_tranfrom_func(func_name) for func_name in transform_func.split("_")[::-1]] 27 | self.rescale_funcs = [self.get_rescale_func(func_name) for func_name in transform_func.split("_")] 28 | 29 | def transform(self, x: torch.Tensor, dynamic_normalize: bool = False, mask: torch.Tensor = None): 30 | # x: [(batch, channel), height, width] 31 | assert x.size(-1) >= self.size 32 | for func in self.transform_funcs: 33 | x = func(x, dynamic_normalize=dynamic_normalize, mask=mask) 34 | return x 35 | 36 | def rescale(self, x: torch.Tensor): 37 | for func in self.rescale_funcs: 38 | x = func(x) 39 | return x 40 | 41 | def get_tranfrom_func(self, func_name: str): 42 | # x: [(batch, channel), height, width] 43 | assert "_" not in func_name 44 | if func_name.startswith("resize"): 45 | if len(func_name) > 6: 46 | InterpolationMode = getattr(torchvision.transforms.InterpolationMode, func_name[6:].replace("-", "_")) 47 | else: 48 | InterpolationMode = torchvision.transforms.InterpolationMode.BILINEAR 49 | return lambda x, **kwargs: torchvision.transforms.functional.resize( 50 | x, size=(self.size, self.size), interpolation=InterpolationMode, antialias=True 51 | ) 52 | elif func_name == "log": 53 | return lambda x, **kwargs: torch.log10(x + 1e-1) + 1 54 | elif func_name == "log10": 55 | return lambda x, **kwargs: torch.log10(x) 56 | elif func_name.startswith("lowerbound"): 57 | bottom = float(func_name[10:]) 58 | return lambda x, **kwargs: torch.clip(x, bottom) 59 | elif func_name == "0p1tom1p1": 60 | return lambda x, **kwargs: x * 2 - 1 61 | elif func_name == "normalizedLogarithmic": 62 | 63 | def func(x: torch.Tensor, mask: torch.Tensor, dynamic_normalize: bool, **kwargs): 64 | if dynamic_normalize: 65 | assert mask is not None 66 | linearmax = (x * mask).amax(dim=(-1, -2, -3), keepdim=True) 67 | log10max = torch.log10(linearmax) 68 | log10min = torch.log10((x * mask + (1 - mask.float()) * linearmax).amin(dim=(-1, -2, -3), keepdim=True)) 69 | self.Logarithmic_params = [log10min, log10max] 70 | log10min, log10max = self.Logarithmic_params 71 | assert x.ndim == log10min.ndim == log10max.ndim, f"{x.ndim}, {log10min.ndim}, {log10max.ndim}" 72 | log10min, log10max = log10min.to(x.device), log10max.to(x.device) 73 | x = (torch.log10(x) - log10min) / (log10max - log10min) 74 | return x 75 | 76 | return func 77 | else: 78 | raise NotImplementedError(func_name) 79 | 80 | def get_rescale_func(self, func_name: str): 81 | do_nothing = lambda x, **kwargs: x 82 | # x: [(batch, channel), height, width] 83 | assert "_" not in func_name 84 | if func_name.startswith("resize"): 85 | return do_nothing 86 | elif func_name == "log": 87 | if self.clamp_before_exp: 88 | return lambda x, **kwargs: torch.pow(10, torch.clamp(x - 1, max=self.clamp_before_exp)) - 1e-1 89 | else: 90 | return lambda x, **kwargs: torch.pow(10, x - 1) - 1e-1 91 | elif func_name == "log10": 92 | if self.clamp_before_exp: 93 | return lambda x, **kwargs: torch.pow(10, torch.clamp(x, max=self.clamp_before_exp)) 94 | else: 95 | return lambda x, **kwargs: torch.pow(10, x) 96 | elif func_name.startswith("lowerbound"): 97 | return do_nothing 98 | elif func_name == "0p1tom1p1": 99 | return lambda x, **kwargs: (x + 1) / 2 100 | elif func_name == "normalizedLogarithmic": 101 | log10 = self.get_rescale_func("log10") 102 | 103 | def func(x: torch.Tensor, **kwargs): 104 | log10min, log10max = self.Logarithmic_params 105 | log10min, log10max = log10min.to(x.device), log10max.to(x.device) 106 | assert x.ndim == log10min.ndim == log10max.ndim, f"{x.ndim=}, {log10min.ndim=}, {log10max.ndim=}" 107 | 108 | return log10(x * (log10max - log10min) + log10min, **kwargs) 109 | 110 | return func 111 | else: 112 | raise NotImplementedError(func_name) 113 | -------------------------------------------------------------------------------- /ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__( 9 | self, 10 | disc_start, 11 | logvar_init=0.0, 12 | kl_weight=1.0, 13 | pixelloss_weight=1.0, 14 | disc_num_layers=3, 15 | disc_in_channels=3, 16 | disc_factor=1.0, 17 | disc_weight=1.0, 18 | perceptual_weight=1.0, 19 | use_actnorm=False, 20 | disc_conditional=False, 21 | disc_loss="hinge", 22 | ): 23 | 24 | super().__init__() 25 | assert disc_loss in ["hinge", "vanilla"] 26 | self.kl_weight = kl_weight 27 | self.pixel_weight = pixelloss_weight 28 | self.perceptual_loss = LPIPS().eval() 29 | self.perceptual_weight = perceptual_weight 30 | # output log variance 31 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 32 | 33 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm).apply( 34 | weights_init 35 | ) 36 | self.discriminator_iter_start = disc_start 37 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 38 | self.disc_factor = disc_factor 39 | self.discriminator_weight = disc_weight 40 | self.disc_conditional = disc_conditional 41 | 42 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 43 | if last_layer is not None: 44 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 45 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 46 | else: 47 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 48 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 49 | 50 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 51 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 52 | d_weight = d_weight * self.discriminator_weight 53 | return d_weight 54 | 55 | def forward( 56 | self, inputs, reconstructions, posteriors, optimizer_idx, global_step, last_layer=None, cond=None, split="train", weights=None 57 | ): 58 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 59 | if self.perceptual_weight > 0: 60 | p_loss = self.perceptual_loss(inputs.contiguous() * 2 - 1, reconstructions.contiguous() * 2 - 1) 61 | rec_loss = rec_loss + self.perceptual_weight * p_loss 62 | 63 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 64 | weighted_nll_loss = nll_loss 65 | if weights is not None: 66 | weighted_nll_loss = weights * nll_loss 67 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 68 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 69 | kl_loss = posteriors.kl() 70 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 71 | 72 | # now the GAN part 73 | if optimizer_idx == 0: 74 | # generator update 75 | if cond is None: 76 | assert not self.disc_conditional 77 | logits_fake = self.discriminator(reconstructions.contiguous()) 78 | else: 79 | assert self.disc_conditional 80 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 81 | g_loss = -torch.mean(logits_fake) 82 | 83 | if self.disc_factor > 0.0: 84 | try: 85 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 86 | except RuntimeError: 87 | assert not self.training 88 | d_weight = torch.tensor(0.0) 89 | else: 90 | d_weight = torch.tensor(0.0) 91 | 92 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 93 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 94 | 95 | log = { 96 | "{}/total_loss".format(split): loss.clone().detach().mean(), 97 | "{}/logvar".format(split): self.logvar.detach(), 98 | "{}/kl_loss".format(split): kl_loss.detach().mean(), 99 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 100 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 101 | "{}/d_weight".format(split): d_weight.detach(), 102 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 103 | "{}/g_loss".format(split): g_loss.detach().mean(), 104 | } 105 | return loss, log 106 | 107 | if optimizer_idx == 1: 108 | # second pass for discriminator update 109 | if cond is None: 110 | logits_real = self.discriminator(inputs.contiguous().detach()) 111 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 112 | else: 113 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 114 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 115 | 116 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 117 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 118 | 119 | log = { 120 | "{}/disc_loss".format(split): d_loss.clone().detach().mean(), 121 | "{}/logits_real".format(split): logits_real.detach().mean(), 122 | "{}/logits_fake".format(split): logits_fake.detach().mean(), 123 | } 124 | return d_loss, log 125 | -------------------------------------------------------------------------------- /data/datalists/mvs_ortho_synth_refmap/sparsemaskobjects_test.txt: -------------------------------------------------------------------------------- 1 | 00000 2 | 00002 3 | 00014 4 | 00017 5 | 00022 6 | 00024 7 | 00030 8 | 00032 9 | 00035 10 | 00042 11 | 00045 12 | 00048 13 | 00050 14 | 00052 15 | 00061 16 | 00065 17 | 00069 18 | 00079 19 | 00085 20 | 00094 21 | 00095 22 | 00096 23 | 00097 24 | 00106 25 | 00108 26 | 00109 27 | 00116 28 | 00126 29 | 00127 30 | 00128 31 | 00136 32 | 00142 33 | 00145 34 | 00146 35 | 00154 36 | 00157 37 | 00158 38 | 00160 39 | 00162 40 | 00178 41 | 00179 42 | 00183 43 | 00194 44 | 00195 45 | 00197 46 | 00201 47 | 00213 48 | 00215 49 | 00218 50 | 00219 51 | 00221 52 | 00225 53 | 00230 54 | 00231 55 | 00232 56 | 00234 57 | 00254 58 | 00260 59 | 00270 60 | 00273 61 | 00285 62 | 00288 63 | 00289 64 | 00293 65 | 00304 66 | 00306 67 | 00309 68 | 00315 69 | 00316 70 | 00318 71 | 00320 72 | 00322 73 | 00330 74 | 00331 75 | 00336 76 | 00353 77 | 00354 78 | 00366 79 | 00388 80 | 00396 81 | 00399 82 | 00402 83 | 00408 84 | 00415 85 | 00416 86 | 00424 87 | 00425 88 | 00428 89 | 00433 90 | 00436 91 | 00442 92 | 00449 93 | 00451 94 | 00453 95 | 00461 96 | 00463 97 | 00464 98 | 00465 99 | 00475 100 | 00485 101 | 00487 102 | 00489 103 | 00490 104 | 00491 105 | 00500 106 | 00504 107 | 00505 108 | 00510 109 | 00511 110 | 00519 111 | 00531 112 | 00535 113 | 00537 114 | 00539 115 | 00541 116 | 00542 117 | 00547 118 | 00550 119 | 00556 120 | 00561 121 | 00564 122 | 00565 123 | 00571 124 | 00573 125 | 00583 126 | 00585 127 | 00591 128 | 00595 129 | 00607 130 | 00609 131 | 00611 132 | 00619 133 | 00624 134 | 00628 135 | 00629 136 | 00633 137 | 00638 138 | 00640 139 | 00641 140 | 00642 141 | 00643 142 | 00644 143 | 00647 144 | 00652 145 | 00653 146 | 00655 147 | 00673 148 | 00678 149 | 00721 150 | 00725 151 | 00728 152 | 00729 153 | 00734 154 | 00738 155 | 00745 156 | 00761 157 | 00766 158 | 00768 159 | 00769 160 | 00772 161 | 00776 162 | 00778 163 | 00780 164 | 00789 165 | 00801 166 | 00802 167 | 00825 168 | 00843 169 | 00846 170 | 00849 171 | 00852 172 | 00854 173 | 00858 174 | 00861 175 | 00871 176 | 00873 177 | 00875 178 | 00876 179 | 00888 180 | 00889 181 | 00892 182 | 00894 183 | 00900 184 | 00917 185 | 00920 186 | 00921 187 | 00930 188 | 00939 189 | 00941 190 | 00945 191 | 00946 192 | 00971 193 | 00979 194 | 00985 195 | 00986 196 | 00990 197 | 01001 198 | 01006 199 | 01009 200 | 01013 201 | 01019 202 | 01021 203 | 01028 204 | 01038 205 | 01043 206 | 01044 207 | 01054 208 | 01056 209 | 01061 210 | 01067 211 | 01068 212 | 01078 213 | 01085 214 | 01088 215 | 01091 216 | 01096 217 | 01105 218 | 01108 219 | 01123 220 | 01129 221 | 01130 222 | 01133 223 | 01135 224 | 01139 225 | 01142 226 | 01144 227 | 01153 228 | 01154 229 | 01159 230 | 01161 231 | 01163 232 | 01165 233 | 01166 234 | 01169 235 | 01171 236 | 01172 237 | 01175 238 | 01177 239 | 01195 240 | 01199 241 | 01210 242 | 01211 243 | 01214 244 | 01215 245 | 01223 246 | 01224 247 | 01227 248 | 01229 249 | 01230 250 | 01234 251 | 01236 252 | 01251 253 | 01254 254 | 01258 255 | 01264 256 | 01293 257 | 01295 258 | 01296 259 | 01297 260 | 01301 261 | 01309 262 | 01310 263 | 01312 264 | 01313 265 | 01316 266 | 01320 267 | 01322 268 | 01323 269 | 01325 270 | 01329 271 | 01330 272 | 01337 273 | 01340 274 | 01344 275 | 01347 276 | 01355 277 | 01357 278 | 01359 279 | 01360 280 | 01362 281 | 01368 282 | 01385 283 | 01387 284 | 01396 285 | 01401 286 | 01406 287 | 01415 288 | 01430 289 | 01436 290 | 01440 291 | 01441 292 | 01444 293 | 01445 294 | 01452 295 | 01457 296 | 01459 297 | 01462 298 | 01467 299 | 01471 300 | 01476 301 | 01478 302 | 01479 303 | 01492 304 | 01498 305 | 01499 306 | 01502 307 | 01506 308 | 01513 309 | 01516 310 | 01518 311 | 01528 312 | 01533 313 | 01535 314 | 01540 315 | 01545 316 | 01547 317 | 01579 318 | 01595 319 | 01598 320 | 01606 321 | 01607 322 | 01612 323 | 01613 324 | 01615 325 | 01616 326 | 01620 327 | 01623 328 | 01641 329 | 01643 330 | 01646 331 | 01647 332 | 01653 333 | 01656 334 | 01657 335 | 01663 336 | 01666 337 | 01670 338 | 01672 339 | 01675 340 | 01685 341 | 01689 342 | 01692 343 | 01694 344 | 01701 345 | 01703 346 | 01708 347 | 01710 348 | 01723 349 | 01729 350 | 01731 351 | 01734 352 | 01740 353 | 01741 354 | 01742 355 | 01743 356 | 01746 357 | 01761 358 | 01770 359 | 01774 360 | 01778 361 | 01782 362 | 01800 363 | 01808 364 | 01809 365 | 01810 366 | 01819 367 | 01820 368 | 01825 369 | 01830 370 | 01834 371 | 01839 372 | 01840 373 | 01844 374 | 01847 375 | 01850 376 | 01856 377 | 01858 378 | 01863 379 | 01864 380 | 01867 381 | 01871 382 | 01880 383 | 01885 384 | 01886 385 | 01895 386 | 01902 387 | 01904 388 | 01905 389 | 01912 390 | 01913 391 | 01916 392 | 01917 393 | 01947 394 | 01964 395 | 01973 396 | 01976 397 | 01978 398 | 01980 399 | 01982 400 | 01989 401 | 01993 402 | 01999 403 | 02009 404 | 02011 405 | 02013 406 | 02018 407 | 02026 408 | 02035 409 | 02038 410 | 02046 411 | 02048 412 | 02053 413 | 02058 414 | 02068 415 | 02072 416 | 02078 417 | 02080 418 | 02090 419 | 02097 420 | 02100 421 | 02104 422 | 02106 423 | 02126 424 | 02128 425 | 02136 426 | 02137 427 | 02142 428 | 02152 429 | 02159 430 | 02160 431 | 02161 432 | 02171 433 | 02188 434 | 02192 435 | 02200 436 | 02202 437 | 02203 438 | 02211 439 | 02221 440 | 02233 441 | 02236 442 | 02242 443 | 02251 444 | 02254 445 | 02255 446 | 02273 447 | 02277 448 | 02282 449 | 02285 450 | 02288 451 | 02293 452 | 02299 453 | 02300 454 | 02301 455 | 02303 456 | 02304 457 | 02308 458 | 02312 459 | 02328 460 | 02329 461 | 02342 462 | 02347 463 | 02354 464 | 02368 465 | 02370 466 | 02373 467 | 02386 468 | 02387 469 | 02390 470 | 02392 471 | 02394 472 | 02396 473 | 02405 474 | 02407 475 | 02409 476 | 02419 477 | 02425 478 | 02427 479 | 02430 480 | 02436 481 | 02440 482 | 02444 483 | 02454 484 | 02456 485 | 02457 486 | 02458 487 | 02459 488 | 02460 489 | 02464 490 | 02465 491 | 02469 492 | 02470 493 | 02473 494 | 02476 495 | 02477 496 | 02481 497 | 02482 498 | 02484 499 | 02487 500 | 02488 501 | 02492 502 | 02493 503 | 02500 504 | 02502 505 | 02505 506 | 02509 507 | 02514 508 | 02523 509 | 02526 510 | 02529 511 | 02530 512 | 02543 513 | 02545 514 | 02547 515 | 02549 516 | 02551 517 | 02561 518 | 02572 519 | 02574 520 | 02576 521 | 02577 522 | 02595 523 | 02598 524 | 02599 525 | 02611 526 | 02615 527 | 02626 528 | 02636 529 | 02645 530 | 02651 531 | 02652 532 | 02657 533 | 02658 534 | 02659 535 | 02663 536 | 02667 537 | 02679 -------------------------------------------------------------------------------- /scripts/estimate.py: -------------------------------------------------------------------------------- 1 | """collect and reshape environment maps""" 2 | 3 | import argparse 4 | import sys 5 | from pathlib import Path 6 | from typing import List, Tuple 7 | 8 | import cv2 9 | import mitsuba as mi 10 | import numpy as np 11 | import torch 12 | from omegaconf import OmegaConf 13 | from tqdm import tqdm 14 | 15 | mi.set_variant("cuda_ad_rgb") 16 | 17 | sys.path.append(str(Path(__file__).parent.parent)) 18 | 19 | from dataset.basedataset import BaseDataset 20 | from ldm.util import instantiate_from_config 21 | from models.drmnet import DRMNet 22 | from models.obsnet import ObsNetDiffusion 23 | from utils.file_io import load_exr, load_png, save_png 24 | from utils.img2refmap import refmap_mask_make 25 | from utils.mitsuba3_utils import get_bsdf, visualize_bsdf 26 | from utils.tonemap import hdr2ldr 27 | 28 | 29 | def estimate( 30 | DRMNet_model: DRMNet, 31 | ObsNet_model: ObsNetDiffusion, 32 | input_img: torch.Tensor, 33 | input_normal: torch.Tensor, 34 | mask: torch.Tensor, 35 | tag: str = "sample", 36 | erode_kernel_size: int = 5, 37 | ): 38 | refmap_res = DRMNet_model.ds.size 39 | 40 | torch.cuda.synchronize() 41 | 42 | # edge removing 43 | if erode_kernel_size > 0: 44 | inv_mask = ~mask 45 | kernel = torch.stack(torch.meshgrid(*torch.arange(erode_kernel_size, device="cuda").expand(2, -1), indexing="ij")) 46 | kernel = kernel + 0.5 47 | kernel = torch.linalg.norm(kernel - erode_kernel_size / 2, axis=0) <= erode_kernel_size / 2 48 | kernel = kernel[None, None].float() 49 | inv_mask = torch.nn.functional.conv2d(inv_mask[None, None].float(), kernel, padding="same").bool()[0, 0] 50 | mask = torch.logical_and(mask, ~inv_mask) 51 | 52 | print("Making refmap from object image...", flush=True) 53 | refmap_est, refmask = refmap_mask_make( 54 | input_img[mask], # [N, 3] 55 | input_normal[mask], 56 | res=refmap_res, 57 | angle_threshold=np.pi / 128 / 2, 58 | ) 59 | 60 | torch.cuda.synchronize() 61 | print("Inpainting refmap ...", flush=True) 62 | 63 | batch = { 64 | "tag": [tag], 65 | "raw_refmap": refmap_est.permute(2, 0, 1)[None], 66 | "raw_refmask": refmask[None], 67 | } 68 | c, _, _ = ObsNet_model.get_cond_for_predict(batch) 69 | 70 | # get denoise row 71 | use_ddim = ObsNet_model.ddim_steps is not None 72 | with ObsNet_model.ema_scope("Plotting"): 73 | samples, _ = ObsNet_model.sample_log( 74 | cond=c, 75 | batch_size=len(batch["tag"]), 76 | ddim=use_ddim, 77 | ddim_steps=ObsNet_model.ddim_steps, 78 | eta=ObsNet_model.ddim_eta, 79 | ) 80 | inpaint_sample: torch.Tensor = ObsNet_model.ds.rescale(ObsNet_model.decode_first_stage(samples))[0] 81 | 82 | torch.cuda.synchronize() 83 | print("Inverse Rendering ...", flush=True) 84 | 85 | batch = { 86 | "tag": [tag], 87 | "LrK": inpaint_sample[None], 88 | } 89 | LrK, _, illnet_c, refnet_c, _ = DRMNet_model.get_input_for_predict(batch) 90 | 91 | torch.cuda.synchronize() 92 | with DRMNet_model.ema_scope(): 93 | samples, zK_est, _ = DRMNet_model.p_sample_loop(LrK, illnet_c, refnet_c, verbose=False) 94 | 95 | Lr0_sample: torch.Tensor = DRMNet_model.ds.rescale(DRMNet_model.decode_first_stage(samples))[0].clip(0) 96 | zK_est = zK_est[0] 97 | 98 | torch.cuda.synchronize() 99 | if DRMNet_model.refmap_input_scaler is not None: 100 | Lr0_sample /= DRMNet_model.normalizing_scale[0] 101 | 102 | return Lr0_sample, zK_est 103 | 104 | 105 | if __name__ == "__main__": 106 | parser = argparse.ArgumentParser() 107 | parser.add_argument("--input_img", type=Path, help="The path of HDR image for an object (.exr, .hdr)") 108 | parser.add_argument("--input_normal", type=Path, help="The path of normal map for an object (.npy)") 109 | parser.add_argument("--input_mask", type=Path, help="The path of mask for an object (.png)", default=None) 110 | parser.add_argument( 111 | "--obsnet_base_path", type=Path, help="the config path for obsnet", default=Path("./configs/obsnet/eval_obsnet.yaml") 112 | ) 113 | parser.add_argument( 114 | "--drmnet_base_path", type=Path, help="the config path for drmnet", default=Path("./configs/drmnet/eval_drmnet.yaml") 115 | ) 116 | parser.add_argument("--output_dir", type=Path, help="the output directory", default=Path("./outputs/")) 117 | args = parser.parse_args() 118 | 119 | # load models 120 | obsnet_base_config = OmegaConf.load(args.obsnet_base_path) 121 | obsnet_model: ObsNetDiffusion = instantiate_from_config(obsnet_base_config.model).cuda() 122 | obsnet_model.ds: BaseDataset = instantiate_from_config(obsnet_base_config.data.params.predict) 123 | drmnet_base_config = OmegaConf.load(args.drmnet_base_path) 124 | drmnet_model: DRMNet = instantiate_from_config(drmnet_base_config.model).cuda() 125 | drmnet_model.ds: BaseDataset = instantiate_from_config(drmnet_base_config.data.params.predict) 126 | 127 | # load input 128 | input_img = load_exr(args.input_img, as_torch=True).cuda() 129 | input_normal = torch.from_numpy(np.load(args.input_normal)).cuda() 130 | normal_mask = torch.linalg.norm(input_normal, dim=-1) > 0.5 131 | if args.input_mask is not None: 132 | input_mask = load_png(args.input_mask, as_torch=True).cuda() 133 | if input_mask.ndim == 3: 134 | input_mask = input_mask[:, :, 0] 135 | mask = torch.logical_and(input_mask, normal_mask) 136 | else: 137 | mask = normal_mask 138 | 139 | Lr0_sample, zK_est = estimate(drmnet_model, obsnet_model, input_img, input_normal, mask) 140 | 141 | envmap_est = drmnet_model.r0toenvmap(Lr0_sample[None], (drmnet_model.image_size, drmnet_model.image_size * 2))[0] # [H, W, 3] 142 | 143 | output_dir: Path = args.output_dir 144 | output_dir.mkdir(exist_ok=True) 145 | envmap_est_ldr = hdr2ldr(envmap_est.cpu().numpy()) 146 | save_png(output_dir / f"sample_env.png", envmap_est_ldr) 147 | vis_ref, vis_ref_mask = visualize_bsdf(get_bsdf(zK_est, drmnet_model.brdf_param_names)) 148 | vis_ref = hdr2ldr(vis_ref) 149 | save_png(output_dir / f"sample_brdf.png", vis_ref, mask=vis_ref_mask) 150 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diffusion Reflectance Map: Single-Image Stochastic Inverse Rendering of Illumination and Reflectance 2 | [arXiv](https://arxiv.org/abs/2312.04529) 3 | 4 | This repository provides an implementation of our paper [Diffusion Reflectance Map: Single-Image Stochastic Inverse Rendering of Illumination and Reflectance](https://arxiv.org/abs/2312.04529). 5 | This implementation is based on [Latent Diffusion Modle](https://github.com/CompVis/latent-diffusion/). 6 | 7 | Please note that this is a research software and may contain bugs or other issues – please use it at your own risk. If you experience major problems with it, you may contact us, but please note that we do not have the resources to deal with all issues. 8 | 9 | Please cite the following paper, if you use any part of our code and data. 10 | ``` 11 | @InProceedings{Yenyo_2022_CVPR, 12 | author = {Enyo, Yuto and Nishino, Ko}, 13 | title = {Diffusion Reflectance Map: Single-Image Stochastic Inverse Rendering of Illumination and Reflectance}, 14 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 15 | month = {June}, 16 | year = {2024}, 17 | } 18 | ``` 19 | 20 | ![drmnet_overall](assets/drmnet_overall.gif) 21 | 22 | ## Requirements 23 | 24 | We tested our code with Python 3.8 on Ubuntu 20.04 LTS using the following packages. 25 | 26 | - numpy==1.23.5 27 | - pytorch==1.12.1 28 | - torchvision==0.13.1 29 | - mitsuba==3.2.0 30 | - pytorch-lightning==1.9.0 31 | - opencv-python==4.9.0.80 32 | - omegaconf 33 | - einops 34 | - and packages for [Latent Diffusion Modles](https://github.com/CompVis/latent-diffusion?tab=readme-ov-file#requirements) (some are imported to avoid errors but not actually used) 35 | 36 | Please refer to [environment/pip_freeze.txt](environment/pip_freeze.txt) for the specific versions we used. 37 | 38 | You can also use `singularity` to replicate our environment: 39 | ```bash 40 | singularity build environment/drmnet.sif environment/drmnet_release.def 41 | singularity run --nv environment/drmnet.sif 42 | ``` 43 | 44 | ## Usage 45 | 46 | ### Demo 47 | 48 | You can download the pretrained models (`drmnet.ckpt` and `obsnet.ckpt`) from [here](https://drive.google.com/drive/folders/1zWkmzOIIwueeUL0ryzK6FU8TtW6g4T6W). Download them and save the files in `./checkpoints`. 49 | You can apply the model on the sample data in the `data` directory by running the script below. 50 | 51 | ```bash 52 | python scripts/estimate.py --input_img ./data/sample/image.exr --input_normal ./data/sample/normal.npy --input_mask ./data/sample/mask.png 53 | ``` 54 | 55 | You can view the outputs in `outputs`. 56 | 57 | ### Training 58 | 59 | #### Data 60 | 61 | The reflectance maps for training are rendered with random sampling during training. 62 | 63 | You can download the cached training data from [here](https://drive.google.com/drive/folders/1zWkmzOIIwueeUL0ryzK6FU8TtW6g4T6W). 64 | To use it, you need to use same pytorch version. 65 | Please unzip the cached data to `data/cache/` by running: 66 | ```bash 67 | unzip -d ./data/cache/refmap 68 | unzip -d ./data/cache/objimg 69 | ``` 70 | 71 | After unzipping caches, `data/cache` will look like below: 72 | ``` 73 | data/cache 74 | ├── /refmap// 75 | │ ├── / 76 | │ │ ├── bv.pt 77 | │ │ ... 78 | │ ... 79 | ├── /objimg// 80 | │ ├── 81 | │ │ ├── / 82 | │ │ │ ├── / 83 | │ │ │ │ ├── bv.pt 84 | │ │ │ │ ... 85 | │ │ │ ... 86 | │ │ ... 87 | │ ├── _rawrefmap 88 | │ │ ├── / 89 | │ │ │ ├── / 90 | │ │ │ │ ├── bv.pt 91 | │ │ │ │ ... 92 | │ │ │ ... 93 | │ │ ... 94 | ``` 95 | 96 | Also, you need to download masks for reflectance maps to train ObsNet from [here](https://drive.google.com/drive/folders/1zWkmzOIIwueeUL0ryzK6FU8TtW6g4T6W) and unzip the cached data to `data/nLMVS-Synth_refmap_masks/` by running: 97 | ```bash 98 | unzip -d ./data/nLMVS-Synth_refmap_masks 99 | ``` 100 | 101 | 102 | The training data is made using the following data: 103 | - HDR Environment maps from [Laval Indoor HDR Dataset](http://vision.gel.ulaval.ca/~jflalonde/publications/projects/deepIndoorLight/index.html) and [Poly Haven](https://polyhaven.com/) 104 | - 3D mesh models of [Xu et al.](https://cseweb.ucsd.edu/~viscomp/projects/SIG18Relighting/) 105 | - Normal maps from nLMVS-Synth from [nLMVS-Net](https://github.com/kyotovision-public/nLMVS-Net) 106 | 107 | If you want to train without the above cache, please download HDR Environment maps from the above sites, save them to `./data/LavalIndoor+PolyHaven_2k` in OpenEXR format (`.exr`) with a resolution of 2000x1000. You can use `scripts/preprocess_envmap.py` for this. 108 | 109 | Also, please download object shapes from [Xu et al.](https://cseweb.ucsd.edu/~viscomp/projects/SIG18Relighting/) and preprocess them by running: 110 | ```bash 111 | python scripts/preprocess_shape.py 112 | ``` 113 | 114 | 115 | #### DRMNet 116 | 117 | You can train DRMNet by running 118 | ```bash 119 | python main.py --base ./configs/drmnet/train_drmnet.yaml -t --device 0 120 | ``` 121 | The logs and checkpoints are saved to `logs/_train_drmnet`. 122 | 123 | #### ObsNet 124 | 125 | You can train ObsNet by running 126 | ```bash 127 | python main.py --base ./configs/obsnet/train_obsnet.yaml -t --device 0 128 | ``` 129 | The logs and checkpoints are saved to `./logs/_train_obsnet`. 130 | 131 | In order to finetune the ObsNet model, you need to modify the configuration file located at `./configs/inpainting/finetune_obsnet.yaml`. 132 | The default value for `model: params: ckpt_path` is set to `./logs/xxxx-xx-xxTxx-xx-xx_train_obsnet/checkpoints/last.ckpt`. 133 | To finetune the network using raw reflectance maps from random object images, update this path with the above directory and run: 134 | ```bash 135 | python main.py --base ./configs/obsnet/finetune_obsnet.yaml -t --device 0 136 | ``` 137 | . 138 | The logs and checkpoints are saved to `./logs/_finetune_obsnet`. 139 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | 9 | import multiprocessing as mp 10 | from threading import Thread 11 | from queue import Queue 12 | 13 | from inspect import isfunction 14 | from PIL import Image, ImageDraw, ImageFont 15 | 16 | 17 | def log_txt_as_img(wh, xc, size=10): 18 | # wh a tuple of (width, height) 19 | # xc a list of captions to plot 20 | b = len(xc) 21 | txts = list() 22 | for bi in range(b): 23 | txt = Image.new("RGB", wh, color="white") 24 | draw = ImageDraw.Draw(txt) 25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 26 | nc = int(40 * (wh[0] / 256)) 27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 28 | 29 | try: 30 | draw.text((0, 0), lines, fill="black", font=font) 31 | except UnicodeEncodeError: 32 | print("Cant encode string for logging. Skipping.") 33 | 34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 35 | txts.append(txt) 36 | txts = np.stack(txts) 37 | txts = torch.tensor(txts) 38 | return txts 39 | 40 | 41 | def ismap(x): 42 | if not isinstance(x, torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] > 3) 45 | 46 | 47 | def isimage(x): 48 | if not isinstance(x, torch.Tensor): 49 | return False 50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 51 | 52 | 53 | def exists(x): 54 | return x is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | 62 | 63 | def mean_flat(tensor): 64 | """ 65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 66 | Take the mean over all non-batch dimensions. 67 | """ 68 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 69 | 70 | 71 | def count_params(model, verbose=False): 72 | total_params = sum(p.numel() for p in model.parameters()) 73 | if verbose: 74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 75 | return total_params 76 | 77 | 78 | def instantiate_from_config(config): 79 | if not "target" in config: 80 | if config == '__is_first_stage__': 81 | return None 82 | elif config == "__is_unconditional__": 83 | return None 84 | raise KeyError("Expected key `target` to instantiate.") 85 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 86 | 87 | 88 | def get_obj_from_str(string, reload=False): 89 | module, cls = string.rsplit(".", 1) 90 | if reload: 91 | module_imp = importlib.import_module(module) 92 | importlib.reload(module_imp) 93 | return getattr(importlib.import_module(module, package=None), cls) 94 | 95 | 96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 97 | # create dummy dataset instance 98 | 99 | # run prefetching 100 | if idx_to_fn: 101 | res = func(data, worker_id=idx) 102 | else: 103 | res = func(data) 104 | Q.put([idx, res]) 105 | Q.put("Done") 106 | 107 | 108 | def parallel_data_prefetch( 109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 110 | ): 111 | # if target_data_type not in ["ndarray", "list"]: 112 | # raise ValueError( 113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 114 | # ) 115 | if isinstance(data, np.ndarray) and target_data_type == "list": 116 | raise ValueError("list expected but function got ndarray.") 117 | elif isinstance(data, abc.Iterable): 118 | if isinstance(data, dict): 119 | print( 120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 121 | ) 122 | data = list(data.values()) 123 | if target_data_type == "ndarray": 124 | data = np.asarray(data) 125 | else: 126 | data = list(data) 127 | else: 128 | raise TypeError( 129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 130 | ) 131 | 132 | if cpu_intensive: 133 | Q = mp.Queue(1000) 134 | proc = mp.Process 135 | else: 136 | Q = Queue(1000) 137 | proc = Thread 138 | # spawn processes 139 | if target_data_type == "ndarray": 140 | arguments = [ 141 | [func, Q, part, i, use_worker_id] 142 | for i, part in enumerate(np.array_split(data, n_proc)) 143 | ] 144 | else: 145 | step = ( 146 | int(len(data) / n_proc + 1) 147 | if len(data) % n_proc != 0 148 | else int(len(data) / n_proc) 149 | ) 150 | arguments = [ 151 | [func, Q, part, i, use_worker_id] 152 | for i, part in enumerate( 153 | [data[i: i + step] for i in range(0, len(data), step)] 154 | ) 155 | ] 156 | processes = [] 157 | for i in range(n_proc): 158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 159 | processes += [p] 160 | 161 | # start processes 162 | print(f"Start prefetching...") 163 | import time 164 | 165 | start = time.time() 166 | gather_res = [[] for _ in range(n_proc)] 167 | try: 168 | for p in processes: 169 | p.start() 170 | 171 | k = 0 172 | while k < n_proc: 173 | # get result 174 | res = Q.get() 175 | if res == "Done": 176 | k += 1 177 | else: 178 | gather_res[res[0]] = res[1] 179 | 180 | except Exception as e: 181 | print("Exception: ", e) 182 | for p in processes: 183 | p.terminate() 184 | 185 | raise e 186 | finally: 187 | for p in processes: 188 | p.join() 189 | print(f"Prefetching complete. [{time.time() - start} sec.]") 190 | 191 | if target_data_type == 'ndarray': 192 | if not isinstance(gather_res[0], np.ndarray): 193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 194 | 195 | # order outputs 196 | return np.concatenate(gather_res, axis=0) 197 | elif target_data_type == 'list': 198 | out = [] 199 | for r in gather_res: 200 | out.extend(r) 201 | return out 202 | else: 203 | return gather_res 204 | -------------------------------------------------------------------------------- /data/datalists/DeepRelighting_shape5000/val_idx_mvs.txt: -------------------------------------------------------------------------------- 1 | 00008 2 | 00013 3 | 00025 4 | 00031 5 | 00033 6 | 00041 7 | 00044 8 | 00051 9 | 00054 10 | 00055 11 | 00059 12 | 00067 13 | 00070 14 | 00076 15 | 00078 16 | 00084 17 | 00087 18 | 00091 19 | 00092 20 | 00093 21 | 00102 22 | 00111 23 | 00121 24 | 00134 25 | 00138 26 | 00143 27 | 00147 28 | 00149 29 | 00161 30 | 00165 31 | 00173 32 | 00187 33 | 00192 34 | 00199 35 | 00227 36 | 00236 37 | 00240 38 | 00245 39 | 00246 40 | 00250 41 | 00251 42 | 00257 43 | 00265 44 | 00274 45 | 00281 46 | 00282 47 | 00284 48 | 00321 49 | 00324 50 | 00329 51 | 00344 52 | 00345 53 | 00373 54 | 00376 55 | 00377 56 | 00379 57 | 00385 58 | 00386 59 | 00401 60 | 00414 61 | 00417 62 | 00418 63 | 00422 64 | 00429 65 | 00435 66 | 00438 67 | 00440 68 | 00445 69 | 00448 70 | 00454 71 | 00457 72 | 00458 73 | 00462 74 | 00470 75 | 00471 76 | 00495 77 | 00496 78 | 00501 79 | 00503 80 | 00514 81 | 00516 82 | 00523 83 | 00525 84 | 00528 85 | 00530 86 | 00532 87 | 00533 88 | 00534 89 | 00540 90 | 00557 91 | 00567 92 | 00568 93 | 00580 94 | 00584 95 | 00598 96 | 00610 97 | 00615 98 | 00617 99 | 00618 100 | 00622 101 | 00625 102 | 00636 103 | 00639 104 | 00648 105 | 00660 106 | 00662 107 | 00668 108 | 00674 109 | 00682 110 | 00684 111 | 00690 112 | 00692 113 | 00693 114 | 00694 115 | 00695 116 | 00704 117 | 00715 118 | 00717 119 | 00722 120 | 00730 121 | 00736 122 | 00737 123 | 00739 124 | 00748 125 | 00750 126 | 00754 127 | 00755 128 | 00760 129 | 00765 130 | 00770 131 | 00784 132 | 00785 133 | 00786 134 | 00792 135 | 00797 136 | 00798 137 | 00823 138 | 00824 139 | 00828 140 | 00832 141 | 00833 142 | 00839 143 | 00847 144 | 00851 145 | 00853 146 | 00870 147 | 00872 148 | 00884 149 | 00899 150 | 00906 151 | 00908 152 | 00912 153 | 00926 154 | 00944 155 | 00947 156 | 00953 157 | 00958 158 | 00959 159 | 00960 160 | 00961 161 | 00965 162 | 00968 163 | 00984 164 | 00988 165 | 00994 166 | 00997 167 | 01002 168 | 01003 169 | 01005 170 | 01014 171 | 01026 172 | 01070 173 | 01072 174 | 01074 175 | 01077 176 | 01087 177 | 01093 178 | 01097 179 | 01118 180 | 01122 181 | 01128 182 | 01143 183 | 01146 184 | 01147 185 | 01151 186 | 01173 187 | 01180 188 | 01185 189 | 01187 190 | 01190 191 | 01196 192 | 01197 193 | 01212 194 | 01219 195 | 01241 196 | 01243 197 | 01260 198 | 01262 199 | 01266 200 | 01267 201 | 01285 202 | 01287 203 | 01304 204 | 01306 205 | 01308 206 | 01317 207 | 01326 208 | 01341 209 | 01350 210 | 01356 211 | 01361 212 | 01364 213 | 01379 214 | 01382 215 | 01383 216 | 01388 217 | 01389 218 | 01391 219 | 01395 220 | 01408 221 | 01409 222 | 01414 223 | 01423 224 | 01424 225 | 01442 226 | 01448 227 | 01455 228 | 01456 229 | 01458 230 | 01466 231 | 01469 232 | 01470 233 | 01474 234 | 01482 235 | 01485 236 | 01489 237 | 01504 238 | 01510 239 | 01514 240 | 01521 241 | 01523 242 | 01524 243 | 01555 244 | 01558 245 | 01560 246 | 01561 247 | 01564 248 | 01585 249 | 01586 250 | 01593 251 | 01594 252 | 01597 253 | 01600 254 | 01611 255 | 01618 256 | 01633 257 | 01637 258 | 01640 259 | 01642 260 | 01648 261 | 01650 262 | 01658 263 | 01661 264 | 01665 265 | 01669 266 | 01678 267 | 01681 268 | 01686 269 | 01687 270 | 01707 271 | 01711 272 | 01712 273 | 01714 274 | 01718 275 | 01732 276 | 01738 277 | 01747 278 | 01755 279 | 01772 280 | 01773 281 | 01775 282 | 01785 283 | 01788 284 | 01802 285 | 01805 286 | 01812 287 | 01815 288 | 01822 289 | 01828 290 | 01833 291 | 01846 292 | 01848 293 | 01849 294 | 01852 295 | 01879 296 | 01882 297 | 01883 298 | 01884 299 | 01897 300 | 01899 301 | 01907 302 | 01918 303 | 01920 304 | 01923 305 | 01927 306 | 01935 307 | 01941 308 | 01968 309 | 01972 310 | 01994 311 | 01996 312 | 01997 313 | 02019 314 | 02029 315 | 02033 316 | 02034 317 | 02037 318 | 02042 319 | 02044 320 | 02045 321 | 02051 322 | 02056 323 | 02057 324 | 02067 325 | 02069 326 | 02071 327 | 02084 328 | 02093 329 | 02095 330 | 02105 331 | 02114 332 | 02116 333 | 02140 334 | 02141 335 | 02146 336 | 02155 337 | 02156 338 | 02158 339 | 02166 340 | 02169 341 | 02172 342 | 02173 343 | 02174 344 | 02182 345 | 02184 346 | 02186 347 | 02191 348 | 02205 349 | 02207 350 | 02209 351 | 02214 352 | 02217 353 | 02218 354 | 02219 355 | 02222 356 | 02224 357 | 02227 358 | 02232 359 | 02237 360 | 02247 361 | 02257 362 | 02267 363 | 02272 364 | 02279 365 | 02287 366 | 02291 367 | 02292 368 | 02305 369 | 02317 370 | 02332 371 | 02333 372 | 02334 373 | 02340 374 | 02344 375 | 02350 376 | 02352 377 | 02365 378 | 02375 379 | 02379 380 | 02388 381 | 02410 382 | 02415 383 | 02418 384 | 02426 385 | 02435 386 | 02439 387 | 02445 388 | 02453 389 | 02472 390 | 02475 391 | 02478 392 | 02491 393 | 02495 394 | 02497 395 | 02512 396 | 02521 397 | 02531 398 | 02532 399 | 02538 400 | 02544 401 | 02550 402 | 02552 403 | 02555 404 | 02559 405 | 02560 406 | 02566 407 | 02570 408 | 02583 409 | 02585 410 | 02587 411 | 02591 412 | 02593 413 | 02596 414 | 02609 415 | 02610 416 | 02613 417 | 02617 418 | 02624 419 | 02631 420 | 02634 421 | 02641 422 | 02642 423 | 02654 424 | 02669 425 | 02673 426 | 02674 427 | 02677 428 | 02678 429 | 02680 430 | 02682 431 | 02696 432 | 02697 433 | 02698 434 | 02724 435 | 02728 436 | 02731 437 | 02732 438 | 02737 439 | 02740 440 | 02743 441 | 02746 442 | 02749 443 | 02750 444 | 02753 445 | 02766 446 | 02769 447 | 02774 448 | 02778 449 | 02780 450 | 02790 451 | 02814 452 | 02816 453 | 02837 454 | 02842 455 | 02847 456 | 02862 457 | 02866 458 | 02867 459 | 02871 460 | 02873 461 | 02875 462 | 02880 463 | 02897 464 | 02903 465 | 02912 466 | 02915 467 | 02921 468 | 02922 469 | 02924 470 | 02927 471 | 02935 472 | 02938 473 | 02939 474 | 02943 475 | 02953 476 | 02954 477 | 02955 478 | 02969 479 | 02975 480 | 02998 481 | 03002 482 | 03003 483 | 03015 484 | 03021 485 | 03022 486 | 03030 487 | 03036 488 | 03037 489 | 03042 490 | 03047 491 | 03063 492 | 03066 493 | 03088 494 | 03092 495 | 03093 496 | 03095 497 | 03100 498 | 03102 499 | 03108 500 | 03115 501 | 03123 502 | 03129 503 | 03136 504 | 03138 505 | 03142 506 | 03151 507 | 03157 508 | 03159 509 | 03163 510 | 03164 511 | 03171 512 | 03177 513 | 03187 514 | 03190 515 | 03217 516 | 03221 517 | 03237 518 | 03238 519 | 03240 520 | 03252 521 | 03260 522 | 03271 523 | 03273 524 | 03276 525 | 03279 526 | 03285 527 | 03313 528 | 03316 529 | 03320 530 | 03326 531 | 03329 532 | 03336 533 | 03338 534 | 03346 535 | 03356 536 | 03361 537 | 03366 538 | 03380 539 | 03393 540 | 03407 541 | 03411 542 | 03430 543 | 03435 544 | 03437 545 | 03438 546 | 03455 547 | 03459 548 | 03460 549 | 03461 550 | 03466 551 | 03471 552 | 03474 553 | 03486 554 | 03490 555 | 03510 556 | 03512 557 | 03513 558 | 03523 559 | 03525 560 | 03528 561 | 03532 562 | 03550 563 | 03555 564 | 03557 565 | 03558 566 | 03580 567 | 03582 568 | 03583 569 | 03591 570 | 03600 571 | 03616 572 | 03620 573 | 03624 574 | 03626 575 | 03629 576 | 03631 577 | 03632 578 | 03638 579 | 03639 580 | 03640 581 | 03642 582 | 03643 583 | 03646 584 | 03647 585 | 03653 586 | 03670 587 | 03671 588 | 03683 589 | 03696 590 | 03699 591 | 03700 592 | 03705 593 | 03714 594 | 03715 595 | 03721 596 | 03745 597 | 03749 598 | 03753 599 | 03756 600 | 03757 601 | 03760 602 | 03771 603 | 03774 604 | 03782 605 | 03788 606 | 03789 607 | 03793 608 | 03809 609 | 03811 610 | 03819 611 | 03826 612 | 03830 613 | 03831 614 | 03835 615 | 03844 616 | 03848 617 | 03856 618 | 03866 619 | 03867 620 | 03880 621 | 03881 622 | 03887 623 | 03899 624 | 03907 625 | 03914 626 | 03932 627 | 03937 628 | 03940 629 | 03948 630 | 03952 631 | 03954 632 | 03959 633 | 03964 634 | 03970 635 | 03975 636 | 03990 637 | 03991 638 | 03993 639 | 03995 640 | 03996 641 | 03999 642 | 04005 643 | 04011 644 | 04014 645 | 04015 646 | 04016 647 | 04021 648 | 04022 649 | 04025 650 | 04027 651 | 04036 652 | 04039 653 | 04042 654 | 04055 655 | 04056 656 | 04062 657 | 04064 658 | 04082 659 | 04086 660 | 04088 661 | 04091 662 | 04093 663 | 04096 664 | 04098 665 | 04106 666 | 04108 667 | 04118 668 | 04128 669 | 04129 670 | 04130 671 | 04131 672 | 04139 673 | 04140 674 | 04146 675 | 04148 676 | 04150 677 | 04160 678 | 04163 679 | 04166 680 | 04182 681 | 04183 682 | 04184 683 | 04191 684 | 04196 685 | 04201 686 | 04206 687 | 04210 688 | 04218 689 | 04232 690 | 04251 691 | 04252 692 | 04253 693 | 04260 694 | 04287 695 | 04319 696 | 04321 697 | 04324 698 | 04332 699 | 04347 700 | 04350 701 | 04352 702 | 04357 703 | 04363 704 | 04381 705 | 04387 706 | 04395 707 | 04404 708 | 04405 709 | 04423 710 | 04426 711 | 04433 712 | 04434 713 | 04437 714 | 04446 715 | 04449 716 | 04455 717 | 04459 718 | 04462 719 | 04463 720 | 04464 721 | 04466 722 | 04476 723 | 04483 724 | 04487 725 | 04488 726 | 04490 727 | 04505 728 | 04509 729 | 04510 730 | 04512 731 | 04515 732 | 04526 733 | 04531 734 | 04533 735 | 04538 736 | 04541 737 | 04546 738 | 04562 739 | 04576 740 | 04583 741 | 04584 742 | 04588 743 | 04593 744 | 04595 745 | 04599 746 | 04604 747 | 04612 748 | 04634 749 | 04638 750 | 04643 751 | 04644 752 | 04649 753 | 04652 754 | 04654 755 | 04657 756 | 04658 757 | 04668 758 | 04674 759 | 04675 760 | 04676 761 | 04681 762 | 04682 763 | 04697 764 | 04706 765 | 04711 766 | 04719 767 | 04727 768 | 04730 769 | 04732 770 | 04733 771 | 04742 772 | 04743 773 | 04745 774 | 04749 775 | 04750 776 | 04759 777 | 04769 778 | 04783 779 | 04794 780 | 04808 781 | 04809 782 | 04815 783 | 04826 784 | 04829 785 | 04830 786 | 04845 787 | 04850 788 | 04863 789 | 04870 790 | 04914 791 | 04915 792 | 04948 793 | 04950 794 | 04960 795 | 04962 796 | 04965 797 | 04968 798 | 04970 799 | 04988 800 | 04997 -------------------------------------------------------------------------------- /ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from ldm.modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 154 | super().__init__() 155 | inner_dim = dim_head * heads 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 164 | 165 | self.to_out = nn.Sequential( 166 | nn.Linear(inner_dim, query_dim), 167 | nn.Dropout(dropout) 168 | ) 169 | 170 | def forward(self, x, context=None, mask=None): 171 | h = self.heads 172 | 173 | q = self.to_q(x) 174 | context = default(context, x) 175 | k = self.to_k(context) 176 | v = self.to_v(context) 177 | 178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 179 | 180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 181 | 182 | if exists(mask): 183 | mask = rearrange(mask, 'b ... -> b (...)') 184 | max_neg_value = -torch.finfo(sim.dtype).max 185 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 186 | sim.masked_fill_(~mask, max_neg_value) 187 | 188 | # attention, what we cannot get enough of 189 | attn = sim.softmax(dim=-1) 190 | 191 | out = einsum('b i j, b j d -> b i d', attn, v) 192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 193 | return self.to_out(out) 194 | 195 | 196 | class BasicTransformerBlock(nn.Module): 197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): 198 | super().__init__() 199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention 200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 203 | self.norm1 = nn.LayerNorm(dim) 204 | self.norm2 = nn.LayerNorm(dim) 205 | self.norm3 = nn.LayerNorm(dim) 206 | self.checkpoint = checkpoint 207 | 208 | def forward(self, x, context=None): 209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 210 | 211 | def _forward(self, x, context=None): 212 | x = self.attn1(self.norm1(x)) + x 213 | x = self.attn2(self.norm2(x), context=context) + x 214 | x = self.ff(self.norm3(x)) + x 215 | return x 216 | 217 | 218 | class SpatialTransformer(nn.Module): 219 | """ 220 | Transformer block for image-like data. 221 | First, project the input (aka embedding) 222 | and reshape to b, t, d. 223 | Then apply standard transformer action. 224 | Finally, reshape to image 225 | """ 226 | def __init__(self, in_channels, n_heads, d_head, 227 | depth=1, dropout=0., context_dim=None): 228 | super().__init__() 229 | self.in_channels = in_channels 230 | inner_dim = n_heads * d_head 231 | self.norm = Normalize(in_channels) 232 | 233 | self.proj_in = nn.Conv2d(in_channels, 234 | inner_dim, 235 | kernel_size=1, 236 | stride=1, 237 | padding=0) 238 | 239 | self.transformer_blocks = nn.ModuleList( 240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) 241 | for d in range(depth)] 242 | ) 243 | 244 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 245 | in_channels, 246 | kernel_size=1, 247 | stride=1, 248 | padding=0)) 249 | 250 | def forward(self, x, context=None): 251 | # note: if no context is given, cross-attention defaults to self-attention 252 | b, c, h, w = x.shape 253 | x_in = x 254 | x = self.norm(x) 255 | x = self.proj_in(x) 256 | x = rearrange(x, 'b c h w -> b (h w) c') 257 | for block in self.transformer_blocks: 258 | x = block(x, context=context) 259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 260 | x = self.proj_out(x) 261 | return x + x_in -------------------------------------------------------------------------------- /dataset/parametricrefmap.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from time import sleep 3 | from typing import Dict, List, Optional, Union 4 | 5 | import cv2 6 | import torch 7 | 8 | from models.drmnet import DRMNet 9 | from models.obsnet import ObsNetDiffusion 10 | from utils.file_io import load_exr 11 | from utils.transform import thetaphi2xyz 12 | 13 | from .basedataset import BaseDataset 14 | 15 | 16 | class ParametricRefmapDataset(BaseDataset): 17 | def __init__( 18 | self, 19 | size: int, 20 | split: str, 21 | data_root: str, 22 | zdim: int, 23 | transform_func: str = "log", 24 | clamp_before_exp: float = 0, 25 | return_envmap: bool = False, 26 | mask_root: str = None, 27 | mask_area_min_rate: float = 0.002, 28 | epoch_bias: int = 0, 29 | epoch_cycle: int = 1000, 30 | preload_envmap: bool = False, 31 | return_cache=True, 32 | refmap_cache_root: Optional[str] = None, 33 | ): 34 | super().__init__(size, transform_func=transform_func, clamp_before_exp=clamp_before_exp) 35 | 36 | assert split in ["train", "val", "test"] 37 | self.split = split 38 | self.root = Path(data_root) 39 | self.data_name = self.root.name 40 | assert self.data_name in ["LavalIndoor2kxMERL", "HDRIHAVEN_4k", "LavalIndoor+PolyHaven_2k"] 41 | self.t = "train" if split in ["train", "val"] else "test" 42 | with open(f"data/datalists/{self.data_name}/envs_{split}.txt", "r") as f: 43 | self.envs = f.read().splitlines() 44 | 45 | if mask_root is not None: 46 | self.with_mask = True 47 | self.mask_root = Path(mask_root) 48 | self.mask_name = self.mask_root.name 49 | with open(f"data/datalists/{self.mask_name}/sparsemaskannotations_{split}.txt", "r") as f: 50 | self.mask_annotations = f.read().splitlines() 51 | self.mask_len = len(self.mask_annotations) 52 | self.mask_area_min_rate = mask_area_min_rate 53 | else: 54 | self.with_mask = False 55 | 56 | self.zdim = zdim 57 | self.return_envmap = return_envmap 58 | 59 | self.generator = torch.Generator() 60 | self.current_epoch = 0 61 | 62 | self.model: Union[DRMNet, ObsNetDiffusion] = None 63 | 64 | self.return_cache = return_cache 65 | self.epoch_bias = epoch_bias 66 | self.epoch_cycle = epoch_cycle 67 | self.preload_envmap = preload_envmap 68 | if self.return_envmap and preload_envmap: 69 | self.envmaps = {} 70 | for env in self.envs: 71 | env_name = env[:-4] 72 | self.envmaps[env_name] = load_exr(self.root / f"{env_name}.exr", as_torch=True) 73 | 74 | self.refmap_cache_root = Path(refmap_cache_root) if refmap_cache_root is not None else None 75 | if self.return_cache: 76 | assert self.refmap_cache_root is not None, "specify refmap_cache_root to return cache" 77 | 78 | def __len__(self): 79 | return len(self.envs) 80 | 81 | def set_current_epoch(self, epoch): 82 | self.current_epoch = epoch 83 | 84 | def set_generator(self, idx: int, epoch: int = None): 85 | if self.split == "train": 86 | epoch = epoch or self.current_epoch 87 | epoch = epoch + self.epoch_bias 88 | if epoch >= self.epoch_cycle: 89 | epoch = epoch % self.epoch_cycle 90 | self.generator.manual_seed((epoch) * len(self) + idx) 91 | elif self.split == "val": 92 | self.generator.manual_seed(idx) 93 | self.generator.manual_seed(torch.empty((), dtype=torch.int64).random_(generator=self.generator).item()) 94 | elif self.split == "test": 95 | self.generator.manual_seed(idx) 96 | torch.empty((), dtype=torch.int64).random_(generator=self.generator) 97 | self.generator.manual_seed(torch.empty((), dtype=torch.int64).random_(generator=self.generator).item()) 98 | else: 99 | raise NotImplementedError() 100 | 101 | @torch.no_grad() 102 | def __getitem__(self, idx: int): 103 | env_name = self.envs[idx][:-4] 104 | self.set_generator(idx) 105 | zK = torch.rand((self.zdim,), generator=self.generator) 106 | 107 | data = {} 108 | data["zK"] = zK 109 | data["envmap_name"] = env_name 110 | 111 | normalized_k = torch.rand((), generator=self.generator) 112 | data["normalized_k"] = normalized_k 113 | 114 | phi = (torch.rand((), generator=self.generator) * 64).int() / 64 * torch.pi * 2 - torch.pi 115 | theta = (torch.rand((), generator=self.generator) * 0 + 0.5) * torch.pi # Invalid 116 | view_from = thetaphi2xyz(torch.stack([theta, phi]), normal=[0, 1, 0], tangent=[0, 0, 1]) 117 | data["view_from"] = view_from 118 | 119 | mask_idx = torch.rand((), generator=self.generator).item() 120 | if self.with_mask: 121 | mask_idx = int(mask_idx * self.mask_len) 122 | while True: 123 | mask = cv2.imread(str(self.mask_root / self.t / self.mask_annotations[mask_idx]), -1) 124 | height, width = mask.shape[:2] 125 | # don't use the masks with too small region 126 | if mask.astype(bool).sum() >= height * width * self.mask_area_min_rate: 127 | break 128 | else: 129 | mask_idx = (mask_idx + 1) % self.mask_len 130 | mask = cv2.resize(mask, (self.size, self.size), interpolation=cv2.INTER_NEAREST) 131 | data["mask"] = mask / 255 132 | 133 | # load rendered cata cache 134 | if self.model is not None and self.return_cache: 135 | brdf_param_names = self.model.renderer.brdf_param_names or self.model.brdf_param_names 136 | size = self.model.renderer.refmap_res 137 | spp = self.model.renderer.spp 138 | denoise_suffix = f"_{self.model.renderer.denoise}denoise" if self.model.renderer.denoise else "" 139 | pieces_cache_dir = self.refmap_cache_root / f'{"-".join(brdf_param_names)}/{size}x{size}_spp{spp}{denoise_suffix}/' 140 | torch.set_printoptions(precision=4, sci_mode=True) 141 | 142 | def get_cache(z) -> torch.Tensor: 143 | pieces_key = "b" + str(z)[7:-1] + "v" + str(view_from)[7:-1] 144 | pieces_key = pieces_key.replace("\n", "").replace(" ", "") 145 | filename = pieces_key + ".pt" 146 | cache_file_path = pieces_cache_dir / env_name / filename 147 | if not cache_file_path.exists(): 148 | return False, torch.full((3, size, size), torch.nan) 149 | for _ in range(3): 150 | try: 151 | cache: dict = torch.load(cache_file_path, map_location="cpu") 152 | except Exception as e: 153 | print(cache_file_path) 154 | print(e) 155 | sleep(0.01) 156 | else: 157 | break 158 | else: 159 | return False, torch.full((3, size, size), torch.nan) 160 | if ( 161 | cache.get("envmap_name") == env_name 162 | and cache.get("brdf_param_names") == brdf_param_names 163 | and torch.allclose(cache.get("zk"), z) 164 | and torch.allclose(cache.get("view_from"), view_from) 165 | and cache.get("refmap_res") == size 166 | ): 167 | if (cache.get("zk") == z).all(): 168 | return True, cache.get("rendering_results"), True 169 | else: 170 | return True, cache.get("rendering_results"), True 171 | else: 172 | return False, torch.full((3, size, size), torch.nan) 173 | 174 | rK = get_cache(zK) 175 | data["LrK"] = rK[1] 176 | 177 | if (z0 := getattr(self.model, "_z0", None)) is not None: # to train DRMNet 178 | K, k, zk, zkm1 = self.model.get_schedule(zK, z0=z0, normalized_k=normalized_k, return_zkm1=True) 179 | # zk and zkm1 is the shape of [batch, zdim] 180 | data["K"] = K 181 | data["k"] = k 182 | rk = get_cache(zk) 183 | data["zk"] = zk 184 | data["Lrk"] = rk[1] 185 | if K > 0: 186 | rkm1 = get_cache(zkm1) 187 | data["zkm1"] = zkm1 188 | data["Lrkm1"] = rkm1[1] 189 | else: 190 | data["zkm1"] = torch.full_like(zkm1, torch.nan) 191 | data["Lrkm1"] = torch.full_like(rk[1], torch.nan) 192 | r0 = get_cache(z0) 193 | data["r0"] = r0[1] 194 | seem_need_envmap = not (rK[0] and rk[0] and (rkm1[0] or r0[0])) 195 | else: # to train ObsNet 196 | seem_need_envmap = not rK[0] 197 | else: 198 | seem_need_envmap = True 199 | 200 | if self.return_envmap and seem_need_envmap: 201 | if self.preload_envmap: 202 | envmap = self.envmaps[env_name] 203 | else: 204 | envmap = load_exr(self.root / f"{env_name}.exr", as_torch=True) 205 | data["envmap"] = envmap 206 | elif self.return_envmap: 207 | if self.preload_envmap: 208 | envmap = self.envmaps[env_name] 209 | else: 210 | # skip loading environment map. 211 | try: 212 | envmap_size = self.model.renderer.envmap_size 213 | except Exception: 214 | envmap_size = (1000, 2000) 215 | envmap = torch.full((*envmap_size, 3), torch.nan, dtype=torch.float) 216 | data["envmap"] = envmap 217 | 218 | data["tag"] = env_name 219 | 220 | return data 221 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from ldm.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def checkpoint(func, inputs, params, flag): 103 | """ 104 | Evaluate a function without caching intermediate activations, allowing for 105 | reduced memory at the expense of extra compute in the backward pass. 106 | :param func: the function to evaluate. 107 | :param inputs: the argument sequence to pass to `func`. 108 | :param params: a sequence of parameters `func` depends on but does not 109 | explicitly take as arguments. 110 | :param flag: if False, disable gradient checkpointing. 111 | """ 112 | if flag: 113 | args = tuple(inputs) + tuple(params) 114 | return CheckpointFunction.apply(func, len(inputs), *args) 115 | else: 116 | return func(*inputs) 117 | 118 | 119 | class CheckpointFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, run_function, length, *args): 122 | ctx.run_function = run_function 123 | ctx.input_tensors = list(args[:length]) 124 | ctx.input_params = list(args[length:]) 125 | 126 | with torch.no_grad(): 127 | output_tensors = ctx.run_function(*ctx.input_tensors) 128 | return output_tensors 129 | 130 | @staticmethod 131 | def backward(ctx, *output_grads): 132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 133 | with torch.enable_grad(): 134 | # Fixes a bug where the first op in run_function modifies the 135 | # Tensor storage in place, which is not allowed for detach()'d 136 | # Tensors. 137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 138 | output_tensors = ctx.run_function(*shallow_copies) 139 | input_grads = torch.autograd.grad( 140 | output_tensors, 141 | ctx.input_tensors + ctx.input_params, 142 | output_grads, 143 | allow_unused=True, 144 | ) 145 | del ctx.input_tensors 146 | del ctx.input_params 147 | del output_tensors 148 | return (None, None) + input_grads 149 | 150 | 151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 152 | """ 153 | Create sinusoidal timestep embeddings. 154 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 155 | These may be fractional. 156 | :param dim: the dimension of the output. 157 | :param max_period: controls the minimum frequency of the embeddings. 158 | :return: an [N x dim] Tensor of positional embeddings. 159 | """ 160 | if not repeat_only: 161 | half = dim // 2 162 | freqs = torch.exp( 163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 164 | ).to(device=timesteps.device) 165 | args = timesteps[:, None].float() * freqs[None] 166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 167 | if dim % 2: 168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 169 | else: 170 | embedding = repeat(timesteps, 'b -> b d', d=dim) 171 | return embedding 172 | 173 | 174 | def zero_module(module): 175 | """ 176 | Zero out the parameters of a module and return it. 177 | """ 178 | for p in module.parameters(): 179 | p.detach().zero_() 180 | return module 181 | 182 | 183 | def scale_module(module, scale): 184 | """ 185 | Scale the parameters of a module and return it. 186 | """ 187 | for p in module.parameters(): 188 | p.detach().mul_(scale) 189 | return module 190 | 191 | 192 | def mean_flat(tensor): 193 | """ 194 | Take the mean over all non-batch dimensions. 195 | """ 196 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 197 | 198 | 199 | def normalization(channels): 200 | """ 201 | Make a standard normalization layer. 202 | :param channels: number of input channels. 203 | :return: an nn.Module for normalization. 204 | """ 205 | return GroupNorm32(32, channels) 206 | 207 | 208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 209 | class SiLU(nn.Module): 210 | def forward(self, x): 211 | return x * torch.sigmoid(x) 212 | 213 | 214 | class GroupNorm32(nn.GroupNorm): 215 | def forward(self, x): 216 | return super().forward(x.float()).type(x.dtype) 217 | 218 | def conv_nd(dims, *args, **kwargs): 219 | """ 220 | Create a 1D, 2D, or 3D convolution module. 221 | """ 222 | if dims == 1: 223 | return nn.Conv1d(*args, **kwargs) 224 | elif dims == 2: 225 | return nn.Conv2d(*args, **kwargs) 226 | elif dims == 3: 227 | return nn.Conv3d(*args, **kwargs) 228 | raise ValueError(f"unsupported dimensions: {dims}") 229 | 230 | 231 | def linear(*args, **kwargs): 232 | """ 233 | Create a linear module. 234 | """ 235 | return nn.Linear(*args, **kwargs) 236 | 237 | 238 | def avg_pool_nd(dims, *args, **kwargs): 239 | """ 240 | Create a 1D, 2D, or 3D average pooling module. 241 | """ 242 | if dims == 1: 243 | return nn.AvgPool1d(*args, **kwargs) 244 | elif dims == 2: 245 | return nn.AvgPool2d(*args, **kwargs) 246 | elif dims == 3: 247 | return nn.AvgPool3d(*args, **kwargs) 248 | raise ValueError(f"unsupported dimensions: {dims}") 249 | 250 | 251 | class HybridConditioner(nn.Module): 252 | 253 | def __init__(self, c_concat_config, c_crossattn_config): 254 | super().__init__() 255 | self.concat_conditioner = instantiate_from_config(c_concat_config) 256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 257 | 258 | def forward(self, c_concat, c_crossattn): 259 | c_concat = self.concat_conditioner(c_concat) 260 | c_crossattn = self.crossattn_conditioner(c_crossattn) 261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 262 | 263 | 264 | def noise_like(shape, device, repeat=False): 265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 266 | noise = lambda: torch.randn(shape, device=device) 267 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /ldm/models/diffusion/ddim.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | from functools import partial 4 | 5 | import numpy as np 6 | import torch 7 | from tqdm import tqdm 8 | 9 | from ldm.modules.diffusionmodules.util import ( 10 | make_ddim_sampling_parameters, 11 | make_ddim_timesteps, 12 | noise_like, 13 | ) 14 | 15 | 16 | class DDIMSampler(object): 17 | def __init__(self, model, schedule="linear", **kwargs): 18 | super().__init__() 19 | self.model = model 20 | self.ddpm_num_timesteps = model.num_timesteps 21 | self.schedule = schedule 22 | 23 | def register_buffer(self, name, attr): 24 | if type(attr) == torch.Tensor: 25 | if attr.device != torch.device("cuda"): 26 | attr = attr.to(torch.device("cuda")) 27 | setattr(self, name, attr) 28 | 29 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True): 30 | self.ddim_timesteps = make_ddim_timesteps( 31 | ddim_discr_method=ddim_discretize, 32 | num_ddim_timesteps=ddim_num_steps, 33 | num_ddpm_timesteps=self.ddpm_num_timesteps, 34 | verbose=verbose, 35 | ) 36 | alphas_cumprod = self.model.alphas_cumprod 37 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, "alphas have to be defined for each timestep" 38 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 39 | 40 | self.register_buffer("betas", to_torch(self.model.betas)) 41 | self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) 42 | self.register_buffer("alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)) 43 | 44 | # calculations for diffusion q(x_t | x_{t-1}) and others 45 | self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))) 46 | self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod.cpu()))) 47 | self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))) 48 | self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))) 49 | self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1))) 50 | 51 | # ddim sampling parameters 52 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( 53 | alphacums=alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, eta=ddim_eta, verbose=verbose 54 | ) 55 | self.register_buffer("ddim_sigmas", ddim_sigmas) 56 | self.register_buffer("ddim_alphas", ddim_alphas) 57 | self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) 58 | self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) 59 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 60 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) 61 | ) 62 | self.register_buffer("ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps) 63 | 64 | @torch.no_grad() 65 | def sample( 66 | self, 67 | S, 68 | batch_size, 69 | shape, 70 | conditioning=None, 71 | callback=None, 72 | normals_sequence=None, 73 | img_callback=None, 74 | quantize_x0=False, 75 | eta=0.0, 76 | mask=None, 77 | x0=None, 78 | temperature=1.0, 79 | noise_dropout=0.0, 80 | score_corrector=None, 81 | corrector_kwargs=None, 82 | verbose=True, 83 | x_T=None, 84 | log_every_t=100, 85 | unconditional_guidance_scale=1.0, 86 | unconditional_conditioning=None, 87 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 88 | **kwargs, 89 | ): 90 | if conditioning is not None: 91 | if isinstance(conditioning, dict): 92 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 93 | if cbs != batch_size: 94 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 95 | else: 96 | if conditioning.shape[0] != batch_size: 97 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 98 | 99 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 100 | # sampling 101 | # C, H, W = shape 102 | # size = (batch_size, C, H, W) 103 | size = (batch_size, *shape) 104 | if verbose: 105 | print(f"Data shape for DDIM sampling is {size}, eta {eta}") 106 | 107 | samples, intermediates = self.ddim_sampling( 108 | conditioning, 109 | size, 110 | callback=callback, 111 | img_callback=img_callback, 112 | quantize_denoised=quantize_x0, 113 | mask=mask, 114 | x0=x0, 115 | ddim_use_original_steps=False, 116 | noise_dropout=noise_dropout, 117 | temperature=temperature, 118 | score_corrector=score_corrector, 119 | corrector_kwargs=corrector_kwargs, 120 | x_T=x_T, 121 | log_every_t=log_every_t, 122 | unconditional_guidance_scale=unconditional_guidance_scale, 123 | unconditional_conditioning=unconditional_conditioning, 124 | verbose=verbose, 125 | ) 126 | return samples, intermediates 127 | 128 | @torch.no_grad() 129 | def ddim_sampling( 130 | self, 131 | cond, 132 | shape, 133 | x_T=None, 134 | ddim_use_original_steps=False, 135 | callback=None, 136 | timesteps=None, 137 | quantize_denoised=False, 138 | mask=None, 139 | x0=None, 140 | img_callback=None, 141 | log_every_t=100, 142 | temperature=1.0, 143 | noise_dropout=0.0, 144 | score_corrector=None, 145 | corrector_kwargs=None, 146 | unconditional_guidance_scale=1.0, 147 | unconditional_conditioning=None, 148 | verbose=True, 149 | ): 150 | device = self.model.betas.device 151 | b = shape[0] 152 | if x_T is None: 153 | img = torch.randn(shape, device=device) 154 | else: 155 | img = x_T 156 | 157 | if timesteps is None: 158 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 159 | elif timesteps is not None and not ddim_use_original_steps: 160 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 161 | timesteps = self.ddim_timesteps[:subset_end] 162 | 163 | intermediates = {"x_inter": [img], "pred_x0": [img]} 164 | time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps) 165 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 166 | if verbose: 167 | print(f"Running DDIM Sampling with {total_steps} timesteps") 168 | 169 | iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) if verbose else time_range 170 | 171 | for i, step in enumerate(iterator): 172 | index = total_steps - i - 1 173 | ts = torch.full((b,), step, device=device, dtype=torch.long) 174 | 175 | if mask is not None: 176 | assert x0 is not None 177 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 178 | img = img_orig * mask + (1.0 - mask) * img 179 | outs = self.p_sample_ddim( 180 | img, 181 | cond, 182 | ts, 183 | index=index, 184 | use_original_steps=ddim_use_original_steps, 185 | quantize_denoised=quantize_denoised, 186 | temperature=temperature, 187 | noise_dropout=noise_dropout, 188 | score_corrector=score_corrector, 189 | corrector_kwargs=corrector_kwargs, 190 | unconditional_guidance_scale=unconditional_guidance_scale, 191 | unconditional_conditioning=unconditional_conditioning, 192 | ) 193 | img, pred_x0 = outs 194 | if callback: 195 | callback(i) 196 | if img_callback: 197 | img_callback(pred_x0, i) 198 | 199 | if index % log_every_t == 0 or index == total_steps - 1: 200 | intermediates["x_inter"].append(img) 201 | intermediates["pred_x0"].append(pred_x0) 202 | 203 | # return pred_x0, intermediates 204 | return img, intermediates 205 | 206 | @torch.no_grad() 207 | def p_sample_ddim( 208 | self, 209 | x, 210 | c, 211 | t, 212 | index, 213 | repeat_noise=False, 214 | use_original_steps=False, 215 | quantize_denoised=False, 216 | temperature=1.0, 217 | noise_dropout=0.0, 218 | score_corrector=None, 219 | corrector_kwargs=None, 220 | unconditional_guidance_scale=1.0, 221 | unconditional_conditioning=None, 222 | ): 223 | b, *_, device = *x.shape, x.device 224 | 225 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: 226 | e_t = self.model.apply_model(x, t, c) 227 | else: 228 | x_in = torch.cat([x] * 2) 229 | t_in = torch.cat([t] * 2) 230 | c_in = torch.cat([unconditional_conditioning, c]) 231 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 232 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 233 | 234 | if score_corrector is not None: 235 | assert self.model.parameterization == "eps" 236 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 237 | 238 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 239 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 240 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 241 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 242 | # select parameters corresponding to the currently considered timestep 243 | param_shape = (b, *[1] * (x.ndim - 1)) 244 | a_t = torch.full(param_shape, alphas[index], device=device) 245 | a_prev = torch.full(param_shape, alphas_prev[index], device=device) 246 | sigma_t = torch.full(param_shape, sigmas[index], device=device) 247 | sqrt_one_minus_at = torch.full(param_shape, sqrt_one_minus_alphas[index], device=device) 248 | 249 | # current prediction for x_0 250 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 251 | if quantize_denoised: 252 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 253 | # direction pointing to x_t 254 | dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t 255 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 256 | if noise_dropout > 0.0: 257 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 258 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 259 | return x_prev, pred_x0 260 | -------------------------------------------------------------------------------- /data/datalists/LavalIndoor+PolyHaven_2k/envs_val.txt: -------------------------------------------------------------------------------- 1 | 9C4A0004-db1d4a14f3.exr 2 | 9C4A0006-5133111e97.exr 3 | 9C4A0094-c96fec6e86.exr 4 | 9C4A0121-ff0fd23bdf.exr 5 | 9C4A0190-fdd2f9a709.exr 6 | 9C4A0195-2bf90f1d2b.exr 7 | 9C4A0223-e3fd38d26b.exr 8 | 9C4A0288-4c4f5603d4.exr 9 | 9C4A0363-e2a65b93c9.exr 10 | 9C4A0380-2a79f5a651.exr 11 | 9C4A0506-a2c5835c9d.exr 12 | 9C4A0515-c6bcfca78b.exr 13 | 9C4A0521-72b6a00e22.exr 14 | 9C4A0548-b8fe6edac7.exr 15 | 9C4A0557-2fe726c248.exr 16 | 9C4A0610-b489c4b14c.exr 17 | 9C4A0647-598b747ac7.exr 18 | 9C4A0699-98eeece4b5.exr 19 | 9C4A0734-48e1ccb990.exr 20 | 9C4A0773-f0f49cf80b.exr 21 | 9C4A0792-5b369124ba.exr 22 | 9C4A0809-3df674487f.exr 23 | 9C4A0820-e8765c372a.exr 24 | 9C4A0834-2de1f99cd3.exr 25 | 9C4A0862-f28cb60be4.exr 26 | 9C4A0891-380751a37a.exr 27 | 9C4A1025-08238e4b68.exr 28 | 9C4A1059-bc248da2fd.exr 29 | 9C4A1183-c1d8f3e7df.exr 30 | 9C4A1203-8752c564f9.exr 31 | 9C4A1234-2002212421.exr 32 | 9C4A1314-9c5d513964.exr 33 | 9C4A1360-73d1175297.exr 34 | 9C4A1444-ef87f639a3.exr 35 | 9C4A1528-b4b5787538.exr 36 | 9C4A1548-a9724e2b8c.exr 37 | 9C4A1570-a5b3087e38.exr 38 | 9C4A1620-dbad85c001.exr 39 | 9C4A1674-a25f73491e.exr 40 | 9C4A1691-5901d1aafe.exr 41 | 9C4A1716-a8e8b7e083.exr 42 | 9C4A1822-85ef877b1e.exr 43 | 9C4A1901-0eeb44294b.exr 44 | 9C4A1943-bdf79017dc.exr 45 | 9C4A1985-aab7b68d9a.exr 46 | 9C4A2085-215391218b.exr 47 | 9C4A2152-ef4bcb5309.exr 48 | 9C4A2166-9212f3b58a.exr 49 | 9C4A2185-da53a9297d.exr 50 | 9C4A2269-3107b3bf85.exr 51 | 9C4A2305-65871b69f5.exr 52 | 9C4A2311-bd26e0601a.exr 53 | 9C4A2334-79dbacbb08.exr 54 | 9C4A2431-e36b96d555.exr 55 | 9C4A2467-6cb5aa1891.exr 56 | 9C4A2495-40e081c571.exr 57 | 9C4A2512-b71383ee6c.exr 58 | 9C4A2551-3436caaf51.exr 59 | 9C4A2557-0c588a5205.exr 60 | 9C4A2596-d28c8f5782.exr 61 | 9C4A2680-6259d498ee.exr 62 | 9C4A2747-1a74e2cc6f.exr 63 | 9C4A2764-a6b3b13445.exr 64 | 9C4A2815-d320106a26.exr 65 | 9C4A2941-6477abc5de.exr 66 | 9C4A3025-333daf71a2.exr 67 | 9C4A3151-a74891cfa1.exr 68 | 9C4A3229-b96c0b1b22.exr 69 | 9C4A3355-a4fc639c23.exr 70 | 9C4A3359-3edd9ddf7d.exr 71 | 9C4A3377-89d0da6ab9.exr 72 | 9C4A3481-8a24936a18.exr 73 | 9C4A3569-419fcb3cd4.exr 74 | 9C4A3587-5b913b67f8.exr 75 | 9C4A3821-5e85168ed1.exr 76 | 9C4A3839-4be94bcaa6.exr 77 | 9C4A4091-08362b2a4a.exr 78 | 9C4A4133-12dc69df77.exr 79 | 9C4A4217-8355d768de.exr 80 | 9C4A4222-c4047ecd5e.exr 81 | 9C4A4259-ac42157b8b.exr 82 | 9C4A4427-c54f7780ac.exr 83 | 9C4A4458-47bb11126c.exr 84 | 9C4A4475-8f1167988a.exr 85 | 9C4A4685-842b5964a3.exr 86 | 9C4A4693-b2deb5d7b4.exr 87 | 9C4A4700-56b00c8adf.exr 88 | 9C4A4745-350f0b51f2.exr 89 | 9C4A4752-e9023a450f.exr 90 | 9C4A4769-2e8ee613aa.exr 91 | 9C4A4777-23b9945780.exr 92 | 9C4A4811-64470bd442.exr 93 | 9C4A4833-a8c0745bc0.exr 94 | 9C4A4861-19626f89e9.exr 95 | 9C4A4903-5827dd45e5.exr 96 | 9C4A4962-f08140d831.exr 97 | 9C4A4987-d943ae67ef.exr 98 | 9C4A5001-33efe6bc42.exr 99 | 9C4A5046-e0854552ca.exr 100 | 9C4A5081-4bec473451.exr 101 | 9C4A5130-e626eb4e0e.exr 102 | 9C4A5207-40b28f1858.exr 103 | 9C4A5231-7156cc48fa.exr 104 | 9C4A5291-d47902d9d1.exr 105 | 9C4A5295-8b21b94afb.exr 106 | 9C4A5417-f5d395fb63.exr 107 | 9C4A5438-f8362ba6fb.exr 108 | 9C4A5463-bed9fb35dc.exr 109 | 9C4A5609-2b0a216e72.exr 110 | 9C4A5673-95c38effde.exr 111 | 9C4A5722-8d5177818a.exr 112 | 9C4A5756-39e65e0e7b.exr 113 | 9C4A5967-b2868d005f.exr 114 | 9C4A5984-edcea9657b.exr 115 | 9C4A6051-f7b81355d6.exr 116 | 9C4A6099-89226b9794.exr 117 | 9C4A6173-2a6a16867b.exr 118 | 9C4A6177-a190c43daf.exr 119 | 9C4A6257-8457dcbbcb.exr 120 | 9C4A6261-8ce9192df8.exr 121 | 9C4A6269-ace1d48e17.exr 122 | 9C4A6353-b2117208db.exr 123 | 9C4A6520-ffc199b651.exr 124 | 9C4A6563-d9748a38ee.exr 125 | 9C4A6605-738d18586f.exr 126 | 9C4A6688-66df351916.exr 127 | 9C4A6729-b4e7c82b9a.exr 128 | 9C4A6730-7762c53f2a.exr 129 | 9C4A6773-240584c2f6.exr 130 | 9C4A6813-8339c9db66.exr 131 | 9C4A6856-bf4f279ab7.exr 132 | 9C4A6898-5ae1360451.exr 133 | 9C4A6948-1c80f5988a.exr 134 | 9C4A6985-ab452edfac.exr 135 | 9C4A7024-135db952f8.exr 136 | 9C4A7149-796e57659b.exr 137 | 9C4A7157-8faaac3321.exr 138 | 9C4A7200-e5a05088f6.exr 139 | 9C4A7234-45127bb0ef.exr 140 | 9C4A7237-86a1e997c7.exr 141 | 9C4A7265-b5a565a5f9.exr 142 | 9C4A7279-f38d3b6e11.exr 143 | 9C4A7325-8acd063877.exr 144 | 9C4A7349-2a97cb8d20.exr 145 | 9C4A7363-e36682b7a6.exr 146 | 9C4A7475-fc16256c1a.exr 147 | 9C4A7527-1a4fddf812.exr 148 | 9C4A7643-621ea8fd87.exr 149 | 9C4A7830-0cd17635eb.exr 150 | 9C4A7853-7a8c7af2e2.exr 151 | 9C4A7905-838e1e619f.exr 152 | 9C4A7916-b255607e80.exr 153 | 9C4A7962-09a6a98dd4.exr 154 | 9C4A8029-bd2865eb6c.exr 155 | 9C4A8115-c461a3d429.exr 156 | 9C4A8126-eebf0a2b1d.exr 157 | 9C4A8201-9f470daffb.exr 158 | 9C4A8256-455a74db43.exr 159 | 9C4A8283-910c0942f5.exr 160 | 9C4A8294-0d27955d83.exr 161 | 9C4A8298-8660f8fb63.exr 162 | 9C4A8336-51cd9186c9.exr 163 | 9C4A8378-92f902cb9a.exr 164 | 9C4A8462-635c86de75.exr 165 | 9C4A8537-a746398768.exr 166 | 9C4A8588-a453bf3f11.exr 167 | 9C4A8673-317fbfa945.exr 168 | 9C4A8676-1428be837d.exr 169 | 9C4A8705-b2b109621a.exr 170 | 9C4A8735-21ad245f21.exr 171 | 9C4A8787-68bb5dfb7f.exr 172 | 9C4A8789-01f75c0d50.exr 173 | 9C4A8805-bb5caeacc0.exr 174 | 9C4A8854-c4d1835943.exr 175 | 9C4A8885-f6af76eb6f.exr 176 | 9C4A8896-9debb45a0a.exr 177 | 9C4A8903-e53dd3b694.exr 178 | 9C4A8938-c4764af7ea.exr 179 | 9C4A9018-a00e7c48c1.exr 180 | 9C4A9022-545962cb4a.exr 181 | 9C4A9067-5f321df977.exr 182 | 9C4A9102-2b2c2e1fc1.exr 183 | 9C4A9106-756ebb1b05.exr 184 | 9C4A9109-aaac1fa3e8.exr 185 | 9C4A9135-ac750c2355.exr 186 | 9C4A9151-d6e3ff9686.exr 187 | 9C4A9169-06db224630.exr 188 | 9C4A9177-01f88e7416.exr 189 | 9C4A9190-44bb277913.exr 190 | 9C4A9232-e34e8e8c5c.exr 191 | 9C4A9258-da8ecccf8b.exr 192 | 9C4A9261-8cd13bb112.exr 193 | 9C4A9274-cdf7c9a546.exr 194 | 9C4A9277-b68d965d55.exr 195 | 9C4A9300-be2f4c26f1.exr 196 | 9C4A9306-1b67f70b9b.exr 197 | 9C4A9319-0777867163.exr 198 | 9C4A9333-ca0390ec8f.exr 199 | 9C4A9337-4aeac4961b.exr 200 | 9C4A9442-1eccd6a5bd.exr 201 | 9C4A9457-674221ed7c.exr 202 | 9C4A9499-145bdcf85c.exr 203 | 9C4A9501-e46f5a89b9.exr 204 | 9C4A9510-7974446501.exr 205 | 9C4A9529-5fb98484b9.exr 206 | 9C4A9559-ba84d94357.exr 207 | 9C4A9560-3227dda698.exr 208 | 9C4A9631-24af5139b1.exr 209 | 9C4A9649-61b1caab87.exr 210 | 9C4A9733-3cbb4b4156.exr 211 | 9C4A9753-11ea66dac0.exr 212 | 9C4A9775-d5e94f8a67.exr 213 | 9C4A9792-d4e2f3874d.exr 214 | 9C4A9834-e1aeaf3b1d.exr 215 | 9C4A9876-0fe6327618.exr 216 | 9C4A9910-9d3680de8c.exr 217 | 9C4A9916-6fca6ab033.exr 218 | 9C4A9921-ff2c065635.exr 219 | 9C4A9925-89234764b3.exr 220 | AG8A0109-6655eb6c45.exr 221 | AG8A0117-e610b0a3f4.exr 222 | AG8A0134-1c7e929b22.exr 223 | AG8A0151-7e2bd1fa37.exr 224 | AG8A0193-fffb462068.exr 225 | AG8A0251-ece968a2a4.exr 226 | AG8A0260-77edec782e.exr 227 | AG8A0277-069adeb954.exr 228 | AG8A0279-4a5da59d8d.exr 229 | AG8A0293-ea887d3710.exr 230 | AG8A0297-167904b271.exr 231 | AG8A0381-045c77009f.exr 232 | AG8A0403-54da1718a0.exr 233 | AG8A0428-e476ea5d27.exr 234 | AG8A0447-38278c576c.exr 235 | AG8A0470-26b8ffcddd.exr 236 | AG8A0549-1f8c0654df.exr 237 | AG8A0630-e5622e17d2.exr 238 | AG8A0697-f4d66ceb41.exr 239 | AG8A0741-1ef89b4831.exr 240 | AG8A0836-614a104327.exr 241 | AG8A0848-7299a90529.exr 242 | AG8A0865-a0a2934ca2.exr 243 | AG8A0924-4484ad8b59.exr 244 | AG8A0932-fa3e4b0159.exr 245 | AG8A0949-324a96cb40.exr 246 | AG8A1088-86892af86a.exr 247 | AG8A1092-47eb67a179.exr 248 | AG8A1097-49cfea8ebc.exr 249 | AG8A1130-166b53384f.exr 250 | AG8A1256-3f9f1a8dc8.exr 251 | AG8A1260-0265cb5f07.exr 252 | AG8A1302-0dd1e6995e.exr 253 | AG8A1344-4ebe8355c9.exr 254 | AG8A1466-8d4acf0f64.exr 255 | AG8A1495-8fdd205261.exr 256 | AG8A1579-e0e15ad898.exr 257 | AG8A1592-9e5e39d41a.exr 258 | AG8A1596-ac1f53bc32.exr 259 | AG8A1860-c6f23099fd.exr 260 | AG8A1886-94d5ffee58.exr 261 | AG8A2041-d3d9cfc286.exr 262 | AG8A2103-662bc79838.exr 263 | AG8A2112-275f101b69.exr 264 | AG8A2122-bc6d9fdb45.exr 265 | AG8A2271-1ba8126512.exr 266 | AG8A2342-7d9b329bb5.exr 267 | AG8A2479-971317db29.exr 268 | AG8A2647-92dd43be63.exr 269 | AG8A2773-3588b21e4b.exr 270 | AG8A2899-cab5ecbf68.exr 271 | AG8A2924-054c4b16cb.exr 272 | AG8A3116-89ae3d7341.exr 273 | AG8A3196-8e1dfd8d95.exr 274 | AG8A3260-65b5c8e148.exr 275 | AG8A3343-c4a982236d.exr 276 | AG8A3344-aa6bd87af1.exr 277 | AG8A3410-5265119841.exr 278 | AG8A3680-f4809b70f5.exr 279 | AG8A3921-5d51f11053.exr 280 | AG8A4131-e339e19f4f.exr 281 | AG8A4148-4171cf00f1.exr 282 | AG8A4274-2c4ee8b1b7.exr 283 | AG8A4341-85df8326c0.exr 284 | AG8A4757-eb65a62b4b.exr 285 | AG8A4786-d502998f99.exr 286 | AG8A4799-d198b87e8f.exr 287 | AG8A4912-7a6d07d4f4.exr 288 | AG8A4971-c5c2ef5b54.exr 289 | AG8A5146-0c0a37afe3.exr 290 | AG8A5318-e2e6dad171.exr 291 | AG8A5358-008e86604e.exr 292 | AG8A5402-c04fda61ec.exr 293 | AG8A5524-6e87e15d72.exr 294 | AG8A5582-5920abaaf0.exr 295 | AG8A6002-b1f17cce81.exr 296 | AG8A6196-f437fee405.exr 297 | AG8A6267-a7acfc754d.exr 298 | AG8A6298-558dbd3434.exr 299 | AG8A6383-695ecde991.exr 300 | AG8A6393-4aaa3e4996.exr 301 | AG8A6603-28fa0305d7.exr 302 | AG8A6803-6fb2a40ac6.exr 303 | AG8A6904-ea77ceb80f.exr 304 | AG8A6960-517298bbb4.exr 305 | AG8A7001-5bbf61b6dc.exr 306 | AG8A7043-510419a6ba.exr 307 | AG8A7044-8750436beb.exr 308 | AG8A7086-fb118b3e96.exr 309 | AG8A7128-802d24d0f1.exr 310 | AG8A7146-97ce1c2a87.exr 311 | AG8A7153-51f99279a5.exr 312 | AG8A7170-ae8392dbca.exr 313 | AG8A7321-3341d4ba77.exr 314 | AG8A7421-4fe7d8ce29.exr 315 | AG8A7440-3bbb2bfc63.exr 316 | AG8A7489-92f70a6360.exr 317 | AG8A7506-4b5788f5b5.exr 318 | AG8A7692-79e4b5baea.exr 319 | AG8A7758-a87e1737b1.exr 320 | AG8A7883-f48b891ed0.exr 321 | AG8A7884-aa41c65ad7.exr 322 | AG8A8143-444efdf806.exr 323 | AG8A8155-ff38e4311f.exr 324 | AG8A8196-6595b95de4.exr 325 | AG8A8247-391b2d3e6d.exr 326 | AG8A8262-ed5861c875.exr 327 | AG8A8280-1d6dd657cf.exr 328 | AG8A8311-3867aa4bbd.exr 329 | AG8A8322-553640ea4d.exr 330 | AG8A8416-e775940e93.exr 331 | AG8A8541-01c58c4cba.exr 332 | AG8A8561-7217d07680.exr 333 | AG8A8605-ca0b7fd8fe.exr 334 | AG8A8626-28011c026e.exr 335 | AG8A8693-52dc54169c.exr 336 | AG8A8793-2d47dbd7bb.exr 337 | AG8A8819-692bc91e9b.exr 338 | AG8A8976-13ab7ce192.exr 339 | AG8A9018-b5ec5d0356.exr 340 | AG8A9102-3ea9661619.exr 341 | AG8A9109-114c1f59d2.exr 342 | AG8A9142-74e4afd466.exr 343 | AG8A9184-3230b2c3f6.exr 344 | AG8A9226-7e9ad4b6a6.exr 345 | AG8A9312-26a5c8def4.exr 346 | AG8A9352-65dfeac3c5.exr 347 | AG8A9354-53b5596f43.exr 348 | AG8A9394-2f7b635088.exr 349 | AG8A9487-46c13d2ea9.exr 350 | AG8A9520-aa7d084f12.exr 351 | AG8A9604-1c325bcd15.exr 352 | AG8A9666-cc44ce23bc.exr 353 | AG8A9704-fc0d91c91f.exr 354 | AG8A9772-9fb9aa4491.exr 355 | AG8A9814-8f5d646009.exr 356 | AG8A9864-a24b11d21d.exr 357 | AG8A9956-30e880bb24.exr 358 | abandoned_games_room_02.exr 359 | abandoned_waterworks.exr 360 | aristea_wreck.exr 361 | autumn_forest_02.exr 362 | autumn_ground.exr 363 | beach_parking.exr 364 | blue_grotto.exr 365 | brick_factory_02.exr 366 | canary_wharf.exr 367 | cayley_lookout.exr 368 | chapmans_drive.exr 369 | childrens_hospital.exr 370 | cinema_hall.exr 371 | cloudy_cliffside_road.exr 372 | colorful_studio.exr 373 | combination_room.exr 374 | courtyard.exr 375 | decor_shop.exr 376 | dresden_moat.exr 377 | dresden_station_night.exr 378 | driving_school.exr 379 | emmarentia.exr 380 | fireplace.exr 381 | fish_eagle_hill.exr 382 | floral_tent.exr 383 | forest_cave.exr 384 | graffiti_shelter.exr 385 | greenwich_park_02.exr 386 | kiara_2_sunrise.exr 387 | kiara_9_dusk.exr 388 | kloppenheim_05.exr 389 | konzerthaus.exr 390 | lakeside.exr 391 | limehouse.exr 392 | lythwood_field.exr 393 | mall_parking_lot.exr 394 | missile_launch_facility_03.exr 395 | misty_dawn.exr 396 | monks_forest.exr 397 | mosaic_tunnel.exr 398 | mpumalanga_veld.exr 399 | museum_of_history.exr 400 | museumplein.exr 401 | mutianyu.exr 402 | old_hall.exr 403 | old_tree_in_city_park.exr 404 | ostrich_road.exr 405 | park_bench.exr 406 | phone_shop.exr 407 | potsdamer_platz.exr 408 | reinforced_concrete_02.exr 409 | roof_garden.exr 410 | royal_esplanade.exr 411 | rural_graffiti_tower.exr 412 | secluded_beach.exr 413 | shanghai_bund.exr 414 | simons_town_harbour.exr 415 | small_cathedral.exr 416 | small_hangar_02.exr 417 | small_rural_road.exr 418 | snowy_forest_path_01.exr 419 | snowy_park_01.exr 420 | stream.exr 421 | studio_small_05.exr 422 | studio_small_07.exr 423 | syferfontein_6d_clear.exr 424 | theater_02.exr 425 | vatican_road.exr 426 | venetian_crossroads.exr 427 | viale_giuseppe_garibaldi.exr 428 | vignaioli_night.exr 429 | whipple_creek_regional_park_01.exr 430 | xanderklinge.exr 431 | xiequ_yuan.exr 432 | zavelstein.exr -------------------------------------------------------------------------------- /data/datalists/DeepRelighting_shape5000/test_idx_mvs.txt: -------------------------------------------------------------------------------- 1 | 00000 2 | 00002 3 | 00014 4 | 00017 5 | 00022 6 | 00024 7 | 00030 8 | 00032 9 | 00035 10 | 00042 11 | 00045 12 | 00048 13 | 00050 14 | 00052 15 | 00061 16 | 00065 17 | 00069 18 | 00079 19 | 00085 20 | 00094 21 | 00095 22 | 00096 23 | 00097 24 | 00106 25 | 00108 26 | 00109 27 | 00116 28 | 00126 29 | 00127 30 | 00128 31 | 00136 32 | 00142 33 | 00145 34 | 00146 35 | 00154 36 | 00157 37 | 00158 38 | 00160 39 | 00162 40 | 00178 41 | 00179 42 | 00183 43 | 00194 44 | 00195 45 | 00197 46 | 00201 47 | 00213 48 | 00215 49 | 00218 50 | 00219 51 | 00221 52 | 00225 53 | 00230 54 | 00231 55 | 00232 56 | 00234 57 | 00254 58 | 00260 59 | 00270 60 | 00273 61 | 00285 62 | 00288 63 | 00289 64 | 00293 65 | 00304 66 | 00306 67 | 00309 68 | 00315 69 | 00316 70 | 00318 71 | 00320 72 | 00322 73 | 00330 74 | 00331 75 | 00336 76 | 00353 77 | 00354 78 | 00366 79 | 00388 80 | 00396 81 | 00399 82 | 00402 83 | 00408 84 | 00415 85 | 00416 86 | 00424 87 | 00425 88 | 00428 89 | 00433 90 | 00436 91 | 00442 92 | 00449 93 | 00451 94 | 00453 95 | 00461 96 | 00463 97 | 00464 98 | 00465 99 | 00475 100 | 00485 101 | 00487 102 | 00489 103 | 00490 104 | 00491 105 | 00500 106 | 00504 107 | 00505 108 | 00510 109 | 00511 110 | 00519 111 | 00531 112 | 00535 113 | 00537 114 | 00539 115 | 00541 116 | 00542 117 | 00547 118 | 00550 119 | 00556 120 | 00561 121 | 00564 122 | 00565 123 | 00571 124 | 00573 125 | 00583 126 | 00585 127 | 00591 128 | 00595 129 | 00607 130 | 00609 131 | 00611 132 | 00619 133 | 00624 134 | 00628 135 | 00629 136 | 00633 137 | 00638 138 | 00640 139 | 00641 140 | 00642 141 | 00643 142 | 00644 143 | 00647 144 | 00652 145 | 00653 146 | 00655 147 | 00673 148 | 00678 149 | 00721 150 | 00725 151 | 00728 152 | 00729 153 | 00734 154 | 00738 155 | 00745 156 | 00761 157 | 00766 158 | 00768 159 | 00769 160 | 00772 161 | 00776 162 | 00778 163 | 00780 164 | 00789 165 | 00801 166 | 00802 167 | 00825 168 | 00843 169 | 00846 170 | 00849 171 | 00852 172 | 00854 173 | 00858 174 | 00861 175 | 00871 176 | 00873 177 | 00875 178 | 00876 179 | 00888 180 | 00889 181 | 00892 182 | 00894 183 | 00900 184 | 00917 185 | 00920 186 | 00921 187 | 00930 188 | 00939 189 | 00941 190 | 00945 191 | 00946 192 | 00971 193 | 00979 194 | 00985 195 | 00986 196 | 00990 197 | 01001 198 | 01006 199 | 01009 200 | 01013 201 | 01019 202 | 01021 203 | 01028 204 | 01038 205 | 01043 206 | 01044 207 | 01054 208 | 01056 209 | 01061 210 | 01067 211 | 01068 212 | 01078 213 | 01085 214 | 01088 215 | 01091 216 | 01096 217 | 01105 218 | 01108 219 | 01123 220 | 01129 221 | 01130 222 | 01133 223 | 01135 224 | 01139 225 | 01142 226 | 01144 227 | 01153 228 | 01154 229 | 01159 230 | 01161 231 | 01163 232 | 01165 233 | 01166 234 | 01169 235 | 01171 236 | 01172 237 | 01175 238 | 01177 239 | 01195 240 | 01199 241 | 01210 242 | 01211 243 | 01214 244 | 01215 245 | 01223 246 | 01224 247 | 01227 248 | 01229 249 | 01230 250 | 01234 251 | 01236 252 | 01251 253 | 01254 254 | 01258 255 | 01264 256 | 01293 257 | 01295 258 | 01296 259 | 01297 260 | 01301 261 | 01309 262 | 01310 263 | 01312 264 | 01313 265 | 01316 266 | 01320 267 | 01322 268 | 01323 269 | 01325 270 | 01329 271 | 01330 272 | 01337 273 | 01340 274 | 01344 275 | 01347 276 | 01355 277 | 01357 278 | 01359 279 | 01360 280 | 01362 281 | 01368 282 | 01385 283 | 01387 284 | 01396 285 | 01401 286 | 01406 287 | 01415 288 | 01430 289 | 01436 290 | 01440 291 | 01441 292 | 01444 293 | 01445 294 | 01452 295 | 01457 296 | 01459 297 | 01462 298 | 01467 299 | 01471 300 | 01476 301 | 01478 302 | 01479 303 | 01492 304 | 01498 305 | 01499 306 | 01502 307 | 01506 308 | 01513 309 | 01516 310 | 01518 311 | 01528 312 | 01533 313 | 01535 314 | 01540 315 | 01545 316 | 01547 317 | 01579 318 | 01595 319 | 01598 320 | 01606 321 | 01607 322 | 01612 323 | 01613 324 | 01615 325 | 01616 326 | 01620 327 | 01623 328 | 01641 329 | 01643 330 | 01646 331 | 01647 332 | 01653 333 | 01656 334 | 01657 335 | 01663 336 | 01666 337 | 01670 338 | 01672 339 | 01675 340 | 01685 341 | 01689 342 | 01692 343 | 01694 344 | 01701 345 | 01703 346 | 01708 347 | 01710 348 | 01723 349 | 01729 350 | 01731 351 | 01734 352 | 01740 353 | 01741 354 | 01742 355 | 01743 356 | 01746 357 | 01761 358 | 01770 359 | 01774 360 | 01778 361 | 01782 362 | 01800 363 | 01808 364 | 01809 365 | 01810 366 | 01819 367 | 01820 368 | 01825 369 | 01830 370 | 01834 371 | 01839 372 | 01840 373 | 01844 374 | 01847 375 | 01850 376 | 01856 377 | 01858 378 | 01863 379 | 01864 380 | 01867 381 | 01871 382 | 01880 383 | 01885 384 | 01886 385 | 01895 386 | 01902 387 | 01904 388 | 01905 389 | 01912 390 | 01913 391 | 01916 392 | 01917 393 | 01947 394 | 01964 395 | 01973 396 | 01976 397 | 01978 398 | 01980 399 | 01982 400 | 01989 401 | 01993 402 | 01999 403 | 02009 404 | 02011 405 | 02013 406 | 02018 407 | 02026 408 | 02035 409 | 02038 410 | 02046 411 | 02048 412 | 02053 413 | 02058 414 | 02068 415 | 02072 416 | 02078 417 | 02080 418 | 02090 419 | 02097 420 | 02100 421 | 02104 422 | 02106 423 | 02126 424 | 02128 425 | 02136 426 | 02137 427 | 02142 428 | 02152 429 | 02159 430 | 02160 431 | 02161 432 | 02171 433 | 02188 434 | 02192 435 | 02200 436 | 02202 437 | 02203 438 | 02211 439 | 02221 440 | 02233 441 | 02236 442 | 02242 443 | 02251 444 | 02254 445 | 02255 446 | 02273 447 | 02277 448 | 02282 449 | 02285 450 | 02288 451 | 02293 452 | 02299 453 | 02300 454 | 02301 455 | 02303 456 | 02304 457 | 02308 458 | 02312 459 | 02328 460 | 02329 461 | 02342 462 | 02347 463 | 02354 464 | 02368 465 | 02370 466 | 02373 467 | 02386 468 | 02387 469 | 02390 470 | 02392 471 | 02394 472 | 02396 473 | 02405 474 | 02407 475 | 02409 476 | 02419 477 | 02425 478 | 02427 479 | 02430 480 | 02436 481 | 02440 482 | 02444 483 | 02454 484 | 02456 485 | 02457 486 | 02458 487 | 02459 488 | 02460 489 | 02464 490 | 02465 491 | 02469 492 | 02470 493 | 02473 494 | 02476 495 | 02477 496 | 02481 497 | 02482 498 | 02484 499 | 02487 500 | 02488 501 | 02492 502 | 02493 503 | 02500 504 | 02502 505 | 02505 506 | 02509 507 | 02514 508 | 02523 509 | 02526 510 | 02529 511 | 02530 512 | 02543 513 | 02545 514 | 02547 515 | 02549 516 | 02551 517 | 02561 518 | 02572 519 | 02574 520 | 02576 521 | 02577 522 | 02595 523 | 02598 524 | 02599 525 | 02611 526 | 02615 527 | 02626 528 | 02636 529 | 02645 530 | 02651 531 | 02652 532 | 02657 533 | 02658 534 | 02659 535 | 02663 536 | 02667 537 | 02679 538 | 02717 539 | 02723 540 | 02729 541 | 02733 542 | 02735 543 | 02739 544 | 02747 545 | 02759 546 | 02760 547 | 02770 548 | 02775 549 | 02779 550 | 02785 551 | 02792 552 | 02794 553 | 02799 554 | 02801 555 | 02806 556 | 02810 557 | 02812 558 | 02820 559 | 02821 560 | 02830 561 | 02831 562 | 02832 563 | 02834 564 | 02840 565 | 02841 566 | 02844 567 | 02855 568 | 02856 569 | 02857 570 | 02860 571 | 02870 572 | 02877 573 | 02898 574 | 02902 575 | 02905 576 | 02910 577 | 02911 578 | 02919 579 | 02923 580 | 02925 581 | 02926 582 | 02928 583 | 02929 584 | 02932 585 | 02933 586 | 02937 587 | 02940 588 | 02944 589 | 02960 590 | 02966 591 | 02971 592 | 02984 593 | 02985 594 | 02987 595 | 02992 596 | 02993 597 | 03005 598 | 03006 599 | 03018 600 | 03019 601 | 03027 602 | 03032 603 | 03039 604 | 03041 605 | 03045 606 | 03055 607 | 03060 608 | 03062 609 | 03065 610 | 03068 611 | 03074 612 | 03075 613 | 03076 614 | 03078 615 | 03080 616 | 03081 617 | 03082 618 | 03087 619 | 03096 620 | 03105 621 | 03107 622 | 03111 623 | 03116 624 | 03128 625 | 03133 626 | 03134 627 | 03137 628 | 03143 629 | 03144 630 | 03146 631 | 03165 632 | 03172 633 | 03175 634 | 03179 635 | 03184 636 | 03195 637 | 03198 638 | 03200 639 | 03204 640 | 03206 641 | 03208 642 | 03209 643 | 03214 644 | 03215 645 | 03223 646 | 03224 647 | 03225 648 | 03229 649 | 03235 650 | 03241 651 | 03251 652 | 03258 653 | 03262 654 | 03263 655 | 03266 656 | 03268 657 | 03281 658 | 03282 659 | 03283 660 | 03287 661 | 03293 662 | 03296 663 | 03298 664 | 03309 665 | 03311 666 | 03314 667 | 03315 668 | 03325 669 | 03330 670 | 03334 671 | 03335 672 | 03344 673 | 03349 674 | 03351 675 | 03352 676 | 03367 677 | 03371 678 | 03381 679 | 03382 680 | 03383 681 | 03387 682 | 03388 683 | 03391 684 | 03397 685 | 03398 686 | 03399 687 | 03400 688 | 03402 689 | 03413 690 | 03416 691 | 03422 692 | 03423 693 | 03427 694 | 03444 695 | 03448 696 | 03449 697 | 03453 698 | 03463 699 | 03468 700 | 03469 701 | 03470 702 | 03475 703 | 03477 704 | 03489 705 | 03498 706 | 03501 707 | 03503 708 | 03506 709 | 03509 710 | 03511 711 | 03516 712 | 03519 713 | 03521 714 | 03526 715 | 03527 716 | 03535 717 | 03545 718 | 03567 719 | 03569 720 | 03570 721 | 03572 722 | 03573 723 | 03576 724 | 03579 725 | 03581 726 | 03588 727 | 03590 728 | 03592 729 | 03594 730 | 03608 731 | 03610 732 | 03618 733 | 03619 734 | 03621 735 | 03622 736 | 03623 737 | 03627 738 | 03657 739 | 03658 740 | 03661 741 | 03663 742 | 03664 743 | 03666 744 | 03672 745 | 03674 746 | 03680 747 | 03684 748 | 03685 749 | 03686 750 | 03690 751 | 03693 752 | 03694 753 | 03704 754 | 03707 755 | 03710 756 | 03713 757 | 03716 758 | 03719 759 | 03731 760 | 03734 761 | 03738 762 | 03741 763 | 03746 764 | 03758 765 | 03762 766 | 03764 767 | 03765 768 | 03770 769 | 03775 770 | 03779 771 | 03780 772 | 03794 773 | 03798 774 | 03813 775 | 03821 776 | 03823 777 | 03824 778 | 03825 779 | 03832 780 | 03834 781 | 03837 782 | 03838 783 | 03847 784 | 03849 785 | 03853 786 | 03861 787 | 03871 788 | 03872 789 | 03873 790 | 03876 791 | 03882 792 | 03892 793 | 03903 794 | 03911 795 | 03925 796 | 03927 797 | 03930 798 | 03938 799 | 03942 800 | 03955 801 | 03963 802 | 03967 803 | 03980 804 | 03983 805 | 03984 806 | 03985 807 | 03988 808 | 04000 809 | 04020 810 | 04029 811 | 04047 812 | 04050 813 | 04052 814 | 04053 815 | 04059 816 | 04060 817 | 04061 818 | 04065 819 | 04078 820 | 04089 821 | 04097 822 | 04105 823 | 04112 824 | 04115 825 | 04116 826 | 04119 827 | 04123 828 | 04126 829 | 04142 830 | 04143 831 | 04145 832 | 04152 833 | 04153 834 | 04155 835 | 04157 836 | 04162 837 | 04164 838 | 04167 839 | 04170 840 | 04176 841 | 04177 842 | 04211 843 | 04215 844 | 04222 845 | 04223 846 | 04224 847 | 04225 848 | 04226 849 | 04227 850 | 04234 851 | 04235 852 | 04240 853 | 04249 854 | 04250 855 | 04261 856 | 04265 857 | 04267 858 | 04269 859 | 04270 860 | 04271 861 | 04278 862 | 04282 863 | 04285 864 | 04286 865 | 04304 866 | 04310 867 | 04318 868 | 04328 869 | 04329 870 | 04337 871 | 04338 872 | 04340 873 | 04349 874 | 04353 875 | 04360 876 | 04361 877 | 04365 878 | 04366 879 | 04373 880 | 04378 881 | 04384 882 | 04385 883 | 04397 884 | 04398 885 | 04400 886 | 04403 887 | 04406 888 | 04409 889 | 04411 890 | 04419 891 | 04421 892 | 04422 893 | 04427 894 | 04438 895 | 04443 896 | 04450 897 | 04460 898 | 04470 899 | 04472 900 | 04491 901 | 04501 902 | 04502 903 | 04503 904 | 04504 905 | 04506 906 | 04521 907 | 04522 908 | 04527 909 | 04529 910 | 04535 911 | 04539 912 | 04544 913 | 04549 914 | 04551 915 | 04555 916 | 04560 917 | 04570 918 | 04574 919 | 04578 920 | 04582 921 | 04591 922 | 04592 923 | 04601 924 | 04602 925 | 04603 926 | 04606 927 | 04616 928 | 04617 929 | 04618 930 | 04619 931 | 04626 932 | 04629 933 | 04630 934 | 04632 935 | 04635 936 | 04636 937 | 04637 938 | 04664 939 | 04685 940 | 04689 941 | 04691 942 | 04693 943 | 04695 944 | 04696 945 | 04702 946 | 04707 947 | 04712 948 | 04714 949 | 04715 950 | 04722 951 | 04723 952 | 04735 953 | 04744 954 | 04754 955 | 04758 956 | 04760 957 | 04761 958 | 04768 959 | 04774 960 | 04776 961 | 04785 962 | 04791 963 | 04792 964 | 04801 965 | 04804 966 | 04814 967 | 04822 968 | 04824 969 | 04825 970 | 04827 971 | 04833 972 | 04836 973 | 04838 974 | 04840 975 | 04842 976 | 04854 977 | 04855 978 | 04871 979 | 04872 980 | 04876 981 | 04878 982 | 04889 983 | 04895 984 | 04897 985 | 04911 986 | 04920 987 | 04921 988 | 04922 989 | 04925 990 | 04929 991 | 04932 992 | 04933 993 | 04951 994 | 04952 995 | 04953 996 | 04973 997 | 04985 998 | 04986 999 | 04989 1000 | 04994 -------------------------------------------------------------------------------- /data/datalists/DeepRelighting_shape5000/shapes_val.txt: -------------------------------------------------------------------------------- 1 | Shape__1004 2 | Shape__1009 3 | Shape__102 4 | Shape__1025 5 | Shape__1027 6 | Shape__1034 7 | Shape__1037 8 | Shape__1043 9 | Shape__1046 10 | Shape__1047 11 | Shape__1050 12 | Shape__1058 13 | Shape__1060 14 | Shape__1066 15 | Shape__1068 16 | Shape__1073 17 | Shape__1076 18 | Shape__108 19 | Shape__1080 20 | Shape__1081 21 | Shape__109 22 | Shape__1098 23 | Shape__1106 24 | Shape__1118 25 | Shape__1121 26 | Shape__1126 27 | Shape__113 28 | Shape__1131 29 | Shape__1142 30 | Shape__1146 31 | Shape__1153 32 | Shape__1166 33 | Shape__1170 34 | Shape__1177 35 | Shape__1201 36 | Shape__121 37 | Shape__1213 38 | Shape__1218 39 | Shape__1219 40 | Shape__1222 41 | Shape__1223 42 | Shape__1229 43 | Shape__1236 44 | Shape__1244 45 | Shape__1250 46 | Shape__1251 47 | Shape__1253 48 | Shape__1287 49 | Shape__129 50 | Shape__1294 51 | Shape__1307 52 | Shape__1308 53 | Shape__1333 54 | Shape__1336 55 | Shape__1337 56 | Shape__1339 57 | Shape__1344 58 | Shape__1345 59 | Shape__1359 60 | Shape__1370 61 | Shape__1373 62 | Shape__1374 63 | Shape__1378 64 | Shape__1384 65 | Shape__139 66 | Shape__1392 67 | Shape__1394 68 | Shape__1399 69 | Shape__1400 70 | Shape__1406 71 | Shape__1409 72 | Shape__141 73 | Shape__1413 74 | Shape__1420 75 | Shape__1421 76 | Shape__1443 77 | Shape__1444 78 | Shape__1449 79 | Shape__1450 80 | Shape__1460 81 | Shape__1462 82 | Shape__1469 83 | Shape__1470 84 | Shape__1473 85 | Shape__1475 86 | Shape__1477 87 | Shape__1478 88 | Shape__1479 89 | Shape__1484 90 | Shape__15 91 | Shape__1508 92 | Shape__1509 93 | Shape__152 94 | Shape__1523 95 | Shape__1536 96 | Shape__1547 97 | Shape__1551 98 | Shape__1553 99 | Shape__1554 100 | Shape__1558 101 | Shape__1560 102 | Shape__1570 103 | Shape__1573 104 | Shape__1581 105 | Shape__1592 106 | Shape__1594 107 | Shape__16 108 | Shape__1604 109 | Shape__1611 110 | Shape__1613 111 | Shape__1619 112 | Shape__1620 113 | Shape__1621 114 | Shape__1622 115 | Shape__1623 116 | Shape__1631 117 | Shape__1641 118 | Shape__1643 119 | Shape__1648 120 | Shape__1655 121 | Shape__1660 122 | Shape__1661 123 | Shape__1663 124 | Shape__1671 125 | Shape__1673 126 | Shape__1677 127 | Shape__1678 128 | Shape__1682 129 | Shape__1687 130 | Shape__1691 131 | Shape__1703 132 | Shape__1704 133 | Shape__1705 134 | Shape__1710 135 | Shape__1715 136 | Shape__1716 137 | Shape__1739 138 | Shape__174 139 | Shape__1743 140 | Shape__1747 141 | Shape__1748 142 | Shape__1753 143 | Shape__1760 144 | Shape__1764 145 | Shape__1766 146 | Shape__1781 147 | Shape__1783 148 | Shape__1794 149 | Shape__1807 150 | Shape__1813 151 | Shape__1815 152 | Shape__1819 153 | Shape__1831 154 | Shape__1848 155 | Shape__1850 156 | Shape__1856 157 | Shape__1860 158 | Shape__1861 159 | Shape__1862 160 | Shape__1863 161 | Shape__1867 162 | Shape__187 163 | Shape__1884 164 | Shape__1888 165 | Shape__1893 166 | Shape__1896 167 | Shape__190 168 | Shape__1900 169 | Shape__1902 170 | Shape__1910 171 | Shape__1921 172 | Shape__1961 173 | Shape__1963 174 | Shape__1965 175 | Shape__1968 176 | Shape__1977 177 | Shape__1982 178 | Shape__1986 179 | Shape__2003 180 | Shape__2007 181 | Shape__2012 182 | Shape__2026 183 | Shape__2029 184 | Shape__203 185 | Shape__2033 186 | Shape__2053 187 | Shape__206 188 | Shape__2064 189 | Shape__2066 190 | Shape__2069 191 | Shape__2074 192 | Shape__2075 193 | Shape__2089 194 | Shape__2095 195 | Shape__2114 196 | Shape__2116 197 | Shape__2131 198 | Shape__2133 199 | Shape__2137 200 | Shape__2138 201 | Shape__2154 202 | Shape__2156 203 | Shape__2171 204 | Shape__2173 205 | Shape__2175 206 | Shape__2183 207 | Shape__2191 208 | Shape__2204 209 | Shape__2212 210 | Shape__2218 211 | Shape__2222 212 | Shape__2225 213 | Shape__2239 214 | Shape__2241 215 | Shape__2242 216 | Shape__2247 217 | Shape__2248 218 | Shape__225 219 | Shape__2253 220 | Shape__2265 221 | Shape__2266 222 | Shape__2270 223 | Shape__2279 224 | Shape__228 225 | Shape__2296 226 | Shape__2300 227 | Shape__2307 228 | Shape__2308 229 | Shape__231 230 | Shape__2317 231 | Shape__232 232 | Shape__2320 233 | Shape__2324 234 | Shape__2331 235 | Shape__2334 236 | Shape__2338 237 | Shape__2351 238 | Shape__2357 239 | Shape__2360 240 | Shape__2367 241 | Shape__2369 242 | Shape__237 243 | Shape__2398 244 | Shape__240 245 | Shape__2401 246 | Shape__2402 247 | Shape__2405 248 | Shape__2424 249 | Shape__2425 250 | Shape__2431 251 | Shape__2432 252 | Shape__2435 253 | Shape__2438 254 | Shape__2448 255 | Shape__2454 256 | Shape__2468 257 | Shape__2471 258 | Shape__2474 259 | Shape__2476 260 | Shape__2481 261 | Shape__2483 262 | Shape__2490 263 | Shape__2493 264 | Shape__2497 265 | Shape__250 266 | Shape__2508 267 | Shape__2510 268 | Shape__2515 269 | Shape__2516 270 | Shape__2534 271 | Shape__2538 272 | Shape__2539 273 | Shape__2540 274 | Shape__2544 275 | Shape__2557 276 | Shape__2562 277 | Shape__2570 278 | Shape__2578 279 | Shape__2593 280 | Shape__2594 281 | Shape__2596 282 | Shape__2604 283 | Shape__2607 284 | Shape__262 285 | Shape__2622 286 | Shape__2629 287 | Shape__2631 288 | Shape__2638 289 | Shape__2643 290 | Shape__2648 291 | Shape__266 292 | Shape__2661 293 | Shape__2662 294 | Shape__2665 295 | Shape__269 296 | Shape__2692 297 | Shape__2693 298 | Shape__2694 299 | Shape__2705 300 | Shape__2707 301 | Shape__2714 302 | Shape__2724 303 | Shape__2726 304 | Shape__2729 305 | Shape__2732 306 | Shape__274 307 | Shape__2745 308 | Shape__277 309 | Shape__2773 310 | Shape__2793 311 | Shape__2795 312 | Shape__2796 313 | Shape__2815 314 | Shape__2824 315 | Shape__2828 316 | Shape__2829 317 | Shape__2831 318 | Shape__2836 319 | Shape__2838 320 | Shape__2839 321 | Shape__2844 322 | Shape__2849 323 | Shape__285 324 | Shape__2859 325 | Shape__2860 326 | Shape__2862 327 | Shape__2874 328 | Shape__2882 329 | Shape__2884 330 | Shape__2893 331 | Shape__2900 332 | Shape__2902 333 | Shape__2924 334 | Shape__2925 335 | Shape__293 336 | Shape__2938 337 | Shape__2939 338 | Shape__2940 339 | Shape__2948 340 | Shape__2950 341 | Shape__2953 342 | Shape__2954 343 | Shape__2955 344 | Shape__2962 345 | Shape__2964 346 | Shape__2966 347 | Shape__2970 348 | Shape__2983 349 | Shape__2985 350 | Shape__2987 351 | Shape__2991 352 | Shape__2994 353 | Shape__2995 354 | Shape__2996 355 | Shape__2999 356 | Shape__30 357 | Shape__3001 358 | Shape__3006 359 | Shape__3010 360 | Shape__302 361 | Shape__3029 362 | Shape__3038 363 | Shape__3042 364 | Shape__3049 365 | Shape__3056 366 | Shape__306 367 | Shape__3060 368 | Shape__3072 369 | Shape__3083 370 | Shape__3097 371 | Shape__3098 372 | Shape__3099 373 | Shape__3103 374 | Shape__3107 375 | Shape__3112 376 | Shape__3114 377 | Shape__3126 378 | Shape__3135 379 | Shape__3139 380 | Shape__3147 381 | Shape__3167 382 | Shape__3171 383 | Shape__3174 384 | Shape__3181 385 | Shape__319 386 | Shape__3193 387 | Shape__3199 388 | Shape__3205 389 | Shape__3222 390 | Shape__3225 391 | Shape__3228 392 | Shape__324 393 | Shape__3243 394 | Shape__3245 395 | Shape__3259 396 | Shape__3267 397 | Shape__3276 398 | Shape__3277 399 | Shape__3282 400 | Shape__3288 401 | Shape__3293 402 | Shape__3295 403 | Shape__3298 404 | Shape__3300 405 | Shape__3301 406 | Shape__3307 407 | Shape__3310 408 | Shape__3322 409 | Shape__3324 410 | Shape__3326 411 | Shape__333 412 | Shape__3331 413 | Shape__3334 414 | Shape__3346 415 | Shape__3347 416 | Shape__335 417 | Shape__3353 418 | Shape__336 419 | Shape__3366 420 | Shape__3369 421 | Shape__3375 422 | Shape__3376 423 | Shape__3387 424 | Shape__340 425 | Shape__3403 426 | Shape__3404 427 | Shape__3407 428 | Shape__3408 429 | Shape__341 430 | Shape__3411 431 | Shape__3424 432 | Shape__3425 433 | Shape__3426 434 | Shape__345 435 | Shape__3453 436 | Shape__3456 437 | Shape__3457 438 | Shape__3461 439 | Shape__3464 440 | Shape__3467 441 | Shape__347 442 | Shape__3472 443 | Shape__3473 444 | Shape__3476 445 | Shape__3488 446 | Shape__3490 447 | Shape__3495 448 | Shape__3499 449 | Shape__350 450 | Shape__3509 451 | Shape__3530 452 | Shape__3532 453 | Shape__3551 454 | Shape__3556 455 | Shape__3560 456 | Shape__3574 457 | Shape__3578 458 | Shape__3579 459 | Shape__3582 460 | Shape__3584 461 | Shape__3586 462 | Shape__3590 463 | Shape__3605 464 | Shape__3610 465 | Shape__3619 466 | Shape__3621 467 | Shape__3627 468 | Shape__3628 469 | Shape__363 470 | Shape__3632 471 | Shape__364 472 | Shape__3642 473 | Shape__3643 474 | Shape__3647 475 | Shape__3656 476 | Shape__3657 477 | Shape__3658 478 | Shape__3670 479 | Shape__3676 480 | Shape__3697 481 | Shape__370 482 | Shape__3700 483 | Shape__3711 484 | Shape__3717 485 | Shape__3718 486 | Shape__3725 487 | Shape__3730 488 | Shape__3731 489 | Shape__3736 490 | Shape__3740 491 | Shape__3755 492 | Shape__3758 493 | Shape__3778 494 | Shape__3781 495 | Shape__3782 496 | Shape__3784 497 | Shape__3789 498 | Shape__3790 499 | Shape__3796 500 | Shape__3801 501 | Shape__3809 502 | Shape__3814 503 | Shape__3820 504 | Shape__3822 505 | Shape__3826 506 | Shape__3834 507 | Shape__384 508 | Shape__3841 509 | Shape__3845 510 | Shape__3846 511 | Shape__3852 512 | Shape__3858 513 | Shape__3867 514 | Shape__387 515 | Shape__3894 516 | Shape__3898 517 | Shape__3911 518 | Shape__3912 519 | Shape__3914 520 | Shape__3925 521 | Shape__3932 522 | Shape__3942 523 | Shape__3944 524 | Shape__3947 525 | Shape__395 526 | Shape__3955 527 | Shape__3980 528 | Shape__3983 529 | Shape__3987 530 | Shape__3992 531 | Shape__3995 532 | Shape__400 533 | Shape__4001 534 | Shape__4009 535 | Shape__4018 536 | Shape__4022 537 | Shape__4027 538 | Shape__404 539 | Shape__4051 540 | Shape__4064 541 | Shape__4068 542 | Shape__4085 543 | Shape__409 544 | Shape__4091 545 | Shape__4092 546 | Shape__4107 547 | Shape__4110 548 | Shape__4111 549 | Shape__4112 550 | Shape__4117 551 | Shape__4121 552 | Shape__4124 553 | Shape__4135 554 | Shape__4139 555 | Shape__4157 556 | Shape__4159 557 | Shape__416 558 | Shape__4169 559 | Shape__4170 560 | Shape__4173 561 | Shape__4177 562 | Shape__4193 563 | Shape__4198 564 | Shape__42 565 | Shape__420 566 | Shape__422 567 | Shape__4221 568 | Shape__4222 569 | Shape__423 570 | Shape__4238 571 | Shape__4252 572 | Shape__4256 573 | Shape__426 574 | Shape__4261 575 | Shape__4264 576 | Shape__4266 577 | Shape__4267 578 | Shape__4272 579 | Shape__4273 580 | Shape__4274 581 | Shape__4276 582 | Shape__4277 583 | Shape__428 584 | Shape__4280 585 | Shape__4286 586 | Shape__4300 587 | Shape__4301 588 | Shape__4312 589 | Shape__4324 590 | Shape__4327 591 | Shape__4328 592 | Shape__4332 593 | Shape__4340 594 | Shape__4341 595 | Shape__4347 596 | Shape__4369 597 | Shape__4372 598 | Shape__4376 599 | Shape__4379 600 | Shape__438 601 | Shape__4382 602 | Shape__4392 603 | Shape__4395 604 | Shape__4401 605 | Shape__4407 606 | Shape__4408 607 | Shape__4411 608 | Shape__4426 609 | Shape__4428 610 | Shape__4435 611 | Shape__4441 612 | Shape__4445 613 | Shape__4446 614 | Shape__445 615 | Shape__4458 616 | Shape__4461 617 | Shape__4469 618 | Shape__4478 619 | Shape__4479 620 | Shape__4490 621 | Shape__4491 622 | Shape__4497 623 | Shape__4507 624 | Shape__4514 625 | Shape__4520 626 | Shape__4537 627 | Shape__4541 628 | Shape__4544 629 | Shape__4551 630 | Shape__4555 631 | Shape__4557 632 | Shape__4561 633 | Shape__4566 634 | Shape__4571 635 | Shape__4576 636 | Shape__459 637 | Shape__4590 638 | Shape__4592 639 | Shape__4594 640 | Shape__4595 641 | Shape__4598 642 | Shape__4602 643 | Shape__4608 644 | Shape__4610 645 | Shape__4611 646 | Shape__4612 647 | Shape__4617 648 | Shape__4618 649 | Shape__4620 650 | Shape__4622 651 | Shape__4630 652 | Shape__4633 653 | Shape__4636 654 | Shape__4648 655 | Shape__4649 656 | Shape__4654 657 | Shape__4656 658 | Shape__4672 659 | Shape__4676 660 | Shape__4678 661 | Shape__4680 662 | Shape__4682 663 | Shape__4685 664 | Shape__4687 665 | Shape__4694 666 | Shape__4696 667 | Shape__4704 668 | Shape__4713 669 | Shape__4714 670 | Shape__4715 671 | Shape__4716 672 | Shape__4723 673 | Shape__4724 674 | Shape__473 675 | Shape__4731 676 | Shape__4733 677 | Shape__4742 678 | Shape__4745 679 | Shape__4748 680 | Shape__4762 681 | Shape__4763 682 | Shape__4764 683 | Shape__4770 684 | Shape__4775 685 | Shape__478 686 | Shape__4784 687 | Shape__4788 688 | Shape__4795 689 | Shape__4807 690 | Shape__4824 691 | Shape__4825 692 | Shape__4826 693 | Shape__4832 694 | Shape__4857 695 | Shape__4886 696 | Shape__4888 697 | Shape__4890 698 | Shape__4898 699 | Shape__4910 700 | Shape__4913 701 | Shape__4915 702 | Shape__492 703 | Shape__4925 704 | Shape__4941 705 | Shape__4947 706 | Shape__4954 707 | Shape__4962 708 | Shape__4963 709 | Shape__498 710 | Shape__4982 711 | Shape__4989 712 | Shape__499 713 | Shape__4992 714 | Shape__50 715 | Shape__502 716 | Shape__508 717 | Shape__511 718 | Shape__514 719 | Shape__515 720 | Shape__516 721 | Shape__518 722 | Shape__527 723 | Shape__533 724 | Shape__537 725 | Shape__538 726 | Shape__54 727 | Shape__553 728 | Shape__557 729 | Shape__558 730 | Shape__56 731 | Shape__562 732 | Shape__572 733 | Shape__577 734 | Shape__579 735 | Shape__583 736 | Shape__586 737 | Shape__590 738 | Shape__604 739 | Shape__617 740 | Shape__623 741 | Shape__624 742 | Shape__628 743 | Shape__632 744 | Shape__634 745 | Shape__638 746 | Shape__642 747 | Shape__65 748 | Shape__67 749 | Shape__673 750 | Shape__678 751 | Shape__679 752 | Shape__683 753 | Shape__686 754 | Shape__688 755 | Shape__690 756 | Shape__691 757 | Shape__70 758 | Shape__705 759 | Shape__706 760 | Shape__707 761 | Shape__711 762 | Shape__712 763 | Shape__726 764 | Shape__734 765 | Shape__739 766 | Shape__746 767 | Shape__753 768 | Shape__756 769 | Shape__758 770 | Shape__759 771 | Shape__767 772 | Shape__768 773 | Shape__77 774 | Shape__773 775 | Shape__774 776 | Shape__782 777 | Shape__791 778 | Shape__803 779 | Shape__813 780 | Shape__826 781 | Shape__827 782 | Shape__832 783 | Shape__842 784 | Shape__845 785 | Shape__846 786 | Shape__86 787 | Shape__864 788 | Shape__876 789 | Shape__882 790 | Shape__921 791 | Shape__922 792 | Shape__952 793 | Shape__954 794 | Shape__963 795 | Shape__965 796 | Shape__968 797 | Shape__970 798 | Shape__972 799 | Shape__989 800 | Shape__997 -------------------------------------------------------------------------------- /dataset/parametric_img2refmap.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | from pathlib import Path 4 | from time import sleep 5 | from typing import Dict, List, Optional, Union 6 | 7 | import cv2 8 | import numpy as np 9 | import torch 10 | 11 | from models.obsnet import ObsNetDiffusion 12 | from utils.file_io import load_exr 13 | from utils.transform import thetaphi2xyz 14 | 15 | from .basedataset import BaseDataset 16 | 17 | 18 | class ParametricImg2RefmapDataset(BaseDataset): 19 | def __init__( 20 | self, 21 | size: int, 22 | split: str, 23 | data_root: str, 24 | shape_root: str, 25 | zdim: int, 26 | transform_func: str = "log", 27 | clamp_before_exp: float = 0, 28 | return_envmap: bool = False, 29 | return_obj: bool = False, 30 | refmap_key: str = "LrK", 31 | epoch_bias: int = 0, 32 | epoch_cycle: int = 1000, 33 | return_cache=True, 34 | refmap_cache_root: Optional[str] = None, 35 | objimg_cache_root: str = None, 36 | preload_envmap: bool = False, 37 | ): 38 | super().__init__(size, transform_func=transform_func, clamp_before_exp=clamp_before_exp) 39 | 40 | assert split in ["train", "val", "test"] 41 | self.split = split 42 | self.root = Path(data_root) 43 | self.data_name = self.root.name 44 | assert self.data_name in ["LavalIndoor+PolyHaven_2k"] 45 | self.t = "train" if split in ["train", "val"] else "test" 46 | with open(f"data/datalists/{self.data_name}/envs_{split}.txt", "r") as f: 47 | self.envs = f.read().splitlines() 48 | 49 | if shape_root is not None: 50 | self.shape_root = Path(shape_root) 51 | self.shape_set_name = self.shape_root.name 52 | with open(f"data/datalists/{self.shape_set_name}/shapes_{split}.txt", "r") as f: 53 | self.shape_tags = f.read().rstrip().splitlines() 54 | self.shape_len = len(self.shape_tags) 55 | 56 | self.zdim = zdim 57 | 58 | self.return_envmap = return_envmap 59 | self.return_obj = return_obj 60 | 61 | self.generator = torch.Generator() 62 | self.current_epoch = 0 63 | 64 | self.model: ObsNetDiffusion = None 65 | self.refmap_key = refmap_key 66 | 67 | self.epoch_bias = epoch_bias 68 | self.epoch_cycle = epoch_cycle 69 | 70 | self.return_cache = return_cache 71 | self.refmap_cache_root = Path(refmap_cache_root) if refmap_cache_root is not None else None 72 | self.objimg_cache_root = Path(objimg_cache_root) if objimg_cache_root is not None else None 73 | if self.return_cache: 74 | assert self.refmap_cache_root is not None, "specify refmap_cache_root to return cache" 75 | 76 | self.preload_envmap = preload_envmap 77 | if self.return_envmap and preload_envmap: 78 | self.envmaps = {} 79 | for env in self.envs: 80 | env_name = env[:-4] 81 | self.envmaps[env_name] = load_exr(self.root / f"{env_name}.exr", as_torch=True) 82 | 83 | def __len__(self): 84 | return len(self.envs) 85 | 86 | def set_current_epoch(self, epoch): 87 | self.current_epoch = epoch 88 | 89 | def set_generator(self, idx: int, epoch: int = None): 90 | if self.split == "train": 91 | epoch = epoch or self.current_epoch 92 | epoch = epoch + self.epoch_bias 93 | if epoch >= self.epoch_cycle: 94 | epoch = epoch % self.epoch_cycle 95 | self.generator.manual_seed((epoch) * len(self) + idx) 96 | elif self.split == "val": 97 | self.generator.manual_seed(idx) 98 | self.generator.manual_seed(torch.empty((), dtype=torch.int64).random_(generator=self.generator).item()) 99 | elif self.split == "test": 100 | self.generator.manual_seed(idx) 101 | torch.empty((), dtype=torch.int64).random_(generator=self.generator) 102 | self.generator.manual_seed(torch.empty((), dtype=torch.int64).random_(generator=self.generator).item()) 103 | else: 104 | raise NotImplementedError() 105 | 106 | @torch.no_grad() 107 | def __getitem__(self, idx: int): 108 | env_name = self.envs[idx][:-4] 109 | self.set_generator(idx) 110 | zK = torch.rand((self.zdim,), generator=self.generator) 111 | data = {} 112 | data["zK"] = zK 113 | data["envmap_name"] = env_name 114 | 115 | normalized_k = torch.rand((), generator=self.generator) 116 | data["normalized_k"] = normalized_k 117 | 118 | phi = (torch.rand((), generator=self.generator) * 64).int() / 64 * torch.pi * 2 - torch.pi 119 | theta = (torch.rand((), generator=self.generator) * 0 + 0.5) * torch.pi # Invalid 120 | view_from = thetaphi2xyz(torch.stack([theta, phi]), normal=[0, 1, 0], tangent=[0, 0, 1]) 121 | data["view_from"] = view_from 122 | 123 | mask_idx = torch.rand((), generator=self.generator).item() # to keep consistency of sampling 124 | 125 | if self.split == "train": 126 | bias = len(self) * self.current_epoch 127 | else: 128 | bias = 0 129 | shape_idx = (idx + bias) % self.shape_len 130 | obj_name = self.shape_tags[shape_idx] 131 | data["obj_name"] = obj_name 132 | 133 | seem_need_envmap = True 134 | seem_need_obj = True 135 | if self.model is not None: 136 | brdf_param_names = self.model.refmap_renderer.brdf_param_names 137 | size = self.model.refmap_renderer.refmap_res 138 | denoise_suffix = f"_{self.model.refmap_renderer.denoise}denoise" if self.model.refmap_renderer.denoise else "" 139 | cache_dir = self.refmap_cache_root / "-".join(brdf_param_names) / f"{size}x{size}_spp{spp}{denoise_suffix}/" 140 | torch.set_printoptions(precision=4, sci_mode=True) 141 | 142 | def get_cache(z) -> torch.Tensor: 143 | pieces_key = "b" + str(z)[7:-1] + "v" + str(view_from)[7:-1] 144 | pieces_key = pieces_key.replace("\n", "").replace(" ", "") 145 | filename = pieces_key + ".pt" 146 | cache_file_path = cache_dir / env_name / filename 147 | if not cache_file_path.exists(): 148 | return False, torch.full((3, size, size), torch.nan) 149 | for _ in range(3): 150 | try: 151 | cache: dict = torch.load(cache_file_path, map_location="cpu") 152 | except Exception as e: 153 | print(cache_file_path) 154 | print(e) 155 | sleep(0.01) 156 | else: 157 | break 158 | else: 159 | return False, torch.full((3, size, size), torch.nan) 160 | if ( 161 | cache.get("envmap_name") == env_name 162 | and cache.get("brdf_param_names") == brdf_param_names 163 | and torch.allclose(cache.get("zk"), z) 164 | and torch.allclose(cache.get("view_from"), view_from) 165 | and cache.get("refmap_res") == size 166 | ): 167 | if (cache.get("zk") == z).all(): 168 | return True, cache.get("rendering_results"), True 169 | else: 170 | return True, cache.get("rendering_results"), True 171 | else: 172 | return False, torch.full((3, size, size), torch.nan) 173 | 174 | refmap = get_cache(zK) 175 | data[self.refmap_key] = refmap[1] 176 | seem_need_envmap = not refmap[0] 177 | 178 | size = self.model.img_renderer.image_size 179 | spp = self.model.img_renderer.spp 180 | denoise_suffix = f"_{self.model.img_renderer.denoise}denoise" if self.model.img_renderer.denoise else "" 181 | cache_dir = self.objimg_cache_root / "-".join(brdf_param_names) / f"{size[0]}x{size[1]}_spp{spp}{denoise_suffix}/" 182 | torch.set_printoptions(precision=4, sci_mode=True) 183 | 184 | def get_cache(z) -> torch.Tensor: 185 | pieces_key = "b" + str(z)[7:-1] + "v" + str(view_from)[7:-1] 186 | pieces_key = pieces_key.replace("\n", "").replace(" ", "") 187 | filename = pieces_key + ".pt" 188 | cache_file_path = cache_dir / env_name / obj_name / filename 189 | if not cache_file_path.exists(): 190 | return (False,) 191 | for _ in range(3): 192 | try: 193 | cache: dict = torch.load(cache_file_path, map_location="cpu") 194 | except Exception as e: 195 | print(cache_file_path) 196 | print(e) 197 | sleep(0.01) 198 | else: 199 | break 200 | else: 201 | return (False,) 202 | if ( 203 | cache.get("envmap_name") == env_name 204 | and cache.get("brdf_param_names") == brdf_param_names 205 | and cache.get("obj_name") == obj_name 206 | and torch.allclose(cache.get("zk"), z) 207 | and torch.allclose(cache.get("view_from"), view_from) 208 | and cache.get("image_size") == size 209 | ): 210 | return ( 211 | True, 212 | cache.get("rendering_results_image"), 213 | cache.get("rendering_results_normal"), 214 | cache.get("rendering_results_depth"), 215 | ) 216 | else: 217 | return (False,) 218 | 219 | result, *cache = get_cache(zK) 220 | if result: 221 | img, normal, depth = cache 222 | data["img"] = img 223 | data["img_normal"] = normal 224 | data["img_depth"] = depth 225 | else: 226 | data["img"] = data["img_normal"] = torch.full((3, *size), torch.nan) 227 | data["img_depth"] = torch.full((1, *size), torch.nan) 228 | seem_need_obj = not result 229 | 230 | ##### get raw refmap cache ##### 231 | size = self.model.img_renderer.image_size 232 | spp = self.model.img_renderer.spp 233 | denoise_suffix = f"_{self.model.img_renderer.denoise}denoise" if self.model.img_renderer.denoise else "" 234 | cache_dir = self.objimg_cache_root / "-".join(brdf_param_names) / f"{size[0]}x{size[1]}_spp{spp}{denoise_suffix}_rawrefmap/" 235 | torch.set_printoptions(precision=4, sci_mode=True) 236 | 237 | def get_cache(z) -> torch.Tensor: 238 | pieces_key = "b" + str(z)[7:-1] + "v" + str(view_from)[7:-1] 239 | pieces_key = pieces_key.replace("\n", "").replace(" ", "") 240 | filename = pieces_key + ".pt" 241 | cache_file_path = cache_dir / env_name / obj_name / filename 242 | if not cache_file_path.exists(): 243 | return (False,) 244 | for _ in range(3): 245 | try: 246 | cache: dict = torch.load(cache_file_path, map_location="cpu") 247 | except Exception as e: 248 | print(cache_file_path) 249 | print(e) 250 | sleep(0.01) 251 | else: 252 | break 253 | else: 254 | return (False,) 255 | if ( 256 | cache.get("envmap_name") == env_name 257 | and cache.get("brdf_param_names") == brdf_param_names 258 | and cache.get("obj_name") == obj_name 259 | and torch.allclose(cache.get("zk"), z) 260 | and torch.allclose(cache.get("view_from"), view_from) 261 | and cache.get("image_size") == size 262 | ): 263 | return ( 264 | True, 265 | cache.get("raw_refmap"), 266 | cache.get("raw_refmask"), 267 | ) 268 | else: 269 | return (False,) 270 | 271 | result, *cache = get_cache(zK) 272 | if result: 273 | data["raw_refmap"], data["raw_refmask"] = cache 274 | else: 275 | refmap_size = self.model.refmap_renderer.image_size 276 | data["raw_refmap"] = torch.full((3, *refmap_size), torch.nan) 277 | data["raw_refmask"] = torch.full((*refmap_size,), torch.nan) 278 | 279 | if self.return_envmap and seem_need_envmap: 280 | if self.preload_envmap: 281 | envmap = self.envmaps[env_name] 282 | else: 283 | envmap = load_exr(self.root / f"{env_name}.exr", as_torch=True) 284 | data["envmap"] = envmap 285 | elif self.return_envmap: 286 | # skip loading environment map. 287 | try: 288 | envmap_size = self.model.renderer.envmap_size 289 | except Exception: 290 | envmap_size = (1000, 2000) 291 | data["envmap"] = torch.full((*envmap_size, 3), torch.nan, dtype=torch.float) 292 | 293 | if self.return_obj and seem_need_obj: 294 | data["obj_shape"] = torch.load(self.shape_root / f"{obj_name}.pt") 295 | 296 | data["tag"] = env_name 297 | 298 | return data 299 | -------------------------------------------------------------------------------- /data/datalists/LavalIndoor+PolyHaven_2k/envs_test.txt: -------------------------------------------------------------------------------- 1 | 9C4A0022-6d8fe2e88e.exr 2 | 9C4A0027-9a856d679a.exr 3 | 9C4A0037-b2d1efd096.exr 4 | 9C4A0048-48d7dfa6b0.exr 5 | 9C4A0064-2bf2b7e178.exr 6 | 9C4A0071-20e04984d0.exr 7 | 9C4A0076-46c51ec2c2.exr 8 | 9C4A0087-5ad3395167.exr 9 | 9C4A0088-5162e36ae6.exr 10 | 9C4A0090-81b4c61bc1.exr 11 | 9C4A0120-0fd27f2a38.exr 12 | 9C4A0132-07352d1dd0.exr 13 | 9C4A0137-73c0813158.exr 14 | 9C4A0202-619f67a287.exr 15 | 9C4A0232-2c9e5539a7.exr 16 | 9C4A0269-9c20633669.exr 17 | 9C4A0286-fb330a56ab.exr 18 | 9C4A0289-598be55b88.exr 19 | 9C4A0330-55d4beffc9.exr 20 | 9C4A0389-f0ecdf64ec.exr 21 | 9C4A0442-721b8f15a6.exr 22 | 9C4A0456-6e650e11bc.exr 23 | 9C4A0473-bde5e7f570.exr 24 | 9C4A0531-176ffd54e6.exr 25 | 9C4A0590-aed3c9aeba.exr 26 | 9C4A0736-308c00e42e.exr 27 | 9C4A0741-2467498d3c.exr 28 | 9C4A0758-f8da9bd020.exr 29 | 9C4A0807-2e77596700.exr 30 | 9C4A0825-99f33c64e1.exr 31 | 9C4A0851-89bc5b7946.exr 32 | 9C4A0857-513e454df9.exr 33 | 9C4A0933-f725db0f7d.exr 34 | 9C4A0935-c41a73af78.exr 35 | 9C4A0951-3569b2c0d0.exr 36 | 9C4A0975-aebac9568b.exr 37 | 9C4A1061-bd1d2e34fc.exr 38 | 9C4A1143-0c066bc93e.exr 39 | 9C4A1193-3b242e921e.exr 40 | 9C4A1225-2ccdc48c9b.exr 41 | 9C4A1229-0eaf4321a8.exr 42 | 9C4A1254-adc42c0fa3.exr 43 | 9C4A1313-80a62190cc.exr 44 | 9C4A1380-c255100a1f.exr 45 | 9C4A1413-67bd114eb8.exr 46 | 9C4A1481-44ceed945c.exr 47 | 9C4A1491-5e42d4fc4c.exr 48 | 9C4A1506-2a1235e04f.exr 49 | 9C4A1553-e5182914b7.exr 50 | 9C4A1578-457ab482d7.exr 51 | 9C4A1612-974a98b1a7.exr 52 | 9C4A1632-254b17a8df.exr 53 | 9C4A1649-7114015c3f.exr 54 | 9C4A1654-5c2a973161.exr 55 | 9C4A1665-5b87a04322.exr 56 | 9C4A1704-1892449f96.exr 57 | 9C4A1732-a7851ff482.exr 58 | 9C4A1738-9f78e5e8ea.exr 59 | 9C4A1830-65ed5bab7f.exr 60 | 9C4A1891-85da239af9.exr 61 | 9C4A1914-18c9480a7a.exr 62 | 9C4A1933-1293562e3c.exr 63 | 9C4A1984-9309a49031.exr 64 | 9C4A2001-27637c0a62.exr 65 | 9C4A2026-601b4934e1.exr 66 | 9C4A2027-f19405ffbc.exr 67 | 9C4A2069-b987fb04c3.exr 68 | 9C4A2110-ea9bb8d077.exr 69 | 9C4A2127-d5643f8d0c.exr 70 | 9C4A2158-dbb5e3d29e.exr 71 | 9C4A2169-306412901a.exr 72 | 9C4A2253-2df4fe5a7c.exr 73 | 9C4A2285-d5ef0f4f37.exr 74 | 9C4A2353-d5d28ed5ca.exr 75 | 9C4A2470-cf8b261b44.exr 76 | 9C4A2515-1c7c6ec0a4.exr 77 | 9C4A2537-490a013587.exr 78 | 9C4A2725-19319b919a.exr 79 | 9C4A2789-4aa46ff9fc.exr 80 | 9C4A2806-fb8abb7257.exr 81 | 9C4A2809-8efdfd7b88.exr 82 | 9C4A2857-dc0a9513b2.exr 83 | 9C4A2859-8556a1b996.exr 84 | 9C4A2893-0b10daa3c5.exr 85 | 9C4A2977-50d47c67ad.exr 86 | 9C4A3016-af62e79bd4.exr 87 | 9C4A3187-9c8f0e2b1b.exr 88 | 9C4A3233-df84227a2e.exr 89 | 9C4A3271-d56401b1b0.exr 90 | 9C4A3401-42728bf160.exr 91 | 9C4A3459-006926b10e.exr 92 | 9C4A3485-c2c366ffa1.exr 93 | 9C4A3527-8d62bea469.exr 94 | 9C4A3614-b7a70ee631.exr 95 | 9C4A3629-2fd9d1ce76.exr 96 | 9C4A3676-f2b2ebf249.exr 97 | 9C4A3695-cb51bf21c7.exr 98 | 9C4A3713-0e6fea92ca.exr 99 | 9C4A3779-3c4ea5d304.exr 100 | 9C4A3905-0ad61f66a6.exr 101 | 9C4A3915-7146a93a0d.exr 102 | 9C4A3947-48f6e05c56.exr 103 | 9C4A3957-e70b314287.exr 104 | 9C4A4048-19ec80027a.exr 105 | 9C4A4049-6fec018e1d.exr 106 | 9C4A4049-76117dc8f2.exr 107 | 9C4A4073-817fa1dd52.exr 108 | 9C4A4090-1e7ebda10a.exr 109 | 9C4A4091-df37c028b0.exr 110 | 9C4A4259-a93f8ed78d.exr 111 | 9C4A4265-5a4082902d.exr 112 | 9C4A4325-efc3b45a84.exr 113 | 9C4A4343-ca396e6daa.exr 114 | 9C4A4349-6f3dab07ed.exr 115 | 9C4A4416-b9d1ceaccc.exr 116 | 9C4A4493-16bb9d0804.exr 117 | 9C4A4511-cc26af8525.exr 118 | 9C4A4511-e4a50b76c8.exr 119 | 9C4A4517-223e5f175e.exr 120 | 9C4A4559-71821b41f8.exr 121 | 9C4A4574-5843a382ec.exr 122 | 9C4A4577-fb8c0938aa.exr 123 | 9C4A4601-018309e12a.exr 124 | 9C4A4658-4ea4a50f3c.exr 125 | 9C4A4710-d7f953501a.exr 126 | 9C4A4742-1d56d6a1cb.exr 127 | 9C4A4829-d460199b9e.exr 128 | 9C4A4836-f0c9bd7e8d.exr 129 | 9C4A4871-68852f9ed8.exr 130 | 9C4A4895-b235532a3a.exr 131 | 9C4A4917-d16e90f9bb.exr 132 | 9C4A5004-755c441d08.exr 133 | 9C4A5078-4dd5d8db46.exr 134 | 9C4A5088-2ea9851b2b.exr 135 | 9C4A5120-ae2015952e.exr 136 | 9C4A5123-b3b4664e04.exr 137 | 9C4A5169-f587ead874.exr 138 | 9C4A5253-7717cd1be9.exr 139 | 9C4A5273-818063299a.exr 140 | 9C4A5315-1f8aad4cf6.exr 141 | 9C4A5399-b0507175ad.exr 142 | 9C4A5421-c32f48efb8.exr 143 | 9C4A5463-8cae624e04.exr 144 | 9C4A5483-576ea585a5.exr 145 | 9C4A5585-1fb3d557eb.exr 146 | 9C4A5631-d2eb89337c.exr 147 | 9C4A5669-87172d02a7.exr 148 | 9C4A5690-7ea1980afe.exr 149 | 9C4A5757-78474365bd.exr 150 | 9C4A5798-88fa5bb259.exr 151 | 9C4A5816-c2e6f40b0f.exr 152 | 9C4A5848-d8e736c362.exr 153 | 9C4A5858-0338b2aed1.exr 154 | 9C4A5879-2239434041.exr 155 | 9C4A5890-3c5dc048a8.exr 156 | 9C4A6005-806579ae95.exr 157 | 9C4A6033-d76f04669f.exr 158 | 9C4A6047-0228188e22.exr 159 | 9C4A6057-427c31afde.exr 160 | 9C4A6142-ff9a15d233.exr 161 | 9C4A6173-1f516a7213.exr 162 | 9C4A6215-534513a8d9.exr 163 | 9C4A6225-ccf0d2da77.exr 164 | 9C4A6303-6fc942c8fe.exr 165 | 9C4A6351-0e98e87647.exr 166 | 9C4A6393-b903bc6eeb.exr 167 | 9C4A6435-79444159dd.exr 168 | 9C4A6436-19040d41c8.exr 169 | 9C4A6471-8995c9a5df.exr 170 | 9C4A6519-e2541d62cd.exr 171 | 9C4A6561-ce430d545f.exr 172 | 9C4A6593-cc89f63f32.exr 173 | 9C4A6677-a24db2ba03.exr 174 | 9C4A6681-d4f49ff53f.exr 175 | 9C4A6723-87ef41ab55.exr 176 | 9C4A6771-d80511f983.exr 177 | 9C4A6814-ab59fd0aee.exr 178 | 9C4A6845-bb1a55dea3.exr 179 | 9C4A6864-533a74ab58.exr 180 | 9C4A6898-2ab56f7592.exr 181 | 9C4A7027-f0a3809e39.exr 182 | 9C4A7108-805c1c44e7.exr 183 | 9C4A7139-052d15387b.exr 184 | 9C4A7158-8d4a4a5f27.exr 185 | 9C4A7206-b84ed793b8.exr 186 | 9C4A7409-3f66dfc3af.exr 187 | 9C4A7443-52fe66904e.exr 188 | 9C4A7494-e43fb40590.exr 189 | 9C4A7577-a706e51b49.exr 190 | 9C4A7584-f649778f07.exr 191 | 9C4A7601-d88ad72116.exr 192 | 9C4A7704-909a566b55.exr 193 | 9C4A7746-d909578cf2.exr 194 | 9C4A7748-c521dee90e.exr 195 | 9C4A7752-377d122518.exr 196 | 9C4A7790-b3c4813217.exr 197 | 9C4A7832-dc7ade72b1.exr 198 | 9C4A7836-6df8adffea.exr 199 | 9C4A7863-157974a593.exr 200 | 9C4A7945-662330df4f.exr 201 | 9C4A7956-179fee4051.exr 202 | 9C4A7989-0812eba9f2.exr 203 | 9C4A8040-0ccefe88e8.exr 204 | 9C4A8063-9345e78404.exr 205 | 9C4A8071-3426acbb49.exr 206 | 9C4A8082-199782912e.exr 207 | 9C4A8113-2c7b06369d.exr 208 | 9C4A8130-5d942895cb.exr 209 | 9C4A8197-6ac47a6940.exr 210 | 9C4A8199-b9beac78e4.exr 211 | 9C4A8210-d02b7e05c7.exr 212 | 9C4A8214-62124c8576.exr 213 | 9C4A8239-7a9d7093e2.exr 214 | 9C4A8323-f49f59d214.exr 215 | 9C4A8407-ef992e38bd.exr 216 | 9C4A8420-c08f621416.exr 217 | 9C4A8453-98f5c2582e.exr 218 | 9C4A8493-110b3710e9.exr 219 | 9C4A8621-cf274f97ba.exr 220 | 9C4A8634-efda6b8d17.exr 221 | 9C4A8756-9a19dd7e37.exr 222 | 9C4A8757-4a91f1b4ad.exr 223 | 9C4A8819-15f7ba6f62.exr 224 | 9C4A8829-95bdc0c166.exr 225 | 9C4A8973-d0bf35e71b.exr 226 | 9C4A9006-697d7494fd.exr 227 | 9C4A9009-b83cabcdf1.exr 228 | 9C4A9012-d2b537d2b5.exr 229 | 9C4A9025-e5a2766805.exr 230 | 9C4A9060-b8192ab931.exr 231 | 9C4A9127-b1feeeb169.exr 232 | 9C4A9144-f730880c8d.exr 233 | 9C4A9151-fada4717af.exr 234 | 9C4A9186-b002fef034.exr 235 | 9C4A9211-dbd6a85a92.exr 236 | 9C4A9235-1a84e2f01f.exr 237 | 9C4A9249-9c8da3785f.exr 238 | 9C4A9277-9c862b2505.exr 239 | 9C4A9295-4a0dcd60c3.exr 240 | 9C4A9336-d2d26994a7.exr 241 | 9C4A9378-a4be8e207d.exr 242 | 9C4A9400-7efe755465.exr 243 | 9C4A9426-f14f216510.exr 244 | 9C4A9432-97c36cc65a.exr 245 | 9C4A9522-26b3895792.exr 246 | 9C4A9571-224c7db7ed.exr 247 | 9C4A9601-8d6d60f0a8.exr 248 | 9C4A9607-d4e1cfd7b1.exr 249 | 9C4A9613-820ff0c193.exr 250 | 9C4A9643-9178148dea.exr 251 | 9C4A9691-fa47a7b00a.exr 252 | 9C4A9757-7a3a1a823c.exr 253 | 9C4A9795-6c3750de78.exr 254 | 9C4A9799-7440cd3c3b.exr 255 | 9C4A9826-c3290de53d.exr 256 | 9C4A9853-a7d92fe209.exr 257 | 9C4A9942-420d60e303.exr 258 | AG8A0008-5f7fa416c3.exr 259 | AG8A0033-37e0cc5b3f.exr 260 | AG8A0069-45e4ea47f3.exr 261 | AG8A0111-ffd7e1ad93.exr 262 | AG8A0129-0fcfdc0b4f.exr 263 | AG8A0167-dd150f47a0.exr 264 | AG8A0235-601b71addf.exr 265 | AG8A0243-acbd6fe241.exr 266 | AG8A0277-ff75a19dbc.exr 267 | AG8A0302-0e0740d651.exr 268 | AG8A0321-d11f8bd853.exr 269 | AG8A0327-4c1ce87b15.exr 270 | AG8A0361-7d810ca94d.exr 271 | AG8A0376-e91cd665eb.exr 272 | AG8A0378-77a98f4b00.exr 273 | AG8A0405-da5e623484.exr 274 | AG8A0418-d45d849cac.exr 275 | AG8A0465-bc4e1415c2.exr 276 | AG8A0487-e8741b1602.exr 277 | AG8A0489-b02f817def.exr 278 | AG8A0504-24e9fdcaa1.exr 279 | AG8A0529-8f77a41bc1.exr 280 | AG8A0531-04ef73906f.exr 281 | AG8A0546-76034a82cc.exr 282 | AG8A0615-f3f7d3d997.exr 283 | AG8A0657-b378e38ddf.exr 284 | AG8A0670-6edcbd3efb.exr 285 | AG8A0680-38e426ed5a.exr 286 | AG8A0699-21dd19a918.exr 287 | AG8A0739-818c841b81.exr 288 | AG8A0823-d7d5fe7916.exr 289 | AG8A0865-f1e24d9db3.exr 290 | AG8A0882-aa605e646c.exr 291 | AG8A0949-925f52a2bb.exr 292 | AG8A0971-791e7557c5.exr 293 | AG8A1004-cb6d25d335.exr 294 | AG8A1008-442956e5fa.exr 295 | AG8A1016-f44fca3fb6.exr 296 | AG8A1058-8425e0a13e.exr 297 | AG8A1100-3dbe89c322.exr 298 | AG8A1134-16c73a2d0d.exr 299 | AG8A1159-b6bb1a94bd.exr 300 | AG8A1172-e85e165d1f.exr 301 | AG8A1218-e50a8d804f.exr 302 | AG8A1223-6a401cfc0f.exr 303 | AG8A1285-e7ec2299e7.exr 304 | AG8A1298-82e070d190.exr 305 | AG8A1307-70d0f8af04.exr 306 | AG8A1352-982b77f8e4.exr 307 | AG8A1386-5ab7a5892c.exr 308 | AG8A1550-3d74fa5285.exr 309 | AG8A1601-486a3db5d3.exr 310 | AG8A1663-fa155284f3.exr 311 | AG8A1680-fa60be802f.exr 312 | AG8A1722-d03677dbbf.exr 313 | AG8A1734-85d2ddf221.exr 314 | AG8A1747-1c71c4733f.exr 315 | AG8A1806-30607c0512.exr 316 | AG8A1912-6ce9504c01.exr 317 | AG8A1974-2ed26a0074.exr 318 | AG8A2019-e095d1c02d.exr 319 | AG8A2132-aa8bfbe11b.exr 320 | AG8A2174-9c794f54d8.exr 321 | AG8A2206-02ce50ddef.exr 322 | AG8A2248-ba2591a352.exr 323 | AG8A2332-b66be9cbc5.exr 324 | AG8A2420-cd7a0f2d9a.exr 325 | AG8A2462-25d762f046.exr 326 | AG8A2552-978a5a4da6.exr 327 | AG8A2689-8600cec23b.exr 328 | AG8A2727-dd421d4610.exr 329 | AG8A2769-ed34ba3286.exr 330 | AG8A2798-90b3f4324b.exr 331 | AG8A2840-efce5f1be3.exr 332 | AG8A3025-65c96016e5.exr 333 | AG8A3074-d699cd0768.exr 334 | AG8A3105-36a0469e3d.exr 335 | AG8A3154-56ba2a63ed.exr 336 | AG8A3158-c709f4dad8.exr 337 | AG8A3252-4f78d92263.exr 338 | AG8A3427-9277a6a188.exr 339 | AG8A3428-5bd15f314f.exr 340 | AG8A3459-b62f18a784.exr 341 | AG8A3469-d06a09bb74.exr 342 | AG8A3501-2d1ee22a3b.exr 343 | AG8A3553-c89f4da33f.exr 344 | AG8A3585-d77910dfd8.exr 345 | AG8A3753-e46eab5377.exr 346 | AG8A3879-55d6896ffe.exr 347 | AG8A3963-e4b95c9b52.exr 348 | AG8A4114-b7982086e0.exr 349 | AG8A4190-256e16c3a9.exr 350 | AG8A4232-b2f0bac421.exr 351 | AG8A4299-5147f35137.exr 352 | AG8A4491-b06d814bd8.exr 353 | AG8A4551-df55f70998.exr 354 | AG8A4593-730c5459fd.exr 355 | AG8A4803-80ba755341.exr 356 | AG8A4828-37a65ad7b3.exr 357 | AG8A4841-ee00cd67c5.exr 358 | AG8A4883-fbe6120d00.exr 359 | AG8A4887-71c7cb7b1f.exr 360 | AG8A5009-83fad15e72.exr 361 | AG8A5013-ad4e1b1b14.exr 362 | AG8A5276-e798820521.exr 363 | AG8A5314-1e720b63af.exr 364 | AG8A5360-dddda72365.exr 365 | AG8A5400-d1783082d2.exr 366 | AG8A5444-f8f35e0b62.exr 367 | AG8A5540-f5c7b0336d.exr 368 | AG8A5624-7f67c67643.exr 369 | AG8A5776-2e736db2be.exr 370 | AG8A5792-db79c07e52.exr 371 | AG8A5878-c9289fa855.exr 372 | AG8A6044-fba42ebf98.exr 373 | AG8A6154-f4c9fd92da.exr 374 | AG8A6172-f49ab429b3.exr 375 | AG8A6225-e1a6beee99.exr 376 | AG8A6322-4381d3c535.exr 377 | AG8A6351-195e7c65df.exr 378 | AG8A6424-c44e6566e6.exr 379 | AG8A6635-3f90581acd.exr 380 | AG8A6677-6e3b70fb20.exr 381 | AG8A6687-06dff16191.exr 382 | AG8A6719-5b305f5869.exr 383 | AG8A6729-acabc12ff1.exr 384 | AG8A6761-d4d27f08a0.exr 385 | AG8A6901-fc04a8708a.exr 386 | AG8A6917-a1e454ea5a.exr 387 | AG8A7062-4cba8652de.exr 388 | AG8A7093-90313edd36.exr 389 | AG8A7169-c9225e32bf.exr 390 | AG8A7188-43a3aa916f.exr 391 | AG8A7212-9a6baaca0c.exr 392 | AG8A7230-2b83e04fdb.exr 393 | AG8A7254-667abf1aa1.exr 394 | AG8A7303-70e2aca011.exr 395 | AG8A7345-fab0887a58.exr 396 | AG8A7482-0c9efe2fda.exr 397 | AG8A7573-ef5c737eea.exr 398 | AG8A7716-11e38d9191.exr 399 | AG8A7849-375e867fcd.exr 400 | AG8A7891-6d756d2939.exr 401 | AG8A7925-42cd6017a9.exr 402 | AG8A7945-e32b43f81e.exr 403 | AG8A7968-6a0b00102a.exr 404 | AG8A8059-71b594469b.exr 405 | AG8A8154-8f567cc42f.exr 406 | AG8A8178-4a01c74148.exr 407 | AG8A8238-1cb4388703.exr 408 | AG8A8269-5b4709233a.exr 409 | AG8A8280-bb535895b6.exr 410 | AG8A8289-0820a81f56.exr 411 | AG8A8331-15461a6c90.exr 412 | AG8A8373-61755ee58f.exr 413 | AG8A8435-a258e7d58d.exr 414 | AG8A8584-cf76711ac4.exr 415 | AG8A8658-cdca6ead27.exr 416 | AG8A8687-d92cfefbfa.exr 417 | AG8A8892-fd0d5be873.exr 418 | AG8A8899-8c17c18be3.exr 419 | AG8A8934-1986519736.exr 420 | AG8A8941-f47739d637.exr 421 | AG8A8974-4fb7fcf8d6.exr 422 | AG8A9078-86a3710c04.exr 423 | AG8A9100-9b0d212129.exr 424 | AG8A9155-ac93623025.exr 425 | AG8A9162-dd89565d2f.exr 426 | AG8A9184-bd188e62f8.exr 427 | AG8A9213-6ccf62da91.exr 428 | AG8A9288-cfd59522a3.exr 429 | AG8A9330-9d5b5d3cf8.exr 430 | AG8A9372-8ac67e4116.exr 431 | AG8A9394-eec2500e82.exr 432 | AG8A9396-c764bf47a9.exr 433 | AG8A9403-6f7a7ca1e5.exr 434 | AG8A9436-757daaad80.exr 435 | AG8A9445-4474dd834d.exr 436 | AG8A9465-a5244feb26.exr 437 | AG8A9478-08018709fd.exr 438 | AG8A9562-bfbcc44936.exr 439 | AG8A9564-208017fcb7.exr 440 | AG8A9591-325e4dc778.exr 441 | AG8A9746-14547b1a7d.exr 442 | AG8A9750-f8ecbc9cb4.exr 443 | AG8A9834-c3c757aec7.exr 444 | AG8A9872-d23929b45b.exr 445 | AG8A9898-70529f41e7.exr 446 | AG8A9940-dad1503196.exr 447 | abandoned_church.exr 448 | abandoned_games_room_01.exr 449 | abandoned_hall_01.exr 450 | abandoned_tank_farm_03.exr 451 | altanka.exr 452 | approaching_storm.exr 453 | art_studio.exr 454 | autoshop_01.exr 455 | autumn_forest_01.exr 456 | bell_park_dawn.exr 457 | bergen.exr 458 | blue_lagoon_night.exr 459 | cabin.exr 460 | cambridge.exr 461 | carpentry_shop_02.exr 462 | cave_wall.exr 463 | cliffside.exr 464 | colosseum.exr 465 | de_balie.exr 466 | dikhololo_night.exr 467 | dry_field.exr 468 | eilenriede_labyrinth.exr 469 | eilenriede_park.exr 470 | evening_road_01.exr 471 | freight_station.exr 472 | future_parking.exr 473 | garage.exr 474 | glass_passage.exr 475 | hospital_room.exr 476 | kiara_1_dawn.exr 477 | kiara_3_morning.exr 478 | kiara_6_afternoon.exr 479 | kloofendal_48d_partly_cloudy.exr 480 | kloppenheim_07.exr 481 | lapa.exr 482 | lauter_waterfall.exr 483 | lot_02.exr 484 | lythwood_room.exr 485 | lythwood_terrace.exr 486 | mealie_road.exr 487 | missile_launch_facility_01.exr 488 | moonlit_golf.exr 489 | muddy_autumn_forest.exr 490 | museum_of_ethnography.exr 491 | ninomaru_teien.exr 492 | noon_grass.exr 493 | outdoor_umbrellas.exr 494 | parched_canal.exr 495 | peppermint_powerplant.exr 496 | piazza_martin_lutero.exr 497 | piazza_san_marco.exr 498 | pond.exr 499 | pool.exr 500 | qwantani.exr 501 | red_hill_curve.exr 502 | reichstag_1.exr 503 | reinforced_concrete_01.exr 504 | rhodes_memorial.exr 505 | river_walk_2.exr 506 | rocky_ridge.exr 507 | roofless_ruins.exr 508 | rural_landscape.exr 509 | sisulu.exr 510 | skate_park.exr 511 | skukuza_golf.exr 512 | skylit_garage.exr 513 | small_cathedral_02.exr 514 | small_hangar_01.exr 515 | small_harbor_02.exr 516 | spiaggia_di_mondello.exr 517 | st_fagans_interior.exr 518 | stadium_01.exr 519 | studio_small_01.exr 520 | studio_small_02.exr 521 | sunflowers.exr 522 | sunset_forest.exr 523 | syferfontein_0d_clear.exr 524 | table_mountain_2.exr 525 | teatro_massimo.exr 526 | the_sky_is_on_fire.exr 527 | theater_01.exr 528 | tiber_island.exr 529 | ulmer_muenster.exr 530 | urban_alley_01.exr 531 | urban_courtyard.exr 532 | urban_street_04.exr 533 | veld_fire.exr 534 | venice_sunset.exr 535 | wide_street_02.exr 536 | winter_river.exr 537 | winter_sky.exr 538 | wooden_motel.exr 539 | zwinger_night.exr -------------------------------------------------------------------------------- /utils/transform.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def convert_array_to_torch(array, device: Optional[torch.device], dtype: Optional[torch.dtype] = None): 9 | if isinstance(array, torch.Tensor): 10 | return array.to(device) 11 | if isinstance(array, np.ndarray): 12 | return torch.from_numpy(array).to(device).to(dtype) 13 | else: 14 | return torch.tensor(array, device=device, dtype=dtype) 15 | 16 | 17 | def thetaphi2xyz( 18 | thetaphi: Union[torch.Tensor, np.ndarray], 19 | normal: Optional[Union[torch.Tensor, np.ndarray, List[float]]] = [0.0, 0.0, 1.0], 20 | tangent: Optional[Union[torch.Tensor, np.ndarray, List[float]]] = [1.0, 0.0, 0.0], 21 | reverse_phi: bool = False, 22 | assume_normalized: bool = True, 23 | ) -> Union[torch.Tensor, np.ndarray]: 24 | """ 25 | thetaphi : [theta, phi] 26 | normal: [x, y, z], default [0, 0, 1] 27 | tangent: [x, y, z], default [1, 0, 0] 28 | reverse_theta: if set to True, phi positive is clockwise. 29 | assume_normalized: if set to True, normal and tangent are assumed to be normalized. 30 | return : [x, y, z] 31 | """ 32 | if isinstance(thetaphi, torch.Tensor): 33 | device = thetaphi.device 34 | dtype = thetaphi.dtype 35 | normal = convert_array_to_torch(normal, device=device, dtype=dtype) 36 | tangent = convert_array_to_torch(tangent, device=device, dtype=dtype) 37 | module = torch 38 | else: 39 | normal = np.array(normal) 40 | tangent = np.array(tangent) 41 | module = np 42 | if not assume_normalized: 43 | normal = normalize(normal) 44 | tangent = normalize(tangent) 45 | binormal = module.cross(normal, tangent) 46 | if reverse_phi: 47 | binormal *= -1 48 | xyz = module.cos(thetaphi[..., 0:1]) * normal 49 | sin_theta = module.sin(thetaphi[..., 0:1]) 50 | xyz += sin_theta * module.cos(thetaphi[..., 1:2]) * tangent 51 | xyz += sin_theta * module.sin(thetaphi[..., 1:2]) * binormal 52 | return xyz 53 | 54 | 55 | def xyz2thetaphi( 56 | xyz: Union[torch.Tensor, np.ndarray], 57 | normal: Optional[Union[torch.Tensor, np.ndarray, List[float]]] = [0.0, 0.0, 1.0], 58 | tangent: Optional[Union[torch.Tensor, np.ndarray, List[float]]] = [1.0, 0.0, 0.0], 59 | reverse_phi: bool = False, 60 | assume_normalized: bool = True, 61 | ) -> Union[torch.Tensor, np.ndarray]: 62 | """ 63 | xyz : [x, y, z] 64 | normal: [x, y, z], default [0, 0, 1] 65 | tangent: [x, y, z], default [1, 0, 0] 66 | reverse_theta: if set to True, phi positive is clockwise. 67 | assume_normalized: if set to True, xyz, normal and tangent are assumed to be normalized. 68 | return : [theta, phi] (theta in [0, pi], phi in (-pi, pi)) 69 | """ 70 | if isinstance(xyz, torch.Tensor): 71 | device = xyz.device 72 | dtype = xyz.dtype 73 | normal = convert_array_to_torch(normal, device=device, dtype=dtype) 74 | tangent = convert_array_to_torch(tangent, device=device, dtype=dtype) 75 | module = torch 76 | else: 77 | normal = np.array(normal) 78 | tangent = np.array(tangent) 79 | module = np 80 | if not assume_normalized: 81 | xyz = normalize(xyz) 82 | normal = normalize(normal) 83 | tangent = normalize(tangent) 84 | binormal = module.cross(normal, tangent) 85 | if reverse_phi: 86 | binormal *= -1 87 | theta = module.arccos(module.matmul(xyz, normal[..., None]))[..., 0] 88 | phi = module.arctan2(module.matmul(xyz, binormal[..., None]), module.matmul(xyz, tangent[..., None]))[..., 0] 89 | return module.stack((theta, phi), -1) 90 | 91 | 92 | def normalize( 93 | xyz: Union[torch.Tensor, np.ndarray], 94 | dim: int = -1, 95 | eps: float = 1e-12, 96 | ) -> Union[torch.Tensor, np.ndarray]: 97 | assert xyz.shape[-1] == 3 98 | if isinstance(xyz, torch.Tensor): 99 | return torch.nn.functional.normalize(xyz, dim=dim, eps=eps) 100 | else: 101 | length = np.linalg.norm(xyz, axis=dim, keepdims=True) 102 | length = np.clip(length, eps, None) 103 | return xyz / length 104 | 105 | 106 | def mirmap2envmap( 107 | mirmap: torch.Tensor, 108 | output_shape: tuple, 109 | view: Union[torch.Tensor, List[float]] = [0, 0, 1], 110 | top: Union[torch.Tensor, List[float]] = [0, 1, 0], 111 | envmap_zenith: Union[torch.Tensor, List[float]] = [0, 1, 0], 112 | envmap_left_edge: Union[torch.Tensor, List[float]] = [0, 0, -1], 113 | reverse_azimuth: bool = True, 114 | log_scale_interpolation: bool = False, 115 | ) -> torch.Tensor: 116 | assert view == [0, 0, 1], "now support [0,0,1] view direction" 117 | device = mirmap.device 118 | dtype = mirmap.dtype 119 | view = convert_array_to_torch(view, device=device, dtype=dtype) 120 | top = convert_array_to_torch(top, device=device, dtype=dtype) 121 | height, width = mirmap.shape[-2:] 122 | OH, OW = output_shape 123 | theta = (torch.arange(OH, device=device) + 0.5) * (torch.pi / OH) 124 | phi = (torch.arange(OW, device=device) + 0.5) * (torch.pi * 2 / OW) 125 | if reverse_azimuth: 126 | phi = -phi 127 | thetaphi = torch.stack(torch.meshgrid(theta, phi, indexing="ij"), axis=-1) 128 | xyz = thetaphi2xyz(thetaphi, normal=envmap_zenith, tangent=envmap_left_edge) 129 | normal_map = xyz2thetaphi(normalize(xyz + view), normal=top, tangent=view) 130 | u = normal_map[..., 1] * (2 / torch.pi) 131 | v = normal_map[..., 0] * (2 / torch.pi) - 1 132 | uv = torch.stack([u, v], axis=-1) 133 | if log_scale_interpolation: 134 | mirmap = torch.log(mirmap.clip(1e-7)) 135 | envmap = torch.nn.functional.grid_sample( 136 | mirmap, 137 | uv[None].expand(mirmap.size(0), -1, -1, -1), 138 | mode="bilinear", 139 | padding_mode="border", 140 | align_corners=False, 141 | ) 142 | if log_scale_interpolation: 143 | envmap = torch.exp(envmap) 144 | return envmap 145 | 146 | 147 | def gen_sphere_normals_realcentering(radius, edge=0): 148 | # real centering 149 | """Generate a set of normals of a spherical object from an orthographic camera.""" 150 | normals = np.zeros((radius * 2, radius * 2, 3), dtype=np.float32) 151 | x = np.linspace(-radius + 0.5, radius - 0.5, num=2 * radius, endpoint=True) 152 | y = np.linspace(radius - 0.5, -radius + 0.5, num=2 * radius, endpoint=True) 153 | x, y = np.meshgrid(x, y) 154 | 155 | zsq = radius**2 - (x**2 + y**2) 156 | 157 | normals[..., 0] = x 158 | normals[..., 1] = y 159 | normals[zsq >= 0.0, 2] = np.sqrt(zsq[zsq >= 0.0]) 160 | normals[...] /= np.sqrt(np.sum(normals**2, axis=2, keepdims=True)) 161 | normals[zsq < 0.0] = 0.0 162 | 163 | xx, yy = np.ogrid[0 : radius * 2, 0 : radius * 2] 164 | xx, yy = xx + 0.5, yy + 0.5 165 | mask = ((xx - radius) ** 2 + (yy - radius) ** 2) <= ((radius - edge) * (radius - edge)) 166 | 167 | return normals * mask[..., None], mask 168 | 169 | 170 | def refmap2refimg_torch(refmap: torch.Tensor, radius: int = None, return_mask: bool = False) -> torch.Tensor: 171 | """ 172 | input: [(Batch), Channel, Height, Width] 173 | return: [(Batch), Channel, Height, Width] 174 | """ 175 | if radius is None: 176 | radius = max(refmap.shape[-2:]) 177 | res = radius * 2 178 | dtype = refmap.dtype 179 | device = refmap.device 180 | height, width = refmap.shape[-2:] 181 | sphere_normal_map, mask = gen_sphere_normals_realcentering(radius) # x, y, z 182 | sphere_normal_map = torch.from_numpy(sphere_normal_map).to(device=device, dtype=dtype) 183 | mask = torch.from_numpy(mask).to(device=device) 184 | uv = xyz2thetaphi(sphere_normal_map[mask, :], [0, 1, 0], [-1, 0, 0]) # [masked HxW, 2(theta, phi)] 185 | uv = uv.flip(-1) * (2 / torch.pi) - 1 186 | batch_flag = True 187 | if refmap.ndim == 3: 188 | batch_flag = False 189 | refmap = refmap[None] 190 | mirimg = torch.zeros((refmap.shape[0], refmap.shape[-3], res, res), dtype=dtype, device=device) 191 | mirimg[:, :, mask] = torch.nn.functional.grid_sample( 192 | refmap, uv.expand(refmap.size(0), 1, -1, -1), mode="bilinear", padding_mode="border", align_corners=False 193 | )[:, :, 0, :] 194 | if not batch_flag: 195 | mirimg = mirimg[0] 196 | if return_mask: 197 | return mirimg, mask 198 | return mirimg 199 | 200 | 201 | def envmap2mirmap( 202 | envmap: torch.Tensor, 203 | output_shape: Tuple[int, int], 204 | flip_horizontal: bool = False, 205 | view_from: Union[torch.Tensor, List[float]] = [1, 0, 0], 206 | top: Union[torch.Tensor, List[float]] = [0, 1, 0], 207 | envmap_zenith: Union[torch.Tensor, List[float]] = [0, 1, 0], 208 | envmap_left_edge: Union[torch.Tensor, List[float]] = [0, 0, -1], 209 | reverse_azimuth_envmap: bool = True, 210 | mitigate_aliasing: bool = True, 211 | log_scale_interpolation: bool = False, 212 | ) -> Union[np.ndarray, torch.Tensor]: 213 | device = envmap.device 214 | dtype = envmap.dtype 215 | view_from = convert_array_to_torch(view_from, device=device, dtype=dtype) 216 | top = convert_array_to_torch(top, device=device, dtype=dtype) 217 | if (torch.einsum("...i, ...i -> ...", view_from, top) != 0).any(): 218 | top = normalize(torch.cross(torch.cross(view_from, top), view_from)) 219 | height, width = envmap.shape[-2:] 220 | OH, OW = output_shape 221 | if mitigate_aliasing: 222 | H = W = min(height, width) if OH < height else max(OH, OW) 223 | else: 224 | H, W = OH, OW 225 | theta = (torch.arange(H, device=device, dtype=dtype) + 0.5) * (torch.pi / H) 226 | phi = (torch.arange(W, device=device, dtype=dtype) - (W - 1) / 2) * (torch.pi / W) 227 | thetaphi = torch.stack(torch.meshgrid(theta, phi, indexing="ij"), -1) 228 | xyz = thetaphi2xyz(thetaphi, normal=top, tangent=view_from, reverse_phi=flip_horizontal) 229 | xyz_env = normalize(2 * torch.matmul(xyz, view_from[..., None]) * xyz - view_from) 230 | thetaphi_env = xyz2thetaphi(xyz_env, normal=envmap_zenith, tangent=envmap_left_edge, reverse_phi=reverse_azimuth_envmap) 231 | u = thetaphi_env[..., 1] % (2 * torch.pi) / torch.pi - 1 232 | v = thetaphi_env[..., 0] * (2 / torch.pi) - 1 233 | uv = torch.stack([u, v], axis=-1) 234 | if uv.dim() == 3: 235 | uv = uv[None].expand(envmap.size(0), -1, -1, -1) 236 | if log_scale_interpolation: 237 | envmap = torch.log(envmap.clip(1e-7)) 238 | mirmap = torch.nn.functional.grid_sample(envmap, uv, mode="bilinear", padding_mode="border", align_corners=False) 239 | mirmap = torch.nn.functional.adaptive_avg_pool2d(mirmap, output_shape) 240 | if log_scale_interpolation: 241 | mirmap = torch.exp(mirmap) 242 | return mirmap 243 | 244 | 245 | def mirimg2envmap( 246 | refimg: torch.Tensor, 247 | output_shape: Tuple[int, int], 248 | view_from: Union[torch.Tensor, List[float]] = [0, 0, 1], 249 | top: Union[torch.Tensor, List[float]] = [0, 1, 0], 250 | envmap_zenith: Union[torch.Tensor, List[float]] = [0, 1, 0], 251 | envmap_left_edge: Union[torch.Tensor, List[float]] = [0, 0, -1], 252 | reverse_azimuth: bool = True, 253 | log_scale_interpolation: bool = False, 254 | ) -> torch.Tensor: 255 | """ 256 | realcentering 257 | refimg: [BS, channle, Height, Width] 258 | output: [BS, channle, Height, Width] 259 | """ 260 | device = refimg.device 261 | dtype = refimg.dtype 262 | view_from = convert_array_to_torch(view_from, device=device, dtype=dtype) 263 | top = convert_array_to_torch(top, device=device, dtype=dtype) 264 | if (torch.einsum("...i, ...i -> ...", view_from, top) != 0).any(): 265 | top = normalize(torch.cross(torch.cross(view_from, top), view_from)) 266 | OH, OW = output_shape 267 | theta = (torch.arange(OH, device=device, dtype=dtype) + 0.5) * (torch.pi / OH) 268 | phi = (torch.arange(OW, device=device, dtype=dtype) + 0.5) * (torch.pi * 2 / OW) 269 | thetaphi = torch.stack(torch.meshgrid(theta, phi, indexing="ij"), axis=-1) 270 | xyz = thetaphi2xyz(thetaphi, normal=envmap_zenith, tangent=envmap_left_edge, reverse_phi=reverse_azimuth) 271 | normal_map = xyz2thetaphi(normalize(xyz + view_from), normal=top, tangent=torch.cross(view_from, top)) 272 | normal_map = normal_map - (torch.pi / 2) 273 | theta, phi = normal_map[..., 0], normal_map[..., 1] 274 | v = torch.sin(theta) 275 | u = torch.cos(theta) * torch.sin(phi) 276 | uv = torch.stack([u, v], axis=-1) 277 | if uv.dim() == 3: 278 | uv = uv[None].expand(refimg.size(0), -1, -1, -1) 279 | if log_scale_interpolation: 280 | refimg = torch.log(refimg.clip(1e-7)) 281 | envmap = torch.nn.functional.grid_sample(refimg, uv, mode="bilinear", padding_mode="border", align_corners=False) 282 | if log_scale_interpolation: 283 | envmap = torch.exp(envmap) 284 | return envmap 285 | 286 | 287 | def mirimg2envmap_numpy( 288 | refimg: np.ndarray, 289 | output_shape: tuple, 290 | view_from: Union[np.ndarray, List[float]] = [0, 0, 1], 291 | top: Union[np.ndarray, List[float]] = [0, 1, 0], 292 | envmap_zenith: Union[np.ndarray, List[float]] = [0, 1, 0], 293 | envmap_left_edge: Union[np.ndarray, List[float]] = [0, 0, -1], 294 | reverse_azimuth: bool = True, 295 | log_scale_interpolation: bool = False, 296 | ) -> np.ndarray: 297 | """ 298 | realcentering 299 | refimg: [Height, Width, Channel] 300 | output: [Height, Width, Channel] 301 | """ 302 | device = "cuda" if torch.cuda.is_available() else "cpu" 303 | refimg = torch.from_numpy(refimg).permute(2, 0, 1)[None].to(device) # [1, C, H, W] 304 | envmap = mirimg2envmap( 305 | refimg, 306 | output_shape, 307 | view_from=view_from, 308 | top=top, 309 | envmap_zenith=envmap_zenith, 310 | envmap_left_edge=envmap_left_edge, 311 | reverse_azimuth=reverse_azimuth, 312 | log_scale_interpolation=log_scale_interpolation, 313 | ) 314 | return envmap[0].permute(1, 2, 0).cpu().numpy() 315 | 316 | 317 | def rotate_envmap( 318 | envmap: torch.Tensor, 319 | src_envmap_zenith: Union[np.ndarray, torch.Tensor, List[float]] = [0, 1, 0], 320 | src_envmap_left_edge: Union[np.ndarray, torch.Tensor, List[float]] = [0, 0, -1], 321 | tgt_envmap_zenith: Union[np.ndarray, torch.Tensor, List[float]] = None, 322 | tgt_envmap_left_edge: Union[np.ndarray, torch.Tensor, List[float]] = None, 323 | out_shape: Tuple = None, 324 | ): 325 | """ 326 | envmap: [BS, C, H, W] 327 | coordinates: [BS, 3] 328 | rmat: rotation matrix [BS, 3, 3] from vector on src to vec on tgt (x_t = R * x_s) 329 | """ 330 | spec_src_coord = src_envmap_zenith is not None and src_envmap_left_edge is not None 331 | spec_tgt_coord = tgt_envmap_zenith is not None and tgt_envmap_left_edge is not None 332 | 333 | height, width = envmap.shape[-2:] 334 | tgt_height, tgt_width = (height, width) if out_shape is None else out_shape 335 | 336 | device = envmap.device 337 | dtype = envmap.dtype 338 | src_envmap_zenith = convert_array_to_torch(src_envmap_zenith, device=device, dtype=dtype)[..., None, None, :] 339 | src_envmap_left_edge = convert_array_to_torch(src_envmap_left_edge, device=device, dtype=dtype)[..., None, None, :] 340 | tgt_envmap_zenith = convert_array_to_torch(tgt_envmap_zenith, device=device, dtype=dtype)[..., None, None, :] 341 | tgt_envmap_left_edge = convert_array_to_torch(tgt_envmap_left_edge, device=device, dtype=dtype)[..., None, None, :] 342 | 343 | h_shift, w_shift = torch.pi / tgt_height / 2, torch.pi / tgt_width 344 | theta = torch.linspace(h_shift, torch.pi - h_shift, tgt_height, device=device) 345 | phi = torch.linspace(w_shift, torch.pi * 2 - w_shift, tgt_width, device=device) 346 | thetaphi_map = torch.stack(torch.meshgrid(theta, phi, indexing="ij"), axis=-1) 347 | xyz_map = thetaphi2xyz(thetaphi_map, normal=tgt_envmap_zenith, tangent=tgt_envmap_left_edge, reverse_phi=True) # [BS, H, W, 3] 348 | thetaphi_map = xyz2thetaphi(xyz_map, normal=src_envmap_zenith, tangent=src_envmap_left_edge, reverse_phi=True) # [BS, H, W, 2] 349 | 350 | # [theta, phi]: [[0, pi], [-pi, pi]] -> [[-1, 1], [-1, 1]] 351 | thetaphi_map[..., 0] /= torch.pi / 2 352 | thetaphi_map[..., 0] += -1 353 | thetaphi_map[..., 1] /= torch.pi 354 | thetaphi_map[..., 1] %= 2 355 | thetaphi_map[..., 1] -= 1 356 | envmap = torch.nn.functional.grid_sample( 357 | envmap, 358 | thetaphi_map.flip(-1), 359 | mode="bilinear", 360 | padding_mode="border", 361 | align_corners=False, 362 | ) 363 | return envmap 364 | --------------------------------------------------------------------------------