├── .gitignore ├── LICENSE ├── README.md ├── configs └── 4diffusion.yaml ├── dataset └── uid.npy ├── launch.py ├── load └── prompt_library.json ├── requirements.txt └── threestudio ├── __init__.py ├── data ├── __init__.py ├── co3d.py ├── image.py ├── multiview.py ├── random_multiview.py ├── single_multiview_combined.py └── uncond.py ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── models ├── __init__.py ├── background │ ├── __init__.py │ ├── base.py │ ├── neural_environment_map_background.py │ ├── solid_color_background.py │ └── textured_background.py ├── exporters │ ├── __init__.py │ ├── base.py │ └── mesh_exporter.py ├── geometry │ ├── __init__.py │ ├── base.py │ ├── implicit_sdf.py │ ├── implicit_volume.py │ ├── implicit_volume_kplane_e.py │ ├── tetrahedra_sdf_grid.py │ └── volume_grid.py ├── guidance │ ├── __init__.py │ ├── deep_floyd_guidance.py │ ├── multiview_video_diffusion_guidance.py │ ├── stable_diffusion_guidance.py │ ├── stable_diffusion_vsd_guidance.py │ ├── zero123_guidance.py │ └── zeroscope_guidance.py ├── hexplane.py ├── imagedream │ ├── .gitignore │ ├── LICENSE-CODE │ ├── __init__.py │ ├── assets │ │ └── yoda │ │ │ ├── 0.png │ │ │ ├── 1.png │ │ │ ├── 10.png │ │ │ ├── 11.png │ │ │ ├── 12.png │ │ │ ├── 13.png │ │ │ ├── 14.png │ │ │ ├── 15.png │ │ │ ├── 16.png │ │ │ ├── 17.png │ │ │ ├── 18.png │ │ │ ├── 19.png │ │ │ ├── 2.png │ │ │ ├── 20.png │ │ │ ├── 21.png │ │ │ ├── 22.png │ │ │ ├── 23.png │ │ │ ├── 24.png │ │ │ ├── 3.png │ │ │ ├── 4.png │ │ │ ├── 5.png │ │ │ ├── 6.png │ │ │ ├── 7.png │ │ │ ├── 8.png │ │ │ └── 9.png │ ├── imagedream │ │ ├── __init__.py │ │ ├── camera_utils.py │ │ ├── configs │ │ │ ├── sd_v2_base_ipmv.yaml │ │ │ └── sd_v2_base_ipmv_local.yaml │ │ ├── ldm │ │ │ ├── __init__.py │ │ │ ├── interface.py │ │ │ ├── models │ │ │ │ ├── __init__.py │ │ │ │ ├── autoencoder.py │ │ │ │ └── diffusion │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── ddim.py │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── attention.py │ │ │ │ ├── diffusionmodules │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── adaptors.py │ │ │ │ │ ├── attention.py │ │ │ │ │ ├── model.py │ │ │ │ │ ├── motion_module.py │ │ │ │ │ ├── openaimodel.py │ │ │ │ │ └── util.py │ │ │ │ ├── distributions │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── distributions.py │ │ │ │ ├── ema.py │ │ │ │ └── encoders │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── modules.py │ │ │ └── util.py │ │ └── model_zoo.py │ └── scripts │ │ ├── demo.py │ │ ├── demo.sh │ │ └── remove_bg.py ├── isosurface.py ├── materials │ ├── __init__.py │ ├── base.py │ ├── diffuse_with_point_light_material.py │ ├── neural_radiance_material.py │ ├── no_material.py │ └── sd_latent_adapter_material.py ├── mesh.py ├── networks.py ├── prompt_processors │ ├── __init__.py │ ├── base.py │ ├── deepfloyd_prompt_processor.py │ ├── stable_diffusion_prompt_processor.py │ ├── zero123_prompt_processor.py │ └── zeroscope_prompt_processor.py └── renderers │ ├── __init__.py │ ├── base.py │ ├── deferred_volume_renderer.py │ ├── nerf_volume_renderer.py │ ├── neus_volume_renderer.py │ └── nvdiff_rasterizer.py ├── systems ├── __init__.py ├── base.py ├── fourdiffusion.py └── utils.py └── utils ├── __init__.py ├── base.py ├── callbacks.py ├── config.py ├── loss_utils.py ├── misc.py ├── ops.py ├── rasterize.py ├── saving.py └── typing.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 4Diffusion: Multi-view Video Diffusion Model for 4D Generation 2 | 3 | | [Project Page](https://aejion.github.io/4diffusion) | [Paper](http://arxiv.org/abs/2405.20674) | 4 | 5 | Official code for 4Diffusion: Multi-view Video Diffusion Model for 4D Generation. 6 | 7 | The paper presents a novel 4D generation pipeline, namely 4Diffusion, aimed at generating spatial-temporally consistent 4D content from a monocular video. We design a multi-view video diffusion model 4DM to capture multi-view spatial-temporal correlations for multi-view video generation. 8 | 9 | ## Installation Requirements 10 | 11 | The code is compatible with python 3.10.0 and pytorch 2.0.1. To create an anaconda environment named `4diffusion` with the required dependencies, run: 12 | 13 | ``` 14 | conda create -n 4diffusion python==3.10.0 15 | conda activate 4diffusion 16 | 17 | pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | ## 4D Data 22 | 23 | We filter out animated 3D shapes from the vast 3D data corpus of [Objaverse-1.0](https://objaverse.allenai.org/objaverse-1.0/). We provide ids of the curated data in `dataset/uid.npy`. We will also release the rendered multi-view videos (To be uploaded) for future works. 24 | 25 | 26 | ## Quickstart 27 | 28 | ### Download pre-trained models 29 | 30 | Please download [4DM](https://drive.google.com/drive/folders/19k3p2CfzQ6ArqpDNOy73RJeJhWfNs4i6?usp=sharing) and [ImageDream modelcard](https://huggingface.co/Peng-Wang/ImageDream/resolve/main/sd-v2.1-base-4view-ipmv.pt?download=true) and put them under `./ckpts/`. 31 | 32 | ### Multi-view Video Generation 33 | 34 | To generate multi-view videos, run: 35 | ``` 36 | bash threestudio/models/imagedream/scripts/demo.sh 37 | ``` 38 | please configure the `image`(input monocular video path), `text`(text prompt), and `num_video_frames`(number of frames of input monocular video) in `demo.sh`. The results can be found in `threestudio/models/imagedream/4dm`. 39 | 40 | We use [rembg](https://github.com/danielgatis/rembg) to segment the foreground object for 4D generation. 41 | ``` 42 | # name denotes the folder's name under threestudio/models/imagedream/4dm 43 | python threestudio/models/imagedream/scripts/remove_bg.py --name yoda 44 | ``` 45 | 46 | 47 | ### 4D Generation 48 | 49 | To generate 4D content from a monocular video, run: 50 | ``` 51 | # system.prompt_processor_multi_view.prompt: text prompt 52 | # system.prompt_processor_multi_view.image_path: monocular video path 53 | # data.multi_view.image_path: anchor video path (anchor loss in Sec3.3) 54 | # system.prompt_processor_multi_view.image_num: number of frames for training, default: 8 55 | # system.prompt_processor_multi_view.total_num: number of frames of input monocular video 56 | # data.multi_view.anchor_view_num: anchor view for anchor loss. 0: 0 azimuth; 1: 90 azimuth; 2: 180 azimuth; 3: 270 azimuth 57 | python launch.py --config ./configs/4diffusion.yaml --train \ 58 | system.prompt_processor_multi_view.prompt='baby yoda in the style of Mormookiee' \ 59 | system.prompt_processor_multi_view.image_path='./threestudio/models/imagedream/assets/yoda/0_rgba.png' \ 60 | data.multi_view.image_path='./threestudio/models/imagedream/4dm/yoda' \ 61 | system.prompt_processor_multi_view.image_num=8 \ 62 | system.prompt_processor_multi_view.total_num=25 \ 63 | data.multi_view.anchor_view_num=0 64 | ``` 65 | The results can be found in `outputs/4diffusion`. 66 | 67 | 68 | ## Citing 69 | 70 | If you find 4Diffusion helpful, please consider citing: 71 | 72 | ``` 73 | @article{zhang20244diffusion, 74 | title={4Diffusion: Multi-view Video Diffusion Model for 4D Generation}, 75 | author={Zhang, Haiyu and Chen, Xinyuan and Wang, Yaohui and Liu, Xihui and Wang, Yunhong and Qiao, Yu}, 76 | journal={arXiv preprint arXiv:2405.20674}, 77 | year={2024} 78 | } 79 | ``` 80 | 81 | ## Credits 82 | 83 | This code is built on the [threestudio-project](https://github.com/threestudio-project/threestudio), [4D-fy](https://github.com/sherwinbahmani/4dfy), and [ImageDream](https://github.com/bytedance/ImageDream). Thanks to the maintainers for their contribution to the community! 84 | -------------------------------------------------------------------------------- /configs/4diffusion.yaml: -------------------------------------------------------------------------------- 1 | name: "4diffusion" 2 | tag: "${rmspace:${system.prompt_processor_multi_view.prompt},_}" 3 | exp_root_dir: "outputs" 4 | seed: 23 5 | 6 | data_type: "single-multiview-combined-camera-datamodule" 7 | data: 8 | prob_multi_view: 1.0 9 | single_view: 10 | batch_size: [1,1] 11 | # 0-4999: 64x64, >=5000: 512x512 12 | # this drastically reduces VRAM usage as empty space is pruned in early training 13 | width: [256, 256] 14 | height: [256, 256] 15 | resolution_milestones: [5000] 16 | camera_distance_range: [2.5, 3.0] 17 | fovy_range: [15, 60] 18 | elevation_range: [0, 30] 19 | camera_perturb: 0. 20 | center_perturb: 0. 21 | up_perturb: 0. 22 | eval_camera_distance: 1.1 23 | eval_fovy_deg: 45 24 | eval_elevation_deg: 0 25 | static: false 26 | num_frames: 8 27 | simultan: true 28 | prob_single_view_video: 1.0 29 | width_vid: 144 30 | height_vid: 80 31 | sample_rand_frames: t1 32 | num_frames_factor: 1 33 | eval_height: 256 34 | eval_width: 256 35 | test_traj: 'constant' 36 | 37 | multi_view: 38 | batch_size: [4,4,4] # must be dividable by n_view 39 | n_view: 4 40 | width: [64, 192, 256] 41 | height: [64, 192, 256] 42 | resolution_milestones: [5000, 10000] 43 | camera_distance_range: [1.1, 1.1] 44 | fovy_range: [45, 45] 45 | elevation_range: [0, 5] 46 | camera_perturb: 0. 47 | center_perturb: 0. 48 | up_perturb: 0. 49 | n_val_views: 4 50 | eval_camera_distance: 2.0 51 | eval_fovy_deg: 40. 52 | relative_radius: false 53 | num_frames: ${data.single_view.num_frames} 54 | sample_rand_frames: ${data.single_view.sample_rand_frames} 55 | eval_height: ${data.single_view.eval_height} 56 | eval_width: ${data.single_view.eval_width} 57 | stage_one: false 58 | stage_one_step: 0 59 | 60 | system_type: "Fourdiffsion-system" 61 | system: 62 | prob_multi_view: ${data.prob_multi_view} 63 | prob_single_view_video: ${data.single_view.prob_single_view_video} 64 | stage: coarse 65 | geometry_type: "implicit-volume" 66 | geometry: 67 | radius: 1.0 68 | normal_type: "analytic" 69 | 70 | density_bias: "blob_magic3d" 71 | density_activation: softplus 72 | density_blob_scale: 10. 73 | density_blob_std: 0.5 74 | 75 | pos_encoding_config: 76 | otype: HashGridSpatialTime 77 | n_levels: 16 78 | n_features_per_level: 2 79 | log2_hashmap_size: 19 80 | base_resolution: 16 81 | per_level_scale: 1.447269237440378 # max resolution 4096 82 | static: ${data.single_view.static} 83 | num_frames: ${data.single_view.num_frames} 84 | 85 | anneal_density_blob_std_config: 86 | min_anneal_step: 0 87 | max_anneal_step: 50000 88 | start_val: ${system.geometry.density_blob_std} 89 | end_val: 0.5 90 | 91 | material_type: "no-material" 92 | material: 93 | n_output_dims: 3 94 | color_activation: sigmoid 95 | 96 | background_type: "solid-color-background" 97 | background: 98 | learned: false 99 | 100 | renderer_type: "nerf-volume-renderer" 101 | renderer: 102 | radius: ${system.geometry.radius} 103 | num_samples_per_ray: 256 104 | 105 | prompt_processor_type_multi_view: "stable-diffusion-prompt-processor" 106 | prompt_processor_multi_view: 107 | pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base" 108 | prompt: ??? 109 | negative_prompt: "ugly, bad anatomy, blurry, pixelated obscure, unnatural colors, poor lighting, dull, and unclear, cropped, lowres, low quality, artifacts, duplicate, morbid, mutilated, poorly drawn face, deformed, dehydrated, bad proportions" 110 | front_threshold: 30. 111 | back_threshold: 30. 112 | image_num: 0 113 | 114 | guidance_type_multi_view: "multiview-video-diffusion-guidance" 115 | guidance_multi_view: 116 | model_name: "sd-v2.1-base-4view" 117 | ckpt_path: './ckpts/sd-v2.1-base-4view-ipmv.pt' # path to a pre-downloaded checkpoint file (null for loading from URL) 118 | config_path: './threestudio/models/imagedream/imagedream/configs/sd_v2_base_ipmv.yaml' 119 | guidance_scale: 5.0 120 | min_step_percent: [ 0, 0.98, 0.02, 5000 ] # (start_iter, start_val, end_val, end_iter) 121 | max_step_percent: [ 0, 0.98, 0.25, 5000 ] 122 | recon_loss: true 123 | recon_std_rescale: 0.2 124 | ip_mode: "pixel" 125 | 126 | 127 | loggers: 128 | wandb: 129 | enable: false 130 | project: "threestudio" 131 | 132 | loss: 133 | lambda_sds: 1. 134 | lambda_orient: [0, 10., 1000., 5000] 135 | lambda_sparsity: 100. 136 | lambda_opaque: [10000, 0., 100., 10001] 137 | lambda_ssim: 50. 138 | lambda_lpips: 100. 139 | lambda_z_variance: 0 140 | optimizer: 141 | name: AdamW 142 | args: 143 | betas: [0.9, 0.99] 144 | eps: 1.e-15 145 | params: 146 | geometry.encoding: 147 | lr: 0.01 148 | geometry.density_network: 149 | lr: 0.001 150 | geometry.feature_network: 151 | lr: 0.001 152 | background: 153 | lr: 0.001 154 | 155 | trainer: 156 | max_steps: 35000 157 | log_every_n_steps: 1 158 | num_sanity_val_steps: 0 159 | val_check_interval: 1000 160 | enable_progress_bar: true 161 | precision: 16-mixed 162 | 163 | checkpoint: 164 | save_last: true 165 | save_top_k: 0 #-1 166 | every_n_train_steps: ${trainer.max_steps} 167 | -------------------------------------------------------------------------------- /dataset/uid.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/dataset/uid.npy -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import sys 5 | 6 | 7 | class ColoredFilter(logging.Filter): 8 | """ 9 | A logging filter to add color to certain log levels. 10 | """ 11 | 12 | RESET = "\033[0m" 13 | RED = "\033[31m" 14 | GREEN = "\033[32m" 15 | YELLOW = "\033[33m" 16 | BLUE = "\033[34m" 17 | MAGENTA = "\033[35m" 18 | CYAN = "\033[36m" 19 | 20 | COLORS = { 21 | "WARNING": YELLOW, 22 | "INFO": GREEN, 23 | "DEBUG": BLUE, 24 | "CRITICAL": MAGENTA, 25 | "ERROR": RED, 26 | } 27 | 28 | RESET = "\x1b[0m" 29 | 30 | def __init__(self): 31 | super().__init__() 32 | 33 | def filter(self, record): 34 | if record.levelname in self.COLORS: 35 | color_start = self.COLORS[record.levelname] 36 | record.levelname = f"{color_start}[{record.levelname}]" 37 | record.msg = f"{record.msg}{self.RESET}" 38 | return True 39 | 40 | 41 | def main() -> None: 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("--config", required=True, help="path to config file") 44 | parser.add_argument("--gpu", default="0", help="GPU(s) to be used") 45 | 46 | group = parser.add_mutually_exclusive_group(required=True) 47 | group.add_argument("--train", action="store_true") 48 | group.add_argument("--validate", action="store_true") 49 | group.add_argument("--test", action="store_true") 50 | group.add_argument("--export", action="store_true") 51 | 52 | parser.add_argument( 53 | "--verbose", action="store_true", help="if true, set logging level to DEBUG" 54 | ) 55 | 56 | parser.add_argument( 57 | "--typecheck", 58 | action="store_true", 59 | help="whether to enable dynamic type checking", 60 | ) 61 | 62 | args, extras = parser.parse_known_args() 63 | 64 | # set CUDA_VISIBLE_DEVICES then import pytorch-lightning 65 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 66 | # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 67 | n_gpus = len(args.gpu.split(",")) 68 | 69 | import pytorch_lightning as pl 70 | import torch 71 | from pytorch_lightning import Trainer 72 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 73 | from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger 74 | from pytorch_lightning.utilities.rank_zero import rank_zero_only 75 | 76 | if args.typecheck: 77 | from jaxtyping import install_import_hook 78 | 79 | install_import_hook("threestudio", "typeguard.typechecked") 80 | 81 | import threestudio 82 | from threestudio.systems.base import BaseSystem 83 | from threestudio.utils.callbacks import ( 84 | CodeSnapshotCallback, 85 | ConfigSnapshotCallback, 86 | CustomProgressBar, 87 | ) 88 | from threestudio.utils.config import ExperimentConfig, load_config 89 | from threestudio.utils.typing import Optional 90 | 91 | logger = logging.getLogger("pytorch_lightning") 92 | if args.verbose: 93 | logger.setLevel(logging.DEBUG) 94 | 95 | for handler in logger.handlers: 96 | if handler.stream == sys.stderr: # type: ignore 97 | handler.setFormatter(logging.Formatter("%(levelname)s %(message)s")) 98 | handler.addFilter(ColoredFilter()) 99 | 100 | # parse YAML config to OmegaConf 101 | cfg: ExperimentConfig 102 | cfg = load_config(args.config, cli_args=extras, n_gpus=n_gpus) 103 | 104 | pl.seed_everything(cfg.seed) 105 | 106 | dm = threestudio.find(cfg.data_type)(cfg.data) 107 | system: BaseSystem = threestudio.find(cfg.system_type)( 108 | cfg.system, resumed=cfg.resume is not None 109 | ) 110 | system.set_save_dir(os.path.join(cfg.trial_dir, "save")) 111 | callbacks = [] 112 | if args.train: 113 | callbacks += [ 114 | ModelCheckpoint( 115 | dirpath=os.path.join(cfg.trial_dir, "ckpts"), **cfg.checkpoint 116 | ), 117 | LearningRateMonitor(logging_interval="step"), 118 | CustomProgressBar(refresh_rate=1), 119 | CodeSnapshotCallback( 120 | os.path.join(cfg.trial_dir, "code"), use_version=False 121 | ), 122 | ConfigSnapshotCallback( 123 | args.config, 124 | cfg, 125 | os.path.join(cfg.trial_dir, "configs"), 126 | use_version=False, 127 | ), 128 | ] 129 | 130 | def write_to_text(file, lines): 131 | with open(file, "w") as f: 132 | for line in lines: 133 | f.write(line + "\n") 134 | 135 | loggers = [] 136 | if args.train: 137 | # make tensorboard logging dir to suppress warning 138 | rank_zero_only( 139 | lambda: os.makedirs(os.path.join(cfg.trial_dir, "tb_logs"), exist_ok=True) 140 | )() 141 | loggers += [ 142 | TensorBoardLogger(cfg.trial_dir, name="tb_logs"), 143 | CSVLogger(cfg.trial_dir, name="csv_logs"), 144 | ] + system.get_loggers() 145 | rank_zero_only( 146 | lambda: write_to_text( 147 | os.path.join(cfg.trial_dir, "log.txt"), 148 | ["python " + " ".join(sys.argv), str(args)], 149 | ) 150 | )() 151 | trainer = Trainer( 152 | callbacks=callbacks, logger=loggers, inference_mode=False, **cfg.trainer 153 | ) 154 | 155 | def set_system_status(system: BaseSystem, ckpt_path: Optional[str]): 156 | if ckpt_path is None: 157 | return 158 | ckpt = torch.load(ckpt_path, map_location="cpu") 159 | system.set_resume_status(ckpt["epoch"], ckpt["global_step"]) 160 | 161 | if args.train: 162 | trainer.fit(system, datamodule=dm, ckpt_path=cfg.resume) 163 | trainer.test(system, datamodule=dm) 164 | elif args.validate: 165 | # manually set epoch and global_step as they cannot be automatically resumed 166 | set_system_status(system, cfg.resume) 167 | trainer.validate(system, datamodule=dm, ckpt_path=cfg.resume) 168 | elif args.test: 169 | # manually set epoch and global_step as they cannot be automatically resumed 170 | set_system_status(system, cfg.resume) 171 | trainer.test(system, datamodule=dm, ckpt_path=cfg.resume) 172 | elif args.export: 173 | set_system_status(system, cfg.resume) 174 | trainer.predict(system, datamodule=dm, ckpt_path=cfg.resume) 175 | 176 | 177 | if __name__ == "__main__": 178 | main() 179 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | lightning==2.2.0 2 | omegaconf==2.3.0 3 | jaxtyping 4 | typeguard 5 | git+https://github.com/KAIR-BAIR/nerfacc.git@v0.5.2 6 | git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch 7 | diffusers==0.21.2 8 | transformers==4.25.1 9 | accelerate 10 | opencv-python 11 | tensorboard 12 | matplotlib 13 | imageio 14 | imageio[ffmpeg] 15 | trimesh 16 | git+https://github.com/NVlabs/nvdiffrast.git 17 | libigl 18 | xatlas 19 | trimesh[easy] 20 | networkx 21 | pysdf 22 | 23 | # deepfloyd 24 | xformers==0.0.21 25 | bitsandbytes 26 | sentencepiece 27 | safetensors 28 | huggingface_hub 29 | libigl 30 | xatlas 31 | trimesh 32 | networkx 33 | PyMCubes 34 | wandb 35 | 36 | # for zero123 37 | einops 38 | kornia 39 | taming-transformers-rom1504 40 | git+https://github.com/openai/CLIP.git 41 | open-clip-torch==2.24.0 42 | rembg 43 | av==11.0.0 -------------------------------------------------------------------------------- /threestudio/__init__.py: -------------------------------------------------------------------------------- 1 | __modules__ = {} 2 | 3 | 4 | def register(name): 5 | def decorator(cls): 6 | __modules__[name] = cls 7 | return cls 8 | 9 | return decorator 10 | 11 | 12 | def find(name): 13 | return __modules__[name] 14 | 15 | 16 | ### grammar sugar for logging utilities ### 17 | import logging 18 | 19 | logger = logging.getLogger("pytorch_lightning") 20 | 21 | from pytorch_lightning.utilities.rank_zero import ( 22 | rank_zero_debug, 23 | rank_zero_info, 24 | rank_zero_only, 25 | ) 26 | 27 | debug = rank_zero_debug 28 | info = rank_zero_info 29 | 30 | 31 | @rank_zero_only 32 | def warn(*args, **kwargs): 33 | logger.warn(*args, **kwargs) 34 | 35 | 36 | from . import data, models, systems 37 | -------------------------------------------------------------------------------- /threestudio/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import co3d, image, uncond, multiview, random_multiview, single_multiview_combined 2 | -------------------------------------------------------------------------------- /threestudio/data/image.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from dataclasses import dataclass, field 4 | 5 | import cv2 6 | import numpy as np 7 | import pytorch_lightning as pl 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.utils.data import DataLoader, Dataset, IterableDataset 11 | 12 | from threestudio import register 13 | from threestudio.data.uncond import ( 14 | RandomCameraDataModuleConfig, 15 | RandomCameraDataset, 16 | RandomCameraIterableDataset, 17 | ) 18 | from threestudio.utils.config import parse_structured 19 | from threestudio.utils.misc import get_rank 20 | from threestudio.utils.ops import ( 21 | get_mvp_matrix, 22 | get_projection_matrix, 23 | get_ray_directions, 24 | get_rays, 25 | ) 26 | from threestudio.utils.typing import * 27 | 28 | 29 | @dataclass 30 | class SingleImageDataModuleConfig: 31 | height: int = 96 32 | width: int = 96 33 | default_elevation_deg: float = 0.0 34 | default_azimuth_deg: float = -180.0 35 | default_camera_distance: float = 1.2 36 | default_fovy_deg: float = 60.0 37 | image_path: str = "" 38 | use_random_camera: bool = True 39 | random_camera: dict = field(default_factory=dict) 40 | rays_noise_scale: float = 2e-3 41 | batch_size: int = 1 42 | 43 | 44 | class SingleImageDataBase: 45 | def setup(self, cfg, split): 46 | self.split = split 47 | self.rank = get_rank() 48 | self.cfg: SingleImageDataModuleConfig = cfg 49 | 50 | if self.cfg.use_random_camera: 51 | random_camera_cfg = parse_structured( 52 | RandomCameraDataModuleConfig, self.cfg.get("random_camera", {}) 53 | ) 54 | if split == "train": 55 | self.random_pose_generator = RandomCameraIterableDataset( 56 | random_camera_cfg 57 | ) 58 | else: 59 | self.random_pose_generator = RandomCameraDataset( 60 | random_camera_cfg, split 61 | ) 62 | 63 | # load image 64 | assert os.path.exists(self.cfg.image_path) 65 | rgba = cv2.cvtColor( 66 | cv2.imread(self.cfg.image_path, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA 67 | ) 68 | rgba = ( 69 | cv2.resize( 70 | rgba, (self.cfg.width, self.cfg.height), interpolation=cv2.INTER_AREA 71 | ).astype(np.float32) 72 | / 255.0 73 | ) 74 | rgb = rgba[..., :3] * rgba[..., 3:] + (1 - rgba[..., 3:]) 75 | self.rgb: Float[Tensor, "1 H W 3"] = ( 76 | torch.from_numpy(rgb).unsqueeze(0).contiguous().to(self.rank) 77 | ) 78 | self.mask: Float[Tensor, "1 H W 1"] = ( 79 | torch.from_numpy(rgba[..., 3:] > 0.5).unsqueeze(0).to(self.rank) 80 | ) 81 | print( 82 | f"[INFO] single image dataset: load image {self.cfg.image_path} {self.rgb.shape}" 83 | ) 84 | 85 | # load depth 86 | depth_path = self.cfg.image_path.replace("_rgba.png", "_depth.png") 87 | assert os.path.exists(depth_path) 88 | depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) 89 | depth = cv2.resize( 90 | depth, (self.cfg.width, self.cfg.height), interpolation=cv2.INTER_AREA 91 | ) 92 | self.depth: Float[Tensor, "1 H W 1"] = ( 93 | torch.from_numpy(depth.astype(np.float32) / 255.0) 94 | .unsqueeze(0) 95 | .to(self.rank) 96 | ) 97 | print( 98 | f"[INFO] single image dataset: load depth {depth_path} {self.depth.shape}" 99 | ) 100 | 101 | elevation_deg = torch.FloatTensor([self.cfg.default_elevation_deg]) 102 | azimuth_deg = torch.FloatTensor([self.cfg.default_azimuth_deg]) 103 | camera_distance = torch.FloatTensor([self.cfg.default_camera_distance]) 104 | 105 | elevation = elevation_deg * math.pi / 180 106 | azimuth = azimuth_deg * math.pi / 180 107 | camera_position: Float[Tensor, "1 3"] = torch.stack( 108 | [ 109 | camera_distance * torch.cos(elevation) * torch.cos(azimuth), 110 | camera_distance * torch.cos(elevation) * torch.sin(azimuth), 111 | camera_distance * torch.sin(elevation), 112 | ], 113 | dim=-1, 114 | ) 115 | 116 | center: Float[Tensor, "1 3"] = torch.zeros_like(camera_position) 117 | up: Float[Tensor, "1 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None] 118 | 119 | light_position: Float[Tensor, "1 3"] = camera_position 120 | lookat: Float[Tensor, "1 3"] = F.normalize(center - camera_position, dim=-1) 121 | right: Float[Tensor, "1 3"] = F.normalize(torch.cross(lookat, up), dim=-1) 122 | up = F.normalize(torch.cross(right, lookat), dim=-1) 123 | c2w: Float[Tensor, "1 3 4"] = torch.cat( 124 | [torch.stack([right, up, -lookat], dim=-1), camera_position[:, :, None]], 125 | dim=-1, 126 | ) 127 | 128 | # get directions by dividing directions_unit_focal by focal length 129 | fovy = torch.deg2rad(torch.FloatTensor([self.cfg.default_fovy_deg])) 130 | focal_length = 0.5 * self.cfg.height / torch.tan(0.5 * fovy) 131 | directions_unit_focal = get_ray_directions( 132 | H=self.cfg.height, W=self.cfg.width, focal=1.0 133 | ) 134 | directions: Float[Tensor, "1 H W 3"] = directions_unit_focal[None] 135 | directions[:, :, :, :2] = directions[:, :, :, :2] / focal_length 136 | 137 | rays_o, rays_d = get_rays( 138 | directions, c2w, keepdim=True, noise_scale=self.cfg.rays_noise_scale 139 | ) 140 | 141 | proj_mtx: Float[Tensor, "4 4"] = get_projection_matrix( 142 | fovy, self.cfg.width / self.cfg.height, 0.1, 100.0 143 | ) # FIXME: hard-coded near and far 144 | mvp_mtx: Float[Tensor, "4 4"] = get_mvp_matrix(c2w, proj_mtx) 145 | 146 | self.rays_o, self.rays_d = rays_o, rays_d 147 | self.mvp_mtx = mvp_mtx 148 | self.camera_position = camera_position 149 | self.light_position = light_position 150 | self.elevation_deg, self.azimuth_deg = elevation_deg, azimuth_deg 151 | self.camera_distance = camera_distance 152 | 153 | def get_all_images(self): 154 | return self.rgb 155 | 156 | 157 | class SingleImageIterableDataset(IterableDataset, SingleImageDataBase): 158 | def __init__(self, cfg: Any, split: str) -> None: 159 | super().__init__() 160 | self.setup(cfg, split) 161 | 162 | def collate(self, batch) -> Dict[str, Any]: 163 | batch = { 164 | "rays_o": self.rays_o, 165 | "rays_d": self.rays_d, 166 | "mvp_mtx": self.mvp_mtx, 167 | "camera_positions": self.camera_position, 168 | "light_positions": self.light_position, 169 | "elevation": self.elevation_deg, 170 | "azimuth": self.azimuth_deg, 171 | "camera_distances": self.camera_distance, 172 | "rgb": self.rgb, 173 | "depth": self.depth, 174 | "mask": self.mask, 175 | } 176 | if self.cfg.use_random_camera: 177 | batch["random_camera"] = self.random_pose_generator.collate(None) 178 | 179 | return batch 180 | 181 | def __iter__(self): 182 | while True: 183 | yield {} 184 | 185 | 186 | class SingleImageDataset(Dataset, SingleImageDataBase): 187 | def __init__(self, cfg: Any, split: str) -> None: 188 | super().__init__() 189 | self.setup(cfg, split) 190 | 191 | def __len__(self): 192 | return len(self.random_pose_generator) 193 | 194 | def __getitem__(self, index): 195 | return self.random_pose_generator[index] 196 | # if index == 0: 197 | # return { 198 | # 'rays_o': self.rays_o[0], 199 | # 'rays_d': self.rays_d[0], 200 | # 'mvp_mtx': self.mvp_mtx[0], 201 | # 'camera_positions': self.camera_position[0], 202 | # 'light_positions': self.light_position[0], 203 | # 'elevation': self.elevation_deg[0], 204 | # 'azimuth': self.azimuth_deg[0], 205 | # 'camera_distances': self.camera_distance[0], 206 | # 'rgb': self.rgb[0], 207 | # 'depth': self.depth[0], 208 | # 'mask': self.mask[0] 209 | # } 210 | # else: 211 | # return self.random_pose_generator[index - 1] 212 | 213 | 214 | @register("single-image-datamodule") 215 | class SingleImageDataModule(pl.LightningDataModule): 216 | cfg: SingleImageDataModuleConfig 217 | 218 | def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: 219 | super().__init__() 220 | self.cfg = parse_structured(SingleImageDataModuleConfig, cfg) 221 | 222 | def setup(self, stage=None) -> None: 223 | if stage in [None, "fit"]: 224 | self.train_dataset = SingleImageIterableDataset(self.cfg, "train") 225 | if stage in [None, "fit", "validate"]: 226 | self.val_dataset = SingleImageDataset(self.cfg, "val") 227 | if stage in [None, "test", "predict"]: 228 | self.test_dataset = SingleImageDataset(self.cfg, "test") 229 | 230 | def prepare_data(self): 231 | pass 232 | 233 | def general_loader(self, dataset, batch_size, collate_fn=None) -> DataLoader: 234 | return DataLoader( 235 | dataset, num_workers=0, batch_size=batch_size, collate_fn=collate_fn 236 | ) 237 | 238 | def train_dataloader(self) -> DataLoader: 239 | return self.general_loader( 240 | self.train_dataset, 241 | batch_size=self.cfg.batch_size, 242 | collate_fn=self.train_dataset.collate, 243 | ) 244 | 245 | def val_dataloader(self) -> DataLoader: 246 | return self.general_loader(self.val_dataset, batch_size=1) 247 | 248 | def test_dataloader(self) -> DataLoader: 249 | return self.general_loader(self.test_dataset, batch_size=1) 250 | 251 | def predict_dataloader(self) -> DataLoader: 252 | return self.general_loader(self.test_dataset, batch_size=1) 253 | -------------------------------------------------------------------------------- /threestudio/lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /threestudio/lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | 18 | def __init__(self, net_type: str = "alex", version: str = "0.1"): 19 | 20 | assert version in ["0.1"], "v0.1 is only supported now" 21 | 22 | super(LPIPS, self).__init__() 23 | 24 | # pretrained network 25 | self.net = get_network(net_type) 26 | 27 | # linear layers 28 | self.lin = LinLayers(self.net.n_channels_list) 29 | self.lin.load_state_dict(get_state_dict(net_type, version)) 30 | self.eval() 31 | 32 | def forward(self, x: torch.Tensor, y: torch.Tensor): 33 | feat_x, feat_y = self.net(x), self.net(y) 34 | 35 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 36 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 37 | 38 | return torch.sum(torch.cat(res, 0), 0, True) 39 | -------------------------------------------------------------------------------- /threestudio/lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /threestudio/lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /threestudio/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | background, 3 | exporters, 4 | geometry, 5 | guidance, 6 | materials, 7 | prompt_processors, 8 | renderers, 9 | ) 10 | -------------------------------------------------------------------------------- /threestudio/models/background/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | base, 3 | neural_environment_map_background, 4 | solid_color_background, 5 | textured_background, 6 | ) 7 | -------------------------------------------------------------------------------- /threestudio/models/background/base.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.utils.base import BaseModule 10 | from threestudio.utils.typing import * 11 | 12 | 13 | class BaseBackground(BaseModule): 14 | @dataclass 15 | class Config(BaseModule.Config): 16 | pass 17 | 18 | cfg: Config 19 | 20 | def configure(self): 21 | pass 22 | 23 | def forward(self, dirs: Float[Tensor, "*B 3"]) -> Float[Tensor, "*B 3"]: 24 | raise NotImplementedError 25 | -------------------------------------------------------------------------------- /threestudio/models/background/neural_environment_map_background.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.background.base import BaseBackground 10 | from threestudio.models.networks import get_encoding, get_mlp 11 | from threestudio.utils.ops import get_activation 12 | from threestudio.utils.typing import * 13 | 14 | 15 | @threestudio.register("neural-environment-map-background") 16 | class NeuralEnvironmentMapBackground(BaseBackground): 17 | @dataclass 18 | class Config(BaseBackground.Config): 19 | n_output_dims: int = 3 20 | color_activation: str = "sigmoid" 21 | dir_encoding_config: dict = field( 22 | default_factory=lambda: {"otype": "SphericalHarmonics", "degree": 3} 23 | ) 24 | mlp_network_config: dict = field( 25 | default_factory=lambda: { 26 | "otype": "VanillaMLP", 27 | "activation": "ReLU", 28 | "n_neurons": 16, 29 | "n_hidden_layers": 2, 30 | } 31 | ) 32 | random_aug: bool = False 33 | random_aug_prob: float = 0.5 34 | eval_color: Optional[Tuple[float, float, float]] = None 35 | share_aug_bg: bool = False 36 | 37 | cfg: Config 38 | 39 | def configure(self) -> None: 40 | self.encoding = get_encoding(3, self.cfg.dir_encoding_config) 41 | self.network = get_mlp( 42 | self.encoding.n_output_dims, 43 | self.cfg.n_output_dims, 44 | self.cfg.mlp_network_config, 45 | ) 46 | self.random_color = torch.rand(self.cfg.n_output_dims) 47 | self.use_aug = False 48 | 49 | def config_aug(self): 50 | if ( self.cfg.random_aug 51 | and random.random() < self.cfg.random_aug_prob 52 | ): 53 | self.use_aug = True 54 | else: 55 | self.use_aug = False 56 | 57 | def forward(self, dirs: Float[Tensor, "*B 3"]) -> Float[Tensor, "*B 3"]: 58 | if not self.training and self.cfg.eval_color is not None: 59 | return torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to( 60 | dirs 61 | ) * torch.as_tensor(self.cfg.eval_color).to(dirs) 62 | # viewdirs must be normalized before passing to this function 63 | squeezed_dim = dirs.view(-1, 3).shape[0] 64 | dirs = (dirs + 1.0) / 2.0 # (-1, 1) => (0, 1) 65 | dirs_embd = self.encoding(dirs.view(-1, 3)) 66 | color = self.network(dirs_embd).view(*dirs.shape[:-1], self.cfg.n_output_dims) 67 | color = get_activation(self.cfg.color_activation)(color) 68 | 69 | if (self.use_aug and self.training): 70 | # use random background color with probability random_aug_prob 71 | # n_color = 1 if self.cfg.share_aug_bg else dirs.shape[0] 72 | color = color * 0 + ( # prevent checking for unused parameters in DDP 73 | self.random_color 74 | .to(dirs)[None, :] 75 | .expand(squeezed_dim, -1) 76 | .view(*dirs.shape[:-1], -1) 77 | ) 78 | return color 79 | -------------------------------------------------------------------------------- /threestudio/models/background/solid_color_background.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import threestudio 8 | from threestudio.models.background.base import BaseBackground 9 | from threestudio.utils.typing import * 10 | 11 | 12 | @threestudio.register("solid-color-background") 13 | class SolidColorBackground(BaseBackground): 14 | @dataclass 15 | class Config(BaseBackground.Config): 16 | n_output_dims: int = 3 17 | color: Tuple = (1.0, 1.0, 1.0) 18 | learned: bool = False 19 | 20 | cfg: Config 21 | 22 | def configure(self) -> None: 23 | self.env_color: Float[Tensor, "Nc"] 24 | if self.cfg.learned: 25 | self.env_color = nn.Parameter( 26 | torch.as_tensor(self.cfg.color, dtype=torch.float32) 27 | ) 28 | else: 29 | self.register_buffer( 30 | "env_color", torch.as_tensor(self.cfg.color, dtype=torch.float32) 31 | ) 32 | 33 | def forward(self, dirs: Float[Tensor, "*B 3"]) -> Float[Tensor, "*B Nc"]: 34 | if not self.training: 35 | return ( 36 | torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to(dirs) 37 | ) 38 | return ( 39 | torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to(dirs) 40 | * self.env_color 41 | ) 42 | -------------------------------------------------------------------------------- /threestudio/models/background/textured_background.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import threestudio 8 | from threestudio.models.background.base import BaseBackground 9 | from threestudio.utils.ops import get_activation 10 | from threestudio.utils.typing import * 11 | 12 | 13 | @threestudio.register("textured-background") 14 | class TexturedBackground(BaseBackground): 15 | @dataclass 16 | class Config(BaseBackground.Config): 17 | n_output_dims: int = 3 18 | height: int = 64 19 | width: int = 64 20 | color_activation: str = "sigmoid" 21 | 22 | cfg: Config 23 | 24 | def configure(self) -> None: 25 | self.texture = nn.Parameter( 26 | torch.randn((1, self.cfg.n_output_dims, self.cfg.height, self.cfg.width)) 27 | ) 28 | 29 | def spherical_xyz_to_uv(self, dirs: Float[Tensor, "*B 3"]) -> Float[Tensor, "*B 2"]: 30 | x, y, z = dirs[..., 0], dirs[..., 1], dirs[..., 2] 31 | xy = (x**2 + y**2) ** 0.5 32 | u = torch.atan2(xy, z) / torch.pi 33 | v = torch.atan2(y, x) / (torch.pi * 2) + 0.5 34 | uv = torch.stack([u, v], -1) 35 | return uv 36 | 37 | def forward(self, dirs: Float[Tensor, "*B 3"]) -> Float[Tensor, "*B Nc"]: 38 | dirs_shape = dirs.shape[:-1] 39 | uv = self.spherical_xyz_to_uv(dirs) 40 | uv = 2 * uv - 1 # rescale to [-1, 1] for grid_sample 41 | uv = uv.reshape(1, -1, 1, 2) 42 | color = ( 43 | F.grid_sample( 44 | self.texture, 45 | uv, 46 | mode="bilinear", 47 | padding_mode="reflection", 48 | align_corners=False, 49 | ) 50 | .reshape(self.cfg.n_output_dims, -1) 51 | .T.reshape(*dirs_shape, self.cfg.n_output_dims) 52 | ) 53 | color = get_activation(self.cfg.color_activation)(color) 54 | return color 55 | -------------------------------------------------------------------------------- /threestudio/models/exporters/__init__.py: -------------------------------------------------------------------------------- 1 | from . import base, mesh_exporter 2 | -------------------------------------------------------------------------------- /threestudio/models/exporters/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import threestudio 4 | from threestudio.models.background.base import BaseBackground 5 | from threestudio.models.geometry.base import BaseImplicitGeometry 6 | from threestudio.models.materials.base import BaseMaterial 7 | from threestudio.utils.base import BaseObject 8 | from threestudio.utils.typing import * 9 | 10 | 11 | @dataclass 12 | class ExporterOutput: 13 | save_name: str 14 | save_type: str 15 | params: Dict[str, Any] 16 | 17 | 18 | class Exporter(BaseObject): 19 | @dataclass 20 | class Config(BaseObject.Config): 21 | save_video: bool = False 22 | 23 | cfg: Config 24 | 25 | def configure( 26 | self, 27 | geometry: BaseImplicitGeometry, 28 | material: BaseMaterial, 29 | background: BaseBackground, 30 | ) -> None: 31 | @dataclass 32 | class SubModules: 33 | geometry: BaseImplicitGeometry 34 | material: BaseMaterial 35 | background: BaseBackground 36 | 37 | self.sub_modules = SubModules(geometry, material, background) 38 | 39 | @property 40 | def geometry(self) -> BaseImplicitGeometry: 41 | return self.sub_modules.geometry 42 | 43 | @property 44 | def material(self) -> BaseMaterial: 45 | return self.sub_modules.material 46 | 47 | @property 48 | def background(self) -> BaseBackground: 49 | return self.sub_modules.background 50 | 51 | def __call__(self, *args, **kwargs) -> List[ExporterOutput]: 52 | raise NotImplementedError 53 | -------------------------------------------------------------------------------- /threestudio/models/exporters/mesh_exporter.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | import threestudio 8 | from threestudio.models.background.base import BaseBackground 9 | from threestudio.models.exporters.base import Exporter, ExporterOutput 10 | from threestudio.models.geometry.base import BaseImplicitGeometry 11 | from threestudio.models.materials.base import BaseMaterial 12 | from threestudio.models.mesh import Mesh 13 | from threestudio.utils.rasterize import NVDiffRasterizerContext 14 | from threestudio.utils.typing import * 15 | 16 | 17 | @threestudio.register("mesh-exporter") 18 | class MeshExporter(Exporter): 19 | @dataclass 20 | class Config(Exporter.Config): 21 | fmt: str = "obj-mtl" # in ['obj-mtl', 'obj'], TODO: fbx 22 | save_name: str = "model" 23 | save_normal: bool = False 24 | save_uv: bool = True 25 | save_texture: bool = True 26 | texture_size: int = 1024 27 | texture_format: str = "jpg" 28 | xatlas_chart_options: dict = field(default_factory=dict) 29 | xatlas_pack_options: dict = field(default_factory=dict) 30 | context_type: str = "gl" 31 | 32 | cfg: Config 33 | 34 | def configure( 35 | self, 36 | geometry: BaseImplicitGeometry, 37 | material: BaseMaterial, 38 | background: BaseBackground, 39 | ) -> None: 40 | super().configure(geometry, material, background) 41 | self.ctx = NVDiffRasterizerContext(self.cfg.context_type, self.device) 42 | 43 | def __call__(self) -> List[ExporterOutput]: 44 | mesh: Mesh = self.geometry.isosurface() 45 | 46 | if self.cfg.fmt == "obj-mtl": 47 | return self.export_obj_with_mtl(mesh) 48 | elif self.cfg.fmt == "obj": 49 | return self.export_obj(mesh) 50 | else: 51 | raise ValueError(f"Unsupported mesh export format: {self.cfg.fmt}") 52 | 53 | def export_obj_with_mtl(self, mesh: Mesh) -> List[ExporterOutput]: 54 | params = { 55 | "mesh": mesh, 56 | "save_mat": True, 57 | "save_normal": self.cfg.save_normal, 58 | "save_uv": self.cfg.save_uv, 59 | "save_vertex_color": False, 60 | "map_Kd": None, 61 | "map_Ks": None, 62 | "map_Bump": None, 63 | "map_format": self.cfg.texture_format, 64 | } 65 | 66 | if self.cfg.save_uv: 67 | mesh.unwrap_uv(self.cfg.xatlas_chart_options, self.cfg.xatlas_pack_options) 68 | 69 | if self.cfg.save_texture: 70 | threestudio.info("Exporting textures ...") 71 | assert self.cfg.save_uv, "save_uv must be True when save_texture is True" 72 | # clip space transform 73 | uv_clip = mesh.v_tex * 2.0 - 1.0 74 | # pad to four component coordinate 75 | uv_clip4 = torch.cat( 76 | ( 77 | uv_clip, 78 | torch.zeros_like(uv_clip[..., 0:1]), 79 | torch.ones_like(uv_clip[..., 0:1]), 80 | ), 81 | dim=-1, 82 | ) 83 | # rasterize 84 | rast, _ = self.ctx.rasterize_one( 85 | uv_clip4, mesh.t_tex_idx, (self.cfg.texture_size, self.cfg.texture_size) 86 | ) 87 | 88 | hole_mask = ~(rast[:, :, 3] > 0) 89 | 90 | def uv_padding(image): 91 | uv_padding_size = self.cfg.xatlas_pack_options.get("padding", 2) 92 | inpaint_image = ( 93 | cv2.inpaint( 94 | (image.detach().cpu().numpy() * 255).astype(np.uint8), 95 | (hole_mask.detach().cpu().numpy() * 255).astype(np.uint8), 96 | uv_padding_size, 97 | cv2.INPAINT_TELEA, 98 | ) 99 | / 255.0 100 | ) 101 | return torch.from_numpy(inpaint_image).to(image) 102 | 103 | # Interpolate world space position 104 | gb_pos, _ = self.ctx.interpolate_one( 105 | mesh.v_pos, rast[None, ...], mesh.t_pos_idx 106 | ) 107 | gb_pos = gb_pos[0] 108 | 109 | # Sample out textures from MLP 110 | geo_out = self.geometry.export(points=gb_pos) 111 | mat_out = self.material.export(points=gb_pos, **geo_out) 112 | 113 | threestudio.info( 114 | "Perform UV padding on texture maps to avoid seams, may take a while ..." 115 | ) 116 | if "normal" in geo_out: 117 | params["map_Bump"] = uv_padding(geo_out["normal"]) 118 | 119 | if "albedo" in mat_out: 120 | params["map_Kd"] = uv_padding(mat_out["albedo"]) 121 | else: 122 | threestudio.warn( 123 | "save_texture is True but no albedo texture found, using default white texture" 124 | ) 125 | # TODO: map_Ks 126 | return [ 127 | ExporterOutput( 128 | save_name=f"{self.cfg.save_name}.obj", save_type="obj", params=params 129 | ) 130 | ] 131 | 132 | def export_obj(self, mesh: Mesh) -> List[ExporterOutput]: 133 | params = { 134 | "mesh": mesh, 135 | "save_mat": False, 136 | "save_normal": self.cfg.save_normal, 137 | "save_uv": self.cfg.save_uv, 138 | "save_vertex_color": False, 139 | "map_Kd": None, 140 | "map_Ks": None, 141 | "map_Bump": None, 142 | "map_format": self.cfg.texture_format, 143 | } 144 | 145 | if self.cfg.save_uv: 146 | mesh.unwrap_uv(self.cfg.xatlas_chart_options, self.cfg.xatlas_pack_options) 147 | 148 | if self.cfg.save_texture: 149 | threestudio.info("Exporting textures ...") 150 | geo_out = self.geometry.export(points=mesh.v_pos) 151 | mat_out = self.material.export(points=mesh.v_pos, **geo_out) 152 | 153 | if "albedo" in mat_out: 154 | mesh.set_vertex_color(mat_out["albedo"]) 155 | params["save_vertex_color"] = True 156 | else: 157 | threestudio.warn( 158 | "save_texture is True but no albedo texture found, not saving vertex color" 159 | ) 160 | 161 | return [ 162 | ExporterOutput( 163 | save_name=f"{self.cfg.save_name}.obj", save_type="obj", params=params 164 | ) 165 | ] 166 | -------------------------------------------------------------------------------- /threestudio/models/geometry/__init__.py: -------------------------------------------------------------------------------- 1 | from . import base, implicit_sdf, implicit_volume, tetrahedra_sdf_grid, volume_grid, implicit_volume_kplane_e 2 | -------------------------------------------------------------------------------- /threestudio/models/geometry/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.isosurface import ( 10 | IsosurfaceHelper, 11 | MarchingCubeCPUHelper, 12 | MarchingTetrahedraHelper, 13 | ) 14 | from threestudio.models.mesh import Mesh 15 | from threestudio.utils.base import BaseModule 16 | from threestudio.utils.ops import chunk_batch, scale_tensor 17 | from threestudio.utils.typing import * 18 | 19 | 20 | def contract_to_unisphere( 21 | x: Float[Tensor, "... 3"], bbox: Float[Tensor, "2 3"], unbounded: bool = False 22 | ) -> Float[Tensor, "... 3"]: 23 | if unbounded: 24 | x = scale_tensor(x, bbox, (0, 1)) 25 | x = x * 2 - 1 # aabb is at [-1, 1] 26 | mag = x.norm(dim=-1, keepdim=True) 27 | mask = mag.squeeze(-1) > 1 28 | x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask]) 29 | x = x / 4 + 0.5 # [-inf, inf] is at [0, 1] 30 | else: 31 | x = scale_tensor(x, bbox, (0, 1)) 32 | return x 33 | 34 | 35 | class BaseGeometry(BaseModule): 36 | @dataclass 37 | class Config(BaseModule.Config): 38 | pass 39 | 40 | cfg: Config 41 | 42 | @staticmethod 43 | def create_from( 44 | other: "BaseGeometry", cfg: Optional[Union[dict, DictConfig]] = None, **kwargs 45 | ) -> "BaseGeometry": 46 | raise TypeError( 47 | f"Cannot create {BaseGeometry.__name__} from {other.__class__.__name__}" 48 | ) 49 | 50 | def export(self, *args, **kwargs) -> Dict[str, Any]: 51 | return {} 52 | 53 | 54 | class BaseImplicitGeometry(BaseGeometry): 55 | @dataclass 56 | class Config(BaseGeometry.Config): 57 | radius: float = 1.0 58 | isosurface: bool = True 59 | isosurface_method: str = "mt" 60 | isosurface_resolution: int = 128 61 | isosurface_threshold: Union[float, str] = 0.0 62 | isosurface_chunk: int = 0 63 | isosurface_coarse_to_fine: bool = True 64 | isosurface_deformable_grid: bool = False 65 | isosurface_remove_outliers: bool = True 66 | isosurface_outlier_n_faces_threshold: Union[int, float] = 0.01 67 | 68 | cfg: Config 69 | 70 | def configure(self) -> None: 71 | self.bbox: Float[Tensor, "2 3"] 72 | self.register_buffer( 73 | "bbox", 74 | torch.as_tensor( 75 | [ 76 | [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], 77 | [self.cfg.radius, self.cfg.radius, self.cfg.radius], 78 | ], 79 | dtype=torch.float32, 80 | ), 81 | ) 82 | self.isosurface_helper: Optional[IsosurfaceHelper] = None 83 | self.unbounded: bool = False 84 | 85 | def _initilize_isosurface_helper(self): 86 | if self.cfg.isosurface and self.isosurface_helper is None: 87 | if self.cfg.isosurface_method == "mc-cpu": 88 | self.isosurface_helper = MarchingCubeCPUHelper( 89 | self.cfg.isosurface_resolution 90 | ).to(self.device) 91 | elif self.cfg.isosurface_method == "mt": 92 | self.isosurface_helper = MarchingTetrahedraHelper( 93 | self.cfg.isosurface_resolution, 94 | f"load/tets/{self.cfg.isosurface_resolution}_tets.npz", 95 | ).to(self.device) 96 | else: 97 | raise AttributeError( 98 | "Unknown isosurface method {self.cfg.isosurface_method}" 99 | ) 100 | 101 | def forward( 102 | self, points: Float[Tensor, "*N Di"], output_normal: bool = False 103 | ) -> Dict[str, Float[Tensor, "..."]]: 104 | raise NotImplementedError 105 | 106 | def forward_field( 107 | self, points: Float[Tensor, "*N Di"] 108 | ) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]: 109 | # return the value of the implicit field, could be density / signed distance 110 | # also return a deformation field if the grid vertices can be optimized 111 | raise NotImplementedError 112 | 113 | def forward_level( 114 | self, field: Float[Tensor, "*N 1"], threshold: float 115 | ) -> Float[Tensor, "*N 1"]: 116 | # return the value of the implicit field, where the zero level set represents the surface 117 | raise NotImplementedError 118 | 119 | def _isosurface(self, bbox: Float[Tensor, "2 3"], fine_stage: bool = False) -> Mesh: 120 | def batch_func(x): 121 | # scale to bbox as the input vertices are in [0, 1] 122 | field, deformation = self.forward_field( 123 | scale_tensor( 124 | x.to(bbox.device), self.isosurface_helper.points_range, bbox 125 | ), 126 | ) 127 | field = field.to( 128 | x.device 129 | ) # move to the same device as the input (could be CPU) 130 | if deformation is not None: 131 | deformation = deformation.to(x.device) 132 | return field, deformation 133 | 134 | assert self.isosurface_helper is not None 135 | 136 | field, deformation = chunk_batch( 137 | batch_func, 138 | self.cfg.isosurface_chunk, 139 | self.isosurface_helper.grid_vertices, 140 | ) 141 | 142 | threshold: float 143 | 144 | if isinstance(self.cfg.isosurface_threshold, float): 145 | threshold = self.cfg.isosurface_threshold 146 | elif self.cfg.isosurface_threshold == "auto": 147 | eps = 1.0e-5 148 | threshold = field[field > eps].mean().item() 149 | threestudio.info( 150 | f"Automatically determined isosurface threshold: {threshold}" 151 | ) 152 | else: 153 | raise TypeError( 154 | f"Unknown isosurface_threshold {self.cfg.isosurface_threshold}" 155 | ) 156 | 157 | level = self.forward_level(field, threshold) 158 | mesh: Mesh = self.isosurface_helper(level, deformation=deformation) 159 | mesh.v_pos = scale_tensor( 160 | mesh.v_pos, self.isosurface_helper.points_range, bbox 161 | ) # scale to bbox as the grid vertices are in [0, 1] 162 | mesh.add_extra("bbox", bbox) 163 | 164 | if self.cfg.isosurface_remove_outliers: 165 | # remove outliers components with small number of faces 166 | # only enabled when the mesh is not differentiable 167 | mesh = mesh.remove_outlier(self.cfg.isosurface_outlier_n_faces_threshold) 168 | 169 | return mesh 170 | 171 | def isosurface(self) -> Mesh: 172 | if not self.cfg.isosurface: 173 | raise NotImplementedError( 174 | "Isosurface is not enabled in the current configuration" 175 | ) 176 | self._initilize_isosurface_helper() 177 | if self.cfg.isosurface_coarse_to_fine: 178 | threestudio.debug("First run isosurface to get a tight bounding box ...") 179 | with torch.no_grad(): 180 | mesh_coarse = self._isosurface(self.bbox) 181 | vmin, vmax = mesh_coarse.v_pos.amin(dim=0), mesh_coarse.v_pos.amax(dim=0) 182 | vmin_ = (vmin - (vmax - vmin) * 0.1).max(self.bbox[0]) 183 | vmax_ = (vmax + (vmax - vmin) * 0.1).min(self.bbox[1]) 184 | threestudio.debug("Run isosurface again with the tight bounding box ...") 185 | mesh = self._isosurface(torch.stack([vmin_, vmax_], dim=0), fine_stage=True) 186 | else: 187 | mesh = self._isosurface(self.bbox) 188 | return mesh 189 | 190 | 191 | class BaseExplicitGeometry(BaseGeometry): 192 | @dataclass 193 | class Config(BaseGeometry.Config): 194 | radius: float = 1.0 195 | 196 | cfg: Config 197 | 198 | def configure(self) -> None: 199 | self.bbox: Float[Tensor, "2 3"] 200 | self.register_buffer( 201 | "bbox", 202 | torch.as_tensor( 203 | [ 204 | [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], 205 | [self.cfg.radius, self.cfg.radius, self.cfg.radius], 206 | ], 207 | dtype=torch.float32, 208 | ), 209 | ) 210 | -------------------------------------------------------------------------------- /threestudio/models/geometry/volume_grid.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.geometry.base import BaseImplicitGeometry, contract_to_unisphere 10 | from threestudio.utils.ops import get_activation 11 | from threestudio.utils.typing import * 12 | 13 | 14 | @threestudio.register("volume-grid") 15 | class VolumeGrid(BaseImplicitGeometry): 16 | @dataclass 17 | class Config(BaseImplicitGeometry.Config): 18 | grid_size: Tuple[int, int, int] = field(default_factory=lambda: (100, 100, 100)) 19 | n_feature_dims: int = 3 20 | density_activation: Optional[str] = "softplus" 21 | density_bias: Union[float, str] = "blob" 22 | density_blob_scale: float = 5.0 23 | density_blob_std: float = 0.5 24 | normal_type: Optional[ 25 | str 26 | ] = "finite_difference" # in ['pred', 'finite_difference'] 27 | 28 | # automatically determine the threshold 29 | isosurface_threshold: Union[float, str] = "auto" 30 | 31 | cfg: Config 32 | 33 | def configure(self) -> None: 34 | super().configure() 35 | self.grid_size = self.cfg.grid_size 36 | 37 | self.grid = nn.Parameter( 38 | torch.zeros(1, self.cfg.n_feature_dims + 1, *self.grid_size) 39 | ) 40 | if self.cfg.density_bias == "blob": 41 | self.register_buffer("density_scale", torch.tensor(0.0)) 42 | else: 43 | self.density_scale = nn.Parameter(torch.tensor(0.0)) 44 | 45 | if self.cfg.normal_type == "pred": 46 | self.normal_grid = nn.Parameter(torch.zeros(1, 3, *self.grid_size)) 47 | 48 | def get_density_bias(self, points: Float[Tensor, "*N Di"]): 49 | if self.cfg.density_bias == "blob": 50 | density_bias: Float[Tensor, "*N 1"] = ( 51 | self.cfg.density_blob_scale 52 | * ( 53 | 1 54 | - torch.sqrt((points.detach() ** 2).sum(dim=-1)) 55 | / self.cfg.density_blob_std 56 | )[..., None] 57 | ) 58 | return density_bias 59 | elif isinstance(self.cfg.density_bias, float): 60 | return self.cfg.density_bias 61 | else: 62 | raise AttributeError(f"Unknown density bias {self.cfg.density_bias}") 63 | 64 | def get_trilinear_feature( 65 | self, points: Float[Tensor, "*N Di"], grid: Float[Tensor, "1 Df G1 G2 G3"] 66 | ) -> Float[Tensor, "*N Df"]: 67 | points_shape = points.shape[:-1] 68 | df = grid.shape[1] 69 | di = points.shape[-1] 70 | out = F.grid_sample( 71 | grid, points.view(1, 1, 1, -1, di), align_corners=False, mode="bilinear" 72 | ) 73 | out = out.reshape(df, -1).T.reshape(*points_shape, df) 74 | return out 75 | 76 | def forward( 77 | self, points: Float[Tensor, "*N Di"], output_normal: bool = False 78 | ) -> Dict[str, Float[Tensor, "..."]]: 79 | points_unscaled = points # points in the original scale 80 | points = contract_to_unisphere( 81 | points, self.bbox, self.unbounded 82 | ) # points normalized to (0, 1) 83 | points = points * 2 - 1 # convert to [-1, 1] for grid sample 84 | 85 | out = self.get_trilinear_feature(points, self.grid) 86 | density, features = out[..., 0:1], out[..., 1:] 87 | density = density * torch.exp(self.density_scale) # exp scaling in DreamFusion 88 | 89 | density = get_activation(self.cfg.density_activation)( 90 | density + self.get_density_bias(points_unscaled) 91 | ) 92 | 93 | output = { 94 | "density": density, 95 | "features": features, 96 | } 97 | 98 | if output_normal: 99 | if self.cfg.normal_type == "finite_difference": 100 | eps = 1.0e-3 101 | offsets: Float[Tensor, "6 3"] = torch.as_tensor( 102 | [ 103 | [eps, 0.0, 0.0], 104 | [-eps, 0.0, 0.0], 105 | [0.0, eps, 0.0], 106 | [0.0, -eps, 0.0], 107 | [0.0, 0.0, eps], 108 | [0.0, 0.0, -eps], 109 | ] 110 | ).to(points_unscaled) 111 | points_offset: Float[Tensor, "... 6 3"] = ( 112 | points_unscaled[..., None, :] + offsets 113 | ).clamp(-self.cfg.radius, self.cfg.radius) 114 | density_offset: Float[Tensor, "... 6 1"] = self.forward_density( 115 | points_offset 116 | ) 117 | normal = ( 118 | -0.5 119 | * (density_offset[..., 0::2, 0] - density_offset[..., 1::2, 0]) 120 | / eps 121 | ) 122 | normal = F.normalize(normal, dim=-1) 123 | elif self.cfg.normal_type == "pred": 124 | normal = self.get_trilinear_feature(points, self.normal_grid) 125 | normal = F.normalize(normal, dim=-1) 126 | else: 127 | raise AttributeError(f"Unknown normal type {self.cfg.normal_type}") 128 | output.update({"normal": normal, "shading_normal": normal}) 129 | return output 130 | 131 | def forward_density(self, points: Float[Tensor, "*N Di"]) -> Float[Tensor, "*N 1"]: 132 | points_unscaled = points 133 | points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded) 134 | points = points * 2 - 1 # convert to [-1, 1] for grid sample 135 | 136 | out = self.get_trilinear_feature(points, self.grid) 137 | density = out[..., 0:1] 138 | density = density * torch.exp(self.density_scale) 139 | 140 | density = get_activation(self.cfg.density_activation)( 141 | density + self.get_density_bias(points_unscaled) 142 | ) 143 | return density 144 | 145 | def forward_field( 146 | self, points: Float[Tensor, "*N Di"] 147 | ) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]: 148 | if self.cfg.isosurface_deformable_grid: 149 | threestudio.warn( 150 | f"{self.__class__.__name__} does not support isosurface_deformable_grid. Ignoring." 151 | ) 152 | density = self.forward_density(points) 153 | return density, None 154 | 155 | def forward_level( 156 | self, field: Float[Tensor, "*N 1"], threshold: float 157 | ) -> Float[Tensor, "*N 1"]: 158 | return -(field - threshold) 159 | 160 | def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]: 161 | out: Dict[str, Any] = {} 162 | if self.cfg.n_feature_dims == 0: 163 | return out 164 | points_unscaled = points 165 | points = contract_to_unisphere(points, self.bbox, self.unbounded) 166 | points = points * 2 - 1 # convert to [-1, 1] for grid sample 167 | features = self.get_trilinear_feature(points, self.grid)[..., 1:] 168 | out.update( 169 | { 170 | "features": features, 171 | } 172 | ) 173 | return out 174 | -------------------------------------------------------------------------------- /threestudio/models/guidance/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | deep_floyd_guidance, 3 | stable_diffusion_guidance, 4 | stable_diffusion_vsd_guidance, 5 | zero123_guidance, 6 | zeroscope_guidance, 7 | multiview_video_diffusion_guidance, 8 | ) 9 | -------------------------------------------------------------------------------- /threestudio/models/hexplane.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import logging as log 3 | from typing import Optional, Union, List, Dict, Sequence, Iterable, Collection, Callable 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | def get_normalized_directions(directions): 11 | """SH encoding must be in the range [0, 1] 12 | 13 | Args: 14 | directions: batch of directions 15 | """ 16 | return (directions + 1.0) / 2.0 17 | 18 | 19 | def normalize_aabb(pts, aabb): 20 | return (pts - aabb[0]) * (2.0 / (aabb[1] - aabb[0])) - 1.0 21 | 22 | 23 | def grid_sample_wrapper(grid: torch.Tensor, coords: torch.Tensor, align_corners: bool = True) -> torch.Tensor: 24 | grid_dim = coords.shape[-1] 25 | 26 | if grid.dim() == grid_dim + 1: 27 | # no batch dimension present, need to add it 28 | grid = grid.unsqueeze(0) 29 | if coords.dim() == 2: 30 | coords = coords.unsqueeze(0) 31 | 32 | if grid_dim == 2 or grid_dim == 3: 33 | grid_sampler = F.grid_sample 34 | else: 35 | raise NotImplementedError(f"Grid-sample was called with {grid_dim}D data but is only " 36 | f"implemented for 2 and 3D data.") 37 | 38 | coords = coords.view([coords.shape[0]] + [1] * (grid_dim - 1) + list(coords.shape[1:])) 39 | B, feature_dim = grid.shape[:2] 40 | n = coords.shape[-2] 41 | interp = grid_sampler( 42 | grid, # [B, feature_dim, reso, ...] 43 | coords, # [B, 1, ..., n, grid_dim] 44 | align_corners=align_corners, 45 | mode='bilinear', padding_mode='border') 46 | interp = interp.view(B, feature_dim, n).transpose(-1, -2) # [B, n, feature_dim] 47 | interp = interp.squeeze() # [B?, n, feature_dim?] 48 | return interp 49 | 50 | def init_grid_param( 51 | grid_nd: int, 52 | in_dim: int, 53 | out_dim: int, 54 | reso: Sequence[int], 55 | a: float = 0.1, 56 | b: float = 0.5): 57 | assert in_dim == len(reso), "Resolution must have same number of elements as input-dimension" 58 | has_time_planes = in_dim == 4 59 | assert grid_nd <= in_dim 60 | coo_combs = list(itertools.combinations(range(in_dim), grid_nd)) 61 | grid_coefs = nn.ParameterList() 62 | for ci, coo_comb in enumerate(coo_combs): 63 | if not 3 in coo_comb: continue 64 | new_grid_coef = nn.Parameter(torch.empty( 65 | [1, out_dim] + [reso[cc] for cc in coo_comb[::-1]] 66 | )) 67 | if has_time_planes and 3 in coo_comb: # Initialize time planes to 1 68 | nn.init.ones_(new_grid_coef) 69 | else: 70 | nn.init.uniform_(new_grid_coef, a=a, b=b) 71 | grid_coefs.append(new_grid_coef) 72 | 73 | return grid_coefs 74 | 75 | 76 | def interpolate_ms_features(pts: torch.Tensor, 77 | ms_grids: Collection[Iterable[nn.Module]], 78 | grid_dimensions: int, 79 | concat_features: bool, 80 | num_levels: Optional[int], 81 | ) -> torch.Tensor: 82 | coo_combs = list(itertools.combinations( 83 | range(pts.shape[-1]), grid_dimensions) 84 | ) 85 | if num_levels is None: 86 | num_levels = len(ms_grids) 87 | multi_scale_interp = [] if concat_features else 0. 88 | grid: nn.ParameterList 89 | for scale_id, grid in enumerate(ms_grids[:num_levels]): 90 | interp_space = 1. 91 | idx_grid = 0 92 | for ci, coo_comb in enumerate(coo_combs): 93 | if not 3 in coo_comb: continue 94 | # interpolate in plane 95 | feature_dim = grid[idx_grid].shape[1] # shape of grid[ci]: 1, out_dim, *reso 96 | interp_out_plane = ( 97 | grid_sample_wrapper(grid[idx_grid], pts[..., coo_comb]) 98 | .view(-1, feature_dim) 99 | ) 100 | # compute product over planes 101 | interp_space = interp_space * interp_out_plane 102 | idx_grid += 1 103 | 104 | # combine over scales 105 | if concat_features: 106 | multi_scale_interp.append(interp_space) 107 | else: 108 | multi_scale_interp = multi_scale_interp + interp_space 109 | 110 | if concat_features: 111 | multi_scale_interp = torch.cat(multi_scale_interp, dim=-1) 112 | return multi_scale_interp 113 | 114 | 115 | class HexPlaneField(nn.Module): 116 | def __init__( 117 | self, 118 | bounds, 119 | planeconfig, 120 | multires 121 | ) -> None: 122 | super().__init__() 123 | aabb = torch.tensor([[bounds,bounds,bounds], 124 | [-bounds,-bounds,-bounds]]) 125 | self.aabb = nn.Parameter(aabb, requires_grad=False) 126 | self.grid_config = [planeconfig] 127 | self.multiscale_res_multipliers = multires 128 | self.concat_features = False 129 | 130 | # 1. Init planes 131 | self.grids = nn.ModuleList() 132 | self.feat_dim = 0 133 | for res in self.multiscale_res_multipliers: 134 | # initialize coordinate grid 135 | config = self.grid_config[0].copy() 136 | # Resolution fix: multi-res only on spatial planes 137 | config["resolution"] = [ 138 | r * res for r in config["resolution"][:3] 139 | ] + config["resolution"][3:] 140 | gp = init_grid_param( 141 | grid_nd=config["grid_dimensions"], 142 | in_dim=config["input_coordinate_dim"], 143 | out_dim=config["output_coordinate_dim"], 144 | reso=config["resolution"], 145 | ) 146 | # shape[1] is out-dim - Concatenate over feature len for each scale 147 | if self.concat_features: 148 | self.feat_dim += gp[-1].shape[1] 149 | else: 150 | self.feat_dim = gp[-1].shape[1] 151 | self.grids.append(gp) 152 | print('planes: ', len(gp)) 153 | # print(f"Initialized model grids: {self.grids}") 154 | print("feature_dim:", self.feat_dim) 155 | 156 | 157 | def set_aabb(self,xyz_max, xyz_min): 158 | aabb = torch.tensor([ 159 | xyz_max, 160 | xyz_min 161 | ]) 162 | self.aabb = nn.Parameter(aabb,requires_grad=True) 163 | print("Voxel Plane: set aabb=",self.aabb) 164 | 165 | def get_density(self, pts: torch.Tensor, timestamps: Optional[torch.Tensor] = None): 166 | """Computes and returns the densities.""" 167 | pts = normalize_aabb(pts, self.aabb) 168 | pts = torch.cat((pts, timestamps), dim=-1) # [n_rays, n_samples, 4] 169 | 170 | pts = pts.reshape(-1, pts.shape[-1]) 171 | features = interpolate_ms_features( 172 | pts, ms_grids=self.grids, # noqa 173 | grid_dimensions=self.grid_config[0]["grid_dimensions"], 174 | concat_features=self.concat_features, num_levels=None) 175 | if len(features) < 1: 176 | features = torch.zeros((0, 1)).to(features.device) 177 | 178 | 179 | return features 180 | 181 | def forward(self, 182 | pts: torch.Tensor): 183 | features = self.get_density(pts[:,:3], pts[:,3,None]) 184 | 185 | return features 186 | 187 | def _regularize(self): 188 | multi_res_grids = self.grids 189 | total = 0 190 | 191 | for grids in multi_res_grids: 192 | # if len(grids) == 3: 193 | # continue 194 | # else: 195 | # These are the spatiotemporal grids 196 | spatiotemporal_grids = [0, 1, 2] 197 | for grid_id in spatiotemporal_grids: 198 | total += torch.abs(1 - grids[grid_id]).mean() 199 | 200 | for grids in multi_res_grids: 201 | time_grids = [0, 1, 2] 202 | for grid_id in time_grids: 203 | total += compute_plane_smoothness(grids[grid_id]) 204 | 205 | return total 206 | 207 | def compute_plane_smoothness(t): 208 | batch_size, c, h, w = t.shape 209 | # Convolve with a second derivative filter, in the time dimension which is dimension 2 210 | first_difference = t[..., 1:, :] - t[..., :h-1, :] # [batch, c, h-1, w] 211 | second_difference = first_difference[..., 1:, :] - first_difference[..., :h-2, :] # [batch, c, h-2, w] 212 | # Take the L2 norm of the result 213 | return torch.square(second_difference).mean() -------------------------------------------------------------------------------- /threestudio/models/imagedream/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | */__pycache__/ 6 | 7 | # dataset-related, pre-trained models, 8 | vae_models/vqgan 9 | vae_models/*.gz 10 | vae_models/*.pt 11 | vae_models/*vqgan 12 | *.pt 13 | *.pth 14 | 15 | # log files 16 | log/*.log 17 | out* 18 | test_results 19 | err* 20 | 21 | 22 | # C extensions 23 | *.so 24 | 25 | # Distribution / packaging 26 | .Python 27 | build/ 28 | develop-eggs/ 29 | dist/ 30 | downloads/ 31 | eggs/ 32 | .eggs/ 33 | lib/ 34 | lib64/ 35 | parts/ 36 | sdist/ 37 | var/ 38 | wheels/ 39 | pip-wheel-metadata/ 40 | share/python-wheels/ 41 | *.egg-info/ 42 | .installed.cfg 43 | *.egg 44 | MANIFEST 45 | 46 | # PyInstaller 47 | # Usually these files are written by a python script from a template 48 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 49 | *.manifest 50 | *.spec 51 | 52 | # Installer logs 53 | pip-log.txt 54 | pip-delete-this-directory.txt 55 | 56 | # Unit test / coverage reports 57 | htmlcov/ 58 | .tox/ 59 | .nox/ 60 | .coverage 61 | .coverage.* 62 | .cache 63 | nosetests.xml 64 | coverage.xml 65 | *.cover 66 | *.py,cover 67 | .hypothesis/ 68 | .pytest_cache/ 69 | 70 | # Translations 71 | *.mo 72 | *.pot 73 | 74 | # Django stuff: 75 | *.log 76 | local_settings.py 77 | db.sqlite3 78 | db.sqlite3-journal 79 | 80 | # Flask stuff: 81 | instance/ 82 | .webassets-cache 83 | 84 | # Scrapy stuff: 85 | .scrapy 86 | 87 | # Sphinx documentation 88 | docs/_build/ 89 | 90 | # PyBuilder 91 | target/ 92 | 93 | # Jupyter Notebook 94 | .ipynb_checkpoints 95 | 96 | # IPython 97 | profile_default/ 98 | ipython_config.py 99 | 100 | # pyenv 101 | .python-version 102 | 103 | # pipenv 104 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 105 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 106 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 107 | # install all needed dependencies. 108 | #Pipfile.lock 109 | 110 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 111 | __pypackages__/ 112 | 113 | # Celery stuff 114 | celerybeat-schedule 115 | celerybeat.pid 116 | 117 | # SageMath parsed files 118 | *.sage.py 119 | 120 | # Environments 121 | .env 122 | .venv 123 | env/ 124 | venv/ 125 | ENV/ 126 | env.bak/ 127 | venv.bak/ 128 | 129 | # Spyder project settings 130 | .spyderproject 131 | .spyproject 132 | 133 | # Rope project settings 134 | .ropeproject 135 | 136 | # mkdocs documentation 137 | /site 138 | 139 | # mypy 140 | .mypy_cache/ 141 | .dmypy.json 142 | dmypy.json 143 | 144 | # Pyre type checker 145 | .pyre/ 146 | 147 | *.zip 148 | *.pkl 149 | *.csv 150 | *.ckpt 151 | *.parquet 152 | 153 | *.whl 154 | *.th 155 | *.onnx -------------------------------------------------------------------------------- /threestudio/models/imagedream/LICENSE-CODE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ByteDance 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /threestudio/models/imagedream/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/__init__.py -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/0.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/1.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/10.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/11.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/12.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/13.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/14.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/15.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/16.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/17.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/18.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/19.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/2.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/20.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/21.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/22.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/23.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/24.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/24.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/3.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/4.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/5.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/6.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/7.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/8.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/assets/yoda/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/assets/yoda/9.png -------------------------------------------------------------------------------- /threestudio/models/imagedream/imagedream/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_zoo import build_model 2 | -------------------------------------------------------------------------------- /threestudio/models/imagedream/imagedream/camera_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def create_camera_to_world_matrix(elevation, azimuth): 6 | elevation = np.radians(elevation) 7 | azimuth = np.radians(azimuth) 8 | # Convert elevation and azimuth angles to Cartesian coordinates on a unit sphere 9 | x = np.cos(elevation) * np.sin(azimuth) 10 | y = np.sin(elevation) 11 | z = np.cos(elevation) * np.cos(azimuth) 12 | 13 | # Calculate camera position, target, and up vectors 14 | camera_pos = np.array([x, y, z]) 15 | target = np.array([0, 0, 0]) 16 | up = np.array([0, 1, 0]) 17 | 18 | # Construct view matrix 19 | forward = target - camera_pos 20 | forward /= np.linalg.norm(forward) 21 | right = np.cross(forward, up) 22 | right /= np.linalg.norm(right) 23 | new_up = np.cross(right, forward) 24 | new_up /= np.linalg.norm(new_up) 25 | cam2world = np.eye(4) 26 | cam2world[:3, :3] = np.array([right, new_up, -forward]).T 27 | cam2world[:3, 3] = camera_pos 28 | return cam2world 29 | 30 | 31 | def convert_opengl_to_blender(camera_matrix): 32 | if isinstance(camera_matrix, np.ndarray): 33 | # Construct transformation matrix to convert from OpenGL space to Blender space 34 | flip_yz = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) 35 | camera_matrix_blender = np.dot(flip_yz, camera_matrix) 36 | else: 37 | # Construct transformation matrix to convert from OpenGL space to Blender space 38 | flip_yz = torch.tensor( 39 | [[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]] 40 | ) 41 | if camera_matrix.ndim == 3: 42 | flip_yz = flip_yz.unsqueeze(0) 43 | camera_matrix_blender = torch.matmul(flip_yz.to(camera_matrix), camera_matrix) 44 | return camera_matrix_blender 45 | 46 | 47 | def normalize_camera(camera_matrix): 48 | """normalize the camera location onto a unit-sphere""" 49 | if isinstance(camera_matrix, np.ndarray): 50 | camera_matrix = camera_matrix.reshape(-1, 4, 4) 51 | translation = camera_matrix[:, :3, 3] 52 | translation = translation / ( 53 | np.linalg.norm(translation, axis=1, keepdims=True) + 1e-8 54 | ) 55 | camera_matrix[:, :3, 3] = translation 56 | else: 57 | camera_matrix = camera_matrix.reshape(-1, 4, 4) 58 | translation = camera_matrix[:, :3, 3] 59 | translation = translation / ( 60 | torch.norm(translation, dim=1, keepdim=True) + 1e-8 61 | ) 62 | camera_matrix[:, :3, 3] = translation 63 | return camera_matrix.reshape(-1, 16) 64 | 65 | 66 | def get_camera( 67 | num_frames, 68 | elevation=15, 69 | azimuth_start=0, 70 | azimuth_span=360, 71 | blender_coord=True, 72 | extra_view=False, 73 | ): 74 | angle_gap = azimuth_span / num_frames 75 | cameras = [] 76 | for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap): 77 | camera_matrix = create_camera_to_world_matrix(elevation, azimuth) 78 | if blender_coord: 79 | camera_matrix = convert_opengl_to_blender(camera_matrix) 80 | cameras.append(camera_matrix.flatten()) 81 | 82 | if extra_view: 83 | dim = len(cameras[0]) 84 | cameras.append(np.zeros(dim)) 85 | return torch.tensor(np.stack(cameras, 0)).float() 86 | -------------------------------------------------------------------------------- /threestudio/models/imagedream/imagedream/configs/sd_v2_base_ipmv.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: threestudio.models.imagedream.imagedream.ldm.interface.LatentDiffusionInterface 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | timesteps: 1000 7 | scale_factor: 0.18215 8 | parameterization: "eps" 9 | 10 | unet_config: 11 | target: threestudio.models.imagedream.imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel 12 | params: 13 | image_size: 32 # unused 14 | in_channels: 4 15 | out_channels: 4 16 | model_channels: 320 17 | attention_resolutions: [ 4, 2, 1 ] 18 | num_res_blocks: 2 19 | channel_mult: [ 1, 2, 4, 4 ] 20 | num_head_channels: 64 # need to fix for flash-attn 21 | use_spatial_transformer: True 22 | use_linear_in_transformer: True 23 | transformer_depth: 1 24 | context_dim: 1024 25 | use_checkpoint: False 26 | legacy: False 27 | camera_dim: 16 28 | with_ip: True 29 | ip_dim: 16 # ip token length 30 | ip_mode: "local_resample" 31 | 32 | use_motion_module: True 33 | motion_module_type: Vanilla 34 | motion_module_kwargs: 35 | num_attention_heads: 8 36 | num_transformer_block: 1 37 | attention_block_types: 38 | - Temporal_Self 39 | - Temporal_Self 40 | temporal_position_encoding: true 41 | temporal_position_encoding_max_len: 32 42 | temporal_attention_dim_div: 1 43 | 44 | vae_config: 45 | target: threestudio.models.imagedream.imagedream.ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 4 48 | monitor: val/rec_loss 49 | ddconfig: 50 | #attn_type: "vanilla-xformers" 51 | double_z: true 52 | z_channels: 4 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: 58 | - 1 59 | - 2 60 | - 4 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | 68 | clip_config: 69 | target: threestudio.models.imagedream.imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 70 | params: 71 | freeze: True 72 | layer: "penultimate" 73 | ip_mode: "local_resample" 74 | -------------------------------------------------------------------------------- /threestudio/models/imagedream/imagedream/configs/sd_v2_base_ipmv_local.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: imagedream.ldm.interface.LatentDiffusionInterface 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | timesteps: 1000 7 | scale_factor: 0.18215 8 | parameterization: "eps" 9 | 10 | unet_config: 11 | target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel 12 | params: 13 | image_size: 32 # unused 14 | in_channels: 4 15 | out_channels: 4 16 | model_channels: 320 17 | attention_resolutions: [ 4, 2, 1 ] 18 | num_res_blocks: 2 19 | channel_mult: [ 1, 2, 4, 4 ] 20 | num_head_channels: 64 # need to fix for flash-attn 21 | use_spatial_transformer: True 22 | use_linear_in_transformer: True 23 | transformer_depth: 1 24 | context_dim: 1024 25 | use_checkpoint: False 26 | legacy: False 27 | camera_dim: 16 28 | with_ip: True 29 | ip_dim: 16 # ip token length 30 | ip_mode: "local_resample" 31 | ip_weight: 1.0 # adjust for similarity to image 32 | 33 | vae_config: 34 | target: imagedream.ldm.models.autoencoder.AutoencoderKL 35 | params: 36 | embed_dim: 4 37 | monitor: val/rec_loss 38 | ddconfig: 39 | #attn_type: "vanilla-xformers" 40 | double_z: true 41 | z_channels: 4 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: 47 | - 1 48 | - 2 49 | - 4 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | 57 | clip_config: 58 | target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 59 | params: 60 | freeze: True 61 | layer: "penultimate" 62 | ip_mode: "local_resample" 63 | -------------------------------------------------------------------------------- /threestudio/models/imagedream/imagedream/ldm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/imagedream/ldm/__init__.py -------------------------------------------------------------------------------- /threestudio/models/imagedream/imagedream/ldm/interface.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from functools import partial 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .modules.diffusionmodules.util import ( 9 | make_beta_schedule, 10 | extract_into_tensor, 11 | noise_like, 12 | ) 13 | from .util import exists, default, instantiate_from_config 14 | from .modules.distributions.distributions import DiagonalGaussianDistribution 15 | 16 | 17 | class DiffusionWrapper(nn.Module): 18 | def __init__(self, diffusion_model): 19 | super().__init__() 20 | self.diffusion_model = diffusion_model 21 | 22 | def forward(self, *args, **kwargs): 23 | return self.diffusion_model(*args, **kwargs) 24 | 25 | 26 | class LatentDiffusionInterface(nn.Module): 27 | """a simple interface class for LDM inference""" 28 | 29 | def __init__( 30 | self, 31 | unet_config, 32 | clip_config, 33 | vae_config, 34 | parameterization="eps", 35 | scale_factor=0.18215, 36 | beta_schedule="linear", 37 | timesteps=1000, 38 | linear_start=0.00085, 39 | linear_end=0.0120, 40 | cosine_s=8e-3, 41 | given_betas=None, 42 | *args, 43 | **kwargs, 44 | ): 45 | super().__init__() 46 | 47 | unet = instantiate_from_config(unet_config) 48 | self.model = DiffusionWrapper(unet) 49 | self.clip_model = instantiate_from_config(clip_config) 50 | self.vae_model = instantiate_from_config(vae_config) 51 | 52 | self.parameterization = parameterization 53 | self.scale_factor = scale_factor 54 | self.register_schedule( 55 | given_betas=given_betas, 56 | beta_schedule=beta_schedule, 57 | timesteps=timesteps, 58 | linear_start=linear_start, 59 | linear_end=linear_end, 60 | cosine_s=cosine_s, 61 | ) 62 | 63 | def register_schedule( 64 | self, 65 | given_betas=None, 66 | beta_schedule="linear", 67 | timesteps=1000, 68 | linear_start=1e-4, 69 | linear_end=2e-2, 70 | cosine_s=8e-3, 71 | ): 72 | if exists(given_betas): 73 | betas = given_betas 74 | else: 75 | betas = make_beta_schedule( 76 | beta_schedule, 77 | timesteps, 78 | linear_start=linear_start, 79 | linear_end=linear_end, 80 | cosine_s=cosine_s, 81 | ) 82 | alphas = 1.0 - betas 83 | alphas_cumprod = np.cumprod(alphas, axis=0) 84 | alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) 85 | 86 | (timesteps,) = betas.shape 87 | self.num_timesteps = int(timesteps) 88 | self.linear_start = linear_start 89 | self.linear_end = linear_end 90 | assert ( 91 | alphas_cumprod.shape[0] == self.num_timesteps 92 | ), "alphas have to be defined for each timestep" 93 | 94 | to_torch = partial(torch.tensor, dtype=torch.float32) 95 | 96 | self.register_buffer("betas", to_torch(betas)) 97 | self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) 98 | self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) 99 | 100 | # calculations for diffusion q(x_t | x_{t-1}) and others 101 | self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) 102 | self.register_buffer( 103 | "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) 104 | ) 105 | self.register_buffer( 106 | "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) 107 | ) 108 | self.register_buffer( 109 | "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) 110 | ) 111 | self.register_buffer( 112 | "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) 113 | ) 114 | 115 | # calculations for posterior q(x_{t-1} | x_t, x_0) 116 | self.v_posterior = 0 117 | posterior_variance = (1 - self.v_posterior) * betas * ( 118 | 1.0 - alphas_cumprod_prev 119 | ) / (1.0 - alphas_cumprod) + self.v_posterior * betas 120 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 121 | self.register_buffer("posterior_variance", to_torch(posterior_variance)) 122 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 123 | self.register_buffer( 124 | "posterior_log_variance_clipped", 125 | to_torch(np.log(np.maximum(posterior_variance, 1e-20))), 126 | ) 127 | self.register_buffer( 128 | "posterior_mean_coef1", 129 | to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), 130 | ) 131 | self.register_buffer( 132 | "posterior_mean_coef2", 133 | to_torch( 134 | (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) 135 | ), 136 | ) 137 | 138 | def q_sample(self, x_start, t, noise=None): 139 | noise = default(noise, lambda: torch.randn_like(x_start)) 140 | return ( 141 | extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 142 | + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) 143 | * noise 144 | ) 145 | 146 | def get_v(self, x, noise, t): 147 | return ( 148 | extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise 149 | - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x 150 | ) 151 | 152 | def predict_start_from_noise(self, x_t, t, noise): 153 | return ( 154 | extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 155 | - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 156 | * noise 157 | ) 158 | 159 | def predict_start_from_z_and_v(self, x_t, t, v): 160 | return ( 161 | extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t 162 | - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v 163 | ) 164 | 165 | def predict_eps_from_z_and_v(self, x_t, t, v): 166 | return ( 167 | extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v 168 | + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) 169 | * x_t 170 | ) 171 | 172 | def apply_model(self, x_noisy, t, cond, **kwargs): 173 | assert isinstance(cond, dict), "cond has to be a dictionary" 174 | return self.model(x_noisy, t, **cond, **kwargs) 175 | 176 | def get_learned_conditioning(self, prompts: List[str]): 177 | return self.clip_model(prompts) 178 | 179 | def get_learned_image_conditioning(self, images): 180 | return self.clip_model.forward_image(images) 181 | 182 | def get_first_stage_encoding(self, encoder_posterior): 183 | if isinstance(encoder_posterior, DiagonalGaussianDistribution): 184 | z = encoder_posterior.sample() 185 | elif isinstance(encoder_posterior, torch.Tensor): 186 | z = encoder_posterior 187 | else: 188 | raise NotImplementedError( 189 | f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" 190 | ) 191 | return self.scale_factor * z 192 | 193 | def encode_first_stage(self, x): 194 | return self.vae_model.encode(x) 195 | 196 | def decode_first_stage(self, z): 197 | z = 1.0 / self.scale_factor * z 198 | return self.vae_model.decode(z) 199 | -------------------------------------------------------------------------------- /threestudio/models/imagedream/imagedream/ldm/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/imagedream/ldm/models/__init__.py -------------------------------------------------------------------------------- /threestudio/models/imagedream/imagedream/ldm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from contextlib import contextmanager 4 | 5 | from ..modules.diffusionmodules.model import Encoder, Decoder 6 | from ..modules.distributions.distributions import DiagonalGaussianDistribution 7 | 8 | from ..util import instantiate_from_config 9 | from ..modules.ema import LitEma 10 | 11 | 12 | class AutoencoderKL(torch.nn.Module): 13 | def __init__( 14 | self, 15 | ddconfig, 16 | lossconfig, 17 | embed_dim, 18 | ckpt_path=None, 19 | ignore_keys=[], 20 | image_key="image", 21 | colorize_nlabels=None, 22 | monitor=None, 23 | ema_decay=None, 24 | learn_logvar=False, 25 | ): 26 | super().__init__() 27 | self.learn_logvar = learn_logvar 28 | self.image_key = image_key 29 | self.encoder = Encoder(**ddconfig) 30 | self.decoder = Decoder(**ddconfig) 31 | self.loss = instantiate_from_config(lossconfig) 32 | assert ddconfig["double_z"] 33 | self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) 34 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 35 | self.embed_dim = embed_dim 36 | if colorize_nlabels is not None: 37 | assert type(colorize_nlabels) == int 38 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 39 | if monitor is not None: 40 | self.monitor = monitor 41 | 42 | self.use_ema = ema_decay is not None 43 | if self.use_ema: 44 | self.ema_decay = ema_decay 45 | assert 0.0 < ema_decay < 1.0 46 | self.model_ema = LitEma(self, decay=ema_decay) 47 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 48 | 49 | if ckpt_path is not None: 50 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 51 | 52 | def init_from_ckpt(self, path, ignore_keys=list()): 53 | sd = torch.load(path, map_location="cpu")["state_dict"] 54 | keys = list(sd.keys()) 55 | for k in keys: 56 | for ik in ignore_keys: 57 | if k.startswith(ik): 58 | print("Deleting key {} from state_dict.".format(k)) 59 | del sd[k] 60 | self.load_state_dict(sd, strict=False) 61 | print(f"Restored from {path}") 62 | 63 | @contextmanager 64 | def ema_scope(self, context=None): 65 | if self.use_ema: 66 | self.model_ema.store(self.parameters()) 67 | self.model_ema.copy_to(self) 68 | if context is not None: 69 | print(f"{context}: Switched to EMA weights") 70 | try: 71 | yield None 72 | finally: 73 | if self.use_ema: 74 | self.model_ema.restore(self.parameters()) 75 | if context is not None: 76 | print(f"{context}: Restored training weights") 77 | 78 | def on_train_batch_end(self, *args, **kwargs): 79 | if self.use_ema: 80 | self.model_ema(self) 81 | 82 | def encode(self, x): 83 | h = self.encoder(x) 84 | moments = self.quant_conv(h) 85 | posterior = DiagonalGaussianDistribution(moments) 86 | return posterior 87 | 88 | def decode(self, z): 89 | z = self.post_quant_conv(z) 90 | dec = self.decoder(z) 91 | return dec 92 | 93 | def forward(self, input, sample_posterior=True): 94 | posterior = self.encode(input) 95 | if sample_posterior: 96 | z = posterior.sample() 97 | else: 98 | z = posterior.mode() 99 | dec = self.decode(z) 100 | return dec, posterior 101 | 102 | def get_input(self, batch, k): 103 | x = batch[k] 104 | if len(x.shape) == 3: 105 | x = x[..., None] 106 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 107 | return x 108 | 109 | def training_step(self, batch, batch_idx, optimizer_idx): 110 | inputs = self.get_input(batch, self.image_key) 111 | reconstructions, posterior = self(inputs) 112 | 113 | if optimizer_idx == 0: 114 | # train encoder+decoder+logvar 115 | aeloss, log_dict_ae = self.loss( 116 | inputs, 117 | reconstructions, 118 | posterior, 119 | optimizer_idx, 120 | self.global_step, 121 | last_layer=self.get_last_layer(), 122 | split="train", 123 | ) 124 | self.log( 125 | "aeloss", 126 | aeloss, 127 | prog_bar=True, 128 | logger=True, 129 | on_step=True, 130 | on_epoch=True, 131 | ) 132 | self.log_dict( 133 | log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False 134 | ) 135 | return aeloss 136 | 137 | if optimizer_idx == 1: 138 | # train the discriminator 139 | discloss, log_dict_disc = self.loss( 140 | inputs, 141 | reconstructions, 142 | posterior, 143 | optimizer_idx, 144 | self.global_step, 145 | last_layer=self.get_last_layer(), 146 | split="train", 147 | ) 148 | 149 | self.log( 150 | "discloss", 151 | discloss, 152 | prog_bar=True, 153 | logger=True, 154 | on_step=True, 155 | on_epoch=True, 156 | ) 157 | self.log_dict( 158 | log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False 159 | ) 160 | return discloss 161 | 162 | def validation_step(self, batch, batch_idx): 163 | log_dict = self._validation_step(batch, batch_idx) 164 | with self.ema_scope(): 165 | log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") 166 | return log_dict 167 | 168 | def _validation_step(self, batch, batch_idx, postfix=""): 169 | inputs = self.get_input(batch, self.image_key) 170 | reconstructions, posterior = self(inputs) 171 | aeloss, log_dict_ae = self.loss( 172 | inputs, 173 | reconstructions, 174 | posterior, 175 | 0, 176 | self.global_step, 177 | last_layer=self.get_last_layer(), 178 | split="val" + postfix, 179 | ) 180 | 181 | discloss, log_dict_disc = self.loss( 182 | inputs, 183 | reconstructions, 184 | posterior, 185 | 1, 186 | self.global_step, 187 | last_layer=self.get_last_layer(), 188 | split="val" + postfix, 189 | ) 190 | 191 | self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) 192 | self.log_dict(log_dict_ae) 193 | self.log_dict(log_dict_disc) 194 | return self.log_dict 195 | 196 | def configure_optimizers(self): 197 | lr = self.learning_rate 198 | ae_params_list = ( 199 | list(self.encoder.parameters()) 200 | + list(self.decoder.parameters()) 201 | + list(self.quant_conv.parameters()) 202 | + list(self.post_quant_conv.parameters()) 203 | ) 204 | if self.learn_logvar: 205 | print(f"{self.__class__.__name__}: Learning logvar") 206 | ae_params_list.append(self.loss.logvar) 207 | opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9)) 208 | opt_disc = torch.optim.Adam( 209 | self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9) 210 | ) 211 | return [opt_ae, opt_disc], [] 212 | 213 | def get_last_layer(self): 214 | return self.decoder.conv_out.weight 215 | 216 | @torch.no_grad() 217 | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): 218 | log = dict() 219 | x = self.get_input(batch, self.image_key) 220 | x = x.to(self.device) 221 | if not only_inputs: 222 | xrec, posterior = self(x) 223 | if x.shape[1] > 3: 224 | # colorize with random projection 225 | assert xrec.shape[1] > 3 226 | x = self.to_rgb(x) 227 | xrec = self.to_rgb(xrec) 228 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 229 | log["reconstructions"] = xrec 230 | if log_ema or self.use_ema: 231 | with self.ema_scope(): 232 | xrec_ema, posterior_ema = self(x) 233 | if x.shape[1] > 3: 234 | # colorize with random projection 235 | assert xrec_ema.shape[1] > 3 236 | xrec_ema = self.to_rgb(xrec_ema) 237 | log["samples_ema"] = self.decode( 238 | torch.randn_like(posterior_ema.sample()) 239 | ) 240 | log["reconstructions_ema"] = xrec_ema 241 | log["inputs"] = x 242 | return log 243 | 244 | def to_rgb(self, x): 245 | assert self.image_key == "segmentation" 246 | if not hasattr(self, "colorize"): 247 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 248 | x = F.conv2d(x, weight=self.colorize) 249 | x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 250 | return x 251 | 252 | 253 | class IdentityFirstStage(torch.nn.Module): 254 | def __init__(self, *args, vq_interface=False, **kwargs): 255 | self.vq_interface = vq_interface 256 | super().__init__() 257 | 258 | def encode(self, x, *args, **kwargs): 259 | return x 260 | 261 | def decode(self, x, *args, **kwargs): 262 | return x 263 | 264 | def quantize(self, x, *args, **kwargs): 265 | if self.vq_interface: 266 | return x, None, [None, None, None] 267 | return x 268 | 269 | def forward(self, x, *args, **kwargs): 270 | return x 271 | -------------------------------------------------------------------------------- /threestudio/models/imagedream/imagedream/ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/imagedream/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /threestudio/models/imagedream/imagedream/ldm/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/imagedream/ldm/modules/__init__.py -------------------------------------------------------------------------------- /threestudio/models/imagedream/imagedream/ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/imagedream/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /threestudio/models/imagedream/imagedream/ldm/modules/diffusionmodules/adaptors.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | # FFN 9 | def FeedForward(dim, mult=4): 10 | inner_dim = int(dim * mult) 11 | return nn.Sequential( 12 | nn.LayerNorm(dim), 13 | nn.Linear(dim, inner_dim, bias=False), 14 | nn.GELU(), 15 | nn.Linear(inner_dim, dim, bias=False), 16 | ) 17 | 18 | 19 | def reshape_tensor(x, heads): 20 | bs, length, width = x.shape 21 | #(bs, length, width) --> (bs, length, n_heads, dim_per_head) 22 | x = x.view(bs, length, heads, -1) 23 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 24 | x = x.transpose(1, 2) 25 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 26 | x = x.reshape(bs, heads, length, -1) 27 | return x 28 | 29 | 30 | class PerceiverAttention(nn.Module): 31 | def __init__(self, *, dim, dim_head=64, heads=8): 32 | super().__init__() 33 | self.scale = dim_head**-0.5 34 | self.dim_head = dim_head 35 | self.heads = heads 36 | inner_dim = dim_head * heads 37 | 38 | self.norm1 = nn.LayerNorm(dim) 39 | self.norm2 = nn.LayerNorm(dim) 40 | 41 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 42 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 43 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 44 | 45 | 46 | def forward(self, x, latents): 47 | """ 48 | Args: 49 | x (torch.Tensor): image features 50 | shape (b, n1, D) 51 | latent (torch.Tensor): latent features 52 | shape (b, n2, D) 53 | """ 54 | x = self.norm1(x) 55 | latents = self.norm2(latents) 56 | 57 | b, l, _ = latents.shape 58 | 59 | q = self.to_q(latents) 60 | kv_input = torch.cat((x, latents), dim=-2) 61 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 62 | 63 | q = reshape_tensor(q, self.heads) 64 | k = reshape_tensor(k, self.heads) 65 | v = reshape_tensor(v, self.heads) 66 | 67 | # attention 68 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 69 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 70 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 71 | out = weight @ v 72 | 73 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 74 | 75 | return self.to_out(out) 76 | 77 | 78 | class ImageProjModel(torch.nn.Module): 79 | """Projection Model""" 80 | def __init__(self, 81 | cross_attention_dim=1024, 82 | clip_embeddings_dim=1024, 83 | clip_extra_context_tokens=4): 84 | super().__init__() 85 | self.cross_attention_dim = cross_attention_dim 86 | self.clip_extra_context_tokens = clip_extra_context_tokens 87 | 88 | # from 1024 -> 4 * 1024 89 | self.proj = torch.nn.Linear( 90 | clip_embeddings_dim, 91 | self.clip_extra_context_tokens * cross_attention_dim) 92 | self.norm = torch.nn.LayerNorm(cross_attention_dim) 93 | 94 | def forward(self, image_embeds): 95 | embeds = image_embeds 96 | clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) 97 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 98 | return clip_extra_context_tokens 99 | 100 | 101 | class SimpleReSampler(nn.Module): 102 | def __init__(self, embedding_dim=1280, output_dim=1024): 103 | super().__init__() 104 | self.proj_out = nn.Linear(embedding_dim, output_dim) 105 | self.norm_out = nn.LayerNorm(output_dim) 106 | 107 | def forward(self, latents): 108 | """ 109 | latents: B 256 N 110 | """ 111 | latents = self.proj_out(latents) 112 | return self.norm_out(latents) 113 | 114 | 115 | class Resampler(nn.Module): 116 | def __init__( 117 | self, 118 | dim=1024, 119 | depth=8, 120 | dim_head=64, 121 | heads=16, 122 | num_queries=8, 123 | embedding_dim=768, 124 | output_dim=1024, 125 | ff_mult=4, 126 | ): 127 | super().__init__() 128 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) 129 | self.proj_in = nn.Linear(embedding_dim, dim) 130 | self.proj_out = nn.Linear(dim, output_dim) 131 | self.norm_out = nn.LayerNorm(output_dim) 132 | 133 | self.layers = nn.ModuleList([]) 134 | for _ in range(depth): 135 | self.layers.append( 136 | nn.ModuleList( 137 | [ 138 | PerceiverAttention(dim=dim, 139 | dim_head=dim_head, 140 | heads=heads), 141 | FeedForward(dim=dim, mult=ff_mult), 142 | ] 143 | ) 144 | ) 145 | 146 | def forward(self, x): 147 | latents = self.latents.repeat(x.size(0), 1, 1) 148 | x = self.proj_in(x) 149 | for attn, ff in self.layers: 150 | latents = attn(x, latents) + latents 151 | latents = ff(latents) + latents 152 | 153 | latents = self.proj_out(latents) 154 | return self.norm_out(latents) 155 | 156 | 157 | if __name__ == '__main__': 158 | resampler = Resampler(embedding_dim=1280) 159 | resampler = SimpleReSampler(embedding_dim=1280) 160 | tensor = torch.rand(4, 257, 1280) 161 | embed = resampler(tensor) 162 | # embed = (tensor) 163 | print(embed.shape) 164 | -------------------------------------------------------------------------------- /threestudio/models/imagedream/imagedream/ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/imagedream/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /threestudio/models/imagedream/imagedream/ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.sum( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=[1, 2, 3]): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims, 69 | ) 70 | 71 | def mode(self): 72 | return self.mean 73 | 74 | 75 | def normal_kl(mean1, logvar1, mean2, logvar2): 76 | """ 77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 78 | Compute the KL divergence between two gaussians. 79 | Shapes are automatically broadcasted, so batches can be compared to 80 | scalars, among other use cases. 81 | """ 82 | tensor = None 83 | for obj in (mean1, logvar1, mean2, logvar2): 84 | if isinstance(obj, torch.Tensor): 85 | tensor = obj 86 | break 87 | assert tensor is not None, "at least one argument must be a Tensor" 88 | 89 | # Force variances to be Tensors. Broadcasting helps convert scalars to 90 | # Tensors, but it does not work for torch.exp(). 91 | logvar1, logvar2 = [ 92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 93 | for x in (logvar1, logvar2) 94 | ] 95 | 96 | return 0.5 * ( 97 | -1.0 98 | + logvar2 99 | - logvar1 100 | + torch.exp(logvar1 - logvar2) 101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 102 | ) 103 | -------------------------------------------------------------------------------- /threestudio/models/imagedream/imagedream/ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | torch.tensor(0, dtype=torch.int) 16 | if use_num_upates 17 | else torch.tensor(-1, dtype=torch.int), 18 | ) 19 | 20 | for name, p in model.named_parameters(): 21 | if p.requires_grad: 22 | # remove as '.'-character is not allowed in buffers 23 | s_name = name.replace(".", "") 24 | self.m_name2s_name.update({name: s_name}) 25 | self.register_buffer(s_name, p.clone().detach().data) 26 | 27 | self.collected_params = [] 28 | 29 | def reset_num_updates(self): 30 | del self.num_updates 31 | self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) 32 | 33 | def forward(self, model): 34 | decay = self.decay 35 | 36 | if self.num_updates >= 0: 37 | self.num_updates += 1 38 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 39 | 40 | one_minus_decay = 1.0 - decay 41 | 42 | with torch.no_grad(): 43 | m_param = dict(model.named_parameters()) 44 | shadow_params = dict(self.named_buffers()) 45 | 46 | for key in m_param: 47 | if m_param[key].requires_grad: 48 | sname = self.m_name2s_name[key] 49 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 50 | shadow_params[sname].sub_( 51 | one_minus_decay * (shadow_params[sname] - m_param[key]) 52 | ) 53 | else: 54 | assert not key in self.m_name2s_name 55 | 56 | def copy_to(self, model): 57 | m_param = dict(model.named_parameters()) 58 | shadow_params = dict(self.named_buffers()) 59 | for key in m_param: 60 | if m_param[key].requires_grad: 61 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 62 | else: 63 | assert not key in self.m_name2s_name 64 | 65 | def store(self, parameters): 66 | """ 67 | Save the current parameters for restoring later. 68 | Args: 69 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 70 | temporarily stored. 71 | """ 72 | self.collected_params = [param.clone() for param in parameters] 73 | 74 | def restore(self, parameters): 75 | """ 76 | Restore the parameters stored with the `store` method. 77 | Useful to validate the model with EMA parameters without affecting the 78 | original optimization process. Store the parameters before the 79 | `copy_to` method. After validation (or model saving), use this to 80 | restore the former parameters. 81 | Args: 82 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 83 | updated with the stored parameters. 84 | """ 85 | for c_param, param in zip(self.collected_params, parameters): 86 | param.data.copy_(c_param.data) 87 | -------------------------------------------------------------------------------- /threestudio/models/imagedream/imagedream/ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aejion/4Diffusion/627d0cd3f1b987822c4086033f8ea6b8045d145b/threestudio/models/imagedream/imagedream/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /threestudio/models/imagedream/imagedream/ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import random 4 | import torch 5 | import numpy as np 6 | from collections import abc 7 | from einops import rearrange 8 | from functools import partial 9 | 10 | import multiprocessing as mp 11 | from threading import Thread 12 | from queue import Queue 13 | 14 | from inspect import isfunction 15 | from PIL import Image, ImageDraw, ImageFont 16 | 17 | 18 | def log_txt_as_img(wh, xc, size=10): 19 | # wh a tuple of (width, height) 20 | # xc a list of captions to plot 21 | b = len(xc) 22 | txts = list() 23 | for bi in range(b): 24 | txt = Image.new("RGB", wh, color="white") 25 | draw = ImageDraw.Draw(txt) 26 | font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) 27 | nc = int(40 * (wh[0] / 256)) 28 | lines = "\n".join( 29 | xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc) 30 | ) 31 | 32 | try: 33 | draw.text((0, 0), lines, fill="black", font=font) 34 | except UnicodeEncodeError: 35 | print("Cant encode string for logging. Skipping.") 36 | 37 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 38 | txts.append(txt) 39 | txts = np.stack(txts) 40 | txts = torch.tensor(txts) 41 | return txts 42 | 43 | 44 | def ismap(x): 45 | if not isinstance(x, torch.Tensor): 46 | return False 47 | return (len(x.shape) == 4) and (x.shape[1] > 3) 48 | 49 | 50 | def isimage(x): 51 | if not isinstance(x, torch.Tensor): 52 | return False 53 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 54 | 55 | 56 | def exists(x): 57 | return x is not None 58 | 59 | 60 | def default(val, d): 61 | if exists(val): 62 | return val 63 | return d() if isfunction(d) else d 64 | 65 | 66 | def mean_flat(tensor): 67 | """ 68 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 69 | Take the mean over all non-batch dimensions. 70 | """ 71 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 72 | 73 | 74 | def count_params(model, verbose=False): 75 | total_params = sum(p.numel() for p in model.parameters()) 76 | if verbose: 77 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 78 | return total_params 79 | 80 | 81 | def instantiate_from_config(config): 82 | if not "target" in config: 83 | if config == "__is_first_stage__": 84 | return None 85 | elif config == "__is_unconditional__": 86 | return None 87 | raise KeyError("Expected key `target` to instantiate.") 88 | # import pdb; pdb.set_trace() 89 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 90 | 91 | 92 | def get_obj_from_str(string, reload=False): 93 | module, cls = string.rsplit(".", 1) 94 | # import pdb; pdb.set_trace() 95 | if reload: 96 | module_imp = importlib.import_module(module) 97 | importlib.reload(module_imp) 98 | return getattr(importlib.import_module(module, package=None), cls) 99 | 100 | 101 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 102 | # create dummy dataset instance 103 | 104 | # run prefetching 105 | if idx_to_fn: 106 | res = func(data, worker_id=idx) 107 | else: 108 | res = func(data) 109 | Q.put([idx, res]) 110 | Q.put("Done") 111 | 112 | 113 | def parallel_data_prefetch( 114 | func: callable, 115 | data, 116 | n_proc, 117 | target_data_type="ndarray", 118 | cpu_intensive=True, 119 | use_worker_id=False, 120 | ): 121 | # if target_data_type not in ["ndarray", "list"]: 122 | # raise ValueError( 123 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 124 | # ) 125 | if isinstance(data, np.ndarray) and target_data_type == "list": 126 | raise ValueError("list expected but function got ndarray.") 127 | elif isinstance(data, abc.Iterable): 128 | if isinstance(data, dict): 129 | print( 130 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 131 | ) 132 | data = list(data.values()) 133 | if target_data_type == "ndarray": 134 | data = np.asarray(data) 135 | else: 136 | data = list(data) 137 | else: 138 | raise TypeError( 139 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 140 | ) 141 | 142 | if cpu_intensive: 143 | Q = mp.Queue(1000) 144 | proc = mp.Process 145 | else: 146 | Q = Queue(1000) 147 | proc = Thread 148 | # spawn processes 149 | if target_data_type == "ndarray": 150 | arguments = [ 151 | [func, Q, part, i, use_worker_id] 152 | for i, part in enumerate(np.array_split(data, n_proc)) 153 | ] 154 | else: 155 | step = ( 156 | int(len(data) / n_proc + 1) 157 | if len(data) % n_proc != 0 158 | else int(len(data) / n_proc) 159 | ) 160 | arguments = [ 161 | [func, Q, part, i, use_worker_id] 162 | for i, part in enumerate( 163 | [data[i : i + step] for i in range(0, len(data), step)] 164 | ) 165 | ] 166 | processes = [] 167 | for i in range(n_proc): 168 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 169 | processes += [p] 170 | 171 | # start processes 172 | print(f"Start prefetching...") 173 | import time 174 | 175 | start = time.time() 176 | gather_res = [[] for _ in range(n_proc)] 177 | try: 178 | for p in processes: 179 | p.start() 180 | 181 | k = 0 182 | while k < n_proc: 183 | # get result 184 | res = Q.get() 185 | if res == "Done": 186 | k += 1 187 | else: 188 | gather_res[res[0]] = res[1] 189 | 190 | except Exception as e: 191 | print("Exception: ", e) 192 | for p in processes: 193 | p.terminate() 194 | 195 | raise e 196 | finally: 197 | for p in processes: 198 | p.join() 199 | print(f"Prefetching complete. [{time.time() - start} sec.]") 200 | 201 | if target_data_type == "ndarray": 202 | if not isinstance(gather_res[0], np.ndarray): 203 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 204 | 205 | # order outputs 206 | return np.concatenate(gather_res, axis=0) 207 | elif target_data_type == "list": 208 | out = [] 209 | for r in gather_res: 210 | out.extend(r) 211 | return out 212 | else: 213 | return gather_res 214 | 215 | def set_seed(seed): 216 | random.seed(seed) 217 | np.random.seed(seed) 218 | torch.manual_seed(seed) 219 | torch.cuda.manual_seed_all(seed) 220 | 221 | def add_random_background(image, bg_color=None): 222 | if np.array(image).shape[-1] == 3: return image 223 | bg_color = np.random.rand() * 255 if bg_color is None else bg_color 224 | # print(bg_color) 225 | image = np.array(image) 226 | rgb, alpha = image[..., :3], image[..., 3:] 227 | alpha = alpha.astype(np.float32) / 255.0 228 | image_new = rgb * alpha + bg_color * (1 - alpha) 229 | # print(image_new) 230 | return Image.fromarray(image_new.astype(np.uint8)) 231 | -------------------------------------------------------------------------------- /threestudio/models/imagedream/imagedream/model_zoo.py: -------------------------------------------------------------------------------- 1 | """ Utiliy functions to load pre-trained models more easily """ 2 | import os 3 | import pkg_resources 4 | from omegaconf import OmegaConf 5 | 6 | import torch 7 | from huggingface_hub import hf_hub_download 8 | 9 | from .ldm.util import instantiate_from_config 10 | 11 | 12 | PRETRAINED_MODELS = { 13 | "sd-v2.1-base-4view-ipmv": { 14 | "config": "sd_v2_base_ipmv.yaml", 15 | "repo_id": "Peng-Wang/ImageDream", 16 | "filename": "sd-v2.1-base-4view-ipmv.pt", 17 | }, 18 | "sd-v2.1-base-4view-ipmv-local": { 19 | "config": "sd_v2_base_ipmv_local.yaml", 20 | "repo_id": "Peng-Wang/ImageDream", 21 | "filename": "sd-v2.1-base-4view-ipmv-local.pt", 22 | }, 23 | } 24 | 25 | 26 | def get_config_file(config_path): 27 | cfg_file = pkg_resources.resource_filename( 28 | "imagedream", os.path.join("configs", config_path) 29 | ) 30 | if not os.path.exists(cfg_file): 31 | raise RuntimeError(f"Config {config_path} not available!") 32 | return cfg_file 33 | 34 | 35 | def build_model(model_name, config_path=None, ckpt_path=None, cache_dir=None): 36 | if (config_path is not None) and (ckpt_path is not None): 37 | config = OmegaConf.load(config_path) 38 | model = instantiate_from_config(config.model) 39 | ckpt = torch.load(ckpt_path, map_location="cpu") 40 | new_ckpt = ckpt 41 | if config.model.params.unet_config.params.use_motion_module: 42 | new_ckpt = {} 43 | keys = list(ckpt.keys()) 44 | for key in keys: 45 | if 'middle_block.2' in key: 46 | new_ckpt[key.replace('middle_block.2', 'middle_block.3')] = ckpt[key] 47 | elif 'output_blocks.2.1' in key: 48 | new_ckpt[key.replace('output_blocks.2.1', 'output_blocks.2.2')] = ckpt[key] 49 | elif 'output_blocks.5.2' in key: 50 | new_ckpt[key.replace('output_blocks.5.2', 'output_blocks.5.3')] = ckpt[key] 51 | elif 'output_blocks.8.2' in key: 52 | new_ckpt[key.replace('output_blocks.8.2', 'output_blocks.8.3')] = ckpt[key] 53 | else: 54 | new_ckpt[key] = ckpt[key] 55 | 56 | missing, unexpected = model.load_state_dict(new_ckpt, strict=False) 57 | print(f"### missing keys: {len(missing)}; \n### unexpected keys: {len(unexpected)};") 58 | ckpt = torch.load( 59 | './ckpts/4dm.ckpt', 60 | map_location="cpu")['state_dict'] 61 | new_ckpt = {} 62 | keys = list(ckpt.keys()) 63 | for key in keys: 64 | new_ckpt[key.replace('module.', '')] = ckpt[key] 65 | missing, unexpected = model.model.diffusion_model.load_state_dict(new_ckpt, strict=False) 66 | print(f"### missing keys: {len(missing)}; \n### unexpected keys: {len(unexpected)};") 67 | return model 68 | 69 | if not model_name in PRETRAINED_MODELS: 70 | raise RuntimeError( 71 | f"Model name {model_name} is not a pre-trained model. Available models are:\n- " 72 | + "\n- ".join(PRETRAINED_MODELS.keys()) 73 | ) 74 | model_info = PRETRAINED_MODELS[model_name] 75 | 76 | # Instiantiate the model 77 | print(f"Loading model from config: {model_info['config']}") 78 | config_file = get_config_file(model_info["config"]) 79 | config = OmegaConf.load(config_file) 80 | model = instantiate_from_config(config.model) 81 | 82 | # Load pre-trained checkpoint from huggingface 83 | if not ckpt_path: 84 | ckpt_path = hf_hub_download( 85 | repo_id=model_info["repo_id"], 86 | filename=model_info["filename"], 87 | cache_dir=cache_dir, 88 | ) 89 | print(f"Loading model from cache file: {ckpt_path}") 90 | model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False) 91 | return model 92 | -------------------------------------------------------------------------------- /threestudio/models/imagedream/scripts/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | from PIL import Image 5 | import numpy as np 6 | from omegaconf import OmegaConf 7 | import torch 8 | 9 | from threestudio.models.imagedream.imagedream.camera_utils import get_camera 10 | from threestudio.models.imagedream.imagedream.ldm.util import ( 11 | instantiate_from_config, 12 | set_seed, 13 | add_random_background 14 | ) 15 | from threestudio.models.imagedream.imagedream.ldm.models.diffusion.ddim import DDIMSampler 16 | from threestudio.models.imagedream.imagedream.model_zoo import build_model 17 | from torchvision import transforms as T 18 | 19 | import torchvision 20 | from einops import rearrange 21 | import imageio 22 | 23 | def i2i( 24 | model, 25 | image_size, 26 | prompt, 27 | uc, 28 | sampler, 29 | ip=None, 30 | step=20, 31 | scale=5.0, 32 | batch_size=8, 33 | ddim_eta=0.0, 34 | dtype=torch.float32, 35 | device="cuda", 36 | camera=None, 37 | num_frames=4, 38 | pixel_control=False, 39 | transform=None 40 | ): 41 | """ The function supports additional image prompt. 42 | Args: 43 | model (_type_): the image dream model 44 | image_size (_type_): size of diffusion output 45 | prompt (_type_): text prompt for the image 46 | uc (_type_): _description_ 47 | sampler (_type_): _description_ 48 | ip (Image, optional): the image prompt. Defaults to None. 49 | step (int, optional): _description_. Defaults to 20. 50 | scale (float, optional): _description_. Defaults to 7.5. 51 | batch_size (int, optional): _description_. Defaults to 8. 52 | ddim_eta (float, optional): _description_. Defaults to 0.0. 53 | dtype (_type_, optional): _description_. Defaults to torch.float32. 54 | device (str, optional): _description_. Defaults to "cuda". 55 | camera (_type_, optional): _description_. Defaults to None. 56 | num_frames (int, optional): _description_. Defaults to 4 57 | pixel_control: whether to use pixel conditioning. Defaults to False. 58 | """ 59 | if type(prompt) != list: 60 | prompt = [prompt] 61 | 62 | with torch.no_grad(), torch.autocast(device_type=device, dtype=dtype): 63 | c = model.get_learned_conditioning(prompt).to(device) 64 | c_ = {"context": c.repeat(batch_size, 1, 1)} 65 | uc_ = {"context": uc.repeat(batch_size, 1, 1)} 66 | 67 | if camera is not None: 68 | c_["camera"] = uc_["camera"] = camera 69 | c_["num_frames"] = uc_["num_frames"] = num_frames 70 | 71 | if ip is not None: 72 | ip_embed = model.get_learned_image_conditioning(ip).to(device) 73 | ip_ = ip_embed.repeat(batch_size // 8, 1, 1) 74 | c_["ip"] = ip_ 75 | uc_["ip"] = torch.zeros_like(ip_) 76 | 77 | if pixel_control: 78 | assert camera is not None 79 | # ip = transform(ip).to(device) 80 | ip = ((ip / 255. - 0.5) * 2).to(device).permute(0, 3, 1, 2) 81 | ip_img = model.get_first_stage_encoding( 82 | model.encode_first_stage(ip) 83 | ) 84 | c_["ip_img"] = ip_img 85 | uc_["ip_img"] = torch.zeros_like(ip_img) 86 | 87 | shape = [4, image_size // 8, image_size // 8] 88 | samples_ddim, _ = sampler.sample( 89 | S=step, 90 | conditioning=c_, 91 | batch_size=batch_size, 92 | shape=shape, 93 | verbose=False, 94 | unconditional_guidance_scale=scale, 95 | unconditional_conditioning=uc_, 96 | eta=ddim_eta, 97 | x_T=None, 98 | ) 99 | x_sample = model.decode_first_stage(samples_ddim) 100 | x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) 101 | x_sample = 255.0 * x_sample.permute(0, 2, 3, 1).cpu().numpy() 102 | 103 | return list(x_sample.astype(np.uint8)) 104 | 105 | 106 | class ImageDreamDiffusion(): 107 | def __init__(self, args) -> None: 108 | assert args.mode in ["pixel", "local"] 109 | assert args.num_frames % 4 == 1 if args.mode == "pixel" else True 110 | 111 | set_seed(args.seed) 112 | dtype = torch.float16 if args.fp16 else torch.float32 113 | device = args.device 114 | batch_size = max(4, args.num_frames) * 8 115 | 116 | print("load image dream diffusion model ... ") 117 | model = build_model(args.model_name, 118 | config_path=args.config_path, 119 | ckpt_path=args.ckpt_path) 120 | # quit() 121 | model.device = device 122 | model.to(device) 123 | model.eval() 124 | 125 | neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear." 126 | # neg_texts = "" 127 | sampler = DDIMSampler(model) 128 | uc = model.get_learned_conditioning([neg_texts]).to(device) 129 | print("image dream model load done . ") 130 | 131 | # pre-compute camera matrices 132 | if args.use_camera: 133 | camera = get_camera( 134 | num_frames=4, 135 | elevation=0, 136 | azimuth_start=90, 137 | azimuth_span=360, 138 | extra_view=(args.mode == "pixel") 139 | ) 140 | camera = camera.repeat(batch_size // args.num_frames, 1).to(device) 141 | else: 142 | camera = None 143 | 144 | self.image_transform = T.Compose( 145 | [ 146 | T.Resize((args.size, args.size)), 147 | T.ToTensor(), 148 | T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 149 | ] 150 | ) 151 | 152 | self.dtype = dtype 153 | self.device = device 154 | self.batch_size = batch_size 155 | self.args = args 156 | self.model = model 157 | self.sampler = sampler 158 | self.uc = uc 159 | self.camera = camera 160 | 161 | def diffuse(self, t, ip, n_test=3): 162 | images = [] 163 | for _ in range(n_test): 164 | img = i2i( 165 | self.model, 166 | self.args.size, 167 | t, 168 | self.uc, 169 | self.sampler, 170 | ip=ip, 171 | step=100, 172 | scale=5, 173 | batch_size=self.batch_size, 174 | ddim_eta=0.0, 175 | dtype=self.dtype, 176 | device=self.device, 177 | camera=self.camera, 178 | num_frames=args.num_frames, 179 | pixel_control=(args.mode == "pixel"), 180 | transform=self.image_transform 181 | ) 182 | img = np.concatenate(img, 1) 183 | images.append(img) 184 | return images 185 | 186 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8): 187 | videos = rearrange(videos, "b c t h w -> t b c h w") 188 | outputs = [] 189 | for x in videos: 190 | x = torchvision.utils.make_grid(x, nrow=n_rows) 191 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 192 | x = x.numpy().astype(np.uint8) 193 | outputs.append(x) 194 | 195 | os.makedirs(os.path.dirname(path), exist_ok=True) 196 | imageio.mimsave(path, outputs, fps=fps) 197 | 198 | if __name__ == "__main__": 199 | parser = argparse.ArgumentParser() 200 | parser.add_argument( 201 | "--model_name", 202 | type=str, 203 | default="sd-v2.1-base-4view-ipmv", 204 | help="load pre-trained model from hugginface", 205 | ) 206 | parser.add_argument( 207 | "--config_path", 208 | type=str, 209 | default=None, 210 | help="load model from local config (override model_name)", 211 | ) 212 | parser.add_argument( 213 | "--ckpt_path", type=str, default=None, help="path to local checkpoint" 214 | ) 215 | parser.add_argument("--text", type=str, default="an astronaut riding a horse") 216 | parser.add_argument("--image", type=str, default="./assets/astrounaut.png") 217 | parser.add_argument("--name", type=str, default="yoda") 218 | parser.add_argument("--suffix", type=str, default=", 3d asset") 219 | parser.add_argument("--size", type=int, default=256) 220 | parser.add_argument("--num_video_frames", type=int, default=25) 221 | parser.add_argument( 222 | "--num_frames", type=int, default=5, help=" \ 223 | num of frames (views) to generate, should be in [4 or 5], \ 224 | 5 for pixel control, 4 for local control" 225 | ) 226 | parser.add_argument("--use_camera", type=int, default=1) 227 | parser.add_argument("--camera_elev", type=int, default=5) 228 | parser.add_argument("--camera_azim", type=int, default=90) 229 | parser.add_argument("--camera_azim_span", type=int, default=360) 230 | parser.add_argument("--seed", type=int, default=23) 231 | parser.add_argument("--fp16", action="store_true") 232 | parser.add_argument("--device", type=str, default="cuda") 233 | parser.add_argument( 234 | "--mode", type=str, default="pixel", 235 | help="ip mode default pixel" 236 | ) 237 | args = parser.parse_args() 238 | 239 | t = args.text 240 | 241 | image_dream = ImageDreamDiffusion(args) 242 | assert args.num_frames in [4, 5], "num_frames should be in [4, 5]" 243 | assert os.path.exists(args.image), "image does not exist!" 244 | ip = [] 245 | for i in range(8): 246 | img_idx = int(args.num_video_frames / 8 * i) 247 | img_path = os.path.join(args.image, '{0}.png'.format(img_idx)) 248 | img1 = Image.open(img_path) 249 | img1 = add_random_background(img1, bg_color=255).resize((256, 256)) 250 | ip.append(np.array(img1)) 251 | ip_ = torch.from_numpy(np.array(ip)) 252 | batch_size = max(4, args.num_frames) * 8 253 | device = args.device 254 | 255 | camera = get_camera( 256 | num_frames=4, 257 | elevation=0, 258 | azimuth_start=90, 259 | azimuth_span=360, 260 | extra_view=(args.mode == "pixel") 261 | ) 262 | camera = camera.repeat(batch_size // args.num_frames, 1).to(device) 263 | image_dream.camera = camera 264 | images = image_dream.diffuse(t, ip_, n_test=1) 265 | 266 | images = np.concatenate(images, 0) 267 | 268 | test_save_path = './threestudio/models/imagedream/4dm/{0}'.format(args.name) 269 | for i in range(4): 270 | # out_image = [] 271 | out_image_path = os.path.join(test_save_path, str(i)) 272 | os.makedirs(out_image_path, exist_ok=True) 273 | for j in range(8): 274 | out_image = images[:, 256 * 5 * j + 256 * i: 256 * 5 * j + 256 * i + 256, :] 275 | Image.fromarray(out_image).save(f"{out_image_path}/{j}.png") 276 | -------------------------------------------------------------------------------- /threestudio/models/imagedream/scripts/demo.sh: -------------------------------------------------------------------------------- 1 | # Run this script under ImageDream/ 2 | export PYTHONPATH=$PYTHONPATH:./ 3 | #export HF_ENDPOINT=https://hf-mirror.com 4 | 5 | # test pixel version 6 | python3 threestudio/models/imagedream/scripts/demo.py \ 7 | --image "./threestudio/models/imagedream/assets/yoda" \ 8 | --name "yoda" \ 9 | --text "baby yoda in the style of Mormookiee" \ 10 | --num_video_frames 25 \ 11 | --config_path "./threestudio/models/imagedream/imagedream/configs/sd_v2_base_ipmv.yaml" \ 12 | --ckpt_path "./ckpts/sd-v2.1-base-4view-ipmv.pt" \ 13 | --mode "pixel" \ 14 | --num_frames 5 15 | -------------------------------------------------------------------------------- /threestudio/models/imagedream/scripts/remove_bg.py: -------------------------------------------------------------------------------- 1 | import rembg 2 | import cv2 3 | import os 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument( 8 | "--name", 9 | type=str, 10 | default="yoda", 11 | ) 12 | args = parser.parse_args() 13 | 14 | name = args.name 15 | bg_remover = rembg.new_session() 16 | test_save_path = './threestudio/models/imagedream/4dm/{0}'.format(name) 17 | for i in range(4): 18 | out_image_path = os.path.join(test_save_path, str(i)) 19 | for j in range(8): 20 | file = f"{out_image_path}/{j}.png" 21 | file_out = f"{out_image_path}/{j}_rgba.png" 22 | 23 | img = cv2.imread(file, cv2.IMREAD_UNCHANGED) 24 | if img.shape[-1] == 3: 25 | img = rembg.remove(img, session=bg_remover) 26 | cv2.imwrite(file_out, img) 27 | else: 28 | cv2.imwrite(file_out, img) 29 | 30 | reference_video_path = './threestudio/models/imagedream/assets/{0}'.format(name) 31 | imgs = os.listdir(reference_video_path) 32 | for img_path in imgs: 33 | if not img_path.endswith('.png'): continue 34 | img_num = img_path.split('.')[0] 35 | file = f"{reference_video_path}/{img_num}.png" 36 | file_out = f"{reference_video_path}/{img_num}_rgba.png" 37 | img = cv2.imread(file, cv2.IMREAD_UNCHANGED) 38 | if img.shape[-1] == 3: 39 | img = rembg.remove(img, session=bg_remover) 40 | cv2.imwrite(file_out, img) 41 | else: 42 | cv2.imwrite(file_out, img) 43 | -------------------------------------------------------------------------------- /threestudio/models/isosurface.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import threestudio 7 | from threestudio.models.mesh import Mesh 8 | from threestudio.utils.typing import * 9 | 10 | 11 | class IsosurfaceHelper(nn.Module): 12 | points_range: Tuple[float, float] = (0, 1) 13 | 14 | @property 15 | def grid_vertices(self) -> Float[Tensor, "N 3"]: 16 | raise NotImplementedError 17 | 18 | 19 | class MarchingCubeCPUHelper(IsosurfaceHelper): 20 | def __init__(self, resolution: int) -> None: 21 | super().__init__() 22 | self.resolution = resolution 23 | import mcubes 24 | 25 | self.mc_func: Callable = mcubes.marching_cubes 26 | self._grid_vertices: Optional[Float[Tensor, "N3 3"]] = None 27 | self._dummy: Float[Tensor, "..."] 28 | self.register_buffer( 29 | "_dummy", torch.zeros(0, dtype=torch.float32), persistent=False 30 | ) 31 | 32 | @property 33 | def grid_vertices(self) -> Float[Tensor, "N3 3"]: 34 | if self._grid_vertices is None: 35 | # keep the vertices on CPU so that we can support very large resolution 36 | x, y, z = ( 37 | torch.linspace(*self.points_range, self.resolution), 38 | torch.linspace(*self.points_range, self.resolution), 39 | torch.linspace(*self.points_range, self.resolution), 40 | ) 41 | x, y, z = torch.meshgrid(x, y, z, indexing="ij") 42 | verts = torch.cat( 43 | [x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1 44 | ).reshape(-1, 3) 45 | self._grid_vertices = verts 46 | return self._grid_vertices 47 | 48 | def forward( 49 | self, 50 | level: Float[Tensor, "N3 1"], 51 | deformation: Optional[Float[Tensor, "N3 3"]] = None, 52 | ) -> Mesh: 53 | if deformation is not None: 54 | threestudio.warn( 55 | f"{self.__class__.__name__} does not support deformation. Ignoring." 56 | ) 57 | level = -level.view(self.resolution, self.resolution, self.resolution) 58 | v_pos, t_pos_idx = self.mc_func( 59 | level.detach().cpu().numpy(), 0.0 60 | ) # transform to numpy 61 | v_pos, t_pos_idx = ( 62 | torch.from_numpy(v_pos).float().to(self._dummy.device), 63 | torch.from_numpy(t_pos_idx.astype(np.int64)).long().to(self._dummy.device), 64 | ) # transform back to torch tensor on CUDA 65 | v_pos = v_pos / (self.resolution - 1.0) 66 | return Mesh(v_pos=v_pos, t_pos_idx=t_pos_idx) 67 | 68 | 69 | class MarchingTetrahedraHelper(IsosurfaceHelper): 70 | def __init__(self, resolution: int, tets_path: str): 71 | super().__init__() 72 | self.resolution = resolution 73 | self.tets_path = tets_path 74 | 75 | self.triangle_table: Float[Tensor, "..."] 76 | self.register_buffer( 77 | "triangle_table", 78 | torch.as_tensor( 79 | [ 80 | [-1, -1, -1, -1, -1, -1], 81 | [1, 0, 2, -1, -1, -1], 82 | [4, 0, 3, -1, -1, -1], 83 | [1, 4, 2, 1, 3, 4], 84 | [3, 1, 5, -1, -1, -1], 85 | [2, 3, 0, 2, 5, 3], 86 | [1, 4, 0, 1, 5, 4], 87 | [4, 2, 5, -1, -1, -1], 88 | [4, 5, 2, -1, -1, -1], 89 | [4, 1, 0, 4, 5, 1], 90 | [3, 2, 0, 3, 5, 2], 91 | [1, 3, 5, -1, -1, -1], 92 | [4, 1, 2, 4, 3, 1], 93 | [3, 0, 4, -1, -1, -1], 94 | [2, 0, 1, -1, -1, -1], 95 | [-1, -1, -1, -1, -1, -1], 96 | ], 97 | dtype=torch.long, 98 | ), 99 | persistent=False, 100 | ) 101 | self.num_triangles_table: Integer[Tensor, "..."] 102 | self.register_buffer( 103 | "num_triangles_table", 104 | torch.as_tensor( 105 | [0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long 106 | ), 107 | persistent=False, 108 | ) 109 | self.base_tet_edges: Integer[Tensor, "..."] 110 | self.register_buffer( 111 | "base_tet_edges", 112 | torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long), 113 | persistent=False, 114 | ) 115 | 116 | tets = np.load(self.tets_path) 117 | self._grid_vertices: Float[Tensor, "..."] 118 | self.register_buffer( 119 | "_grid_vertices", 120 | torch.from_numpy(tets["vertices"]).float(), 121 | persistent=False, 122 | ) 123 | self.indices: Integer[Tensor, "..."] 124 | self.register_buffer( 125 | "indices", torch.from_numpy(tets["indices"]).long(), persistent=False 126 | ) 127 | 128 | self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None 129 | 130 | def normalize_grid_deformation( 131 | self, grid_vertex_offsets: Float[Tensor, "Nv 3"] 132 | ) -> Float[Tensor, "Nv 3"]: 133 | return ( 134 | (self.points_range[1] - self.points_range[0]) 135 | / (self.resolution) # half tet size is approximately 1 / self.resolution 136 | * torch.tanh(grid_vertex_offsets) 137 | ) # FIXME: hard-coded activation 138 | 139 | @property 140 | def grid_vertices(self) -> Float[Tensor, "Nv 3"]: 141 | return self._grid_vertices 142 | 143 | @property 144 | def all_edges(self) -> Integer[Tensor, "Ne 2"]: 145 | if self._all_edges is None: 146 | # compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation) 147 | edges = torch.tensor( 148 | [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], 149 | dtype=torch.long, 150 | device=self.indices.device, 151 | ) 152 | _all_edges = self.indices[:, edges].reshape(-1, 2) 153 | _all_edges_sorted = torch.sort(_all_edges, dim=1)[0] 154 | _all_edges = torch.unique(_all_edges_sorted, dim=0) 155 | self._all_edges = _all_edges 156 | return self._all_edges 157 | 158 | def sort_edges(self, edges_ex2): 159 | with torch.no_grad(): 160 | order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long() 161 | order = order.unsqueeze(dim=1) 162 | 163 | a = torch.gather(input=edges_ex2, index=order, dim=1) 164 | b = torch.gather(input=edges_ex2, index=1 - order, dim=1) 165 | 166 | return torch.stack([a, b], -1) 167 | 168 | def _forward(self, pos_nx3, sdf_n, tet_fx4): 169 | with torch.no_grad(): 170 | occ_n = sdf_n > 0 171 | occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) 172 | occ_sum = torch.sum(occ_fx4, -1) 173 | valid_tets = (occ_sum > 0) & (occ_sum < 4) 174 | occ_sum = occ_sum[valid_tets] 175 | 176 | # find all vertices 177 | all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2) 178 | all_edges = self.sort_edges(all_edges) 179 | unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) 180 | 181 | unique_edges = unique_edges.long() 182 | mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 183 | mapping = ( 184 | torch.ones( 185 | (unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device 186 | ) 187 | * -1 188 | ) 189 | mapping[mask_edges] = torch.arange( 190 | mask_edges.sum(), dtype=torch.long, device=pos_nx3.device 191 | ) 192 | idx_map = mapping[idx_map] # map edges to verts 193 | 194 | interp_v = unique_edges[mask_edges] 195 | edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) 196 | edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) 197 | edges_to_interp_sdf[:, -1] *= -1 198 | 199 | denominator = edges_to_interp_sdf.sum(1, keepdim=True) 200 | 201 | edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator 202 | verts = (edges_to_interp * edges_to_interp_sdf).sum(1) 203 | 204 | idx_map = idx_map.reshape(-1, 6) 205 | 206 | v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device)) 207 | tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) 208 | num_triangles = self.num_triangles_table[tetindex] 209 | 210 | # Generate triangle indices 211 | faces = torch.cat( 212 | ( 213 | torch.gather( 214 | input=idx_map[num_triangles == 1], 215 | dim=1, 216 | index=self.triangle_table[tetindex[num_triangles == 1]][:, :3], 217 | ).reshape(-1, 3), 218 | torch.gather( 219 | input=idx_map[num_triangles == 2], 220 | dim=1, 221 | index=self.triangle_table[tetindex[num_triangles == 2]][:, :6], 222 | ).reshape(-1, 3), 223 | ), 224 | dim=0, 225 | ) 226 | 227 | return verts, faces 228 | 229 | def forward( 230 | self, 231 | level: Float[Tensor, "N3 1"], 232 | deformation: Optional[Float[Tensor, "N3 3"]] = None, 233 | ) -> Mesh: 234 | if deformation is not None: 235 | grid_vertices = self.grid_vertices + self.normalize_grid_deformation( 236 | deformation 237 | ) 238 | else: 239 | grid_vertices = self.grid_vertices 240 | 241 | v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices) 242 | 243 | mesh = Mesh( 244 | v_pos=v_pos, 245 | t_pos_idx=t_pos_idx, 246 | # extras 247 | grid_vertices=grid_vertices, 248 | tet_edges=self.all_edges, 249 | grid_level=level, 250 | grid_deformation=deformation, 251 | ) 252 | 253 | return mesh 254 | -------------------------------------------------------------------------------- /threestudio/models/materials/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | base, 3 | diffuse_with_point_light_material, 4 | neural_radiance_material, 5 | no_material, 6 | sd_latent_adapter_material, 7 | ) 8 | -------------------------------------------------------------------------------- /threestudio/models/materials/base.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.utils.base import BaseModule 10 | from threestudio.utils.typing import * 11 | 12 | 13 | class BaseMaterial(BaseModule): 14 | @dataclass 15 | class Config(BaseModule.Config): 16 | pass 17 | 18 | cfg: Config 19 | requires_normal: bool = False 20 | 21 | def configure(self): 22 | pass 23 | 24 | def forward(self, *args, **kwargs) -> Float[Tensor, "*B 3"]: 25 | raise NotImplementedError 26 | 27 | def export(self, *args, **kwargs) -> Dict[str, Any]: 28 | return {} 29 | -------------------------------------------------------------------------------- /threestudio/models/materials/diffuse_with_point_light_material.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.utils.ops import dot, get_activation 11 | from threestudio.utils.typing import * 12 | 13 | 14 | @threestudio.register("diffuse-with-point-light-material") 15 | class DiffuseWithPointLightMaterial(BaseMaterial): 16 | @dataclass 17 | class Config(BaseMaterial.Config): 18 | ambient_light_color: Tuple[float, float, float] = (0.1, 0.1, 0.1) 19 | diffuse_light_color: Tuple[float, float, float] = (0.9, 0.9, 0.9) 20 | ambient_only_steps: int = 1000 21 | diffuse_prob: float = 0.75 22 | textureless_prob: float = 0.5 23 | albedo_activation: str = "sigmoid" 24 | soft_shading: bool = False 25 | 26 | cfg: Config 27 | requires_normal: bool = True 28 | 29 | def configure(self) -> None: 30 | self.ambient_light_color: Float[Tensor, "3"] 31 | self.register_buffer( 32 | "ambient_light_color", 33 | torch.as_tensor(self.cfg.ambient_light_color, dtype=torch.float32), 34 | ) 35 | self.diffuse_light_color: Float[Tensor, "3"] 36 | self.register_buffer( 37 | "diffuse_light_color", 38 | torch.as_tensor(self.cfg.diffuse_light_color, dtype=torch.float32), 39 | ) 40 | self.ambient_only = False 41 | 42 | def forward( 43 | self, 44 | features: Float[Tensor, "B ... Nf"], 45 | positions: Float[Tensor, "B ... 3"], 46 | shading_normal: Float[Tensor, "B ... 3"], 47 | light_positions: Float[Tensor, "B ... 3"], 48 | ambient_ratio: Optional[float] = None, 49 | shading: Optional[str] = None, 50 | **kwargs, 51 | ) -> Float[Tensor, "B ... 3"]: 52 | albedo = get_activation(self.cfg.albedo_activation)(features[..., :3]) 53 | if ambient_ratio is not None: 54 | # if ambient ratio is specified, use it 55 | diffuse_light_color = (1 - ambient_ratio) * torch.ones_like( 56 | self.diffuse_light_color 57 | ) 58 | ambient_light_color = ambient_ratio * torch.ones_like( 59 | self.ambient_light_color 60 | ) 61 | elif self.training and self.cfg.soft_shading: 62 | # otherwise if in training and soft shading is enabled, random a ambient ratio 63 | diffuse_light_color = torch.full_like( 64 | self.diffuse_light_color, random.random() 65 | ) 66 | ambient_light_color = 1.0 - diffuse_light_color 67 | else: 68 | # otherwise use the default fixed values 69 | diffuse_light_color = self.diffuse_light_color 70 | ambient_light_color = self.ambient_light_color 71 | 72 | light_directions: Float[Tensor, "B ... 3"] = F.normalize( 73 | light_positions - positions, dim=-1 74 | ) 75 | diffuse_light: Float[Tensor, "B ... 3"] = ( 76 | dot(shading_normal, light_directions).clamp(min=0.0) * diffuse_light_color 77 | ) 78 | textureless_color = diffuse_light + ambient_light_color 79 | # clamp albedo to [0, 1] to compute shading 80 | color = albedo.clamp(0.0, 1.0) * textureless_color 81 | 82 | if shading is None: 83 | if self.training: 84 | # adopt the same type of augmentation for the whole batch 85 | if self.ambient_only or random.random() > self.cfg.diffuse_prob: 86 | shading = "albedo" 87 | elif random.random() < self.cfg.textureless_prob: 88 | shading = "textureless" 89 | else: 90 | shading = "diffuse" 91 | else: 92 | if self.ambient_only: 93 | shading = "albedo" 94 | else: 95 | # return shaded color by default in evaluation 96 | shading = "diffuse" 97 | 98 | # multiply by 0 to prevent checking for unused parameters in DDP 99 | if shading == "albedo": 100 | return albedo + textureless_color * 0 101 | elif shading == "textureless": 102 | return albedo * 0 + textureless_color 103 | elif shading == "diffuse": 104 | return color 105 | else: 106 | raise ValueError(f"Unknown shading type {shading}") 107 | 108 | def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): 109 | if global_step < self.cfg.ambient_only_steps: 110 | self.ambient_only = True 111 | else: 112 | self.ambient_only = False 113 | 114 | def export(self, features: Float[Tensor, "*N Nf"], **kwargs) -> Dict[str, Any]: 115 | albedo = get_activation(self.cfg.albedo_activation)(features[..., :3]).clamp( 116 | 0.0, 1.0 117 | ) 118 | return {"albedo": albedo} 119 | -------------------------------------------------------------------------------- /threestudio/models/materials/neural_radiance_material.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.models.networks import get_encoding, get_mlp 11 | from threestudio.utils.ops import dot, get_activation 12 | from threestudio.utils.typing import * 13 | 14 | 15 | @threestudio.register("neural-radiance-material") 16 | class NeuralRadianceMaterial(BaseMaterial): 17 | @dataclass 18 | class Config(BaseMaterial.Config): 19 | input_feature_dims: int = 8 20 | color_activation: str = "sigmoid" 21 | dir_encoding_config: dict = field( 22 | default_factory=lambda: {"otype": "SphericalHarmonics", "degree": 3} 23 | ) 24 | mlp_network_config: dict = field( 25 | default_factory=lambda: { 26 | "otype": "FullyFusedMLP", 27 | "activation": "ReLU", 28 | "n_neurons": 16, 29 | "n_hidden_layers": 2, 30 | } 31 | ) 32 | 33 | cfg: Config 34 | 35 | def configure(self) -> None: 36 | self.encoding = get_encoding(3, self.cfg.dir_encoding_config) 37 | self.n_input_dims = self.cfg.input_feature_dims + self.encoding.n_output_dims # type: ignore 38 | self.network = get_mlp(self.n_input_dims, 3, self.cfg.mlp_network_config) 39 | 40 | def forward( 41 | self, 42 | features: Float[Tensor, "*B Nf"], 43 | viewdirs: Float[Tensor, "*B 3"], 44 | **kwargs, 45 | ) -> Float[Tensor, "*B 3"]: 46 | # viewdirs and normals must be normalized before passing to this function 47 | viewdirs = (viewdirs + 1.0) / 2.0 # (-1, 1) => (0, 1) 48 | viewdirs_embd = self.encoding(viewdirs.view(-1, 3)) 49 | network_inp = torch.cat( 50 | [features.view(-1, features.shape[-1]), viewdirs_embd], dim=-1 51 | ) 52 | color = self.network(network_inp).view(*features.shape[:-1], 3) 53 | color = get_activation(self.cfg.color_activation)(color) 54 | return color 55 | -------------------------------------------------------------------------------- /threestudio/models/materials/no_material.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.models.networks import get_encoding, get_mlp 11 | from threestudio.utils.ops import dot, get_activation 12 | from threestudio.utils.typing import * 13 | 14 | 15 | @threestudio.register("no-material") 16 | class NoMaterial(BaseMaterial): 17 | @dataclass 18 | class Config(BaseMaterial.Config): 19 | n_output_dims: int = 3 20 | color_activation: str = "sigmoid" 21 | input_feature_dims: Optional[int] = None 22 | mlp_network_config: Optional[dict] = None 23 | 24 | cfg: Config 25 | # requires_normal = True 26 | 27 | def configure(self) -> None: 28 | self.use_network = False 29 | if ( 30 | self.cfg.input_feature_dims is not None 31 | and self.cfg.mlp_network_config is not None 32 | ): 33 | self.network = get_mlp( 34 | self.cfg.input_feature_dims, 35 | self.cfg.n_output_dims, 36 | self.cfg.mlp_network_config, 37 | ) 38 | self.use_network = True 39 | 40 | def forward( 41 | self, features: Float[Tensor, "B ... Nf"], **kwargs 42 | ) -> Float[Tensor, "B ... Nc"]: 43 | if not self.use_network: 44 | assert ( 45 | features.shape[-1] == self.cfg.n_output_dims 46 | ), f"Expected {self.cfg.n_output_dims} output dims, only got {features.shape[-1]} dims input." 47 | color = get_activation(self.cfg.color_activation)(features) 48 | else: 49 | color = self.network(features.view(-1, features.shape[-1])).view( 50 | *features.shape[:-1], self.cfg.n_output_dims 51 | ) 52 | color = get_activation(self.cfg.color_activation)(color) 53 | return color 54 | 55 | def export(self, features: Float[Tensor, "*N Nf"], **kwargs) -> Dict[str, Any]: 56 | color = self(features, **kwargs).clamp(0, 1) 57 | assert color.shape[-1] >= 3, "Output color must have at least 3 channels" 58 | if color.shape[-1] > 3: 59 | threestudio.warn( 60 | "Output color has >3 channels, treating the first 3 as RGB" 61 | ) 62 | return {"albedo": color[..., :3]} 63 | -------------------------------------------------------------------------------- /threestudio/models/materials/sd_latent_adapter_material.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.utils.typing import * 11 | 12 | 13 | @threestudio.register("sd-latent-adapter-material") 14 | class StableDiffusionLatentAdapterMaterial(BaseMaterial): 15 | @dataclass 16 | class Config(BaseMaterial.Config): 17 | pass 18 | 19 | cfg: Config 20 | 21 | def configure(self) -> None: 22 | adapter = nn.Parameter( 23 | torch.as_tensor( 24 | [ 25 | # R G B 26 | [0.298, 0.207, 0.208], # L1 27 | [0.187, 0.286, 0.173], # L2 28 | [-0.158, 0.189, 0.264], # L3 29 | [-0.184, -0.271, -0.473], # L4 30 | ] 31 | ) 32 | ) 33 | self.register_parameter("adapter", adapter) 34 | 35 | def forward( 36 | self, features: Float[Tensor, "B ... 4"], **kwargs 37 | ) -> Float[Tensor, "B ... 3"]: 38 | assert features.shape[-1] == 4 39 | color = features @ self.adapter 40 | color = (color + 1) / 2 41 | color = color.clamp(0.0, 1.0) 42 | return color 43 | -------------------------------------------------------------------------------- /threestudio/models/prompt_processors/__init__.py: -------------------------------------------------------------------------------- 1 | from . import base, deepfloyd_prompt_processor, stable_diffusion_prompt_processor, zeroscope_prompt_processor 2 | -------------------------------------------------------------------------------- /threestudio/models/prompt_processors/deepfloyd_prompt_processor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass 4 | 5 | import torch 6 | import torch.nn as nn 7 | from diffusers import IFPipeline 8 | from transformers import T5EncoderModel, T5Tokenizer 9 | 10 | import threestudio 11 | from threestudio.models.prompt_processors.base import PromptProcessor, hash_prompt 12 | from threestudio.utils.misc import cleanup 13 | from threestudio.utils.typing import * 14 | 15 | 16 | @threestudio.register("deep-floyd-prompt-processor") 17 | class DeepFloydPromptProcessor(PromptProcessor): 18 | @dataclass 19 | class Config(PromptProcessor.Config): 20 | pretrained_model_name_or_path: str = "DeepFloyd/IF-I-XL-v1.0" 21 | 22 | cfg: Config 23 | 24 | ### these functions are unused, kept for debugging ### 25 | def configure_text_encoder(self) -> None: 26 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 27 | self.text_encoder = T5EncoderModel.from_pretrained( 28 | self.cfg.pretrained_model_name_or_path, 29 | subfolder="text_encoder", 30 | load_in_8bit=True, 31 | variant="8bit", 32 | device_map="auto", 33 | ) # FIXME: behavior of auto device map in multi-GPU training 34 | self.pipe = IFPipeline.from_pretrained( 35 | self.cfg.pretrained_model_name_or_path, 36 | text_encoder=self.text_encoder, # pass the previously instantiated 8bit text encoder 37 | unet=None, 38 | ) 39 | 40 | def destroy_text_encoder(self) -> None: 41 | del self.text_encoder 42 | del self.pipe 43 | cleanup() 44 | 45 | def get_text_embeddings( 46 | self, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] 47 | ) -> Tuple[Float[Tensor, "B 77 4096"], Float[Tensor, "B 77 4096"]]: 48 | text_embeddings, uncond_text_embeddings = self.pipe.encode_prompt( 49 | prompt=prompt, negative_prompt=negative_prompt, device=self.device 50 | ) 51 | return text_embeddings, uncond_text_embeddings 52 | 53 | ### 54 | 55 | @staticmethod 56 | def spawn_func(pretrained_model_name_or_path, prompts, cache_dir): 57 | max_length = 77 58 | tokenizer = T5Tokenizer.from_pretrained( 59 | pretrained_model_name_or_path, subfolder="tokenizer" 60 | ) 61 | text_encoder = T5EncoderModel.from_pretrained( 62 | pretrained_model_name_or_path, 63 | subfolder="text_encoder", 64 | torch_dtype=torch.float16, # suppress warning 65 | load_in_8bit=True, 66 | variant="8bit", 67 | device_map="auto", 68 | ) 69 | with torch.no_grad(): 70 | text_inputs = tokenizer( 71 | prompts, 72 | padding="max_length", 73 | max_length=max_length, 74 | truncation=True, 75 | add_special_tokens=True, 76 | return_tensors="pt", 77 | ) 78 | text_input_ids = text_inputs.input_ids 79 | attention_mask = text_inputs.attention_mask 80 | text_embeddings = text_encoder( 81 | text_input_ids, 82 | attention_mask=attention_mask, 83 | ) 84 | text_embeddings = text_embeddings[0] 85 | 86 | for prompt, embedding in zip(prompts, text_embeddings): 87 | torch.save( 88 | embedding, 89 | os.path.join( 90 | cache_dir, 91 | f"{hash_prompt(pretrained_model_name_or_path, prompt)}.pt", 92 | ), 93 | ) 94 | 95 | del text_encoder 96 | -------------------------------------------------------------------------------- /threestudio/models/prompt_processors/stable_diffusion_prompt_processor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass 4 | 5 | import torch 6 | import torch.nn as nn 7 | from transformers import AutoTokenizer, CLIPTextModel 8 | 9 | import threestudio 10 | from threestudio.models.prompt_processors.base import PromptProcessor, hash_prompt 11 | from threestudio.utils.misc import cleanup 12 | from threestudio.utils.typing import * 13 | 14 | 15 | @threestudio.register("stable-diffusion-prompt-processor") 16 | class StableDiffusionPromptProcessor(PromptProcessor): 17 | @dataclass 18 | class Config(PromptProcessor.Config): 19 | pass 20 | 21 | cfg: Config 22 | 23 | ### these functions are unused, kept for debugging ### 24 | def configure_text_encoder(self) -> None: 25 | self.tokenizer = AutoTokenizer.from_pretrained( 26 | self.cfg.pretrained_model_name_or_path, subfolder="tokenizer" 27 | ) 28 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 29 | self.text_encoder = CLIPTextModel.from_pretrained( 30 | self.cfg.pretrained_model_name_or_path, subfolder="text_encoder" 31 | ).to(self.device) 32 | 33 | for p in self.text_encoder.parameters(): 34 | p.requires_grad_(False) 35 | 36 | def destroy_text_encoder(self) -> None: 37 | del self.tokenizer 38 | del self.text_encoder 39 | cleanup() 40 | 41 | def get_text_embeddings( 42 | self, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] 43 | ) -> Tuple[Float[Tensor, "B 77 768"], Float[Tensor, "B 77 768"]]: 44 | if isinstance(prompt, str): 45 | prompt = [prompt] 46 | if isinstance(negative_prompt, str): 47 | negative_prompt = [negative_prompt] 48 | # Tokenize text and get embeddings 49 | tokens = self.tokenizer( 50 | prompt, 51 | padding="max_length", 52 | max_length=self.tokenizer.model_max_length, 53 | return_tensors="pt", 54 | ) 55 | uncond_tokens = self.tokenizer( 56 | negative_prompt, 57 | padding="max_length", 58 | max_length=self.tokenizer.model_max_length, 59 | return_tensors="pt", 60 | ) 61 | 62 | with torch.no_grad(): 63 | text_embeddings = self.text_encoder(tokens.input_ids.to(self.device))[0] 64 | uncond_text_embeddings = self.text_encoder( 65 | uncond_tokens.input_ids.to(self.device) 66 | )[0] 67 | 68 | return text_embeddings, uncond_text_embeddings 69 | 70 | ### 71 | 72 | @staticmethod 73 | def spawn_func(pretrained_model_name_or_path, prompts, cache_dir): 74 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 75 | tokenizer = AutoTokenizer.from_pretrained( 76 | pretrained_model_name_or_path, subfolder="tokenizer" 77 | ) 78 | text_encoder = CLIPTextModel.from_pretrained( 79 | pretrained_model_name_or_path, 80 | subfolder="text_encoder", 81 | device_map="auto", 82 | ) 83 | 84 | with torch.no_grad(): 85 | tokens = tokenizer( 86 | prompts, 87 | padding="max_length", 88 | max_length=tokenizer.model_max_length, 89 | return_tensors="pt", 90 | ) 91 | text_embeddings = text_encoder(tokens.input_ids.to(text_encoder.device))[0] 92 | 93 | for prompt, embedding in zip(prompts, text_embeddings): 94 | torch.save( 95 | embedding, 96 | os.path.join( 97 | cache_dir, 98 | f"{hash_prompt(pretrained_model_name_or_path, prompt)}.pt", 99 | ), 100 | ) 101 | 102 | del text_encoder 103 | -------------------------------------------------------------------------------- /threestudio/models/prompt_processors/zero123_prompt_processor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass 4 | 5 | import threestudio 6 | from threestudio.models.prompt_processors.base import PromptProcessor, hash_prompt 7 | from threestudio.utils.misc import cleanup 8 | from threestudio.utils.typing import * 9 | 10 | 11 | @threestudio.register("zero123-prompt-processor") 12 | class Zero123PromptProcessor(PromptProcessor): 13 | @dataclass 14 | class Config(PromptProcessor.Config): 15 | pretrained_model_name_or_path: str = "" 16 | 17 | cfg: Config 18 | -------------------------------------------------------------------------------- /threestudio/models/prompt_processors/zeroscope_prompt_processor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass 4 | 5 | import torch 6 | import torch.nn as nn 7 | from transformers import CLIPTokenizer, CLIPTextModel 8 | 9 | import threestudio 10 | from threestudio.models.prompt_processors.base import PromptProcessor, hash_prompt 11 | from threestudio.utils.misc import cleanup 12 | from threestudio.utils.typing import * 13 | 14 | 15 | @threestudio.register("zeroscope-prompt-processor") 16 | class ZeroscopePromptProcessor(PromptProcessor): 17 | @dataclass 18 | class Config(PromptProcessor.Config): 19 | pass 20 | 21 | cfg: Config 22 | 23 | ### these functions are unused, kept for debugging ### 24 | def configure_text_encoder(self) -> None: 25 | self.tokenizer = CLIPTokenizer.from_pretrained( 26 | self.cfg.pretrained_model_name_or_path, subfolder="tokenizer" 27 | ) 28 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 29 | self.text_encoder = CLIPTextModel.from_pretrained( 30 | self.cfg.pretrained_model_name_or_path, subfolder="text_encoder" 31 | ).to(self.device) 32 | 33 | for p in self.text_encoder.parameters(): 34 | p.requires_grad_(False) 35 | 36 | def destroy_text_encoder(self) -> None: 37 | del self.tokenizer 38 | del self.text_encoder 39 | cleanup() 40 | 41 | def get_text_embeddings( 42 | self, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] 43 | ) -> Tuple[Float[Tensor, "B 77 768"], Float[Tensor, "B 77 768"]]: 44 | if isinstance(prompt, str): 45 | prompt = [prompt] 46 | if isinstance(negative_prompt, str): 47 | negative_prompt = [negative_prompt] 48 | # Tokenize text and get embeddings 49 | tokens = self.tokenizer( 50 | prompt, 51 | padding="max_length", 52 | max_length=self.tokenizer.model_max_length, 53 | return_tensors="pt", 54 | ) 55 | uncond_tokens = self.tokenizer( 56 | negative_prompt, 57 | padding="max_length", 58 | max_length=self.tokenizer.model_max_length, 59 | return_tensors="pt", 60 | ) 61 | 62 | with torch.no_grad(): 63 | text_embeddings = self.text_encoder(tokens.input_ids.to(self.device))[0] 64 | uncond_text_embeddings = self.text_encoder( 65 | uncond_tokens.input_ids.to(self.device) 66 | )[0] 67 | 68 | return text_embeddings, uncond_text_embeddings 69 | 70 | ### 71 | 72 | @staticmethod 73 | def spawn_func(pretrained_model_name_or_path, prompts, cache_dir): 74 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 75 | tokenizer = CLIPTokenizer.from_pretrained( 76 | pretrained_model_name_or_path, subfolder="tokenizer" 77 | ) 78 | text_encoder = CLIPTextModel.from_pretrained( 79 | pretrained_model_name_or_path, 80 | subfolder="text_encoder", 81 | device_map="auto", 82 | ) 83 | with torch.no_grad(): 84 | tokens = tokenizer( 85 | prompts, 86 | padding="max_length", 87 | max_length=tokenizer.model_max_length, 88 | return_tensors="pt", 89 | ) 90 | text_embeddings = text_encoder(tokens.input_ids.to(text_encoder.device))[0] 91 | 92 | for prompt, embedding in zip(prompts, text_embeddings): 93 | torch.save( 94 | embedding, 95 | os.path.join( 96 | cache_dir, 97 | f"{hash_prompt(pretrained_model_name_or_path, prompt)}.pt", 98 | ), 99 | ) 100 | 101 | del text_encoder 102 | -------------------------------------------------------------------------------- /threestudio/models/renderers/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | base, 3 | deferred_volume_renderer, 4 | nerf_volume_renderer, 5 | neus_volume_renderer, 6 | nvdiff_rasterizer, 7 | ) 8 | -------------------------------------------------------------------------------- /threestudio/models/renderers/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import nerfacc 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | import threestudio 8 | from threestudio.models.background.base import BaseBackground 9 | from threestudio.models.geometry.base import BaseImplicitGeometry 10 | from threestudio.models.materials.base import BaseMaterial 11 | from threestudio.utils.base import BaseModule 12 | from threestudio.utils.typing import * 13 | 14 | 15 | class Renderer(BaseModule): 16 | @dataclass 17 | class Config(BaseModule.Config): 18 | radius: float = 1.0 19 | 20 | cfg: Config 21 | 22 | def configure( 23 | self, 24 | geometry: BaseImplicitGeometry, 25 | material: BaseMaterial, 26 | background: BaseBackground, 27 | ) -> None: 28 | # keep references to submodules using namedtuple, avoid being registered as modules 29 | @dataclass 30 | class SubModules: 31 | geometry: BaseImplicitGeometry 32 | material: BaseMaterial 33 | background: BaseBackground 34 | 35 | self.sub_modules = SubModules(geometry, material, background) 36 | 37 | # set up bounding box 38 | self.bbox: Float[Tensor, "2 3"] 39 | self.register_buffer( 40 | "bbox", 41 | torch.as_tensor( 42 | [ 43 | [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], 44 | [self.cfg.radius, self.cfg.radius, self.cfg.radius], 45 | ], 46 | dtype=torch.float32, 47 | ), 48 | ) 49 | 50 | def forward(self, *args, **kwargs) -> Dict[str, Any]: 51 | raise NotImplementedError 52 | 53 | @property 54 | def geometry(self) -> BaseImplicitGeometry: 55 | return self.sub_modules.geometry 56 | 57 | @property 58 | def material(self) -> BaseMaterial: 59 | return self.sub_modules.material 60 | 61 | @property 62 | def background(self) -> BaseBackground: 63 | return self.sub_modules.background 64 | 65 | def set_geometry(self, geometry: BaseImplicitGeometry) -> None: 66 | self.sub_modules.geometry = geometry 67 | 68 | def set_material(self, material: BaseMaterial) -> None: 69 | self.sub_modules.material = material 70 | 71 | def set_background(self, background: BaseBackground) -> None: 72 | self.sub_modules.background = background 73 | 74 | 75 | class VolumeRenderer(Renderer): 76 | pass 77 | 78 | 79 | class Rasterizer(Renderer): 80 | pass 81 | -------------------------------------------------------------------------------- /threestudio/models/renderers/deferred_volume_renderer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import threestudio 7 | from threestudio.models.renderers.base import VolumeRenderer 8 | 9 | 10 | class DeferredVolumeRenderer(VolumeRenderer): 11 | pass 12 | -------------------------------------------------------------------------------- /threestudio/models/renderers/neus_volume_renderer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import nerfacc 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | import threestudio 8 | from threestudio.models.renderers.base import VolumeRenderer 9 | from threestudio.utils.typing import * 10 | 11 | 12 | class NeuSVolumeRenderer(VolumeRenderer): 13 | pass 14 | -------------------------------------------------------------------------------- /threestudio/models/renderers/nvdiff_rasterizer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import nerfacc 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | import threestudio 8 | from threestudio.models.background.base import BaseBackground 9 | from threestudio.models.geometry.base import BaseImplicitGeometry 10 | from threestudio.models.materials.base import BaseMaterial 11 | from threestudio.models.renderers.base import Rasterizer, VolumeRenderer 12 | from threestudio.utils.misc import get_device 13 | from threestudio.utils.rasterize import NVDiffRasterizerContext 14 | from threestudio.utils.typing import * 15 | 16 | 17 | @threestudio.register("nvdiff-rasterizer") 18 | class NVDiffRasterizer(Rasterizer): 19 | @dataclass 20 | class Config(VolumeRenderer.Config): 21 | context_type: str = "gl" 22 | 23 | cfg: Config 24 | 25 | def configure( 26 | self, 27 | geometry: BaseImplicitGeometry, 28 | material: BaseMaterial, 29 | background: BaseBackground, 30 | ) -> None: 31 | super().configure(geometry, material, background) 32 | self.ctx = NVDiffRasterizerContext(self.cfg.context_type, get_device()) 33 | 34 | def forward( 35 | self, 36 | mvp_mtx: Float[Tensor, "B 4 4"], 37 | camera_positions: Float[Tensor, "B 3"], 38 | light_positions: Float[Tensor, "B 3"], 39 | height: int, 40 | width: int, 41 | render_normal: bool = True, 42 | render_rgb: bool = True, 43 | **kwargs 44 | ) -> Dict[str, Any]: 45 | batch_size = mvp_mtx.shape[0] 46 | mesh = self.geometry.isosurface() 47 | 48 | v_pos_clip: Float[Tensor, "B Nv 4"] = self.ctx.vertex_transform( 49 | mesh.v_pos, mvp_mtx 50 | ) 51 | rast, _ = self.ctx.rasterize(v_pos_clip, mesh.t_pos_idx, (height, width)) 52 | mask = rast[..., 3:] > 0 53 | mask_aa = self.ctx.antialias(mask.float(), rast, v_pos_clip, mesh.t_pos_idx) 54 | 55 | out = {"opacity": mask_aa, "mesh": mesh} 56 | 57 | if render_normal: 58 | gb_normal, _ = self.ctx.interpolate_one(mesh.v_nrm, rast, mesh.t_pos_idx) 59 | gb_normal = F.normalize(gb_normal, dim=-1) 60 | gb_normal_aa = torch.lerp( 61 | torch.zeros_like(gb_normal), (gb_normal + 1.0) / 2.0, mask.float() 62 | ) 63 | gb_normal_aa = self.ctx.antialias( 64 | gb_normal_aa, rast, v_pos_clip, mesh.t_pos_idx 65 | ) 66 | out.update({"comp_normal": gb_normal_aa}) # in [0, 1] 67 | 68 | if render_rgb: 69 | selector = mask[..., 0] 70 | 71 | gb_pos, _ = self.ctx.interpolate_one(mesh.v_pos, rast, mesh.t_pos_idx) 72 | gb_viewdirs = F.normalize( 73 | gb_pos - camera_positions[:, None, None, :], dim=-1 74 | ) 75 | gb_light_positions = light_positions[:, None, None, :].expand( 76 | -1, height, width, -1 77 | ) 78 | 79 | positions = gb_pos[selector] 80 | geo_out = self.geometry(positions, output_normal=False) 81 | rgb_fg = self.material( 82 | viewdirs=gb_viewdirs[selector], 83 | positions=positions, 84 | light_positions=gb_light_positions[selector], 85 | shading_normal=gb_normal[selector], 86 | **geo_out 87 | ) 88 | gb_rgb_fg = torch.zeros(batch_size, height, width, 3).to(rgb_fg) 89 | gb_rgb_fg[selector] = rgb_fg 90 | 91 | gb_rgb_bg = self.background(dirs=gb_viewdirs) 92 | gb_rgb = torch.lerp(gb_rgb_bg, gb_rgb_fg, mask.float()) 93 | gb_rgb_aa = self.ctx.antialias(gb_rgb, rast, v_pos_clip, mesh.t_pos_idx) 94 | 95 | out.update({"comp_rgb": gb_rgb_aa}) 96 | 97 | return out 98 | -------------------------------------------------------------------------------- /threestudio/systems/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | fourdiffusion, 3 | ) 4 | -------------------------------------------------------------------------------- /threestudio/systems/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import warnings 3 | from bisect import bisect_right 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.optim import lr_scheduler 8 | 9 | import threestudio 10 | 11 | 12 | def get_scheduler(name): 13 | if hasattr(lr_scheduler, name): 14 | return getattr(lr_scheduler, name) 15 | else: 16 | raise NotImplementedError 17 | 18 | 19 | def getattr_recursive(m, attr): 20 | for name in attr.split("."): 21 | m = getattr(m, name) 22 | return m 23 | 24 | 25 | def get_parameters(model, name): 26 | module = getattr_recursive(model, name) 27 | if isinstance(module, nn.Module): 28 | return module.parameters() 29 | elif isinstance(module, nn.Parameter): 30 | return module 31 | return [] 32 | 33 | 34 | def parse_optimizer(config, model): 35 | if hasattr(config, "params"): 36 | params = [ 37 | {"params": get_parameters(model, name), "name": name, **args} 38 | for name, args in config.params.items() 39 | ] 40 | threestudio.debug(f"Specify optimizer params: {config.params}") 41 | else: 42 | params = model.parameters() 43 | if config.name in ["FusedAdam"]: 44 | import apex 45 | 46 | optim = getattr(apex.optimizers, config.name)(params, **config.args) 47 | elif config.name in ["Adan"]: 48 | from threestudio.systems import optimizers 49 | 50 | optim = getattr(optimizers, config.name)(params, **config.args) 51 | else: 52 | optim = getattr(torch.optim, config.name)(params, **config.args) 53 | return optim 54 | 55 | 56 | def parse_scheduler(config, optimizer): 57 | interval = config.get("interval", "epoch") 58 | assert interval in ["epoch", "step"] 59 | if config.name == "SequentialLR": 60 | scheduler = { 61 | "scheduler": lr_scheduler.SequentialLR( 62 | optimizer, 63 | [ 64 | parse_scheduler(conf, optimizer)["scheduler"] 65 | for conf in config.schedulers 66 | ], 67 | milestones=config.milestones, 68 | ), 69 | "interval": interval, 70 | } 71 | elif config.name == "ChainedScheduler": 72 | scheduler = { 73 | "scheduler": lr_scheduler.ChainedScheduler( 74 | [ 75 | parse_scheduler(conf, optimizer)["scheduler"] 76 | for conf in config.schedulers 77 | ] 78 | ), 79 | "interval": interval, 80 | } 81 | else: 82 | scheduler = { 83 | "scheduler": get_scheduler(config.name)(optimizer, **config.args), 84 | "interval": interval, 85 | } 86 | return scheduler 87 | -------------------------------------------------------------------------------- /threestudio/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import base 2 | -------------------------------------------------------------------------------- /threestudio/utils/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from threestudio.utils.config import parse_structured 7 | from threestudio.utils.misc import get_device, load_module_weights 8 | from threestudio.utils.typing import * 9 | 10 | 11 | class Configurable: 12 | @dataclass 13 | class Config: 14 | pass 15 | 16 | def __init__(self, cfg: Optional[dict] = None) -> None: 17 | super().__init__() 18 | self.cfg = parse_structured(self.Config, cfg) 19 | 20 | 21 | class Updateable: 22 | def do_update_step( 23 | self, epoch: int, global_step: int, on_load_weights: bool = False 24 | ): 25 | for attr in self.__dir__(): 26 | if attr.startswith("_"): 27 | continue 28 | try: 29 | module = getattr(self, attr) 30 | except: 31 | continue # ignore attributes like property, which can't be retrived using getattr? 32 | if isinstance(module, Updateable): 33 | module.do_update_step( 34 | epoch, global_step, on_load_weights=on_load_weights 35 | ) 36 | self.update_step(epoch, global_step, on_load_weights=on_load_weights) 37 | 38 | def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): 39 | # override this method to implement custom update logic 40 | # if on_load_weights is True, you should be careful doing things related to model evaluations, 41 | # as the models and tensors are not guarenteed to be on the same device 42 | pass 43 | 44 | 45 | def update_if_possible(module: Any, epoch: int, global_step: int) -> None: 46 | if isinstance(module, Updateable): 47 | module.do_update_step(epoch, global_step) 48 | 49 | 50 | class BaseObject(Updateable): 51 | @dataclass 52 | class Config: 53 | pass 54 | 55 | cfg: Config # add this to every subclass of BaseObject to enable static type checking 56 | 57 | def __init__( 58 | self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs 59 | ) -> None: 60 | super().__init__() 61 | self.cfg = parse_structured(self.Config, cfg) 62 | self.device = get_device() 63 | self.configure(*args, **kwargs) 64 | 65 | def configure(self, *args, **kwargs) -> None: 66 | pass 67 | 68 | 69 | class BaseModule(nn.Module, Updateable): 70 | @dataclass 71 | class Config: 72 | weights: Optional[str] = None 73 | 74 | cfg: Config # add this to every subclass of BaseModule to enable static type checking 75 | 76 | def __init__( 77 | self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs 78 | ) -> None: 79 | super().__init__() 80 | self.cfg = parse_structured(self.Config, cfg) 81 | self.device = get_device() 82 | self.configure(*args, **kwargs) 83 | if self.cfg.weights is not None: 84 | # format: path/to/weights:module_name 85 | weights_path, module_name = self.cfg.weights.split(":") 86 | state_dict, epoch, global_step = load_module_weights( 87 | weights_path, module_name=module_name, map_location="cpu" 88 | ) 89 | self.load_state_dict(state_dict) 90 | self.do_update_step( 91 | epoch, global_step, on_load_weights=True 92 | ) # restore states 93 | # dummy tensor to indicate model state 94 | self._dummy: Float[Tensor, "..."] 95 | self.register_buffer("_dummy", torch.zeros(0).float(), persistent=False) 96 | 97 | def configure(self, *args, **kwargs) -> None: 98 | pass 99 | -------------------------------------------------------------------------------- /threestudio/utils/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | 5 | import pytorch_lightning 6 | 7 | from threestudio.utils.config import dump_config 8 | from threestudio.utils.misc import parse_version 9 | 10 | if parse_version(pytorch_lightning.__version__) > parse_version("1.8"): 11 | from pytorch_lightning.callbacks import Callback 12 | else: 13 | from pytorch_lightning.callbacks.base import Callback 14 | 15 | from pytorch_lightning.callbacks.progress import TQDMProgressBar 16 | from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn 17 | from pytorch_lightning.callbacks import ModelCheckpoint 18 | 19 | 20 | class VersionedCallback(Callback): 21 | def __init__(self, save_root, version=None, use_version=True): 22 | self.save_root = save_root 23 | self._version = version 24 | self.use_version = use_version 25 | 26 | @property 27 | def version(self) -> int: 28 | """Get the experiment version. 29 | 30 | Returns: 31 | The experiment version if specified else the next version. 32 | """ 33 | if self._version is None: 34 | self._version = self._get_next_version() 35 | return self._version 36 | 37 | def _get_next_version(self): 38 | existing_versions = [] 39 | if os.path.isdir(self.save_root): 40 | for f in os.listdir(self.save_root): 41 | bn = os.path.basename(f) 42 | if bn.startswith("version_"): 43 | dir_ver = os.path.splitext(bn)[0].split("_")[1].replace("/", "") 44 | existing_versions.append(int(dir_ver)) 45 | if len(existing_versions) == 0: 46 | return 0 47 | return max(existing_versions) + 1 48 | 49 | @property 50 | def savedir(self): 51 | if not self.use_version: 52 | return self.save_root 53 | return os.path.join( 54 | self.save_root, 55 | self.version 56 | if isinstance(self.version, str) 57 | else f"version_{self.version}", 58 | ) 59 | 60 | 61 | class CodeSnapshotCallback(VersionedCallback): 62 | def __init__(self, save_root, version=None, use_version=True): 63 | super().__init__(save_root, version, use_version) 64 | 65 | def get_file_list(self): 66 | return [ 67 | b.decode() 68 | for b in set( 69 | subprocess.check_output( 70 | 'git ls-files -- ":!:load/*"', shell=True 71 | ).splitlines() 72 | ) 73 | | set( # hard code, TODO: use config to exclude folders or files 74 | subprocess.check_output( 75 | "git ls-files --others --exclude-standard", shell=True 76 | ).splitlines() 77 | ) 78 | ] 79 | 80 | @rank_zero_only 81 | def save_code_snapshot(self): 82 | os.makedirs(self.savedir, exist_ok=True) 83 | for f in self.get_file_list(): 84 | if not os.path.exists(f) or os.path.isdir(f): 85 | continue 86 | os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True) 87 | shutil.copyfile(f, os.path.join(self.savedir, f)) 88 | 89 | def on_fit_start(self, trainer, pl_module): 90 | try: 91 | self.save_code_snapshot() 92 | except: 93 | rank_zero_warn( 94 | "Code snapshot is not saved. Please make sure you have git installed and are in a git repository." 95 | ) 96 | 97 | 98 | class ConfigSnapshotCallback(VersionedCallback): 99 | def __init__(self, config_path, config, save_root, version=None, use_version=True): 100 | super().__init__(save_root, version, use_version) 101 | self.config_path = config_path 102 | self.config = config 103 | 104 | @rank_zero_only 105 | def save_config_snapshot(self): 106 | os.makedirs(self.savedir, exist_ok=True) 107 | dump_config(os.path.join(self.savedir, "parsed.yaml"), self.config) 108 | shutil.copyfile(self.config_path, os.path.join(self.savedir, "raw.yaml")) 109 | 110 | def on_fit_start(self, trainer, pl_module): 111 | self.save_config_snapshot() 112 | 113 | 114 | class CustomProgressBar(TQDMProgressBar): 115 | def get_metrics(self, *args, **kwargs): 116 | # don't show the version number 117 | items = super().get_metrics(*args, **kwargs) 118 | items.pop("v_num", None) 119 | return items -------------------------------------------------------------------------------- /threestudio/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from datetime import datetime 4 | 5 | from omegaconf import OmegaConf 6 | 7 | import threestudio 8 | from threestudio.utils.typing import * 9 | 10 | # ============ Register OmegaConf Recolvers ============= # 11 | OmegaConf.register_new_resolver( 12 | "calc_exp_lr_decay_rate", lambda factor, n: factor ** (1.0 / n) 13 | ) 14 | OmegaConf.register_new_resolver("add", lambda a, b: a + b) 15 | OmegaConf.register_new_resolver("sub", lambda a, b: a - b) 16 | OmegaConf.register_new_resolver("mul", lambda a, b: a * b) 17 | OmegaConf.register_new_resolver("div", lambda a, b: a / b) 18 | OmegaConf.register_new_resolver("idiv", lambda a, b: a // b) 19 | OmegaConf.register_new_resolver("basename", lambda p: os.path.basename(p)) 20 | OmegaConf.register_new_resolver("rmspace", lambda s, sub: s.replace(" ", sub)) 21 | OmegaConf.register_new_resolver("tuple2", lambda s: [float(s), float(s)]) 22 | OmegaConf.register_new_resolver("gt0", lambda s: s > 0) 23 | OmegaConf.register_new_resolver("not", lambda s: not s) 24 | # ======================================================= # 25 | 26 | 27 | @dataclass 28 | class ExperimentConfig: 29 | name: str = "default" 30 | description: str = "" 31 | tag: str = "" 32 | seed: int = 0 33 | use_timestamp: bool = True 34 | timestamp: Optional[str] = None 35 | exp_root_dir: str = "outputs" 36 | 37 | ### these shouldn't be set manually 38 | exp_dir: str = "outputs/default" 39 | trial_name: str = "exp" 40 | trial_dir: str = "outputs/default/exp" 41 | n_gpus: int = 1 42 | ### 43 | 44 | resume: Optional[str] = None 45 | 46 | data_type: str = "" 47 | data: dict = field(default_factory=dict) 48 | 49 | system_type: str = "" 50 | system: dict = field(default_factory=dict) 51 | 52 | # accept pytorch-lightning trainer parameters 53 | # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api 54 | trainer: dict = field(default_factory=dict) 55 | 56 | # accept pytorch-lightning checkpoint callback parameters 57 | # see https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint 58 | checkpoint: dict = field(default_factory=dict) 59 | 60 | def __post_init__(self): 61 | if not self.tag and not self.use_timestamp: 62 | raise ValueError("Either tag is specified or use_timestamp is True.") 63 | self.trial_name = self.tag 64 | # if resume from an existing config, self.timestamp should not be None 65 | if self.timestamp is None: 66 | self.timestamp = "" 67 | if self.use_timestamp: 68 | if self.n_gpus > 1: 69 | threestudio.warn( 70 | "Timestamp is disabled when using multiple GPUs, please make sure you have a unique tag." 71 | ) 72 | else: 73 | self.timestamp = datetime.now().strftime("@%Y%m%d-%H%M%S") 74 | self.trial_name += self.timestamp 75 | self.exp_dir = os.path.join(self.exp_root_dir, self.name) 76 | self.trial_dir = os.path.join(self.exp_dir, self.trial_name) 77 | 78 | 79 | def load_config(*yaml_files: str, cli_args: list = [], **kwargs) -> Any: 80 | yaml_confs = [OmegaConf.load(f) for f in yaml_files] 81 | cli_conf = OmegaConf.from_cli(cli_args) 82 | cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs) 83 | print(cfg) 84 | OmegaConf.resolve(cfg) 85 | assert isinstance(cfg, DictConfig) 86 | scfg = parse_structured(ExperimentConfig, cfg) 87 | return scfg 88 | 89 | 90 | def config_to_primitive(config, resolve: bool = True) -> Any: 91 | return OmegaConf.to_container(config, resolve=resolve) 92 | 93 | 94 | def dump_config(path: str, config) -> None: 95 | with open(path, "w") as fp: 96 | OmegaConf.save(config=config, f=fp) 97 | 98 | 99 | def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: 100 | scfg = OmegaConf.structured(fields(**cfg)) 101 | return scfg 102 | -------------------------------------------------------------------------------- /threestudio/utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | from threestudio.lpipsPyTorch import lpips as lpips_fn 17 | from threestudio.lpipsPyTorch.modules.lpips import LPIPS 18 | 19 | _lpips = None 20 | 21 | 22 | def l1_loss(network_output, gt): 23 | return torch.abs((network_output - gt)).mean() 24 | 25 | 26 | def l2_loss(network_output, gt): 27 | return ((network_output - gt) ** 2).mean() 28 | 29 | 30 | def gaussian(window_size, sigma): 31 | gauss = torch.Tensor( 32 | [ 33 | exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) 34 | for x in range(window_size) 35 | ] 36 | ) 37 | return gauss / gauss.sum() 38 | 39 | 40 | def create_window(window_size, channel): 41 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 42 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 43 | window = Variable( 44 | _2D_window.expand(channel, 1, window_size, window_size).contiguous() 45 | ) 46 | return window 47 | 48 | 49 | def ssim(img1, img2, window_size=11, size_average=True): 50 | channel = img1.size(-3) 51 | window = create_window(window_size, channel) 52 | 53 | if img1.is_cuda: 54 | window = window.cuda(img1.get_device()) 55 | window = window.type_as(img1) 56 | 57 | return _ssim(img1, img2, window, window_size, channel, size_average) 58 | 59 | 60 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 61 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 62 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 63 | 64 | mu1_sq = mu1.pow(2) 65 | mu2_sq = mu2.pow(2) 66 | mu1_mu2 = mu1 * mu2 67 | 68 | sigma1_sq = ( 69 | F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 70 | ) 71 | sigma2_sq = ( 72 | F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 73 | ) 74 | sigma12 = ( 75 | F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) 76 | - mu1_mu2 77 | ) 78 | 79 | C1 = 0.01**2 80 | C2 = 0.03**2 81 | 82 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( 83 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) 84 | ) 85 | 86 | if size_average: 87 | return ssim_map.mean() 88 | else: 89 | return ssim_map.mean(1).mean(1).mean(1) 90 | 91 | 92 | def lpips(img1, img2): 93 | global _lpips 94 | if _lpips is None: 95 | _lpips = LPIPS("vgg", "0.1").to("cuda") 96 | return _lpips(img1, img2).mean() -------------------------------------------------------------------------------- /threestudio/utils/misc.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | import re 4 | 5 | import tinycudann as tcnn 6 | import torch 7 | from packaging import version 8 | 9 | from threestudio.utils.config import config_to_primitive 10 | from threestudio.utils.typing import * 11 | 12 | 13 | def parse_version(ver: str): 14 | return version.parse(ver) 15 | 16 | 17 | def get_rank(): 18 | # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, 19 | # therefore LOCAL_RANK needs to be checked first 20 | rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") 21 | for key in rank_keys: 22 | rank = os.environ.get(key) 23 | if rank is not None: 24 | return int(rank) 25 | return 0 26 | 27 | 28 | def get_device(): 29 | return torch.device(f"cuda:{get_rank()}") 30 | 31 | 32 | def load_module_weights( 33 | path, module_name=None, ignore_modules=None, map_location=None 34 | ) -> Tuple[dict, int, int]: 35 | if module_name is not None and ignore_modules is not None: 36 | raise ValueError("module_name and ignore_modules cannot be both set") 37 | if map_location is None: 38 | map_location = get_device() 39 | 40 | ckpt = torch.load(path, map_location=map_location) 41 | state_dict = ckpt["state_dict"] 42 | state_dict_to_load = state_dict 43 | 44 | if ignore_modules is not None: 45 | state_dict_to_load = {} 46 | for k, v in state_dict.items(): 47 | ignore = any( 48 | [k.startswith(ignore_module + ".") for ignore_module in ignore_modules] 49 | ) 50 | if ignore: 51 | continue 52 | state_dict_to_load[k] = v 53 | 54 | if module_name is not None: 55 | state_dict_to_load = {} 56 | for k, v in state_dict.items(): 57 | m = re.match(rf"^{module_name}\.(.*)$", k) 58 | if m is None: 59 | continue 60 | state_dict_to_load[m.group(1)] = v 61 | 62 | return state_dict_to_load, ckpt["epoch"], ckpt["global_step"] 63 | 64 | 65 | def C(value: Any, epoch: int, global_step: int) -> float: 66 | if isinstance(value, int) or isinstance(value, float): 67 | pass 68 | else: 69 | value = config_to_primitive(value) 70 | if not isinstance(value, list): 71 | raise TypeError("Scalar specification only supports list, got", type(value)) 72 | if len(value) == 3: 73 | value = [0] + value 74 | assert len(value) == 4 75 | start_step, start_value, end_value, end_step = value 76 | if isinstance(end_step, int): 77 | current_step = global_step 78 | value = start_value + (end_value - start_value) * max( 79 | min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 80 | ) 81 | elif isinstance(end_step, float): 82 | current_step = epoch 83 | value = start_value + (end_value - start_value) * max( 84 | min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 85 | ) 86 | return value 87 | 88 | 89 | def cleanup(): 90 | gc.collect() 91 | torch.cuda.empty_cache() 92 | tcnn.free_temporary_memory() 93 | 94 | 95 | def finish_with_cleanup(func: Callable): 96 | def wrapper(*args, **kwargs): 97 | out = func(*args, **kwargs) 98 | cleanup() 99 | return out 100 | 101 | return wrapper 102 | 103 | 104 | def _distributed_available(): 105 | return torch.distributed.is_available() and torch.distributed.is_initialized() 106 | 107 | 108 | def barrier(): 109 | if not _distributed_available(): 110 | return 111 | else: 112 | torch.distributed.barrier() 113 | -------------------------------------------------------------------------------- /threestudio/utils/rasterize.py: -------------------------------------------------------------------------------- 1 | import nvdiffrast.torch as dr 2 | import torch 3 | 4 | from threestudio.utils.typing import * 5 | 6 | 7 | class NVDiffRasterizerContext: 8 | def __init__(self, context_type: str, device: torch.device) -> None: 9 | self.device = device 10 | self.ctx = self.initialize_context(context_type, device) 11 | 12 | def initialize_context( 13 | self, context_type: str, device: torch.device 14 | ) -> Union[dr.RasterizeGLContext, dr.RasterizeCudaContext]: 15 | if context_type == "gl": 16 | return dr.RasterizeGLContext(device=device) 17 | elif context_type == "cuda": 18 | return dr.RasterizeCudaContext(device=device) 19 | else: 20 | raise ValueError(f"Unknown rasterizer context type: {context_type}") 21 | 22 | def vertex_transform( 23 | self, verts: Float[Tensor, "Nv 3"], mvp_mtx: Float[Tensor, "B 4 4"] 24 | ) -> Float[Tensor, "B Nv 4"]: 25 | verts_homo = torch.cat( 26 | [verts, torch.ones([verts.shape[0], 1]).to(verts)], dim=-1 27 | ) 28 | return torch.matmul(verts_homo, mvp_mtx.permute(0, 2, 1)) 29 | 30 | def rasterize( 31 | self, 32 | pos: Float[Tensor, "B Nv 4"], 33 | tri: Integer[Tensor, "Nf 3"], 34 | resolution: Union[int, Tuple[int, int]], 35 | ): 36 | # rasterize in instance mode (single topology) 37 | return dr.rasterize(self.ctx, pos.float(), tri.int(), resolution, grad_db=True) 38 | 39 | def rasterize_one( 40 | self, 41 | pos: Float[Tensor, "Nv 4"], 42 | tri: Integer[Tensor, "Nf 3"], 43 | resolution: Union[int, Tuple[int, int]], 44 | ): 45 | # rasterize one single mesh under a single viewpoint 46 | rast, rast_db = self.rasterize(pos[None, ...], tri, resolution) 47 | return rast[0], rast_db[0] 48 | 49 | def antialias( 50 | self, 51 | color: Float[Tensor, "B H W C"], 52 | rast: Float[Tensor, "B H W 4"], 53 | pos: Float[Tensor, "B Nv 4"], 54 | tri: Integer[Tensor, "Nf 3"], 55 | ) -> Float[Tensor, "B H W C"]: 56 | return dr.antialias(color.float(), rast, pos.float(), tri.int()) 57 | 58 | def interpolate( 59 | self, 60 | attr: Float[Tensor, "B Nv C"], 61 | rast: Float[Tensor, "B H W 4"], 62 | tri: Integer[Tensor, "Nf 3"], 63 | rast_db=None, 64 | diff_attrs=None, 65 | ) -> Float[Tensor, "B H W C"]: 66 | return dr.interpolate( 67 | attr.float(), rast, tri.int(), rast_db=rast_db, diff_attrs=diff_attrs 68 | ) 69 | 70 | def interpolate_one( 71 | self, 72 | attr: Float[Tensor, "Nv C"], 73 | rast: Float[Tensor, "B H W 4"], 74 | tri: Integer[Tensor, "Nf 3"], 75 | rast_db=None, 76 | diff_attrs=None, 77 | ) -> Float[Tensor, "B H W C"]: 78 | return self.interpolate(attr[None, ...], rast, tri, rast_db, diff_attrs) 79 | -------------------------------------------------------------------------------- /threestudio/utils/typing.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains type annotations for the project, using 3 | 1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects 4 | 2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors 5 | 6 | Two types of typing checking can be used: 7 | 1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode) 8 | 2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking) 9 | """ 10 | 11 | # Basic types 12 | from typing import ( 13 | Any, 14 | Callable, 15 | Dict, 16 | Iterable, 17 | List, 18 | Literal, 19 | NamedTuple, 20 | NewType, 21 | Optional, 22 | Sized, 23 | Tuple, 24 | Type, 25 | TypeVar, 26 | Union, 27 | ) 28 | 29 | # Tensor dtype 30 | # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md 31 | from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt 32 | 33 | # Config type 34 | from omegaconf import DictConfig 35 | 36 | # PyTorch Tensor type 37 | from torch import Tensor 38 | 39 | # Runtime type checking decorator 40 | from typeguard import typechecked as typechecker 41 | --------------------------------------------------------------------------------