├── threestudio ├── utils │ ├── GAN │ │ ├── __init__.py │ │ ├── loss.py │ │ ├── distribution.py │ │ ├── util.py │ │ ├── discriminator.py │ │ ├── mobilenet.py │ │ ├── attention.py │ │ └── network_util.py │ ├── __init__.py │ ├── perceptual │ │ ├── __init__.py │ │ ├── utils.py │ │ └── perceptual.py │ ├── loss.py │ ├── typing.py │ ├── rasterize.py │ ├── base.py │ ├── config.py │ ├── misc.py │ └── callbacks.py ├── data │ └── __init__.py ├── models │ ├── exporters │ │ ├── __init__.py │ │ ├── base.py │ │ └── mesh_exporter.py │ ├── guidance │ │ └── __init__.py │ ├── background │ │ ├── __init__.py │ │ ├── base.py │ │ ├── solid_color_background.py │ │ ├── textured_background.py │ │ └── neural_environment_map_background.py │ ├── geometry │ │ ├── __init__.py │ │ ├── custom_mesh.py │ │ ├── base.py │ │ └── volume_grid.py │ ├── prompt_processors │ │ ├── __init__.py │ │ ├── dummy_prompt_processor.py │ │ ├── deepfloyd_prompt_processor.py │ │ └── stable_diffusion_prompt_processor.py │ ├── __init__.py │ ├── renderers │ │ ├── __init__.py │ │ ├── deferred_volume_renderer.py │ │ ├── base.py │ │ ├── patch_renderer.py │ │ ├── nvdiff_rasterizer.py │ │ └── gan_volume_renderer.py │ ├── materials │ │ ├── __init__.py │ │ ├── base.py │ │ ├── hybrid_rgb_latent_material.py │ │ ├── sd_latent_adapter_material.py │ │ ├── neural_radiance_material.py │ │ ├── no_material.py │ │ ├── diffuse_with_point_light_material.py │ │ └── pbr_material.py │ ├── estimators.py │ └── isosurface.py ├── systems │ ├── __init__.py │ └── utils.py └── __init__.py ├── load ├── tets │ ├── 128_tets.npz │ ├── 32_tets.npz │ ├── 64_tets.npz │ └── generate_tets.py ├── images │ └── firekeeper.jpg ├── lights │ ├── bsdf_256_256.bin │ └── mud_road_puresky_1k.hdr ├── zero123 │ ├── download.sh │ └── sd-objaverse-finetune-c_concat-256.yaml └── shapes │ └── README.md ├── run.sh ├── docker ├── compose.yaml └── Dockerfile ├── .gitignore ├── configs ├── sweetdreamer-stage2.yaml └── sweetdreamer-stage1.yaml └── README.md /threestudio/utils/GAN/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /threestudio/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import uncond 2 | -------------------------------------------------------------------------------- /threestudio/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import base 2 | -------------------------------------------------------------------------------- /threestudio/models/exporters/__init__.py: -------------------------------------------------------------------------------- 1 | from . import base, mesh_exporter 2 | -------------------------------------------------------------------------------- /threestudio/systems/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | sweetdreamer, 3 | ) 4 | -------------------------------------------------------------------------------- /threestudio/utils/perceptual/__init__.py: -------------------------------------------------------------------------------- 1 | from .perceptual import PerceptualLoss 2 | -------------------------------------------------------------------------------- /load/tets/128_tets.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyysf-98/SweetDreamer/HEAD/load/tets/128_tets.npz -------------------------------------------------------------------------------- /load/tets/32_tets.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyysf-98/SweetDreamer/HEAD/load/tets/32_tets.npz -------------------------------------------------------------------------------- /load/tets/64_tets.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyysf-98/SweetDreamer/HEAD/load/tets/64_tets.npz -------------------------------------------------------------------------------- /load/images/firekeeper.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyysf-98/SweetDreamer/HEAD/load/images/firekeeper.jpg -------------------------------------------------------------------------------- /load/lights/bsdf_256_256.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyysf-98/SweetDreamer/HEAD/load/lights/bsdf_256_256.bin -------------------------------------------------------------------------------- /load/lights/mud_road_puresky_1k.hdr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyysf-98/SweetDreamer/HEAD/load/lights/mud_road_puresky_1k.hdr -------------------------------------------------------------------------------- /threestudio/models/guidance/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | controlnet_guidance, 3 | deep_floyd_guidance, 4 | stable_diffusion_guidance, 5 | ) 6 | -------------------------------------------------------------------------------- /threestudio/models/background/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | base, 3 | neural_environment_map_background, 4 | solid_color_background, 5 | textured_background, 6 | ) 7 | -------------------------------------------------------------------------------- /threestudio/models/geometry/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | base, 3 | custom_mesh, 4 | implicit_sdf, 5 | implicit_volume, 6 | tetrahedra_sdf_grid, 7 | volume_grid, 8 | ) 9 | -------------------------------------------------------------------------------- /threestudio/models/prompt_processors/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | base, 3 | deepfloyd_prompt_processor, 4 | dummy_prompt_processor, 5 | stable_diffusion_prompt_processor, 6 | ) 7 | -------------------------------------------------------------------------------- /threestudio/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | background, 3 | exporters, 4 | geometry, 5 | guidance, 6 | materials, 7 | prompt_processors, 8 | renderers, 9 | ) 10 | -------------------------------------------------------------------------------- /threestudio/models/renderers/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | base, 3 | deferred_volume_renderer, 4 | gan_volume_renderer, 5 | nerf_volume_renderer, 6 | neus_volume_renderer, 7 | nvdiff_rasterizer, 8 | patch_renderer, 9 | ) 10 | -------------------------------------------------------------------------------- /threestudio/models/materials/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | base, 3 | diffuse_with_point_light_material, 4 | hybrid_rgb_latent_material, 5 | neural_radiance_material, 6 | no_material, 7 | pbr_material, 8 | sd_latent_adapter_material, 9 | ) 10 | -------------------------------------------------------------------------------- /load/zero123/download.sh: -------------------------------------------------------------------------------- 1 | # wget https://huggingface.co/cvlab/zero123-weights/resolve/main/105000.ckpt 2 | # mv 105000.ckpt zero123-original.ckpt 3 | wget https://zero123.cs.columbia.edu/assets/zero123-xl.ckpt 4 | # Download stable_zero123.ckpt from https://huggingface.co/stabilityai/stable-zero123 5 | -------------------------------------------------------------------------------- /threestudio/models/renderers/deferred_volume_renderer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import threestudio 7 | from threestudio.models.renderers.base import VolumeRenderer 8 | 9 | 10 | class DeferredVolumeRenderer(VolumeRenderer): 11 | pass 12 | -------------------------------------------------------------------------------- /threestudio/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def _tensor_size(t): 5 | return t.size()[1] * t.size()[2] * t.size()[3] 6 | 7 | 8 | def tv_loss(x): 9 | batch_size = x.size()[0] 10 | h_x = x.size()[2] 11 | w_x = x.size()[3] 12 | count_h = _tensor_size(x[:, :, 1:, :]) 13 | count_w = _tensor_size(x[:, :, :, 1:]) 14 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, : h_x - 1, :]), 2).sum() 15 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, : w_x - 1]), 2).sum() 16 | return 2 * (h_tv / count_h + w_tv / count_w) / batch_size 17 | -------------------------------------------------------------------------------- /load/shapes/README.md: -------------------------------------------------------------------------------- 1 | # Shape Credits 2 | 3 | - `animal.obj` - Ido Richardson 4 | - `hand_prismatic.obj` - Ido Richardson 5 | - `potion.obj` - Ido Richardson 6 | - `blub.obj` - [Keenan's 3D Model Repository](https://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/) 7 | - `nascar.obj` - [Princeton ModelNet](https://modelnet.cs.princeton.edu/) 8 | - `cabin.obj` - [Princeton ModelNet](https://modelnet.cs.princeton.edu/) 9 | - `teddy.obj` - [Gal Metzer](https://galmetzer.github.io/) 10 | - `human.obj` - [TurboSquid](https://www.turbosquid.com/3d-models/3d-model-character-base/524860) 11 | -------------------------------------------------------------------------------- /threestudio/models/prompt_processors/dummy_prompt_processor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass 4 | 5 | import threestudio 6 | from threestudio.models.prompt_processors.base import PromptProcessor, hash_prompt 7 | from threestudio.utils.misc import cleanup 8 | from threestudio.utils.typing import * 9 | 10 | 11 | @threestudio.register("dummy-prompt-processor") 12 | class DummyPromptProcessor(PromptProcessor): 13 | @dataclass 14 | class Config(PromptProcessor.Config): 15 | pretrained_model_name_or_path: str = "" 16 | prompt: str = "" 17 | 18 | cfg: Config 19 | -------------------------------------------------------------------------------- /threestudio/models/background/base.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.utils.base import BaseModule 10 | from threestudio.utils.typing import * 11 | 12 | 13 | class BaseBackground(BaseModule): 14 | @dataclass 15 | class Config(BaseModule.Config): 16 | pass 17 | 18 | cfg: Config 19 | 20 | def configure(self): 21 | pass 22 | 23 | def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]: 24 | raise NotImplementedError 25 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | python launch.py --config configs_debug/sweetdreamer-stage1.yaml --train --gpu 0 \ 2 | system.prompt_processor.prompt="Albert Einstein with grey suit is riding a bicycle" \ 3 | system.cmm_prompt_processor.prompt="Albert Einstein with grey suit is riding a bicycle" \ 4 | tag=einstein 5 | 6 | python launch.py --config configs/sweetdreamer-stage2.yaml --train --gpu 0 \ 7 | system.prompt_processor.prompt="Albert Einstein with grey suit is riding a bicycle" \ 8 | system.cmm_prompt_processor.prompt="Albert Einstein with grey suit is riding a bicycle" \ 9 | tag=einstein -------------------------------------------------------------------------------- /threestudio/models/materials/base.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.utils.base import BaseModule 10 | from threestudio.utils.typing import * 11 | 12 | 13 | class BaseMaterial(BaseModule): 14 | @dataclass 15 | class Config(BaseModule.Config): 16 | pass 17 | 18 | cfg: Config 19 | requires_normal: bool = False 20 | requires_tangent: bool = False 21 | 22 | def configure(self): 23 | pass 24 | 25 | def forward(self, *args, **kwargs) -> Float[Tensor, "*B 3"]: 26 | raise NotImplementedError 27 | 28 | def export(self, *args, **kwargs) -> Dict[str, Any]: 29 | return {} 30 | -------------------------------------------------------------------------------- /docker/compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | threestudio: 3 | build: 4 | context: ../ 5 | dockerfile: docker/Dockerfile 6 | args: 7 | # you can set environment variables, otherwise default values will be used 8 | USER_NAME: ${HOST_USER_NAME:-dreamer} # export HOST_USER_NAME=$USER 9 | GROUP_NAME: ${HOST_GROUP_NAME:-dreamers} 10 | UID: ${HOST_UID:-1000} # export HOST_UID=$(id -u) 11 | GID: ${HOST_GID:-1000} # export HOST_GID=$(id -g) 12 | shm_size: '4gb' 13 | environment: 14 | NVIDIA_DISABLE_REQUIRE: 1 # avoid wrong `nvidia-container-cli: requirement error` 15 | tty: true 16 | volumes: 17 | - ../:/home/${HOST_USER_NAME:-dreamer}/threestudio 18 | deploy: 19 | resources: 20 | reservations: 21 | devices: 22 | - driver: nvidia 23 | capabilities: [gpu] 24 | -------------------------------------------------------------------------------- /threestudio/utils/typing.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains type annotations for the project, using 3 | 1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects 4 | 2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors 5 | 6 | Two types of typing checking can be used: 7 | 1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode) 8 | 2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking) 9 | """ 10 | 11 | # Basic types 12 | from typing import ( 13 | Any, 14 | Callable, 15 | Dict, 16 | Iterable, 17 | List, 18 | Literal, 19 | NamedTuple, 20 | NewType, 21 | Optional, 22 | Sized, 23 | Tuple, 24 | Type, 25 | TypeVar, 26 | Union, 27 | ) 28 | 29 | # Tensor dtype 30 | # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md 31 | from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt 32 | 33 | # Config type 34 | from omegaconf import DictConfig 35 | 36 | # PyTorch Tensor type 37 | from torch import Tensor 38 | 39 | # Runtime type checking decorator 40 | from typeguard import typechecked as typechecker 41 | -------------------------------------------------------------------------------- /threestudio/utils/GAN/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def generator_loss(discriminator, inputs, reconstructions, cond=None): 6 | if cond is None: 7 | logits_fake = discriminator(reconstructions.contiguous()) 8 | else: 9 | logits_fake = discriminator( 10 | torch.cat((reconstructions.contiguous(), cond), dim=1) 11 | ) 12 | g_loss = -torch.mean(logits_fake) 13 | return g_loss 14 | 15 | 16 | def hinge_d_loss(logits_real, logits_fake): 17 | loss_real = torch.mean(F.relu(1.0 - logits_real)) 18 | loss_fake = torch.mean(F.relu(1.0 + logits_fake)) 19 | d_loss = 0.5 * (loss_real + loss_fake) 20 | return d_loss 21 | 22 | 23 | def discriminator_loss(discriminator, inputs, reconstructions, cond=None): 24 | if cond is None: 25 | logits_real = discriminator(inputs.contiguous().detach()) 26 | logits_fake = discriminator(reconstructions.contiguous().detach()) 27 | else: 28 | logits_real = discriminator( 29 | torch.cat((inputs.contiguous().detach(), cond), dim=1) 30 | ) 31 | logits_fake = discriminator( 32 | torch.cat((reconstructions.contiguous().detach(), cond), dim=1) 33 | ) 34 | d_loss = hinge_d_loss(logits_real, logits_fake).mean() 35 | return d_loss 36 | -------------------------------------------------------------------------------- /threestudio/models/materials/hybrid_rgb_latent_material.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.models.networks import get_encoding, get_mlp 11 | from threestudio.utils.ops import dot, get_activation 12 | from threestudio.utils.typing import * 13 | 14 | 15 | @threestudio.register("hybrid-rgb-latent-material") 16 | class HybridRGBLatentMaterial(BaseMaterial): 17 | @dataclass 18 | class Config(BaseMaterial.Config): 19 | n_output_dims: int = 3 20 | color_activation: str = "sigmoid" 21 | requires_normal: bool = True 22 | 23 | cfg: Config 24 | 25 | def configure(self) -> None: 26 | self.requires_normal = self.cfg.requires_normal 27 | 28 | def forward( 29 | self, features: Float[Tensor, "B ... Nf"], **kwargs 30 | ) -> Float[Tensor, "B ... Nc"]: 31 | assert ( 32 | features.shape[-1] == self.cfg.n_output_dims 33 | ), f"Expected {self.cfg.n_output_dims} output dims, only got {features.shape[-1]} dims input." 34 | color = features 35 | color[..., :3] = get_activation(self.cfg.color_activation)(color[..., :3]) 36 | return color 37 | -------------------------------------------------------------------------------- /threestudio/models/materials/sd_latent_adapter_material.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.utils.typing import * 11 | 12 | 13 | @threestudio.register("sd-latent-adapter-material") 14 | class StableDiffusionLatentAdapterMaterial(BaseMaterial): 15 | @dataclass 16 | class Config(BaseMaterial.Config): 17 | pass 18 | 19 | cfg: Config 20 | 21 | def configure(self) -> None: 22 | adapter = nn.Parameter( 23 | torch.as_tensor( 24 | [ 25 | # R G B 26 | [0.298, 0.207, 0.208], # L1 27 | [0.187, 0.286, 0.173], # L2 28 | [-0.158, 0.189, 0.264], # L3 29 | [-0.184, -0.271, -0.473], # L4 30 | ] 31 | ) 32 | ) 33 | self.register_parameter("adapter", adapter) 34 | 35 | def forward( 36 | self, features: Float[Tensor, "B ... 4"], **kwargs 37 | ) -> Float[Tensor, "B ... 3"]: 38 | assert features.shape[-1] == 4 39 | color = features @ self.adapter 40 | color = (color + 1) / 2 41 | color = color.clamp(0.0, 1.0) 42 | return color 43 | -------------------------------------------------------------------------------- /threestudio/__init__.py: -------------------------------------------------------------------------------- 1 | __modules__ = {} 2 | __version__ = "0.2.3" 3 | 4 | 5 | def register(name): 6 | def decorator(cls): 7 | if name in __modules__: 8 | raise ValueError( 9 | f"Module {name} already exists! Names of extensions conflict!" 10 | ) 11 | else: 12 | __modules__[name] = cls 13 | return cls 14 | 15 | return decorator 16 | 17 | 18 | def find(name): 19 | if ":" in name: 20 | main_name, sub_name = name.split(":") 21 | if "," in sub_name: 22 | name_list = sub_name.split(",") 23 | else: 24 | name_list = [sub_name] 25 | name_list.append(main_name) 26 | NewClass = type( 27 | f"{main_name}.{sub_name}", 28 | tuple([__modules__[name] for name in name_list]), 29 | {}, 30 | ) 31 | return NewClass 32 | return __modules__[name] 33 | 34 | 35 | ### grammar sugar for logging utilities ### 36 | import logging 37 | 38 | logger = logging.getLogger("pytorch_lightning") 39 | 40 | from pytorch_lightning.utilities.rank_zero import ( 41 | rank_zero_debug, 42 | rank_zero_info, 43 | rank_zero_only, 44 | ) 45 | 46 | debug = rank_zero_debug 47 | info = rank_zero_info 48 | 49 | 50 | @rank_zero_only 51 | def warn(*args, **kwargs): 52 | logger.warn(*args, **kwargs) 53 | 54 | 55 | from . import data, models, systems 56 | -------------------------------------------------------------------------------- /threestudio/models/exporters/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import threestudio 4 | from threestudio.models.background.base import BaseBackground 5 | from threestudio.models.geometry.base import BaseImplicitGeometry 6 | from threestudio.models.materials.base import BaseMaterial 7 | from threestudio.utils.base import BaseObject 8 | from threestudio.utils.typing import * 9 | 10 | 11 | @dataclass 12 | class ExporterOutput: 13 | save_name: str 14 | save_type: str 15 | params: Dict[str, Any] 16 | 17 | 18 | class Exporter(BaseObject): 19 | @dataclass 20 | class Config(BaseObject.Config): 21 | save_video: bool = False 22 | 23 | cfg: Config 24 | 25 | def configure( 26 | self, 27 | geometry: BaseImplicitGeometry, 28 | material: BaseMaterial, 29 | background: BaseBackground, 30 | ) -> None: 31 | @dataclass 32 | class SubModules: 33 | geometry: BaseImplicitGeometry 34 | material: BaseMaterial 35 | background: BaseBackground 36 | 37 | self.sub_modules = SubModules(geometry, material, background) 38 | 39 | @property 40 | def geometry(self) -> BaseImplicitGeometry: 41 | return self.sub_modules.geometry 42 | 43 | @property 44 | def material(self) -> BaseMaterial: 45 | return self.sub_modules.material 46 | 47 | @property 48 | def background(self) -> BaseBackground: 49 | return self.sub_modules.background 50 | 51 | def __call__(self, *args, **kwargs) -> List[ExporterOutput]: 52 | raise NotImplementedError 53 | 54 | 55 | @threestudio.register("dummy-exporter") 56 | class DummyExporter(Exporter): 57 | def __call__(self, *args, **kwargs) -> List[ExporterOutput]: 58 | # DummyExporter does not export anything 59 | return [] 60 | -------------------------------------------------------------------------------- /threestudio/models/background/solid_color_background.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.background.base import BaseBackground 10 | from threestudio.utils.typing import * 11 | 12 | 13 | @threestudio.register("solid-color-background") 14 | class SolidColorBackground(BaseBackground): 15 | @dataclass 16 | class Config(BaseBackground.Config): 17 | n_output_dims: int = 3 18 | color: Tuple = (1.0, 1.0, 1.0) 19 | learned: bool = False 20 | random_aug: bool = False 21 | random_aug_prob: float = 0.5 22 | 23 | cfg: Config 24 | 25 | def configure(self) -> None: 26 | self.env_color: Float[Tensor, "Nc"] 27 | if self.cfg.learned: 28 | self.env_color = nn.Parameter( 29 | torch.as_tensor(self.cfg.color, dtype=torch.float32) 30 | ) 31 | else: 32 | self.register_buffer( 33 | "env_color", torch.as_tensor(self.cfg.color, dtype=torch.float32) 34 | ) 35 | 36 | def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]: 37 | color = torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to( 38 | dirs 39 | ) * self.env_color.to(dirs) 40 | if ( 41 | self.training 42 | and self.cfg.random_aug 43 | and random.random() < self.cfg.random_aug_prob 44 | ): 45 | # use random background color with probability random_aug_prob 46 | color = color * 0 + ( # prevent checking for unused parameters in DDP 47 | torch.rand(dirs.shape[0], 1, 1, self.cfg.n_output_dims) 48 | .to(dirs) 49 | .expand(*dirs.shape[:-1], -1) 50 | ) 51 | return color 52 | -------------------------------------------------------------------------------- /threestudio/models/background/textured_background.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import threestudio 8 | from threestudio.models.background.base import BaseBackground 9 | from threestudio.utils.ops import get_activation 10 | from threestudio.utils.typing import * 11 | 12 | 13 | @threestudio.register("textured-background") 14 | class TexturedBackground(BaseBackground): 15 | @dataclass 16 | class Config(BaseBackground.Config): 17 | n_output_dims: int = 3 18 | height: int = 64 19 | width: int = 64 20 | color_activation: str = "sigmoid" 21 | 22 | cfg: Config 23 | 24 | def configure(self) -> None: 25 | self.texture = nn.Parameter( 26 | torch.randn((1, self.cfg.n_output_dims, self.cfg.height, self.cfg.width)) 27 | ) 28 | 29 | def spherical_xyz_to_uv(self, dirs: Float[Tensor, "*B 3"]) -> Float[Tensor, "*B 2"]: 30 | x, y, z = dirs[..., 0], dirs[..., 1], dirs[..., 2] 31 | xy = (x**2 + y**2) ** 0.5 32 | u = torch.atan2(xy, z) / torch.pi 33 | v = torch.atan2(y, x) / (torch.pi * 2) + 0.5 34 | uv = torch.stack([u, v], -1) 35 | return uv 36 | 37 | def forward(self, dirs: Float[Tensor, "*B 3"]) -> Float[Tensor, "*B Nc"]: 38 | dirs_shape = dirs.shape[:-1] 39 | uv = self.spherical_xyz_to_uv(dirs.reshape(-1, dirs.shape[-1])) 40 | uv = 2 * uv - 1 # rescale to [-1, 1] for grid_sample 41 | uv = uv.reshape(1, -1, 1, 2) 42 | color = ( 43 | F.grid_sample( 44 | self.texture, 45 | uv, 46 | mode="bilinear", 47 | padding_mode="reflection", 48 | align_corners=False, 49 | ) 50 | .reshape(self.cfg.n_output_dims, -1) 51 | .T.reshape(*dirs_shape, self.cfg.n_output_dims) 52 | ) 53 | color = get_activation(self.cfg.color_activation)(color) 54 | return color 55 | -------------------------------------------------------------------------------- /threestudio/models/materials/neural_radiance_material.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.models.networks import get_encoding, get_mlp 11 | from threestudio.utils.ops import dot, get_activation 12 | from threestudio.utils.typing import * 13 | 14 | 15 | @threestudio.register("neural-radiance-material") 16 | class NeuralRadianceMaterial(BaseMaterial): 17 | @dataclass 18 | class Config(BaseMaterial.Config): 19 | input_feature_dims: int = 8 20 | color_activation: str = "sigmoid" 21 | dir_encoding_config: dict = field( 22 | default_factory=lambda: {"otype": "SphericalHarmonics", "degree": 3} 23 | ) 24 | mlp_network_config: dict = field( 25 | default_factory=lambda: { 26 | "otype": "FullyFusedMLP", 27 | "activation": "ReLU", 28 | "n_neurons": 16, 29 | "n_hidden_layers": 2, 30 | } 31 | ) 32 | 33 | cfg: Config 34 | 35 | def configure(self) -> None: 36 | self.encoding = get_encoding(3, self.cfg.dir_encoding_config) 37 | self.n_input_dims = self.cfg.input_feature_dims + self.encoding.n_output_dims # type: ignore 38 | self.network = get_mlp(self.n_input_dims, 3, self.cfg.mlp_network_config) 39 | 40 | def forward( 41 | self, 42 | features: Float[Tensor, "*B Nf"], 43 | viewdirs: Float[Tensor, "*B 3"], 44 | **kwargs, 45 | ) -> Float[Tensor, "*B 3"]: 46 | # viewdirs and normals must be normalized before passing to this function 47 | viewdirs = (viewdirs + 1.0) / 2.0 # (-1, 1) => (0, 1) 48 | viewdirs_embd = self.encoding(viewdirs.view(-1, 3)) 49 | network_inp = torch.cat( 50 | [features.view(-1, features.shape[-1]), viewdirs_embd], dim=-1 51 | ) 52 | color = self.network(network_inp).view(*features.shape[:-1], 3) 53 | color = get_activation(self.cfg.color_activation)(color) 54 | return color 55 | -------------------------------------------------------------------------------- /load/tets/generate_tets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 4 | # property and proprietary rights in and to this material, related 5 | # documentation and any modifications thereto. Any use, reproduction, 6 | # disclosure or distribution of this material and related documentation 7 | # without an express license agreement from NVIDIA CORPORATION or 8 | # its affiliates is strictly prohibited. 9 | 10 | import os 11 | 12 | import numpy as np 13 | 14 | """ 15 | This code segment shows how to use Quartet: https://github.com/crawforddoran/quartet, 16 | to generate a tet grid 17 | 1) Download, compile and run Quartet as described in the link above. Example usage `quartet meshes/cube.obj 0.5 cube_5.tet` 18 | 2) Run the function below to generate a file `cube_32_tet.tet` 19 | """ 20 | 21 | 22 | def generate_tetrahedron_grid_file(res=32, root=".."): 23 | frac = 1.0 / res 24 | command = f"cd {root}; ./quartet meshes/cube.obj {frac} meshes/cube_{res}_tet.tet -s meshes/cube_boundary_{res}.obj" 25 | os.system(command) 26 | 27 | 28 | """ 29 | This code segment shows how to convert from a quartet .tet file to compressed npz file 30 | """ 31 | 32 | 33 | def convert_from_quartet_to_npz(quartetfile="cube_32_tet.tet", npzfile="32_tets"): 34 | file1 = open(quartetfile, "r") 35 | header = file1.readline() 36 | numvertices = int(header.split(" ")[1]) 37 | numtets = int(header.split(" ")[2]) 38 | print(numvertices, numtets) 39 | 40 | # load vertices 41 | vertices = np.loadtxt(quartetfile, skiprows=1, max_rows=numvertices) 42 | print(vertices.shape) 43 | 44 | # load indices 45 | indices = np.loadtxt( 46 | quartetfile, dtype=int, skiprows=1 + numvertices, max_rows=numtets 47 | ) 48 | print(indices.shape) 49 | 50 | np.savez_compressed(npzfile, vertices=vertices, indices=indices) 51 | 52 | 53 | root = "/home/gyc/quartet" 54 | for res in [300, 350, 400]: 55 | generate_tetrahedron_grid_file(res, root) 56 | convert_from_quartet_to_npz( 57 | os.path.join(root, f"meshes/cube_{res}_tet.tet"), npzfile=f"{res}_tets" 58 | ) 59 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Reference: 2 | # https://github.com/cvpaperchallenge/Ascender 3 | # https://github.com/nerfstudio-project/nerfstudio 4 | 5 | FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 6 | 7 | ARG USER_NAME=dreamer 8 | ARG GROUP_NAME=dreamers 9 | ARG UID=1000 10 | ARG GID=1000 11 | 12 | # Set compute capability for nerfacc and tiny-cuda-nn 13 | # See https://developer.nvidia.com/cuda-gpus and limit number to speed-up build 14 | ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX" 15 | ENV TCNN_CUDA_ARCHITECTURES=90;89;86;80;75;70;61;60 16 | # Speed-up build for RTX 30xx 17 | # ENV TORCH_CUDA_ARCH_LIST="8.6" 18 | # ENV TCNN_CUDA_ARCHITECTURES=86 19 | # Speed-up build for RTX 40xx 20 | # ENV TORCH_CUDA_ARCH_LIST="8.9" 21 | # ENV TCNN_CUDA_ARCHITECTURES=89 22 | 23 | ENV CUDA_HOME=/usr/local/cuda 24 | ENV PATH=${CUDA_HOME}/bin:/home/${USER_NAME}/.local/bin:${PATH} 25 | ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} 26 | ENV LIBRARY_PATH=${CUDA_HOME}/lib64/stubs:${LIBRARY_PATH} 27 | 28 | # apt install by root user 29 | RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ 30 | build-essential \ 31 | curl \ 32 | git \ 33 | libegl1-mesa-dev \ 34 | libgl1-mesa-dev \ 35 | libgles2-mesa-dev \ 36 | libglib2.0-0 \ 37 | libsm6 \ 38 | libxext6 \ 39 | libxrender1 \ 40 | python-is-python3 \ 41 | python3.10-dev \ 42 | python3-pip \ 43 | wget \ 44 | && rm -rf /var/lib/apt/lists/* 45 | 46 | # Change user to non-root user 47 | RUN groupadd -g ${GID} ${GROUP_NAME} \ 48 | && useradd -ms /bin/sh -u ${UID} -g ${GID} ${USER_NAME} 49 | USER ${USER_NAME} 50 | 51 | RUN pip install --upgrade pip setuptools ninja 52 | RUN pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 53 | # Install nerfacc and tiny-cuda-nn before installing requirements.txt 54 | # because these two installations are time consuming and error prone 55 | RUN pip install git+https://github.com/KAIR-BAIR/nerfacc.git@v0.5.2 56 | RUN pip install git+https://github.com/NVlabs/tiny-cuda-nn.git#subdirectory=bindings/torch 57 | 58 | COPY requirements.txt /tmp 59 | RUN cd /tmp && pip install -r requirements.txt 60 | WORKDIR /home/${USER_NAME}/threestudio 61 | -------------------------------------------------------------------------------- /threestudio/models/materials/no_material.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.models.networks import get_encoding, get_mlp 11 | from threestudio.utils.ops import dot, get_activation 12 | from threestudio.utils.typing import * 13 | 14 | 15 | @threestudio.register("no-material") 16 | class NoMaterial(BaseMaterial): 17 | @dataclass 18 | class Config(BaseMaterial.Config): 19 | n_output_dims: int = 3 20 | color_activation: str = "sigmoid" 21 | input_feature_dims: Optional[int] = None 22 | mlp_network_config: Optional[dict] = None 23 | requires_normal: bool = False 24 | 25 | cfg: Config 26 | 27 | def configure(self) -> None: 28 | self.use_network = False 29 | if ( 30 | self.cfg.input_feature_dims is not None 31 | and self.cfg.mlp_network_config is not None 32 | ): 33 | self.network = get_mlp( 34 | self.cfg.input_feature_dims, 35 | self.cfg.n_output_dims, 36 | self.cfg.mlp_network_config, 37 | ) 38 | self.use_network = True 39 | self.requires_normal = self.cfg.requires_normal 40 | 41 | def forward( 42 | self, features: Float[Tensor, "B ... Nf"], **kwargs 43 | ) -> Float[Tensor, "B ... Nc"]: 44 | if not self.use_network: 45 | assert ( 46 | features.shape[-1] == self.cfg.n_output_dims 47 | ), f"Expected {self.cfg.n_output_dims} output dims, only got {features.shape[-1]} dims input." 48 | color = get_activation(self.cfg.color_activation)(features) 49 | else: 50 | color = self.network(features.view(-1, features.shape[-1])).view( 51 | *features.shape[:-1], self.cfg.n_output_dims 52 | ) 53 | color = get_activation(self.cfg.color_activation)(color) 54 | return color 55 | 56 | def export(self, features: Float[Tensor, "*N Nf"], **kwargs) -> Dict[str, Any]: 57 | color = self(features, **kwargs).clamp(0, 1) 58 | assert color.shape[-1] >= 3, "Output color must have at least 3 channels" 59 | if color.shape[-1] > 3: 60 | threestudio.warn( 61 | "Output color has >3 channels, treating the first 3 as RGB" 62 | ) 63 | return {"albedo": color[..., :3]} 64 | -------------------------------------------------------------------------------- /threestudio/models/renderers/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import nerfacc 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | import threestudio 8 | from threestudio.models.background.base import BaseBackground 9 | from threestudio.models.geometry.base import BaseImplicitGeometry 10 | from threestudio.models.materials.base import BaseMaterial 11 | from threestudio.utils.base import BaseModule 12 | from threestudio.utils.typing import * 13 | 14 | 15 | class Renderer(BaseModule): 16 | @dataclass 17 | class Config(BaseModule.Config): 18 | radius: float = 1.0 19 | 20 | cfg: Config 21 | 22 | def configure( 23 | self, 24 | geometry: BaseImplicitGeometry, 25 | material: BaseMaterial, 26 | background: BaseBackground, 27 | ) -> None: 28 | # keep references to submodules using namedtuple, avoid being registered as modules 29 | @dataclass 30 | class SubModules: 31 | geometry: BaseImplicitGeometry 32 | material: BaseMaterial 33 | background: BaseBackground 34 | 35 | self.sub_modules = SubModules(geometry, material, background) 36 | 37 | # set up bounding box 38 | self.bbox: Float[Tensor, "2 3"] 39 | self.register_buffer( 40 | "bbox", 41 | torch.as_tensor( 42 | [ 43 | [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], 44 | [self.cfg.radius, self.cfg.radius, self.cfg.radius], 45 | ], 46 | dtype=torch.float32, 47 | ), 48 | ) 49 | 50 | def forward(self, *args, **kwargs) -> Dict[str, Any]: 51 | raise NotImplementedError 52 | 53 | @property 54 | def geometry(self) -> BaseImplicitGeometry: 55 | return self.sub_modules.geometry 56 | 57 | @property 58 | def material(self) -> BaseMaterial: 59 | return self.sub_modules.material 60 | 61 | @property 62 | def background(self) -> BaseBackground: 63 | return self.sub_modules.background 64 | 65 | def set_geometry(self, geometry: BaseImplicitGeometry) -> None: 66 | self.sub_modules.geometry = geometry 67 | 68 | def set_material(self, material: BaseMaterial) -> None: 69 | self.sub_modules.material = material 70 | 71 | def set_background(self, background: BaseBackground) -> None: 72 | self.sub_modules.background = background 73 | 74 | 75 | class VolumeRenderer(Renderer): 76 | pass 77 | 78 | 79 | class Rasterizer(Renderer): 80 | pass 81 | -------------------------------------------------------------------------------- /threestudio/models/background/neural_environment_map_background.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.background.base import BaseBackground 10 | from threestudio.models.networks import get_encoding, get_mlp 11 | from threestudio.utils.ops import get_activation 12 | from threestudio.utils.typing import * 13 | 14 | 15 | @threestudio.register("neural-environment-map-background") 16 | class NeuralEnvironmentMapBackground(BaseBackground): 17 | @dataclass 18 | class Config(BaseBackground.Config): 19 | n_output_dims: int = 3 20 | color_activation: str = "sigmoid" 21 | dir_encoding_config: dict = field( 22 | default_factory=lambda: {"otype": "SphericalHarmonics", "degree": 3} 23 | ) 24 | mlp_network_config: dict = field( 25 | default_factory=lambda: { 26 | "otype": "VanillaMLP", 27 | "activation": "ReLU", 28 | "n_neurons": 16, 29 | "n_hidden_layers": 2, 30 | } 31 | ) 32 | random_aug: bool = False 33 | random_aug_prob: float = 0.5 34 | eval_color: Optional[Tuple[float, float, float]] = None 35 | 36 | cfg: Config 37 | 38 | def configure(self) -> None: 39 | self.encoding = get_encoding(3, self.cfg.dir_encoding_config) 40 | self.network = get_mlp( 41 | self.encoding.n_output_dims, 42 | self.cfg.n_output_dims, 43 | self.cfg.mlp_network_config, 44 | ) 45 | 46 | def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]: 47 | if not self.training and self.cfg.eval_color is not None: 48 | return torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to( 49 | dirs 50 | ) * torch.as_tensor(self.cfg.eval_color).to(dirs) 51 | # viewdirs must be normalized before passing to this function 52 | dirs = (dirs + 1.0) / 2.0 # (-1, 1) => (0, 1) 53 | dirs_embd = self.encoding(dirs.view(-1, 3)) 54 | color = self.network(dirs_embd).view(*dirs.shape[:-1], self.cfg.n_output_dims) 55 | color = get_activation(self.cfg.color_activation)(color) 56 | if ( 57 | self.training 58 | and self.cfg.random_aug 59 | and random.random() < self.cfg.random_aug_prob 60 | ): 61 | # use random background color with probability random_aug_prob 62 | color = color * 0 + ( # prevent checking for unused parameters in DDP 63 | torch.rand(dirs.shape[0], 1, 1, self.cfg.n_output_dims) 64 | .to(dirs) 65 | .expand(*dirs.shape[:-1], -1) 66 | ) 67 | return color 68 | -------------------------------------------------------------------------------- /threestudio/utils/rasterize.py: -------------------------------------------------------------------------------- 1 | import nvdiffrast.torch as dr 2 | import torch 3 | 4 | from threestudio.utils.typing import * 5 | 6 | 7 | class NVDiffRasterizerContext: 8 | def __init__(self, context_type: str, device: torch.device) -> None: 9 | self.device = device 10 | self.ctx = self.initialize_context(context_type, device) 11 | 12 | def initialize_context( 13 | self, context_type: str, device: torch.device 14 | ) -> Union[dr.RasterizeGLContext, dr.RasterizeCudaContext]: 15 | if context_type == "gl": 16 | return dr.RasterizeGLContext(device=device) 17 | elif context_type == "cuda": 18 | return dr.RasterizeCudaContext(device=device) 19 | else: 20 | raise ValueError(f"Unknown rasterizer context type: {context_type}") 21 | 22 | def vertex_transform( 23 | self, verts: Float[Tensor, "Nv 3"], mvp_mtx: Float[Tensor, "B 4 4"] 24 | ) -> Float[Tensor, "B Nv 4"]: 25 | verts_homo = torch.cat( 26 | [verts, torch.ones([verts.shape[0], 1]).to(verts)], dim=-1 27 | ) 28 | return torch.matmul(verts_homo, mvp_mtx.permute(0, 2, 1)) 29 | 30 | def rasterize( 31 | self, 32 | pos: Float[Tensor, "B Nv 4"], 33 | tri: Integer[Tensor, "Nf 3"], 34 | resolution: Union[int, Tuple[int, int]], 35 | ): 36 | # rasterize in instance mode (single topology) 37 | return dr.rasterize(self.ctx, pos.float(), tri.int(), resolution, grad_db=True) 38 | 39 | def rasterize_one( 40 | self, 41 | pos: Float[Tensor, "Nv 4"], 42 | tri: Integer[Tensor, "Nf 3"], 43 | resolution: Union[int, Tuple[int, int]], 44 | ): 45 | # rasterize one single mesh under a single viewpoint 46 | rast, rast_db = self.rasterize(pos[None, ...], tri, resolution) 47 | return rast[0], rast_db[0] 48 | 49 | def antialias( 50 | self, 51 | color: Float[Tensor, "B H W C"], 52 | rast: Float[Tensor, "B H W 4"], 53 | pos: Float[Tensor, "B Nv 4"], 54 | tri: Integer[Tensor, "Nf 3"], 55 | ) -> Float[Tensor, "B H W C"]: 56 | return dr.antialias(color.float(), rast, pos.float(), tri.int()) 57 | 58 | def interpolate( 59 | self, 60 | attr: Float[Tensor, "B Nv C"], 61 | rast: Float[Tensor, "B H W 4"], 62 | tri: Integer[Tensor, "Nf 3"], 63 | rast_db=None, 64 | diff_attrs=None, 65 | ) -> Float[Tensor, "B H W C"]: 66 | return dr.interpolate( 67 | attr.float(), rast, tri.int(), rast_db=rast_db, diff_attrs=diff_attrs 68 | ) 69 | 70 | def interpolate_one( 71 | self, 72 | attr: Float[Tensor, "Nv C"], 73 | rast: Float[Tensor, "B H W 4"], 74 | tri: Integer[Tensor, "Nf 3"], 75 | rast_db=None, 76 | diff_attrs=None, 77 | ) -> Float[Tensor, "B H W C"]: 78 | return self.interpolate(attr[None, ...], rast, tri, rast_db, diff_attrs) 79 | -------------------------------------------------------------------------------- /threestudio/systems/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import warnings 3 | from bisect import bisect_right 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.optim import lr_scheduler 8 | 9 | import threestudio 10 | 11 | 12 | def get_scheduler(name): 13 | if hasattr(lr_scheduler, name): 14 | return getattr(lr_scheduler, name) 15 | else: 16 | raise NotImplementedError 17 | 18 | 19 | def getattr_recursive(m, attr): 20 | for name in attr.split("."): 21 | m = getattr(m, name) 22 | return m 23 | 24 | 25 | def get_parameters(model, name): 26 | module = getattr_recursive(model, name) 27 | if isinstance(module, nn.Module): 28 | return module.parameters() 29 | elif isinstance(module, nn.Parameter): 30 | return module 31 | return [] 32 | 33 | 34 | def parse_optimizer(config, model): 35 | if hasattr(config, "params"): 36 | params = [ 37 | {"params": get_parameters(model, name), "name": name, **args} 38 | for name, args in config.params.items() 39 | ] 40 | threestudio.debug(f"Specify optimizer params: {config.params}") 41 | else: 42 | params = model.parameters() 43 | if config.name in ["FusedAdam"]: 44 | import apex 45 | 46 | optim = getattr(apex.optimizers, config.name)(params, **config.args) 47 | elif config.name in ["Adan"]: 48 | from threestudio.systems import optimizers 49 | 50 | optim = getattr(optimizers, config.name)(params, **config.args) 51 | else: 52 | optim = getattr(torch.optim, config.name)(params, **config.args) 53 | return optim 54 | 55 | 56 | def parse_scheduler_to_instance(config, optimizer): 57 | if config.name == "ChainedScheduler": 58 | schedulers = [ 59 | parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers 60 | ] 61 | scheduler = lr_scheduler.ChainedScheduler(schedulers) 62 | elif config.name == "Sequential": 63 | schedulers = [ 64 | parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers 65 | ] 66 | scheduler = lr_scheduler.SequentialLR( 67 | optimizer, schedulers, milestones=config.milestones 68 | ) 69 | else: 70 | scheduler = getattr(lr_scheduler, config.name)(optimizer, **config.args) 71 | return scheduler 72 | 73 | 74 | def parse_scheduler(config, optimizer): 75 | interval = config.get("interval", "epoch") 76 | assert interval in ["epoch", "step"] 77 | if config.name == "SequentialLR": 78 | scheduler = { 79 | "scheduler": lr_scheduler.SequentialLR( 80 | optimizer, 81 | [ 82 | parse_scheduler(conf, optimizer)["scheduler"] 83 | for conf in config.schedulers 84 | ], 85 | milestones=config.milestones, 86 | ), 87 | "interval": interval, 88 | } 89 | elif config.name == "ChainedScheduler": 90 | scheduler = { 91 | "scheduler": lr_scheduler.ChainedScheduler( 92 | [ 93 | parse_scheduler(conf, optimizer)["scheduler"] 94 | for conf in config.schedulers 95 | ] 96 | ), 97 | "interval": interval, 98 | } 99 | else: 100 | scheduler = { 101 | "scheduler": get_scheduler(config.name)(optimizer, **config.args), 102 | "interval": interval, 103 | } 104 | return scheduler 105 | -------------------------------------------------------------------------------- /threestudio/utils/GAN/distribution.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.sum( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=[1, 2, 3]): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims, 69 | ) 70 | 71 | def mode(self): 72 | return self.mean 73 | 74 | 75 | def normal_kl(mean1, logvar1, mean2, logvar2): 76 | """ 77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 78 | Compute the KL divergence between two gaussians. 79 | Shapes are automatically broadcasted, so batches can be compared to 80 | scalars, among other use cases. 81 | """ 82 | tensor = None 83 | for obj in (mean1, logvar1, mean2, logvar2): 84 | if isinstance(obj, torch.Tensor): 85 | tensor = obj 86 | break 87 | assert tensor is not None, "at least one argument must be a Tensor" 88 | 89 | # Force variances to be Tensors. Broadcasting helps convert scalars to 90 | # Tensors, but it does not work for torch.exp(). 91 | logvar1, logvar2 = [ 92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 93 | for x in (logvar1, logvar2) 94 | ] 95 | 96 | return 0.5 * ( 97 | -1.0 98 | + logvar2 99 | - logvar1 100 | + torch.exp(logvar1 - logvar2) 101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 102 | ) 103 | -------------------------------------------------------------------------------- /load/zero123/sd-objaverse-finetune-c_concat-256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: extern.ldm_zero123.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "image_target" 11 | cond_stage_key: "image_cond" 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: hybrid 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | 19 | scheduler_config: # 10000 warmup steps 20 | target: extern.ldm_zero123.lr_scheduler.LambdaLinearScheduler 21 | params: 22 | warm_up_steps: [ 100 ] 23 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 24 | f_start: [ 1.e-6 ] 25 | f_max: [ 1. ] 26 | f_min: [ 1. ] 27 | 28 | unet_config: 29 | target: extern.ldm_zero123.modules.diffusionmodules.openaimodel.UNetModel 30 | params: 31 | image_size: 32 # unused 32 | in_channels: 8 33 | out_channels: 4 34 | model_channels: 320 35 | attention_resolutions: [ 4, 2, 1 ] 36 | num_res_blocks: 2 37 | channel_mult: [ 1, 2, 4, 4 ] 38 | num_heads: 8 39 | use_spatial_transformer: True 40 | transformer_depth: 1 41 | context_dim: 768 42 | use_checkpoint: True 43 | legacy: False 44 | 45 | first_stage_config: 46 | target: extern.ldm_zero123.models.autoencoder.AutoencoderKL 47 | params: 48 | embed_dim: 4 49 | monitor: val/rec_loss 50 | ddconfig: 51 | double_z: true 52 | z_channels: 4 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: 58 | - 1 59 | - 2 60 | - 4 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | 68 | cond_stage_config: 69 | target: extern.ldm_zero123.modules.encoders.modules.FrozenCLIPImageEmbedder 70 | 71 | 72 | # data: 73 | # target: extern.ldm_zero123.data.simple.ObjaverseDataModuleFromConfig 74 | # params: 75 | # root_dir: 'views_whole_sphere' 76 | # batch_size: 192 77 | # num_workers: 16 78 | # total_view: 4 79 | # train: 80 | # validation: False 81 | # image_transforms: 82 | # size: 256 83 | 84 | # validation: 85 | # validation: True 86 | # image_transforms: 87 | # size: 256 88 | 89 | 90 | # lightning: 91 | # find_unused_parameters: false 92 | # metrics_over_trainsteps_checkpoint: True 93 | # modelcheckpoint: 94 | # params: 95 | # every_n_train_steps: 5000 96 | # callbacks: 97 | # image_logger: 98 | # target: main.ImageLogger 99 | # params: 100 | # batch_frequency: 500 101 | # max_images: 32 102 | # increase_log_steps: False 103 | # log_first_step: True 104 | # log_images_kwargs: 105 | # use_ema_scope: False 106 | # inpaint: False 107 | # plot_progressive_rows: False 108 | # plot_diffusion_rows: False 109 | # N: 32 110 | # unconditional_scale: 3.0 111 | # unconditional_label: [""] 112 | 113 | # trainer: 114 | # benchmark: True 115 | # val_check_interval: 5000000 # really sorry 116 | # num_sanity_val_steps: 0 117 | # accumulate_grad_batches: 1 118 | -------------------------------------------------------------------------------- /threestudio/models/prompt_processors/deepfloyd_prompt_processor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass 4 | 5 | import torch 6 | import torch.nn as nn 7 | from diffusers import IFPipeline 8 | from transformers import T5EncoderModel, T5Tokenizer 9 | 10 | import threestudio 11 | from threestudio.models.prompt_processors.base import PromptProcessor, hash_prompt 12 | from threestudio.utils.misc import cleanup 13 | from threestudio.utils.typing import * 14 | 15 | 16 | @threestudio.register("deep-floyd-prompt-processor") 17 | class DeepFloydPromptProcessor(PromptProcessor): 18 | @dataclass 19 | class Config(PromptProcessor.Config): 20 | pretrained_model_name_or_path: str = "DeepFloyd/IF-I-XL-v1.0" 21 | 22 | cfg: Config 23 | 24 | ### these functions are unused, kept for debugging ### 25 | def configure_text_encoder(self) -> None: 26 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 27 | self.text_encoder = T5EncoderModel.from_pretrained( 28 | self.cfg.pretrained_model_name_or_path, 29 | subfolder="text_encoder", 30 | load_in_8bit=True, 31 | variant="8bit", 32 | device_map="auto", 33 | ) # FIXME: behavior of auto device map in multi-GPU training 34 | self.pipe = IFPipeline.from_pretrained( 35 | self.cfg.pretrained_model_name_or_path, 36 | text_encoder=self.text_encoder, # pass the previously instantiated 8bit text encoder 37 | unet=None, 38 | ) 39 | 40 | def destroy_text_encoder(self) -> None: 41 | del self.text_encoder 42 | del self.pipe 43 | cleanup() 44 | 45 | def get_text_embeddings( 46 | self, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] 47 | ) -> Tuple[Float[Tensor, "B 77 4096"], Float[Tensor, "B 77 4096"]]: 48 | text_embeddings, uncond_text_embeddings = self.pipe.encode_prompt( 49 | prompt=prompt, negative_prompt=negative_prompt, device=self.device 50 | ) 51 | return text_embeddings, uncond_text_embeddings 52 | 53 | ### 54 | 55 | @staticmethod 56 | def spawn_func(pretrained_model_name_or_path, prompts, cache_dir): 57 | max_length = 77 58 | tokenizer = T5Tokenizer.from_pretrained( 59 | pretrained_model_name_or_path, subfolder="tokenizer" 60 | ) 61 | text_encoder = T5EncoderModel.from_pretrained( 62 | pretrained_model_name_or_path, 63 | subfolder="text_encoder", 64 | torch_dtype=torch.float16, # suppress warning 65 | load_in_8bit=True, 66 | variant="8bit", 67 | device_map="auto", 68 | ) 69 | with torch.no_grad(): 70 | text_inputs = tokenizer( 71 | prompts, 72 | padding="max_length", 73 | max_length=max_length, 74 | truncation=True, 75 | add_special_tokens=True, 76 | return_tensors="pt", 77 | ) 78 | text_input_ids = text_inputs.input_ids 79 | attention_mask = text_inputs.attention_mask 80 | text_embeddings = text_encoder( 81 | text_input_ids.to(text_encoder.device), 82 | attention_mask=attention_mask.to(text_encoder.device), 83 | ) 84 | text_embeddings = text_embeddings[0] 85 | 86 | for prompt, embedding in zip(prompts, text_embeddings): 87 | torch.save( 88 | embedding, 89 | os.path.join( 90 | cache_dir, 91 | f"{hash_prompt(pretrained_model_name_or_path, prompt)}.pt", 92 | ), 93 | ) 94 | 95 | del text_encoder 96 | -------------------------------------------------------------------------------- /threestudio/models/prompt_processors/stable_diffusion_prompt_processor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass 4 | 5 | import torch 6 | import torch.nn as nn 7 | from transformers import AutoTokenizer, CLIPTextModel 8 | 9 | import threestudio 10 | from threestudio.models.prompt_processors.base import PromptProcessor, hash_prompt 11 | from threestudio.utils.misc import cleanup 12 | from threestudio.utils.typing import * 13 | 14 | 15 | @threestudio.register("stable-diffusion-prompt-processor") 16 | class StableDiffusionPromptProcessor(PromptProcessor): 17 | @dataclass 18 | class Config(PromptProcessor.Config): 19 | pass 20 | 21 | cfg: Config 22 | 23 | ### these functions are unused, kept for debugging ### 24 | def configure_text_encoder(self) -> None: 25 | self.tokenizer = AutoTokenizer.from_pretrained( 26 | self.cfg.pretrained_model_name_or_path, subfolder="tokenizer" 27 | ) 28 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 29 | self.text_encoder = CLIPTextModel.from_pretrained( 30 | self.cfg.pretrained_model_name_or_path, subfolder="text_encoder" 31 | ).to(self.device) 32 | 33 | for p in self.text_encoder.parameters(): 34 | p.requires_grad_(False) 35 | 36 | def destroy_text_encoder(self) -> None: 37 | del self.tokenizer 38 | del self.text_encoder 39 | cleanup() 40 | 41 | def get_text_embeddings( 42 | self, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] 43 | ) -> Tuple[Float[Tensor, "B 77 768"], Float[Tensor, "B 77 768"]]: 44 | if isinstance(prompt, str): 45 | prompt = [prompt] 46 | if isinstance(negative_prompt, str): 47 | negative_prompt = [negative_prompt] 48 | # Tokenize text and get embeddings 49 | tokens = self.tokenizer( 50 | prompt, 51 | padding="max_length", 52 | max_length=self.tokenizer.model_max_length, 53 | return_tensors="pt", 54 | ) 55 | uncond_tokens = self.tokenizer( 56 | negative_prompt, 57 | padding="max_length", 58 | max_length=self.tokenizer.model_max_length, 59 | return_tensors="pt", 60 | ) 61 | 62 | with torch.no_grad(): 63 | text_embeddings = self.text_encoder(tokens.input_ids.to(self.device))[0] 64 | uncond_text_embeddings = self.text_encoder( 65 | uncond_tokens.input_ids.to(self.device) 66 | )[0] 67 | 68 | return text_embeddings, uncond_text_embeddings 69 | 70 | ### 71 | 72 | @staticmethod 73 | def spawn_func(pretrained_model_name_or_path, prompts, cache_dir): 74 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 75 | tokenizer = AutoTokenizer.from_pretrained( 76 | pretrained_model_name_or_path, subfolder="tokenizer" 77 | ) 78 | text_encoder = CLIPTextModel.from_pretrained( 79 | pretrained_model_name_or_path, 80 | subfolder="text_encoder", 81 | device_map="auto", 82 | ) 83 | 84 | with torch.no_grad(): 85 | tokens = tokenizer( 86 | prompts, 87 | padding="max_length", 88 | max_length=tokenizer.model_max_length, 89 | return_tensors="pt", 90 | ) 91 | text_embeddings = text_encoder(tokens.input_ids.to(text_encoder.device))[0] 92 | 93 | for prompt, embedding in zip(prompts, text_embeddings): 94 | torch.save( 95 | embedding, 96 | os.path.join( 97 | cache_dir, 98 | f"{hash_prompt(pretrained_model_name_or_path, prompt)}.pt", 99 | ), 100 | ) 101 | 102 | del text_encoder 103 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # ---> Python 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | 163 | .threestudio_cache 164 | .vscode 165 | outputs 166 | *.png 167 | debug 168 | *.obj 169 | pretrained_models 170 | raw_data 171 | data 172 | wandb 173 | *.ply 174 | *.txt 175 | archive 176 | configs_debug 177 | run_debug.sh -------------------------------------------------------------------------------- /configs/sweetdreamer-stage2.yaml: -------------------------------------------------------------------------------- 1 | name: "sweeetdreamer-stage2" 2 | tag: "${rmspace:${system.prompt_processor.prompt},_}" 3 | exp_root_dir: "outputs" 4 | seed: 0 5 | use_timestamp: false 6 | 7 | data_type: "random-camera-datamodule" 8 | data: 9 | batch_size: 2 10 | width: 512 11 | height: 512 12 | camera_distance_range: [1.4, 1.6] 13 | # fovy_range: [30, 40] # more precise 14 | fovy_range: [44, 46] # more precise 15 | elevation_range: [-10, 45] 16 | camera_perturb: 0. 17 | center_perturb: 0. 18 | up_perturb: 0. 19 | eval_camera_distance: 1.5 20 | eval_fovy_deg: 35 21 | n_val_views: 4 22 | 23 | system_type: "prolificdreamer-system" 24 | system: 25 | start_app: 0 26 | before_start_app_weight: 0.0 27 | app_weight: 1.0 28 | end_app: 100000 29 | after_end_app_weight: 1.0 30 | 31 | start_cmm: 0 32 | before_start_cmm_weight: 20.0 33 | cmm_weight: 20.0 34 | end_cmm: 0 35 | after_end_cmm_weight: 0.5 36 | 37 | geometry_type: "implicit-volume" 38 | geometry: 39 | radius: 1.0 40 | isosurface: true 41 | isosurface_method: "mc-cpu" 42 | isosurface_resolution: 64 43 | isosurface_threshold: "auto" 44 | isosurface_coarse_to_fine: false 45 | normal_type: "finite_difference" 46 | 47 | density_bias: "blob_magic3d" 48 | density_activation: softplus 49 | density_blob_scale: 10. 50 | density_blob_std: 0.5 51 | 52 | pos_encoding_config: 53 | otype: HashGrid 54 | n_levels: 16 55 | n_features_per_level: 2 56 | log2_hashmap_size: 19 57 | base_resolution: 16 58 | per_level_scale: 1.447269237440378 # max resolution 4096 59 | 60 | material_type: "no-material" 61 | material: 62 | n_output_dims: 3 63 | color_activation: sigmoid 64 | requires_normal: true 65 | 66 | background_type: "neural-environment-map-background" 67 | background: 68 | color_activation: sigmoid 69 | random_aug: true 70 | random_aug_prob: 0.5 71 | 72 | renderer_type: "nerf-volume-renderer" 73 | renderer: 74 | radius: ${system.geometry.radius} 75 | num_samples_per_ray: 512 76 | return_comp_normal: false 77 | 78 | prompt_processor_type: "stable-diffusion-prompt-processor" 79 | prompt_processor: 80 | pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base" 81 | negative_prompt: "worst quality, normal quality, low quality, low res, blurry, text, watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, sketch ,duplicate, ugly, monochrome, horror, geometry, mutation, disgusting" 82 | prompt: ??? 83 | front_threshold: 45. 84 | back_threshold: 45. 85 | view_dependent_prompt_front: true 86 | 87 | cmm_prompt_processor_type: "stable-diffusion-prompt-processor" 88 | cmm_prompt_processor: 89 | pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base" 90 | prompt: ??? 91 | 92 | guidance_type: "stable-diffusion-guidance" 93 | guidance: 94 | cmm_pretrained_model_name_or_path: "ckpt/checkpoint-latent-tiny" 95 | cmm_schedule_pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base" 96 | pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base" 97 | guidance_scale: 100 98 | min_step_percent: 0.02 99 | max_step_percent: [0, 0.98, 0.5, 4001] 100 | 101 | loggers: 102 | wandb: 103 | enable: false 104 | project: "threestudio" 105 | name: None 106 | 107 | loss: 108 | lambda_orient: [0, 10., 1000., 4000] 109 | lambda_sparsity: 0. 110 | lambda_opaque: 0. 111 | lambda_z_variance: 0. 112 | optimizer: 113 | name: AdamW 114 | args: 115 | betas: [0.9, 0.99] 116 | eps: 1.e-15 117 | params: 118 | geometry.encoding: 119 | lr: 0.01 120 | geometry.density_network: 121 | lr: 0.001 122 | geometry.feature_network: 123 | lr: 0.001 124 | background: 125 | lr: 0.001 126 | guidance: 127 | lr: 0.0001 128 | 129 | trainer: 130 | max_steps: 10000 131 | log_every_n_steps: 1 132 | num_sanity_val_steps: 0 133 | val_check_interval: 200 134 | enable_progress_bar: true 135 | precision: 16-mixed 136 | 137 | checkpoint: 138 | save_last: true 139 | save_top_k: -1 140 | every_n_train_steps: ${trainer.max_steps} 141 | -------------------------------------------------------------------------------- /threestudio/models/renderers/patch_renderer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import threestudio 7 | from threestudio.models.background.base import BaseBackground 8 | from threestudio.models.geometry.base import BaseImplicitGeometry 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.models.renderers.base import VolumeRenderer 11 | from threestudio.utils.typing import * 12 | 13 | 14 | @threestudio.register("patch-renderer") 15 | class PatchRenderer(VolumeRenderer): 16 | @dataclass 17 | class Config(VolumeRenderer.Config): 18 | patch_size: int = 128 19 | base_renderer_type: str = "" 20 | base_renderer: Optional[VolumeRenderer.Config] = None 21 | global_detach: bool = False 22 | global_downsample: int = 4 23 | 24 | cfg: Config 25 | 26 | def configure( 27 | self, 28 | geometry: BaseImplicitGeometry, 29 | material: BaseMaterial, 30 | background: BaseBackground, 31 | ) -> None: 32 | self.base_renderer = threestudio.find(self.cfg.base_renderer_type)( 33 | self.cfg.base_renderer, 34 | geometry=geometry, 35 | material=material, 36 | background=background, 37 | ) 38 | 39 | def forward( 40 | self, 41 | rays_o: Float[Tensor, "B H W 3"], 42 | rays_d: Float[Tensor, "B H W 3"], 43 | light_positions: Float[Tensor, "B 3"], 44 | bg_color: Optional[Tensor] = None, 45 | **kwargs 46 | ) -> Dict[str, Float[Tensor, "..."]]: 47 | B, H, W, _ = rays_o.shape 48 | 49 | if self.base_renderer.training: 50 | downsample = self.cfg.global_downsample 51 | global_rays_o = torch.nn.functional.interpolate( 52 | rays_o.permute(0, 3, 1, 2), 53 | (H // downsample, W // downsample), 54 | mode="bilinear", 55 | ).permute(0, 2, 3, 1) 56 | global_rays_d = torch.nn.functional.interpolate( 57 | rays_d.permute(0, 3, 1, 2), 58 | (H // downsample, W // downsample), 59 | mode="bilinear", 60 | ).permute(0, 2, 3, 1) 61 | out_global = self.base_renderer( 62 | global_rays_o, global_rays_d, light_positions, bg_color, **kwargs 63 | ) 64 | 65 | PS = self.cfg.patch_size 66 | patch_x = torch.randint(0, W - PS, (1,)).item() 67 | patch_y = torch.randint(0, H - PS, (1,)).item() 68 | patch_rays_o = rays_o[:, patch_y : patch_y + PS, patch_x : patch_x + PS] 69 | patch_rays_d = rays_d[:, patch_y : patch_y + PS, patch_x : patch_x + PS] 70 | out = self.base_renderer( 71 | patch_rays_o, patch_rays_d, light_positions, bg_color, **kwargs 72 | ) 73 | 74 | valid_patch_key = [] 75 | for key in out: 76 | if torch.is_tensor(out[key]): 77 | if len(out[key].shape) == len(out["comp_rgb"].shape): 78 | if out[key][..., 0].shape == out["comp_rgb"][..., 0].shape: 79 | valid_patch_key.append(key) 80 | for key in valid_patch_key: 81 | out_global[key] = F.interpolate( 82 | out_global[key].permute(0, 3, 1, 2), (H, W), mode="bilinear" 83 | ).permute(0, 2, 3, 1) 84 | if self.cfg.global_detach: 85 | out_global[key] = out_global[key].detach() 86 | out_global[key][ 87 | :, patch_y : patch_y + PS, patch_x : patch_x + PS 88 | ] = out[key] 89 | out = out_global 90 | else: 91 | out = self.base_renderer( 92 | rays_o, rays_d, light_positions, bg_color, **kwargs 93 | ) 94 | 95 | return out 96 | 97 | def update_step( 98 | self, epoch: int, global_step: int, on_load_weights: bool = False 99 | ) -> None: 100 | self.base_renderer.update_step(epoch, global_step, on_load_weights) 101 | 102 | def train(self, mode=True): 103 | return self.base_renderer.train(mode) 104 | 105 | def eval(self): 106 | return self.base_renderer.eval() 107 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SweetDreamer: Aligning Geometric Priors in 2D Diffusion for Consistent Text-to-3D (ICLR 2024) 2 | 3 | #####
[Weiyu Li](https://wyysf-98.github.io/), [Rui Chen](https://aruichen.github.io/), [Xuelin Chen](https://xuelin-chen.github.io/), [Ping Tan](https://ece.hkust.edu.hk/pingtan)
4 | 5 |
6 |
7 |
[Project Page](https://sweetdreamer3d.github.io/) | [ArXiv](https://arxiv.org/abs/2310.02596) | [Paper]() | [Video]()
10 |All Code and Ckpt will be released in the next few days, sorry for the delay due some to some permission issues :( 🏗️ 🚧 🔨
11 | 12 | ### Important: This repo. is under construction. Finally, I got time to sort it out :) really sorry for the dealy 13 | 14 | - [x] Release the reorganized code 15 | - [ ] Release the pretrained model (tiny-version) 16 | - [ ] Release the full model 17 | 18 | ## Prerequisite 19 | 20 | ### Setup environment (Install threestudio) 21 | 22 | **This part is the same as original threestudio. Skip it if you already have installed the environment.** 23 | 24 | See [installation.md](https://github.com/threestudio-project/threestudio/blob/main/docs/installation.md) for additional information, including installation via Docker. 25 | 26 | - You must have an NVIDIA graphics card with at least 20GB VRAM and have [CUDA](https://developer.nvidia.com/cuda-downloads) installed. 27 | - Install `Python >= 3.8`. 28 | - (Optional, Recommended) Create a virtual environment: 29 | 30 | ```sh 31 | python3 -m virtualenv venv 32 | . venv/bin/activate 33 | 34 | # Newer pip versions, e.g. pip-23.x, can be much faster than old versions, e.g. pip-20.x. 35 | # For instance, it caches the wheels of git packages to avoid unnecessarily rebuilding them later. 36 | python3 -m pip install --upgrade pip 37 | ``` 38 | 39 | - Install `PyTorch >= 1.12`. We have tested on `torch1.12.1+cu113` and `torch2.0.0+cu118`, but other versions should also work fine. 40 | 41 | ```sh 42 | # torch1.12.1+cu113 43 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 44 | # or torch2.0.0+cu118 45 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118 46 | ``` 47 | 48 | - (Optional, Recommended) Install ninja to speed up the compilation of CUDA extensions: 49 | 50 | ```sh 51 | pip install ninja 52 | ``` 53 | 54 | - Install dependencies: 55 | 56 | ```sh 57 | pip install -r requirements.txt 58 | ``` 59 | 60 | 61 | ### Download the pretrained CCM model(TBD) 62 | 63 | ```sh 64 | sh download.sh 65 | ``` 66 | 67 | 68 | 69 | ## Quick demo 70 | 71 | ```sh 72 | python launch.py --config configs/sweetdreamer-stage1.yaml --train --gpu 0 \ 73 | system.prompt_processor.prompt="Albert Einstein with grey suit is riding a bicycle" \ 74 | system.cmm_prompt_processor.prompt="Albert Einstein with grey suit is riding a bicycle" \ 75 | tag=einstein 76 | 77 | python launch.py --config configs/sweetdreamer-stage2.yaml --train --gpu 0 \ 78 | system.prompt_processor.prompt="Albert Einstein with grey suit is riding a bicycle" \ 79 | system.cmm_prompt_processor.prompt="Albert Einstein with grey suit is riding a bicycle" \ 80 | tag=einstein 81 | ``` 82 | 83 | 84 | ## Acknowledgement 85 | 86 | This code is built on the amazing open-source projects: 87 | - [threestudio-project](https://github.com/threestudio-project/threestudio?tab=readme-ov-file) 88 | - [diffusers](https://github.com/huggingface/diffusers) 89 | - [stable-diffusion](https://stability.ai/news/stable-diffusion-public-release) 90 | - [deep-floyed](https://github.com/deep-floyd/IF?tab=readme-ov-file) 91 | 92 | We also thank Jianxiong Pan and Feipeng Tian for the help of the data and GPU server. 93 | 94 | ## Citation 95 | 96 | If you find our work useful for your research, please consider citing using the following BibTeX entry. 97 | 98 | ```BibTeX 99 | @article{sweetdreamer, 100 | author = {Weiyu Li and Rui Chen and Xuelin Chen and Ping Tan}, 101 | title = {SweetDreamer: Aligning Geometric Priors in 2D Diffusion for Consistent Text-to-3D}, 102 | journal = {arxiv:2310.02596}, 103 | year = {2023}, 104 | } 105 | ``` 106 | -------------------------------------------------------------------------------- /configs/sweetdreamer-stage1.yaml: -------------------------------------------------------------------------------- 1 | name: "sweetdreamer-stage1" 2 | tag: "${rmspace:${system.prompt_processor.prompt},_}" 3 | exp_root_dir: "outputs" 4 | seed: 0 5 | use_timestamp: false 6 | 7 | data_type: "random-camera-datamodule" 8 | data: 9 | batch_size: 8 10 | width: 64 11 | height: 64 12 | camera_distance_range: [1.4, 1.6] 13 | fovy_range: [35, 45] 14 | elevation_range: [-10, 45] 15 | camera_perturb: 0. 16 | center_perturb: 0. 17 | up_perturb: 0. 18 | eval_camera_distance: 1.5 19 | eval_fovy_deg: 35 20 | n_val_views: 4 21 | 22 | system_type: "sweetdreamer-system" 23 | system: 24 | start_app: 1500 25 | before_start_app_weight: 0.0 26 | app_weight: 1.0 27 | end_app: 100000 28 | after_end_app_weight: 1.0 29 | 30 | start_cmm: 0 31 | before_start_cmm_weight: 0.0 32 | cmm_weight: 20.0 33 | end_cmm: 1500 34 | after_end_cmm_weight: 1.0 35 | 36 | geometry_type: "implicit-volume" 37 | geometry: 38 | radius: 2.0 39 | isosurface: true 40 | isosurface_method: "mc-cpu" 41 | isosurface_resolution: 64 42 | isosurface_threshold: "auto" 43 | isosurface_coarse_to_fine: false 44 | normal_type: "finite_difference" 45 | 46 | density_bias: "blob_magic3d" 47 | density_activation: softplus 48 | density_blob_scale: 10. 49 | density_blob_std: 0.5 50 | 51 | pos_encoding_config: 52 | otype: HashGrid 53 | n_levels: 16 54 | n_features_per_level: 2 55 | log2_hashmap_size: 19 56 | base_resolution: 16 57 | per_level_scale: 1.447269237440378 # max resolution 4096 58 | 59 | material_type: "no-material" 60 | material: 61 | n_output_dims: 3 62 | color_activation: sigmoid 63 | requires_normal: true 64 | 65 | background_type: "neural-environment-map-background" 66 | background: 67 | color_activation: sigmoid 68 | random_aug: true 69 | random_aug_prob: 0.5 70 | 71 | renderer_type: "nerf-volume-renderer" 72 | renderer: 73 | radius: ${system.geometry.radius} 74 | num_samples_per_ray: 512 75 | return_comp_normal: true 76 | 77 | prompt_processor_type: "deep-floyd-prompt-processor" 78 | prompt_processor: 79 | pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0" 80 | negative_prompt: "worst quality, normal quality, low quality, low res, blurry, text, watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, sketch ,duplicate, ugly, monochrome, horror, geometry, mutation, disgusting" 81 | prompt: ??? 82 | front_threshold: 45. 83 | back_threshold: 45. 84 | view_dependent_prompt_front: true 85 | 86 | cmm_prompt_processor_type: "stable-diffusion-prompt-processor" 87 | cmm_prompt_processor: 88 | pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base" 89 | negative_prompt: "worst quality, normal quality, low quality, low res, blurry, text, watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, sketch ,duplicate, ugly, monochrome, horror, geometry, mutation, disgusting" 90 | prompt: ??? 91 | front_threshold: 45. 92 | back_threshold: 45. 93 | view_dependent_prompt_front: true 94 | 95 | guidance_type: "deep-floyd-guidance" 96 | guidance: 97 | cmm_pretrained_model_name_or_path: "ckpt/checkpoint-latent-tiny" 98 | cmm_schedule_pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base" 99 | pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0" 100 | guidance_scale: 20. 101 | min_step_percent: 0.02 102 | max_step_percent: [0, 0.98, 0.5, 5001] 103 | 104 | loggers: 105 | wandb: 106 | enable: false 107 | project: "threestudio" 108 | name: None 109 | 110 | loss: 111 | lambda_orient: [0, 10., 1000., 5000] 112 | lambda_sparsity: 0. 113 | lambda_opaque: 0. 114 | lambda_z_variance: 0. 115 | optimizer: 116 | name: AdamW 117 | args: 118 | betas: [0.9, 0.99] 119 | eps: 1.e-15 120 | params: 121 | geometry.encoding: 122 | lr: 0.01 123 | geometry.density_network: 124 | lr: 0.001 125 | geometry.feature_network: 126 | lr: 0.001 127 | background: 128 | lr: 0.001 129 | guidance: 130 | lr: 0.0001 131 | 132 | trainer: 133 | max_steps: 6000 134 | log_every_n_steps: 1 135 | num_sanity_val_steps: 0 136 | val_check_interval: 200 137 | enable_progress_bar: true 138 | precision: 16-mixed 139 | 140 | checkpoint: 141 | save_last: true 142 | save_top_k: -1 143 | every_n_train_steps: ${trainer.max_steps} 144 | -------------------------------------------------------------------------------- /threestudio/utils/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from threestudio.utils.config import parse_structured 7 | from threestudio.utils.misc import get_device, load_module_weights 8 | from threestudio.utils.typing import * 9 | 10 | 11 | class Configurable: 12 | @dataclass 13 | class Config: 14 | pass 15 | 16 | def __init__(self, cfg: Optional[dict] = None) -> None: 17 | super().__init__() 18 | self.cfg = parse_structured(self.Config, cfg) 19 | 20 | 21 | class Updateable: 22 | def do_update_step( 23 | self, epoch: int, global_step: int, on_load_weights: bool = False 24 | ): 25 | for attr in self.__dir__(): 26 | if attr.startswith("_"): 27 | continue 28 | try: 29 | module = getattr(self, attr) 30 | except: 31 | continue # ignore attributes like property, which can't be retrived using getattr? 32 | if isinstance(module, Updateable): 33 | module.do_update_step( 34 | epoch, global_step, on_load_weights=on_load_weights 35 | ) 36 | self.update_step(epoch, global_step, on_load_weights=on_load_weights) 37 | 38 | def do_update_step_end(self, epoch: int, global_step: int): 39 | for attr in self.__dir__(): 40 | if attr.startswith("_"): 41 | continue 42 | try: 43 | module = getattr(self, attr) 44 | except: 45 | continue # ignore attributes like property, which can't be retrived using getattr? 46 | if isinstance(module, Updateable): 47 | module.do_update_step_end(epoch, global_step) 48 | self.update_step_end(epoch, global_step) 49 | 50 | def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): 51 | # override this method to implement custom update logic 52 | # if on_load_weights is True, you should be careful doing things related to model evaluations, 53 | # as the models and tensors are not guarenteed to be on the same device 54 | pass 55 | 56 | def update_step_end(self, epoch: int, global_step: int): 57 | pass 58 | 59 | 60 | def update_if_possible(module: Any, epoch: int, global_step: int) -> None: 61 | if isinstance(module, Updateable): 62 | module.do_update_step(epoch, global_step) 63 | 64 | 65 | def update_end_if_possible(module: Any, epoch: int, global_step: int) -> None: 66 | if isinstance(module, Updateable): 67 | module.do_update_step_end(epoch, global_step) 68 | 69 | 70 | class BaseObject(Updateable): 71 | @dataclass 72 | class Config: 73 | pass 74 | 75 | cfg: Config # add this to every subclass of BaseObject to enable static type checking 76 | 77 | def __init__( 78 | self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs 79 | ) -> None: 80 | super().__init__() 81 | self.cfg = parse_structured(self.Config, cfg) 82 | self.device = get_device() 83 | self.configure(*args, **kwargs) 84 | 85 | def configure(self, *args, **kwargs) -> None: 86 | pass 87 | 88 | 89 | class BaseModule(nn.Module, Updateable): 90 | @dataclass 91 | class Config: 92 | weights: Optional[str] = None 93 | 94 | cfg: Config # add this to every subclass of BaseModule to enable static type checking 95 | 96 | def __init__( 97 | self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs 98 | ) -> None: 99 | super().__init__() 100 | self.cfg = parse_structured(self.Config, cfg) 101 | self.device = get_device() 102 | self.configure(*args, **kwargs) 103 | if self.cfg.weights is not None: 104 | # format: path/to/weights:module_name 105 | weights_path, module_name = self.cfg.weights.split(":") 106 | state_dict, epoch, global_step = load_module_weights( 107 | weights_path, module_name=module_name, map_location="cpu" 108 | ) 109 | self.load_state_dict(state_dict) 110 | self.do_update_step( 111 | epoch, global_step, on_load_weights=True 112 | ) # restore states 113 | # dummy tensor to indicate model state 114 | self._dummy: Float[Tensor, "..."] 115 | self.register_buffer("_dummy", torch.zeros(0).float(), persistent=False) 116 | 117 | def configure(self, *args, **kwargs) -> None: 118 | pass 119 | -------------------------------------------------------------------------------- /threestudio/models/estimators.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional, Tuple 2 | 3 | try: 4 | from typing import Literal 5 | except ImportError: 6 | from typing_extensions import Literal 7 | 8 | import torch 9 | from nerfacc.data_specs import RayIntervals 10 | from nerfacc.estimators.base import AbstractEstimator 11 | from nerfacc.pdf import importance_sampling, searchsorted 12 | from nerfacc.volrend import render_transmittance_from_density 13 | from torch import Tensor 14 | 15 | 16 | class ImportanceEstimator(AbstractEstimator): 17 | def __init__( 18 | self, 19 | ) -> None: 20 | super().__init__() 21 | 22 | @torch.no_grad() 23 | def sampling( 24 | self, 25 | prop_sigma_fns: List[Callable], 26 | prop_samples: List[int], 27 | num_samples: int, 28 | # rendering options 29 | n_rays: int, 30 | near_plane: float, 31 | far_plane: float, 32 | sampling_type: Literal["uniform", "lindisp"] = "uniform", 33 | # training options 34 | stratified: bool = False, 35 | requires_grad: bool = False, 36 | ) -> Tuple[Tensor, Tensor]: 37 | """Sampling with CDFs from proposal networks. 38 | 39 | Args: 40 | prop_sigma_fns: Proposal network evaluate functions. It should be a list 41 | of functions that take in samples {t_starts (n_rays, n_samples), 42 | t_ends (n_rays, n_samples)} and returns the post-activation densities 43 | (n_rays, n_samples). 44 | prop_samples: Number of samples to draw from each proposal network. Should 45 | be the same length as `prop_sigma_fns`. 46 | num_samples: Number of samples to draw in the end. 47 | n_rays: Number of rays. 48 | near_plane: Near plane. 49 | far_plane: Far plane. 50 | sampling_type: Sampling type. Either "uniform" or "lindisp". Default to 51 | "lindisp". 52 | stratified: Whether to use stratified sampling. Default to `False`. 53 | 54 | Returns: 55 | A tuple of {Tensor, Tensor}: 56 | 57 | - **t_starts**: The starts of the samples. Shape (n_rays, num_samples). 58 | - **t_ends**: The ends of the samples. Shape (n_rays, num_samples). 59 | 60 | """ 61 | assert len(prop_sigma_fns) == len(prop_samples), ( 62 | "The number of proposal networks and the number of samples " 63 | "should be the same." 64 | ) 65 | cdfs = torch.cat( 66 | [ 67 | torch.zeros((n_rays, 1), device=self.device), 68 | torch.ones((n_rays, 1), device=self.device), 69 | ], 70 | dim=-1, 71 | ) 72 | intervals = RayIntervals(vals=cdfs) 73 | 74 | for level_fn, level_samples in zip(prop_sigma_fns, prop_samples): 75 | intervals, _ = importance_sampling( 76 | intervals, cdfs, level_samples, stratified 77 | ) 78 | t_vals = _transform_stot( 79 | sampling_type, intervals.vals, near_plane, far_plane 80 | ) 81 | t_starts = t_vals[..., :-1] 82 | t_ends = t_vals[..., 1:] 83 | 84 | with torch.set_grad_enabled(requires_grad): 85 | sigmas = level_fn(t_starts, t_ends) 86 | assert sigmas.shape == t_starts.shape 87 | trans, _ = render_transmittance_from_density(t_starts, t_ends, sigmas) 88 | cdfs = 1.0 - torch.cat([trans, torch.zeros_like(trans[:, :1])], dim=-1) 89 | 90 | intervals, _ = importance_sampling(intervals, cdfs, num_samples, stratified) 91 | t_vals_fine = _transform_stot( 92 | sampling_type, intervals.vals, near_plane, far_plane 93 | ) 94 | 95 | t_vals = torch.cat([t_vals, t_vals_fine], dim=-1) 96 | t_vals, _ = torch.sort(t_vals, dim=-1) 97 | 98 | t_starts_ = t_vals[..., :-1] 99 | t_ends_ = t_vals[..., 1:] 100 | 101 | return t_starts_, t_ends_ 102 | 103 | 104 | def _transform_stot( 105 | transform_type: Literal["uniform", "lindisp"], 106 | s_vals: torch.Tensor, 107 | t_min: torch.Tensor, 108 | t_max: torch.Tensor, 109 | ) -> torch.Tensor: 110 | if transform_type == "uniform": 111 | _contract_fn, _icontract_fn = lambda x: x, lambda x: x 112 | elif transform_type == "lindisp": 113 | _contract_fn, _icontract_fn = lambda x: 1 / x, lambda x: 1 / x 114 | else: 115 | raise ValueError(f"Unknown transform_type: {transform_type}") 116 | s_min, s_max = _contract_fn(t_min), _contract_fn(t_max) 117 | icontract_fn = lambda s: _icontract_fn(s * s_max + (1 - s) * s_min) 118 | return icontract_fn(s_vals) 119 | -------------------------------------------------------------------------------- /threestudio/models/materials/diffuse_with_point_light_material.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.utils.ops import dot, get_activation 11 | from threestudio.utils.typing import * 12 | 13 | 14 | @threestudio.register("diffuse-with-point-light-material") 15 | class DiffuseWithPointLightMaterial(BaseMaterial): 16 | @dataclass 17 | class Config(BaseMaterial.Config): 18 | ambient_light_color: Tuple[float, float, float] = (0.1, 0.1, 0.1) 19 | diffuse_light_color: Tuple[float, float, float] = (0.9, 0.9, 0.9) 20 | ambient_only_steps: int = 1000 21 | diffuse_prob: float = 0.75 22 | textureless_prob: float = 0.5 23 | albedo_activation: str = "sigmoid" 24 | soft_shading: bool = False 25 | 26 | cfg: Config 27 | 28 | def configure(self) -> None: 29 | self.requires_normal = True 30 | 31 | self.ambient_light_color: Float[Tensor, "3"] 32 | self.register_buffer( 33 | "ambient_light_color", 34 | torch.as_tensor(self.cfg.ambient_light_color, dtype=torch.float32), 35 | ) 36 | self.diffuse_light_color: Float[Tensor, "3"] 37 | self.register_buffer( 38 | "diffuse_light_color", 39 | torch.as_tensor(self.cfg.diffuse_light_color, dtype=torch.float32), 40 | ) 41 | self.ambient_only = False 42 | 43 | def forward( 44 | self, 45 | features: Float[Tensor, "B ... Nf"], 46 | positions: Float[Tensor, "B ... 3"], 47 | shading_normal: Float[Tensor, "B ... 3"], 48 | light_positions: Float[Tensor, "B ... 3"], 49 | ambient_ratio: Optional[float] = None, 50 | shading: Optional[str] = None, 51 | **kwargs, 52 | ) -> Float[Tensor, "B ... 3"]: 53 | albedo = get_activation(self.cfg.albedo_activation)(features[..., :3]) 54 | 55 | if ambient_ratio is not None: 56 | # if ambient ratio is specified, use it 57 | diffuse_light_color = (1 - ambient_ratio) * torch.ones_like( 58 | self.diffuse_light_color 59 | ) 60 | ambient_light_color = ambient_ratio * torch.ones_like( 61 | self.ambient_light_color 62 | ) 63 | elif self.training and self.cfg.soft_shading: 64 | # otherwise if in training and soft shading is enabled, random a ambient ratio 65 | diffuse_light_color = torch.full_like( 66 | self.diffuse_light_color, random.random() 67 | ) 68 | ambient_light_color = 1.0 - diffuse_light_color 69 | else: 70 | # otherwise use the default fixed values 71 | diffuse_light_color = self.diffuse_light_color 72 | ambient_light_color = self.ambient_light_color 73 | 74 | light_directions: Float[Tensor, "B ... 3"] = F.normalize( 75 | light_positions - positions, dim=-1 76 | ) 77 | diffuse_light: Float[Tensor, "B ... 3"] = ( 78 | dot(shading_normal, light_directions).clamp(min=0.0) * diffuse_light_color 79 | ) 80 | textureless_color = diffuse_light + ambient_light_color 81 | # clamp albedo to [0, 1] to compute shading 82 | color = albedo.clamp(0.0, 1.0) * textureless_color 83 | 84 | if shading is None: 85 | if self.training: 86 | # adopt the same type of augmentation for the whole batch 87 | if self.ambient_only or random.random() > self.cfg.diffuse_prob: 88 | shading = "albedo" 89 | elif random.random() < self.cfg.textureless_prob: 90 | shading = "textureless" 91 | else: 92 | shading = "diffuse" 93 | else: 94 | if self.ambient_only: 95 | shading = "albedo" 96 | else: 97 | # return shaded color by default in evaluation 98 | shading = "diffuse" 99 | 100 | # multiply by 0 to prevent checking for unused parameters in DDP 101 | if shading == "albedo": 102 | return albedo + textureless_color * 0 103 | elif shading == "textureless": 104 | return albedo * 0 + textureless_color 105 | elif shading == "diffuse": 106 | return color 107 | else: 108 | raise ValueError(f"Unknown shading type {shading}") 109 | 110 | def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): 111 | if global_step < self.cfg.ambient_only_steps: 112 | self.ambient_only = True 113 | else: 114 | self.ambient_only = False 115 | 116 | def export(self, features: Float[Tensor, "*N Nf"], **kwargs) -> Dict[str, Any]: 117 | albedo = get_activation(self.cfg.albedo_activation)(features[..., :3]).clamp( 118 | 0.0, 1.0 119 | ) 120 | return {"albedo": albedo} 121 | -------------------------------------------------------------------------------- /threestudio/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from datetime import datetime 4 | 5 | from omegaconf import OmegaConf 6 | 7 | import threestudio 8 | from threestudio.utils.typing import * 9 | 10 | # ============ Register OmegaConf Recolvers ============= # 11 | OmegaConf.register_new_resolver( 12 | "calc_exp_lr_decay_rate", lambda factor, n: factor ** (1.0 / n) 13 | ) 14 | OmegaConf.register_new_resolver("add", lambda a, b: a + b) 15 | OmegaConf.register_new_resolver("sub", lambda a, b: a - b) 16 | OmegaConf.register_new_resolver("mul", lambda a, b: a * b) 17 | OmegaConf.register_new_resolver("div", lambda a, b: a / b) 18 | OmegaConf.register_new_resolver("idiv", lambda a, b: a // b) 19 | OmegaConf.register_new_resolver("basename", lambda p: os.path.basename(p)) 20 | OmegaConf.register_new_resolver("rmspace", lambda s, sub: s.replace(" ", sub)) 21 | OmegaConf.register_new_resolver("tuple2", lambda s: [float(s), float(s)]) 22 | OmegaConf.register_new_resolver("gt0", lambda s: s > 0) 23 | OmegaConf.register_new_resolver("cmaxgt0", lambda s: C_max(s) > 0) 24 | OmegaConf.register_new_resolver("not", lambda s: not s) 25 | OmegaConf.register_new_resolver( 26 | "cmaxgt0orcmaxgt0", lambda a, b: C_max(a) > 0 or C_max(b) > 0 27 | ) 28 | # ======================================================= # 29 | 30 | 31 | def C_max(value: Any) -> float: 32 | if isinstance(value, int) or isinstance(value, float): 33 | pass 34 | else: 35 | value = config_to_primitive(value) 36 | if not isinstance(value, list): 37 | raise TypeError("Scalar specification only supports list, got", type(value)) 38 | if len(value) >= 6: 39 | max_value = value[2] 40 | for i in range(4, len(value), 2): 41 | max_value = max(max_value, value[i]) 42 | value = [value[0], value[1], max_value, value[3]] 43 | if len(value) == 3: 44 | value = [0] + value 45 | assert len(value) == 4 46 | start_step, start_value, end_value, end_step = value 47 | value = max(start_value, end_value) 48 | return value 49 | 50 | 51 | @dataclass 52 | class ExperimentConfig: 53 | name: str = "default" 54 | description: str = "" 55 | tag: str = "" 56 | seed: int = 0 57 | use_timestamp: bool = True 58 | timestamp: Optional[str] = None 59 | exp_root_dir: str = "outputs" 60 | 61 | ### these shouldn't be set manually 62 | exp_dir: str = "outputs/default" 63 | trial_name: str = "exp" 64 | trial_dir: str = "outputs/default/exp" 65 | n_gpus: int = 1 66 | ### 67 | 68 | resume: Optional[str] = None 69 | 70 | data_type: str = "" 71 | data: dict = field(default_factory=dict) 72 | 73 | system_type: str = "" 74 | system: dict = field(default_factory=dict) 75 | 76 | # accept pytorch-lightning trainer parameters 77 | # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api 78 | trainer: dict = field(default_factory=dict) 79 | 80 | # accept pytorch-lightning checkpoint callback parameters 81 | # see https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint 82 | checkpoint: dict = field(default_factory=dict) 83 | 84 | def __post_init__(self): 85 | if not self.tag and not self.use_timestamp: 86 | raise ValueError("Either tag is specified or use_timestamp is True.") 87 | self.trial_name = self.tag 88 | # if resume from an existing config, self.timestamp should not be None 89 | if self.timestamp is None: 90 | self.timestamp = "" 91 | if self.use_timestamp: 92 | if self.n_gpus > 1: 93 | threestudio.warn( 94 | "Timestamp is disabled when using multiple GPUs, please make sure you have a unique tag." 95 | ) 96 | else: 97 | self.timestamp = datetime.now().strftime("@%Y%m%d-%H%M%S") 98 | self.trial_name += self.timestamp 99 | self.exp_dir = os.path.join(self.exp_root_dir, self.name) 100 | self.trial_dir = os.path.join(self.exp_dir, self.trial_name) 101 | os.makedirs(self.trial_dir, exist_ok=True) 102 | 103 | 104 | def load_config(*yamls: str, cli_args: list = [], from_string=False, **kwargs) -> Any: 105 | if from_string: 106 | yaml_confs = [OmegaConf.create(s) for s in yamls] 107 | else: 108 | yaml_confs = [OmegaConf.load(f) for f in yamls] 109 | cli_conf = OmegaConf.from_cli(cli_args) 110 | cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs) 111 | OmegaConf.resolve(cfg) 112 | assert isinstance(cfg, DictConfig) 113 | scfg = parse_structured(ExperimentConfig, cfg) 114 | return scfg 115 | 116 | 117 | def config_to_primitive(config, resolve: bool = True) -> Any: 118 | return OmegaConf.to_container(config, resolve=resolve) 119 | 120 | 121 | def dump_config(path: str, config) -> None: 122 | with open(path, "w") as fp: 123 | OmegaConf.save(config=config, f=fp) 124 | 125 | 126 | def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: 127 | scfg = OmegaConf.structured(fields(**cfg)) 128 | return scfg 129 | -------------------------------------------------------------------------------- /threestudio/utils/perceptual/utils.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | 4 | import requests 5 | from tqdm import tqdm 6 | 7 | URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} 8 | 9 | CKPT_MAP = {"vgg_lpips": "vgg.pth"} 10 | 11 | MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} 12 | 13 | 14 | def download(url, local_path, chunk_size=1024): 15 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 16 | with requests.get(url, stream=True) as r: 17 | total_size = int(r.headers.get("content-length", 0)) 18 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 19 | with open(local_path, "wb") as f: 20 | for data in r.iter_content(chunk_size=chunk_size): 21 | if data: 22 | f.write(data) 23 | pbar.update(chunk_size) 24 | 25 | 26 | def md5_hash(path): 27 | with open(path, "rb") as f: 28 | content = f.read() 29 | return hashlib.md5(content).hexdigest() 30 | 31 | 32 | def get_ckpt_path(name, root, check=False): 33 | assert name in URL_MAP 34 | path = os.path.join(root, CKPT_MAP[name]) 35 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 36 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 37 | download(URL_MAP[name], path) 38 | md5 = md5_hash(path) 39 | assert md5 == MD5_MAP[name], md5 40 | return path 41 | 42 | 43 | class KeyNotFoundError(Exception): 44 | def __init__(self, cause, keys=None, visited=None): 45 | self.cause = cause 46 | self.keys = keys 47 | self.visited = visited 48 | messages = list() 49 | if keys is not None: 50 | messages.append("Key not found: {}".format(keys)) 51 | if visited is not None: 52 | messages.append("Visited: {}".format(visited)) 53 | messages.append("Cause:\n{}".format(cause)) 54 | message = "\n".join(messages) 55 | super().__init__(message) 56 | 57 | 58 | def retrieve( 59 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 60 | ): 61 | """Given a nested list or dict return the desired value at key expanding 62 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 63 | is done in-place. 64 | 65 | Parameters 66 | ---------- 67 | list_or_dict : list or dict 68 | Possibly nested list or dictionary. 69 | key : str 70 | key/to/value, path like string describing all keys necessary to 71 | consider to get to the desired value. List indices can also be 72 | passed here. 73 | splitval : str 74 | String that defines the delimiter between keys of the 75 | different depth levels in `key`. 76 | default : obj 77 | Value returned if :attr:`key` is not found. 78 | expand : bool 79 | Whether to expand callable nodes on the path or not. 80 | 81 | Returns 82 | ------- 83 | The desired value or if :attr:`default` is not ``None`` and the 84 | :attr:`key` is not found returns ``default``. 85 | 86 | Raises 87 | ------ 88 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 89 | ``None``. 90 | """ 91 | 92 | keys = key.split(splitval) 93 | 94 | success = True 95 | try: 96 | visited = [] 97 | parent = None 98 | last_key = None 99 | for key in keys: 100 | if callable(list_or_dict): 101 | if not expand: 102 | raise KeyNotFoundError( 103 | ValueError( 104 | "Trying to get past callable node with expand=False." 105 | ), 106 | keys=keys, 107 | visited=visited, 108 | ) 109 | list_or_dict = list_or_dict() 110 | parent[last_key] = list_or_dict 111 | 112 | last_key = key 113 | parent = list_or_dict 114 | 115 | try: 116 | if isinstance(list_or_dict, dict): 117 | list_or_dict = list_or_dict[key] 118 | else: 119 | list_or_dict = list_or_dict[int(key)] 120 | except (KeyError, IndexError, ValueError) as e: 121 | raise KeyNotFoundError(e, keys=keys, visited=visited) 122 | 123 | visited += [key] 124 | # final expansion of retrieved value 125 | if expand and callable(list_or_dict): 126 | list_or_dict = list_or_dict() 127 | parent[last_key] = list_or_dict 128 | except KeyNotFoundError as e: 129 | if default is None: 130 | raise e 131 | else: 132 | list_or_dict = default 133 | success = False 134 | 135 | if not pass_success: 136 | return list_or_dict 137 | else: 138 | return list_or_dict, success 139 | 140 | 141 | if __name__ == "__main__": 142 | config = { 143 | "keya": "a", 144 | "keyb": "b", 145 | "keyc": { 146 | "cc1": 1, 147 | "cc2": 2, 148 | }, 149 | } 150 | from omegaconf import OmegaConf 151 | 152 | config = OmegaConf.create(config) 153 | print(config) 154 | retrieve(config, "keya") 155 | -------------------------------------------------------------------------------- /threestudio/models/materials/pbr_material.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import envlight 5 | import numpy as np 6 | import nvdiffrast.torch as dr 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | import threestudio 12 | from threestudio.models.materials.base import BaseMaterial 13 | from threestudio.utils.ops import get_activation 14 | from threestudio.utils.typing import * 15 | 16 | 17 | @threestudio.register("pbr-material") 18 | class PBRMaterial(BaseMaterial): 19 | @dataclass 20 | class Config(BaseMaterial.Config): 21 | material_activation: str = "sigmoid" 22 | environment_texture: str = "load/lights/mud_road_puresky_1k.hdr" 23 | environment_scale: float = 2.0 24 | min_metallic: float = 0.0 25 | max_metallic: float = 0.9 26 | min_roughness: float = 0.08 27 | max_roughness: float = 0.9 28 | use_bump: bool = True 29 | 30 | cfg: Config 31 | 32 | def configure(self) -> None: 33 | self.requires_normal = True 34 | self.requires_tangent = self.cfg.use_bump 35 | 36 | self.light = envlight.EnvLight( 37 | self.cfg.environment_texture, scale=self.cfg.environment_scale 38 | ) 39 | 40 | FG_LUT = torch.from_numpy( 41 | np.fromfile("load/lights/bsdf_256_256.bin", dtype=np.float32).reshape( 42 | 1, 256, 256, 2 43 | ) 44 | ) 45 | self.register_buffer("FG_LUT", FG_LUT) 46 | 47 | def forward( 48 | self, 49 | features: Float[Tensor, "*B Nf"], 50 | viewdirs: Float[Tensor, "*B 3"], 51 | shading_normal: Float[Tensor, "B ... 3"], 52 | tangent: Optional[Float[Tensor, "B ... 3"]] = None, 53 | **kwargs, 54 | ) -> Float[Tensor, "*B 3"]: 55 | prefix_shape = features.shape[:-1] 56 | 57 | material: Float[Tensor, "*B Nf"] = get_activation(self.cfg.material_activation)( 58 | features 59 | ) 60 | albedo = material[..., :3] 61 | metallic = ( 62 | material[..., 3:4] * (self.cfg.max_metallic - self.cfg.min_metallic) 63 | + self.cfg.min_metallic 64 | ) 65 | roughness = ( 66 | material[..., 4:5] * (self.cfg.max_roughness - self.cfg.min_roughness) 67 | + self.cfg.min_roughness 68 | ) 69 | 70 | if self.cfg.use_bump: 71 | assert tangent is not None 72 | # perturb_normal is a delta to the initialization [0, 0, 1] 73 | perturb_normal = (material[..., 5:8] * 2 - 1) + torch.tensor( 74 | [0, 0, 1], dtype=material.dtype, device=material.device 75 | ) 76 | perturb_normal = F.normalize(perturb_normal.clamp(-1, 1), dim=-1) 77 | 78 | # apply normal perturbation in tangent space 79 | bitangent = F.normalize(torch.cross(tangent, shading_normal), dim=-1) 80 | shading_normal = ( 81 | tangent * perturb_normal[..., 0:1] 82 | - bitangent * perturb_normal[..., 1:2] 83 | + shading_normal * perturb_normal[..., 2:3] 84 | ) 85 | shading_normal = F.normalize(shading_normal, dim=-1) 86 | 87 | v = -viewdirs 88 | n_dot_v = (shading_normal * v).sum(-1, keepdim=True) 89 | reflective = n_dot_v * shading_normal * 2 - v 90 | 91 | diffuse_albedo = (1 - metallic) * albedo 92 | 93 | fg_uv = torch.cat([n_dot_v, roughness], -1).clamp(0, 1) 94 | fg = dr.texture( 95 | self.FG_LUT, 96 | fg_uv.reshape(1, -1, 1, 2).contiguous(), 97 | filter_mode="linear", 98 | boundary_mode="clamp", 99 | ).reshape(*prefix_shape, 2) 100 | F0 = (1 - metallic) * 0.04 + metallic * albedo 101 | specular_albedo = F0 * fg[:, 0:1] + fg[:, 1:2] 102 | 103 | diffuse_light = self.light(shading_normal) 104 | specular_light = self.light(reflective, roughness) 105 | 106 | color = diffuse_albedo * diffuse_light + specular_albedo * specular_light 107 | color = color.clamp(0.0, 1.0) 108 | 109 | return color 110 | 111 | def export(self, features: Float[Tensor, "*N Nf"], **kwargs) -> Dict[str, Any]: 112 | material: Float[Tensor, "*N Nf"] = get_activation(self.cfg.material_activation)( 113 | features 114 | ) 115 | albedo = material[..., :3] 116 | metallic = ( 117 | material[..., 3:4] * (self.cfg.max_metallic - self.cfg.min_metallic) 118 | + self.cfg.min_metallic 119 | ) 120 | roughness = ( 121 | material[..., 4:5] * (self.cfg.max_roughness - self.cfg.min_roughness) 122 | + self.cfg.min_roughness 123 | ) 124 | 125 | out = { 126 | "albedo": albedo, 127 | "metallic": metallic, 128 | "roughness": roughness, 129 | } 130 | 131 | if self.cfg.use_bump: 132 | perturb_normal = (material[..., 5:8] * 2 - 1) + torch.tensor( 133 | [0, 0, 1], dtype=material.dtype, device=material.device 134 | ) 135 | perturb_normal = F.normalize(perturb_normal.clamp(-1, 1), dim=-1) 136 | perturb_normal = (perturb_normal + 1) / 2 137 | out.update( 138 | { 139 | "bump": perturb_normal, 140 | } 141 | ) 142 | 143 | return out 144 | -------------------------------------------------------------------------------- /threestudio/utils/misc.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import math 3 | import os 4 | import re 5 | 6 | import tinycudann as tcnn 7 | import torch 8 | from packaging import version 9 | 10 | from threestudio.utils.config import config_to_primitive 11 | from threestudio.utils.typing import * 12 | 13 | 14 | def parse_version(ver: str): 15 | return version.parse(ver) 16 | 17 | 18 | def get_rank(): 19 | # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, 20 | # therefore LOCAL_RANK needs to be checked first 21 | rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") 22 | for key in rank_keys: 23 | rank = os.environ.get(key) 24 | if rank is not None: 25 | return int(rank) 26 | return 0 27 | 28 | 29 | def get_device(): 30 | return torch.device(f"cuda:{get_rank()}") 31 | 32 | 33 | def load_module_weights( 34 | path, module_name=None, ignore_modules=None, map_location=None 35 | ) -> Tuple[dict, int, int]: 36 | if module_name is not None and ignore_modules is not None: 37 | raise ValueError("module_name and ignore_modules cannot be both set") 38 | if map_location is None: 39 | map_location = get_device() 40 | 41 | ckpt = torch.load(path, map_location=map_location) 42 | state_dict = ckpt["state_dict"] 43 | state_dict_to_load = state_dict 44 | 45 | if ignore_modules is not None: 46 | state_dict_to_load = {} 47 | for k, v in state_dict.items(): 48 | ignore = any( 49 | [k.startswith(ignore_module + ".") for ignore_module in ignore_modules] 50 | ) 51 | if ignore: 52 | continue 53 | state_dict_to_load[k] = v 54 | 55 | if module_name is not None: 56 | state_dict_to_load = {} 57 | for k, v in state_dict.items(): 58 | m = re.match(rf"^{module_name}\.(.*)$", k) 59 | if m is None: 60 | continue 61 | state_dict_to_load[m.group(1)] = v 62 | 63 | return state_dict_to_load, ckpt["epoch"], ckpt["global_step"] 64 | 65 | 66 | def C(value: Any, epoch: int, global_step: int, interpolation="linear") -> float: 67 | if isinstance(value, int) or isinstance(value, float): 68 | pass 69 | else: 70 | value = config_to_primitive(value) 71 | if not isinstance(value, list): 72 | raise TypeError("Scalar specification only supports list, got", type(value)) 73 | if len(value) == 3: 74 | value = [0] + value 75 | if len(value) >= 6: 76 | select_i = 3 77 | for i in range(3, len(value) - 2, 2): 78 | if global_step >= value[i]: 79 | select_i = i + 2 80 | if select_i != 3: 81 | start_value, start_step = value[select_i - 3], value[select_i - 2] 82 | else: 83 | start_step, start_value = value[:2] 84 | end_value, end_step = value[select_i - 1], value[select_i] 85 | value = [start_step, start_value, end_value, end_step] 86 | assert len(value) == 4 87 | start_step, start_value, end_value, end_step = value 88 | if isinstance(end_step, int): 89 | current_step = global_step 90 | elif isinstance(end_step, float): 91 | current_step = epoch 92 | t = max(min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0) 93 | if interpolation == "linear": 94 | value = start_value + (end_value - start_value) * t 95 | elif interpolation == "exp": 96 | value = math.exp(math.log(start_value) * (1 - t) + math.log(end_value) * t) 97 | else: 98 | raise ValueError( 99 | f"Unknown interpolation method: {interpolation}, only support linear and exp" 100 | ) 101 | return value 102 | 103 | 104 | def cleanup(): 105 | gc.collect() 106 | torch.cuda.empty_cache() 107 | tcnn.free_temporary_memory() 108 | 109 | 110 | def finish_with_cleanup(func: Callable): 111 | def wrapper(*args, **kwargs): 112 | out = func(*args, **kwargs) 113 | cleanup() 114 | return out 115 | 116 | return wrapper 117 | 118 | 119 | def _distributed_available(): 120 | return torch.distributed.is_available() and torch.distributed.is_initialized() 121 | 122 | 123 | def barrier(): 124 | if not _distributed_available(): 125 | return 126 | else: 127 | torch.distributed.barrier() 128 | 129 | 130 | def broadcast(tensor, src=0): 131 | if not _distributed_available(): 132 | return tensor 133 | else: 134 | torch.distributed.broadcast(tensor, src=src) 135 | return tensor 136 | 137 | 138 | def enable_gradient(model, enabled: bool = True) -> None: 139 | for param in model.parameters(): 140 | param.requires_grad_(enabled) 141 | 142 | 143 | def find_last_path(path: str): 144 | if (path is not None) and ("LAST" in path): 145 | path = path.replace(" ", "_") 146 | base_dir_prefix, suffix = path.split("LAST", 1) 147 | base_dir = os.path.dirname(base_dir_prefix) 148 | prefix = os.path.split(base_dir_prefix)[-1] 149 | base_dir_prefix = os.path.join(base_dir, prefix) 150 | all_path = os.listdir(base_dir) 151 | all_path = [os.path.join(base_dir, dir) for dir in all_path] 152 | filtered_path = [dir for dir in all_path if dir.startswith(base_dir_prefix)] 153 | filtered_path.sort(reverse=True) 154 | last_path = filtered_path[0] 155 | new_path = last_path + suffix 156 | if os.path.exists(new_path): 157 | return new_path 158 | else: 159 | raise FileNotFoundError(new_path) 160 | else: 161 | return path 162 | -------------------------------------------------------------------------------- /threestudio/utils/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | 5 | import pytorch_lightning 6 | 7 | from threestudio.utils.config import dump_config 8 | from threestudio.utils.misc import parse_version 9 | 10 | if parse_version(pytorch_lightning.__version__) > parse_version("1.8"): 11 | from pytorch_lightning.callbacks import Callback 12 | else: 13 | from pytorch_lightning.callbacks.base import Callback 14 | 15 | from pytorch_lightning.callbacks.progress import TQDMProgressBar 16 | from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn 17 | 18 | 19 | class VersionedCallback(Callback): 20 | def __init__(self, save_root, version=None, use_version=True): 21 | self.save_root = save_root 22 | self._version = version 23 | self.use_version = use_version 24 | 25 | @property 26 | def version(self) -> int: 27 | """Get the experiment version. 28 | 29 | Returns: 30 | The experiment version if specified else the next version. 31 | """ 32 | if self._version is None: 33 | self._version = self._get_next_version() 34 | return self._version 35 | 36 | def _get_next_version(self): 37 | existing_versions = [] 38 | if os.path.isdir(self.save_root): 39 | for f in os.listdir(self.save_root): 40 | bn = os.path.basename(f) 41 | if bn.startswith("version_"): 42 | dir_ver = os.path.splitext(bn)[0].split("_")[1].replace("/", "") 43 | existing_versions.append(int(dir_ver)) 44 | if len(existing_versions) == 0: 45 | return 0 46 | return max(existing_versions) + 1 47 | 48 | @property 49 | def savedir(self): 50 | if not self.use_version: 51 | return self.save_root 52 | return os.path.join( 53 | self.save_root, 54 | self.version 55 | if isinstance(self.version, str) 56 | else f"version_{self.version}", 57 | ) 58 | 59 | 60 | class CodeSnapshotCallback(VersionedCallback): 61 | def __init__(self, save_root, version=None, use_version=True): 62 | super().__init__(save_root, version, use_version) 63 | 64 | def get_file_list(self): 65 | return [ 66 | b.decode() 67 | for b in set( 68 | subprocess.check_output( 69 | 'git ls-files -- ":!:load/*"', shell=True 70 | ).splitlines() 71 | ) 72 | | set( # hard code, TODO: use config to exclude folders or files 73 | subprocess.check_output( 74 | "git ls-files --others --exclude-standard", shell=True 75 | ).splitlines() 76 | ) 77 | ] 78 | 79 | @rank_zero_only 80 | def save_code_snapshot(self): 81 | os.makedirs(self.savedir, exist_ok=True) 82 | for f in self.get_file_list(): 83 | if not os.path.exists(f) or os.path.isdir(f): 84 | continue 85 | os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True) 86 | shutil.copyfile(f, os.path.join(self.savedir, f)) 87 | 88 | def on_fit_start(self, trainer, pl_module): 89 | try: 90 | self.save_code_snapshot() 91 | except: 92 | rank_zero_warn( 93 | "Code snapshot is not saved. Please make sure you have git installed and are in a git repository." 94 | ) 95 | 96 | 97 | class ConfigSnapshotCallback(VersionedCallback): 98 | def __init__(self, config_path, config, save_root, version=None, use_version=True): 99 | super().__init__(save_root, version, use_version) 100 | self.config_path = config_path 101 | self.config = config 102 | 103 | @rank_zero_only 104 | def save_config_snapshot(self): 105 | os.makedirs(self.savedir, exist_ok=True) 106 | dump_config(os.path.join(self.savedir, "parsed.yaml"), self.config) 107 | shutil.copyfile(self.config_path, os.path.join(self.savedir, "raw.yaml")) 108 | 109 | def on_fit_start(self, trainer, pl_module): 110 | self.save_config_snapshot() 111 | 112 | 113 | class CustomProgressBar(TQDMProgressBar): 114 | def get_metrics(self, *args, **kwargs): 115 | # don't show the version number 116 | items = super().get_metrics(*args, **kwargs) 117 | items.pop("v_num", None) 118 | return items 119 | 120 | 121 | class ProgressCallback(Callback): 122 | def __init__(self, save_path): 123 | super().__init__() 124 | self.save_path = save_path 125 | self._file_handle = None 126 | 127 | @property 128 | def file_handle(self): 129 | if self._file_handle is None: 130 | self._file_handle = open(self.save_path, "w") 131 | return self._file_handle 132 | 133 | @rank_zero_only 134 | def write(self, msg: str) -> None: 135 | self.file_handle.seek(0) 136 | self.file_handle.truncate() 137 | self.file_handle.write(msg) 138 | self.file_handle.flush() 139 | 140 | @rank_zero_only 141 | def on_train_batch_end(self, trainer, pl_module, *args, **kwargs): 142 | self.write( 143 | f"Generation progress: {pl_module.true_global_step / trainer.max_steps * 100:.2f}%" 144 | ) 145 | 146 | @rank_zero_only 147 | def on_validation_start(self, trainer, pl_module): 148 | self.write(f"Rendering validation image ...") 149 | 150 | @rank_zero_only 151 | def on_test_start(self, trainer, pl_module): 152 | self.write(f"Rendering video ...") 153 | 154 | @rank_zero_only 155 | def on_predict_start(self, trainer, pl_module): 156 | self.write(f"Exporting mesh assets ...") 157 | -------------------------------------------------------------------------------- /threestudio/models/renderers/nvdiff_rasterizer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import nerfacc 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | import threestudio 8 | from threestudio.models.background.base import BaseBackground 9 | from threestudio.models.geometry.base import BaseImplicitGeometry 10 | from threestudio.models.materials.base import BaseMaterial 11 | from threestudio.models.renderers.base import Rasterizer, VolumeRenderer 12 | from threestudio.utils.misc import get_device 13 | from threestudio.utils.rasterize import NVDiffRasterizerContext 14 | from threestudio.utils.typing import * 15 | 16 | 17 | @threestudio.register("nvdiff-rasterizer") 18 | class NVDiffRasterizer(Rasterizer): 19 | @dataclass 20 | class Config(VolumeRenderer.Config): 21 | context_type: str = "cuda" 22 | 23 | cfg: Config 24 | 25 | def configure( 26 | self, 27 | geometry: BaseImplicitGeometry, 28 | material: BaseMaterial, 29 | background: BaseBackground, 30 | ) -> None: 31 | super().configure(geometry, material, background) 32 | self.ctx = NVDiffRasterizerContext(self.cfg.context_type, get_device()) 33 | 34 | def forward( 35 | self, 36 | mvp_mtx: Float[Tensor, "B 4 4"], 37 | camera_positions: Float[Tensor, "B 3"], 38 | light_positions: Float[Tensor, "B 3"], 39 | height: int, 40 | width: int, 41 | render_rgb: bool = True, 42 | render_cmm: bool = True, 43 | rand_rotate_normal: bool = False, 44 | **kwargs 45 | ) -> Dict[str, Any]: 46 | batch_size = mvp_mtx.shape[0] 47 | mesh = self.geometry.isosurface() 48 | 49 | v_pos_clip: Float[Tensor, "B Nv 4"] = self.ctx.vertex_transform( 50 | mesh.v_pos, mvp_mtx 51 | ) 52 | rast, _ = self.ctx.rasterize(v_pos_clip, mesh.t_pos_idx, (height, width)) 53 | mask = rast[..., 3:] > 0 54 | mask_aa = self.ctx.antialias(mask.float(), rast, v_pos_clip, mesh.t_pos_idx) 55 | 56 | out = {"opacity": mask_aa, "mesh": mesh} 57 | 58 | if rand_rotate_normal: 59 | theta = torch.rand(1) * 2 * 3.1415926 60 | rot_matrix = torch.tensor([[torch.cos(theta), 0, torch.sin(theta)], 61 | [0, 1, 0], 62 | [-torch.sin(theta), 0, torch.cos(theta)]]).to(mesh.v_nrm.device) 63 | v_nrm = torch.matmul(rot_matrix, mesh.v_nrm.T).T 64 | gb_normal, _ = self.ctx.interpolate_one(v_nrm.contiguous(), rast, mesh.t_pos_idx) 65 | else: 66 | gb_normal, _ = self.ctx.interpolate_one(mesh.v_nrm, rast, mesh.t_pos_idx) 67 | out.update({"comp_normal": (gb_normal + 1.0) / 2.0}) # in [0, 1] 68 | 69 | # gb_normal = F.normalize(gb_normal, dim=-1) 70 | # gb_normal_aa = torch.lerp( 71 | # torch.zeros_like(gb_normal), (gb_normal + 1.0) / 2.0, mask.float() 72 | # ) 73 | # gb_normal_aa = self.ctx.antialias( 74 | # gb_normal_aa, rast, v_pos_clip, mesh.t_pos_idx 75 | # ) 76 | # out.update({"comp_normal": gb_normal_aa}) # in [0, 1] 77 | 78 | if render_cmm: 79 | vmin, vmax = mesh.v_pos.amin(dim=0) * 1.1, mesh.v_pos.amax(dim=0) * 1.1 80 | center = (vmin + vmax) / 2 81 | scale = (vmax - vmin) / 2 82 | cmm = (mesh.v_pos - center) / scale 83 | cmm = cmm[..., [1, 0, 2]] 84 | cmm[..., 1] = - cmm[..., 1] 85 | 86 | 87 | gb_cmm, _ = self.ctx.interpolate_one(cmm, rast, mesh.t_pos_idx) # in [-1, 1] 88 | if self.training: 89 | out.update({"comp_cmm": torch.cat([gb_cmm, mask_aa * 2 - 1], dim=-1)}) # in [0, 1] 90 | else: 91 | # gb_cmm = F.normalize(gb_cmm, dim=-1) 92 | # gb_cmm_aa = torch.lerp( 93 | # torch.zeros_like(gb_cmm), (gb_cmm + 1.0) / 2.0, mask.float() 94 | # ) 95 | # gb_cmm_aa = self.ctx.antialias( 96 | # gb_cmm_aa, rast, v_pos_clip, mesh.t_pos_idx 97 | # ) 98 | out.update({"comp_cmm": (gb_cmm + 1) / 2 * mask_aa}) # in [0, 1] 99 | 100 | 101 | # TODO: make it clear whether to compute the normal, now we compute it in all cases 102 | # consider using: require_normal_computation = render_normal or (render_rgb and material.requires_normal) 103 | # or 104 | # render_normal = render_normal or (render_rgb and material.requires_normal) 105 | 106 | if render_rgb: 107 | selector = mask[..., 0] 108 | 109 | gb_pos, _ = self.ctx.interpolate_one(mesh.v_pos, rast, mesh.t_pos_idx) 110 | gb_viewdirs = F.normalize( 111 | gb_pos - camera_positions[:, None, None, :], dim=-1 112 | ) 113 | gb_light_positions = light_positions[:, None, None, :].expand( 114 | -1, height, width, -1 115 | ) 116 | 117 | positions = gb_pos[selector] 118 | geo_out = self.geometry(positions, output_normal=False) 119 | 120 | extra_geo_info = {} 121 | if self.material.requires_normal: 122 | extra_geo_info["shading_normal"] = gb_normal[selector] 123 | if self.material.requires_tangent: 124 | gb_tangent, _ = self.ctx.interpolate_one( 125 | mesh.v_tng, rast, mesh.t_pos_idx 126 | ) 127 | gb_tangent = F.normalize(gb_tangent, dim=-1) 128 | extra_geo_info["tangent"] = gb_tangent[selector] 129 | 130 | rgb_fg = self.material( 131 | viewdirs=gb_viewdirs[selector], 132 | positions=positions, 133 | light_positions=gb_light_positions[selector], 134 | **extra_geo_info, 135 | **geo_out 136 | ) 137 | gb_rgb_fg = torch.zeros(batch_size, height, width, 3).to(rgb_fg) 138 | gb_rgb_fg[selector] = rgb_fg 139 | 140 | gb_rgb_bg = self.background(dirs=gb_viewdirs) 141 | gb_rgb = torch.lerp(gb_rgb_bg, gb_rgb_fg, mask.float()) 142 | gb_rgb_aa = self.ctx.antialias(gb_rgb, rast, v_pos_clip, mesh.t_pos_idx) 143 | 144 | out.update({"comp_rgb": gb_rgb_aa, "comp_rgb_bg": gb_rgb_bg}) 145 | 146 | return out 147 | -------------------------------------------------------------------------------- /threestudio/utils/perceptual/perceptual.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | from collections import namedtuple 4 | from dataclasses import dataclass, field 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torchvision import models 9 | 10 | import threestudio 11 | from threestudio.utils.base import BaseObject 12 | from threestudio.utils.perceptual.utils import get_ckpt_path 13 | from threestudio.utils.typing import * 14 | 15 | 16 | @threestudio.register("perceptual-loss") 17 | class PerceptualLossObject(BaseObject): 18 | @dataclass 19 | class Config(BaseObject.Config): 20 | use_dropout: bool = True 21 | 22 | cfg: Config 23 | 24 | def configure(self) -> None: 25 | self.perceptual_loss = PerceptualLoss(self.cfg.use_dropout).to(self.device) 26 | 27 | def __call__( 28 | self, 29 | x: Float[Tensor, "B 3 256 256"], 30 | y: Float[Tensor, "B 3 256 256"], 31 | ): 32 | return self.perceptual_loss(x, y) 33 | 34 | 35 | class PerceptualLoss(nn.Module): 36 | # Learned perceptual metric 37 | def __init__(self, use_dropout=True): 38 | super().__init__() 39 | self.scaling_layer = ScalingLayer() 40 | self.chns = [64, 128, 256, 512, 512] # vg16 features 41 | self.net = vgg16(pretrained=True, requires_grad=False) 42 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 43 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 44 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 45 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 46 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 47 | self.load_from_pretrained() 48 | for param in self.parameters(): 49 | param.requires_grad = False 50 | 51 | def load_from_pretrained(self, name="vgg_lpips"): 52 | ckpt = get_ckpt_path(name, "threestudio/utils/lpips") 53 | self.load_state_dict( 54 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 55 | ) 56 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 57 | 58 | @classmethod 59 | def from_pretrained(cls, name="vgg_lpips"): 60 | if name != "vgg_lpips": 61 | raise NotImplementedError 62 | model = cls() 63 | ckpt = get_ckpt_path(name) 64 | model.load_state_dict( 65 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 66 | ) 67 | return model 68 | 69 | def forward(self, input, target): 70 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 71 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 72 | feats0, feats1, diffs = {}, {}, {} 73 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 74 | for kk in range(len(self.chns)): 75 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( 76 | outs1[kk] 77 | ) 78 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 79 | 80 | res = [ 81 | spatial_average(lins[kk].model(diffs[kk]), keepdim=True) 82 | for kk in range(len(self.chns)) 83 | ] 84 | val = res[0] 85 | for l in range(1, len(self.chns)): 86 | val += res[l] 87 | return val 88 | 89 | 90 | class ScalingLayer(nn.Module): 91 | def __init__(self): 92 | super(ScalingLayer, self).__init__() 93 | self.register_buffer( 94 | "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] 95 | ) 96 | self.register_buffer( 97 | "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] 98 | ) 99 | 100 | def forward(self, inp): 101 | return (inp - self.shift) / self.scale 102 | 103 | 104 | class NetLinLayer(nn.Module): 105 | """A single linear layer which does a 1x1 conv""" 106 | 107 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 108 | super(NetLinLayer, self).__init__() 109 | layers = ( 110 | [ 111 | nn.Dropout(), 112 | ] 113 | if (use_dropout) 114 | else [] 115 | ) 116 | layers += [ 117 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), 118 | ] 119 | self.model = nn.Sequential(*layers) 120 | 121 | 122 | class vgg16(torch.nn.Module): 123 | def __init__(self, requires_grad=False, pretrained=True): 124 | super(vgg16, self).__init__() 125 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 126 | self.slice1 = torch.nn.Sequential() 127 | self.slice2 = torch.nn.Sequential() 128 | self.slice3 = torch.nn.Sequential() 129 | self.slice4 = torch.nn.Sequential() 130 | self.slice5 = torch.nn.Sequential() 131 | self.N_slices = 5 132 | for x in range(4): 133 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 134 | for x in range(4, 9): 135 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 136 | for x in range(9, 16): 137 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 138 | for x in range(16, 23): 139 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 140 | for x in range(23, 30): 141 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 142 | if not requires_grad: 143 | for param in self.parameters(): 144 | param.requires_grad = False 145 | 146 | def forward(self, X): 147 | h = self.slice1(X) 148 | h_relu1_2 = h 149 | h = self.slice2(h) 150 | h_relu2_2 = h 151 | h = self.slice3(h) 152 | h_relu3_3 = h 153 | h = self.slice4(h) 154 | h_relu4_3 = h 155 | h = self.slice5(h) 156 | h_relu5_3 = h 157 | vgg_outputs = namedtuple( 158 | "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] 159 | ) 160 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 161 | return out 162 | 163 | 164 | def normalize_tensor(x, eps=1e-10): 165 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 166 | return x / (norm_factor + eps) 167 | 168 | 169 | def spatial_average(x, keepdim=True): 170 | return x.mean([2, 3], keepdim=keepdim) 171 | -------------------------------------------------------------------------------- /threestudio/models/renderers/gan_volume_renderer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import threestudio 7 | from threestudio.models.background.base import BaseBackground 8 | from threestudio.models.geometry.base import BaseImplicitGeometry 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.models.renderers.base import VolumeRenderer 11 | from threestudio.utils.GAN.discriminator import NLayerDiscriminator, weights_init 12 | from threestudio.utils.GAN.distribution import DiagonalGaussianDistribution 13 | from threestudio.utils.GAN.mobilenet import MobileNetV3 as GlobalEncoder 14 | from threestudio.utils.GAN.vae import Decoder as Generator 15 | from threestudio.utils.GAN.vae import Encoder as LocalEncoder 16 | from threestudio.utils.typing import * 17 | 18 | 19 | @threestudio.register("gan-volume-renderer") 20 | class GANVolumeRenderer(VolumeRenderer): 21 | @dataclass 22 | class Config(VolumeRenderer.Config): 23 | base_renderer_type: str = "" 24 | base_renderer: Optional[VolumeRenderer.Config] = None 25 | 26 | cfg: Config 27 | 28 | def configure( 29 | self, 30 | geometry: BaseImplicitGeometry, 31 | material: BaseMaterial, 32 | background: BaseBackground, 33 | ) -> None: 34 | self.base_renderer = threestudio.find(self.cfg.base_renderer_type)( 35 | self.cfg.base_renderer, 36 | geometry=geometry, 37 | material=material, 38 | background=background, 39 | ) 40 | self.ch_mult = [1, 2, 4] 41 | self.generator = Generator( 42 | ch=64, 43 | out_ch=3, 44 | ch_mult=self.ch_mult, 45 | num_res_blocks=1, 46 | attn_resolutions=[], 47 | dropout=0.0, 48 | resamp_with_conv=True, 49 | in_channels=7, 50 | resolution=512, 51 | z_channels=4, 52 | ) 53 | self.local_encoder = LocalEncoder( 54 | ch=32, 55 | out_ch=3, 56 | ch_mult=self.ch_mult, 57 | num_res_blocks=1, 58 | attn_resolutions=[], 59 | dropout=0.0, 60 | resamp_with_conv=True, 61 | in_channels=3, 62 | resolution=512, 63 | z_channels=4, 64 | ) 65 | self.global_encoder = GlobalEncoder(n_class=64) 66 | self.discriminator = NLayerDiscriminator( 67 | input_nc=3, n_layers=3, use_actnorm=False, ndf=64 68 | ).apply(weights_init) 69 | 70 | def forward( 71 | self, 72 | rays_o: Float[Tensor, "B H W 3"], 73 | rays_d: Float[Tensor, "B H W 3"], 74 | light_positions: Float[Tensor, "B 3"], 75 | bg_color: Optional[Tensor] = None, 76 | gt_rgb: Float[Tensor, "B H W 3"] = None, 77 | multi_level_guidance: Bool = False, 78 | **kwargs 79 | ) -> Dict[str, Float[Tensor, "..."]]: 80 | B, H, W, _ = rays_o.shape 81 | if gt_rgb is not None and multi_level_guidance: 82 | generator_level = torch.randint(0, 3, (1,)).item() 83 | interval_x = torch.randint(0, 8, (1,)).item() 84 | interval_y = torch.randint(0, 8, (1,)).item() 85 | int_rays_o = rays_o[:, interval_y::8, interval_x::8] 86 | int_rays_d = rays_d[:, interval_y::8, interval_x::8] 87 | out = self.base_renderer( 88 | int_rays_o, int_rays_d, light_positions, bg_color, **kwargs 89 | ) 90 | comp_int_rgb = out["comp_rgb"][..., :3] 91 | comp_gt_rgb = gt_rgb[:, interval_y::8, interval_x::8] 92 | else: 93 | generator_level = 0 94 | scale_ratio = 2 ** (len(self.ch_mult) - 1) 95 | rays_o = torch.nn.functional.interpolate( 96 | rays_o.permute(0, 3, 1, 2), 97 | (H // scale_ratio, W // scale_ratio), 98 | mode="bilinear", 99 | ).permute(0, 2, 3, 1) 100 | rays_d = torch.nn.functional.interpolate( 101 | rays_d.permute(0, 3, 1, 2), 102 | (H // scale_ratio, W // scale_ratio), 103 | mode="bilinear", 104 | ).permute(0, 2, 3, 1) 105 | 106 | out = self.base_renderer(rays_o, rays_d, light_positions, bg_color, **kwargs) 107 | comp_rgb = out["comp_rgb"][..., :3] 108 | latent = out["comp_rgb"][..., 3:] 109 | out["comp_lr_rgb"] = comp_rgb.clone() 110 | 111 | posterior = DiagonalGaussianDistribution(latent.permute(0, 3, 1, 2)) 112 | if multi_level_guidance: 113 | z_map = posterior.sample() 114 | else: 115 | z_map = posterior.mode() 116 | lr_rgb = comp_rgb.permute(0, 3, 1, 2) 117 | 118 | if generator_level == 0: 119 | g_code_rgb = self.global_encoder(F.interpolate(lr_rgb, (224, 224))) 120 | comp_gan_rgb = self.generator(torch.cat([lr_rgb, z_map], dim=1), g_code_rgb) 121 | elif generator_level == 1: 122 | g_code_rgb = self.global_encoder( 123 | F.interpolate(gt_rgb.permute(0, 3, 1, 2), (224, 224)) 124 | ) 125 | comp_gan_rgb = self.generator(torch.cat([lr_rgb, z_map], dim=1), g_code_rgb) 126 | elif generator_level == 2: 127 | g_code_rgb = self.global_encoder( 128 | F.interpolate(gt_rgb.permute(0, 3, 1, 2), (224, 224)) 129 | ) 130 | l_code_rgb = self.local_encoder(gt_rgb.permute(0, 3, 1, 2)) 131 | posterior = DiagonalGaussianDistribution(l_code_rgb) 132 | z_map = posterior.sample() 133 | comp_gan_rgb = self.generator(torch.cat([lr_rgb, z_map], dim=1), g_code_rgb) 134 | 135 | comp_rgb = F.interpolate(comp_rgb.permute(0, 3, 1, 2), (H, W), mode="bilinear") 136 | comp_gan_rgb = F.interpolate(comp_gan_rgb, (H, W), mode="bilinear") 137 | out.update( 138 | { 139 | "posterior": posterior, 140 | "comp_gan_rgb": comp_gan_rgb.permute(0, 2, 3, 1), 141 | "comp_rgb": comp_rgb.permute(0, 2, 3, 1), 142 | "generator_level": generator_level, 143 | } 144 | ) 145 | 146 | if gt_rgb is not None and multi_level_guidance: 147 | out.update({"comp_int_rgb": comp_int_rgb, "comp_gt_rgb": comp_gt_rgb}) 148 | return out 149 | 150 | def update_step( 151 | self, epoch: int, global_step: int, on_load_weights: bool = False 152 | ) -> None: 153 | self.base_renderer.update_step(epoch, global_step, on_load_weights) 154 | 155 | def train(self, mode=True): 156 | return self.base_renderer.train(mode) 157 | 158 | def eval(self): 159 | return self.base_renderer.eval() 160 | -------------------------------------------------------------------------------- /threestudio/utils/GAN/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import multiprocessing as mp 3 | from collections import abc 4 | from functools import partial 5 | from inspect import isfunction 6 | from queue import Queue 7 | from threading import Thread 8 | 9 | import numpy as np 10 | import torch 11 | from einops import rearrange 12 | from PIL import Image, ImageDraw, ImageFont 13 | 14 | 15 | def log_txt_as_img(wh, xc, size=10): 16 | # wh a tuple of (width, height) 17 | # xc a list of captions to plot 18 | b = len(xc) 19 | txts = list() 20 | for bi in range(b): 21 | txt = Image.new("RGB", wh, color="white") 22 | draw = ImageDraw.Draw(txt) 23 | font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) 24 | nc = int(40 * (wh[0] / 256)) 25 | lines = "\n".join( 26 | xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc) 27 | ) 28 | 29 | try: 30 | draw.text((0, 0), lines, fill="black", font=font) 31 | except UnicodeEncodeError: 32 | print("Cant encode string for logging. Skipping.") 33 | 34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 35 | txts.append(txt) 36 | txts = np.stack(txts) 37 | txts = torch.tensor(txts) 38 | return txts 39 | 40 | 41 | def ismap(x): 42 | if not isinstance(x, torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] > 3) 45 | 46 | 47 | def isimage(x): 48 | if not isinstance(x, torch.Tensor): 49 | return False 50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 51 | 52 | 53 | def exists(x): 54 | return x is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | 62 | 63 | def mean_flat(tensor): 64 | """ 65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 66 | Take the mean over all non-batch dimensions. 67 | """ 68 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 69 | 70 | 71 | def count_params(model, verbose=False): 72 | total_params = sum(p.numel() for p in model.parameters()) 73 | if verbose: 74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 75 | return total_params 76 | 77 | 78 | def instantiate_from_config(config): 79 | if not "target" in config: 80 | if config == "__is_first_stage__": 81 | return None 82 | elif config == "__is_unconditional__": 83 | return None 84 | raise KeyError("Expected key `target` to instantiate.") 85 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 86 | 87 | 88 | def get_obj_from_str(string, reload=False): 89 | module, cls = string.rsplit(".", 1) 90 | if reload: 91 | module_imp = importlib.import_module(module) 92 | importlib.reload(module_imp) 93 | return getattr(importlib.import_module(module, package=None), cls) 94 | 95 | 96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 97 | # create dummy dataset instance 98 | 99 | # run prefetching 100 | if idx_to_fn: 101 | res = func(data, worker_id=idx) 102 | else: 103 | res = func(data) 104 | Q.put([idx, res]) 105 | Q.put("Done") 106 | 107 | 108 | def parallel_data_prefetch( 109 | func: callable, 110 | data, 111 | n_proc, 112 | target_data_type="ndarray", 113 | cpu_intensive=True, 114 | use_worker_id=False, 115 | ): 116 | # if target_data_type not in ["ndarray", "list"]: 117 | # raise ValueError( 118 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 119 | # ) 120 | if isinstance(data, np.ndarray) and target_data_type == "list": 121 | raise ValueError("list expected but function got ndarray.") 122 | elif isinstance(data, abc.Iterable): 123 | if isinstance(data, dict): 124 | print( 125 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 126 | ) 127 | data = list(data.values()) 128 | if target_data_type == "ndarray": 129 | data = np.asarray(data) 130 | else: 131 | data = list(data) 132 | else: 133 | raise TypeError( 134 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 135 | ) 136 | 137 | if cpu_intensive: 138 | Q = mp.Queue(1000) 139 | proc = mp.Process 140 | else: 141 | Q = Queue(1000) 142 | proc = Thread 143 | # spawn processes 144 | if target_data_type == "ndarray": 145 | arguments = [ 146 | [func, Q, part, i, use_worker_id] 147 | for i, part in enumerate(np.array_split(data, n_proc)) 148 | ] 149 | else: 150 | step = ( 151 | int(len(data) / n_proc + 1) 152 | if len(data) % n_proc != 0 153 | else int(len(data) / n_proc) 154 | ) 155 | arguments = [ 156 | [func, Q, part, i, use_worker_id] 157 | for i, part in enumerate( 158 | [data[i : i + step] for i in range(0, len(data), step)] 159 | ) 160 | ] 161 | processes = [] 162 | for i in range(n_proc): 163 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 164 | processes += [p] 165 | 166 | # start processes 167 | print(f"Start prefetching...") 168 | import time 169 | 170 | start = time.time() 171 | gather_res = [[] for _ in range(n_proc)] 172 | try: 173 | for p in processes: 174 | p.start() 175 | 176 | k = 0 177 | while k < n_proc: 178 | # get result 179 | res = Q.get() 180 | if res == "Done": 181 | k += 1 182 | else: 183 | gather_res[res[0]] = res[1] 184 | 185 | except Exception as e: 186 | print("Exception: ", e) 187 | for p in processes: 188 | p.terminate() 189 | 190 | raise e 191 | finally: 192 | for p in processes: 193 | p.join() 194 | print(f"Prefetching complete. [{time.time() - start} sec.]") 195 | 196 | if target_data_type == "ndarray": 197 | if not isinstance(gather_res[0], np.ndarray): 198 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 199 | 200 | # order outputs 201 | return np.concatenate(gather_res, axis=0) 202 | elif target_data_type == "list": 203 | out = [] 204 | for r in gather_res: 205 | out.extend(r) 206 | return out 207 | else: 208 | return gather_res 209 | -------------------------------------------------------------------------------- /threestudio/models/geometry/custom_mesh.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import threestudio 10 | from threestudio.models.geometry.base import ( 11 | BaseExplicitGeometry, 12 | BaseGeometry, 13 | contract_to_unisphere, 14 | ) 15 | from threestudio.models.mesh import Mesh 16 | from threestudio.models.networks import get_encoding, get_mlp 17 | from threestudio.utils.ops import scale_tensor 18 | from threestudio.utils.typing import * 19 | 20 | 21 | @threestudio.register("custom-mesh") 22 | class CustomMesh(BaseExplicitGeometry): 23 | @dataclass 24 | class Config(BaseExplicitGeometry.Config): 25 | n_input_dims: int = 3 26 | n_feature_dims: int = 3 27 | pos_encoding_config: dict = field( 28 | default_factory=lambda: { 29 | "otype": "HashGrid", 30 | "n_levels": 16, 31 | "n_features_per_level": 2, 32 | "log2_hashmap_size": 19, 33 | "base_resolution": 16, 34 | "per_level_scale": 1.447269237440378, 35 | } 36 | ) 37 | mlp_network_config: dict = field( 38 | default_factory=lambda: { 39 | "otype": "VanillaMLP", 40 | "activation": "ReLU", 41 | "output_activation": "none", 42 | "n_neurons": 64, 43 | "n_hidden_layers": 1, 44 | } 45 | ) 46 | shape_init: str = "" 47 | shape_init_params: Optional[Any] = None 48 | shape_init_mesh_up: str = "+z" 49 | shape_init_mesh_front: str = "+x" 50 | 51 | cfg: Config 52 | 53 | def configure(self) -> None: 54 | super().configure() 55 | 56 | self.encoding = get_encoding( 57 | self.cfg.n_input_dims, self.cfg.pos_encoding_config 58 | ) 59 | self.feature_network = get_mlp( 60 | self.encoding.n_output_dims, 61 | self.cfg.n_feature_dims, 62 | self.cfg.mlp_network_config, 63 | ) 64 | 65 | # Initialize custom mesh 66 | if self.cfg.shape_init.startswith("mesh:"): 67 | assert isinstance(self.cfg.shape_init_params, float) 68 | mesh_path = self.cfg.shape_init[5:] 69 | if not os.path.exists(mesh_path): 70 | raise ValueError(f"Mesh file {mesh_path} does not exist.") 71 | 72 | import trimesh 73 | 74 | scene = trimesh.load(mesh_path) 75 | if isinstance(scene, trimesh.Trimesh): 76 | mesh = scene 77 | elif isinstance(scene, trimesh.scene.Scene): 78 | mesh = trimesh.Trimesh() 79 | for obj in scene.geometry.values(): 80 | mesh = trimesh.util.concatenate([mesh, obj]) 81 | else: 82 | raise ValueError(f"Unknown mesh type at {mesh_path}.") 83 | 84 | # move to center 85 | centroid = mesh.vertices.mean(0) 86 | mesh.vertices = mesh.vertices - centroid 87 | 88 | # align to up-z and front-x 89 | dirs = ["+x", "+y", "+z", "-x", "-y", "-z"] 90 | dir2vec = { 91 | "+x": np.array([1, 0, 0]), 92 | "+y": np.array([0, 1, 0]), 93 | "+z": np.array([0, 0, 1]), 94 | "-x": np.array([-1, 0, 0]), 95 | "-y": np.array([0, -1, 0]), 96 | "-z": np.array([0, 0, -1]), 97 | } 98 | if ( 99 | self.cfg.shape_init_mesh_up not in dirs 100 | or self.cfg.shape_init_mesh_front not in dirs 101 | ): 102 | raise ValueError( 103 | f"shape_init_mesh_up and shape_init_mesh_front must be one of {dirs}." 104 | ) 105 | if self.cfg.shape_init_mesh_up[1] == self.cfg.shape_init_mesh_front[1]: 106 | raise ValueError( 107 | "shape_init_mesh_up and shape_init_mesh_front must be orthogonal." 108 | ) 109 | z_, x_ = ( 110 | dir2vec[self.cfg.shape_init_mesh_up], 111 | dir2vec[self.cfg.shape_init_mesh_front], 112 | ) 113 | y_ = np.cross(z_, x_) 114 | std2mesh = np.stack([x_, y_, z_], axis=0).T 115 | mesh2std = np.linalg.inv(std2mesh) 116 | 117 | # scaling 118 | scale = np.abs(mesh.vertices).max() 119 | mesh.vertices = mesh.vertices / scale * self.cfg.shape_init_params 120 | mesh.vertices = np.dot(mesh2std, mesh.vertices.T).T 121 | 122 | v_pos = torch.tensor(mesh.vertices, dtype=torch.float32).to(self.device) 123 | t_pos_idx = torch.tensor(mesh.faces, dtype=torch.int64).to(self.device) 124 | self.mesh = Mesh(v_pos=v_pos, t_pos_idx=t_pos_idx) 125 | self.register_buffer( 126 | "v_buffer", 127 | v_pos, 128 | ) 129 | self.register_buffer( 130 | "t_buffer", 131 | t_pos_idx, 132 | ) 133 | 134 | else: 135 | raise ValueError( 136 | f"Unknown shape initialization type: {self.cfg.shape_init}" 137 | ) 138 | print(self.mesh.v_pos.device) 139 | 140 | def isosurface(self) -> Mesh: 141 | if hasattr(self, "mesh"): 142 | return self.mesh 143 | elif hasattr(self, "v_buffer"): 144 | self.mesh = Mesh(v_pos=self.v_buffer, t_pos_idx=self.t_buffer) 145 | return self.mesh 146 | else: 147 | raise ValueError(f"custom mesh is not initialized") 148 | 149 | def forward( 150 | self, points: Float[Tensor, "*N Di"], output_normal: bool = False 151 | ) -> Dict[str, Float[Tensor, "..."]]: 152 | assert ( 153 | output_normal == False 154 | ), f"Normal output is not supported for {self.__class__.__name__}" 155 | points_unscaled = points # points in the original scale 156 | points = contract_to_unisphere(points, self.bbox) # points normalized to (0, 1) 157 | enc = self.encoding(points.view(-1, self.cfg.n_input_dims)) 158 | features = self.feature_network(enc).view( 159 | *points.shape[:-1], self.cfg.n_feature_dims 160 | ) 161 | return {"features": features} 162 | 163 | def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]: 164 | out: Dict[str, Any] = {} 165 | if self.cfg.n_feature_dims == 0: 166 | return out 167 | points_unscaled = points 168 | points = contract_to_unisphere(points_unscaled, self.bbox) 169 | enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims)) 170 | features = self.feature_network(enc).view( 171 | *points.shape[:-1], self.cfg.n_feature_dims 172 | ) 173 | out.update( 174 | { 175 | "features": features, 176 | } 177 | ) 178 | return out 179 | -------------------------------------------------------------------------------- /threestudio/models/exporters/mesh_exporter.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | import threestudio 8 | from threestudio.models.background.base import BaseBackground 9 | from threestudio.models.exporters.base import Exporter, ExporterOutput 10 | from threestudio.models.geometry.base import BaseImplicitGeometry 11 | from threestudio.models.materials.base import BaseMaterial 12 | from threestudio.models.mesh import Mesh 13 | from threestudio.utils.rasterize import NVDiffRasterizerContext 14 | from threestudio.utils.typing import * 15 | 16 | 17 | @threestudio.register("mesh-exporter") 18 | class MeshExporter(Exporter): 19 | @dataclass 20 | class Config(Exporter.Config): 21 | fmt: str = "obj-mtl" # in ['obj-mtl', 'obj'], TODO: fbx 22 | save_name: str = "model" 23 | save_normal: bool = False 24 | save_uv: bool = True 25 | save_texture: bool = True 26 | texture_size: int = 1024 27 | texture_format: str = "jpg" 28 | xatlas_chart_options: dict = field(default_factory=dict) 29 | xatlas_pack_options: dict = field(default_factory=dict) 30 | context_type: str = "gl" 31 | 32 | cfg: Config 33 | 34 | def configure( 35 | self, 36 | geometry: BaseImplicitGeometry, 37 | material: BaseMaterial, 38 | background: BaseBackground, 39 | ) -> None: 40 | super().configure(geometry, material, background) 41 | self.ctx = NVDiffRasterizerContext(self.cfg.context_type, self.device) 42 | 43 | def __call__(self) -> List[ExporterOutput]: 44 | mesh: Mesh = self.geometry.isosurface() 45 | 46 | if self.cfg.fmt == "obj-mtl": 47 | return self.export_obj_with_mtl(mesh) 48 | elif self.cfg.fmt == "obj": 49 | return self.export_obj(mesh) 50 | else: 51 | raise ValueError(f"Unsupported mesh export format: {self.cfg.fmt}") 52 | 53 | def export_obj_with_mtl(self, mesh: Mesh) -> List[ExporterOutput]: 54 | params = { 55 | "mesh": mesh, 56 | "save_mat": True, 57 | "save_normal": self.cfg.save_normal, 58 | "save_uv": self.cfg.save_uv, 59 | "save_vertex_color": False, 60 | "map_Kd": None, # Base Color 61 | "map_Ks": None, # Specular 62 | "map_Bump": None, # Normal 63 | # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering 64 | "map_Pm": None, # Metallic 65 | "map_Pr": None, # Roughness 66 | "map_format": self.cfg.texture_format, 67 | } 68 | 69 | if self.cfg.save_uv: 70 | mesh.unwrap_uv(self.cfg.xatlas_chart_options, self.cfg.xatlas_pack_options) 71 | 72 | if self.cfg.save_texture: 73 | threestudio.info("Exporting textures ...") 74 | assert self.cfg.save_uv, "save_uv must be True when save_texture is True" 75 | # clip space transform 76 | uv_clip = mesh.v_tex * 2.0 - 1.0 77 | # pad to four component coordinate 78 | uv_clip4 = torch.cat( 79 | ( 80 | uv_clip, 81 | torch.zeros_like(uv_clip[..., 0:1]), 82 | torch.ones_like(uv_clip[..., 0:1]), 83 | ), 84 | dim=-1, 85 | ) 86 | # rasterize 87 | rast, _ = self.ctx.rasterize_one( 88 | uv_clip4, mesh.t_tex_idx, (self.cfg.texture_size, self.cfg.texture_size) 89 | ) 90 | 91 | hole_mask = ~(rast[:, :, 3] > 0) 92 | 93 | def uv_padding(image): 94 | uv_padding_size = self.cfg.xatlas_pack_options.get("padding", 2) 95 | inpaint_image = ( 96 | cv2.inpaint( 97 | (image.detach().cpu().numpy() * 255).astype(np.uint8), 98 | (hole_mask.detach().cpu().numpy() * 255).astype(np.uint8), 99 | uv_padding_size, 100 | cv2.INPAINT_TELEA, 101 | ) 102 | / 255.0 103 | ) 104 | return torch.from_numpy(inpaint_image).to(image) 105 | 106 | # Interpolate world space position 107 | gb_pos, _ = self.ctx.interpolate_one( 108 | mesh.v_pos, rast[None, ...], mesh.t_pos_idx 109 | ) 110 | gb_pos = gb_pos[0] 111 | 112 | # Sample out textures from MLP 113 | geo_out = self.geometry.export(points=gb_pos) 114 | mat_out = self.material.export(points=gb_pos, **geo_out) 115 | 116 | threestudio.info( 117 | "Perform UV padding on texture maps to avoid seams, may take a while ..." 118 | ) 119 | 120 | if "albedo" in mat_out: 121 | params["map_Kd"] = uv_padding(mat_out["albedo"]) 122 | else: 123 | threestudio.warn( 124 | "save_texture is True but no albedo texture found, using default white texture" 125 | ) 126 | if "metallic" in mat_out: 127 | params["map_Pm"] = uv_padding(mat_out["metallic"]) 128 | if "roughness" in mat_out: 129 | params["map_Pr"] = uv_padding(mat_out["roughness"]) 130 | if "bump" in mat_out: 131 | params["map_Bump"] = uv_padding(mat_out["bump"]) 132 | # TODO: map_Ks 133 | return [ 134 | ExporterOutput( 135 | save_name=f"{self.cfg.save_name}.obj", save_type="obj", params=params 136 | ) 137 | ] 138 | 139 | def export_obj(self, mesh: Mesh) -> List[ExporterOutput]: 140 | params = { 141 | "mesh": mesh, 142 | "save_mat": False, 143 | "save_normal": self.cfg.save_normal, 144 | "save_uv": self.cfg.save_uv, 145 | "save_vertex_color": False, 146 | "map_Kd": None, # Base Color 147 | "map_Ks": None, # Specular 148 | "map_Bump": None, # Normal 149 | # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering 150 | "map_Pm": None, # Metallic 151 | "map_Pr": None, # Roughness 152 | "map_format": self.cfg.texture_format, 153 | } 154 | 155 | if self.cfg.save_uv: 156 | mesh.unwrap_uv(self.cfg.xatlas_chart_options, self.cfg.xatlas_pack_options) 157 | 158 | if self.cfg.save_texture: 159 | threestudio.info("Exporting textures ...") 160 | geo_out = self.geometry.export(points=mesh.v_pos) 161 | mat_out = self.material.export(points=mesh.v_pos, **geo_out) 162 | 163 | if "albedo" in mat_out: 164 | mesh.set_vertex_color(mat_out["albedo"]) 165 | params["save_vertex_color"] = True 166 | else: 167 | threestudio.warn( 168 | "save_texture is True but no albedo texture found, not saving vertex color" 169 | ) 170 | 171 | return [ 172 | ExporterOutput( 173 | save_name=f"{self.cfg.save_name}.obj", save_type="obj", params=params 174 | ) 175 | ] 176 | -------------------------------------------------------------------------------- /threestudio/utils/GAN/discriminator.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def count_params(model): 8 | total_params = sum(p.numel() for p in model.parameters()) 9 | return total_params 10 | 11 | 12 | class ActNorm(nn.Module): 13 | def __init__( 14 | self, num_features, logdet=False, affine=True, allow_reverse_init=False 15 | ): 16 | assert affine 17 | super().__init__() 18 | self.logdet = logdet 19 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 20 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 21 | self.allow_reverse_init = allow_reverse_init 22 | 23 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) 24 | 25 | def initialize(self, input): 26 | with torch.no_grad(): 27 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 28 | mean = ( 29 | flatten.mean(1) 30 | .unsqueeze(1) 31 | .unsqueeze(2) 32 | .unsqueeze(3) 33 | .permute(1, 0, 2, 3) 34 | ) 35 | std = ( 36 | flatten.std(1) 37 | .unsqueeze(1) 38 | .unsqueeze(2) 39 | .unsqueeze(3) 40 | .permute(1, 0, 2, 3) 41 | ) 42 | 43 | self.loc.data.copy_(-mean) 44 | self.scale.data.copy_(1 / (std + 1e-6)) 45 | 46 | def forward(self, input, reverse=False): 47 | if reverse: 48 | return self.reverse(input) 49 | if len(input.shape) == 2: 50 | input = input[:, :, None, None] 51 | squeeze = True 52 | else: 53 | squeeze = False 54 | 55 | _, _, height, width = input.shape 56 | 57 | if self.training and self.initialized.item() == 0: 58 | self.initialize(input) 59 | self.initialized.fill_(1) 60 | 61 | h = self.scale * (input + self.loc) 62 | 63 | if squeeze: 64 | h = h.squeeze(-1).squeeze(-1) 65 | 66 | if self.logdet: 67 | log_abs = torch.log(torch.abs(self.scale)) 68 | logdet = height * width * torch.sum(log_abs) 69 | logdet = logdet * torch.ones(input.shape[0]).to(input) 70 | return h, logdet 71 | 72 | return h 73 | 74 | def reverse(self, output): 75 | if self.training and self.initialized.item() == 0: 76 | if not self.allow_reverse_init: 77 | raise RuntimeError( 78 | "Initializing ActNorm in reverse direction is " 79 | "disabled by default. Use allow_reverse_init=True to enable." 80 | ) 81 | else: 82 | self.initialize(output) 83 | self.initialized.fill_(1) 84 | 85 | if len(output.shape) == 2: 86 | output = output[:, :, None, None] 87 | squeeze = True 88 | else: 89 | squeeze = False 90 | 91 | h = output / self.scale - self.loc 92 | 93 | if squeeze: 94 | h = h.squeeze(-1).squeeze(-1) 95 | return h 96 | 97 | 98 | class AbstractEncoder(nn.Module): 99 | def __init__(self): 100 | super().__init__() 101 | 102 | def encode(self, *args, **kwargs): 103 | raise NotImplementedError 104 | 105 | 106 | class Labelator(AbstractEncoder): 107 | """Net2Net Interface for Class-Conditional Model""" 108 | 109 | def __init__(self, n_classes, quantize_interface=True): 110 | super().__init__() 111 | self.n_classes = n_classes 112 | self.quantize_interface = quantize_interface 113 | 114 | def encode(self, c): 115 | c = c[:, None] 116 | if self.quantize_interface: 117 | return c, None, [None, None, c.long()] 118 | return c 119 | 120 | 121 | class SOSProvider(AbstractEncoder): 122 | # for unconditional training 123 | def __init__(self, sos_token, quantize_interface=True): 124 | super().__init__() 125 | self.sos_token = sos_token 126 | self.quantize_interface = quantize_interface 127 | 128 | def encode(self, x): 129 | # get batch size from data and replicate sos_token 130 | c = torch.ones(x.shape[0], 1) * self.sos_token 131 | c = c.long().to(x.device) 132 | if self.quantize_interface: 133 | return c, None, [None, None, c] 134 | return c 135 | 136 | 137 | def weights_init(m): 138 | classname = m.__class__.__name__ 139 | if classname.find("Conv") != -1: 140 | nn.init.normal_(m.weight.data, 0.0, 0.02) 141 | elif classname.find("BatchNorm") != -1: 142 | nn.init.normal_(m.weight.data, 1.0, 0.02) 143 | nn.init.constant_(m.bias.data, 0) 144 | 145 | 146 | class NLayerDiscriminator(nn.Module): 147 | """Defines a PatchGAN discriminator as in Pix2Pix 148 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 149 | """ 150 | 151 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 152 | """Construct a PatchGAN discriminator 153 | Parameters: 154 | input_nc (int) -- the number of channels in input images 155 | ndf (int) -- the number of filters in the last conv layer 156 | n_layers (int) -- the number of conv layers in the discriminator 157 | norm_layer -- normalization layer 158 | """ 159 | super(NLayerDiscriminator, self).__init__() 160 | if not use_actnorm: 161 | norm_layer = nn.BatchNorm2d 162 | else: 163 | norm_layer = ActNorm 164 | if ( 165 | type(norm_layer) == functools.partial 166 | ): # no need to use bias as BatchNorm2d has affine parameters 167 | use_bias = norm_layer.func != nn.BatchNorm2d 168 | else: 169 | use_bias = norm_layer != nn.BatchNorm2d 170 | 171 | kw = 4 172 | padw = 1 173 | sequence = [ 174 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 175 | nn.LeakyReLU(0.2, True), 176 | ] 177 | nf_mult = 1 178 | nf_mult_prev = 1 179 | for n in range(1, n_layers): # gradually increase the number of filters 180 | nf_mult_prev = nf_mult 181 | nf_mult = min(2**n, 8) 182 | sequence += [ 183 | nn.Conv2d( 184 | ndf * nf_mult_prev, 185 | ndf * nf_mult, 186 | kernel_size=kw, 187 | stride=2, 188 | padding=padw, 189 | bias=use_bias, 190 | ), 191 | norm_layer(ndf * nf_mult), 192 | nn.LeakyReLU(0.2, True), 193 | ] 194 | 195 | nf_mult_prev = nf_mult 196 | nf_mult = min(2**n_layers, 8) 197 | sequence += [ 198 | nn.Conv2d( 199 | ndf * nf_mult_prev, 200 | ndf * nf_mult, 201 | kernel_size=kw, 202 | stride=1, 203 | padding=padw, 204 | bias=use_bias, 205 | ), 206 | norm_layer(ndf * nf_mult), 207 | nn.LeakyReLU(0.2, True), 208 | ] 209 | 210 | sequence += [ 211 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) 212 | ] # output 1 channel prediction map 213 | self.main = nn.Sequential(*sequence) 214 | 215 | def forward(self, input): 216 | """Standard forward.""" 217 | return self.main(input) 218 | -------------------------------------------------------------------------------- /threestudio/models/geometry/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.isosurface import ( 10 | IsosurfaceHelper, 11 | MarchingCubeCPUHelper, 12 | MarchingTetrahedraHelper, 13 | ) 14 | from threestudio.models.mesh import Mesh 15 | from threestudio.utils.base import BaseModule 16 | from threestudio.utils.ops import chunk_batch, scale_tensor 17 | from threestudio.utils.typing import * 18 | 19 | 20 | def contract_to_unisphere( 21 | x: Float[Tensor, "... 3"], bbox: Float[Tensor, "2 3"], unbounded: bool = False 22 | ) -> Float[Tensor, "... 3"]: 23 | if unbounded: 24 | x = scale_tensor(x, bbox, (0, 1)) 25 | x = x * 2 - 1 # aabb is at [-1, 1] 26 | mag = x.norm(dim=-1, keepdim=True) 27 | mask = mag.squeeze(-1) > 1 28 | x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask]) 29 | x = x / 4 + 0.5 # [-inf, inf] is at [0, 1] 30 | else: 31 | x = scale_tensor(x, bbox, (0, 1)) 32 | return x 33 | 34 | 35 | class BaseGeometry(BaseModule): 36 | @dataclass 37 | class Config(BaseModule.Config): 38 | pass 39 | 40 | cfg: Config 41 | 42 | @staticmethod 43 | def create_from( 44 | other: "BaseGeometry", cfg: Optional[Union[dict, DictConfig]] = None, **kwargs 45 | ) -> "BaseGeometry": 46 | raise TypeError( 47 | f"Cannot create {BaseGeometry.__name__} from {other.__class__.__name__}" 48 | ) 49 | 50 | def export(self, *args, **kwargs) -> Dict[str, Any]: 51 | return {} 52 | 53 | 54 | class BaseImplicitGeometry(BaseGeometry): 55 | @dataclass 56 | class Config(BaseGeometry.Config): 57 | radius: float = 1.0 58 | isosurface: bool = True 59 | isosurface_method: str = "mt" 60 | isosurface_resolution: int = 128 61 | isosurface_threshold: Union[float, str] = 0.0 62 | isosurface_chunk: int = 0 63 | isosurface_coarse_to_fine: bool = True 64 | isosurface_deformable_grid: bool = False 65 | isosurface_remove_outliers: bool = True 66 | isosurface_outlier_n_faces_threshold: Union[int, float] = 0.01 67 | 68 | cfg: Config 69 | 70 | def configure(self) -> None: 71 | self.bbox: Float[Tensor, "2 3"] 72 | self.register_buffer( 73 | "bbox", 74 | torch.as_tensor( 75 | [ 76 | [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], 77 | [self.cfg.radius, self.cfg.radius, self.cfg.radius], 78 | ], 79 | dtype=torch.float32, 80 | ), 81 | ) 82 | self.isosurface_helper: Optional[IsosurfaceHelper] = None 83 | self.unbounded: bool = False 84 | 85 | def _initilize_isosurface_helper(self): 86 | if self.cfg.isosurface and self.isosurface_helper is None: 87 | if self.cfg.isosurface_method == "mc-cpu": 88 | self.isosurface_helper = MarchingCubeCPUHelper( 89 | self.cfg.isosurface_resolution 90 | ).to(self.device) 91 | elif self.cfg.isosurface_method == "mt": 92 | self.isosurface_helper = MarchingTetrahedraHelper( 93 | self.cfg.isosurface_resolution, 94 | f"load/tets/{self.cfg.isosurface_resolution}_tets.npz", 95 | ).to(self.device) 96 | else: 97 | raise AttributeError( 98 | "Unknown isosurface method {self.cfg.isosurface_method}" 99 | ) 100 | 101 | def forward( 102 | self, points: Float[Tensor, "*N Di"], output_normal: bool = False 103 | ) -> Dict[str, Float[Tensor, "..."]]: 104 | raise NotImplementedError 105 | 106 | def forward_field( 107 | self, points: Float[Tensor, "*N Di"] 108 | ) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]: 109 | # return the value of the implicit field, could be density / signed distance 110 | # also return a deformation field if the grid vertices can be optimized 111 | raise NotImplementedError 112 | 113 | def forward_level( 114 | self, field: Float[Tensor, "*N 1"], threshold: float 115 | ) -> Float[Tensor, "*N 1"]: 116 | # return the value of the implicit field, where the zero level set represents the surface 117 | raise NotImplementedError 118 | 119 | def _isosurface(self, bbox: Float[Tensor, "2 3"], fine_stage: bool = False) -> Mesh: 120 | def batch_func(x): 121 | # scale to bbox as the input vertices are in [0, 1] 122 | field, deformation = self.forward_field( 123 | scale_tensor( 124 | x.to(bbox.device), self.isosurface_helper.points_range, bbox 125 | ), 126 | ) 127 | field = field.to( 128 | x.device 129 | ) # move to the same device as the input (could be CPU) 130 | if deformation is not None: 131 | deformation = deformation.to(x.device) 132 | return field, deformation 133 | 134 | assert self.isosurface_helper is not None 135 | 136 | field, deformation = chunk_batch( 137 | batch_func, 138 | self.cfg.isosurface_chunk, 139 | self.isosurface_helper.grid_vertices, 140 | ) 141 | 142 | threshold: float 143 | 144 | if isinstance(self.cfg.isosurface_threshold, float): 145 | threshold = self.cfg.isosurface_threshold 146 | elif self.cfg.isosurface_threshold == "auto": 147 | eps = 1.0e-5 148 | threshold = field[field > eps].mean().item() 149 | threestudio.info( 150 | f"Automatically determined isosurface threshold: {threshold}" 151 | ) 152 | else: 153 | raise TypeError( 154 | f"Unknown isosurface_threshold {self.cfg.isosurface_threshold}" 155 | ) 156 | 157 | level = self.forward_level(field, threshold) 158 | mesh: Mesh = self.isosurface_helper(level, deformation=deformation) 159 | mesh.v_pos = scale_tensor( 160 | mesh.v_pos, self.isosurface_helper.points_range, bbox 161 | ) # scale to bbox as the grid vertices are in [0, 1] 162 | mesh.add_extra("bbox", bbox) 163 | 164 | if self.cfg.isosurface_remove_outliers: 165 | # remove outliers components with small number of faces 166 | # only enabled when the mesh is not differentiable 167 | mesh = mesh.remove_outlier(self.cfg.isosurface_outlier_n_faces_threshold) 168 | 169 | return mesh 170 | 171 | def isosurface(self) -> Mesh: 172 | if not self.cfg.isosurface: 173 | raise NotImplementedError( 174 | "Isosurface is not enabled in the current configuration" 175 | ) 176 | self._initilize_isosurface_helper() 177 | if self.cfg.isosurface_coarse_to_fine: 178 | threestudio.debug("First run isosurface to get a tight bounding box ...") 179 | with torch.no_grad(): 180 | mesh_coarse = self._isosurface(self.bbox) 181 | vmin, vmax = mesh_coarse.v_pos.amin(dim=0), mesh_coarse.v_pos.amax(dim=0) 182 | vmin_ = (vmin - (vmax - vmin) * 0.1).max(self.bbox[0]) 183 | vmax_ = (vmax + (vmax - vmin) * 0.1).min(self.bbox[1]) 184 | threestudio.debug("Run isosurface again with the tight bounding box ...") 185 | mesh = self._isosurface(torch.stack([vmin_, vmax_], dim=0), fine_stage=True) 186 | else: 187 | mesh = self._isosurface(self.bbox) 188 | return mesh 189 | 190 | 191 | class BaseExplicitGeometry(BaseGeometry): 192 | @dataclass 193 | class Config(BaseGeometry.Config): 194 | radius: float = 1.0 195 | 196 | cfg: Config 197 | 198 | def configure(self) -> None: 199 | self.bbox: Float[Tensor, "2 3"] 200 | self.register_buffer( 201 | "bbox", 202 | torch.as_tensor( 203 | [ 204 | [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], 205 | [self.cfg.radius, self.cfg.radius, self.cfg.radius], 206 | ], 207 | dtype=torch.float32, 208 | ), 209 | ) 210 | -------------------------------------------------------------------------------- /threestudio/models/geometry/volume_grid.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.geometry.base import BaseImplicitGeometry, contract_to_unisphere 10 | from threestudio.utils.ops import get_activation 11 | from threestudio.utils.typing import * 12 | 13 | 14 | @threestudio.register("volume-grid") 15 | class VolumeGrid(BaseImplicitGeometry): 16 | @dataclass 17 | class Config(BaseImplicitGeometry.Config): 18 | grid_size: Tuple[int, int, int] = field(default_factory=lambda: (100, 100, 100)) 19 | n_feature_dims: int = 3 20 | density_activation: Optional[str] = "softplus" 21 | density_bias: Union[float, str] = "blob" 22 | density_blob_scale: float = 5.0 23 | density_blob_std: float = 0.5 24 | normal_type: Optional[ 25 | str 26 | ] = "finite_difference" # in ['pred', 'finite_difference', 'finite_difference_laplacian'] 27 | 28 | # automatically determine the threshold 29 | isosurface_threshold: Union[float, str] = "auto" 30 | 31 | cfg: Config 32 | 33 | def configure(self) -> None: 34 | super().configure() 35 | self.grid_size = self.cfg.grid_size 36 | 37 | self.grid = nn.Parameter( 38 | torch.zeros(1, self.cfg.n_feature_dims + 1, *self.grid_size) 39 | ) 40 | if self.cfg.density_bias == "blob": 41 | self.register_buffer("density_scale", torch.tensor(0.0)) 42 | else: 43 | self.density_scale = nn.Parameter(torch.tensor(0.0)) 44 | 45 | if self.cfg.normal_type == "pred": 46 | self.normal_grid = nn.Parameter(torch.zeros(1, 3, *self.grid_size)) 47 | 48 | def get_density_bias(self, points: Float[Tensor, "*N Di"]): 49 | if self.cfg.density_bias == "blob": 50 | # density_bias: Float[Tensor, "*N 1"] = self.cfg.density_blob_scale * torch.exp(-0.5 * (points ** 2).sum(dim=-1) / self.cfg.density_blob_std ** 2)[...,None] 51 | density_bias: Float[Tensor, "*N 1"] = ( 52 | self.cfg.density_blob_scale 53 | * ( 54 | 1 55 | - torch.sqrt((points.detach() ** 2).sum(dim=-1)) 56 | / self.cfg.density_blob_std 57 | )[..., None] 58 | ) 59 | return density_bias 60 | elif isinstance(self.cfg.density_bias, float): 61 | return self.cfg.density_bias 62 | else: 63 | raise AttributeError(f"Unknown density bias {self.cfg.density_bias}") 64 | 65 | def get_trilinear_feature( 66 | self, points: Float[Tensor, "*N Di"], grid: Float[Tensor, "1 Df G1 G2 G3"] 67 | ) -> Float[Tensor, "*N Df"]: 68 | points_shape = points.shape[:-1] 69 | df = grid.shape[1] 70 | di = points.shape[-1] 71 | out = F.grid_sample( 72 | grid, points.view(1, 1, 1, -1, di), align_corners=False, mode="bilinear" 73 | ) 74 | out = out.reshape(df, -1).T.reshape(*points_shape, df) 75 | return out 76 | 77 | def forward( 78 | self, points: Float[Tensor, "*N Di"], output_normal: bool = False 79 | ) -> Dict[str, Float[Tensor, "..."]]: 80 | points_unscaled = points # points in the original scale 81 | points = contract_to_unisphere( 82 | points, self.bbox, self.unbounded 83 | ) # points normalized to (0, 1) 84 | points = points * 2 - 1 # convert to [-1, 1] for grid sample 85 | 86 | out = self.get_trilinear_feature(points, self.grid) 87 | density, features = out[..., 0:1], out[..., 1:] 88 | density = density * torch.exp(self.density_scale) # exp scaling in DreamFusion 89 | 90 | # breakpoint() 91 | density = get_activation(self.cfg.density_activation)( 92 | density + self.get_density_bias(points_unscaled) 93 | ) 94 | 95 | output = { 96 | "density": density, 97 | "features": features, 98 | } 99 | 100 | if output_normal: 101 | if ( 102 | self.cfg.normal_type == "finite_difference" 103 | or self.cfg.normal_type == "finite_difference_laplacian" 104 | ): 105 | eps = 1.0e-3 106 | if self.cfg.normal_type == "finite_difference_laplacian": 107 | offsets: Float[Tensor, "6 3"] = torch.as_tensor( 108 | [ 109 | [eps, 0.0, 0.0], 110 | [-eps, 0.0, 0.0], 111 | [0.0, eps, 0.0], 112 | [0.0, -eps, 0.0], 113 | [0.0, 0.0, eps], 114 | [0.0, 0.0, -eps], 115 | ] 116 | ).to(points_unscaled) 117 | points_offset: Float[Tensor, "... 6 3"] = ( 118 | points_unscaled[..., None, :] + offsets 119 | ).clamp(-self.cfg.radius, self.cfg.radius) 120 | density_offset: Float[Tensor, "... 6 1"] = self.forward_density( 121 | points_offset 122 | ) 123 | normal = ( 124 | -0.5 125 | * (density_offset[..., 0::2, 0] - density_offset[..., 1::2, 0]) 126 | / eps 127 | ) 128 | else: 129 | offsets: Float[Tensor, "3 3"] = torch.as_tensor( 130 | [[eps, 0.0, 0.0], [0.0, eps, 0.0], [0.0, 0.0, eps]] 131 | ).to(points_unscaled) 132 | points_offset: Float[Tensor, "... 3 3"] = ( 133 | points_unscaled[..., None, :] + offsets 134 | ).clamp(-self.cfg.radius, self.cfg.radius) 135 | density_offset: Float[Tensor, "... 3 1"] = self.forward_density( 136 | points_offset 137 | ) 138 | normal = -(density_offset[..., 0::1, 0] - density) / eps 139 | normal = F.normalize(normal, dim=-1) 140 | elif self.cfg.normal_type == "pred": 141 | normal = self.get_trilinear_feature(points, self.normal_grid) 142 | normal = F.normalize(normal, dim=-1) 143 | else: 144 | raise AttributeError(f"Unknown normal type {self.cfg.normal_type}") 145 | output.update({"normal": normal, "shading_normal": normal}) 146 | return output 147 | 148 | def forward_density(self, points: Float[Tensor, "*N Di"]) -> Float[Tensor, "*N 1"]: 149 | points_unscaled = points 150 | points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded) 151 | points = points * 2 - 1 # convert to [-1, 1] for grid sample 152 | 153 | out = self.get_trilinear_feature(points, self.grid) 154 | density = out[..., 0:1] 155 | density = density * torch.exp(self.density_scale) 156 | 157 | density = get_activation(self.cfg.density_activation)( 158 | density + self.get_density_bias(points_unscaled) 159 | ) 160 | return density 161 | 162 | def forward_field( 163 | self, points: Float[Tensor, "*N Di"] 164 | ) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]: 165 | if self.cfg.isosurface_deformable_grid: 166 | threestudio.warn( 167 | f"{self.__class__.__name__} does not support isosurface_deformable_grid. Ignoring." 168 | ) 169 | density = self.forward_density(points) 170 | return density, None 171 | 172 | def forward_level( 173 | self, field: Float[Tensor, "*N 1"], threshold: float 174 | ) -> Float[Tensor, "*N 1"]: 175 | return -(field - threshold) 176 | 177 | def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]: 178 | out: Dict[str, Any] = {} 179 | if self.cfg.n_feature_dims == 0: 180 | return out 181 | points_unscaled = points 182 | points = contract_to_unisphere(points, self.bbox, self.unbounded) 183 | points = points * 2 - 1 # convert to [-1, 1] for grid sample 184 | features = self.get_trilinear_feature(points, self.grid)[..., 1:] 185 | out.update( 186 | { 187 | "features": features, 188 | } 189 | ) 190 | return out 191 | -------------------------------------------------------------------------------- /threestudio/utils/GAN/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | __all__ = ["MobileNetV3", "mobilenetv3"] 6 | 7 | 8 | def conv_bn( 9 | inp, 10 | oup, 11 | stride, 12 | conv_layer=nn.Conv2d, 13 | norm_layer=nn.BatchNorm2d, 14 | nlin_layer=nn.ReLU, 15 | ): 16 | return nn.Sequential( 17 | conv_layer(inp, oup, 3, stride, 1, bias=False), 18 | norm_layer(oup), 19 | nlin_layer(inplace=True), 20 | ) 21 | 22 | 23 | def conv_1x1_bn( 24 | inp, oup, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, nlin_layer=nn.ReLU 25 | ): 26 | return nn.Sequential( 27 | conv_layer(inp, oup, 1, 1, 0, bias=False), 28 | norm_layer(oup), 29 | nlin_layer(inplace=True), 30 | ) 31 | 32 | 33 | class Hswish(nn.Module): 34 | def __init__(self, inplace=True): 35 | super(Hswish, self).__init__() 36 | self.inplace = inplace 37 | 38 | def forward(self, x): 39 | return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0 40 | 41 | 42 | class Hsigmoid(nn.Module): 43 | def __init__(self, inplace=True): 44 | super(Hsigmoid, self).__init__() 45 | self.inplace = inplace 46 | 47 | def forward(self, x): 48 | return F.relu6(x + 3.0, inplace=self.inplace) / 6.0 49 | 50 | 51 | class SEModule(nn.Module): 52 | def __init__(self, channel, reduction=4): 53 | super(SEModule, self).__init__() 54 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 55 | self.fc = nn.Sequential( 56 | nn.Linear(channel, channel // reduction, bias=False), 57 | nn.ReLU(inplace=True), 58 | nn.Linear(channel // reduction, channel, bias=False), 59 | Hsigmoid() 60 | # nn.Sigmoid() 61 | ) 62 | 63 | def forward(self, x): 64 | b, c, _, _ = x.size() 65 | y = self.avg_pool(x).view(b, c) 66 | y = self.fc(y).view(b, c, 1, 1) 67 | return x * y.expand_as(x) 68 | 69 | 70 | class Identity(nn.Module): 71 | def __init__(self, channel): 72 | super(Identity, self).__init__() 73 | 74 | def forward(self, x): 75 | return x 76 | 77 | 78 | def make_divisible(x, divisible_by=8): 79 | import numpy as np 80 | 81 | return int(np.ceil(x * 1.0 / divisible_by) * divisible_by) 82 | 83 | 84 | class MobileBottleneck(nn.Module): 85 | def __init__(self, inp, oup, kernel, stride, exp, se=False, nl="RE"): 86 | super(MobileBottleneck, self).__init__() 87 | assert stride in [1, 2] 88 | assert kernel in [3, 5] 89 | padding = (kernel - 1) // 2 90 | self.use_res_connect = stride == 1 and inp == oup 91 | 92 | conv_layer = nn.Conv2d 93 | norm_layer = nn.BatchNorm2d 94 | if nl == "RE": 95 | nlin_layer = nn.ReLU # or ReLU6 96 | elif nl == "HS": 97 | nlin_layer = Hswish 98 | else: 99 | raise NotImplementedError 100 | if se: 101 | SELayer = SEModule 102 | else: 103 | SELayer = Identity 104 | 105 | self.conv = nn.Sequential( 106 | # pw 107 | conv_layer(inp, exp, 1, 1, 0, bias=False), 108 | norm_layer(exp), 109 | nlin_layer(inplace=True), 110 | # dw 111 | conv_layer(exp, exp, kernel, stride, padding, groups=exp, bias=False), 112 | norm_layer(exp), 113 | SELayer(exp), 114 | nlin_layer(inplace=True), 115 | # pw-linear 116 | conv_layer(exp, oup, 1, 1, 0, bias=False), 117 | norm_layer(oup), 118 | ) 119 | 120 | def forward(self, x): 121 | if self.use_res_connect: 122 | return x + self.conv(x) 123 | else: 124 | return self.conv(x) 125 | 126 | 127 | class MobileNetV3(nn.Module): 128 | def __init__( 129 | self, n_class=1000, input_size=224, dropout=0.0, mode="small", width_mult=1.0 130 | ): 131 | super(MobileNetV3, self).__init__() 132 | input_channel = 16 133 | last_channel = 1280 134 | if mode == "large": 135 | # refer to Table 1 in paper 136 | mobile_setting = [ 137 | # k, exp, c, se, nl, s, 138 | [3, 16, 16, False, "RE", 1], 139 | [3, 64, 24, False, "RE", 2], 140 | [3, 72, 24, False, "RE", 1], 141 | [5, 72, 40, True, "RE", 2], 142 | [5, 120, 40, True, "RE", 1], 143 | [5, 120, 40, True, "RE", 1], 144 | [3, 240, 80, False, "HS", 2], 145 | [3, 200, 80, False, "HS", 1], 146 | [3, 184, 80, False, "HS", 1], 147 | [3, 184, 80, False, "HS", 1], 148 | [3, 480, 112, True, "HS", 1], 149 | [3, 672, 112, True, "HS", 1], 150 | [5, 672, 160, True, "HS", 2], 151 | [5, 960, 160, True, "HS", 1], 152 | [5, 960, 160, True, "HS", 1], 153 | ] 154 | elif mode == "small": 155 | # refer to Table 2 in paper 156 | mobile_setting = [ 157 | # k, exp, c, se, nl, s, 158 | [3, 16, 16, True, "RE", 2], 159 | [3, 72, 24, False, "RE", 2], 160 | [3, 88, 24, False, "RE", 1], 161 | [5, 96, 40, True, "HS", 2], 162 | [5, 240, 40, True, "HS", 1], 163 | [5, 240, 40, True, "HS", 1], 164 | [5, 120, 48, True, "HS", 1], 165 | [5, 144, 48, True, "HS", 1], 166 | [5, 288, 96, True, "HS", 2], 167 | [5, 576, 96, True, "HS", 1], 168 | [5, 576, 96, True, "HS", 1], 169 | ] 170 | else: 171 | raise NotImplementedError 172 | 173 | # building first layer 174 | assert input_size % 32 == 0 175 | last_channel = ( 176 | make_divisible(last_channel * width_mult) 177 | if width_mult > 1.0 178 | else last_channel 179 | ) 180 | self.features = [conv_bn(3, input_channel, 2, nlin_layer=Hswish)] 181 | self.classifier = [] 182 | 183 | # building mobile blocks 184 | for k, exp, c, se, nl, s in mobile_setting: 185 | output_channel = make_divisible(c * width_mult) 186 | exp_channel = make_divisible(exp * width_mult) 187 | self.features.append( 188 | MobileBottleneck( 189 | input_channel, output_channel, k, s, exp_channel, se, nl 190 | ) 191 | ) 192 | input_channel = output_channel 193 | 194 | # building last several layers 195 | if mode == "large": 196 | last_conv = make_divisible(960 * width_mult) 197 | self.features.append( 198 | conv_1x1_bn(input_channel, last_conv, nlin_layer=Hswish) 199 | ) 200 | self.features.append(nn.AdaptiveAvgPool2d(1)) 201 | self.features.append(nn.Conv2d(last_conv, last_channel, 1, 1, 0)) 202 | self.features.append(Hswish(inplace=True)) 203 | elif mode == "small": 204 | last_conv = make_divisible(576 * width_mult) 205 | self.features.append( 206 | conv_1x1_bn(input_channel, last_conv, nlin_layer=Hswish) 207 | ) 208 | # self.features.append(SEModule(last_conv)) # refer to paper Table2, but I think this is a mistake 209 | self.features.append(nn.AdaptiveAvgPool2d(1)) 210 | self.features.append(nn.Conv2d(last_conv, last_channel, 1, 1, 0)) 211 | self.features.append(Hswish(inplace=True)) 212 | else: 213 | raise NotImplementedError 214 | 215 | # make it nn.Sequential 216 | self.features = nn.Sequential(*self.features) 217 | 218 | # building classifier 219 | self.classifier = nn.Sequential( 220 | nn.Dropout(p=dropout), # refer to paper section 6 221 | nn.Linear(last_channel, n_class), 222 | ) 223 | 224 | self._initialize_weights() 225 | 226 | def forward(self, x): 227 | x = self.features(x) 228 | x = x.mean(3).mean(2) 229 | x = self.classifier(x) 230 | return x 231 | 232 | def _initialize_weights(self): 233 | # weight initialization 234 | for m in self.modules(): 235 | if isinstance(m, nn.Conv2d): 236 | nn.init.kaiming_normal_(m.weight, mode="fan_out") 237 | if m.bias is not None: 238 | nn.init.zeros_(m.bias) 239 | elif isinstance(m, nn.BatchNorm2d): 240 | nn.init.ones_(m.weight) 241 | nn.init.zeros_(m.bias) 242 | elif isinstance(m, nn.Linear): 243 | nn.init.normal_(m.weight, 0, 0.01) 244 | if m.bias is not None: 245 | nn.init.zeros_(m.bias) 246 | 247 | 248 | def mobilenetv3(pretrained=False, **kwargs): 249 | model = MobileNetV3(**kwargs) 250 | if pretrained: 251 | state_dict = torch.load("mobilenetv3_small_67.4.pth.tar") 252 | model.load_state_dict(state_dict, strict=True) 253 | # raise NotImplementedError 254 | return model 255 | -------------------------------------------------------------------------------- /threestudio/utils/GAN/attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | from inspect import isfunction 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from einops import rearrange, repeat 7 | from torch import einsum, nn 8 | 9 | from threestudio.utils.GAN.network_util import checkpoint 10 | 11 | 12 | def exists(val): 13 | return val is not None 14 | 15 | 16 | def uniq(arr): 17 | return {el: True for el in arr}.keys() 18 | 19 | 20 | def default(val, d): 21 | if exists(val): 22 | return val 23 | return d() if isfunction(d) else d 24 | 25 | 26 | def max_neg_value(t): 27 | return -torch.finfo(t.dtype).max 28 | 29 | 30 | def init_(tensor): 31 | dim = tensor.shape[-1] 32 | std = 1 / math.sqrt(dim) 33 | tensor.uniform_(-std, std) 34 | return tensor 35 | 36 | 37 | # feedforward 38 | class GEGLU(nn.Module): 39 | def __init__(self, dim_in, dim_out): 40 | super().__init__() 41 | self.proj = nn.Linear(dim_in, dim_out * 2) 42 | 43 | def forward(self, x): 44 | x, gate = self.proj(x).chunk(2, dim=-1) 45 | return x * F.gelu(gate) 46 | 47 | 48 | class FeedForward(nn.Module): 49 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): 50 | super().__init__() 51 | inner_dim = int(dim * mult) 52 | dim_out = default(dim_out, dim) 53 | project_in = ( 54 | nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) 55 | if not glu 56 | else GEGLU(dim, inner_dim) 57 | ) 58 | 59 | self.net = nn.Sequential( 60 | project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm( 78 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True 79 | ) 80 | 81 | 82 | class LinearAttention(nn.Module): 83 | def __init__(self, dim, heads=4, dim_head=32): 84 | super().__init__() 85 | self.heads = heads 86 | hidden_dim = dim_head * heads 87 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 88 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 89 | 90 | def forward(self, x): 91 | b, c, h, w = x.shape 92 | qkv = self.to_qkv(x) 93 | q, k, v = rearrange( 94 | qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 95 | ) 96 | k = k.softmax(dim=-1) 97 | context = torch.einsum("bhdn,bhen->bhde", k, v) 98 | out = torch.einsum("bhde,bhdn->bhen", context, q) 99 | out = rearrange( 100 | out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w 101 | ) 102 | return self.to_out(out) 103 | 104 | 105 | class SpatialSelfAttention(nn.Module): 106 | def __init__(self, in_channels): 107 | super().__init__() 108 | self.in_channels = in_channels 109 | 110 | self.norm = Normalize(in_channels) 111 | self.q = torch.nn.Conv2d( 112 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 113 | ) 114 | self.k = torch.nn.Conv2d( 115 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 116 | ) 117 | self.v = torch.nn.Conv2d( 118 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 119 | ) 120 | self.proj_out = torch.nn.Conv2d( 121 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 122 | ) 123 | 124 | def forward(self, x): 125 | h_ = x 126 | h_ = self.norm(h_) 127 | q = self.q(h_) 128 | k = self.k(h_) 129 | v = self.v(h_) 130 | 131 | # compute attention 132 | b, c, h, w = q.shape 133 | q = rearrange(q, "b c h w -> b (h w) c") 134 | k = rearrange(k, "b c h w -> b c (h w)") 135 | w_ = torch.einsum("bij,bjk->bik", q, k) 136 | 137 | w_ = w_ * (int(c) ** (-0.5)) 138 | w_ = torch.nn.functional.softmax(w_, dim=2) 139 | 140 | # attend to values 141 | v = rearrange(v, "b c h w -> b c (h w)") 142 | w_ = rearrange(w_, "b i j -> b j i") 143 | h_ = torch.einsum("bij,bjk->bik", v, w_) 144 | h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) 145 | h_ = self.proj_out(h_) 146 | 147 | return x + h_ 148 | 149 | 150 | class CrossAttention(nn.Module): 151 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): 152 | super().__init__() 153 | inner_dim = dim_head * heads 154 | context_dim = default(context_dim, query_dim) 155 | 156 | self.scale = dim_head**-0.5 157 | self.heads = heads 158 | 159 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 160 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 161 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 162 | 163 | self.to_out = nn.Sequential( 164 | nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) 165 | ) 166 | 167 | def forward(self, x, context=None, mask=None): 168 | h = self.heads 169 | 170 | q = self.to_q(x) 171 | context = default(context, x) 172 | k = self.to_k(context) 173 | v = self.to_v(context) 174 | 175 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) 176 | 177 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale 178 | 179 | if exists(mask): 180 | mask = rearrange(mask, "b ... -> b (...)") 181 | max_neg_value = -torch.finfo(sim.dtype).max 182 | mask = repeat(mask, "b j -> (b h) () j", h=h) 183 | sim.masked_fill_(~mask, max_neg_value) 184 | 185 | # attention, what we cannot get enough of 186 | attn = sim.softmax(dim=-1) 187 | 188 | out = einsum("b i j, b j d -> b i d", attn, v) 189 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 190 | return self.to_out(out) 191 | 192 | 193 | class BasicTransformerBlock(nn.Module): 194 | def __init__( 195 | self, 196 | dim, 197 | n_heads, 198 | d_head, 199 | dropout=0.0, 200 | context_dim=None, 201 | gated_ff=True, 202 | checkpoint=True, 203 | ): 204 | super().__init__() 205 | self.attn1 = CrossAttention( 206 | query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout 207 | ) # is a self-attention 208 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 209 | self.attn2 = CrossAttention( 210 | query_dim=dim, 211 | context_dim=context_dim, 212 | heads=n_heads, 213 | dim_head=d_head, 214 | dropout=dropout, 215 | ) # is self-attn if context is none 216 | self.norm1 = nn.LayerNorm(dim) 217 | self.norm2 = nn.LayerNorm(dim) 218 | self.norm3 = nn.LayerNorm(dim) 219 | self.checkpoint = checkpoint 220 | 221 | def forward(self, x, context=None): 222 | return checkpoint( 223 | self._forward, (x, context), self.parameters(), self.checkpoint 224 | ) 225 | 226 | def _forward(self, x, context=None): 227 | x = self.attn1(self.norm1(x)) + x 228 | x = self.attn2(self.norm2(x), context=context) + x 229 | x = self.ff(self.norm3(x)) + x 230 | return x 231 | 232 | 233 | class SpatialTransformer(nn.Module): 234 | """ 235 | Transformer block for image-like data. 236 | First, project the input (aka embedding) 237 | and reshape to b, t, d. 238 | Then apply standard transformer action. 239 | Finally, reshape to image 240 | """ 241 | 242 | def __init__( 243 | self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None 244 | ): 245 | super().__init__() 246 | self.in_channels = in_channels 247 | inner_dim = n_heads * d_head 248 | self.norm = Normalize(in_channels) 249 | 250 | self.proj_in = nn.Conv2d( 251 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 252 | ) 253 | 254 | self.transformer_blocks = nn.ModuleList( 255 | [ 256 | BasicTransformerBlock( 257 | inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim 258 | ) 259 | for d in range(depth) 260 | ] 261 | ) 262 | 263 | self.proj_out = zero_module( 264 | nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 265 | ) 266 | 267 | def forward(self, x, context=None): 268 | # note: if no context is given, cross-attention defaults to self-attention 269 | b, c, h, w = x.shape 270 | x_in = x 271 | x = self.norm(x) 272 | x = self.proj_in(x) 273 | x = rearrange(x, "b c h w -> b (h w) c") 274 | for block in self.transformer_blocks: 275 | x = block(x, context=context) 276 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 277 | x = self.proj_out(x) 278 | return x + x_in 279 | -------------------------------------------------------------------------------- /threestudio/models/isosurface.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import threestudio 7 | from threestudio.models.mesh import Mesh 8 | from threestudio.utils.typing import * 9 | 10 | 11 | class IsosurfaceHelper(nn.Module): 12 | points_range: Tuple[float, float] = (0, 1) 13 | 14 | @property 15 | def grid_vertices(self) -> Float[Tensor, "N 3"]: 16 | raise NotImplementedError 17 | 18 | 19 | class MarchingCubeCPUHelper(IsosurfaceHelper): 20 | def __init__(self, resolution: int) -> None: 21 | super().__init__() 22 | self.resolution = resolution 23 | import mcubes 24 | 25 | self.mc_func: Callable = mcubes.marching_cubes 26 | self._grid_vertices: Optional[Float[Tensor, "N3 3"]] = None 27 | self._dummy: Float[Tensor, "..."] 28 | self.register_buffer( 29 | "_dummy", torch.zeros(0, dtype=torch.float32), persistent=False 30 | ) 31 | 32 | @property 33 | def grid_vertices(self) -> Float[Tensor, "N3 3"]: 34 | if self._grid_vertices is None: 35 | # keep the vertices on CPU so that we can support very large resolution 36 | x, y, z = ( 37 | torch.linspace(*self.points_range, self.resolution), 38 | torch.linspace(*self.points_range, self.resolution), 39 | torch.linspace(*self.points_range, self.resolution), 40 | ) 41 | x, y, z = torch.meshgrid(x, y, z, indexing="ij") 42 | verts = torch.cat( 43 | [x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1 44 | ).reshape(-1, 3) 45 | self._grid_vertices = verts 46 | return self._grid_vertices 47 | 48 | def forward( 49 | self, 50 | level: Float[Tensor, "N3 1"], 51 | deformation: Optional[Float[Tensor, "N3 3"]] = None, 52 | ) -> Mesh: 53 | if deformation is not None: 54 | threestudio.warn( 55 | f"{self.__class__.__name__} does not support deformation. Ignoring." 56 | ) 57 | level = -level.view(self.resolution, self.resolution, self.resolution) 58 | v_pos, t_pos_idx = self.mc_func( 59 | level.detach().cpu().numpy(), 0.0 60 | ) # transform to numpy 61 | v_pos, t_pos_idx = ( 62 | torch.from_numpy(v_pos).float().to(self._dummy.device), 63 | torch.from_numpy(t_pos_idx.astype(np.int64)).long().to(self._dummy.device), 64 | ) # transform back to torch tensor on CUDA 65 | v_pos = v_pos / (self.resolution - 1.0) 66 | return Mesh(v_pos=v_pos, t_pos_idx=t_pos_idx) 67 | 68 | 69 | class MarchingTetrahedraHelper(IsosurfaceHelper): 70 | def __init__(self, resolution: int, tets_path: str): 71 | super().__init__() 72 | self.resolution = resolution 73 | self.tets_path = tets_path 74 | 75 | self.triangle_table: Float[Tensor, "..."] 76 | self.register_buffer( 77 | "triangle_table", 78 | torch.as_tensor( 79 | [ 80 | [-1, -1, -1, -1, -1, -1], 81 | [1, 0, 2, -1, -1, -1], 82 | [4, 0, 3, -1, -1, -1], 83 | [1, 4, 2, 1, 3, 4], 84 | [3, 1, 5, -1, -1, -1], 85 | [2, 3, 0, 2, 5, 3], 86 | [1, 4, 0, 1, 5, 4], 87 | [4, 2, 5, -1, -1, -1], 88 | [4, 5, 2, -1, -1, -1], 89 | [4, 1, 0, 4, 5, 1], 90 | [3, 2, 0, 3, 5, 2], 91 | [1, 3, 5, -1, -1, -1], 92 | [4, 1, 2, 4, 3, 1], 93 | [3, 0, 4, -1, -1, -1], 94 | [2, 0, 1, -1, -1, -1], 95 | [-1, -1, -1, -1, -1, -1], 96 | ], 97 | dtype=torch.long, 98 | ), 99 | persistent=False, 100 | ) 101 | self.num_triangles_table: Integer[Tensor, "..."] 102 | self.register_buffer( 103 | "num_triangles_table", 104 | torch.as_tensor( 105 | [0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long 106 | ), 107 | persistent=False, 108 | ) 109 | self.base_tet_edges: Integer[Tensor, "..."] 110 | self.register_buffer( 111 | "base_tet_edges", 112 | torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long), 113 | persistent=False, 114 | ) 115 | 116 | tets = np.load(self.tets_path) 117 | self._grid_vertices: Float[Tensor, "..."] 118 | self.register_buffer( 119 | "_grid_vertices", 120 | torch.from_numpy(tets["vertices"]).float(), 121 | persistent=False, 122 | ) 123 | self.indices: Integer[Tensor, "..."] 124 | self.register_buffer( 125 | "indices", torch.from_numpy(tets["indices"]).long(), persistent=False 126 | ) 127 | 128 | self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None 129 | 130 | def normalize_grid_deformation( 131 | self, grid_vertex_offsets: Float[Tensor, "Nv 3"] 132 | ) -> Float[Tensor, "Nv 3"]: 133 | return ( 134 | (self.points_range[1] - self.points_range[0]) 135 | / (self.resolution) # half tet size is approximately 1 / self.resolution 136 | * torch.tanh(grid_vertex_offsets) 137 | ) # FIXME: hard-coded activation 138 | 139 | @property 140 | def grid_vertices(self) -> Float[Tensor, "Nv 3"]: 141 | return self._grid_vertices 142 | 143 | @property 144 | def all_edges(self) -> Integer[Tensor, "Ne 2"]: 145 | if self._all_edges is None: 146 | # compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation) 147 | edges = torch.tensor( 148 | [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], 149 | dtype=torch.long, 150 | device=self.indices.device, 151 | ) 152 | _all_edges = self.indices[:, edges].reshape(-1, 2) 153 | _all_edges_sorted = torch.sort(_all_edges, dim=1)[0] 154 | _all_edges = torch.unique(_all_edges_sorted, dim=0) 155 | self._all_edges = _all_edges 156 | return self._all_edges 157 | 158 | def sort_edges(self, edges_ex2): 159 | with torch.no_grad(): 160 | order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long() 161 | order = order.unsqueeze(dim=1) 162 | 163 | a = torch.gather(input=edges_ex2, index=order, dim=1) 164 | b = torch.gather(input=edges_ex2, index=1 - order, dim=1) 165 | 166 | return torch.stack([a, b], -1) 167 | 168 | def _forward(self, pos_nx3, sdf_n, tet_fx4): 169 | with torch.no_grad(): 170 | occ_n = sdf_n > 0 171 | occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) 172 | occ_sum = torch.sum(occ_fx4, -1) 173 | valid_tets = (occ_sum > 0) & (occ_sum < 4) 174 | occ_sum = occ_sum[valid_tets] 175 | 176 | # find all vertices 177 | all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2) 178 | all_edges = self.sort_edges(all_edges) 179 | unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) 180 | 181 | unique_edges = unique_edges.long() 182 | mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 183 | mapping = ( 184 | torch.ones( 185 | (unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device 186 | ) 187 | * -1 188 | ) 189 | mapping[mask_edges] = torch.arange( 190 | mask_edges.sum(), dtype=torch.long, device=pos_nx3.device 191 | ) 192 | idx_map = mapping[idx_map] # map edges to verts 193 | 194 | interp_v = unique_edges[mask_edges] 195 | edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) 196 | edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) 197 | edges_to_interp_sdf[:, -1] *= -1 198 | 199 | denominator = edges_to_interp_sdf.sum(1, keepdim=True) 200 | 201 | edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator 202 | verts = (edges_to_interp * edges_to_interp_sdf).sum(1) 203 | 204 | idx_map = idx_map.reshape(-1, 6) 205 | 206 | v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device)) 207 | tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) 208 | num_triangles = self.num_triangles_table[tetindex] 209 | 210 | # Generate triangle indices 211 | faces = torch.cat( 212 | ( 213 | torch.gather( 214 | input=idx_map[num_triangles == 1], 215 | dim=1, 216 | index=self.triangle_table[tetindex[num_triangles == 1]][:, :3], 217 | ).reshape(-1, 3), 218 | torch.gather( 219 | input=idx_map[num_triangles == 2], 220 | dim=1, 221 | index=self.triangle_table[tetindex[num_triangles == 2]][:, :6], 222 | ).reshape(-1, 3), 223 | ), 224 | dim=0, 225 | ) 226 | 227 | return verts, faces 228 | 229 | def forward( 230 | self, 231 | level: Float[Tensor, "N3 1"], 232 | deformation: Optional[Float[Tensor, "N3 3"]] = None, 233 | ) -> Mesh: 234 | if deformation is not None: 235 | grid_vertices = self.grid_vertices + self.normalize_grid_deformation( 236 | deformation 237 | ) 238 | else: 239 | grid_vertices = self.grid_vertices 240 | 241 | v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices) 242 | 243 | mesh = Mesh( 244 | v_pos=v_pos, 245 | t_pos_idx=t_pos_idx, 246 | # extras 247 | grid_vertices=grid_vertices, 248 | tet_edges=self.all_edges, 249 | grid_level=level, 250 | grid_deformation=deformation, 251 | ) 252 | 253 | return mesh 254 | -------------------------------------------------------------------------------- /threestudio/utils/GAN/network_util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import math 12 | import os 13 | 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | from einops import repeat 18 | 19 | from threestudio.utils.GAN.util import instantiate_from_config 20 | 21 | 22 | def make_beta_schedule( 23 | schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 24 | ): 25 | if schedule == "linear": 26 | betas = ( 27 | torch.linspace( 28 | linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 29 | ) 30 | ** 2 31 | ) 32 | 33 | elif schedule == "cosine": 34 | timesteps = ( 35 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 36 | ) 37 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 38 | alphas = torch.cos(alphas).pow(2) 39 | alphas = alphas / alphas[0] 40 | betas = 1 - alphas[1:] / alphas[:-1] 41 | betas = np.clip(betas, a_min=0, a_max=0.999) 42 | 43 | elif schedule == "sqrt_linear": 44 | betas = torch.linspace( 45 | linear_start, linear_end, n_timestep, dtype=torch.float64 46 | ) 47 | elif schedule == "sqrt": 48 | betas = ( 49 | torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 50 | ** 0.5 51 | ) 52 | else: 53 | raise ValueError(f"schedule '{schedule}' unknown.") 54 | return betas.numpy() 55 | 56 | 57 | def make_ddim_timesteps( 58 | ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True 59 | ): 60 | if ddim_discr_method == "uniform": 61 | c = num_ddpm_timesteps // num_ddim_timesteps 62 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 63 | elif ddim_discr_method == "quad": 64 | ddim_timesteps = ( 65 | (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 66 | ).astype(int) 67 | else: 68 | raise NotImplementedError( 69 | f'There is no ddim discretization method called "{ddim_discr_method}"' 70 | ) 71 | 72 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 73 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 74 | steps_out = ddim_timesteps + 1 75 | if verbose: 76 | print(f"Selected timesteps for ddim sampler: {steps_out}") 77 | return steps_out 78 | 79 | 80 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 81 | # select alphas for computing the variance schedule 82 | alphas = alphacums[ddim_timesteps] 83 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 84 | 85 | # according the the formula provided in https://arxiv.org/abs/2010.02502 86 | sigmas = eta * np.sqrt( 87 | (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) 88 | ) 89 | if verbose: 90 | print( 91 | f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" 92 | ) 93 | print( 94 | f"For the chosen value of eta, which is {eta}, " 95 | f"this results in the following sigma_t schedule for ddim sampler {sigmas}" 96 | ) 97 | return sigmas, alphas, alphas_prev 98 | 99 | 100 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 101 | """ 102 | Create a beta schedule that discretizes the given alpha_t_bar function, 103 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 104 | :param num_diffusion_timesteps: the number of betas to produce. 105 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 106 | produces the cumulative product of (1-beta) up to that 107 | part of the diffusion process. 108 | :param max_beta: the maximum beta to use; use values lower than 1 to 109 | prevent singularities. 110 | """ 111 | betas = [] 112 | for i in range(num_diffusion_timesteps): 113 | t1 = i / num_diffusion_timesteps 114 | t2 = (i + 1) / num_diffusion_timesteps 115 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 116 | return np.array(betas) 117 | 118 | 119 | def extract_into_tensor(a, t, x_shape): 120 | b, *_ = t.shape 121 | out = a.gather(-1, t) 122 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 123 | 124 | 125 | def checkpoint(func, inputs, params, flag): 126 | """ 127 | Evaluate a function without caching intermediate activations, allowing for 128 | reduced memory at the expense of extra compute in the backward pass. 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(torch.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | 149 | with torch.no_grad(): 150 | output_tensors = ctx.run_function(*ctx.input_tensors) 151 | return output_tensors 152 | 153 | @staticmethod 154 | def backward(ctx, *output_grads): 155 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 156 | with torch.enable_grad(): 157 | # Fixes a bug where the first op in run_function modifies the 158 | # Tensor storage in place, which is not allowed for detach()'d 159 | # Tensors. 160 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 161 | output_tensors = ctx.run_function(*shallow_copies) 162 | input_grads = torch.autograd.grad( 163 | output_tensors, 164 | ctx.input_tensors + ctx.input_params, 165 | output_grads, 166 | allow_unused=True, 167 | ) 168 | del ctx.input_tensors 169 | del ctx.input_params 170 | del output_tensors 171 | return (None, None) + input_grads 172 | 173 | 174 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 175 | """ 176 | Create sinusoidal timestep embeddings. 177 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 178 | These may be fractional. 179 | :param dim: the dimension of the output. 180 | :param max_period: controls the minimum frequency of the embeddings. 181 | :return: an [N x dim] Tensor of positional embeddings. 182 | """ 183 | if not repeat_only: 184 | half = dim // 2 185 | freqs = torch.exp( 186 | -math.log(max_period) 187 | * torch.arange(start=0, end=half, dtype=torch.float32) 188 | / half 189 | ).to(device=timesteps.device) 190 | args = timesteps[:, None].float() * freqs[None] 191 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 192 | if dim % 2: 193 | embedding = torch.cat( 194 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 195 | ) 196 | else: 197 | embedding = repeat(timesteps, "b -> b d", d=dim) 198 | return embedding 199 | 200 | 201 | def zero_module(module): 202 | """ 203 | Zero out the parameters of a module and return it. 204 | """ 205 | for p in module.parameters(): 206 | p.detach().zero_() 207 | return module 208 | 209 | 210 | def scale_module(module, scale): 211 | """ 212 | Scale the parameters of a module and return it. 213 | """ 214 | for p in module.parameters(): 215 | p.detach().mul_(scale) 216 | return module 217 | 218 | 219 | def mean_flat(tensor): 220 | """ 221 | Take the mean over all non-batch dimensions. 222 | """ 223 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 224 | 225 | 226 | def normalization(channels): 227 | """ 228 | Make a standard normalization layer. 229 | :param channels: number of input channels. 230 | :return: an nn.Module for normalization. 231 | """ 232 | return GroupNorm32(32, channels) 233 | 234 | 235 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 236 | class SiLU(nn.Module): 237 | def forward(self, x): 238 | return x * torch.sigmoid(x) 239 | 240 | 241 | class GroupNorm32(nn.GroupNorm): 242 | def forward(self, x): 243 | return super().forward(x.float()).type(x.dtype) 244 | 245 | 246 | def conv_nd(dims, *args, **kwargs): 247 | """ 248 | Create a 1D, 2D, or 3D convolution module. 249 | """ 250 | if dims == 1: 251 | return nn.Conv1d(*args, **kwargs) 252 | elif dims == 2: 253 | return nn.Conv2d(*args, **kwargs) 254 | elif dims == 3: 255 | return nn.Conv3d(*args, **kwargs) 256 | raise ValueError(f"unsupported dimensions: {dims}") 257 | 258 | 259 | def linear(*args, **kwargs): 260 | """ 261 | Create a linear module. 262 | """ 263 | return nn.Linear(*args, **kwargs) 264 | 265 | 266 | def avg_pool_nd(dims, *args, **kwargs): 267 | """ 268 | Create a 1D, 2D, or 3D average pooling module. 269 | """ 270 | if dims == 1: 271 | return nn.AvgPool1d(*args, **kwargs) 272 | elif dims == 2: 273 | return nn.AvgPool2d(*args, **kwargs) 274 | elif dims == 3: 275 | return nn.AvgPool3d(*args, **kwargs) 276 | raise ValueError(f"unsupported dimensions: {dims}") 277 | 278 | 279 | class HybridConditioner(nn.Module): 280 | def __init__(self, c_concat_config, c_crossattn_config): 281 | super().__init__() 282 | self.concat_conditioner = instantiate_from_config(c_concat_config) 283 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 284 | 285 | def forward(self, c_concat, c_crossattn): 286 | c_concat = self.concat_conditioner(c_concat) 287 | c_crossattn = self.crossattn_conditioner(c_crossattn) 288 | return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} 289 | 290 | 291 | def noise_like(shape, device, repeat=False): 292 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( 293 | shape[0], *((1,) * (len(shape) - 1)) 294 | ) 295 | noise = lambda: torch.randn(shape, device=device) 296 | return repeat_noise() if repeat else noise() 297 | --------------------------------------------------------------------------------