├── DOCUMENTATION.md ├── LICENSE ├── README.md ├── SMPL_poses ├── 0004_smpl.pkl ├── 0069_smpl.pkl ├── 0073_smpl.pkl ├── 0159_smpl.pkl ├── 0167_smpl.pkl ├── 0188_smpl.pkl ├── 0279_smpl.pkl ├── 0308_smpl.pkl ├── 0392_smpl.pkl ├── 0484_smpl.pkl └── 0501_smpl.pkl ├── configs └── dreamavatar.yaml ├── docker ├── Dockerfile └── compose.yaml ├── docs ├── DreamAvatar-supp.pdf ├── index.html └── static │ ├── css │ ├── bulma-carousel.min.css │ ├── bulma-slider.min.css │ ├── bulma.css.map.txt │ ├── bulma.min.css │ ├── fontawesome.all.min.css │ └── index.css │ ├── gif │ ├── clown.gif │ ├── deadpool.gif │ ├── joker.gif │ └── link.gif │ ├── js │ ├── bulma-carousel.js │ ├── bulma-carousel.min.js │ ├── bulma-slider.js │ ├── bulma-slider.min.js │ ├── fontawesome.all.min.js │ ├── index.js │ └── result.js │ └── video │ ├── Pipeline-n.png │ ├── canonical │ ├── Alien.mp4 │ ├── Buddhist_monk.mp4 │ ├── C-3PO.mp4 │ ├── Crystal_maiden.mp4 │ ├── Electro.mp4 │ ├── Flash.mp4 │ ├── Groot.mp4 │ ├── Joker.mp4 │ ├── Link.mp4 │ ├── Link_2.mp4 │ ├── Luffy.mp4 │ ├── Spiderman.mp4 │ ├── Track_field_athlete.mp4 │ ├── Wonder_woman.mp4 │ ├── Woody.mp4 │ ├── Woody_in_joker.mp4 │ ├── body_builder.mp4 │ ├── clown.mp4 │ ├── hipster_man.mp4 │ ├── kakashi.mp4 │ ├── sasuke.mp4 │ └── woman_hippie.mp4 │ ├── integration │ ├── clown.mp4 │ ├── deadpool.mp4 │ ├── joker.mp4 │ └── link.mp4 │ ├── poses │ ├── Flash-00576.mp4 │ ├── Flash-00596.mp4 │ ├── groot-00084.mp4 │ ├── groot-00230.mp4 │ ├── groot-00308.mp4 │ ├── groot-00350.mp4 │ ├── joker-00296.mp4 │ ├── joker-00510.mp4 │ ├── joker-00530.mp4 │ ├── joker-00536.mp4 │ ├── spiderman-00028.mp4 │ └── spiderman-0279.mp4 │ ├── shapes │ ├── groot-0-1.mp4 │ ├── groot-0-3.mp4 │ ├── groot-1+2.mp4 │ └── groot-1-2.mp4 │ └── text_manipulation │ ├── joker-black.mp4 │ ├── joker-green.mp4 │ ├── joker-pink.mp4 │ └── joker-texudo.mp4 ├── environment.yml ├── launch.py ├── load ├── make_prompt_library.py └── prompt_library.json ├── requirements.txt └── threestudio ├── __init__.py ├── data ├── __init__.py ├── co3d.py ├── image.py ├── multiview.py └── uncond.py ├── models ├── __init__.py ├── background │ ├── __init__.py │ ├── base.py │ └── neural_environment_map_background.py ├── exporters │ ├── __init__.py │ ├── base.py │ └── mesh_exporter.py ├── geometry │ ├── __init__.py │ ├── base.py │ ├── implicit_volume.py │ └── inv_deformation.py ├── guidance │ ├── __init__.py │ ├── stable_diffusion_guidance.py │ └── stable_diffusion_vsd_guidance.py ├── isosurface.py ├── materials │ ├── __init__.py │ ├── base.py │ └── no_material.py ├── mesh.py ├── networks.py ├── prompt_processors │ ├── __init__.py │ ├── base.py │ └── stable_diffusion_prompt_processor.py └── renderers │ ├── __init__.py │ ├── base.py │ └── nerf_volume_renderer.py ├── systems ├── __init__.py ├── base.py ├── dreamavatar.py ├── optimizers.py └── utils.py └── utils ├── GAN ├── attention.py ├── discriminator.py ├── distribution.py ├── loss.py ├── mobilenet.py ├── network_util.py ├── util.py └── vae.py ├── __init__.py ├── base.py ├── callbacks.py ├── config.py ├── misc.py ├── ops.py ├── perceptual ├── __init__.py ├── perceptual.py └── utils.py ├── rasterize.py ├── saving.py └── typing.py /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # DreamAvatar: Text-and-Shape Guided 3D Human Avatar Generation via Diffusion Models 4 | 5 | Yukang Cao\*, 6 | Yan-Pei Cao\*, 7 | Kai Han, 8 | Ying Shan, 9 | Kwan-Yee K. Wong 10 | 11 | 12 | [![Paper](http://img.shields.io/badge/Paper-arxiv.2306.03038-B31B1B.svg)](https://arxiv.org/abs/2304.00916) 13 | page 14 | 15 | 16 | 17 | 18 | 19 | 20 | Please refer to our webpage for more visualizations. 21 |
22 | 23 | ## Abstract 24 | We present **DreamAvatar**, a text-and-shape guided framework for generating high-quality 3D human avatars with controllable poses. While encouraging results have been reported by recent methods on text-guided 3D common object generation, generating high-quality human avatars remains an open challenge due to the complexity of the human body's shape, pose, and appearance. We propose DreamAvatar to tackle this challenge, which utilizes a trainable NeRF for predicting density and color for 3D points and pretrained text-to-image diffusion models for providing 2D self-supervision. Specifically, we leverage the SMPL model to provide shape and pose guidance for the generation. We introduce a dual-observation-space design that involves the joint optimization of a canonical space and a posed space that are related by a learnable deformation field. This facilitates the generation of more complete textures and geometry faithful to the target pose. We also jointly optimize the losses computed from the full body and from the zoomed-in 3D head to alleviate the common multi-face ''Janus'' problem and improve facial details in the generated avatars. Extensive evaluations demonstrate that DreamAvatar significantly outperforms existing methods, establishing a new state-of-the-art for text-and-shape guided 3D human avatar generation. 25 | 26 |
27 | 28 |
29 | 30 | ## Installation 31 | 32 | See [installation.md](docs/installation.md) for additional information, including installation via Docker. 33 | 34 | The following steps have been tested on Ubuntu20.04. 35 | 36 | - You must have an NVIDIA graphics card with at least 48GB VRAM and have [CUDA](https://developer.nvidia.com/cuda-downloads) installed. 37 | - Install `Python >= 3.8`. 38 | - (Optional, Recommended) Create a virtual environment: 39 | 40 | ```sh 41 | python3 -m virtualenv venv 42 | . venv/bin/activate 43 | 44 | # Newer pip versions, e.g. pip-23.x, can be much faster than old versions, e.g. pip-20.x. 45 | # For instance, it caches the wheels of git packages to avoid unnecessarily rebuilding them later. 46 | python3 -m pip install --upgrade pip 47 | ``` 48 | 49 | - Install `PyTorch >= 1.12`. We have tested on `torch1.12.1+cu113` and `torch2.0.0+cu118`, but other versions should also work fine. 50 | 51 | ```sh 52 | # torch1.12.1+cu113 53 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 54 | # or torch2.0.0+cu118 55 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118 56 | ``` 57 | 58 | - Install [pytorch3d](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md) and [kaolin](https://kaolin.readthedocs.io/en/latest/notes/installation.html). 59 | 60 | ``` 61 | pip install git+https://github.com/facebookresearch/pytorch3d.git@v0.7.1 62 | pip install git+https://github.com/NVIDIAGameWorks/kaolin.git 63 | ``` 64 | 65 | - (Optional, Recommended) Install ninja to speed up the compilation of CUDA extensions: 66 | 67 | ```sh 68 | pip install ninja 69 | ``` 70 | 71 | - Install dependencies: 72 | 73 | ```sh 74 | pip install -r requirements.txt 75 | ``` 76 | 77 | ## Download SMPL-X model 78 | 79 | * Please kindly follow instructions in [SMPL-X](https://smpl-x.is.tue.mpg.de/) to download required model. 80 | ``` 81 | ./smpl_data 82 | ├── SMPLX_NEUTRAL.pkl 83 | ├── SMPLX_FEMALE.pkl 84 | ├── SMPLX_MALE.pkl 85 | ``` 86 | 87 | ## Training canonical DreamAvatar 88 | 89 | ```sh 90 | # avatar generation with 512x512 NeRF rendering, ~48GB VRAM 91 | python launch.py --config configs/dreamavatar.yaml --train --gpu 0 system.prompt_processor.prompt="Wonder Woman" 92 | # if you don't have enough VRAM, try training with 64x64 NeRF rendering 93 | python launch.py --config configs/dreamavatar.yaml --train --gpu 0 system.prompt_processor.prompt="Wonder Woman" data.width=64 data.height=64 data.batch_size=1 94 | ``` 95 | 96 | ### Resume from checkpoints 97 | 98 | If you want to resume from a checkpoint, do: 99 | 100 | ```sh 101 | # resume training from the last checkpoint, you may replace last.ckpt with any other checkpoints 102 | python launch.py --config path/to/trial/dir/configs/parsed.yaml --train --gpu 0 resume=path/to/trial/dir/ckpts/last.ckpt 103 | # if the training has completed, you can still continue training for a longer time by setting trainer.max_steps 104 | python launch.py --config path/to/trial/dir/configs/parsed.yaml --train --gpu 0 resume=path/to/trial/dir/ckpts/last.ckpt trainer.max_steps=20000 105 | # you can also perform testing using resumed checkpoints 106 | python launch.py --config path/to/trial/dir/configs/parsed.yaml --test --gpu 0 resume=path/to/trial/dir/ckpts/last.ckpt 107 | # note that the above commands use parsed configuration files from previous trials 108 | # which will continue using the same trial directory 109 | # if you want to save to a new trial directory, replace parsed.yaml with raw.yaml in the command 110 | 111 | # only load weights from saved checkpoint but dont resume training (i.e. dont load optimizer state): 112 | python launch.py --config path/to/trial/dir/configs/parsed.yaml --train --gpu 0 system.weights=path/to/trial/dir/ckpts/last.ckpt 113 | ``` 114 | 115 | ## Bibtex. 116 | If you want to cite our work, please use the following bib entry: 117 | ``` 118 | @inproceedings{cao2024dreamavatar, 119 | title={Dreamavatar: Text-and-shape guided 3d human avatar generation via diffusion models}, 120 | author={Cao, Yukang and Cao, Yan-Pei and Han, Kai and Shan, Ying and Wong, Kwan-Yee~K.}, 121 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 122 | pages={958--968}, 123 | year={2024} 124 | } 125 | ``` 126 | 127 | ## Acknowledgement 128 | Thanks to the brilliant works from [Threestudio](https://github.com/threestudio-project/threestudio) and [Stable-DreamFusion](https://github.com/ashawkey/stable-dreamfusion) 129 | -------------------------------------------------------------------------------- /SMPL_poses/0004_smpl.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/SMPL_poses/0004_smpl.pkl -------------------------------------------------------------------------------- /SMPL_poses/0069_smpl.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/SMPL_poses/0069_smpl.pkl -------------------------------------------------------------------------------- /SMPL_poses/0073_smpl.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/SMPL_poses/0073_smpl.pkl -------------------------------------------------------------------------------- /SMPL_poses/0159_smpl.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/SMPL_poses/0159_smpl.pkl -------------------------------------------------------------------------------- /SMPL_poses/0167_smpl.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/SMPL_poses/0167_smpl.pkl -------------------------------------------------------------------------------- /SMPL_poses/0188_smpl.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/SMPL_poses/0188_smpl.pkl -------------------------------------------------------------------------------- /SMPL_poses/0279_smpl.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/SMPL_poses/0279_smpl.pkl -------------------------------------------------------------------------------- /SMPL_poses/0308_smpl.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/SMPL_poses/0308_smpl.pkl -------------------------------------------------------------------------------- /SMPL_poses/0392_smpl.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/SMPL_poses/0392_smpl.pkl -------------------------------------------------------------------------------- /SMPL_poses/0484_smpl.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/SMPL_poses/0484_smpl.pkl -------------------------------------------------------------------------------- /SMPL_poses/0501_smpl.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/SMPL_poses/0501_smpl.pkl -------------------------------------------------------------------------------- /configs/dreamavatar.yaml: -------------------------------------------------------------------------------- 1 | name: "dreamavatar" 2 | tag: "${rmspace:${system.prompt_processor.prompt},_}" 3 | exp_root_dir: "outputs" 4 | seed: 0 5 | 6 | data_type: "random-camera-datamodule" 7 | data: 8 | batch_size: [1, 1] 9 | # 0-4999: 64x64, >=5000: 512x512 10 | # this drastically reduces VRAM usage as empty space is pruned in early training 11 | width: [64, 512] 12 | height: [64, 512] 13 | resolution_milestones: [5000] 14 | camera_distance_range: [1.0, 1.5] 15 | fovy_range: [40, 70] 16 | elevation_range: [-10, 45] 17 | camera_perturb: 0. 18 | center_perturb: 0. 19 | up_perturb: 0. 20 | eval_camera_distance: 1.5 21 | eval_fovy_deg: 70. 22 | 23 | system_type: "dreamavatar-system" 24 | system: 25 | stage: coarse 26 | geometry_type: "implicit-volume" 27 | geometry: 28 | radius: 1.0 29 | normal_type: "pred" 30 | 31 | density_bias: "blob_magic3d" 32 | density_activation: softplus 33 | density_blob_scale: 10. 34 | density_blob_std: 0.5 35 | 36 | pos_encoding_config: 37 | otype: HashGrid 38 | n_levels: 16 39 | n_features_per_level: 2 40 | log2_hashmap_size: 19 41 | base_resolution: 16 42 | per_level_scale: 1.447269237440378 # max resolution 4096 43 | 44 | material_type: "no-material" 45 | material: 46 | n_output_dims: 3 47 | color_activation: sigmoid 48 | 49 | background_type: "neural-environment-map-background" 50 | background: 51 | color_activation: sigmoid 52 | random_aug: true 53 | 54 | renderer_type: "nerf-volume-renderer" 55 | renderer: 56 | radius: ${system.geometry.radius} 57 | num_samples_per_ray: 512 58 | 59 | # prompt_processor_type: "stable-diffusion-prompt-processor" 60 | # prompt_processor: 61 | # pretrained_model_name_or_path: "/root/autodl-tmp/huggingface/hub/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06" 62 | # prompt: ??? 63 | 64 | # guidance_type: "stable-diffusion-guidance" 65 | # guidance: 66 | # pretrained_model_name_or_path: "/root/autodl-tmp/huggingface/hub/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06" 67 | # guidance_scale: 100. 68 | # weighting_strategy: sds 69 | 70 | prompt_processor_type: "stable-diffusion-prompt-processor" 71 | prompt_processor: 72 | pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base" 73 | prompt: ??? 74 | front_threshold: 30. 75 | back_threshold: 30. 76 | 77 | guidance_type: "stable-diffusion-vsd-guidance" 78 | guidance: 79 | pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base" 80 | pretrained_model_name_or_path_lora: "stabilityai/stable-diffusion-2-1" 81 | guidance_scale: 7.5 82 | min_step_percent: 0.02 83 | max_step_percent: [5000, 0.98, 0.5, 5001] # annealed to 0.5 after 5000 steps 84 | 85 | loggers: 86 | wandb: 87 | enable: false 88 | project: "threestudio" 89 | name: None 90 | 91 | loss: 92 | lambda_vsd: 1. 93 | lambda_lora: 1. 94 | lambda_orient: 0. 95 | lambda_sparsity: 10. 96 | lambda_opaque: [10000, 0.0, 1000.0, 10001] 97 | lambda_z_variance: 0. 98 | optimizer: 99 | name: AdamW 100 | args: 101 | betas: [0.9, 0.99] 102 | eps: 1.e-15 103 | params: 104 | geometry.encoding: 105 | lr: 0.01 106 | geometry.density_network: 107 | lr: 0.001 108 | geometry.feature_network: 109 | lr: 0.001 110 | background: 111 | lr: 0.001 112 | guidance: 113 | lr: 0.0001 114 | 115 | trainer: 116 | max_steps: 10000 117 | log_every_n_steps: 1 118 | num_sanity_val_steps: 0 119 | val_check_interval: 200 120 | enable_progress_bar: true 121 | precision: 32 122 | 123 | checkpoint: 124 | save_last: true 125 | save_top_k: -1 126 | every_n_train_steps: ${trainer.max_steps} 127 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Reference: 2 | # https://github.com/cvpaperchallenge/Ascender 3 | # https://github.com/nerfstudio-project/nerfstudio 4 | 5 | FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 6 | 7 | ARG USER_NAME=dreamer 8 | ARG GROUP_NAME=dreamers 9 | ARG UID=1000 10 | ARG GID=1000 11 | 12 | # Set compute capability for nerfacc and tiny-cuda-nn 13 | # See https://developer.nvidia.com/cuda-gpus and limit number to speed-up build 14 | ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX" 15 | ENV TCNN_CUDA_ARCHITECTURES=90;89;86;80;75;70;61;60 16 | # Speed-up build for RTX 30xx 17 | # ENV TORCH_CUDA_ARCH_LIST="8.6" 18 | # ENV TCNN_CUDA_ARCHITECTURES=86 19 | # Speed-up build for RTX 40xx 20 | # ENV TORCH_CUDA_ARCH_LIST="8.9" 21 | # ENV TCNN_CUDA_ARCHITECTURES=89 22 | 23 | ENV CUDA_HOME=/usr/local/cuda 24 | ENV PATH=${CUDA_HOME}/bin:/home/${USER_NAME}/.local/bin:${PATH} 25 | ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} 26 | ENV LIBRARY_PATH=${CUDA_HOME}/lib64/stubs:${LIBRARY_PATH} 27 | 28 | # apt install by root user 29 | RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ 30 | build-essential \ 31 | curl \ 32 | git \ 33 | libegl1-mesa-dev \ 34 | libgl1-mesa-dev \ 35 | libgles2-mesa-dev \ 36 | libglib2.0-0 \ 37 | libsm6 \ 38 | libxext6 \ 39 | libxrender1 \ 40 | python-is-python3 \ 41 | python3.10-dev \ 42 | python3-pip \ 43 | wget \ 44 | && rm -rf /var/lib/apt/lists/* 45 | 46 | # Change user to non-root user 47 | RUN groupadd -g ${GID} ${GROUP_NAME} \ 48 | && useradd -ms /bin/sh -u ${UID} -g ${GID} ${USER_NAME} 49 | USER ${USER_NAME} 50 | 51 | RUN pip install --upgrade pip setuptools ninja 52 | RUN pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 53 | # Install nerfacc and tiny-cuda-nn before installing requirements.txt 54 | # because these two installations are time consuming and error prone 55 | RUN pip install git+https://github.com/KAIR-BAIR/nerfacc.git@v0.5.2 56 | RUN pip install git+https://github.com/NVlabs/tiny-cuda-nn.git#subdirectory=bindings/torch 57 | 58 | COPY requirements.txt /tmp 59 | RUN cd /tmp && pip install -r requirements.txt 60 | WORKDIR /home/${USER_NAME}/threestudio 61 | -------------------------------------------------------------------------------- /docker/compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | threestudio: 3 | build: 4 | context: ../ 5 | dockerfile: docker/Dockerfile 6 | args: 7 | # you can set environment variables, otherwise default values will be used 8 | USER_NAME: ${HOST_USER_NAME:-dreamer} # export HOST_USER_NAME=$USER 9 | GROUP_NAME: ${HOST_GROUP_NAME:-dreamers} 10 | UID: ${HOST_UID:-1000} # export HOST_UID=$(id -u) 11 | GID: ${HOST_GID:-1000} # export HOST_GID=$(id -g) 12 | shm_size: '4gb' 13 | environment: 14 | NVIDIA_DISABLE_REQUIRE: 1 # avoid wrong `nvidia-container-cli: requirement error` 15 | tty: true 16 | volumes: 17 | - ../:/home/${HOST_USER_NAME:-dreamer}/threestudio 18 | deploy: 19 | resources: 20 | reservations: 21 | devices: 22 | - driver: nvidia 23 | capabilities: [gpu] 24 | -------------------------------------------------------------------------------- /docs/DreamAvatar-supp.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/DreamAvatar-supp.pdf -------------------------------------------------------------------------------- /docs/static/css/bulma-carousel.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:.5rem 1rem;text-align:center}.slider-pagination .slider-page{background:#fff;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10} -------------------------------------------------------------------------------- /docs/static/css/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Noto Sans', sans-serif; 3 | } 4 | 5 | 6 | .footer .icon-link { 7 | font-size: 25px; 8 | color: #000; 9 | } 10 | 11 | .link-block a { 12 | margin-top: 5px; 13 | margin-bottom: 5px; 14 | } 15 | 16 | .dnerf { 17 | font-variant: small-caps; 18 | } 19 | 20 | 21 | .teaser .hero-body { 22 | padding-top: 0; 23 | padding-bottom: 3rem; 24 | } 25 | 26 | .teaser { 27 | font-family: 'Google Sans', sans-serif; 28 | } 29 | 30 | 31 | .publication-title { 32 | } 33 | 34 | .publication-banner { 35 | max-height: parent; 36 | 37 | } 38 | 39 | .publication-banner video { 40 | position: relative; 41 | left: auto; 42 | top: auto; 43 | transform: none; 44 | object-fit: fit; 45 | } 46 | 47 | .publication-header .hero-body { 48 | } 49 | 50 | .publication-title { 51 | font-family: 'Google Sans', sans-serif; 52 | } 53 | 54 | .publication-authors { 55 | font-family: 'Google Sans', sans-serif; 56 | } 57 | 58 | .publication-venue { 59 | color: #555; 60 | width: fit-content; 61 | font-weight: bold; 62 | } 63 | 64 | .publication-awards { 65 | color: #ff3860; 66 | width: fit-content; 67 | font-weight: bolder; 68 | } 69 | 70 | .publication-authors { 71 | } 72 | 73 | .publication-authors a { 74 | color: hsl(204, 86%, 53%) !important; 75 | } 76 | 77 | .publication-authors a:hover { 78 | text-decoration: underline; 79 | } 80 | 81 | .author-block { 82 | display: inline-block; 83 | } 84 | 85 | .publication-banner img { 86 | } 87 | 88 | .publication-authors { 89 | /*color: #4286f4;*/ 90 | } 91 | 92 | .publication-video { 93 | position: relative; 94 | width: 100%; 95 | height: 0; 96 | padding-bottom: 56.25%; 97 | 98 | overflow: hidden; 99 | border-radius: 10px !important; 100 | } 101 | 102 | .publication-video iframe { 103 | position: absolute; 104 | top: 0; 105 | left: 0; 106 | width: 100%; 107 | height: 100%; 108 | } 109 | 110 | .publication-body img { 111 | } 112 | 113 | .results-carousel { 114 | overflow: hidden; 115 | } 116 | 117 | .results-carousel .item { 118 | margin: 5px; 119 | overflow: hidden; 120 | border: 1px solid #bbb; 121 | border-radius: 10px; 122 | padding: 0; 123 | font-size: 0; 124 | } 125 | 126 | .results-carousel video { 127 | margin: 0; 128 | } 129 | 130 | 131 | .interpolation-panel { 132 | background: #f5f5f5; 133 | border-radius: 10px; 134 | } 135 | 136 | .interpolation-panel .interpolation-image { 137 | width: 100%; 138 | border-radius: 5px; 139 | } 140 | 141 | .interpolation-video-column { 142 | } 143 | 144 | .interpolation-panel .slider { 145 | margin: 0 !important; 146 | } 147 | 148 | .interpolation-panel .slider { 149 | margin: 0 !important; 150 | } 151 | 152 | #interpolation-image-wrapper { 153 | width: 100%; 154 | } 155 | #interpolation-image-wrapper img { 156 | border-radius: 5px; 157 | } 158 | -------------------------------------------------------------------------------- /docs/static/gif/clown.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/gif/clown.gif -------------------------------------------------------------------------------- /docs/static/gif/deadpool.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/gif/deadpool.gif -------------------------------------------------------------------------------- /docs/static/gif/joker.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/gif/joker.gif -------------------------------------------------------------------------------- /docs/static/gif/link.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/gif/link.gif -------------------------------------------------------------------------------- /docs/static/js/bulma-slider.min.js: -------------------------------------------------------------------------------- 1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default}); -------------------------------------------------------------------------------- /docs/static/js/index.js: -------------------------------------------------------------------------------- 1 | window.HELP_IMPROVE_VIDEOJS = false; 2 | 3 | var INTERP_BASE = "./static/interpolation/stacked"; 4 | var NUM_INTERP_FRAMES = 240; 5 | 6 | var interp_images = []; 7 | function preloadInterpolationImages() { 8 | for (var i = 0; i < NUM_INTERP_FRAMES; i++) { 9 | var path = INTERP_BASE + '/' + String(i).padStart(6, '0') + '.jpg'; 10 | interp_images[i] = new Image(); 11 | interp_images[i].src = path; 12 | } 13 | } 14 | 15 | function setInterpolationImage(i) { 16 | var image = interp_images[i]; 17 | image.ondragstart = function() { return false; }; 18 | image.oncontextmenu = function() { return false; }; 19 | $('#interpolation-image-wrapper').empty().append(image); 20 | } 21 | 22 | 23 | $(document).ready(function() { 24 | // Check for click events on the navbar burger icon 25 | $(".navbar-burger").click(function() { 26 | // Toggle the "is-active" class on both the "navbar-burger" and the "navbar-menu" 27 | $(".navbar-burger").toggleClass("is-active"); 28 | $(".navbar-menu").toggleClass("is-active"); 29 | 30 | }); 31 | 32 | var options = { 33 | slidesToScroll: 1, 34 | slidesToShow: 3, 35 | loop: true, 36 | infinite: true, 37 | autoplay: false, 38 | autoplaySpeed: 3000, 39 | } 40 | 41 | // Initialize all div with carousel class 42 | var carousels = bulmaCarousel.attach('.carousel', options); 43 | 44 | // Loop on each carousel initialized 45 | for(var i = 0; i < carousels.length; i++) { 46 | // Add listener to event 47 | carousels[i].on('before:show', state => { 48 | console.log(state); 49 | }); 50 | } 51 | 52 | // Access to bulmaCarousel instance of an element 53 | var element = document.querySelector('#my-element'); 54 | if (element && element.bulmaCarousel) { 55 | // bulmaCarousel instance is available as element.bulmaCarousel 56 | element.bulmaCarousel.on('before-show', function(state) { 57 | console.log(state); 58 | }); 59 | } 60 | 61 | /*var player = document.getElementById('interpolation-video'); 62 | player.addEventListener('loadedmetadata', function() { 63 | $('#interpolation-slider').on('input', function(event) { 64 | console.log(this.value, player.duration); 65 | player.currentTime = player.duration / 100 * this.value; 66 | }) 67 | }, false);*/ 68 | preloadInterpolationImages(); 69 | 70 | $('#interpolation-slider').on('input', function(event) { 71 | setInterpolationImage(this.value); 72 | }); 73 | setInterpolationImage(0); 74 | $('#interpolation-slider').prop('max', NUM_INTERP_FRAMES - 1); 75 | 76 | bulmaSlider.attach(); 77 | 78 | }) 79 | -------------------------------------------------------------------------------- /docs/static/js/result.js: -------------------------------------------------------------------------------- 1 | const container = document.querySelector('.container'); 2 | const divider = document.querySelector('.divider'); 3 | const leftGif = document.querySelector('.gif:first-child'); 4 | const rightGif = document.querySelector('.gif:last-child'); 5 | let isDragging = false; 6 | 7 | divider.addEventListener('mousedown', function(e) { 8 | isDragging = true; 9 | }); 10 | 11 | container.addEventListener('mousemove', function(e) { 12 | if (!isDragging) return; 13 | const containerRect = container.getBoundingClientRect(); 14 | const mousePosition = e.clientX - containerRect.left; 15 | const containerWidth = containerRect.width; 16 | const dividerWidth = divider.offsetWidth; 17 | const minLeft = 0; 18 | const maxLeft = containerWidth - dividerWidth; 19 | const newLeft = Math.max(minLeft, Math.min(maxLeft, mousePosition)); 20 | const leftWidth = newLeft / containerWidth * 100; 21 | const rightWidth = 100 - leftWidth; 22 | 23 | divider.style.left = leftWidth + '%'; 24 | leftGif.style.width = leftWidth + '%'; 25 | rightGif.style.left = leftWidth + '%'; 26 | rightGif.style.width = rightWidth + '%'; 27 | }); 28 | 29 | document.addEventListener('mouseup', function(e) { 30 | isDragging = false; 31 | }); 32 | -------------------------------------------------------------------------------- /docs/static/video/Pipeline-n.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/Pipeline-n.png -------------------------------------------------------------------------------- /docs/static/video/canonical/Alien.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Alien.mp4 -------------------------------------------------------------------------------- /docs/static/video/canonical/Buddhist_monk.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Buddhist_monk.mp4 -------------------------------------------------------------------------------- /docs/static/video/canonical/C-3PO.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/C-3PO.mp4 -------------------------------------------------------------------------------- /docs/static/video/canonical/Crystal_maiden.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Crystal_maiden.mp4 -------------------------------------------------------------------------------- /docs/static/video/canonical/Electro.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Electro.mp4 -------------------------------------------------------------------------------- /docs/static/video/canonical/Flash.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Flash.mp4 -------------------------------------------------------------------------------- /docs/static/video/canonical/Groot.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Groot.mp4 -------------------------------------------------------------------------------- /docs/static/video/canonical/Joker.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Joker.mp4 -------------------------------------------------------------------------------- /docs/static/video/canonical/Link.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Link.mp4 -------------------------------------------------------------------------------- /docs/static/video/canonical/Link_2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Link_2.mp4 -------------------------------------------------------------------------------- /docs/static/video/canonical/Luffy.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Luffy.mp4 -------------------------------------------------------------------------------- /docs/static/video/canonical/Spiderman.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Spiderman.mp4 -------------------------------------------------------------------------------- /docs/static/video/canonical/Track_field_athlete.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Track_field_athlete.mp4 -------------------------------------------------------------------------------- /docs/static/video/canonical/Wonder_woman.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Wonder_woman.mp4 -------------------------------------------------------------------------------- /docs/static/video/canonical/Woody.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Woody.mp4 -------------------------------------------------------------------------------- /docs/static/video/canonical/Woody_in_joker.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Woody_in_joker.mp4 -------------------------------------------------------------------------------- /docs/static/video/canonical/body_builder.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/body_builder.mp4 -------------------------------------------------------------------------------- /docs/static/video/canonical/clown.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/clown.mp4 -------------------------------------------------------------------------------- /docs/static/video/canonical/hipster_man.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/hipster_man.mp4 -------------------------------------------------------------------------------- /docs/static/video/canonical/kakashi.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/kakashi.mp4 -------------------------------------------------------------------------------- /docs/static/video/canonical/sasuke.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/sasuke.mp4 -------------------------------------------------------------------------------- /docs/static/video/canonical/woman_hippie.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/woman_hippie.mp4 -------------------------------------------------------------------------------- /docs/static/video/integration/clown.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/integration/clown.mp4 -------------------------------------------------------------------------------- /docs/static/video/integration/deadpool.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/integration/deadpool.mp4 -------------------------------------------------------------------------------- /docs/static/video/integration/joker.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/integration/joker.mp4 -------------------------------------------------------------------------------- /docs/static/video/integration/link.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/integration/link.mp4 -------------------------------------------------------------------------------- /docs/static/video/poses/Flash-00576.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/poses/Flash-00576.mp4 -------------------------------------------------------------------------------- /docs/static/video/poses/Flash-00596.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/poses/Flash-00596.mp4 -------------------------------------------------------------------------------- /docs/static/video/poses/groot-00084.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/poses/groot-00084.mp4 -------------------------------------------------------------------------------- /docs/static/video/poses/groot-00230.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/poses/groot-00230.mp4 -------------------------------------------------------------------------------- /docs/static/video/poses/groot-00308.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/poses/groot-00308.mp4 -------------------------------------------------------------------------------- /docs/static/video/poses/groot-00350.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/poses/groot-00350.mp4 -------------------------------------------------------------------------------- /docs/static/video/poses/joker-00296.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/poses/joker-00296.mp4 -------------------------------------------------------------------------------- /docs/static/video/poses/joker-00510.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/poses/joker-00510.mp4 -------------------------------------------------------------------------------- /docs/static/video/poses/joker-00530.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/poses/joker-00530.mp4 -------------------------------------------------------------------------------- /docs/static/video/poses/joker-00536.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/poses/joker-00536.mp4 -------------------------------------------------------------------------------- /docs/static/video/poses/spiderman-00028.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/poses/spiderman-00028.mp4 -------------------------------------------------------------------------------- /docs/static/video/poses/spiderman-0279.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/poses/spiderman-0279.mp4 -------------------------------------------------------------------------------- /docs/static/video/shapes/groot-0-1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/shapes/groot-0-1.mp4 -------------------------------------------------------------------------------- /docs/static/video/shapes/groot-0-3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/shapes/groot-0-3.mp4 -------------------------------------------------------------------------------- /docs/static/video/shapes/groot-1+2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/shapes/groot-1+2.mp4 -------------------------------------------------------------------------------- /docs/static/video/shapes/groot-1-2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/shapes/groot-1-2.mp4 -------------------------------------------------------------------------------- /docs/static/video/text_manipulation/joker-black.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/text_manipulation/joker-black.mp4 -------------------------------------------------------------------------------- /docs/static/video/text_manipulation/joker-green.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/text_manipulation/joker-green.mp4 -------------------------------------------------------------------------------- /docs/static/video/text_manipulation/joker-pink.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/text_manipulation/joker-pink.mp4 -------------------------------------------------------------------------------- /docs/static/video/text_manipulation/joker-texudo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/text_manipulation/joker-texudo.mp4 -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: threestudio 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2023.01.10=h06a4308_0 8 | - ld_impl_linux-64=2.38=h1181459_1 9 | - libffi=3.4.4=h6a678d5_0 10 | - libgcc-ng=11.2.0=h1234567_1 11 | - libgomp=11.2.0=h1234567_1 12 | - libstdcxx-ng=11.2.0=h1234567_1 13 | - ncurses=6.4=h6a678d5_0 14 | - openssl=1.1.1t=h7f8727e_0 15 | - pip=23.0.1=py38h06a4308_0 16 | - python=3.8.16=h7a1cb2a_3 17 | - readline=8.2=h5eee18b_0 18 | - setuptools=66.0.0=py38h06a4308_0 19 | - sqlite=3.41.2=h5eee18b_0 20 | - tk=8.6.12=h1ccaba5_0 21 | - wheel=0.38.4=py38h06a4308_0 22 | - xz=5.4.2=h5eee18b_0 23 | - zlib=1.2.13=h5eee18b_0 24 | - pip: 25 | - absl-py==1.4.0 26 | - accelerate==0.19.0 27 | - aiofiles==23.1.0 28 | - aiohttp==3.8.4 29 | - aiosignal==1.3.1 30 | - altair==5.0.1 31 | - antlr4-python3-runtime==4.9.3 32 | - anyio==3.6.2 33 | - appdirs==1.4.4 34 | - arrow==1.2.3 35 | - asttokens==2.2.1 36 | - astunparse==1.6.3 37 | - async-timeout==4.0.2 38 | - attrs==23.1.0 39 | - backcall==0.2.0 40 | - beautifulsoup4==4.12.2 41 | - bitsandbytes==0.39.0 42 | - blessed==1.20.0 43 | - blinker==1.6.2 44 | - cachetools==5.3.0 45 | - certifi==2023.5.7 46 | - chardet==5.1.0 47 | - charset-normalizer==3.1.0 48 | - chumpy==0.70 49 | - click==8.1.3 50 | - clip==1.0 51 | - colorlog==6.7.0 52 | - comm==0.1.3 53 | - contourpy==1.0.7 54 | - controlnet-aux==0.0.6 55 | - croniter==1.3.14 56 | - cycler==0.11.0 57 | - dateutils==0.6.12 58 | - debugpy==1.6.7 59 | - decorator==5.1.1 60 | - deepdiff==6.3.0 61 | - diffusers==0.16.1 62 | - distro==1.8.0 63 | - docker-pycreds==0.4.0 64 | - einops==0.6.1 65 | - entrypoints==0.4 66 | - envlight==0.1.0 67 | - executing==1.2.0 68 | - fastapi==0.88.0 69 | - ffmpy==0.3.1 70 | - filelock==3.12.0 71 | - flask==2.3.2 72 | - flatbuffers==23.5.26 73 | - fonttools==4.39.4 74 | - frozenlist==1.3.3 75 | - fsspec==2023.5.0 76 | - ftfy==6.1.1 77 | - fvcore==0.1.5.post20221221 78 | - gast==0.4.0 79 | - gitdb==4.0.10 80 | - gitpython==3.1.31 81 | - google-auth==2.18.1 82 | - google-auth-oauthlib==1.0.0 83 | - google-pasta==0.2.0 84 | - gradio==3.38.0 85 | - gradio-client==0.2.10 86 | - grpcio==1.54.2 87 | - h11==0.14.0 88 | - h5py==3.9.0 89 | - httpcore==0.17.3 90 | - httpx==0.24.1 91 | - huggingface-hub==0.14.1 92 | - idna==3.4 93 | - imageio==2.29.0 94 | - imageio-ffmpeg==0.4.8 95 | - importlib-metadata==6.6.0 96 | - importlib-resources==5.12.0 97 | - inquirer==3.1.3 98 | - iopath==0.1.10 99 | - ipycanvas==0.13.1 100 | - ipyevents==2.0.1 101 | - ipykernel==6.23.1 102 | - ipython==8.12.2 103 | - ipywidgets==8.0.6 104 | - itsdangerous==2.1.2 105 | - jaxtyping==0.2.19 106 | - jedi==0.18.2 107 | - jinja2==3.1.2 108 | - jsonschema==4.18.4 109 | - jsonschema-specifications==2023.7.1 110 | - jupyter-client==7.4.9 111 | - jupyter-core==5.3.0 112 | - jupyterlab-widgets==3.0.7 113 | - kaolin==0.14.0a0 114 | - keras==2.13.1 115 | - kiwisolver==1.4.4 116 | - kornia==0.6.12 117 | - lazy-loader==0.3 118 | - libclang==16.0.6 119 | - libigl==2.4.1 120 | - lightning==2.0.0 121 | - lightning-cloud==0.5.36 122 | - lightning-utilities==0.8.0 123 | - linkify-it-py==2.0.2 124 | - lxml==4.9.3 125 | - mapbox-earcut==1.0.1 126 | - markdown==3.4.3 127 | - markdown-it-py==2.2.0 128 | - markupsafe==2.1.2 129 | - matplotlib==3.7.1 130 | - matplotlib-inline==0.1.6 131 | - mdit-py-plugins==0.3.3 132 | - mdurl==0.1.2 133 | - mpmath==1.3.0 134 | - multidict==6.0.4 135 | - mypy-extensions==1.0.0 136 | - nerfacc==0.5.2 137 | - nest-asyncio==1.5.6 138 | - networkx==3.1 139 | - ninja==1.11.1 140 | - numpy==1.23.0 141 | - nvdiffrast==0.3.1 142 | - oauthlib==3.2.2 143 | - omegaconf==2.3.0 144 | - opencv-python==4.7.0.72 145 | - opt-einsum==3.3.0 146 | - ordered-set==4.1.0 147 | - orjson==3.9.2 148 | - packaging==23.1 149 | - pandas==2.0.3 150 | - parso==0.8.3 151 | - pathtools==0.1.2 152 | - pexpect==4.8.0 153 | - pickleshare==0.7.5 154 | - pillow==9.5.0 155 | - pkgutil-resolve-name==1.3.10 156 | - platformdirs==3.5.1 157 | - plyfile==0.9 158 | - portalocker==2.7.0 159 | - prompt-toolkit==3.0.38 160 | - protobuf==4.23.1 161 | - psutil==5.9.5 162 | - ptyprocess==0.7.0 163 | - pure-eval==0.2.2 164 | - pyasn1==0.5.0 165 | - pyasn1-modules==0.3.0 166 | - pybind11==2.10.4 167 | - pycollada==0.7.2 168 | - pydantic==1.10.8 169 | - pydub==0.25.1 170 | - pygments==2.15.1 171 | - pyjwt==2.7.0 172 | - pymcubes==0.1.4 173 | - pyparsing==3.0.9 174 | - pyre-extensions==0.0.29 175 | - pysdf==0.1.9 176 | - python-dateutil==2.8.2 177 | - python-editor==1.0.4 178 | - python-multipart==0.0.6 179 | - pytorch-lightning==2.0.2 180 | - pytorch3d==0.7.4 181 | - pytz==2023.3 182 | - pywavelets==1.4.1 183 | - pyyaml==6.0 184 | - pyzmq==24.0.1 185 | - readchar==4.0.5 186 | - referencing==0.30.0 187 | - regex==2023.5.5 188 | - requests==2.31.0 189 | - requests-oauthlib==1.3.1 190 | - rich==13.3.5 191 | - rpds-py==0.9.2 192 | - rsa==4.9 193 | - rtree==1.0.1 194 | - safetensors==0.3.1 195 | - scikit-build==0.17.5 196 | - scikit-image==0.21.0 197 | - scipy==1.10.1 198 | - semantic-version==2.10.0 199 | - sentencepiece==0.1.99 200 | - sentry-sdk==1.25.0 201 | - setproctitle==1.3.2 202 | - shapely==2.0.1 203 | - six==1.16.0 204 | - smmap==5.0.0 205 | - smplx==0.1.28 206 | - sniffio==1.3.0 207 | - soupsieve==2.4.1 208 | - stack-data==0.6.2 209 | - starlette==0.22.0 210 | - starsessions==1.3.0 211 | - svg-path==6.3 212 | - sympy==1.12 213 | - tabulate==0.9.0 214 | - taming-transformers-rom1504==0.0.6 215 | - tensorboard==2.13.0 216 | - tensorboard-data-server==0.7.0 217 | - tensorflow==2.13.0 218 | - tensorflow-estimator==2.13.0 219 | - tensorflow-io-gcs-filesystem==0.32.0 220 | - termcolor==2.3.0 221 | - tifffile==2023.7.10 222 | - timm==0.9.2 223 | - tinycudann==1.7 224 | - tokenizers==0.13.3 225 | - tomli==2.0.1 226 | - toolz==0.12.0 227 | - torch==1.12.1+cu113 228 | - torchmetrics==0.11.4 229 | - torchvision==0.13.1+cu113 230 | - tornado==6.3.2 231 | - tqdm==4.65.0 232 | - traitlets==5.9.0 233 | - transformers==4.29.2 234 | - trimesh==3.21.7 235 | - typeguard==4.0.0 236 | - typing-extensions==4.5.0 237 | - typing-inspect==0.9.0 238 | - tzdata==2023.3 239 | - uc-micro-py==1.0.2 240 | - urllib3==1.26.16 241 | - usd-core==23.5 242 | - uvicorn==0.22.0 243 | - wandb==0.15.3 244 | - wcwidth==0.2.6 245 | - websocket-client==1.5.2 246 | - websockets==11.0.3 247 | - werkzeug==2.3.4 248 | - widgetsnbextension==4.0.7 249 | - wrapt==1.15.0 250 | - xatlas==0.0.7 251 | - xformers==0.0.21+efcd789.d20230525 252 | - xxhash==3.2.0 253 | - yacs==0.1.8 254 | - yarl==1.9.2 255 | - zipp==3.15.0 256 | prefix: /data2/ykcao/anaconda3/envs/threestudio 257 | -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import contextlib 3 | import logging 4 | import os 5 | import sys 6 | 7 | 8 | class ColoredFilter(logging.Filter): 9 | """ 10 | A logging filter to add color to certain log levels. 11 | """ 12 | 13 | RESET = "\033[0m" 14 | RED = "\033[31m" 15 | GREEN = "\033[32m" 16 | YELLOW = "\033[33m" 17 | BLUE = "\033[34m" 18 | MAGENTA = "\033[35m" 19 | CYAN = "\033[36m" 20 | 21 | COLORS = { 22 | "WARNING": YELLOW, 23 | "INFO": GREEN, 24 | "DEBUG": BLUE, 25 | "CRITICAL": MAGENTA, 26 | "ERROR": RED, 27 | } 28 | 29 | RESET = "\x1b[0m" 30 | 31 | def __init__(self): 32 | super().__init__() 33 | 34 | def filter(self, record): 35 | if record.levelname in self.COLORS: 36 | color_start = self.COLORS[record.levelname] 37 | record.levelname = f"{color_start}[{record.levelname}]" 38 | record.msg = f"{record.msg}{self.RESET}" 39 | return True 40 | 41 | 42 | def main(args, extras) -> None: 43 | # set CUDA_VISIBLE_DEVICES if needed, then import pytorch-lightning 44 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 45 | env_gpus_str = os.environ.get("CUDA_VISIBLE_DEVICES", None) 46 | env_gpus = list(env_gpus_str.split(",")) if env_gpus_str else [] 47 | selected_gpus = [0] 48 | 49 | # Always rely on CUDA_VISIBLE_DEVICES if specific GPU ID(s) are specified. 50 | # As far as Pytorch Lightning is concerned, we always use all available GPUs 51 | # (possibly filtered by CUDA_VISIBLE_DEVICES). 52 | devices = -1 53 | if len(env_gpus) > 0: 54 | # CUDA_VISIBLE_DEVICES was set already, e.g. within SLURM srun or higher-level script. 55 | n_gpus = len(env_gpus) 56 | else: 57 | selected_gpus = list(args.gpu.split(",")) 58 | n_gpus = len(selected_gpus) 59 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 60 | 61 | import pytorch_lightning as pl 62 | import torch 63 | from pytorch_lightning import Trainer 64 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 65 | from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger 66 | from pytorch_lightning.utilities.rank_zero import rank_zero_only 67 | 68 | if args.typecheck: 69 | from jaxtyping import install_import_hook 70 | 71 | install_import_hook("threestudio", "typeguard.typechecked") 72 | 73 | import threestudio 74 | from threestudio.systems.base import BaseSystem 75 | from threestudio.utils.callbacks import ( 76 | CodeSnapshotCallback, 77 | ConfigSnapshotCallback, 78 | CustomProgressBar, 79 | ProgressCallback, 80 | ) 81 | from threestudio.utils.config import ExperimentConfig, load_config 82 | from threestudio.utils.misc import get_rank 83 | from threestudio.utils.typing import Optional 84 | 85 | logger = logging.getLogger("pytorch_lightning") 86 | if args.verbose: 87 | logger.setLevel(logging.DEBUG) 88 | 89 | for handler in logger.handlers: 90 | if handler.stream == sys.stderr: # type: ignore 91 | if not args.gradio: 92 | handler.setFormatter(logging.Formatter("%(levelname)s %(message)s")) 93 | handler.addFilter(ColoredFilter()) 94 | else: 95 | handler.setFormatter(logging.Formatter("[%(levelname)s] %(message)s")) 96 | 97 | # parse YAML config to OmegaConf 98 | cfg: ExperimentConfig 99 | cfg = load_config(args.config, cli_args=extras, n_gpus=n_gpus) 100 | 101 | # set a different seed for each device 102 | pl.seed_everything(cfg.seed + get_rank(), workers=True) 103 | 104 | dm = threestudio.find(cfg.data_type)(cfg.data) 105 | system: BaseSystem = threestudio.find(cfg.system_type)( 106 | cfg.system, resumed=cfg.resume is not None 107 | ) 108 | system.set_save_dir(os.path.join(cfg.trial_dir, "save")) 109 | 110 | if args.gradio: 111 | fh = logging.FileHandler(os.path.join(cfg.trial_dir, "logs")) 112 | fh.setLevel(logging.INFO) 113 | if args.verbose: 114 | fh.setLevel(logging.DEBUG) 115 | fh.setFormatter(logging.Formatter("[%(levelname)s] %(message)s")) 116 | logger.addHandler(fh) 117 | 118 | callbacks = [] 119 | if args.train: 120 | callbacks += [ 121 | ModelCheckpoint( 122 | dirpath=os.path.join(cfg.trial_dir, "ckpts"), **cfg.checkpoint 123 | ), 124 | LearningRateMonitor(logging_interval="step"), 125 | CodeSnapshotCallback( 126 | os.path.join(cfg.trial_dir, "code"), use_version=False 127 | ), 128 | ConfigSnapshotCallback( 129 | args.config, 130 | cfg, 131 | os.path.join(cfg.trial_dir, "configs"), 132 | use_version=False, 133 | ), 134 | ] 135 | if args.gradio: 136 | callbacks += [ 137 | ProgressCallback(save_path=os.path.join(cfg.trial_dir, "progress")) 138 | ] 139 | else: 140 | callbacks += [CustomProgressBar(refresh_rate=1)] 141 | 142 | def write_to_text(file, lines): 143 | with open(file, "w") as f: 144 | for line in lines: 145 | f.write(line + "\n") 146 | 147 | loggers = [] 148 | if args.train: 149 | # make tensorboard logging dir to suppress warning 150 | rank_zero_only( 151 | lambda: os.makedirs(os.path.join(cfg.trial_dir, "tb_logs"), exist_ok=True) 152 | )() 153 | loggers += [ 154 | TensorBoardLogger(cfg.trial_dir, name="tb_logs"), 155 | CSVLogger(cfg.trial_dir, name="csv_logs"), 156 | ] + system.get_loggers() 157 | rank_zero_only( 158 | lambda: write_to_text( 159 | os.path.join(cfg.trial_dir, "cmd.txt"), 160 | ["python " + " ".join(sys.argv), str(args)], 161 | ) 162 | )() 163 | 164 | trainer = Trainer( 165 | callbacks=callbacks, 166 | logger=loggers, 167 | inference_mode=False, 168 | accelerator="gpu", 169 | devices=devices, 170 | **cfg.trainer, 171 | ) 172 | 173 | def set_system_status(system: BaseSystem, ckpt_path: Optional[str]): 174 | if ckpt_path is None: 175 | return 176 | ckpt = torch.load(ckpt_path, map_location="cpu") 177 | system.set_resume_status(ckpt["epoch"], ckpt["global_step"]) 178 | 179 | if args.train: 180 | trainer.fit(system, datamodule=dm, ckpt_path=cfg.resume) 181 | trainer.test(system, datamodule=dm) 182 | if args.gradio: 183 | # also export assets if in gradio mode 184 | trainer.predict(system, datamodule=dm) 185 | elif args.validate: 186 | # manually set epoch and global_step as they cannot be automatically resumed 187 | set_system_status(system, cfg.resume) 188 | trainer.validate(system, datamodule=dm, ckpt_path=cfg.resume) 189 | elif args.test: 190 | # manually set epoch and global_step as they cannot be automatically resumed 191 | set_system_status(system, cfg.resume) 192 | trainer.test(system, datamodule=dm, ckpt_path=cfg.resume) 193 | elif args.export: 194 | set_system_status(system, cfg.resume) 195 | trainer.predict(system, datamodule=dm, ckpt_path=cfg.resume) 196 | 197 | 198 | if __name__ == "__main__": 199 | parser = argparse.ArgumentParser() 200 | parser.add_argument("--config", required=True, help="path to config file") 201 | parser.add_argument( 202 | "--gpu", 203 | default="0", 204 | help="GPU(s) to be used. 0 means use the 1st available GPU. " 205 | "1,2 means use the 2nd and 3rd available GPU. " 206 | "If CUDA_VISIBLE_DEVICES is set before calling `launch.py`, " 207 | "this argument is ignored and all available GPUs are always used.", 208 | ) 209 | 210 | group = parser.add_mutually_exclusive_group(required=True) 211 | group.add_argument("--train", action="store_true") 212 | group.add_argument("--validate", action="store_true") 213 | group.add_argument("--test", action="store_true") 214 | group.add_argument("--export", action="store_true") 215 | 216 | parser.add_argument( 217 | "--gradio", action="store_true", help="if true, run in gradio mode" 218 | ) 219 | 220 | parser.add_argument( 221 | "--verbose", action="store_true", help="if true, set logging level to DEBUG" 222 | ) 223 | 224 | parser.add_argument( 225 | "--typecheck", 226 | action="store_true", 227 | help="whether to enable dynamic type checking", 228 | ) 229 | 230 | args, extras = parser.parse_known_args() 231 | 232 | if args.gradio: 233 | # FIXME: no effect, stdout is not captured 234 | with contextlib.redirect_stdout(sys.stderr): 235 | main(args, extras) 236 | else: 237 | main(args, extras) 238 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | lightning==2.0.0 2 | omegaconf==2.3.0 3 | jaxtyping 4 | typeguard 5 | git+https://github.com/KAIR-BAIR/nerfacc.git@v0.5.2 6 | git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch 7 | diffusers 8 | transformers 9 | accelerate 10 | opencv-python 11 | tensorboard 12 | matplotlib 13 | imageio>=2.28.0 14 | imageio[ffmpeg] 15 | git+https://github.com/NVlabs/nvdiffrast.git 16 | libigl 17 | xatlas 18 | trimesh[easy] 19 | networkx 20 | pysdf 21 | PyMCubes 22 | wandb 23 | gradio 24 | git+https://github.com/ashawkey/envlight.git 25 | torchmetrics 26 | 27 | # deepfloyd 28 | xformers 29 | bitsandbytes 30 | sentencepiece 31 | safetensors 32 | huggingface_hub 33 | 34 | # for zero123 35 | einops 36 | kornia 37 | taming-transformers-rom1504 38 | git+https://github.com/openai/CLIP.git 39 | 40 | #controlnet 41 | controlnet_aux 42 | -------------------------------------------------------------------------------- /threestudio/__init__.py: -------------------------------------------------------------------------------- 1 | __modules__ = {} 2 | 3 | 4 | def register(name): 5 | def decorator(cls): 6 | __modules__[name] = cls 7 | return cls 8 | 9 | return decorator 10 | 11 | 12 | def find(name): 13 | return __modules__[name] 14 | 15 | 16 | ### grammar sugar for logging utilities ### 17 | import logging 18 | 19 | logger = logging.getLogger("pytorch_lightning") 20 | 21 | from pytorch_lightning.utilities.rank_zero import ( 22 | rank_zero_debug, 23 | rank_zero_info, 24 | rank_zero_only, 25 | ) 26 | 27 | debug = rank_zero_debug 28 | info = rank_zero_info 29 | 30 | 31 | @rank_zero_only 32 | def warn(*args, **kwargs): 33 | logger.warn(*args, **kwargs) 34 | 35 | 36 | from . import data, models, systems 37 | -------------------------------------------------------------------------------- /threestudio/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import co3d, image, multiview, uncond 2 | -------------------------------------------------------------------------------- /threestudio/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | background, 3 | exporters, 4 | geometry, 5 | guidance, 6 | materials, 7 | prompt_processors, 8 | renderers, 9 | ) 10 | -------------------------------------------------------------------------------- /threestudio/models/background/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | base, 3 | neural_environment_map_background, 4 | ) 5 | -------------------------------------------------------------------------------- /threestudio/models/background/base.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.utils.base import BaseModule 10 | from threestudio.utils.typing import * 11 | 12 | 13 | class BaseBackground(BaseModule): 14 | @dataclass 15 | class Config(BaseModule.Config): 16 | pass 17 | 18 | cfg: Config 19 | 20 | def configure(self): 21 | pass 22 | 23 | def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]: 24 | raise NotImplementedError 25 | -------------------------------------------------------------------------------- /threestudio/models/background/neural_environment_map_background.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.background.base import BaseBackground 10 | from threestudio.models.networks import get_encoding, get_mlp 11 | from threestudio.utils.ops import get_activation 12 | from threestudio.utils.typing import * 13 | 14 | 15 | @threestudio.register("neural-environment-map-background") 16 | class NeuralEnvironmentMapBackground(BaseBackground): 17 | @dataclass 18 | class Config(BaseBackground.Config): 19 | n_output_dims: int = 3 20 | color_activation: str = "sigmoid" 21 | dir_encoding_config: dict = field( 22 | default_factory=lambda: {"otype": "SphericalHarmonics", "degree": 3} 23 | ) 24 | mlp_network_config: dict = field( 25 | default_factory=lambda: { 26 | "otype": "VanillaMLP", 27 | "activation": "ReLU", 28 | "n_neurons": 16, 29 | "n_hidden_layers": 2, 30 | } 31 | ) 32 | random_aug: bool = False 33 | random_aug_prob: float = 0.5 34 | eval_color: Optional[Tuple[float, float, float]] = None 35 | 36 | cfg: Config 37 | 38 | def configure(self) -> None: 39 | self.encoding = get_encoding(3, self.cfg.dir_encoding_config) 40 | self.network = get_mlp( 41 | self.encoding.n_output_dims, 42 | self.cfg.n_output_dims, 43 | self.cfg.mlp_network_config, 44 | ) 45 | 46 | def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]: 47 | if not self.training and self.cfg.eval_color is not None: 48 | return torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to( 49 | dirs 50 | ) * torch.as_tensor(self.cfg.eval_color).to(dirs) 51 | # viewdirs must be normalized before passing to this function 52 | dirs = (dirs + 1.0) / 2.0 # (-1, 1) => (0, 1) 53 | dirs_embd = self.encoding(dirs.view(-1, 3)) 54 | color = self.network(dirs_embd).view(*dirs.shape[:-1], self.cfg.n_output_dims) 55 | color = get_activation(self.cfg.color_activation)(color) 56 | if ( 57 | self.training 58 | and self.cfg.random_aug 59 | and random.random() < self.cfg.random_aug_prob 60 | ): 61 | # use random background color with probability random_aug_prob 62 | color = color * 0 + ( # prevent checking for unused parameters in DDP 63 | torch.rand(dirs.shape[0], 1, 1, self.cfg.n_output_dims) 64 | .to(dirs) 65 | .expand(*dirs.shape[:-1], -1) 66 | ) 67 | return color 68 | -------------------------------------------------------------------------------- /threestudio/models/exporters/__init__.py: -------------------------------------------------------------------------------- 1 | from . import base, mesh_exporter 2 | -------------------------------------------------------------------------------- /threestudio/models/exporters/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import threestudio 4 | from threestudio.models.background.base import BaseBackground 5 | from threestudio.models.geometry.base import BaseImplicitGeometry 6 | from threestudio.models.materials.base import BaseMaterial 7 | from threestudio.utils.base import BaseObject 8 | from threestudio.utils.typing import * 9 | 10 | 11 | @dataclass 12 | class ExporterOutput: 13 | save_name: str 14 | save_type: str 15 | params: Dict[str, Any] 16 | 17 | 18 | class Exporter(BaseObject): 19 | @dataclass 20 | class Config(BaseObject.Config): 21 | save_video: bool = False 22 | 23 | cfg: Config 24 | 25 | def configure( 26 | self, 27 | geometry: BaseImplicitGeometry, 28 | material: BaseMaterial, 29 | background: BaseBackground, 30 | ) -> None: 31 | @dataclass 32 | class SubModules: 33 | geometry: BaseImplicitGeometry 34 | material: BaseMaterial 35 | background: BaseBackground 36 | 37 | self.sub_modules = SubModules(geometry, material, background) 38 | 39 | @property 40 | def geometry(self) -> BaseImplicitGeometry: 41 | return self.sub_modules.geometry 42 | 43 | @property 44 | def material(self) -> BaseMaterial: 45 | return self.sub_modules.material 46 | 47 | @property 48 | def background(self) -> BaseBackground: 49 | return self.sub_modules.background 50 | 51 | def __call__(self, *args, **kwargs) -> List[ExporterOutput]: 52 | raise NotImplementedError 53 | 54 | 55 | @threestudio.register("dummy-exporter") 56 | class DummyExporter(Exporter): 57 | def __call__(self, *args, **kwargs) -> List[ExporterOutput]: 58 | # DummyExporter does not export anything 59 | return [] 60 | -------------------------------------------------------------------------------- /threestudio/models/exporters/mesh_exporter.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | import threestudio 8 | from threestudio.models.background.base import BaseBackground 9 | from threestudio.models.exporters.base import Exporter, ExporterOutput 10 | from threestudio.models.geometry.base import BaseImplicitGeometry 11 | from threestudio.models.materials.base import BaseMaterial 12 | from threestudio.models.mesh import Mesh 13 | from threestudio.utils.rasterize import NVDiffRasterizerContext 14 | from threestudio.utils.typing import * 15 | 16 | 17 | @threestudio.register("mesh-exporter") 18 | class MeshExporter(Exporter): 19 | @dataclass 20 | class Config(Exporter.Config): 21 | fmt: str = "obj-mtl" # in ['obj-mtl', 'obj'], TODO: fbx 22 | save_name: str = "model" 23 | save_normal: bool = False 24 | save_uv: bool = True 25 | save_texture: bool = True 26 | texture_size: int = 1024 27 | texture_format: str = "jpg" 28 | xatlas_chart_options: dict = field(default_factory=dict) 29 | xatlas_pack_options: dict = field(default_factory=dict) 30 | context_type: str = "gl" 31 | 32 | cfg: Config 33 | 34 | def configure( 35 | self, 36 | geometry: BaseImplicitGeometry, 37 | material: BaseMaterial, 38 | background: BaseBackground, 39 | ) -> None: 40 | super().configure(geometry, material, background) 41 | self.ctx = NVDiffRasterizerContext(self.cfg.context_type, self.device) 42 | 43 | def __call__(self) -> List[ExporterOutput]: 44 | mesh: Mesh = self.geometry.isosurface() 45 | 46 | if self.cfg.fmt == "obj-mtl": 47 | return self.export_obj_with_mtl(mesh) 48 | elif self.cfg.fmt == "obj": 49 | return self.export_obj(mesh) 50 | else: 51 | raise ValueError(f"Unsupported mesh export format: {self.cfg.fmt}") 52 | 53 | def export_obj_with_mtl(self, mesh: Mesh) -> List[ExporterOutput]: 54 | params = { 55 | "mesh": mesh, 56 | "save_mat": True, 57 | "save_normal": self.cfg.save_normal, 58 | "save_uv": self.cfg.save_uv, 59 | "save_vertex_color": False, 60 | "map_Kd": None, # Base Color 61 | "map_Ks": None, # Specular 62 | "map_Bump": None, # Normal 63 | # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering 64 | "map_Pm": None, # Metallic 65 | "map_Pr": None, # Roughness 66 | "map_format": self.cfg.texture_format, 67 | } 68 | 69 | if self.cfg.save_uv: 70 | mesh.unwrap_uv(self.cfg.xatlas_chart_options, self.cfg.xatlas_pack_options) 71 | 72 | if self.cfg.save_texture: 73 | threestudio.info("Exporting textures ...") 74 | assert self.cfg.save_uv, "save_uv must be True when save_texture is True" 75 | # clip space transform 76 | uv_clip = mesh.v_tex * 2.0 - 1.0 77 | # pad to four component coordinate 78 | uv_clip4 = torch.cat( 79 | ( 80 | uv_clip, 81 | torch.zeros_like(uv_clip[..., 0:1]), 82 | torch.ones_like(uv_clip[..., 0:1]), 83 | ), 84 | dim=-1, 85 | ) 86 | # rasterize 87 | rast, _ = self.ctx.rasterize_one( 88 | uv_clip4, mesh.t_tex_idx, (self.cfg.texture_size, self.cfg.texture_size) 89 | ) 90 | 91 | hole_mask = ~(rast[:, :, 3] > 0) 92 | 93 | def uv_padding(image): 94 | uv_padding_size = self.cfg.xatlas_pack_options.get("padding", 2) 95 | inpaint_image = ( 96 | cv2.inpaint( 97 | (image.detach().cpu().numpy() * 255).astype(np.uint8), 98 | (hole_mask.detach().cpu().numpy() * 255).astype(np.uint8), 99 | uv_padding_size, 100 | cv2.INPAINT_TELEA, 101 | ) 102 | / 255.0 103 | ) 104 | return torch.from_numpy(inpaint_image).to(image) 105 | 106 | # Interpolate world space position 107 | gb_pos, _ = self.ctx.interpolate_one( 108 | mesh.v_pos, rast[None, ...], mesh.t_pos_idx 109 | ) 110 | gb_pos = gb_pos[0] 111 | 112 | # Sample out textures from MLP 113 | geo_out = self.geometry.export(points=gb_pos) 114 | mat_out = self.material.export(points=gb_pos, **geo_out) 115 | 116 | threestudio.info( 117 | "Perform UV padding on texture maps to avoid seams, may take a while ..." 118 | ) 119 | 120 | if "albedo" in mat_out: 121 | params["map_Kd"] = uv_padding(mat_out["albedo"]) 122 | else: 123 | threestudio.warn( 124 | "save_texture is True but no albedo texture found, using default white texture" 125 | ) 126 | if "metallic" in mat_out: 127 | params["map_Pm"] = uv_padding(mat_out["metallic"]) 128 | if "roughness" in mat_out: 129 | params["map_Pr"] = uv_padding(mat_out["roughness"]) 130 | if "bump" in mat_out: 131 | params["map_Bump"] = uv_padding(mat_out["bump"]) 132 | # TODO: map_Ks 133 | return [ 134 | ExporterOutput( 135 | save_name=f"{self.cfg.save_name}.obj", save_type="obj", params=params 136 | ) 137 | ] 138 | 139 | def export_obj(self, mesh: Mesh) -> List[ExporterOutput]: 140 | params = { 141 | "mesh": mesh, 142 | "save_mat": False, 143 | "save_normal": self.cfg.save_normal, 144 | "save_uv": self.cfg.save_uv, 145 | "save_vertex_color": False, 146 | "map_Kd": None, # Base Color 147 | "map_Ks": None, # Specular 148 | "map_Bump": None, # Normal 149 | # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering 150 | "map_Pm": None, # Metallic 151 | "map_Pr": None, # Roughness 152 | "map_format": self.cfg.texture_format, 153 | } 154 | 155 | if self.cfg.save_uv: 156 | mesh.unwrap_uv(self.cfg.xatlas_chart_options, self.cfg.xatlas_pack_options) 157 | 158 | if self.cfg.save_texture: 159 | threestudio.info("Exporting textures ...") 160 | geo_out = self.geometry.export(points=mesh.v_pos) 161 | mat_out = self.material.export(points=mesh.v_pos, **geo_out) 162 | 163 | if "albedo" in mat_out: 164 | mesh.set_vertex_color(mat_out["albedo"]) 165 | params["save_vertex_color"] = True 166 | else: 167 | threestudio.warn( 168 | "save_texture is True but no albedo texture found, not saving vertex color" 169 | ) 170 | 171 | return [ 172 | ExporterOutput( 173 | save_name=f"{self.cfg.save_name}.obj", save_type="obj", params=params 174 | ) 175 | ] 176 | -------------------------------------------------------------------------------- /threestudio/models/geometry/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | base, 3 | implicit_volume, 4 | ) 5 | -------------------------------------------------------------------------------- /threestudio/models/geometry/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.isosurface import ( 10 | IsosurfaceHelper, 11 | MarchingCubeCPUHelper, 12 | MarchingTetrahedraHelper, 13 | ) 14 | from threestudio.models.mesh import Mesh 15 | from threestudio.utils.base import BaseModule 16 | from threestudio.utils.ops import chunk_batch, scale_tensor 17 | from threestudio.utils.typing import * 18 | 19 | 20 | def contract_to_unisphere( 21 | x: Float[Tensor, "... 3"], bbox: Float[Tensor, "2 3"], unbounded: bool = False 22 | ) -> Float[Tensor, "... 3"]: 23 | if unbounded: 24 | x = scale_tensor(x, bbox, (0, 1)) 25 | x = x * 2 - 1 # aabb is at [-1, 1] 26 | mag = x.norm(dim=-1, keepdim=True) 27 | mask = mag.squeeze(-1) > 1 28 | x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask]) 29 | x = x / 4 + 0.5 # [-inf, inf] is at [0, 1] 30 | else: 31 | x = scale_tensor(x, bbox, (0, 1)) 32 | return x 33 | 34 | 35 | class BaseGeometry(BaseModule): 36 | @dataclass 37 | class Config(BaseModule.Config): 38 | pass 39 | 40 | cfg: Config 41 | 42 | @staticmethod 43 | def create_from( 44 | other: "BaseGeometry", cfg: Optional[Union[dict, DictConfig]] = None, **kwargs 45 | ) -> "BaseGeometry": 46 | raise TypeError( 47 | f"Cannot create {BaseGeometry.__name__} from {other.__class__.__name__}" 48 | ) 49 | 50 | def export(self, *args, **kwargs) -> Dict[str, Any]: 51 | return {} 52 | 53 | 54 | class BaseImplicitGeometry(BaseGeometry): 55 | @dataclass 56 | class Config(BaseGeometry.Config): 57 | radius: float = 1.0 58 | isosurface: bool = True 59 | isosurface_method: str = "mt" 60 | isosurface_resolution: int = 128 61 | isosurface_threshold: Union[float, str] = 0.0 62 | isosurface_chunk: int = 0 63 | isosurface_coarse_to_fine: bool = True 64 | isosurface_deformable_grid: bool = False 65 | isosurface_remove_outliers: bool = True 66 | isosurface_outlier_n_faces_threshold: Union[int, float] = 0.01 67 | 68 | cfg: Config 69 | 70 | def configure(self) -> None: 71 | self.bbox: Float[Tensor, "2 3"] 72 | self.register_buffer( 73 | "bbox", 74 | torch.as_tensor( 75 | [ 76 | [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], 77 | [self.cfg.radius, self.cfg.radius, self.cfg.radius], 78 | ], 79 | dtype=torch.float32, 80 | ), 81 | ) 82 | self.isosurface_helper: Optional[IsosurfaceHelper] = None 83 | self.unbounded: bool = False 84 | 85 | def _initilize_isosurface_helper(self): 86 | if self.cfg.isosurface and self.isosurface_helper is None: 87 | if self.cfg.isosurface_method == "mc-cpu": 88 | self.isosurface_helper = MarchingCubeCPUHelper( 89 | self.cfg.isosurface_resolution 90 | ).to(self.device) 91 | elif self.cfg.isosurface_method == "mt": 92 | self.isosurface_helper = MarchingTetrahedraHelper( 93 | self.cfg.isosurface_resolution, 94 | f"load/tets/{self.cfg.isosurface_resolution}_tets.npz", 95 | ).to(self.device) 96 | else: 97 | raise AttributeError( 98 | "Unknown isosurface method {self.cfg.isosurface_method}" 99 | ) 100 | 101 | def forward( 102 | self, points: Float[Tensor, "*N Di"], output_normal: bool = False 103 | ) -> Dict[str, Float[Tensor, "..."]]: 104 | raise NotImplementedError 105 | 106 | def forward_field( 107 | self, points: Float[Tensor, "*N Di"] 108 | ) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]: 109 | # return the value of the implicit field, could be density / signed distance 110 | # also return a deformation field if the grid vertices can be optimized 111 | raise NotImplementedError 112 | 113 | def forward_level( 114 | self, field: Float[Tensor, "*N 1"], threshold: float 115 | ) -> Float[Tensor, "*N 1"]: 116 | # return the value of the implicit field, where the zero level set represents the surface 117 | raise NotImplementedError 118 | 119 | def _isosurface(self, bbox: Float[Tensor, "2 3"], fine_stage: bool = False) -> Mesh: 120 | def batch_func(x): 121 | # scale to bbox as the input vertices are in [0, 1] 122 | field, deformation = self.forward_field( 123 | scale_tensor( 124 | x.to(bbox.device), self.isosurface_helper.points_range, bbox 125 | ), 126 | ) 127 | field = field.to( 128 | x.device 129 | ) # move to the same device as the input (could be CPU) 130 | if deformation is not None: 131 | deformation = deformation.to(x.device) 132 | return field, deformation 133 | 134 | assert self.isosurface_helper is not None 135 | 136 | field, deformation = chunk_batch( 137 | batch_func, 138 | self.cfg.isosurface_chunk, 139 | self.isosurface_helper.grid_vertices, 140 | ) 141 | 142 | threshold: float 143 | 144 | if isinstance(self.cfg.isosurface_threshold, float): 145 | threshold = self.cfg.isosurface_threshold 146 | elif self.cfg.isosurface_threshold == "auto": 147 | eps = 1.0e-5 148 | threshold = field[field > eps].mean().item() 149 | threestudio.info( 150 | f"Automatically determined isosurface threshold: {threshold}" 151 | ) 152 | else: 153 | raise TypeError( 154 | f"Unknown isosurface_threshold {self.cfg.isosurface_threshold}" 155 | ) 156 | 157 | level = self.forward_level(field, threshold) 158 | mesh: Mesh = self.isosurface_helper(level, deformation=deformation) 159 | mesh.v_pos = scale_tensor( 160 | mesh.v_pos, self.isosurface_helper.points_range, bbox 161 | ) # scale to bbox as the grid vertices are in [0, 1] 162 | mesh.add_extra("bbox", bbox) 163 | 164 | if self.cfg.isosurface_remove_outliers: 165 | # remove outliers components with small number of faces 166 | # only enabled when the mesh is not differentiable 167 | mesh = mesh.remove_outlier(self.cfg.isosurface_outlier_n_faces_threshold) 168 | 169 | return mesh 170 | 171 | def isosurface(self) -> Mesh: 172 | if not self.cfg.isosurface: 173 | raise NotImplementedError( 174 | "Isosurface is not enabled in the current configuration" 175 | ) 176 | self._initilize_isosurface_helper() 177 | if self.cfg.isosurface_coarse_to_fine: 178 | threestudio.debug("First run isosurface to get a tight bounding box ...") 179 | with torch.no_grad(): 180 | mesh_coarse = self._isosurface(self.bbox) 181 | vmin, vmax = mesh_coarse.v_pos.amin(dim=0), mesh_coarse.v_pos.amax(dim=0) 182 | vmin_ = (vmin - (vmax - vmin) * 0.1).max(self.bbox[0]) 183 | vmax_ = (vmax + (vmax - vmin) * 0.1).min(self.bbox[1]) 184 | threestudio.debug("Run isosurface again with the tight bounding box ...") 185 | mesh = self._isosurface(torch.stack([vmin_, vmax_], dim=0), fine_stage=True) 186 | else: 187 | mesh = self._isosurface(self.bbox) 188 | return mesh 189 | 190 | 191 | class BaseExplicitGeometry(BaseGeometry): 192 | @dataclass 193 | class Config(BaseGeometry.Config): 194 | radius: float = 1.0 195 | 196 | cfg: Config 197 | 198 | def configure(self) -> None: 199 | self.bbox: Float[Tensor, "2 3"] 200 | self.register_buffer( 201 | "bbox", 202 | torch.as_tensor( 203 | [ 204 | [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], 205 | [self.cfg.radius, self.cfg.radius, self.cfg.radius], 206 | ], 207 | dtype=torch.float32, 208 | ), 209 | ) 210 | -------------------------------------------------------------------------------- /threestudio/models/guidance/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | stable_diffusion_guidance, 3 | stable_diffusion_vsd_guidance, 4 | ) 5 | -------------------------------------------------------------------------------- /threestudio/models/isosurface.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import threestudio 7 | from threestudio.models.mesh import Mesh 8 | from threestudio.utils.typing import * 9 | 10 | 11 | class IsosurfaceHelper(nn.Module): 12 | points_range: Tuple[float, float] = (0, 1) 13 | 14 | @property 15 | def grid_vertices(self) -> Float[Tensor, "N 3"]: 16 | raise NotImplementedError 17 | 18 | 19 | class MarchingCubeCPUHelper(IsosurfaceHelper): 20 | def __init__(self, resolution: int) -> None: 21 | super().__init__() 22 | self.resolution = resolution 23 | import mcubes 24 | 25 | self.mc_func: Callable = mcubes.marching_cubes 26 | self._grid_vertices: Optional[Float[Tensor, "N3 3"]] = None 27 | self._dummy: Float[Tensor, "..."] 28 | self.register_buffer( 29 | "_dummy", torch.zeros(0, dtype=torch.float32), persistent=False 30 | ) 31 | 32 | @property 33 | def grid_vertices(self) -> Float[Tensor, "N3 3"]: 34 | if self._grid_vertices is None: 35 | # keep the vertices on CPU so that we can support very large resolution 36 | x, y, z = ( 37 | torch.linspace(*self.points_range, self.resolution), 38 | torch.linspace(*self.points_range, self.resolution), 39 | torch.linspace(*self.points_range, self.resolution), 40 | ) 41 | x, y, z = torch.meshgrid(x, y, z, indexing="ij") 42 | verts = torch.cat( 43 | [x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1 44 | ).reshape(-1, 3) 45 | self._grid_vertices = verts 46 | return self._grid_vertices 47 | 48 | def forward( 49 | self, 50 | level: Float[Tensor, "N3 1"], 51 | deformation: Optional[Float[Tensor, "N3 3"]] = None, 52 | ) -> Mesh: 53 | if deformation is not None: 54 | threestudio.warn( 55 | f"{self.__class__.__name__} does not support deformation. Ignoring." 56 | ) 57 | level = -level.view(self.resolution, self.resolution, self.resolution) 58 | v_pos, t_pos_idx = self.mc_func( 59 | level.detach().cpu().numpy(), 0.0 60 | ) # transform to numpy 61 | v_pos, t_pos_idx = ( 62 | torch.from_numpy(v_pos).float().to(self._dummy.device), 63 | torch.from_numpy(t_pos_idx.astype(np.int64)).long().to(self._dummy.device), 64 | ) # transform back to torch tensor on CUDA 65 | v_pos = v_pos / (self.resolution - 1.0) 66 | return Mesh(v_pos=v_pos, t_pos_idx=t_pos_idx) 67 | 68 | 69 | class MarchingTetrahedraHelper(IsosurfaceHelper): 70 | def __init__(self, resolution: int, tets_path: str): 71 | super().__init__() 72 | self.resolution = resolution 73 | self.tets_path = tets_path 74 | 75 | self.triangle_table: Float[Tensor, "..."] 76 | self.register_buffer( 77 | "triangle_table", 78 | torch.as_tensor( 79 | [ 80 | [-1, -1, -1, -1, -1, -1], 81 | [1, 0, 2, -1, -1, -1], 82 | [4, 0, 3, -1, -1, -1], 83 | [1, 4, 2, 1, 3, 4], 84 | [3, 1, 5, -1, -1, -1], 85 | [2, 3, 0, 2, 5, 3], 86 | [1, 4, 0, 1, 5, 4], 87 | [4, 2, 5, -1, -1, -1], 88 | [4, 5, 2, -1, -1, -1], 89 | [4, 1, 0, 4, 5, 1], 90 | [3, 2, 0, 3, 5, 2], 91 | [1, 3, 5, -1, -1, -1], 92 | [4, 1, 2, 4, 3, 1], 93 | [3, 0, 4, -1, -1, -1], 94 | [2, 0, 1, -1, -1, -1], 95 | [-1, -1, -1, -1, -1, -1], 96 | ], 97 | dtype=torch.long, 98 | ), 99 | persistent=False, 100 | ) 101 | self.num_triangles_table: Integer[Tensor, "..."] 102 | self.register_buffer( 103 | "num_triangles_table", 104 | torch.as_tensor( 105 | [0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long 106 | ), 107 | persistent=False, 108 | ) 109 | self.base_tet_edges: Integer[Tensor, "..."] 110 | self.register_buffer( 111 | "base_tet_edges", 112 | torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long), 113 | persistent=False, 114 | ) 115 | 116 | tets = np.load(self.tets_path) 117 | self._grid_vertices: Float[Tensor, "..."] 118 | self.register_buffer( 119 | "_grid_vertices", 120 | torch.from_numpy(tets["vertices"]).float(), 121 | persistent=False, 122 | ) 123 | self.indices: Integer[Tensor, "..."] 124 | self.register_buffer( 125 | "indices", torch.from_numpy(tets["indices"]).long(), persistent=False 126 | ) 127 | 128 | self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None 129 | 130 | def normalize_grid_deformation( 131 | self, grid_vertex_offsets: Float[Tensor, "Nv 3"] 132 | ) -> Float[Tensor, "Nv 3"]: 133 | return ( 134 | (self.points_range[1] - self.points_range[0]) 135 | / (self.resolution) # half tet size is approximately 1 / self.resolution 136 | * torch.tanh(grid_vertex_offsets) 137 | ) # FIXME: hard-coded activation 138 | 139 | @property 140 | def grid_vertices(self) -> Float[Tensor, "Nv 3"]: 141 | return self._grid_vertices 142 | 143 | @property 144 | def all_edges(self) -> Integer[Tensor, "Ne 2"]: 145 | if self._all_edges is None: 146 | # compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation) 147 | edges = torch.tensor( 148 | [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], 149 | dtype=torch.long, 150 | device=self.indices.device, 151 | ) 152 | _all_edges = self.indices[:, edges].reshape(-1, 2) 153 | _all_edges_sorted = torch.sort(_all_edges, dim=1)[0] 154 | _all_edges = torch.unique(_all_edges_sorted, dim=0) 155 | self._all_edges = _all_edges 156 | return self._all_edges 157 | 158 | def sort_edges(self, edges_ex2): 159 | with torch.no_grad(): 160 | order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long() 161 | order = order.unsqueeze(dim=1) 162 | 163 | a = torch.gather(input=edges_ex2, index=order, dim=1) 164 | b = torch.gather(input=edges_ex2, index=1 - order, dim=1) 165 | 166 | return torch.stack([a, b], -1) 167 | 168 | def _forward(self, pos_nx3, sdf_n, tet_fx4): 169 | with torch.no_grad(): 170 | occ_n = sdf_n > 0 171 | occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) 172 | occ_sum = torch.sum(occ_fx4, -1) 173 | valid_tets = (occ_sum > 0) & (occ_sum < 4) 174 | occ_sum = occ_sum[valid_tets] 175 | 176 | # find all vertices 177 | all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2) 178 | all_edges = self.sort_edges(all_edges) 179 | unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) 180 | 181 | unique_edges = unique_edges.long() 182 | mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 183 | mapping = ( 184 | torch.ones( 185 | (unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device 186 | ) 187 | * -1 188 | ) 189 | mapping[mask_edges] = torch.arange( 190 | mask_edges.sum(), dtype=torch.long, device=pos_nx3.device 191 | ) 192 | idx_map = mapping[idx_map] # map edges to verts 193 | 194 | interp_v = unique_edges[mask_edges] 195 | edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) 196 | edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) 197 | edges_to_interp_sdf[:, -1] *= -1 198 | 199 | denominator = edges_to_interp_sdf.sum(1, keepdim=True) 200 | 201 | edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator 202 | verts = (edges_to_interp * edges_to_interp_sdf).sum(1) 203 | 204 | idx_map = idx_map.reshape(-1, 6) 205 | 206 | v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device)) 207 | tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) 208 | num_triangles = self.num_triangles_table[tetindex] 209 | 210 | # Generate triangle indices 211 | faces = torch.cat( 212 | ( 213 | torch.gather( 214 | input=idx_map[num_triangles == 1], 215 | dim=1, 216 | index=self.triangle_table[tetindex[num_triangles == 1]][:, :3], 217 | ).reshape(-1, 3), 218 | torch.gather( 219 | input=idx_map[num_triangles == 2], 220 | dim=1, 221 | index=self.triangle_table[tetindex[num_triangles == 2]][:, :6], 222 | ).reshape(-1, 3), 223 | ), 224 | dim=0, 225 | ) 226 | 227 | return verts, faces 228 | 229 | def forward( 230 | self, 231 | level: Float[Tensor, "N3 1"], 232 | deformation: Optional[Float[Tensor, "N3 3"]] = None, 233 | ) -> Mesh: 234 | if deformation is not None: 235 | grid_vertices = self.grid_vertices + self.normalize_grid_deformation( 236 | deformation 237 | ) 238 | else: 239 | grid_vertices = self.grid_vertices 240 | 241 | v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices) 242 | 243 | mesh = Mesh( 244 | v_pos=v_pos, 245 | t_pos_idx=t_pos_idx, 246 | # extras 247 | grid_vertices=grid_vertices, 248 | tet_edges=self.all_edges, 249 | grid_level=level, 250 | grid_deformation=deformation, 251 | ) 252 | 253 | return mesh 254 | -------------------------------------------------------------------------------- /threestudio/models/materials/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | base, 3 | no_material, 4 | ) 5 | -------------------------------------------------------------------------------- /threestudio/models/materials/base.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.utils.base import BaseModule 10 | from threestudio.utils.typing import * 11 | 12 | 13 | class BaseMaterial(BaseModule): 14 | @dataclass 15 | class Config(BaseModule.Config): 16 | pass 17 | 18 | cfg: Config 19 | requires_normal: bool = False 20 | requires_tangent: bool = False 21 | 22 | def configure(self): 23 | pass 24 | 25 | def forward(self, *args, **kwargs) -> Float[Tensor, "*B 3"]: 26 | raise NotImplementedError 27 | 28 | def export(self, *args, **kwargs) -> Dict[str, Any]: 29 | return {} 30 | -------------------------------------------------------------------------------- /threestudio/models/materials/no_material.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.models.networks import get_encoding, get_mlp 11 | from threestudio.utils.ops import dot, get_activation 12 | from threestudio.utils.typing import * 13 | 14 | 15 | @threestudio.register("no-material") 16 | class NoMaterial(BaseMaterial): 17 | @dataclass 18 | class Config(BaseMaterial.Config): 19 | n_output_dims: int = 3 20 | color_activation: str = "sigmoid" 21 | input_feature_dims: Optional[int] = None 22 | mlp_network_config: Optional[dict] = None 23 | 24 | cfg: Config 25 | 26 | def configure(self) -> None: 27 | self.use_network = False 28 | if ( 29 | self.cfg.input_feature_dims is not None 30 | and self.cfg.mlp_network_config is not None 31 | ): 32 | self.network = get_mlp( 33 | self.cfg.input_feature_dims, 34 | self.cfg.n_output_dims, 35 | self.cfg.mlp_network_config, 36 | ) 37 | self.use_network = True 38 | 39 | def forward( 40 | self, features: Float[Tensor, "B ... Nf"], **kwargs 41 | ) -> Float[Tensor, "B ... Nc"]: 42 | if not self.use_network: 43 | assert ( 44 | features.shape[-1] == self.cfg.n_output_dims 45 | ), f"Expected {self.cfg.n_output_dims} output dims, only got {features.shape[-1]} dims input." 46 | color = get_activation(self.cfg.color_activation)(features) 47 | else: 48 | color = self.network(features.view(-1, features.shape[-1])).view( 49 | *features.shape[:-1], self.cfg.n_output_dims 50 | ) 51 | color = get_activation(self.cfg.color_activation)(color) 52 | return color 53 | 54 | def export(self, features: Float[Tensor, "*N Nf"], **kwargs) -> Dict[str, Any]: 55 | color = self(features, **kwargs).clamp(0, 1) 56 | assert color.shape[-1] >= 3, "Output color must have at least 3 channels" 57 | if color.shape[-1] > 3: 58 | threestudio.warn( 59 | "Output color has >3 channels, treating the first 3 as RGB" 60 | ) 61 | return {"albedo": color[..., :3]} 62 | -------------------------------------------------------------------------------- /threestudio/models/mesh.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | import threestudio 8 | from threestudio.utils.ops import dot 9 | from threestudio.utils.typing import * 10 | 11 | 12 | class Mesh: 13 | def __init__( 14 | self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs 15 | ) -> None: 16 | self.v_pos: Float[Tensor, "Nv 3"] = v_pos 17 | self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx 18 | self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None 19 | self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None 20 | self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None 21 | self._t_tex_idx: Optional[Float[Tensor, "Nf 3"]] = None 22 | self._v_rgb: Optional[Float[Tensor, "Nv 3"]] = None 23 | self._edges: Optional[Integer[Tensor, "Ne 2"]] = None 24 | self.extras: Dict[str, Any] = {} 25 | for k, v in kwargs.items(): 26 | self.add_extra(k, v) 27 | 28 | def add_extra(self, k, v) -> None: 29 | self.extras[k] = v 30 | 31 | def remove_outlier(self, outlier_n_faces_threshold: Union[int, float]) -> Mesh: 32 | if self.requires_grad: 33 | threestudio.debug("Mesh is differentiable, not removing outliers") 34 | return self 35 | 36 | # use trimesh to first split the mesh into connected components 37 | # then remove the components with less than n_face_threshold faces 38 | import trimesh 39 | 40 | # construct a trimesh object 41 | mesh = trimesh.Trimesh( 42 | vertices=self.v_pos.detach().cpu().numpy(), 43 | faces=self.t_pos_idx.detach().cpu().numpy(), 44 | ) 45 | 46 | # split the mesh into connected components 47 | components = mesh.split(only_watertight=False) 48 | # log the number of faces in each component 49 | threestudio.debug( 50 | "Mesh has {} components, with faces: {}".format( 51 | len(components), [c.faces.shape[0] for c in components] 52 | ) 53 | ) 54 | 55 | n_faces_threshold: int 56 | if isinstance(outlier_n_faces_threshold, float): 57 | # set the threshold to the number of faces in the largest component multiplied by outlier_n_faces_threshold 58 | n_faces_threshold = int( 59 | max([c.faces.shape[0] for c in components]) * outlier_n_faces_threshold 60 | ) 61 | else: 62 | # set the threshold directly to outlier_n_faces_threshold 63 | n_faces_threshold = outlier_n_faces_threshold 64 | 65 | # log the threshold 66 | threestudio.debug( 67 | "Removing components with less than {} faces".format(n_faces_threshold) 68 | ) 69 | 70 | # remove the components with less than n_face_threshold faces 71 | components = [c for c in components if c.faces.shape[0] >= n_faces_threshold] 72 | 73 | # log the number of faces in each component after removing outliers 74 | threestudio.debug( 75 | "Mesh has {} components after removing outliers, with faces: {}".format( 76 | len(components), [c.faces.shape[0] for c in components] 77 | ) 78 | ) 79 | # merge the components 80 | mesh = trimesh.util.concatenate(components) 81 | 82 | # convert back to our mesh format 83 | v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos) 84 | t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx) 85 | 86 | clean_mesh = Mesh(v_pos, t_pos_idx) 87 | # keep the extras unchanged 88 | 89 | if len(self.extras) > 0: 90 | clean_mesh.extras = self.extras 91 | threestudio.debug( 92 | f"The following extra attributes are inherited from the original mesh unchanged: {list(self.extras.keys())}" 93 | ) 94 | return clean_mesh 95 | 96 | @property 97 | def requires_grad(self): 98 | return self.v_pos.requires_grad 99 | 100 | @property 101 | def v_nrm(self): 102 | if self._v_nrm is None: 103 | self._v_nrm = self._compute_vertex_normal() 104 | return self._v_nrm 105 | 106 | @property 107 | def v_tng(self): 108 | if self._v_tng is None: 109 | self._v_tng = self._compute_vertex_tangent() 110 | return self._v_tng 111 | 112 | @property 113 | def v_tex(self): 114 | if self._v_tex is None: 115 | self._v_tex, self._t_tex_idx = self._unwrap_uv() 116 | return self._v_tex 117 | 118 | @property 119 | def t_tex_idx(self): 120 | if self._t_tex_idx is None: 121 | self._v_tex, self._t_tex_idx = self._unwrap_uv() 122 | return self._t_tex_idx 123 | 124 | @property 125 | def v_rgb(self): 126 | return self._v_rgb 127 | 128 | @property 129 | def edges(self): 130 | if self._edges is None: 131 | self._edges = self._compute_edges() 132 | return self._edges 133 | 134 | def _compute_vertex_normal(self): 135 | i0 = self.t_pos_idx[:, 0] 136 | i1 = self.t_pos_idx[:, 1] 137 | i2 = self.t_pos_idx[:, 2] 138 | 139 | v0 = self.v_pos[i0, :] 140 | v1 = self.v_pos[i1, :] 141 | v2 = self.v_pos[i2, :] 142 | 143 | face_normals = torch.cross(v1 - v0, v2 - v0) 144 | 145 | # Splat face normals to vertices 146 | v_nrm = torch.zeros_like(self.v_pos) 147 | v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) 148 | v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) 149 | v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) 150 | 151 | # Normalize, replace zero (degenerated) normals with some default value 152 | v_nrm = torch.where( 153 | dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) 154 | ) 155 | v_nrm = F.normalize(v_nrm, dim=1) 156 | 157 | if torch.is_anomaly_enabled(): 158 | assert torch.all(torch.isfinite(v_nrm)) 159 | 160 | return v_nrm 161 | 162 | def _compute_vertex_tangent(self): 163 | vn_idx = [None] * 3 164 | pos = [None] * 3 165 | tex = [None] * 3 166 | for i in range(0, 3): 167 | pos[i] = self.v_pos[self.t_pos_idx[:, i]] 168 | tex[i] = self.v_tex[self.t_tex_idx[:, i]] 169 | # t_nrm_idx is always the same as t_pos_idx 170 | vn_idx[i] = self.t_pos_idx[:, i] 171 | 172 | tangents = torch.zeros_like(self.v_nrm) 173 | tansum = torch.zeros_like(self.v_nrm) 174 | 175 | # Compute tangent space for each triangle 176 | uve1 = tex[1] - tex[0] 177 | uve2 = tex[2] - tex[0] 178 | pe1 = pos[1] - pos[0] 179 | pe2 = pos[2] - pos[0] 180 | 181 | nom = pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2] 182 | denom = uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1] 183 | 184 | # Avoid division by zero for degenerated texture coordinates 185 | tang = nom / torch.where( 186 | denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6) 187 | ) 188 | 189 | # Update all 3 vertices 190 | for i in range(0, 3): 191 | idx = vn_idx[i][:, None].repeat(1, 3) 192 | tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang 193 | tansum.scatter_add_( 194 | 0, idx, torch.ones_like(tang) 195 | ) # tansum[n_i] = tansum[n_i] + 1 196 | tangents = tangents / tansum 197 | 198 | # Normalize and make sure tangent is perpendicular to normal 199 | tangents = F.normalize(tangents, dim=1) 200 | tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm) 201 | 202 | if torch.is_anomaly_enabled(): 203 | assert torch.all(torch.isfinite(tangents)) 204 | 205 | return tangents 206 | 207 | def _unwrap_uv( 208 | self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {} 209 | ): 210 | threestudio.info("Using xatlas to perform UV unwrapping, may take a while ...") 211 | 212 | import xatlas 213 | 214 | atlas = xatlas.Atlas() 215 | atlas.add_mesh( 216 | self.v_pos.detach().cpu().numpy(), 217 | self.t_pos_idx.cpu().numpy(), 218 | ) 219 | co = xatlas.ChartOptions() 220 | po = xatlas.PackOptions() 221 | for k, v in xatlas_chart_options.items(): 222 | setattr(co, k, v) 223 | for k, v in xatlas_pack_options.items(): 224 | setattr(po, k, v) 225 | atlas.generate(co, po) 226 | vmapping, indices, uvs = atlas.get_mesh(0) 227 | vmapping = ( 228 | torch.from_numpy( 229 | vmapping.astype(np.uint64, casting="same_kind").view(np.int64) 230 | ) 231 | .to(self.v_pos.device) 232 | .long() 233 | ) 234 | uvs = torch.from_numpy(uvs).to(self.v_pos.device).float() 235 | indices = ( 236 | torch.from_numpy( 237 | indices.astype(np.uint64, casting="same_kind").view(np.int64) 238 | ) 239 | .to(self.v_pos.device) 240 | .long() 241 | ) 242 | return uvs, indices 243 | 244 | def unwrap_uv( 245 | self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {} 246 | ): 247 | self._v_tex, self._t_tex_idx = self._unwrap_uv( 248 | xatlas_chart_options, xatlas_pack_options 249 | ) 250 | 251 | def set_vertex_color(self, v_rgb): 252 | assert v_rgb.shape[0] == self.v_pos.shape[0] 253 | self._v_rgb = v_rgb 254 | 255 | def _compute_edges(self): 256 | # Compute edges 257 | edges = torch.cat( 258 | [ 259 | self.t_pos_idx[:, [0, 1]], 260 | self.t_pos_idx[:, [1, 2]], 261 | self.t_pos_idx[:, [2, 0]], 262 | ], 263 | dim=0, 264 | ) 265 | edges = edges.sort()[0] 266 | edges = torch.unique(edges, dim=0) 267 | return edges 268 | 269 | def normal_consistency(self) -> Float[Tensor, ""]: 270 | edge_nrm: Float[Tensor, "Ne 2 3"] = self.v_nrm[self.edges] 271 | nc = ( 272 | 1.0 - torch.cosine_similarity(edge_nrm[:, 0], edge_nrm[:, 1], dim=-1) 273 | ).mean() 274 | return nc 275 | 276 | def _laplacian_uniform(self): 277 | # from stable-dreamfusion 278 | # https://github.com/ashawkey/stable-dreamfusion/blob/8fb3613e9e4cd1ded1066b46e80ca801dfb9fd06/nerf/renderer.py#L224 279 | verts, faces = self.v_pos, self.t_pos_idx 280 | 281 | V = verts.shape[0] 282 | F = faces.shape[0] 283 | 284 | # Neighbor indices 285 | ii = faces[:, [1, 2, 0]].flatten() 286 | jj = faces[:, [2, 0, 1]].flatten() 287 | adj = torch.stack([torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique( 288 | dim=1 289 | ) 290 | adj_values = torch.ones(adj.shape[1]).to(verts) 291 | 292 | # Diagonal indices 293 | diag_idx = adj[0] 294 | 295 | # Build the sparse matrix 296 | idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1) 297 | values = torch.cat((-adj_values, adj_values)) 298 | 299 | # The coalesce operation sums the duplicate indices, resulting in the 300 | # correct diagonal 301 | return torch.sparse_coo_tensor(idx, values, (V, V)).coalesce() 302 | 303 | def laplacian(self) -> Float[Tensor, ""]: 304 | with torch.no_grad(): 305 | L = self._laplacian_uniform() 306 | loss = L.mm(self.v_pos) 307 | loss = loss.norm(dim=1) 308 | loss = loss.mean() 309 | return loss 310 | -------------------------------------------------------------------------------- /threestudio/models/prompt_processors/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | base, 3 | stable_diffusion_prompt_processor, 4 | ) 5 | -------------------------------------------------------------------------------- /threestudio/models/prompt_processors/stable_diffusion_prompt_processor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass 4 | 5 | import torch 6 | import torch.nn as nn 7 | from transformers import AutoTokenizer, CLIPTextModel 8 | 9 | import threestudio 10 | from threestudio.models.prompt_processors.base import PromptProcessor, hash_prompt 11 | from threestudio.utils.misc import cleanup 12 | from threestudio.utils.typing import * 13 | 14 | 15 | @threestudio.register("stable-diffusion-prompt-processor") 16 | class StableDiffusionPromptProcessor(PromptProcessor): 17 | @dataclass 18 | class Config(PromptProcessor.Config): 19 | pass 20 | 21 | cfg: Config 22 | 23 | ### these functions are unused, kept for debugging ### 24 | def configure_text_encoder(self) -> None: 25 | self.tokenizer = AutoTokenizer.from_pretrained( 26 | self.cfg.pretrained_model_name_or_path, subfolder="tokenizer" 27 | ) 28 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 29 | self.text_encoder = CLIPTextModel.from_pretrained( 30 | self.cfg.pretrained_model_name_or_path, subfolder="text_encoder" 31 | ).to(self.device) 32 | 33 | for p in self.text_encoder.parameters(): 34 | p.requires_grad_(False) 35 | 36 | def destroy_text_encoder(self) -> None: 37 | del self.tokenizer 38 | del self.text_encoder 39 | cleanup() 40 | 41 | def get_text_embeddings( 42 | self, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] 43 | ) -> Tuple[Float[Tensor, "B 77 768"], Float[Tensor, "B 77 768"]]: 44 | if isinstance(prompt, str): 45 | prompt = [prompt] 46 | if isinstance(negative_prompt, str): 47 | negative_prompt = [negative_prompt] 48 | # Tokenize text and get embeddings 49 | tokens = self.tokenizer( 50 | prompt, 51 | padding="max_length", 52 | max_length=self.tokenizer.model_max_length, 53 | return_tensors="pt", 54 | ) 55 | uncond_tokens = self.tokenizer( 56 | negative_prompt, 57 | padding="max_length", 58 | max_length=self.tokenizer.model_max_length, 59 | return_tensors="pt", 60 | ) 61 | 62 | with torch.no_grad(): 63 | text_embeddings = self.text_encoder(tokens.input_ids.to(self.device))[0] 64 | uncond_text_embeddings = self.text_encoder( 65 | uncond_tokens.input_ids.to(self.device) 66 | )[0] 67 | 68 | return text_embeddings, uncond_text_embeddings 69 | 70 | ### 71 | 72 | @staticmethod 73 | def spawn_func(pretrained_model_name_or_path, prompts, cache_dir): 74 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 75 | tokenizer = AutoTokenizer.from_pretrained( 76 | pretrained_model_name_or_path, subfolder="tokenizer" 77 | ) 78 | text_encoder = CLIPTextModel.from_pretrained( 79 | pretrained_model_name_or_path, 80 | subfolder="text_encoder", 81 | device_map="auto", 82 | ) 83 | 84 | with torch.no_grad(): 85 | tokens = tokenizer( 86 | prompts, 87 | padding="max_length", 88 | max_length=tokenizer.model_max_length, 89 | return_tensors="pt", 90 | ) 91 | text_embeddings = text_encoder(tokens.input_ids.cuda())[0] 92 | 93 | for prompt, embedding in zip(prompts, text_embeddings): 94 | torch.save( 95 | embedding, 96 | os.path.join( 97 | cache_dir, 98 | f"{hash_prompt(pretrained_model_name_or_path, prompt)}.pt", 99 | ), 100 | ) 101 | 102 | del text_encoder 103 | -------------------------------------------------------------------------------- /threestudio/models/renderers/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | base, 3 | nerf_volume_renderer, 4 | ) 5 | -------------------------------------------------------------------------------- /threestudio/models/renderers/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import nerfacc 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | import threestudio 8 | from threestudio.models.background.base import BaseBackground 9 | from threestudio.models.geometry.base import BaseImplicitGeometry 10 | from threestudio.models.materials.base import BaseMaterial 11 | from threestudio.utils.base import BaseModule 12 | from threestudio.utils.typing import * 13 | 14 | 15 | class Renderer(BaseModule): 16 | @dataclass 17 | class Config(BaseModule.Config): 18 | radius: float = 1.0 19 | 20 | cfg: Config 21 | 22 | def configure( 23 | self, 24 | geometry: BaseImplicitGeometry, 25 | material: BaseMaterial, 26 | background: BaseBackground, 27 | ) -> None: 28 | # keep references to submodules using namedtuple, avoid being registered as modules 29 | @dataclass 30 | class SubModules: 31 | geometry: BaseImplicitGeometry 32 | material: BaseMaterial 33 | background: BaseBackground 34 | 35 | self.sub_modules = SubModules(geometry, material, background) 36 | 37 | # set up bounding box 38 | self.bbox: Float[Tensor, "2 3"] 39 | self.register_buffer( 40 | "bbox", 41 | torch.as_tensor( 42 | [ 43 | [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], 44 | [self.cfg.radius, self.cfg.radius, self.cfg.radius], 45 | ], 46 | dtype=torch.float32, 47 | ), 48 | ) 49 | 50 | def forward(self, *args, **kwargs) -> Dict[str, Any]: 51 | raise NotImplementedError 52 | 53 | @property 54 | def geometry(self) -> BaseImplicitGeometry: 55 | return self.sub_modules.geometry 56 | 57 | @property 58 | def material(self) -> BaseMaterial: 59 | return self.sub_modules.material 60 | 61 | @property 62 | def background(self) -> BaseBackground: 63 | return self.sub_modules.background 64 | 65 | def set_geometry(self, geometry: BaseImplicitGeometry) -> None: 66 | self.sub_modules.geometry = geometry 67 | 68 | def set_material(self, material: BaseMaterial) -> None: 69 | self.sub_modules.material = material 70 | 71 | def set_background(self, background: BaseBackground) -> None: 72 | self.sub_modules.background = background 73 | 74 | 75 | class VolumeRenderer(Renderer): 76 | pass 77 | 78 | 79 | class Rasterizer(Renderer): 80 | pass 81 | -------------------------------------------------------------------------------- /threestudio/models/renderers/nerf_volume_renderer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import nerfacc 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | import threestudio 8 | from threestudio.models.background.base import BaseBackground 9 | from threestudio.models.geometry.base import BaseImplicitGeometry 10 | from threestudio.models.materials.base import BaseMaterial 11 | from threestudio.models.renderers.base import VolumeRenderer 12 | from threestudio.utils.ops import chunk_batch, validate_empty_rays 13 | from threestudio.utils.typing import * 14 | 15 | 16 | @threestudio.register("nerf-volume-renderer") 17 | class NeRFVolumeRenderer(VolumeRenderer): 18 | @dataclass 19 | class Config(VolumeRenderer.Config): 20 | num_samples_per_ray: int = 512 21 | randomized: bool = True 22 | eval_chunk_size: int = 160000 23 | grid_prune: bool = True 24 | prune_alpha_threshold: bool = True 25 | return_comp_normal: bool = False 26 | return_normal_perturb: bool = False 27 | 28 | cfg: Config 29 | 30 | def configure( 31 | self, 32 | geometry: BaseImplicitGeometry, 33 | material: BaseMaterial, 34 | background: BaseBackground, 35 | ) -> None: 36 | super().configure(geometry, material, background) 37 | self.estimator = nerfacc.OccGridEstimator( 38 | roi_aabb=self.bbox.view(-1), resolution=32, levels=1 39 | ) 40 | if not self.cfg.grid_prune: 41 | self.estimator.occs.fill_(True) 42 | self.estimator.binaries.fill_(True) 43 | self.render_step_size = ( 44 | 1.732 * 2 * self.cfg.radius / self.cfg.num_samples_per_ray 45 | ) 46 | self.randomized = self.cfg.randomized 47 | 48 | 49 | def forward( 50 | self, 51 | batch_idx: int, 52 | rays_o: Float[Tensor, "B H W 3"], 53 | rays_d: Float[Tensor, "B H W 3"], 54 | light_positions: Float[Tensor, "B 3"], 55 | bg_color: Optional[Tensor] = None, 56 | **kwargs, 57 | ) -> Dict[str, Float[Tensor, "..."]]: 58 | batch_size, height, width = rays_o.shape[:3] 59 | rays_o_flatten: Float[Tensor, "Nr 3"] = rays_o.reshape(-1, 3) 60 | rays_d_flatten: Float[Tensor, "Nr 3"] = rays_d.reshape(-1, 3) 61 | light_positions_flatten: Float[Tensor, "Nr 3"] = ( 62 | light_positions.reshape(-1, 1, 1, 3) 63 | .expand(-1, height, width, -1) 64 | .reshape(-1, 3) 65 | ) 66 | n_rays = rays_o_flatten.shape[0] 67 | 68 | def sigma_fn(t_starts, t_ends, ray_indices, idx=batch_idx): 69 | t_starts, t_ends = t_starts[..., None], t_ends[..., None] 70 | t_origins = rays_o_flatten[ray_indices] 71 | t_positions = (t_starts + t_ends) / 2.0 72 | t_dirs = rays_d_flatten[ray_indices] 73 | positions = t_origins + t_dirs * t_positions 74 | if self.training: 75 | sigma = self.geometry.forward_density(positions, idx=batch_idx)[..., 0] 76 | else: 77 | sigma = chunk_batch( 78 | self.geometry.forward_density, 79 | self.cfg.eval_chunk_size, 80 | positions, 81 | )[..., 0] 82 | return sigma 83 | 84 | if not self.cfg.grid_prune: 85 | with torch.no_grad(): 86 | ray_indices, t_starts_, t_ends_ = self.estimator.sampling( 87 | rays_o_flatten, 88 | rays_d_flatten, 89 | sigma_fn=None, 90 | render_step_size=self.render_step_size, 91 | alpha_thre=0.0, 92 | stratified=self.randomized, 93 | cone_angle=0.0, 94 | early_stop_eps=0, 95 | ) 96 | else: 97 | with torch.no_grad(): 98 | ray_indices, t_starts_, t_ends_ = self.estimator.sampling( 99 | rays_o_flatten, 100 | rays_d_flatten, 101 | sigma_fn=sigma_fn if self.cfg.prune_alpha_threshold else None, 102 | render_step_size=self.render_step_size, 103 | alpha_thre=0.01 if self.cfg.prune_alpha_threshold else 0.0, 104 | stratified=self.randomized, 105 | cone_angle=0.0, 106 | ) 107 | ray_indices, t_starts_, t_ends_ = validate_empty_rays( 108 | ray_indices, t_starts_, t_ends_ 109 | ) 110 | ray_indices = ray_indices.long() 111 | t_starts, t_ends = t_starts_[..., None], t_ends_[..., None] 112 | t_origins = rays_o_flatten[ray_indices] 113 | t_dirs = rays_d_flatten[ray_indices] 114 | t_light_positions = light_positions_flatten[ray_indices] 115 | t_positions = (t_starts + t_ends) / 2.0 116 | positions = t_origins + t_dirs * t_positions 117 | t_intervals = t_ends - t_starts 118 | 119 | if self.training: 120 | geo_out = self.geometry( 121 | positions, output_normal=True, idx=batch_idx 122 | ) 123 | rgb_fg_all = self.material( 124 | viewdirs=t_dirs, 125 | positions=positions, 126 | light_positions=t_light_positions, 127 | **geo_out, 128 | **kwargs 129 | ) 130 | comp_rgb_bg = self.background(dirs=rays_d) 131 | 132 | else: 133 | geo_out = chunk_batch( 134 | self.geometry, 135 | self.cfg.eval_chunk_size, 136 | positions, 137 | output_normal=True, 138 | idx=batch_idx, 139 | ) 140 | rgb_fg_all = chunk_batch( 141 | self.material, 142 | self.cfg.eval_chunk_size, 143 | viewdirs=t_dirs, 144 | positions=positions, 145 | light_positions=t_light_positions, 146 | **geo_out 147 | ) 148 | comp_rgb_bg = chunk_batch( 149 | self.background, self.cfg.eval_chunk_size, dirs=rays_d 150 | ) 151 | 152 | weights: Float[Tensor, "Nr 1"] 153 | weights_, _, _ = nerfacc.render_weight_from_density( 154 | t_starts[..., 0], 155 | t_ends[..., 0], 156 | geo_out["density"][..., 0], 157 | ray_indices=ray_indices, 158 | n_rays=n_rays, 159 | ) 160 | weights = weights_[..., None] 161 | opacity: Float[Tensor, "Nr 1"] = nerfacc.accumulate_along_rays( 162 | weights[..., 0], values=None, ray_indices=ray_indices, n_rays=n_rays 163 | ) 164 | depth: Float[Tensor, "Nr 1"] = nerfacc.accumulate_along_rays( 165 | weights[..., 0], values=t_positions, ray_indices=ray_indices, n_rays=n_rays 166 | ) 167 | comp_rgb_fg: Float[Tensor, "Nr Nc"] = nerfacc.accumulate_along_rays( 168 | weights[..., 0], values=rgb_fg_all, ray_indices=ray_indices, n_rays=n_rays 169 | ) 170 | 171 | # populate depth and opacity to each point 172 | t_depth = depth[ray_indices] 173 | z_variance = nerfacc.accumulate_along_rays( 174 | weights[..., 0], 175 | values=(t_positions - t_depth) ** 2, 176 | ray_indices=ray_indices, 177 | n_rays=n_rays, 178 | ) 179 | 180 | if bg_color is None: 181 | bg_color = comp_rgb_bg 182 | else: 183 | if bg_color.shape[:-1] == (batch_size,): 184 | # e.g. constant random color used for Zero123 185 | # [bs,3] -> [bs, 1, 1, 3]): 186 | bg_color = bg_color.unsqueeze(1).unsqueeze(1) 187 | # -> [bs, height, width, 3]): 188 | bg_color = bg_color.expand(-1, height, width, -1) 189 | 190 | if bg_color.shape[:-1] == (batch_size, height, width): 191 | bg_color = bg_color.reshape(batch_size * height * width, -1) 192 | 193 | comp_rgb = comp_rgb_fg + bg_color * (1.0 - opacity) 194 | 195 | out = { 196 | "comp_rgb": comp_rgb.view(batch_size, height, width, -1), 197 | "comp_rgb_fg": comp_rgb_fg.view(batch_size, height, width, -1), 198 | "comp_rgb_bg": comp_rgb_bg.view(batch_size, height, width, -1), 199 | "opacity": opacity.view(batch_size, height, width, 1), 200 | "depth": depth.view(batch_size, height, width, 1), 201 | "z_variance": z_variance.view(batch_size, height, width, 1), 202 | } 203 | 204 | if self.training: 205 | out.update( 206 | { 207 | "weights": weights, 208 | "t_points": t_positions, 209 | "t_intervals": t_intervals, 210 | "t_dirs": t_dirs, 211 | "ray_indices": ray_indices, 212 | "points": positions, 213 | **geo_out, 214 | } 215 | ) 216 | if "normal" in geo_out: 217 | if self.cfg.return_comp_normal: 218 | comp_normal: Float[Tensor, "Nr 3"] = nerfacc.accumulate_along_rays( 219 | weights[..., 0], 220 | values=geo_out["normal"], 221 | ray_indices=ray_indices, 222 | n_rays=n_rays, 223 | ) 224 | comp_normal = F.normalize(comp_normal, dim=-1) 225 | comp_normal = ( 226 | (comp_normal + 1.0) / 2.0 * opacity 227 | ) # for visualization 228 | out.update( 229 | { 230 | "comp_normal": comp_normal.view( 231 | batch_size, height, width, 3 232 | ), 233 | } 234 | ) 235 | if self.cfg.return_normal_perturb: 236 | normal_perturb = self.geometry( 237 | positions + torch.randn_like(positions) * 1e-2, 238 | output_normal=True, 239 | )["normal"] 240 | out.update({"normal_perturb": normal_perturb}) 241 | else: 242 | if "normal" in geo_out: 243 | comp_normal = nerfacc.accumulate_along_rays( 244 | weights[..., 0], 245 | values=geo_out["normal"], 246 | ray_indices=ray_indices, 247 | n_rays=n_rays, 248 | ) 249 | comp_normal = F.normalize(comp_normal, dim=-1) 250 | comp_normal = (comp_normal + 1.0) / 2.0 * opacity # for visualization 251 | out.update( 252 | { 253 | "comp_normal": comp_normal.view(batch_size, height, width, 3), 254 | } 255 | ) 256 | 257 | 258 | return out 259 | 260 | def update_step( 261 | self, epoch: int, global_step: int, on_load_weights: bool = False 262 | ) -> None: 263 | 264 | # print(global_step) 265 | if self.cfg.grid_prune: 266 | def occ_eval_fn(x): 267 | density = self.geometry.forward_density(x) 268 | # approximate for 1 - torch.exp(-density * self.render_step_size) based on taylor series 269 | return density * self.render_step_size 270 | 271 | if self.training and not on_load_weights: 272 | self.estimator.update_every_n_steps( 273 | step=global_step, occ_eval_fn=occ_eval_fn 274 | ) 275 | def train(self, mode=True): 276 | self.randomized = mode and self.cfg.randomized 277 | return super().train(mode=mode) 278 | 279 | def eval(self): 280 | self.randomized = False 281 | return super().eval() 282 | -------------------------------------------------------------------------------- /threestudio/systems/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | dreamavatar, 3 | ) 4 | -------------------------------------------------------------------------------- /threestudio/systems/dreamavatar.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | 6 | import threestudio 7 | from threestudio.systems.base import BaseLift3DSystem 8 | from threestudio.utils.misc import cleanup, get_device 9 | from threestudio.utils.ops import binary_cross_entropy, dot 10 | from threestudio.utils.typing import * 11 | 12 | 13 | @threestudio.register("dreamavatar-system") 14 | class DreamAvatar(BaseLift3DSystem): 15 | @dataclass 16 | class Config(BaseLift3DSystem.Config): 17 | # in ['coarse', 'geometry', 'texture'] 18 | stage: str = "coarse" 19 | visualize_samples: bool = False 20 | 21 | cfg: Config 22 | 23 | def configure(self) -> None: 24 | # set up geometry, material, background, renderer 25 | super().configure() 26 | 27 | self.guidance = threestudio.find(self.cfg.guidance_type)(self.cfg.guidance) 28 | self.prompt_processor = threestudio.find(self.cfg.prompt_processor_type)( 29 | self.cfg.prompt_processor 30 | ) 31 | self.prompt_utils = self.prompt_processor() 32 | 33 | def forward(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any]: 34 | 35 | render_out = self.renderer(batch_idx, **batch) 36 | return { 37 | **render_out, 38 | } 39 | 40 | def on_fit_start(self) -> None: 41 | super().on_fit_start() 42 | 43 | def training_step(self, batch, batch_idx): 44 | out = self(batch, batch_idx) 45 | guidance_inp = out["comp_rgb"] 46 | guidance_out = self.guidance( 47 | guidance_inp, self.prompt_utils, **batch, rgb_as_latents=False 48 | ) 49 | 50 | loss = 0.0 51 | 52 | for name, value in guidance_out.items(): 53 | self.log(f"train/{name}", value) 54 | if name.startswith("loss_"): 55 | loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")]) 56 | 57 | if self.C(self.cfg.loss.lambda_orient) > 0: 58 | if "normal" not in out: 59 | raise ValueError( 60 | "Normal is required for orientation loss, no normal is found in the output." 61 | ) 62 | loss_orient = ( 63 | out["weights"].detach() 64 | * dot(out["normal"], out["t_dirs"]).clamp_min(0.0) ** 2 65 | ).sum() / (out["opacity"] > 0).sum() 66 | self.log("train/loss_orient", loss_orient) 67 | loss += loss_orient * self.C(self.cfg.loss.lambda_orient) 68 | 69 | loss_sparsity = (out["opacity"] ** 2 + 0.01).sqrt().mean() 70 | self.log("train/loss_sparsity", loss_sparsity) 71 | loss += loss_sparsity * self.C(self.cfg.loss.lambda_sparsity) 72 | 73 | opacity_clamped = out["opacity"].clamp(1.0e-3, 1.0 - 1.0e-3) 74 | loss_opaque = binary_cross_entropy(opacity_clamped, opacity_clamped) 75 | self.log("train/loss_opaque", loss_opaque) 76 | loss += loss_opaque * self.C(self.cfg.loss.lambda_opaque) 77 | 78 | # z variance loss proposed in HiFA: http://arxiv.org/abs/2305.18766 79 | # helps reduce floaters and produce solid geometry 80 | loss_z_variance = out["z_variance"][out["opacity"] > 0.5].mean() 81 | self.log("train/loss_z_variance", loss_z_variance) 82 | loss += loss_z_variance * self.C(self.cfg.loss.lambda_z_variance) 83 | 84 | for name, value in self.cfg.loss.items(): 85 | self.log(f"train_params/{name}", self.C(value)) 86 | 87 | return {"loss": loss} 88 | 89 | def validation_step(self, batch, batch_idx): 90 | out = self(batch, batch_idx) 91 | self.save_image_grid( 92 | f"it{self.true_global_step}-{batch['index'][0]}.png", 93 | ( 94 | [ 95 | { 96 | "type": "rgb", 97 | "img": out["comp_rgb"][0], 98 | "kwargs": {"data_format": "HWC"}, 99 | }, 100 | ] 101 | if "comp_rgb" in out 102 | else [] 103 | ) 104 | + ( 105 | [ 106 | { 107 | "type": "rgb", 108 | "img": out["comp_normal"][0], 109 | "kwargs": {"data_format": "HWC", "data_range": (0, 1)}, 110 | } 111 | ] 112 | if "comp_normal" in out 113 | else [] 114 | ) 115 | + [ 116 | { 117 | "type": "grayscale", 118 | "img": out["opacity"][0, :, :, 0], 119 | "kwargs": {"cmap": None, "data_range": (0, 1)}, 120 | }, 121 | ], 122 | name="validation_step", 123 | step=self.true_global_step, 124 | ) 125 | 126 | if self.cfg.visualize_samples: 127 | self.save_image_grid( 128 | f"it{self.true_global_step}-{batch['index'][0]}-sample.png", 129 | [ 130 | { 131 | "type": "rgb", 132 | "img": self.guidance.sample( 133 | self.prompt_utils, **batch, seed=self.global_step 134 | )[0], 135 | "kwargs": {"data_format": "HWC"}, 136 | }, 137 | { 138 | "type": "rgb", 139 | "img": self.guidance.sample_lora(self.prompt_utils, **batch)[0], 140 | "kwargs": {"data_format": "HWC"}, 141 | }, 142 | ], 143 | name="validation_step_samples", 144 | step=self.true_global_step, 145 | ) 146 | 147 | def on_validation_epoch_end(self): 148 | pass 149 | 150 | def test_step(self, batch, batch_idx): 151 | out = self(batch, batch_idx) 152 | self.save_image_grid( 153 | f"it{self.true_global_step}-test/{batch['index'][0]}.png", 154 | ( 155 | [ 156 | { 157 | "type": "rgb", 158 | "img": out["comp_rgb"][0], 159 | "kwargs": {"data_format": "HWC"}, 160 | }, 161 | ] 162 | if "comp_rgb" in out 163 | else [] 164 | ), 165 | name="test_step", 166 | step=self.true_global_step, 167 | ) 168 | 169 | def on_test_epoch_end(self): 170 | self.save_img_sequence( 171 | f"it{self.true_global_step}-test", 172 | f"it{self.true_global_step}-test", 173 | "(\d+)\.png", 174 | save_format="mp4", 175 | fps=30, 176 | name="test", 177 | step=self.true_global_step, 178 | ) 179 | -------------------------------------------------------------------------------- /threestudio/systems/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import warnings 3 | from bisect import bisect_right 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.optim import lr_scheduler 8 | 9 | import threestudio 10 | 11 | 12 | def get_scheduler(name): 13 | if hasattr(lr_scheduler, name): 14 | return getattr(lr_scheduler, name) 15 | else: 16 | raise NotImplementedError 17 | 18 | 19 | def getattr_recursive(m, attr): 20 | for name in attr.split("."): 21 | m = getattr(m, name) 22 | return m 23 | 24 | 25 | def get_parameters(model, name): 26 | module = getattr_recursive(model, name) 27 | if isinstance(module, nn.Module): 28 | return module.parameters() 29 | elif isinstance(module, nn.Parameter): 30 | return module 31 | return [] 32 | 33 | 34 | def parse_optimizer(config, model): 35 | if hasattr(config, "params"): 36 | params = [ 37 | {"params": get_parameters(model, name), "name": name, **args} 38 | for name, args in config.params.items() 39 | ] 40 | threestudio.debug(f"Specify optimizer params: {config.params}") 41 | else: 42 | params = model.parameters() 43 | if config.name in ["FusedAdam"]: 44 | import apex 45 | 46 | optim = getattr(apex.optimizers, config.name)(params, **config.args) 47 | elif config.name in ["Adan"]: 48 | from threestudio.systems import optimizers 49 | 50 | optim = getattr(optimizers, config.name)(params, **config.args) 51 | else: 52 | optim = getattr(torch.optim, config.name)(params, **config.args) 53 | return optim 54 | 55 | 56 | def parse_scheduler(config, optimizer): 57 | interval = config.get("interval", "epoch") 58 | assert interval in ["epoch", "step"] 59 | if config.name == "SequentialLR": 60 | scheduler = { 61 | "scheduler": lr_scheduler.SequentialLR( 62 | optimizer, 63 | [ 64 | parse_scheduler(conf, optimizer)["scheduler"] 65 | for conf in config.schedulers 66 | ], 67 | milestones=config.milestones, 68 | ), 69 | "interval": interval, 70 | } 71 | elif config.name == "ChainedScheduler": 72 | scheduler = { 73 | "scheduler": lr_scheduler.ChainedScheduler( 74 | [ 75 | parse_scheduler(conf, optimizer)["scheduler"] 76 | for conf in config.schedulers 77 | ] 78 | ), 79 | "interval": interval, 80 | } 81 | else: 82 | scheduler = { 83 | "scheduler": get_scheduler(config.name)(optimizer, **config.args), 84 | "interval": interval, 85 | } 86 | return scheduler 87 | -------------------------------------------------------------------------------- /threestudio/utils/GAN/attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | from inspect import isfunction 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from einops import rearrange, repeat 7 | from torch import einsum, nn 8 | 9 | from threestudio.utils.GAN.network_util import checkpoint 10 | 11 | 12 | def exists(val): 13 | return val is not None 14 | 15 | 16 | def uniq(arr): 17 | return {el: True for el in arr}.keys() 18 | 19 | 20 | def default(val, d): 21 | if exists(val): 22 | return val 23 | return d() if isfunction(d) else d 24 | 25 | 26 | def max_neg_value(t): 27 | return -torch.finfo(t.dtype).max 28 | 29 | 30 | def init_(tensor): 31 | dim = tensor.shape[-1] 32 | std = 1 / math.sqrt(dim) 33 | tensor.uniform_(-std, std) 34 | return tensor 35 | 36 | 37 | # feedforward 38 | class GEGLU(nn.Module): 39 | def __init__(self, dim_in, dim_out): 40 | super().__init__() 41 | self.proj = nn.Linear(dim_in, dim_out * 2) 42 | 43 | def forward(self, x): 44 | x, gate = self.proj(x).chunk(2, dim=-1) 45 | return x * F.gelu(gate) 46 | 47 | 48 | class FeedForward(nn.Module): 49 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): 50 | super().__init__() 51 | inner_dim = int(dim * mult) 52 | dim_out = default(dim_out, dim) 53 | project_in = ( 54 | nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) 55 | if not glu 56 | else GEGLU(dim, inner_dim) 57 | ) 58 | 59 | self.net = nn.Sequential( 60 | project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm( 78 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True 79 | ) 80 | 81 | 82 | class LinearAttention(nn.Module): 83 | def __init__(self, dim, heads=4, dim_head=32): 84 | super().__init__() 85 | self.heads = heads 86 | hidden_dim = dim_head * heads 87 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 88 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 89 | 90 | def forward(self, x): 91 | b, c, h, w = x.shape 92 | qkv = self.to_qkv(x) 93 | q, k, v = rearrange( 94 | qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 95 | ) 96 | k = k.softmax(dim=-1) 97 | context = torch.einsum("bhdn,bhen->bhde", k, v) 98 | out = torch.einsum("bhde,bhdn->bhen", context, q) 99 | out = rearrange( 100 | out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w 101 | ) 102 | return self.to_out(out) 103 | 104 | 105 | class SpatialSelfAttention(nn.Module): 106 | def __init__(self, in_channels): 107 | super().__init__() 108 | self.in_channels = in_channels 109 | 110 | self.norm = Normalize(in_channels) 111 | self.q = torch.nn.Conv2d( 112 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 113 | ) 114 | self.k = torch.nn.Conv2d( 115 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 116 | ) 117 | self.v = torch.nn.Conv2d( 118 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 119 | ) 120 | self.proj_out = torch.nn.Conv2d( 121 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 122 | ) 123 | 124 | def forward(self, x): 125 | h_ = x 126 | h_ = self.norm(h_) 127 | q = self.q(h_) 128 | k = self.k(h_) 129 | v = self.v(h_) 130 | 131 | # compute attention 132 | b, c, h, w = q.shape 133 | q = rearrange(q, "b c h w -> b (h w) c") 134 | k = rearrange(k, "b c h w -> b c (h w)") 135 | w_ = torch.einsum("bij,bjk->bik", q, k) 136 | 137 | w_ = w_ * (int(c) ** (-0.5)) 138 | w_ = torch.nn.functional.softmax(w_, dim=2) 139 | 140 | # attend to values 141 | v = rearrange(v, "b c h w -> b c (h w)") 142 | w_ = rearrange(w_, "b i j -> b j i") 143 | h_ = torch.einsum("bij,bjk->bik", v, w_) 144 | h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) 145 | h_ = self.proj_out(h_) 146 | 147 | return x + h_ 148 | 149 | 150 | class CrossAttention(nn.Module): 151 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): 152 | super().__init__() 153 | inner_dim = dim_head * heads 154 | context_dim = default(context_dim, query_dim) 155 | 156 | self.scale = dim_head**-0.5 157 | self.heads = heads 158 | 159 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 160 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 161 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 162 | 163 | self.to_out = nn.Sequential( 164 | nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) 165 | ) 166 | 167 | def forward(self, x, context=None, mask=None): 168 | h = self.heads 169 | 170 | q = self.to_q(x) 171 | context = default(context, x) 172 | k = self.to_k(context) 173 | v = self.to_v(context) 174 | 175 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) 176 | 177 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale 178 | 179 | if exists(mask): 180 | mask = rearrange(mask, "b ... -> b (...)") 181 | max_neg_value = -torch.finfo(sim.dtype).max 182 | mask = repeat(mask, "b j -> (b h) () j", h=h) 183 | sim.masked_fill_(~mask, max_neg_value) 184 | 185 | # attention, what we cannot get enough of 186 | attn = sim.softmax(dim=-1) 187 | 188 | out = einsum("b i j, b j d -> b i d", attn, v) 189 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 190 | return self.to_out(out) 191 | 192 | 193 | class BasicTransformerBlock(nn.Module): 194 | def __init__( 195 | self, 196 | dim, 197 | n_heads, 198 | d_head, 199 | dropout=0.0, 200 | context_dim=None, 201 | gated_ff=True, 202 | checkpoint=True, 203 | ): 204 | super().__init__() 205 | self.attn1 = CrossAttention( 206 | query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout 207 | ) # is a self-attention 208 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 209 | self.attn2 = CrossAttention( 210 | query_dim=dim, 211 | context_dim=context_dim, 212 | heads=n_heads, 213 | dim_head=d_head, 214 | dropout=dropout, 215 | ) # is self-attn if context is none 216 | self.norm1 = nn.LayerNorm(dim) 217 | self.norm2 = nn.LayerNorm(dim) 218 | self.norm3 = nn.LayerNorm(dim) 219 | self.checkpoint = checkpoint 220 | 221 | def forward(self, x, context=None): 222 | return checkpoint( 223 | self._forward, (x, context), self.parameters(), self.checkpoint 224 | ) 225 | 226 | def _forward(self, x, context=None): 227 | x = self.attn1(self.norm1(x)) + x 228 | x = self.attn2(self.norm2(x), context=context) + x 229 | x = self.ff(self.norm3(x)) + x 230 | return x 231 | 232 | 233 | class SpatialTransformer(nn.Module): 234 | """ 235 | Transformer block for image-like data. 236 | First, project the input (aka embedding) 237 | and reshape to b, t, d. 238 | Then apply standard transformer action. 239 | Finally, reshape to image 240 | """ 241 | 242 | def __init__( 243 | self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None 244 | ): 245 | super().__init__() 246 | self.in_channels = in_channels 247 | inner_dim = n_heads * d_head 248 | self.norm = Normalize(in_channels) 249 | 250 | self.proj_in = nn.Conv2d( 251 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 252 | ) 253 | 254 | self.transformer_blocks = nn.ModuleList( 255 | [ 256 | BasicTransformerBlock( 257 | inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim 258 | ) 259 | for d in range(depth) 260 | ] 261 | ) 262 | 263 | self.proj_out = zero_module( 264 | nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 265 | ) 266 | 267 | def forward(self, x, context=None): 268 | # note: if no context is given, cross-attention defaults to self-attention 269 | b, c, h, w = x.shape 270 | x_in = x 271 | x = self.norm(x) 272 | x = self.proj_in(x) 273 | x = rearrange(x, "b c h w -> b (h w) c") 274 | for block in self.transformer_blocks: 275 | x = block(x, context=context) 276 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 277 | x = self.proj_out(x) 278 | return x + x_in 279 | -------------------------------------------------------------------------------- /threestudio/utils/GAN/discriminator.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def count_params(model): 8 | total_params = sum(p.numel() for p in model.parameters()) 9 | return total_params 10 | 11 | 12 | class ActNorm(nn.Module): 13 | def __init__( 14 | self, num_features, logdet=False, affine=True, allow_reverse_init=False 15 | ): 16 | assert affine 17 | super().__init__() 18 | self.logdet = logdet 19 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 20 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 21 | self.allow_reverse_init = allow_reverse_init 22 | 23 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) 24 | 25 | def initialize(self, input): 26 | with torch.no_grad(): 27 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 28 | mean = ( 29 | flatten.mean(1) 30 | .unsqueeze(1) 31 | .unsqueeze(2) 32 | .unsqueeze(3) 33 | .permute(1, 0, 2, 3) 34 | ) 35 | std = ( 36 | flatten.std(1) 37 | .unsqueeze(1) 38 | .unsqueeze(2) 39 | .unsqueeze(3) 40 | .permute(1, 0, 2, 3) 41 | ) 42 | 43 | self.loc.data.copy_(-mean) 44 | self.scale.data.copy_(1 / (std + 1e-6)) 45 | 46 | def forward(self, input, reverse=False): 47 | if reverse: 48 | return self.reverse(input) 49 | if len(input.shape) == 2: 50 | input = input[:, :, None, None] 51 | squeeze = True 52 | else: 53 | squeeze = False 54 | 55 | _, _, height, width = input.shape 56 | 57 | if self.training and self.initialized.item() == 0: 58 | self.initialize(input) 59 | self.initialized.fill_(1) 60 | 61 | h = self.scale * (input + self.loc) 62 | 63 | if squeeze: 64 | h = h.squeeze(-1).squeeze(-1) 65 | 66 | if self.logdet: 67 | log_abs = torch.log(torch.abs(self.scale)) 68 | logdet = height * width * torch.sum(log_abs) 69 | logdet = logdet * torch.ones(input.shape[0]).to(input) 70 | return h, logdet 71 | 72 | return h 73 | 74 | def reverse(self, output): 75 | if self.training and self.initialized.item() == 0: 76 | if not self.allow_reverse_init: 77 | raise RuntimeError( 78 | "Initializing ActNorm in reverse direction is " 79 | "disabled by default. Use allow_reverse_init=True to enable." 80 | ) 81 | else: 82 | self.initialize(output) 83 | self.initialized.fill_(1) 84 | 85 | if len(output.shape) == 2: 86 | output = output[:, :, None, None] 87 | squeeze = True 88 | else: 89 | squeeze = False 90 | 91 | h = output / self.scale - self.loc 92 | 93 | if squeeze: 94 | h = h.squeeze(-1).squeeze(-1) 95 | return h 96 | 97 | 98 | class AbstractEncoder(nn.Module): 99 | def __init__(self): 100 | super().__init__() 101 | 102 | def encode(self, *args, **kwargs): 103 | raise NotImplementedError 104 | 105 | 106 | class Labelator(AbstractEncoder): 107 | """Net2Net Interface for Class-Conditional Model""" 108 | 109 | def __init__(self, n_classes, quantize_interface=True): 110 | super().__init__() 111 | self.n_classes = n_classes 112 | self.quantize_interface = quantize_interface 113 | 114 | def encode(self, c): 115 | c = c[:, None] 116 | if self.quantize_interface: 117 | return c, None, [None, None, c.long()] 118 | return c 119 | 120 | 121 | class SOSProvider(AbstractEncoder): 122 | # for unconditional training 123 | def __init__(self, sos_token, quantize_interface=True): 124 | super().__init__() 125 | self.sos_token = sos_token 126 | self.quantize_interface = quantize_interface 127 | 128 | def encode(self, x): 129 | # get batch size from data and replicate sos_token 130 | c = torch.ones(x.shape[0], 1) * self.sos_token 131 | c = c.long().to(x.device) 132 | if self.quantize_interface: 133 | return c, None, [None, None, c] 134 | return c 135 | 136 | 137 | def weights_init(m): 138 | classname = m.__class__.__name__ 139 | if classname.find("Conv") != -1: 140 | nn.init.normal_(m.weight.data, 0.0, 0.02) 141 | elif classname.find("BatchNorm") != -1: 142 | nn.init.normal_(m.weight.data, 1.0, 0.02) 143 | nn.init.constant_(m.bias.data, 0) 144 | 145 | 146 | class NLayerDiscriminator(nn.Module): 147 | """Defines a PatchGAN discriminator as in Pix2Pix 148 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 149 | """ 150 | 151 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 152 | """Construct a PatchGAN discriminator 153 | Parameters: 154 | input_nc (int) -- the number of channels in input images 155 | ndf (int) -- the number of filters in the last conv layer 156 | n_layers (int) -- the number of conv layers in the discriminator 157 | norm_layer -- normalization layer 158 | """ 159 | super(NLayerDiscriminator, self).__init__() 160 | if not use_actnorm: 161 | norm_layer = nn.BatchNorm2d 162 | else: 163 | norm_layer = ActNorm 164 | if ( 165 | type(norm_layer) == functools.partial 166 | ): # no need to use bias as BatchNorm2d has affine parameters 167 | use_bias = norm_layer.func != nn.BatchNorm2d 168 | else: 169 | use_bias = norm_layer != nn.BatchNorm2d 170 | 171 | kw = 4 172 | padw = 1 173 | sequence = [ 174 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 175 | nn.LeakyReLU(0.2, True), 176 | ] 177 | nf_mult = 1 178 | nf_mult_prev = 1 179 | for n in range(1, n_layers): # gradually increase the number of filters 180 | nf_mult_prev = nf_mult 181 | nf_mult = min(2**n, 8) 182 | sequence += [ 183 | nn.Conv2d( 184 | ndf * nf_mult_prev, 185 | ndf * nf_mult, 186 | kernel_size=kw, 187 | stride=2, 188 | padding=padw, 189 | bias=use_bias, 190 | ), 191 | norm_layer(ndf * nf_mult), 192 | nn.LeakyReLU(0.2, True), 193 | ] 194 | 195 | nf_mult_prev = nf_mult 196 | nf_mult = min(2**n_layers, 8) 197 | sequence += [ 198 | nn.Conv2d( 199 | ndf * nf_mult_prev, 200 | ndf * nf_mult, 201 | kernel_size=kw, 202 | stride=1, 203 | padding=padw, 204 | bias=use_bias, 205 | ), 206 | norm_layer(ndf * nf_mult), 207 | nn.LeakyReLU(0.2, True), 208 | ] 209 | 210 | sequence += [ 211 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) 212 | ] # output 1 channel prediction map 213 | self.main = nn.Sequential(*sequence) 214 | 215 | def forward(self, input): 216 | """Standard forward.""" 217 | return self.main(input) 218 | -------------------------------------------------------------------------------- /threestudio/utils/GAN/distribution.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.sum( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=[1, 2, 3]): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims, 69 | ) 70 | 71 | def mode(self): 72 | return self.mean 73 | 74 | 75 | def normal_kl(mean1, logvar1, mean2, logvar2): 76 | """ 77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 78 | Compute the KL divergence between two gaussians. 79 | Shapes are automatically broadcasted, so batches can be compared to 80 | scalars, among other use cases. 81 | """ 82 | tensor = None 83 | for obj in (mean1, logvar1, mean2, logvar2): 84 | if isinstance(obj, torch.Tensor): 85 | tensor = obj 86 | break 87 | assert tensor is not None, "at least one argument must be a Tensor" 88 | 89 | # Force variances to be Tensors. Broadcasting helps convert scalars to 90 | # Tensors, but it does not work for torch.exp(). 91 | logvar1, logvar2 = [ 92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 93 | for x in (logvar1, logvar2) 94 | ] 95 | 96 | return 0.5 * ( 97 | -1.0 98 | + logvar2 99 | - logvar1 100 | + torch.exp(logvar1 - logvar2) 101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 102 | ) 103 | -------------------------------------------------------------------------------- /threestudio/utils/GAN/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def generator_loss(discriminator, inputs, reconstructions, cond=None): 6 | if cond is None: 7 | logits_fake = discriminator(reconstructions.contiguous()) 8 | else: 9 | logits_fake = discriminator( 10 | torch.cat((reconstructions.contiguous(), cond), dim=1) 11 | ) 12 | g_loss = -torch.mean(logits_fake) 13 | return g_loss 14 | 15 | 16 | def hinge_d_loss(logits_real, logits_fake): 17 | loss_real = torch.mean(F.relu(1.0 - logits_real)) 18 | loss_fake = torch.mean(F.relu(1.0 + logits_fake)) 19 | d_loss = 0.5 * (loss_real + loss_fake) 20 | return d_loss 21 | 22 | 23 | def discriminator_loss(discriminator, inputs, reconstructions, cond=None): 24 | if cond is None: 25 | logits_real = discriminator(inputs.contiguous().detach()) 26 | logits_fake = discriminator(reconstructions.contiguous().detach()) 27 | else: 28 | logits_real = discriminator( 29 | torch.cat((inputs.contiguous().detach(), cond), dim=1) 30 | ) 31 | logits_fake = discriminator( 32 | torch.cat((reconstructions.contiguous().detach(), cond), dim=1) 33 | ) 34 | d_loss = hinge_d_loss(logits_real, logits_fake).mean() 35 | return d_loss 36 | -------------------------------------------------------------------------------- /threestudio/utils/GAN/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | __all__ = ["MobileNetV3", "mobilenetv3"] 6 | 7 | 8 | def conv_bn( 9 | inp, 10 | oup, 11 | stride, 12 | conv_layer=nn.Conv2d, 13 | norm_layer=nn.BatchNorm2d, 14 | nlin_layer=nn.ReLU, 15 | ): 16 | return nn.Sequential( 17 | conv_layer(inp, oup, 3, stride, 1, bias=False), 18 | norm_layer(oup), 19 | nlin_layer(inplace=True), 20 | ) 21 | 22 | 23 | def conv_1x1_bn( 24 | inp, oup, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, nlin_layer=nn.ReLU 25 | ): 26 | return nn.Sequential( 27 | conv_layer(inp, oup, 1, 1, 0, bias=False), 28 | norm_layer(oup), 29 | nlin_layer(inplace=True), 30 | ) 31 | 32 | 33 | class Hswish(nn.Module): 34 | def __init__(self, inplace=True): 35 | super(Hswish, self).__init__() 36 | self.inplace = inplace 37 | 38 | def forward(self, x): 39 | return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0 40 | 41 | 42 | class Hsigmoid(nn.Module): 43 | def __init__(self, inplace=True): 44 | super(Hsigmoid, self).__init__() 45 | self.inplace = inplace 46 | 47 | def forward(self, x): 48 | return F.relu6(x + 3.0, inplace=self.inplace) / 6.0 49 | 50 | 51 | class SEModule(nn.Module): 52 | def __init__(self, channel, reduction=4): 53 | super(SEModule, self).__init__() 54 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 55 | self.fc = nn.Sequential( 56 | nn.Linear(channel, channel // reduction, bias=False), 57 | nn.ReLU(inplace=True), 58 | nn.Linear(channel // reduction, channel, bias=False), 59 | Hsigmoid() 60 | # nn.Sigmoid() 61 | ) 62 | 63 | def forward(self, x): 64 | b, c, _, _ = x.size() 65 | y = self.avg_pool(x).view(b, c) 66 | y = self.fc(y).view(b, c, 1, 1) 67 | return x * y.expand_as(x) 68 | 69 | 70 | class Identity(nn.Module): 71 | def __init__(self, channel): 72 | super(Identity, self).__init__() 73 | 74 | def forward(self, x): 75 | return x 76 | 77 | 78 | def make_divisible(x, divisible_by=8): 79 | import numpy as np 80 | 81 | return int(np.ceil(x * 1.0 / divisible_by) * divisible_by) 82 | 83 | 84 | class MobileBottleneck(nn.Module): 85 | def __init__(self, inp, oup, kernel, stride, exp, se=False, nl="RE"): 86 | super(MobileBottleneck, self).__init__() 87 | assert stride in [1, 2] 88 | assert kernel in [3, 5] 89 | padding = (kernel - 1) // 2 90 | self.use_res_connect = stride == 1 and inp == oup 91 | 92 | conv_layer = nn.Conv2d 93 | norm_layer = nn.BatchNorm2d 94 | if nl == "RE": 95 | nlin_layer = nn.ReLU # or ReLU6 96 | elif nl == "HS": 97 | nlin_layer = Hswish 98 | else: 99 | raise NotImplementedError 100 | if se: 101 | SELayer = SEModule 102 | else: 103 | SELayer = Identity 104 | 105 | self.conv = nn.Sequential( 106 | # pw 107 | conv_layer(inp, exp, 1, 1, 0, bias=False), 108 | norm_layer(exp), 109 | nlin_layer(inplace=True), 110 | # dw 111 | conv_layer(exp, exp, kernel, stride, padding, groups=exp, bias=False), 112 | norm_layer(exp), 113 | SELayer(exp), 114 | nlin_layer(inplace=True), 115 | # pw-linear 116 | conv_layer(exp, oup, 1, 1, 0, bias=False), 117 | norm_layer(oup), 118 | ) 119 | 120 | def forward(self, x): 121 | if self.use_res_connect: 122 | return x + self.conv(x) 123 | else: 124 | return self.conv(x) 125 | 126 | 127 | class MobileNetV3(nn.Module): 128 | def __init__( 129 | self, n_class=1000, input_size=224, dropout=0.0, mode="small", width_mult=1.0 130 | ): 131 | super(MobileNetV3, self).__init__() 132 | input_channel = 16 133 | last_channel = 1280 134 | if mode == "large": 135 | # refer to Table 1 in paper 136 | mobile_setting = [ 137 | # k, exp, c, se, nl, s, 138 | [3, 16, 16, False, "RE", 1], 139 | [3, 64, 24, False, "RE", 2], 140 | [3, 72, 24, False, "RE", 1], 141 | [5, 72, 40, True, "RE", 2], 142 | [5, 120, 40, True, "RE", 1], 143 | [5, 120, 40, True, "RE", 1], 144 | [3, 240, 80, False, "HS", 2], 145 | [3, 200, 80, False, "HS", 1], 146 | [3, 184, 80, False, "HS", 1], 147 | [3, 184, 80, False, "HS", 1], 148 | [3, 480, 112, True, "HS", 1], 149 | [3, 672, 112, True, "HS", 1], 150 | [5, 672, 160, True, "HS", 2], 151 | [5, 960, 160, True, "HS", 1], 152 | [5, 960, 160, True, "HS", 1], 153 | ] 154 | elif mode == "small": 155 | # refer to Table 2 in paper 156 | mobile_setting = [ 157 | # k, exp, c, se, nl, s, 158 | [3, 16, 16, True, "RE", 2], 159 | [3, 72, 24, False, "RE", 2], 160 | [3, 88, 24, False, "RE", 1], 161 | [5, 96, 40, True, "HS", 2], 162 | [5, 240, 40, True, "HS", 1], 163 | [5, 240, 40, True, "HS", 1], 164 | [5, 120, 48, True, "HS", 1], 165 | [5, 144, 48, True, "HS", 1], 166 | [5, 288, 96, True, "HS", 2], 167 | [5, 576, 96, True, "HS", 1], 168 | [5, 576, 96, True, "HS", 1], 169 | ] 170 | else: 171 | raise NotImplementedError 172 | 173 | # building first layer 174 | assert input_size % 32 == 0 175 | last_channel = ( 176 | make_divisible(last_channel * width_mult) 177 | if width_mult > 1.0 178 | else last_channel 179 | ) 180 | self.features = [conv_bn(3, input_channel, 2, nlin_layer=Hswish)] 181 | self.classifier = [] 182 | 183 | # building mobile blocks 184 | for k, exp, c, se, nl, s in mobile_setting: 185 | output_channel = make_divisible(c * width_mult) 186 | exp_channel = make_divisible(exp * width_mult) 187 | self.features.append( 188 | MobileBottleneck( 189 | input_channel, output_channel, k, s, exp_channel, se, nl 190 | ) 191 | ) 192 | input_channel = output_channel 193 | 194 | # building last several layers 195 | if mode == "large": 196 | last_conv = make_divisible(960 * width_mult) 197 | self.features.append( 198 | conv_1x1_bn(input_channel, last_conv, nlin_layer=Hswish) 199 | ) 200 | self.features.append(nn.AdaptiveAvgPool2d(1)) 201 | self.features.append(nn.Conv2d(last_conv, last_channel, 1, 1, 0)) 202 | self.features.append(Hswish(inplace=True)) 203 | elif mode == "small": 204 | last_conv = make_divisible(576 * width_mult) 205 | self.features.append( 206 | conv_1x1_bn(input_channel, last_conv, nlin_layer=Hswish) 207 | ) 208 | # self.features.append(SEModule(last_conv)) # refer to paper Table2, but I think this is a mistake 209 | self.features.append(nn.AdaptiveAvgPool2d(1)) 210 | self.features.append(nn.Conv2d(last_conv, last_channel, 1, 1, 0)) 211 | self.features.append(Hswish(inplace=True)) 212 | else: 213 | raise NotImplementedError 214 | 215 | # make it nn.Sequential 216 | self.features = nn.Sequential(*self.features) 217 | 218 | # building classifier 219 | self.classifier = nn.Sequential( 220 | nn.Dropout(p=dropout), # refer to paper section 6 221 | nn.Linear(last_channel, n_class), 222 | ) 223 | 224 | self._initialize_weights() 225 | 226 | def forward(self, x): 227 | x = self.features(x) 228 | x = x.mean(3).mean(2) 229 | x = self.classifier(x) 230 | return x 231 | 232 | def _initialize_weights(self): 233 | # weight initialization 234 | for m in self.modules(): 235 | if isinstance(m, nn.Conv2d): 236 | nn.init.kaiming_normal_(m.weight, mode="fan_out") 237 | if m.bias is not None: 238 | nn.init.zeros_(m.bias) 239 | elif isinstance(m, nn.BatchNorm2d): 240 | nn.init.ones_(m.weight) 241 | nn.init.zeros_(m.bias) 242 | elif isinstance(m, nn.Linear): 243 | nn.init.normal_(m.weight, 0, 0.01) 244 | if m.bias is not None: 245 | nn.init.zeros_(m.bias) 246 | 247 | 248 | def mobilenetv3(pretrained=False, **kwargs): 249 | model = MobileNetV3(**kwargs) 250 | if pretrained: 251 | state_dict = torch.load("mobilenetv3_small_67.4.pth.tar") 252 | model.load_state_dict(state_dict, strict=True) 253 | # raise NotImplementedError 254 | return model 255 | -------------------------------------------------------------------------------- /threestudio/utils/GAN/network_util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import math 12 | import os 13 | 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | from einops import repeat 18 | 19 | from threestudio.utils.GAN.util import instantiate_from_config 20 | 21 | 22 | def make_beta_schedule( 23 | schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 24 | ): 25 | if schedule == "linear": 26 | betas = ( 27 | torch.linspace( 28 | linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 29 | ) 30 | ** 2 31 | ) 32 | 33 | elif schedule == "cosine": 34 | timesteps = ( 35 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 36 | ) 37 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 38 | alphas = torch.cos(alphas).pow(2) 39 | alphas = alphas / alphas[0] 40 | betas = 1 - alphas[1:] / alphas[:-1] 41 | betas = np.clip(betas, a_min=0, a_max=0.999) 42 | 43 | elif schedule == "sqrt_linear": 44 | betas = torch.linspace( 45 | linear_start, linear_end, n_timestep, dtype=torch.float64 46 | ) 47 | elif schedule == "sqrt": 48 | betas = ( 49 | torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 50 | ** 0.5 51 | ) 52 | else: 53 | raise ValueError(f"schedule '{schedule}' unknown.") 54 | return betas.numpy() 55 | 56 | 57 | def make_ddim_timesteps( 58 | ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True 59 | ): 60 | if ddim_discr_method == "uniform": 61 | c = num_ddpm_timesteps // num_ddim_timesteps 62 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 63 | elif ddim_discr_method == "quad": 64 | ddim_timesteps = ( 65 | (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 66 | ).astype(int) 67 | else: 68 | raise NotImplementedError( 69 | f'There is no ddim discretization method called "{ddim_discr_method}"' 70 | ) 71 | 72 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 73 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 74 | steps_out = ddim_timesteps + 1 75 | if verbose: 76 | print(f"Selected timesteps for ddim sampler: {steps_out}") 77 | return steps_out 78 | 79 | 80 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 81 | # select alphas for computing the variance schedule 82 | alphas = alphacums[ddim_timesteps] 83 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 84 | 85 | # according the the formula provided in https://arxiv.org/abs/2010.02502 86 | sigmas = eta * np.sqrt( 87 | (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) 88 | ) 89 | if verbose: 90 | print( 91 | f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" 92 | ) 93 | print( 94 | f"For the chosen value of eta, which is {eta}, " 95 | f"this results in the following sigma_t schedule for ddim sampler {sigmas}" 96 | ) 97 | return sigmas, alphas, alphas_prev 98 | 99 | 100 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 101 | """ 102 | Create a beta schedule that discretizes the given alpha_t_bar function, 103 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 104 | :param num_diffusion_timesteps: the number of betas to produce. 105 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 106 | produces the cumulative product of (1-beta) up to that 107 | part of the diffusion process. 108 | :param max_beta: the maximum beta to use; use values lower than 1 to 109 | prevent singularities. 110 | """ 111 | betas = [] 112 | for i in range(num_diffusion_timesteps): 113 | t1 = i / num_diffusion_timesteps 114 | t2 = (i + 1) / num_diffusion_timesteps 115 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 116 | return np.array(betas) 117 | 118 | 119 | def extract_into_tensor(a, t, x_shape): 120 | b, *_ = t.shape 121 | out = a.gather(-1, t) 122 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 123 | 124 | 125 | def checkpoint(func, inputs, params, flag): 126 | """ 127 | Evaluate a function without caching intermediate activations, allowing for 128 | reduced memory at the expense of extra compute in the backward pass. 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(torch.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | 149 | with torch.no_grad(): 150 | output_tensors = ctx.run_function(*ctx.input_tensors) 151 | return output_tensors 152 | 153 | @staticmethod 154 | def backward(ctx, *output_grads): 155 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 156 | with torch.enable_grad(): 157 | # Fixes a bug where the first op in run_function modifies the 158 | # Tensor storage in place, which is not allowed for detach()'d 159 | # Tensors. 160 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 161 | output_tensors = ctx.run_function(*shallow_copies) 162 | input_grads = torch.autograd.grad( 163 | output_tensors, 164 | ctx.input_tensors + ctx.input_params, 165 | output_grads, 166 | allow_unused=True, 167 | ) 168 | del ctx.input_tensors 169 | del ctx.input_params 170 | del output_tensors 171 | return (None, None) + input_grads 172 | 173 | 174 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 175 | """ 176 | Create sinusoidal timestep embeddings. 177 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 178 | These may be fractional. 179 | :param dim: the dimension of the output. 180 | :param max_period: controls the minimum frequency of the embeddings. 181 | :return: an [N x dim] Tensor of positional embeddings. 182 | """ 183 | if not repeat_only: 184 | half = dim // 2 185 | freqs = torch.exp( 186 | -math.log(max_period) 187 | * torch.arange(start=0, end=half, dtype=torch.float32) 188 | / half 189 | ).to(device=timesteps.device) 190 | args = timesteps[:, None].float() * freqs[None] 191 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 192 | if dim % 2: 193 | embedding = torch.cat( 194 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 195 | ) 196 | else: 197 | embedding = repeat(timesteps, "b -> b d", d=dim) 198 | return embedding 199 | 200 | 201 | def zero_module(module): 202 | """ 203 | Zero out the parameters of a module and return it. 204 | """ 205 | for p in module.parameters(): 206 | p.detach().zero_() 207 | return module 208 | 209 | 210 | def scale_module(module, scale): 211 | """ 212 | Scale the parameters of a module and return it. 213 | """ 214 | for p in module.parameters(): 215 | p.detach().mul_(scale) 216 | return module 217 | 218 | 219 | def mean_flat(tensor): 220 | """ 221 | Take the mean over all non-batch dimensions. 222 | """ 223 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 224 | 225 | 226 | def normalization(channels): 227 | """ 228 | Make a standard normalization layer. 229 | :param channels: number of input channels. 230 | :return: an nn.Module for normalization. 231 | """ 232 | return GroupNorm32(32, channels) 233 | 234 | 235 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 236 | class SiLU(nn.Module): 237 | def forward(self, x): 238 | return x * torch.sigmoid(x) 239 | 240 | 241 | class GroupNorm32(nn.GroupNorm): 242 | def forward(self, x): 243 | return super().forward(x.float()).type(x.dtype) 244 | 245 | 246 | def conv_nd(dims, *args, **kwargs): 247 | """ 248 | Create a 1D, 2D, or 3D convolution module. 249 | """ 250 | if dims == 1: 251 | return nn.Conv1d(*args, **kwargs) 252 | elif dims == 2: 253 | return nn.Conv2d(*args, **kwargs) 254 | elif dims == 3: 255 | return nn.Conv3d(*args, **kwargs) 256 | raise ValueError(f"unsupported dimensions: {dims}") 257 | 258 | 259 | def linear(*args, **kwargs): 260 | """ 261 | Create a linear module. 262 | """ 263 | return nn.Linear(*args, **kwargs) 264 | 265 | 266 | def avg_pool_nd(dims, *args, **kwargs): 267 | """ 268 | Create a 1D, 2D, or 3D average pooling module. 269 | """ 270 | if dims == 1: 271 | return nn.AvgPool1d(*args, **kwargs) 272 | elif dims == 2: 273 | return nn.AvgPool2d(*args, **kwargs) 274 | elif dims == 3: 275 | return nn.AvgPool3d(*args, **kwargs) 276 | raise ValueError(f"unsupported dimensions: {dims}") 277 | 278 | 279 | class HybridConditioner(nn.Module): 280 | def __init__(self, c_concat_config, c_crossattn_config): 281 | super().__init__() 282 | self.concat_conditioner = instantiate_from_config(c_concat_config) 283 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 284 | 285 | def forward(self, c_concat, c_crossattn): 286 | c_concat = self.concat_conditioner(c_concat) 287 | c_crossattn = self.crossattn_conditioner(c_crossattn) 288 | return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} 289 | 290 | 291 | def noise_like(shape, device, repeat=False): 292 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( 293 | shape[0], *((1,) * (len(shape) - 1)) 294 | ) 295 | noise = lambda: torch.randn(shape, device=device) 296 | return repeat_noise() if repeat else noise() 297 | -------------------------------------------------------------------------------- /threestudio/utils/GAN/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import multiprocessing as mp 3 | from collections import abc 4 | from functools import partial 5 | from inspect import isfunction 6 | from queue import Queue 7 | from threading import Thread 8 | 9 | import numpy as np 10 | import torch 11 | from einops import rearrange 12 | from PIL import Image, ImageDraw, ImageFont 13 | 14 | 15 | def log_txt_as_img(wh, xc, size=10): 16 | # wh a tuple of (width, height) 17 | # xc a list of captions to plot 18 | b = len(xc) 19 | txts = list() 20 | for bi in range(b): 21 | txt = Image.new("RGB", wh, color="white") 22 | draw = ImageDraw.Draw(txt) 23 | font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) 24 | nc = int(40 * (wh[0] / 256)) 25 | lines = "\n".join( 26 | xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc) 27 | ) 28 | 29 | try: 30 | draw.text((0, 0), lines, fill="black", font=font) 31 | except UnicodeEncodeError: 32 | print("Cant encode string for logging. Skipping.") 33 | 34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 35 | txts.append(txt) 36 | txts = np.stack(txts) 37 | txts = torch.tensor(txts) 38 | return txts 39 | 40 | 41 | def ismap(x): 42 | if not isinstance(x, torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] > 3) 45 | 46 | 47 | def isimage(x): 48 | if not isinstance(x, torch.Tensor): 49 | return False 50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 51 | 52 | 53 | def exists(x): 54 | return x is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | 62 | 63 | def mean_flat(tensor): 64 | """ 65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 66 | Take the mean over all non-batch dimensions. 67 | """ 68 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 69 | 70 | 71 | def count_params(model, verbose=False): 72 | total_params = sum(p.numel() for p in model.parameters()) 73 | if verbose: 74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 75 | return total_params 76 | 77 | 78 | def instantiate_from_config(config): 79 | if not "target" in config: 80 | if config == "__is_first_stage__": 81 | return None 82 | elif config == "__is_unconditional__": 83 | return None 84 | raise KeyError("Expected key `target` to instantiate.") 85 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 86 | 87 | 88 | def get_obj_from_str(string, reload=False): 89 | module, cls = string.rsplit(".", 1) 90 | if reload: 91 | module_imp = importlib.import_module(module) 92 | importlib.reload(module_imp) 93 | return getattr(importlib.import_module(module, package=None), cls) 94 | 95 | 96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 97 | # create dummy dataset instance 98 | 99 | # run prefetching 100 | if idx_to_fn: 101 | res = func(data, worker_id=idx) 102 | else: 103 | res = func(data) 104 | Q.put([idx, res]) 105 | Q.put("Done") 106 | 107 | 108 | def parallel_data_prefetch( 109 | func: callable, 110 | data, 111 | n_proc, 112 | target_data_type="ndarray", 113 | cpu_intensive=True, 114 | use_worker_id=False, 115 | ): 116 | # if target_data_type not in ["ndarray", "list"]: 117 | # raise ValueError( 118 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 119 | # ) 120 | if isinstance(data, np.ndarray) and target_data_type == "list": 121 | raise ValueError("list expected but function got ndarray.") 122 | elif isinstance(data, abc.Iterable): 123 | if isinstance(data, dict): 124 | print( 125 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 126 | ) 127 | data = list(data.values()) 128 | if target_data_type == "ndarray": 129 | data = np.asarray(data) 130 | else: 131 | data = list(data) 132 | else: 133 | raise TypeError( 134 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 135 | ) 136 | 137 | if cpu_intensive: 138 | Q = mp.Queue(1000) 139 | proc = mp.Process 140 | else: 141 | Q = Queue(1000) 142 | proc = Thread 143 | # spawn processes 144 | if target_data_type == "ndarray": 145 | arguments = [ 146 | [func, Q, part, i, use_worker_id] 147 | for i, part in enumerate(np.array_split(data, n_proc)) 148 | ] 149 | else: 150 | step = ( 151 | int(len(data) / n_proc + 1) 152 | if len(data) % n_proc != 0 153 | else int(len(data) / n_proc) 154 | ) 155 | arguments = [ 156 | [func, Q, part, i, use_worker_id] 157 | for i, part in enumerate( 158 | [data[i : i + step] for i in range(0, len(data), step)] 159 | ) 160 | ] 161 | processes = [] 162 | for i in range(n_proc): 163 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 164 | processes += [p] 165 | 166 | # start processes 167 | print(f"Start prefetching...") 168 | import time 169 | 170 | start = time.time() 171 | gather_res = [[] for _ in range(n_proc)] 172 | try: 173 | for p in processes: 174 | p.start() 175 | 176 | k = 0 177 | while k < n_proc: 178 | # get result 179 | res = Q.get() 180 | if res == "Done": 181 | k += 1 182 | else: 183 | gather_res[res[0]] = res[1] 184 | 185 | except Exception as e: 186 | print("Exception: ", e) 187 | for p in processes: 188 | p.terminate() 189 | 190 | raise e 191 | finally: 192 | for p in processes: 193 | p.join() 194 | print(f"Prefetching complete. [{time.time() - start} sec.]") 195 | 196 | if target_data_type == "ndarray": 197 | if not isinstance(gather_res[0], np.ndarray): 198 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 199 | 200 | # order outputs 201 | return np.concatenate(gather_res, axis=0) 202 | elif target_data_type == "list": 203 | out = [] 204 | for r in gather_res: 205 | out.extend(r) 206 | return out 207 | else: 208 | return gather_res 209 | -------------------------------------------------------------------------------- /threestudio/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import base 2 | -------------------------------------------------------------------------------- /threestudio/utils/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from threestudio.utils.config import parse_structured 7 | from threestudio.utils.misc import get_device, load_module_weights 8 | from threestudio.utils.typing import * 9 | 10 | 11 | class Configurable: 12 | @dataclass 13 | class Config: 14 | pass 15 | 16 | def __init__(self, cfg: Optional[dict] = None) -> None: 17 | super().__init__() 18 | self.cfg = parse_structured(self.Config, cfg) 19 | 20 | 21 | class Updateable: 22 | def do_update_step( 23 | self, epoch: int, global_step: int, on_load_weights: bool = False 24 | ): 25 | for attr in self.__dir__(): 26 | if attr.startswith("_"): 27 | continue 28 | try: 29 | module = getattr(self, attr) 30 | except: 31 | continue # ignore attributes like property, which can't be retrived using getattr? 32 | if isinstance(module, Updateable): 33 | module.do_update_step( 34 | epoch, global_step, on_load_weights=on_load_weights 35 | ) 36 | self.update_step(epoch, global_step, on_load_weights=on_load_weights) 37 | 38 | def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): 39 | # override this method to implement custom update logic 40 | # if on_load_weights is True, you should be careful doing things related to model evaluations, 41 | # as the models and tensors are not guarenteed to be on the same device 42 | pass 43 | 44 | 45 | def update_if_possible(module: Any, epoch: int, global_step: int) -> None: 46 | if isinstance(module, Updateable): 47 | module.do_update_step(epoch, global_step) 48 | 49 | 50 | class BaseObject(Updateable): 51 | @dataclass 52 | class Config: 53 | pass 54 | 55 | cfg: Config # add this to every subclass of BaseObject to enable static type checking 56 | 57 | def __init__( 58 | self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs 59 | ) -> None: 60 | super().__init__() 61 | self.cfg = parse_structured(self.Config, cfg) 62 | self.device = get_device() 63 | self.configure(*args, **kwargs) 64 | 65 | def configure(self, *args, **kwargs) -> None: 66 | pass 67 | 68 | 69 | class BaseModule(nn.Module, Updateable): 70 | @dataclass 71 | class Config: 72 | weights: Optional[str] = None 73 | 74 | cfg: Config # add this to every subclass of BaseModule to enable static type checking 75 | 76 | def __init__( 77 | self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs 78 | ) -> None: 79 | super().__init__() 80 | self.cfg = parse_structured(self.Config, cfg) 81 | self.device = get_device() 82 | self.configure(*args, **kwargs) 83 | if self.cfg.weights is not None: 84 | # format: path/to/weights:module_name 85 | weights_path, module_name = self.cfg.weights.split(":") 86 | state_dict, epoch, global_step = load_module_weights( 87 | weights_path, module_name=module_name, map_location="cpu" 88 | ) 89 | self.load_state_dict(state_dict) 90 | self.do_update_step( 91 | epoch, global_step, on_load_weights=True 92 | ) # restore states 93 | # dummy tensor to indicate model state 94 | self._dummy: Float[Tensor, "..."] 95 | self.register_buffer("_dummy", torch.zeros(0).float(), persistent=False) 96 | 97 | def configure(self, *args, **kwargs) -> None: 98 | pass 99 | -------------------------------------------------------------------------------- /threestudio/utils/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | 5 | import pytorch_lightning 6 | 7 | from threestudio.utils.config import dump_config 8 | from threestudio.utils.misc import parse_version 9 | 10 | if parse_version(pytorch_lightning.__version__) > parse_version("1.8"): 11 | from pytorch_lightning.callbacks import Callback 12 | else: 13 | from pytorch_lightning.callbacks.base import Callback 14 | 15 | from pytorch_lightning.callbacks.progress import TQDMProgressBar 16 | from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn 17 | 18 | 19 | class VersionedCallback(Callback): 20 | def __init__(self, save_root, version=None, use_version=True): 21 | self.save_root = save_root 22 | self._version = version 23 | self.use_version = use_version 24 | 25 | @property 26 | def version(self) -> int: 27 | """Get the experiment version. 28 | 29 | Returns: 30 | The experiment version if specified else the next version. 31 | """ 32 | if self._version is None: 33 | self._version = self._get_next_version() 34 | return self._version 35 | 36 | def _get_next_version(self): 37 | existing_versions = [] 38 | if os.path.isdir(self.save_root): 39 | for f in os.listdir(self.save_root): 40 | bn = os.path.basename(f) 41 | if bn.startswith("version_"): 42 | dir_ver = os.path.splitext(bn)[0].split("_")[1].replace("/", "") 43 | existing_versions.append(int(dir_ver)) 44 | if len(existing_versions) == 0: 45 | return 0 46 | return max(existing_versions) + 1 47 | 48 | @property 49 | def savedir(self): 50 | if not self.use_version: 51 | return self.save_root 52 | return os.path.join( 53 | self.save_root, 54 | self.version 55 | if isinstance(self.version, str) 56 | else f"version_{self.version}", 57 | ) 58 | 59 | 60 | class CodeSnapshotCallback(VersionedCallback): 61 | def __init__(self, save_root, version=None, use_version=True): 62 | super().__init__(save_root, version, use_version) 63 | 64 | def get_file_list(self): 65 | return [ 66 | b.decode() 67 | for b in set( 68 | subprocess.check_output( 69 | 'git ls-files -- ":!:load/*"', shell=True 70 | ).splitlines() 71 | ) 72 | | set( # hard code, TODO: use config to exclude folders or files 73 | subprocess.check_output( 74 | "git ls-files --others --exclude-standard", shell=True 75 | ).splitlines() 76 | ) 77 | ] 78 | 79 | @rank_zero_only 80 | def save_code_snapshot(self): 81 | os.makedirs(self.savedir, exist_ok=True) 82 | for f in self.get_file_list(): 83 | if not os.path.exists(f) or os.path.isdir(f): 84 | continue 85 | os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True) 86 | shutil.copyfile(f, os.path.join(self.savedir, f)) 87 | 88 | def on_fit_start(self, trainer, pl_module): 89 | try: 90 | self.save_code_snapshot() 91 | except: 92 | rank_zero_warn( 93 | "Code snapshot is not saved. Please make sure you have git installed and are in a git repository." 94 | ) 95 | 96 | 97 | class ConfigSnapshotCallback(VersionedCallback): 98 | def __init__(self, config_path, config, save_root, version=None, use_version=True): 99 | super().__init__(save_root, version, use_version) 100 | self.config_path = config_path 101 | self.config = config 102 | 103 | @rank_zero_only 104 | def save_config_snapshot(self): 105 | os.makedirs(self.savedir, exist_ok=True) 106 | dump_config(os.path.join(self.savedir, "parsed.yaml"), self.config) 107 | shutil.copyfile(self.config_path, os.path.join(self.savedir, "raw.yaml")) 108 | 109 | def on_fit_start(self, trainer, pl_module): 110 | self.save_config_snapshot() 111 | 112 | 113 | class CustomProgressBar(TQDMProgressBar): 114 | def get_metrics(self, *args, **kwargs): 115 | # don't show the version number 116 | items = super().get_metrics(*args, **kwargs) 117 | items.pop("v_num", None) 118 | return items 119 | 120 | 121 | class ProgressCallback(Callback): 122 | def __init__(self, save_path): 123 | super().__init__() 124 | self.save_path = save_path 125 | self._file_handle = None 126 | 127 | @property 128 | def file_handle(self): 129 | if self._file_handle is None: 130 | self._file_handle = open(self.save_path, "w") 131 | return self._file_handle 132 | 133 | @rank_zero_only 134 | def write(self, msg: str) -> None: 135 | self.file_handle.seek(0) 136 | self.file_handle.truncate() 137 | self.file_handle.write(msg) 138 | self.file_handle.flush() 139 | 140 | @rank_zero_only 141 | def on_train_batch_end(self, trainer, pl_module, *args, **kwargs): 142 | self.write( 143 | f"Generation progress: {pl_module.true_global_step / trainer.max_steps * 100:.2f}%" 144 | ) 145 | 146 | @rank_zero_only 147 | def on_validation_start(self, trainer, pl_module): 148 | self.write(f"Rendering validation image ...") 149 | 150 | @rank_zero_only 151 | def on_test_start(self, trainer, pl_module): 152 | self.write(f"Rendering video ...") 153 | 154 | @rank_zero_only 155 | def on_predict_start(self, trainer, pl_module): 156 | self.write(f"Exporting mesh assets ...") 157 | -------------------------------------------------------------------------------- /threestudio/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from datetime import datetime 4 | 5 | from omegaconf import OmegaConf 6 | 7 | import threestudio 8 | from threestudio.utils.typing import * 9 | 10 | # ============ Register OmegaConf Recolvers ============= # 11 | OmegaConf.register_new_resolver( 12 | "calc_exp_lr_decay_rate", lambda factor, n: factor ** (1.0 / n) 13 | ) 14 | OmegaConf.register_new_resolver("add", lambda a, b: a + b) 15 | OmegaConf.register_new_resolver("sub", lambda a, b: a - b) 16 | OmegaConf.register_new_resolver("mul", lambda a, b: a * b) 17 | OmegaConf.register_new_resolver("div", lambda a, b: a / b) 18 | OmegaConf.register_new_resolver("idiv", lambda a, b: a // b) 19 | OmegaConf.register_new_resolver("basename", lambda p: os.path.basename(p)) 20 | OmegaConf.register_new_resolver("rmspace", lambda s, sub: s.replace(" ", sub)) 21 | OmegaConf.register_new_resolver("tuple2", lambda s: [float(s), float(s)]) 22 | OmegaConf.register_new_resolver("gt0", lambda s: s > 0) 23 | OmegaConf.register_new_resolver("cmaxgt0", lambda s: C_max(s) > 0) 24 | OmegaConf.register_new_resolver("not", lambda s: not s) 25 | OmegaConf.register_new_resolver( 26 | "cmaxgt0orcmaxgt0", lambda a, b: C_max(a) > 0 or C_max(b) > 0 27 | ) 28 | # ======================================================= # 29 | 30 | 31 | def C_max(value: Any) -> float: 32 | if isinstance(value, int) or isinstance(value, float): 33 | pass 34 | else: 35 | value = config_to_primitive(value) 36 | if not isinstance(value, list): 37 | raise TypeError("Scalar specification only supports list, got", type(value)) 38 | if len(value) == 3: 39 | value = [0] + value 40 | assert len(value) == 4 41 | start_step, start_value, end_value, end_step = value 42 | value = max(start_value, end_value) 43 | return value 44 | 45 | 46 | @dataclass 47 | class ExperimentConfig: 48 | name: str = "default" 49 | description: str = "" 50 | tag: str = "" 51 | seed: int = 0 52 | use_timestamp: bool = True 53 | timestamp: Optional[str] = None 54 | exp_root_dir: str = "outputs" 55 | 56 | ### these shouldn't be set manually 57 | exp_dir: str = "outputs/default" 58 | trial_name: str = "exp" 59 | trial_dir: str = "outputs/default/exp" 60 | n_gpus: int = 1 61 | ### 62 | 63 | resume: Optional[str] = None 64 | 65 | data_type: str = "" 66 | data: dict = field(default_factory=dict) 67 | 68 | system_type: str = "" 69 | system: dict = field(default_factory=dict) 70 | 71 | # accept pytorch-lightning trainer parameters 72 | # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api 73 | trainer: dict = field(default_factory=dict) 74 | 75 | # accept pytorch-lightning checkpoint callback parameters 76 | # see https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint 77 | checkpoint: dict = field(default_factory=dict) 78 | 79 | def __post_init__(self): 80 | if not self.tag and not self.use_timestamp: 81 | raise ValueError("Either tag is specified or use_timestamp is True.") 82 | self.trial_name = self.tag 83 | # if resume from an existing config, self.timestamp should not be None 84 | if self.timestamp is None: 85 | self.timestamp = "" 86 | if self.use_timestamp: 87 | if self.n_gpus > 1: 88 | threestudio.warn( 89 | "Timestamp is disabled when using multiple GPUs, please make sure you have a unique tag." 90 | ) 91 | else: 92 | self.timestamp = datetime.now().strftime("@%Y%m%d-%H%M%S") 93 | self.trial_name += self.timestamp 94 | self.exp_dir = os.path.join(self.exp_root_dir, self.name) 95 | self.trial_dir = os.path.join(self.exp_dir, self.trial_name) 96 | os.makedirs(self.trial_dir, exist_ok=True) 97 | 98 | 99 | def load_config(*yamls: str, cli_args: list = [], from_string=False, **kwargs) -> Any: 100 | if from_string: 101 | yaml_confs = [OmegaConf.create(s) for s in yamls] 102 | else: 103 | yaml_confs = [OmegaConf.load(f) for f in yamls] 104 | cli_conf = OmegaConf.from_cli(cli_args) 105 | cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs) 106 | OmegaConf.resolve(cfg) 107 | assert isinstance(cfg, DictConfig) 108 | scfg = parse_structured(ExperimentConfig, cfg) 109 | return scfg 110 | 111 | 112 | def config_to_primitive(config, resolve: bool = True) -> Any: 113 | return OmegaConf.to_container(config, resolve=resolve) 114 | 115 | 116 | def dump_config(path: str, config) -> None: 117 | with open(path, "w") as fp: 118 | OmegaConf.save(config=config, f=fp) 119 | 120 | 121 | def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: 122 | scfg = OmegaConf.structured(fields(**cfg)) 123 | return scfg 124 | -------------------------------------------------------------------------------- /threestudio/utils/misc.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | import re 4 | 5 | import tinycudann as tcnn 6 | import torch 7 | from packaging import version 8 | 9 | from threestudio.utils.config import config_to_primitive 10 | from threestudio.utils.typing import * 11 | 12 | 13 | def parse_version(ver: str): 14 | return version.parse(ver) 15 | 16 | 17 | def get_rank(): 18 | # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, 19 | # therefore LOCAL_RANK needs to be checked first 20 | rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") 21 | for key in rank_keys: 22 | rank = os.environ.get(key) 23 | if rank is not None: 24 | return int(rank) 25 | return 0 26 | 27 | 28 | def get_device(): 29 | return torch.device(f"cuda:{get_rank()}") 30 | 31 | 32 | def load_module_weights( 33 | path, module_name=None, ignore_modules=None, map_location=None 34 | ) -> Tuple[dict, int, int]: 35 | if module_name is not None and ignore_modules is not None: 36 | raise ValueError("module_name and ignore_modules cannot be both set") 37 | if map_location is None: 38 | map_location = get_device() 39 | 40 | ckpt = torch.load(path, map_location=map_location) 41 | state_dict = ckpt["state_dict"] 42 | state_dict_to_load = state_dict 43 | 44 | if ignore_modules is not None: 45 | state_dict_to_load = {} 46 | for k, v in state_dict.items(): 47 | ignore = any( 48 | [k.startswith(ignore_module + ".") for ignore_module in ignore_modules] 49 | ) 50 | if ignore: 51 | continue 52 | state_dict_to_load[k] = v 53 | 54 | if module_name is not None: 55 | state_dict_to_load = {} 56 | for k, v in state_dict.items(): 57 | m = re.match(rf"^{module_name}\.(.*)$", k) 58 | if m is None: 59 | continue 60 | state_dict_to_load[m.group(1)] = v 61 | 62 | return state_dict_to_load, ckpt["epoch"], ckpt["global_step"] 63 | 64 | 65 | def C(value: Any, epoch: int, global_step: int) -> float: 66 | if isinstance(value, int) or isinstance(value, float): 67 | pass 68 | else: 69 | value = config_to_primitive(value) 70 | if not isinstance(value, list): 71 | raise TypeError("Scalar specification only supports list, got", type(value)) 72 | if len(value) == 3: 73 | value = [0] + value 74 | assert len(value) == 4 75 | start_step, start_value, end_value, end_step = value 76 | if isinstance(end_step, int): 77 | current_step = global_step 78 | value = start_value + (end_value - start_value) * max( 79 | min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 80 | ) 81 | elif isinstance(end_step, float): 82 | current_step = epoch 83 | value = start_value + (end_value - start_value) * max( 84 | min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 85 | ) 86 | return value 87 | 88 | 89 | def cleanup(): 90 | gc.collect() 91 | torch.cuda.empty_cache() 92 | tcnn.free_temporary_memory() 93 | 94 | 95 | def finish_with_cleanup(func: Callable): 96 | def wrapper(*args, **kwargs): 97 | out = func(*args, **kwargs) 98 | cleanup() 99 | return out 100 | 101 | return wrapper 102 | 103 | 104 | def _distributed_available(): 105 | return torch.distributed.is_available() and torch.distributed.is_initialized() 106 | 107 | 108 | def barrier(): 109 | if not _distributed_available(): 110 | return 111 | else: 112 | torch.distributed.barrier() 113 | 114 | 115 | def broadcast(tensor, src=0): 116 | if not _distributed_available(): 117 | return tensor 118 | else: 119 | torch.distributed.broadcast(tensor, src=src) 120 | return tensor 121 | -------------------------------------------------------------------------------- /threestudio/utils/perceptual/__init__.py: -------------------------------------------------------------------------------- 1 | from .perceptual import PerceptualLoss 2 | -------------------------------------------------------------------------------- /threestudio/utils/perceptual/perceptual.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | from collections import namedtuple 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from threestudio.utils.perceptual.utils import get_ckpt_path 10 | 11 | 12 | class PerceptualLoss(nn.Module): 13 | # Learned perceptual metric 14 | def __init__(self, use_dropout=True): 15 | super().__init__() 16 | self.scaling_layer = ScalingLayer() 17 | self.chns = [64, 128, 256, 512, 512] # vg16 features 18 | self.net = vgg16(pretrained=True, requires_grad=False) 19 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 20 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 21 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 22 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 23 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 24 | self.load_from_pretrained() 25 | for param in self.parameters(): 26 | param.requires_grad = False 27 | 28 | def load_from_pretrained(self, name="vgg_lpips"): 29 | ckpt = get_ckpt_path(name, "threestudio/utils/lpips") 30 | self.load_state_dict( 31 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 32 | ) 33 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 34 | 35 | @classmethod 36 | def from_pretrained(cls, name="vgg_lpips"): 37 | if name != "vgg_lpips": 38 | raise NotImplementedError 39 | model = cls() 40 | ckpt = get_ckpt_path(name) 41 | model.load_state_dict( 42 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 43 | ) 44 | return model 45 | 46 | def forward(self, input, target): 47 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 48 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 49 | feats0, feats1, diffs = {}, {}, {} 50 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 51 | for kk in range(len(self.chns)): 52 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( 53 | outs1[kk] 54 | ) 55 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 56 | 57 | res = [ 58 | spatial_average(lins[kk].model(diffs[kk]), keepdim=True) 59 | for kk in range(len(self.chns)) 60 | ] 61 | val = res[0] 62 | for l in range(1, len(self.chns)): 63 | val += res[l] 64 | return val 65 | 66 | 67 | class ScalingLayer(nn.Module): 68 | def __init__(self): 69 | super(ScalingLayer, self).__init__() 70 | self.register_buffer( 71 | "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] 72 | ) 73 | self.register_buffer( 74 | "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] 75 | ) 76 | 77 | def forward(self, inp): 78 | return (inp - self.shift) / self.scale 79 | 80 | 81 | class NetLinLayer(nn.Module): 82 | """A single linear layer which does a 1x1 conv""" 83 | 84 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 85 | super(NetLinLayer, self).__init__() 86 | layers = ( 87 | [ 88 | nn.Dropout(), 89 | ] 90 | if (use_dropout) 91 | else [] 92 | ) 93 | layers += [ 94 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), 95 | ] 96 | self.model = nn.Sequential(*layers) 97 | 98 | 99 | class vgg16(torch.nn.Module): 100 | def __init__(self, requires_grad=False, pretrained=True): 101 | super(vgg16, self).__init__() 102 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 103 | self.slice1 = torch.nn.Sequential() 104 | self.slice2 = torch.nn.Sequential() 105 | self.slice3 = torch.nn.Sequential() 106 | self.slice4 = torch.nn.Sequential() 107 | self.slice5 = torch.nn.Sequential() 108 | self.N_slices = 5 109 | for x in range(4): 110 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 111 | for x in range(4, 9): 112 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 113 | for x in range(9, 16): 114 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 115 | for x in range(16, 23): 116 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 117 | for x in range(23, 30): 118 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 119 | if not requires_grad: 120 | for param in self.parameters(): 121 | param.requires_grad = False 122 | 123 | def forward(self, X): 124 | h = self.slice1(X) 125 | h_relu1_2 = h 126 | h = self.slice2(h) 127 | h_relu2_2 = h 128 | h = self.slice3(h) 129 | h_relu3_3 = h 130 | h = self.slice4(h) 131 | h_relu4_3 = h 132 | h = self.slice5(h) 133 | h_relu5_3 = h 134 | vgg_outputs = namedtuple( 135 | "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] 136 | ) 137 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 138 | return out 139 | 140 | 141 | def normalize_tensor(x, eps=1e-10): 142 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 143 | return x / (norm_factor + eps) 144 | 145 | 146 | def spatial_average(x, keepdim=True): 147 | return x.mean([2, 3], keepdim=keepdim) 148 | -------------------------------------------------------------------------------- /threestudio/utils/perceptual/utils.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | 4 | import requests 5 | from tqdm import tqdm 6 | 7 | URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} 8 | 9 | CKPT_MAP = {"vgg_lpips": "vgg.pth"} 10 | 11 | MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} 12 | 13 | 14 | def download(url, local_path, chunk_size=1024): 15 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 16 | with requests.get(url, stream=True) as r: 17 | total_size = int(r.headers.get("content-length", 0)) 18 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 19 | with open(local_path, "wb") as f: 20 | for data in r.iter_content(chunk_size=chunk_size): 21 | if data: 22 | f.write(data) 23 | pbar.update(chunk_size) 24 | 25 | 26 | def md5_hash(path): 27 | with open(path, "rb") as f: 28 | content = f.read() 29 | return hashlib.md5(content).hexdigest() 30 | 31 | 32 | def get_ckpt_path(name, root, check=False): 33 | assert name in URL_MAP 34 | path = os.path.join(root, CKPT_MAP[name]) 35 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 36 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 37 | download(URL_MAP[name], path) 38 | md5 = md5_hash(path) 39 | assert md5 == MD5_MAP[name], md5 40 | return path 41 | 42 | 43 | class KeyNotFoundError(Exception): 44 | def __init__(self, cause, keys=None, visited=None): 45 | self.cause = cause 46 | self.keys = keys 47 | self.visited = visited 48 | messages = list() 49 | if keys is not None: 50 | messages.append("Key not found: {}".format(keys)) 51 | if visited is not None: 52 | messages.append("Visited: {}".format(visited)) 53 | messages.append("Cause:\n{}".format(cause)) 54 | message = "\n".join(messages) 55 | super().__init__(message) 56 | 57 | 58 | def retrieve( 59 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 60 | ): 61 | """Given a nested list or dict return the desired value at key expanding 62 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 63 | is done in-place. 64 | 65 | Parameters 66 | ---------- 67 | list_or_dict : list or dict 68 | Possibly nested list or dictionary. 69 | key : str 70 | key/to/value, path like string describing all keys necessary to 71 | consider to get to the desired value. List indices can also be 72 | passed here. 73 | splitval : str 74 | String that defines the delimiter between keys of the 75 | different depth levels in `key`. 76 | default : obj 77 | Value returned if :attr:`key` is not found. 78 | expand : bool 79 | Whether to expand callable nodes on the path or not. 80 | 81 | Returns 82 | ------- 83 | The desired value or if :attr:`default` is not ``None`` and the 84 | :attr:`key` is not found returns ``default``. 85 | 86 | Raises 87 | ------ 88 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 89 | ``None``. 90 | """ 91 | 92 | keys = key.split(splitval) 93 | 94 | success = True 95 | try: 96 | visited = [] 97 | parent = None 98 | last_key = None 99 | for key in keys: 100 | if callable(list_or_dict): 101 | if not expand: 102 | raise KeyNotFoundError( 103 | ValueError( 104 | "Trying to get past callable node with expand=False." 105 | ), 106 | keys=keys, 107 | visited=visited, 108 | ) 109 | list_or_dict = list_or_dict() 110 | parent[last_key] = list_or_dict 111 | 112 | last_key = key 113 | parent = list_or_dict 114 | 115 | try: 116 | if isinstance(list_or_dict, dict): 117 | list_or_dict = list_or_dict[key] 118 | else: 119 | list_or_dict = list_or_dict[int(key)] 120 | except (KeyError, IndexError, ValueError) as e: 121 | raise KeyNotFoundError(e, keys=keys, visited=visited) 122 | 123 | visited += [key] 124 | # final expansion of retrieved value 125 | if expand and callable(list_or_dict): 126 | list_or_dict = list_or_dict() 127 | parent[last_key] = list_or_dict 128 | except KeyNotFoundError as e: 129 | if default is None: 130 | raise e 131 | else: 132 | list_or_dict = default 133 | success = False 134 | 135 | if not pass_success: 136 | return list_or_dict 137 | else: 138 | return list_or_dict, success 139 | 140 | 141 | if __name__ == "__main__": 142 | config = { 143 | "keya": "a", 144 | "keyb": "b", 145 | "keyc": { 146 | "cc1": 1, 147 | "cc2": 2, 148 | }, 149 | } 150 | from omegaconf import OmegaConf 151 | 152 | config = OmegaConf.create(config) 153 | print(config) 154 | retrieve(config, "keya") 155 | -------------------------------------------------------------------------------- /threestudio/utils/rasterize.py: -------------------------------------------------------------------------------- 1 | import nvdiffrast.torch as dr 2 | import torch 3 | 4 | from threestudio.utils.typing import * 5 | 6 | 7 | class NVDiffRasterizerContext: 8 | def __init__(self, context_type: str, device: torch.device) -> None: 9 | self.device = device 10 | self.ctx = self.initialize_context(context_type, device) 11 | 12 | def initialize_context( 13 | self, context_type: str, device: torch.device 14 | ) -> Union[dr.RasterizeGLContext, dr.RasterizeCudaContext]: 15 | if context_type == "gl": 16 | return dr.RasterizeGLContext(device=device) 17 | elif context_type == "cuda": 18 | return dr.RasterizeCudaContext(device=device) 19 | else: 20 | raise ValueError(f"Unknown rasterizer context type: {context_type}") 21 | 22 | def vertex_transform( 23 | self, verts: Float[Tensor, "Nv 3"], mvp_mtx: Float[Tensor, "B 4 4"] 24 | ) -> Float[Tensor, "B Nv 4"]: 25 | verts_homo = torch.cat( 26 | [verts, torch.ones([verts.shape[0], 1]).to(verts)], dim=-1 27 | ) 28 | return torch.matmul(verts_homo, mvp_mtx.permute(0, 2, 1)) 29 | 30 | def rasterize( 31 | self, 32 | pos: Float[Tensor, "B Nv 4"], 33 | tri: Integer[Tensor, "Nf 3"], 34 | resolution: Union[int, Tuple[int, int]], 35 | ): 36 | # rasterize in instance mode (single topology) 37 | return dr.rasterize(self.ctx, pos.float(), tri.int(), resolution, grad_db=True) 38 | 39 | def rasterize_one( 40 | self, 41 | pos: Float[Tensor, "Nv 4"], 42 | tri: Integer[Tensor, "Nf 3"], 43 | resolution: Union[int, Tuple[int, int]], 44 | ): 45 | # rasterize one single mesh under a single viewpoint 46 | rast, rast_db = self.rasterize(pos[None, ...], tri, resolution) 47 | return rast[0], rast_db[0] 48 | 49 | def antialias( 50 | self, 51 | color: Float[Tensor, "B H W C"], 52 | rast: Float[Tensor, "B H W 4"], 53 | pos: Float[Tensor, "B Nv 4"], 54 | tri: Integer[Tensor, "Nf 3"], 55 | ) -> Float[Tensor, "B H W C"]: 56 | return dr.antialias(color.float(), rast, pos.float(), tri.int()) 57 | 58 | def interpolate( 59 | self, 60 | attr: Float[Tensor, "B Nv C"], 61 | rast: Float[Tensor, "B H W 4"], 62 | tri: Integer[Tensor, "Nf 3"], 63 | rast_db=None, 64 | diff_attrs=None, 65 | ) -> Float[Tensor, "B H W C"]: 66 | return dr.interpolate( 67 | attr.float(), rast, tri.int(), rast_db=rast_db, diff_attrs=diff_attrs 68 | ) 69 | 70 | def interpolate_one( 71 | self, 72 | attr: Float[Tensor, "Nv C"], 73 | rast: Float[Tensor, "B H W 4"], 74 | tri: Integer[Tensor, "Nf 3"], 75 | rast_db=None, 76 | diff_attrs=None, 77 | ) -> Float[Tensor, "B H W C"]: 78 | return self.interpolate(attr[None, ...], rast, tri, rast_db, diff_attrs) 79 | -------------------------------------------------------------------------------- /threestudio/utils/typing.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains type annotations for the project, using 3 | 1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects 4 | 2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors 5 | 6 | Two types of typing checking can be used: 7 | 1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode) 8 | 2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking) 9 | """ 10 | 11 | # Basic types 12 | from typing import ( 13 | Any, 14 | Callable, 15 | Dict, 16 | Iterable, 17 | List, 18 | Literal, 19 | NamedTuple, 20 | NewType, 21 | Optional, 22 | Sized, 23 | Tuple, 24 | Type, 25 | TypeVar, 26 | Union, 27 | ) 28 | 29 | # Tensor dtype 30 | # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md 31 | from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt 32 | 33 | # Config type 34 | from omegaconf import DictConfig 35 | 36 | # PyTorch Tensor type 37 | from torch import Tensor 38 | 39 | # Runtime type checking decorator 40 | from typeguard import typechecked as typechecker 41 | --------------------------------------------------------------------------------