├── .github └── workflows │ ├── black.yml │ ├── test-build.yaml │ └── test-inference.yml ├── .gitignore ├── .gitmodules ├── .project-root ├── 3drec ├── configs │ └── nvsadapter.yaml ├── launch.py ├── load │ ├── images │ │ ├── a_beautiful_rainbow_fish_512.png │ │ ├── a_beautiful_rainbow_fish_512_rgba.png │ │ ├── a_cozy_sofa_in_the_shape_of_a_llama.png │ │ ├── a_cozy_sofa_in_the_shape_of_a_llama2.png │ │ ├── a_cozy_sofa_in_the_shape_of_a_llama2_rgba.png │ │ └── a_cozy_sofa_in_the_shape_of_a_llama_rgba.png │ ├── lights │ │ ├── LICENSE.txt │ │ ├── bsdf_256_256.bin │ │ └── mud_road_puresky_1k.hdr │ ├── make_prompt_library.py │ ├── prompt_library.json │ ├── shapes │ │ ├── README.md │ │ ├── animal.obj │ │ ├── blub.obj │ │ ├── cabin.obj │ │ ├── env_sphere.obj │ │ ├── hand_prismatic.obj │ │ ├── human.obj │ │ ├── nascar.obj │ │ ├── potion.obj │ │ └── teddy.obj │ └── tets │ │ ├── 128_tets.npz │ │ ├── 32_tets.npz │ │ ├── 64_tets.npz │ │ └── generate_tets.py ├── requirements.txt └── threestudio │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── co3d.py │ ├── image.py │ ├── multiview.py │ └── uncond.py │ ├── models │ ├── __init__.py │ ├── background │ │ ├── __init__.py │ │ ├── base.py │ │ ├── neural_environment_map_background.py │ │ ├── solid_color_background.py │ │ └── textured_background.py │ ├── estimators.py │ ├── exporters │ │ ├── __init__.py │ │ ├── base.py │ │ └── mesh_exporter.py │ ├── geometry │ │ ├── __init__.py │ │ ├── base.py │ │ ├── custom_mesh.py │ │ ├── implicit_sdf.py │ │ ├── implicit_volume.py │ │ ├── tetrahedra_sdf_grid.py │ │ └── volume_grid.py │ ├── guidance │ │ ├── __init__.py │ │ └── nvsadapter_guidance.py │ ├── isosurface.py │ ├── materials │ │ ├── __init__.py │ │ ├── base.py │ │ ├── diffuse_with_point_light_material.py │ │ ├── hybrid_rgb_latent_material.py │ │ ├── neural_radiance_material.py │ │ ├── no_material.py │ │ ├── pbr_material.py │ │ └── sd_latent_adapter_material.py │ ├── mesh.py │ ├── networks.py │ ├── prompt_processors │ │ ├── __init__.py │ │ ├── base.py │ │ ├── deepfloyd_prompt_processor.py │ │ ├── dummy_prompt_processor.py │ │ └── stable_diffusion_prompt_processor.py │ └── renderers │ │ ├── __init__.py │ │ ├── base.py │ │ ├── deferred_volume_renderer.py │ │ ├── gan_volume_renderer.py │ │ ├── nerf_volume_renderer.py │ │ ├── neus_volume_renderer.py │ │ ├── nvdiff_rasterizer.py │ │ └── patch_renderer.py │ ├── systems │ ├── __init__.py │ ├── base.py │ ├── nvsadapter.py │ ├── optimizers.py │ └── utils.py │ └── utils │ ├── GAN │ ├── attention.py │ ├── discriminator.py │ ├── distribution.py │ ├── loss.py │ ├── mobilenet.py │ ├── network_util.py │ ├── util.py │ └── vae.py │ ├── __init__.py │ ├── base.py │ ├── callbacks.py │ ├── config.py │ ├── misc.py │ ├── ops.py │ ├── perceptual │ ├── __init__.py │ ├── perceptual.py │ └── utils.py │ ├── rasterize.py │ ├── saving.py │ └── typing.py ├── LICENSE-CODE ├── LICENSE-MODEL ├── README.md ├── assets ├── main_framework.png └── teaser.png ├── configs ├── ablation │ ├── all_to_all_attn.yaml │ ├── camera_extrinsic.yaml │ ├── no_image_attn.yaml │ ├── query_emb_scale_2.yaml │ ├── query_emb_scale_half.yaml │ └── rayo_rayd.yaml ├── base.yaml ├── base_sd15.yaml ├── num_queries │ ├── 15_queries.yaml │ ├── 16_queries.yaml │ ├── 1_queries.yaml │ ├── 2_queries.yaml │ ├── 4_queries.yaml │ └── 6_queries.yaml └── options │ ├── controlnet_canny.yaml │ ├── controlnet_canny_depth.yaml │ ├── controlnet_canny_hed.yaml │ ├── controlnet_canny_hed_depth.yaml │ ├── controlnet_depth.yaml │ ├── controlnet_hed.yaml │ ├── controlnet_hed_depth.yaml │ ├── full_finetune.yaml │ ├── lora_blueresin.yaml │ ├── lora_cofzee.yaml │ ├── lora_friedegg.yaml │ ├── lora_gelato.yaml │ ├── lora_gemstone.yaml │ ├── lora_watce.yaml │ ├── lora_wood.yaml │ └── sd15.yaml ├── data └── DejaVuSans.ttf ├── demo.py ├── licenses ├── LICENSE_DPT ├── LICENSE_SD ├── LICENSE_SD_MODEL └── LICENSE_SD_XL ├── main.py ├── requirements.txt ├── sample ├── deer.png ├── dolphine.png └── kunkun.png ├── scripts ├── convert_diffusers_to_original_stable_diffusion.py ├── convert_lora_safetensor_to_diffusers.py ├── eval_images.py ├── novel_view_sampling.py ├── objaverse_renderings_to_webdataset.py └── sampling.py └── sgm ├── __init__.py ├── data ├── __init__.py ├── cifar10.py ├── dataset.py ├── dirdataset.py ├── mnist.py ├── objaverse.py ├── single_image.py └── utils.py ├── geometry.py ├── inference ├── api.py └── helpers.py ├── lr_scheduler.py ├── models ├── __init__.py ├── autoencoder.py ├── diffusion.py └── nvsadapter.py ├── modules ├── __init__.py ├── attention.py ├── autoencoding │ ├── __init__.py │ ├── losses │ │ └── __init__.py │ ├── lpips │ │ ├── __init__.py │ │ ├── loss │ │ │ ├── .gitignore │ │ │ ├── LICENSE │ │ │ ├── __init__.py │ │ │ └── lpips.py │ │ ├── model │ │ │ ├── LICENSE │ │ │ ├── __init__.py │ │ │ └── model.py │ │ ├── util.py │ │ └── vqperceptual.py │ └── regularizers │ │ └── __init__.py ├── diffusionmodules │ ├── __init__.py │ ├── denoiser.py │ ├── denoiser_scaling.py │ ├── denoiser_weighting.py │ ├── discretizer.py │ ├── guiders.py │ ├── loss.py │ ├── model.py │ ├── openaimodel.py │ ├── sampling.py │ ├── sampling_utils.py │ ├── sigma_sampling.py │ ├── util.py │ └── wrappers.py ├── distributions │ ├── __init__.py │ └── distributions.py ├── ema.py ├── encoders │ ├── __init__.py │ └── modules.py └── nvsadapter │ ├── __init__.py │ ├── canny │ └── api.py │ ├── conditioner.py │ ├── controlnet.py │ ├── hed │ └── api.py │ ├── lora │ ├── __init__.py │ ├── lora.py │ ├── safe_open.py │ └── utils.py │ ├── midas │ ├── __init__.py │ ├── api.py │ ├── midas │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── blocks.py │ │ ├── dpt_depth.py │ │ ├── midas_net.py │ │ ├── midas_net_custom.py │ │ ├── transforms.py │ │ └── vit.py │ └── utils.py │ ├── threedim.py │ └── wrappers.py └── util.py /.github/workflows/black.yml: -------------------------------------------------------------------------------- 1 | name: Run black 2 | on: [pull_request] 3 | 4 | jobs: 5 | lint: 6 | runs-on: ubuntu-latest 7 | steps: 8 | - uses: actions/checkout@v3 9 | - name: Install venv 10 | run: | 11 | sudo apt-get -y install python3.10-venv 12 | - uses: psf/black@stable 13 | with: 14 | options: "--check --verbose -l88" 15 | src: "./sgm ./scripts ./main.py" 16 | -------------------------------------------------------------------------------- /.github/workflows/test-build.yaml: -------------------------------------------------------------------------------- 1 | name: Build package 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | 8 | jobs: 9 | build: 10 | name: Build 11 | runs-on: ubuntu-latest 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | python-version: ["3.8", "3.10"] 16 | requirements-file: ["pt2", "pt13"] 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install -r requirements/${{ matrix.requirements-file }}.txt 27 | pip install . -------------------------------------------------------------------------------- /.github/workflows/test-inference.yml: -------------------------------------------------------------------------------- 1 | name: Test inference 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - main 8 | 9 | jobs: 10 | test: 11 | name: "Test inference" 12 | # This action is designed only to run on the Stability research cluster at this time, so many assumptions are made about the environment 13 | if: github.repository == 'stability-ai/generative-models' 14 | runs-on: [self-hosted, slurm, g40] 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: "Symlink checkpoints" 18 | run: ln -s ${{vars.SGM_CHECKPOINTS_PATH}} checkpoints 19 | - name: "Setup python" 20 | uses: actions/setup-python@v4 21 | with: 22 | python-version: "3.10" 23 | - name: "Install Hatch" 24 | run: pip install hatch 25 | - name: "Run inference tests" 26 | run: hatch run ci:test-inference --junit-xml test-results.xml 27 | - name: Surface failing tests 28 | if: always() 29 | uses: pmeier/pytest-results-action@main 30 | with: 31 | path: test-results.xml 32 | summary: true 33 | display-options: fEX 34 | fail-on-empty: true 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # extensions 2 | *.egg-info 3 | *.py[cod] 4 | 5 | # envs 6 | .pt13 7 | .pt2 8 | 9 | # directories 10 | /checkpoints 11 | /dist 12 | /outputs 13 | /build 14 | /src 15 | .ipynb_checkpoints 16 | 17 | logs/ 18 | logs_viz/ 19 | logs_sampling/ 20 | lightning_logs/ 21 | logs_sampling_control 22 | notebooks/* 23 | 24 | sample/* 25 | bash_scripts/* 26 | data/* 27 | .vscode/* -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "thirdparty/carvekit"] 2 | path = thirdparty/carvekit 3 | url = https://github.com/OPHoperHPO/image-background-remove-tool.git 4 | -------------------------------------------------------------------------------- /.project-root: -------------------------------------------------------------------------------- 1 | # this file is required for inferring the project root directory 2 | # do not delete 3 | -------------------------------------------------------------------------------- /3drec/load/images/a_beautiful_rainbow_fish_512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/3drec/load/images/a_beautiful_rainbow_fish_512.png -------------------------------------------------------------------------------- /3drec/load/images/a_beautiful_rainbow_fish_512_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/3drec/load/images/a_beautiful_rainbow_fish_512_rgba.png -------------------------------------------------------------------------------- /3drec/load/images/a_cozy_sofa_in_the_shape_of_a_llama.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/3drec/load/images/a_cozy_sofa_in_the_shape_of_a_llama.png -------------------------------------------------------------------------------- /3drec/load/images/a_cozy_sofa_in_the_shape_of_a_llama2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/3drec/load/images/a_cozy_sofa_in_the_shape_of_a_llama2.png -------------------------------------------------------------------------------- /3drec/load/images/a_cozy_sofa_in_the_shape_of_a_llama2_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/3drec/load/images/a_cozy_sofa_in_the_shape_of_a_llama2_rgba.png -------------------------------------------------------------------------------- /3drec/load/images/a_cozy_sofa_in_the_shape_of_a_llama_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/3drec/load/images/a_cozy_sofa_in_the_shape_of_a_llama_rgba.png -------------------------------------------------------------------------------- /3drec/load/lights/LICENSE.txt: -------------------------------------------------------------------------------- 1 | The mud_road_puresky.hdr HDR probe is from https://polyhaven.com/a/mud_road_puresky 2 | CC0 License. 3 | -------------------------------------------------------------------------------- /3drec/load/lights/bsdf_256_256.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/3drec/load/lights/bsdf_256_256.bin -------------------------------------------------------------------------------- /3drec/load/lights/mud_road_puresky_1k.hdr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/3drec/load/lights/mud_road_puresky_1k.hdr -------------------------------------------------------------------------------- /3drec/load/shapes/README.md: -------------------------------------------------------------------------------- 1 | # Shape Credits 2 | 3 | - `animal.obj` - Ido Richardson 4 | - `hand_prismatic.obj` - Ido Richardson 5 | - `potion.obj` - Ido Richardson 6 | - `blub.obj` - [Keenan's 3D Model Repository](https://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/) 7 | - `nascar.obj` - [Princeton ModelNet](https://modelnet.cs.princeton.edu/) 8 | - `cabin.obj` - [Princeton ModelNet](https://modelnet.cs.princeton.edu/) 9 | - `teddy.obj` - [Gal Metzer](https://galmetzer.github.io/) 10 | - `human.obj` - [TurboSquid](https://www.turbosquid.com/3d-models/3d-model-character-base/524860) 11 | -------------------------------------------------------------------------------- /3drec/load/tets/128_tets.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/3drec/load/tets/128_tets.npz -------------------------------------------------------------------------------- /3drec/load/tets/32_tets.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/3drec/load/tets/32_tets.npz -------------------------------------------------------------------------------- /3drec/load/tets/64_tets.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/3drec/load/tets/64_tets.npz -------------------------------------------------------------------------------- /3drec/load/tets/generate_tets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 4 | # property and proprietary rights in and to this material, related 5 | # documentation and any modifications thereto. Any use, reproduction, 6 | # disclosure or distribution of this material and related documentation 7 | # without an express license agreement from NVIDIA CORPORATION or 8 | # its affiliates is strictly prohibited. 9 | 10 | import os 11 | 12 | import numpy as np 13 | 14 | """ 15 | This code segment shows how to use Quartet: https://github.com/crawforddoran/quartet, 16 | to generate a tet grid 17 | 1) Download, compile and run Quartet as described in the link above. Example usage `quartet meshes/cube.obj 0.5 cube_5.tet` 18 | 2) Run the function below to generate a file `cube_32_tet.tet` 19 | """ 20 | 21 | 22 | def generate_tetrahedron_grid_file(res=32, root=".."): 23 | frac = 1.0 / res 24 | command = f"cd {root}; ./quartet meshes/cube.obj {frac} meshes/cube_{res}_tet.tet -s meshes/cube_boundary_{res}.obj" 25 | os.system(command) 26 | 27 | 28 | """ 29 | This code segment shows how to convert from a quartet .tet file to compressed npz file 30 | """ 31 | 32 | 33 | def convert_from_quartet_to_npz(quartetfile="cube_32_tet.tet", npzfile="32_tets"): 34 | file1 = open(quartetfile, "r") 35 | header = file1.readline() 36 | numvertices = int(header.split(" ")[1]) 37 | numtets = int(header.split(" ")[2]) 38 | print(numvertices, numtets) 39 | 40 | # load vertices 41 | vertices = np.loadtxt(quartetfile, skiprows=1, max_rows=numvertices) 42 | print(vertices.shape) 43 | 44 | # load indices 45 | indices = np.loadtxt( 46 | quartetfile, dtype=int, skiprows=1 + numvertices, max_rows=numtets 47 | ) 48 | print(indices.shape) 49 | 50 | np.savez_compressed(npzfile, vertices=vertices, indices=indices) 51 | 52 | 53 | root = "/home/gyc/quartet" 54 | for res in [300, 350, 400]: 55 | generate_tetrahedron_grid_file(res, root) 56 | convert_from_quartet_to_npz( 57 | os.path.join(root, f"meshes/cube_{res}_tet.tet"), npzfile=f"{res}_tets" 58 | ) 59 | -------------------------------------------------------------------------------- /3drec/requirements.txt: -------------------------------------------------------------------------------- 1 | lightning==2.0.0 2 | omegaconf==2.3.0 3 | jaxtyping 4 | typeguard 5 | git+https://github.com/KAIR-BAIR/nerfacc.git@v0.5.2 6 | git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch 7 | diffusers 8 | transformers==4.28.1 9 | accelerate 10 | opencv-python 11 | tensorboard 12 | matplotlib 13 | imageio>=2.28.0 14 | imageio[ffmpeg] 15 | git+https://github.com/NVlabs/nvdiffrast.git 16 | libigl 17 | xatlas 18 | trimesh[easy] 19 | networkx 20 | pysdf 21 | PyMCubes 22 | wandb 23 | gradio 24 | git+https://github.com/ashawkey/envlight.git 25 | torchmetrics 26 | 27 | # deepfloyd 28 | xformers 29 | bitsandbytes==0.38.1 30 | sentencepiece 31 | safetensors 32 | huggingface_hub 33 | 34 | # for zero123 35 | einops 36 | kornia 37 | taming-transformers-rom1504 38 | git+https://github.com/openai/CLIP.git 39 | 40 | #controlnet 41 | controlnet_aux 42 | 43 | # nvs-adapter 44 | open-clip-torch>=2.20.0 45 | -------------------------------------------------------------------------------- /3drec/threestudio/__init__.py: -------------------------------------------------------------------------------- 1 | __modules__ = {} 2 | 3 | 4 | def register(name): 5 | def decorator(cls): 6 | __modules__[name] = cls 7 | return cls 8 | 9 | return decorator 10 | 11 | 12 | def find(name): 13 | return __modules__[name] 14 | 15 | 16 | ### grammar sugar for logging utilities ### 17 | import logging 18 | 19 | logger = logging.getLogger("pytorch_lightning") 20 | 21 | from pytorch_lightning.utilities.rank_zero import ( 22 | rank_zero_debug, 23 | rank_zero_info, 24 | rank_zero_only, 25 | ) 26 | 27 | debug = rank_zero_debug 28 | info = rank_zero_info 29 | 30 | 31 | @rank_zero_only 32 | def warn(*args, **kwargs): 33 | logger.warn(*args, **kwargs) 34 | 35 | 36 | from . import data, models, systems 37 | -------------------------------------------------------------------------------- /3drec/threestudio/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import co3d, image, multiview, uncond 2 | -------------------------------------------------------------------------------- /3drec/threestudio/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | background, 3 | exporters, 4 | geometry, 5 | guidance, 6 | materials, 7 | prompt_processors, 8 | renderers, 9 | ) 10 | -------------------------------------------------------------------------------- /3drec/threestudio/models/background/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | base, 3 | neural_environment_map_background, 4 | solid_color_background, 5 | textured_background, 6 | ) 7 | -------------------------------------------------------------------------------- /3drec/threestudio/models/background/base.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.utils.base import BaseModule 10 | from threestudio.utils.typing import * 11 | 12 | 13 | class BaseBackground(BaseModule): 14 | @dataclass 15 | class Config(BaseModule.Config): 16 | pass 17 | 18 | cfg: Config 19 | 20 | def configure(self): 21 | pass 22 | 23 | def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]: 24 | raise NotImplementedError 25 | -------------------------------------------------------------------------------- /3drec/threestudio/models/background/neural_environment_map_background.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.background.base import BaseBackground 10 | from threestudio.models.networks import get_encoding, get_mlp 11 | from threestudio.utils.ops import get_activation 12 | from threestudio.utils.typing import * 13 | 14 | 15 | @threestudio.register("neural-environment-map-background") 16 | class NeuralEnvironmentMapBackground(BaseBackground): 17 | @dataclass 18 | class Config(BaseBackground.Config): 19 | n_output_dims: int = 3 20 | color_activation: str = "sigmoid" 21 | dir_encoding_config: dict = field( 22 | default_factory=lambda: {"otype": "SphericalHarmonics", "degree": 3} 23 | ) 24 | mlp_network_config: dict = field( 25 | default_factory=lambda: { 26 | "otype": "VanillaMLP", 27 | "activation": "ReLU", 28 | "n_neurons": 16, 29 | "n_hidden_layers": 2, 30 | } 31 | ) 32 | random_aug: bool = False 33 | random_aug_prob: float = 0.5 34 | eval_color: Optional[Tuple[float, float, float]] = None 35 | 36 | cfg: Config 37 | 38 | def configure(self) -> None: 39 | self.encoding = get_encoding(3, self.cfg.dir_encoding_config) 40 | self.network = get_mlp( 41 | self.encoding.n_output_dims, 42 | self.cfg.n_output_dims, 43 | self.cfg.mlp_network_config, 44 | ) 45 | 46 | def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]: 47 | if not self.training and self.cfg.eval_color is not None: 48 | return torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to( 49 | dirs 50 | ) * torch.as_tensor(self.cfg.eval_color).to(dirs) 51 | # viewdirs must be normalized before passing to this function 52 | dirs = (dirs + 1.0) / 2.0 # (-1, 1) => (0, 1) 53 | dirs_embd = self.encoding(dirs.view(-1, 3)) 54 | color = self.network(dirs_embd).view(*dirs.shape[:-1], self.cfg.n_output_dims) 55 | color = get_activation(self.cfg.color_activation)(color) 56 | if ( 57 | self.training 58 | and self.cfg.random_aug 59 | and random.random() < self.cfg.random_aug_prob 60 | ): 61 | # use random background color with probability random_aug_prob 62 | color = color * 0 + ( # prevent checking for unused parameters in DDP 63 | torch.rand(dirs.shape[0], 1, 1, self.cfg.n_output_dims) 64 | .to(dirs) 65 | .expand(*dirs.shape[:-1], -1) 66 | ) 67 | return color 68 | -------------------------------------------------------------------------------- /3drec/threestudio/models/background/solid_color_background.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.background.base import BaseBackground 10 | from threestudio.utils.typing import * 11 | 12 | 13 | @threestudio.register("solid-color-background") 14 | class SolidColorBackground(BaseBackground): 15 | @dataclass 16 | class Config(BaseBackground.Config): 17 | n_output_dims: int = 3 18 | color: Tuple = (1.0, 1.0, 1.0) 19 | learned: bool = False 20 | random_aug: bool = False 21 | random_aug_prob: float = 0.5 22 | 23 | cfg: Config 24 | 25 | def configure(self) -> None: 26 | self.env_color: Float[Tensor, "Nc"] 27 | if self.cfg.learned: 28 | self.env_color = nn.Parameter( 29 | torch.as_tensor(self.cfg.color, dtype=torch.float32) 30 | ) 31 | else: 32 | self.register_buffer( 33 | "env_color", torch.as_tensor(self.cfg.color, dtype=torch.float32) 34 | ) 35 | 36 | def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]: 37 | color = ( 38 | torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to(dirs) 39 | * self.env_color 40 | ) 41 | if ( 42 | self.training 43 | and self.cfg.random_aug 44 | and random.random() < self.cfg.random_aug_prob 45 | ): 46 | # use random background color with probability random_aug_prob 47 | color = color * 0 + ( # prevent checking for unused parameters in DDP 48 | torch.rand(dirs.shape[0], 1, 1, self.cfg.n_output_dims) 49 | .to(dirs) 50 | .expand(*dirs.shape[:-1], -1) 51 | ) 52 | return color 53 | -------------------------------------------------------------------------------- /3drec/threestudio/models/background/textured_background.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import threestudio 8 | from threestudio.models.background.base import BaseBackground 9 | from threestudio.utils.ops import get_activation 10 | from threestudio.utils.typing import * 11 | 12 | 13 | @threestudio.register("textured-background") 14 | class TexturedBackground(BaseBackground): 15 | @dataclass 16 | class Config(BaseBackground.Config): 17 | n_output_dims: int = 3 18 | height: int = 64 19 | width: int = 64 20 | color_activation: str = "sigmoid" 21 | 22 | cfg: Config 23 | 24 | def configure(self) -> None: 25 | self.texture = nn.Parameter( 26 | torch.randn((1, self.cfg.n_output_dims, self.cfg.height, self.cfg.width)) 27 | ) 28 | 29 | def spherical_xyz_to_uv(self, dirs: Float[Tensor, "*B 3"]) -> Float[Tensor, "*B 2"]: 30 | x, y, z = dirs[..., 0], dirs[..., 1], dirs[..., 2] 31 | xy = (x**2 + y**2) ** 0.5 32 | u = torch.atan2(xy, z) / torch.pi 33 | v = torch.atan2(y, x) / (torch.pi * 2) + 0.5 34 | uv = torch.stack([u, v], -1) 35 | return uv 36 | 37 | def forward(self, dirs: Float[Tensor, "*B 3"]) -> Float[Tensor, "*B Nc"]: 38 | dirs_shape = dirs.shape[:-1] 39 | uv = self.spherical_xyz_to_uv(dirs.reshape(-1, dirs.shape[-1])) 40 | uv = 2 * uv - 1 # rescale to [-1, 1] for grid_sample 41 | uv = uv.reshape(1, -1, 1, 2) 42 | color = ( 43 | F.grid_sample( 44 | self.texture, 45 | uv, 46 | mode="bilinear", 47 | padding_mode="reflection", 48 | align_corners=False, 49 | ) 50 | .reshape(self.cfg.n_output_dims, -1) 51 | .T.reshape(*dirs_shape, self.cfg.n_output_dims) 52 | ) 53 | color = get_activation(self.cfg.color_activation)(color) 54 | return color 55 | -------------------------------------------------------------------------------- /3drec/threestudio/models/estimators.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional, Tuple 2 | 3 | try: 4 | from typing import Literal 5 | except ImportError: 6 | from typing_extensions import Literal 7 | 8 | import torch 9 | from nerfacc.data_specs import RayIntervals 10 | from nerfacc.estimators.base import AbstractEstimator 11 | from nerfacc.pdf import importance_sampling, searchsorted 12 | from nerfacc.volrend import render_transmittance_from_density 13 | from torch import Tensor 14 | 15 | 16 | class ImportanceEstimator(AbstractEstimator): 17 | def __init__( 18 | self, 19 | ) -> None: 20 | super().__init__() 21 | 22 | @torch.no_grad() 23 | def sampling( 24 | self, 25 | prop_sigma_fns: List[Callable], 26 | prop_samples: List[int], 27 | num_samples: int, 28 | # rendering options 29 | n_rays: int, 30 | near_plane: float, 31 | far_plane: float, 32 | sampling_type: Literal["uniform", "lindisp"] = "uniform", 33 | # training options 34 | stratified: bool = False, 35 | requires_grad: bool = False, 36 | ) -> Tuple[Tensor, Tensor]: 37 | """Sampling with CDFs from proposal networks. 38 | 39 | Args: 40 | prop_sigma_fns: Proposal network evaluate functions. It should be a list 41 | of functions that take in samples {t_starts (n_rays, n_samples), 42 | t_ends (n_rays, n_samples)} and returns the post-activation densities 43 | (n_rays, n_samples). 44 | prop_samples: Number of samples to draw from each proposal network. Should 45 | be the same length as `prop_sigma_fns`. 46 | num_samples: Number of samples to draw in the end. 47 | n_rays: Number of rays. 48 | near_plane: Near plane. 49 | far_plane: Far plane. 50 | sampling_type: Sampling type. Either "uniform" or "lindisp". Default to 51 | "lindisp". 52 | stratified: Whether to use stratified sampling. Default to `False`. 53 | 54 | Returns: 55 | A tuple of {Tensor, Tensor}: 56 | 57 | - **t_starts**: The starts of the samples. Shape (n_rays, num_samples). 58 | - **t_ends**: The ends of the samples. Shape (n_rays, num_samples). 59 | 60 | """ 61 | assert len(prop_sigma_fns) == len(prop_samples), ( 62 | "The number of proposal networks and the number of samples " 63 | "should be the same." 64 | ) 65 | cdfs = torch.cat( 66 | [ 67 | torch.zeros((n_rays, 1), device=self.device), 68 | torch.ones((n_rays, 1), device=self.device), 69 | ], 70 | dim=-1, 71 | ) 72 | intervals = RayIntervals(vals=cdfs) 73 | 74 | for level_fn, level_samples in zip(prop_sigma_fns, prop_samples): 75 | intervals, _ = importance_sampling( 76 | intervals, cdfs, level_samples, stratified 77 | ) 78 | t_vals = _transform_stot( 79 | sampling_type, intervals.vals, near_plane, far_plane 80 | ) 81 | t_starts = t_vals[..., :-1] 82 | t_ends = t_vals[..., 1:] 83 | 84 | with torch.set_grad_enabled(requires_grad): 85 | sigmas = level_fn(t_starts, t_ends) 86 | assert sigmas.shape == t_starts.shape 87 | trans, _ = render_transmittance_from_density(t_starts, t_ends, sigmas) 88 | cdfs = 1.0 - torch.cat([trans, torch.zeros_like(trans[:, :1])], dim=-1) 89 | 90 | intervals, _ = importance_sampling(intervals, cdfs, num_samples, stratified) 91 | t_vals_fine = _transform_stot( 92 | sampling_type, intervals.vals, near_plane, far_plane 93 | ) 94 | 95 | t_vals = torch.cat([t_vals, t_vals_fine], dim=-1) 96 | t_vals, _ = torch.sort(t_vals, dim=-1) 97 | 98 | t_starts_ = t_vals[..., :-1] 99 | t_ends_ = t_vals[..., 1:] 100 | 101 | return t_starts_, t_ends_ 102 | 103 | 104 | def _transform_stot( 105 | transform_type: Literal["uniform", "lindisp"], 106 | s_vals: torch.Tensor, 107 | t_min: torch.Tensor, 108 | t_max: torch.Tensor, 109 | ) -> torch.Tensor: 110 | if transform_type == "uniform": 111 | _contract_fn, _icontract_fn = lambda x: x, lambda x: x 112 | elif transform_type == "lindisp": 113 | _contract_fn, _icontract_fn = lambda x: 1 / x, lambda x: 1 / x 114 | else: 115 | raise ValueError(f"Unknown transform_type: {transform_type}") 116 | s_min, s_max = _contract_fn(t_min), _contract_fn(t_max) 117 | icontract_fn = lambda s: _icontract_fn(s * s_max + (1 - s) * s_min) 118 | return icontract_fn(s_vals) 119 | -------------------------------------------------------------------------------- /3drec/threestudio/models/exporters/__init__.py: -------------------------------------------------------------------------------- 1 | from . import base, mesh_exporter 2 | -------------------------------------------------------------------------------- /3drec/threestudio/models/exporters/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import threestudio 4 | from threestudio.models.background.base import BaseBackground 5 | from threestudio.models.geometry.base import BaseImplicitGeometry 6 | from threestudio.models.materials.base import BaseMaterial 7 | from threestudio.utils.base import BaseObject 8 | from threestudio.utils.typing import * 9 | 10 | 11 | @dataclass 12 | class ExporterOutput: 13 | save_name: str 14 | save_type: str 15 | params: Dict[str, Any] 16 | 17 | 18 | class Exporter(BaseObject): 19 | @dataclass 20 | class Config(BaseObject.Config): 21 | save_video: bool = False 22 | 23 | cfg: Config 24 | 25 | def configure( 26 | self, 27 | geometry: BaseImplicitGeometry, 28 | material: BaseMaterial, 29 | background: BaseBackground, 30 | ) -> None: 31 | @dataclass 32 | class SubModules: 33 | geometry: BaseImplicitGeometry 34 | material: BaseMaterial 35 | background: BaseBackground 36 | 37 | self.sub_modules = SubModules(geometry, material, background) 38 | 39 | @property 40 | def geometry(self) -> BaseImplicitGeometry: 41 | return self.sub_modules.geometry 42 | 43 | @property 44 | def material(self) -> BaseMaterial: 45 | return self.sub_modules.material 46 | 47 | @property 48 | def background(self) -> BaseBackground: 49 | return self.sub_modules.background 50 | 51 | def __call__(self, *args, **kwargs) -> List[ExporterOutput]: 52 | raise NotImplementedError 53 | 54 | 55 | @threestudio.register("dummy-exporter") 56 | class DummyExporter(Exporter): 57 | def __call__(self, *args, **kwargs) -> List[ExporterOutput]: 58 | # DummyExporter does not export anything 59 | return [] 60 | -------------------------------------------------------------------------------- /3drec/threestudio/models/geometry/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | base, 3 | custom_mesh, 4 | implicit_sdf, 5 | implicit_volume, 6 | tetrahedra_sdf_grid, 7 | volume_grid, 8 | ) 9 | -------------------------------------------------------------------------------- /3drec/threestudio/models/guidance/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | nvsadapter_guidance, 3 | ) 4 | -------------------------------------------------------------------------------- /3drec/threestudio/models/materials/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | base, 3 | diffuse_with_point_light_material, 4 | hybrid_rgb_latent_material, 5 | neural_radiance_material, 6 | no_material, 7 | pbr_material, 8 | sd_latent_adapter_material, 9 | ) 10 | -------------------------------------------------------------------------------- /3drec/threestudio/models/materials/base.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.utils.base import BaseModule 10 | from threestudio.utils.typing import * 11 | 12 | 13 | class BaseMaterial(BaseModule): 14 | @dataclass 15 | class Config(BaseModule.Config): 16 | pass 17 | 18 | cfg: Config 19 | requires_normal: bool = False 20 | requires_tangent: bool = False 21 | 22 | def configure(self): 23 | pass 24 | 25 | def forward(self, *args, **kwargs) -> Float[Tensor, "*B 3"]: 26 | raise NotImplementedError 27 | 28 | def export(self, *args, **kwargs) -> Dict[str, Any]: 29 | return {} 30 | -------------------------------------------------------------------------------- /3drec/threestudio/models/materials/diffuse_with_point_light_material.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.utils.ops import dot, get_activation 11 | from threestudio.utils.typing import * 12 | 13 | 14 | @threestudio.register("diffuse-with-point-light-material") 15 | class DiffuseWithPointLightMaterial(BaseMaterial): 16 | @dataclass 17 | class Config(BaseMaterial.Config): 18 | ambient_light_color: Tuple[float, float, float] = (0.1, 0.1, 0.1) 19 | diffuse_light_color: Tuple[float, float, float] = (0.9, 0.9, 0.9) 20 | ambient_only_steps: int = 1000 21 | diffuse_prob: float = 0.75 22 | textureless_prob: float = 0.5 23 | albedo_activation: str = "sigmoid" 24 | soft_shading: bool = False 25 | 26 | cfg: Config 27 | 28 | def configure(self) -> None: 29 | self.requires_normal = True 30 | 31 | self.ambient_light_color: Float[Tensor, "3"] 32 | self.register_buffer( 33 | "ambient_light_color", 34 | torch.as_tensor(self.cfg.ambient_light_color, dtype=torch.float32), 35 | ) 36 | self.diffuse_light_color: Float[Tensor, "3"] 37 | self.register_buffer( 38 | "diffuse_light_color", 39 | torch.as_tensor(self.cfg.diffuse_light_color, dtype=torch.float32), 40 | ) 41 | self.ambient_only = False 42 | 43 | def forward( 44 | self, 45 | features: Float[Tensor, "B ... Nf"], 46 | positions: Float[Tensor, "B ... 3"], 47 | shading_normal: Float[Tensor, "B ... 3"], 48 | light_positions: Float[Tensor, "B ... 3"], 49 | ambient_ratio: Optional[float] = None, 50 | shading: Optional[str] = None, 51 | **kwargs, 52 | ) -> Float[Tensor, "B ... 3"]: 53 | albedo = get_activation(self.cfg.albedo_activation)(features[..., :3]) 54 | 55 | if ambient_ratio is not None: 56 | # if ambient ratio is specified, use it 57 | diffuse_light_color = (1 - ambient_ratio) * torch.ones_like( 58 | self.diffuse_light_color 59 | ) 60 | ambient_light_color = ambient_ratio * torch.ones_like( 61 | self.ambient_light_color 62 | ) 63 | elif self.training and self.cfg.soft_shading: 64 | # otherwise if in training and soft shading is enabled, random a ambient ratio 65 | diffuse_light_color = torch.full_like( 66 | self.diffuse_light_color, random.random() 67 | ) 68 | ambient_light_color = 1.0 - diffuse_light_color 69 | else: 70 | # otherwise use the default fixed values 71 | diffuse_light_color = self.diffuse_light_color 72 | ambient_light_color = self.ambient_light_color 73 | 74 | light_directions: Float[Tensor, "B ... 3"] = F.normalize( 75 | light_positions - positions, dim=-1 76 | ) 77 | diffuse_light: Float[Tensor, "B ... 3"] = ( 78 | dot(shading_normal, light_directions).clamp(min=0.0) * diffuse_light_color 79 | ) 80 | textureless_color = diffuse_light + ambient_light_color 81 | # clamp albedo to [0, 1] to compute shading 82 | color = albedo.clamp(0.0, 1.0) * textureless_color 83 | 84 | if shading is None: 85 | if self.training: 86 | # adopt the same type of augmentation for the whole batch 87 | if self.ambient_only or random.random() > self.cfg.diffuse_prob: 88 | shading = "albedo" 89 | elif random.random() < self.cfg.textureless_prob: 90 | shading = "textureless" 91 | else: 92 | shading = "diffuse" 93 | else: 94 | if self.ambient_only: 95 | shading = "albedo" 96 | else: 97 | # return shaded color by default in evaluation 98 | shading = "diffuse" 99 | 100 | # multiply by 0 to prevent checking for unused parameters in DDP 101 | if shading == "albedo": 102 | return albedo + textureless_color * 0 103 | elif shading == "textureless": 104 | return albedo * 0 + textureless_color 105 | elif shading == "diffuse": 106 | return color 107 | else: 108 | raise ValueError(f"Unknown shading type {shading}") 109 | 110 | def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): 111 | if global_step < self.cfg.ambient_only_steps: 112 | self.ambient_only = True 113 | else: 114 | self.ambient_only = False 115 | 116 | def export(self, features: Float[Tensor, "*N Nf"], **kwargs) -> Dict[str, Any]: 117 | albedo = get_activation(self.cfg.albedo_activation)(features[..., :3]).clamp( 118 | 0.0, 1.0 119 | ) 120 | return {"albedo": albedo} 121 | -------------------------------------------------------------------------------- /3drec/threestudio/models/materials/hybrid_rgb_latent_material.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.models.networks import get_encoding, get_mlp 11 | from threestudio.utils.ops import dot, get_activation 12 | from threestudio.utils.typing import * 13 | 14 | 15 | @threestudio.register("hybrid-rgb-latent-material") 16 | class HybridRGBLatentMaterial(BaseMaterial): 17 | @dataclass 18 | class Config(BaseMaterial.Config): 19 | n_output_dims: int = 3 20 | color_activation: str = "sigmoid" 21 | requires_normal: bool = True 22 | 23 | cfg: Config 24 | 25 | def configure(self) -> None: 26 | self.requires_normal = self.cfg.requires_normal 27 | 28 | def forward( 29 | self, features: Float[Tensor, "B ... Nf"], **kwargs 30 | ) -> Float[Tensor, "B ... Nc"]: 31 | assert ( 32 | features.shape[-1] == self.cfg.n_output_dims 33 | ), f"Expected {self.cfg.n_output_dims} output dims, only got {features.shape[-1]} dims input." 34 | color = features 35 | color[..., :3] = get_activation(self.cfg.color_activation)(color[..., :3]) 36 | return color 37 | -------------------------------------------------------------------------------- /3drec/threestudio/models/materials/neural_radiance_material.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.models.networks import get_encoding, get_mlp 11 | from threestudio.utils.ops import dot, get_activation 12 | from threestudio.utils.typing import * 13 | 14 | 15 | @threestudio.register("neural-radiance-material") 16 | class NeuralRadianceMaterial(BaseMaterial): 17 | @dataclass 18 | class Config(BaseMaterial.Config): 19 | input_feature_dims: int = 8 20 | color_activation: str = "sigmoid" 21 | dir_encoding_config: dict = field( 22 | default_factory=lambda: {"otype": "SphericalHarmonics", "degree": 3} 23 | ) 24 | mlp_network_config: dict = field( 25 | default_factory=lambda: { 26 | "otype": "FullyFusedMLP", 27 | "activation": "ReLU", 28 | "n_neurons": 16, 29 | "n_hidden_layers": 2, 30 | } 31 | ) 32 | 33 | cfg: Config 34 | 35 | def configure(self) -> None: 36 | self.encoding = get_encoding(3, self.cfg.dir_encoding_config) 37 | self.n_input_dims = self.cfg.input_feature_dims + self.encoding.n_output_dims # type: ignore 38 | self.network = get_mlp(self.n_input_dims, 3, self.cfg.mlp_network_config) 39 | 40 | def forward( 41 | self, 42 | features: Float[Tensor, "*B Nf"], 43 | viewdirs: Float[Tensor, "*B 3"], 44 | **kwargs, 45 | ) -> Float[Tensor, "*B 3"]: 46 | # viewdirs and normals must be normalized before passing to this function 47 | viewdirs = (viewdirs + 1.0) / 2.0 # (-1, 1) => (0, 1) 48 | viewdirs_embd = self.encoding(viewdirs.view(-1, 3)) 49 | network_inp = torch.cat( 50 | [features.view(-1, features.shape[-1]), viewdirs_embd], dim=-1 51 | ) 52 | color = self.network(network_inp).view(*features.shape[:-1], 3) 53 | color = get_activation(self.cfg.color_activation)(color) 54 | return color 55 | -------------------------------------------------------------------------------- /3drec/threestudio/models/materials/no_material.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.models.networks import get_encoding, get_mlp 11 | from threestudio.utils.ops import dot, get_activation 12 | from threestudio.utils.typing import * 13 | 14 | 15 | @threestudio.register("no-material") 16 | class NoMaterial(BaseMaterial): 17 | @dataclass 18 | class Config(BaseMaterial.Config): 19 | n_output_dims: int = 3 20 | color_activation: str = "sigmoid" 21 | input_feature_dims: Optional[int] = None 22 | mlp_network_config: Optional[dict] = None 23 | requires_normal: bool = False 24 | 25 | cfg: Config 26 | 27 | def configure(self) -> None: 28 | self.use_network = False 29 | if ( 30 | self.cfg.input_feature_dims is not None 31 | and self.cfg.mlp_network_config is not None 32 | ): 33 | self.network = get_mlp( 34 | self.cfg.input_feature_dims, 35 | self.cfg.n_output_dims, 36 | self.cfg.mlp_network_config, 37 | ) 38 | self.use_network = True 39 | self.requires_normal = self.cfg.requires_normal 40 | 41 | def forward( 42 | self, features: Float[Tensor, "B ... Nf"], **kwargs 43 | ) -> Float[Tensor, "B ... Nc"]: 44 | if not self.use_network: 45 | assert ( 46 | features.shape[-1] == self.cfg.n_output_dims 47 | ), f"Expected {self.cfg.n_output_dims} output dims, only got {features.shape[-1]} dims input." 48 | color = get_activation(self.cfg.color_activation)(features) 49 | else: 50 | color = self.network(features.view(-1, features.shape[-1])).view( 51 | *features.shape[:-1], self.cfg.n_output_dims 52 | ) 53 | color = get_activation(self.cfg.color_activation)(color) 54 | return color 55 | 56 | def export(self, features: Float[Tensor, "*N Nf"], **kwargs) -> Dict[str, Any]: 57 | color = self(features, **kwargs).clamp(0, 1) 58 | assert color.shape[-1] >= 3, "Output color must have at least 3 channels" 59 | if color.shape[-1] > 3: 60 | threestudio.warn( 61 | "Output color has >3 channels, treating the first 3 as RGB" 62 | ) 63 | return {"albedo": color[..., :3]} 64 | -------------------------------------------------------------------------------- /3drec/threestudio/models/materials/sd_latent_adapter_material.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.utils.typing import * 11 | 12 | 13 | @threestudio.register("sd-latent-adapter-material") 14 | class StableDiffusionLatentAdapterMaterial(BaseMaterial): 15 | @dataclass 16 | class Config(BaseMaterial.Config): 17 | pass 18 | 19 | cfg: Config 20 | 21 | def configure(self) -> None: 22 | adapter = nn.Parameter( 23 | torch.as_tensor( 24 | [ 25 | # R G B 26 | [0.298, 0.207, 0.208], # L1 27 | [0.187, 0.286, 0.173], # L2 28 | [-0.158, 0.189, 0.264], # L3 29 | [-0.184, -0.271, -0.473], # L4 30 | ] 31 | ) 32 | ) 33 | self.register_parameter("adapter", adapter) 34 | 35 | def forward( 36 | self, features: Float[Tensor, "B ... 4"], **kwargs 37 | ) -> Float[Tensor, "B ... 3"]: 38 | assert features.shape[-1] == 4 39 | color = features @ self.adapter 40 | color = (color + 1) / 2 41 | color = color.clamp(0.0, 1.0) 42 | return color 43 | -------------------------------------------------------------------------------- /3drec/threestudio/models/prompt_processors/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | base, 3 | deepfloyd_prompt_processor, 4 | dummy_prompt_processor, 5 | stable_diffusion_prompt_processor, 6 | ) 7 | -------------------------------------------------------------------------------- /3drec/threestudio/models/prompt_processors/deepfloyd_prompt_processor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass 4 | 5 | import torch 6 | import torch.nn as nn 7 | from diffusers import IFPipeline 8 | from transformers import T5EncoderModel, T5Tokenizer 9 | 10 | import threestudio 11 | from threestudio.models.prompt_processors.base import PromptProcessor, hash_prompt 12 | from threestudio.utils.misc import cleanup 13 | from threestudio.utils.typing import * 14 | 15 | 16 | @threestudio.register("deep-floyd-prompt-processor") 17 | class DeepFloydPromptProcessor(PromptProcessor): 18 | @dataclass 19 | class Config(PromptProcessor.Config): 20 | pretrained_model_name_or_path: str = "DeepFloyd/IF-I-XL-v1.0" 21 | 22 | cfg: Config 23 | 24 | ### these functions are unused, kept for debugging ### 25 | def configure_text_encoder(self) -> None: 26 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 27 | self.text_encoder = T5EncoderModel.from_pretrained( 28 | self.cfg.pretrained_model_name_or_path, 29 | subfolder="text_encoder", 30 | load_in_8bit=True, 31 | variant="8bit", 32 | device_map="auto", 33 | ) # FIXME: behavior of auto device map in multi-GPU training 34 | self.pipe = IFPipeline.from_pretrained( 35 | self.cfg.pretrained_model_name_or_path, 36 | text_encoder=self.text_encoder, # pass the previously instantiated 8bit text encoder 37 | unet=None, 38 | ) 39 | 40 | def destroy_text_encoder(self) -> None: 41 | del self.text_encoder 42 | del self.pipe 43 | cleanup() 44 | 45 | def get_text_embeddings( 46 | self, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] 47 | ) -> Tuple[Float[Tensor, "B 77 4096"], Float[Tensor, "B 77 4096"]]: 48 | text_embeddings, uncond_text_embeddings = self.pipe.encode_prompt( 49 | prompt=prompt, negative_prompt=negative_prompt, device=self.device 50 | ) 51 | return text_embeddings, uncond_text_embeddings 52 | 53 | ### 54 | 55 | @staticmethod 56 | def spawn_func(pretrained_model_name_or_path, prompts, cache_dir): 57 | max_length = 77 58 | tokenizer = T5Tokenizer.from_pretrained( 59 | pretrained_model_name_or_path, subfolder="tokenizer" 60 | ) 61 | text_encoder = T5EncoderModel.from_pretrained( 62 | pretrained_model_name_or_path, 63 | subfolder="text_encoder", 64 | torch_dtype=torch.float16, # suppress warning 65 | load_in_8bit=True, 66 | variant="8bit", 67 | device_map="auto", 68 | ) 69 | with torch.no_grad(): 70 | text_inputs = tokenizer( 71 | prompts, 72 | padding="max_length", 73 | max_length=max_length, 74 | truncation=True, 75 | add_special_tokens=True, 76 | return_tensors="pt", 77 | ) 78 | text_input_ids = text_inputs.input_ids 79 | attention_mask = text_inputs.attention_mask 80 | text_embeddings = text_encoder( 81 | text_input_ids.to(text_encoder.device), 82 | attention_mask=attention_mask.to(text_encoder.device), 83 | ) 84 | text_embeddings = text_embeddings[0] 85 | 86 | for prompt, embedding in zip(prompts, text_embeddings): 87 | torch.save( 88 | embedding, 89 | os.path.join( 90 | cache_dir, 91 | f"{hash_prompt(pretrained_model_name_or_path, prompt)}.pt", 92 | ), 93 | ) 94 | 95 | del text_encoder 96 | -------------------------------------------------------------------------------- /3drec/threestudio/models/prompt_processors/dummy_prompt_processor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass 4 | 5 | import threestudio 6 | from threestudio.models.prompt_processors.base import PromptProcessor, hash_prompt 7 | from threestudio.utils.misc import cleanup 8 | from threestudio.utils.typing import * 9 | 10 | 11 | @threestudio.register("dummy-prompt-processor") 12 | class DummyPromptProcessor(PromptProcessor): 13 | @dataclass 14 | class Config(PromptProcessor.Config): 15 | pretrained_model_name_or_path: str = "" 16 | prompt: str = "" 17 | 18 | cfg: Config 19 | -------------------------------------------------------------------------------- /3drec/threestudio/models/prompt_processors/stable_diffusion_prompt_processor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass 4 | 5 | import torch 6 | import torch.nn as nn 7 | from transformers import AutoTokenizer, CLIPTextModel 8 | 9 | import threestudio 10 | from threestudio.models.prompt_processors.base import PromptProcessor, hash_prompt 11 | from threestudio.utils.misc import cleanup 12 | from threestudio.utils.typing import * 13 | 14 | 15 | @threestudio.register("stable-diffusion-prompt-processor") 16 | class StableDiffusionPromptProcessor(PromptProcessor): 17 | @dataclass 18 | class Config(PromptProcessor.Config): 19 | pass 20 | 21 | cfg: Config 22 | 23 | ### these functions are unused, kept for debugging ### 24 | def configure_text_encoder(self) -> None: 25 | self.tokenizer = AutoTokenizer.from_pretrained( 26 | self.cfg.pretrained_model_name_or_path, subfolder="tokenizer" 27 | ) 28 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 29 | self.text_encoder = CLIPTextModel.from_pretrained( 30 | self.cfg.pretrained_model_name_or_path, subfolder="text_encoder" 31 | ).to(self.device) 32 | 33 | for p in self.text_encoder.parameters(): 34 | p.requires_grad_(False) 35 | 36 | def destroy_text_encoder(self) -> None: 37 | del self.tokenizer 38 | del self.text_encoder 39 | cleanup() 40 | 41 | def get_text_embeddings( 42 | self, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] 43 | ) -> Tuple[Float[Tensor, "B 77 768"], Float[Tensor, "B 77 768"]]: 44 | if isinstance(prompt, str): 45 | prompt = [prompt] 46 | if isinstance(negative_prompt, str): 47 | negative_prompt = [negative_prompt] 48 | # Tokenize text and get embeddings 49 | tokens = self.tokenizer( 50 | prompt, 51 | padding="max_length", 52 | max_length=self.tokenizer.model_max_length, 53 | return_tensors="pt", 54 | ) 55 | uncond_tokens = self.tokenizer( 56 | negative_prompt, 57 | padding="max_length", 58 | max_length=self.tokenizer.model_max_length, 59 | return_tensors="pt", 60 | ) 61 | 62 | with torch.no_grad(): 63 | text_embeddings = self.text_encoder(tokens.input_ids.to(self.device))[0] 64 | uncond_text_embeddings = self.text_encoder( 65 | uncond_tokens.input_ids.to(self.device) 66 | )[0] 67 | 68 | return text_embeddings, uncond_text_embeddings 69 | 70 | ### 71 | 72 | @staticmethod 73 | def spawn_func(pretrained_model_name_or_path, prompts, cache_dir): 74 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 75 | tokenizer = AutoTokenizer.from_pretrained( 76 | pretrained_model_name_or_path, subfolder="tokenizer" 77 | ) 78 | text_encoder = CLIPTextModel.from_pretrained( 79 | pretrained_model_name_or_path, 80 | subfolder="text_encoder", 81 | device_map="auto", 82 | ) 83 | 84 | with torch.no_grad(): 85 | tokens = tokenizer( 86 | prompts, 87 | padding="max_length", 88 | max_length=tokenizer.model_max_length, 89 | return_tensors="pt", 90 | ) 91 | text_embeddings = text_encoder(tokens.input_ids.to(text_encoder.device))[0] 92 | 93 | for prompt, embedding in zip(prompts, text_embeddings): 94 | torch.save( 95 | embedding, 96 | os.path.join( 97 | cache_dir, 98 | f"{hash_prompt(pretrained_model_name_or_path, prompt)}.pt", 99 | ), 100 | ) 101 | 102 | del text_encoder 103 | -------------------------------------------------------------------------------- /3drec/threestudio/models/renderers/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | base, 3 | deferred_volume_renderer, 4 | gan_volume_renderer, 5 | nerf_volume_renderer, 6 | neus_volume_renderer, 7 | nvdiff_rasterizer, 8 | patch_renderer, 9 | ) 10 | -------------------------------------------------------------------------------- /3drec/threestudio/models/renderers/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import nerfacc 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | import threestudio 8 | from threestudio.models.background.base import BaseBackground 9 | from threestudio.models.geometry.base import BaseImplicitGeometry 10 | from threestudio.models.materials.base import BaseMaterial 11 | from threestudio.utils.base import BaseModule 12 | from threestudio.utils.typing import * 13 | 14 | 15 | class Renderer(BaseModule): 16 | @dataclass 17 | class Config(BaseModule.Config): 18 | radius: float = 1.0 19 | 20 | cfg: Config 21 | 22 | def configure( 23 | self, 24 | geometry: BaseImplicitGeometry, 25 | material: BaseMaterial, 26 | background: BaseBackground, 27 | ) -> None: 28 | # keep references to submodules using namedtuple, avoid being registered as modules 29 | @dataclass 30 | class SubModules: 31 | geometry: BaseImplicitGeometry 32 | material: BaseMaterial 33 | background: BaseBackground 34 | 35 | self.sub_modules = SubModules(geometry, material, background) 36 | 37 | # set up bounding box 38 | self.bbox: Float[Tensor, "2 3"] 39 | self.register_buffer( 40 | "bbox", 41 | torch.as_tensor( 42 | [ 43 | [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], 44 | [self.cfg.radius, self.cfg.radius, self.cfg.radius], 45 | ], 46 | dtype=torch.float32, 47 | ), 48 | ) 49 | 50 | def forward(self, *args, **kwargs) -> Dict[str, Any]: 51 | raise NotImplementedError 52 | 53 | @property 54 | def geometry(self) -> BaseImplicitGeometry: 55 | return self.sub_modules.geometry 56 | 57 | @property 58 | def material(self) -> BaseMaterial: 59 | return self.sub_modules.material 60 | 61 | @property 62 | def background(self) -> BaseBackground: 63 | return self.sub_modules.background 64 | 65 | def set_geometry(self, geometry: BaseImplicitGeometry) -> None: 66 | self.sub_modules.geometry = geometry 67 | 68 | def set_material(self, material: BaseMaterial) -> None: 69 | self.sub_modules.material = material 70 | 71 | def set_background(self, background: BaseBackground) -> None: 72 | self.sub_modules.background = background 73 | 74 | 75 | class VolumeRenderer(Renderer): 76 | pass 77 | 78 | 79 | class Rasterizer(Renderer): 80 | pass 81 | -------------------------------------------------------------------------------- /3drec/threestudio/models/renderers/deferred_volume_renderer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import threestudio 7 | from threestudio.models.renderers.base import VolumeRenderer 8 | 9 | 10 | class DeferredVolumeRenderer(VolumeRenderer): 11 | pass 12 | -------------------------------------------------------------------------------- /3drec/threestudio/models/renderers/nvdiff_rasterizer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import nerfacc 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | import threestudio 8 | from threestudio.models.background.base import BaseBackground 9 | from threestudio.models.geometry.base import BaseImplicitGeometry 10 | from threestudio.models.materials.base import BaseMaterial 11 | from threestudio.models.renderers.base import Rasterizer, VolumeRenderer 12 | from threestudio.utils.misc import get_device 13 | from threestudio.utils.rasterize import NVDiffRasterizerContext 14 | from threestudio.utils.typing import * 15 | 16 | 17 | @threestudio.register("nvdiff-rasterizer") 18 | class NVDiffRasterizer(Rasterizer): 19 | @dataclass 20 | class Config(VolumeRenderer.Config): 21 | context_type: str = "gl" 22 | 23 | cfg: Config 24 | 25 | def configure( 26 | self, 27 | geometry: BaseImplicitGeometry, 28 | material: BaseMaterial, 29 | background: BaseBackground, 30 | ) -> None: 31 | super().configure(geometry, material, background) 32 | self.ctx = NVDiffRasterizerContext(self.cfg.context_type, get_device()) 33 | 34 | def forward( 35 | self, 36 | mvp_mtx: Float[Tensor, "B 4 4"], 37 | camera_positions: Float[Tensor, "B 3"], 38 | light_positions: Float[Tensor, "B 3"], 39 | height: int, 40 | width: int, 41 | render_rgb: bool = True, 42 | **kwargs 43 | ) -> Dict[str, Any]: 44 | batch_size = mvp_mtx.shape[0] 45 | mesh = self.geometry.isosurface() 46 | 47 | v_pos_clip: Float[Tensor, "B Nv 4"] = self.ctx.vertex_transform( 48 | mesh.v_pos, mvp_mtx 49 | ) 50 | rast, _ = self.ctx.rasterize(v_pos_clip, mesh.t_pos_idx, (height, width)) 51 | mask = rast[..., 3:] > 0 52 | mask_aa = self.ctx.antialias(mask.float(), rast, v_pos_clip, mesh.t_pos_idx) 53 | 54 | out = {"opacity": mask_aa, "mesh": mesh} 55 | 56 | gb_normal, _ = self.ctx.interpolate_one(mesh.v_nrm, rast, mesh.t_pos_idx) 57 | gb_normal = F.normalize(gb_normal, dim=-1) 58 | gb_normal_aa = torch.lerp( 59 | torch.zeros_like(gb_normal), (gb_normal + 1.0) / 2.0, mask.float() 60 | ) 61 | gb_normal_aa = self.ctx.antialias( 62 | gb_normal_aa, rast, v_pos_clip, mesh.t_pos_idx 63 | ) 64 | out.update({"comp_normal": gb_normal_aa}) # in [0, 1] 65 | 66 | # TODO: make it clear whether to compute the normal, now we compute it in all cases 67 | # consider using: require_normal_computation = render_normal or (render_rgb and material.requires_normal) 68 | # or 69 | # render_normal = render_normal or (render_rgb and material.requires_normal) 70 | 71 | if render_rgb: 72 | selector = mask[..., 0] 73 | 74 | gb_pos, _ = self.ctx.interpolate_one(mesh.v_pos, rast, mesh.t_pos_idx) 75 | gb_viewdirs = F.normalize( 76 | gb_pos - camera_positions[:, None, None, :], dim=-1 77 | ) 78 | gb_light_positions = light_positions[:, None, None, :].expand( 79 | -1, height, width, -1 80 | ) 81 | 82 | positions = gb_pos[selector] 83 | geo_out = self.geometry(positions, output_normal=False) 84 | 85 | extra_geo_info = {} 86 | if self.material.requires_normal: 87 | extra_geo_info["shading_normal"] = gb_normal[selector] 88 | if self.material.requires_tangent: 89 | gb_tangent, _ = self.ctx.interpolate_one( 90 | mesh.v_tng, rast, mesh.t_pos_idx 91 | ) 92 | gb_tangent = F.normalize(gb_tangent, dim=-1) 93 | extra_geo_info["tangent"] = gb_tangent[selector] 94 | 95 | rgb_fg = self.material( 96 | viewdirs=gb_viewdirs[selector], 97 | positions=positions, 98 | light_positions=gb_light_positions[selector], 99 | **extra_geo_info, 100 | **geo_out 101 | ) 102 | gb_rgb_fg = torch.zeros(batch_size, height, width, 3).to(rgb_fg) 103 | gb_rgb_fg[selector] = rgb_fg 104 | 105 | gb_rgb_bg = self.background(dirs=gb_viewdirs) 106 | gb_rgb = torch.lerp(gb_rgb_bg, gb_rgb_fg, mask.float()) 107 | gb_rgb_aa = self.ctx.antialias(gb_rgb, rast, v_pos_clip, mesh.t_pos_idx) 108 | 109 | out.update({"comp_rgb": gb_rgb_aa, "comp_rgb_bg": gb_rgb_bg}) 110 | 111 | return out 112 | -------------------------------------------------------------------------------- /3drec/threestudio/models/renderers/patch_renderer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import threestudio 7 | from threestudio.models.background.base import BaseBackground 8 | from threestudio.models.geometry.base import BaseImplicitGeometry 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.models.renderers.base import VolumeRenderer 11 | from threestudio.utils.typing import * 12 | 13 | 14 | @threestudio.register("patch-renderer") 15 | class PatchRenderer(VolumeRenderer): 16 | @dataclass 17 | class Config(VolumeRenderer.Config): 18 | patch_size: int = 128 19 | base_renderer_type: str = "" 20 | base_renderer: Optional[VolumeRenderer.Config] = None 21 | global_detach: bool = False 22 | global_downsample: int = 4 23 | 24 | cfg: Config 25 | 26 | def configure( 27 | self, 28 | geometry: BaseImplicitGeometry, 29 | material: BaseMaterial, 30 | background: BaseBackground, 31 | ) -> None: 32 | self.base_renderer = threestudio.find(self.cfg.base_renderer_type)( 33 | self.cfg.base_renderer, 34 | geometry=geometry, 35 | material=material, 36 | background=background, 37 | ) 38 | 39 | def forward( 40 | self, 41 | rays_o: Float[Tensor, "B H W 3"], 42 | rays_d: Float[Tensor, "B H W 3"], 43 | light_positions: Float[Tensor, "B 3"], 44 | bg_color: Optional[Tensor] = None, 45 | **kwargs 46 | ) -> Dict[str, Float[Tensor, "..."]]: 47 | B, H, W, _ = rays_o.shape 48 | 49 | if self.base_renderer.training: 50 | downsample = self.cfg.global_downsample 51 | global_rays_o = torch.nn.functional.interpolate( 52 | rays_o.permute(0, 3, 1, 2), 53 | (H // downsample, W // downsample), 54 | mode="bilinear", 55 | ).permute(0, 2, 3, 1) 56 | global_rays_d = torch.nn.functional.interpolate( 57 | rays_d.permute(0, 3, 1, 2), 58 | (H // downsample, W // downsample), 59 | mode="bilinear", 60 | ).permute(0, 2, 3, 1) 61 | out_global = self.base_renderer( 62 | global_rays_o, global_rays_d, light_positions, bg_color, **kwargs 63 | ) 64 | 65 | PS = self.cfg.patch_size 66 | patch_x = torch.randint(0, W - PS, (1,)).item() 67 | patch_y = torch.randint(0, H - PS, (1,)).item() 68 | patch_rays_o = rays_o[:, patch_y : patch_y + PS, patch_x : patch_x + PS] 69 | patch_rays_d = rays_d[:, patch_y : patch_y + PS, patch_x : patch_x + PS] 70 | out = self.base_renderer( 71 | patch_rays_o, patch_rays_d, light_positions, bg_color, **kwargs 72 | ) 73 | 74 | valid_patch_key = [] 75 | for key in out: 76 | if torch.is_tensor(out[key]): 77 | if len(out[key].shape) == len(out["comp_rgb"].shape): 78 | if out[key][..., 0].shape == out["comp_rgb"][..., 0].shape: 79 | valid_patch_key.append(key) 80 | for key in valid_patch_key: 81 | out_global[key] = F.interpolate( 82 | out_global[key].permute(0, 3, 1, 2), (H, W), mode="bilinear" 83 | ).permute(0, 2, 3, 1) 84 | if self.cfg.global_detach: 85 | out_global[key] = out_global[key].detach() 86 | out_global[key][ 87 | :, patch_y : patch_y + PS, patch_x : patch_x + PS 88 | ] = out[key] 89 | out = out_global 90 | else: 91 | out = self.base_renderer( 92 | rays_o, rays_d, light_positions, bg_color, **kwargs 93 | ) 94 | 95 | return out 96 | 97 | def update_step( 98 | self, epoch: int, global_step: int, on_load_weights: bool = False 99 | ) -> None: 100 | self.base_renderer.update_step(epoch, global_step, on_load_weights) 101 | 102 | def train(self, mode=True): 103 | return self.base_renderer.train(mode) 104 | 105 | def eval(self): 106 | return self.base_renderer.eval() 107 | -------------------------------------------------------------------------------- /3drec/threestudio/systems/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | nvsadapter, 3 | ) 4 | -------------------------------------------------------------------------------- /3drec/threestudio/systems/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import warnings 3 | from bisect import bisect_right 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.optim import lr_scheduler 8 | 9 | import threestudio 10 | 11 | 12 | def get_scheduler(name): 13 | if hasattr(lr_scheduler, name): 14 | return getattr(lr_scheduler, name) 15 | else: 16 | raise NotImplementedError 17 | 18 | 19 | def getattr_recursive(m, attr): 20 | for name in attr.split("."): 21 | m = getattr(m, name) 22 | return m 23 | 24 | 25 | def get_parameters(model, name): 26 | module = getattr_recursive(model, name) 27 | if isinstance(module, nn.Module): 28 | return module.parameters() 29 | elif isinstance(module, nn.Parameter): 30 | return module 31 | return [] 32 | 33 | 34 | def parse_optimizer(config, model): 35 | if hasattr(config, "params"): 36 | params = [ 37 | {"params": get_parameters(model, name), "name": name, **args} 38 | for name, args in config.params.items() 39 | ] 40 | threestudio.debug(f"Specify optimizer params: {config.params}") 41 | else: 42 | params = model.parameters() 43 | if config.name in ["FusedAdam"]: 44 | import apex 45 | 46 | optim = getattr(apex.optimizers, config.name)(params, **config.args) 47 | elif config.name in ["Adan"]: 48 | from threestudio.systems import optimizers 49 | 50 | optim = getattr(optimizers, config.name)(params, **config.args) 51 | else: 52 | optim = getattr(torch.optim, config.name)(params, **config.args) 53 | return optim 54 | 55 | 56 | def parse_scheduler_to_instance(config, optimizer): 57 | if config.name == "ChainedScheduler": 58 | schedulers = [ 59 | parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers 60 | ] 61 | scheduler = lr_scheduler.ChainedScheduler(schedulers) 62 | elif config.name == "Sequential": 63 | schedulers = [ 64 | parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers 65 | ] 66 | scheduler = lr_scheduler.SequentialLR( 67 | optimizer, schedulers, milestones=config.milestones 68 | ) 69 | else: 70 | scheduler = getattr(lr_scheduler, config.name)(optimizer, **config.args) 71 | return scheduler 72 | 73 | 74 | def parse_scheduler(config, optimizer): 75 | interval = config.get("interval", "epoch") 76 | assert interval in ["epoch", "step"] 77 | if config.name == "SequentialLR": 78 | scheduler = { 79 | "scheduler": lr_scheduler.SequentialLR( 80 | optimizer, 81 | [ 82 | parse_scheduler(conf, optimizer)["scheduler"] 83 | for conf in config.schedulers 84 | ], 85 | milestones=config.milestones, 86 | ), 87 | "interval": interval, 88 | } 89 | elif config.name == "ChainedScheduler": 90 | scheduler = { 91 | "scheduler": lr_scheduler.ChainedScheduler( 92 | [ 93 | parse_scheduler(conf, optimizer)["scheduler"] 94 | for conf in config.schedulers 95 | ] 96 | ), 97 | "interval": interval, 98 | } 99 | else: 100 | scheduler = { 101 | "scheduler": get_scheduler(config.name)(optimizer, **config.args), 102 | "interval": interval, 103 | } 104 | return scheduler 105 | -------------------------------------------------------------------------------- /3drec/threestudio/utils/GAN/distribution.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.sum( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=[1, 2, 3]): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims, 69 | ) 70 | 71 | def mode(self): 72 | return self.mean 73 | 74 | 75 | def normal_kl(mean1, logvar1, mean2, logvar2): 76 | """ 77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 78 | Compute the KL divergence between two gaussians. 79 | Shapes are automatically broadcasted, so batches can be compared to 80 | scalars, among other use cases. 81 | """ 82 | tensor = None 83 | for obj in (mean1, logvar1, mean2, logvar2): 84 | if isinstance(obj, torch.Tensor): 85 | tensor = obj 86 | break 87 | assert tensor is not None, "at least one argument must be a Tensor" 88 | 89 | # Force variances to be Tensors. Broadcasting helps convert scalars to 90 | # Tensors, but it does not work for torch.exp(). 91 | logvar1, logvar2 = [ 92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 93 | for x in (logvar1, logvar2) 94 | ] 95 | 96 | return 0.5 * ( 97 | -1.0 98 | + logvar2 99 | - logvar1 100 | + torch.exp(logvar1 - logvar2) 101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 102 | ) 103 | -------------------------------------------------------------------------------- /3drec/threestudio/utils/GAN/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def generator_loss(discriminator, inputs, reconstructions, cond=None): 6 | if cond is None: 7 | logits_fake = discriminator(reconstructions.contiguous()) 8 | else: 9 | logits_fake = discriminator( 10 | torch.cat((reconstructions.contiguous(), cond), dim=1) 11 | ) 12 | g_loss = -torch.mean(logits_fake) 13 | return g_loss 14 | 15 | 16 | def hinge_d_loss(logits_real, logits_fake): 17 | loss_real = torch.mean(F.relu(1.0 - logits_real)) 18 | loss_fake = torch.mean(F.relu(1.0 + logits_fake)) 19 | d_loss = 0.5 * (loss_real + loss_fake) 20 | return d_loss 21 | 22 | 23 | def discriminator_loss(discriminator, inputs, reconstructions, cond=None): 24 | if cond is None: 25 | logits_real = discriminator(inputs.contiguous().detach()) 26 | logits_fake = discriminator(reconstructions.contiguous().detach()) 27 | else: 28 | logits_real = discriminator( 29 | torch.cat((inputs.contiguous().detach(), cond), dim=1) 30 | ) 31 | logits_fake = discriminator( 32 | torch.cat((reconstructions.contiguous().detach(), cond), dim=1) 33 | ) 34 | d_loss = hinge_d_loss(logits_real, logits_fake).mean() 35 | return d_loss 36 | -------------------------------------------------------------------------------- /3drec/threestudio/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import base 2 | -------------------------------------------------------------------------------- /3drec/threestudio/utils/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from threestudio.utils.config import parse_structured 7 | from threestudio.utils.misc import get_device, load_module_weights 8 | from threestudio.utils.typing import * 9 | 10 | 11 | class Configurable: 12 | @dataclass 13 | class Config: 14 | pass 15 | 16 | def __init__(self, cfg: Optional[dict] = None) -> None: 17 | super().__init__() 18 | self.cfg = parse_structured(self.Config, cfg) 19 | 20 | 21 | class Updateable: 22 | def do_update_step( 23 | self, epoch: int, global_step: int, on_load_weights: bool = False 24 | ): 25 | for attr in self.__dir__(): 26 | if attr.startswith("_"): 27 | continue 28 | try: 29 | module = getattr(self, attr) 30 | except: 31 | continue # ignore attributes like property, which can't be retrived using getattr? 32 | if isinstance(module, Updateable): 33 | module.do_update_step( 34 | epoch, global_step, on_load_weights=on_load_weights 35 | ) 36 | self.update_step(epoch, global_step, on_load_weights=on_load_weights) 37 | 38 | def do_update_step_end(self, epoch: int, global_step: int): 39 | for attr in self.__dir__(): 40 | if attr.startswith("_"): 41 | continue 42 | try: 43 | module = getattr(self, attr) 44 | except: 45 | continue # ignore attributes like property, which can't be retrived using getattr? 46 | if isinstance(module, Updateable): 47 | module.do_update_step_end(epoch, global_step) 48 | self.update_step_end(epoch, global_step) 49 | 50 | def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): 51 | # override this method to implement custom update logic 52 | # if on_load_weights is True, you should be careful doing things related to model evaluations, 53 | # as the models and tensors are not guarenteed to be on the same device 54 | pass 55 | 56 | def update_step_end(self, epoch: int, global_step: int): 57 | pass 58 | 59 | 60 | def update_if_possible(module: Any, epoch: int, global_step: int) -> None: 61 | if isinstance(module, Updateable): 62 | module.do_update_step(epoch, global_step) 63 | 64 | 65 | def update_end_if_possible(module: Any, epoch: int, global_step: int) -> None: 66 | if isinstance(module, Updateable): 67 | module.do_update_step_end(epoch, global_step) 68 | 69 | 70 | class BaseObject(Updateable): 71 | @dataclass 72 | class Config: 73 | pass 74 | 75 | cfg: Config # add this to every subclass of BaseObject to enable static type checking 76 | 77 | def __init__( 78 | self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs 79 | ) -> None: 80 | super().__init__() 81 | self.cfg = parse_structured(self.Config, cfg) 82 | self.device = get_device() 83 | self.configure(*args, **kwargs) 84 | 85 | def configure(self, *args, **kwargs) -> None: 86 | pass 87 | 88 | 89 | class BaseModule(nn.Module, Updateable): 90 | @dataclass 91 | class Config: 92 | weights: Optional[str] = None 93 | 94 | cfg: Config # add this to every subclass of BaseModule to enable static type checking 95 | 96 | def __init__( 97 | self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs 98 | ) -> None: 99 | super().__init__() 100 | self.cfg = parse_structured(self.Config, cfg) 101 | self.device = get_device() 102 | self.configure(*args, **kwargs) 103 | if self.cfg.weights is not None: 104 | # format: path/to/weights:module_name 105 | weights_path, module_name = self.cfg.weights.split(":") 106 | state_dict, epoch, global_step = load_module_weights( 107 | weights_path, module_name=module_name, map_location="cpu" 108 | ) 109 | self.load_state_dict(state_dict) 110 | self.do_update_step( 111 | epoch, global_step, on_load_weights=True 112 | ) # restore states 113 | # dummy tensor to indicate model state 114 | self._dummy: Float[Tensor, "..."] 115 | self.register_buffer("_dummy", torch.zeros(0).float(), persistent=False) 116 | 117 | def configure(self, *args, **kwargs) -> None: 118 | pass 119 | -------------------------------------------------------------------------------- /3drec/threestudio/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from datetime import datetime 4 | 5 | from omegaconf import OmegaConf 6 | 7 | import threestudio 8 | from threestudio.utils.typing import * 9 | 10 | # ============ Register OmegaConf Recolvers ============= # 11 | OmegaConf.register_new_resolver( 12 | "calc_exp_lr_decay_rate", lambda factor, n: factor ** (1.0 / n) 13 | ) 14 | OmegaConf.register_new_resolver("add", lambda a, b: a + b) 15 | OmegaConf.register_new_resolver("sub", lambda a, b: a - b) 16 | OmegaConf.register_new_resolver("mul", lambda a, b: a * b) 17 | OmegaConf.register_new_resolver("div", lambda a, b: a / b) 18 | OmegaConf.register_new_resolver("idiv", lambda a, b: a // b) 19 | OmegaConf.register_new_resolver("basename", lambda p: os.path.basename(p)) 20 | OmegaConf.register_new_resolver("rmspace", lambda s, sub: s.replace(" ", sub)) 21 | OmegaConf.register_new_resolver("tuple2", lambda s: [float(s), float(s)]) 22 | OmegaConf.register_new_resolver("gt0", lambda s: s > 0) 23 | OmegaConf.register_new_resolver("cmaxgt0", lambda s: C_max(s) > 0) 24 | OmegaConf.register_new_resolver("not", lambda s: not s) 25 | OmegaConf.register_new_resolver( 26 | "cmaxgt0orcmaxgt0", lambda a, b: C_max(a) > 0 or C_max(b) > 0 27 | ) 28 | # ======================================================= # 29 | 30 | 31 | def C_max(value: Any) -> float: 32 | if isinstance(value, int) or isinstance(value, float): 33 | pass 34 | else: 35 | value = config_to_primitive(value) 36 | if not isinstance(value, list): 37 | raise TypeError("Scalar specification only supports list, got", type(value)) 38 | if len(value) == 3: 39 | value = [0] + value 40 | assert len(value) == 4 41 | start_step, start_value, end_value, end_step = value 42 | value = max(start_value, end_value) 43 | return value 44 | 45 | 46 | @dataclass 47 | class ExperimentConfig: 48 | name: str = "default" 49 | description: str = "" 50 | tag: str = "" 51 | seed: int = 0 52 | use_timestamp: bool = True 53 | timestamp: Optional[str] = None 54 | exp_root_dir: str = "outputs" 55 | 56 | ### these shouldn't be set manually 57 | exp_dir: str = "outputs/default" 58 | trial_name: str = "exp" 59 | trial_dir: str = "outputs/default/exp" 60 | n_gpus: int = 1 61 | ### 62 | 63 | resume: Optional[str] = None 64 | 65 | data_type: str = "" 66 | data: dict = field(default_factory=dict) 67 | 68 | system_type: str = "" 69 | system: dict = field(default_factory=dict) 70 | 71 | # accept pytorch-lightning trainer parameters 72 | # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api 73 | trainer: dict = field(default_factory=dict) 74 | 75 | # accept pytorch-lightning checkpoint callback parameters 76 | # see https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint 77 | checkpoint: dict = field(default_factory=dict) 78 | 79 | def __post_init__(self): 80 | if not self.tag and not self.use_timestamp: 81 | raise ValueError("Either tag is specified or use_timestamp is True.") 82 | self.trial_name = self.tag 83 | # if resume from an existing config, self.timestamp should not be None 84 | if self.timestamp is None: 85 | self.timestamp = "" 86 | if self.use_timestamp: 87 | if self.n_gpus > 1: 88 | threestudio.warn( 89 | "Timestamp is disabled when using multiple GPUs, please make sure you have a unique tag." 90 | ) 91 | else: 92 | self.timestamp = datetime.now().strftime("@%Y%m%d-%H%M%S") 93 | self.trial_name += self.timestamp 94 | self.exp_dir = os.path.join(self.exp_root_dir, self.name) 95 | self.trial_dir = os.path.join(self.exp_dir, self.trial_name) 96 | os.makedirs(self.trial_dir, exist_ok=True) 97 | 98 | 99 | def load_config(*yamls: str, cli_args: list = [], from_string=False, **kwargs) -> Any: 100 | if from_string: 101 | yaml_confs = [OmegaConf.create(s) for s in yamls] 102 | else: 103 | yaml_confs = [OmegaConf.load(f) for f in yamls] 104 | cli_conf = OmegaConf.from_cli(cli_args) 105 | cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs) 106 | OmegaConf.resolve(cfg) 107 | assert isinstance(cfg, DictConfig) 108 | scfg = parse_structured(ExperimentConfig, cfg) 109 | return scfg 110 | 111 | 112 | def config_to_primitive(config, resolve: bool = True) -> Any: 113 | return OmegaConf.to_container(config, resolve=resolve) 114 | 115 | 116 | def dump_config(path: str, config) -> None: 117 | with open(path, "w") as fp: 118 | OmegaConf.save(config=config, f=fp) 119 | 120 | 121 | def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: 122 | scfg = OmegaConf.structured(fields(**cfg)) 123 | return scfg 124 | -------------------------------------------------------------------------------- /3drec/threestudio/utils/misc.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | import re 4 | 5 | import tinycudann as tcnn 6 | import torch 7 | from packaging import version 8 | 9 | from threestudio.utils.config import config_to_primitive 10 | from threestudio.utils.typing import * 11 | 12 | 13 | def parse_version(ver: str): 14 | return version.parse(ver) 15 | 16 | 17 | def get_rank(): 18 | # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, 19 | # therefore LOCAL_RANK needs to be checked first 20 | rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") 21 | for key in rank_keys: 22 | rank = os.environ.get(key) 23 | if rank is not None: 24 | return int(rank) 25 | return 0 26 | 27 | 28 | def get_device(): 29 | return torch.device(f"cuda:{get_rank()}") 30 | 31 | 32 | def load_module_weights( 33 | path, module_name=None, ignore_modules=None, map_location=None 34 | ) -> Tuple[dict, int, int]: 35 | if module_name is not None and ignore_modules is not None: 36 | raise ValueError("module_name and ignore_modules cannot be both set") 37 | if map_location is None: 38 | map_location = get_device() 39 | 40 | ckpt = torch.load(path, map_location=map_location) 41 | state_dict = ckpt["state_dict"] 42 | state_dict_to_load = state_dict 43 | 44 | if ignore_modules is not None: 45 | state_dict_to_load = {} 46 | for k, v in state_dict.items(): 47 | ignore = any( 48 | [k.startswith(ignore_module + ".") for ignore_module in ignore_modules] 49 | ) 50 | if ignore: 51 | continue 52 | state_dict_to_load[k] = v 53 | 54 | if module_name is not None: 55 | state_dict_to_load = {} 56 | for k, v in state_dict.items(): 57 | m = re.match(rf"^{module_name}\.(.*)$", k) 58 | if m is None: 59 | continue 60 | state_dict_to_load[m.group(1)] = v 61 | 62 | return state_dict_to_load, ckpt["epoch"], ckpt["global_step"] 63 | 64 | 65 | def C(value: Any, epoch: int, global_step: int) -> float: 66 | if isinstance(value, int) or isinstance(value, float): 67 | pass 68 | else: 69 | value = config_to_primitive(value) 70 | if not isinstance(value, list): 71 | raise TypeError("Scalar specification only supports list, got", type(value)) 72 | if len(value) == 3: 73 | value = [0] + value 74 | assert len(value) == 4 75 | start_step, start_value, end_value, end_step = value 76 | if isinstance(end_step, int): 77 | current_step = global_step 78 | value = start_value + (end_value - start_value) * max( 79 | min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 80 | ) 81 | elif isinstance(end_step, float): 82 | current_step = epoch 83 | value = start_value + (end_value - start_value) * max( 84 | min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 85 | ) 86 | return value 87 | 88 | 89 | def cleanup(): 90 | gc.collect() 91 | torch.cuda.empty_cache() 92 | tcnn.free_temporary_memory() 93 | 94 | 95 | def finish_with_cleanup(func: Callable): 96 | def wrapper(*args, **kwargs): 97 | out = func(*args, **kwargs) 98 | cleanup() 99 | return out 100 | 101 | return wrapper 102 | 103 | 104 | def _distributed_available(): 105 | return torch.distributed.is_available() and torch.distributed.is_initialized() 106 | 107 | 108 | def barrier(): 109 | if not _distributed_available(): 110 | return 111 | else: 112 | torch.distributed.barrier() 113 | 114 | 115 | def broadcast(tensor, src=0): 116 | if not _distributed_available(): 117 | return tensor 118 | else: 119 | torch.distributed.broadcast(tensor, src=src) 120 | return tensor 121 | 122 | 123 | def enable_gradient(model, enabled: bool = True) -> None: 124 | for param in model.parameters(): 125 | param.requires_grad_(enabled) 126 | -------------------------------------------------------------------------------- /3drec/threestudio/utils/perceptual/__init__.py: -------------------------------------------------------------------------------- 1 | from .perceptual import PerceptualLoss 2 | -------------------------------------------------------------------------------- /3drec/threestudio/utils/rasterize.py: -------------------------------------------------------------------------------- 1 | import nvdiffrast.torch as dr 2 | import torch 3 | 4 | from threestudio.utils.typing import * 5 | 6 | 7 | class NVDiffRasterizerContext: 8 | def __init__(self, context_type: str, device: torch.device) -> None: 9 | self.device = device 10 | self.ctx = self.initialize_context(context_type, device) 11 | 12 | def initialize_context( 13 | self, context_type: str, device: torch.device 14 | ) -> Union[dr.RasterizeGLContext, dr.RasterizeCudaContext]: 15 | if context_type == "gl": 16 | return dr.RasterizeGLContext(device=device) 17 | elif context_type == "cuda": 18 | return dr.RasterizeCudaContext(device=device) 19 | else: 20 | raise ValueError(f"Unknown rasterizer context type: {context_type}") 21 | 22 | def vertex_transform( 23 | self, verts: Float[Tensor, "Nv 3"], mvp_mtx: Float[Tensor, "B 4 4"] 24 | ) -> Float[Tensor, "B Nv 4"]: 25 | verts_homo = torch.cat( 26 | [verts, torch.ones([verts.shape[0], 1]).to(verts)], dim=-1 27 | ) 28 | return torch.matmul(verts_homo, mvp_mtx.permute(0, 2, 1)) 29 | 30 | def rasterize( 31 | self, 32 | pos: Float[Tensor, "B Nv 4"], 33 | tri: Integer[Tensor, "Nf 3"], 34 | resolution: Union[int, Tuple[int, int]], 35 | ): 36 | # rasterize in instance mode (single topology) 37 | return dr.rasterize(self.ctx, pos.float(), tri.int(), resolution, grad_db=True) 38 | 39 | def rasterize_one( 40 | self, 41 | pos: Float[Tensor, "Nv 4"], 42 | tri: Integer[Tensor, "Nf 3"], 43 | resolution: Union[int, Tuple[int, int]], 44 | ): 45 | # rasterize one single mesh under a single viewpoint 46 | rast, rast_db = self.rasterize(pos[None, ...], tri, resolution) 47 | return rast[0], rast_db[0] 48 | 49 | def antialias( 50 | self, 51 | color: Float[Tensor, "B H W C"], 52 | rast: Float[Tensor, "B H W 4"], 53 | pos: Float[Tensor, "B Nv 4"], 54 | tri: Integer[Tensor, "Nf 3"], 55 | ) -> Float[Tensor, "B H W C"]: 56 | return dr.antialias(color.float(), rast, pos.float(), tri.int()) 57 | 58 | def interpolate( 59 | self, 60 | attr: Float[Tensor, "B Nv C"], 61 | rast: Float[Tensor, "B H W 4"], 62 | tri: Integer[Tensor, "Nf 3"], 63 | rast_db=None, 64 | diff_attrs=None, 65 | ) -> Float[Tensor, "B H W C"]: 66 | return dr.interpolate( 67 | attr.float(), rast, tri.int(), rast_db=rast_db, diff_attrs=diff_attrs 68 | ) 69 | 70 | def interpolate_one( 71 | self, 72 | attr: Float[Tensor, "Nv C"], 73 | rast: Float[Tensor, "B H W 4"], 74 | tri: Integer[Tensor, "Nf 3"], 75 | rast_db=None, 76 | diff_attrs=None, 77 | ) -> Float[Tensor, "B H W C"]: 78 | return self.interpolate(attr[None, ...], rast, tri, rast_db, diff_attrs) 79 | -------------------------------------------------------------------------------- /3drec/threestudio/utils/typing.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains type annotations for the project, using 3 | 1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects 4 | 2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors 5 | 6 | Two types of typing checking can be used: 7 | 1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode) 8 | 2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking) 9 | """ 10 | 11 | # Basic types 12 | from typing import ( 13 | Any, 14 | Callable, 15 | Dict, 16 | Iterable, 17 | List, 18 | Literal, 19 | NamedTuple, 20 | NewType, 21 | Optional, 22 | Sized, 23 | Tuple, 24 | Type, 25 | TypeVar, 26 | Union, 27 | ) 28 | 29 | # Tensor dtype 30 | # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md 31 | from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt 32 | 33 | # Config type 34 | from omegaconf import DictConfig 35 | 36 | # PyTorch Tensor type 37 | from torch import Tensor 38 | 39 | # Runtime type checking decorator 40 | from typeguard import typechecked as typechecker 41 | -------------------------------------------------------------------------------- /LICENSE-CODE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Kakao Brain and POSTECH 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /assets/main_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/assets/main_framework.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/assets/teaser.png -------------------------------------------------------------------------------- /configs/ablation/camera_extrinsic.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | network_config: 4 | params: 5 | posemb_dim: 496 6 | 7 | conditioner_config: 8 | target: sgm.modules.nvsadapter.conditioner.MultipleGeneralConditioners 9 | params: 10 | conditioners_config: 11 | - name: support_latents 12 | target: sgm.modules.GeneralConditioner 13 | params: 14 | emb_models: 15 | - is_trainable: False 16 | input_key: support_latents 17 | target: sgm.modules.encoders.modules.IdentityEncoder 18 | params: {} 19 | 20 | - name: ray 21 | target: sgm.modules.GeneralConditioner 22 | params: 23 | emb_models: 24 | # ray embedding 25 | - is_trainable: False 26 | input_keys: 27 | - support_c2ws 28 | - query_c2ws 29 | - support_latents 30 | target: sgm.modules.nvsadapter.conditioner.ExtrinsicEmbedder 31 | params: 32 | deg: [0, 15] 33 | 34 | - name: image 35 | target: sgm.modules.GeneralConditioner 36 | params: 37 | emb_models: 38 | # image embedding 39 | - is_trainable: True 40 | input_key: support_rgbs 41 | target: sgm.modules.nvsadapter.conditioner.ImageEmbedAttentionProjector 42 | params: 43 | unsqueeze_dim: true 44 | 45 | - name: txt 46 | target: sgm.modules.GeneralConditioner 47 | params: 48 | emb_models: 49 | # text embedding 50 | - is_trainable: False 51 | input_key: txt 52 | target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder 53 | params: 54 | freeze: true 55 | layer: penultimate -------------------------------------------------------------------------------- /configs/ablation/no_image_attn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | network_config: 4 | params: 5 | image_attn_mode: null -------------------------------------------------------------------------------- /configs/ablation/query_emb_scale_2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | network_config: 4 | params: 5 | query_emb_scale: 2.0 -------------------------------------------------------------------------------- /configs/ablation/query_emb_scale_half.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | network_config: 4 | params: 5 | query_emb_scale: 0.5 -------------------------------------------------------------------------------- /configs/ablation/rayo_rayd.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | conditioner_config: 4 | target: sgm.modules.nvsadapter.conditioner.MultipleGeneralConditioners 5 | params: 6 | conditioners_config: 7 | - name: support_latents 8 | target: sgm.modules.GeneralConditioner 9 | params: 10 | emb_models: 11 | - is_trainable: False 12 | input_key: support_latents 13 | target: sgm.modules.encoders.modules.IdentityEncoder 14 | params: {} 15 | 16 | - name: ray 17 | target: sgm.modules.GeneralConditioner 18 | params: 19 | emb_models: 20 | # ray embedding 21 | - is_trainable: False 22 | input_keys: 23 | - support_rays_offset 24 | - support_rays_direction 25 | - query_rays_offset 26 | - query_rays_direction 27 | target: sgm.modules.nvsadapter.conditioner.RayPosConditionEmbedder 28 | params: 29 | offset_deg: [0, 15] 30 | direction_deg: [0, 8] 31 | use_plucker: false 32 | 33 | - name: txt 34 | target: sgm.modules.GeneralConditioner 35 | params: 36 | emb_models: 37 | # text embedding 38 | - is_trainable: False 39 | input_key: txt 40 | target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder 41 | params: 42 | freeze: true 43 | layer: penultimate 44 | 45 | - name: image 46 | target: sgm.modules.GeneralConditioner 47 | params: 48 | emb_models: 49 | # image embedding 50 | - is_trainable: False 51 | input_key: support_rgbs 52 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 53 | params: 54 | unsqueeze_dim: true 55 | -------------------------------------------------------------------------------- /configs/num_queries/15_queries.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | params: 3 | train_config: 4 | num_query_views: 15 5 | val_config: 6 | num_query_views: 15 7 | 8 | model: 9 | params: 10 | network_config: 11 | params: 12 | num_query: 15 -------------------------------------------------------------------------------- /configs/num_queries/16_queries.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | params: 3 | train_config: 4 | num_query_views: 16 5 | val_config: 6 | num_query_views: 16 7 | 8 | model: 9 | params: 10 | network_config: 11 | params: 12 | num_query: 16 -------------------------------------------------------------------------------- /configs/num_queries/1_queries.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | params: 3 | train_config: 4 | num_query_views: 1 5 | val_config: 6 | num_query_views: 1 7 | 8 | model: 9 | params: 10 | network_config: 11 | params: 12 | num_query: 1 -------------------------------------------------------------------------------- /configs/num_queries/2_queries.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | params: 3 | train_config: 4 | num_query_views: 2 5 | val_config: 6 | num_query_views: 2 7 | 8 | model: 9 | params: 10 | network_config: 11 | params: 12 | num_query: 2 -------------------------------------------------------------------------------- /configs/num_queries/4_queries.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | params: 3 | train_config: 4 | num_query_views: 4 5 | val_config: 6 | num_query_views: 4 7 | 8 | model: 9 | params: 10 | network_config: 11 | params: 12 | num_query: 4 -------------------------------------------------------------------------------- /configs/num_queries/6_queries.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | params: 3 | train_config: 4 | num_query_views: 6 5 | val_config: 6 | num_query_views: 6 7 | 8 | model: 9 | params: 10 | network_config: 11 | params: 12 | num_query: 6 -------------------------------------------------------------------------------- /configs/options/controlnet_canny.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | network_config: 4 | params: 5 | control_model_config: 6 | target: sgm.modules.nvsadapter.controlnet.ControlNet 7 | params: 8 | image_size: 32 # unused 9 | in_channels: 4 10 | hint_channels: 3 11 | model_channels: 320 12 | attention_resolutions: [ 4, 2, 1 ] 13 | num_res_blocks: 2 14 | channel_mult: [ 1, 2, 4, 4 ] 15 | num_heads: 8 16 | use_spatial_transformer: True 17 | transformer_depth: 1 18 | context_dim: 768 19 | use_checkpoint: True 20 | legacy: False 21 | 22 | conditioner_config: 23 | target: sgm.modules.nvsadapter.conditioner.MultipleGeneralConditioners 24 | params: 25 | conditioners_config: 26 | - name: support_latents 27 | target: sgm.modules.GeneralConditioner 28 | params: 29 | emb_models: 30 | - is_trainable: False 31 | input_key: support_latents 32 | target: sgm.modules.encoders.modules.IdentityEncoder 33 | params: {} 34 | 35 | - name: ray 36 | target: sgm.modules.GeneralConditioner 37 | params: 38 | emb_models: 39 | # ray embedding 40 | - is_trainable: False 41 | input_keys: 42 | - support_rays_offset 43 | - support_rays_direction 44 | - query_rays_offset 45 | - query_rays_direction 46 | target: sgm.modules.nvsadapter.conditioner.RayPosConditionEmbedder 47 | params: 48 | offset_deg: [0, 15] 49 | direction_deg: [0, 8] 50 | use_plucker: true 51 | 52 | - name: txt 53 | target: sgm.modules.GeneralConditioner 54 | params: 55 | emb_models: 56 | # text embedding 57 | - is_trainable: False 58 | input_key: txt 59 | target: sgm.modules.encoders.modules.FrozenCLIPEmbedder 60 | 61 | - name: image 62 | target: sgm.modules.GeneralConditioner 63 | params: 64 | emb_models: 65 | # image embedding 66 | - is_trainable: False 67 | input_key: support_rgbs 68 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 69 | params: 70 | unsqueeze_dim: true 71 | 72 | - name: control 73 | target: sgm.modules.GeneralConditioner 74 | params: 75 | emb_models: 76 | # image embedding 77 | - is_trainable: False 78 | input_keys: 79 | - support_rgbs_cond 80 | - query_rgbs_cond 81 | target: sgm.modules.nvsadapter.conditioner.CannyConditioner 82 | 83 | # path to the pre-trained SD model checkpoint 84 | sd_ckpt_path: checkpoints/v1-5-pruned-emaonly.ckpt 85 | controlnet_ckpt_path: checkpoints/control_sd15_canny.pth 86 | ckpt_path: checkpoints/4query_sd_15_last.ckpt 87 | -------------------------------------------------------------------------------- /configs/options/controlnet_canny_depth.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | network_config: 4 | params: 5 | control_model_config: 6 | target: sgm.modules.nvsadapter.controlnet.ControlNets 7 | params: 8 | num_controlnets: 2 9 | image_size: 32 # unused 10 | in_channels: 4 11 | hint_channels: 3 12 | model_channels: 320 13 | attention_resolutions: [ 4, 2, 1 ] 14 | num_res_blocks: 2 15 | channel_mult: [ 1, 2, 4, 4 ] 16 | num_heads: 8 17 | use_spatial_transformer: True 18 | transformer_depth: 1 19 | context_dim: 768 20 | use_checkpoint: True 21 | legacy: False 22 | 23 | conditioner_config: 24 | target: sgm.modules.nvsadapter.conditioner.MultipleGeneralConditioners 25 | params: 26 | conditioners_config: 27 | - name: support_latents 28 | target: sgm.modules.GeneralConditioner 29 | params: 30 | emb_models: 31 | - is_trainable: False 32 | input_key: support_latents 33 | target: sgm.modules.encoders.modules.IdentityEncoder 34 | params: {} 35 | 36 | - name: ray 37 | target: sgm.modules.GeneralConditioner 38 | params: 39 | emb_models: 40 | # ray embedding 41 | - is_trainable: False 42 | input_keys: 43 | - support_rays_offset 44 | - support_rays_direction 45 | - query_rays_offset 46 | - query_rays_direction 47 | target: sgm.modules.nvsadapter.conditioner.RayPosConditionEmbedder 48 | params: 49 | offset_deg: [0, 15] 50 | direction_deg: [0, 8] 51 | use_plucker: true 52 | 53 | - name: txt 54 | target: sgm.modules.GeneralConditioner 55 | params: 56 | emb_models: 57 | # text embedding 58 | - is_trainable: False 59 | input_key: txt 60 | target: sgm.modules.encoders.modules.FrozenCLIPEmbedder 61 | 62 | - name: image 63 | target: sgm.modules.GeneralConditioner 64 | params: 65 | emb_models: 66 | # image embedding 67 | - is_trainable: False 68 | input_key: support_rgbs 69 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 70 | params: 71 | unsqueeze_dim: true 72 | 73 | - name: control 74 | target: sgm.modules.GeneralConditioner 75 | params: 76 | emb_models: 77 | # image embedding 78 | - is_trainable: False 79 | input_keys: 80 | - support_rgbs_cond 81 | - query_rgbs_cond 82 | target: sgm.modules.nvsadapter.conditioner.CannyDepthConditioner 83 | 84 | # path to the pre-trained SD model checkpoint 85 | sd_ckpt_path: checkpoints/v1-5-pruned-emaonly.ckpt 86 | controlnet_ckpt_path: 87 | - checkpoints/control_sd15_canny.pth 88 | - checkpoints/control_sd15_depth.pth 89 | ckpt_path: checkpoints/4query_sd_15_last.ckpt 90 | -------------------------------------------------------------------------------- /configs/options/controlnet_canny_hed.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | network_config: 4 | params: 5 | control_model_config: 6 | target: sgm.modules.nvsadapter.controlnet.ControlNets 7 | params: 8 | num_controlnets: 2 9 | image_size: 32 # unused 10 | in_channels: 4 11 | hint_channels: 3 12 | model_channels: 320 13 | attention_resolutions: [ 4, 2, 1 ] 14 | num_res_blocks: 2 15 | channel_mult: [ 1, 2, 4, 4 ] 16 | num_heads: 8 17 | use_spatial_transformer: True 18 | transformer_depth: 1 19 | context_dim: 768 20 | use_checkpoint: True 21 | legacy: False 22 | 23 | conditioner_config: 24 | target: sgm.modules.nvsadapter.conditioner.MultipleGeneralConditioners 25 | params: 26 | conditioners_config: 27 | - name: support_latents 28 | target: sgm.modules.GeneralConditioner 29 | params: 30 | emb_models: 31 | - is_trainable: False 32 | input_key: support_latents 33 | target: sgm.modules.encoders.modules.IdentityEncoder 34 | params: {} 35 | 36 | - name: ray 37 | target: sgm.modules.GeneralConditioner 38 | params: 39 | emb_models: 40 | # ray embedding 41 | - is_trainable: False 42 | input_keys: 43 | - support_rays_offset 44 | - support_rays_direction 45 | - query_rays_offset 46 | - query_rays_direction 47 | target: sgm.modules.nvsadapter.conditioner.RayPosConditionEmbedder 48 | params: 49 | offset_deg: [0, 15] 50 | direction_deg: [0, 8] 51 | use_plucker: true 52 | 53 | - name: txt 54 | target: sgm.modules.GeneralConditioner 55 | params: 56 | emb_models: 57 | # text embedding 58 | - is_trainable: False 59 | input_key: txt 60 | target: sgm.modules.encoders.modules.FrozenCLIPEmbedder 61 | 62 | - name: image 63 | target: sgm.modules.GeneralConditioner 64 | params: 65 | emb_models: 66 | # image embedding 67 | - is_trainable: False 68 | input_key: support_rgbs 69 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 70 | params: 71 | unsqueeze_dim: true 72 | 73 | - name: control 74 | target: sgm.modules.GeneralConditioner 75 | params: 76 | emb_models: 77 | # image embedding 78 | - is_trainable: False 79 | input_keys: 80 | - support_rgbs_cond 81 | - query_rgbs_cond 82 | target: sgm.modules.nvsadapter.conditioner.CannyHEDConditioner 83 | 84 | # path to the pre-trained SD model checkpoint 85 | sd_ckpt_path: checkpoints/v1-5-pruned-emaonly.ckpt 86 | controlnet_ckpt_path: 87 | - checkpoints/control_sd15_canny.pth 88 | - checkpoints/control_sd15_hed.pth 89 | ckpt_path: checkpoints/4query_sd_15_last.ckpt 90 | -------------------------------------------------------------------------------- /configs/options/controlnet_canny_hed_depth.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | network_config: 4 | params: 5 | control_model_config: 6 | target: sgm.modules.nvsadapter.controlnet.ControlNets 7 | params: 8 | num_controlnets: 3 9 | image_size: 32 # unused 10 | in_channels: 4 11 | hint_channels: 3 12 | model_channels: 320 13 | attention_resolutions: [ 4, 2, 1 ] 14 | num_res_blocks: 2 15 | channel_mult: [ 1, 2, 4, 4 ] 16 | num_heads: 8 17 | use_spatial_transformer: True 18 | transformer_depth: 1 19 | context_dim: 768 20 | use_checkpoint: True 21 | legacy: False 22 | 23 | conditioner_config: 24 | target: sgm.modules.nvsadapter.conditioner.MultipleGeneralConditioners 25 | params: 26 | conditioners_config: 27 | - name: support_latents 28 | target: sgm.modules.GeneralConditioner 29 | params: 30 | emb_models: 31 | - is_trainable: False 32 | input_key: support_latents 33 | target: sgm.modules.encoders.modules.IdentityEncoder 34 | params: {} 35 | 36 | - name: ray 37 | target: sgm.modules.GeneralConditioner 38 | params: 39 | emb_models: 40 | # ray embedding 41 | - is_trainable: False 42 | input_keys: 43 | - support_rays_offset 44 | - support_rays_direction 45 | - query_rays_offset 46 | - query_rays_direction 47 | target: sgm.modules.nvsadapter.conditioner.RayPosConditionEmbedder 48 | params: 49 | offset_deg: [0, 15] 50 | direction_deg: [0, 8] 51 | use_plucker: true 52 | 53 | - name: txt 54 | target: sgm.modules.GeneralConditioner 55 | params: 56 | emb_models: 57 | # text embedding 58 | - is_trainable: False 59 | input_key: txt 60 | target: sgm.modules.encoders.modules.FrozenCLIPEmbedder 61 | 62 | - name: image 63 | target: sgm.modules.GeneralConditioner 64 | params: 65 | emb_models: 66 | # image embedding 67 | - is_trainable: False 68 | input_key: support_rgbs 69 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 70 | params: 71 | unsqueeze_dim: true 72 | 73 | - name: control 74 | target: sgm.modules.GeneralConditioner 75 | params: 76 | emb_models: 77 | # image embedding 78 | - is_trainable: False 79 | input_keys: 80 | - support_rgbs_cond 81 | - query_rgbs_cond 82 | target: sgm.modules.nvsadapter.conditioner.CannyHEDDepthConditioner 83 | 84 | # path to the pre-trained SD model checkpoint 85 | sd_ckpt_path: checkpoints/v1-5-pruned-emaonly.ckpt 86 | controlnet_ckpt_path: 87 | - checkpoints/control_sd15_canny.pth 88 | - checkpoints/control_sd15_hed.pth 89 | - checkpoints/control_sd15_depth.pth 90 | ckpt_path: checkpoints/4query_sd_15_last.ckpt 91 | -------------------------------------------------------------------------------- /configs/options/controlnet_depth.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | network_config: 4 | params: 5 | control_model_config: 6 | target: sgm.modules.nvsadapter.controlnet.ControlNet 7 | params: 8 | image_size: 32 # unused 9 | in_channels: 4 10 | hint_channels: 3 11 | model_channels: 320 12 | attention_resolutions: [ 4, 2, 1 ] 13 | num_res_blocks: 2 14 | channel_mult: [ 1, 2, 4, 4 ] 15 | num_heads: 8 16 | use_spatial_transformer: True 17 | transformer_depth: 1 18 | context_dim: 768 19 | use_checkpoint: True 20 | legacy: False 21 | 22 | conditioner_config: 23 | target: sgm.modules.nvsadapter.conditioner.MultipleGeneralConditioners 24 | params: 25 | conditioners_config: 26 | - name: support_latents 27 | target: sgm.modules.GeneralConditioner 28 | params: 29 | emb_models: 30 | - is_trainable: False 31 | input_key: support_latents 32 | target: sgm.modules.encoders.modules.IdentityEncoder 33 | params: {} 34 | 35 | - name: ray 36 | target: sgm.modules.GeneralConditioner 37 | params: 38 | emb_models: 39 | # ray embedding 40 | - is_trainable: False 41 | input_keys: 42 | - support_rays_offset 43 | - support_rays_direction 44 | - query_rays_offset 45 | - query_rays_direction 46 | target: sgm.modules.nvsadapter.conditioner.RayPosConditionEmbedder 47 | params: 48 | offset_deg: [0, 15] 49 | direction_deg: [0, 8] 50 | use_plucker: true 51 | 52 | - name: txt 53 | target: sgm.modules.GeneralConditioner 54 | params: 55 | emb_models: 56 | # text embedding 57 | - is_trainable: False 58 | input_key: txt 59 | target: sgm.modules.encoders.modules.FrozenCLIPEmbedder 60 | 61 | - name: image 62 | target: sgm.modules.GeneralConditioner 63 | params: 64 | emb_models: 65 | # image embedding 66 | - is_trainable: False 67 | input_key: support_rgbs 68 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 69 | params: 70 | unsqueeze_dim: true 71 | 72 | - name: control 73 | target: sgm.modules.GeneralConditioner 74 | params: 75 | emb_models: 76 | # image embedding 77 | - is_trainable: False 78 | input_keys: 79 | - support_rgbs_cond 80 | - query_rgbs_cond 81 | target: sgm.modules.nvsadapter.conditioner.MiDASDepthConditioner 82 | params: 83 | model_type: dpt_hybrid 84 | 85 | # path to the pre-trained SD model checkpoint 86 | sd_ckpt_path: checkpoints/v1-5-pruned-emaonly.ckpt 87 | controlnet_ckpt_path: checkpoints/control_sd15_depth.pth 88 | ckpt_path: checkpoints/4query_sd_15_last.ckpt 89 | -------------------------------------------------------------------------------- /configs/options/controlnet_hed.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | network_config: 4 | params: 5 | control_model_config: 6 | target: sgm.modules.nvsadapter.controlnet.ControlNet 7 | params: 8 | image_size: 32 # unused 9 | in_channels: 4 10 | hint_channels: 3 11 | model_channels: 320 12 | attention_resolutions: [ 4, 2, 1 ] 13 | num_res_blocks: 2 14 | channel_mult: [ 1, 2, 4, 4 ] 15 | num_heads: 8 16 | use_spatial_transformer: True 17 | transformer_depth: 1 18 | context_dim: 768 19 | use_checkpoint: True 20 | legacy: False 21 | 22 | conditioner_config: 23 | target: sgm.modules.nvsadapter.conditioner.MultipleGeneralConditioners 24 | params: 25 | conditioners_config: 26 | - name: support_latents 27 | target: sgm.modules.GeneralConditioner 28 | params: 29 | emb_models: 30 | - is_trainable: False 31 | input_key: support_latents 32 | target: sgm.modules.encoders.modules.IdentityEncoder 33 | params: {} 34 | 35 | - name: ray 36 | target: sgm.modules.GeneralConditioner 37 | params: 38 | emb_models: 39 | # ray embedding 40 | - is_trainable: False 41 | input_keys: 42 | - support_rays_offset 43 | - support_rays_direction 44 | - query_rays_offset 45 | - query_rays_direction 46 | target: sgm.modules.nvsadapter.conditioner.RayPosConditionEmbedder 47 | params: 48 | offset_deg: [0, 15] 49 | direction_deg: [0, 8] 50 | use_plucker: true 51 | 52 | - name: txt 53 | target: sgm.modules.GeneralConditioner 54 | params: 55 | emb_models: 56 | # text embedding 57 | - is_trainable: False 58 | input_key: txt 59 | target: sgm.modules.encoders.modules.FrozenCLIPEmbedder 60 | 61 | - name: image 62 | target: sgm.modules.GeneralConditioner 63 | params: 64 | emb_models: 65 | # image embedding 66 | - is_trainable: False 67 | input_key: support_rgbs 68 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 69 | params: 70 | unsqueeze_dim: true 71 | 72 | - name: control 73 | target: sgm.modules.GeneralConditioner 74 | params: 75 | emb_models: 76 | # image embedding 77 | - is_trainable: False 78 | input_keys: 79 | - support_rgbs_cond 80 | - query_rgbs_cond 81 | target: sgm.modules.nvsadapter.conditioner.HEDConditioner 82 | 83 | # path to the pre-trained SD model checkpoint 84 | sd_ckpt_path: checkpoints/v1-5-pruned-emaonly.ckpt 85 | controlnet_ckpt_path: checkpoints/control_sd15_hed.pth 86 | ckpt_path: checkpoints/4query_sd_15_last.ckpt 87 | -------------------------------------------------------------------------------- /configs/options/controlnet_hed_depth.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | network_config: 4 | params: 5 | control_model_config: 6 | target: sgm.modules.nvsadapter.controlnet.ControlNets 7 | params: 8 | num_controlnets: 2 9 | image_size: 32 # unused 10 | in_channels: 4 11 | hint_channels: 3 12 | model_channels: 320 13 | attention_resolutions: [ 4, 2, 1 ] 14 | num_res_blocks: 2 15 | channel_mult: [ 1, 2, 4, 4 ] 16 | num_heads: 8 17 | use_spatial_transformer: True 18 | transformer_depth: 1 19 | context_dim: 768 20 | use_checkpoint: True 21 | legacy: False 22 | 23 | conditioner_config: 24 | target: sgm.modules.nvsadapter.conditioner.MultipleGeneralConditioners 25 | params: 26 | conditioners_config: 27 | - name: support_latents 28 | target: sgm.modules.GeneralConditioner 29 | params: 30 | emb_models: 31 | - is_trainable: False 32 | input_key: support_latents 33 | target: sgm.modules.encoders.modules.IdentityEncoder 34 | params: {} 35 | 36 | - name: ray 37 | target: sgm.modules.GeneralConditioner 38 | params: 39 | emb_models: 40 | # ray embedding 41 | - is_trainable: False 42 | input_keys: 43 | - support_rays_offset 44 | - support_rays_direction 45 | - query_rays_offset 46 | - query_rays_direction 47 | target: sgm.modules.nvsadapter.conditioner.RayPosConditionEmbedder 48 | params: 49 | offset_deg: [0, 15] 50 | direction_deg: [0, 8] 51 | use_plucker: true 52 | 53 | - name: txt 54 | target: sgm.modules.GeneralConditioner 55 | params: 56 | emb_models: 57 | # text embedding 58 | - is_trainable: False 59 | input_key: txt 60 | target: sgm.modules.encoders.modules.FrozenCLIPEmbedder 61 | 62 | - name: image 63 | target: sgm.modules.GeneralConditioner 64 | params: 65 | emb_models: 66 | # image embedding 67 | - is_trainable: False 68 | input_key: support_rgbs 69 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 70 | params: 71 | unsqueeze_dim: true 72 | 73 | - name: control 74 | target: sgm.modules.GeneralConditioner 75 | params: 76 | emb_models: 77 | # image embedding 78 | - is_trainable: False 79 | input_keys: 80 | - support_rgbs_cond 81 | - query_rgbs_cond 82 | target: sgm.modules.nvsadapter.conditioner.HEDDepthConditioner 83 | 84 | # path to the pre-trained SD model checkpoint 85 | sd_ckpt_path: checkpoints/v1-5-pruned-emaonly.ckpt 86 | controlnet_ckpt_path: 87 | - checkpoints/control_sd15_hed.pth 88 | - checkpoints/control_sd15_depth.pth 89 | ckpt_path: checkpoints/4query_sd_15_last.ckpt 90 | -------------------------------------------------------------------------------- /configs/options/full_finetune.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | lr_mult_for_pretrained: 0.1 4 | -------------------------------------------------------------------------------- /configs/options/lora_blueresin.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | lora_ckpt_path: ./checkpoints/lora_checkpoint/blueresin.ckpt -------------------------------------------------------------------------------- /configs/options/lora_cofzee.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | lora_ckpt_path: ./checkpoints/lora_checkpoint/cofzee.ckpt -------------------------------------------------------------------------------- /configs/options/lora_friedegg.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | lora_ckpt_path: ./checkpoints/lora_checkpoint/friedegg.ckpt -------------------------------------------------------------------------------- /configs/options/lora_gelato.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | lora_ckpt_path: ./checkpoints/lora_checkpoint/gelato.ckpt -------------------------------------------------------------------------------- /configs/options/lora_gemstone.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | lora_ckpt_path: ./checkpoints/lora_checkpoint/ntrgmstn.ckpt -------------------------------------------------------------------------------- /configs/options/lora_watce.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | lora_ckpt_path: ./checkpoints/lora_checkpoint/cofzee.ckpt -------------------------------------------------------------------------------- /configs/options/lora_wood.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | lora_ckpt_path: ./checkpoints/lora_checkpoint/woodfigurez.ckpt -------------------------------------------------------------------------------- /configs/options/sd15.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | network_config: 4 | params: 5 | # pre-trained SD configuration 6 | sd_config: 7 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 8 | params: 9 | use_checkpoint: False 10 | in_channels: 4 11 | out_channels: 4 12 | model_channels: 320 13 | attention_resolutions: [ 4, 2, 1 ] 14 | num_res_blocks: 2 15 | channel_mult: [ 1, 2, 4, 4 ] 16 | num_heads: 8 17 | num_head_channels: -1 18 | use_linear_in_transformer: False 19 | use_spatial_transformer: True 20 | transformer_depth: 1 21 | context_dim: 768 22 | legacy: False 23 | 24 | conditioner_config: 25 | target: sgm.modules.nvsadapter.conditioner.MultipleGeneralConditioners 26 | params: 27 | conditioners_config: 28 | - name: support_latents 29 | target: sgm.modules.GeneralConditioner 30 | params: 31 | emb_models: 32 | - is_trainable: False 33 | input_key: support_latents 34 | target: sgm.modules.encoders.modules.IdentityEncoder 35 | params: {} 36 | 37 | - name: ray 38 | target: sgm.modules.GeneralConditioner 39 | params: 40 | emb_models: 41 | # ray embedding 42 | - is_trainable: False 43 | input_keys: 44 | - support_rays_offset 45 | - support_rays_direction 46 | - query_rays_offset 47 | - query_rays_direction 48 | target: sgm.modules.nvsadapter.conditioner.RayPosConditionEmbedder 49 | params: 50 | offset_deg: [0, 15] 51 | direction_deg: [0, 8] 52 | use_plucker: true 53 | 54 | - name: txt 55 | target: sgm.modules.GeneralConditioner 56 | params: 57 | emb_models: 58 | # text embedding 59 | - is_trainable: False 60 | input_key: txt 61 | target: sgm.modules.encoders.modules.FrozenCLIPEmbedder 62 | 63 | - name: image 64 | target: sgm.modules.GeneralConditioner 65 | params: 66 | emb_models: 67 | # image embedding 68 | - is_trainable: False 69 | input_key: support_rgbs 70 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 71 | params: 72 | unsqueeze_dim: true 73 | 74 | # path to the pre-trained SD model checkpoint 75 | sd_ckpt_path: checkpoints/v1-5-pruned-emaonly.ckpt 76 | -------------------------------------------------------------------------------- /data/DejaVuSans.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/data/DejaVuSans.ttf -------------------------------------------------------------------------------- /licenses/LICENSE_DPT: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Intel ISL (Intel Intelligent Systems Lab) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /licenses/LICENSE_SD: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Stability AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /licenses/LICENSE_SD_XL: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Stability AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | black==23.7.0 2 | chardet==5.1.0 3 | clip @ git+https://github.com/openai/CLIP.git 4 | einops>=0.6.1 5 | fairscale>=0.4.13 6 | fire>=0.5.0 7 | fsspec>=2023.6.0 8 | invisible-watermark>=0.2.0 9 | kornia==0.6.9 10 | matplotlib>=3.7.2 11 | natsort>=8.4.0 12 | ninja>=1.11.1 13 | numpy>=1.24.4 14 | omegaconf>=2.3.0 15 | open-clip-torch>=2.20.0 16 | opencv-python==4.6.0.66 17 | pandas>=2.0.3 18 | pillow>=9.5.0 19 | pudb>=2022.1.3 20 | pytorch-lightning==2.0.1 21 | pyyaml>=6.0.1 22 | scipy>=1.10.1 23 | streamlit>=0.73.1 24 | tensorboardx==2.6 25 | timm>=0.9.2 26 | tokenizers==0.12.1 27 | torch>=2.0.1 28 | torchaudio>=2.0.2 29 | torchdata==0.6.1 30 | torchmetrics>=1.0.1 31 | torchvision>=0.15.2 32 | tqdm>=4.65.0 33 | transformers==4.19.1 34 | triton==2.0.0 35 | urllib3<1.27,>=1.25.4 36 | wandb>=0.15.6 37 | webdataset>=0.2.33 38 | wheel>=0.41.0 39 | xformers>=0.0.20 40 | streamlit-keyup==0.2.0 41 | piqa==1.3.2 42 | pyrootutils==1.0.4 43 | gradio==4.10.0 44 | -------------------------------------------------------------------------------- /sample/deer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/sample/deer.png -------------------------------------------------------------------------------- /sample/dolphine.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/sample/dolphine.png -------------------------------------------------------------------------------- /sample/kunkun.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/sample/kunkun.png -------------------------------------------------------------------------------- /scripts/eval_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from argparse import ArgumentParser 5 | from tqdm import tqdm 6 | from glob import glob 7 | 8 | import torch 9 | from skimage.io import imread 10 | from piqa.ssim import SSIM 11 | from piqa.lpips import LPIPS 12 | from piqa.psnr import PSNR 13 | 14 | 15 | def compute_psnr_float(img_gt, img_pr): 16 | img_gt = img_gt.reshape([-1, 3]).astype(np.float32) 17 | img_pr = img_pr.reshape([-1, 3]).astype(np.float32) 18 | mse = np.mean((img_gt - img_pr) ** 2, 0) 19 | mse = np.mean(mse) 20 | psnr = 10 * np.log10(1 / mse) 21 | return psnr 22 | 23 | 24 | def color_map_forward(rgb): 25 | dim = rgb.shape[-1] 26 | if dim==3: 27 | return rgb.astype(np.float32)/255 28 | else: 29 | rgb = rgb.astype(np.float32)/255 30 | rgb, alpha = rgb[:,:,:3], rgb[:,:,3:] 31 | rgb = rgb * alpha + (1-alpha) 32 | return rgb 33 | 34 | 35 | def main(): 36 | """ 37 | input_dir 38 | - folder_0 39 | - pred.png 40 | - target.png 41 | - folder_1 42 | - pred.png 43 | - target.png 44 | ... 45 | """ 46 | parser = ArgumentParser() 47 | parser.add_argument('--input_dir', type=str) 48 | args = parser.parse_args() 49 | 50 | output_log_path = os.path.join(args.input_dir, "metric.txt") 51 | 52 | target_path_list = glob(os.path.join(args.input_dir, "**", "target.png"), recursive=True) 53 | 54 | psnr_fn = PSNR().cuda() 55 | ssim_fn = SSIM().cuda() 56 | lpips_fn = LPIPS(network="vgg").cuda() 57 | 58 | psnrs, ssims, lpipss = [], [], [] 59 | for target_path in tqdm(target_path_list): 60 | pred_path = target_path.replace("target.png", "pred.png") 61 | 62 | img_gt_int = imread(target_path) 63 | img_pr_int = imread(pred_path) 64 | 65 | img_gt = color_map_forward(img_gt_int) 66 | img_pr = color_map_forward(img_pr_int) 67 | 68 | with torch.no_grad(): 69 | img_gt_tensor = torch.from_numpy(img_gt.astype(np.float32)).permute(2,0,1).unsqueeze(0).cuda() 70 | img_pr_tensor = torch.from_numpy(img_pr.astype(np.float32)).permute(2,0,1).unsqueeze(0).cuda() 71 | 72 | ssims.append(ssim_fn(img_pr_tensor, img_gt_tensor).cpu().numpy()) 73 | lpipss.append(lpips_fn(img_pr_tensor, img_gt_tensor).cpu().numpy()) 74 | psnrs.append(psnr_fn(img_pr_tensor, img_gt_tensor).cpu().numpy()) 75 | 76 | msg=f'psnr: {np.mean(psnrs):.5f}\nssim: {np.mean(ssims):.5f}\nlpips {np.mean(lpipss):.5f}' 77 | print(msg) 78 | with open(output_log_path,'w') as f: 79 | f.write(msg+'\n') 80 | 81 | 82 | if __name__=="__main__": 83 | main() 84 | -------------------------------------------------------------------------------- /scripts/objaverse_renderings_to_webdataset.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import os 3 | import math 4 | import json 5 | from glob import glob 6 | from tqdm import tqdm 7 | 8 | import fire 9 | import webdataset as wds 10 | 11 | 12 | def paths_to_webdataset(paths: List[str], 13 | basepath: str, 14 | target_filepath_pattern: str, 15 | maxcount_per_file: int = 100000, 16 | maxsize_in_bytes_per_file: float = 1e9): 17 | 18 | shard_writer = wds.ShardWriter(pattern=target_filepath_pattern, 19 | maxcount=maxcount_per_file, # default 100,000 20 | maxsize=maxsize_in_bytes_per_file) # default 3e9, in bytes 21 | total_view = 12 22 | 23 | for path in tqdm(paths): 24 | basename = path 25 | sample = { 26 | "__key__": basename 27 | } 28 | for view_index in range(total_view): 29 | view_index_str = f"{view_index:03d}" 30 | 31 | full_basepath = os.path.join(basepath, basename, view_index_str) 32 | 33 | with open(f"{full_basepath}.png", "rb") as stream: 34 | image = stream.read() 35 | with open(f"{full_basepath}.npy", "rb") as stream: 36 | data = stream.read() 37 | 38 | sample[f"png_{view_index_str}"] = image 39 | sample[f"npy_{view_index_str}"] = data 40 | shard_writer.write(sample) 41 | shard_writer.close() 42 | 43 | 44 | def objaverse_renderings_to_webdataset(paths_json_path: str, 45 | basepath: str = "./", 46 | maxcount_per_file: int = 100000, 47 | maxsize_in_bytes_per_file: float = 1e9): 48 | assert os.path.exists(paths_json_path), f"{paths_json_path} not exits." 49 | with open(os.path.join(paths_json_path)) as file: 50 | paths = json.load(file) 51 | 52 | total_objects = len(paths) 53 | assert total_objects > 0, f"total objects: {total_objects}, no valid objects exits." 54 | 55 | split_index = math.floor(total_objects * 0.99) 56 | 57 | train_paths = paths[:split_index] # first 99 % as training 58 | valid_paths = paths[split_index:] # last 1 % as validation 59 | 60 | assert len(train_paths) > 0, f"{len(train_paths)}, no train path exits." 61 | assert len(valid_paths) > 0, f"{len(valid_paths)}, no valid path exits." 62 | 63 | paths_to_webdataset(paths=train_paths, 64 | basepath=basepath, 65 | target_filepath_pattern="objaverse_rendering_train_%06d.tar", 66 | maxcount_per_file=maxcount_per_file, 67 | maxsize_in_bytes_per_file=maxsize_in_bytes_per_file) 68 | 69 | paths_to_webdataset(paths=valid_paths, 70 | basepath=basepath, 71 | target_filepath_pattern="objaverse_rendering_valid_%06d.tar", 72 | maxcount_per_file=maxcount_per_file, 73 | maxsize_in_bytes_per_file=maxsize_in_bytes_per_file) 74 | 75 | 76 | if __name__ == "__main__": 77 | fire.Fire(objaverse_renderings_to_webdataset) 78 | -------------------------------------------------------------------------------- /scripts/sampling.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import os 4 | 5 | import pytorch_lightning as pl 6 | from omegaconf import OmegaConf 7 | import pyrootutils 8 | 9 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 10 | 11 | from sgm.util import instantiate_from_config 12 | from sgm.data.dirdataset import DirDataModule 13 | 14 | 15 | def evaluate(args): 16 | pl.seed_everything(args.seed) 17 | 18 | name = args.name 19 | if name is None: 20 | name = "noname" 21 | 22 | expname = os.path.splitext(os.path.basename(args.config_path))[0] 23 | 24 | save_dir = Path(args.logdir, expname) 25 | save_dir.mkdir(exist_ok=True) 26 | 27 | with open(args.config_path) as fp: 28 | config = OmegaConf.load(fp) 29 | 30 | for cfg_path in args.additional_configs: 31 | with open(cfg_path) as fp: 32 | config = OmegaConf.merge(config, OmegaConf.load(fp)) 33 | 34 | model_config = config.model 35 | model_config.params.use_ema = args.use_ema 36 | model_config.params.sd_ckpt_path = None 37 | model_config.params.ckpt_path = args.ckpt_path 38 | 39 | data_config = config.data 40 | 41 | if args.cfg_scale is not None: 42 | model_config.params.sampler_config.params.guider_config.params.scale = args.cfg_scale 43 | cfg_scale = args.cfg_scale 44 | else: 45 | cfg_scale = model_config.params.sampler_config.params.guider_config.params.scale 46 | 47 | dirname = f"{name}_cfg_scale_{cfg_scale}_use_ema_{args.use_ema}_seed_{args.seed}" 48 | if args.split_idx is not None: 49 | dirname = dirname + "_" + f"{args.split_idx}" 50 | 51 | save_dir = save_dir.joinpath(dirname) 52 | save_dir.mkdir(exist_ok=True) 53 | 54 | litmodule = instantiate_from_config(model_config) 55 | litmodule.save_dir = save_dir 56 | 57 | datamodule = DirDataModule( 58 | ds_root_path=args.ds_root_path, 59 | ds_list_json_path=args.ds_list_json_path, 60 | num_total_views=args.ds_num_total_views, 61 | batch_size=args.batch_size, 62 | num_workers=data_config.params.num_workers, 63 | resolution=data_config.params.val_config.resolution, 64 | use_relative=data_config.params.val_config.use_relative, 65 | ) 66 | 67 | trainer = pl.Trainer(devices=1) 68 | trainer.test(litmodule, dataloaders=datamodule) 69 | 70 | 71 | if __name__ == "__main__": 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument("--config_path", type=str, default=None, required=True, help="path to config of trained model") 74 | parser.add_argument("--ckpt_path", type=str, default=None, required=True, help="path to checkpoint of trained model") 75 | parser.add_argument("-n", "--name", type=str, default=None, help="name of the visualization") 76 | parser.add_argument("--logdir", type=str, default="./logs_sampling", help="path to save the visualization") 77 | parser.add_argument("--use_ema", action="store_true", default=False, help="whether to use EMA model") 78 | parser.add_argument("--cfg_scale", type=float, default=None, help="scale for classifier free guidance") 79 | parser.add_argument("--ds_name", type=str, default="objaverse", help="the name of dataset") 80 | parser.add_argument("--ds_root_path", type=str, help="path to dataset for test", required=True) 81 | parser.add_argument("--ds_list_json_path", type=str, help="json path for list of dataset", required=True) 82 | parser.add_argument("--ds_num_total_views", type=int, help="number of total views per scene", required=True) 83 | parser.add_argument("--split_idx", type=int, default=None, help="split index for dataset") 84 | parser.add_argument("--batch_size", type=int, default=8, help="batch size for test") 85 | parser.add_argument("--seed", type=int, default=0, help="seed for random number generator") 86 | parser.add_argument("-c", "--additional_configs", nargs="*", default=list()) 87 | args = parser.parse_args() 88 | 89 | print('=' * 100) 90 | for k, v in vars(args).items(): 91 | print(f'{k}: {v}') 92 | print('=' * 100) 93 | 94 | evaluate(args) 95 | -------------------------------------------------------------------------------- /sgm/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import AutoencodingEngine, DiffusionEngine 2 | from .util import get_configs_path, instantiate_from_config 3 | 4 | __version__ = "0.1.0" 5 | -------------------------------------------------------------------------------- /sgm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/sgm/data/__init__.py -------------------------------------------------------------------------------- /sgm/data/cifar10.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torchvision 3 | from torch.utils.data import DataLoader, Dataset 4 | from torchvision import transforms 5 | 6 | 7 | class CIFAR10DataDictWrapper(Dataset): 8 | def __init__(self, dset): 9 | super().__init__() 10 | self.dset = dset 11 | 12 | def __getitem__(self, i): 13 | x, y = self.dset[i] 14 | return {"jpg": x, "cls": y} 15 | 16 | def __len__(self): 17 | return len(self.dset) 18 | 19 | 20 | class CIFAR10Loader(pl.LightningDataModule): 21 | def __init__(self, batch_size, num_workers=0, shuffle=True): 22 | super().__init__() 23 | 24 | transform = transforms.Compose( 25 | [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] 26 | ) 27 | 28 | self.batch_size = batch_size 29 | self.num_workers = num_workers 30 | self.shuffle = shuffle 31 | self.train_dataset = CIFAR10DataDictWrapper( 32 | torchvision.datasets.CIFAR10( 33 | root=".data/", train=True, download=True, transform=transform 34 | ) 35 | ) 36 | self.test_dataset = CIFAR10DataDictWrapper( 37 | torchvision.datasets.CIFAR10( 38 | root=".data/", train=False, download=True, transform=transform 39 | ) 40 | ) 41 | 42 | def prepare_data(self): 43 | pass 44 | 45 | def train_dataloader(self): 46 | return DataLoader( 47 | self.train_dataset, 48 | batch_size=self.batch_size, 49 | shuffle=self.shuffle, 50 | num_workers=self.num_workers, 51 | ) 52 | 53 | def test_dataloader(self): 54 | return DataLoader( 55 | self.test_dataset, 56 | batch_size=self.batch_size, 57 | shuffle=self.shuffle, 58 | num_workers=self.num_workers, 59 | ) 60 | 61 | def val_dataloader(self): 62 | return DataLoader( 63 | self.test_dataset, 64 | batch_size=self.batch_size, 65 | shuffle=self.shuffle, 66 | num_workers=self.num_workers, 67 | ) 68 | -------------------------------------------------------------------------------- /sgm/data/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torchdata.datapipes.iter 4 | import webdataset as wds 5 | from omegaconf import DictConfig 6 | from pytorch_lightning import LightningDataModule 7 | 8 | try: 9 | from sdata import create_dataset, create_dummy_dataset, create_loader 10 | except ImportError as e: 11 | print("#" * 100) 12 | print("Datasets not yet available") 13 | print("to enable, we need to add stable-datasets as a submodule") 14 | print("please use ``git submodule update --init --recursive``") 15 | print("and do ``pip install -e stable-datasets/`` from the root of this repo") 16 | print("#" * 100) 17 | exit(1) 18 | 19 | 20 | class StableDataModuleFromConfig(LightningDataModule): 21 | def __init__( 22 | self, 23 | train: DictConfig, 24 | validation: Optional[DictConfig] = None, 25 | test: Optional[DictConfig] = None, 26 | skip_val_loader: bool = False, 27 | dummy: bool = False, 28 | ): 29 | super().__init__() 30 | self.train_config = train 31 | assert ( 32 | "datapipeline" in self.train_config and "loader" in self.train_config 33 | ), "train config requires the fields `datapipeline` and `loader`" 34 | 35 | self.val_config = validation 36 | if not skip_val_loader: 37 | if self.val_config is not None: 38 | assert ( 39 | "datapipeline" in self.val_config and "loader" in self.val_config 40 | ), "validation config requires the fields `datapipeline` and `loader`" 41 | else: 42 | print( 43 | "Warning: No Validation datapipeline defined, using that one from training" 44 | ) 45 | self.val_config = train 46 | 47 | self.test_config = test 48 | if self.test_config is not None: 49 | assert ( 50 | "datapipeline" in self.test_config and "loader" in self.test_config 51 | ), "test config requires the fields `datapipeline` and `loader`" 52 | 53 | self.dummy = dummy 54 | if self.dummy: 55 | print("#" * 100) 56 | print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)") 57 | print("#" * 100) 58 | 59 | def setup(self, stage: str) -> None: 60 | print("Preparing datasets") 61 | if self.dummy: 62 | data_fn = create_dummy_dataset 63 | else: 64 | data_fn = create_dataset 65 | 66 | self.train_datapipeline = data_fn(**self.train_config.datapipeline) 67 | if self.val_config: 68 | self.val_datapipeline = data_fn(**self.val_config.datapipeline) 69 | if self.test_config: 70 | self.test_datapipeline = data_fn(**self.test_config.datapipeline) 71 | 72 | def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe: 73 | loader = create_loader(self.train_datapipeline, **self.train_config.loader) 74 | return loader 75 | 76 | def val_dataloader(self) -> wds.DataPipeline: 77 | return create_loader(self.val_datapipeline, **self.val_config.loader) 78 | 79 | def test_dataloader(self) -> wds.DataPipeline: 80 | return create_loader(self.test_datapipeline, **self.test_config.loader) 81 | -------------------------------------------------------------------------------- /sgm/data/mnist.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torchvision 3 | from torch.utils.data import DataLoader, Dataset 4 | from torchvision import transforms 5 | 6 | 7 | class MNISTDataDictWrapper(Dataset): 8 | def __init__(self, dset): 9 | super().__init__() 10 | self.dset = dset 11 | 12 | def __getitem__(self, i): 13 | x, y = self.dset[i] 14 | return {"jpg": x, "cls": y} 15 | 16 | def __len__(self): 17 | return len(self.dset) 18 | 19 | 20 | class MNISTLoader(pl.LightningDataModule): 21 | def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True): 22 | super().__init__() 23 | 24 | transform = transforms.Compose( 25 | [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] 26 | ) 27 | 28 | self.batch_size = batch_size 29 | self.num_workers = num_workers 30 | self.prefetch_factor = prefetch_factor if num_workers > 0 else 0 31 | self.shuffle = shuffle 32 | self.train_dataset = MNISTDataDictWrapper( 33 | torchvision.datasets.MNIST( 34 | root=".data/", train=True, download=True, transform=transform 35 | ) 36 | ) 37 | self.test_dataset = MNISTDataDictWrapper( 38 | torchvision.datasets.MNIST( 39 | root=".data/", train=False, download=True, transform=transform 40 | ) 41 | ) 42 | 43 | def prepare_data(self): 44 | pass 45 | 46 | def train_dataloader(self): 47 | return DataLoader( 48 | self.train_dataset, 49 | batch_size=self.batch_size, 50 | shuffle=self.shuffle, 51 | num_workers=self.num_workers, 52 | prefetch_factor=self.prefetch_factor, 53 | ) 54 | 55 | def test_dataloader(self): 56 | return DataLoader( 57 | self.test_dataset, 58 | batch_size=self.batch_size, 59 | shuffle=self.shuffle, 60 | num_workers=self.num_workers, 61 | prefetch_factor=self.prefetch_factor, 62 | ) 63 | 64 | def val_dataloader(self): 65 | return DataLoader( 66 | self.test_dataset, 67 | batch_size=self.batch_size, 68 | shuffle=self.shuffle, 69 | num_workers=self.num_workers, 70 | prefetch_factor=self.prefetch_factor, 71 | ) 72 | 73 | 74 | if __name__ == "__main__": 75 | dset = MNISTDataDictWrapper( 76 | torchvision.datasets.MNIST( 77 | root=".data/", 78 | train=False, 79 | download=True, 80 | transform=transforms.Compose( 81 | [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] 82 | ), 83 | ) 84 | ) 85 | ex = dset[0] 86 | -------------------------------------------------------------------------------- /sgm/data/single_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional 3 | import json 4 | 5 | import numpy as np 6 | import torch 7 | from torchvision import transforms 8 | from torch.utils.data import Dataset 9 | from PIL import Image 10 | 11 | from sgm.geometry import make_view_matrix, make_intrinsic_matrix 12 | 13 | 14 | def decode_image(path: str, color: List, has_alpha: bool = True) -> np.array: 15 | img = Image.open(path) 16 | img = np.array(img, dtype=np.float32) 17 | if has_alpha: 18 | img[img[:, :, -1] == 0.0] = color 19 | return Image.fromarray(np.uint8(img[:, :, :3])) 20 | 21 | 22 | class SingleImageDataset(Dataset): 23 | fov_rad = np.deg2rad(49.1) # for objaverse rendering dataset 24 | color_background = [255.0, 255.0, 255.0, 255.0] 25 | 26 | def __init__(self, image_path: str, 27 | support_elevation: float, support_azimuth: float, support_dist: float, 28 | elevations: list, azimuths: list, dists: list, 29 | resolution: int, num_query: int, use_relative: bool = True): 30 | super().__init__() 31 | 32 | assert len(elevations) == len(azimuths) == len(azimuths), \ 33 | f"{len(elevations)=} == {len(azimuths)=} == {len(azimuths)=}" 34 | 35 | self.image_path = image_path 36 | self.num_query = num_query 37 | self.elevations = np.array(elevations).reshape([-1, self.num_query]) 38 | self.azimuths = np.array(azimuths).reshape([-1, self.num_query]) 39 | self.dists = np.array(dists).reshape([-1, self.num_query]) 40 | self.resolution = resolution 41 | self.use_relative = use_relative 42 | 43 | image_transform = transforms.Compose([ 44 | transforms.ToTensor(), 45 | transforms.Resize( 46 | (self.resolution, self.resolution), 47 | interpolation=transforms.InterpolationMode.BICUBIC, 48 | antialias=True 49 | ), 50 | transforms.Lambda(lambda x: x * 2.0 - 1.0), 51 | ]) 52 | self.support_rgb = image_transform(decode_image(self.image_path, self.color_background)) 53 | 54 | self.support_c2w = make_view_matrix(azimuth=np.deg2rad(support_azimuth), elevation=np.deg2rad(support_elevation), dist=support_dist) 55 | self.support_c2w[:3, :3] *= -1 56 | 57 | self.intrinsic = make_intrinsic_matrix(fov_rad=self.fov_rad, h=self.resolution, w=self.resolution) 58 | 59 | def __len__(self): 60 | return len(self.elevations) 61 | 62 | def __getitem__(self, index): 63 | num_views_each = [1, self.num_query] 64 | 65 | intrinsics = [self.intrinsic] 66 | c2ws = [self.support_c2w] 67 | 68 | for azimuth, elevation, dist in zip(self.azimuths[index], self.elevations[index], self.dists[index]): 69 | c2w = make_view_matrix(azimuth=np.deg2rad(azimuth), elevation=np.deg2rad(elevation), dist=dist) 70 | c2w[:3, :3] *= -1 71 | 72 | intrinsics.append(self.intrinsic) 73 | c2ws.append(c2w) 74 | 75 | intrinsics, c2ws = map(lambda x: torch.stack(x), (intrinsics, c2ws)) 76 | 77 | support_rgbs = self.support_rgb[None, ...] 78 | support_intrinsics, query_intrinsics = torch.split(intrinsics, num_views_each) 79 | support_c2ws, query_c2ws = torch.split(c2ws, num_views_each) 80 | 81 | if self.use_relative: 82 | inverse_support_c2ws = torch.inverse(support_c2ws) 83 | support_c2ws = inverse_support_c2ws @ support_c2ws 84 | query_c2ws = inverse_support_c2ws @ query_c2ws 85 | 86 | return dict( 87 | support_rgbs=support_rgbs, 88 | support_intrinsics=support_intrinsics, 89 | support_c2ws=support_c2ws, 90 | query_intrinsics=query_intrinsics, 91 | query_c2ws=query_c2ws, 92 | ) 93 | -------------------------------------------------------------------------------- /sgm/data/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import io 3 | 4 | from PIL import Image 5 | import numpy as np 6 | 7 | 8 | def decode_image(data, color: List, has_alpha: bool = True) -> np.array: 9 | img = Image.open(io.BytesIO(data)) 10 | img = np.array(img, dtype=np.float32) 11 | if has_alpha: 12 | img[img[:, :, -1] == 0.0] = color 13 | return Image.fromarray(np.uint8(img[:, :, :3])) 14 | -------------------------------------------------------------------------------- /sgm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | 9 | def __init__( 10 | self, 11 | warm_up_steps, 12 | lr_min, 13 | lr_max, 14 | lr_start, 15 | max_decay_steps, 16 | verbosity_interval=0, 17 | ): 18 | self.lr_warm_up_steps = warm_up_steps 19 | self.lr_start = lr_start 20 | self.lr_min = lr_min 21 | self.lr_max = lr_max 22 | self.lr_max_decay_steps = max_decay_steps 23 | self.last_lr = 0.0 24 | self.verbosity_interval = verbosity_interval 25 | 26 | def schedule(self, n, **kwargs): 27 | if self.verbosity_interval > 0: 28 | if n % self.verbosity_interval == 0: 29 | print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 30 | if n < self.lr_warm_up_steps: 31 | lr = ( 32 | self.lr_max - self.lr_start 33 | ) / self.lr_warm_up_steps * n + self.lr_start 34 | self.last_lr = lr 35 | return lr 36 | else: 37 | t = (n - self.lr_warm_up_steps) / ( 38 | self.lr_max_decay_steps - self.lr_warm_up_steps 39 | ) 40 | t = min(t, 1.0) 41 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 42 | 1 + np.cos(t * np.pi) 43 | ) 44 | self.last_lr = lr 45 | return lr 46 | 47 | def __call__(self, n, **kwargs): 48 | return self.schedule(n, **kwargs) 49 | 50 | 51 | class LambdaWarmUpCosineScheduler2: 52 | """ 53 | supports repeated iterations, configurable via lists 54 | note: use with a base_lr of 1.0. 55 | """ 56 | 57 | def __init__( 58 | self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0 59 | ): 60 | assert ( 61 | len(warm_up_steps) 62 | == len(f_min) 63 | == len(f_max) 64 | == len(f_start) 65 | == len(cycle_lengths) 66 | ) 67 | self.lr_warm_up_steps = warm_up_steps 68 | self.f_start = f_start 69 | self.f_min = f_min 70 | self.f_max = f_max 71 | self.cycle_lengths = cycle_lengths 72 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 73 | self.last_f = 0.0 74 | self.verbosity_interval = verbosity_interval 75 | 76 | def find_in_interval(self, n): 77 | interval = 0 78 | for cl in self.cum_cycles[1:]: 79 | if n <= cl: 80 | return interval 81 | interval += 1 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: 88 | print( 89 | f"current step: {n}, recent lr-multiplier: {self.last_f}, " 90 | f"current cycle {cycle}" 91 | ) 92 | if n < self.lr_warm_up_steps[cycle]: 93 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ 94 | cycle 95 | ] * n + self.f_start[cycle] 96 | self.last_f = f 97 | return f 98 | else: 99 | t = (n - self.lr_warm_up_steps[cycle]) / ( 100 | self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] 101 | ) 102 | t = min(t, 1.0) 103 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 104 | 1 + np.cos(t * np.pi) 105 | ) 106 | self.last_f = f 107 | return f 108 | 109 | def __call__(self, n, **kwargs): 110 | return self.schedule(n, **kwargs) 111 | 112 | 113 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 114 | def schedule(self, n, **kwargs): 115 | cycle = self.find_in_interval(n) 116 | n = n - self.cum_cycles[cycle] 117 | if self.verbosity_interval > 0: 118 | if n % self.verbosity_interval == 0: 119 | print( 120 | f"current step: {n}, recent lr-multiplier: {self.last_f}, " 121 | f"current cycle {cycle}" 122 | ) 123 | 124 | if n < self.lr_warm_up_steps[cycle]: 125 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ 126 | cycle 127 | ] * n + self.f_start[cycle] 128 | self.last_f = f 129 | return f 130 | else: 131 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * ( 132 | self.cycle_lengths[cycle] - n 133 | ) / (self.cycle_lengths[cycle]) 134 | self.last_f = f 135 | return f 136 | -------------------------------------------------------------------------------- /sgm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .autoencoder import AutoencodingEngine 2 | from .diffusion import DiffusionEngine 3 | -------------------------------------------------------------------------------- /sgm/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoders.modules import GeneralConditioner 2 | 3 | UNCONDITIONAL_CONFIG = { 4 | "target": "sgm.modules.GeneralConditioner", 5 | "params": {"emb_models": []}, 6 | } 7 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/sgm/modules/autoencoding/__init__.py -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/sgm/modules/autoencoding/lpips/__init__.py -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/.gitignore: -------------------------------------------------------------------------------- 1 | vgg.pth -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/sgm/modules/autoencoding/lpips/loss/__init__.py -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/model/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | 26 | --------------------------- LICENSE FOR pix2pix -------------------------------- 27 | BSD License 28 | 29 | For pix2pix software 30 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu 31 | All rights reserved. 32 | 33 | Redistribution and use in source and binary forms, with or without 34 | modification, are permitted provided that the following conditions are met: 35 | 36 | * Redistributions of source code must retain the above copyright notice, this 37 | list of conditions and the following disclaimer. 38 | 39 | * Redistributions in binary form must reproduce the above copyright notice, 40 | this list of conditions and the following disclaimer in the documentation 41 | and/or other materials provided with the distribution. 42 | 43 | ----------------------------- LICENSE FOR DCGAN -------------------------------- 44 | BSD License 45 | 46 | For dcgan.torch software 47 | 48 | Copyright (c) 2015, Facebook, Inc. All rights reserved. 49 | 50 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 51 | 52 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 53 | 54 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 55 | 56 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 57 | 58 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/sgm/modules/autoencoding/lpips/model/__init__.py -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/model/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch.nn as nn 4 | 5 | from ..util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find("BatchNorm") != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | 22 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 23 | """Construct a PatchGAN discriminator 24 | Parameters: 25 | input_nc (int) -- the number of channels in input images 26 | ndf (int) -- the number of filters in the last conv layer 27 | n_layers (int) -- the number of conv layers in the discriminator 28 | norm_layer -- normalization layer 29 | """ 30 | super(NLayerDiscriminator, self).__init__() 31 | if not use_actnorm: 32 | norm_layer = nn.BatchNorm2d 33 | else: 34 | norm_layer = ActNorm 35 | if ( 36 | type(norm_layer) == functools.partial 37 | ): # no need to use bias as BatchNorm2d has affine parameters 38 | use_bias = norm_layer.func != nn.BatchNorm2d 39 | else: 40 | use_bias = norm_layer != nn.BatchNorm2d 41 | 42 | kw = 4 43 | padw = 1 44 | sequence = [ 45 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 46 | nn.LeakyReLU(0.2, True), 47 | ] 48 | nf_mult = 1 49 | nf_mult_prev = 1 50 | for n in range(1, n_layers): # gradually increase the number of filters 51 | nf_mult_prev = nf_mult 52 | nf_mult = min(2**n, 8) 53 | sequence += [ 54 | nn.Conv2d( 55 | ndf * nf_mult_prev, 56 | ndf * nf_mult, 57 | kernel_size=kw, 58 | stride=2, 59 | padding=padw, 60 | bias=use_bias, 61 | ), 62 | norm_layer(ndf * nf_mult), 63 | nn.LeakyReLU(0.2, True), 64 | ] 65 | 66 | nf_mult_prev = nf_mult 67 | nf_mult = min(2**n_layers, 8) 68 | sequence += [ 69 | nn.Conv2d( 70 | ndf * nf_mult_prev, 71 | ndf * nf_mult, 72 | kernel_size=kw, 73 | stride=1, 74 | padding=padw, 75 | bias=use_bias, 76 | ), 77 | norm_layer(ndf * nf_mult), 78 | nn.LeakyReLU(0.2, True), 79 | ] 80 | 81 | sequence += [ 82 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) 83 | ] # output 1 channel prediction map 84 | self.main = nn.Sequential(*sequence) 85 | 86 | def forward(self, input): 87 | """Standard forward.""" 88 | return self.main(input) 89 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/util.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | 4 | import requests 5 | import torch 6 | import torch.nn as nn 7 | from tqdm import tqdm 8 | 9 | URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} 10 | 11 | CKPT_MAP = {"vgg_lpips": "vgg.pth"} 12 | 13 | MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} 14 | 15 | 16 | def download(url, local_path, chunk_size=1024): 17 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 18 | with requests.get(url, stream=True) as r: 19 | total_size = int(r.headers.get("content-length", 0)) 20 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 21 | with open(local_path, "wb") as f: 22 | for data in r.iter_content(chunk_size=chunk_size): 23 | if data: 24 | f.write(data) 25 | pbar.update(chunk_size) 26 | 27 | 28 | def md5_hash(path): 29 | with open(path, "rb") as f: 30 | content = f.read() 31 | return hashlib.md5(content).hexdigest() 32 | 33 | 34 | def get_ckpt_path(name, root, check=False): 35 | assert name in URL_MAP 36 | path = os.path.join(root, CKPT_MAP[name]) 37 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 38 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 39 | download(URL_MAP[name], path) 40 | md5 = md5_hash(path) 41 | assert md5 == MD5_MAP[name], md5 42 | return path 43 | 44 | 45 | class ActNorm(nn.Module): 46 | def __init__( 47 | self, num_features, logdet=False, affine=True, allow_reverse_init=False 48 | ): 49 | assert affine 50 | super().__init__() 51 | self.logdet = logdet 52 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 53 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 54 | self.allow_reverse_init = allow_reverse_init 55 | 56 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) 57 | 58 | def initialize(self, input): 59 | with torch.no_grad(): 60 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 61 | mean = ( 62 | flatten.mean(1) 63 | .unsqueeze(1) 64 | .unsqueeze(2) 65 | .unsqueeze(3) 66 | .permute(1, 0, 2, 3) 67 | ) 68 | std = ( 69 | flatten.std(1) 70 | .unsqueeze(1) 71 | .unsqueeze(2) 72 | .unsqueeze(3) 73 | .permute(1, 0, 2, 3) 74 | ) 75 | 76 | self.loc.data.copy_(-mean) 77 | self.scale.data.copy_(1 / (std + 1e-6)) 78 | 79 | def forward(self, input, reverse=False): 80 | if reverse: 81 | return self.reverse(input) 82 | if len(input.shape) == 2: 83 | input = input[:, :, None, None] 84 | squeeze = True 85 | else: 86 | squeeze = False 87 | 88 | _, _, height, width = input.shape 89 | 90 | if self.training and self.initialized.item() == 0: 91 | self.initialize(input) 92 | self.initialized.fill_(1) 93 | 94 | h = self.scale * (input + self.loc) 95 | 96 | if squeeze: 97 | h = h.squeeze(-1).squeeze(-1) 98 | 99 | if self.logdet: 100 | log_abs = torch.log(torch.abs(self.scale)) 101 | logdet = height * width * torch.sum(log_abs) 102 | logdet = logdet * torch.ones(input.shape[0]).to(input) 103 | return h, logdet 104 | 105 | return h 106 | 107 | def reverse(self, output): 108 | if self.training and self.initialized.item() == 0: 109 | if not self.allow_reverse_init: 110 | raise RuntimeError( 111 | "Initializing ActNorm in reverse direction is " 112 | "disabled by default. Use allow_reverse_init=True to enable." 113 | ) 114 | else: 115 | self.initialize(output) 116 | self.initialized.fill_(1) 117 | 118 | if len(output.shape) == 2: 119 | output = output[:, :, None, None] 120 | squeeze = True 121 | else: 122 | squeeze = False 123 | 124 | h = output / self.scale - self.loc 125 | 126 | if squeeze: 127 | h = h.squeeze(-1).squeeze(-1) 128 | return h 129 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def hinge_d_loss(logits_real, logits_fake): 6 | loss_real = torch.mean(F.relu(1.0 - logits_real)) 7 | loss_fake = torch.mean(F.relu(1.0 + logits_fake)) 8 | d_loss = 0.5 * (loss_real + loss_fake) 9 | return d_loss 10 | 11 | 12 | def vanilla_d_loss(logits_real, logits_fake): 13 | d_loss = 0.5 * ( 14 | torch.mean(torch.nn.functional.softplus(-logits_real)) 15 | + torch.mean(torch.nn.functional.softplus(logits_fake)) 16 | ) 17 | return d_loss 18 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/regularizers/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from ....modules.distributions.distributions import DiagonalGaussianDistribution 9 | 10 | 11 | class AbstractRegularizer(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 16 | raise NotImplementedError() 17 | 18 | @abstractmethod 19 | def get_trainable_parameters(self) -> Any: 20 | raise NotImplementedError() 21 | 22 | 23 | class DiagonalGaussianRegularizer(AbstractRegularizer): 24 | def __init__(self, sample: bool = True): 25 | super().__init__() 26 | self.sample = sample 27 | 28 | def get_trainable_parameters(self) -> Any: 29 | yield from () 30 | 31 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 32 | log = dict() 33 | posterior = DiagonalGaussianDistribution(z) 34 | if self.sample: 35 | z = posterior.sample() 36 | else: 37 | z = posterior.mode() 38 | kl_loss = posterior.kl() 39 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 40 | log["kl_loss"] = kl_loss 41 | return z, log 42 | 43 | 44 | def measure_perplexity(predicted_indices, num_centroids): 45 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 46 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 47 | encodings = ( 48 | F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) 49 | ) 50 | avg_probs = encodings.mean(0) 51 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 52 | cluster_use = torch.sum(avg_probs > 0) 53 | return perplexity, cluster_use 54 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- 1 | from .denoiser import Denoiser 2 | from .discretizer import Discretization 3 | from .loss import StandardDiffusionLoss 4 | from .model import Decoder, Encoder, Model 5 | from .openaimodel import UNetModel 6 | from .sampling import BaseDiffusionSampler 7 | from .wrappers import OpenAIWrapper 8 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/denoiser.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from ...util import append_dims, instantiate_from_config 4 | 5 | 6 | class Denoiser(nn.Module): 7 | def __init__(self, weighting_config, scaling_config): 8 | super().__init__() 9 | 10 | self.weighting = instantiate_from_config(weighting_config) 11 | self.scaling = instantiate_from_config(scaling_config) 12 | 13 | def possibly_quantize_sigma(self, sigma): 14 | return sigma 15 | 16 | def possibly_quantize_c_noise(self, c_noise): 17 | return c_noise 18 | 19 | def w(self, sigma): 20 | return self.weighting(sigma) 21 | 22 | def __call__(self, network, input, sigma, cond): 23 | sigma = self.possibly_quantize_sigma(sigma) 24 | sigma_shape = sigma.shape 25 | sigma = append_dims(sigma, input.ndim) 26 | c_skip, c_out, c_in, c_noise = self.scaling(sigma) 27 | c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) 28 | return network(input * c_in, c_noise, cond) * c_out + input * c_skip 29 | 30 | 31 | class DiscreteDenoiser(Denoiser): 32 | def __init__( 33 | self, 34 | weighting_config, 35 | scaling_config, 36 | num_idx, 37 | discretization_config, 38 | do_append_zero=False, 39 | quantize_c_noise=True, 40 | flip=True, 41 | ): 42 | super().__init__(weighting_config, scaling_config) 43 | sigmas = instantiate_from_config(discretization_config)( 44 | num_idx, do_append_zero=do_append_zero, flip=flip 45 | ) 46 | self.register_buffer("sigmas", sigmas) 47 | self.quantize_c_noise = quantize_c_noise 48 | 49 | def sigma_to_idx(self, sigma): 50 | dists = sigma - self.sigmas[:, None] 51 | return dists.abs().argmin(dim=0).view(sigma.shape) 52 | 53 | def idx_to_sigma(self, idx): 54 | return self.sigmas[idx] 55 | 56 | def possibly_quantize_sigma(self, sigma): 57 | return self.idx_to_sigma(self.sigma_to_idx(sigma)) 58 | 59 | def possibly_quantize_c_noise(self, c_noise): 60 | if self.quantize_c_noise: 61 | return self.sigma_to_idx(c_noise) 62 | else: 63 | return c_noise 64 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/denoiser_scaling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class EDMScaling: 5 | def __init__(self, sigma_data=0.5): 6 | self.sigma_data = sigma_data 7 | 8 | def __call__(self, sigma): 9 | c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) 10 | c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 11 | c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 12 | c_noise = 0.25 * sigma.log() 13 | return c_skip, c_out, c_in, c_noise 14 | 15 | 16 | class EpsScaling: 17 | def __call__(self, sigma): 18 | c_skip = torch.ones_like(sigma, device=sigma.device) 19 | c_out = -sigma 20 | c_in = 1 / (sigma**2 + 1.0) ** 0.5 21 | c_noise = sigma.clone() 22 | return c_skip, c_out, c_in, c_noise 23 | 24 | 25 | class VScaling: 26 | def __call__(self, sigma): 27 | c_skip = 1.0 / (sigma**2 + 1.0) 28 | c_out = -sigma / (sigma**2 + 1.0) ** 0.5 29 | c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 30 | c_noise = sigma.clone() 31 | return c_skip, c_out, c_in, c_noise 32 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/denoiser_weighting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class UnitWeighting: 5 | def __call__(self, sigma): 6 | return torch.ones_like(sigma, device=sigma.device) 7 | 8 | 9 | class EDMWeighting: 10 | def __init__(self, sigma_data=0.5): 11 | self.sigma_data = sigma_data 12 | 13 | def __call__(self, sigma): 14 | return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 15 | 16 | 17 | class VWeighting(EDMWeighting): 18 | def __init__(self): 19 | super().__init__(sigma_data=1.0) 20 | 21 | 22 | class EpsWeighting: 23 | def __call__(self, sigma): 24 | return sigma**-2.0 25 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/discretizer.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from functools import partial 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from ...modules.diffusionmodules.util import make_beta_schedule 8 | from ...util import append_zero 9 | 10 | 11 | def generate_roughly_equally_spaced_steps( 12 | num_substeps: int, max_step: int 13 | ) -> np.ndarray: 14 | return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] 15 | 16 | 17 | class Discretization: 18 | def __call__(self, n, do_append_zero=True, device="cpu", flip=False): 19 | sigmas = self.get_sigmas(n, device=device) 20 | sigmas = append_zero(sigmas) if do_append_zero else sigmas 21 | return sigmas if not flip else torch.flip(sigmas, (0,)) 22 | 23 | @abstractmethod 24 | def get_sigmas(self, n, device): 25 | pass 26 | 27 | 28 | class EDMDiscretization(Discretization): 29 | def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0): 30 | self.sigma_min = sigma_min 31 | self.sigma_max = sigma_max 32 | self.rho = rho 33 | 34 | def get_sigmas(self, n, device="cpu"): 35 | ramp = torch.linspace(0, 1, n, device=device) 36 | min_inv_rho = self.sigma_min ** (1 / self.rho) 37 | max_inv_rho = self.sigma_max ** (1 / self.rho) 38 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho 39 | return sigmas 40 | 41 | 42 | class LegacyDDPMDiscretization(Discretization): 43 | def __init__( 44 | self, 45 | linear_start=0.00085, 46 | linear_end=0.0120, 47 | num_timesteps=1000, 48 | ): 49 | super().__init__() 50 | self.num_timesteps = num_timesteps 51 | betas = make_beta_schedule( 52 | "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end 53 | ) 54 | alphas = 1.0 - betas 55 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 56 | self.to_torch = partial(torch.tensor, dtype=torch.float32) 57 | 58 | def get_sigmas(self, n, device="cpu"): 59 | if n < self.num_timesteps: 60 | timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) 61 | alphas_cumprod = self.alphas_cumprod[timesteps] 62 | elif n == self.num_timesteps: 63 | alphas_cumprod = self.alphas_cumprod 64 | else: 65 | raise ValueError 66 | 67 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device) 68 | sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 69 | return torch.flip(sigmas, (0,)) 70 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/guiders.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | 5 | from ...util import default, instantiate_from_config 6 | 7 | 8 | class VanillaCFG: 9 | """ 10 | implements parallelized CFG 11 | """ 12 | 13 | def __init__(self, scale, dyn_thresh_config=None): 14 | scale_schedule = lambda scale, sigma: scale # independent of step 15 | self.scale_schedule = partial(scale_schedule, scale) 16 | self.dyn_thresh = instantiate_from_config( 17 | default( 18 | dyn_thresh_config, 19 | { 20 | "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" 21 | }, 22 | ) 23 | ) 24 | 25 | def __call__(self, x, sigma): 26 | x_u, x_c = x.chunk(2) 27 | scale_value = self.scale_schedule(sigma) 28 | x_pred = self.dyn_thresh(x_u, x_c, scale_value) 29 | return x_pred 30 | 31 | def prepare_inputs(self, x, s, c, uc): 32 | c_out = dict() 33 | 34 | for k in c: 35 | if k in ["vector", "crossattn", "concat"]: 36 | c_out[k] = torch.cat((uc[k], c[k]), 0) 37 | else: 38 | assert c[k] == uc[k] 39 | c_out[k] = c[k] 40 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out 41 | 42 | 43 | class IdentityGuider: 44 | def __call__(self, x, sigma): 45 | return x 46 | 47 | def prepare_inputs(self, x, s, c, uc): 48 | c_out = dict() 49 | 50 | for k in c: 51 | c_out[k] = c[k] 52 | 53 | return x, s, c_out 54 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/loss.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | from omegaconf import ListConfig 6 | 7 | from ...util import append_dims, instantiate_from_config 8 | from ...modules.autoencoding.lpips.loss.lpips import LPIPS 9 | 10 | 11 | class StandardDiffusionLoss(nn.Module): 12 | def __init__( 13 | self, 14 | sigma_sampler_config, 15 | type="l2", 16 | offset_noise_level=0.0, 17 | batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None, 18 | ): 19 | super().__init__() 20 | 21 | assert type in ["l2", "l1", "lpips"] 22 | 23 | self.sigma_sampler = instantiate_from_config(sigma_sampler_config) 24 | 25 | self.type = type 26 | self.offset_noise_level = offset_noise_level 27 | 28 | if type == "lpips": 29 | self.lpips = LPIPS().eval() 30 | 31 | if not batch2model_keys: 32 | batch2model_keys = [] 33 | 34 | if isinstance(batch2model_keys, str): 35 | batch2model_keys = [batch2model_keys] 36 | 37 | self.batch2model_keys = set(batch2model_keys) 38 | 39 | def __call__(self, network, denoiser, conditioner, input, batch): 40 | cond = conditioner(batch) 41 | additional_model_inputs = { 42 | key: batch[key] for key in self.batch2model_keys.intersection(batch) 43 | } 44 | 45 | sigmas = self.sigma_sampler(input.shape[0]).to(input.device) 46 | noise = torch.randn_like(input) 47 | if self.offset_noise_level > 0.0: 48 | noise = noise + self.offset_noise_level * append_dims( 49 | torch.randn(input.shape[0], device=input.device), input.ndim 50 | ) 51 | noised_input = input + noise * append_dims(sigmas, input.ndim) 52 | model_output = denoiser( 53 | network, noised_input, sigmas, cond, **additional_model_inputs 54 | ) 55 | w = append_dims(denoiser.w(sigmas), input.ndim) 56 | return self.get_loss(model_output, input, w) 57 | 58 | def get_loss(self, model_output, target, w): 59 | if self.type == "l2": 60 | return torch.mean( 61 | (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1 62 | ) 63 | elif self.type == "l1": 64 | return torch.mean( 65 | (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1 66 | ) 67 | elif self.type == "lpips": 68 | loss = self.lpips(model_output, target).reshape(-1) 69 | return loss 70 | 71 | # for score distillation sampling 72 | def add_noise(self, input, noise, sigmas): 73 | if self.offset_noise_level > 0.0: 74 | noise = noise + self.offset_noise_level * append_dims( 75 | torch.randn(input.shape[0], device=input.device), input.ndim 76 | ) 77 | noised_input = input + noise * append_dims(sigmas, input.ndim) 78 | return noised_input 79 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/sampling_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy import integrate 3 | 4 | from ...util import append_dims 5 | 6 | 7 | class NoDynamicThresholding: 8 | def __call__(self, uncond, cond, scale): 9 | return uncond + scale * (cond - uncond) 10 | 11 | 12 | def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): 13 | if order - 1 > i: 14 | raise ValueError(f"Order {order} too high for step {i}") 15 | 16 | def fn(tau): 17 | prod = 1.0 18 | for k in range(order): 19 | if j == k: 20 | continue 21 | prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) 22 | return prod 23 | 24 | return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0] 25 | 26 | 27 | def get_ancestral_step(sigma_from, sigma_to, eta=1.0): 28 | if not eta: 29 | return sigma_to, 0.0 30 | sigma_up = torch.minimum( 31 | sigma_to, 32 | eta 33 | * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, 34 | ) 35 | sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 36 | return sigma_down, sigma_up 37 | 38 | 39 | def to_d(x, sigma, denoised): 40 | return (x - denoised) / append_dims(sigma, x.ndim) 41 | 42 | 43 | def to_neg_log_sigma(sigma): 44 | return sigma.log().neg() 45 | 46 | 47 | def to_sigma(neg_log_sigma): 48 | return neg_log_sigma.neg().exp() 49 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/sigma_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ...util import default, instantiate_from_config 4 | 5 | 6 | class EDMSampling: 7 | def __init__(self, p_mean=-1.2, p_std=1.2): 8 | self.p_mean = p_mean 9 | self.p_std = p_std 10 | 11 | def __call__(self, n_samples, rand=None): 12 | log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,))) 13 | return log_sigma.exp() 14 | 15 | 16 | class DiscreteSampling: 17 | def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True): 18 | self.num_idx = num_idx 19 | self.sigmas = instantiate_from_config(discretization_config)( 20 | num_idx, do_append_zero=do_append_zero, flip=flip 21 | ) 22 | 23 | def idx_to_sigma(self, idx): 24 | return self.sigmas[idx] 25 | 26 | def __call__(self, n_samples, rand=None): 27 | idx = default( 28 | rand, 29 | torch.randint(0, self.num_idx, (n_samples,)), 30 | ) 31 | return self.idx_to_sigma(idx) 32 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/wrappers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from packaging import version 4 | 5 | OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" 6 | 7 | 8 | class IdentityWrapper(nn.Module): 9 | def __init__(self, diffusion_model, compile_model: bool = False): 10 | super().__init__() 11 | compile = ( 12 | torch.compile 13 | if (version.parse(torch.__version__) >= version.parse("2.0.0")) 14 | and compile_model 15 | else lambda x: x 16 | ) 17 | self.diffusion_model = compile(diffusion_model) 18 | 19 | def forward(self, *args, **kwargs): 20 | return self.diffusion_model(*args, **kwargs) 21 | 22 | 23 | class OpenAIWrapper(IdentityWrapper): 24 | def forward( 25 | self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs 26 | ) -> torch.Tensor: 27 | x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) 28 | return self.diffusion_model( 29 | x, 30 | timesteps=t, 31 | context=c.get("crossattn", None), 32 | y=c.get("vector", None), 33 | **kwargs, 34 | ) 35 | -------------------------------------------------------------------------------- /sgm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/sgm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /sgm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.sum( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=[1, 2, 3]): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims, 69 | ) 70 | 71 | def mode(self): 72 | return self.mean 73 | 74 | 75 | def normal_kl(mean1, logvar1, mean2, logvar2): 76 | """ 77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 78 | Compute the KL divergence between two gaussians. 79 | Shapes are automatically broadcasted, so batches can be compared to 80 | scalars, among other use cases. 81 | """ 82 | tensor = None 83 | for obj in (mean1, logvar1, mean2, logvar2): 84 | if isinstance(obj, torch.Tensor): 85 | tensor = obj 86 | break 87 | assert tensor is not None, "at least one argument must be a Tensor" 88 | 89 | # Force variances to be Tensors. Broadcasting helps convert scalars to 90 | # Tensors, but it does not work for torch.exp(). 91 | logvar1, logvar2 = [ 92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 93 | for x in (logvar1, logvar2) 94 | ] 95 | 96 | return 0.5 * ( 97 | -1.0 98 | + logvar2 99 | - logvar1 100 | + torch.exp(logvar1 - logvar2) 101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 102 | ) 103 | -------------------------------------------------------------------------------- /sgm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | torch.tensor(0, dtype=torch.int) 16 | if use_num_upates 17 | else torch.tensor(-1, dtype=torch.int), 18 | ) 19 | 20 | for name, p in model.named_parameters(): 21 | if p.requires_grad: 22 | # remove as '.'-character is not allowed in buffers 23 | s_name = name.replace(".", "") 24 | self.m_name2s_name.update({name: s_name}) 25 | self.register_buffer(s_name, p.clone().detach().data) 26 | 27 | self.collected_params = [] 28 | 29 | def reset_num_updates(self): 30 | del self.num_updates 31 | self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) 32 | 33 | def forward(self, model): 34 | decay = self.decay 35 | 36 | if self.num_updates >= 0: 37 | self.num_updates += 1 38 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 39 | 40 | one_minus_decay = 1.0 - decay 41 | 42 | with torch.no_grad(): 43 | m_param = dict(model.named_parameters()) 44 | shadow_params = dict(self.named_buffers()) 45 | 46 | for key in m_param: 47 | if m_param[key].requires_grad: 48 | sname = self.m_name2s_name[key] 49 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 50 | shadow_params[sname].sub_( 51 | one_minus_decay * (shadow_params[sname] - m_param[key]) 52 | ) 53 | else: 54 | assert not key in self.m_name2s_name 55 | 56 | def copy_to(self, model): 57 | m_param = dict(model.named_parameters()) 58 | shadow_params = dict(self.named_buffers()) 59 | for key in m_param: 60 | if m_param[key].requires_grad: 61 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 62 | else: 63 | assert not key in self.m_name2s_name 64 | 65 | def store(self, parameters): 66 | """ 67 | Save the current parameters for restoring later. 68 | Args: 69 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 70 | temporarily stored. 71 | """ 72 | self.collected_params = [param.clone() for param in parameters] 73 | 74 | def restore(self, parameters): 75 | """ 76 | Restore the parameters stored with the `store` method. 77 | Useful to validate the model with EMA parameters without affecting the 78 | original optimization process. Store the parameters before the 79 | `copy_to` method. After validation (or model saving), use this to 80 | restore the former parameters. 81 | Args: 82 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 83 | updated with the stored parameters. 84 | """ 85 | for c_param, param in zip(self.collected_params, parameters): 86 | param.data.copy_(c_param.data) 87 | -------------------------------------------------------------------------------- /sgm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/sgm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /sgm/modules/nvsadapter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/sgm/modules/nvsadapter/__init__.py -------------------------------------------------------------------------------- /sgm/modules/nvsadapter/canny/api.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | class CannyInference: 5 | 6 | def __init__(self, low_threshold, high_threshold): 7 | self.low_threshold = low_threshold 8 | self.high_threshold = high_threshold 9 | 10 | def forward_each(self, img): 11 | numpy_img = img.numpy().astype(np.uint8).transpose(1, 2, 0) 12 | return cv2.Canny(numpy_img, self.low_threshold, self.high_threshold)[np.newaxis, :, :] 13 | 14 | def __call__(self, images): 15 | preds = [] 16 | for image in images: 17 | preds.append(self.forward_each(image)) 18 | return np.stack(preds) -------------------------------------------------------------------------------- /sgm/modules/nvsadapter/hed/api.py: -------------------------------------------------------------------------------- 1 | # This is an improved version and model of HED edge detection with Apache License, Version 2.0. 2 | # Please use this implementation in your products 3 | # This implementation may produce slightly different results from Saining Xie's official implementations, 4 | # but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations. 5 | # Different from official models and other implementations, this is an RGB-input model (rather than BGR) 6 | # and in this way it works better for gradio's RGB protocol 7 | 8 | import os 9 | import cv2 10 | import torch 11 | import numpy as np 12 | 13 | from einops import rearrange 14 | 15 | 16 | class DoubleConvBlock(torch.nn.Module): 17 | def __init__(self, input_channel, output_channel, layer_number): 18 | super().__init__() 19 | self.convs = torch.nn.Sequential() 20 | self.convs.append( 21 | torch.nn.Conv2d( 22 | in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1 23 | ) 24 | ) 25 | for i in range(1, layer_number): 26 | self.convs.append( 27 | torch.nn.Conv2d( 28 | in_channels=output_channel, 29 | out_channels=output_channel, 30 | kernel_size=(3, 3), 31 | stride=(1, 1), 32 | padding=1, 33 | ) 34 | ) 35 | self.projection = torch.nn.Conv2d( 36 | in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0 37 | ) 38 | 39 | def __call__(self, x, down_sampling=False): 40 | h = x 41 | if down_sampling: 42 | h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2)) 43 | for conv in self.convs: 44 | h = conv(h) 45 | h = torch.nn.functional.relu(h) 46 | return h, self.projection(h) 47 | 48 | 49 | class ControlNetHED_Apache2(torch.nn.Module): 50 | def __init__(self): 51 | super().__init__() 52 | self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1))) 53 | self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2) 54 | self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2) 55 | self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3) 56 | self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3) 57 | self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3) 58 | 59 | def __call__(self, x): 60 | h = x - self.norm 61 | h, projection1 = self.block1(h) 62 | h, projection2 = self.block2(h, down_sampling=True) 63 | h, projection3 = self.block3(h, down_sampling=True) 64 | h, projection4 = self.block4(h, down_sampling=True) 65 | h, projection5 = self.block5(h, down_sampling=True) 66 | return projection1, projection2, projection3, projection4, projection5 67 | 68 | 69 | class HEDdetector(torch.nn.Module): 70 | def __init__(self): 71 | super(HEDdetector, self).__init__() 72 | modelpath = os.path.join("./checkpoints", "ControlNetHED.pth") 73 | self.netNetwork = ControlNetHED_Apache2().float().eval() 74 | self.netNetwork.load_state_dict(torch.load(modelpath, map_location="cpu")) 75 | 76 | def forward_each(self, input_image): 77 | assert input_image.ndim == 3 78 | input_image = input_image.permute(1, 2, 0) 79 | H, W, C = input_image.shape 80 | with torch.no_grad(): 81 | image_hed = rearrange(input_image, "h w c -> 1 c h w") 82 | edges = self.netNetwork(image_hed) 83 | edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] 84 | edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges] 85 | edges = np.stack(edges, axis=2) 86 | edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) 87 | edge = (edge * 255.0).clip(0, 255).astype(np.uint8) 88 | return edge[np.newaxis, :, :] 89 | 90 | def forward(self, images): 91 | preds = [] 92 | for input_image in images: 93 | preds.append(self.forward_each(input_image)) 94 | return np.stack(preds) 95 | 96 | 97 | def nms(x, t, s): 98 | x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) 99 | 100 | f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) 101 | f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) 102 | f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) 103 | f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) 104 | 105 | y = np.zeros_like(x) 106 | 107 | for f in [f1, f2, f3, f4]: 108 | np.putmask(y, cv2.dilate(x, kernel=f) == x, x) 109 | 110 | z = np.zeros_like(y, dtype=np.uint8) 111 | z[y > t] = 255 112 | return z 113 | -------------------------------------------------------------------------------- /sgm/modules/nvsadapter/lora/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/sgm/modules/nvsadapter/lora/__init__.py -------------------------------------------------------------------------------- /sgm/modules/nvsadapter/lora/safe_open.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pure python version of Safetensors safe_open 3 | From https://gist.github.com/Narsil/3edeec2669a5e94e4707aa0f901d2282 4 | """ 5 | 6 | import json 7 | import mmap 8 | import os 9 | 10 | import torch 11 | 12 | 13 | class SafetensorsWrapper: 14 | def __init__(self, metadata, tensors): 15 | self._metadata = metadata 16 | self._tensors = tensors 17 | 18 | def metadata(self): 19 | return self._metadata 20 | 21 | def keys(self): 22 | return self._tensors.keys() 23 | 24 | def get_tensor(self, k): 25 | return self._tensors[k] 26 | 27 | 28 | DTYPES = { 29 | "F32": torch.float32, 30 | "F16": torch.float16, 31 | "BF16": torch.bfloat16, 32 | } 33 | 34 | 35 | def create_tensor(storage, info, offset): 36 | dtype = DTYPES[info["dtype"]] 37 | shape = info["shape"] 38 | start, stop = info["data_offsets"] 39 | return ( 40 | torch.asarray(storage[start + offset : stop + offset], dtype=torch.uint8) 41 | .view(dtype=dtype) 42 | .reshape(shape) 43 | ) 44 | 45 | 46 | def safe_open(filename, framework="pt", device="cpu"): 47 | if framework != "pt": 48 | raise ValueError("`framework` must be 'pt'") 49 | 50 | with open(filename, mode="r", encoding="utf8") as file_obj: 51 | with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m: 52 | header = m.read(8) 53 | n = int.from_bytes(header, "little") 54 | metadata_bytes = m.read(n) 55 | metadata = json.loads(metadata_bytes) 56 | 57 | size = os.stat(filename).st_size 58 | storage = torch.ByteStorage.from_file(filename, shared=False, size=size).untyped() 59 | offset = n + 8 60 | 61 | return SafetensorsWrapper( 62 | metadata=metadata.get("__metadata__", {}), 63 | tensors={ 64 | name: create_tensor(storage, info, offset).to(device) 65 | for name, info in metadata.items() 66 | if name != "__metadata__" 67 | }, 68 | ) -------------------------------------------------------------------------------- /sgm/modules/nvsadapter/lora/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/sgm/modules/nvsadapter/lora/utils.py -------------------------------------------------------------------------------- /sgm/modules/nvsadapter/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/sgm/modules/nvsadapter/midas/__init__.py -------------------------------------------------------------------------------- /sgm/modules/nvsadapter/midas/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/nvsadapter/e18786d7bab844eefba60805aded4fb5460895b2/sgm/modules/nvsadapter/midas/midas/__init__.py -------------------------------------------------------------------------------- /sgm/modules/nvsadapter/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /sgm/modules/nvsadapter/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /sgm/modules/nvsadapter/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /sgm/modules/nvsadapter/wrappers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from sgm.modules.diffusionmodules.denoiser import DiscreteDenoiser, append_dims 4 | from sgm.modules.diffusionmodules.sampling import EulerEDMSampler 5 | from sgm.modules.diffusionmodules.wrappers import IdentityWrapper 6 | from sgm.modules.diffusionmodules.guiders import VanillaCFG 7 | 8 | 9 | class NVSAdapterDiscreteDenoiser(DiscreteDenoiser): 10 | def __call__(self, network, input, sigma, cond, ucg_mask=None): 11 | # follows the original implementation but adding ucg_mask for forwarding 12 | # for training, ucg_mask is None since CondDrop will cover the unconditional cases 13 | sigma = self.possibly_quantize_sigma(sigma) 14 | sigma_shape = sigma.shape 15 | sigma = append_dims(sigma, input.ndim) 16 | c_skip, c_out, c_in, c_noise = self.scaling(sigma) 17 | c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) 18 | return network(input * c_in, c_noise, cond, ucg_mask=ucg_mask) * c_out + input * c_skip 19 | 20 | 21 | class NVSAdapterEulerEDMSampler(EulerEDMSampler): 22 | def denoise(self, x, denoiser, sigma, cond, uc): 23 | batch_size = x.shape[0] 24 | # first #bsz elements are conditional forward 25 | # last #bsz elements are unconditional forward 26 | ucg_mask = torch.cat([torch.ones(batch_size), torch.zeros(batch_size)]).to(device=x.device, dtype=torch.bool) 27 | denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), ucg_mask=ucg_mask) 28 | denoised = self.guider(denoised, sigma) 29 | return denoised 30 | 31 | 32 | class NVSAdapterWrapper(IdentityWrapper): 33 | def forward( 34 | self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs 35 | ) -> torch.Tensor: 36 | txt_context = {k[len("txt/"):]: v for k, v in c.items() if k.startswith("txt/")} 37 | image_context = {k[len("image/"):]: v for k, v in c.items() if k.startswith("image/")} 38 | rays_context = {k[len("ray/"):]: v for k, v in c.items() if k.startswith("ray/")} 39 | support_latents = {k[len("support_latents/"):]: v for k, v in c.items() if k.startswith("support_latents/")} 40 | control_context = {k[len("control/"):]: v for k, v in c.items() if k.startswith("control/")} 41 | 42 | assert rays_context is not None, "ray conditions are required" 43 | assert support_latents is not None, "support latents are required" 44 | 45 | image_emb = image_context.get("crossattn", None) 46 | txt_emb = txt_context.get("crossattn", None) 47 | 48 | rays_emb = rays_context.get("concat", None) 49 | support_latents = support_latents.get("concat", None) 50 | control_context = control_context.get("concat", None) 51 | 52 | return self.diffusion_model( 53 | x, 54 | timesteps=t, 55 | control_context=control_context, 56 | txt_context=txt_emb, 57 | image_context=image_emb, 58 | rays_context=rays_emb, 59 | support_latents=support_latents, 60 | **kwargs, 61 | ) 62 | 63 | 64 | class NVSAdapterCFG(VanillaCFG): 65 | def prepare_inputs(self, x, s, c, uc): 66 | c_out = dict() 67 | 68 | for k in c: 69 | if k in ["image/crossattn", "txt/crossattn", "ray/concat", "support_latents/concat", "control/concat"]: 70 | c_out[k] = torch.cat((uc[k], c[k]), 0) 71 | else: 72 | assert c[k] == uc[k] 73 | c_out[k] = c[k] 74 | 75 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out 76 | --------------------------------------------------------------------------------