├── util ├── __init__.py ├── detection │ ├── __init__.py │ ├── p_head_v1.npz │ ├── w_head_v1.npz │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── nsfw_and_watermark_dectection.cpython-39.pyc │ │ └── nsfw_and_watermark_dectection.cpython-310.pyc │ └── nsfw_and_watermark_dectection.py └── __pycache__ │ └── __init__.cpython-310.pyc ├── sgm ├── modules │ ├── encoders │ │ ├── __init__.py │ │ └── __pycache__ │ │ │ ├── modules.cpython-310.pyc │ │ │ └── __init__.cpython-310.pyc │ ├── autoencoding │ │ ├── __init__.py │ │ ├── lpips │ │ │ ├── __init__.py │ │ │ ├── loss │ │ │ │ ├── __init__.py │ │ │ │ ├── LICENSE │ │ │ │ └── lpips.py │ │ │ ├── model │ │ │ │ ├── __init__.py │ │ │ │ ├── model.py │ │ │ │ └── LICENSE │ │ │ ├── vqperceptual.py │ │ │ └── util.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ └── temporal_ae.cpython-310.pyc │ │ ├── regularizers │ │ │ ├── __pycache__ │ │ │ │ ├── base.cpython-310.pyc │ │ │ │ └── __init__.cpython-310.pyc │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── quantize.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── lpips.py │ │ │ └── discriminator_loss.py │ │ └── temporal_ae.py │ ├── distributions │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ └── distributions.cpython-310.pyc │ │ └── distributions.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── util.cpython-310.pyc │ │ │ ├── model.cpython-310.pyc │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── denoiser.cpython-310.pyc │ │ │ ├── guiders.cpython-310.pyc │ │ │ ├── sampling.cpython-310.pyc │ │ │ ├── wrappers.cpython-310.pyc │ │ │ ├── discretizer.cpython-310.pyc │ │ │ ├── openaimodel.cpython-310.pyc │ │ │ ├── video_model.cpython-310.pyc │ │ │ ├── sampling_utils.cpython-310.pyc │ │ │ └── denoiser_scaling.cpython-310.pyc │ │ ├── denoiser_weighting.py │ │ ├── loss_weighting.py │ │ ├── sigma_sampling.py │ │ ├── wrappers.py │ │ ├── sampling_utils.py │ │ ├── denoiser_scaling.py │ │ ├── discretizer.py │ │ ├── denoiser.py │ │ ├── loss.py │ │ ├── guiders.py │ │ ├── sampling.py │ │ └── util.py │ ├── __pycache__ │ │ ├── ema.cpython-310.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── attention.cpython-310.pyc │ │ └── video_attention.cpython-310.pyc │ ├── __init__.py │ ├── ema.py │ └── video_attention.py ├── data │ ├── __init__.py │ ├── cifar10.py │ ├── mnist.py │ └── dataset.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── diffusion.cpython-310.pyc │ │ └── autoencoder.cpython-310.pyc │ └── diffusion.py ├── __pycache__ │ ├── util.cpython-310.pyc │ └── __init__.cpython-310.pyc ├── inference │ ├── __pycache__ │ │ └── helpers.cpython-310.pyc │ └── helpers.py ├── __init__.py ├── lr_scheduler.py ├── svd.yaml └── util.py ├── assets ├── images │ ├── cat.jpg │ ├── rocket.png │ ├── street.jpg │ └── waterfall.jpg ├── trajectory │ └── complex_4.pth └── outputs │ ├── 000001_tilt_30_14_up_i_10_seed_1.gif │ ├── 000002_zoom_1_14_in_i_10_seed_1.gif │ ├── 000003_rotate_30_14_clockwise_i_10_seed_1.gif │ └── 000004_hybrid_30_14_anticlockwise_i_12_seed_1.gif ├── requirement.txt ├── README.md └── sampling.py /util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /util/detection/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sgm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sgm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sgm/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import StableDataModuleFromConfig 2 | -------------------------------------------------------------------------------- /assets/images/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/assets/images/cat.jpg -------------------------------------------------------------------------------- /assets/images/rocket.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/assets/images/rocket.png -------------------------------------------------------------------------------- /assets/images/street.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/assets/images/street.jpg -------------------------------------------------------------------------------- /assets/images/waterfall.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/assets/images/waterfall.jpg -------------------------------------------------------------------------------- /util/detection/p_head_v1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/util/detection/p_head_v1.npz -------------------------------------------------------------------------------- /util/detection/w_head_v1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/util/detection/w_head_v1.npz -------------------------------------------------------------------------------- /assets/trajectory/complex_4.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/assets/trajectory/complex_4.pth -------------------------------------------------------------------------------- /sgm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .autoencoder import AutoencodingEngine 2 | from .diffusion import DiffusionEngine 3 | -------------------------------------------------------------------------------- /sgm/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /util/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/util/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/__pycache__/ema.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/modules/__pycache__/ema.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/models/__pycache__/diffusion.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/models/__pycache__/diffusion.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/modules/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /assets/outputs/000001_tilt_30_14_up_i_10_seed_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/assets/outputs/000001_tilt_30_14_up_i_10_seed_1.gif -------------------------------------------------------------------------------- /assets/outputs/000002_zoom_1_14_in_i_10_seed_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/assets/outputs/000002_zoom_1_14_in_i_10_seed_1.gif -------------------------------------------------------------------------------- /sgm/inference/__pycache__/helpers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/inference/__pycache__/helpers.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/models/__pycache__/autoencoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/models/__pycache__/autoencoder.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/__pycache__/attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/modules/__pycache__/attention.cpython-310.pyc -------------------------------------------------------------------------------- /util/detection/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/util/detection/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /util/detection/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/util/detection/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /sgm/modules/__pycache__/video_attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/modules/__pycache__/video_attention.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/encoders/__pycache__/modules.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/modules/encoders/__pycache__/modules.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/encoders/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/modules/encoders/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /assets/outputs/000003_rotate_30_14_clockwise_i_10_seed_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/assets/outputs/000003_rotate_30_14_clockwise_i_10_seed_1.gif -------------------------------------------------------------------------------- /sgm/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import AutoencodingEngine, DiffusionEngine 2 | from .util import get_configs_path, instantiate_from_config 3 | 4 | __version__ = "0.1.0" 5 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/modules/autoencoding/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/distributions/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/modules/distributions/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /assets/outputs/000004_hybrid_30_14_anticlockwise_i_12_seed_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/assets/outputs/000004_hybrid_30_14_anticlockwise_i_12_seed_1.gif -------------------------------------------------------------------------------- /sgm/modules/autoencoding/__pycache__/temporal_ae.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/modules/autoencoding/__pycache__/temporal_ae.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/denoiser.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/modules/diffusionmodules/__pycache__/denoiser.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/guiders.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/modules/diffusionmodules/__pycache__/guiders.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/sampling.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/modules/diffusionmodules/__pycache__/sampling.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/wrappers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/modules/diffusionmodules/__pycache__/wrappers.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/discretizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/modules/diffusionmodules/__pycache__/discretizer.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/video_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/modules/diffusionmodules/__pycache__/video_model.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/distributions/__pycache__/distributions.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/modules/distributions/__pycache__/distributions.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/autoencoding/regularizers/__pycache__/base.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/modules/autoencoding/regularizers/__pycache__/base.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/sampling_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/modules/diffusionmodules/__pycache__/sampling_utils.cpython-310.pyc -------------------------------------------------------------------------------- /util/detection/__pycache__/nsfw_and_watermark_dectection.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/util/detection/__pycache__/nsfw_and_watermark_dectection.cpython-39.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/denoiser_scaling.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/modules/diffusionmodules/__pycache__/denoiser_scaling.cpython-310.pyc -------------------------------------------------------------------------------- /util/detection/__pycache__/nsfw_and_watermark_dectection.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/util/detection/__pycache__/nsfw_and_watermark_dectection.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/autoencoding/regularizers/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAARRRY/CamTrol/HEAD/sgm/modules/autoencoding/regularizers/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoders.modules import GeneralConditioner 2 | 3 | UNCONDITIONAL_CONFIG = { 4 | "target": "sgm.modules.GeneralConditioner", 5 | "params": {"emb_models": []}, 6 | } 7 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/losses/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "GeneralLPIPSWithDiscriminator", 3 | "LatentLPIPS", 4 | ] 5 | 6 | from .discriminator_loss import GeneralLPIPSWithDiscriminator 7 | from .lpips import LatentLPIPS 8 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def hinge_d_loss(logits_real, logits_fake): 6 | loss_real = torch.mean(F.relu(1.0 - logits_real)) 7 | loss_fake = torch.mean(F.relu(1.0 + logits_fake)) 8 | d_loss = 0.5 * (loss_real + loss_fake) 9 | return d_loss 10 | 11 | 12 | def vanilla_d_loss(logits_real, logits_fake): 13 | d_loss = 0.5 * ( 14 | torch.mean(torch.nn.functional.softplus(-logits_real)) 15 | + torch.mean(torch.nn.functional.softplus(logits_fake)) 16 | ) 17 | return d_loss 18 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/denoiser_weighting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class UnitWeighting: 5 | def __call__(self, sigma): 6 | return torch.ones_like(sigma, device=sigma.device) 7 | 8 | 9 | class EDMWeighting: 10 | def __init__(self, sigma_data=0.5): 11 | self.sigma_data = sigma_data 12 | 13 | def __call__(self, sigma): 14 | return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 15 | 16 | 17 | class VWeighting(EDMWeighting): 18 | def __init__(self): 19 | super().__init__(sigma_data=1.0) 20 | 21 | 22 | class EpsWeighting: 23 | def __call__(self, sigma): 24 | return sigma**-2.0 25 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/loss_weighting.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | 5 | 6 | class DiffusionLossWeighting(ABC): 7 | @abstractmethod 8 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor: 9 | pass 10 | 11 | 12 | class UnitWeighting(DiffusionLossWeighting): 13 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor: 14 | return torch.ones_like(sigma, device=sigma.device) 15 | 16 | 17 | class EDMWeighting(DiffusionLossWeighting): 18 | def __init__(self, sigma_data: float = 0.5): 19 | self.sigma_data = sigma_data 20 | 21 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor: 22 | return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 23 | 24 | 25 | class VWeighting(EDMWeighting): 26 | def __init__(self): 27 | super().__init__(sigma_data=1.0) 28 | 29 | 30 | class EpsWeighting(DiffusionLossWeighting): 31 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor: 32 | return sigma**-2.0 33 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | black==23.7.0 2 | chardet==5.1.0 3 | clip @ git+https://github.com/openai/CLIP.git 4 | einops>=0.6.1 5 | fairscale>=0.4.13 6 | fire>=0.5.0 7 | fsspec>=2023.6.0 8 | invisible-watermark>=0.2.0 9 | kornia==0.6.9 10 | matplotlib>=3.7.2 11 | natsort>=8.4.0 12 | ninja>=1.11.1 13 | numpy>=1.24.4 14 | omegaconf>=2.3.0 15 | open-clip-torch>=2.20.0 16 | opencv-python==4.6.0.66 17 | pandas>=2.0.3 18 | pillow>=9.5.0 19 | pudb>=2022.1.3 20 | pytorch-lightning==2.0.1 21 | pyyaml>=6.0.1 22 | rembg 23 | scipy>=1.10.1 24 | streamlit>=0.73.1 25 | tensorboardx==2.6 26 | tokenizers==0.12.1 27 | torch>=2.0.1 28 | torchaudio>=2.0.2 29 | torchdata==0.6.1 30 | torchmetrics>=1.0.1 31 | torchvision>=0.15.2 32 | tqdm>=4.65.0 33 | # transformers==4.19.1 34 | transformers 35 | triton==2.0.0 36 | urllib3<1.27,>=1.25.4 37 | wandb>=0.15.6 38 | webdataset>=0.2.33 39 | wheel>=0.41.0 40 | xformers == 0.0.22 41 | gradio 42 | streamlit-keyup==0.2.0 43 | imageio-ffmpeg 44 | pyav 45 | accelerate 46 | diffusers 47 | nltk 48 | decord 49 | ffmpeg-python 50 | timm==0.6.7 -------------------------------------------------------------------------------- /sgm/modules/autoencoding/regularizers/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from ....modules.distributions.distributions import \ 9 | DiagonalGaussianDistribution 10 | from .base import AbstractRegularizer 11 | 12 | 13 | class DiagonalGaussianRegularizer(AbstractRegularizer): 14 | def __init__(self, sample: bool = True): 15 | super().__init__() 16 | self.sample = sample 17 | 18 | def get_trainable_parameters(self) -> Any: 19 | yield from () 20 | 21 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 22 | log = dict() 23 | posterior = DiagonalGaussianDistribution(z) 24 | if self.sample: 25 | z = posterior.sample() 26 | else: 27 | z = posterior.mode() 28 | kl_loss = posterior.kl() 29 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 30 | log["kl_loss"] = kl_loss 31 | return z, log 32 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/sigma_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ...util import default, instantiate_from_config 4 | 5 | 6 | class EDMSampling: 7 | def __init__(self, p_mean=-1.2, p_std=1.2): 8 | self.p_mean = p_mean 9 | self.p_std = p_std 10 | 11 | def __call__(self, n_samples, rand=None): 12 | log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,))) 13 | return log_sigma.exp() 14 | 15 | 16 | class DiscreteSampling: 17 | def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True): 18 | self.num_idx = num_idx 19 | self.sigmas = instantiate_from_config(discretization_config)( 20 | num_idx, do_append_zero=do_append_zero, flip=flip 21 | ) 22 | 23 | def idx_to_sigma(self, idx): 24 | return self.sigmas[idx] 25 | 26 | def __call__(self, n_samples, rand=None): 27 | idx = default( 28 | rand, 29 | torch.randint(0, self.num_idx, (n_samples,)), 30 | ) 31 | return self.idx_to_sigma(idx) 32 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/wrappers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from packaging import version 4 | 5 | OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" 6 | 7 | 8 | class IdentityWrapper(nn.Module): 9 | def __init__(self, diffusion_model, compile_model: bool = False): 10 | super().__init__() 11 | compile = ( 12 | torch.compile 13 | if (version.parse(torch.__version__) >= version.parse("2.0.0")) 14 | and compile_model 15 | else lambda x: x 16 | ) 17 | self.diffusion_model = compile(diffusion_model) 18 | 19 | def forward(self, *args, **kwargs): 20 | return self.diffusion_model(*args, **kwargs) 21 | 22 | 23 | class OpenAIWrapper(IdentityWrapper): 24 | def forward( 25 | self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs 26 | ) -> torch.Tensor: 27 | x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) 28 | return self.diffusion_model( 29 | x, 30 | timesteps=t, 31 | context=c.get("crossattn", None), 32 | y=c.get("vector", None), 33 | **kwargs, 34 | ) 35 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/sampling_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy import integrate 3 | 4 | from ...util import append_dims 5 | 6 | 7 | def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): 8 | if order - 1 > i: 9 | raise ValueError(f"Order {order} too high for step {i}") 10 | 11 | def fn(tau): 12 | prod = 1.0 13 | for k in range(order): 14 | if j == k: 15 | continue 16 | prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) 17 | return prod 18 | 19 | return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0] 20 | 21 | 22 | def get_ancestral_step(sigma_from, sigma_to, eta=1.0): 23 | if not eta: 24 | return sigma_to, 0.0 25 | sigma_up = torch.minimum( 26 | sigma_to, 27 | eta 28 | * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, 29 | ) 30 | sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 31 | return sigma_down, sigma_up 32 | 33 | 34 | def to_d(x, sigma, denoised): 35 | return (x - denoised) / append_dims(sigma, x.ndim) 36 | 37 | 38 | def to_neg_log_sigma(sigma): 39 | return sigma.log().neg() 40 | 41 | 42 | def to_sigma(neg_log_sigma): 43 | return neg_log_sigma.neg().exp() 44 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/regularizers/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | 9 | class AbstractRegularizer(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 14 | raise NotImplementedError() 15 | 16 | @abstractmethod 17 | def get_trainable_parameters(self) -> Any: 18 | raise NotImplementedError() 19 | 20 | 21 | class IdentityRegularizer(AbstractRegularizer): 22 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 23 | return z, dict() 24 | 25 | def get_trainable_parameters(self) -> Any: 26 | yield from () 27 | 28 | 29 | def measure_perplexity( 30 | predicted_indices: torch.Tensor, num_centroids: int 31 | ) -> Tuple[torch.Tensor, torch.Tensor]: 32 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 33 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 34 | encodings = ( 35 | F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) 36 | ) 37 | avg_probs = encodings.mean(0) 38 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 39 | cluster_use = torch.sum(avg_probs > 0) 40 | return perplexity, cluster_use 41 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/denoiser_scaling.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Tuple 3 | 4 | import torch 5 | 6 | 7 | class DenoiserScaling(ABC): 8 | @abstractmethod 9 | def __call__( 10 | self, sigma: torch.Tensor 11 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 12 | pass 13 | 14 | 15 | class EDMScaling: 16 | def __init__(self, sigma_data: float = 0.5): 17 | self.sigma_data = sigma_data 18 | 19 | def __call__( 20 | self, sigma: torch.Tensor 21 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 22 | c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) 23 | c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 24 | c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 25 | c_noise = 0.25 * sigma.log() 26 | return c_skip, c_out, c_in, c_noise 27 | 28 | 29 | class EpsScaling: 30 | def __call__( 31 | self, sigma: torch.Tensor 32 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 33 | c_skip = torch.ones_like(sigma, device=sigma.device) 34 | c_out = -sigma 35 | c_in = 1 / (sigma**2 + 1.0) ** 0.5 36 | c_noise = sigma.clone() 37 | return c_skip, c_out, c_in, c_noise 38 | 39 | 40 | class VScaling: 41 | def __call__( 42 | self, sigma: torch.Tensor 43 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 44 | c_skip = 1.0 / (sigma**2 + 1.0) 45 | c_out = -sigma / (sigma**2 + 1.0) ** 0.5 46 | c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 47 | c_noise = sigma.clone() 48 | return c_skip, c_out, c_in, c_noise 49 | 50 | 51 | class VScalingWithEDMcNoise(DenoiserScaling): 52 | def __call__( 53 | self, sigma: torch.Tensor 54 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 55 | c_skip = 1.0 / (sigma**2 + 1.0) 56 | c_out = -sigma / (sigma**2 + 1.0) ** 0.5 57 | c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 58 | c_noise = 0.25 * sigma.log() 59 | return c_skip, c_out, c_in, c_noise 60 | -------------------------------------------------------------------------------- /sgm/data/cifar10.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torchvision 3 | from torch.utils.data import DataLoader, Dataset 4 | from torchvision import transforms 5 | 6 | 7 | class CIFAR10DataDictWrapper(Dataset): 8 | def __init__(self, dset): 9 | super().__init__() 10 | self.dset = dset 11 | 12 | def __getitem__(self, i): 13 | x, y = self.dset[i] 14 | return {"jpg": x, "cls": y} 15 | 16 | def __len__(self): 17 | return len(self.dset) 18 | 19 | 20 | class CIFAR10Loader(pl.LightningDataModule): 21 | def __init__(self, batch_size, num_workers=0, shuffle=True): 22 | super().__init__() 23 | 24 | transform = transforms.Compose( 25 | [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] 26 | ) 27 | 28 | self.batch_size = batch_size 29 | self.num_workers = num_workers 30 | self.shuffle = shuffle 31 | self.train_dataset = CIFAR10DataDictWrapper( 32 | torchvision.datasets.CIFAR10( 33 | root=".data/", train=True, download=True, transform=transform 34 | ) 35 | ) 36 | self.test_dataset = CIFAR10DataDictWrapper( 37 | torchvision.datasets.CIFAR10( 38 | root=".data/", train=False, download=True, transform=transform 39 | ) 40 | ) 41 | 42 | def prepare_data(self): 43 | pass 44 | 45 | def train_dataloader(self): 46 | return DataLoader( 47 | self.train_dataset, 48 | batch_size=self.batch_size, 49 | shuffle=self.shuffle, 50 | num_workers=self.num_workers, 51 | ) 52 | 53 | def test_dataloader(self): 54 | return DataLoader( 55 | self.test_dataset, 56 | batch_size=self.batch_size, 57 | shuffle=self.shuffle, 58 | num_workers=self.num_workers, 59 | ) 60 | 61 | def val_dataloader(self): 62 | return DataLoader( 63 | self.test_dataset, 64 | batch_size=self.batch_size, 65 | shuffle=self.shuffle, 66 | num_workers=self.num_workers, 67 | ) 68 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/discretizer.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from functools import partial 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from ...modules.diffusionmodules.util import make_beta_schedule 8 | from ...util import append_zero 9 | 10 | 11 | def generate_roughly_equally_spaced_steps( 12 | num_substeps: int, max_step: int 13 | ) -> np.ndarray: 14 | return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] 15 | 16 | 17 | class Discretization: 18 | def __call__(self, n, do_append_zero=True, device="cpu", flip=False): 19 | sigmas = self.get_sigmas(n, device=device) 20 | sigmas = append_zero(sigmas) if do_append_zero else sigmas 21 | return sigmas if not flip else torch.flip(sigmas, (0,)) 22 | 23 | @abstractmethod 24 | def get_sigmas(self, n, device): 25 | pass 26 | 27 | 28 | class EDMDiscretization(Discretization): 29 | def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0): 30 | self.sigma_min = sigma_min 31 | self.sigma_max = sigma_max 32 | self.rho = rho 33 | 34 | def get_sigmas(self, n, device="cpu"): 35 | ramp = torch.linspace(0, 1, n, device=device) 36 | min_inv_rho = self.sigma_min ** (1 / self.rho) 37 | max_inv_rho = self.sigma_max ** (1 / self.rho) 38 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho 39 | return sigmas 40 | 41 | 42 | class LegacyDDPMDiscretization(Discretization): 43 | def __init__( 44 | self, 45 | linear_start=0.00085, 46 | linear_end=0.0120, 47 | num_timesteps=1000, 48 | ): 49 | super().__init__() 50 | self.num_timesteps = num_timesteps 51 | betas = make_beta_schedule( 52 | "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end 53 | ) 54 | alphas = 1.0 - betas 55 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 56 | self.to_torch = partial(torch.tensor, dtype=torch.float32) 57 | 58 | def get_sigmas(self, n, device="cpu"): 59 | if n < self.num_timesteps: 60 | timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) 61 | alphas_cumprod = self.alphas_cumprod[timesteps] 62 | elif n == self.num_timesteps: 63 | alphas_cumprod = self.alphas_cumprod 64 | else: 65 | raise ValueError 66 | 67 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device) 68 | sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 69 | return torch.flip(sigmas, (0,)) 70 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/denoiser.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ...util import append_dims, instantiate_from_config 7 | from .denoiser_scaling import DenoiserScaling 8 | from .discretizer import Discretization 9 | 10 | 11 | class Denoiser(nn.Module): 12 | def __init__(self, scaling_config: Dict): 13 | super().__init__() 14 | 15 | self.scaling: DenoiserScaling = instantiate_from_config(scaling_config) 16 | 17 | def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: 18 | return sigma 19 | 20 | def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: 21 | return c_noise 22 | 23 | def forward( 24 | self, 25 | network: nn.Module, 26 | input: torch.Tensor, 27 | sigma: torch.Tensor, 28 | cond: Dict, 29 | **additional_model_inputs, 30 | ) -> torch.Tensor: 31 | sigma = self.possibly_quantize_sigma(sigma) 32 | sigma_shape = sigma.shape 33 | sigma = append_dims(sigma, input.ndim) 34 | c_skip, c_out, c_in, c_noise = self.scaling(sigma) 35 | c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) 36 | return ( 37 | network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out 38 | + input * c_skip 39 | ) 40 | 41 | 42 | class DiscreteDenoiser(Denoiser): 43 | def __init__( 44 | self, 45 | scaling_config: Dict, 46 | num_idx: int, 47 | discretization_config: Dict, 48 | do_append_zero: bool = False, 49 | quantize_c_noise: bool = True, 50 | flip: bool = True, 51 | ): 52 | super().__init__(scaling_config) 53 | self.discretization: Discretization = instantiate_from_config( 54 | discretization_config 55 | ) 56 | sigmas = self.discretization(num_idx, do_append_zero=do_append_zero, flip=flip) 57 | self.register_buffer("sigmas", sigmas) 58 | self.quantize_c_noise = quantize_c_noise 59 | self.num_idx = num_idx 60 | 61 | def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor: 62 | dists = sigma - self.sigmas[:, None] 63 | return dists.abs().argmin(dim=0).view(sigma.shape) 64 | 65 | def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor: 66 | return self.sigmas[idx] 67 | 68 | def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: 69 | return self.idx_to_sigma(self.sigma_to_idx(sigma)) 70 | 71 | def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: 72 | if self.quantize_c_noise: 73 | return self.sigma_to_idx(c_noise) 74 | else: 75 | return c_noise 76 | -------------------------------------------------------------------------------- /sgm/data/mnist.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torchvision 3 | from torch.utils.data import DataLoader, Dataset 4 | from torchvision import transforms 5 | 6 | 7 | class MNISTDataDictWrapper(Dataset): 8 | def __init__(self, dset): 9 | super().__init__() 10 | self.dset = dset 11 | 12 | def __getitem__(self, i): 13 | x, y = self.dset[i] 14 | return {"jpg": x, "cls": y} 15 | 16 | def __len__(self): 17 | return len(self.dset) 18 | 19 | 20 | class MNISTLoader(pl.LightningDataModule): 21 | def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True): 22 | super().__init__() 23 | 24 | transform = transforms.Compose( 25 | [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] 26 | ) 27 | 28 | self.batch_size = batch_size 29 | self.num_workers = num_workers 30 | self.prefetch_factor = prefetch_factor if num_workers > 0 else 0 31 | self.shuffle = shuffle 32 | self.train_dataset = MNISTDataDictWrapper( 33 | torchvision.datasets.MNIST( 34 | root=".data/", train=True, download=True, transform=transform 35 | ) 36 | ) 37 | self.test_dataset = MNISTDataDictWrapper( 38 | torchvision.datasets.MNIST( 39 | root=".data/", train=False, download=True, transform=transform 40 | ) 41 | ) 42 | 43 | def prepare_data(self): 44 | pass 45 | 46 | def train_dataloader(self): 47 | return DataLoader( 48 | self.train_dataset, 49 | batch_size=self.batch_size, 50 | shuffle=self.shuffle, 51 | num_workers=self.num_workers, 52 | prefetch_factor=self.prefetch_factor, 53 | ) 54 | 55 | def test_dataloader(self): 56 | return DataLoader( 57 | self.test_dataset, 58 | batch_size=self.batch_size, 59 | shuffle=self.shuffle, 60 | num_workers=self.num_workers, 61 | prefetch_factor=self.prefetch_factor, 62 | ) 63 | 64 | def val_dataloader(self): 65 | return DataLoader( 66 | self.test_dataset, 67 | batch_size=self.batch_size, 68 | shuffle=self.shuffle, 69 | num_workers=self.num_workers, 70 | prefetch_factor=self.prefetch_factor, 71 | ) 72 | 73 | 74 | if __name__ == "__main__": 75 | dset = MNISTDataDictWrapper( 76 | torchvision.datasets.MNIST( 77 | root=".data/", 78 | train=False, 79 | download=True, 80 | transform=transforms.Compose( 81 | [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] 82 | ), 83 | ) 84 | ) 85 | ex = dset[0] 86 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/losses/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ....util import default, instantiate_from_config 5 | from ..lpips.loss.lpips import LPIPS 6 | 7 | 8 | class LatentLPIPS(nn.Module): 9 | def __init__( 10 | self, 11 | decoder_config, 12 | perceptual_weight=1.0, 13 | latent_weight=1.0, 14 | scale_input_to_tgt_size=False, 15 | scale_tgt_to_input_size=False, 16 | perceptual_weight_on_inputs=0.0, 17 | ): 18 | super().__init__() 19 | self.scale_input_to_tgt_size = scale_input_to_tgt_size 20 | self.scale_tgt_to_input_size = scale_tgt_to_input_size 21 | self.init_decoder(decoder_config) 22 | self.perceptual_loss = LPIPS().eval() 23 | self.perceptual_weight = perceptual_weight 24 | self.latent_weight = latent_weight 25 | self.perceptual_weight_on_inputs = perceptual_weight_on_inputs 26 | 27 | def init_decoder(self, config): 28 | self.decoder = instantiate_from_config(config) 29 | if hasattr(self.decoder, "encoder"): 30 | del self.decoder.encoder 31 | 32 | def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"): 33 | log = dict() 34 | loss = (latent_inputs - latent_predictions) ** 2 35 | log[f"{split}/latent_l2_loss"] = loss.mean().detach() 36 | image_reconstructions = None 37 | if self.perceptual_weight > 0.0: 38 | image_reconstructions = self.decoder.decode(latent_predictions) 39 | image_targets = self.decoder.decode(latent_inputs) 40 | perceptual_loss = self.perceptual_loss( 41 | image_targets.contiguous(), image_reconstructions.contiguous() 42 | ) 43 | loss = ( 44 | self.latent_weight * loss.mean() 45 | + self.perceptual_weight * perceptual_loss.mean() 46 | ) 47 | log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach() 48 | 49 | if self.perceptual_weight_on_inputs > 0.0: 50 | image_reconstructions = default( 51 | image_reconstructions, self.decoder.decode(latent_predictions) 52 | ) 53 | if self.scale_input_to_tgt_size: 54 | image_inputs = torch.nn.functional.interpolate( 55 | image_inputs, 56 | image_reconstructions.shape[2:], 57 | mode="bicubic", 58 | antialias=True, 59 | ) 60 | elif self.scale_tgt_to_input_size: 61 | image_reconstructions = torch.nn.functional.interpolate( 62 | image_reconstructions, 63 | image_inputs.shape[2:], 64 | mode="bicubic", 65 | antialias=True, 66 | ) 67 | 68 | perceptual_loss2 = self.perceptual_loss( 69 | image_inputs.contiguous(), image_reconstructions.contiguous() 70 | ) 71 | loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean() 72 | log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach() 73 | return loss, log 74 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/model/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch.nn as nn 4 | 5 | from ..util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find("BatchNorm") != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | 22 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 23 | """Construct a PatchGAN discriminator 24 | Parameters: 25 | input_nc (int) -- the number of channels in input images 26 | ndf (int) -- the number of filters in the last conv layer 27 | n_layers (int) -- the number of conv layers in the discriminator 28 | norm_layer -- normalization layer 29 | """ 30 | super(NLayerDiscriminator, self).__init__() 31 | if not use_actnorm: 32 | norm_layer = nn.BatchNorm2d 33 | else: 34 | norm_layer = ActNorm 35 | if ( 36 | type(norm_layer) == functools.partial 37 | ): # no need to use bias as BatchNorm2d has affine parameters 38 | use_bias = norm_layer.func != nn.BatchNorm2d 39 | else: 40 | use_bias = norm_layer != nn.BatchNorm2d 41 | 42 | kw = 4 43 | padw = 1 44 | sequence = [ 45 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 46 | nn.LeakyReLU(0.2, True), 47 | ] 48 | nf_mult = 1 49 | nf_mult_prev = 1 50 | for n in range(1, n_layers): # gradually increase the number of filters 51 | nf_mult_prev = nf_mult 52 | nf_mult = min(2**n, 8) 53 | sequence += [ 54 | nn.Conv2d( 55 | ndf * nf_mult_prev, 56 | ndf * nf_mult, 57 | kernel_size=kw, 58 | stride=2, 59 | padding=padw, 60 | bias=use_bias, 61 | ), 62 | norm_layer(ndf * nf_mult), 63 | nn.LeakyReLU(0.2, True), 64 | ] 65 | 66 | nf_mult_prev = nf_mult 67 | nf_mult = min(2**n_layers, 8) 68 | sequence += [ 69 | nn.Conv2d( 70 | ndf * nf_mult_prev, 71 | ndf * nf_mult, 72 | kernel_size=kw, 73 | stride=1, 74 | padding=padw, 75 | bias=use_bias, 76 | ), 77 | norm_layer(ndf * nf_mult), 78 | nn.LeakyReLU(0.2, True), 79 | ] 80 | 81 | sequence += [ 82 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) 83 | ] # output 1 channel prediction map 84 | self.main = nn.Sequential(*sequence) 85 | 86 | def forward(self, input): 87 | """Standard forward.""" 88 | return self.main(input) 89 | -------------------------------------------------------------------------------- /sgm/data/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torchdata.datapipes.iter 4 | import webdataset as wds 5 | from omegaconf import DictConfig 6 | from pytorch_lightning import LightningDataModule 7 | 8 | try: 9 | from sdata import create_dataset, create_dummy_dataset, create_loader 10 | except ImportError as e: 11 | print("#" * 100) 12 | print("Datasets not yet available") 13 | print("to enable, we need to add stable-datasets as a submodule") 14 | print("please use ``git submodule update --init --recursive``") 15 | print("and do ``pip install -e stable-datasets/`` from the root of this repo") 16 | print("#" * 100) 17 | exit(1) 18 | 19 | 20 | class StableDataModuleFromConfig(LightningDataModule): 21 | def __init__( 22 | self, 23 | train: DictConfig, 24 | validation: Optional[DictConfig] = None, 25 | test: Optional[DictConfig] = None, 26 | skip_val_loader: bool = False, 27 | dummy: bool = False, 28 | ): 29 | super().__init__() 30 | self.train_config = train 31 | assert ( 32 | "datapipeline" in self.train_config and "loader" in self.train_config 33 | ), "train config requires the fields `datapipeline` and `loader`" 34 | 35 | self.val_config = validation 36 | if not skip_val_loader: 37 | if self.val_config is not None: 38 | assert ( 39 | "datapipeline" in self.val_config and "loader" in self.val_config 40 | ), "validation config requires the fields `datapipeline` and `loader`" 41 | else: 42 | print( 43 | "Warning: No Validation datapipeline defined, using that one from training" 44 | ) 45 | self.val_config = train 46 | 47 | self.test_config = test 48 | if self.test_config is not None: 49 | assert ( 50 | "datapipeline" in self.test_config and "loader" in self.test_config 51 | ), "test config requires the fields `datapipeline` and `loader`" 52 | 53 | self.dummy = dummy 54 | if self.dummy: 55 | print("#" * 100) 56 | print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)") 57 | print("#" * 100) 58 | 59 | def setup(self, stage: str) -> None: 60 | print("Preparing datasets") 61 | if self.dummy: 62 | data_fn = create_dummy_dataset 63 | else: 64 | data_fn = create_dataset 65 | 66 | self.train_datapipeline = data_fn(**self.train_config.datapipeline) 67 | if self.val_config: 68 | self.val_datapipeline = data_fn(**self.val_config.datapipeline) 69 | if self.test_config: 70 | self.test_datapipeline = data_fn(**self.test_config.datapipeline) 71 | 72 | def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe: 73 | loader = create_loader(self.train_datapipeline, **self.train_config.loader) 74 | return loader 75 | 76 | def val_dataloader(self) -> wds.DataPipeline: 77 | return create_loader(self.val_datapipeline, **self.val_config.loader) 78 | 79 | def test_dataloader(self) -> wds.DataPipeline: 80 | return create_loader(self.test_datapipeline, **self.test_config.loader) 81 | -------------------------------------------------------------------------------- /sgm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | torch.tensor(0, dtype=torch.int) 16 | if use_num_upates 17 | else torch.tensor(-1, dtype=torch.int), 18 | ) 19 | 20 | for name, p in model.named_parameters(): 21 | if p.requires_grad: 22 | # remove as '.'-character is not allowed in buffers 23 | s_name = name.replace(".", "") 24 | self.m_name2s_name.update({name: s_name}) 25 | self.register_buffer(s_name, p.clone().detach().data) 26 | 27 | self.collected_params = [] 28 | 29 | def reset_num_updates(self): 30 | del self.num_updates 31 | self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) 32 | 33 | def forward(self, model): 34 | decay = self.decay 35 | 36 | if self.num_updates >= 0: 37 | self.num_updates += 1 38 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 39 | 40 | one_minus_decay = 1.0 - decay 41 | 42 | with torch.no_grad(): 43 | m_param = dict(model.named_parameters()) 44 | shadow_params = dict(self.named_buffers()) 45 | 46 | for key in m_param: 47 | if m_param[key].requires_grad: 48 | sname = self.m_name2s_name[key] 49 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 50 | shadow_params[sname].sub_( 51 | one_minus_decay * (shadow_params[sname] - m_param[key]) 52 | ) 53 | else: 54 | assert not key in self.m_name2s_name 55 | 56 | def copy_to(self, model): 57 | m_param = dict(model.named_parameters()) 58 | shadow_params = dict(self.named_buffers()) 59 | for key in m_param: 60 | if m_param[key].requires_grad: 61 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 62 | else: 63 | assert not key in self.m_name2s_name 64 | 65 | def store(self, parameters): 66 | """ 67 | Save the current parameters for restoring later. 68 | Args: 69 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 70 | temporarily stored. 71 | """ 72 | self.collected_params = [param.clone() for param in parameters] 73 | 74 | def restore(self, parameters): 75 | """ 76 | Restore the parameters stored with the `store` method. 77 | Useful to validate the model with EMA parameters without affecting the 78 | original optimization process. Store the parameters before the 79 | `copy_to` method. After validation (or model saving), use this to 80 | restore the former parameters. 81 | Args: 82 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 83 | updated with the stored parameters. 84 | """ 85 | for c_param, param in zip(self.collected_params, parameters): 86 | param.data.copy_(c_param.data) 87 | -------------------------------------------------------------------------------- /sgm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.sum( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=[1, 2, 3]): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims, 69 | ) 70 | 71 | def mode(self): 72 | return self.mean 73 | 74 | 75 | def normal_kl(mean1, logvar1, mean2, logvar2): 76 | """ 77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 78 | Compute the KL divergence between two gaussians. 79 | Shapes are automatically broadcasted, so batches can be compared to 80 | scalars, among other use cases. 81 | """ 82 | tensor = None 83 | for obj in (mean1, logvar1, mean2, logvar2): 84 | if isinstance(obj, torch.Tensor): 85 | tensor = obj 86 | break 87 | assert tensor is not None, "at least one argument must be a Tensor" 88 | 89 | # Force variances to be Tensors. Broadcasting helps convert scalars to 90 | # Tensors, but it does not work for torch.exp(). 91 | logvar1, logvar2 = [ 92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 93 | for x in (logvar1, logvar2) 94 | ] 95 | 96 | return 0.5 * ( 97 | -1.0 98 | + logvar2 99 | - logvar1 100 | + torch.exp(logvar1 - logvar2) 101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 102 | ) 103 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CamTrol: Training-free Camera Control for Video Generation 2 | 3 | --- 4 | ### ‼️CogVideoX version/Any Video Model version CamTrol‼️ 5 | We now have CamTrol code implemented on **[diffusers-based](https://github.com/huggingface/diffusers/tree/main)** video models. It makes it faster to revise the code for the more powerful video models in diffusers. 6 | 7 | Some results of CogVideoX+CamTrol can be found on the CamTrol [page](https://lifedecoder.github.io/CamTrol/). 8 | 9 | The code : [https://github.com/LAARRRY/CamTrol-CogVideoX-Diffusers](https://github.com/LAARRRY/CamTrol-CogVideoX-Diffusers). 10 | 11 | --- 12 | This repository is unofficial implementation of [CamTrol: Training-free Camera Control for Video Generation](https://lifedecoder.github.io/CamTrol/), based on SVD. 13 | 14 | Some videos generated through SVD: 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 |
23 | 24 | ## Setup 25 | 26 | 1. `pip install -r requirement.txt` 27 | 28 | 2. Download SVD checkpoint [svd.safetensors](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid/tree/main) and set its path at `ckpt_path` in `sgm/svd.yaml`. 29 | 30 | 3. Clone depth estimation model: `git clone https://github.com/isl-org/ZoeDepth.git` 31 | 32 | 33 | The code downloads [stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting) and [open-clip](https://github.com/mlfoundations/open_clip) automatically, you can set to your path if they're already done. 34 | 35 | ## Sampling 36 | ``` 37 | CUDA_VISIBLE_DEVICES=0 python3 sampling.py \ 38 | --input_path "assets/images/street.jpg" \ 39 | --prompt "a vivid anime street, wind blows." \ 40 | --neg_prompt " " \ 41 | --pcd_mode "hybrid default 14 out_left_up_down" \ 42 | --add_index 12 \ 43 | --seed 1 \ 44 | --save_warps False \ 45 | --load_warps None 46 | ``` 47 | 48 | - `pcd_mode`: camera motion for point cloud rendering, a string concat by four elements. For each element, the first defines camera motion, the second defines moving distance or angle, the third defines number of frames, the last defines moving direction. You can load any camera extrinsics matrices in `complex` mode, and set bigger `add_index` for better motion alignment. 49 | - `prompt`, `neg_prompt`: as SVD doesn't support text input, these mainly serve for stable diffusion inpainting. 50 | - `add_index`: t_0 in the paper, balancing trade-off between motion fidelity and video diversity. Set between `0` and `num_frames`, the bigger the more faithful video aligns to camera motion. 51 | - `save_warps`: whether save multi-view renderings, you can reload the already-rendered images as this process might takes some time. Use low-res images to boost speed. 52 | - `load_warps`: whether load renderings from `save_warps` or not. 53 | 54 | 55 | ## Backbones 56 | I used SVD in this repository. You can use it on your customized video diffusion model. 57 | 58 | ## Acknowledgement 59 | The code is majorly founded on [SVD](https://github.com/Stability-AI/generative-models/tree/main) and [LucidDreamer](https://github.com/luciddreamer-cvlab/LucidDreamer). 60 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/model/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | 26 | --------------------------- LICENSE FOR pix2pix -------------------------------- 27 | BSD License 28 | 29 | For pix2pix software 30 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu 31 | All rights reserved. 32 | 33 | Redistribution and use in source and binary forms, with or without 34 | modification, are permitted provided that the following conditions are met: 35 | 36 | * Redistributions of source code must retain the above copyright notice, this 37 | list of conditions and the following disclaimer. 38 | 39 | * Redistributions in binary form must reproduce the above copyright notice, 40 | this list of conditions and the following disclaimer in the documentation 41 | and/or other materials provided with the distribution. 42 | 43 | ----------------------------- LICENSE FOR DCGAN -------------------------------- 44 | BSD License 45 | 46 | For dcgan.torch software 47 | 48 | Copyright (c) 2015, Facebook, Inc. All rights reserved. 49 | 50 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 51 | 52 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 53 | 54 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 55 | 56 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 57 | 58 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/loss.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ...modules.autoencoding.lpips.loss.lpips import LPIPS 7 | from ...modules.encoders.modules import GeneralConditioner 8 | from ...util import append_dims, instantiate_from_config 9 | from .denoiser import Denoiser 10 | 11 | 12 | class StandardDiffusionLoss(nn.Module): 13 | def __init__( 14 | self, 15 | sigma_sampler_config: dict, 16 | loss_weighting_config: dict, 17 | loss_type: str = "l2", 18 | offset_noise_level: float = 0.0, 19 | batch2model_keys: Optional[Union[str, List[str]]] = None, 20 | ): 21 | super().__init__() 22 | 23 | assert loss_type in ["l2", "l1", "lpips"] 24 | 25 | self.sigma_sampler = instantiate_from_config(sigma_sampler_config) 26 | self.loss_weighting = instantiate_from_config(loss_weighting_config) 27 | 28 | self.loss_type = loss_type 29 | self.offset_noise_level = offset_noise_level 30 | 31 | if loss_type == "lpips": 32 | self.lpips = LPIPS().eval() 33 | 34 | if not batch2model_keys: 35 | batch2model_keys = [] 36 | 37 | if isinstance(batch2model_keys, str): 38 | batch2model_keys = [batch2model_keys] 39 | 40 | self.batch2model_keys = set(batch2model_keys) 41 | 42 | def get_noised_input( 43 | self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor 44 | ) -> torch.Tensor: 45 | noised_input = input + noise * sigmas_bc 46 | return noised_input 47 | 48 | def forward( 49 | self, 50 | network: nn.Module, 51 | denoiser: Denoiser, 52 | conditioner: GeneralConditioner, 53 | input: torch.Tensor, 54 | batch: Dict, 55 | ) -> torch.Tensor: 56 | cond = conditioner(batch) 57 | return self._forward(network, denoiser, cond, input, batch) 58 | 59 | def _forward( 60 | self, 61 | network: nn.Module, 62 | denoiser: Denoiser, 63 | cond: Dict, 64 | input: torch.Tensor, 65 | batch: Dict, 66 | ) -> Tuple[torch.Tensor, Dict]: 67 | additional_model_inputs = { 68 | key: batch[key] for key in self.batch2model_keys.intersection(batch) 69 | } 70 | sigmas = self.sigma_sampler(input.shape[0]).to(input) 71 | 72 | noise = torch.randn_like(input) 73 | if self.offset_noise_level > 0.0: 74 | offset_shape = ( 75 | (input.shape[0], 1, input.shape[2]) 76 | if self.n_frames is not None 77 | else (input.shape[0], input.shape[1]) 78 | ) 79 | noise = noise + self.offset_noise_level * append_dims( 80 | torch.randn(offset_shape, device=input.device), 81 | input.ndim, 82 | ) 83 | sigmas_bc = append_dims(sigmas, input.ndim) 84 | noised_input = self.get_noised_input(sigmas_bc, noise, input) 85 | 86 | model_output = denoiser( 87 | network, noised_input, sigmas, cond, **additional_model_inputs 88 | ) 89 | w = append_dims(self.loss_weighting(sigmas), input.ndim) 90 | return self.get_loss(model_output, input, w) 91 | 92 | def get_loss(self, model_output, target, w): 93 | if self.loss_type == "l2": 94 | return torch.mean( 95 | (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1 96 | ) 97 | elif self.loss_type == "l1": 98 | return torch.mean( 99 | (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1 100 | ) 101 | elif self.loss_type == "lpips": 102 | loss = self.lpips(model_output, target).reshape(-1) 103 | return loss 104 | else: 105 | raise NotImplementedError(f"Unknown loss type {self.loss_type}") 106 | -------------------------------------------------------------------------------- /util/detection/nsfw_and_watermark_dectection.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import clip 4 | import numpy as np 5 | import torch 6 | import torchvision.transforms as T 7 | from PIL import Image 8 | 9 | RESOURCES_ROOT = "util/detection/" 10 | 11 | 12 | def predict_proba(X, weights, biases): 13 | logits = X @ weights.T + biases 14 | proba = np.where( 15 | logits >= 0, 1 / (1 + np.exp(-logits)), np.exp(logits) / (1 + np.exp(logits)) 16 | ) 17 | return proba.T 18 | 19 | 20 | def load_model_weights(path: str): 21 | model_weights = np.load(path) 22 | return model_weights["weights"], model_weights["biases"] 23 | 24 | 25 | def clip_process_images(images: torch.Tensor) -> torch.Tensor: 26 | min_size = min(images.shape[-2:]) 27 | return T.Compose( 28 | [ 29 | T.CenterCrop(min_size), # TODO: this might affect the watermark, check this 30 | T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True), 31 | T.Normalize( 32 | (0.48145466, 0.4578275, 0.40821073), 33 | (0.26862954, 0.26130258, 0.27577711), 34 | ), 35 | ] 36 | )(images) 37 | 38 | 39 | class DeepFloydDataFiltering(object): 40 | def __init__( 41 | self, verbose: bool = False, device: torch.device = torch.device("cpu") 42 | ): 43 | super().__init__() 44 | self.verbose = verbose 45 | self._device = None 46 | self.clip_model, _ = clip.load("ViT-L/14", device=device) 47 | self.clip_model.eval() 48 | 49 | self.cpu_w_weights, self.cpu_w_biases = load_model_weights( 50 | os.path.join(RESOURCES_ROOT, "w_head_v1.npz") 51 | ) 52 | self.cpu_p_weights, self.cpu_p_biases = load_model_weights( 53 | os.path.join(RESOURCES_ROOT, "p_head_v1.npz") 54 | ) 55 | self.w_threshold, self.p_threshold = 0.5, 0.5 56 | 57 | @torch.inference_mode() 58 | def __call__(self, images: torch.Tensor) -> torch.Tensor: 59 | imgs = clip_process_images(images) 60 | if self._device is None: 61 | self._device = next(p for p in self.clip_model.parameters()).device 62 | image_features = self.clip_model.encode_image(imgs.to(self._device)) 63 | image_features = image_features.detach().cpu().numpy().astype(np.float16) 64 | p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases) 65 | w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases) 66 | print(f"p_pred = {p_pred}, w_pred = {w_pred}") if self.verbose else None 67 | query = p_pred > self.p_threshold 68 | if query.sum() > 0: 69 | print(f"Hit for p_threshold: {p_pred}") if self.verbose else None 70 | images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query]) 71 | query = w_pred > self.w_threshold 72 | if query.sum() > 0: 73 | print(f"Hit for w_threshold: {w_pred}") if self.verbose else None 74 | images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query]) 75 | return images 76 | 77 | 78 | def load_img(path: str) -> torch.Tensor: 79 | image = Image.open(path) 80 | if not image.mode == "RGB": 81 | image = image.convert("RGB") 82 | image_transforms = T.Compose( 83 | [ 84 | T.ToTensor(), 85 | ] 86 | ) 87 | return image_transforms(image)[None, ...] 88 | 89 | 90 | def test(root): 91 | from einops import rearrange 92 | 93 | filter = DeepFloydDataFiltering(verbose=True) 94 | for p in os.listdir((root)): 95 | print(f"running on {p}...") 96 | img = load_img(os.path.join(root, p)) 97 | filtered_img = filter(img) 98 | filtered_img = rearrange( 99 | 255.0 * (filtered_img.numpy())[0], "c h w -> h w c" 100 | ).astype(np.uint8) 101 | Image.fromarray(filtered_img).save( 102 | os.path.join(root, f"{os.path.splitext(p)[0]}-filtered.jpg") 103 | ) 104 | 105 | 106 | if __name__ == "__main__": 107 | import fire 108 | 109 | fire.Fire(test) 110 | print("done.") 111 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/util.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | 4 | import requests 5 | import torch 6 | import torch.nn as nn 7 | from tqdm import tqdm 8 | 9 | URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} 10 | 11 | CKPT_MAP = {"vgg_lpips": "vgg.pth"} 12 | 13 | MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} 14 | 15 | 16 | def download(url, local_path, chunk_size=1024): 17 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 18 | with requests.get(url, stream=True) as r: 19 | total_size = int(r.headers.get("content-length", 0)) 20 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 21 | with open(local_path, "wb") as f: 22 | for data in r.iter_content(chunk_size=chunk_size): 23 | if data: 24 | f.write(data) 25 | pbar.update(chunk_size) 26 | 27 | 28 | def md5_hash(path): 29 | with open(path, "rb") as f: 30 | content = f.read() 31 | return hashlib.md5(content).hexdigest() 32 | 33 | 34 | def get_ckpt_path(name, root, check=False): 35 | assert name in URL_MAP 36 | path = os.path.join(root, CKPT_MAP[name]) 37 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 38 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 39 | download(URL_MAP[name], path) 40 | md5 = md5_hash(path) 41 | assert md5 == MD5_MAP[name], md5 42 | return path 43 | 44 | 45 | class ActNorm(nn.Module): 46 | def __init__( 47 | self, num_features, logdet=False, affine=True, allow_reverse_init=False 48 | ): 49 | assert affine 50 | super().__init__() 51 | self.logdet = logdet 52 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 53 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 54 | self.allow_reverse_init = allow_reverse_init 55 | 56 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) 57 | 58 | def initialize(self, input): 59 | with torch.no_grad(): 60 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 61 | mean = ( 62 | flatten.mean(1) 63 | .unsqueeze(1) 64 | .unsqueeze(2) 65 | .unsqueeze(3) 66 | .permute(1, 0, 2, 3) 67 | ) 68 | std = ( 69 | flatten.std(1) 70 | .unsqueeze(1) 71 | .unsqueeze(2) 72 | .unsqueeze(3) 73 | .permute(1, 0, 2, 3) 74 | ) 75 | 76 | self.loc.data.copy_(-mean) 77 | self.scale.data.copy_(1 / (std + 1e-6)) 78 | 79 | def forward(self, input, reverse=False): 80 | if reverse: 81 | return self.reverse(input) 82 | if len(input.shape) == 2: 83 | input = input[:, :, None, None] 84 | squeeze = True 85 | else: 86 | squeeze = False 87 | 88 | _, _, height, width = input.shape 89 | 90 | if self.training and self.initialized.item() == 0: 91 | self.initialize(input) 92 | self.initialized.fill_(1) 93 | 94 | h = self.scale * (input + self.loc) 95 | 96 | if squeeze: 97 | h = h.squeeze(-1).squeeze(-1) 98 | 99 | if self.logdet: 100 | log_abs = torch.log(torch.abs(self.scale)) 101 | logdet = height * width * torch.sum(log_abs) 102 | logdet = logdet * torch.ones(input.shape[0]).to(input) 103 | return h, logdet 104 | 105 | return h 106 | 107 | def reverse(self, output): 108 | if self.training and self.initialized.item() == 0: 109 | if not self.allow_reverse_init: 110 | raise RuntimeError( 111 | "Initializing ActNorm in reverse direction is " 112 | "disabled by default. Use allow_reverse_init=True to enable." 113 | ) 114 | else: 115 | self.initialize(output) 116 | self.initialized.fill_(1) 117 | 118 | if len(output.shape) == 2: 119 | output = output[:, :, None, None] 120 | squeeze = True 121 | else: 122 | squeeze = False 123 | 124 | h = output / self.scale - self.loc 125 | 126 | if squeeze: 127 | h = h.squeeze(-1).squeeze(-1) 128 | return h 129 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/guiders.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABC, abstractmethod 3 | from typing import Dict, List, Literal, Optional, Tuple, Union 4 | 5 | import torch 6 | from einops import rearrange, repeat 7 | 8 | from ...util import append_dims, default 9 | 10 | logpy = logging.getLogger(__name__) 11 | 12 | 13 | class Guider(ABC): 14 | @abstractmethod 15 | def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: 16 | pass 17 | 18 | def prepare_inputs( 19 | self, x: torch.Tensor, s: float, c: Dict, uc: Dict 20 | ) -> Tuple[torch.Tensor, float, Dict]: 21 | pass 22 | 23 | 24 | class VanillaCFG(Guider): 25 | def __init__(self, scale: float): 26 | self.scale = scale 27 | 28 | def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: 29 | x_u, x_c = x.chunk(2) 30 | x_pred = x_u + self.scale * (x_c - x_u) 31 | return x_pred 32 | 33 | def prepare_inputs(self, x, s, c, uc): 34 | c_out = dict() 35 | 36 | for k in c: 37 | if k in ["vector", "crossattn", "concat"]: 38 | c_out[k] = torch.cat((uc[k], c[k]), 0) 39 | else: 40 | assert c[k] == uc[k] 41 | c_out[k] = c[k] 42 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out 43 | 44 | 45 | class IdentityGuider(Guider): 46 | def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: 47 | return x 48 | 49 | def prepare_inputs( 50 | self, x: torch.Tensor, s: float, c: Dict, uc: Dict 51 | ) -> Tuple[torch.Tensor, float, Dict]: 52 | c_out = dict() 53 | 54 | for k in c: 55 | c_out[k] = c[k] 56 | 57 | return x, s, c_out 58 | 59 | 60 | class LinearPredictionGuider(Guider): 61 | def __init__( 62 | self, 63 | max_scale: float, 64 | num_frames: int, 65 | min_scale: float = 1.0, 66 | additional_cond_keys: Optional[Union[List[str], str]] = None, 67 | ): 68 | self.min_scale = min_scale 69 | self.max_scale = max_scale 70 | self.num_frames = num_frames 71 | self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0) 72 | 73 | additional_cond_keys = default(additional_cond_keys, []) 74 | if isinstance(additional_cond_keys, str): 75 | additional_cond_keys = [additional_cond_keys] 76 | self.additional_cond_keys = additional_cond_keys 77 | 78 | def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: 79 | x_u, x_c = x.chunk(2) 80 | 81 | x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames) 82 | x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames) 83 | scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0]) 84 | scale = append_dims(scale, x_u.ndim).to(x_u.device) 85 | 86 | return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...") 87 | 88 | def prepare_inputs( 89 | self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict 90 | ) -> Tuple[torch.Tensor, torch.Tensor, dict]: 91 | c_out = dict() 92 | 93 | for k in c: 94 | if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys: 95 | c_out[k] = torch.cat((uc[k], c[k]), 0) 96 | else: 97 | assert c[k] == uc[k] 98 | c_out[k] = c[k] 99 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out 100 | 101 | 102 | class TrianglePredictionGuider(LinearPredictionGuider): 103 | def __init__( 104 | self, 105 | max_scale: float, 106 | num_frames: int, 107 | min_scale: float = 1.0, 108 | period: float | List[float] = 1.0, 109 | period_fusing: Literal["mean", "multiply", "max"] = "max", 110 | additional_cond_keys: Optional[Union[List[str], str]] = None, 111 | ): 112 | super().__init__(max_scale, num_frames, min_scale, additional_cond_keys) 113 | values = torch.linspace(0, 1, num_frames) 114 | # Constructs a triangle wave 115 | if isinstance(period, float): 116 | period = [period] 117 | 118 | scales = [] 119 | for p in period: 120 | scales.append(self.triangle_wave(values, p)) 121 | 122 | if period_fusing == "mean": 123 | scale = sum(scales) / len(period) 124 | elif period_fusing == "multiply": 125 | scale = torch.prod(torch.stack(scales), dim=0) 126 | elif period_fusing == "max": 127 | scale = torch.max(torch.stack(scales), dim=0).values 128 | self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0) 129 | 130 | def triangle_wave(self, values: torch.Tensor, period) -> torch.Tensor: 131 | return 2 * (values / period - torch.floor(values / period + 0.5)).abs() 132 | -------------------------------------------------------------------------------- /sgm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | 9 | def __init__( 10 | self, 11 | warm_up_steps, 12 | lr_min, 13 | lr_max, 14 | lr_start, 15 | max_decay_steps, 16 | verbosity_interval=0, 17 | ): 18 | self.lr_warm_up_steps = warm_up_steps 19 | self.lr_start = lr_start 20 | self.lr_min = lr_min 21 | self.lr_max = lr_max 22 | self.lr_max_decay_steps = max_decay_steps 23 | self.last_lr = 0.0 24 | self.verbosity_interval = verbosity_interval 25 | 26 | def schedule(self, n, **kwargs): 27 | if self.verbosity_interval > 0: 28 | if n % self.verbosity_interval == 0: 29 | print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 30 | if n < self.lr_warm_up_steps: 31 | lr = ( 32 | self.lr_max - self.lr_start 33 | ) / self.lr_warm_up_steps * n + self.lr_start 34 | self.last_lr = lr 35 | return lr 36 | else: 37 | t = (n - self.lr_warm_up_steps) / ( 38 | self.lr_max_decay_steps - self.lr_warm_up_steps 39 | ) 40 | t = min(t, 1.0) 41 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 42 | 1 + np.cos(t * np.pi) 43 | ) 44 | self.last_lr = lr 45 | return lr 46 | 47 | def __call__(self, n, **kwargs): 48 | return self.schedule(n, **kwargs) 49 | 50 | 51 | class LambdaWarmUpCosineScheduler2: 52 | """ 53 | supports repeated iterations, configurable via lists 54 | note: use with a base_lr of 1.0. 55 | """ 56 | 57 | def __init__( 58 | self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0 59 | ): 60 | assert ( 61 | len(warm_up_steps) 62 | == len(f_min) 63 | == len(f_max) 64 | == len(f_start) 65 | == len(cycle_lengths) 66 | ) 67 | self.lr_warm_up_steps = warm_up_steps 68 | self.f_start = f_start 69 | self.f_min = f_min 70 | self.f_max = f_max 71 | self.cycle_lengths = cycle_lengths 72 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 73 | self.last_f = 0.0 74 | self.verbosity_interval = verbosity_interval 75 | 76 | def find_in_interval(self, n): 77 | interval = 0 78 | for cl in self.cum_cycles[1:]: 79 | if n <= cl: 80 | return interval 81 | interval += 1 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: 88 | print( 89 | f"current step: {n}, recent lr-multiplier: {self.last_f}, " 90 | f"current cycle {cycle}" 91 | ) 92 | if n < self.lr_warm_up_steps[cycle]: 93 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ 94 | cycle 95 | ] * n + self.f_start[cycle] 96 | self.last_f = f 97 | return f 98 | else: 99 | t = (n - self.lr_warm_up_steps[cycle]) / ( 100 | self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] 101 | ) 102 | t = min(t, 1.0) 103 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 104 | 1 + np.cos(t * np.pi) 105 | ) 106 | self.last_f = f 107 | return f 108 | 109 | def __call__(self, n, **kwargs): 110 | return self.schedule(n, **kwargs) 111 | 112 | 113 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 114 | def schedule(self, n, **kwargs): 115 | cycle = self.find_in_interval(n) 116 | n = n - self.cum_cycles[cycle] 117 | if self.verbosity_interval > 0: 118 | if n % self.verbosity_interval == 0: 119 | print( 120 | f"current step: {n}, recent lr-multiplier: {self.last_f}, " 121 | f"current cycle {cycle}" 122 | ) 123 | 124 | if n < self.lr_warm_up_steps[cycle]: 125 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ 126 | cycle 127 | ] * n + self.f_start[cycle] 128 | self.last_f = f 129 | return f 130 | else: 131 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * ( 132 | self.cycle_lengths[cycle] - n 133 | ) / (self.cycle_lengths[cycle]) 134 | self.last_f = f 135 | return f 136 | -------------------------------------------------------------------------------- /sgm/svd.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.18215 5 | disable_first_stage_autocast: True 6 | ckpt_path: svd.safetensors 7 | 8 | denoiser_config: 9 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 10 | params: 11 | scaling_config: 12 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise 13 | 14 | network_config: 15 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 16 | params: 17 | adm_in_channels: 768 18 | num_classes: sequential 19 | use_checkpoint: True 20 | in_channels: 8 21 | out_channels: 4 22 | model_channels: 320 23 | attention_resolutions: [4, 2, 1] 24 | num_res_blocks: 2 25 | channel_mult: [1, 2, 4, 4] 26 | num_head_channels: 64 27 | use_linear_in_transformer: True 28 | transformer_depth: 1 29 | context_dim: 1024 30 | spatial_transformer_attn_type: softmax-xformers 31 | extra_ff_mix_layer: True 32 | use_spatial_context: True 33 | merge_strategy: learned_with_images 34 | video_kernel_size: [3, 1, 1] 35 | 36 | conditioner_config: 37 | target: sgm.modules.GeneralConditioner 38 | params: 39 | emb_models: 40 | - is_trainable: False 41 | input_key: cond_frames_without_noise 42 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder 43 | params: 44 | n_cond_frames: 1 45 | n_copies: 1 46 | open_clip_embedding_config: 47 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 48 | params: 49 | freeze: True 50 | 51 | - input_key: fps_id 52 | is_trainable: False 53 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 54 | params: 55 | outdim: 256 56 | 57 | - input_key: motion_bucket_id 58 | is_trainable: False 59 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 60 | params: 61 | outdim: 256 62 | 63 | - input_key: cond_frames 64 | is_trainable: False 65 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder 66 | params: 67 | disable_encoder_autocast: True 68 | n_cond_frames: 1 69 | n_copies: 1 70 | is_ae: True 71 | encoder_config: 72 | target: sgm.models.autoencoder.AutoencoderKLModeOnly 73 | params: 74 | embed_dim: 4 75 | monitor: val/rec_loss 76 | ddconfig: 77 | attn_type: vanilla-xformers 78 | double_z: True 79 | z_channels: 4 80 | resolution: 256 81 | in_channels: 3 82 | out_ch: 3 83 | ch: 128 84 | ch_mult: [1, 2, 4, 4] 85 | num_res_blocks: 2 86 | attn_resolutions: [] 87 | dropout: 0.0 88 | lossconfig: 89 | target: torch.nn.Identity 90 | 91 | - input_key: cond_aug 92 | is_trainable: False 93 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 94 | params: 95 | outdim: 256 96 | 97 | first_stage_config: 98 | target: sgm.models.autoencoder.AutoencodingEngine 99 | params: 100 | loss_config: 101 | target: torch.nn.Identity 102 | regularizer_config: 103 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer 104 | encoder_config: 105 | target: sgm.modules.diffusionmodules.model.Encoder 106 | params: 107 | attn_type: vanilla 108 | double_z: True 109 | z_channels: 4 110 | resolution: 256 111 | in_channels: 3 112 | out_ch: 3 113 | ch: 128 114 | ch_mult: [1, 2, 4, 4] 115 | num_res_blocks: 2 116 | attn_resolutions: [] 117 | dropout: 0.0 118 | decoder_config: 119 | target: sgm.modules.autoencoding.temporal_ae.VideoDecoder 120 | params: 121 | attn_type: vanilla 122 | double_z: True 123 | z_channels: 4 124 | resolution: 256 125 | in_channels: 3 126 | out_ch: 3 127 | ch: 128 128 | ch_mult: [1, 2, 4, 4] 129 | num_res_blocks: 2 130 | attn_resolutions: [] 131 | dropout: 0.0 132 | video_kernel_size: [3, 1, 1] 133 | 134 | sampler_config: 135 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 136 | params: 137 | discretization_config: 138 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization 139 | params: 140 | sigma_max: 700.0 141 | 142 | guider_config: 143 | target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider 144 | params: 145 | max_scale: 2.5 146 | min_scale: 1.0 -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | from collections import namedtuple 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from ..util import get_ckpt_path 10 | 11 | 12 | class LPIPS(nn.Module): 13 | # Learned perceptual metric 14 | def __init__(self, use_dropout=True): 15 | super().__init__() 16 | self.scaling_layer = ScalingLayer() 17 | self.chns = [64, 128, 256, 512, 512] # vg16 features 18 | self.net = vgg16(pretrained=True, requires_grad=False) 19 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 20 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 21 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 22 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 23 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 24 | self.load_from_pretrained() 25 | for param in self.parameters(): 26 | param.requires_grad = False 27 | 28 | def load_from_pretrained(self, name="vgg_lpips"): 29 | ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss") 30 | self.load_state_dict( 31 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 32 | ) 33 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 34 | 35 | @classmethod 36 | def from_pretrained(cls, name="vgg_lpips"): 37 | if name != "vgg_lpips": 38 | raise NotImplementedError 39 | model = cls() 40 | ckpt = get_ckpt_path(name) 41 | model.load_state_dict( 42 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 43 | ) 44 | return model 45 | 46 | def forward(self, input, target): 47 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 48 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 49 | feats0, feats1, diffs = {}, {}, {} 50 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 51 | for kk in range(len(self.chns)): 52 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( 53 | outs1[kk] 54 | ) 55 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 56 | 57 | res = [ 58 | spatial_average(lins[kk].model(diffs[kk]), keepdim=True) 59 | for kk in range(len(self.chns)) 60 | ] 61 | val = res[0] 62 | for l in range(1, len(self.chns)): 63 | val += res[l] 64 | return val 65 | 66 | 67 | class ScalingLayer(nn.Module): 68 | def __init__(self): 69 | super(ScalingLayer, self).__init__() 70 | self.register_buffer( 71 | "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] 72 | ) 73 | self.register_buffer( 74 | "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] 75 | ) 76 | 77 | def forward(self, inp): 78 | return (inp - self.shift) / self.scale 79 | 80 | 81 | class NetLinLayer(nn.Module): 82 | """A single linear layer which does a 1x1 conv""" 83 | 84 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 85 | super(NetLinLayer, self).__init__() 86 | layers = ( 87 | [ 88 | nn.Dropout(), 89 | ] 90 | if (use_dropout) 91 | else [] 92 | ) 93 | layers += [ 94 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), 95 | ] 96 | self.model = nn.Sequential(*layers) 97 | 98 | 99 | class vgg16(torch.nn.Module): 100 | def __init__(self, requires_grad=False, pretrained=True): 101 | super(vgg16, self).__init__() 102 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 103 | self.slice1 = torch.nn.Sequential() 104 | self.slice2 = torch.nn.Sequential() 105 | self.slice3 = torch.nn.Sequential() 106 | self.slice4 = torch.nn.Sequential() 107 | self.slice5 = torch.nn.Sequential() 108 | self.N_slices = 5 109 | for x in range(4): 110 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 111 | for x in range(4, 9): 112 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 113 | for x in range(9, 16): 114 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 115 | for x in range(16, 23): 116 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 117 | for x in range(23, 30): 118 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 119 | if not requires_grad: 120 | for param in self.parameters(): 121 | param.requires_grad = False 122 | 123 | def forward(self, X): 124 | h = self.slice1(X) 125 | h_relu1_2 = h 126 | h = self.slice2(h) 127 | h_relu2_2 = h 128 | h = self.slice3(h) 129 | h_relu3_3 = h 130 | h = self.slice4(h) 131 | h_relu4_3 = h 132 | h = self.slice5(h) 133 | h_relu5_3 = h 134 | vgg_outputs = namedtuple( 135 | "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] 136 | ) 137 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 138 | return out 139 | 140 | 141 | def normalize_tensor(x, eps=1e-10): 142 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 143 | return x / (norm_factor + eps) 144 | 145 | 146 | def spatial_average(x, keepdim=True): 147 | return x.mean([2, 3], keepdim=keepdim) 148 | -------------------------------------------------------------------------------- /sgm/util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import importlib 3 | import os 4 | from functools import partial 5 | from inspect import isfunction 6 | 7 | import fsspec 8 | import numpy as np 9 | import torch 10 | from PIL import Image, ImageDraw, ImageFont 11 | from safetensors.torch import load_file as load_safetensors 12 | 13 | 14 | def disabled_train(self, mode=True): 15 | """Overwrite model.train with this function to make sure train/eval mode 16 | does not change anymore.""" 17 | return self 18 | 19 | 20 | def get_string_from_tuple(s): 21 | try: 22 | # Check if the string starts and ends with parentheses 23 | if s[0] == "(" and s[-1] == ")": 24 | # Convert the string to a tuple 25 | t = eval(s) 26 | # Check if the type of t is tuple 27 | if type(t) == tuple: 28 | return t[0] 29 | else: 30 | pass 31 | except: 32 | pass 33 | return s 34 | 35 | 36 | def is_power_of_two(n): 37 | """ 38 | chat.openai.com/chat 39 | Return True if n is a power of 2, otherwise return False. 40 | 41 | The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False. 42 | The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False. 43 | If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise. 44 | Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False. 45 | 46 | """ 47 | if n <= 0: 48 | return False 49 | return (n & (n - 1)) == 0 50 | 51 | 52 | def autocast(f, enabled=True): 53 | def do_autocast(*args, **kwargs): 54 | with torch.cuda.amp.autocast( 55 | enabled=enabled, 56 | dtype=torch.get_autocast_gpu_dtype(), 57 | cache_enabled=torch.is_autocast_cache_enabled(), 58 | ): 59 | return f(*args, **kwargs) 60 | 61 | return do_autocast 62 | 63 | 64 | def load_partial_from_config(config): 65 | return partial(get_obj_from_str(config["target"]), **config.get("params", dict())) 66 | 67 | 68 | def log_txt_as_img(wh, xc, size=10): 69 | # wh a tuple of (width, height) 70 | # xc a list of captions to plot 71 | b = len(xc) 72 | txts = list() 73 | for bi in range(b): 74 | txt = Image.new("RGB", wh, color="white") 75 | draw = ImageDraw.Draw(txt) 76 | font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) 77 | nc = int(40 * (wh[0] / 256)) 78 | if isinstance(xc[bi], list): 79 | text_seq = xc[bi][0] 80 | else: 81 | text_seq = xc[bi] 82 | lines = "\n".join( 83 | text_seq[start : start + nc] for start in range(0, len(text_seq), nc) 84 | ) 85 | 86 | try: 87 | draw.text((0, 0), lines, fill="black", font=font) 88 | except UnicodeEncodeError: 89 | print("Cant encode string for logging. Skipping.") 90 | 91 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 92 | txts.append(txt) 93 | txts = np.stack(txts) 94 | txts = torch.tensor(txts) 95 | return txts 96 | 97 | 98 | def partialclass(cls, *args, **kwargs): 99 | class NewCls(cls): 100 | __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) 101 | 102 | return NewCls 103 | 104 | 105 | def make_path_absolute(path): 106 | fs, p = fsspec.core.url_to_fs(path) 107 | if fs.protocol == "file": 108 | return os.path.abspath(p) 109 | return path 110 | 111 | 112 | def ismap(x): 113 | if not isinstance(x, torch.Tensor): 114 | return False 115 | return (len(x.shape) == 4) and (x.shape[1] > 3) 116 | 117 | 118 | def isimage(x): 119 | if not isinstance(x, torch.Tensor): 120 | return False 121 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 122 | 123 | 124 | def isheatmap(x): 125 | if not isinstance(x, torch.Tensor): 126 | return False 127 | 128 | return x.ndim == 2 129 | 130 | 131 | def isneighbors(x): 132 | if not isinstance(x, torch.Tensor): 133 | return False 134 | return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1) 135 | 136 | 137 | def exists(x): 138 | return x is not None 139 | 140 | 141 | def expand_dims_like(x, y): 142 | while x.dim() != y.dim(): 143 | x = x.unsqueeze(-1) 144 | return x 145 | 146 | 147 | def default(val, d): 148 | if exists(val): 149 | return val 150 | return d() if isfunction(d) else d 151 | 152 | 153 | def mean_flat(tensor): 154 | """ 155 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 156 | Take the mean over all non-batch dimensions. 157 | """ 158 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 159 | 160 | 161 | def count_params(model, verbose=False): 162 | total_params = sum(p.numel() for p in model.parameters()) 163 | if verbose: 164 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 165 | return total_params 166 | 167 | 168 | def instantiate_from_config(config): 169 | if not "target" in config: 170 | if config == "__is_first_stage__": 171 | return None 172 | elif config == "__is_unconditional__": 173 | return None 174 | raise KeyError("Expected key `target` to instantiate.") 175 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 176 | 177 | 178 | def get_obj_from_str(string, reload=False, invalidate_cache=True): 179 | module, cls = string.rsplit(".", 1) 180 | if invalidate_cache: 181 | importlib.invalidate_caches() 182 | if reload: 183 | module_imp = importlib.import_module(module) 184 | importlib.reload(module_imp) 185 | return getattr(importlib.import_module(module, package=None), cls) 186 | 187 | 188 | def append_zero(x): 189 | return torch.cat([x, x.new_zeros([1])]) 190 | 191 | 192 | def append_dims(x, target_dims): 193 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 194 | dims_to_append = target_dims - x.ndim 195 | if dims_to_append < 0: 196 | raise ValueError( 197 | f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" 198 | ) 199 | return x[(...,) + (None,) * dims_to_append] 200 | 201 | 202 | def load_model_from_config(config, ckpt, verbose=True, freeze=True): 203 | print(f"Loading model from {ckpt}") 204 | if ckpt.endswith("ckpt"): 205 | pl_sd = torch.load(ckpt, map_location="cpu") 206 | if "global_step" in pl_sd: 207 | print(f"Global Step: {pl_sd['global_step']}") 208 | sd = pl_sd["state_dict"] 209 | elif ckpt.endswith("safetensors"): 210 | sd = load_safetensors(ckpt) 211 | else: 212 | raise NotImplementedError 213 | 214 | model = instantiate_from_config(config.model) 215 | 216 | m, u = model.load_state_dict(sd, strict=False) 217 | 218 | if len(m) > 0 and verbose: 219 | print("missing keys:") 220 | print(m) 221 | if len(u) > 0 and verbose: 222 | print("unexpected keys:") 223 | print(u) 224 | 225 | if freeze: 226 | for param in model.parameters(): 227 | param.requires_grad = False 228 | 229 | model.eval() 230 | return model 231 | 232 | 233 | def get_configs_path() -> str: 234 | """ 235 | Get the `configs` directory. 236 | For a working copy, this is the one in the root of the repository, 237 | but for an installed copy, it's in the `sgm` package (see pyproject.toml). 238 | """ 239 | this_dir = os.path.dirname(__file__) 240 | candidates = ( 241 | os.path.join(this_dir, "configs"), 242 | os.path.join(this_dir, "..", "configs"), 243 | ) 244 | for candidate in candidates: 245 | candidate = os.path.abspath(candidate) 246 | if os.path.isdir(candidate): 247 | return candidate 248 | raise FileNotFoundError(f"Could not find SGM configs in {candidates}") 249 | 250 | 251 | def get_nested_attribute(obj, attribute_path, depth=None, return_key=False): 252 | """ 253 | Will return the result of a recursive get attribute call. 254 | E.g.: 255 | a.b.c 256 | = getattr(getattr(a, "b"), "c") 257 | = get_nested_attribute(a, "b.c") 258 | If any part of the attribute call is an integer x with current obj a, will 259 | try to call a[x] instead of a.x first. 260 | """ 261 | attributes = attribute_path.split(".") 262 | if depth is not None and depth > 0: 263 | attributes = attributes[:depth] 264 | assert len(attributes) > 0, "At least one attribute should be selected" 265 | current_attribute = obj 266 | current_key = None 267 | for level, attribute in enumerate(attributes): 268 | current_key = ".".join(attributes[: level + 1]) 269 | try: 270 | id_ = int(attribute) 271 | current_attribute = current_attribute[id_] 272 | except ValueError: 273 | current_attribute = getattr(current_attribute, attribute) 274 | 275 | return (current_attribute, current_key) if return_key else current_attribute 276 | -------------------------------------------------------------------------------- /sgm/modules/video_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..modules.attention import * 4 | from ..modules.diffusionmodules.util import (AlphaBlender, linear, 5 | timestep_embedding) 6 | 7 | 8 | class TimeMixSequential(nn.Sequential): 9 | def forward(self, x, context=None, timesteps=None): 10 | for layer in self: 11 | x = layer(x, context, timesteps) 12 | 13 | return x 14 | 15 | 16 | class VideoTransformerBlock(nn.Module): 17 | ATTENTION_MODES = { 18 | "softmax": CrossAttention, 19 | "softmax-xformers": MemoryEfficientCrossAttention, 20 | } 21 | 22 | def __init__( 23 | self, 24 | dim, 25 | n_heads, 26 | d_head, 27 | dropout=0.0, 28 | context_dim=None, 29 | gated_ff=True, 30 | checkpoint=True, 31 | timesteps=None, 32 | ff_in=False, 33 | inner_dim=None, 34 | attn_mode="softmax", 35 | disable_self_attn=False, 36 | disable_temporal_crossattention=False, 37 | switch_temporal_ca_to_sa=False, 38 | ): 39 | super().__init__() 40 | 41 | attn_cls = self.ATTENTION_MODES[attn_mode] 42 | 43 | self.ff_in = ff_in or inner_dim is not None 44 | if inner_dim is None: 45 | inner_dim = dim 46 | 47 | assert int(n_heads * d_head) == inner_dim 48 | 49 | self.is_res = inner_dim == dim 50 | 51 | if self.ff_in: 52 | self.norm_in = nn.LayerNorm(dim) 53 | self.ff_in = FeedForward( 54 | dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff 55 | ) 56 | 57 | self.timesteps = timesteps 58 | self.disable_self_attn = disable_self_attn 59 | if self.disable_self_attn: 60 | self.attn1 = attn_cls( 61 | query_dim=inner_dim, 62 | heads=n_heads, 63 | dim_head=d_head, 64 | context_dim=context_dim, 65 | dropout=dropout, 66 | ) # is a cross-attention 67 | else: 68 | self.attn1 = attn_cls( 69 | query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout 70 | ) # is a self-attention 71 | 72 | self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff) 73 | 74 | if disable_temporal_crossattention: 75 | if switch_temporal_ca_to_sa: 76 | raise ValueError 77 | else: 78 | self.attn2 = None 79 | else: 80 | self.norm2 = nn.LayerNorm(inner_dim) 81 | if switch_temporal_ca_to_sa: 82 | self.attn2 = attn_cls( 83 | query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout 84 | ) # is a self-attention 85 | else: 86 | self.attn2 = attn_cls( 87 | query_dim=inner_dim, 88 | context_dim=context_dim, 89 | heads=n_heads, 90 | dim_head=d_head, 91 | dropout=dropout, 92 | ) # is self-attn if context is none 93 | 94 | self.norm1 = nn.LayerNorm(inner_dim) 95 | self.norm3 = nn.LayerNorm(inner_dim) 96 | self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa 97 | 98 | self.checkpoint = checkpoint 99 | if self.checkpoint: 100 | print(f"{self.__class__.__name__} is using checkpointing") 101 | 102 | def forward( 103 | self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None 104 | ) -> torch.Tensor: 105 | if self.checkpoint: 106 | return checkpoint(self._forward, x, context, timesteps) 107 | else: 108 | return self._forward(x, context, timesteps=timesteps) 109 | 110 | def _forward(self, x, context=None, timesteps=None): 111 | assert self.timesteps or timesteps 112 | assert not (self.timesteps and timesteps) or self.timesteps == timesteps 113 | timesteps = self.timesteps or timesteps 114 | B, S, C = x.shape 115 | x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps) 116 | 117 | if self.ff_in: 118 | x_skip = x 119 | x = self.ff_in(self.norm_in(x)) 120 | if self.is_res: 121 | x += x_skip 122 | 123 | if self.disable_self_attn: 124 | x = self.attn1(self.norm1(x), context=context) + x 125 | else: 126 | x = self.attn1(self.norm1(x)) + x 127 | 128 | if self.attn2 is not None: 129 | if self.switch_temporal_ca_to_sa: 130 | x = self.attn2(self.norm2(x)) + x 131 | else: 132 | x = self.attn2(self.norm2(x), context=context) + x 133 | x_skip = x 134 | x = self.ff(self.norm3(x)) 135 | if self.is_res: 136 | x += x_skip 137 | 138 | x = rearrange( 139 | x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps 140 | ) 141 | return x 142 | 143 | def get_last_layer(self): 144 | return self.ff.net[-1].weight 145 | 146 | 147 | class SpatialVideoTransformer(SpatialTransformer): 148 | def __init__( 149 | self, 150 | in_channels, 151 | n_heads, 152 | d_head, 153 | depth=1, 154 | dropout=0.0, 155 | use_linear=False, 156 | context_dim=None, 157 | use_spatial_context=False, 158 | timesteps=None, 159 | merge_strategy: str = "fixed", 160 | merge_factor: float = 0.5, 161 | time_context_dim=None, 162 | ff_in=False, 163 | checkpoint=False, 164 | time_depth=1, 165 | attn_mode="softmax", 166 | disable_self_attn=False, 167 | disable_temporal_crossattention=False, 168 | max_time_embed_period: int = 10000, 169 | ): 170 | super().__init__( 171 | in_channels, 172 | n_heads, 173 | d_head, 174 | depth=depth, 175 | dropout=dropout, 176 | attn_type=attn_mode, 177 | use_checkpoint=checkpoint, 178 | context_dim=context_dim, 179 | use_linear=use_linear, 180 | disable_self_attn=disable_self_attn, 181 | ) 182 | self.time_depth = time_depth 183 | self.depth = depth 184 | self.max_time_embed_period = max_time_embed_period 185 | 186 | time_mix_d_head = d_head 187 | n_time_mix_heads = n_heads 188 | 189 | time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads) 190 | 191 | inner_dim = n_heads * d_head 192 | if use_spatial_context: 193 | time_context_dim = context_dim 194 | 195 | self.time_stack = nn.ModuleList( 196 | [ 197 | VideoTransformerBlock( 198 | inner_dim, 199 | n_time_mix_heads, 200 | time_mix_d_head, 201 | dropout=dropout, 202 | context_dim=time_context_dim, 203 | timesteps=timesteps, 204 | checkpoint=checkpoint, 205 | ff_in=ff_in, 206 | inner_dim=time_mix_inner_dim, 207 | attn_mode=attn_mode, 208 | disable_self_attn=disable_self_attn, 209 | disable_temporal_crossattention=disable_temporal_crossattention, 210 | ) 211 | for _ in range(self.depth) 212 | ] 213 | ) 214 | 215 | assert len(self.time_stack) == len(self.transformer_blocks) 216 | 217 | self.use_spatial_context = use_spatial_context 218 | self.in_channels = in_channels 219 | 220 | time_embed_dim = self.in_channels * 4 221 | self.time_pos_embed = nn.Sequential( 222 | linear(self.in_channels, time_embed_dim), 223 | nn.SiLU(), 224 | linear(time_embed_dim, self.in_channels), 225 | ) 226 | 227 | self.time_mixer = AlphaBlender( 228 | alpha=merge_factor, merge_strategy=merge_strategy 229 | ) 230 | 231 | def forward( 232 | self, 233 | x: torch.Tensor, 234 | context: Optional[torch.Tensor] = None, 235 | time_context: Optional[torch.Tensor] = None, 236 | timesteps: Optional[int] = None, 237 | image_only_indicator: Optional[torch.Tensor] = None, 238 | ) -> torch.Tensor: 239 | _, _, h, w = x.shape 240 | x_in = x 241 | spatial_context = None 242 | if exists(context): 243 | spatial_context = context 244 | 245 | if self.use_spatial_context: 246 | assert ( 247 | context.ndim == 3 248 | ), f"n dims of spatial context should be 3 but are {context.ndim}" 249 | 250 | time_context = context 251 | time_context_first_timestep = time_context[::timesteps] 252 | time_context = repeat( 253 | time_context_first_timestep, "b ... -> (b n) ...", n=h * w 254 | ) 255 | elif time_context is not None and not self.use_spatial_context: 256 | time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w) 257 | if time_context.ndim == 2: 258 | time_context = rearrange(time_context, "b c -> b 1 c") 259 | 260 | x = self.norm(x) 261 | if not self.use_linear: 262 | x = self.proj_in(x) 263 | x = rearrange(x, "b c h w -> b (h w) c") 264 | if self.use_linear: 265 | x = self.proj_in(x) 266 | 267 | num_frames = torch.arange(timesteps, device=x.device) 268 | num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) 269 | num_frames = rearrange(num_frames, "b t -> (b t)") 270 | t_emb = timestep_embedding( 271 | num_frames, 272 | self.in_channels, 273 | repeat_only=False, 274 | max_period=self.max_time_embed_period, 275 | ) 276 | emb = self.time_pos_embed(t_emb) 277 | emb = emb[:, None, :] 278 | 279 | for it_, (block, mix_block) in enumerate( 280 | zip(self.transformer_blocks, self.time_stack) 281 | ): 282 | x = block( 283 | x, 284 | context=spatial_context, 285 | ) 286 | 287 | x_mix = x 288 | x_mix = x_mix + emb 289 | 290 | x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps) 291 | x = self.time_mixer( 292 | x_spatial=x, 293 | x_temporal=x_mix, 294 | image_only_indicator=image_only_indicator, 295 | ) 296 | if self.use_linear: 297 | x = self.proj_out(x) 298 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 299 | if not self.use_linear: 300 | x = self.proj_out(x) 301 | out = x + x_in 302 | return out 303 | -------------------------------------------------------------------------------- /sgm/inference/helpers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from typing import List, Optional, Union 4 | 5 | import numpy as np 6 | import torch 7 | from einops import rearrange 8 | from imwatermark import WatermarkEncoder 9 | from omegaconf import ListConfig 10 | from PIL import Image 11 | from torch import autocast 12 | 13 | from sgm.util import append_dims 14 | 15 | 16 | class WatermarkEmbedder: 17 | def __init__(self, watermark): 18 | self.watermark = watermark 19 | self.num_bits = len(WATERMARK_BITS) 20 | self.encoder = WatermarkEncoder() 21 | self.encoder.set_watermark("bits", self.watermark) 22 | 23 | def __call__(self, image: torch.Tensor) -> torch.Tensor: 24 | """ 25 | Adds a predefined watermark to the input image 26 | 27 | Args: 28 | image: ([N,] B, RGB, H, W) in range [0, 1] 29 | 30 | Returns: 31 | same as input but watermarked 32 | """ 33 | squeeze = len(image.shape) == 4 34 | if squeeze: 35 | image = image[None, ...] 36 | n = image.shape[0] 37 | image_np = rearrange( 38 | (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" 39 | ).numpy()[:, :, :, ::-1] 40 | # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] 41 | # watermarking libary expects input as cv2 BGR format 42 | for k in range(image_np.shape[0]): 43 | image_np[k] = self.encoder.encode(image_np[k], "dwtDct") 44 | image = torch.from_numpy( 45 | rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n) 46 | ).to(image.device) 47 | image = torch.clamp(image / 255, min=0.0, max=1.0) 48 | if squeeze: 49 | image = image[0] 50 | return image 51 | 52 | 53 | # A fixed 48-bit message that was choosen at random 54 | # WATERMARK_MESSAGE = 0xB3EC907BB19E 55 | WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 56 | # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 57 | WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] 58 | embed_watermark = WatermarkEmbedder(WATERMARK_BITS) 59 | 60 | 61 | def get_unique_embedder_keys_from_conditioner(conditioner): 62 | return list({x.input_key for x in conditioner.embedders}) 63 | 64 | 65 | def perform_save_locally(save_path, samples): 66 | os.makedirs(os.path.join(save_path), exist_ok=True) 67 | base_count = len(os.listdir(os.path.join(save_path))) 68 | samples = embed_watermark(samples) 69 | for sample in samples: 70 | sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") 71 | Image.fromarray(sample.astype(np.uint8)).save( 72 | os.path.join(save_path, f"{base_count:09}.png") 73 | ) 74 | base_count += 1 75 | 76 | 77 | class Img2ImgDiscretizationWrapper: 78 | """ 79 | wraps a discretizer, and prunes the sigmas 80 | params: 81 | strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) 82 | """ 83 | 84 | def __init__(self, discretization, strength: float = 1.0): 85 | self.discretization = discretization 86 | self.strength = strength 87 | assert 0.0 <= self.strength <= 1.0 88 | 89 | def __call__(self, *args, **kwargs): 90 | # sigmas start large first, and decrease then 91 | sigmas = self.discretization(*args, **kwargs) 92 | print(f"sigmas after discretization, before pruning img2img: ", sigmas) 93 | sigmas = torch.flip(sigmas, (0,)) 94 | sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] 95 | print("prune index:", max(int(self.strength * len(sigmas)), 1)) 96 | sigmas = torch.flip(sigmas, (0,)) 97 | print(f"sigmas after pruning: ", sigmas) 98 | return sigmas 99 | 100 | 101 | def do_sample( 102 | model, 103 | sampler, 104 | value_dict, 105 | num_samples, 106 | H, 107 | W, 108 | C, 109 | F, 110 | force_uc_zero_embeddings: Optional[List] = None, 111 | batch2model_input: Optional[List] = None, 112 | return_latents=False, 113 | filter=None, 114 | device="cuda", 115 | ): 116 | if force_uc_zero_embeddings is None: 117 | force_uc_zero_embeddings = [] 118 | if batch2model_input is None: 119 | batch2model_input = [] 120 | 121 | with torch.no_grad(): 122 | with autocast(device) as precision_scope: 123 | with model.ema_scope(): 124 | num_samples = [num_samples] 125 | batch, batch_uc = get_batch( 126 | get_unique_embedder_keys_from_conditioner(model.conditioner), 127 | value_dict, 128 | num_samples, 129 | ) 130 | for key in batch: 131 | if isinstance(batch[key], torch.Tensor): 132 | print(key, batch[key].shape) 133 | elif isinstance(batch[key], list): 134 | print(key, [len(l) for l in batch[key]]) 135 | else: 136 | print(key, batch[key]) 137 | c, uc = model.conditioner.get_unconditional_conditioning( 138 | batch, 139 | batch_uc=batch_uc, 140 | force_uc_zero_embeddings=force_uc_zero_embeddings, 141 | ) 142 | 143 | for k in c: 144 | if not k == "crossattn": 145 | c[k], uc[k] = map( 146 | lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc) 147 | ) 148 | 149 | additional_model_inputs = {} 150 | for k in batch2model_input: 151 | additional_model_inputs[k] = batch[k] 152 | 153 | shape = (math.prod(num_samples), C, H // F, W // F) 154 | randn = torch.randn(shape).to(device) 155 | 156 | def denoiser(input, sigma, c): 157 | return model.denoiser( 158 | model.model, input, sigma, c, **additional_model_inputs 159 | ) 160 | 161 | samples_z = sampler(denoiser, randn, cond=c, uc=uc) 162 | samples_x = model.decode_first_stage(samples_z) 163 | samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) 164 | 165 | if filter is not None: 166 | samples = filter(samples) 167 | 168 | if return_latents: 169 | return samples, samples_z 170 | return samples 171 | 172 | 173 | def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): 174 | # Hardcoded demo setups; might undergo some changes in the future 175 | 176 | batch = {} 177 | batch_uc = {} 178 | 179 | for key in keys: 180 | if key == "txt": 181 | batch["txt"] = ( 182 | np.repeat([value_dict["prompt"]], repeats=math.prod(N)) 183 | .reshape(N) 184 | .tolist() 185 | ) 186 | batch_uc["txt"] = ( 187 | np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) 188 | .reshape(N) 189 | .tolist() 190 | ) 191 | elif key == "original_size_as_tuple": 192 | batch["original_size_as_tuple"] = ( 193 | torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) 194 | .to(device) 195 | .repeat(*N, 1) 196 | ) 197 | elif key == "crop_coords_top_left": 198 | batch["crop_coords_top_left"] = ( 199 | torch.tensor( 200 | [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] 201 | ) 202 | .to(device) 203 | .repeat(*N, 1) 204 | ) 205 | elif key == "aesthetic_score": 206 | batch["aesthetic_score"] = ( 207 | torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) 208 | ) 209 | batch_uc["aesthetic_score"] = ( 210 | torch.tensor([value_dict["negative_aesthetic_score"]]) 211 | .to(device) 212 | .repeat(*N, 1) 213 | ) 214 | 215 | elif key == "target_size_as_tuple": 216 | batch["target_size_as_tuple"] = ( 217 | torch.tensor([value_dict["target_height"], value_dict["target_width"]]) 218 | .to(device) 219 | .repeat(*N, 1) 220 | ) 221 | else: 222 | batch[key] = value_dict[key] 223 | 224 | for key in batch.keys(): 225 | if key not in batch_uc and isinstance(batch[key], torch.Tensor): 226 | batch_uc[key] = torch.clone(batch[key]) 227 | return batch, batch_uc 228 | 229 | 230 | def get_input_image_tensor(image: Image.Image, device="cuda"): 231 | w, h = image.size 232 | print(f"loaded input image of size ({w}, {h})") 233 | width, height = map( 234 | lambda x: x - x % 64, (w, h) 235 | ) # resize to integer multiple of 64 236 | image = image.resize((width, height)) 237 | image_array = np.array(image.convert("RGB")) 238 | image_array = image_array[None].transpose(0, 3, 1, 2) 239 | image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0 240 | return image_tensor.to(device) 241 | 242 | 243 | def do_img2img( 244 | img, 245 | model, 246 | sampler, 247 | value_dict, 248 | num_samples, 249 | force_uc_zero_embeddings=[], 250 | additional_kwargs={}, 251 | offset_noise_level: float = 0.0, 252 | return_latents=False, 253 | skip_encode=False, 254 | filter=None, 255 | device="cuda", 256 | ): 257 | with torch.no_grad(): 258 | with autocast(device) as precision_scope: 259 | with model.ema_scope(): 260 | batch, batch_uc = get_batch( 261 | get_unique_embedder_keys_from_conditioner(model.conditioner), 262 | value_dict, 263 | [num_samples], 264 | ) 265 | c, uc = model.conditioner.get_unconditional_conditioning( 266 | batch, 267 | batch_uc=batch_uc, 268 | force_uc_zero_embeddings=force_uc_zero_embeddings, 269 | ) 270 | 271 | for k in c: 272 | c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc)) 273 | 274 | for k in additional_kwargs: 275 | c[k] = uc[k] = additional_kwargs[k] 276 | if skip_encode: 277 | z = img 278 | else: 279 | z = model.encode_first_stage(img) 280 | noise = torch.randn_like(z) 281 | sigmas = sampler.discretization(sampler.num_steps) 282 | sigma = sigmas[0].to(z.device) 283 | 284 | if offset_noise_level > 0.0: 285 | noise = noise + offset_noise_level * append_dims( 286 | torch.randn(z.shape[0], device=z.device), z.ndim 287 | ) 288 | noised_z = z + noise * append_dims(sigma, z.ndim) 289 | noised_z = noised_z / torch.sqrt( 290 | 1.0 + sigmas[0] ** 2.0 291 | ) # Note: hardcoded to DDPM-like scaling. need to generalize later. 292 | 293 | def denoiser(x, sigma, c): 294 | return model.denoiser(model.model, x, sigma, c) 295 | 296 | samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) 297 | samples_x = model.decode_first_stage(samples_z) 298 | samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) 299 | 300 | if filter is not None: 301 | samples = filter(samples) 302 | 303 | if return_latents: 304 | return samples, samples_z 305 | return samples 306 | -------------------------------------------------------------------------------- /sampling.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import sys 4 | from glob import glob 5 | from pathlib import Path 6 | from typing import List, Optional 7 | 8 | import cv2 9 | import imageio 10 | import numpy as np 11 | import torch 12 | from einops import rearrange, repeat 13 | from fire import Fire 14 | from omegaconf import OmegaConf 15 | from PIL import Image 16 | from rembg import remove 17 | from util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering 18 | from sgm.inference.helpers import embed_watermark 19 | from sgm.util import default, instantiate_from_config 20 | from torchvision.transforms import ToTensor 21 | from tqdm import tqdm 22 | from camera import Warper 23 | 24 | def sample( 25 | input_path: str = "assets/images/cat.jpg", # Can either be image file or folder with image files 26 | prompt: str="a cat wandering in garden", 27 | neg_prompt: str=" ", 28 | pcd_mode: str = 'complex default 14 mode_4', 29 | add_index: int = 10, 30 | num_frames: int = 14, 31 | num_steps: Optional[int] = 25, 32 | fps_id: int = 6, 33 | motion_bucket_id: int = 127, 34 | version: str = 'svd', 35 | cond_aug: float = 0.02, 36 | seed: int = 1, 37 | decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. 38 | device: str = "cuda", 39 | output_folder: Optional[str] = None, 40 | verbose: Optional[bool] = False, 41 | save_warps: Optional[bool] = False, 42 | load_warps: Optional[str] = None, 43 | ): 44 | """ 45 | Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each 46 | image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`. 47 | """ 48 | pcd_mode = pcd_mode.split(' ') 49 | num_frames = default(num_frames, 14) 50 | num_steps = default(num_steps, 25) 51 | output_folder = default(output_folder, "outputs") 52 | model_config = "sgm/svd.yaml" 53 | pcd_dir = os.path.join(output_folder,'renderings') 54 | os.makedirs(output_folder, exist_ok=True) 55 | if save_warps == True: 56 | os.makedirs(pcd_dir, exist_ok=True) 57 | 58 | model, filter = load_model( 59 | model_config, 60 | device, 61 | num_frames, 62 | num_steps, 63 | verbose, 64 | ) 65 | torch.manual_seed(seed) 66 | 67 | path = Path(input_path) 68 | all_img_paths = [] 69 | if path.is_file(): 70 | if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]): 71 | all_img_paths = [input_path] 72 | else: 73 | raise ValueError("Path is not valid image file.") 74 | elif path.is_dir(): 75 | all_img_paths = sorted( 76 | [ 77 | f 78 | for f in path.iterdir() 79 | if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"] 80 | ] 81 | ) 82 | if len(all_img_paths) == 0: 83 | raise ValueError("Folder does not contain any images.") 84 | else: 85 | raise ValueError 86 | 87 | for input_img_path in all_img_paths: 88 | with Image.open(input_img_path) as image: 89 | input_image = image.convert("RGB") 90 | w, h = image.size 91 | if h % 64 != 0 or w % 64 != 0: 92 | width, height = map(lambda x: x - x % 64, (w, h)) 93 | input_image = input_image.resize((width, height)) 94 | print( 95 | f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!" 96 | ) 97 | 98 | image = ToTensor()(input_image) 99 | image = image * 2.0 - 1.0 100 | 101 | image = image.unsqueeze(0).to(device) 102 | H, W = image.shape[2:] 103 | assert image.shape[1] == 3 104 | F = 8 105 | C = 4 106 | shape = (num_frames, C, H // F, W // F) 107 | if (H, W) != (576, 1024) and "sv3d" not in version: 108 | print( 109 | "WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`." 110 | ) 111 | if motion_bucket_id > 255: 112 | print( 113 | "WARNING: High motion bucket! This may lead to suboptimal performance." 114 | ) 115 | if fps_id < 5: 116 | print("WARNING: Small fps value! This may lead to suboptimal performance.") 117 | 118 | if fps_id > 30: 119 | print("WARNING: Large fps value! This may lead to suboptimal performance.") 120 | 121 | value_dict = {} 122 | value_dict["cond_frames_without_noise"] = image 123 | value_dict["motion_bucket_id"] = motion_bucket_id 124 | value_dict["fps_id"] = fps_id 125 | value_dict["cond_aug"] = cond_aug 126 | value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image) 127 | 128 | with torch.no_grad(): 129 | with torch.autocast(device): 130 | batch, batch_uc = get_batch( 131 | get_unique_embedder_keys_from_conditioner(model.conditioner), 132 | value_dict, 133 | [1, num_frames], 134 | T=num_frames, 135 | device=device, 136 | ) 137 | c, uc = model.conditioner.get_unconditional_conditioning( 138 | batch, 139 | batch_uc=batch_uc, 140 | force_uc_zero_embeddings=[ 141 | "cond_frames", 142 | "cond_frames_without_noise", 143 | ], 144 | ) 145 | 146 | for k in ["crossattn", "concat"]: 147 | uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames) 148 | uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames) 149 | c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) 150 | c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames) 151 | 152 | additional_model_inputs = {} 153 | additional_model_inputs["image_only_indicator"] = torch.zeros( 154 | 2, num_frames 155 | ).to(device) 156 | additional_model_inputs["num_video_frames"] = batch["num_video_frames"] 157 | 158 | def denoiser(input, sigma, c): 159 | return model.denoiser( 160 | model.model, input, sigma, c, **additional_model_inputs 161 | ) 162 | 163 | if load_warps != None: 164 | print('warp path provided, reading from folder') 165 | images = concat_warp_start(image, num_frames, load_warps) 166 | else: 167 | warper = Warper(H, W) 168 | images = warper.generate_pcd(input_image, prompt, neg_prompt, pcd_mode, seed, num_steps, pcd_dir, save_warps) 169 | latent_images = model.encode_first_stage(images) 170 | 171 | # samples_z = model.sampler(denoiser, randn, cond=c, uc=uc) 172 | randn = torch.randn(shape, device=device) 173 | _, s_in, sigmas, num_sigmas, cond, uc = model.sampler.prepare_sampling_loop(randn, cond=c, uc=uc, num_steps=num_steps) 174 | 175 | noise = torch.randn(shape, device=device) 176 | x = latent_images + noise * sigmas[add_index] 177 | 178 | for i in tqdm(model.sampler.get_sigma_gen(num_sigmas)[add_index:]): 179 | gamma = ( 180 | min(model.sampler.s_churn / (num_sigmas - 1), 2**0.5 - 1) 181 | if model.sampler.s_tmin <= sigmas[i] <= model.sampler.s_tmax 182 | else 0.0 183 | ) 184 | 185 | x = model.sampler.sampler_step( 186 | s_in * sigmas[i], 187 | s_in * sigmas[i + 1], 188 | denoiser, 189 | x, 190 | cond, 191 | uc, 192 | gamma, 193 | ) 194 | 195 | model.en_and_decode_n_samples_a_time = decoding_t 196 | samples_x = model.decode_first_stage(x) 197 | if "sv3d" in version: 198 | samples_x[-1:] = value_dict["cond_frames_without_noise"] 199 | samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) 200 | 201 | base_count = len(glob(os.path.join(output_folder, "*.gif"))) 202 | 203 | samples = embed_watermark(samples) 204 | samples = filter(samples) 205 | vid = ( 206 | (rearrange(samples, "t c h w -> t h w c") * 255) 207 | .cpu() 208 | .numpy() 209 | .astype(np.uint8) 210 | ) 211 | video_path = os.path.join(output_folder, f"{base_count:06d}_{'_'.join(pcd_mode)}_i_{add_index}_seed_{seed}.gif") 212 | imageio.mimwrite(video_path, vid) 213 | 214 | def concat_warp_start(image, num_frames, concat_path, device = 'cuda'): 215 | images = torch.Tensor([]).to(device) 216 | h, w = image.shape[2:] 217 | for i in range(num_frames): 218 | if i == 0: 219 | new_image = image 220 | else: 221 | new_image = Image.open(f'{concat_path}/{i}_concat.png').resize((w, h)) 222 | new_image = ToTensor()(new_image) 223 | new_image = new_image * 2.0 - 1.0 224 | new_image = new_image.unsqueeze(0).to(device) 225 | images = torch.cat([images, new_image]) 226 | 227 | return images 228 | 229 | def get_unique_embedder_keys_from_conditioner(conditioner): 230 | return list(set([x.input_key for x in conditioner.embedders])) 231 | 232 | 233 | def get_batch(keys, value_dict, N, T, device): 234 | batch = {} 235 | batch_uc = {} 236 | 237 | for key in keys: 238 | if key == "fps_id": 239 | batch[key] = ( 240 | torch.tensor([value_dict["fps_id"]]) 241 | .to(device) 242 | .repeat(int(math.prod(N))) 243 | ) 244 | elif key == "motion_bucket_id": 245 | batch[key] = ( 246 | torch.tensor([value_dict["motion_bucket_id"]]) 247 | .to(device) 248 | .repeat(int(math.prod(N))) 249 | ) 250 | elif key == "cond_aug": 251 | batch[key] = repeat( 252 | torch.tensor([value_dict["cond_aug"]]).to(device), 253 | "1 -> b", 254 | b=math.prod(N), 255 | ) 256 | elif key == "cond_frames" or key == "cond_frames_without_noise": 257 | batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=N[0]) 258 | elif key == "polars_rad" or key == "azimuths_rad": 259 | batch[key] = torch.tensor(value_dict[key]).to(device).repeat(N[0]) 260 | else: 261 | batch[key] = value_dict[key] 262 | 263 | if T is not None: 264 | batch["num_video_frames"] = T 265 | 266 | for key in batch.keys(): 267 | if key not in batch_uc and isinstance(batch[key], torch.Tensor): 268 | batch_uc[key] = torch.clone(batch[key]) 269 | return batch, batch_uc 270 | 271 | 272 | def load_model( 273 | config: str, 274 | device: str, 275 | num_frames: int, 276 | num_steps: int, 277 | verbose: bool = False, 278 | ): 279 | config = OmegaConf.load(config) 280 | if device == "cuda": 281 | config.model.params.conditioner_config.params.emb_models[ 282 | 0 283 | ].params.open_clip_embedding_config.params.init_device = device 284 | 285 | config.model.params.sampler_config.params.verbose = verbose 286 | config.model.params.sampler_config.params.num_steps = num_steps 287 | config.model.params.sampler_config.params.guider_config.params.num_frames = ( 288 | num_frames 289 | ) 290 | if device == "cuda": 291 | with torch.device(device): 292 | model = instantiate_from_config(config.model).to(device).eval() 293 | else: 294 | model = instantiate_from_config(config.model).to(device).eval() 295 | 296 | filter = DeepFloydDataFiltering(verbose=False, device=device) 297 | return model, filter 298 | 299 | 300 | if __name__ == "__main__": 301 | Fire(sample) -------------------------------------------------------------------------------- /sgm/modules/autoencoding/temporal_ae.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Iterable, Union 2 | 3 | import torch 4 | from einops import rearrange, repeat 5 | 6 | from sgm.modules.diffusionmodules.model import (XFORMERS_IS_AVAILABLE, 7 | AttnBlock, Decoder, 8 | MemoryEfficientAttnBlock, 9 | ResnetBlock) 10 | from sgm.modules.diffusionmodules.openaimodel import (ResBlock, 11 | timestep_embedding) 12 | from sgm.modules.video_attention import VideoTransformerBlock 13 | from sgm.util import partialclass 14 | 15 | 16 | class VideoResBlock(ResnetBlock): 17 | def __init__( 18 | self, 19 | out_channels, 20 | *args, 21 | dropout=0.0, 22 | video_kernel_size=3, 23 | alpha=0.0, 24 | merge_strategy="learned", 25 | **kwargs, 26 | ): 27 | super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs) 28 | if video_kernel_size is None: 29 | video_kernel_size = [3, 1, 1] 30 | self.time_stack = ResBlock( 31 | channels=out_channels, 32 | emb_channels=0, 33 | dropout=dropout, 34 | dims=3, 35 | use_scale_shift_norm=False, 36 | use_conv=False, 37 | up=False, 38 | down=False, 39 | kernel_size=video_kernel_size, 40 | use_checkpoint=False, 41 | skip_t_emb=True, 42 | ) 43 | 44 | self.merge_strategy = merge_strategy 45 | if self.merge_strategy == "fixed": 46 | self.register_buffer("mix_factor", torch.Tensor([alpha])) 47 | elif self.merge_strategy == "learned": 48 | self.register_parameter( 49 | "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) 50 | ) 51 | else: 52 | raise ValueError(f"unknown merge strategy {self.merge_strategy}") 53 | 54 | def get_alpha(self, bs): 55 | if self.merge_strategy == "fixed": 56 | return self.mix_factor 57 | elif self.merge_strategy == "learned": 58 | return torch.sigmoid(self.mix_factor) 59 | else: 60 | raise NotImplementedError() 61 | 62 | def forward(self, x, temb, skip_video=False, timesteps=None): 63 | if timesteps is None: 64 | timesteps = self.timesteps 65 | 66 | b, c, h, w = x.shape 67 | 68 | x = super().forward(x, temb) 69 | 70 | if not skip_video: 71 | x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) 72 | 73 | x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) 74 | 75 | x = self.time_stack(x, temb) 76 | 77 | alpha = self.get_alpha(bs=b // timesteps) 78 | x = alpha * x + (1.0 - alpha) * x_mix 79 | 80 | x = rearrange(x, "b c t h w -> (b t) c h w") 81 | return x 82 | 83 | 84 | class AE3DConv(torch.nn.Conv2d): 85 | def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): 86 | super().__init__(in_channels, out_channels, *args, **kwargs) 87 | if isinstance(video_kernel_size, Iterable): 88 | padding = [int(k // 2) for k in video_kernel_size] 89 | else: 90 | padding = int(video_kernel_size // 2) 91 | 92 | self.time_mix_conv = torch.nn.Conv3d( 93 | in_channels=out_channels, 94 | out_channels=out_channels, 95 | kernel_size=video_kernel_size, 96 | padding=padding, 97 | ) 98 | 99 | def forward(self, input, timesteps, skip_video=False): 100 | x = super().forward(input) 101 | if skip_video: 102 | return x 103 | x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) 104 | x = self.time_mix_conv(x) 105 | return rearrange(x, "b c t h w -> (b t) c h w") 106 | 107 | 108 | class VideoBlock(AttnBlock): 109 | def __init__( 110 | self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned" 111 | ): 112 | super().__init__(in_channels) 113 | # no context, single headed, as in base class 114 | self.time_mix_block = VideoTransformerBlock( 115 | dim=in_channels, 116 | n_heads=1, 117 | d_head=in_channels, 118 | checkpoint=False, 119 | ff_in=True, 120 | attn_mode="softmax", 121 | ) 122 | 123 | time_embed_dim = self.in_channels * 4 124 | self.video_time_embed = torch.nn.Sequential( 125 | torch.nn.Linear(self.in_channels, time_embed_dim), 126 | torch.nn.SiLU(), 127 | torch.nn.Linear(time_embed_dim, self.in_channels), 128 | ) 129 | 130 | self.merge_strategy = merge_strategy 131 | if self.merge_strategy == "fixed": 132 | self.register_buffer("mix_factor", torch.Tensor([alpha])) 133 | elif self.merge_strategy == "learned": 134 | self.register_parameter( 135 | "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) 136 | ) 137 | else: 138 | raise ValueError(f"unknown merge strategy {self.merge_strategy}") 139 | 140 | def forward(self, x, timesteps, skip_video=False): 141 | if skip_video: 142 | return super().forward(x) 143 | 144 | x_in = x 145 | x = self.attention(x) 146 | h, w = x.shape[2:] 147 | x = rearrange(x, "b c h w -> b (h w) c") 148 | 149 | x_mix = x 150 | num_frames = torch.arange(timesteps, device=x.device) 151 | num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) 152 | num_frames = rearrange(num_frames, "b t -> (b t)") 153 | t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) 154 | emb = self.video_time_embed(t_emb) # b, n_channels 155 | emb = emb[:, None, :] 156 | x_mix = x_mix + emb 157 | 158 | alpha = self.get_alpha() 159 | x_mix = self.time_mix_block(x_mix, timesteps=timesteps) 160 | x = alpha * x + (1.0 - alpha) * x_mix # alpha merge 161 | 162 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 163 | x = self.proj_out(x) 164 | 165 | return x_in + x 166 | 167 | def get_alpha( 168 | self, 169 | ): 170 | if self.merge_strategy == "fixed": 171 | return self.mix_factor 172 | elif self.merge_strategy == "learned": 173 | return torch.sigmoid(self.mix_factor) 174 | else: 175 | raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") 176 | 177 | 178 | class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock): 179 | def __init__( 180 | self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned" 181 | ): 182 | super().__init__(in_channels) 183 | # no context, single headed, as in base class 184 | self.time_mix_block = VideoTransformerBlock( 185 | dim=in_channels, 186 | n_heads=1, 187 | d_head=in_channels, 188 | checkpoint=False, 189 | ff_in=True, 190 | attn_mode="softmax-xformers", 191 | ) 192 | 193 | time_embed_dim = self.in_channels * 4 194 | self.video_time_embed = torch.nn.Sequential( 195 | torch.nn.Linear(self.in_channels, time_embed_dim), 196 | torch.nn.SiLU(), 197 | torch.nn.Linear(time_embed_dim, self.in_channels), 198 | ) 199 | 200 | self.merge_strategy = merge_strategy 201 | if self.merge_strategy == "fixed": 202 | self.register_buffer("mix_factor", torch.Tensor([alpha])) 203 | elif self.merge_strategy == "learned": 204 | self.register_parameter( 205 | "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) 206 | ) 207 | else: 208 | raise ValueError(f"unknown merge strategy {self.merge_strategy}") 209 | 210 | def forward(self, x, timesteps, skip_time_block=False): 211 | if skip_time_block: 212 | return super().forward(x) 213 | 214 | x_in = x 215 | x = self.attention(x) 216 | h, w = x.shape[2:] 217 | x = rearrange(x, "b c h w -> b (h w) c") 218 | 219 | x_mix = x 220 | num_frames = torch.arange(timesteps, device=x.device) 221 | num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) 222 | num_frames = rearrange(num_frames, "b t -> (b t)") 223 | t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) 224 | emb = self.video_time_embed(t_emb) # b, n_channels 225 | emb = emb[:, None, :] 226 | x_mix = x_mix + emb 227 | 228 | alpha = self.get_alpha() 229 | x_mix = self.time_mix_block(x_mix, timesteps=timesteps) 230 | x = alpha * x + (1.0 - alpha) * x_mix # alpha merge 231 | 232 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 233 | x = self.proj_out(x) 234 | 235 | return x_in + x 236 | 237 | def get_alpha( 238 | self, 239 | ): 240 | if self.merge_strategy == "fixed": 241 | return self.mix_factor 242 | elif self.merge_strategy == "learned": 243 | return torch.sigmoid(self.mix_factor) 244 | else: 245 | raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") 246 | 247 | 248 | def make_time_attn( 249 | in_channels, 250 | attn_type="vanilla", 251 | attn_kwargs=None, 252 | alpha: float = 0, 253 | merge_strategy: str = "learned", 254 | ): 255 | assert attn_type in [ 256 | "vanilla", 257 | "vanilla-xformers", 258 | ], f"attn_type {attn_type} not supported for spatio-temporal attention" 259 | print( 260 | f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels" 261 | ) 262 | if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers": 263 | print( 264 | f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. " 265 | f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}" 266 | ) 267 | attn_type = "vanilla" 268 | 269 | if attn_type == "vanilla": 270 | assert attn_kwargs is None 271 | return partialclass( 272 | VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy 273 | ) 274 | elif attn_type == "vanilla-xformers": 275 | print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") 276 | return partialclass( 277 | MemoryEfficientVideoBlock, 278 | in_channels, 279 | alpha=alpha, 280 | merge_strategy=merge_strategy, 281 | ) 282 | else: 283 | return NotImplementedError() 284 | 285 | 286 | class Conv2DWrapper(torch.nn.Conv2d): 287 | def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor: 288 | return super().forward(input) 289 | 290 | 291 | class VideoDecoder(Decoder): 292 | available_time_modes = ["all", "conv-only", "attn-only"] 293 | 294 | def __init__( 295 | self, 296 | *args, 297 | video_kernel_size: Union[int, list] = 3, 298 | alpha: float = 0.0, 299 | merge_strategy: str = "learned", 300 | time_mode: str = "conv-only", 301 | **kwargs, 302 | ): 303 | self.video_kernel_size = video_kernel_size 304 | self.alpha = alpha 305 | self.merge_strategy = merge_strategy 306 | self.time_mode = time_mode 307 | assert ( 308 | self.time_mode in self.available_time_modes 309 | ), f"time_mode parameter has to be in {self.available_time_modes}" 310 | super().__init__(*args, **kwargs) 311 | 312 | def get_last_layer(self, skip_time_mix=False, **kwargs): 313 | if self.time_mode == "attn-only": 314 | raise NotImplementedError("TODO") 315 | else: 316 | return ( 317 | self.conv_out.time_mix_conv.weight 318 | if not skip_time_mix 319 | else self.conv_out.weight 320 | ) 321 | 322 | def _make_attn(self) -> Callable: 323 | if self.time_mode not in ["conv-only", "only-last-conv"]: 324 | return partialclass( 325 | make_time_attn, 326 | alpha=self.alpha, 327 | merge_strategy=self.merge_strategy, 328 | ) 329 | else: 330 | return super()._make_attn() 331 | 332 | def _make_conv(self) -> Callable: 333 | if self.time_mode != "attn-only": 334 | return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size) 335 | else: 336 | return Conv2DWrapper 337 | 338 | def _make_resblock(self) -> Callable: 339 | if self.time_mode not in ["attn-only", "only-last-conv"]: 340 | return partialclass( 341 | VideoResBlock, 342 | video_kernel_size=self.video_kernel_size, 343 | alpha=self.alpha, 344 | merge_strategy=self.merge_strategy, 345 | ) 346 | else: 347 | return super()._make_resblock() 348 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/losses/discriminator_loss.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Iterator, List, Optional, Tuple, Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torchvision 7 | from einops import rearrange 8 | from matplotlib import colormaps 9 | from matplotlib import pyplot as plt 10 | 11 | from ....util import default, instantiate_from_config 12 | from ..lpips.loss.lpips import LPIPS 13 | from ..lpips.model.model import weights_init 14 | from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss 15 | 16 | 17 | class GeneralLPIPSWithDiscriminator(nn.Module): 18 | def __init__( 19 | self, 20 | disc_start: int, 21 | logvar_init: float = 0.0, 22 | disc_num_layers: int = 3, 23 | disc_in_channels: int = 3, 24 | disc_factor: float = 1.0, 25 | disc_weight: float = 1.0, 26 | perceptual_weight: float = 1.0, 27 | disc_loss: str = "hinge", 28 | scale_input_to_tgt_size: bool = False, 29 | dims: int = 2, 30 | learn_logvar: bool = False, 31 | regularization_weights: Union[None, Dict[str, float]] = None, 32 | additional_log_keys: Optional[List[str]] = None, 33 | discriminator_config: Optional[Dict] = None, 34 | ): 35 | super().__init__() 36 | self.dims = dims 37 | if self.dims > 2: 38 | print( 39 | f"running with dims={dims}. This means that for perceptual loss " 40 | f"calculation, the LPIPS loss will be applied to each frame " 41 | f"independently." 42 | ) 43 | self.scale_input_to_tgt_size = scale_input_to_tgt_size 44 | assert disc_loss in ["hinge", "vanilla"] 45 | self.perceptual_loss = LPIPS().eval() 46 | self.perceptual_weight = perceptual_weight 47 | # output log variance 48 | self.logvar = nn.Parameter( 49 | torch.full((), logvar_init), requires_grad=learn_logvar 50 | ) 51 | self.learn_logvar = learn_logvar 52 | 53 | discriminator_config = default( 54 | discriminator_config, 55 | { 56 | "target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator", 57 | "params": { 58 | "input_nc": disc_in_channels, 59 | "n_layers": disc_num_layers, 60 | "use_actnorm": False, 61 | }, 62 | }, 63 | ) 64 | 65 | self.discriminator = instantiate_from_config(discriminator_config).apply( 66 | weights_init 67 | ) 68 | self.discriminator_iter_start = disc_start 69 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 70 | self.disc_factor = disc_factor 71 | self.discriminator_weight = disc_weight 72 | self.regularization_weights = default(regularization_weights, {}) 73 | 74 | self.forward_keys = [ 75 | "optimizer_idx", 76 | "global_step", 77 | "last_layer", 78 | "split", 79 | "regularization_log", 80 | ] 81 | 82 | self.additional_log_keys = set(default(additional_log_keys, [])) 83 | self.additional_log_keys.update(set(self.regularization_weights.keys())) 84 | 85 | def get_trainable_parameters(self) -> Iterator[nn.Parameter]: 86 | return self.discriminator.parameters() 87 | 88 | def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]: 89 | if self.learn_logvar: 90 | yield self.logvar 91 | yield from () 92 | 93 | @torch.no_grad() 94 | def log_images( 95 | self, inputs: torch.Tensor, reconstructions: torch.Tensor 96 | ) -> Dict[str, torch.Tensor]: 97 | # calc logits of real/fake 98 | logits_real = self.discriminator(inputs.contiguous().detach()) 99 | if len(logits_real.shape) < 4: 100 | # Non patch-discriminator 101 | return dict() 102 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 103 | # -> (b, 1, h, w) 104 | 105 | # parameters for colormapping 106 | high = max(logits_fake.abs().max(), logits_real.abs().max()).item() 107 | cmap = colormaps["PiYG"] # diverging colormap 108 | 109 | def to_colormap(logits: torch.Tensor) -> torch.Tensor: 110 | """(b, 1, ...) -> (b, 3, ...)""" 111 | logits = (logits + high) / (2 * high) 112 | logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel 113 | # -> (b, 1, ..., 3) 114 | logits = torch.from_numpy(logits_np).to(logits.device) 115 | return rearrange(logits, "b 1 ... c -> b c ...") 116 | 117 | logits_real = torch.nn.functional.interpolate( 118 | logits_real, 119 | size=inputs.shape[-2:], 120 | mode="nearest", 121 | antialias=False, 122 | ) 123 | logits_fake = torch.nn.functional.interpolate( 124 | logits_fake, 125 | size=reconstructions.shape[-2:], 126 | mode="nearest", 127 | antialias=False, 128 | ) 129 | 130 | # alpha value of logits for overlay 131 | alpha_real = torch.abs(logits_real) / high 132 | alpha_fake = torch.abs(logits_fake) / high 133 | # -> (b, 1, h, w) in range [0, 0.5] 134 | # alpha value of lines don't really matter, since the values are the same 135 | # for both images and logits anyway 136 | grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4) 137 | grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4) 138 | grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1) 139 | # -> (1, h, w) 140 | # blend logits and images together 141 | 142 | # prepare logits for plotting 143 | logits_real = to_colormap(logits_real) 144 | logits_fake = to_colormap(logits_fake) 145 | # resize logits 146 | # -> (b, 3, h, w) 147 | 148 | # make some grids 149 | # add all logits to one plot 150 | logits_real = torchvision.utils.make_grid(logits_real, nrow=4) 151 | logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4) 152 | # I just love how torchvision calls the number of columns `nrow` 153 | grid_logits = torch.cat((logits_real, logits_fake), dim=1) 154 | # -> (3, h, w) 155 | 156 | grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4) 157 | grid_images_fake = torchvision.utils.make_grid( 158 | 0.5 * reconstructions + 0.5, nrow=4 159 | ) 160 | grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1) 161 | # -> (3, h, w) in range [0, 1] 162 | 163 | grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images 164 | 165 | # Create labeled colorbar 166 | dpi = 100 167 | height = 128 / dpi 168 | width = grid_logits.shape[2] / dpi 169 | fig, ax = plt.subplots(figsize=(width, height), dpi=dpi) 170 | img = ax.imshow(np.array([[-high, high]]), cmap=cmap) 171 | plt.colorbar( 172 | img, 173 | cax=ax, 174 | orientation="horizontal", 175 | fraction=0.9, 176 | aspect=width / height, 177 | pad=0.0, 178 | ) 179 | img.set_visible(False) 180 | fig.tight_layout() 181 | fig.canvas.draw() 182 | # manually convert figure to numpy 183 | cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 184 | cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 185 | cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0 186 | cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device) 187 | 188 | # Add colorbar to plot 189 | annotated_grid = torch.cat((grid_logits, cbar), dim=1) 190 | blended_grid = torch.cat((grid_blend, cbar), dim=1) 191 | return { 192 | "vis_logits": 2 * annotated_grid[None, ...] - 1, 193 | "vis_logits_blended": 2 * blended_grid[None, ...] - 1, 194 | } 195 | 196 | def calculate_adaptive_weight( 197 | self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor 198 | ) -> torch.Tensor: 199 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 200 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 201 | 202 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 203 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 204 | d_weight = d_weight * self.discriminator_weight 205 | return d_weight 206 | 207 | def forward( 208 | self, 209 | inputs: torch.Tensor, 210 | reconstructions: torch.Tensor, 211 | *, # added because I changed the order here 212 | regularization_log: Dict[str, torch.Tensor], 213 | optimizer_idx: int, 214 | global_step: int, 215 | last_layer: torch.Tensor, 216 | split: str = "train", 217 | weights: Union[None, float, torch.Tensor] = None, 218 | ) -> Tuple[torch.Tensor, dict]: 219 | if self.scale_input_to_tgt_size: 220 | inputs = torch.nn.functional.interpolate( 221 | inputs, reconstructions.shape[2:], mode="bicubic", antialias=True 222 | ) 223 | 224 | if self.dims > 2: 225 | inputs, reconstructions = map( 226 | lambda x: rearrange(x, "b c t h w -> (b t) c h w"), 227 | (inputs, reconstructions), 228 | ) 229 | 230 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 231 | if self.perceptual_weight > 0: 232 | p_loss = self.perceptual_loss( 233 | inputs.contiguous(), reconstructions.contiguous() 234 | ) 235 | rec_loss = rec_loss + self.perceptual_weight * p_loss 236 | 237 | nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights) 238 | 239 | # now the GAN part 240 | if optimizer_idx == 0: 241 | # generator update 242 | if global_step >= self.discriminator_iter_start or not self.training: 243 | logits_fake = self.discriminator(reconstructions.contiguous()) 244 | g_loss = -torch.mean(logits_fake) 245 | if self.training: 246 | d_weight = self.calculate_adaptive_weight( 247 | nll_loss, g_loss, last_layer=last_layer 248 | ) 249 | else: 250 | d_weight = torch.tensor(1.0) 251 | else: 252 | d_weight = torch.tensor(0.0) 253 | g_loss = torch.tensor(0.0, requires_grad=True) 254 | 255 | loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss 256 | log = dict() 257 | for k in regularization_log: 258 | if k in self.regularization_weights: 259 | loss = loss + self.regularization_weights[k] * regularization_log[k] 260 | if k in self.additional_log_keys: 261 | log[f"{split}/{k}"] = regularization_log[k].detach().float().mean() 262 | 263 | log.update( 264 | { 265 | f"{split}/loss/total": loss.clone().detach().mean(), 266 | f"{split}/loss/nll": nll_loss.detach().mean(), 267 | f"{split}/loss/rec": rec_loss.detach().mean(), 268 | f"{split}/loss/g": g_loss.detach().mean(), 269 | f"{split}/scalars/logvar": self.logvar.detach(), 270 | f"{split}/scalars/d_weight": d_weight.detach(), 271 | } 272 | ) 273 | 274 | return loss, log 275 | elif optimizer_idx == 1: 276 | # second pass for discriminator update 277 | logits_real = self.discriminator(inputs.contiguous().detach()) 278 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 279 | 280 | if global_step >= self.discriminator_iter_start or not self.training: 281 | d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake) 282 | else: 283 | d_loss = torch.tensor(0.0, requires_grad=True) 284 | 285 | log = { 286 | f"{split}/loss/disc": d_loss.clone().detach().mean(), 287 | f"{split}/logits/real": logits_real.detach().mean(), 288 | f"{split}/logits/fake": logits_fake.detach().mean(), 289 | } 290 | return d_loss, log 291 | else: 292 | raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}") 293 | 294 | def get_nll_loss( 295 | self, 296 | rec_loss: torch.Tensor, 297 | weights: Optional[Union[float, torch.Tensor]] = None, 298 | ) -> Tuple[torch.Tensor, torch.Tensor]: 299 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 300 | weighted_nll_loss = nll_loss 301 | if weights is not None: 302 | weighted_nll_loss = weights * nll_loss 303 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 304 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 305 | 306 | return nll_loss, weighted_nll_loss 307 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/sampling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py 3 | """ 4 | 5 | 6 | from typing import Dict, Union 7 | 8 | import torch 9 | from omegaconf import ListConfig, OmegaConf 10 | from tqdm import tqdm 11 | 12 | from ...modules.diffusionmodules.sampling_utils import (get_ancestral_step, 13 | linear_multistep_coeff, 14 | to_d, to_neg_log_sigma, 15 | to_sigma) 16 | from ...util import append_dims, default, instantiate_from_config 17 | 18 | DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} 19 | 20 | 21 | class BaseDiffusionSampler: 22 | def __init__( 23 | self, 24 | discretization_config: Union[Dict, ListConfig, OmegaConf], 25 | num_steps: Union[int, None] = None, 26 | guider_config: Union[Dict, ListConfig, OmegaConf, None] = None, 27 | verbose: bool = False, 28 | device: str = "cuda", 29 | ): 30 | self.num_steps = num_steps 31 | self.discretization = instantiate_from_config(discretization_config) 32 | self.guider = instantiate_from_config( 33 | default( 34 | guider_config, 35 | DEFAULT_GUIDER, 36 | ) 37 | ) 38 | self.verbose = verbose 39 | self.device = device 40 | 41 | def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): 42 | sigmas = self.discretization( 43 | self.num_steps if num_steps is None else num_steps, device=self.device 44 | ) 45 | uc = default(uc, cond) 46 | 47 | x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) 48 | num_sigmas = len(sigmas) 49 | 50 | s_in = x.new_ones([x.shape[0]]) 51 | 52 | return x, s_in, sigmas, num_sigmas, cond, uc 53 | 54 | def denoise(self, x, denoiser, sigma, cond, uc): 55 | denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc)) 56 | denoised = self.guider(denoised, sigma) 57 | return denoised 58 | 59 | def get_sigma_gen(self, num_sigmas): 60 | sigma_generator = range(num_sigmas - 1) 61 | if self.verbose: 62 | print("#" * 30, " Sampling setting ", "#" * 30) 63 | print(f"Sampler: {self.__class__.__name__}") 64 | print(f"Discretization: {self.discretization.__class__.__name__}") 65 | print(f"Guider: {self.guider.__class__.__name__}") 66 | sigma_generator = tqdm( 67 | sigma_generator, 68 | total=num_sigmas, 69 | desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps", 70 | ) 71 | return sigma_generator 72 | 73 | 74 | class SingleStepDiffusionSampler(BaseDiffusionSampler): 75 | def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs): 76 | raise NotImplementedError 77 | 78 | def euler_step(self, x, d, dt): 79 | return x + dt * d 80 | 81 | 82 | class EDMSampler(SingleStepDiffusionSampler): 83 | def __init__( 84 | self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs 85 | ): 86 | super().__init__(*args, **kwargs) 87 | 88 | self.s_churn = s_churn 89 | self.s_tmin = s_tmin 90 | self.s_tmax = s_tmax 91 | self.s_noise = s_noise 92 | 93 | def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0): 94 | sigma_hat = sigma * (gamma + 1.0) 95 | if gamma > 0: 96 | eps = torch.randn_like(x) * self.s_noise 97 | x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 98 | 99 | denoised = self.denoise(x, denoiser, sigma_hat, cond, uc) 100 | d = to_d(x, sigma_hat, denoised) 101 | dt = append_dims(next_sigma - sigma_hat, x.ndim) 102 | 103 | euler_step = self.euler_step(x, d, dt) 104 | x = self.possible_correction_step( 105 | euler_step, x, d, dt, next_sigma, denoiser, cond, uc 106 | ) 107 | return x 108 | 109 | def __call__(self, denoiser, x, cond, uc=None, num_steps=None): 110 | x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( 111 | x, cond, uc, num_steps 112 | ) 113 | 114 | for i in self.get_sigma_gen(num_sigmas): 115 | gamma = ( 116 | min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) 117 | if self.s_tmin <= sigmas[i] <= self.s_tmax 118 | else 0.0 119 | ) 120 | x = self.sampler_step( 121 | s_in * sigmas[i], 122 | s_in * sigmas[i + 1], 123 | denoiser, 124 | x, 125 | cond, 126 | uc, 127 | gamma, 128 | ) 129 | 130 | return x 131 | 132 | 133 | class AncestralSampler(SingleStepDiffusionSampler): 134 | def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs): 135 | super().__init__(*args, **kwargs) 136 | 137 | self.eta = eta 138 | self.s_noise = s_noise 139 | self.noise_sampler = lambda x: torch.randn_like(x) 140 | 141 | def ancestral_euler_step(self, x, denoised, sigma, sigma_down): 142 | d = to_d(x, sigma, denoised) 143 | dt = append_dims(sigma_down - sigma, x.ndim) 144 | 145 | return self.euler_step(x, d, dt) 146 | 147 | def ancestral_step(self, x, sigma, next_sigma, sigma_up): 148 | x = torch.where( 149 | append_dims(next_sigma, x.ndim) > 0.0, 150 | x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim), 151 | x, 152 | ) 153 | return x 154 | 155 | def __call__(self, denoiser, x, cond, uc=None, num_steps=None): 156 | x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( 157 | x, cond, uc, num_steps 158 | ) 159 | 160 | for i in self.get_sigma_gen(num_sigmas): 161 | x = self.sampler_step( 162 | s_in * sigmas[i], 163 | s_in * sigmas[i + 1], 164 | denoiser, 165 | x, 166 | cond, 167 | uc, 168 | ) 169 | 170 | return x 171 | 172 | 173 | class LinearMultistepSampler(BaseDiffusionSampler): 174 | def __init__( 175 | self, 176 | order=4, 177 | *args, 178 | **kwargs, 179 | ): 180 | super().__init__(*args, **kwargs) 181 | 182 | self.order = order 183 | 184 | def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): 185 | x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( 186 | x, cond, uc, num_steps 187 | ) 188 | 189 | ds = [] 190 | sigmas_cpu = sigmas.detach().cpu().numpy() 191 | for i in self.get_sigma_gen(num_sigmas): 192 | sigma = s_in * sigmas[i] 193 | denoised = denoiser( 194 | *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs 195 | ) 196 | denoised = self.guider(denoised, sigma) 197 | d = to_d(x, sigma, denoised) 198 | ds.append(d) 199 | if len(ds) > self.order: 200 | ds.pop(0) 201 | cur_order = min(i + 1, self.order) 202 | coeffs = [ 203 | linear_multistep_coeff(cur_order, sigmas_cpu, i, j) 204 | for j in range(cur_order) 205 | ] 206 | x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) 207 | 208 | return x 209 | 210 | 211 | class EulerEDMSampler(EDMSampler): 212 | def possible_correction_step( 213 | self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc 214 | ): 215 | return euler_step 216 | 217 | 218 | class HeunEDMSampler(EDMSampler): 219 | def possible_correction_step( 220 | self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc 221 | ): 222 | if torch.sum(next_sigma) < 1e-14: 223 | # Save a network evaluation if all noise levels are 0 224 | return euler_step 225 | else: 226 | denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc) 227 | d_new = to_d(euler_step, next_sigma, denoised) 228 | d_prime = (d + d_new) / 2.0 229 | 230 | # apply correction if noise level is not 0 231 | x = torch.where( 232 | append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step 233 | ) 234 | return x 235 | 236 | 237 | class EulerAncestralSampler(AncestralSampler): 238 | def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc): 239 | sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) 240 | denoised = self.denoise(x, denoiser, sigma, cond, uc) 241 | x = self.ancestral_euler_step(x, denoised, sigma, sigma_down) 242 | x = self.ancestral_step(x, sigma, next_sigma, sigma_up) 243 | 244 | return x 245 | 246 | 247 | class DPMPP2SAncestralSampler(AncestralSampler): 248 | def get_variables(self, sigma, sigma_down): 249 | t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)] 250 | h = t_next - t 251 | s = t + 0.5 * h 252 | return h, s, t, t_next 253 | 254 | def get_mult(self, h, s, t, t_next): 255 | mult1 = to_sigma(s) / to_sigma(t) 256 | mult2 = (-0.5 * h).expm1() 257 | mult3 = to_sigma(t_next) / to_sigma(t) 258 | mult4 = (-h).expm1() 259 | 260 | return mult1, mult2, mult3, mult4 261 | 262 | def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs): 263 | sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) 264 | denoised = self.denoise(x, denoiser, sigma, cond, uc) 265 | x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down) 266 | 267 | if torch.sum(sigma_down) < 1e-14: 268 | # Save a network evaluation if all noise levels are 0 269 | x = x_euler 270 | else: 271 | h, s, t, t_next = self.get_variables(sigma, sigma_down) 272 | mult = [ 273 | append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next) 274 | ] 275 | 276 | x2 = mult[0] * x - mult[1] * denoised 277 | denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc) 278 | x_dpmpp2s = mult[2] * x - mult[3] * denoised2 279 | 280 | # apply correction if noise level is not 0 281 | x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler) 282 | 283 | x = self.ancestral_step(x, sigma, next_sigma, sigma_up) 284 | return x 285 | 286 | 287 | class DPMPP2MSampler(BaseDiffusionSampler): 288 | def get_variables(self, sigma, next_sigma, previous_sigma=None): 289 | t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)] 290 | h = t_next - t 291 | 292 | if previous_sigma is not None: 293 | h_last = t - to_neg_log_sigma(previous_sigma) 294 | r = h_last / h 295 | return h, r, t, t_next 296 | else: 297 | return h, None, t, t_next 298 | 299 | def get_mult(self, h, r, t, t_next, previous_sigma): 300 | mult1 = to_sigma(t_next) / to_sigma(t) 301 | mult2 = (-h).expm1() 302 | 303 | if previous_sigma is not None: 304 | mult3 = 1 + 1 / (2 * r) 305 | mult4 = 1 / (2 * r) 306 | return mult1, mult2, mult3, mult4 307 | else: 308 | return mult1, mult2 309 | 310 | def sampler_step( 311 | self, 312 | old_denoised, 313 | previous_sigma, 314 | sigma, 315 | next_sigma, 316 | denoiser, 317 | x, 318 | cond, 319 | uc=None, 320 | ): 321 | denoised = self.denoise(x, denoiser, sigma, cond, uc) 322 | 323 | h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) 324 | mult = [ 325 | append_dims(mult, x.ndim) 326 | for mult in self.get_mult(h, r, t, t_next, previous_sigma) 327 | ] 328 | 329 | x_standard = mult[0] * x - mult[1] * denoised 330 | if old_denoised is None or torch.sum(next_sigma) < 1e-14: 331 | # Save a network evaluation if all noise levels are 0 or on the first step 332 | return x_standard, denoised 333 | else: 334 | denoised_d = mult[2] * denoised - mult[3] * old_denoised 335 | x_advanced = mult[0] * x - mult[1] * denoised_d 336 | 337 | # apply correction if noise level is not 0 and not first step 338 | x = torch.where( 339 | append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard 340 | ) 341 | 342 | return x, denoised 343 | 344 | def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): 345 | x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( 346 | x, cond, uc, num_steps 347 | ) 348 | 349 | old_denoised = None 350 | for i in self.get_sigma_gen(num_sigmas): 351 | x, old_denoised = self.sampler_step( 352 | old_denoised, 353 | None if i == 0 else s_in * sigmas[i - 1], 354 | s_in * sigmas[i], 355 | s_in * sigmas[i + 1], 356 | denoiser, 357 | x, 358 | cond, 359 | uc=uc, 360 | ) 361 | 362 | return x 363 | -------------------------------------------------------------------------------- /sgm/models/diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | from contextlib import contextmanager 3 | from typing import Any, Dict, List, Optional, Tuple, Union 4 | 5 | import pytorch_lightning as pl 6 | import torch 7 | from omegaconf import ListConfig, OmegaConf 8 | from safetensors.torch import load_file as load_safetensors 9 | from torch.optim.lr_scheduler import LambdaLR 10 | 11 | from ..modules import UNCONDITIONAL_CONFIG 12 | from ..modules.autoencoding.temporal_ae import VideoDecoder 13 | from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER 14 | from ..modules.ema import LitEma 15 | from ..util import (default, disabled_train, get_obj_from_str, 16 | instantiate_from_config, log_txt_as_img) 17 | 18 | 19 | class DiffusionEngine(pl.LightningModule): 20 | def __init__( 21 | self, 22 | network_config, 23 | denoiser_config, 24 | first_stage_config, 25 | conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None, 26 | sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None, 27 | optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None, 28 | scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None, 29 | loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None, 30 | network_wrapper: Union[None, str] = None, 31 | ckpt_path: Union[None, str] = None, 32 | use_ema: bool = False, 33 | ema_decay_rate: float = 0.9999, 34 | scale_factor: float = 1.0, 35 | disable_first_stage_autocast=False, 36 | input_key: str = "jpg", 37 | log_keys: Union[List, None] = None, 38 | no_cond_log: bool = False, 39 | compile_model: bool = False, 40 | en_and_decode_n_samples_a_time: Optional[int] = None, 41 | ): 42 | super().__init__() 43 | self.log_keys = log_keys 44 | self.input_key = input_key 45 | self.optimizer_config = default( 46 | optimizer_config, {"target": "torch.optim.AdamW"} 47 | ) 48 | model = instantiate_from_config(network_config) 49 | self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))( 50 | model, compile_model=compile_model 51 | ) 52 | 53 | self.denoiser = instantiate_from_config(denoiser_config) 54 | self.sampler = ( 55 | instantiate_from_config(sampler_config) 56 | if sampler_config is not None 57 | else None 58 | ) 59 | self.conditioner = instantiate_from_config( 60 | default(conditioner_config, UNCONDITIONAL_CONFIG) 61 | ) 62 | self.scheduler_config = scheduler_config 63 | self._init_first_stage(first_stage_config) 64 | 65 | self.loss_fn = ( 66 | instantiate_from_config(loss_fn_config) 67 | if loss_fn_config is not None 68 | else None 69 | ) 70 | 71 | self.use_ema = use_ema 72 | if self.use_ema: 73 | self.model_ema = LitEma(self.model, decay=ema_decay_rate) 74 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 75 | 76 | self.scale_factor = scale_factor 77 | self.disable_first_stage_autocast = disable_first_stage_autocast 78 | self.no_cond_log = no_cond_log 79 | 80 | if ckpt_path is not None: 81 | self.init_from_ckpt(ckpt_path) 82 | 83 | self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time 84 | 85 | def init_from_ckpt( 86 | self, 87 | path: str, 88 | ) -> None: 89 | if path.endswith("ckpt"): 90 | sd = torch.load(path, map_location="cpu")["state_dict"] 91 | elif path.endswith("safetensors"): 92 | sd = load_safetensors(path) 93 | else: 94 | raise NotImplementedError 95 | 96 | missing, unexpected = self.load_state_dict(sd, strict=False) 97 | print( 98 | f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" 99 | ) 100 | if len(missing) > 0: 101 | print(f"Missing Keys: {missing}") 102 | if len(unexpected) > 0: 103 | print(f"Unexpected Keys: {unexpected}") 104 | 105 | def _init_first_stage(self, config): 106 | model = instantiate_from_config(config).eval() 107 | model.train = disabled_train 108 | for param in model.parameters(): 109 | param.requires_grad = False 110 | self.first_stage_model = model 111 | 112 | def get_input(self, batch): 113 | # assuming unified data format, dataloader returns a dict. 114 | # image tensors should be scaled to -1 ... 1 and in bchw format 115 | return batch[self.input_key] 116 | 117 | @torch.no_grad() 118 | def decode_first_stage(self, z): 119 | z = 1.0 / self.scale_factor * z 120 | n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) 121 | 122 | n_rounds = math.ceil(z.shape[0] / n_samples) 123 | all_out = [] 124 | with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): 125 | for n in range(n_rounds): 126 | if isinstance(self.first_stage_model.decoder, VideoDecoder): 127 | kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} 128 | else: 129 | kwargs = {} 130 | out = self.first_stage_model.decode( 131 | z[n * n_samples : (n + 1) * n_samples], **kwargs 132 | ) 133 | all_out.append(out) 134 | out = torch.cat(all_out, dim=0) 135 | return out 136 | 137 | @torch.no_grad() 138 | def encode_first_stage(self, x): 139 | n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0]) 140 | n_rounds = math.ceil(x.shape[0] / n_samples) 141 | all_out = [] 142 | with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): 143 | for n in range(n_rounds): 144 | out = self.first_stage_model.encode( 145 | x[n * n_samples : (n + 1) * n_samples] 146 | ) 147 | all_out.append(out) 148 | z = torch.cat(all_out, dim=0) 149 | z = self.scale_factor * z 150 | return z 151 | 152 | def forward(self, x, batch): 153 | loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch) 154 | loss_mean = loss.mean() 155 | loss_dict = {"loss": loss_mean} 156 | return loss_mean, loss_dict 157 | 158 | def shared_step(self, batch: Dict) -> Any: 159 | x = self.get_input(batch) 160 | x = self.encode_first_stage(x) 161 | batch["global_step"] = self.global_step 162 | loss, loss_dict = self(x, batch) 163 | return loss, loss_dict 164 | 165 | def training_step(self, batch, batch_idx): 166 | loss, loss_dict = self.shared_step(batch) 167 | 168 | self.log_dict( 169 | loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False 170 | ) 171 | 172 | self.log( 173 | "global_step", 174 | self.global_step, 175 | prog_bar=True, 176 | logger=True, 177 | on_step=True, 178 | on_epoch=False, 179 | ) 180 | 181 | if self.scheduler_config is not None: 182 | lr = self.optimizers().param_groups[0]["lr"] 183 | self.log( 184 | "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False 185 | ) 186 | 187 | return loss 188 | 189 | def on_train_start(self, *args, **kwargs): 190 | if self.sampler is None or self.loss_fn is None: 191 | raise ValueError("Sampler and loss function need to be set for training.") 192 | 193 | def on_train_batch_end(self, *args, **kwargs): 194 | if self.use_ema: 195 | self.model_ema(self.model) 196 | 197 | @contextmanager 198 | def ema_scope(self, context=None): 199 | if self.use_ema: 200 | self.model_ema.store(self.model.parameters()) 201 | self.model_ema.copy_to(self.model) 202 | if context is not None: 203 | print(f"{context}: Switched to EMA weights") 204 | try: 205 | yield None 206 | finally: 207 | if self.use_ema: 208 | self.model_ema.restore(self.model.parameters()) 209 | if context is not None: 210 | print(f"{context}: Restored training weights") 211 | 212 | def instantiate_optimizer_from_config(self, params, lr, cfg): 213 | return get_obj_from_str(cfg["target"])( 214 | params, lr=lr, **cfg.get("params", dict()) 215 | ) 216 | 217 | def configure_optimizers(self): 218 | lr = self.learning_rate 219 | params = list(self.model.parameters()) 220 | for embedder in self.conditioner.embedders: 221 | if embedder.is_trainable: 222 | params = params + list(embedder.parameters()) 223 | opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config) 224 | if self.scheduler_config is not None: 225 | scheduler = instantiate_from_config(self.scheduler_config) 226 | print("Setting up LambdaLR scheduler...") 227 | scheduler = [ 228 | { 229 | "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), 230 | "interval": "step", 231 | "frequency": 1, 232 | } 233 | ] 234 | return [opt], scheduler 235 | return opt 236 | 237 | @torch.no_grad() 238 | def sample( 239 | self, 240 | cond: Dict, 241 | uc: Union[Dict, None] = None, 242 | batch_size: int = 16, 243 | shape: Union[None, Tuple, List] = None, 244 | **kwargs, 245 | ): 246 | randn = torch.randn(batch_size, *shape).to(self.device) 247 | 248 | denoiser = lambda input, sigma, c: self.denoiser( 249 | self.model, input, sigma, c, **kwargs 250 | ) 251 | samples = self.sampler(denoiser, randn, cond, uc=uc) 252 | return samples 253 | 254 | @torch.no_grad() 255 | def log_conditionings(self, batch: Dict, n: int) -> Dict: 256 | """ 257 | Defines heuristics to log different conditionings. 258 | These can be lists of strings (text-to-image), tensors, ints, ... 259 | """ 260 | image_h, image_w = batch[self.input_key].shape[2:] 261 | log = dict() 262 | 263 | for embedder in self.conditioner.embedders: 264 | if ( 265 | (self.log_keys is None) or (embedder.input_key in self.log_keys) 266 | ) and not self.no_cond_log: 267 | x = batch[embedder.input_key][:n] 268 | if isinstance(x, torch.Tensor): 269 | if x.dim() == 1: 270 | # class-conditional, convert integer to string 271 | x = [str(x[i].item()) for i in range(x.shape[0])] 272 | xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4) 273 | elif x.dim() == 2: 274 | # size and crop cond and the like 275 | x = [ 276 | "x".join([str(xx) for xx in x[i].tolist()]) 277 | for i in range(x.shape[0]) 278 | ] 279 | xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) 280 | else: 281 | raise NotImplementedError() 282 | elif isinstance(x, (List, ListConfig)): 283 | if isinstance(x[0], str): 284 | # strings 285 | xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) 286 | else: 287 | raise NotImplementedError() 288 | else: 289 | raise NotImplementedError() 290 | log[embedder.input_key] = xc 291 | return log 292 | 293 | @torch.no_grad() 294 | def log_images( 295 | self, 296 | batch: Dict, 297 | N: int = 8, 298 | sample: bool = True, 299 | ucg_keys: List[str] = None, 300 | **kwargs, 301 | ) -> Dict: 302 | conditioner_input_keys = [e.input_key for e in self.conditioner.embedders] 303 | if ucg_keys: 304 | assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), ( 305 | "Each defined ucg key for sampling must be in the provided conditioner input keys," 306 | f"but we have {ucg_keys} vs. {conditioner_input_keys}" 307 | ) 308 | else: 309 | ucg_keys = conditioner_input_keys 310 | log = dict() 311 | 312 | x = self.get_input(batch) 313 | 314 | c, uc = self.conditioner.get_unconditional_conditioning( 315 | batch, 316 | force_uc_zero_embeddings=ucg_keys 317 | if len(self.conditioner.embedders) > 0 318 | else [], 319 | ) 320 | 321 | sampling_kwargs = {} 322 | 323 | N = min(x.shape[0], N) 324 | x = x.to(self.device)[:N] 325 | log["inputs"] = x 326 | z = self.encode_first_stage(x) 327 | log["reconstructions"] = self.decode_first_stage(z) 328 | log.update(self.log_conditionings(batch, N)) 329 | 330 | for k in c: 331 | if isinstance(c[k], torch.Tensor): 332 | c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc)) 333 | 334 | if sample: 335 | with self.ema_scope("Plotting"): 336 | samples = self.sample( 337 | c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs 338 | ) 339 | samples = self.decode_first_stage(samples) 340 | log["samples"] = samples 341 | return log 342 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | partially adopted from 3 | https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 4 | and 5 | https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 6 | and 7 | https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 8 | 9 | thanks! 10 | """ 11 | 12 | import math 13 | from typing import Optional 14 | 15 | import torch 16 | import torch.nn as nn 17 | from einops import rearrange, repeat 18 | 19 | 20 | def make_beta_schedule( 21 | schedule, 22 | n_timestep, 23 | linear_start=1e-4, 24 | linear_end=2e-2, 25 | ): 26 | if schedule == "linear": 27 | betas = ( 28 | torch.linspace( 29 | linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 30 | ) 31 | ** 2 32 | ) 33 | return betas.numpy() 34 | 35 | 36 | def extract_into_tensor(a, t, x_shape): 37 | b, *_ = t.shape 38 | out = a.gather(-1, t) 39 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 40 | 41 | 42 | def mixed_checkpoint(func, inputs: dict, params, flag): 43 | """ 44 | Evaluate a function without caching intermediate activations, allowing for 45 | reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function 46 | borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that 47 | it also works with non-tensor inputs 48 | :param func: the function to evaluate. 49 | :param inputs: the argument dictionary to pass to `func`. 50 | :param params: a sequence of parameters `func` depends on but does not 51 | explicitly take as arguments. 52 | :param flag: if False, disable gradient checkpointing. 53 | """ 54 | if flag: 55 | tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)] 56 | tensor_inputs = [ 57 | inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor) 58 | ] 59 | non_tensor_keys = [ 60 | key for key in inputs if not isinstance(inputs[key], torch.Tensor) 61 | ] 62 | non_tensor_inputs = [ 63 | inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor) 64 | ] 65 | args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params) 66 | return MixedCheckpointFunction.apply( 67 | func, 68 | len(tensor_inputs), 69 | len(non_tensor_inputs), 70 | tensor_keys, 71 | non_tensor_keys, 72 | *args, 73 | ) 74 | else: 75 | return func(**inputs) 76 | 77 | 78 | class MixedCheckpointFunction(torch.autograd.Function): 79 | @staticmethod 80 | def forward( 81 | ctx, 82 | run_function, 83 | length_tensors, 84 | length_non_tensors, 85 | tensor_keys, 86 | non_tensor_keys, 87 | *args, 88 | ): 89 | ctx.end_tensors = length_tensors 90 | ctx.end_non_tensors = length_tensors + length_non_tensors 91 | ctx.gpu_autocast_kwargs = { 92 | "enabled": torch.is_autocast_enabled(), 93 | "dtype": torch.get_autocast_gpu_dtype(), 94 | "cache_enabled": torch.is_autocast_cache_enabled(), 95 | } 96 | assert ( 97 | len(tensor_keys) == length_tensors 98 | and len(non_tensor_keys) == length_non_tensors 99 | ) 100 | 101 | ctx.input_tensors = { 102 | key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors])) 103 | } 104 | ctx.input_non_tensors = { 105 | key: val 106 | for (key, val) in zip( 107 | non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors]) 108 | ) 109 | } 110 | ctx.run_function = run_function 111 | ctx.input_params = list(args[ctx.end_non_tensors :]) 112 | 113 | with torch.no_grad(): 114 | output_tensors = ctx.run_function( 115 | **ctx.input_tensors, **ctx.input_non_tensors 116 | ) 117 | return output_tensors 118 | 119 | @staticmethod 120 | def backward(ctx, *output_grads): 121 | # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)} 122 | ctx.input_tensors = { 123 | key: ctx.input_tensors[key].detach().requires_grad_(True) 124 | for key in ctx.input_tensors 125 | } 126 | 127 | with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): 128 | # Fixes a bug where the first op in run_function modifies the 129 | # Tensor storage in place, which is not allowed for detach()'d 130 | # Tensors. 131 | shallow_copies = { 132 | key: ctx.input_tensors[key].view_as(ctx.input_tensors[key]) 133 | for key in ctx.input_tensors 134 | } 135 | # shallow_copies.update(additional_args) 136 | output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors) 137 | input_grads = torch.autograd.grad( 138 | output_tensors, 139 | list(ctx.input_tensors.values()) + ctx.input_params, 140 | output_grads, 141 | allow_unused=True, 142 | ) 143 | del ctx.input_tensors 144 | del ctx.input_params 145 | del output_tensors 146 | return ( 147 | (None, None, None, None, None) 148 | + input_grads[: ctx.end_tensors] 149 | + (None,) * (ctx.end_non_tensors - ctx.end_tensors) 150 | + input_grads[ctx.end_tensors :] 151 | ) 152 | 153 | 154 | def checkpoint(func, inputs, params, flag): 155 | """ 156 | Evaluate a function without caching intermediate activations, allowing for 157 | reduced memory at the expense of extra compute in the backward pass. 158 | :param func: the function to evaluate. 159 | :param inputs: the argument sequence to pass to `func`. 160 | :param params: a sequence of parameters `func` depends on but does not 161 | explicitly take as arguments. 162 | :param flag: if False, disable gradient checkpointing. 163 | """ 164 | if flag: 165 | args = tuple(inputs) + tuple(params) 166 | return CheckpointFunction.apply(func, len(inputs), *args) 167 | else: 168 | return func(*inputs) 169 | 170 | 171 | class CheckpointFunction(torch.autograd.Function): 172 | @staticmethod 173 | def forward(ctx, run_function, length, *args): 174 | ctx.run_function = run_function 175 | ctx.input_tensors = list(args[:length]) 176 | ctx.input_params = list(args[length:]) 177 | ctx.gpu_autocast_kwargs = { 178 | "enabled": torch.is_autocast_enabled(), 179 | "dtype": torch.get_autocast_gpu_dtype(), 180 | "cache_enabled": torch.is_autocast_cache_enabled(), 181 | } 182 | with torch.no_grad(): 183 | output_tensors = ctx.run_function(*ctx.input_tensors) 184 | return output_tensors 185 | 186 | @staticmethod 187 | def backward(ctx, *output_grads): 188 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 189 | with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): 190 | # Fixes a bug where the first op in run_function modifies the 191 | # Tensor storage in place, which is not allowed for detach()'d 192 | # Tensors. 193 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 194 | output_tensors = ctx.run_function(*shallow_copies) 195 | input_grads = torch.autograd.grad( 196 | output_tensors, 197 | ctx.input_tensors + ctx.input_params, 198 | output_grads, 199 | allow_unused=True, 200 | ) 201 | del ctx.input_tensors 202 | del ctx.input_params 203 | del output_tensors 204 | return (None, None) + input_grads 205 | 206 | 207 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 208 | """ 209 | Create sinusoidal timestep embeddings. 210 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 211 | These may be fractional. 212 | :param dim: the dimension of the output. 213 | :param max_period: controls the minimum frequency of the embeddings. 214 | :return: an [N x dim] Tensor of positional embeddings. 215 | """ 216 | if not repeat_only: 217 | half = dim // 2 218 | freqs = torch.exp( 219 | -math.log(max_period) 220 | * torch.arange(start=0, end=half, dtype=torch.float32) 221 | / half 222 | ).to(device=timesteps.device) 223 | args = timesteps[:, None].float() * freqs[None] 224 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 225 | if dim % 2: 226 | embedding = torch.cat( 227 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 228 | ) 229 | else: 230 | embedding = repeat(timesteps, "b -> b d", d=dim) 231 | return embedding 232 | 233 | 234 | def zero_module(module): 235 | """ 236 | Zero out the parameters of a module and return it. 237 | """ 238 | for p in module.parameters(): 239 | p.detach().zero_() 240 | return module 241 | 242 | 243 | def scale_module(module, scale): 244 | """ 245 | Scale the parameters of a module and return it. 246 | """ 247 | for p in module.parameters(): 248 | p.detach().mul_(scale) 249 | return module 250 | 251 | 252 | def mean_flat(tensor): 253 | """ 254 | Take the mean over all non-batch dimensions. 255 | """ 256 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 257 | 258 | 259 | def normalization(channels): 260 | """ 261 | Make a standard normalization layer. 262 | :param channels: number of input channels. 263 | :return: an nn.Module for normalization. 264 | """ 265 | return GroupNorm32(32, channels) 266 | 267 | 268 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 269 | class SiLU(nn.Module): 270 | def forward(self, x): 271 | return x * torch.sigmoid(x) 272 | 273 | 274 | class GroupNorm32(nn.GroupNorm): 275 | def forward(self, x): 276 | return super().forward(x.float()).type(x.dtype) 277 | 278 | 279 | def conv_nd(dims, *args, **kwargs): 280 | """ 281 | Create a 1D, 2D, or 3D convolution module. 282 | """ 283 | if dims == 1: 284 | return nn.Conv1d(*args, **kwargs) 285 | elif dims == 2: 286 | return nn.Conv2d(*args, **kwargs) 287 | elif dims == 3: 288 | return nn.Conv3d(*args, **kwargs) 289 | raise ValueError(f"unsupported dimensions: {dims}") 290 | 291 | 292 | def linear(*args, **kwargs): 293 | """ 294 | Create a linear module. 295 | """ 296 | return nn.Linear(*args, **kwargs) 297 | 298 | 299 | def avg_pool_nd(dims, *args, **kwargs): 300 | """ 301 | Create a 1D, 2D, or 3D average pooling module. 302 | """ 303 | if dims == 1: 304 | return nn.AvgPool1d(*args, **kwargs) 305 | elif dims == 2: 306 | return nn.AvgPool2d(*args, **kwargs) 307 | elif dims == 3: 308 | return nn.AvgPool3d(*args, **kwargs) 309 | raise ValueError(f"unsupported dimensions: {dims}") 310 | 311 | 312 | class AlphaBlender(nn.Module): 313 | strategies = ["learned", "fixed", "learned_with_images"] 314 | 315 | def __init__( 316 | self, 317 | alpha: float, 318 | merge_strategy: str = "learned_with_images", 319 | rearrange_pattern: str = "b t -> (b t) 1 1", 320 | ): 321 | super().__init__() 322 | self.merge_strategy = merge_strategy 323 | self.rearrange_pattern = rearrange_pattern 324 | 325 | assert ( 326 | merge_strategy in self.strategies 327 | ), f"merge_strategy needs to be in {self.strategies}" 328 | 329 | if self.merge_strategy == "fixed": 330 | self.register_buffer("mix_factor", torch.Tensor([alpha])) 331 | elif ( 332 | self.merge_strategy == "learned" 333 | or self.merge_strategy == "learned_with_images" 334 | ): 335 | self.register_parameter( 336 | "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) 337 | ) 338 | else: 339 | raise ValueError(f"unknown merge strategy {self.merge_strategy}") 340 | 341 | def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor: 342 | if self.merge_strategy == "fixed": 343 | alpha = self.mix_factor 344 | elif self.merge_strategy == "learned": 345 | alpha = torch.sigmoid(self.mix_factor) 346 | elif self.merge_strategy == "learned_with_images": 347 | assert image_only_indicator is not None, "need image_only_indicator ..." 348 | alpha = torch.where( 349 | image_only_indicator.bool(), 350 | torch.ones(1, 1, device=image_only_indicator.device), 351 | rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"), 352 | ) 353 | alpha = rearrange(alpha, self.rearrange_pattern) 354 | else: 355 | raise NotImplementedError 356 | return alpha 357 | 358 | def forward( 359 | self, 360 | x_spatial: torch.Tensor, 361 | x_temporal: torch.Tensor, 362 | image_only_indicator: Optional[torch.Tensor] = None, 363 | ) -> torch.Tensor: 364 | alpha = self.get_alpha(image_only_indicator) 365 | x = ( 366 | alpha.to(x_spatial.dtype) * x_spatial 367 | + (1.0 - alpha).to(x_spatial.dtype) * x_temporal 368 | ) 369 | return x 370 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/regularizers/quantize.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import abstractmethod 3 | from typing import Dict, Iterator, Literal, Optional, Tuple, Union 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from einops import rearrange 10 | from torch import einsum 11 | 12 | from .base import AbstractRegularizer, measure_perplexity 13 | 14 | logpy = logging.getLogger(__name__) 15 | 16 | 17 | class AbstractQuantizer(AbstractRegularizer): 18 | def __init__(self): 19 | super().__init__() 20 | # Define these in your init 21 | # shape (N,) 22 | self.used: Optional[torch.Tensor] 23 | self.re_embed: int 24 | self.unknown_index: Union[Literal["random"], int] 25 | 26 | def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor: 27 | assert self.used is not None, "You need to define used indices for remap" 28 | ishape = inds.shape 29 | assert len(ishape) > 1 30 | inds = inds.reshape(ishape[0], -1) 31 | used = self.used.to(inds) 32 | match = (inds[:, :, None] == used[None, None, ...]).long() 33 | new = match.argmax(-1) 34 | unknown = match.sum(2) < 1 35 | if self.unknown_index == "random": 36 | new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to( 37 | device=new.device 38 | ) 39 | else: 40 | new[unknown] = self.unknown_index 41 | return new.reshape(ishape) 42 | 43 | def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor: 44 | assert self.used is not None, "You need to define used indices for remap" 45 | ishape = inds.shape 46 | assert len(ishape) > 1 47 | inds = inds.reshape(ishape[0], -1) 48 | used = self.used.to(inds) 49 | if self.re_embed > self.used.shape[0]: # extra token 50 | inds[inds >= self.used.shape[0]] = 0 # simply set to zero 51 | back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) 52 | return back.reshape(ishape) 53 | 54 | @abstractmethod 55 | def get_codebook_entry( 56 | self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None 57 | ) -> torch.Tensor: 58 | raise NotImplementedError() 59 | 60 | def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]: 61 | yield from self.parameters() 62 | 63 | 64 | class GumbelQuantizer(AbstractQuantizer): 65 | """ 66 | credit to @karpathy: 67 | https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) 68 | Gumbel Softmax trick quantizer 69 | Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 70 | https://arxiv.org/abs/1611.01144 71 | """ 72 | 73 | def __init__( 74 | self, 75 | num_hiddens: int, 76 | embedding_dim: int, 77 | n_embed: int, 78 | straight_through: bool = True, 79 | kl_weight: float = 5e-4, 80 | temp_init: float = 1.0, 81 | remap: Optional[str] = None, 82 | unknown_index: str = "random", 83 | loss_key: str = "loss/vq", 84 | ) -> None: 85 | super().__init__() 86 | 87 | self.loss_key = loss_key 88 | self.embedding_dim = embedding_dim 89 | self.n_embed = n_embed 90 | 91 | self.straight_through = straight_through 92 | self.temperature = temp_init 93 | self.kl_weight = kl_weight 94 | 95 | self.proj = nn.Conv2d(num_hiddens, n_embed, 1) 96 | self.embed = nn.Embedding(n_embed, embedding_dim) 97 | 98 | self.remap = remap 99 | if self.remap is not None: 100 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 101 | self.re_embed = self.used.shape[0] 102 | else: 103 | self.used = None 104 | self.re_embed = n_embed 105 | if unknown_index == "extra": 106 | self.unknown_index = self.re_embed 107 | self.re_embed = self.re_embed + 1 108 | else: 109 | assert unknown_index == "random" or isinstance( 110 | unknown_index, int 111 | ), "unknown index needs to be 'random', 'extra' or any integer" 112 | self.unknown_index = unknown_index # "random" or "extra" or integer 113 | if self.remap is not None: 114 | logpy.info( 115 | f"Remapping {self.n_embed} indices to {self.re_embed} indices. " 116 | f"Using {self.unknown_index} for unknown indices." 117 | ) 118 | 119 | def forward( 120 | self, z: torch.Tensor, temp: Optional[float] = None, return_logits: bool = False 121 | ) -> Tuple[torch.Tensor, Dict]: 122 | # force hard = True when we are in eval mode, as we must quantize. 123 | # actually, always true seems to work 124 | hard = self.straight_through if self.training else True 125 | temp = self.temperature if temp is None else temp 126 | out_dict = {} 127 | logits = self.proj(z) 128 | if self.remap is not None: 129 | # continue only with used logits 130 | full_zeros = torch.zeros_like(logits) 131 | logits = logits[:, self.used, ...] 132 | 133 | soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) 134 | if self.remap is not None: 135 | # go back to all entries but unused set to zero 136 | full_zeros[:, self.used, ...] = soft_one_hot 137 | soft_one_hot = full_zeros 138 | z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) 139 | 140 | # + kl divergence to the prior loss 141 | qy = F.softmax(logits, dim=1) 142 | diff = ( 143 | self.kl_weight 144 | * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() 145 | ) 146 | out_dict[self.loss_key] = diff 147 | 148 | ind = soft_one_hot.argmax(dim=1) 149 | out_dict["indices"] = ind 150 | if self.remap is not None: 151 | ind = self.remap_to_used(ind) 152 | 153 | if return_logits: 154 | out_dict["logits"] = logits 155 | 156 | return z_q, out_dict 157 | 158 | def get_codebook_entry(self, indices, shape): 159 | # TODO: shape not yet optional 160 | b, h, w, c = shape 161 | assert b * h * w == indices.shape[0] 162 | indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w) 163 | if self.remap is not None: 164 | indices = self.unmap_to_all(indices) 165 | one_hot = ( 166 | F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() 167 | ) 168 | z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight) 169 | return z_q 170 | 171 | 172 | class VectorQuantizer(AbstractQuantizer): 173 | """ 174 | ____________________________________________ 175 | Discretization bottleneck part of the VQ-VAE. 176 | Inputs: 177 | - n_e : number of embeddings 178 | - e_dim : dimension of embedding 179 | - beta : commitment cost used in loss term, 180 | beta * ||z_e(x)-sg[e]||^2 181 | _____________________________________________ 182 | """ 183 | 184 | def __init__( 185 | self, 186 | n_e: int, 187 | e_dim: int, 188 | beta: float = 0.25, 189 | remap: Optional[str] = None, 190 | unknown_index: str = "random", 191 | sane_index_shape: bool = False, 192 | log_perplexity: bool = False, 193 | embedding_weight_norm: bool = False, 194 | loss_key: str = "loss/vq", 195 | ): 196 | super().__init__() 197 | self.n_e = n_e 198 | self.e_dim = e_dim 199 | self.beta = beta 200 | self.loss_key = loss_key 201 | 202 | if not embedding_weight_norm: 203 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 204 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 205 | else: 206 | self.embedding = torch.nn.utils.weight_norm( 207 | nn.Embedding(self.n_e, self.e_dim), dim=1 208 | ) 209 | 210 | self.remap = remap 211 | if self.remap is not None: 212 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 213 | self.re_embed = self.used.shape[0] 214 | else: 215 | self.used = None 216 | self.re_embed = n_e 217 | if unknown_index == "extra": 218 | self.unknown_index = self.re_embed 219 | self.re_embed = self.re_embed + 1 220 | else: 221 | assert unknown_index == "random" or isinstance( 222 | unknown_index, int 223 | ), "unknown index needs to be 'random', 'extra' or any integer" 224 | self.unknown_index = unknown_index # "random" or "extra" or integer 225 | if self.remap is not None: 226 | logpy.info( 227 | f"Remapping {self.n_e} indices to {self.re_embed} indices. " 228 | f"Using {self.unknown_index} for unknown indices." 229 | ) 230 | 231 | self.sane_index_shape = sane_index_shape 232 | self.log_perplexity = log_perplexity 233 | 234 | def forward( 235 | self, 236 | z: torch.Tensor, 237 | ) -> Tuple[torch.Tensor, Dict]: 238 | do_reshape = z.ndim == 4 239 | if do_reshape: 240 | # # reshape z -> (batch, height, width, channel) and flatten 241 | z = rearrange(z, "b c h w -> b h w c").contiguous() 242 | 243 | else: 244 | assert z.ndim < 4, "No reshaping strategy for inputs > 4 dimensions defined" 245 | z = z.contiguous() 246 | 247 | z_flattened = z.view(-1, self.e_dim) 248 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 249 | 250 | d = ( 251 | torch.sum(z_flattened**2, dim=1, keepdim=True) 252 | + torch.sum(self.embedding.weight**2, dim=1) 253 | - 2 254 | * torch.einsum( 255 | "bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n") 256 | ) 257 | ) 258 | 259 | min_encoding_indices = torch.argmin(d, dim=1) 260 | z_q = self.embedding(min_encoding_indices).view(z.shape) 261 | loss_dict = {} 262 | if self.log_perplexity: 263 | perplexity, cluster_usage = measure_perplexity( 264 | min_encoding_indices.detach(), self.n_e 265 | ) 266 | loss_dict.update({"perplexity": perplexity, "cluster_usage": cluster_usage}) 267 | 268 | # compute loss for embedding 269 | loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean( 270 | (z_q - z.detach()) ** 2 271 | ) 272 | loss_dict[self.loss_key] = loss 273 | 274 | # preserve gradients 275 | z_q = z + (z_q - z).detach() 276 | 277 | # reshape back to match original input shape 278 | if do_reshape: 279 | z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() 280 | 281 | if self.remap is not None: 282 | min_encoding_indices = min_encoding_indices.reshape( 283 | z.shape[0], -1 284 | ) # add batch axis 285 | min_encoding_indices = self.remap_to_used(min_encoding_indices) 286 | min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten 287 | 288 | if self.sane_index_shape: 289 | if do_reshape: 290 | min_encoding_indices = min_encoding_indices.reshape( 291 | z_q.shape[0], z_q.shape[2], z_q.shape[3] 292 | ) 293 | else: 294 | min_encoding_indices = rearrange( 295 | min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0] 296 | ) 297 | 298 | loss_dict["min_encoding_indices"] = min_encoding_indices 299 | 300 | return z_q, loss_dict 301 | 302 | def get_codebook_entry( 303 | self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None 304 | ) -> torch.Tensor: 305 | # shape specifying (batch, height, width, channel) 306 | if self.remap is not None: 307 | assert shape is not None, "Need to give shape for remap" 308 | indices = indices.reshape(shape[0], -1) # add batch axis 309 | indices = self.unmap_to_all(indices) 310 | indices = indices.reshape(-1) # flatten again 311 | 312 | # get quantized latent vectors 313 | z_q = self.embedding(indices) 314 | 315 | if shape is not None: 316 | z_q = z_q.view(shape) 317 | # reshape back to match original input shape 318 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 319 | 320 | return z_q 321 | 322 | 323 | class EmbeddingEMA(nn.Module): 324 | def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): 325 | super().__init__() 326 | self.decay = decay 327 | self.eps = eps 328 | weight = torch.randn(num_tokens, codebook_dim) 329 | self.weight = nn.Parameter(weight, requires_grad=False) 330 | self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False) 331 | self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) 332 | self.update = True 333 | 334 | def forward(self, embed_id): 335 | return F.embedding(embed_id, self.weight) 336 | 337 | def cluster_size_ema_update(self, new_cluster_size): 338 | self.cluster_size.data.mul_(self.decay).add_( 339 | new_cluster_size, alpha=1 - self.decay 340 | ) 341 | 342 | def embed_avg_ema_update(self, new_embed_avg): 343 | self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) 344 | 345 | def weight_update(self, num_tokens): 346 | n = self.cluster_size.sum() 347 | smoothed_cluster_size = ( 348 | (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n 349 | ) 350 | # normalize embedding average with smoothed cluster size 351 | embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) 352 | self.weight.data.copy_(embed_normalized) 353 | 354 | 355 | class EMAVectorQuantizer(AbstractQuantizer): 356 | def __init__( 357 | self, 358 | n_embed: int, 359 | embedding_dim: int, 360 | beta: float, 361 | decay: float = 0.99, 362 | eps: float = 1e-5, 363 | remap: Optional[str] = None, 364 | unknown_index: str = "random", 365 | loss_key: str = "loss/vq", 366 | ): 367 | super().__init__() 368 | self.codebook_dim = embedding_dim 369 | self.num_tokens = n_embed 370 | self.beta = beta 371 | self.loss_key = loss_key 372 | 373 | self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) 374 | 375 | self.remap = remap 376 | if self.remap is not None: 377 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 378 | self.re_embed = self.used.shape[0] 379 | else: 380 | self.used = None 381 | self.re_embed = n_embed 382 | if unknown_index == "extra": 383 | self.unknown_index = self.re_embed 384 | self.re_embed = self.re_embed + 1 385 | else: 386 | assert unknown_index == "random" or isinstance( 387 | unknown_index, int 388 | ), "unknown index needs to be 'random', 'extra' or any integer" 389 | self.unknown_index = unknown_index # "random" or "extra" or integer 390 | if self.remap is not None: 391 | logpy.info( 392 | f"Remapping {self.n_embed} indices to {self.re_embed} indices. " 393 | f"Using {self.unknown_index} for unknown indices." 394 | ) 395 | 396 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]: 397 | # reshape z -> (batch, height, width, channel) and flatten 398 | # z, 'b c h w -> b h w c' 399 | z = rearrange(z, "b c h w -> b h w c") 400 | z_flattened = z.reshape(-1, self.codebook_dim) 401 | 402 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 403 | d = ( 404 | z_flattened.pow(2).sum(dim=1, keepdim=True) 405 | + self.embedding.weight.pow(2).sum(dim=1) 406 | - 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight) 407 | ) # 'n d -> d n' 408 | 409 | encoding_indices = torch.argmin(d, dim=1) 410 | 411 | z_q = self.embedding(encoding_indices).view(z.shape) 412 | encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) 413 | avg_probs = torch.mean(encodings, dim=0) 414 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) 415 | 416 | if self.training and self.embedding.update: 417 | # EMA cluster size 418 | encodings_sum = encodings.sum(0) 419 | self.embedding.cluster_size_ema_update(encodings_sum) 420 | # EMA embedding average 421 | embed_sum = encodings.transpose(0, 1) @ z_flattened 422 | self.embedding.embed_avg_ema_update(embed_sum) 423 | # normalize embed_avg and update weight 424 | self.embedding.weight_update(self.num_tokens) 425 | 426 | # compute loss for embedding 427 | loss = self.beta * F.mse_loss(z_q.detach(), z) 428 | 429 | # preserve gradients 430 | z_q = z + (z_q - z).detach() 431 | 432 | # reshape back to match original input shape 433 | # z_q, 'b h w c -> b c h w' 434 | z_q = rearrange(z_q, "b h w c -> b c h w") 435 | 436 | out_dict = { 437 | self.loss_key: loss, 438 | "encodings": encodings, 439 | "encoding_indices": encoding_indices, 440 | "perplexity": perplexity, 441 | } 442 | 443 | return z_q, out_dict 444 | 445 | 446 | class VectorQuantizerWithInputProjection(VectorQuantizer): 447 | def __init__( 448 | self, 449 | input_dim: int, 450 | n_codes: int, 451 | codebook_dim: int, 452 | beta: float = 1.0, 453 | output_dim: Optional[int] = None, 454 | **kwargs, 455 | ): 456 | super().__init__(n_codes, codebook_dim, beta, **kwargs) 457 | self.proj_in = nn.Linear(input_dim, codebook_dim) 458 | self.output_dim = output_dim 459 | if output_dim is not None: 460 | self.proj_out = nn.Linear(codebook_dim, output_dim) 461 | else: 462 | self.proj_out = nn.Identity() 463 | 464 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]: 465 | rearr = False 466 | in_shape = z.shape 467 | 468 | if z.ndim > 3: 469 | rearr = self.output_dim is not None 470 | z = rearrange(z, "b c ... -> b (...) c") 471 | z = self.proj_in(z) 472 | z_q, loss_dict = super().forward(z) 473 | 474 | z_q = self.proj_out(z_q) 475 | if rearr: 476 | if len(in_shape) == 4: 477 | z_q = rearrange(z_q, "b (h w) c -> b c h w ", w=in_shape[-1]) 478 | elif len(in_shape) == 5: 479 | z_q = rearrange( 480 | z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2] 481 | ) 482 | else: 483 | raise NotImplementedError( 484 | f"rearranging not available for {len(in_shape)}-dimensional input." 485 | ) 486 | 487 | return z_q, loss_dict 488 | --------------------------------------------------------------------------------