├── 2D_experiments ├── requirements.txt ├── generate.py └── guidance.py ├── threestudio-sds-bridge ├── __init__.py ├── .pre-commit-config.yaml ├── configs │ └── sds-bridge.yaml ├── .gitignore ├── systems │ └── sds_bridge.py ├── guidance │ └── sds_bridge_guidance.py └── prompt_processors │ └── stable_diffusion_sds_bridge_prompt_processor.py ├── LICENSE └── README.md /2D_experiments/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.13.1 2 | torchvision>=0.14.1 3 | diffusers==0.25.1 4 | transformers==4.37.0 5 | matplotlib 6 | jaxtyping 7 | tqdm 8 | numpy 9 | imageio[pyav] 10 | -------------------------------------------------------------------------------- /threestudio-sds-bridge/__init__.py: -------------------------------------------------------------------------------- 1 | import threestudio 2 | from packaging.version import Version 3 | 4 | if hasattr(threestudio, "__version__") and Version(threestudio.__version__) >= Version( 5 | "0.2.0" 6 | ): 7 | pass 8 | else: 9 | if hasattr(threestudio, "__version__"): 10 | print(f"[INFO] threestudio version: {threestudio.__version__}") 11 | raise ValueError( 12 | "threestudio version must be >= 0.2.0, please update threestudio by pulling the latest version from github" 13 | ) 14 | 15 | from .guidance import sds_bridge_guidance 16 | from .systems import sds_bridge 17 | from .prompt_processors import stable_diffusion_sds_bridge_prompt_processor 18 | -------------------------------------------------------------------------------- /threestudio-sds-bridge/.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.4.0 7 | hooks: 8 | - id: trailing-whitespace 9 | - id: check-ast 10 | - id: check-merge-conflict 11 | - id: check-yaml 12 | - id: end-of-file-fixer 13 | - id: trailing-whitespace 14 | args: [--markdown-linebreak-ext=md] 15 | 16 | - repo: https://github.com/psf/black 17 | rev: 23.3.0 18 | hooks: 19 | - id: black 20 | language_version: python3 21 | 22 | - repo: https://github.com/pycqa/isort 23 | rev: 5.12.0 24 | hooks: 25 | - id: isort 26 | exclude: README.md 27 | args: ["--profile", "black"] 28 | 29 | # temporarily disable static type checking 30 | # - repo: https://github.com/pre-commit/mirrors-mypy 31 | # rev: v1.2.0 32 | # hooks: 33 | # - id: mypy 34 | # args: ["--ignore-missing-imports", "--scripts-are-modules", "--pretty"] 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 David McAllister 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /threestudio-sds-bridge/configs/sds-bridge.yaml: -------------------------------------------------------------------------------- 1 | name: "sds-bridge" 2 | tag: "${rmspace:${system.prompt_processor.prompt},_}" 3 | exp_root_dir: "outputs" 4 | seed: 0 5 | 6 | data_type: "random-camera-datamodule" 7 | data: 8 | batch_size: [1, 1] 9 | # 0-4999: 64x64, >=5000: 512x512 10 | # this drastically reduces VRAM usage as empty space is pruned in early training 11 | width: [64, 512] 12 | height: [64, 512] 13 | resolution_milestones: [5000] 14 | camera_distance_range: [1.0, 1.5] 15 | fovy_range: [40, 70] 16 | elevation_range: [-10, 45] 17 | camera_perturb: 0. 18 | center_perturb: 0. 19 | up_perturb: 0. 20 | eval_camera_distance: 1.5 21 | eval_fovy_deg: 70. 22 | 23 | system_type: "sds-bridge-system" 24 | system: 25 | stage: coarse 26 | geometry_type: "implicit-volume" 27 | geometry: 28 | radius: 1.0 29 | normal_type: null 30 | 31 | density_bias: "blob_magic3d" 32 | density_activation: softplus 33 | density_blob_scale: 10. 34 | density_blob_std: 0.5 35 | 36 | pos_encoding_config: 37 | otype: HashGrid 38 | n_levels: 16 39 | n_features_per_level: 2 40 | log2_hashmap_size: 19 41 | base_resolution: 16 42 | per_level_scale: 1.447269237440378 # max resolution 4096 43 | 44 | material_type: "no-material" 45 | material: 46 | n_output_dims: 3 47 | color_activation: sigmoid 48 | 49 | background_type: "neural-environment-map-background" 50 | background: 51 | color_activation: sigmoid 52 | random_aug: true 53 | 54 | renderer_type: "nerf-volume-renderer" 55 | renderer: 56 | radius: ${system.geometry.radius} 57 | num_samples_per_ray: 512 58 | 59 | prompt_processor_type: "stable-diffusion-sds-bridge-prompt-processor" 60 | prompt_processor: 61 | pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base" 62 | prompt: ??? 63 | src_modifier: "oversaturated, smooth, pixelated, cartoon, foggy, hazy, blurry, bad structure, noisy, malformed" 64 | tgt_modifier: "." 65 | use_modifier_only: false 66 | front_threshold: 30. 67 | back_threshold: 30. 68 | 69 | guidance_type: "stable-diffusion-sds-bridge-guidance" 70 | guidance: 71 | pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base" 72 | guidance_scale: 100. 73 | weighting_strategy: uniform 74 | min_step_percent: 0.02 75 | max_step_percent: [5000, 0.98, 0.5, 5001] # annealed to 0.5 after 5000 steps 76 | half_precision_weights: true 77 | sqrt_anneal: false 78 | stage_one_weight: 1. 79 | stage_two_weight: 1. 80 | stage_two_start_step: 20000 81 | 82 | loggers: 83 | wandb: 84 | enable: false 85 | project: "threestudio" 86 | name: None 87 | 88 | loss: 89 | lambda_sds: 1. 90 | lambda_orient: 0. 91 | lambda_sparsity: 10. 92 | lambda_opaque: [10000, 0.0, 1000.0, 10001] 93 | lambda_z_variance: 0. 94 | optimizer: 95 | name: AdamW 96 | args: 97 | betas: [0.9, 0.99] 98 | eps: 1.e-15 99 | params: 100 | geometry.encoding: 101 | lr: 0.01 102 | geometry.density_network: 103 | lr: 0.001 104 | geometry.feature_network: 105 | lr: 0.001 106 | background: 107 | lr: 0.001 108 | guidance: 109 | lr: 0.0001 110 | 111 | trainer: 112 | max_steps: 25000 113 | log_every_n_steps: 1 114 | num_sanity_val_steps: 0 115 | val_check_interval: 200 116 | enable_progress_bar: true 117 | precision: 32 118 | 119 | checkpoint: 120 | save_last: true 121 | save_top_k: -1 122 | every_n_train_steps: ${trainer.max_steps} 123 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SDS-Bridge 2 | 3 | ### [Project Page](https://sds-bridge.github.io/) | [Paper](https://arxiv.org/abs/2406.09417) 4 | 5 | **TLDR:** A unified framework to explain SDS and its variants, plus a new method that is fast & high-quality. 6 | 7 | https://github.com/davidmcall/SDS-Bridge/assets/50497963/a5af3b0a-8edb-4acf-8c89-02a14451257a 8 | 9 | 10 | ## Experimenting in 3D 11 | 12 | We provide our code for text-based NeRF optimization as an extension in Threestudio. To use it, please first install threestudio following the [official instructions](https://github.com/threestudio-project/threestudio?tab=readme-ov-file#installation). 13 | 14 | ### Extension Installation 15 | 16 | ```bash 17 | cp -r ./threestudio-sds-bridge ../threestudio/custom/ 18 | cd ../threestudio 19 | ``` 20 | 21 | ### Run 3D Optimization 22 | 23 | In the `threestudio` repo... 24 | 25 | ```bash 26 | python launch.py --config custom/threestudio-sds-bridge/configs/sds-bridge.yaml --train --gpu 0 system.prompt_processor.prompt="a pineapple" 27 | ``` 28 | 29 | Some options to play with for sds-bridge guidance are: 30 | * `system.guidance.stage_two_start_step` The step at which to switch to the second stage. 31 | * `system.guidance.stage_two_weight` The weight of the second stage. 32 | * `system.prompt_processor.src_modifier` The prompt modfier that describes the current source distribution, e.g. "oversaturated, smooth, pixelated, cartoon, foggy, hazy, blurry, bad structure, noisy, malformed." 33 | * `system.prompt_processor.tgt_modifier` The prompt modfier that describes the target distribution, e.g. " detailed, high resolution, high quality, sharp." 34 | 35 | 36 | ## Experimenting in 2D 37 | 38 | We offer a simpler installation than Threestudio with minimal dependencies if you just want to run experiments in 2D. This installation guide is adapted from [Nerfstudio](https://github.com/nerfstudio-project/nerfstudio) 39 | 40 | ### Prerequisites 41 | 42 | You must have an NVIDIA video card with CUDA installed on the system. This project has been tested with version 11.8 of CUDA. You can find more information about installing CUDA [here](https://docs.nvidia.com/cuda/cuda-quick-start-guide/index.html) 43 | 44 | ### Create Environment 45 | 46 | This repository requires `python >= 3.8`. We recommend using conda to manage dependencies. Make sure to install [Conda](https://docs.conda.io/miniconda.html) before proceeding. 47 | 48 | ```bash 49 | conda create --name bridge -y python=3.8 50 | conda activate bridge 51 | pip install --upgrade pip 52 | ``` 53 | 54 | ### Dependencies 55 | 56 | Install PyTorch with CUDA 57 | 58 | For CUDA 11.8: 59 | 60 | ```bash 61 | pip install torch==2.1.2+cu118 torchvision==0.16.2+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 62 | ``` 63 | 64 | Install other dependencies with pip: 65 | 66 | ```bash 67 | cd 2D_experiments 68 | pip install -r requirements.txt 69 | ``` 70 | 71 | ### Run 2D Optimization 72 | 73 | In the `2D_experiments` directory... 74 | 75 | ```bash 76 | python generate.py 77 | ``` 78 | 79 | See `generate.py` for more options, including but not limited to: 80 | * `--mode` Choose between SDS-like loss functions [bridge (ours)](https://sds-bridge.github.io/), [SDS](https://dreamfusion3d.github.io), [NFSD](https://orenkatzir.github.io/nfsd/), [VSD](https://ml.cs.tsinghua.edu.cn/prolificdreamer/) 81 | * `--seed` Random seed 82 | * `--lr` Learning rate 83 | * `--cfg_scale` Scale of classifier-free guidance computation 84 | 85 | 86 | 87 | ## Citation 88 | 89 | ``` bibtex 90 | @article{mcallister2024rethinking, 91 | title={Rethinking Score Distillation as a Bridge Between Image Distributions}, 92 | author={David McAllister and Songwei Ge and Jia-Bin Huang and David W. Jacobs and Alexei A. Efros and Aleksander Holynski and Angjoo Kanazawa}, 93 | journal={arXiv preprint arXiv:2406.09417}, 94 | year={2024} 95 | } 96 | ``` 97 | -------------------------------------------------------------------------------- /threestudio-sds-bridge/.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | 166 | ### Python Patch ### 167 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 168 | poetry.toml 169 | 170 | # ruff 171 | .ruff_cache/ 172 | 173 | # LSP config files 174 | pyrightconfig.json 175 | 176 | # End of https://www.toptal.com/developers/gitignore/api/python 177 | 178 | .vscode/ 179 | .threestudio_cache/ 180 | outputs/ 181 | outputs-gradio/ 182 | 183 | # pretrained model weights 184 | *.ckpt 185 | *.pt 186 | *.pth 187 | 188 | # wandb 189 | wandb/ 190 | 191 | custom/* 192 | 193 | load/tets/256_tets.npz 194 | -------------------------------------------------------------------------------- /2D_experiments/generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | 5 | import imageio 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import torch 9 | from guidance import Guidance, GuidanceConfig 10 | from tqdm import tqdm 11 | 12 | device = torch.device("cuda") 13 | 14 | 15 | def seed_everything(seed): 16 | random.seed(seed) 17 | os.environ["PYTHONHASHSEED"] = str(seed) 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed(seed) 21 | 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--prompt", type=str, default="a DSLR photo of a dolphin") 25 | parser.add_argument( 26 | "--extra_src_prompt", 27 | type=str, 28 | default=", oversaturated, smooth, pixelated, cartoon, foggy, hazy, blurry, bad structure, noisy, malformed", 29 | ) 30 | parser.add_argument( 31 | "--extra_tgt_prompt", 32 | type=str, 33 | default=", detailed high resolution, high quality, sharp", 34 | ) 35 | parser.add_argument("--init_image_fn", type=str, default=None) 36 | parser.add_argument( 37 | "--mode", type=str, default="bridge", choices=["bridge", "sds", "nfsd", "vsd"] 38 | ) 39 | parser.add_argument("--cfg_scale", type=float, default=40) 40 | parser.add_argument("--lr", type=float, default=0.01) 41 | parser.add_argument("--seed", type=int, default=0) 42 | parser.add_argument("--n_steps", type=int, default=1000) 43 | parser.add_argument("--stage_two_start_step", type=int, default=500) 44 | args = parser.parse_args() 45 | 46 | init_image_fn = args.init_image_fn 47 | 48 | guidance = Guidance( 49 | GuidanceConfig(sd_pretrained_model_or_path="stabilityai/stable-diffusion-2-1-base"), 50 | use_lora=(args.mode == "vsd"), 51 | ) 52 | 53 | if init_image_fn is not None: 54 | reference = torch.tensor(plt.imread(init_image_fn))[..., :3] 55 | reference = reference.permute(2, 0, 1)[None, ...] 56 | reference = reference.to(guidance.unet.device) 57 | 58 | reference_latent = guidance.encode_image(reference) 59 | im = reference_latent 60 | else: 61 | # Initialize with low-magnitude noise, zeros also works 62 | im = torch.randn((1, 4, 64, 64), device=guidance.unet.device) 63 | 64 | save_dir = "results/%s_gen/%s_lr%.3f_seed%d_scale%.1f" % ( 65 | args.mode, 66 | args.prompt.replace(" ", "_"), 67 | args.lr, 68 | args.seed, 69 | args.cfg_scale, 70 | ) 71 | os.makedirs(save_dir, exist_ok=True) 72 | print("Save dir:", save_dir) 73 | 74 | seed_everything(args.seed) 75 | 76 | 77 | def decode_latent(latent): 78 | latent = latent.detach().to(device) 79 | with torch.no_grad(): 80 | rgb = guidance.decode_latent(latent) 81 | rgb = rgb.float().cpu().permute(0, 2, 3, 1) 82 | rgb = rgb.permute(1, 0, 2, 3) 83 | rgb = rgb.flatten(start_dim=1, end_dim=2) 84 | return rgb 85 | 86 | 87 | batch_size = 1 88 | 89 | im.requires_grad_(True) 90 | im.retain_grad() 91 | 92 | im_optimizer = torch.optim.AdamW([im], lr=args.lr, betas=(0.9, 0.99), eps=1e-15) 93 | if args.mode == "vsd": 94 | lora_optimizer = torch.optim.AdamW( 95 | [ 96 | {"params": guidance.unet_lora.parameters(), "lr": 3e-4}, 97 | ], 98 | weight_decay=0, 99 | ) 100 | 101 | im_opts = [] 102 | 103 | for step in tqdm(range(args.n_steps)): 104 | 105 | guidance.config.guidance_scale = args.cfg_scale 106 | if args.mode == "bridge": 107 | if step < args.stage_two_start_step: 108 | loss_dict = guidance.sds_loss( 109 | im=im, prompt=args.prompt, cfg_scale=args.cfg_scale, return_dict=True 110 | ) 111 | else: 112 | loss_dict = guidance.bridge_stage_two( 113 | im=im, prompt=args.prompt, cfg_scale=args.cfg_scale, return_dict=True 114 | ) 115 | 116 | elif args.mode == "sds": 117 | loss_dict = guidance.sds_loss( 118 | im=im, prompt=args.prompt, cfg_scale=args.cfg_scale, return_dict=True 119 | ) 120 | elif args.mode == "nfsd": 121 | loss_dict = guidance.nfsd_loss( 122 | im=im, prompt=args.prompt, cfg_scale=args.cfg_scale, return_dict=True 123 | ) 124 | elif args.mode == "vsd": 125 | loss_dict = guidance.vsd_loss( 126 | im=im, prompt=args.prompt, cfg_scale=7.5, return_dict=True 127 | ) 128 | lora_loss = loss_dict["lora_loss"] 129 | lora_loss.backward() 130 | lora_optimizer.step() 131 | lora_optimizer.zero_grad() 132 | else: 133 | raise ValueError(args.mode) 134 | 135 | grad = loss_dict["grad"] 136 | src_x0 = loss_dict["src_x0"] if "src_x0" in loss_dict else grad 137 | 138 | im.backward(gradient=grad) 139 | im_optimizer.step() 140 | im_optimizer.zero_grad() 141 | 142 | if step % 10 == 0: 143 | decoded = decode_latent(im.detach()).cpu().numpy() 144 | im_opts.append(decoded) 145 | plt.imsave(os.path.join(save_dir, "debug_image.png"), decoded) 146 | 147 | if step % 100 == 0: 148 | imageio.mimwrite( 149 | os.path.join(save_dir, "debug_optimization.mp4"), 150 | np.stack(im_opts).astype(np.float32) * 255, 151 | fps=10, 152 | codec="libx264", 153 | ) 154 | -------------------------------------------------------------------------------- /threestudio-sds-bridge/systems/sds_bridge.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import imageio 5 | import numpy as np 6 | 7 | import threestudio 8 | from threestudio.systems.base import BaseLift3DSystem 9 | from threestudio.utils.ops import binary_cross_entropy, dot 10 | from threestudio.utils.typing import * 11 | 12 | 13 | @threestudio.register("sds-bridge-system") 14 | class SDSBridge(BaseLift3DSystem): 15 | @dataclass 16 | class Config(BaseLift3DSystem.Config): 17 | stage: str = "coarse" 18 | visualize_samples: bool = False 19 | 20 | cfg: Config 21 | 22 | def configure(self) -> None: 23 | # set up geometry, material, background, renderer 24 | super().configure() 25 | self.guidance = threestudio.find(self.cfg.guidance_type)(self.cfg.guidance) 26 | self.prompt_processor = threestudio.find(self.cfg.prompt_processor_type)( 27 | self.cfg.prompt_processor 28 | ) 29 | self.prompt_utils = self.prompt_processor() 30 | 31 | def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]: 32 | if self.cfg.stage == "geometry": 33 | render_out = self.renderer(**batch, render_rgb=False) 34 | else: 35 | render_out = self.renderer(**batch) 36 | return { 37 | **render_out, 38 | } 39 | 40 | def on_fit_start(self) -> None: 41 | super().on_fit_start() 42 | 43 | def training_step(self, batch, batch_idx): 44 | out = self(batch) 45 | 46 | if self.true_global_step == self.guidance.cfg.stage_two_start_step: 47 | threestudio.info(f"Moving to stage 2 at step {self.true_global_step}") 48 | self.guidance.phase_id = 2 49 | 50 | 51 | if self.cfg.stage == "geometry": 52 | guidance_inp = out["comp_normal"] 53 | guidance_out = self.guidance( 54 | guidance_inp, self.prompt_utils, **batch, rgb_as_latents=False 55 | ) 56 | else: 57 | guidance_inp = out["comp_rgb"] 58 | guidance_out = self.guidance( 59 | guidance_inp, self.prompt_utils, **batch, rgb_as_latents=False 60 | ) 61 | 62 | loss = 0.0 63 | 64 | for name, value in guidance_out.items(): 65 | if not (type(value) is torch.Tensor and value.numel() > 1): 66 | self.log(f"train/{name}", value) 67 | if name.startswith("loss_"): 68 | loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")]) 69 | 70 | if self.cfg.stage == "coarse": 71 | if self.C(self.cfg.loss.lambda_orient) > 0: 72 | if "normal" not in out: 73 | raise ValueError( 74 | "Normal is required for orientation loss, no normal is found in the output." 75 | ) 76 | loss_orient = ( 77 | out["weights"].detach() 78 | * dot(out["normal"], out["t_dirs"]).clamp_min(0.0) ** 2 79 | ).sum() / (out["opacity"] > 0).sum() 80 | self.log("train/loss_orient", loss_orient) 81 | loss += loss_orient * self.C(self.cfg.loss.lambda_orient) 82 | 83 | loss_sparsity = (out["opacity"] ** 2 + 0.01).sqrt().mean() 84 | self.log("train/loss_sparsity", loss_sparsity) 85 | loss += loss_sparsity * self.C(self.cfg.loss.lambda_sparsity) 86 | 87 | opacity_clamped = out["opacity"].clamp(1.0e-3, 1.0 - 1.0e-3) 88 | loss_opaque = binary_cross_entropy(opacity_clamped, opacity_clamped) 89 | self.log("train/loss_opaque", loss_opaque) 90 | loss += loss_opaque * self.C(self.cfg.loss.lambda_opaque) 91 | 92 | # z variance loss proposed in HiFA: http://arxiv.org/abs/2305.18766 93 | # helps reduce floaters and produce solid geometry 94 | if "z_variance" in out: 95 | loss_z_variance = out["z_variance"][out["opacity"] > 0.5].mean() 96 | self.log("train/loss_z_variance", loss_z_variance) 97 | loss += loss_z_variance * self.C(self.cfg.loss.lambda_z_variance) 98 | 99 | # sdf loss 100 | if "sdf_grad" in out: 101 | loss_eikonal = ( 102 | (torch.linalg.norm(out["sdf_grad"], ord=2, dim=-1) - 1.0) ** 2 103 | ).mean() 104 | self.log("train/loss_eikonal", loss_eikonal) 105 | loss += loss_eikonal * self.C(self.cfg.loss.lambda_eikonal) 106 | self.log("train/inv_std", out["inv_std"], prog_bar=True) 107 | 108 | elif self.cfg.stage == "geometry": 109 | loss_normal_consistency = out["mesh"].normal_consistency() 110 | self.log("train/loss_normal_consistency", loss_normal_consistency) 111 | loss += loss_normal_consistency * self.C( 112 | self.cfg.loss.lambda_normal_consistency 113 | ) 114 | 115 | if self.C(self.cfg.loss.lambda_laplacian_smoothness) > 0: 116 | loss_laplacian_smoothness = out["mesh"].laplacian() 117 | self.log("train/loss_laplacian_smoothness", loss_laplacian_smoothness) 118 | loss += loss_laplacian_smoothness * self.C( 119 | self.cfg.loss.lambda_laplacian_smoothness 120 | ) 121 | elif self.cfg.stage == "texture": 122 | pass 123 | else: 124 | raise ValueError(f"Unknown stage {self.cfg.stage}") 125 | 126 | for name, value in self.cfg.loss.items(): 127 | self.log(f"train_params/{name}", self.C(value)) 128 | 129 | return {"loss": loss} 130 | 131 | def validation_step(self, batch, batch_idx): 132 | out = self(batch) 133 | self.save_image_grid( 134 | f"it{self.true_global_step}-{batch['index'][0]}.png", 135 | ( 136 | [ 137 | { 138 | "type": "rgb", 139 | "img": out["comp_rgb"][0], 140 | "kwargs": {"data_format": "HWC"}, 141 | }, 142 | ] 143 | if "comp_rgb" in out 144 | else [] 145 | ) 146 | + ( 147 | [ 148 | { 149 | "type": "rgb", 150 | "img": out["comp_normal"][0], 151 | "kwargs": {"data_format": "HWC", "data_range": (0, 1)}, 152 | } 153 | ] 154 | if "comp_normal" in out 155 | else [] 156 | ) 157 | + [ 158 | { 159 | "type": "grayscale", 160 | "img": out["opacity"][0, :, :, 0], 161 | "kwargs": {"cmap": None, "data_range": (0, 1)}, 162 | }, 163 | ], 164 | name="validation_step", 165 | step=self.true_global_step, 166 | ) 167 | 168 | if self.cfg.visualize_samples: 169 | self.save_image_grid( 170 | f"it{self.true_global_step}-{batch['index'][0]}-sample.png", 171 | [ 172 | { 173 | "type": "rgb", 174 | "img": self.guidance.sample( 175 | self.prompt_utils, **batch, seed=self.global_step 176 | )[0], 177 | "kwargs": {"data_format": "HWC"}, 178 | }, 179 | { 180 | "type": "rgb", 181 | "img": self.guidance.sample_lora(self.prompt_utils, **batch)[0], 182 | "kwargs": {"data_format": "HWC"}, 183 | }, 184 | ], 185 | name="validation_step_samples", 186 | step=self.true_global_step, 187 | ) 188 | 189 | def on_validation_epoch_end(self): 190 | pass 191 | 192 | def test_step(self, batch, batch_idx): 193 | out = self(batch) 194 | self.save_image_grid( 195 | f"it{self.true_global_step}-test/{batch['index'][0]}.png", 196 | ( 197 | [ 198 | { 199 | "type": "rgb", 200 | "img": out["comp_rgb"][0], 201 | "kwargs": {"data_format": "HWC"}, 202 | }, 203 | ] 204 | if "comp_rgb" in out 205 | else [] 206 | ) 207 | + ( 208 | [ 209 | { 210 | "type": "rgb", 211 | "img": out["comp_normal"][0], 212 | "kwargs": {"data_format": "HWC", "data_range": (0, 1)}, 213 | } 214 | ] 215 | if "comp_normal" in out 216 | else [] 217 | ) 218 | + [ 219 | { 220 | "type": "grayscale", 221 | "img": out["opacity"][0, :, :, 0], 222 | "kwargs": {"cmap": None, "data_range": (0, 1)}, 223 | }, 224 | ], 225 | name="test_step", 226 | step=self.true_global_step, 227 | ) 228 | 229 | def on_test_epoch_end(self): 230 | self.save_img_sequence( 231 | f"it{self.true_global_step}-test", 232 | f"it{self.true_global_step}-test", 233 | "(\d+)\.png", 234 | save_format="mp4", 235 | fps=30, 236 | name="test", 237 | step=self.true_global_step, 238 | ) 239 | -------------------------------------------------------------------------------- /2D_experiments/guidance.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from contextlib import contextmanager 3 | from dataclasses import dataclass 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from diffusers import DDIMScheduler, DiffusionPipeline 9 | from diffusers.loaders import AttnProcsLayers 10 | from diffusers.models.attention_processor import LoRAAttnProcessor 11 | from diffusers.models.embeddings import TimestepEmbedding 12 | from jaxtyping import Float 13 | 14 | 15 | def cleanup(): 16 | gc.collect() 17 | torch.cuda.empty_cache() 18 | 19 | 20 | class ToWeightsDType(torch.nn.Module): 21 | def __init__(self, module: torch.nn.Module, dtype: torch.dtype): 22 | super().__init__() 23 | self.module = module 24 | self.dtype = dtype 25 | 26 | def forward(self, x: Float[torch.Tensor, "..."]) -> Float[torch.Tensor, "..."]: 27 | return self.module(x).to(self.dtype) 28 | 29 | 30 | @dataclass 31 | class GuidanceConfig: 32 | sd_pretrained_model_or_path: str = "runwayml/stable-diffusion-v2-1-base" 33 | sd_pretrained_model_or_path_lora: str = "stabilityai/stable-diffusion-2-1" 34 | 35 | num_inference_steps: int = 500 36 | min_step_ratio: float = 0.02 37 | max_step_ratio: float = 0.98 38 | 39 | src_prompt: str = "" 40 | tgt_prompt: str = "" 41 | 42 | guidance_scale: float = 30 43 | guidance_scale_lora: float = 1.0 44 | sdedit_guidance_scale: float = 15 45 | device: torch.device = torch.device("cuda") 46 | lora_n_timestamp_samples: int = 1 47 | 48 | sync_noise_and_t: bool = True 49 | lora_cfg_training: bool = True 50 | 51 | 52 | class Guidance(object): 53 | def __init__(self, config: GuidanceConfig, use_lora: bool = False): 54 | self.config = config 55 | self.device = torch.device(config.device) 56 | 57 | self.pipe = DiffusionPipeline.from_pretrained( 58 | config.sd_pretrained_model_or_path 59 | ).to(self.device) 60 | 61 | self.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) 62 | self.scheduler.set_timesteps(config.num_inference_steps) 63 | self.pipe.scheduler = self.scheduler 64 | 65 | self.unet = self.pipe.unet 66 | self.tokenizer = self.pipe.tokenizer 67 | self.text_encoder = self.pipe.text_encoder 68 | self.vae = self.pipe.vae 69 | 70 | self.unet.requires_grad_(False) 71 | self.text_encoder.requires_grad_(False) 72 | self.vae.requires_grad_(False) 73 | 74 | ## construct text features beforehand. 75 | self.src_prompt = self.config.src_prompt 76 | self.tgt_prompt = self.config.tgt_prompt 77 | 78 | self.update_text_features( 79 | src_prompt=self.src_prompt, tgt_prompt=self.tgt_prompt 80 | ) 81 | self.null_text_feature = self.encode_text("") 82 | 83 | if use_lora: 84 | self.pipe_lora = DiffusionPipeline.from_pretrained( 85 | config.sd_pretrained_model_or_path_lora 86 | ).to(self.device) 87 | self.single_model = False 88 | del self.pipe_lora.vae 89 | del self.pipe_lora.text_encoder 90 | cleanup() 91 | self.vae_lora = self.pipe_lora.vae = self.pipe.vae 92 | self.unet_lora = self.pipe_lora.unet 93 | for p in self.unet_lora.parameters(): 94 | p.requires_grad_(False) 95 | # FIXME: hard-coded dims 96 | self.camera_embedding = TimestepEmbedding(16, 1280).to(self.device) 97 | self.unet_lora.class_embedding = self.camera_embedding 98 | self.scheduler_lora = DDIMScheduler.from_config( 99 | self.pipe_lora.scheduler.config 100 | ) 101 | self.scheduler_lora.set_timesteps(config.num_inference_steps) 102 | self.pipe_lora.scheduler = self.scheduler_lora 103 | 104 | # set up LoRA layers 105 | lora_attn_procs = {} 106 | for name in self.unet_lora.attn_processors.keys(): 107 | cross_attention_dim = ( 108 | None 109 | if name.endswith("attn1.processor") 110 | else self.unet_lora.config.cross_attention_dim 111 | ) 112 | if name.startswith("mid_block"): 113 | hidden_size = self.unet_lora.config.block_out_channels[-1] 114 | elif name.startswith("up_blocks"): 115 | block_id = int(name[len("up_blocks.")]) 116 | hidden_size = list( 117 | reversed(self.unet_lora.config.block_out_channels) 118 | )[block_id] 119 | elif name.startswith("down_blocks"): 120 | block_id = int(name[len("down_blocks.")]) 121 | hidden_size = self.unet_lora.config.block_out_channels[block_id] 122 | 123 | lora_attn_procs[name] = LoRAAttnProcessor( 124 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim 125 | ) 126 | 127 | self.unet_lora.set_attn_processor(lora_attn_procs) 128 | 129 | self.lora_layers = AttnProcsLayers(self.unet_lora.attn_processors).to( 130 | self.device 131 | ) 132 | self.lora_layers._load_state_dict_pre_hooks.clear() 133 | self.lora_layers._state_dict_hooks.clear() 134 | 135 | def encode_image(self, img_tensor: Float[torch.Tensor, "B C H W"]): 136 | x = img_tensor 137 | x = 2 * x - 1 138 | x = x.float() 139 | return self.vae.encode(x).latent_dist.sample() * 0.18215 140 | 141 | def encode_text(self, prompt): 142 | text_input = self.tokenizer( 143 | prompt, 144 | padding="max_length", 145 | max_length=self.pipe.tokenizer.model_max_length, 146 | truncation=True, 147 | return_tensors="pt", 148 | ) 149 | text_encoding = self.text_encoder(text_input.input_ids.to(self.device))[0] 150 | return text_encoding 151 | 152 | def decode_latent(self, latent): 153 | x = self.vae.decode(latent / 0.18215).sample 154 | x = (x / 2 + 0.5).clamp(0, 1) 155 | return x 156 | 157 | def update_text_features(self, src_prompt=None, tgt_prompt=None): 158 | if getattr(self, "src_text_feature", None) is None: 159 | assert src_prompt is not None 160 | self.src_prompt = src_prompt 161 | self.src_text_feature = self.encode_text(src_prompt) 162 | else: 163 | if src_prompt is not None and src_prompt != self.src_prompt: 164 | self.src_prompt = src_prompt 165 | self.src_text_feature = self.encode_text(src_prompt) 166 | 167 | if getattr(self, "tgt_text_feature", None) is None: 168 | assert tgt_prompt is not None 169 | self.tgt_prompt = tgt_prompt 170 | self.tgt_text_feature = self.encode_text(tgt_prompt) 171 | else: 172 | if tgt_prompt is not None and tgt_prompt != self.tgt_prompt: 173 | self.tgt_prompt = tgt_prompt 174 | self.tgt_text_feature = self.encode_text(tgt_prompt) 175 | 176 | def sample_timestep(self, batch_size): 177 | self.scheduler.set_timesteps(self.config.num_inference_steps) 178 | timesteps = reversed(self.scheduler.timesteps) 179 | 180 | min_step = ( 181 | 1 182 | if self.config.min_step_ratio <= 0 183 | else int(len(timesteps) * self.config.min_step_ratio) 184 | ) 185 | max_step = ( 186 | len(timesteps) 187 | if self.config.max_step_ratio >= 1 188 | else int(len(timesteps) * self.config.max_step_ratio) 189 | ) 190 | max_step = max(max_step, min_step + 1) 191 | idx = torch.randint( 192 | min_step, 193 | max_step, 194 | [batch_size], 195 | dtype=torch.long, 196 | device="cpu", 197 | ) 198 | t = timesteps[idx].cpu() 199 | t_prev = timesteps[idx - 1].cpu() 200 | 201 | return t, t_prev 202 | 203 | def sds_loss( 204 | self, 205 | im, 206 | prompt=None, 207 | reduction="mean", 208 | cfg_scale=100, 209 | noise=None, 210 | return_dict=False, 211 | ): 212 | device = self.device 213 | scheduler = self.scheduler 214 | 215 | # process text. 216 | self.update_text_features(tgt_prompt=prompt) 217 | tgt_text_embedding = self.tgt_text_feature 218 | uncond_embedding = self.null_text_feature 219 | 220 | batch_size = im.shape[0] 221 | t, _ = self.sample_timestep(batch_size) 222 | 223 | if noise is None: 224 | noise = torch.randn_like(im) 225 | 226 | latents_noisy = scheduler.add_noise(im, noise, t) 227 | latent_model_input = torch.cat([latents_noisy] * 2, dim=0) 228 | text_embeddings = torch.cat([tgt_text_embedding, uncond_embedding], dim=0) 229 | noise_pred = self.unet.forward( 230 | latent_model_input, 231 | torch.cat([t] * 2).to(device), 232 | encoder_hidden_states=text_embeddings, 233 | ).sample 234 | noise_pred_text, noise_pred_uncond = noise_pred.chunk(2) 235 | noise_pred = noise_pred_uncond + cfg_scale * ( 236 | noise_pred_text - noise_pred_uncond 237 | ) 238 | 239 | w = 1 - scheduler.alphas_cumprod[t].to(device) 240 | grad = w * (noise_pred - noise) 241 | grad = torch.nan_to_num(grad) 242 | target = (im - grad).detach() 243 | loss = 0.5 * F.mse_loss(im, target, reduction=reduction) / batch_size 244 | if return_dict: 245 | dic = {"loss": loss, "grad": grad, "t": t} 246 | return dic 247 | else: 248 | return loss 249 | 250 | def bridge_stage_two( 251 | self, 252 | im, 253 | prompt=None, 254 | reduction="mean", 255 | cfg_scale=30, 256 | extra_tgt_prompts=", detailed high resolution, high quality, sharp", 257 | extra_src_prompts=", oversaturated, smooth, pixelated, cartoon, foggy, hazy, blurry, bad structure, noisy, malformed", 258 | noise=None, 259 | return_dict=False, 260 | ): 261 | device = self.device 262 | scheduler = self.scheduler 263 | 264 | # process text. 265 | self.update_text_features( 266 | tgt_prompt=prompt + extra_tgt_prompts, src_prompt=prompt + extra_src_prompts 267 | ) 268 | tgt_text_embedding = self.tgt_text_feature 269 | src_text_embedding = self.src_text_feature 270 | 271 | batch_size = im.shape[0] 272 | t, _ = self.sample_timestep(batch_size) 273 | 274 | if noise is None: 275 | noise = torch.randn_like(im) 276 | 277 | latents_noisy = scheduler.add_noise(im, noise, t) 278 | latent_model_input = torch.cat([latents_noisy] * 2, dim=0) 279 | text_embeddings = torch.cat([tgt_text_embedding, src_text_embedding], dim=0) 280 | noise_pred = self.unet.forward( 281 | latent_model_input, 282 | torch.cat([t] * 2).to(device), 283 | encoder_hidden_states=text_embeddings, 284 | ).sample 285 | noise_pred_tgt, noise_pred_src = noise_pred.chunk(2) 286 | 287 | w = 1 - scheduler.alphas_cumprod[t].to(device) 288 | grad = w * cfg_scale * (noise_pred_tgt - noise_pred_src) 289 | grad = torch.nan_to_num(grad) 290 | target = (im - grad).detach() 291 | loss = 0.5 * F.mse_loss(im, target, reduction=reduction) / batch_size 292 | if return_dict: 293 | dic = {"loss": loss, "grad": grad, "t": t} 294 | return dic 295 | else: 296 | return loss 297 | 298 | def nfsd_loss( 299 | self, 300 | im, 301 | prompt=None, 302 | reduction="mean", 303 | cfg_scale=100, 304 | return_dict=False, 305 | ): 306 | device = self.device 307 | scheduler = self.scheduler 308 | 309 | batch_size = im.shape[0] 310 | t, _ = self.sample_timestep(batch_size) 311 | 312 | noise = torch.randn_like(im) 313 | 314 | latents_noisy = scheduler.add_noise(im, noise, t) 315 | latent_model_input = torch.cat([latents_noisy] * 2, dim=0) 316 | 317 | # process text. 318 | self.update_text_features(tgt_prompt=prompt) 319 | tgt_text_embedding = self.tgt_text_feature 320 | uncond_embedding = self.null_text_feature 321 | with torch.no_grad(): 322 | text_embeddings = torch.cat([tgt_text_embedding, uncond_embedding], dim=0) 323 | noise_pred = self.unet.forward( 324 | latent_model_input, 325 | torch.cat([t] * 2).to(device), 326 | encoder_hidden_states=text_embeddings, 327 | ).sample 328 | noise_pred_text, noise_pred_uncond = noise_pred.chunk(2) 329 | delta_C = cfg_scale * (noise_pred_text - noise_pred_uncond) 330 | 331 | self.update_text_features( 332 | tgt_prompt="unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, dark, low-resolution, gloomy" 333 | ) 334 | tgt_text_embedding = self.tgt_text_feature 335 | uncond_embedding = self.null_text_feature 336 | with torch.no_grad(): 337 | text_embeddings = torch.cat([tgt_text_embedding, uncond_embedding], dim=0) 338 | noise_pred = self.unet.forward( 339 | latent_model_input, 340 | torch.cat([t] * 2).to(device), 341 | encoder_hidden_states=text_embeddings, 342 | ).sample 343 | noise_pred_text_neg, _ = noise_pred.chunk(2) 344 | 345 | delta_D = ( 346 | noise_pred_uncond if t < 200 else (noise_pred_uncond - noise_pred_text_neg) 347 | ) 348 | 349 | w = 1 - scheduler.alphas_cumprod[t].to(device) 350 | grad = w * (delta_C + delta_D) 351 | grad = torch.nan_to_num(grad) 352 | target = (im - grad).detach() 353 | loss = 0.5 * F.mse_loss(im, target, reduction=reduction) / batch_size 354 | if return_dict: 355 | dic = {"loss": loss, "grad": grad, "t": t} 356 | return dic 357 | else: 358 | return loss 359 | 360 | def get_variance(self, timestep, scheduler=None): 361 | 362 | if scheduler is None: 363 | scheduler = self.scheduler 364 | 365 | prev_timestep = ( 366 | timestep 367 | - scheduler.config.num_train_timesteps // scheduler.num_inference_steps 368 | ) 369 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] 370 | alpha_prod_t_prev = ( 371 | scheduler.alphas_cumprod[prev_timestep] 372 | if prev_timestep >= 0 373 | else scheduler.final_alpha_cumprod 374 | ) 375 | beta_prod_t = 1 - alpha_prod_t 376 | beta_prod_t_prev = 1 - alpha_prod_t_prev 377 | variance = (beta_prod_t_prev / beta_prod_t) * ( 378 | 1 - alpha_prod_t / alpha_prod_t_prev 379 | ) 380 | return variance 381 | 382 | @contextmanager 383 | def disable_unet_class_embedding(self, unet): 384 | class_embedding = unet.class_embedding 385 | try: 386 | unet.class_embedding = None 387 | yield unet 388 | finally: 389 | unet.class_embedding = class_embedding 390 | 391 | def vsd_loss( 392 | self, 393 | im, 394 | prompt=None, 395 | reduction="mean", 396 | cfg_scale=100, 397 | return_dict=False, 398 | ): 399 | device = self.device 400 | scheduler = self.scheduler 401 | 402 | # process text. 403 | self.update_text_features(tgt_prompt=prompt) 404 | tgt_text_embedding = self.tgt_text_feature 405 | uncond_embedding = self.null_text_feature 406 | 407 | batch_size = im.shape[0] 408 | camera_condition = torch.zeros([batch_size, 4, 4], device=device) 409 | 410 | with torch.no_grad(): 411 | # random timestamp 412 | t = torch.randint( 413 | 20, 414 | 980 + 1, 415 | [batch_size], 416 | dtype=torch.long, 417 | device=self.device, 418 | ) 419 | 420 | noise = torch.randn_like(im) 421 | 422 | latents_noisy = scheduler.add_noise(im, noise, t) 423 | latent_model_input = torch.cat([latents_noisy] * 2, dim=0) 424 | text_embeddings = torch.cat([tgt_text_embedding, uncond_embedding], dim=0) 425 | with self.disable_unet_class_embedding(self.unet) as unet: 426 | cross_attention_kwargs = {"scale": 0.0} if self.single_model else None 427 | noise_pred_pretrain = unet.forward( 428 | latent_model_input, 429 | torch.cat([t] * 2).to(device), 430 | encoder_hidden_states=text_embeddings, 431 | cross_attention_kwargs=cross_attention_kwargs, 432 | ) 433 | 434 | # use view-independent text embeddings in LoRA 435 | noise_pred_est = self.unet_lora.forward( 436 | latent_model_input, 437 | torch.cat([t] * 2).to(device), 438 | encoder_hidden_states=torch.cat([tgt_text_embedding] * 2), 439 | class_labels=torch.cat( 440 | [ 441 | camera_condition.view(batch_size, -1), 442 | camera_condition.view(batch_size, -1), 443 | ], 444 | dim=0, 445 | ), 446 | cross_attention_kwargs={"scale": 1.0}, 447 | ).sample 448 | 449 | ( 450 | noise_pred_pretrain_text, 451 | noise_pred_pretrain_uncond, 452 | ) = noise_pred_pretrain.sample.chunk(2) 453 | 454 | # NOTE: guidance scale definition here is aligned with diffusers, but different from other guidance 455 | noise_pred_pretrain = noise_pred_pretrain_uncond + cfg_scale * ( 456 | noise_pred_pretrain_text - noise_pred_pretrain_uncond 457 | ) 458 | assert self.scheduler.config.prediction_type == "epsilon" 459 | if self.scheduler_lora.config.prediction_type == "v_prediction": 460 | alphas_cumprod = self.scheduler_lora.alphas_cumprod.to( 461 | device=latents_noisy.device, dtype=latents_noisy.dtype 462 | ) 463 | alpha_t = alphas_cumprod[t] ** 0.5 464 | sigma_t = (1 - alphas_cumprod[t]) ** 0.5 465 | 466 | noise_pred_est = latent_model_input * torch.cat([sigma_t] * 2, dim=0).view( 467 | -1, 1, 1, 1 468 | ) + noise_pred_est * torch.cat([alpha_t] * 2, dim=0).view(-1, 1, 1, 1) 469 | 470 | ( 471 | noise_pred_est_camera, 472 | noise_pred_est_uncond, 473 | ) = noise_pred_est.chunk(2) 474 | 475 | # NOTE: guidance scale definition here is aligned with diffusers, but different from other guidance 476 | noise_pred_est = noise_pred_est_uncond + self.config.guidance_scale_lora * ( 477 | noise_pred_est_camera - noise_pred_est_uncond 478 | ) 479 | 480 | w = (1 - scheduler.alphas_cumprod[t.cpu()]).view(-1, 1, 1, 1).to(device) 481 | grad = w * (noise_pred_pretrain - noise_pred_est) 482 | 483 | grad = torch.nan_to_num(grad) 484 | target = (im - grad).detach() 485 | loss = 0.5 * F.mse_loss(im, target, reduction=reduction) / batch_size 486 | loss_lora = self.train_lora(im, text_embeddings, camera_condition) 487 | if return_dict: 488 | dic = {"loss": loss, "lora_loss": loss_lora, "grad": grad, "t": t} 489 | return dic 490 | else: 491 | return loss 492 | 493 | def train_lora( 494 | self, 495 | latents: Float[torch.Tensor, "B 4 64 64"], 496 | text_embeddings: Float[torch.Tensor, "BB 77 768"], 497 | camera_condition: Float[torch.Tensor, "B 4 4"], 498 | ): 499 | scheduler = self.scheduler_lora 500 | 501 | B = latents.shape[0] 502 | latents = latents.detach().repeat(self.config.lora_n_timestamp_samples, 1, 1, 1) 503 | 504 | t = torch.randint( 505 | int(scheduler.num_train_timesteps * 0.0), 506 | int(scheduler.num_train_timesteps * 1.0), 507 | [B * self.config.lora_n_timestamp_samples], 508 | dtype=torch.long, 509 | device=self.device, 510 | ) 511 | 512 | noise = torch.randn_like(latents) 513 | noisy_latents = self.scheduler_lora.add_noise(latents, noise, t) 514 | if self.scheduler_lora.config.prediction_type == "epsilon": 515 | target = noise 516 | elif self.scheduler_lora.config.prediction_type == "v_prediction": 517 | target = self.scheduler_lora.get_velocity(latents, noise, t) 518 | else: 519 | raise ValueError( 520 | f"Unknown prediction type {self.scheduler_lora.config.prediction_type}" 521 | ) 522 | # use view-independent text embeddings in LoRA 523 | text_embeddings_cond, _ = text_embeddings.chunk(2) 524 | if self.config.lora_cfg_training and np.random.random() < 0.1: 525 | camera_condition = torch.zeros_like(camera_condition) 526 | noise_pred = self.unet_lora.forward( 527 | noisy_latents, 528 | t, 529 | encoder_hidden_states=text_embeddings_cond.repeat( 530 | self.config.lora_n_timestamp_samples, 1, 1 531 | ), 532 | class_labels=camera_condition.view(B, -1).repeat( 533 | self.config.lora_n_timestamp_samples, 1 534 | ), 535 | cross_attention_kwargs={"scale": 1.0}, 536 | ).sample 537 | return F.mse_loss(noise_pred.float(), target.float(), reduction="mean") 538 | -------------------------------------------------------------------------------- /threestudio-sds-bridge/guidance/sds_bridge_guidance.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from diffusers import DDIMScheduler, StableDiffusionPipeline 8 | from diffusers.utils.import_utils import is_xformers_available 9 | from tqdm import tqdm 10 | 11 | import threestudio 12 | from threestudio.models.prompt_processors.base import PromptProcessorOutput 13 | from threestudio.utils.base import BaseObject 14 | from threestudio.utils.misc import C, cleanup, parse_version 15 | from threestudio.utils.ops import perpendicular_component 16 | from threestudio.utils.typing import * 17 | 18 | 19 | @threestudio.register("stable-diffusion-sds-bridge-guidance") 20 | class SDSBridgeGuidance(BaseObject): 21 | @dataclass 22 | class Config(BaseObject.Config): 23 | pretrained_model_name_or_path: str = "runwayml/stable-diffusion-v1-5" 24 | enable_memory_efficient_attention: bool = False 25 | enable_sequential_cpu_offload: bool = False 26 | enable_attention_slicing: bool = False 27 | enable_channels_last_format: bool = False 28 | grad_clip: Optional[ 29 | Any 30 | ] = None # field(default_factory=lambda: [0, 2.0, 8.0, 1000]) 31 | half_precision_weights: bool = True 32 | 33 | min_step_percent: float = 0.02 34 | max_step_percent: float = 0.98 35 | sqrt_anneal: bool = False # sqrt anneal proposed in HiFA: https://hifa-team.github.io/HiFA-site/ 36 | trainer_max_steps: int = 25000 37 | use_img_loss: bool = False # image-space SDS proposed in HiFA: https://hifa-team.github.io/HiFA-site/ 38 | 39 | var_red: bool = True 40 | weighting_strategy: str = "sds" 41 | 42 | token_merging: bool = False 43 | token_merging_params: Optional[dict] = field(default_factory=dict) 44 | 45 | view_dependent_prompting: bool = True 46 | 47 | """Maximum number of batch items to evaluate guidance for (for debugging) and to save on disk. -1 means save all items.""" 48 | max_items_eval: int = 4 49 | 50 | """Configs for SDS-Bridges""" 51 | num_inference_steps: int = 500 52 | guidance_scale: float = 100.0 53 | stage_one_weight: float = 1. 54 | stage_two_weight: float = 100. 55 | stage_two_start_step: int = 20000 56 | 57 | 58 | cfg: Config 59 | 60 | def configure(self) -> None: 61 | threestudio.info(f"Loading Stable Diffusion ...") 62 | 63 | self.weights_dtype = ( 64 | torch.float16 if self.cfg.half_precision_weights else torch.float32 65 | ) 66 | 67 | pipe_kwargs = { 68 | "tokenizer": None, 69 | "safety_checker": None, 70 | "feature_extractor": None, 71 | "requires_safety_checker": False, 72 | "torch_dtype": self.weights_dtype, 73 | } 74 | self.pipe = StableDiffusionPipeline.from_pretrained( 75 | self.cfg.pretrained_model_name_or_path, 76 | **pipe_kwargs, 77 | ).to(self.device) 78 | 79 | if self.cfg.enable_memory_efficient_attention: 80 | if parse_version(torch.__version__) >= parse_version("2"): 81 | threestudio.info( 82 | "PyTorch2.0 uses memory efficient attention by default." 83 | ) 84 | elif not is_xformers_available(): 85 | threestudio.warn( 86 | "xformers is not available, memory efficient attention is not enabled." 87 | ) 88 | else: 89 | self.pipe.enable_xformers_memory_efficient_attention() 90 | 91 | if self.cfg.enable_sequential_cpu_offload: 92 | self.pipe.enable_sequential_cpu_offload() 93 | 94 | if self.cfg.enable_attention_slicing: 95 | self.pipe.enable_attention_slicing(1) 96 | 97 | if self.cfg.enable_channels_last_format: 98 | self.pipe.unet.to(memory_format=torch.channels_last) 99 | 100 | del self.pipe.text_encoder 101 | cleanup() 102 | 103 | # Create model 104 | self.vae = self.pipe.vae.eval() 105 | self.unet = self.pipe.unet.eval() 106 | 107 | for p in self.vae.parameters(): 108 | p.requires_grad_(False) 109 | for p in self.unet.parameters(): 110 | p.requires_grad_(False) 111 | 112 | if self.cfg.token_merging: 113 | import tomesd 114 | 115 | tomesd.apply_patch(self.unet, **self.cfg.token_merging_params) 116 | 117 | self.scheduler = DDIMScheduler.from_pretrained( 118 | self.cfg.pretrained_model_name_or_path, 119 | subfolder="scheduler", 120 | torch_dtype=self.weights_dtype, 121 | ) 122 | 123 | 124 | self.num_train_timesteps = self.scheduler.config.num_train_timesteps 125 | self.set_min_max_steps() # set to default value 126 | 127 | self.alphas: Float[Tensor, "..."] = self.scheduler.alphas_cumprod.to( 128 | self.device 129 | ) 130 | self.grad_clip_val: Optional[float] = None 131 | 132 | self.phase_id = 1 133 | 134 | threestudio.info(f"Loaded Stable Diffusion!") 135 | 136 | @torch.cuda.amp.autocast(enabled=False) 137 | def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98): 138 | self.min_step = int(self.num_train_timesteps * min_step_percent) 139 | self.max_step = int(self.num_train_timesteps * max_step_percent) 140 | 141 | @torch.cuda.amp.autocast(enabled=False) 142 | def forward_unet( 143 | self, 144 | latents: Float[Tensor, "..."], 145 | t: Float[Tensor, "..."], 146 | encoder_hidden_states: Float[Tensor, "..."], 147 | ) -> Float[Tensor, "..."]: 148 | input_dtype = latents.dtype 149 | return self.unet( 150 | latents.to(self.weights_dtype), 151 | t.to(self.weights_dtype), 152 | encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype), 153 | ).sample.to(input_dtype) 154 | 155 | @torch.cuda.amp.autocast(enabled=False) 156 | def encode_images( 157 | self, imgs: Float[Tensor, "B 3 512 512"] 158 | ) -> Float[Tensor, "B 4 64 64"]: 159 | input_dtype = imgs.dtype 160 | imgs = imgs * 2.0 - 1.0 161 | posterior = self.vae.encode(imgs.to(self.weights_dtype)).latent_dist 162 | latents = posterior.sample() * self.vae.config.scaling_factor 163 | return latents.to(input_dtype) 164 | 165 | 166 | @torch.cuda.amp.autocast(enabled=False) 167 | def decode_latents( 168 | self, 169 | latents: Float[Tensor, "B 4 H W"], 170 | latent_height: int = 64, 171 | latent_width: int = 64, 172 | ) -> Float[Tensor, "B 3 512 512"]: 173 | input_dtype = latents.dtype 174 | latents = F.interpolate( 175 | latents, (latent_height, latent_width), mode="bilinear", align_corners=False 176 | ) 177 | latents = 1 / self.vae.config.scaling_factor * latents 178 | image = self.vae.decode(latents.to(self.weights_dtype)).sample 179 | image = (image * 0.5 + 0.5).clamp(0, 1) 180 | return image.to(input_dtype) 181 | 182 | def compute_grad_sds_bridge( 183 | self, 184 | latents: Float[Tensor, "B 4 64 64"], 185 | image: Float[Tensor, "B 3 512 512"], 186 | t: Int[Tensor, "B"], 187 | prompt_utils: PromptProcessorOutput, 188 | elevation: Float[Tensor, "B"], 189 | azimuth: Float[Tensor, "B"], 190 | camera_distances: Float[Tensor, "B"], 191 | ): 192 | batch_size = elevation.shape[0] 193 | 194 | device = latents.device 195 | neg_guidance_weights = None 196 | text_embeddings_all = prompt_utils.get_text_embeddings( 197 | elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting 198 | ) 199 | src_text_embeddings, tgt_text_embeddings, base_text_embeddings = text_embeddings_all 200 | src_text_embedding = src_text_embeddings[:1] 201 | tgt_text_embedding = tgt_text_embeddings[:1] 202 | base_text_embedding = base_text_embeddings[:1] 203 | uncond_text_embedding = src_text_embeddings[1:] 204 | if self.phase_id == 1: 205 | text_embeddings = torch.cat([uncond_text_embedding, tgt_text_embedding]) 206 | else: 207 | text_embeddings = torch.cat([src_text_embedding, tgt_text_embedding]) 208 | 209 | noise = torch.randn_like(latents) 210 | 211 | with torch.no_grad(): 212 | latents_noisy = self.scheduler.add_noise(latents, noise, t) 213 | 214 | latent_model_input = torch.cat([latents_noisy] * 2, dim=0) 215 | noise_pred = self.forward_unet(latent_model_input, torch.cat([t] * 2).to(device), text_embeddings) 216 | noise_pred_text_src, noise_pred_text_tgt = noise_pred.chunk(2) 217 | if self.phase_id == 1: 218 | w = (1 - self.alphas[t]).view(-1, 1, 1, 1) 219 | noise_pred_sds = noise_pred_text_src + self.cfg.guidance_scale * (noise_pred_text_tgt - noise_pred_text_src) 220 | noise_pred = self.cfg.stage_one_weight * w * noise_pred_sds 221 | noise = self.cfg.stage_one_weight * w * noise 222 | elif self.phase_id == 2: 223 | noise_pred = self.cfg.stage_two_weight * noise_pred_text_tgt 224 | noise = self.cfg.stage_two_weight * noise_pred_text_src 225 | 226 | 227 | if self.cfg.weighting_strategy == "sds": 228 | # w(t), sigma_t^2 229 | w = (1 - self.alphas[t]).view(-1, 1, 1, 1) 230 | elif self.cfg.weighting_strategy == "uniform": 231 | w = 1 232 | elif self.cfg.weighting_strategy == "fantasia3d": 233 | w = (self.alphas[t] ** 0.5 * (1 - self.alphas[t])).view(-1, 1, 1, 1) 234 | else: 235 | raise ValueError( 236 | f"Unknown weighting strategy: {self.cfg.weighting_strategy}" 237 | ) 238 | 239 | alpha = (self.alphas[t] ** 0.5).view(-1, 1, 1, 1) 240 | sigma = ((1 - self.alphas[t]) ** 0.5).view(-1, 1, 1, 1) 241 | latents_denoised = (latents_noisy - sigma * noise_pred) / alpha 242 | image_denoised = self.decode_latents(latents_denoised) 243 | 244 | grad = w * (noise_pred - noise) 245 | # image-space SDS proposed in HiFA: https://hifa-team.github.io/HiFA-site/ 246 | if self.cfg.use_img_loss: 247 | grad_img = w * (image - image_denoised) * alpha / sigma 248 | else: 249 | grad_img = None 250 | 251 | guidance_eval_utils = { 252 | "use_perp_neg": prompt_utils.use_perp_neg, 253 | "neg_guidance_weights": neg_guidance_weights, 254 | "text_embeddings": text_embeddings, 255 | "t_orig": t, 256 | "latents_noisy": latents_noisy, 257 | "noise_pred": noise_pred, 258 | } 259 | 260 | return grad, grad_img, guidance_eval_utils 261 | 262 | def compute_posterior_mean(self, xt, noise_pred, t, t_prev): 263 | """ 264 | Computes an estimated posterior mean \mu_\phi(x_t, y; \epsilon_\phi). 265 | """ 266 | device = self.device 267 | beta_t = self.scheduler.betas.to(device)[t][:, None, None, None] 268 | alpha_t = self.scheduler.alphas.to(device)[t][:, None, None, None] 269 | alpha_bar_t = self.scheduler.alphas_cumprod.to(device)[t][:, None, None, None] 270 | alpha_bar_t_prev = self.scheduler.alphas_cumprod.to(device)[t_prev][:, None, None, None] 271 | 272 | pred_x0 = (xt - torch.sqrt(1 - alpha_bar_t) * noise_pred) / torch.sqrt(alpha_bar_t) 273 | c0 = torch.sqrt(alpha_bar_t_prev) * beta_t / (1 - alpha_bar_t) 274 | c1 = torch.sqrt(alpha_t) * (1 - alpha_bar_t_prev) / (1 - alpha_bar_t) 275 | 276 | mean_func = c0 * pred_x0 + c1 * xt 277 | return mean_func 278 | 279 | def __call__( 280 | self, 281 | rgb: Float[Tensor, "B H W C"], 282 | prompt_utils: PromptProcessorOutput, 283 | elevation: Float[Tensor, "B"], 284 | azimuth: Float[Tensor, "B"], 285 | camera_distances: Float[Tensor, "B"], 286 | rgb_as_latents=False, 287 | guidance_eval=False, 288 | **kwargs, 289 | ): 290 | batch_size = rgb.shape[0] 291 | 292 | rgb_BCHW = rgb.permute(0, 3, 1, 2) 293 | latents: Float[Tensor, "B 4 64 64"] 294 | rgb_BCHW_512 = F.interpolate( 295 | rgb_BCHW, (512, 512), mode="bilinear", align_corners=False 296 | ) 297 | if rgb_as_latents: 298 | latents = F.interpolate( 299 | rgb_BCHW, (64, 64), mode="bilinear", align_corners=False 300 | ) 301 | else: 302 | # encode image into latents with vae 303 | latents = self.encode_images(rgb_BCHW_512) 304 | 305 | # timestep ~ U(0.02, 0.98) to avoid very high/low noise level 306 | t = torch.randint( 307 | self.min_step, 308 | self.max_step + 1, 309 | [batch_size], 310 | dtype=torch.long, 311 | device=self.device, 312 | ) 313 | 314 | grad, grad_img, guidance_eval_utils = self.compute_grad_sds_bridge( 315 | latents, 316 | rgb_BCHW_512, 317 | t, 318 | prompt_utils, 319 | elevation, 320 | azimuth, 321 | camera_distances, 322 | ) 323 | 324 | grad = torch.nan_to_num(grad) 325 | 326 | # clip grad for stable training? 327 | if self.grad_clip_val is not None: 328 | grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val) 329 | 330 | # loss = SpecifyGradient.apply(latents, grad) 331 | # SpecifyGradient is not straghtforward, use a reparameterization trick instead 332 | target = (latents - grad).detach() 333 | # d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad 334 | loss_sds = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size 335 | 336 | guidance_out = { 337 | "loss_sds": loss_sds, 338 | "grad_norm": grad.norm(), 339 | "min_step": self.min_step, 340 | "max_step": self.max_step, 341 | } 342 | 343 | if self.cfg.use_img_loss: 344 | grad_img = torch.nan_to_num(grad_img) 345 | if self.grad_clip_val is not None: 346 | grad_img = grad_img.clamp(-self.grad_clip_val, self.grad_clip_val) 347 | target_img = (rgb_BCHW_512 - grad_img).detach() 348 | loss_sds_img = ( 349 | 0.5 * F.mse_loss(rgb_BCHW_512, target_img, reduction="sum") / batch_size 350 | ) 351 | guidance_out["loss_sds_img"] = loss_sds_img 352 | 353 | if guidance_eval: 354 | guidance_eval_out = self.guidance_eval(**guidance_eval_utils) 355 | texts = [] 356 | for n, e, a, c in zip( 357 | guidance_eval_out["noise_levels"], elevation, azimuth, camera_distances 358 | ): 359 | texts.append( 360 | f"n{n:.02f}\ne{e.item():.01f}\na{a.item():.01f}\nc{c.item():.02f}" 361 | ) 362 | guidance_eval_out.update({"texts": texts}) 363 | guidance_out.update({"eval": guidance_eval_out}) 364 | 365 | return guidance_out 366 | 367 | @torch.cuda.amp.autocast(enabled=False) 368 | @torch.no_grad() 369 | def get_noise_pred( 370 | self, 371 | latents_noisy, 372 | t, 373 | text_embeddings, 374 | use_perp_neg=False, 375 | neg_guidance_weights=None, 376 | ): 377 | batch_size = latents_noisy.shape[0] 378 | 379 | if use_perp_neg: 380 | # pred noise 381 | latent_model_input = torch.cat([latents_noisy] * 4, dim=0) 382 | noise_pred = self.forward_unet( 383 | latent_model_input, 384 | torch.cat([t.reshape(1)] * 4).to(self.device), 385 | encoder_hidden_states=text_embeddings, 386 | ) # (4B, 3, 64, 64) 387 | 388 | noise_pred_text = noise_pred[:batch_size] 389 | noise_pred_uncond = noise_pred[batch_size : batch_size * 2] 390 | noise_pred_neg = noise_pred[batch_size * 2 :] 391 | 392 | e_pos = noise_pred_text - noise_pred_uncond 393 | accum_grad = 0 394 | n_negative_prompts = neg_guidance_weights.shape[-1] 395 | for i in range(n_negative_prompts): 396 | e_i_neg = noise_pred_neg[i::n_negative_prompts] - noise_pred_uncond 397 | accum_grad += neg_guidance_weights[:, i].view( 398 | -1, 1, 1, 1 399 | ) * perpendicular_component(e_i_neg, e_pos) 400 | 401 | noise_pred = noise_pred_uncond + self.cfg.guidance_scale * ( 402 | e_pos + accum_grad 403 | ) 404 | else: 405 | # pred noise 406 | latent_model_input = torch.cat([latents_noisy] * 2, dim=0) 407 | noise_pred = self.forward_unet( 408 | latent_model_input, 409 | torch.cat([t.reshape(1)] * 2).to(self.device), 410 | encoder_hidden_states=text_embeddings, 411 | ) 412 | # perform guidance (high scale from paper!) 413 | noise_pred_text, noise_pred_uncond = noise_pred.chunk(2) 414 | noise_pred = noise_pred_text + self.cfg.guidance_scale * ( 415 | noise_pred_text - noise_pred_uncond 416 | ) 417 | 418 | return noise_pred 419 | 420 | @torch.cuda.amp.autocast(enabled=False) 421 | @torch.no_grad() 422 | def guidance_eval( 423 | self, 424 | t_orig, 425 | text_embeddings, 426 | latents_noisy, 427 | noise_pred, 428 | use_perp_neg=False, 429 | neg_guidance_weights=None, 430 | ): 431 | # use only 50 timesteps, and find nearest of those to t 432 | self.scheduler.set_timesteps(50) 433 | self.scheduler.timesteps_gpu = self.scheduler.timesteps.to(self.device) 434 | bs = ( 435 | min(self.cfg.max_items_eval, latents_noisy.shape[0]) 436 | if self.cfg.max_items_eval > 0 437 | else latents_noisy.shape[0] 438 | ) # batch size 439 | large_enough_idxs = self.scheduler.timesteps_gpu.expand([bs, -1]) > t_orig[ 440 | :bs 441 | ].unsqueeze( 442 | -1 443 | ) # sized [bs,50] > [bs,1] 444 | idxs = torch.min(large_enough_idxs, dim=1)[1] 445 | t = self.scheduler.timesteps_gpu[idxs] 446 | 447 | fracs = list((t / self.scheduler.config.num_train_timesteps).cpu().numpy()) 448 | imgs_noisy = self.decode_latents(latents_noisy[:bs]).permute(0, 2, 3, 1) 449 | 450 | # get prev latent 451 | latents_1step = [] 452 | pred_1orig = [] 453 | for b in range(bs): 454 | step_output = self.scheduler.step( 455 | noise_pred[b : b + 1], t[b], latents_noisy[b : b + 1], eta=1 456 | ) 457 | latents_1step.append(step_output["prev_sample"]) 458 | pred_1orig.append(step_output["pred_original_sample"]) 459 | latents_1step = torch.cat(latents_1step) 460 | pred_1orig = torch.cat(pred_1orig) 461 | imgs_1step = self.decode_latents(latents_1step).permute(0, 2, 3, 1) 462 | imgs_1orig = self.decode_latents(pred_1orig).permute(0, 2, 3, 1) 463 | 464 | latents_final = [] 465 | for b, i in enumerate(idxs): 466 | latents = latents_1step[b : b + 1] 467 | text_emb = ( 468 | text_embeddings[ 469 | [b, b + len(idxs), b + 2 * len(idxs), b + 3 * len(idxs)], ... 470 | ] 471 | if use_perp_neg 472 | else text_embeddings[[b, b + len(idxs)], ...] 473 | ) 474 | neg_guid = neg_guidance_weights[b : b + 1] if use_perp_neg else None 475 | for t in tqdm(self.scheduler.timesteps[i + 1 :], leave=False): 476 | # pred noise 477 | noise_pred = self.get_noise_pred( 478 | latents, t, text_emb, use_perp_neg, neg_guid 479 | ) 480 | # get prev latent 481 | latents = self.scheduler.step(noise_pred, t, latents, eta=1)[ 482 | "prev_sample" 483 | ] 484 | latents_final.append(latents) 485 | 486 | latents_final = torch.cat(latents_final) 487 | imgs_final = self.decode_latents(latents_final).permute(0, 2, 3, 1) 488 | 489 | return { 490 | "bs": bs, 491 | "noise_levels": fracs, 492 | "imgs_noisy": imgs_noisy, 493 | "imgs_1step": imgs_1step, 494 | "imgs_1orig": imgs_1orig, 495 | "imgs_final": imgs_final, 496 | } 497 | 498 | def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): 499 | # clip grad for stable training as demonstrated in 500 | # Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation 501 | # http://arxiv.org/abs/2303.15413 502 | if self.cfg.grad_clip is not None: 503 | self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step) 504 | 505 | if self.cfg.sqrt_anneal: 506 | percentage = ( 507 | float(global_step) / self.cfg.trainer_max_steps 508 | ) ** 0.5 # progress percentage 509 | if type(self.cfg.max_step_percent) not in [float, int]: 510 | max_step_percent = self.cfg.max_step_percent[1] 511 | else: 512 | max_step_percent = self.cfg.max_step_percent 513 | curr_percent = ( 514 | max_step_percent - C(self.cfg.min_step_percent, epoch, global_step) 515 | ) * (1 - percentage) + C(self.cfg.min_step_percent, epoch, global_step) 516 | self.set_min_max_steps( 517 | min_step_percent=curr_percent, 518 | max_step_percent=curr_percent, 519 | ) 520 | else: 521 | self.set_min_max_steps( 522 | min_step_percent=C(self.cfg.min_step_percent, epoch, global_step), 523 | max_step_percent=C(self.cfg.max_step_percent, epoch, global_step), 524 | ) 525 | -------------------------------------------------------------------------------- /threestudio-sds-bridge/prompt_processors/stable_diffusion_sds_bridge_prompt_processor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass, field 4 | 5 | import torch 6 | import torch.multiprocessing as mp 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from pytorch_lightning.utilities.rank_zero import rank_zero_only 10 | from transformers import AutoTokenizer, BertForMaskedLM, CLIPTextModel 11 | 12 | import threestudio 13 | from threestudio.utils.base import BaseObject 14 | from threestudio.utils.misc import barrier, cleanup, get_rank 15 | from threestudio.utils.ops import shifted_cosine_decay, shifted_expotional_decay 16 | from threestudio.utils.typing import * 17 | 18 | import safetensors 19 | from diffusers.utils import _get_model_file 20 | 21 | def hash_prompt(model: str, prompt: str) -> str: 22 | import hashlib 23 | 24 | identifier = f"{model}-{prompt}" 25 | return hashlib.md5(identifier.encode()).hexdigest() 26 | 27 | 28 | @dataclass 29 | class DirectionConfig: 30 | name: str 31 | prompt: Callable[[str], str] 32 | negative_prompt: Callable[[str], str] 33 | condition: Callable[ 34 | [Float[Tensor, "B"], Float[Tensor, "B"], Float[Tensor, "B"]], 35 | Float[Tensor, "B"], 36 | ] 37 | 38 | 39 | @dataclass 40 | class PromptProcessorOutput: 41 | base_text_embeddings: Float[Tensor, "N Nf"] 42 | src_text_embeddings: Float[Tensor, "N Nf"] 43 | tgt_text_embeddings: Float[Tensor, "N Nf"] 44 | uncond_text_embeddings: Float[Tensor, "N Nf"] 45 | base_text_embeddings_vd: Float[Tensor, "Nv N Nf"] 46 | src_text_embeddings_vd: Float[Tensor, "Nv N Nf"] 47 | tgt_text_embeddings_vd: Float[Tensor, "Nv N Nf"] 48 | uncond_text_embeddings_vd: Float[Tensor, "Nv N Nf"] 49 | directions: List[DirectionConfig] 50 | direction2idx: Dict[str, int] 51 | use_perp_neg: bool 52 | perp_neg_f_sb: Tuple[float, float, float] 53 | perp_neg_f_fsb: Tuple[float, float, float] 54 | perp_neg_f_fs: Tuple[float, float, float] 55 | perp_neg_f_sf: Tuple[float, float, float] 56 | prompt: str 57 | base_prompts_vd: List[str] 58 | src_prompts_vd: List[str] 59 | tgt_prompts_vd: List[str] 60 | 61 | def get_text_embeddings( 62 | self, 63 | elevation: Float[Tensor, "B"], 64 | azimuth: Float[Tensor, "B"], 65 | camera_distances: Float[Tensor, "B"], 66 | view_dependent_prompting: bool = True, 67 | ) -> Float[Tensor, "BB N Nf"]: 68 | batch_size = elevation.shape[0] 69 | 70 | if view_dependent_prompting: 71 | # Get direction 72 | direction_idx = torch.zeros_like(elevation, dtype=torch.long) 73 | for d in self.directions: 74 | direction_idx[ 75 | d.condition(elevation, azimuth, camera_distances) 76 | ] = self.direction2idx[d.name] 77 | 78 | # Get text embeddings 79 | base_text_embeddings = self.base_text_embeddings_vd[direction_idx] # type: ignore 80 | src_text_embeddings = self.src_text_embeddings_vd[direction_idx] # type: ignore 81 | tgt_text_embeddings = self.tgt_text_embeddings_vd[direction_idx] # type: ignore 82 | uncond_text_embeddings = self.uncond_text_embeddings_vd[direction_idx] # type: ignore 83 | else: 84 | base_text_embeddings = self.base_text_embeddings.expand(batch_size, -1, -1) # type: ignore 85 | src_text_embeddings = self.src_text_embeddings.expand(batch_size, -1, -1) # type: ignore 86 | tgt_text_embeddings = self.tgt_text_embeddings.expand(batch_size, -1, -1) # type: ignore 87 | uncond_text_embeddings = self.uncond_text_embeddings.expand( # type: ignore 88 | batch_size, -1, -1 89 | ) 90 | 91 | # IMPORTANT: we return (cond, uncond), which is in different order than other implementations! 92 | return torch.cat([src_text_embeddings, uncond_text_embeddings], dim=0), torch.cat([tgt_text_embeddings, uncond_text_embeddings], dim=0), torch.cat([base_text_embeddings, uncond_text_embeddings], dim=0) 93 | 94 | def get_text_embeddings_perp_neg( 95 | self, 96 | elevation: Float[Tensor, "B"], 97 | azimuth: Float[Tensor, "B"], 98 | camera_distances: Float[Tensor, "B"], 99 | view_dependent_prompting: bool = True, 100 | ) -> Tuple[Float[Tensor, "BBBB N Nf"], Float[Tensor, "B 2"]]: 101 | assert ( 102 | view_dependent_prompting 103 | ), "Perp-Neg only works with view-dependent prompting" 104 | 105 | batch_size = elevation.shape[0] 106 | 107 | direction_idx = torch.zeros_like(elevation, dtype=torch.long) 108 | for d in self.directions: 109 | direction_idx[ 110 | d.condition(elevation, azimuth, camera_distances) 111 | ] = self.direction2idx[d.name] 112 | # 0 - side view 113 | # 1 - front view 114 | # 2 - back view 115 | # 3 - overhead view 116 | 117 | pos_text_embeddings = [] 118 | neg_text_embeddings = [] 119 | neg_guidance_weights = [] 120 | uncond_text_embeddings = [] 121 | 122 | side_emb = self.text_embeddings_vd[0] 123 | front_emb = self.text_embeddings_vd[1] 124 | back_emb = self.text_embeddings_vd[2] 125 | overhead_emb = self.text_embeddings_vd[3] 126 | 127 | for idx, ele, azi, dis in zip( 128 | direction_idx, elevation, azimuth, camera_distances 129 | ): 130 | azi = shift_azimuth_deg(azi) # to (-180, 180) 131 | uncond_text_embeddings.append( 132 | self.uncond_text_embeddings_vd[idx] 133 | ) # should be "" 134 | if idx.item() == 3: # overhead view 135 | pos_text_embeddings.append(overhead_emb) # side view 136 | # dummy 137 | neg_text_embeddings += [ 138 | self.uncond_text_embeddings_vd[idx], 139 | self.uncond_text_embeddings_vd[idx], 140 | ] 141 | neg_guidance_weights += [0.0, 0.0] 142 | else: # interpolating views 143 | if torch.abs(azi) < 90: 144 | # front-side interpolation 145 | # 0 - complete side, 1 - complete front 146 | r_inter = 1 - torch.abs(azi) / 90 147 | pos_text_embeddings.append( 148 | r_inter * front_emb + (1 - r_inter) * side_emb 149 | ) 150 | neg_text_embeddings += [front_emb, side_emb] 151 | neg_guidance_weights += [ 152 | -shifted_expotional_decay(*self.perp_neg_f_fs, r_inter), 153 | -shifted_expotional_decay(*self.perp_neg_f_sf, 1 - r_inter), 154 | ] 155 | else: 156 | # side-back interpolation 157 | # 0 - complete back, 1 - complete side 158 | r_inter = 2.0 - torch.abs(azi) / 90 159 | pos_text_embeddings.append( 160 | r_inter * side_emb + (1 - r_inter) * back_emb 161 | ) 162 | neg_text_embeddings += [side_emb, front_emb] 163 | neg_guidance_weights += [ 164 | -shifted_expotional_decay(*self.perp_neg_f_sb, r_inter), 165 | -shifted_expotional_decay(*self.perp_neg_f_fsb, r_inter), 166 | ] 167 | 168 | text_embeddings = torch.cat( 169 | [ 170 | torch.stack(pos_text_embeddings, dim=0), 171 | torch.stack(uncond_text_embeddings, dim=0), 172 | torch.stack(neg_text_embeddings, dim=0), 173 | ], 174 | dim=0, 175 | ) 176 | 177 | return text_embeddings, torch.as_tensor( 178 | neg_guidance_weights, device=elevation.device 179 | ).reshape(batch_size, 2) 180 | 181 | 182 | def shift_azimuth_deg(azimuth: Float[Tensor, "..."]) -> Float[Tensor, "..."]: 183 | # shift azimuth angle (in degrees), to [-180, 180] 184 | return (azimuth + 180) % 360 - 180 185 | 186 | 187 | class PromptProcessor(BaseObject): 188 | @dataclass 189 | class Config(BaseObject.Config): 190 | prompt: str = "a hamburger" 191 | src_modifier: str = 'oversaturated, smooth, pixelated, cartoon, foggy, hazy, blurry, bad structure, noisy, malformed' 192 | tgt_modifier: str = '.' 193 | texture_inversion_embedding: str = '' 194 | 195 | # manually assigned view-dependent prompts 196 | src_prompt_front: Optional[str] = None 197 | src_prompt_side: Optional[str] = None 198 | src_prompt_back: Optional[str] = None 199 | src_prompt_overhead: Optional[str] = None 200 | tgt_prompt_front: Optional[str] = None 201 | tgt_prompt_side: Optional[str] = None 202 | tgt_prompt_back: Optional[str] = None 203 | tgt_prompt_overhead: Optional[str] = None 204 | 205 | negative_prompt: str = "" 206 | pretrained_model_name_or_path: str = "runwayml/stable-diffusion-v1-5" 207 | overhead_threshold: float = 60.0 208 | front_threshold: float = 45.0 209 | back_threshold: float = 45.0 210 | view_dependent_prompt_front: bool = False 211 | use_cache: bool = True 212 | spawn: bool = True 213 | 214 | # perp neg 215 | use_perp_neg: bool = False 216 | # a*e(-b*r) + c 217 | # a * e(-b) + c = 0 218 | perp_neg_f_sb: Tuple[float, float, float] = (1, 0.5, -0.606) 219 | perp_neg_f_fsb: Tuple[float, float, float] = (1, 0.5, +0.967) 220 | perp_neg_f_fs: Tuple[float, float, float] = ( 221 | 4, 222 | 0.5, 223 | -2.426, 224 | ) # f_fs(1) = 0, a, b > 0 225 | perp_neg_f_sf: Tuple[float, float, float] = (4, 0.5, -2.426) 226 | 227 | # prompt debiasing 228 | use_prompt_debiasing: bool = False 229 | pretrained_model_name_or_path_prompt_debiasing: str = "bert-base-uncased" 230 | # index of words that can potentially be removed 231 | prompt_debiasing_mask_ids: Optional[List[int]] = None 232 | 233 | use_modifier_only: bool = True 234 | 235 | cfg: Config 236 | 237 | @rank_zero_only 238 | def configure_text_encoder(self) -> None: 239 | raise NotImplementedError 240 | 241 | @rank_zero_only 242 | def destroy_text_encoder(self) -> None: 243 | raise NotImplementedError 244 | 245 | def configure(self) -> None: 246 | self._cache_dir = ".threestudio_cache/text_embeddings" # FIXME: hard-coded path 247 | 248 | # view-dependent text embeddings 249 | self.directions: List[DirectionConfig] 250 | if self.cfg.view_dependent_prompt_front: 251 | self.directions = [ 252 | DirectionConfig( 253 | "side", 254 | lambda s: f"side view of {s}", 255 | lambda s: s, 256 | lambda ele, azi, dis: torch.ones_like(ele, dtype=torch.bool), 257 | ), 258 | DirectionConfig( 259 | "front", 260 | lambda s: f"front view of {s}", 261 | lambda s: s, 262 | lambda ele, azi, dis: ( 263 | shift_azimuth_deg(azi) > -self.cfg.front_threshold 264 | ) 265 | & (shift_azimuth_deg(azi) < self.cfg.front_threshold), 266 | ), 267 | DirectionConfig( 268 | "back", 269 | lambda s: f"backside view of {s}", 270 | lambda s: s, 271 | lambda ele, azi, dis: ( 272 | shift_azimuth_deg(azi) > 180 - self.cfg.back_threshold 273 | ) 274 | | (shift_azimuth_deg(azi) < -180 + self.cfg.back_threshold), 275 | ), 276 | DirectionConfig( 277 | "overhead", 278 | lambda s: f"overhead view of {s}", 279 | lambda s: s, 280 | lambda ele, azi, dis: ele > self.cfg.overhead_threshold, 281 | ), 282 | ] 283 | else: 284 | self.directions = [ 285 | DirectionConfig( 286 | "side", 287 | lambda s: f"{s}, side view", 288 | lambda s: s, 289 | lambda ele, azi, dis: torch.ones_like(ele, dtype=torch.bool), 290 | ), 291 | DirectionConfig( 292 | "front", 293 | lambda s: f"{s}, front view", 294 | lambda s: s, 295 | lambda ele, azi, dis: ( 296 | shift_azimuth_deg(azi) > -self.cfg.front_threshold 297 | ) 298 | & (shift_azimuth_deg(azi) < self.cfg.front_threshold), 299 | ), 300 | DirectionConfig( 301 | "back", 302 | lambda s: f"{s}, back view", 303 | lambda s: s, 304 | lambda ele, azi, dis: ( 305 | shift_azimuth_deg(azi) > 180 - self.cfg.back_threshold 306 | ) 307 | | (shift_azimuth_deg(azi) < -180 + self.cfg.back_threshold), 308 | ), 309 | DirectionConfig( 310 | "overhead", 311 | lambda s: f"{s}, overhead view", 312 | lambda s: s, 313 | lambda ele, azi, dis: ele > self.cfg.overhead_threshold, 314 | ), 315 | ] 316 | 317 | self.direction2idx = {d.name: i for i, d in enumerate(self.directions)} 318 | 319 | if os.path.exists("load/prompt_library.json"): 320 | with open(os.path.join("load/prompt_library.json"), "r") as f: 321 | self.prompt_library = json.load(f) 322 | else: 323 | self.prompt_library = {} 324 | # use provided prompt or find prompt in library 325 | self.prompt = self.preprocess_prompt(self.cfg.prompt) 326 | # use provided negative prompt 327 | self.negative_prompt = self.cfg.negative_prompt 328 | 329 | # process sds bridge source and target prompt 330 | if self.cfg.use_modifier_only: 331 | self.src_prompt = self.cfg.src_modifier 332 | else: 333 | self.src_prompt = self.prompt + ', ' + self.cfg.src_modifier 334 | 335 | self.tgt_prompt = self.prompt + ', ' + self.cfg.tgt_modifier 336 | 337 | threestudio.info( 338 | f"Using prompt [{self.prompt}] and negative prompt [{self.negative_prompt}]" 339 | ) 340 | 341 | # view-dependent prompting 342 | if self.cfg.use_prompt_debiasing: 343 | # Warning: not implemented with sds bridge yet 344 | assert ( 345 | self.cfg.prompt_side is None 346 | and self.cfg.prompt_back is None 347 | and self.cfg.prompt_overhead is None 348 | ), "Do not manually assign prompt_side, prompt_back or prompt_overhead when using prompt debiasing" 349 | prompts = self.get_debiased_prompt(self.prompt) 350 | self.prompts_vd = [ 351 | d.prompt(prompt) for d, prompt in zip(self.directions, prompts) 352 | ] 353 | else: 354 | self.base_prompts_vd = [ 355 | self.cfg.get(f"src_prompt_{d.name}", None) or d.prompt(self.prompt) # type: ignore 356 | for d in self.directions 357 | ] 358 | self.src_prompts_vd = [ 359 | self.cfg.get(f"src_prompt_{d.name}", None) or d.prompt(self.src_prompt) # type: ignore 360 | for d in self.directions 361 | ] 362 | self.tgt_prompts_vd = [ 363 | self.cfg.get(f"tgt_prompt_{d.name}", None) or d.prompt(self.tgt_prompt) # type: ignore 364 | for d in self.directions 365 | ] 366 | 367 | prompts_vd_display = " ".join( 368 | [ 369 | f"[{d.name}]:[{prompt}]" 370 | for prompt, d in zip(self.src_prompts_vd, self.directions) 371 | ] 372 | ) 373 | threestudio.info(f"Using source view-dependent prompts {prompts_vd_display}") 374 | 375 | prompts_vd_display = " ".join( 376 | [ 377 | f"[{d.name}]:[{prompt}]" 378 | for prompt, d in zip(self.tgt_prompts_vd, self.directions) 379 | ] 380 | ) 381 | threestudio.info(f"Using target view-dependent prompts {prompts_vd_display}") 382 | 383 | self.negative_prompts_vd = [ 384 | d.negative_prompt(self.negative_prompt) for d in self.directions 385 | ] 386 | 387 | self.prepare_text_embeddings() 388 | self.load_text_embeddings() 389 | 390 | def spawn_func(self, pretrained_model_name_or_path, prompts, cache_dir): 391 | raise NotImplementedError 392 | 393 | @rank_zero_only 394 | def prepare_text_embeddings(self): 395 | os.makedirs(self._cache_dir, exist_ok=True) 396 | 397 | all_prompts = ( 398 | [self.src_prompt, self.tgt_prompt, self.prompt] 399 | + [self.negative_prompt] 400 | + self.src_prompts_vd 401 | + self.tgt_prompts_vd 402 | + self.base_prompts_vd 403 | + self.negative_prompts_vd 404 | ) 405 | prompts_to_process = [] 406 | for prompt in all_prompts: 407 | if self.cfg.use_cache and '' not in prompt: 408 | # some text embeddings are already in cache 409 | # do not process them 410 | cache_path = os.path.join( 411 | self._cache_dir, 412 | f"{hash_prompt(self.cfg.pretrained_model_name_or_path, prompt)}.pt", 413 | ) 414 | if os.path.exists(cache_path): 415 | threestudio.debug( 416 | f"Text embeddings for model {self.cfg.pretrained_model_name_or_path} and prompt [{prompt}] are already in cache, skip processing." 417 | ) 418 | continue 419 | prompts_to_process.append(prompt) 420 | 421 | if len(prompts_to_process) > 0: 422 | if False: # deprecated for now 423 | ctx = mp.get_context("spawn") 424 | subprocess = ctx.Process( 425 | target=self.spawn_func, 426 | args=( 427 | self.cfg.pretrained_model_name_or_path, 428 | prompts_to_process, 429 | self._cache_dir, 430 | ), 431 | ) 432 | subprocess.start() 433 | subprocess.join() 434 | assert subprocess.exitcode == 0, "prompt embedding process failed!" 435 | else: 436 | self.spawn_func( 437 | self.cfg.pretrained_model_name_or_path, 438 | prompts_to_process, 439 | self._cache_dir, 440 | ) 441 | cleanup() 442 | 443 | def load_text_embeddings(self): 444 | # synchronize, to ensure the text embeddings have been computed and saved to cache 445 | barrier() 446 | self.base_text_embeddings = self.load_from_cache(self.prompt)[None, ...] 447 | self.src_text_embeddings = self.load_from_cache(self.src_prompt)[None, ...] 448 | self.tgt_text_embeddings = self.load_from_cache(self.tgt_prompt)[None, ...] 449 | self.uncond_text_embeddings = self.load_from_cache(self.negative_prompt)[ 450 | None, ... 451 | ] 452 | self.base_text_embeddings_vd = torch.stack( 453 | [self.load_from_cache(prompt) for prompt in self.base_prompts_vd], dim=0 454 | ) 455 | self.src_text_embeddings_vd = torch.stack( 456 | [self.load_from_cache(prompt) for prompt in self.src_prompts_vd], dim=0 457 | ) 458 | self.tgt_text_embeddings_vd = torch.stack( 459 | [self.load_from_cache(prompt) for prompt in self.tgt_prompts_vd], dim=0 460 | ) 461 | self.uncond_text_embeddings_vd = torch.stack( 462 | [self.load_from_cache(prompt) for prompt in self.negative_prompts_vd], dim=0 463 | ) 464 | threestudio.debug(f"Loaded text embeddings.") 465 | 466 | def load_from_cache(self, prompt): 467 | cache_path = os.path.join( 468 | self._cache_dir, 469 | f"{hash_prompt(self.cfg.pretrained_model_name_or_path, prompt)}.pt", 470 | ) 471 | if not os.path.exists(cache_path): 472 | raise FileNotFoundError( 473 | f"Text embedding file {cache_path} for model {self.cfg.pretrained_model_name_or_path} and prompt [{prompt}] not found." 474 | ) 475 | return torch.load(cache_path, map_location=self.device) 476 | 477 | def preprocess_prompt(self, prompt: str) -> str: 478 | if prompt.startswith("lib:"): 479 | # find matches in the library 480 | candidate = None 481 | keywords = prompt[4:].lower().split("_") 482 | for prompt in self.prompt_library["dreamfusion"]: 483 | if all([k in prompt.lower() for k in keywords]): 484 | if candidate is not None: 485 | raise ValueError( 486 | f"Multiple prompts matched with keywords {keywords} in library" 487 | ) 488 | candidate = prompt 489 | if candidate is None: 490 | raise ValueError( 491 | f"Cannot find prompt with keywords {keywords} in library" 492 | ) 493 | threestudio.info("Find matched prompt in library: " + candidate) 494 | return candidate 495 | else: 496 | return prompt 497 | 498 | def get_text_embeddings( 499 | self, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] 500 | ) -> Tuple[Float[Tensor, "B ..."], Float[Tensor, "B ..."]]: 501 | raise NotImplementedError 502 | 503 | def get_debiased_prompt(self, prompt: str) -> List[str]: 504 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 505 | 506 | tokenizer = AutoTokenizer.from_pretrained( 507 | self.cfg.pretrained_model_name_or_path_prompt_debiasing 508 | ) 509 | model = BertForMaskedLM.from_pretrained( 510 | self.cfg.pretrained_model_name_or_path_prompt_debiasing 511 | ) 512 | 513 | views = [d.name for d in self.directions] 514 | view_ids = tokenizer(" ".join(views), return_tensors="pt").input_ids[0] 515 | view_ids = view_ids[1:5] 516 | 517 | def modulate(prompt): 518 | prompt_vd = f"This image is depicting a [MASK] view of {prompt}" 519 | tokens = tokenizer( 520 | prompt_vd, 521 | padding="max_length", 522 | truncation=True, 523 | add_special_tokens=True, 524 | return_tensors="pt", 525 | ) 526 | mask_idx = torch.where(tokens.input_ids == tokenizer.mask_token_id)[1] 527 | 528 | logits = model(**tokens).logits 529 | logits = F.softmax(logits[0, mask_idx], dim=-1) 530 | logits = logits[0, view_ids] 531 | probes = logits / logits.sum() 532 | return probes 533 | 534 | prompts = [prompt.split(" ") for _ in range(4)] 535 | full_probe = modulate(prompt) 536 | n_words = len(prompt.split(" ")) 537 | prompt_debiasing_mask_ids = ( 538 | self.cfg.prompt_debiasing_mask_ids 539 | if self.cfg.prompt_debiasing_mask_ids is not None 540 | else list(range(n_words)) 541 | ) 542 | words_to_debias = [prompt.split(" ")[idx] for idx in prompt_debiasing_mask_ids] 543 | threestudio.info(f"Words that can potentially be removed: {words_to_debias}") 544 | for idx in prompt_debiasing_mask_ids: 545 | words = prompt.split(" ") 546 | prompt_ = " ".join(words[:idx] + words[(idx + 1) :]) 547 | part_probe = modulate(prompt_) 548 | 549 | pmi = full_probe / torch.lerp(part_probe, full_probe, 0.5) 550 | for i in range(pmi.shape[0]): 551 | if pmi[i].item() < 0.95: 552 | prompts[i][idx] = "" 553 | 554 | debiased_prompts = [" ".join([word for word in p if word]) for p in prompts] 555 | for d, debiased_prompt in zip(views, debiased_prompts): 556 | threestudio.info(f"Debiased prompt of the {d} view is [{debiased_prompt}]") 557 | 558 | del tokenizer, model 559 | cleanup() 560 | 561 | return debiased_prompts 562 | 563 | def __call__(self) -> PromptProcessorOutput: 564 | return PromptProcessorOutput( 565 | base_text_embeddings=self.base_text_embeddings, 566 | src_text_embeddings=self.src_text_embeddings, 567 | tgt_text_embeddings=self.tgt_text_embeddings, 568 | uncond_text_embeddings=self.uncond_text_embeddings, 569 | prompt=self.prompt, 570 | base_text_embeddings_vd=self.base_text_embeddings_vd, 571 | src_text_embeddings_vd=self.src_text_embeddings_vd, 572 | tgt_text_embeddings_vd=self.tgt_text_embeddings_vd, 573 | uncond_text_embeddings_vd=self.uncond_text_embeddings_vd, 574 | base_prompts_vd=self.base_prompts_vd, 575 | src_prompts_vd=self.src_prompts_vd, 576 | tgt_prompts_vd=self.tgt_prompts_vd, 577 | directions=self.directions, 578 | direction2idx=self.direction2idx, 579 | use_perp_neg=self.cfg.use_perp_neg, 580 | perp_neg_f_sb=self.cfg.perp_neg_f_sb, 581 | perp_neg_f_fsb=self.cfg.perp_neg_f_fsb, 582 | perp_neg_f_fs=self.cfg.perp_neg_f_fs, 583 | perp_neg_f_sf=self.cfg.perp_neg_f_sf, 584 | ) 585 | 586 | 587 | def load_textual_inversion( 588 | self, 589 | pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]], 590 | text_encoder = None, 591 | tokenizer = None, 592 | token: Optional[Union[str, List[str]]] = None, 593 | **kwargs, 594 | ): 595 | 596 | cache_dir = kwargs.pop("cache_dir", None) 597 | force_download = kwargs.pop("force_download", False) 598 | resume_download = kwargs.pop("resume_download", False) 599 | proxies = kwargs.pop("proxies", None) 600 | local_files_only = kwargs.pop("local_files_only", None) 601 | use_auth_token = kwargs.pop("use_auth_token", None) 602 | revision = kwargs.pop("revision", None) 603 | subfolder = kwargs.pop("subfolder", None) 604 | weight_name = kwargs.pop("weight_name", None) 605 | use_safetensors = kwargs.pop("use_safetensors", None) 606 | 607 | user_agent = { 608 | "file_type": "text_inversion", 609 | "framework": "pytorch", 610 | } 611 | 612 | if not isinstance(pretrained_model_name_or_path, list): 613 | pretrained_model_name_or_paths = [pretrained_model_name_or_path] 614 | else: 615 | pretrained_model_name_or_paths = pretrained_model_name_or_path 616 | 617 | if isinstance(token, str): 618 | tokens = [token] 619 | elif token is None: 620 | tokens = [None] * len(pretrained_model_name_or_paths) 621 | else: 622 | tokens = token 623 | 624 | if len(pretrained_model_name_or_paths) != len(tokens): 625 | raise ValueError( 626 | f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}" 627 | f"Make sure both lists have the same length." 628 | ) 629 | 630 | valid_tokens = [t for t in tokens if t is not None] 631 | if len(set(valid_tokens)) < len(valid_tokens): 632 | raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}") 633 | 634 | token_ids_and_embeddings = [] 635 | 636 | for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens): 637 | if not isinstance(pretrained_model_name_or_path, dict): 638 | # 1. Load textual inversion file 639 | model_file = None 640 | if model_file is None: 641 | model_file = _get_model_file( 642 | pretrained_model_name_or_path, 643 | weights_name=weight_name, 644 | cache_dir=cache_dir, 645 | force_download=force_download, 646 | resume_download=resume_download, 647 | proxies=proxies, 648 | local_files_only=local_files_only, 649 | use_auth_token=use_auth_token, 650 | revision=revision, 651 | subfolder=subfolder, 652 | user_agent=user_agent, 653 | ) 654 | try: 655 | state_dict = safetensors.torch.load_file(model_file, device="cpu") 656 | except: 657 | state_dict = torch.load(model_file, map_location="cpu") 658 | else: 659 | state_dict = pretrained_model_name_or_path 660 | 661 | # 2. Load token and embedding correcly from file 662 | loaded_token = None 663 | if isinstance(state_dict, torch.Tensor): 664 | if token is None: 665 | raise ValueError( 666 | "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`." 667 | ) 668 | embedding = state_dict 669 | elif len(state_dict) == 1: 670 | # diffusers 671 | loaded_token, embedding = next(iter(state_dict.items())) 672 | elif "string_to_param" in state_dict: 673 | # A1111 674 | loaded_token = state_dict["name"] 675 | embedding = state_dict["string_to_param"]["*"] 676 | 677 | if token is not None and loaded_token != token: 678 | print(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.") 679 | else: 680 | token = loaded_token 681 | 682 | embedding = embedding.to(dtype=text_encoder.dtype, device=text_encoder.device) 683 | 684 | # 3. Make sure we don't mess up the tokenizer or text encoder 685 | vocab = tokenizer.get_vocab() 686 | if token in vocab: 687 | raise ValueError( 688 | f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder." 689 | ) 690 | elif f"{token}_1" in vocab: 691 | multi_vector_tokens = [token] 692 | i = 1 693 | while f"{token}_{i}" in tokenizer.added_tokens_encoder: 694 | multi_vector_tokens.append(f"{token}_{i}") 695 | i += 1 696 | 697 | raise ValueError( 698 | f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder." 699 | ) 700 | 701 | is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1 702 | 703 | if is_multi_vector: 704 | tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])] 705 | embeddings = [e for e in embedding] # noqa: C416 706 | else: 707 | tokens = [token] 708 | embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding] 709 | 710 | # add tokens and get ids 711 | tokenizer.add_tokens(tokens) 712 | token_ids = tokenizer.convert_tokens_to_ids(tokens) 713 | token_ids_and_embeddings += zip(token_ids, embeddings) 714 | 715 | # resize token embeddings and set all new embeddings 716 | text_encoder.resize_token_embeddings(len(tokenizer)) 717 | for token_id, embedding in token_ids_and_embeddings: 718 | text_encoder.get_input_embeddings().weight.data[token_id] = embedding 719 | 720 | return text_encoder, tokenizer 721 | 722 | 723 | 724 | @threestudio.register("stable-diffusion-sds-bridge-prompt-processor") 725 | class SDSBridgePromptProcessor(PromptProcessor): 726 | @dataclass 727 | class Config(PromptProcessor.Config): 728 | pass 729 | 730 | cfg: Config 731 | 732 | ### these functions are unused, kept for debugging ### 733 | def configure_text_encoder(self) -> None: 734 | self.tokenizer = AutoTokenizer.from_pretrained( 735 | self.cfg.pretrained_model_name_or_path, subfolder="tokenizer" 736 | ) 737 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 738 | self.text_encoder = CLIPTextModel.from_pretrained( 739 | self.cfg.pretrained_model_name_or_path, subfolder="text_encoder" 740 | ).to(self.device) 741 | 742 | for p in self.text_encoder.parameters(): 743 | p.requires_grad_(False) 744 | 745 | def destroy_text_encoder(self) -> None: 746 | del self.tokenizer 747 | del self.text_encoder 748 | cleanup() 749 | 750 | def get_text_embeddings( 751 | self, prompt_src: Union[str, List[str]], prompt_tgt: Union[str, List[str]], negative_prompt: Union[str, List[str]] 752 | ) -> Tuple[Float[Tensor, "B 77 768"], Float[Tensor, "B 77 768"]]: 753 | if isinstance(prompt_src, str): 754 | prompt_src = [prompt_src] 755 | if isinstance(negative_prompt, str): 756 | negative_prompt = [negative_prompt] 757 | # Tokenize text and get embeddings 758 | tokens_src = self.tokenizer( 759 | prompt_src, 760 | padding="max_length", 761 | max_length=self.tokenizer.model_max_length, 762 | return_tensors="pt", 763 | ) 764 | tokens_tgt = self.tokenizer( 765 | prompt_tgt, 766 | padding="max_length", 767 | max_length=self.tokenizer.model_max_length, 768 | return_tensors="pt", 769 | ) 770 | uncond_tokens = self.tokenizer( 771 | negative_prompt, 772 | padding="max_length", 773 | max_length=self.tokenizer.model_max_length, 774 | return_tensors="pt", 775 | ) 776 | 777 | with torch.no_grad(): 778 | text_embeddings_src = self.text_encoder(tokens_src.input_ids.to(self.device))[0] 779 | text_embeddings_tgt = self.text_encoder(tokens_tgt.input_ids.to(self.device))[0] 780 | uncond_text_embeddings = self.text_encoder( 781 | uncond_tokens.input_ids.to(self.device) 782 | )[0] 783 | 784 | return text_embeddings_src, text_embeddings_tgt, uncond_text_embeddings 785 | 786 | ### 787 | 788 | def spawn_func(self, pretrained_model_name_or_path, prompts, cache_dir): 789 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 790 | tokenizer = AutoTokenizer.from_pretrained( 791 | pretrained_model_name_or_path, subfolder="tokenizer" 792 | ) 793 | text_encoder = CLIPTextModel.from_pretrained( 794 | pretrained_model_name_or_path, 795 | subfolder="text_encoder", 796 | device_map="auto", 797 | ) 798 | if len(self.cfg.texture_inversion_embedding) > 0: 799 | text_encoder, tokenizer = self.load_textual_inversion(self.cfg.texture_inversion_embedding, text_encoder, tokenizer) 800 | 801 | with torch.no_grad(): 802 | tokens = tokenizer( 803 | prompts, 804 | padding="max_length", 805 | max_length=tokenizer.model_max_length, 806 | return_tensors="pt", 807 | ) 808 | text_embeddings = text_encoder(tokens.input_ids.to(text_encoder.device))[0] 809 | 810 | for prompt, embedding in zip(prompts, text_embeddings): 811 | torch.save( 812 | embedding, 813 | os.path.join( 814 | cache_dir, 815 | f"{hash_prompt(pretrained_model_name_or_path, prompt)}.pt", 816 | ), 817 | ) 818 | 819 | del text_encoder 820 | --------------------------------------------------------------------------------