├── DOCUMENTATION.md
├── LICENSE
├── README.md
├── SMPL_poses
├── 0004_smpl.pkl
├── 0069_smpl.pkl
├── 0073_smpl.pkl
├── 0159_smpl.pkl
├── 0167_smpl.pkl
├── 0188_smpl.pkl
├── 0279_smpl.pkl
├── 0308_smpl.pkl
├── 0392_smpl.pkl
├── 0484_smpl.pkl
└── 0501_smpl.pkl
├── configs
└── dreamavatar.yaml
├── docker
├── Dockerfile
└── compose.yaml
├── docs
├── DreamAvatar-supp.pdf
├── index.html
└── static
│ ├── css
│ ├── bulma-carousel.min.css
│ ├── bulma-slider.min.css
│ ├── bulma.css.map.txt
│ ├── bulma.min.css
│ ├── fontawesome.all.min.css
│ └── index.css
│ ├── gif
│ ├── clown.gif
│ ├── deadpool.gif
│ ├── joker.gif
│ └── link.gif
│ ├── js
│ ├── bulma-carousel.js
│ ├── bulma-carousel.min.js
│ ├── bulma-slider.js
│ ├── bulma-slider.min.js
│ ├── fontawesome.all.min.js
│ ├── index.js
│ └── result.js
│ └── video
│ ├── Pipeline-n.png
│ ├── canonical
│ ├── Alien.mp4
│ ├── Buddhist_monk.mp4
│ ├── C-3PO.mp4
│ ├── Crystal_maiden.mp4
│ ├── Electro.mp4
│ ├── Flash.mp4
│ ├── Groot.mp4
│ ├── Joker.mp4
│ ├── Link.mp4
│ ├── Link_2.mp4
│ ├── Luffy.mp4
│ ├── Spiderman.mp4
│ ├── Track_field_athlete.mp4
│ ├── Wonder_woman.mp4
│ ├── Woody.mp4
│ ├── Woody_in_joker.mp4
│ ├── body_builder.mp4
│ ├── clown.mp4
│ ├── hipster_man.mp4
│ ├── kakashi.mp4
│ ├── sasuke.mp4
│ └── woman_hippie.mp4
│ ├── integration
│ ├── clown.mp4
│ ├── deadpool.mp4
│ ├── joker.mp4
│ └── link.mp4
│ ├── poses
│ ├── Flash-00576.mp4
│ ├── Flash-00596.mp4
│ ├── groot-00084.mp4
│ ├── groot-00230.mp4
│ ├── groot-00308.mp4
│ ├── groot-00350.mp4
│ ├── joker-00296.mp4
│ ├── joker-00510.mp4
│ ├── joker-00530.mp4
│ ├── joker-00536.mp4
│ ├── spiderman-00028.mp4
│ └── spiderman-0279.mp4
│ ├── shapes
│ ├── groot-0-1.mp4
│ ├── groot-0-3.mp4
│ ├── groot-1+2.mp4
│ └── groot-1-2.mp4
│ └── text_manipulation
│ ├── joker-black.mp4
│ ├── joker-green.mp4
│ ├── joker-pink.mp4
│ └── joker-texudo.mp4
├── environment.yml
├── launch.py
├── load
├── make_prompt_library.py
└── prompt_library.json
├── requirements.txt
└── threestudio
├── __init__.py
├── data
├── __init__.py
├── co3d.py
├── image.py
├── multiview.py
└── uncond.py
├── models
├── __init__.py
├── background
│ ├── __init__.py
│ ├── base.py
│ └── neural_environment_map_background.py
├── exporters
│ ├── __init__.py
│ ├── base.py
│ └── mesh_exporter.py
├── geometry
│ ├── __init__.py
│ ├── base.py
│ ├── implicit_volume.py
│ └── inv_deformation.py
├── guidance
│ ├── __init__.py
│ ├── stable_diffusion_guidance.py
│ └── stable_diffusion_vsd_guidance.py
├── isosurface.py
├── materials
│ ├── __init__.py
│ ├── base.py
│ └── no_material.py
├── mesh.py
├── networks.py
├── prompt_processors
│ ├── __init__.py
│ ├── base.py
│ └── stable_diffusion_prompt_processor.py
└── renderers
│ ├── __init__.py
│ ├── base.py
│ └── nerf_volume_renderer.py
├── systems
├── __init__.py
├── base.py
├── dreamavatar.py
├── optimizers.py
└── utils.py
└── utils
├── GAN
├── attention.py
├── discriminator.py
├── distribution.py
├── loss.py
├── mobilenet.py
├── network_util.py
├── util.py
└── vae.py
├── __init__.py
├── base.py
├── callbacks.py
├── config.py
├── misc.py
├── ops.py
├── perceptual
├── __init__.py
├── perceptual.py
└── utils.py
├── rasterize.py
├── saving.py
└── typing.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # DreamAvatar: Text-and-Shape Guided 3D Human Avatar Generation via Diffusion Models
4 |
5 |
Yukang Cao\*,
6 |
Yan-Pei Cao\*,
7 |
Kai Han†,
8 |
Ying Shan,
9 |
Kwan-Yee K. Wong†
10 |
11 |
12 | [](https://arxiv.org/abs/2304.00916)
13 |

14 |
15 |

16 |

17 |

18 |

19 |
20 | Please refer to our webpage for more visualizations.
21 |
22 |
23 | ## Abstract
24 | We present **DreamAvatar**, a text-and-shape guided framework for generating high-quality 3D human avatars with controllable poses. While encouraging results have been reported by recent methods on text-guided 3D common object generation, generating high-quality human avatars remains an open challenge due to the complexity of the human body's shape, pose, and appearance. We propose DreamAvatar to tackle this challenge, which utilizes a trainable NeRF for predicting density and color for 3D points and pretrained text-to-image diffusion models for providing 2D self-supervision. Specifically, we leverage the SMPL model to provide shape and pose guidance for the generation. We introduce a dual-observation-space design that involves the joint optimization of a canonical space and a posed space that are related by a learnable deformation field. This facilitates the generation of more complete textures and geometry faithful to the target pose. We also jointly optimize the losses computed from the full body and from the zoomed-in 3D head to alleviate the common multi-face ''Janus'' problem and improve facial details in the generated avatars. Extensive evaluations demonstrate that DreamAvatar significantly outperforms existing methods, establishing a new state-of-the-art for text-and-shape guided 3D human avatar generation.
25 |
26 |
27 |

28 |
29 |
30 | ## Installation
31 |
32 | See [installation.md](docs/installation.md) for additional information, including installation via Docker.
33 |
34 | The following steps have been tested on Ubuntu20.04.
35 |
36 | - You must have an NVIDIA graphics card with at least 48GB VRAM and have [CUDA](https://developer.nvidia.com/cuda-downloads) installed.
37 | - Install `Python >= 3.8`.
38 | - (Optional, Recommended) Create a virtual environment:
39 |
40 | ```sh
41 | python3 -m virtualenv venv
42 | . venv/bin/activate
43 |
44 | # Newer pip versions, e.g. pip-23.x, can be much faster than old versions, e.g. pip-20.x.
45 | # For instance, it caches the wheels of git packages to avoid unnecessarily rebuilding them later.
46 | python3 -m pip install --upgrade pip
47 | ```
48 |
49 | - Install `PyTorch >= 1.12`. We have tested on `torch1.12.1+cu113` and `torch2.0.0+cu118`, but other versions should also work fine.
50 |
51 | ```sh
52 | # torch1.12.1+cu113
53 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
54 | # or torch2.0.0+cu118
55 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
56 | ```
57 |
58 | - Install [pytorch3d](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md) and [kaolin](https://kaolin.readthedocs.io/en/latest/notes/installation.html).
59 |
60 | ```
61 | pip install git+https://github.com/facebookresearch/pytorch3d.git@v0.7.1
62 | pip install git+https://github.com/NVIDIAGameWorks/kaolin.git
63 | ```
64 |
65 | - (Optional, Recommended) Install ninja to speed up the compilation of CUDA extensions:
66 |
67 | ```sh
68 | pip install ninja
69 | ```
70 |
71 | - Install dependencies:
72 |
73 | ```sh
74 | pip install -r requirements.txt
75 | ```
76 |
77 | ## Download SMPL-X model
78 |
79 | * Please kindly follow instructions in [SMPL-X](https://smpl-x.is.tue.mpg.de/) to download required model.
80 | ```
81 | ./smpl_data
82 | ├── SMPLX_NEUTRAL.pkl
83 | ├── SMPLX_FEMALE.pkl
84 | ├── SMPLX_MALE.pkl
85 | ```
86 |
87 | ## Training canonical DreamAvatar
88 |
89 | ```sh
90 | # avatar generation with 512x512 NeRF rendering, ~48GB VRAM
91 | python launch.py --config configs/dreamavatar.yaml --train --gpu 0 system.prompt_processor.prompt="Wonder Woman"
92 | # if you don't have enough VRAM, try training with 64x64 NeRF rendering
93 | python launch.py --config configs/dreamavatar.yaml --train --gpu 0 system.prompt_processor.prompt="Wonder Woman" data.width=64 data.height=64 data.batch_size=1
94 | ```
95 |
96 | ### Resume from checkpoints
97 |
98 | If you want to resume from a checkpoint, do:
99 |
100 | ```sh
101 | # resume training from the last checkpoint, you may replace last.ckpt with any other checkpoints
102 | python launch.py --config path/to/trial/dir/configs/parsed.yaml --train --gpu 0 resume=path/to/trial/dir/ckpts/last.ckpt
103 | # if the training has completed, you can still continue training for a longer time by setting trainer.max_steps
104 | python launch.py --config path/to/trial/dir/configs/parsed.yaml --train --gpu 0 resume=path/to/trial/dir/ckpts/last.ckpt trainer.max_steps=20000
105 | # you can also perform testing using resumed checkpoints
106 | python launch.py --config path/to/trial/dir/configs/parsed.yaml --test --gpu 0 resume=path/to/trial/dir/ckpts/last.ckpt
107 | # note that the above commands use parsed configuration files from previous trials
108 | # which will continue using the same trial directory
109 | # if you want to save to a new trial directory, replace parsed.yaml with raw.yaml in the command
110 |
111 | # only load weights from saved checkpoint but dont resume training (i.e. dont load optimizer state):
112 | python launch.py --config path/to/trial/dir/configs/parsed.yaml --train --gpu 0 system.weights=path/to/trial/dir/ckpts/last.ckpt
113 | ```
114 |
115 | ## Bibtex.
116 | If you want to cite our work, please use the following bib entry:
117 | ```
118 | @inproceedings{cao2024dreamavatar,
119 | title={Dreamavatar: Text-and-shape guided 3d human avatar generation via diffusion models},
120 | author={Cao, Yukang and Cao, Yan-Pei and Han, Kai and Shan, Ying and Wong, Kwan-Yee~K.},
121 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
122 | pages={958--968},
123 | year={2024}
124 | }
125 | ```
126 |
127 | ## Acknowledgement
128 | Thanks to the brilliant works from [Threestudio](https://github.com/threestudio-project/threestudio) and [Stable-DreamFusion](https://github.com/ashawkey/stable-dreamfusion)
129 |
--------------------------------------------------------------------------------
/SMPL_poses/0004_smpl.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/SMPL_poses/0004_smpl.pkl
--------------------------------------------------------------------------------
/SMPL_poses/0069_smpl.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/SMPL_poses/0069_smpl.pkl
--------------------------------------------------------------------------------
/SMPL_poses/0073_smpl.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/SMPL_poses/0073_smpl.pkl
--------------------------------------------------------------------------------
/SMPL_poses/0159_smpl.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/SMPL_poses/0159_smpl.pkl
--------------------------------------------------------------------------------
/SMPL_poses/0167_smpl.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/SMPL_poses/0167_smpl.pkl
--------------------------------------------------------------------------------
/SMPL_poses/0188_smpl.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/SMPL_poses/0188_smpl.pkl
--------------------------------------------------------------------------------
/SMPL_poses/0279_smpl.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/SMPL_poses/0279_smpl.pkl
--------------------------------------------------------------------------------
/SMPL_poses/0308_smpl.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/SMPL_poses/0308_smpl.pkl
--------------------------------------------------------------------------------
/SMPL_poses/0392_smpl.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/SMPL_poses/0392_smpl.pkl
--------------------------------------------------------------------------------
/SMPL_poses/0484_smpl.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/SMPL_poses/0484_smpl.pkl
--------------------------------------------------------------------------------
/SMPL_poses/0501_smpl.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/SMPL_poses/0501_smpl.pkl
--------------------------------------------------------------------------------
/configs/dreamavatar.yaml:
--------------------------------------------------------------------------------
1 | name: "dreamavatar"
2 | tag: "${rmspace:${system.prompt_processor.prompt},_}"
3 | exp_root_dir: "outputs"
4 | seed: 0
5 |
6 | data_type: "random-camera-datamodule"
7 | data:
8 | batch_size: [1, 1]
9 | # 0-4999: 64x64, >=5000: 512x512
10 | # this drastically reduces VRAM usage as empty space is pruned in early training
11 | width: [64, 512]
12 | height: [64, 512]
13 | resolution_milestones: [5000]
14 | camera_distance_range: [1.0, 1.5]
15 | fovy_range: [40, 70]
16 | elevation_range: [-10, 45]
17 | camera_perturb: 0.
18 | center_perturb: 0.
19 | up_perturb: 0.
20 | eval_camera_distance: 1.5
21 | eval_fovy_deg: 70.
22 |
23 | system_type: "dreamavatar-system"
24 | system:
25 | stage: coarse
26 | geometry_type: "implicit-volume"
27 | geometry:
28 | radius: 1.0
29 | normal_type: "pred"
30 |
31 | density_bias: "blob_magic3d"
32 | density_activation: softplus
33 | density_blob_scale: 10.
34 | density_blob_std: 0.5
35 |
36 | pos_encoding_config:
37 | otype: HashGrid
38 | n_levels: 16
39 | n_features_per_level: 2
40 | log2_hashmap_size: 19
41 | base_resolution: 16
42 | per_level_scale: 1.447269237440378 # max resolution 4096
43 |
44 | material_type: "no-material"
45 | material:
46 | n_output_dims: 3
47 | color_activation: sigmoid
48 |
49 | background_type: "neural-environment-map-background"
50 | background:
51 | color_activation: sigmoid
52 | random_aug: true
53 |
54 | renderer_type: "nerf-volume-renderer"
55 | renderer:
56 | radius: ${system.geometry.radius}
57 | num_samples_per_ray: 512
58 |
59 | # prompt_processor_type: "stable-diffusion-prompt-processor"
60 | # prompt_processor:
61 | # pretrained_model_name_or_path: "/root/autodl-tmp/huggingface/hub/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06"
62 | # prompt: ???
63 |
64 | # guidance_type: "stable-diffusion-guidance"
65 | # guidance:
66 | # pretrained_model_name_or_path: "/root/autodl-tmp/huggingface/hub/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06"
67 | # guidance_scale: 100.
68 | # weighting_strategy: sds
69 |
70 | prompt_processor_type: "stable-diffusion-prompt-processor"
71 | prompt_processor:
72 | pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base"
73 | prompt: ???
74 | front_threshold: 30.
75 | back_threshold: 30.
76 |
77 | guidance_type: "stable-diffusion-vsd-guidance"
78 | guidance:
79 | pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base"
80 | pretrained_model_name_or_path_lora: "stabilityai/stable-diffusion-2-1"
81 | guidance_scale: 7.5
82 | min_step_percent: 0.02
83 | max_step_percent: [5000, 0.98, 0.5, 5001] # annealed to 0.5 after 5000 steps
84 |
85 | loggers:
86 | wandb:
87 | enable: false
88 | project: "threestudio"
89 | name: None
90 |
91 | loss:
92 | lambda_vsd: 1.
93 | lambda_lora: 1.
94 | lambda_orient: 0.
95 | lambda_sparsity: 10.
96 | lambda_opaque: [10000, 0.0, 1000.0, 10001]
97 | lambda_z_variance: 0.
98 | optimizer:
99 | name: AdamW
100 | args:
101 | betas: [0.9, 0.99]
102 | eps: 1.e-15
103 | params:
104 | geometry.encoding:
105 | lr: 0.01
106 | geometry.density_network:
107 | lr: 0.001
108 | geometry.feature_network:
109 | lr: 0.001
110 | background:
111 | lr: 0.001
112 | guidance:
113 | lr: 0.0001
114 |
115 | trainer:
116 | max_steps: 10000
117 | log_every_n_steps: 1
118 | num_sanity_val_steps: 0
119 | val_check_interval: 200
120 | enable_progress_bar: true
121 | precision: 32
122 |
123 | checkpoint:
124 | save_last: true
125 | save_top_k: -1
126 | every_n_train_steps: ${trainer.max_steps}
127 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | # Reference:
2 | # https://github.com/cvpaperchallenge/Ascender
3 | # https://github.com/nerfstudio-project/nerfstudio
4 |
5 | FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
6 |
7 | ARG USER_NAME=dreamer
8 | ARG GROUP_NAME=dreamers
9 | ARG UID=1000
10 | ARG GID=1000
11 |
12 | # Set compute capability for nerfacc and tiny-cuda-nn
13 | # See https://developer.nvidia.com/cuda-gpus and limit number to speed-up build
14 | ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX"
15 | ENV TCNN_CUDA_ARCHITECTURES=90;89;86;80;75;70;61;60
16 | # Speed-up build for RTX 30xx
17 | # ENV TORCH_CUDA_ARCH_LIST="8.6"
18 | # ENV TCNN_CUDA_ARCHITECTURES=86
19 | # Speed-up build for RTX 40xx
20 | # ENV TORCH_CUDA_ARCH_LIST="8.9"
21 | # ENV TCNN_CUDA_ARCHITECTURES=89
22 |
23 | ENV CUDA_HOME=/usr/local/cuda
24 | ENV PATH=${CUDA_HOME}/bin:/home/${USER_NAME}/.local/bin:${PATH}
25 | ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
26 | ENV LIBRARY_PATH=${CUDA_HOME}/lib64/stubs:${LIBRARY_PATH}
27 |
28 | # apt install by root user
29 | RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
30 | build-essential \
31 | curl \
32 | git \
33 | libegl1-mesa-dev \
34 | libgl1-mesa-dev \
35 | libgles2-mesa-dev \
36 | libglib2.0-0 \
37 | libsm6 \
38 | libxext6 \
39 | libxrender1 \
40 | python-is-python3 \
41 | python3.10-dev \
42 | python3-pip \
43 | wget \
44 | && rm -rf /var/lib/apt/lists/*
45 |
46 | # Change user to non-root user
47 | RUN groupadd -g ${GID} ${GROUP_NAME} \
48 | && useradd -ms /bin/sh -u ${UID} -g ${GID} ${USER_NAME}
49 | USER ${USER_NAME}
50 |
51 | RUN pip install --upgrade pip setuptools ninja
52 | RUN pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
53 | # Install nerfacc and tiny-cuda-nn before installing requirements.txt
54 | # because these two installations are time consuming and error prone
55 | RUN pip install git+https://github.com/KAIR-BAIR/nerfacc.git@v0.5.2
56 | RUN pip install git+https://github.com/NVlabs/tiny-cuda-nn.git#subdirectory=bindings/torch
57 |
58 | COPY requirements.txt /tmp
59 | RUN cd /tmp && pip install -r requirements.txt
60 | WORKDIR /home/${USER_NAME}/threestudio
61 |
--------------------------------------------------------------------------------
/docker/compose.yaml:
--------------------------------------------------------------------------------
1 | services:
2 | threestudio:
3 | build:
4 | context: ../
5 | dockerfile: docker/Dockerfile
6 | args:
7 | # you can set environment variables, otherwise default values will be used
8 | USER_NAME: ${HOST_USER_NAME:-dreamer} # export HOST_USER_NAME=$USER
9 | GROUP_NAME: ${HOST_GROUP_NAME:-dreamers}
10 | UID: ${HOST_UID:-1000} # export HOST_UID=$(id -u)
11 | GID: ${HOST_GID:-1000} # export HOST_GID=$(id -g)
12 | shm_size: '4gb'
13 | environment:
14 | NVIDIA_DISABLE_REQUIRE: 1 # avoid wrong `nvidia-container-cli: requirement error`
15 | tty: true
16 | volumes:
17 | - ../:/home/${HOST_USER_NAME:-dreamer}/threestudio
18 | deploy:
19 | resources:
20 | reservations:
21 | devices:
22 | - driver: nvidia
23 | capabilities: [gpu]
24 |
--------------------------------------------------------------------------------
/docs/DreamAvatar-supp.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/DreamAvatar-supp.pdf
--------------------------------------------------------------------------------
/docs/static/css/bulma-carousel.min.css:
--------------------------------------------------------------------------------
1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:.5rem 1rem;text-align:center}.slider-pagination .slider-page{background:#fff;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10}
--------------------------------------------------------------------------------
/docs/static/css/index.css:
--------------------------------------------------------------------------------
1 | body {
2 | font-family: 'Noto Sans', sans-serif;
3 | }
4 |
5 |
6 | .footer .icon-link {
7 | font-size: 25px;
8 | color: #000;
9 | }
10 |
11 | .link-block a {
12 | margin-top: 5px;
13 | margin-bottom: 5px;
14 | }
15 |
16 | .dnerf {
17 | font-variant: small-caps;
18 | }
19 |
20 |
21 | .teaser .hero-body {
22 | padding-top: 0;
23 | padding-bottom: 3rem;
24 | }
25 |
26 | .teaser {
27 | font-family: 'Google Sans', sans-serif;
28 | }
29 |
30 |
31 | .publication-title {
32 | }
33 |
34 | .publication-banner {
35 | max-height: parent;
36 |
37 | }
38 |
39 | .publication-banner video {
40 | position: relative;
41 | left: auto;
42 | top: auto;
43 | transform: none;
44 | object-fit: fit;
45 | }
46 |
47 | .publication-header .hero-body {
48 | }
49 |
50 | .publication-title {
51 | font-family: 'Google Sans', sans-serif;
52 | }
53 |
54 | .publication-authors {
55 | font-family: 'Google Sans', sans-serif;
56 | }
57 |
58 | .publication-venue {
59 | color: #555;
60 | width: fit-content;
61 | font-weight: bold;
62 | }
63 |
64 | .publication-awards {
65 | color: #ff3860;
66 | width: fit-content;
67 | font-weight: bolder;
68 | }
69 |
70 | .publication-authors {
71 | }
72 |
73 | .publication-authors a {
74 | color: hsl(204, 86%, 53%) !important;
75 | }
76 |
77 | .publication-authors a:hover {
78 | text-decoration: underline;
79 | }
80 |
81 | .author-block {
82 | display: inline-block;
83 | }
84 |
85 | .publication-banner img {
86 | }
87 |
88 | .publication-authors {
89 | /*color: #4286f4;*/
90 | }
91 |
92 | .publication-video {
93 | position: relative;
94 | width: 100%;
95 | height: 0;
96 | padding-bottom: 56.25%;
97 |
98 | overflow: hidden;
99 | border-radius: 10px !important;
100 | }
101 |
102 | .publication-video iframe {
103 | position: absolute;
104 | top: 0;
105 | left: 0;
106 | width: 100%;
107 | height: 100%;
108 | }
109 |
110 | .publication-body img {
111 | }
112 |
113 | .results-carousel {
114 | overflow: hidden;
115 | }
116 |
117 | .results-carousel .item {
118 | margin: 5px;
119 | overflow: hidden;
120 | border: 1px solid #bbb;
121 | border-radius: 10px;
122 | padding: 0;
123 | font-size: 0;
124 | }
125 |
126 | .results-carousel video {
127 | margin: 0;
128 | }
129 |
130 |
131 | .interpolation-panel {
132 | background: #f5f5f5;
133 | border-radius: 10px;
134 | }
135 |
136 | .interpolation-panel .interpolation-image {
137 | width: 100%;
138 | border-radius: 5px;
139 | }
140 |
141 | .interpolation-video-column {
142 | }
143 |
144 | .interpolation-panel .slider {
145 | margin: 0 !important;
146 | }
147 |
148 | .interpolation-panel .slider {
149 | margin: 0 !important;
150 | }
151 |
152 | #interpolation-image-wrapper {
153 | width: 100%;
154 | }
155 | #interpolation-image-wrapper img {
156 | border-radius: 5px;
157 | }
158 |
--------------------------------------------------------------------------------
/docs/static/gif/clown.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/gif/clown.gif
--------------------------------------------------------------------------------
/docs/static/gif/deadpool.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/gif/deadpool.gif
--------------------------------------------------------------------------------
/docs/static/gif/joker.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/gif/joker.gif
--------------------------------------------------------------------------------
/docs/static/gif/link.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/gif/link.gif
--------------------------------------------------------------------------------
/docs/static/js/bulma-slider.min.js:
--------------------------------------------------------------------------------
1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default});
--------------------------------------------------------------------------------
/docs/static/js/index.js:
--------------------------------------------------------------------------------
1 | window.HELP_IMPROVE_VIDEOJS = false;
2 |
3 | var INTERP_BASE = "./static/interpolation/stacked";
4 | var NUM_INTERP_FRAMES = 240;
5 |
6 | var interp_images = [];
7 | function preloadInterpolationImages() {
8 | for (var i = 0; i < NUM_INTERP_FRAMES; i++) {
9 | var path = INTERP_BASE + '/' + String(i).padStart(6, '0') + '.jpg';
10 | interp_images[i] = new Image();
11 | interp_images[i].src = path;
12 | }
13 | }
14 |
15 | function setInterpolationImage(i) {
16 | var image = interp_images[i];
17 | image.ondragstart = function() { return false; };
18 | image.oncontextmenu = function() { return false; };
19 | $('#interpolation-image-wrapper').empty().append(image);
20 | }
21 |
22 |
23 | $(document).ready(function() {
24 | // Check for click events on the navbar burger icon
25 | $(".navbar-burger").click(function() {
26 | // Toggle the "is-active" class on both the "navbar-burger" and the "navbar-menu"
27 | $(".navbar-burger").toggleClass("is-active");
28 | $(".navbar-menu").toggleClass("is-active");
29 |
30 | });
31 |
32 | var options = {
33 | slidesToScroll: 1,
34 | slidesToShow: 3,
35 | loop: true,
36 | infinite: true,
37 | autoplay: false,
38 | autoplaySpeed: 3000,
39 | }
40 |
41 | // Initialize all div with carousel class
42 | var carousels = bulmaCarousel.attach('.carousel', options);
43 |
44 | // Loop on each carousel initialized
45 | for(var i = 0; i < carousels.length; i++) {
46 | // Add listener to event
47 | carousels[i].on('before:show', state => {
48 | console.log(state);
49 | });
50 | }
51 |
52 | // Access to bulmaCarousel instance of an element
53 | var element = document.querySelector('#my-element');
54 | if (element && element.bulmaCarousel) {
55 | // bulmaCarousel instance is available as element.bulmaCarousel
56 | element.bulmaCarousel.on('before-show', function(state) {
57 | console.log(state);
58 | });
59 | }
60 |
61 | /*var player = document.getElementById('interpolation-video');
62 | player.addEventListener('loadedmetadata', function() {
63 | $('#interpolation-slider').on('input', function(event) {
64 | console.log(this.value, player.duration);
65 | player.currentTime = player.duration / 100 * this.value;
66 | })
67 | }, false);*/
68 | preloadInterpolationImages();
69 |
70 | $('#interpolation-slider').on('input', function(event) {
71 | setInterpolationImage(this.value);
72 | });
73 | setInterpolationImage(0);
74 | $('#interpolation-slider').prop('max', NUM_INTERP_FRAMES - 1);
75 |
76 | bulmaSlider.attach();
77 |
78 | })
79 |
--------------------------------------------------------------------------------
/docs/static/js/result.js:
--------------------------------------------------------------------------------
1 | const container = document.querySelector('.container');
2 | const divider = document.querySelector('.divider');
3 | const leftGif = document.querySelector('.gif:first-child');
4 | const rightGif = document.querySelector('.gif:last-child');
5 | let isDragging = false;
6 |
7 | divider.addEventListener('mousedown', function(e) {
8 | isDragging = true;
9 | });
10 |
11 | container.addEventListener('mousemove', function(e) {
12 | if (!isDragging) return;
13 | const containerRect = container.getBoundingClientRect();
14 | const mousePosition = e.clientX - containerRect.left;
15 | const containerWidth = containerRect.width;
16 | const dividerWidth = divider.offsetWidth;
17 | const minLeft = 0;
18 | const maxLeft = containerWidth - dividerWidth;
19 | const newLeft = Math.max(minLeft, Math.min(maxLeft, mousePosition));
20 | const leftWidth = newLeft / containerWidth * 100;
21 | const rightWidth = 100 - leftWidth;
22 |
23 | divider.style.left = leftWidth + '%';
24 | leftGif.style.width = leftWidth + '%';
25 | rightGif.style.left = leftWidth + '%';
26 | rightGif.style.width = rightWidth + '%';
27 | });
28 |
29 | document.addEventListener('mouseup', function(e) {
30 | isDragging = false;
31 | });
32 |
--------------------------------------------------------------------------------
/docs/static/video/Pipeline-n.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/Pipeline-n.png
--------------------------------------------------------------------------------
/docs/static/video/canonical/Alien.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Alien.mp4
--------------------------------------------------------------------------------
/docs/static/video/canonical/Buddhist_monk.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Buddhist_monk.mp4
--------------------------------------------------------------------------------
/docs/static/video/canonical/C-3PO.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/C-3PO.mp4
--------------------------------------------------------------------------------
/docs/static/video/canonical/Crystal_maiden.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Crystal_maiden.mp4
--------------------------------------------------------------------------------
/docs/static/video/canonical/Electro.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Electro.mp4
--------------------------------------------------------------------------------
/docs/static/video/canonical/Flash.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Flash.mp4
--------------------------------------------------------------------------------
/docs/static/video/canonical/Groot.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Groot.mp4
--------------------------------------------------------------------------------
/docs/static/video/canonical/Joker.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Joker.mp4
--------------------------------------------------------------------------------
/docs/static/video/canonical/Link.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Link.mp4
--------------------------------------------------------------------------------
/docs/static/video/canonical/Link_2.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Link_2.mp4
--------------------------------------------------------------------------------
/docs/static/video/canonical/Luffy.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Luffy.mp4
--------------------------------------------------------------------------------
/docs/static/video/canonical/Spiderman.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Spiderman.mp4
--------------------------------------------------------------------------------
/docs/static/video/canonical/Track_field_athlete.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Track_field_athlete.mp4
--------------------------------------------------------------------------------
/docs/static/video/canonical/Wonder_woman.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Wonder_woman.mp4
--------------------------------------------------------------------------------
/docs/static/video/canonical/Woody.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Woody.mp4
--------------------------------------------------------------------------------
/docs/static/video/canonical/Woody_in_joker.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/Woody_in_joker.mp4
--------------------------------------------------------------------------------
/docs/static/video/canonical/body_builder.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/body_builder.mp4
--------------------------------------------------------------------------------
/docs/static/video/canonical/clown.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/clown.mp4
--------------------------------------------------------------------------------
/docs/static/video/canonical/hipster_man.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/hipster_man.mp4
--------------------------------------------------------------------------------
/docs/static/video/canonical/kakashi.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/kakashi.mp4
--------------------------------------------------------------------------------
/docs/static/video/canonical/sasuke.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/sasuke.mp4
--------------------------------------------------------------------------------
/docs/static/video/canonical/woman_hippie.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/canonical/woman_hippie.mp4
--------------------------------------------------------------------------------
/docs/static/video/integration/clown.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/integration/clown.mp4
--------------------------------------------------------------------------------
/docs/static/video/integration/deadpool.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/integration/deadpool.mp4
--------------------------------------------------------------------------------
/docs/static/video/integration/joker.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/integration/joker.mp4
--------------------------------------------------------------------------------
/docs/static/video/integration/link.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/integration/link.mp4
--------------------------------------------------------------------------------
/docs/static/video/poses/Flash-00576.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/poses/Flash-00576.mp4
--------------------------------------------------------------------------------
/docs/static/video/poses/Flash-00596.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/poses/Flash-00596.mp4
--------------------------------------------------------------------------------
/docs/static/video/poses/groot-00084.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/poses/groot-00084.mp4
--------------------------------------------------------------------------------
/docs/static/video/poses/groot-00230.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/poses/groot-00230.mp4
--------------------------------------------------------------------------------
/docs/static/video/poses/groot-00308.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/poses/groot-00308.mp4
--------------------------------------------------------------------------------
/docs/static/video/poses/groot-00350.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/poses/groot-00350.mp4
--------------------------------------------------------------------------------
/docs/static/video/poses/joker-00296.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/poses/joker-00296.mp4
--------------------------------------------------------------------------------
/docs/static/video/poses/joker-00510.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/poses/joker-00510.mp4
--------------------------------------------------------------------------------
/docs/static/video/poses/joker-00530.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/poses/joker-00530.mp4
--------------------------------------------------------------------------------
/docs/static/video/poses/joker-00536.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/poses/joker-00536.mp4
--------------------------------------------------------------------------------
/docs/static/video/poses/spiderman-00028.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/poses/spiderman-00028.mp4
--------------------------------------------------------------------------------
/docs/static/video/poses/spiderman-0279.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/poses/spiderman-0279.mp4
--------------------------------------------------------------------------------
/docs/static/video/shapes/groot-0-1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/shapes/groot-0-1.mp4
--------------------------------------------------------------------------------
/docs/static/video/shapes/groot-0-3.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/shapes/groot-0-3.mp4
--------------------------------------------------------------------------------
/docs/static/video/shapes/groot-1+2.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/shapes/groot-1+2.mp4
--------------------------------------------------------------------------------
/docs/static/video/shapes/groot-1-2.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/shapes/groot-1-2.mp4
--------------------------------------------------------------------------------
/docs/static/video/text_manipulation/joker-black.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/text_manipulation/joker-black.mp4
--------------------------------------------------------------------------------
/docs/static/video/text_manipulation/joker-green.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/text_manipulation/joker-green.mp4
--------------------------------------------------------------------------------
/docs/static/video/text_manipulation/joker-pink.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/text_manipulation/joker-pink.mp4
--------------------------------------------------------------------------------
/docs/static/video/text_manipulation/joker-texudo.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukangcao/DreamAvatar/4d6bdbbb75638f7fb87829ee8f2c7f6a6055c629/docs/static/video/text_manipulation/joker-texudo.mp4
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: threestudio
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - _openmp_mutex=5.1=1_gnu
7 | - ca-certificates=2023.01.10=h06a4308_0
8 | - ld_impl_linux-64=2.38=h1181459_1
9 | - libffi=3.4.4=h6a678d5_0
10 | - libgcc-ng=11.2.0=h1234567_1
11 | - libgomp=11.2.0=h1234567_1
12 | - libstdcxx-ng=11.2.0=h1234567_1
13 | - ncurses=6.4=h6a678d5_0
14 | - openssl=1.1.1t=h7f8727e_0
15 | - pip=23.0.1=py38h06a4308_0
16 | - python=3.8.16=h7a1cb2a_3
17 | - readline=8.2=h5eee18b_0
18 | - setuptools=66.0.0=py38h06a4308_0
19 | - sqlite=3.41.2=h5eee18b_0
20 | - tk=8.6.12=h1ccaba5_0
21 | - wheel=0.38.4=py38h06a4308_0
22 | - xz=5.4.2=h5eee18b_0
23 | - zlib=1.2.13=h5eee18b_0
24 | - pip:
25 | - absl-py==1.4.0
26 | - accelerate==0.19.0
27 | - aiofiles==23.1.0
28 | - aiohttp==3.8.4
29 | - aiosignal==1.3.1
30 | - altair==5.0.1
31 | - antlr4-python3-runtime==4.9.3
32 | - anyio==3.6.2
33 | - appdirs==1.4.4
34 | - arrow==1.2.3
35 | - asttokens==2.2.1
36 | - astunparse==1.6.3
37 | - async-timeout==4.0.2
38 | - attrs==23.1.0
39 | - backcall==0.2.0
40 | - beautifulsoup4==4.12.2
41 | - bitsandbytes==0.39.0
42 | - blessed==1.20.0
43 | - blinker==1.6.2
44 | - cachetools==5.3.0
45 | - certifi==2023.5.7
46 | - chardet==5.1.0
47 | - charset-normalizer==3.1.0
48 | - chumpy==0.70
49 | - click==8.1.3
50 | - clip==1.0
51 | - colorlog==6.7.0
52 | - comm==0.1.3
53 | - contourpy==1.0.7
54 | - controlnet-aux==0.0.6
55 | - croniter==1.3.14
56 | - cycler==0.11.0
57 | - dateutils==0.6.12
58 | - debugpy==1.6.7
59 | - decorator==5.1.1
60 | - deepdiff==6.3.0
61 | - diffusers==0.16.1
62 | - distro==1.8.0
63 | - docker-pycreds==0.4.0
64 | - einops==0.6.1
65 | - entrypoints==0.4
66 | - envlight==0.1.0
67 | - executing==1.2.0
68 | - fastapi==0.88.0
69 | - ffmpy==0.3.1
70 | - filelock==3.12.0
71 | - flask==2.3.2
72 | - flatbuffers==23.5.26
73 | - fonttools==4.39.4
74 | - frozenlist==1.3.3
75 | - fsspec==2023.5.0
76 | - ftfy==6.1.1
77 | - fvcore==0.1.5.post20221221
78 | - gast==0.4.0
79 | - gitdb==4.0.10
80 | - gitpython==3.1.31
81 | - google-auth==2.18.1
82 | - google-auth-oauthlib==1.0.0
83 | - google-pasta==0.2.0
84 | - gradio==3.38.0
85 | - gradio-client==0.2.10
86 | - grpcio==1.54.2
87 | - h11==0.14.0
88 | - h5py==3.9.0
89 | - httpcore==0.17.3
90 | - httpx==0.24.1
91 | - huggingface-hub==0.14.1
92 | - idna==3.4
93 | - imageio==2.29.0
94 | - imageio-ffmpeg==0.4.8
95 | - importlib-metadata==6.6.0
96 | - importlib-resources==5.12.0
97 | - inquirer==3.1.3
98 | - iopath==0.1.10
99 | - ipycanvas==0.13.1
100 | - ipyevents==2.0.1
101 | - ipykernel==6.23.1
102 | - ipython==8.12.2
103 | - ipywidgets==8.0.6
104 | - itsdangerous==2.1.2
105 | - jaxtyping==0.2.19
106 | - jedi==0.18.2
107 | - jinja2==3.1.2
108 | - jsonschema==4.18.4
109 | - jsonschema-specifications==2023.7.1
110 | - jupyter-client==7.4.9
111 | - jupyter-core==5.3.0
112 | - jupyterlab-widgets==3.0.7
113 | - kaolin==0.14.0a0
114 | - keras==2.13.1
115 | - kiwisolver==1.4.4
116 | - kornia==0.6.12
117 | - lazy-loader==0.3
118 | - libclang==16.0.6
119 | - libigl==2.4.1
120 | - lightning==2.0.0
121 | - lightning-cloud==0.5.36
122 | - lightning-utilities==0.8.0
123 | - linkify-it-py==2.0.2
124 | - lxml==4.9.3
125 | - mapbox-earcut==1.0.1
126 | - markdown==3.4.3
127 | - markdown-it-py==2.2.0
128 | - markupsafe==2.1.2
129 | - matplotlib==3.7.1
130 | - matplotlib-inline==0.1.6
131 | - mdit-py-plugins==0.3.3
132 | - mdurl==0.1.2
133 | - mpmath==1.3.0
134 | - multidict==6.0.4
135 | - mypy-extensions==1.0.0
136 | - nerfacc==0.5.2
137 | - nest-asyncio==1.5.6
138 | - networkx==3.1
139 | - ninja==1.11.1
140 | - numpy==1.23.0
141 | - nvdiffrast==0.3.1
142 | - oauthlib==3.2.2
143 | - omegaconf==2.3.0
144 | - opencv-python==4.7.0.72
145 | - opt-einsum==3.3.0
146 | - ordered-set==4.1.0
147 | - orjson==3.9.2
148 | - packaging==23.1
149 | - pandas==2.0.3
150 | - parso==0.8.3
151 | - pathtools==0.1.2
152 | - pexpect==4.8.0
153 | - pickleshare==0.7.5
154 | - pillow==9.5.0
155 | - pkgutil-resolve-name==1.3.10
156 | - platformdirs==3.5.1
157 | - plyfile==0.9
158 | - portalocker==2.7.0
159 | - prompt-toolkit==3.0.38
160 | - protobuf==4.23.1
161 | - psutil==5.9.5
162 | - ptyprocess==0.7.0
163 | - pure-eval==0.2.2
164 | - pyasn1==0.5.0
165 | - pyasn1-modules==0.3.0
166 | - pybind11==2.10.4
167 | - pycollada==0.7.2
168 | - pydantic==1.10.8
169 | - pydub==0.25.1
170 | - pygments==2.15.1
171 | - pyjwt==2.7.0
172 | - pymcubes==0.1.4
173 | - pyparsing==3.0.9
174 | - pyre-extensions==0.0.29
175 | - pysdf==0.1.9
176 | - python-dateutil==2.8.2
177 | - python-editor==1.0.4
178 | - python-multipart==0.0.6
179 | - pytorch-lightning==2.0.2
180 | - pytorch3d==0.7.4
181 | - pytz==2023.3
182 | - pywavelets==1.4.1
183 | - pyyaml==6.0
184 | - pyzmq==24.0.1
185 | - readchar==4.0.5
186 | - referencing==0.30.0
187 | - regex==2023.5.5
188 | - requests==2.31.0
189 | - requests-oauthlib==1.3.1
190 | - rich==13.3.5
191 | - rpds-py==0.9.2
192 | - rsa==4.9
193 | - rtree==1.0.1
194 | - safetensors==0.3.1
195 | - scikit-build==0.17.5
196 | - scikit-image==0.21.0
197 | - scipy==1.10.1
198 | - semantic-version==2.10.0
199 | - sentencepiece==0.1.99
200 | - sentry-sdk==1.25.0
201 | - setproctitle==1.3.2
202 | - shapely==2.0.1
203 | - six==1.16.0
204 | - smmap==5.0.0
205 | - smplx==0.1.28
206 | - sniffio==1.3.0
207 | - soupsieve==2.4.1
208 | - stack-data==0.6.2
209 | - starlette==0.22.0
210 | - starsessions==1.3.0
211 | - svg-path==6.3
212 | - sympy==1.12
213 | - tabulate==0.9.0
214 | - taming-transformers-rom1504==0.0.6
215 | - tensorboard==2.13.0
216 | - tensorboard-data-server==0.7.0
217 | - tensorflow==2.13.0
218 | - tensorflow-estimator==2.13.0
219 | - tensorflow-io-gcs-filesystem==0.32.0
220 | - termcolor==2.3.0
221 | - tifffile==2023.7.10
222 | - timm==0.9.2
223 | - tinycudann==1.7
224 | - tokenizers==0.13.3
225 | - tomli==2.0.1
226 | - toolz==0.12.0
227 | - torch==1.12.1+cu113
228 | - torchmetrics==0.11.4
229 | - torchvision==0.13.1+cu113
230 | - tornado==6.3.2
231 | - tqdm==4.65.0
232 | - traitlets==5.9.0
233 | - transformers==4.29.2
234 | - trimesh==3.21.7
235 | - typeguard==4.0.0
236 | - typing-extensions==4.5.0
237 | - typing-inspect==0.9.0
238 | - tzdata==2023.3
239 | - uc-micro-py==1.0.2
240 | - urllib3==1.26.16
241 | - usd-core==23.5
242 | - uvicorn==0.22.0
243 | - wandb==0.15.3
244 | - wcwidth==0.2.6
245 | - websocket-client==1.5.2
246 | - websockets==11.0.3
247 | - werkzeug==2.3.4
248 | - widgetsnbextension==4.0.7
249 | - wrapt==1.15.0
250 | - xatlas==0.0.7
251 | - xformers==0.0.21+efcd789.d20230525
252 | - xxhash==3.2.0
253 | - yacs==0.1.8
254 | - yarl==1.9.2
255 | - zipp==3.15.0
256 | prefix: /data2/ykcao/anaconda3/envs/threestudio
257 |
--------------------------------------------------------------------------------
/launch.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import contextlib
3 | import logging
4 | import os
5 | import sys
6 |
7 |
8 | class ColoredFilter(logging.Filter):
9 | """
10 | A logging filter to add color to certain log levels.
11 | """
12 |
13 | RESET = "\033[0m"
14 | RED = "\033[31m"
15 | GREEN = "\033[32m"
16 | YELLOW = "\033[33m"
17 | BLUE = "\033[34m"
18 | MAGENTA = "\033[35m"
19 | CYAN = "\033[36m"
20 |
21 | COLORS = {
22 | "WARNING": YELLOW,
23 | "INFO": GREEN,
24 | "DEBUG": BLUE,
25 | "CRITICAL": MAGENTA,
26 | "ERROR": RED,
27 | }
28 |
29 | RESET = "\x1b[0m"
30 |
31 | def __init__(self):
32 | super().__init__()
33 |
34 | def filter(self, record):
35 | if record.levelname in self.COLORS:
36 | color_start = self.COLORS[record.levelname]
37 | record.levelname = f"{color_start}[{record.levelname}]"
38 | record.msg = f"{record.msg}{self.RESET}"
39 | return True
40 |
41 |
42 | def main(args, extras) -> None:
43 | # set CUDA_VISIBLE_DEVICES if needed, then import pytorch-lightning
44 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
45 | env_gpus_str = os.environ.get("CUDA_VISIBLE_DEVICES", None)
46 | env_gpus = list(env_gpus_str.split(",")) if env_gpus_str else []
47 | selected_gpus = [0]
48 |
49 | # Always rely on CUDA_VISIBLE_DEVICES if specific GPU ID(s) are specified.
50 | # As far as Pytorch Lightning is concerned, we always use all available GPUs
51 | # (possibly filtered by CUDA_VISIBLE_DEVICES).
52 | devices = -1
53 | if len(env_gpus) > 0:
54 | # CUDA_VISIBLE_DEVICES was set already, e.g. within SLURM srun or higher-level script.
55 | n_gpus = len(env_gpus)
56 | else:
57 | selected_gpus = list(args.gpu.split(","))
58 | n_gpus = len(selected_gpus)
59 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
60 |
61 | import pytorch_lightning as pl
62 | import torch
63 | from pytorch_lightning import Trainer
64 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
65 | from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
66 | from pytorch_lightning.utilities.rank_zero import rank_zero_only
67 |
68 | if args.typecheck:
69 | from jaxtyping import install_import_hook
70 |
71 | install_import_hook("threestudio", "typeguard.typechecked")
72 |
73 | import threestudio
74 | from threestudio.systems.base import BaseSystem
75 | from threestudio.utils.callbacks import (
76 | CodeSnapshotCallback,
77 | ConfigSnapshotCallback,
78 | CustomProgressBar,
79 | ProgressCallback,
80 | )
81 | from threestudio.utils.config import ExperimentConfig, load_config
82 | from threestudio.utils.misc import get_rank
83 | from threestudio.utils.typing import Optional
84 |
85 | logger = logging.getLogger("pytorch_lightning")
86 | if args.verbose:
87 | logger.setLevel(logging.DEBUG)
88 |
89 | for handler in logger.handlers:
90 | if handler.stream == sys.stderr: # type: ignore
91 | if not args.gradio:
92 | handler.setFormatter(logging.Formatter("%(levelname)s %(message)s"))
93 | handler.addFilter(ColoredFilter())
94 | else:
95 | handler.setFormatter(logging.Formatter("[%(levelname)s] %(message)s"))
96 |
97 | # parse YAML config to OmegaConf
98 | cfg: ExperimentConfig
99 | cfg = load_config(args.config, cli_args=extras, n_gpus=n_gpus)
100 |
101 | # set a different seed for each device
102 | pl.seed_everything(cfg.seed + get_rank(), workers=True)
103 |
104 | dm = threestudio.find(cfg.data_type)(cfg.data)
105 | system: BaseSystem = threestudio.find(cfg.system_type)(
106 | cfg.system, resumed=cfg.resume is not None
107 | )
108 | system.set_save_dir(os.path.join(cfg.trial_dir, "save"))
109 |
110 | if args.gradio:
111 | fh = logging.FileHandler(os.path.join(cfg.trial_dir, "logs"))
112 | fh.setLevel(logging.INFO)
113 | if args.verbose:
114 | fh.setLevel(logging.DEBUG)
115 | fh.setFormatter(logging.Formatter("[%(levelname)s] %(message)s"))
116 | logger.addHandler(fh)
117 |
118 | callbacks = []
119 | if args.train:
120 | callbacks += [
121 | ModelCheckpoint(
122 | dirpath=os.path.join(cfg.trial_dir, "ckpts"), **cfg.checkpoint
123 | ),
124 | LearningRateMonitor(logging_interval="step"),
125 | CodeSnapshotCallback(
126 | os.path.join(cfg.trial_dir, "code"), use_version=False
127 | ),
128 | ConfigSnapshotCallback(
129 | args.config,
130 | cfg,
131 | os.path.join(cfg.trial_dir, "configs"),
132 | use_version=False,
133 | ),
134 | ]
135 | if args.gradio:
136 | callbacks += [
137 | ProgressCallback(save_path=os.path.join(cfg.trial_dir, "progress"))
138 | ]
139 | else:
140 | callbacks += [CustomProgressBar(refresh_rate=1)]
141 |
142 | def write_to_text(file, lines):
143 | with open(file, "w") as f:
144 | for line in lines:
145 | f.write(line + "\n")
146 |
147 | loggers = []
148 | if args.train:
149 | # make tensorboard logging dir to suppress warning
150 | rank_zero_only(
151 | lambda: os.makedirs(os.path.join(cfg.trial_dir, "tb_logs"), exist_ok=True)
152 | )()
153 | loggers += [
154 | TensorBoardLogger(cfg.trial_dir, name="tb_logs"),
155 | CSVLogger(cfg.trial_dir, name="csv_logs"),
156 | ] + system.get_loggers()
157 | rank_zero_only(
158 | lambda: write_to_text(
159 | os.path.join(cfg.trial_dir, "cmd.txt"),
160 | ["python " + " ".join(sys.argv), str(args)],
161 | )
162 | )()
163 |
164 | trainer = Trainer(
165 | callbacks=callbacks,
166 | logger=loggers,
167 | inference_mode=False,
168 | accelerator="gpu",
169 | devices=devices,
170 | **cfg.trainer,
171 | )
172 |
173 | def set_system_status(system: BaseSystem, ckpt_path: Optional[str]):
174 | if ckpt_path is None:
175 | return
176 | ckpt = torch.load(ckpt_path, map_location="cpu")
177 | system.set_resume_status(ckpt["epoch"], ckpt["global_step"])
178 |
179 | if args.train:
180 | trainer.fit(system, datamodule=dm, ckpt_path=cfg.resume)
181 | trainer.test(system, datamodule=dm)
182 | if args.gradio:
183 | # also export assets if in gradio mode
184 | trainer.predict(system, datamodule=dm)
185 | elif args.validate:
186 | # manually set epoch and global_step as they cannot be automatically resumed
187 | set_system_status(system, cfg.resume)
188 | trainer.validate(system, datamodule=dm, ckpt_path=cfg.resume)
189 | elif args.test:
190 | # manually set epoch and global_step as they cannot be automatically resumed
191 | set_system_status(system, cfg.resume)
192 | trainer.test(system, datamodule=dm, ckpt_path=cfg.resume)
193 | elif args.export:
194 | set_system_status(system, cfg.resume)
195 | trainer.predict(system, datamodule=dm, ckpt_path=cfg.resume)
196 |
197 |
198 | if __name__ == "__main__":
199 | parser = argparse.ArgumentParser()
200 | parser.add_argument("--config", required=True, help="path to config file")
201 | parser.add_argument(
202 | "--gpu",
203 | default="0",
204 | help="GPU(s) to be used. 0 means use the 1st available GPU. "
205 | "1,2 means use the 2nd and 3rd available GPU. "
206 | "If CUDA_VISIBLE_DEVICES is set before calling `launch.py`, "
207 | "this argument is ignored and all available GPUs are always used.",
208 | )
209 |
210 | group = parser.add_mutually_exclusive_group(required=True)
211 | group.add_argument("--train", action="store_true")
212 | group.add_argument("--validate", action="store_true")
213 | group.add_argument("--test", action="store_true")
214 | group.add_argument("--export", action="store_true")
215 |
216 | parser.add_argument(
217 | "--gradio", action="store_true", help="if true, run in gradio mode"
218 | )
219 |
220 | parser.add_argument(
221 | "--verbose", action="store_true", help="if true, set logging level to DEBUG"
222 | )
223 |
224 | parser.add_argument(
225 | "--typecheck",
226 | action="store_true",
227 | help="whether to enable dynamic type checking",
228 | )
229 |
230 | args, extras = parser.parse_known_args()
231 |
232 | if args.gradio:
233 | # FIXME: no effect, stdout is not captured
234 | with contextlib.redirect_stdout(sys.stderr):
235 | main(args, extras)
236 | else:
237 | main(args, extras)
238 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | lightning==2.0.0
2 | omegaconf==2.3.0
3 | jaxtyping
4 | typeguard
5 | git+https://github.com/KAIR-BAIR/nerfacc.git@v0.5.2
6 | git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
7 | diffusers
8 | transformers
9 | accelerate
10 | opencv-python
11 | tensorboard
12 | matplotlib
13 | imageio>=2.28.0
14 | imageio[ffmpeg]
15 | git+https://github.com/NVlabs/nvdiffrast.git
16 | libigl
17 | xatlas
18 | trimesh[easy]
19 | networkx
20 | pysdf
21 | PyMCubes
22 | wandb
23 | gradio
24 | git+https://github.com/ashawkey/envlight.git
25 | torchmetrics
26 |
27 | # deepfloyd
28 | xformers
29 | bitsandbytes
30 | sentencepiece
31 | safetensors
32 | huggingface_hub
33 |
34 | # for zero123
35 | einops
36 | kornia
37 | taming-transformers-rom1504
38 | git+https://github.com/openai/CLIP.git
39 |
40 | #controlnet
41 | controlnet_aux
42 |
--------------------------------------------------------------------------------
/threestudio/__init__.py:
--------------------------------------------------------------------------------
1 | __modules__ = {}
2 |
3 |
4 | def register(name):
5 | def decorator(cls):
6 | __modules__[name] = cls
7 | return cls
8 |
9 | return decorator
10 |
11 |
12 | def find(name):
13 | return __modules__[name]
14 |
15 |
16 | ### grammar sugar for logging utilities ###
17 | import logging
18 |
19 | logger = logging.getLogger("pytorch_lightning")
20 |
21 | from pytorch_lightning.utilities.rank_zero import (
22 | rank_zero_debug,
23 | rank_zero_info,
24 | rank_zero_only,
25 | )
26 |
27 | debug = rank_zero_debug
28 | info = rank_zero_info
29 |
30 |
31 | @rank_zero_only
32 | def warn(*args, **kwargs):
33 | logger.warn(*args, **kwargs)
34 |
35 |
36 | from . import data, models, systems
37 |
--------------------------------------------------------------------------------
/threestudio/data/__init__.py:
--------------------------------------------------------------------------------
1 | from . import co3d, image, multiview, uncond
2 |
--------------------------------------------------------------------------------
/threestudio/models/__init__.py:
--------------------------------------------------------------------------------
1 | from . import (
2 | background,
3 | exporters,
4 | geometry,
5 | guidance,
6 | materials,
7 | prompt_processors,
8 | renderers,
9 | )
10 |
--------------------------------------------------------------------------------
/threestudio/models/background/__init__.py:
--------------------------------------------------------------------------------
1 | from . import (
2 | base,
3 | neural_environment_map_background,
4 | )
5 |
--------------------------------------------------------------------------------
/threestudio/models/background/base.py:
--------------------------------------------------------------------------------
1 | import random
2 | from dataclasses import dataclass, field
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | import threestudio
9 | from threestudio.utils.base import BaseModule
10 | from threestudio.utils.typing import *
11 |
12 |
13 | class BaseBackground(BaseModule):
14 | @dataclass
15 | class Config(BaseModule.Config):
16 | pass
17 |
18 | cfg: Config
19 |
20 | def configure(self):
21 | pass
22 |
23 | def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]:
24 | raise NotImplementedError
25 |
--------------------------------------------------------------------------------
/threestudio/models/background/neural_environment_map_background.py:
--------------------------------------------------------------------------------
1 | import random
2 | from dataclasses import dataclass, field
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | import threestudio
9 | from threestudio.models.background.base import BaseBackground
10 | from threestudio.models.networks import get_encoding, get_mlp
11 | from threestudio.utils.ops import get_activation
12 | from threestudio.utils.typing import *
13 |
14 |
15 | @threestudio.register("neural-environment-map-background")
16 | class NeuralEnvironmentMapBackground(BaseBackground):
17 | @dataclass
18 | class Config(BaseBackground.Config):
19 | n_output_dims: int = 3
20 | color_activation: str = "sigmoid"
21 | dir_encoding_config: dict = field(
22 | default_factory=lambda: {"otype": "SphericalHarmonics", "degree": 3}
23 | )
24 | mlp_network_config: dict = field(
25 | default_factory=lambda: {
26 | "otype": "VanillaMLP",
27 | "activation": "ReLU",
28 | "n_neurons": 16,
29 | "n_hidden_layers": 2,
30 | }
31 | )
32 | random_aug: bool = False
33 | random_aug_prob: float = 0.5
34 | eval_color: Optional[Tuple[float, float, float]] = None
35 |
36 | cfg: Config
37 |
38 | def configure(self) -> None:
39 | self.encoding = get_encoding(3, self.cfg.dir_encoding_config)
40 | self.network = get_mlp(
41 | self.encoding.n_output_dims,
42 | self.cfg.n_output_dims,
43 | self.cfg.mlp_network_config,
44 | )
45 |
46 | def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]:
47 | if not self.training and self.cfg.eval_color is not None:
48 | return torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to(
49 | dirs
50 | ) * torch.as_tensor(self.cfg.eval_color).to(dirs)
51 | # viewdirs must be normalized before passing to this function
52 | dirs = (dirs + 1.0) / 2.0 # (-1, 1) => (0, 1)
53 | dirs_embd = self.encoding(dirs.view(-1, 3))
54 | color = self.network(dirs_embd).view(*dirs.shape[:-1], self.cfg.n_output_dims)
55 | color = get_activation(self.cfg.color_activation)(color)
56 | if (
57 | self.training
58 | and self.cfg.random_aug
59 | and random.random() < self.cfg.random_aug_prob
60 | ):
61 | # use random background color with probability random_aug_prob
62 | color = color * 0 + ( # prevent checking for unused parameters in DDP
63 | torch.rand(dirs.shape[0], 1, 1, self.cfg.n_output_dims)
64 | .to(dirs)
65 | .expand(*dirs.shape[:-1], -1)
66 | )
67 | return color
68 |
--------------------------------------------------------------------------------
/threestudio/models/exporters/__init__.py:
--------------------------------------------------------------------------------
1 | from . import base, mesh_exporter
2 |
--------------------------------------------------------------------------------
/threestudio/models/exporters/base.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import threestudio
4 | from threestudio.models.background.base import BaseBackground
5 | from threestudio.models.geometry.base import BaseImplicitGeometry
6 | from threestudio.models.materials.base import BaseMaterial
7 | from threestudio.utils.base import BaseObject
8 | from threestudio.utils.typing import *
9 |
10 |
11 | @dataclass
12 | class ExporterOutput:
13 | save_name: str
14 | save_type: str
15 | params: Dict[str, Any]
16 |
17 |
18 | class Exporter(BaseObject):
19 | @dataclass
20 | class Config(BaseObject.Config):
21 | save_video: bool = False
22 |
23 | cfg: Config
24 |
25 | def configure(
26 | self,
27 | geometry: BaseImplicitGeometry,
28 | material: BaseMaterial,
29 | background: BaseBackground,
30 | ) -> None:
31 | @dataclass
32 | class SubModules:
33 | geometry: BaseImplicitGeometry
34 | material: BaseMaterial
35 | background: BaseBackground
36 |
37 | self.sub_modules = SubModules(geometry, material, background)
38 |
39 | @property
40 | def geometry(self) -> BaseImplicitGeometry:
41 | return self.sub_modules.geometry
42 |
43 | @property
44 | def material(self) -> BaseMaterial:
45 | return self.sub_modules.material
46 |
47 | @property
48 | def background(self) -> BaseBackground:
49 | return self.sub_modules.background
50 |
51 | def __call__(self, *args, **kwargs) -> List[ExporterOutput]:
52 | raise NotImplementedError
53 |
54 |
55 | @threestudio.register("dummy-exporter")
56 | class DummyExporter(Exporter):
57 | def __call__(self, *args, **kwargs) -> List[ExporterOutput]:
58 | # DummyExporter does not export anything
59 | return []
60 |
--------------------------------------------------------------------------------
/threestudio/models/exporters/mesh_exporter.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 |
7 | import threestudio
8 | from threestudio.models.background.base import BaseBackground
9 | from threestudio.models.exporters.base import Exporter, ExporterOutput
10 | from threestudio.models.geometry.base import BaseImplicitGeometry
11 | from threestudio.models.materials.base import BaseMaterial
12 | from threestudio.models.mesh import Mesh
13 | from threestudio.utils.rasterize import NVDiffRasterizerContext
14 | from threestudio.utils.typing import *
15 |
16 |
17 | @threestudio.register("mesh-exporter")
18 | class MeshExporter(Exporter):
19 | @dataclass
20 | class Config(Exporter.Config):
21 | fmt: str = "obj-mtl" # in ['obj-mtl', 'obj'], TODO: fbx
22 | save_name: str = "model"
23 | save_normal: bool = False
24 | save_uv: bool = True
25 | save_texture: bool = True
26 | texture_size: int = 1024
27 | texture_format: str = "jpg"
28 | xatlas_chart_options: dict = field(default_factory=dict)
29 | xatlas_pack_options: dict = field(default_factory=dict)
30 | context_type: str = "gl"
31 |
32 | cfg: Config
33 |
34 | def configure(
35 | self,
36 | geometry: BaseImplicitGeometry,
37 | material: BaseMaterial,
38 | background: BaseBackground,
39 | ) -> None:
40 | super().configure(geometry, material, background)
41 | self.ctx = NVDiffRasterizerContext(self.cfg.context_type, self.device)
42 |
43 | def __call__(self) -> List[ExporterOutput]:
44 | mesh: Mesh = self.geometry.isosurface()
45 |
46 | if self.cfg.fmt == "obj-mtl":
47 | return self.export_obj_with_mtl(mesh)
48 | elif self.cfg.fmt == "obj":
49 | return self.export_obj(mesh)
50 | else:
51 | raise ValueError(f"Unsupported mesh export format: {self.cfg.fmt}")
52 |
53 | def export_obj_with_mtl(self, mesh: Mesh) -> List[ExporterOutput]:
54 | params = {
55 | "mesh": mesh,
56 | "save_mat": True,
57 | "save_normal": self.cfg.save_normal,
58 | "save_uv": self.cfg.save_uv,
59 | "save_vertex_color": False,
60 | "map_Kd": None, # Base Color
61 | "map_Ks": None, # Specular
62 | "map_Bump": None, # Normal
63 | # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering
64 | "map_Pm": None, # Metallic
65 | "map_Pr": None, # Roughness
66 | "map_format": self.cfg.texture_format,
67 | }
68 |
69 | if self.cfg.save_uv:
70 | mesh.unwrap_uv(self.cfg.xatlas_chart_options, self.cfg.xatlas_pack_options)
71 |
72 | if self.cfg.save_texture:
73 | threestudio.info("Exporting textures ...")
74 | assert self.cfg.save_uv, "save_uv must be True when save_texture is True"
75 | # clip space transform
76 | uv_clip = mesh.v_tex * 2.0 - 1.0
77 | # pad to four component coordinate
78 | uv_clip4 = torch.cat(
79 | (
80 | uv_clip,
81 | torch.zeros_like(uv_clip[..., 0:1]),
82 | torch.ones_like(uv_clip[..., 0:1]),
83 | ),
84 | dim=-1,
85 | )
86 | # rasterize
87 | rast, _ = self.ctx.rasterize_one(
88 | uv_clip4, mesh.t_tex_idx, (self.cfg.texture_size, self.cfg.texture_size)
89 | )
90 |
91 | hole_mask = ~(rast[:, :, 3] > 0)
92 |
93 | def uv_padding(image):
94 | uv_padding_size = self.cfg.xatlas_pack_options.get("padding", 2)
95 | inpaint_image = (
96 | cv2.inpaint(
97 | (image.detach().cpu().numpy() * 255).astype(np.uint8),
98 | (hole_mask.detach().cpu().numpy() * 255).astype(np.uint8),
99 | uv_padding_size,
100 | cv2.INPAINT_TELEA,
101 | )
102 | / 255.0
103 | )
104 | return torch.from_numpy(inpaint_image).to(image)
105 |
106 | # Interpolate world space position
107 | gb_pos, _ = self.ctx.interpolate_one(
108 | mesh.v_pos, rast[None, ...], mesh.t_pos_idx
109 | )
110 | gb_pos = gb_pos[0]
111 |
112 | # Sample out textures from MLP
113 | geo_out = self.geometry.export(points=gb_pos)
114 | mat_out = self.material.export(points=gb_pos, **geo_out)
115 |
116 | threestudio.info(
117 | "Perform UV padding on texture maps to avoid seams, may take a while ..."
118 | )
119 |
120 | if "albedo" in mat_out:
121 | params["map_Kd"] = uv_padding(mat_out["albedo"])
122 | else:
123 | threestudio.warn(
124 | "save_texture is True but no albedo texture found, using default white texture"
125 | )
126 | if "metallic" in mat_out:
127 | params["map_Pm"] = uv_padding(mat_out["metallic"])
128 | if "roughness" in mat_out:
129 | params["map_Pr"] = uv_padding(mat_out["roughness"])
130 | if "bump" in mat_out:
131 | params["map_Bump"] = uv_padding(mat_out["bump"])
132 | # TODO: map_Ks
133 | return [
134 | ExporterOutput(
135 | save_name=f"{self.cfg.save_name}.obj", save_type="obj", params=params
136 | )
137 | ]
138 |
139 | def export_obj(self, mesh: Mesh) -> List[ExporterOutput]:
140 | params = {
141 | "mesh": mesh,
142 | "save_mat": False,
143 | "save_normal": self.cfg.save_normal,
144 | "save_uv": self.cfg.save_uv,
145 | "save_vertex_color": False,
146 | "map_Kd": None, # Base Color
147 | "map_Ks": None, # Specular
148 | "map_Bump": None, # Normal
149 | # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering
150 | "map_Pm": None, # Metallic
151 | "map_Pr": None, # Roughness
152 | "map_format": self.cfg.texture_format,
153 | }
154 |
155 | if self.cfg.save_uv:
156 | mesh.unwrap_uv(self.cfg.xatlas_chart_options, self.cfg.xatlas_pack_options)
157 |
158 | if self.cfg.save_texture:
159 | threestudio.info("Exporting textures ...")
160 | geo_out = self.geometry.export(points=mesh.v_pos)
161 | mat_out = self.material.export(points=mesh.v_pos, **geo_out)
162 |
163 | if "albedo" in mat_out:
164 | mesh.set_vertex_color(mat_out["albedo"])
165 | params["save_vertex_color"] = True
166 | else:
167 | threestudio.warn(
168 | "save_texture is True but no albedo texture found, not saving vertex color"
169 | )
170 |
171 | return [
172 | ExporterOutput(
173 | save_name=f"{self.cfg.save_name}.obj", save_type="obj", params=params
174 | )
175 | ]
176 |
--------------------------------------------------------------------------------
/threestudio/models/geometry/__init__.py:
--------------------------------------------------------------------------------
1 | from . import (
2 | base,
3 | implicit_volume,
4 | )
5 |
--------------------------------------------------------------------------------
/threestudio/models/geometry/base.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | import threestudio
9 | from threestudio.models.isosurface import (
10 | IsosurfaceHelper,
11 | MarchingCubeCPUHelper,
12 | MarchingTetrahedraHelper,
13 | )
14 | from threestudio.models.mesh import Mesh
15 | from threestudio.utils.base import BaseModule
16 | from threestudio.utils.ops import chunk_batch, scale_tensor
17 | from threestudio.utils.typing import *
18 |
19 |
20 | def contract_to_unisphere(
21 | x: Float[Tensor, "... 3"], bbox: Float[Tensor, "2 3"], unbounded: bool = False
22 | ) -> Float[Tensor, "... 3"]:
23 | if unbounded:
24 | x = scale_tensor(x, bbox, (0, 1))
25 | x = x * 2 - 1 # aabb is at [-1, 1]
26 | mag = x.norm(dim=-1, keepdim=True)
27 | mask = mag.squeeze(-1) > 1
28 | x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask])
29 | x = x / 4 + 0.5 # [-inf, inf] is at [0, 1]
30 | else:
31 | x = scale_tensor(x, bbox, (0, 1))
32 | return x
33 |
34 |
35 | class BaseGeometry(BaseModule):
36 | @dataclass
37 | class Config(BaseModule.Config):
38 | pass
39 |
40 | cfg: Config
41 |
42 | @staticmethod
43 | def create_from(
44 | other: "BaseGeometry", cfg: Optional[Union[dict, DictConfig]] = None, **kwargs
45 | ) -> "BaseGeometry":
46 | raise TypeError(
47 | f"Cannot create {BaseGeometry.__name__} from {other.__class__.__name__}"
48 | )
49 |
50 | def export(self, *args, **kwargs) -> Dict[str, Any]:
51 | return {}
52 |
53 |
54 | class BaseImplicitGeometry(BaseGeometry):
55 | @dataclass
56 | class Config(BaseGeometry.Config):
57 | radius: float = 1.0
58 | isosurface: bool = True
59 | isosurface_method: str = "mt"
60 | isosurface_resolution: int = 128
61 | isosurface_threshold: Union[float, str] = 0.0
62 | isosurface_chunk: int = 0
63 | isosurface_coarse_to_fine: bool = True
64 | isosurface_deformable_grid: bool = False
65 | isosurface_remove_outliers: bool = True
66 | isosurface_outlier_n_faces_threshold: Union[int, float] = 0.01
67 |
68 | cfg: Config
69 |
70 | def configure(self) -> None:
71 | self.bbox: Float[Tensor, "2 3"]
72 | self.register_buffer(
73 | "bbox",
74 | torch.as_tensor(
75 | [
76 | [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
77 | [self.cfg.radius, self.cfg.radius, self.cfg.radius],
78 | ],
79 | dtype=torch.float32,
80 | ),
81 | )
82 | self.isosurface_helper: Optional[IsosurfaceHelper] = None
83 | self.unbounded: bool = False
84 |
85 | def _initilize_isosurface_helper(self):
86 | if self.cfg.isosurface and self.isosurface_helper is None:
87 | if self.cfg.isosurface_method == "mc-cpu":
88 | self.isosurface_helper = MarchingCubeCPUHelper(
89 | self.cfg.isosurface_resolution
90 | ).to(self.device)
91 | elif self.cfg.isosurface_method == "mt":
92 | self.isosurface_helper = MarchingTetrahedraHelper(
93 | self.cfg.isosurface_resolution,
94 | f"load/tets/{self.cfg.isosurface_resolution}_tets.npz",
95 | ).to(self.device)
96 | else:
97 | raise AttributeError(
98 | "Unknown isosurface method {self.cfg.isosurface_method}"
99 | )
100 |
101 | def forward(
102 | self, points: Float[Tensor, "*N Di"], output_normal: bool = False
103 | ) -> Dict[str, Float[Tensor, "..."]]:
104 | raise NotImplementedError
105 |
106 | def forward_field(
107 | self, points: Float[Tensor, "*N Di"]
108 | ) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]:
109 | # return the value of the implicit field, could be density / signed distance
110 | # also return a deformation field if the grid vertices can be optimized
111 | raise NotImplementedError
112 |
113 | def forward_level(
114 | self, field: Float[Tensor, "*N 1"], threshold: float
115 | ) -> Float[Tensor, "*N 1"]:
116 | # return the value of the implicit field, where the zero level set represents the surface
117 | raise NotImplementedError
118 |
119 | def _isosurface(self, bbox: Float[Tensor, "2 3"], fine_stage: bool = False) -> Mesh:
120 | def batch_func(x):
121 | # scale to bbox as the input vertices are in [0, 1]
122 | field, deformation = self.forward_field(
123 | scale_tensor(
124 | x.to(bbox.device), self.isosurface_helper.points_range, bbox
125 | ),
126 | )
127 | field = field.to(
128 | x.device
129 | ) # move to the same device as the input (could be CPU)
130 | if deformation is not None:
131 | deformation = deformation.to(x.device)
132 | return field, deformation
133 |
134 | assert self.isosurface_helper is not None
135 |
136 | field, deformation = chunk_batch(
137 | batch_func,
138 | self.cfg.isosurface_chunk,
139 | self.isosurface_helper.grid_vertices,
140 | )
141 |
142 | threshold: float
143 |
144 | if isinstance(self.cfg.isosurface_threshold, float):
145 | threshold = self.cfg.isosurface_threshold
146 | elif self.cfg.isosurface_threshold == "auto":
147 | eps = 1.0e-5
148 | threshold = field[field > eps].mean().item()
149 | threestudio.info(
150 | f"Automatically determined isosurface threshold: {threshold}"
151 | )
152 | else:
153 | raise TypeError(
154 | f"Unknown isosurface_threshold {self.cfg.isosurface_threshold}"
155 | )
156 |
157 | level = self.forward_level(field, threshold)
158 | mesh: Mesh = self.isosurface_helper(level, deformation=deformation)
159 | mesh.v_pos = scale_tensor(
160 | mesh.v_pos, self.isosurface_helper.points_range, bbox
161 | ) # scale to bbox as the grid vertices are in [0, 1]
162 | mesh.add_extra("bbox", bbox)
163 |
164 | if self.cfg.isosurface_remove_outliers:
165 | # remove outliers components with small number of faces
166 | # only enabled when the mesh is not differentiable
167 | mesh = mesh.remove_outlier(self.cfg.isosurface_outlier_n_faces_threshold)
168 |
169 | return mesh
170 |
171 | def isosurface(self) -> Mesh:
172 | if not self.cfg.isosurface:
173 | raise NotImplementedError(
174 | "Isosurface is not enabled in the current configuration"
175 | )
176 | self._initilize_isosurface_helper()
177 | if self.cfg.isosurface_coarse_to_fine:
178 | threestudio.debug("First run isosurface to get a tight bounding box ...")
179 | with torch.no_grad():
180 | mesh_coarse = self._isosurface(self.bbox)
181 | vmin, vmax = mesh_coarse.v_pos.amin(dim=0), mesh_coarse.v_pos.amax(dim=0)
182 | vmin_ = (vmin - (vmax - vmin) * 0.1).max(self.bbox[0])
183 | vmax_ = (vmax + (vmax - vmin) * 0.1).min(self.bbox[1])
184 | threestudio.debug("Run isosurface again with the tight bounding box ...")
185 | mesh = self._isosurface(torch.stack([vmin_, vmax_], dim=0), fine_stage=True)
186 | else:
187 | mesh = self._isosurface(self.bbox)
188 | return mesh
189 |
190 |
191 | class BaseExplicitGeometry(BaseGeometry):
192 | @dataclass
193 | class Config(BaseGeometry.Config):
194 | radius: float = 1.0
195 |
196 | cfg: Config
197 |
198 | def configure(self) -> None:
199 | self.bbox: Float[Tensor, "2 3"]
200 | self.register_buffer(
201 | "bbox",
202 | torch.as_tensor(
203 | [
204 | [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
205 | [self.cfg.radius, self.cfg.radius, self.cfg.radius],
206 | ],
207 | dtype=torch.float32,
208 | ),
209 | )
210 |
--------------------------------------------------------------------------------
/threestudio/models/guidance/__init__.py:
--------------------------------------------------------------------------------
1 | from . import (
2 | stable_diffusion_guidance,
3 | stable_diffusion_vsd_guidance,
4 | )
5 |
--------------------------------------------------------------------------------
/threestudio/models/isosurface.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | import threestudio
7 | from threestudio.models.mesh import Mesh
8 | from threestudio.utils.typing import *
9 |
10 |
11 | class IsosurfaceHelper(nn.Module):
12 | points_range: Tuple[float, float] = (0, 1)
13 |
14 | @property
15 | def grid_vertices(self) -> Float[Tensor, "N 3"]:
16 | raise NotImplementedError
17 |
18 |
19 | class MarchingCubeCPUHelper(IsosurfaceHelper):
20 | def __init__(self, resolution: int) -> None:
21 | super().__init__()
22 | self.resolution = resolution
23 | import mcubes
24 |
25 | self.mc_func: Callable = mcubes.marching_cubes
26 | self._grid_vertices: Optional[Float[Tensor, "N3 3"]] = None
27 | self._dummy: Float[Tensor, "..."]
28 | self.register_buffer(
29 | "_dummy", torch.zeros(0, dtype=torch.float32), persistent=False
30 | )
31 |
32 | @property
33 | def grid_vertices(self) -> Float[Tensor, "N3 3"]:
34 | if self._grid_vertices is None:
35 | # keep the vertices on CPU so that we can support very large resolution
36 | x, y, z = (
37 | torch.linspace(*self.points_range, self.resolution),
38 | torch.linspace(*self.points_range, self.resolution),
39 | torch.linspace(*self.points_range, self.resolution),
40 | )
41 | x, y, z = torch.meshgrid(x, y, z, indexing="ij")
42 | verts = torch.cat(
43 | [x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1
44 | ).reshape(-1, 3)
45 | self._grid_vertices = verts
46 | return self._grid_vertices
47 |
48 | def forward(
49 | self,
50 | level: Float[Tensor, "N3 1"],
51 | deformation: Optional[Float[Tensor, "N3 3"]] = None,
52 | ) -> Mesh:
53 | if deformation is not None:
54 | threestudio.warn(
55 | f"{self.__class__.__name__} does not support deformation. Ignoring."
56 | )
57 | level = -level.view(self.resolution, self.resolution, self.resolution)
58 | v_pos, t_pos_idx = self.mc_func(
59 | level.detach().cpu().numpy(), 0.0
60 | ) # transform to numpy
61 | v_pos, t_pos_idx = (
62 | torch.from_numpy(v_pos).float().to(self._dummy.device),
63 | torch.from_numpy(t_pos_idx.astype(np.int64)).long().to(self._dummy.device),
64 | ) # transform back to torch tensor on CUDA
65 | v_pos = v_pos / (self.resolution - 1.0)
66 | return Mesh(v_pos=v_pos, t_pos_idx=t_pos_idx)
67 |
68 |
69 | class MarchingTetrahedraHelper(IsosurfaceHelper):
70 | def __init__(self, resolution: int, tets_path: str):
71 | super().__init__()
72 | self.resolution = resolution
73 | self.tets_path = tets_path
74 |
75 | self.triangle_table: Float[Tensor, "..."]
76 | self.register_buffer(
77 | "triangle_table",
78 | torch.as_tensor(
79 | [
80 | [-1, -1, -1, -1, -1, -1],
81 | [1, 0, 2, -1, -1, -1],
82 | [4, 0, 3, -1, -1, -1],
83 | [1, 4, 2, 1, 3, 4],
84 | [3, 1, 5, -1, -1, -1],
85 | [2, 3, 0, 2, 5, 3],
86 | [1, 4, 0, 1, 5, 4],
87 | [4, 2, 5, -1, -1, -1],
88 | [4, 5, 2, -1, -1, -1],
89 | [4, 1, 0, 4, 5, 1],
90 | [3, 2, 0, 3, 5, 2],
91 | [1, 3, 5, -1, -1, -1],
92 | [4, 1, 2, 4, 3, 1],
93 | [3, 0, 4, -1, -1, -1],
94 | [2, 0, 1, -1, -1, -1],
95 | [-1, -1, -1, -1, -1, -1],
96 | ],
97 | dtype=torch.long,
98 | ),
99 | persistent=False,
100 | )
101 | self.num_triangles_table: Integer[Tensor, "..."]
102 | self.register_buffer(
103 | "num_triangles_table",
104 | torch.as_tensor(
105 | [0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long
106 | ),
107 | persistent=False,
108 | )
109 | self.base_tet_edges: Integer[Tensor, "..."]
110 | self.register_buffer(
111 | "base_tet_edges",
112 | torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long),
113 | persistent=False,
114 | )
115 |
116 | tets = np.load(self.tets_path)
117 | self._grid_vertices: Float[Tensor, "..."]
118 | self.register_buffer(
119 | "_grid_vertices",
120 | torch.from_numpy(tets["vertices"]).float(),
121 | persistent=False,
122 | )
123 | self.indices: Integer[Tensor, "..."]
124 | self.register_buffer(
125 | "indices", torch.from_numpy(tets["indices"]).long(), persistent=False
126 | )
127 |
128 | self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None
129 |
130 | def normalize_grid_deformation(
131 | self, grid_vertex_offsets: Float[Tensor, "Nv 3"]
132 | ) -> Float[Tensor, "Nv 3"]:
133 | return (
134 | (self.points_range[1] - self.points_range[0])
135 | / (self.resolution) # half tet size is approximately 1 / self.resolution
136 | * torch.tanh(grid_vertex_offsets)
137 | ) # FIXME: hard-coded activation
138 |
139 | @property
140 | def grid_vertices(self) -> Float[Tensor, "Nv 3"]:
141 | return self._grid_vertices
142 |
143 | @property
144 | def all_edges(self) -> Integer[Tensor, "Ne 2"]:
145 | if self._all_edges is None:
146 | # compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation)
147 | edges = torch.tensor(
148 | [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3],
149 | dtype=torch.long,
150 | device=self.indices.device,
151 | )
152 | _all_edges = self.indices[:, edges].reshape(-1, 2)
153 | _all_edges_sorted = torch.sort(_all_edges, dim=1)[0]
154 | _all_edges = torch.unique(_all_edges_sorted, dim=0)
155 | self._all_edges = _all_edges
156 | return self._all_edges
157 |
158 | def sort_edges(self, edges_ex2):
159 | with torch.no_grad():
160 | order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
161 | order = order.unsqueeze(dim=1)
162 |
163 | a = torch.gather(input=edges_ex2, index=order, dim=1)
164 | b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
165 |
166 | return torch.stack([a, b], -1)
167 |
168 | def _forward(self, pos_nx3, sdf_n, tet_fx4):
169 | with torch.no_grad():
170 | occ_n = sdf_n > 0
171 | occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
172 | occ_sum = torch.sum(occ_fx4, -1)
173 | valid_tets = (occ_sum > 0) & (occ_sum < 4)
174 | occ_sum = occ_sum[valid_tets]
175 |
176 | # find all vertices
177 | all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2)
178 | all_edges = self.sort_edges(all_edges)
179 | unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
180 |
181 | unique_edges = unique_edges.long()
182 | mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
183 | mapping = (
184 | torch.ones(
185 | (unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device
186 | )
187 | * -1
188 | )
189 | mapping[mask_edges] = torch.arange(
190 | mask_edges.sum(), dtype=torch.long, device=pos_nx3.device
191 | )
192 | idx_map = mapping[idx_map] # map edges to verts
193 |
194 | interp_v = unique_edges[mask_edges]
195 | edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
196 | edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
197 | edges_to_interp_sdf[:, -1] *= -1
198 |
199 | denominator = edges_to_interp_sdf.sum(1, keepdim=True)
200 |
201 | edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
202 | verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
203 |
204 | idx_map = idx_map.reshape(-1, 6)
205 |
206 | v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device))
207 | tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
208 | num_triangles = self.num_triangles_table[tetindex]
209 |
210 | # Generate triangle indices
211 | faces = torch.cat(
212 | (
213 | torch.gather(
214 | input=idx_map[num_triangles == 1],
215 | dim=1,
216 | index=self.triangle_table[tetindex[num_triangles == 1]][:, :3],
217 | ).reshape(-1, 3),
218 | torch.gather(
219 | input=idx_map[num_triangles == 2],
220 | dim=1,
221 | index=self.triangle_table[tetindex[num_triangles == 2]][:, :6],
222 | ).reshape(-1, 3),
223 | ),
224 | dim=0,
225 | )
226 |
227 | return verts, faces
228 |
229 | def forward(
230 | self,
231 | level: Float[Tensor, "N3 1"],
232 | deformation: Optional[Float[Tensor, "N3 3"]] = None,
233 | ) -> Mesh:
234 | if deformation is not None:
235 | grid_vertices = self.grid_vertices + self.normalize_grid_deformation(
236 | deformation
237 | )
238 | else:
239 | grid_vertices = self.grid_vertices
240 |
241 | v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices)
242 |
243 | mesh = Mesh(
244 | v_pos=v_pos,
245 | t_pos_idx=t_pos_idx,
246 | # extras
247 | grid_vertices=grid_vertices,
248 | tet_edges=self.all_edges,
249 | grid_level=level,
250 | grid_deformation=deformation,
251 | )
252 |
253 | return mesh
254 |
--------------------------------------------------------------------------------
/threestudio/models/materials/__init__.py:
--------------------------------------------------------------------------------
1 | from . import (
2 | base,
3 | no_material,
4 | )
5 |
--------------------------------------------------------------------------------
/threestudio/models/materials/base.py:
--------------------------------------------------------------------------------
1 | import random
2 | from dataclasses import dataclass, field
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | import threestudio
9 | from threestudio.utils.base import BaseModule
10 | from threestudio.utils.typing import *
11 |
12 |
13 | class BaseMaterial(BaseModule):
14 | @dataclass
15 | class Config(BaseModule.Config):
16 | pass
17 |
18 | cfg: Config
19 | requires_normal: bool = False
20 | requires_tangent: bool = False
21 |
22 | def configure(self):
23 | pass
24 |
25 | def forward(self, *args, **kwargs) -> Float[Tensor, "*B 3"]:
26 | raise NotImplementedError
27 |
28 | def export(self, *args, **kwargs) -> Dict[str, Any]:
29 | return {}
30 |
--------------------------------------------------------------------------------
/threestudio/models/materials/no_material.py:
--------------------------------------------------------------------------------
1 | import random
2 | from dataclasses import dataclass, field
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | import threestudio
9 | from threestudio.models.materials.base import BaseMaterial
10 | from threestudio.models.networks import get_encoding, get_mlp
11 | from threestudio.utils.ops import dot, get_activation
12 | from threestudio.utils.typing import *
13 |
14 |
15 | @threestudio.register("no-material")
16 | class NoMaterial(BaseMaterial):
17 | @dataclass
18 | class Config(BaseMaterial.Config):
19 | n_output_dims: int = 3
20 | color_activation: str = "sigmoid"
21 | input_feature_dims: Optional[int] = None
22 | mlp_network_config: Optional[dict] = None
23 |
24 | cfg: Config
25 |
26 | def configure(self) -> None:
27 | self.use_network = False
28 | if (
29 | self.cfg.input_feature_dims is not None
30 | and self.cfg.mlp_network_config is not None
31 | ):
32 | self.network = get_mlp(
33 | self.cfg.input_feature_dims,
34 | self.cfg.n_output_dims,
35 | self.cfg.mlp_network_config,
36 | )
37 | self.use_network = True
38 |
39 | def forward(
40 | self, features: Float[Tensor, "B ... Nf"], **kwargs
41 | ) -> Float[Tensor, "B ... Nc"]:
42 | if not self.use_network:
43 | assert (
44 | features.shape[-1] == self.cfg.n_output_dims
45 | ), f"Expected {self.cfg.n_output_dims} output dims, only got {features.shape[-1]} dims input."
46 | color = get_activation(self.cfg.color_activation)(features)
47 | else:
48 | color = self.network(features.view(-1, features.shape[-1])).view(
49 | *features.shape[:-1], self.cfg.n_output_dims
50 | )
51 | color = get_activation(self.cfg.color_activation)(color)
52 | return color
53 |
54 | def export(self, features: Float[Tensor, "*N Nf"], **kwargs) -> Dict[str, Any]:
55 | color = self(features, **kwargs).clamp(0, 1)
56 | assert color.shape[-1] >= 3, "Output color must have at least 3 channels"
57 | if color.shape[-1] > 3:
58 | threestudio.warn(
59 | "Output color has >3 channels, treating the first 3 as RGB"
60 | )
61 | return {"albedo": color[..., :3]}
62 |
--------------------------------------------------------------------------------
/threestudio/models/mesh.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | import threestudio
8 | from threestudio.utils.ops import dot
9 | from threestudio.utils.typing import *
10 |
11 |
12 | class Mesh:
13 | def __init__(
14 | self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs
15 | ) -> None:
16 | self.v_pos: Float[Tensor, "Nv 3"] = v_pos
17 | self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx
18 | self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None
19 | self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None
20 | self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None
21 | self._t_tex_idx: Optional[Float[Tensor, "Nf 3"]] = None
22 | self._v_rgb: Optional[Float[Tensor, "Nv 3"]] = None
23 | self._edges: Optional[Integer[Tensor, "Ne 2"]] = None
24 | self.extras: Dict[str, Any] = {}
25 | for k, v in kwargs.items():
26 | self.add_extra(k, v)
27 |
28 | def add_extra(self, k, v) -> None:
29 | self.extras[k] = v
30 |
31 | def remove_outlier(self, outlier_n_faces_threshold: Union[int, float]) -> Mesh:
32 | if self.requires_grad:
33 | threestudio.debug("Mesh is differentiable, not removing outliers")
34 | return self
35 |
36 | # use trimesh to first split the mesh into connected components
37 | # then remove the components with less than n_face_threshold faces
38 | import trimesh
39 |
40 | # construct a trimesh object
41 | mesh = trimesh.Trimesh(
42 | vertices=self.v_pos.detach().cpu().numpy(),
43 | faces=self.t_pos_idx.detach().cpu().numpy(),
44 | )
45 |
46 | # split the mesh into connected components
47 | components = mesh.split(only_watertight=False)
48 | # log the number of faces in each component
49 | threestudio.debug(
50 | "Mesh has {} components, with faces: {}".format(
51 | len(components), [c.faces.shape[0] for c in components]
52 | )
53 | )
54 |
55 | n_faces_threshold: int
56 | if isinstance(outlier_n_faces_threshold, float):
57 | # set the threshold to the number of faces in the largest component multiplied by outlier_n_faces_threshold
58 | n_faces_threshold = int(
59 | max([c.faces.shape[0] for c in components]) * outlier_n_faces_threshold
60 | )
61 | else:
62 | # set the threshold directly to outlier_n_faces_threshold
63 | n_faces_threshold = outlier_n_faces_threshold
64 |
65 | # log the threshold
66 | threestudio.debug(
67 | "Removing components with less than {} faces".format(n_faces_threshold)
68 | )
69 |
70 | # remove the components with less than n_face_threshold faces
71 | components = [c for c in components if c.faces.shape[0] >= n_faces_threshold]
72 |
73 | # log the number of faces in each component after removing outliers
74 | threestudio.debug(
75 | "Mesh has {} components after removing outliers, with faces: {}".format(
76 | len(components), [c.faces.shape[0] for c in components]
77 | )
78 | )
79 | # merge the components
80 | mesh = trimesh.util.concatenate(components)
81 |
82 | # convert back to our mesh format
83 | v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos)
84 | t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx)
85 |
86 | clean_mesh = Mesh(v_pos, t_pos_idx)
87 | # keep the extras unchanged
88 |
89 | if len(self.extras) > 0:
90 | clean_mesh.extras = self.extras
91 | threestudio.debug(
92 | f"The following extra attributes are inherited from the original mesh unchanged: {list(self.extras.keys())}"
93 | )
94 | return clean_mesh
95 |
96 | @property
97 | def requires_grad(self):
98 | return self.v_pos.requires_grad
99 |
100 | @property
101 | def v_nrm(self):
102 | if self._v_nrm is None:
103 | self._v_nrm = self._compute_vertex_normal()
104 | return self._v_nrm
105 |
106 | @property
107 | def v_tng(self):
108 | if self._v_tng is None:
109 | self._v_tng = self._compute_vertex_tangent()
110 | return self._v_tng
111 |
112 | @property
113 | def v_tex(self):
114 | if self._v_tex is None:
115 | self._v_tex, self._t_tex_idx = self._unwrap_uv()
116 | return self._v_tex
117 |
118 | @property
119 | def t_tex_idx(self):
120 | if self._t_tex_idx is None:
121 | self._v_tex, self._t_tex_idx = self._unwrap_uv()
122 | return self._t_tex_idx
123 |
124 | @property
125 | def v_rgb(self):
126 | return self._v_rgb
127 |
128 | @property
129 | def edges(self):
130 | if self._edges is None:
131 | self._edges = self._compute_edges()
132 | return self._edges
133 |
134 | def _compute_vertex_normal(self):
135 | i0 = self.t_pos_idx[:, 0]
136 | i1 = self.t_pos_idx[:, 1]
137 | i2 = self.t_pos_idx[:, 2]
138 |
139 | v0 = self.v_pos[i0, :]
140 | v1 = self.v_pos[i1, :]
141 | v2 = self.v_pos[i2, :]
142 |
143 | face_normals = torch.cross(v1 - v0, v2 - v0)
144 |
145 | # Splat face normals to vertices
146 | v_nrm = torch.zeros_like(self.v_pos)
147 | v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
148 | v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
149 | v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
150 |
151 | # Normalize, replace zero (degenerated) normals with some default value
152 | v_nrm = torch.where(
153 | dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
154 | )
155 | v_nrm = F.normalize(v_nrm, dim=1)
156 |
157 | if torch.is_anomaly_enabled():
158 | assert torch.all(torch.isfinite(v_nrm))
159 |
160 | return v_nrm
161 |
162 | def _compute_vertex_tangent(self):
163 | vn_idx = [None] * 3
164 | pos = [None] * 3
165 | tex = [None] * 3
166 | for i in range(0, 3):
167 | pos[i] = self.v_pos[self.t_pos_idx[:, i]]
168 | tex[i] = self.v_tex[self.t_tex_idx[:, i]]
169 | # t_nrm_idx is always the same as t_pos_idx
170 | vn_idx[i] = self.t_pos_idx[:, i]
171 |
172 | tangents = torch.zeros_like(self.v_nrm)
173 | tansum = torch.zeros_like(self.v_nrm)
174 |
175 | # Compute tangent space for each triangle
176 | uve1 = tex[1] - tex[0]
177 | uve2 = tex[2] - tex[0]
178 | pe1 = pos[1] - pos[0]
179 | pe2 = pos[2] - pos[0]
180 |
181 | nom = pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2]
182 | denom = uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1]
183 |
184 | # Avoid division by zero for degenerated texture coordinates
185 | tang = nom / torch.where(
186 | denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6)
187 | )
188 |
189 | # Update all 3 vertices
190 | for i in range(0, 3):
191 | idx = vn_idx[i][:, None].repeat(1, 3)
192 | tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
193 | tansum.scatter_add_(
194 | 0, idx, torch.ones_like(tang)
195 | ) # tansum[n_i] = tansum[n_i] + 1
196 | tangents = tangents / tansum
197 |
198 | # Normalize and make sure tangent is perpendicular to normal
199 | tangents = F.normalize(tangents, dim=1)
200 | tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm)
201 |
202 | if torch.is_anomaly_enabled():
203 | assert torch.all(torch.isfinite(tangents))
204 |
205 | return tangents
206 |
207 | def _unwrap_uv(
208 | self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {}
209 | ):
210 | threestudio.info("Using xatlas to perform UV unwrapping, may take a while ...")
211 |
212 | import xatlas
213 |
214 | atlas = xatlas.Atlas()
215 | atlas.add_mesh(
216 | self.v_pos.detach().cpu().numpy(),
217 | self.t_pos_idx.cpu().numpy(),
218 | )
219 | co = xatlas.ChartOptions()
220 | po = xatlas.PackOptions()
221 | for k, v in xatlas_chart_options.items():
222 | setattr(co, k, v)
223 | for k, v in xatlas_pack_options.items():
224 | setattr(po, k, v)
225 | atlas.generate(co, po)
226 | vmapping, indices, uvs = atlas.get_mesh(0)
227 | vmapping = (
228 | torch.from_numpy(
229 | vmapping.astype(np.uint64, casting="same_kind").view(np.int64)
230 | )
231 | .to(self.v_pos.device)
232 | .long()
233 | )
234 | uvs = torch.from_numpy(uvs).to(self.v_pos.device).float()
235 | indices = (
236 | torch.from_numpy(
237 | indices.astype(np.uint64, casting="same_kind").view(np.int64)
238 | )
239 | .to(self.v_pos.device)
240 | .long()
241 | )
242 | return uvs, indices
243 |
244 | def unwrap_uv(
245 | self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {}
246 | ):
247 | self._v_tex, self._t_tex_idx = self._unwrap_uv(
248 | xatlas_chart_options, xatlas_pack_options
249 | )
250 |
251 | def set_vertex_color(self, v_rgb):
252 | assert v_rgb.shape[0] == self.v_pos.shape[0]
253 | self._v_rgb = v_rgb
254 |
255 | def _compute_edges(self):
256 | # Compute edges
257 | edges = torch.cat(
258 | [
259 | self.t_pos_idx[:, [0, 1]],
260 | self.t_pos_idx[:, [1, 2]],
261 | self.t_pos_idx[:, [2, 0]],
262 | ],
263 | dim=0,
264 | )
265 | edges = edges.sort()[0]
266 | edges = torch.unique(edges, dim=0)
267 | return edges
268 |
269 | def normal_consistency(self) -> Float[Tensor, ""]:
270 | edge_nrm: Float[Tensor, "Ne 2 3"] = self.v_nrm[self.edges]
271 | nc = (
272 | 1.0 - torch.cosine_similarity(edge_nrm[:, 0], edge_nrm[:, 1], dim=-1)
273 | ).mean()
274 | return nc
275 |
276 | def _laplacian_uniform(self):
277 | # from stable-dreamfusion
278 | # https://github.com/ashawkey/stable-dreamfusion/blob/8fb3613e9e4cd1ded1066b46e80ca801dfb9fd06/nerf/renderer.py#L224
279 | verts, faces = self.v_pos, self.t_pos_idx
280 |
281 | V = verts.shape[0]
282 | F = faces.shape[0]
283 |
284 | # Neighbor indices
285 | ii = faces[:, [1, 2, 0]].flatten()
286 | jj = faces[:, [2, 0, 1]].flatten()
287 | adj = torch.stack([torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique(
288 | dim=1
289 | )
290 | adj_values = torch.ones(adj.shape[1]).to(verts)
291 |
292 | # Diagonal indices
293 | diag_idx = adj[0]
294 |
295 | # Build the sparse matrix
296 | idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1)
297 | values = torch.cat((-adj_values, adj_values))
298 |
299 | # The coalesce operation sums the duplicate indices, resulting in the
300 | # correct diagonal
301 | return torch.sparse_coo_tensor(idx, values, (V, V)).coalesce()
302 |
303 | def laplacian(self) -> Float[Tensor, ""]:
304 | with torch.no_grad():
305 | L = self._laplacian_uniform()
306 | loss = L.mm(self.v_pos)
307 | loss = loss.norm(dim=1)
308 | loss = loss.mean()
309 | return loss
310 |
--------------------------------------------------------------------------------
/threestudio/models/prompt_processors/__init__.py:
--------------------------------------------------------------------------------
1 | from . import (
2 | base,
3 | stable_diffusion_prompt_processor,
4 | )
5 |
--------------------------------------------------------------------------------
/threestudio/models/prompt_processors/stable_diffusion_prompt_processor.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from dataclasses import dataclass
4 |
5 | import torch
6 | import torch.nn as nn
7 | from transformers import AutoTokenizer, CLIPTextModel
8 |
9 | import threestudio
10 | from threestudio.models.prompt_processors.base import PromptProcessor, hash_prompt
11 | from threestudio.utils.misc import cleanup
12 | from threestudio.utils.typing import *
13 |
14 |
15 | @threestudio.register("stable-diffusion-prompt-processor")
16 | class StableDiffusionPromptProcessor(PromptProcessor):
17 | @dataclass
18 | class Config(PromptProcessor.Config):
19 | pass
20 |
21 | cfg: Config
22 |
23 | ### these functions are unused, kept for debugging ###
24 | def configure_text_encoder(self) -> None:
25 | self.tokenizer = AutoTokenizer.from_pretrained(
26 | self.cfg.pretrained_model_name_or_path, subfolder="tokenizer"
27 | )
28 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
29 | self.text_encoder = CLIPTextModel.from_pretrained(
30 | self.cfg.pretrained_model_name_or_path, subfolder="text_encoder"
31 | ).to(self.device)
32 |
33 | for p in self.text_encoder.parameters():
34 | p.requires_grad_(False)
35 |
36 | def destroy_text_encoder(self) -> None:
37 | del self.tokenizer
38 | del self.text_encoder
39 | cleanup()
40 |
41 | def get_text_embeddings(
42 | self, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]]
43 | ) -> Tuple[Float[Tensor, "B 77 768"], Float[Tensor, "B 77 768"]]:
44 | if isinstance(prompt, str):
45 | prompt = [prompt]
46 | if isinstance(negative_prompt, str):
47 | negative_prompt = [negative_prompt]
48 | # Tokenize text and get embeddings
49 | tokens = self.tokenizer(
50 | prompt,
51 | padding="max_length",
52 | max_length=self.tokenizer.model_max_length,
53 | return_tensors="pt",
54 | )
55 | uncond_tokens = self.tokenizer(
56 | negative_prompt,
57 | padding="max_length",
58 | max_length=self.tokenizer.model_max_length,
59 | return_tensors="pt",
60 | )
61 |
62 | with torch.no_grad():
63 | text_embeddings = self.text_encoder(tokens.input_ids.to(self.device))[0]
64 | uncond_text_embeddings = self.text_encoder(
65 | uncond_tokens.input_ids.to(self.device)
66 | )[0]
67 |
68 | return text_embeddings, uncond_text_embeddings
69 |
70 | ###
71 |
72 | @staticmethod
73 | def spawn_func(pretrained_model_name_or_path, prompts, cache_dir):
74 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
75 | tokenizer = AutoTokenizer.from_pretrained(
76 | pretrained_model_name_or_path, subfolder="tokenizer"
77 | )
78 | text_encoder = CLIPTextModel.from_pretrained(
79 | pretrained_model_name_or_path,
80 | subfolder="text_encoder",
81 | device_map="auto",
82 | )
83 |
84 | with torch.no_grad():
85 | tokens = tokenizer(
86 | prompts,
87 | padding="max_length",
88 | max_length=tokenizer.model_max_length,
89 | return_tensors="pt",
90 | )
91 | text_embeddings = text_encoder(tokens.input_ids.cuda())[0]
92 |
93 | for prompt, embedding in zip(prompts, text_embeddings):
94 | torch.save(
95 | embedding,
96 | os.path.join(
97 | cache_dir,
98 | f"{hash_prompt(pretrained_model_name_or_path, prompt)}.pt",
99 | ),
100 | )
101 |
102 | del text_encoder
103 |
--------------------------------------------------------------------------------
/threestudio/models/renderers/__init__.py:
--------------------------------------------------------------------------------
1 | from . import (
2 | base,
3 | nerf_volume_renderer,
4 | )
5 |
--------------------------------------------------------------------------------
/threestudio/models/renderers/base.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import nerfacc
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | import threestudio
8 | from threestudio.models.background.base import BaseBackground
9 | from threestudio.models.geometry.base import BaseImplicitGeometry
10 | from threestudio.models.materials.base import BaseMaterial
11 | from threestudio.utils.base import BaseModule
12 | from threestudio.utils.typing import *
13 |
14 |
15 | class Renderer(BaseModule):
16 | @dataclass
17 | class Config(BaseModule.Config):
18 | radius: float = 1.0
19 |
20 | cfg: Config
21 |
22 | def configure(
23 | self,
24 | geometry: BaseImplicitGeometry,
25 | material: BaseMaterial,
26 | background: BaseBackground,
27 | ) -> None:
28 | # keep references to submodules using namedtuple, avoid being registered as modules
29 | @dataclass
30 | class SubModules:
31 | geometry: BaseImplicitGeometry
32 | material: BaseMaterial
33 | background: BaseBackground
34 |
35 | self.sub_modules = SubModules(geometry, material, background)
36 |
37 | # set up bounding box
38 | self.bbox: Float[Tensor, "2 3"]
39 | self.register_buffer(
40 | "bbox",
41 | torch.as_tensor(
42 | [
43 | [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
44 | [self.cfg.radius, self.cfg.radius, self.cfg.radius],
45 | ],
46 | dtype=torch.float32,
47 | ),
48 | )
49 |
50 | def forward(self, *args, **kwargs) -> Dict[str, Any]:
51 | raise NotImplementedError
52 |
53 | @property
54 | def geometry(self) -> BaseImplicitGeometry:
55 | return self.sub_modules.geometry
56 |
57 | @property
58 | def material(self) -> BaseMaterial:
59 | return self.sub_modules.material
60 |
61 | @property
62 | def background(self) -> BaseBackground:
63 | return self.sub_modules.background
64 |
65 | def set_geometry(self, geometry: BaseImplicitGeometry) -> None:
66 | self.sub_modules.geometry = geometry
67 |
68 | def set_material(self, material: BaseMaterial) -> None:
69 | self.sub_modules.material = material
70 |
71 | def set_background(self, background: BaseBackground) -> None:
72 | self.sub_modules.background = background
73 |
74 |
75 | class VolumeRenderer(Renderer):
76 | pass
77 |
78 |
79 | class Rasterizer(Renderer):
80 | pass
81 |
--------------------------------------------------------------------------------
/threestudio/models/renderers/nerf_volume_renderer.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import nerfacc
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | import threestudio
8 | from threestudio.models.background.base import BaseBackground
9 | from threestudio.models.geometry.base import BaseImplicitGeometry
10 | from threestudio.models.materials.base import BaseMaterial
11 | from threestudio.models.renderers.base import VolumeRenderer
12 | from threestudio.utils.ops import chunk_batch, validate_empty_rays
13 | from threestudio.utils.typing import *
14 |
15 |
16 | @threestudio.register("nerf-volume-renderer")
17 | class NeRFVolumeRenderer(VolumeRenderer):
18 | @dataclass
19 | class Config(VolumeRenderer.Config):
20 | num_samples_per_ray: int = 512
21 | randomized: bool = True
22 | eval_chunk_size: int = 160000
23 | grid_prune: bool = True
24 | prune_alpha_threshold: bool = True
25 | return_comp_normal: bool = False
26 | return_normal_perturb: bool = False
27 |
28 | cfg: Config
29 |
30 | def configure(
31 | self,
32 | geometry: BaseImplicitGeometry,
33 | material: BaseMaterial,
34 | background: BaseBackground,
35 | ) -> None:
36 | super().configure(geometry, material, background)
37 | self.estimator = nerfacc.OccGridEstimator(
38 | roi_aabb=self.bbox.view(-1), resolution=32, levels=1
39 | )
40 | if not self.cfg.grid_prune:
41 | self.estimator.occs.fill_(True)
42 | self.estimator.binaries.fill_(True)
43 | self.render_step_size = (
44 | 1.732 * 2 * self.cfg.radius / self.cfg.num_samples_per_ray
45 | )
46 | self.randomized = self.cfg.randomized
47 |
48 |
49 | def forward(
50 | self,
51 | batch_idx: int,
52 | rays_o: Float[Tensor, "B H W 3"],
53 | rays_d: Float[Tensor, "B H W 3"],
54 | light_positions: Float[Tensor, "B 3"],
55 | bg_color: Optional[Tensor] = None,
56 | **kwargs,
57 | ) -> Dict[str, Float[Tensor, "..."]]:
58 | batch_size, height, width = rays_o.shape[:3]
59 | rays_o_flatten: Float[Tensor, "Nr 3"] = rays_o.reshape(-1, 3)
60 | rays_d_flatten: Float[Tensor, "Nr 3"] = rays_d.reshape(-1, 3)
61 | light_positions_flatten: Float[Tensor, "Nr 3"] = (
62 | light_positions.reshape(-1, 1, 1, 3)
63 | .expand(-1, height, width, -1)
64 | .reshape(-1, 3)
65 | )
66 | n_rays = rays_o_flatten.shape[0]
67 |
68 | def sigma_fn(t_starts, t_ends, ray_indices, idx=batch_idx):
69 | t_starts, t_ends = t_starts[..., None], t_ends[..., None]
70 | t_origins = rays_o_flatten[ray_indices]
71 | t_positions = (t_starts + t_ends) / 2.0
72 | t_dirs = rays_d_flatten[ray_indices]
73 | positions = t_origins + t_dirs * t_positions
74 | if self.training:
75 | sigma = self.geometry.forward_density(positions, idx=batch_idx)[..., 0]
76 | else:
77 | sigma = chunk_batch(
78 | self.geometry.forward_density,
79 | self.cfg.eval_chunk_size,
80 | positions,
81 | )[..., 0]
82 | return sigma
83 |
84 | if not self.cfg.grid_prune:
85 | with torch.no_grad():
86 | ray_indices, t_starts_, t_ends_ = self.estimator.sampling(
87 | rays_o_flatten,
88 | rays_d_flatten,
89 | sigma_fn=None,
90 | render_step_size=self.render_step_size,
91 | alpha_thre=0.0,
92 | stratified=self.randomized,
93 | cone_angle=0.0,
94 | early_stop_eps=0,
95 | )
96 | else:
97 | with torch.no_grad():
98 | ray_indices, t_starts_, t_ends_ = self.estimator.sampling(
99 | rays_o_flatten,
100 | rays_d_flatten,
101 | sigma_fn=sigma_fn if self.cfg.prune_alpha_threshold else None,
102 | render_step_size=self.render_step_size,
103 | alpha_thre=0.01 if self.cfg.prune_alpha_threshold else 0.0,
104 | stratified=self.randomized,
105 | cone_angle=0.0,
106 | )
107 | ray_indices, t_starts_, t_ends_ = validate_empty_rays(
108 | ray_indices, t_starts_, t_ends_
109 | )
110 | ray_indices = ray_indices.long()
111 | t_starts, t_ends = t_starts_[..., None], t_ends_[..., None]
112 | t_origins = rays_o_flatten[ray_indices]
113 | t_dirs = rays_d_flatten[ray_indices]
114 | t_light_positions = light_positions_flatten[ray_indices]
115 | t_positions = (t_starts + t_ends) / 2.0
116 | positions = t_origins + t_dirs * t_positions
117 | t_intervals = t_ends - t_starts
118 |
119 | if self.training:
120 | geo_out = self.geometry(
121 | positions, output_normal=True, idx=batch_idx
122 | )
123 | rgb_fg_all = self.material(
124 | viewdirs=t_dirs,
125 | positions=positions,
126 | light_positions=t_light_positions,
127 | **geo_out,
128 | **kwargs
129 | )
130 | comp_rgb_bg = self.background(dirs=rays_d)
131 |
132 | else:
133 | geo_out = chunk_batch(
134 | self.geometry,
135 | self.cfg.eval_chunk_size,
136 | positions,
137 | output_normal=True,
138 | idx=batch_idx,
139 | )
140 | rgb_fg_all = chunk_batch(
141 | self.material,
142 | self.cfg.eval_chunk_size,
143 | viewdirs=t_dirs,
144 | positions=positions,
145 | light_positions=t_light_positions,
146 | **geo_out
147 | )
148 | comp_rgb_bg = chunk_batch(
149 | self.background, self.cfg.eval_chunk_size, dirs=rays_d
150 | )
151 |
152 | weights: Float[Tensor, "Nr 1"]
153 | weights_, _, _ = nerfacc.render_weight_from_density(
154 | t_starts[..., 0],
155 | t_ends[..., 0],
156 | geo_out["density"][..., 0],
157 | ray_indices=ray_indices,
158 | n_rays=n_rays,
159 | )
160 | weights = weights_[..., None]
161 | opacity: Float[Tensor, "Nr 1"] = nerfacc.accumulate_along_rays(
162 | weights[..., 0], values=None, ray_indices=ray_indices, n_rays=n_rays
163 | )
164 | depth: Float[Tensor, "Nr 1"] = nerfacc.accumulate_along_rays(
165 | weights[..., 0], values=t_positions, ray_indices=ray_indices, n_rays=n_rays
166 | )
167 | comp_rgb_fg: Float[Tensor, "Nr Nc"] = nerfacc.accumulate_along_rays(
168 | weights[..., 0], values=rgb_fg_all, ray_indices=ray_indices, n_rays=n_rays
169 | )
170 |
171 | # populate depth and opacity to each point
172 | t_depth = depth[ray_indices]
173 | z_variance = nerfacc.accumulate_along_rays(
174 | weights[..., 0],
175 | values=(t_positions - t_depth) ** 2,
176 | ray_indices=ray_indices,
177 | n_rays=n_rays,
178 | )
179 |
180 | if bg_color is None:
181 | bg_color = comp_rgb_bg
182 | else:
183 | if bg_color.shape[:-1] == (batch_size,):
184 | # e.g. constant random color used for Zero123
185 | # [bs,3] -> [bs, 1, 1, 3]):
186 | bg_color = bg_color.unsqueeze(1).unsqueeze(1)
187 | # -> [bs, height, width, 3]):
188 | bg_color = bg_color.expand(-1, height, width, -1)
189 |
190 | if bg_color.shape[:-1] == (batch_size, height, width):
191 | bg_color = bg_color.reshape(batch_size * height * width, -1)
192 |
193 | comp_rgb = comp_rgb_fg + bg_color * (1.0 - opacity)
194 |
195 | out = {
196 | "comp_rgb": comp_rgb.view(batch_size, height, width, -1),
197 | "comp_rgb_fg": comp_rgb_fg.view(batch_size, height, width, -1),
198 | "comp_rgb_bg": comp_rgb_bg.view(batch_size, height, width, -1),
199 | "opacity": opacity.view(batch_size, height, width, 1),
200 | "depth": depth.view(batch_size, height, width, 1),
201 | "z_variance": z_variance.view(batch_size, height, width, 1),
202 | }
203 |
204 | if self.training:
205 | out.update(
206 | {
207 | "weights": weights,
208 | "t_points": t_positions,
209 | "t_intervals": t_intervals,
210 | "t_dirs": t_dirs,
211 | "ray_indices": ray_indices,
212 | "points": positions,
213 | **geo_out,
214 | }
215 | )
216 | if "normal" in geo_out:
217 | if self.cfg.return_comp_normal:
218 | comp_normal: Float[Tensor, "Nr 3"] = nerfacc.accumulate_along_rays(
219 | weights[..., 0],
220 | values=geo_out["normal"],
221 | ray_indices=ray_indices,
222 | n_rays=n_rays,
223 | )
224 | comp_normal = F.normalize(comp_normal, dim=-1)
225 | comp_normal = (
226 | (comp_normal + 1.0) / 2.0 * opacity
227 | ) # for visualization
228 | out.update(
229 | {
230 | "comp_normal": comp_normal.view(
231 | batch_size, height, width, 3
232 | ),
233 | }
234 | )
235 | if self.cfg.return_normal_perturb:
236 | normal_perturb = self.geometry(
237 | positions + torch.randn_like(positions) * 1e-2,
238 | output_normal=True,
239 | )["normal"]
240 | out.update({"normal_perturb": normal_perturb})
241 | else:
242 | if "normal" in geo_out:
243 | comp_normal = nerfacc.accumulate_along_rays(
244 | weights[..., 0],
245 | values=geo_out["normal"],
246 | ray_indices=ray_indices,
247 | n_rays=n_rays,
248 | )
249 | comp_normal = F.normalize(comp_normal, dim=-1)
250 | comp_normal = (comp_normal + 1.0) / 2.0 * opacity # for visualization
251 | out.update(
252 | {
253 | "comp_normal": comp_normal.view(batch_size, height, width, 3),
254 | }
255 | )
256 |
257 |
258 | return out
259 |
260 | def update_step(
261 | self, epoch: int, global_step: int, on_load_weights: bool = False
262 | ) -> None:
263 |
264 | # print(global_step)
265 | if self.cfg.grid_prune:
266 | def occ_eval_fn(x):
267 | density = self.geometry.forward_density(x)
268 | # approximate for 1 - torch.exp(-density * self.render_step_size) based on taylor series
269 | return density * self.render_step_size
270 |
271 | if self.training and not on_load_weights:
272 | self.estimator.update_every_n_steps(
273 | step=global_step, occ_eval_fn=occ_eval_fn
274 | )
275 | def train(self, mode=True):
276 | self.randomized = mode and self.cfg.randomized
277 | return super().train(mode=mode)
278 |
279 | def eval(self):
280 | self.randomized = False
281 | return super().eval()
282 |
--------------------------------------------------------------------------------
/threestudio/systems/__init__.py:
--------------------------------------------------------------------------------
1 | from . import (
2 | dreamavatar,
3 | )
4 |
--------------------------------------------------------------------------------
/threestudio/systems/dreamavatar.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass, field
3 |
4 | import torch
5 |
6 | import threestudio
7 | from threestudio.systems.base import BaseLift3DSystem
8 | from threestudio.utils.misc import cleanup, get_device
9 | from threestudio.utils.ops import binary_cross_entropy, dot
10 | from threestudio.utils.typing import *
11 |
12 |
13 | @threestudio.register("dreamavatar-system")
14 | class DreamAvatar(BaseLift3DSystem):
15 | @dataclass
16 | class Config(BaseLift3DSystem.Config):
17 | # in ['coarse', 'geometry', 'texture']
18 | stage: str = "coarse"
19 | visualize_samples: bool = False
20 |
21 | cfg: Config
22 |
23 | def configure(self) -> None:
24 | # set up geometry, material, background, renderer
25 | super().configure()
26 |
27 | self.guidance = threestudio.find(self.cfg.guidance_type)(self.cfg.guidance)
28 | self.prompt_processor = threestudio.find(self.cfg.prompt_processor_type)(
29 | self.cfg.prompt_processor
30 | )
31 | self.prompt_utils = self.prompt_processor()
32 |
33 | def forward(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any]:
34 |
35 | render_out = self.renderer(batch_idx, **batch)
36 | return {
37 | **render_out,
38 | }
39 |
40 | def on_fit_start(self) -> None:
41 | super().on_fit_start()
42 |
43 | def training_step(self, batch, batch_idx):
44 | out = self(batch, batch_idx)
45 | guidance_inp = out["comp_rgb"]
46 | guidance_out = self.guidance(
47 | guidance_inp, self.prompt_utils, **batch, rgb_as_latents=False
48 | )
49 |
50 | loss = 0.0
51 |
52 | for name, value in guidance_out.items():
53 | self.log(f"train/{name}", value)
54 | if name.startswith("loss_"):
55 | loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")])
56 |
57 | if self.C(self.cfg.loss.lambda_orient) > 0:
58 | if "normal" not in out:
59 | raise ValueError(
60 | "Normal is required for orientation loss, no normal is found in the output."
61 | )
62 | loss_orient = (
63 | out["weights"].detach()
64 | * dot(out["normal"], out["t_dirs"]).clamp_min(0.0) ** 2
65 | ).sum() / (out["opacity"] > 0).sum()
66 | self.log("train/loss_orient", loss_orient)
67 | loss += loss_orient * self.C(self.cfg.loss.lambda_orient)
68 |
69 | loss_sparsity = (out["opacity"] ** 2 + 0.01).sqrt().mean()
70 | self.log("train/loss_sparsity", loss_sparsity)
71 | loss += loss_sparsity * self.C(self.cfg.loss.lambda_sparsity)
72 |
73 | opacity_clamped = out["opacity"].clamp(1.0e-3, 1.0 - 1.0e-3)
74 | loss_opaque = binary_cross_entropy(opacity_clamped, opacity_clamped)
75 | self.log("train/loss_opaque", loss_opaque)
76 | loss += loss_opaque * self.C(self.cfg.loss.lambda_opaque)
77 |
78 | # z variance loss proposed in HiFA: http://arxiv.org/abs/2305.18766
79 | # helps reduce floaters and produce solid geometry
80 | loss_z_variance = out["z_variance"][out["opacity"] > 0.5].mean()
81 | self.log("train/loss_z_variance", loss_z_variance)
82 | loss += loss_z_variance * self.C(self.cfg.loss.lambda_z_variance)
83 |
84 | for name, value in self.cfg.loss.items():
85 | self.log(f"train_params/{name}", self.C(value))
86 |
87 | return {"loss": loss}
88 |
89 | def validation_step(self, batch, batch_idx):
90 | out = self(batch, batch_idx)
91 | self.save_image_grid(
92 | f"it{self.true_global_step}-{batch['index'][0]}.png",
93 | (
94 | [
95 | {
96 | "type": "rgb",
97 | "img": out["comp_rgb"][0],
98 | "kwargs": {"data_format": "HWC"},
99 | },
100 | ]
101 | if "comp_rgb" in out
102 | else []
103 | )
104 | + (
105 | [
106 | {
107 | "type": "rgb",
108 | "img": out["comp_normal"][0],
109 | "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
110 | }
111 | ]
112 | if "comp_normal" in out
113 | else []
114 | )
115 | + [
116 | {
117 | "type": "grayscale",
118 | "img": out["opacity"][0, :, :, 0],
119 | "kwargs": {"cmap": None, "data_range": (0, 1)},
120 | },
121 | ],
122 | name="validation_step",
123 | step=self.true_global_step,
124 | )
125 |
126 | if self.cfg.visualize_samples:
127 | self.save_image_grid(
128 | f"it{self.true_global_step}-{batch['index'][0]}-sample.png",
129 | [
130 | {
131 | "type": "rgb",
132 | "img": self.guidance.sample(
133 | self.prompt_utils, **batch, seed=self.global_step
134 | )[0],
135 | "kwargs": {"data_format": "HWC"},
136 | },
137 | {
138 | "type": "rgb",
139 | "img": self.guidance.sample_lora(self.prompt_utils, **batch)[0],
140 | "kwargs": {"data_format": "HWC"},
141 | },
142 | ],
143 | name="validation_step_samples",
144 | step=self.true_global_step,
145 | )
146 |
147 | def on_validation_epoch_end(self):
148 | pass
149 |
150 | def test_step(self, batch, batch_idx):
151 | out = self(batch, batch_idx)
152 | self.save_image_grid(
153 | f"it{self.true_global_step}-test/{batch['index'][0]}.png",
154 | (
155 | [
156 | {
157 | "type": "rgb",
158 | "img": out["comp_rgb"][0],
159 | "kwargs": {"data_format": "HWC"},
160 | },
161 | ]
162 | if "comp_rgb" in out
163 | else []
164 | ),
165 | name="test_step",
166 | step=self.true_global_step,
167 | )
168 |
169 | def on_test_epoch_end(self):
170 | self.save_img_sequence(
171 | f"it{self.true_global_step}-test",
172 | f"it{self.true_global_step}-test",
173 | "(\d+)\.png",
174 | save_format="mp4",
175 | fps=30,
176 | name="test",
177 | step=self.true_global_step,
178 | )
179 |
--------------------------------------------------------------------------------
/threestudio/systems/utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import warnings
3 | from bisect import bisect_right
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.optim import lr_scheduler
8 |
9 | import threestudio
10 |
11 |
12 | def get_scheduler(name):
13 | if hasattr(lr_scheduler, name):
14 | return getattr(lr_scheduler, name)
15 | else:
16 | raise NotImplementedError
17 |
18 |
19 | def getattr_recursive(m, attr):
20 | for name in attr.split("."):
21 | m = getattr(m, name)
22 | return m
23 |
24 |
25 | def get_parameters(model, name):
26 | module = getattr_recursive(model, name)
27 | if isinstance(module, nn.Module):
28 | return module.parameters()
29 | elif isinstance(module, nn.Parameter):
30 | return module
31 | return []
32 |
33 |
34 | def parse_optimizer(config, model):
35 | if hasattr(config, "params"):
36 | params = [
37 | {"params": get_parameters(model, name), "name": name, **args}
38 | for name, args in config.params.items()
39 | ]
40 | threestudio.debug(f"Specify optimizer params: {config.params}")
41 | else:
42 | params = model.parameters()
43 | if config.name in ["FusedAdam"]:
44 | import apex
45 |
46 | optim = getattr(apex.optimizers, config.name)(params, **config.args)
47 | elif config.name in ["Adan"]:
48 | from threestudio.systems import optimizers
49 |
50 | optim = getattr(optimizers, config.name)(params, **config.args)
51 | else:
52 | optim = getattr(torch.optim, config.name)(params, **config.args)
53 | return optim
54 |
55 |
56 | def parse_scheduler(config, optimizer):
57 | interval = config.get("interval", "epoch")
58 | assert interval in ["epoch", "step"]
59 | if config.name == "SequentialLR":
60 | scheduler = {
61 | "scheduler": lr_scheduler.SequentialLR(
62 | optimizer,
63 | [
64 | parse_scheduler(conf, optimizer)["scheduler"]
65 | for conf in config.schedulers
66 | ],
67 | milestones=config.milestones,
68 | ),
69 | "interval": interval,
70 | }
71 | elif config.name == "ChainedScheduler":
72 | scheduler = {
73 | "scheduler": lr_scheduler.ChainedScheduler(
74 | [
75 | parse_scheduler(conf, optimizer)["scheduler"]
76 | for conf in config.schedulers
77 | ]
78 | ),
79 | "interval": interval,
80 | }
81 | else:
82 | scheduler = {
83 | "scheduler": get_scheduler(config.name)(optimizer, **config.args),
84 | "interval": interval,
85 | }
86 | return scheduler
87 |
--------------------------------------------------------------------------------
/threestudio/utils/GAN/attention.py:
--------------------------------------------------------------------------------
1 | import math
2 | from inspect import isfunction
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from einops import rearrange, repeat
7 | from torch import einsum, nn
8 |
9 | from threestudio.utils.GAN.network_util import checkpoint
10 |
11 |
12 | def exists(val):
13 | return val is not None
14 |
15 |
16 | def uniq(arr):
17 | return {el: True for el in arr}.keys()
18 |
19 |
20 | def default(val, d):
21 | if exists(val):
22 | return val
23 | return d() if isfunction(d) else d
24 |
25 |
26 | def max_neg_value(t):
27 | return -torch.finfo(t.dtype).max
28 |
29 |
30 | def init_(tensor):
31 | dim = tensor.shape[-1]
32 | std = 1 / math.sqrt(dim)
33 | tensor.uniform_(-std, std)
34 | return tensor
35 |
36 |
37 | # feedforward
38 | class GEGLU(nn.Module):
39 | def __init__(self, dim_in, dim_out):
40 | super().__init__()
41 | self.proj = nn.Linear(dim_in, dim_out * 2)
42 |
43 | def forward(self, x):
44 | x, gate = self.proj(x).chunk(2, dim=-1)
45 | return x * F.gelu(gate)
46 |
47 |
48 | class FeedForward(nn.Module):
49 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
50 | super().__init__()
51 | inner_dim = int(dim * mult)
52 | dim_out = default(dim_out, dim)
53 | project_in = (
54 | nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
55 | if not glu
56 | else GEGLU(dim, inner_dim)
57 | )
58 |
59 | self.net = nn.Sequential(
60 | project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
61 | )
62 |
63 | def forward(self, x):
64 | return self.net(x)
65 |
66 |
67 | def zero_module(module):
68 | """
69 | Zero out the parameters of a module and return it.
70 | """
71 | for p in module.parameters():
72 | p.detach().zero_()
73 | return module
74 |
75 |
76 | def Normalize(in_channels):
77 | return torch.nn.GroupNorm(
78 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
79 | )
80 |
81 |
82 | class LinearAttention(nn.Module):
83 | def __init__(self, dim, heads=4, dim_head=32):
84 | super().__init__()
85 | self.heads = heads
86 | hidden_dim = dim_head * heads
87 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
88 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
89 |
90 | def forward(self, x):
91 | b, c, h, w = x.shape
92 | qkv = self.to_qkv(x)
93 | q, k, v = rearrange(
94 | qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
95 | )
96 | k = k.softmax(dim=-1)
97 | context = torch.einsum("bhdn,bhen->bhde", k, v)
98 | out = torch.einsum("bhde,bhdn->bhen", context, q)
99 | out = rearrange(
100 | out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
101 | )
102 | return self.to_out(out)
103 |
104 |
105 | class SpatialSelfAttention(nn.Module):
106 | def __init__(self, in_channels):
107 | super().__init__()
108 | self.in_channels = in_channels
109 |
110 | self.norm = Normalize(in_channels)
111 | self.q = torch.nn.Conv2d(
112 | in_channels, in_channels, kernel_size=1, stride=1, padding=0
113 | )
114 | self.k = torch.nn.Conv2d(
115 | in_channels, in_channels, kernel_size=1, stride=1, padding=0
116 | )
117 | self.v = torch.nn.Conv2d(
118 | in_channels, in_channels, kernel_size=1, stride=1, padding=0
119 | )
120 | self.proj_out = torch.nn.Conv2d(
121 | in_channels, in_channels, kernel_size=1, stride=1, padding=0
122 | )
123 |
124 | def forward(self, x):
125 | h_ = x
126 | h_ = self.norm(h_)
127 | q = self.q(h_)
128 | k = self.k(h_)
129 | v = self.v(h_)
130 |
131 | # compute attention
132 | b, c, h, w = q.shape
133 | q = rearrange(q, "b c h w -> b (h w) c")
134 | k = rearrange(k, "b c h w -> b c (h w)")
135 | w_ = torch.einsum("bij,bjk->bik", q, k)
136 |
137 | w_ = w_ * (int(c) ** (-0.5))
138 | w_ = torch.nn.functional.softmax(w_, dim=2)
139 |
140 | # attend to values
141 | v = rearrange(v, "b c h w -> b c (h w)")
142 | w_ = rearrange(w_, "b i j -> b j i")
143 | h_ = torch.einsum("bij,bjk->bik", v, w_)
144 | h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
145 | h_ = self.proj_out(h_)
146 |
147 | return x + h_
148 |
149 |
150 | class CrossAttention(nn.Module):
151 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
152 | super().__init__()
153 | inner_dim = dim_head * heads
154 | context_dim = default(context_dim, query_dim)
155 |
156 | self.scale = dim_head**-0.5
157 | self.heads = heads
158 |
159 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
160 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
161 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
162 |
163 | self.to_out = nn.Sequential(
164 | nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
165 | )
166 |
167 | def forward(self, x, context=None, mask=None):
168 | h = self.heads
169 |
170 | q = self.to_q(x)
171 | context = default(context, x)
172 | k = self.to_k(context)
173 | v = self.to_v(context)
174 |
175 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
176 |
177 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
178 |
179 | if exists(mask):
180 | mask = rearrange(mask, "b ... -> b (...)")
181 | max_neg_value = -torch.finfo(sim.dtype).max
182 | mask = repeat(mask, "b j -> (b h) () j", h=h)
183 | sim.masked_fill_(~mask, max_neg_value)
184 |
185 | # attention, what we cannot get enough of
186 | attn = sim.softmax(dim=-1)
187 |
188 | out = einsum("b i j, b j d -> b i d", attn, v)
189 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
190 | return self.to_out(out)
191 |
192 |
193 | class BasicTransformerBlock(nn.Module):
194 | def __init__(
195 | self,
196 | dim,
197 | n_heads,
198 | d_head,
199 | dropout=0.0,
200 | context_dim=None,
201 | gated_ff=True,
202 | checkpoint=True,
203 | ):
204 | super().__init__()
205 | self.attn1 = CrossAttention(
206 | query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
207 | ) # is a self-attention
208 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
209 | self.attn2 = CrossAttention(
210 | query_dim=dim,
211 | context_dim=context_dim,
212 | heads=n_heads,
213 | dim_head=d_head,
214 | dropout=dropout,
215 | ) # is self-attn if context is none
216 | self.norm1 = nn.LayerNorm(dim)
217 | self.norm2 = nn.LayerNorm(dim)
218 | self.norm3 = nn.LayerNorm(dim)
219 | self.checkpoint = checkpoint
220 |
221 | def forward(self, x, context=None):
222 | return checkpoint(
223 | self._forward, (x, context), self.parameters(), self.checkpoint
224 | )
225 |
226 | def _forward(self, x, context=None):
227 | x = self.attn1(self.norm1(x)) + x
228 | x = self.attn2(self.norm2(x), context=context) + x
229 | x = self.ff(self.norm3(x)) + x
230 | return x
231 |
232 |
233 | class SpatialTransformer(nn.Module):
234 | """
235 | Transformer block for image-like data.
236 | First, project the input (aka embedding)
237 | and reshape to b, t, d.
238 | Then apply standard transformer action.
239 | Finally, reshape to image
240 | """
241 |
242 | def __init__(
243 | self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None
244 | ):
245 | super().__init__()
246 | self.in_channels = in_channels
247 | inner_dim = n_heads * d_head
248 | self.norm = Normalize(in_channels)
249 |
250 | self.proj_in = nn.Conv2d(
251 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0
252 | )
253 |
254 | self.transformer_blocks = nn.ModuleList(
255 | [
256 | BasicTransformerBlock(
257 | inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
258 | )
259 | for d in range(depth)
260 | ]
261 | )
262 |
263 | self.proj_out = zero_module(
264 | nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
265 | )
266 |
267 | def forward(self, x, context=None):
268 | # note: if no context is given, cross-attention defaults to self-attention
269 | b, c, h, w = x.shape
270 | x_in = x
271 | x = self.norm(x)
272 | x = self.proj_in(x)
273 | x = rearrange(x, "b c h w -> b (h w) c")
274 | for block in self.transformer_blocks:
275 | x = block(x, context=context)
276 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
277 | x = self.proj_out(x)
278 | return x + x_in
279 |
--------------------------------------------------------------------------------
/threestudio/utils/GAN/discriminator.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | def count_params(model):
8 | total_params = sum(p.numel() for p in model.parameters())
9 | return total_params
10 |
11 |
12 | class ActNorm(nn.Module):
13 | def __init__(
14 | self, num_features, logdet=False, affine=True, allow_reverse_init=False
15 | ):
16 | assert affine
17 | super().__init__()
18 | self.logdet = logdet
19 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
20 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
21 | self.allow_reverse_init = allow_reverse_init
22 |
23 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
24 |
25 | def initialize(self, input):
26 | with torch.no_grad():
27 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
28 | mean = (
29 | flatten.mean(1)
30 | .unsqueeze(1)
31 | .unsqueeze(2)
32 | .unsqueeze(3)
33 | .permute(1, 0, 2, 3)
34 | )
35 | std = (
36 | flatten.std(1)
37 | .unsqueeze(1)
38 | .unsqueeze(2)
39 | .unsqueeze(3)
40 | .permute(1, 0, 2, 3)
41 | )
42 |
43 | self.loc.data.copy_(-mean)
44 | self.scale.data.copy_(1 / (std + 1e-6))
45 |
46 | def forward(self, input, reverse=False):
47 | if reverse:
48 | return self.reverse(input)
49 | if len(input.shape) == 2:
50 | input = input[:, :, None, None]
51 | squeeze = True
52 | else:
53 | squeeze = False
54 |
55 | _, _, height, width = input.shape
56 |
57 | if self.training and self.initialized.item() == 0:
58 | self.initialize(input)
59 | self.initialized.fill_(1)
60 |
61 | h = self.scale * (input + self.loc)
62 |
63 | if squeeze:
64 | h = h.squeeze(-1).squeeze(-1)
65 |
66 | if self.logdet:
67 | log_abs = torch.log(torch.abs(self.scale))
68 | logdet = height * width * torch.sum(log_abs)
69 | logdet = logdet * torch.ones(input.shape[0]).to(input)
70 | return h, logdet
71 |
72 | return h
73 |
74 | def reverse(self, output):
75 | if self.training and self.initialized.item() == 0:
76 | if not self.allow_reverse_init:
77 | raise RuntimeError(
78 | "Initializing ActNorm in reverse direction is "
79 | "disabled by default. Use allow_reverse_init=True to enable."
80 | )
81 | else:
82 | self.initialize(output)
83 | self.initialized.fill_(1)
84 |
85 | if len(output.shape) == 2:
86 | output = output[:, :, None, None]
87 | squeeze = True
88 | else:
89 | squeeze = False
90 |
91 | h = output / self.scale - self.loc
92 |
93 | if squeeze:
94 | h = h.squeeze(-1).squeeze(-1)
95 | return h
96 |
97 |
98 | class AbstractEncoder(nn.Module):
99 | def __init__(self):
100 | super().__init__()
101 |
102 | def encode(self, *args, **kwargs):
103 | raise NotImplementedError
104 |
105 |
106 | class Labelator(AbstractEncoder):
107 | """Net2Net Interface for Class-Conditional Model"""
108 |
109 | def __init__(self, n_classes, quantize_interface=True):
110 | super().__init__()
111 | self.n_classes = n_classes
112 | self.quantize_interface = quantize_interface
113 |
114 | def encode(self, c):
115 | c = c[:, None]
116 | if self.quantize_interface:
117 | return c, None, [None, None, c.long()]
118 | return c
119 |
120 |
121 | class SOSProvider(AbstractEncoder):
122 | # for unconditional training
123 | def __init__(self, sos_token, quantize_interface=True):
124 | super().__init__()
125 | self.sos_token = sos_token
126 | self.quantize_interface = quantize_interface
127 |
128 | def encode(self, x):
129 | # get batch size from data and replicate sos_token
130 | c = torch.ones(x.shape[0], 1) * self.sos_token
131 | c = c.long().to(x.device)
132 | if self.quantize_interface:
133 | return c, None, [None, None, c]
134 | return c
135 |
136 |
137 | def weights_init(m):
138 | classname = m.__class__.__name__
139 | if classname.find("Conv") != -1:
140 | nn.init.normal_(m.weight.data, 0.0, 0.02)
141 | elif classname.find("BatchNorm") != -1:
142 | nn.init.normal_(m.weight.data, 1.0, 0.02)
143 | nn.init.constant_(m.bias.data, 0)
144 |
145 |
146 | class NLayerDiscriminator(nn.Module):
147 | """Defines a PatchGAN discriminator as in Pix2Pix
148 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
149 | """
150 |
151 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
152 | """Construct a PatchGAN discriminator
153 | Parameters:
154 | input_nc (int) -- the number of channels in input images
155 | ndf (int) -- the number of filters in the last conv layer
156 | n_layers (int) -- the number of conv layers in the discriminator
157 | norm_layer -- normalization layer
158 | """
159 | super(NLayerDiscriminator, self).__init__()
160 | if not use_actnorm:
161 | norm_layer = nn.BatchNorm2d
162 | else:
163 | norm_layer = ActNorm
164 | if (
165 | type(norm_layer) == functools.partial
166 | ): # no need to use bias as BatchNorm2d has affine parameters
167 | use_bias = norm_layer.func != nn.BatchNorm2d
168 | else:
169 | use_bias = norm_layer != nn.BatchNorm2d
170 |
171 | kw = 4
172 | padw = 1
173 | sequence = [
174 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
175 | nn.LeakyReLU(0.2, True),
176 | ]
177 | nf_mult = 1
178 | nf_mult_prev = 1
179 | for n in range(1, n_layers): # gradually increase the number of filters
180 | nf_mult_prev = nf_mult
181 | nf_mult = min(2**n, 8)
182 | sequence += [
183 | nn.Conv2d(
184 | ndf * nf_mult_prev,
185 | ndf * nf_mult,
186 | kernel_size=kw,
187 | stride=2,
188 | padding=padw,
189 | bias=use_bias,
190 | ),
191 | norm_layer(ndf * nf_mult),
192 | nn.LeakyReLU(0.2, True),
193 | ]
194 |
195 | nf_mult_prev = nf_mult
196 | nf_mult = min(2**n_layers, 8)
197 | sequence += [
198 | nn.Conv2d(
199 | ndf * nf_mult_prev,
200 | ndf * nf_mult,
201 | kernel_size=kw,
202 | stride=1,
203 | padding=padw,
204 | bias=use_bias,
205 | ),
206 | norm_layer(ndf * nf_mult),
207 | nn.LeakyReLU(0.2, True),
208 | ]
209 |
210 | sequence += [
211 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
212 | ] # output 1 channel prediction map
213 | self.main = nn.Sequential(*sequence)
214 |
215 | def forward(self, input):
216 | """Standard forward."""
217 | return self.main(input)
218 |
--------------------------------------------------------------------------------
/threestudio/utils/GAN/distribution.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(
34 | device=self.parameters.device
35 | )
36 |
37 | def sample(self):
38 | x = self.mean + self.std * torch.randn(self.mean.shape).to(
39 | device=self.parameters.device
40 | )
41 | return x
42 |
43 | def kl(self, other=None):
44 | if self.deterministic:
45 | return torch.Tensor([0.0])
46 | else:
47 | if other is None:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
50 | dim=[1, 2, 3],
51 | )
52 | else:
53 | return 0.5 * torch.sum(
54 | torch.pow(self.mean - other.mean, 2) / other.var
55 | + self.var / other.var
56 | - 1.0
57 | - self.logvar
58 | + other.logvar,
59 | dim=[1, 2, 3],
60 | )
61 |
62 | def nll(self, sample, dims=[1, 2, 3]):
63 | if self.deterministic:
64 | return torch.Tensor([0.0])
65 | logtwopi = np.log(2.0 * np.pi)
66 | return 0.5 * torch.sum(
67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
68 | dim=dims,
69 | )
70 |
71 | def mode(self):
72 | return self.mean
73 |
74 |
75 | def normal_kl(mean1, logvar1, mean2, logvar2):
76 | """
77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
78 | Compute the KL divergence between two gaussians.
79 | Shapes are automatically broadcasted, so batches can be compared to
80 | scalars, among other use cases.
81 | """
82 | tensor = None
83 | for obj in (mean1, logvar1, mean2, logvar2):
84 | if isinstance(obj, torch.Tensor):
85 | tensor = obj
86 | break
87 | assert tensor is not None, "at least one argument must be a Tensor"
88 |
89 | # Force variances to be Tensors. Broadcasting helps convert scalars to
90 | # Tensors, but it does not work for torch.exp().
91 | logvar1, logvar2 = [
92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
93 | for x in (logvar1, logvar2)
94 | ]
95 |
96 | return 0.5 * (
97 | -1.0
98 | + logvar2
99 | - logvar1
100 | + torch.exp(logvar1 - logvar2)
101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
102 | )
103 |
--------------------------------------------------------------------------------
/threestudio/utils/GAN/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def generator_loss(discriminator, inputs, reconstructions, cond=None):
6 | if cond is None:
7 | logits_fake = discriminator(reconstructions.contiguous())
8 | else:
9 | logits_fake = discriminator(
10 | torch.cat((reconstructions.contiguous(), cond), dim=1)
11 | )
12 | g_loss = -torch.mean(logits_fake)
13 | return g_loss
14 |
15 |
16 | def hinge_d_loss(logits_real, logits_fake):
17 | loss_real = torch.mean(F.relu(1.0 - logits_real))
18 | loss_fake = torch.mean(F.relu(1.0 + logits_fake))
19 | d_loss = 0.5 * (loss_real + loss_fake)
20 | return d_loss
21 |
22 |
23 | def discriminator_loss(discriminator, inputs, reconstructions, cond=None):
24 | if cond is None:
25 | logits_real = discriminator(inputs.contiguous().detach())
26 | logits_fake = discriminator(reconstructions.contiguous().detach())
27 | else:
28 | logits_real = discriminator(
29 | torch.cat((inputs.contiguous().detach(), cond), dim=1)
30 | )
31 | logits_fake = discriminator(
32 | torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
33 | )
34 | d_loss = hinge_d_loss(logits_real, logits_fake).mean()
35 | return d_loss
36 |
--------------------------------------------------------------------------------
/threestudio/utils/GAN/mobilenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | __all__ = ["MobileNetV3", "mobilenetv3"]
6 |
7 |
8 | def conv_bn(
9 | inp,
10 | oup,
11 | stride,
12 | conv_layer=nn.Conv2d,
13 | norm_layer=nn.BatchNorm2d,
14 | nlin_layer=nn.ReLU,
15 | ):
16 | return nn.Sequential(
17 | conv_layer(inp, oup, 3, stride, 1, bias=False),
18 | norm_layer(oup),
19 | nlin_layer(inplace=True),
20 | )
21 |
22 |
23 | def conv_1x1_bn(
24 | inp, oup, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, nlin_layer=nn.ReLU
25 | ):
26 | return nn.Sequential(
27 | conv_layer(inp, oup, 1, 1, 0, bias=False),
28 | norm_layer(oup),
29 | nlin_layer(inplace=True),
30 | )
31 |
32 |
33 | class Hswish(nn.Module):
34 | def __init__(self, inplace=True):
35 | super(Hswish, self).__init__()
36 | self.inplace = inplace
37 |
38 | def forward(self, x):
39 | return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0
40 |
41 |
42 | class Hsigmoid(nn.Module):
43 | def __init__(self, inplace=True):
44 | super(Hsigmoid, self).__init__()
45 | self.inplace = inplace
46 |
47 | def forward(self, x):
48 | return F.relu6(x + 3.0, inplace=self.inplace) / 6.0
49 |
50 |
51 | class SEModule(nn.Module):
52 | def __init__(self, channel, reduction=4):
53 | super(SEModule, self).__init__()
54 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
55 | self.fc = nn.Sequential(
56 | nn.Linear(channel, channel // reduction, bias=False),
57 | nn.ReLU(inplace=True),
58 | nn.Linear(channel // reduction, channel, bias=False),
59 | Hsigmoid()
60 | # nn.Sigmoid()
61 | )
62 |
63 | def forward(self, x):
64 | b, c, _, _ = x.size()
65 | y = self.avg_pool(x).view(b, c)
66 | y = self.fc(y).view(b, c, 1, 1)
67 | return x * y.expand_as(x)
68 |
69 |
70 | class Identity(nn.Module):
71 | def __init__(self, channel):
72 | super(Identity, self).__init__()
73 |
74 | def forward(self, x):
75 | return x
76 |
77 |
78 | def make_divisible(x, divisible_by=8):
79 | import numpy as np
80 |
81 | return int(np.ceil(x * 1.0 / divisible_by) * divisible_by)
82 |
83 |
84 | class MobileBottleneck(nn.Module):
85 | def __init__(self, inp, oup, kernel, stride, exp, se=False, nl="RE"):
86 | super(MobileBottleneck, self).__init__()
87 | assert stride in [1, 2]
88 | assert kernel in [3, 5]
89 | padding = (kernel - 1) // 2
90 | self.use_res_connect = stride == 1 and inp == oup
91 |
92 | conv_layer = nn.Conv2d
93 | norm_layer = nn.BatchNorm2d
94 | if nl == "RE":
95 | nlin_layer = nn.ReLU # or ReLU6
96 | elif nl == "HS":
97 | nlin_layer = Hswish
98 | else:
99 | raise NotImplementedError
100 | if se:
101 | SELayer = SEModule
102 | else:
103 | SELayer = Identity
104 |
105 | self.conv = nn.Sequential(
106 | # pw
107 | conv_layer(inp, exp, 1, 1, 0, bias=False),
108 | norm_layer(exp),
109 | nlin_layer(inplace=True),
110 | # dw
111 | conv_layer(exp, exp, kernel, stride, padding, groups=exp, bias=False),
112 | norm_layer(exp),
113 | SELayer(exp),
114 | nlin_layer(inplace=True),
115 | # pw-linear
116 | conv_layer(exp, oup, 1, 1, 0, bias=False),
117 | norm_layer(oup),
118 | )
119 |
120 | def forward(self, x):
121 | if self.use_res_connect:
122 | return x + self.conv(x)
123 | else:
124 | return self.conv(x)
125 |
126 |
127 | class MobileNetV3(nn.Module):
128 | def __init__(
129 | self, n_class=1000, input_size=224, dropout=0.0, mode="small", width_mult=1.0
130 | ):
131 | super(MobileNetV3, self).__init__()
132 | input_channel = 16
133 | last_channel = 1280
134 | if mode == "large":
135 | # refer to Table 1 in paper
136 | mobile_setting = [
137 | # k, exp, c, se, nl, s,
138 | [3, 16, 16, False, "RE", 1],
139 | [3, 64, 24, False, "RE", 2],
140 | [3, 72, 24, False, "RE", 1],
141 | [5, 72, 40, True, "RE", 2],
142 | [5, 120, 40, True, "RE", 1],
143 | [5, 120, 40, True, "RE", 1],
144 | [3, 240, 80, False, "HS", 2],
145 | [3, 200, 80, False, "HS", 1],
146 | [3, 184, 80, False, "HS", 1],
147 | [3, 184, 80, False, "HS", 1],
148 | [3, 480, 112, True, "HS", 1],
149 | [3, 672, 112, True, "HS", 1],
150 | [5, 672, 160, True, "HS", 2],
151 | [5, 960, 160, True, "HS", 1],
152 | [5, 960, 160, True, "HS", 1],
153 | ]
154 | elif mode == "small":
155 | # refer to Table 2 in paper
156 | mobile_setting = [
157 | # k, exp, c, se, nl, s,
158 | [3, 16, 16, True, "RE", 2],
159 | [3, 72, 24, False, "RE", 2],
160 | [3, 88, 24, False, "RE", 1],
161 | [5, 96, 40, True, "HS", 2],
162 | [5, 240, 40, True, "HS", 1],
163 | [5, 240, 40, True, "HS", 1],
164 | [5, 120, 48, True, "HS", 1],
165 | [5, 144, 48, True, "HS", 1],
166 | [5, 288, 96, True, "HS", 2],
167 | [5, 576, 96, True, "HS", 1],
168 | [5, 576, 96, True, "HS", 1],
169 | ]
170 | else:
171 | raise NotImplementedError
172 |
173 | # building first layer
174 | assert input_size % 32 == 0
175 | last_channel = (
176 | make_divisible(last_channel * width_mult)
177 | if width_mult > 1.0
178 | else last_channel
179 | )
180 | self.features = [conv_bn(3, input_channel, 2, nlin_layer=Hswish)]
181 | self.classifier = []
182 |
183 | # building mobile blocks
184 | for k, exp, c, se, nl, s in mobile_setting:
185 | output_channel = make_divisible(c * width_mult)
186 | exp_channel = make_divisible(exp * width_mult)
187 | self.features.append(
188 | MobileBottleneck(
189 | input_channel, output_channel, k, s, exp_channel, se, nl
190 | )
191 | )
192 | input_channel = output_channel
193 |
194 | # building last several layers
195 | if mode == "large":
196 | last_conv = make_divisible(960 * width_mult)
197 | self.features.append(
198 | conv_1x1_bn(input_channel, last_conv, nlin_layer=Hswish)
199 | )
200 | self.features.append(nn.AdaptiveAvgPool2d(1))
201 | self.features.append(nn.Conv2d(last_conv, last_channel, 1, 1, 0))
202 | self.features.append(Hswish(inplace=True))
203 | elif mode == "small":
204 | last_conv = make_divisible(576 * width_mult)
205 | self.features.append(
206 | conv_1x1_bn(input_channel, last_conv, nlin_layer=Hswish)
207 | )
208 | # self.features.append(SEModule(last_conv)) # refer to paper Table2, but I think this is a mistake
209 | self.features.append(nn.AdaptiveAvgPool2d(1))
210 | self.features.append(nn.Conv2d(last_conv, last_channel, 1, 1, 0))
211 | self.features.append(Hswish(inplace=True))
212 | else:
213 | raise NotImplementedError
214 |
215 | # make it nn.Sequential
216 | self.features = nn.Sequential(*self.features)
217 |
218 | # building classifier
219 | self.classifier = nn.Sequential(
220 | nn.Dropout(p=dropout), # refer to paper section 6
221 | nn.Linear(last_channel, n_class),
222 | )
223 |
224 | self._initialize_weights()
225 |
226 | def forward(self, x):
227 | x = self.features(x)
228 | x = x.mean(3).mean(2)
229 | x = self.classifier(x)
230 | return x
231 |
232 | def _initialize_weights(self):
233 | # weight initialization
234 | for m in self.modules():
235 | if isinstance(m, nn.Conv2d):
236 | nn.init.kaiming_normal_(m.weight, mode="fan_out")
237 | if m.bias is not None:
238 | nn.init.zeros_(m.bias)
239 | elif isinstance(m, nn.BatchNorm2d):
240 | nn.init.ones_(m.weight)
241 | nn.init.zeros_(m.bias)
242 | elif isinstance(m, nn.Linear):
243 | nn.init.normal_(m.weight, 0, 0.01)
244 | if m.bias is not None:
245 | nn.init.zeros_(m.bias)
246 |
247 |
248 | def mobilenetv3(pretrained=False, **kwargs):
249 | model = MobileNetV3(**kwargs)
250 | if pretrained:
251 | state_dict = torch.load("mobilenetv3_small_67.4.pth.tar")
252 | model.load_state_dict(state_dict, strict=True)
253 | # raise NotImplementedError
254 | return model
255 |
--------------------------------------------------------------------------------
/threestudio/utils/GAN/network_util.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 |
11 | import math
12 | import os
13 |
14 | import numpy as np
15 | import torch
16 | import torch.nn as nn
17 | from einops import repeat
18 |
19 | from threestudio.utils.GAN.util import instantiate_from_config
20 |
21 |
22 | def make_beta_schedule(
23 | schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
24 | ):
25 | if schedule == "linear":
26 | betas = (
27 | torch.linspace(
28 | linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
29 | )
30 | ** 2
31 | )
32 |
33 | elif schedule == "cosine":
34 | timesteps = (
35 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
36 | )
37 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
38 | alphas = torch.cos(alphas).pow(2)
39 | alphas = alphas / alphas[0]
40 | betas = 1 - alphas[1:] / alphas[:-1]
41 | betas = np.clip(betas, a_min=0, a_max=0.999)
42 |
43 | elif schedule == "sqrt_linear":
44 | betas = torch.linspace(
45 | linear_start, linear_end, n_timestep, dtype=torch.float64
46 | )
47 | elif schedule == "sqrt":
48 | betas = (
49 | torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
50 | ** 0.5
51 | )
52 | else:
53 | raise ValueError(f"schedule '{schedule}' unknown.")
54 | return betas.numpy()
55 |
56 |
57 | def make_ddim_timesteps(
58 | ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
59 | ):
60 | if ddim_discr_method == "uniform":
61 | c = num_ddpm_timesteps // num_ddim_timesteps
62 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
63 | elif ddim_discr_method == "quad":
64 | ddim_timesteps = (
65 | (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
66 | ).astype(int)
67 | else:
68 | raise NotImplementedError(
69 | f'There is no ddim discretization method called "{ddim_discr_method}"'
70 | )
71 |
72 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
73 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
74 | steps_out = ddim_timesteps + 1
75 | if verbose:
76 | print(f"Selected timesteps for ddim sampler: {steps_out}")
77 | return steps_out
78 |
79 |
80 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
81 | # select alphas for computing the variance schedule
82 | alphas = alphacums[ddim_timesteps]
83 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
84 |
85 | # according the the formula provided in https://arxiv.org/abs/2010.02502
86 | sigmas = eta * np.sqrt(
87 | (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
88 | )
89 | if verbose:
90 | print(
91 | f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
92 | )
93 | print(
94 | f"For the chosen value of eta, which is {eta}, "
95 | f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
96 | )
97 | return sigmas, alphas, alphas_prev
98 |
99 |
100 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
101 | """
102 | Create a beta schedule that discretizes the given alpha_t_bar function,
103 | which defines the cumulative product of (1-beta) over time from t = [0,1].
104 | :param num_diffusion_timesteps: the number of betas to produce.
105 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
106 | produces the cumulative product of (1-beta) up to that
107 | part of the diffusion process.
108 | :param max_beta: the maximum beta to use; use values lower than 1 to
109 | prevent singularities.
110 | """
111 | betas = []
112 | for i in range(num_diffusion_timesteps):
113 | t1 = i / num_diffusion_timesteps
114 | t2 = (i + 1) / num_diffusion_timesteps
115 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
116 | return np.array(betas)
117 |
118 |
119 | def extract_into_tensor(a, t, x_shape):
120 | b, *_ = t.shape
121 | out = a.gather(-1, t)
122 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
123 |
124 |
125 | def checkpoint(func, inputs, params, flag):
126 | """
127 | Evaluate a function without caching intermediate activations, allowing for
128 | reduced memory at the expense of extra compute in the backward pass.
129 | :param func: the function to evaluate.
130 | :param inputs: the argument sequence to pass to `func`.
131 | :param params: a sequence of parameters `func` depends on but does not
132 | explicitly take as arguments.
133 | :param flag: if False, disable gradient checkpointing.
134 | """
135 | if flag:
136 | args = tuple(inputs) + tuple(params)
137 | return CheckpointFunction.apply(func, len(inputs), *args)
138 | else:
139 | return func(*inputs)
140 |
141 |
142 | class CheckpointFunction(torch.autograd.Function):
143 | @staticmethod
144 | def forward(ctx, run_function, length, *args):
145 | ctx.run_function = run_function
146 | ctx.input_tensors = list(args[:length])
147 | ctx.input_params = list(args[length:])
148 |
149 | with torch.no_grad():
150 | output_tensors = ctx.run_function(*ctx.input_tensors)
151 | return output_tensors
152 |
153 | @staticmethod
154 | def backward(ctx, *output_grads):
155 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
156 | with torch.enable_grad():
157 | # Fixes a bug where the first op in run_function modifies the
158 | # Tensor storage in place, which is not allowed for detach()'d
159 | # Tensors.
160 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
161 | output_tensors = ctx.run_function(*shallow_copies)
162 | input_grads = torch.autograd.grad(
163 | output_tensors,
164 | ctx.input_tensors + ctx.input_params,
165 | output_grads,
166 | allow_unused=True,
167 | )
168 | del ctx.input_tensors
169 | del ctx.input_params
170 | del output_tensors
171 | return (None, None) + input_grads
172 |
173 |
174 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
175 | """
176 | Create sinusoidal timestep embeddings.
177 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
178 | These may be fractional.
179 | :param dim: the dimension of the output.
180 | :param max_period: controls the minimum frequency of the embeddings.
181 | :return: an [N x dim] Tensor of positional embeddings.
182 | """
183 | if not repeat_only:
184 | half = dim // 2
185 | freqs = torch.exp(
186 | -math.log(max_period)
187 | * torch.arange(start=0, end=half, dtype=torch.float32)
188 | / half
189 | ).to(device=timesteps.device)
190 | args = timesteps[:, None].float() * freqs[None]
191 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
192 | if dim % 2:
193 | embedding = torch.cat(
194 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
195 | )
196 | else:
197 | embedding = repeat(timesteps, "b -> b d", d=dim)
198 | return embedding
199 |
200 |
201 | def zero_module(module):
202 | """
203 | Zero out the parameters of a module and return it.
204 | """
205 | for p in module.parameters():
206 | p.detach().zero_()
207 | return module
208 |
209 |
210 | def scale_module(module, scale):
211 | """
212 | Scale the parameters of a module and return it.
213 | """
214 | for p in module.parameters():
215 | p.detach().mul_(scale)
216 | return module
217 |
218 |
219 | def mean_flat(tensor):
220 | """
221 | Take the mean over all non-batch dimensions.
222 | """
223 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
224 |
225 |
226 | def normalization(channels):
227 | """
228 | Make a standard normalization layer.
229 | :param channels: number of input channels.
230 | :return: an nn.Module for normalization.
231 | """
232 | return GroupNorm32(32, channels)
233 |
234 |
235 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
236 | class SiLU(nn.Module):
237 | def forward(self, x):
238 | return x * torch.sigmoid(x)
239 |
240 |
241 | class GroupNorm32(nn.GroupNorm):
242 | def forward(self, x):
243 | return super().forward(x.float()).type(x.dtype)
244 |
245 |
246 | def conv_nd(dims, *args, **kwargs):
247 | """
248 | Create a 1D, 2D, or 3D convolution module.
249 | """
250 | if dims == 1:
251 | return nn.Conv1d(*args, **kwargs)
252 | elif dims == 2:
253 | return nn.Conv2d(*args, **kwargs)
254 | elif dims == 3:
255 | return nn.Conv3d(*args, **kwargs)
256 | raise ValueError(f"unsupported dimensions: {dims}")
257 |
258 |
259 | def linear(*args, **kwargs):
260 | """
261 | Create a linear module.
262 | """
263 | return nn.Linear(*args, **kwargs)
264 |
265 |
266 | def avg_pool_nd(dims, *args, **kwargs):
267 | """
268 | Create a 1D, 2D, or 3D average pooling module.
269 | """
270 | if dims == 1:
271 | return nn.AvgPool1d(*args, **kwargs)
272 | elif dims == 2:
273 | return nn.AvgPool2d(*args, **kwargs)
274 | elif dims == 3:
275 | return nn.AvgPool3d(*args, **kwargs)
276 | raise ValueError(f"unsupported dimensions: {dims}")
277 |
278 |
279 | class HybridConditioner(nn.Module):
280 | def __init__(self, c_concat_config, c_crossattn_config):
281 | super().__init__()
282 | self.concat_conditioner = instantiate_from_config(c_concat_config)
283 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
284 |
285 | def forward(self, c_concat, c_crossattn):
286 | c_concat = self.concat_conditioner(c_concat)
287 | c_crossattn = self.crossattn_conditioner(c_crossattn)
288 | return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
289 |
290 |
291 | def noise_like(shape, device, repeat=False):
292 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
293 | shape[0], *((1,) * (len(shape) - 1))
294 | )
295 | noise = lambda: torch.randn(shape, device=device)
296 | return repeat_noise() if repeat else noise()
297 |
--------------------------------------------------------------------------------
/threestudio/utils/GAN/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import multiprocessing as mp
3 | from collections import abc
4 | from functools import partial
5 | from inspect import isfunction
6 | from queue import Queue
7 | from threading import Thread
8 |
9 | import numpy as np
10 | import torch
11 | from einops import rearrange
12 | from PIL import Image, ImageDraw, ImageFont
13 |
14 |
15 | def log_txt_as_img(wh, xc, size=10):
16 | # wh a tuple of (width, height)
17 | # xc a list of captions to plot
18 | b = len(xc)
19 | txts = list()
20 | for bi in range(b):
21 | txt = Image.new("RGB", wh, color="white")
22 | draw = ImageDraw.Draw(txt)
23 | font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
24 | nc = int(40 * (wh[0] / 256))
25 | lines = "\n".join(
26 | xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
27 | )
28 |
29 | try:
30 | draw.text((0, 0), lines, fill="black", font=font)
31 | except UnicodeEncodeError:
32 | print("Cant encode string for logging. Skipping.")
33 |
34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
35 | txts.append(txt)
36 | txts = np.stack(txts)
37 | txts = torch.tensor(txts)
38 | return txts
39 |
40 |
41 | def ismap(x):
42 | if not isinstance(x, torch.Tensor):
43 | return False
44 | return (len(x.shape) == 4) and (x.shape[1] > 3)
45 |
46 |
47 | def isimage(x):
48 | if not isinstance(x, torch.Tensor):
49 | return False
50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
51 |
52 |
53 | def exists(x):
54 | return x is not None
55 |
56 |
57 | def default(val, d):
58 | if exists(val):
59 | return val
60 | return d() if isfunction(d) else d
61 |
62 |
63 | def mean_flat(tensor):
64 | """
65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
66 | Take the mean over all non-batch dimensions.
67 | """
68 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
69 |
70 |
71 | def count_params(model, verbose=False):
72 | total_params = sum(p.numel() for p in model.parameters())
73 | if verbose:
74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
75 | return total_params
76 |
77 |
78 | def instantiate_from_config(config):
79 | if not "target" in config:
80 | if config == "__is_first_stage__":
81 | return None
82 | elif config == "__is_unconditional__":
83 | return None
84 | raise KeyError("Expected key `target` to instantiate.")
85 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
86 |
87 |
88 | def get_obj_from_str(string, reload=False):
89 | module, cls = string.rsplit(".", 1)
90 | if reload:
91 | module_imp = importlib.import_module(module)
92 | importlib.reload(module_imp)
93 | return getattr(importlib.import_module(module, package=None), cls)
94 |
95 |
96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
97 | # create dummy dataset instance
98 |
99 | # run prefetching
100 | if idx_to_fn:
101 | res = func(data, worker_id=idx)
102 | else:
103 | res = func(data)
104 | Q.put([idx, res])
105 | Q.put("Done")
106 |
107 |
108 | def parallel_data_prefetch(
109 | func: callable,
110 | data,
111 | n_proc,
112 | target_data_type="ndarray",
113 | cpu_intensive=True,
114 | use_worker_id=False,
115 | ):
116 | # if target_data_type not in ["ndarray", "list"]:
117 | # raise ValueError(
118 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
119 | # )
120 | if isinstance(data, np.ndarray) and target_data_type == "list":
121 | raise ValueError("list expected but function got ndarray.")
122 | elif isinstance(data, abc.Iterable):
123 | if isinstance(data, dict):
124 | print(
125 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
126 | )
127 | data = list(data.values())
128 | if target_data_type == "ndarray":
129 | data = np.asarray(data)
130 | else:
131 | data = list(data)
132 | else:
133 | raise TypeError(
134 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
135 | )
136 |
137 | if cpu_intensive:
138 | Q = mp.Queue(1000)
139 | proc = mp.Process
140 | else:
141 | Q = Queue(1000)
142 | proc = Thread
143 | # spawn processes
144 | if target_data_type == "ndarray":
145 | arguments = [
146 | [func, Q, part, i, use_worker_id]
147 | for i, part in enumerate(np.array_split(data, n_proc))
148 | ]
149 | else:
150 | step = (
151 | int(len(data) / n_proc + 1)
152 | if len(data) % n_proc != 0
153 | else int(len(data) / n_proc)
154 | )
155 | arguments = [
156 | [func, Q, part, i, use_worker_id]
157 | for i, part in enumerate(
158 | [data[i : i + step] for i in range(0, len(data), step)]
159 | )
160 | ]
161 | processes = []
162 | for i in range(n_proc):
163 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
164 | processes += [p]
165 |
166 | # start processes
167 | print(f"Start prefetching...")
168 | import time
169 |
170 | start = time.time()
171 | gather_res = [[] for _ in range(n_proc)]
172 | try:
173 | for p in processes:
174 | p.start()
175 |
176 | k = 0
177 | while k < n_proc:
178 | # get result
179 | res = Q.get()
180 | if res == "Done":
181 | k += 1
182 | else:
183 | gather_res[res[0]] = res[1]
184 |
185 | except Exception as e:
186 | print("Exception: ", e)
187 | for p in processes:
188 | p.terminate()
189 |
190 | raise e
191 | finally:
192 | for p in processes:
193 | p.join()
194 | print(f"Prefetching complete. [{time.time() - start} sec.]")
195 |
196 | if target_data_type == "ndarray":
197 | if not isinstance(gather_res[0], np.ndarray):
198 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
199 |
200 | # order outputs
201 | return np.concatenate(gather_res, axis=0)
202 | elif target_data_type == "list":
203 | out = []
204 | for r in gather_res:
205 | out.extend(r)
206 | return out
207 | else:
208 | return gather_res
209 |
--------------------------------------------------------------------------------
/threestudio/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from . import base
2 |
--------------------------------------------------------------------------------
/threestudio/utils/base.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from threestudio.utils.config import parse_structured
7 | from threestudio.utils.misc import get_device, load_module_weights
8 | from threestudio.utils.typing import *
9 |
10 |
11 | class Configurable:
12 | @dataclass
13 | class Config:
14 | pass
15 |
16 | def __init__(self, cfg: Optional[dict] = None) -> None:
17 | super().__init__()
18 | self.cfg = parse_structured(self.Config, cfg)
19 |
20 |
21 | class Updateable:
22 | def do_update_step(
23 | self, epoch: int, global_step: int, on_load_weights: bool = False
24 | ):
25 | for attr in self.__dir__():
26 | if attr.startswith("_"):
27 | continue
28 | try:
29 | module = getattr(self, attr)
30 | except:
31 | continue # ignore attributes like property, which can't be retrived using getattr?
32 | if isinstance(module, Updateable):
33 | module.do_update_step(
34 | epoch, global_step, on_load_weights=on_load_weights
35 | )
36 | self.update_step(epoch, global_step, on_load_weights=on_load_weights)
37 |
38 | def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
39 | # override this method to implement custom update logic
40 | # if on_load_weights is True, you should be careful doing things related to model evaluations,
41 | # as the models and tensors are not guarenteed to be on the same device
42 | pass
43 |
44 |
45 | def update_if_possible(module: Any, epoch: int, global_step: int) -> None:
46 | if isinstance(module, Updateable):
47 | module.do_update_step(epoch, global_step)
48 |
49 |
50 | class BaseObject(Updateable):
51 | @dataclass
52 | class Config:
53 | pass
54 |
55 | cfg: Config # add this to every subclass of BaseObject to enable static type checking
56 |
57 | def __init__(
58 | self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
59 | ) -> None:
60 | super().__init__()
61 | self.cfg = parse_structured(self.Config, cfg)
62 | self.device = get_device()
63 | self.configure(*args, **kwargs)
64 |
65 | def configure(self, *args, **kwargs) -> None:
66 | pass
67 |
68 |
69 | class BaseModule(nn.Module, Updateable):
70 | @dataclass
71 | class Config:
72 | weights: Optional[str] = None
73 |
74 | cfg: Config # add this to every subclass of BaseModule to enable static type checking
75 |
76 | def __init__(
77 | self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
78 | ) -> None:
79 | super().__init__()
80 | self.cfg = parse_structured(self.Config, cfg)
81 | self.device = get_device()
82 | self.configure(*args, **kwargs)
83 | if self.cfg.weights is not None:
84 | # format: path/to/weights:module_name
85 | weights_path, module_name = self.cfg.weights.split(":")
86 | state_dict, epoch, global_step = load_module_weights(
87 | weights_path, module_name=module_name, map_location="cpu"
88 | )
89 | self.load_state_dict(state_dict)
90 | self.do_update_step(
91 | epoch, global_step, on_load_weights=True
92 | ) # restore states
93 | # dummy tensor to indicate model state
94 | self._dummy: Float[Tensor, "..."]
95 | self.register_buffer("_dummy", torch.zeros(0).float(), persistent=False)
96 |
97 | def configure(self, *args, **kwargs) -> None:
98 | pass
99 |
--------------------------------------------------------------------------------
/threestudio/utils/callbacks.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import subprocess
4 |
5 | import pytorch_lightning
6 |
7 | from threestudio.utils.config import dump_config
8 | from threestudio.utils.misc import parse_version
9 |
10 | if parse_version(pytorch_lightning.__version__) > parse_version("1.8"):
11 | from pytorch_lightning.callbacks import Callback
12 | else:
13 | from pytorch_lightning.callbacks.base import Callback
14 |
15 | from pytorch_lightning.callbacks.progress import TQDMProgressBar
16 | from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
17 |
18 |
19 | class VersionedCallback(Callback):
20 | def __init__(self, save_root, version=None, use_version=True):
21 | self.save_root = save_root
22 | self._version = version
23 | self.use_version = use_version
24 |
25 | @property
26 | def version(self) -> int:
27 | """Get the experiment version.
28 |
29 | Returns:
30 | The experiment version if specified else the next version.
31 | """
32 | if self._version is None:
33 | self._version = self._get_next_version()
34 | return self._version
35 |
36 | def _get_next_version(self):
37 | existing_versions = []
38 | if os.path.isdir(self.save_root):
39 | for f in os.listdir(self.save_root):
40 | bn = os.path.basename(f)
41 | if bn.startswith("version_"):
42 | dir_ver = os.path.splitext(bn)[0].split("_")[1].replace("/", "")
43 | existing_versions.append(int(dir_ver))
44 | if len(existing_versions) == 0:
45 | return 0
46 | return max(existing_versions) + 1
47 |
48 | @property
49 | def savedir(self):
50 | if not self.use_version:
51 | return self.save_root
52 | return os.path.join(
53 | self.save_root,
54 | self.version
55 | if isinstance(self.version, str)
56 | else f"version_{self.version}",
57 | )
58 |
59 |
60 | class CodeSnapshotCallback(VersionedCallback):
61 | def __init__(self, save_root, version=None, use_version=True):
62 | super().__init__(save_root, version, use_version)
63 |
64 | def get_file_list(self):
65 | return [
66 | b.decode()
67 | for b in set(
68 | subprocess.check_output(
69 | 'git ls-files -- ":!:load/*"', shell=True
70 | ).splitlines()
71 | )
72 | | set( # hard code, TODO: use config to exclude folders or files
73 | subprocess.check_output(
74 | "git ls-files --others --exclude-standard", shell=True
75 | ).splitlines()
76 | )
77 | ]
78 |
79 | @rank_zero_only
80 | def save_code_snapshot(self):
81 | os.makedirs(self.savedir, exist_ok=True)
82 | for f in self.get_file_list():
83 | if not os.path.exists(f) or os.path.isdir(f):
84 | continue
85 | os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True)
86 | shutil.copyfile(f, os.path.join(self.savedir, f))
87 |
88 | def on_fit_start(self, trainer, pl_module):
89 | try:
90 | self.save_code_snapshot()
91 | except:
92 | rank_zero_warn(
93 | "Code snapshot is not saved. Please make sure you have git installed and are in a git repository."
94 | )
95 |
96 |
97 | class ConfigSnapshotCallback(VersionedCallback):
98 | def __init__(self, config_path, config, save_root, version=None, use_version=True):
99 | super().__init__(save_root, version, use_version)
100 | self.config_path = config_path
101 | self.config = config
102 |
103 | @rank_zero_only
104 | def save_config_snapshot(self):
105 | os.makedirs(self.savedir, exist_ok=True)
106 | dump_config(os.path.join(self.savedir, "parsed.yaml"), self.config)
107 | shutil.copyfile(self.config_path, os.path.join(self.savedir, "raw.yaml"))
108 |
109 | def on_fit_start(self, trainer, pl_module):
110 | self.save_config_snapshot()
111 |
112 |
113 | class CustomProgressBar(TQDMProgressBar):
114 | def get_metrics(self, *args, **kwargs):
115 | # don't show the version number
116 | items = super().get_metrics(*args, **kwargs)
117 | items.pop("v_num", None)
118 | return items
119 |
120 |
121 | class ProgressCallback(Callback):
122 | def __init__(self, save_path):
123 | super().__init__()
124 | self.save_path = save_path
125 | self._file_handle = None
126 |
127 | @property
128 | def file_handle(self):
129 | if self._file_handle is None:
130 | self._file_handle = open(self.save_path, "w")
131 | return self._file_handle
132 |
133 | @rank_zero_only
134 | def write(self, msg: str) -> None:
135 | self.file_handle.seek(0)
136 | self.file_handle.truncate()
137 | self.file_handle.write(msg)
138 | self.file_handle.flush()
139 |
140 | @rank_zero_only
141 | def on_train_batch_end(self, trainer, pl_module, *args, **kwargs):
142 | self.write(
143 | f"Generation progress: {pl_module.true_global_step / trainer.max_steps * 100:.2f}%"
144 | )
145 |
146 | @rank_zero_only
147 | def on_validation_start(self, trainer, pl_module):
148 | self.write(f"Rendering validation image ...")
149 |
150 | @rank_zero_only
151 | def on_test_start(self, trainer, pl_module):
152 | self.write(f"Rendering video ...")
153 |
154 | @rank_zero_only
155 | def on_predict_start(self, trainer, pl_module):
156 | self.write(f"Exporting mesh assets ...")
157 |
--------------------------------------------------------------------------------
/threestudio/utils/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass, field
3 | from datetime import datetime
4 |
5 | from omegaconf import OmegaConf
6 |
7 | import threestudio
8 | from threestudio.utils.typing import *
9 |
10 | # ============ Register OmegaConf Recolvers ============= #
11 | OmegaConf.register_new_resolver(
12 | "calc_exp_lr_decay_rate", lambda factor, n: factor ** (1.0 / n)
13 | )
14 | OmegaConf.register_new_resolver("add", lambda a, b: a + b)
15 | OmegaConf.register_new_resolver("sub", lambda a, b: a - b)
16 | OmegaConf.register_new_resolver("mul", lambda a, b: a * b)
17 | OmegaConf.register_new_resolver("div", lambda a, b: a / b)
18 | OmegaConf.register_new_resolver("idiv", lambda a, b: a // b)
19 | OmegaConf.register_new_resolver("basename", lambda p: os.path.basename(p))
20 | OmegaConf.register_new_resolver("rmspace", lambda s, sub: s.replace(" ", sub))
21 | OmegaConf.register_new_resolver("tuple2", lambda s: [float(s), float(s)])
22 | OmegaConf.register_new_resolver("gt0", lambda s: s > 0)
23 | OmegaConf.register_new_resolver("cmaxgt0", lambda s: C_max(s) > 0)
24 | OmegaConf.register_new_resolver("not", lambda s: not s)
25 | OmegaConf.register_new_resolver(
26 | "cmaxgt0orcmaxgt0", lambda a, b: C_max(a) > 0 or C_max(b) > 0
27 | )
28 | # ======================================================= #
29 |
30 |
31 | def C_max(value: Any) -> float:
32 | if isinstance(value, int) or isinstance(value, float):
33 | pass
34 | else:
35 | value = config_to_primitive(value)
36 | if not isinstance(value, list):
37 | raise TypeError("Scalar specification only supports list, got", type(value))
38 | if len(value) == 3:
39 | value = [0] + value
40 | assert len(value) == 4
41 | start_step, start_value, end_value, end_step = value
42 | value = max(start_value, end_value)
43 | return value
44 |
45 |
46 | @dataclass
47 | class ExperimentConfig:
48 | name: str = "default"
49 | description: str = ""
50 | tag: str = ""
51 | seed: int = 0
52 | use_timestamp: bool = True
53 | timestamp: Optional[str] = None
54 | exp_root_dir: str = "outputs"
55 |
56 | ### these shouldn't be set manually
57 | exp_dir: str = "outputs/default"
58 | trial_name: str = "exp"
59 | trial_dir: str = "outputs/default/exp"
60 | n_gpus: int = 1
61 | ###
62 |
63 | resume: Optional[str] = None
64 |
65 | data_type: str = ""
66 | data: dict = field(default_factory=dict)
67 |
68 | system_type: str = ""
69 | system: dict = field(default_factory=dict)
70 |
71 | # accept pytorch-lightning trainer parameters
72 | # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api
73 | trainer: dict = field(default_factory=dict)
74 |
75 | # accept pytorch-lightning checkpoint callback parameters
76 | # see https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint
77 | checkpoint: dict = field(default_factory=dict)
78 |
79 | def __post_init__(self):
80 | if not self.tag and not self.use_timestamp:
81 | raise ValueError("Either tag is specified or use_timestamp is True.")
82 | self.trial_name = self.tag
83 | # if resume from an existing config, self.timestamp should not be None
84 | if self.timestamp is None:
85 | self.timestamp = ""
86 | if self.use_timestamp:
87 | if self.n_gpus > 1:
88 | threestudio.warn(
89 | "Timestamp is disabled when using multiple GPUs, please make sure you have a unique tag."
90 | )
91 | else:
92 | self.timestamp = datetime.now().strftime("@%Y%m%d-%H%M%S")
93 | self.trial_name += self.timestamp
94 | self.exp_dir = os.path.join(self.exp_root_dir, self.name)
95 | self.trial_dir = os.path.join(self.exp_dir, self.trial_name)
96 | os.makedirs(self.trial_dir, exist_ok=True)
97 |
98 |
99 | def load_config(*yamls: str, cli_args: list = [], from_string=False, **kwargs) -> Any:
100 | if from_string:
101 | yaml_confs = [OmegaConf.create(s) for s in yamls]
102 | else:
103 | yaml_confs = [OmegaConf.load(f) for f in yamls]
104 | cli_conf = OmegaConf.from_cli(cli_args)
105 | cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs)
106 | OmegaConf.resolve(cfg)
107 | assert isinstance(cfg, DictConfig)
108 | scfg = parse_structured(ExperimentConfig, cfg)
109 | return scfg
110 |
111 |
112 | def config_to_primitive(config, resolve: bool = True) -> Any:
113 | return OmegaConf.to_container(config, resolve=resolve)
114 |
115 |
116 | def dump_config(path: str, config) -> None:
117 | with open(path, "w") as fp:
118 | OmegaConf.save(config=config, f=fp)
119 |
120 |
121 | def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
122 | scfg = OmegaConf.structured(fields(**cfg))
123 | return scfg
124 |
--------------------------------------------------------------------------------
/threestudio/utils/misc.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import os
3 | import re
4 |
5 | import tinycudann as tcnn
6 | import torch
7 | from packaging import version
8 |
9 | from threestudio.utils.config import config_to_primitive
10 | from threestudio.utils.typing import *
11 |
12 |
13 | def parse_version(ver: str):
14 | return version.parse(ver)
15 |
16 |
17 | def get_rank():
18 | # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
19 | # therefore LOCAL_RANK needs to be checked first
20 | rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
21 | for key in rank_keys:
22 | rank = os.environ.get(key)
23 | if rank is not None:
24 | return int(rank)
25 | return 0
26 |
27 |
28 | def get_device():
29 | return torch.device(f"cuda:{get_rank()}")
30 |
31 |
32 | def load_module_weights(
33 | path, module_name=None, ignore_modules=None, map_location=None
34 | ) -> Tuple[dict, int, int]:
35 | if module_name is not None and ignore_modules is not None:
36 | raise ValueError("module_name and ignore_modules cannot be both set")
37 | if map_location is None:
38 | map_location = get_device()
39 |
40 | ckpt = torch.load(path, map_location=map_location)
41 | state_dict = ckpt["state_dict"]
42 | state_dict_to_load = state_dict
43 |
44 | if ignore_modules is not None:
45 | state_dict_to_load = {}
46 | for k, v in state_dict.items():
47 | ignore = any(
48 | [k.startswith(ignore_module + ".") for ignore_module in ignore_modules]
49 | )
50 | if ignore:
51 | continue
52 | state_dict_to_load[k] = v
53 |
54 | if module_name is not None:
55 | state_dict_to_load = {}
56 | for k, v in state_dict.items():
57 | m = re.match(rf"^{module_name}\.(.*)$", k)
58 | if m is None:
59 | continue
60 | state_dict_to_load[m.group(1)] = v
61 |
62 | return state_dict_to_load, ckpt["epoch"], ckpt["global_step"]
63 |
64 |
65 | def C(value: Any, epoch: int, global_step: int) -> float:
66 | if isinstance(value, int) or isinstance(value, float):
67 | pass
68 | else:
69 | value = config_to_primitive(value)
70 | if not isinstance(value, list):
71 | raise TypeError("Scalar specification only supports list, got", type(value))
72 | if len(value) == 3:
73 | value = [0] + value
74 | assert len(value) == 4
75 | start_step, start_value, end_value, end_step = value
76 | if isinstance(end_step, int):
77 | current_step = global_step
78 | value = start_value + (end_value - start_value) * max(
79 | min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0
80 | )
81 | elif isinstance(end_step, float):
82 | current_step = epoch
83 | value = start_value + (end_value - start_value) * max(
84 | min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0
85 | )
86 | return value
87 |
88 |
89 | def cleanup():
90 | gc.collect()
91 | torch.cuda.empty_cache()
92 | tcnn.free_temporary_memory()
93 |
94 |
95 | def finish_with_cleanup(func: Callable):
96 | def wrapper(*args, **kwargs):
97 | out = func(*args, **kwargs)
98 | cleanup()
99 | return out
100 |
101 | return wrapper
102 |
103 |
104 | def _distributed_available():
105 | return torch.distributed.is_available() and torch.distributed.is_initialized()
106 |
107 |
108 | def barrier():
109 | if not _distributed_available():
110 | return
111 | else:
112 | torch.distributed.barrier()
113 |
114 |
115 | def broadcast(tensor, src=0):
116 | if not _distributed_available():
117 | return tensor
118 | else:
119 | torch.distributed.broadcast(tensor, src=src)
120 | return tensor
121 |
--------------------------------------------------------------------------------
/threestudio/utils/perceptual/__init__.py:
--------------------------------------------------------------------------------
1 | from .perceptual import PerceptualLoss
2 |
--------------------------------------------------------------------------------
/threestudio/utils/perceptual/perceptual.py:
--------------------------------------------------------------------------------
1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2 |
3 | from collections import namedtuple
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchvision import models
8 |
9 | from threestudio.utils.perceptual.utils import get_ckpt_path
10 |
11 |
12 | class PerceptualLoss(nn.Module):
13 | # Learned perceptual metric
14 | def __init__(self, use_dropout=True):
15 | super().__init__()
16 | self.scaling_layer = ScalingLayer()
17 | self.chns = [64, 128, 256, 512, 512] # vg16 features
18 | self.net = vgg16(pretrained=True, requires_grad=False)
19 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
20 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
21 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
22 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
23 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
24 | self.load_from_pretrained()
25 | for param in self.parameters():
26 | param.requires_grad = False
27 |
28 | def load_from_pretrained(self, name="vgg_lpips"):
29 | ckpt = get_ckpt_path(name, "threestudio/utils/lpips")
30 | self.load_state_dict(
31 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False
32 | )
33 | print("loaded pretrained LPIPS loss from {}".format(ckpt))
34 |
35 | @classmethod
36 | def from_pretrained(cls, name="vgg_lpips"):
37 | if name != "vgg_lpips":
38 | raise NotImplementedError
39 | model = cls()
40 | ckpt = get_ckpt_path(name)
41 | model.load_state_dict(
42 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False
43 | )
44 | return model
45 |
46 | def forward(self, input, target):
47 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
48 | outs0, outs1 = self.net(in0_input), self.net(in1_input)
49 | feats0, feats1, diffs = {}, {}, {}
50 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
51 | for kk in range(len(self.chns)):
52 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
53 | outs1[kk]
54 | )
55 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
56 |
57 | res = [
58 | spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
59 | for kk in range(len(self.chns))
60 | ]
61 | val = res[0]
62 | for l in range(1, len(self.chns)):
63 | val += res[l]
64 | return val
65 |
66 |
67 | class ScalingLayer(nn.Module):
68 | def __init__(self):
69 | super(ScalingLayer, self).__init__()
70 | self.register_buffer(
71 | "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
72 | )
73 | self.register_buffer(
74 | "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
75 | )
76 |
77 | def forward(self, inp):
78 | return (inp - self.shift) / self.scale
79 |
80 |
81 | class NetLinLayer(nn.Module):
82 | """A single linear layer which does a 1x1 conv"""
83 |
84 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
85 | super(NetLinLayer, self).__init__()
86 | layers = (
87 | [
88 | nn.Dropout(),
89 | ]
90 | if (use_dropout)
91 | else []
92 | )
93 | layers += [
94 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
95 | ]
96 | self.model = nn.Sequential(*layers)
97 |
98 |
99 | class vgg16(torch.nn.Module):
100 | def __init__(self, requires_grad=False, pretrained=True):
101 | super(vgg16, self).__init__()
102 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
103 | self.slice1 = torch.nn.Sequential()
104 | self.slice2 = torch.nn.Sequential()
105 | self.slice3 = torch.nn.Sequential()
106 | self.slice4 = torch.nn.Sequential()
107 | self.slice5 = torch.nn.Sequential()
108 | self.N_slices = 5
109 | for x in range(4):
110 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
111 | for x in range(4, 9):
112 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
113 | for x in range(9, 16):
114 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
115 | for x in range(16, 23):
116 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
117 | for x in range(23, 30):
118 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
119 | if not requires_grad:
120 | for param in self.parameters():
121 | param.requires_grad = False
122 |
123 | def forward(self, X):
124 | h = self.slice1(X)
125 | h_relu1_2 = h
126 | h = self.slice2(h)
127 | h_relu2_2 = h
128 | h = self.slice3(h)
129 | h_relu3_3 = h
130 | h = self.slice4(h)
131 | h_relu4_3 = h
132 | h = self.slice5(h)
133 | h_relu5_3 = h
134 | vgg_outputs = namedtuple(
135 | "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
136 | )
137 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
138 | return out
139 |
140 |
141 | def normalize_tensor(x, eps=1e-10):
142 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
143 | return x / (norm_factor + eps)
144 |
145 |
146 | def spatial_average(x, keepdim=True):
147 | return x.mean([2, 3], keepdim=keepdim)
148 |
--------------------------------------------------------------------------------
/threestudio/utils/perceptual/utils.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 |
4 | import requests
5 | from tqdm import tqdm
6 |
7 | URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
8 |
9 | CKPT_MAP = {"vgg_lpips": "vgg.pth"}
10 |
11 | MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
12 |
13 |
14 | def download(url, local_path, chunk_size=1024):
15 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
16 | with requests.get(url, stream=True) as r:
17 | total_size = int(r.headers.get("content-length", 0))
18 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
19 | with open(local_path, "wb") as f:
20 | for data in r.iter_content(chunk_size=chunk_size):
21 | if data:
22 | f.write(data)
23 | pbar.update(chunk_size)
24 |
25 |
26 | def md5_hash(path):
27 | with open(path, "rb") as f:
28 | content = f.read()
29 | return hashlib.md5(content).hexdigest()
30 |
31 |
32 | def get_ckpt_path(name, root, check=False):
33 | assert name in URL_MAP
34 | path = os.path.join(root, CKPT_MAP[name])
35 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
36 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
37 | download(URL_MAP[name], path)
38 | md5 = md5_hash(path)
39 | assert md5 == MD5_MAP[name], md5
40 | return path
41 |
42 |
43 | class KeyNotFoundError(Exception):
44 | def __init__(self, cause, keys=None, visited=None):
45 | self.cause = cause
46 | self.keys = keys
47 | self.visited = visited
48 | messages = list()
49 | if keys is not None:
50 | messages.append("Key not found: {}".format(keys))
51 | if visited is not None:
52 | messages.append("Visited: {}".format(visited))
53 | messages.append("Cause:\n{}".format(cause))
54 | message = "\n".join(messages)
55 | super().__init__(message)
56 |
57 |
58 | def retrieve(
59 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
60 | ):
61 | """Given a nested list or dict return the desired value at key expanding
62 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion
63 | is done in-place.
64 |
65 | Parameters
66 | ----------
67 | list_or_dict : list or dict
68 | Possibly nested list or dictionary.
69 | key : str
70 | key/to/value, path like string describing all keys necessary to
71 | consider to get to the desired value. List indices can also be
72 | passed here.
73 | splitval : str
74 | String that defines the delimiter between keys of the
75 | different depth levels in `key`.
76 | default : obj
77 | Value returned if :attr:`key` is not found.
78 | expand : bool
79 | Whether to expand callable nodes on the path or not.
80 |
81 | Returns
82 | -------
83 | The desired value or if :attr:`default` is not ``None`` and the
84 | :attr:`key` is not found returns ``default``.
85 |
86 | Raises
87 | ------
88 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
89 | ``None``.
90 | """
91 |
92 | keys = key.split(splitval)
93 |
94 | success = True
95 | try:
96 | visited = []
97 | parent = None
98 | last_key = None
99 | for key in keys:
100 | if callable(list_or_dict):
101 | if not expand:
102 | raise KeyNotFoundError(
103 | ValueError(
104 | "Trying to get past callable node with expand=False."
105 | ),
106 | keys=keys,
107 | visited=visited,
108 | )
109 | list_or_dict = list_or_dict()
110 | parent[last_key] = list_or_dict
111 |
112 | last_key = key
113 | parent = list_or_dict
114 |
115 | try:
116 | if isinstance(list_or_dict, dict):
117 | list_or_dict = list_or_dict[key]
118 | else:
119 | list_or_dict = list_or_dict[int(key)]
120 | except (KeyError, IndexError, ValueError) as e:
121 | raise KeyNotFoundError(e, keys=keys, visited=visited)
122 |
123 | visited += [key]
124 | # final expansion of retrieved value
125 | if expand and callable(list_or_dict):
126 | list_or_dict = list_or_dict()
127 | parent[last_key] = list_or_dict
128 | except KeyNotFoundError as e:
129 | if default is None:
130 | raise e
131 | else:
132 | list_or_dict = default
133 | success = False
134 |
135 | if not pass_success:
136 | return list_or_dict
137 | else:
138 | return list_or_dict, success
139 |
140 |
141 | if __name__ == "__main__":
142 | config = {
143 | "keya": "a",
144 | "keyb": "b",
145 | "keyc": {
146 | "cc1": 1,
147 | "cc2": 2,
148 | },
149 | }
150 | from omegaconf import OmegaConf
151 |
152 | config = OmegaConf.create(config)
153 | print(config)
154 | retrieve(config, "keya")
155 |
--------------------------------------------------------------------------------
/threestudio/utils/rasterize.py:
--------------------------------------------------------------------------------
1 | import nvdiffrast.torch as dr
2 | import torch
3 |
4 | from threestudio.utils.typing import *
5 |
6 |
7 | class NVDiffRasterizerContext:
8 | def __init__(self, context_type: str, device: torch.device) -> None:
9 | self.device = device
10 | self.ctx = self.initialize_context(context_type, device)
11 |
12 | def initialize_context(
13 | self, context_type: str, device: torch.device
14 | ) -> Union[dr.RasterizeGLContext, dr.RasterizeCudaContext]:
15 | if context_type == "gl":
16 | return dr.RasterizeGLContext(device=device)
17 | elif context_type == "cuda":
18 | return dr.RasterizeCudaContext(device=device)
19 | else:
20 | raise ValueError(f"Unknown rasterizer context type: {context_type}")
21 |
22 | def vertex_transform(
23 | self, verts: Float[Tensor, "Nv 3"], mvp_mtx: Float[Tensor, "B 4 4"]
24 | ) -> Float[Tensor, "B Nv 4"]:
25 | verts_homo = torch.cat(
26 | [verts, torch.ones([verts.shape[0], 1]).to(verts)], dim=-1
27 | )
28 | return torch.matmul(verts_homo, mvp_mtx.permute(0, 2, 1))
29 |
30 | def rasterize(
31 | self,
32 | pos: Float[Tensor, "B Nv 4"],
33 | tri: Integer[Tensor, "Nf 3"],
34 | resolution: Union[int, Tuple[int, int]],
35 | ):
36 | # rasterize in instance mode (single topology)
37 | return dr.rasterize(self.ctx, pos.float(), tri.int(), resolution, grad_db=True)
38 |
39 | def rasterize_one(
40 | self,
41 | pos: Float[Tensor, "Nv 4"],
42 | tri: Integer[Tensor, "Nf 3"],
43 | resolution: Union[int, Tuple[int, int]],
44 | ):
45 | # rasterize one single mesh under a single viewpoint
46 | rast, rast_db = self.rasterize(pos[None, ...], tri, resolution)
47 | return rast[0], rast_db[0]
48 |
49 | def antialias(
50 | self,
51 | color: Float[Tensor, "B H W C"],
52 | rast: Float[Tensor, "B H W 4"],
53 | pos: Float[Tensor, "B Nv 4"],
54 | tri: Integer[Tensor, "Nf 3"],
55 | ) -> Float[Tensor, "B H W C"]:
56 | return dr.antialias(color.float(), rast, pos.float(), tri.int())
57 |
58 | def interpolate(
59 | self,
60 | attr: Float[Tensor, "B Nv C"],
61 | rast: Float[Tensor, "B H W 4"],
62 | tri: Integer[Tensor, "Nf 3"],
63 | rast_db=None,
64 | diff_attrs=None,
65 | ) -> Float[Tensor, "B H W C"]:
66 | return dr.interpolate(
67 | attr.float(), rast, tri.int(), rast_db=rast_db, diff_attrs=diff_attrs
68 | )
69 |
70 | def interpolate_one(
71 | self,
72 | attr: Float[Tensor, "Nv C"],
73 | rast: Float[Tensor, "B H W 4"],
74 | tri: Integer[Tensor, "Nf 3"],
75 | rast_db=None,
76 | diff_attrs=None,
77 | ) -> Float[Tensor, "B H W C"]:
78 | return self.interpolate(attr[None, ...], rast, tri, rast_db, diff_attrs)
79 |
--------------------------------------------------------------------------------
/threestudio/utils/typing.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains type annotations for the project, using
3 | 1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects
4 | 2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors
5 |
6 | Two types of typing checking can be used:
7 | 1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode)
8 | 2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking)
9 | """
10 |
11 | # Basic types
12 | from typing import (
13 | Any,
14 | Callable,
15 | Dict,
16 | Iterable,
17 | List,
18 | Literal,
19 | NamedTuple,
20 | NewType,
21 | Optional,
22 | Sized,
23 | Tuple,
24 | Type,
25 | TypeVar,
26 | Union,
27 | )
28 |
29 | # Tensor dtype
30 | # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
31 | from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
32 |
33 | # Config type
34 | from omegaconf import DictConfig
35 |
36 | # PyTorch Tensor type
37 | from torch import Tensor
38 |
39 | # Runtime type checking decorator
40 | from typeguard import typechecked as typechecker
41 |
--------------------------------------------------------------------------------