├── core ├── datasets │ ├── __init__.py │ ├── .DS_Store │ ├── augmentation.py │ ├── turbulence │ │ ├── torch_dataset.py │ │ └── data_module.py │ └── normalizer.py ├── ldm │ ├── data │ │ ├── __init__.py │ │ └── util.py │ ├── models │ │ └── diffusion │ │ │ ├── __init__.py │ │ │ ├── dpm_solver │ │ │ ├── __init__.py │ │ │ └── sampler.py │ │ │ └── sampling_util.py │ ├── modules │ │ ├── encoders │ │ │ └── __init__.py │ │ ├── midas │ │ │ ├── __init__.py │ │ │ ├── midas │ │ │ │ ├── __init__.py │ │ │ │ ├── base_model.py │ │ │ │ ├── midas_net.py │ │ │ │ ├── dpt_depth.py │ │ │ │ └── midas_net_custom.py │ │ │ └── utils.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ ├── upscaling.py │ │ │ └── warping_unet.py │ │ ├── losses │ │ │ └── __init__.py │ │ ├── image_degradation │ │ │ ├── utils │ │ │ │ └── test.png │ │ │ └── __init__.py │ │ └── ema.py │ └── .DS_Store ├── metrics │ └── __init__.py ├── sgm │ ├── modules │ │ ├── encoders │ │ │ └── __init__.py │ │ ├── autoencoding │ │ │ ├── __init__.py │ │ │ ├── lpips │ │ │ │ ├── __init__.py │ │ │ │ ├── loss │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── .gitignore │ │ │ │ │ ├── LICENSE │ │ │ │ │ └── lpips.py │ │ │ │ ├── model │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── model.py │ │ │ │ │ └── LICENSE │ │ │ │ ├── vqperceptual.py │ │ │ │ └── util.py │ │ │ ├── losses │ │ │ │ ├── __init__.py │ │ │ │ └── lpips.py │ │ │ └── regularizers │ │ │ │ ├── __init__.py │ │ │ │ └── base.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ ├── denoiser_weighting.py │ │ │ ├── loss_weighting.py │ │ │ ├── sampling_utils.py │ │ │ ├── sigma_sampling.py │ │ │ ├── wrappers.py │ │ │ ├── denoiser_scaling.py │ │ │ ├── discretizer.py │ │ │ ├── denoiser.py │ │ │ └── loss.py │ │ ├── __init__.py │ │ └── ema.py │ ├── data │ │ ├── __init__.py │ │ ├── cifar10.py │ │ ├── mnist.py │ │ └── dataset.py │ ├── .DS_Store │ ├── models │ │ └── __init__.py │ ├── __init__.py │ └── lr_scheduler.py ├── .DS_Store ├── dydiff │ ├── .DS_Store │ └── datasets │ │ ├── bair.py │ │ └── turbulence.py ├── logger │ ├── .DS_Store │ ├── visualization_turbulence.py │ └── logger_turbulence.py ├── models │ ├── .DS_Store │ └── turbulence │ │ ├── video_dit_B_baseline_1st_mask_svd_lr1e-4.yaml │ │ ├── video_dit_B_ratio05_1st_mask_svd_lr1e-4.yaml │ │ ├── video_dit_S_baseline_1st_mask_svd_lr1e-4.yaml │ │ ├── video_dit_S_ratio05_1st_mask_svd_lr1e-4.yaml │ │ ├── baseline_1st_mask_svd_lr1e-4.yaml │ │ ├── dfdiff_ema_cosine_ratio01_1st_mask_svd_lr1e-4.yaml │ │ ├── dfdiff_ema_cosine_ratio09_1st_mask_svd_lr1e-4.yaml │ │ ├── dydiff_ema_cosine_ratio05_1st_mask_svd_lr1e-4.yaml │ │ └── dydiff_ema_cosine_ratio10_1st_mask_svd_lr1e-4.yaml ├── taming │ ├── .DS_Store │ ├── modules │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── segmentation.py │ │ │ └── lpips.py │ │ ├── autoencoder │ │ │ └── lpips │ │ │ │ └── vgg.pth │ │ ├── misc │ │ │ └── coord.py │ │ ├── discriminator │ │ │ └── model.py │ │ └── util.py │ ├── models │ │ └── dummy_cond_stage.py │ ├── data │ │ ├── custom.py │ │ ├── helper_types.py │ │ ├── base.py │ │ ├── conditional_builder │ │ │ ├── objects_bbox.py │ │ │ └── utils.py │ │ ├── sflckr.py │ │ ├── image_transforms.py │ │ ├── faceshq.py │ │ └── ade20k.py │ ├── lr_scheduler.py │ └── util.py └── train_turbulence_dydiff.py ├── environment.yaml ├── README.md └── .gitignore /core/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/ldm/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/ldm/modules/midas/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/sgm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/ldm/modules/midas/midas/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/sgm/modules/autoencoding/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/sgm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/sgm/modules/autoencoding/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/sgm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/sgm/modules/autoencoding/lpips/loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/sgm/modules/autoencoding/lpips/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/sgm/modules/autoencoding/lpips/loss/.gitignore: -------------------------------------------------------------------------------- 1 | vgg.pth -------------------------------------------------------------------------------- /core/sgm/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import StableDataModuleFromConfig 2 | -------------------------------------------------------------------------------- /core/ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /core/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/dynamical-diffusion/HEAD/core/.DS_Store -------------------------------------------------------------------------------- /core/ldm/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/dynamical-diffusion/HEAD/core/ldm/.DS_Store -------------------------------------------------------------------------------- /core/sgm/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/dynamical-diffusion/HEAD/core/sgm/.DS_Store -------------------------------------------------------------------------------- /core/dydiff/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/dynamical-diffusion/HEAD/core/dydiff/.DS_Store -------------------------------------------------------------------------------- /core/ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /core/logger/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/dynamical-diffusion/HEAD/core/logger/.DS_Store -------------------------------------------------------------------------------- /core/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/dynamical-diffusion/HEAD/core/models/.DS_Store -------------------------------------------------------------------------------- /core/taming/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/dynamical-diffusion/HEAD/core/taming/.DS_Store -------------------------------------------------------------------------------- /core/datasets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/dynamical-diffusion/HEAD/core/datasets/.DS_Store -------------------------------------------------------------------------------- /core/taming/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from taming.modules.losses.vqperceptual import DummyLoss 2 | 3 | -------------------------------------------------------------------------------- /core/sgm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .autoencoder import AutoencodingEngine 2 | from .diffusion import DiffusionEngine 3 | -------------------------------------------------------------------------------- /core/taming/modules/autoencoder/lpips/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/dynamical-diffusion/HEAD/core/taming/modules/autoencoder/lpips/vgg.pth -------------------------------------------------------------------------------- /core/ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/dynamical-diffusion/HEAD/core/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /core/sgm/__init__.py: -------------------------------------------------------------------------------- 1 | # from .models import AutoencodingEngine, DiffusionEngine 2 | # from .util import get_configs_path, instantiate_from_config 3 | 4 | __version__ = "0.1.0" 5 | -------------------------------------------------------------------------------- /core/sgm/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # from .encoders.modules import GeneralConditioner 2 | 3 | UNCONDITIONAL_CONFIG = { 4 | "target": "sgm.modules.GeneralConditioner", 5 | "params": {"emb_models": []}, 6 | } 7 | -------------------------------------------------------------------------------- /core/sgm/modules/autoencoding/losses/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "GeneralLPIPSWithDiscriminator", 3 | "LatentLPIPS", 4 | ] 5 | 6 | from .discriminator_loss import GeneralLPIPSWithDiscriminator 7 | from .lpips import LatentLPIPS 8 | -------------------------------------------------------------------------------- /core/ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /core/ldm/modules/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /core/taming/models/dummy_cond_stage.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | 4 | class DummyCondStage: 5 | def __init__(self, conditional_key): 6 | self.conditional_key = conditional_key 7 | self.train = None 8 | 9 | def eval(self): 10 | return self 11 | 12 | @staticmethod 13 | def encode(c: Tensor): 14 | return c, None, (None, None, c) 15 | 16 | @staticmethod 17 | def decode(c: Tensor): 18 | return c 19 | 20 | @staticmethod 21 | def to_rgb(c: Tensor): 22 | return c 23 | -------------------------------------------------------------------------------- /core/sgm/modules/autoencoding/lpips/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def hinge_d_loss(logits_real, logits_fake): 6 | loss_real = torch.mean(F.relu(1.0 - logits_real)) 7 | loss_fake = torch.mean(F.relu(1.0 + logits_fake)) 8 | d_loss = 0.5 * (loss_real + loss_fake) 9 | return d_loss 10 | 11 | 12 | def vanilla_d_loss(logits_real, logits_fake): 13 | d_loss = 0.5 * ( 14 | torch.mean(torch.nn.functional.softplus(-logits_real)) 15 | + torch.mean(torch.nn.functional.softplus(logits_fake)) 16 | ) 17 | return d_loss 18 | -------------------------------------------------------------------------------- /core/sgm/modules/diffusionmodules/denoiser_weighting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class UnitWeighting: 5 | def __call__(self, sigma): 6 | return torch.ones_like(sigma, device=sigma.device) 7 | 8 | 9 | class EDMWeighting: 10 | def __init__(self, sigma_data=0.5): 11 | self.sigma_data = sigma_data 12 | 13 | def __call__(self, sigma): 14 | return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 15 | 16 | 17 | class VWeighting(EDMWeighting): 18 | def __init__(self): 19 | super().__init__(sigma_data=1.0) 20 | 21 | 22 | class EpsWeighting: 23 | def __call__(self, sigma): 24 | return sigma**-2.0 25 | -------------------------------------------------------------------------------- /core/logger/visualization_turbulence.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def save_plots(fig_names, outputs, save_root): 7 | os.makedirs(save_root, exist_ok=True) 8 | h, w = outputs.shape[-2:] 9 | outputs = outputs.reshape(-1, 2, h, w) 10 | 11 | for (fig_name, output) in zip(fig_names, outputs): 12 | x, y = output 13 | plt.imshow(x, cmap='coolwarm', vmin=-1, vmax=1) 14 | plt.savefig(os.path.join(save_root, 'x_' + fig_name)) 15 | plt.close() 16 | plt.imshow(y, cmap='coolwarm', vmin=-1, vmax=1) 17 | plt.savefig(os.path.join(save_root, 'y_' + fig_name)) 18 | plt.close() 19 | 20 | -------------------------------------------------------------------------------- /core/ldm/data/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ldm.modules.midas.api import load_midas_transform 4 | 5 | 6 | class AddMiDaS(object): 7 | def __init__(self, model_type): 8 | super().__init__() 9 | self.transform = load_midas_transform(model_type) 10 | 11 | def pt2np(self, x): 12 | x = ((x + 1.0) * .5).detach().cpu().numpy() 13 | return x 14 | 15 | def np2pt(self, x): 16 | x = torch.from_numpy(x) * 2 - 1. 17 | return x 18 | 19 | def __call__(self, sample): 20 | # sample['jpg'] is tensor hwc in [-1, 1] at this point 21 | x = self.pt2np(sample['jpg']) 22 | x = self.transform({"image": x})["image"] 23 | sample['midas_in'] = x 24 | return sample -------------------------------------------------------------------------------- /core/datasets/augmentation.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | import random 3 | from torch import nn 4 | import torchvision.transforms.functional as TF 5 | 6 | 7 | class TransformsFixRotation(nn.Module): 8 | r""" 9 | Rotate by one of the given angles. 10 | 11 | Example: `rotation_transform = MyRotationTransform(angles=[-30, -15, 0, 15, 30])` 12 | """ 13 | 14 | def __init__(self, angles): 15 | super(TransformsFixRotation, self).__init__() 16 | if not isinstance(angles, Sequence): 17 | angles = [angles, ] 18 | self.angles = angles 19 | 20 | def forward(self, x): 21 | angle = random.choice(self.angles) 22 | return TF.rotate(x, angle) 23 | 24 | def __repr__(self) -> str: 25 | return f"{self.__class__.__name__}(angles={self.angles})" 26 | -------------------------------------------------------------------------------- /core/ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def norm_thresholding(x0, value): 15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 16 | return x0 * (value / s) 17 | 18 | 19 | def spatial_norm_thresholding(x0, value): 20 | # b c h w 21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 22 | return x0 * (value / s) -------------------------------------------------------------------------------- /core/taming/modules/losses/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class BCELoss(nn.Module): 6 | def forward(self, prediction, target): 7 | loss = F.binary_cross_entropy_with_logits(prediction,target) 8 | return loss, {} 9 | 10 | 11 | class BCELossWithQuant(nn.Module): 12 | def __init__(self, codebook_weight=1.): 13 | super().__init__() 14 | self.codebook_weight = codebook_weight 15 | 16 | def forward(self, qloss, target, prediction, split): 17 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target) 18 | loss = bce_loss + self.codebook_weight*qloss 19 | return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(), 20 | "{}/bce_loss".format(split): bce_loss.detach().mean(), 21 | "{}/quant_loss".format(split): qloss.detach().mean() 22 | } 23 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: dydiff 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.12.1 10 | - torchvision=0.13.1 11 | - numpy=1.23.1 12 | - pip: 13 | - gradio==3.16.2 14 | - albumentations==1.3.0 15 | - opencv-contrib-python==4.3.0.36 16 | - imageio==2.9.0 17 | - imageio-ffmpeg==0.4.2 18 | - pytorch-lightning==1.5.0 19 | - omegaconf==2.1.1 20 | - test-tube>=0.7.5 21 | - streamlit==1.12.1 22 | - einops==0.3.0 23 | - transformers==4.19.2 24 | - webdataset==0.2.5 25 | - kornia==0.6 26 | - open_clip_torch==2.0.2 27 | - invisible-watermark>=0.1.5 28 | - streamlit-drawable-canvas==0.8.0 29 | - torchmetrics==0.6.0 30 | - timm==0.6.12 31 | - addict==2.4.0 32 | - yapf==0.32.0 33 | - prettytable==3.6.0 34 | - safetensors==0.2.7 35 | - basicsr==1.4.2 36 | - xformers==0.0.21 -------------------------------------------------------------------------------- /core/sgm/modules/diffusionmodules/loss_weighting.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | 5 | 6 | class DiffusionLossWeighting(ABC): 7 | @abstractmethod 8 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor: 9 | pass 10 | 11 | 12 | class UnitWeighting(DiffusionLossWeighting): 13 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor: 14 | return torch.ones_like(sigma, device=sigma.device) 15 | 16 | 17 | class EDMWeighting(DiffusionLossWeighting): 18 | def __init__(self, sigma_data: float = 0.5): 19 | self.sigma_data = sigma_data 20 | 21 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor: 22 | return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 23 | 24 | 25 | class VWeighting(EDMWeighting): 26 | def __init__(self): 27 | super().__init__(sigma_data=1.0) 28 | 29 | 30 | class EpsWeighting(DiffusionLossWeighting): 31 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor: 32 | return sigma**-2.0 33 | -------------------------------------------------------------------------------- /core/sgm/modules/autoencoding/regularizers/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from ....modules.distributions.distributions import \ 9 | DiagonalGaussianDistribution 10 | from .base import AbstractRegularizer 11 | 12 | 13 | class DiagonalGaussianRegularizer(AbstractRegularizer): 14 | def __init__(self, sample: bool = True): 15 | super().__init__() 16 | self.sample = sample 17 | 18 | def get_trainable_parameters(self) -> Any: 19 | yield from () 20 | 21 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 22 | log = dict() 23 | posterior = DiagonalGaussianDistribution(z) 24 | if self.sample: 25 | z = posterior.sample() 26 | else: 27 | z = posterior.mode() 28 | kl_loss = posterior.kl() 29 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 30 | log["kl_loss"] = kl_loss 31 | return z, log 32 | -------------------------------------------------------------------------------- /core/taming/modules/misc/coord.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class CoordStage(object): 4 | def __init__(self, n_embed, down_factor): 5 | self.n_embed = n_embed 6 | self.down_factor = down_factor 7 | 8 | def eval(self): 9 | return self 10 | 11 | def encode(self, c): 12 | """fake vqmodel interface""" 13 | assert 0.0 <= c.min() and c.max() <= 1.0 14 | b,ch,h,w = c.shape 15 | assert ch == 1 16 | 17 | c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor, 18 | mode="area") 19 | c = c.clamp(0.0, 1.0) 20 | c = self.n_embed*c 21 | c_quant = c.round() 22 | c_ind = c_quant.to(dtype=torch.long) 23 | 24 | info = None, None, c_ind 25 | return c_quant, None, info 26 | 27 | def decode(self, c): 28 | c = c/self.n_embed 29 | c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor, 30 | mode="nearest") 31 | return c 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dynamical-diffusion 2 | About Code release for "Dynamical Diffusion: Learning Temporal Dynamics with Diffusion Models" (ICLR 2025) 3 | 4 | ## Usage 5 | 6 | ### Training 7 | 8 | To train a model (e.g., DyDiff for Turbulence), run: 9 | 10 | ```bash 11 | cd core; python train_turbulence_dydiff.py --config_file models/turbulence/dydiff_ema_cosine_ratio05_1st_mask_svd_lr1e-4.yaml 12 | ``` 13 | 14 | **Note:** Before running, we need to update the config file with the `root` for data and the `ckpt_path` for VAE model. 15 | 16 | ### Sampling 17 | 18 | To sampling with the trained model, run: 19 | 20 | ```bash 21 | cd core; python train_turbulence_dydiff.py --config_file models/turbulence/dydiff_ema_cosine_ratio05_1st_mask_svd_lr1e-4.yaml --resume ${model_ckpt} --test 22 | ``` 23 | 24 | This will generate samples in the `logs` directory. 25 | 26 | ### Evaluation 27 | 28 | After generating samples, evaluate them using the following command: 29 | 30 | ```bash 31 | python core/evaluation/evaluate_turbulence.py --model_output_root logs/turbulence/dydiff_ema_cosine_ratio05_1st_mask_svd_lr1e-4/output_for_evaluation --i3d_model_path ${pretrained_i3d_model} 32 | ``` 33 | -------------------------------------------------------------------------------- /core/taming/data/custom.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import albumentations 4 | from torch.utils.data import Dataset 5 | 6 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex 7 | 8 | 9 | class CustomBase(Dataset): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__() 12 | self.data = None 13 | 14 | def __len__(self): 15 | return len(self.data) 16 | 17 | def __getitem__(self, i): 18 | example = self.data[i] 19 | return example 20 | 21 | 22 | 23 | class CustomTrain(CustomBase): 24 | def __init__(self, size, training_images_list_file): 25 | super().__init__() 26 | with open(training_images_list_file, "r") as f: 27 | paths = f.read().splitlines() 28 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 29 | 30 | 31 | class CustomTest(CustomBase): 32 | def __init__(self, size, test_images_list_file): 33 | super().__init__() 34 | with open(test_images_list_file, "r") as f: 35 | paths = f.read().splitlines() 36 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 37 | 38 | 39 | -------------------------------------------------------------------------------- /core/sgm/modules/diffusionmodules/sampling_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy import integrate 3 | 4 | from ...util import append_dims 5 | 6 | 7 | def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): 8 | if order - 1 > i: 9 | raise ValueError(f"Order {order} too high for step {i}") 10 | 11 | def fn(tau): 12 | prod = 1.0 13 | for k in range(order): 14 | if j == k: 15 | continue 16 | prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) 17 | return prod 18 | 19 | return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0] 20 | 21 | 22 | def get_ancestral_step(sigma_from, sigma_to, eta=1.0): 23 | if not eta: 24 | return sigma_to, 0.0 25 | sigma_up = torch.minimum( 26 | sigma_to, 27 | eta 28 | * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, 29 | ) 30 | sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 31 | return sigma_down, sigma_up 32 | 33 | 34 | def to_d(x, sigma, denoised): 35 | return (x - denoised) / append_dims(sigma, x.ndim) 36 | 37 | 38 | def to_neg_log_sigma(sigma): 39 | return sigma.log().neg() 40 | 41 | 42 | def to_sigma(neg_log_sigma): 43 | return neg_log_sigma.neg().exp() 44 | -------------------------------------------------------------------------------- /core/datasets/turbulence/torch_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | from time import time 5 | import os 6 | 7 | 8 | class Turbulence(Dataset): 9 | def __init__(self, root, filename, total_length, n_subsample=None): 10 | self.root = os.path.join(root, filename) 11 | self.total_length = total_length 12 | self.data_cube = np.load(self.root).astype(np.float32) 13 | n, l, c, h, w = self.data_cube.shape 14 | self.shape = (self.total_length, c, h, w) 15 | self.num_samples = n 16 | if n_subsample is not None: 17 | self.num_samples = min(self.num_samples, n_subsample) 18 | 19 | def __getitem__(self, index): 20 | # self.cal_stair(self.data_cube[index, :self.total_length].reshape(self.shape)) 21 | # data = torch.from_numpy(self.data_cube[index, :self.total_length].reshape(self.shape)) / 4 22 | # print("channel 0:", data[:, 0].min(), data[:, 0].max(), data[:, 0].mean(), data[:, 0].std()) 23 | # print("channel 1:", data[:, 1].min(), data[:, 1].max(), data[:, 1].mean(), data[:, 1].std()) 24 | return torch.from_numpy(self.data_cube[index, :self.total_length].reshape(self.shape)) / 4 25 | 26 | def __len__(self): 27 | return self.num_samples 28 | -------------------------------------------------------------------------------- /core/sgm/modules/diffusionmodules/sigma_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, Union 3 | from ...util import default, instantiate_from_config 4 | 5 | 6 | class EDMSampling: 7 | def __init__(self, p_mean=-1.2, p_std=1.2): 8 | self.p_mean = p_mean 9 | self.p_std = p_std 10 | 11 | def __call__(self, n_samples, rand=None): 12 | log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,))) 13 | return log_sigma.exp() 14 | 15 | 16 | class DiscreteSampling: 17 | def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True): 18 | self.num_idx = num_idx 19 | self.sigmas = instantiate_from_config(discretization_config)( 20 | num_idx, do_append_zero=do_append_zero, flip=flip 21 | ) 22 | 23 | def idx_to_sigma(self, idx): 24 | return self.sigmas[idx] 25 | 26 | def __call__(self, n_samples, rand=None): 27 | idx = default( 28 | rand, 29 | torch.randint(0, self.num_idx, (n_samples,)), 30 | ) 31 | return self.idx_to_sigma(idx) 32 | 33 | 34 | class ZeroSampler: 35 | def __call__( 36 | self, n_samples: int, rand: Optional[torch.Tensor] = None 37 | ) -> torch.Tensor: 38 | return torch.zeros_like(default(rand, torch.randn((n_samples,)))) + 1.0e-5 39 | -------------------------------------------------------------------------------- /core/taming/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n): 33 | return self.schedule(n) 34 | 35 | -------------------------------------------------------------------------------- /core/sgm/modules/autoencoding/regularizers/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | 9 | class AbstractRegularizer(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 14 | raise NotImplementedError() 15 | 16 | @abstractmethod 17 | def get_trainable_parameters(self) -> Any: 18 | raise NotImplementedError() 19 | 20 | 21 | class IdentityRegularizer(AbstractRegularizer): 22 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 23 | return z, dict() 24 | 25 | def get_trainable_parameters(self) -> Any: 26 | yield from () 27 | 28 | 29 | def measure_perplexity( 30 | predicted_indices: torch.Tensor, num_centroids: int 31 | ) -> Tuple[torch.Tensor, torch.Tensor]: 32 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 33 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 34 | encodings = ( 35 | F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) 36 | ) 37 | avg_probs = encodings.mean(0) 38 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 39 | cluster_use = torch.sum(avg_probs > 0) 40 | return perplexity, cluster_use 41 | -------------------------------------------------------------------------------- /core/sgm/modules/autoencoding/lpips/loss/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /core/taming/data/helper_types.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple, Optional, NamedTuple, Union 2 | from PIL.Image import Image as pil_image 3 | from torch import Tensor 4 | 5 | try: 6 | from typing import Literal 7 | except ImportError: 8 | from typing_extensions import Literal 9 | 10 | Image = Union[Tensor, pil_image] 11 | BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h 12 | CropMethodType = Literal['none', 'random', 'center', 'random-2d'] 13 | SplitType = Literal['train', 'validation', 'test'] 14 | 15 | 16 | class ImageDescription(NamedTuple): 17 | id: int 18 | file_name: str 19 | original_size: Tuple[int, int] # w, h 20 | url: Optional[str] = None 21 | license: Optional[int] = None 22 | coco_url: Optional[str] = None 23 | date_captured: Optional[str] = None 24 | flickr_url: Optional[str] = None 25 | flickr_id: Optional[str] = None 26 | coco_id: Optional[str] = None 27 | 28 | 29 | class Category(NamedTuple): 30 | id: str 31 | super_category: Optional[str] 32 | name: str 33 | 34 | 35 | class Annotation(NamedTuple): 36 | area: float 37 | image_id: str 38 | bbox: BoundingBox 39 | category_no: int 40 | category_id: str 41 | id: Optional[int] = None 42 | source: Optional[str] = None 43 | confidence: Optional[float] = None 44 | is_group_of: Optional[bool] = None 45 | is_truncated: Optional[bool] = None 46 | is_occluded: Optional[bool] = None 47 | is_depiction: Optional[bool] = None 48 | is_inside: Optional[bool] = None 49 | segmentation: Optional[Dict] = None 50 | -------------------------------------------------------------------------------- /core/sgm/modules/diffusionmodules/wrappers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from packaging import version 4 | 5 | OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" 6 | 7 | 8 | class IdentityWrapper(nn.Module): 9 | def __init__(self, diffusion_model, compile_model: bool = False): 10 | super().__init__() 11 | compile = ( 12 | torch.compile 13 | if (version.parse(torch.__version__) >= version.parse("2.0.0")) 14 | and compile_model 15 | else lambda x: x 16 | ) 17 | self.diffusion_model = compile(diffusion_model) 18 | 19 | def forward(self, *args, **kwargs): 20 | return self.diffusion_model(*args, **kwargs) 21 | 22 | 23 | class OpenAIWrapper(IdentityWrapper): 24 | def forward( 25 | self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs 26 | ) -> torch.Tensor: 27 | x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) 28 | if "cond_view" in c: 29 | return self.diffusion_model( 30 | x, 31 | timesteps=t, 32 | context=c.get("crossattn", None), 33 | y=c.get("vector", None), 34 | cond_view=c.get("cond_view", None), 35 | cond_motion=c.get("cond_motion", None), 36 | **kwargs, 37 | ) 38 | else: 39 | return self.diffusion_model( 40 | x, 41 | timesteps=t, 42 | context=c.get("crossattn", None), 43 | y=c.get("vector", None), 44 | **kwargs, 45 | ) 46 | -------------------------------------------------------------------------------- /core/datasets/normalizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class NullNormalizer: 6 | """ 7 | Identity normalizer. 8 | """ 9 | @staticmethod 10 | def normalize(x: torch.Tensor): 11 | return x 12 | 13 | @staticmethod 14 | def denormalize(v: torch.Tensor): 15 | return v 16 | 17 | 18 | class PrecipitationNormalizer: 19 | """ 20 | First transform input x into log_10(1+x), then normalize into [vmin, vmax]. 21 | """ 22 | xmin = 0. 23 | xmax = 128. 24 | log_xmin = np.log10(1 + xmin) 25 | log_xmax = np.log10(1 + xmax) 26 | vmin = -1. 27 | vmax = 1. 28 | 29 | @classmethod 30 | def normalize(cls, x: torch.Tensor, affine: bool = True): 31 | x = torch.clamp(x, cls.xmin, cls.xmax) 32 | log_x = torch.log10(x + 1) 33 | if affine: 34 | return (log_x - cls.log_xmin) / (cls.log_xmax - cls.log_xmin) * (cls.vmax - cls.vmin) + cls.vmin 35 | else: 36 | return log_x 37 | 38 | @classmethod 39 | def denormalize(cls, v: torch.Tensor, affine: bool = True): 40 | v = torch.clamp(v, cls.vmin, cls.vmax) 41 | if affine: 42 | log_x = (v - cls.vmin) / (cls.vmax - cls.vmin) * (cls.log_xmax - cls.log_xmin) + cls.log_xmin 43 | else: 44 | log_x = v 45 | return 10 ** log_x - 1 46 | 47 | 48 | class RGBNormalizer: 49 | """ 50 | Normalize RGB images from [0, 1] to [-1, 1]. 51 | """ 52 | x_min = 0. 53 | x_max = 1. 54 | v_min = -1. 55 | v_max = 1. 56 | 57 | @staticmethod 58 | def normalize(x: torch.Tensor): 59 | x = torch.clamp(x, RGBNormalizer.x_min, RGBNormalizer.x_max) 60 | return x * 2 - 1 61 | 62 | @staticmethod 63 | def denormalize(v: torch.Tensor): 64 | v = torch.clamp(v, RGBNormalizer.v_min, RGBNormalizer.v_max) 65 | return (v + 1) / 2 66 | -------------------------------------------------------------------------------- /core/sgm/modules/diffusionmodules/denoiser_scaling.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Tuple 3 | 4 | import torch 5 | 6 | 7 | class DenoiserScaling(ABC): 8 | @abstractmethod 9 | def __call__( 10 | self, sigma: torch.Tensor 11 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 12 | pass 13 | 14 | 15 | class EDMScaling: 16 | def __init__(self, sigma_data: float = 0.5): 17 | self.sigma_data = sigma_data 18 | 19 | def __call__( 20 | self, sigma: torch.Tensor 21 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 22 | c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) 23 | c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 24 | c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 25 | c_noise = 0.25 * sigma.log() 26 | return c_skip, c_out, c_in, c_noise 27 | 28 | 29 | class EpsScaling: 30 | def __call__( 31 | self, sigma: torch.Tensor 32 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 33 | c_skip = torch.ones_like(sigma, device=sigma.device) 34 | c_out = -sigma 35 | c_in = 1 / (sigma**2 + 1.0) ** 0.5 36 | c_noise = sigma.clone() 37 | return c_skip, c_out, c_in, c_noise 38 | 39 | 40 | class VScaling: 41 | def __call__( 42 | self, sigma: torch.Tensor 43 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 44 | c_skip = 1.0 / (sigma**2 + 1.0) 45 | c_out = -sigma / (sigma**2 + 1.0) ** 0.5 46 | c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 47 | c_noise = sigma.clone() 48 | return c_skip, c_out, c_in, c_noise 49 | 50 | 51 | class VScalingWithEDMcNoise(DenoiserScaling): 52 | def __call__( 53 | self, sigma: torch.Tensor 54 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 55 | c_skip = 1.0 / (sigma**2 + 1.0) 56 | c_out = -sigma / (sigma**2 + 1.0) ** 0.5 57 | c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 58 | c_noise = 0.25 * sigma.log() 59 | return c_skip, c_out, c_in, c_noise 60 | -------------------------------------------------------------------------------- /core/sgm/data/cifar10.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torchvision 3 | from torch.utils.data import DataLoader, Dataset 4 | from torchvision import transforms 5 | 6 | 7 | class CIFAR10DataDictWrapper(Dataset): 8 | def __init__(self, dset): 9 | super().__init__() 10 | self.dset = dset 11 | 12 | def __getitem__(self, i): 13 | x, y = self.dset[i] 14 | return {"jpg": x, "cls": y} 15 | 16 | def __len__(self): 17 | return len(self.dset) 18 | 19 | 20 | class CIFAR10Loader(pl.LightningDataModule): 21 | def __init__(self, batch_size, num_workers=0, shuffle=True): 22 | super().__init__() 23 | 24 | transform = transforms.Compose( 25 | [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] 26 | ) 27 | 28 | self.batch_size = batch_size 29 | self.num_workers = num_workers 30 | self.shuffle = shuffle 31 | self.train_dataset = CIFAR10DataDictWrapper( 32 | torchvision.datasets.CIFAR10( 33 | root=".data/", train=True, download=True, transform=transform 34 | ) 35 | ) 36 | self.test_dataset = CIFAR10DataDictWrapper( 37 | torchvision.datasets.CIFAR10( 38 | root=".data/", train=False, download=True, transform=transform 39 | ) 40 | ) 41 | 42 | def prepare_data(self): 43 | pass 44 | 45 | def train_dataloader(self): 46 | return DataLoader( 47 | self.train_dataset, 48 | batch_size=self.batch_size, 49 | shuffle=self.shuffle, 50 | num_workers=self.num_workers, 51 | ) 52 | 53 | def test_dataloader(self): 54 | return DataLoader( 55 | self.test_dataset, 56 | batch_size=self.batch_size, 57 | shuffle=self.shuffle, 58 | num_workers=self.num_workers, 59 | ) 60 | 61 | def val_dataloader(self): 62 | return DataLoader( 63 | self.test_dataset, 64 | batch_size=self.batch_size, 65 | shuffle=self.shuffle, 66 | num_workers=self.num_workers, 67 | ) 68 | -------------------------------------------------------------------------------- /core/datasets/turbulence/data_module.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from pytorch_lightning import LightningDataModule 3 | 4 | from .torch_dataset import Turbulence 5 | 6 | 7 | class TurbulenceDataModule(LightningDataModule): 8 | 9 | def __init__(self, dataset_config, batch_size, num_workers): 10 | super(TurbulenceDataModule, self).__init__() 11 | self.train_dataset, self.val_dataset, self.test_dataset = None, None, None 12 | self.save_hyperparameters() 13 | 14 | def prepare_data(self): 15 | pass 16 | 17 | def setup(self, stage=None): 18 | dataset_config = self.hparams.dataset_config 19 | self.train_dataset = Turbulence( 20 | root=dataset_config.root, 21 | total_length=dataset_config.total_length, 22 | **dataset_config.train 23 | ) 24 | self.test_dataset = Turbulence( 25 | root=dataset_config.root, 26 | total_length=dataset_config.total_length, 27 | **dataset_config.test 28 | ) 29 | 30 | def train_dataloader(self): 31 | return DataLoader( 32 | self.train_dataset, 33 | batch_size=self.hparams.batch_size, 34 | pin_memory=True, 35 | num_workers=self.hparams.num_workers, 36 | persistent_workers=True, 37 | ) 38 | 39 | def val_dataloader(self): 40 | """Currently use the test set for validation. Will modify later.""" 41 | return DataLoader( 42 | self.test_dataset, 43 | batch_size=self.hparams.batch_size, 44 | pin_memory=True, 45 | num_workers=self.hparams.num_workers, 46 | persistent_workers=True, 47 | shuffle=False 48 | ) 49 | 50 | def test_dataloader(self): 51 | return DataLoader( 52 | self.test_dataset, 53 | batch_size=self.hparams.batch_size, 54 | pin_memory=True, 55 | num_workers=self.hparams.num_workers, 56 | shuffle=False 57 | ) 58 | 59 | @property 60 | def num_train_samples(self): 61 | return len(self.train_dataset) 62 | 63 | @property 64 | def num_val_samples(self): 65 | """Currently use the test set for validation. Will modify later.""" 66 | return len(self.test_dataset) 67 | 68 | @property 69 | def num_test_samples(self): 70 | return len(self.test_dataset) 71 | -------------------------------------------------------------------------------- /core/sgm/modules/diffusionmodules/discretizer.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from functools import partial 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from ...modules.diffusionmodules.util import make_beta_schedule 8 | from ...util import append_zero 9 | 10 | 11 | def generate_roughly_equally_spaced_steps( 12 | num_substeps: int, max_step: int 13 | ) -> np.ndarray: 14 | return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] 15 | 16 | 17 | class Discretization: 18 | def __call__(self, n, do_append_zero=True, device="cpu", flip=False): 19 | sigmas = self.get_sigmas(n, device=device) 20 | sigmas = append_zero(sigmas) if do_append_zero else sigmas 21 | return sigmas if not flip else torch.flip(sigmas, (0,)) 22 | 23 | @abstractmethod 24 | def get_sigmas(self, n, device): 25 | pass 26 | 27 | 28 | class EDMDiscretization(Discretization): 29 | def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0): 30 | self.sigma_min = sigma_min 31 | self.sigma_max = sigma_max 32 | self.rho = rho 33 | 34 | def get_sigmas(self, n, device="cpu"): 35 | ramp = torch.linspace(0, 1, n, device=device) 36 | min_inv_rho = self.sigma_min ** (1 / self.rho) 37 | max_inv_rho = self.sigma_max ** (1 / self.rho) 38 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho 39 | return sigmas 40 | 41 | 42 | class LegacyDDPMDiscretization(Discretization): 43 | def __init__( 44 | self, 45 | linear_start=0.00085, 46 | linear_end=0.0120, 47 | num_timesteps=1000, 48 | ): 49 | super().__init__() 50 | self.num_timesteps = num_timesteps 51 | betas = make_beta_schedule( 52 | "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end 53 | ) 54 | alphas = 1.0 - betas 55 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 56 | self.to_torch = partial(torch.tensor, dtype=torch.float32) 57 | 58 | def get_sigmas(self, n, device="cpu"): 59 | if n < self.num_timesteps: 60 | timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) 61 | alphas_cumprod = self.alphas_cumprod[timesteps] 62 | elif n == self.num_timesteps: 63 | alphas_cumprod = self.alphas_cumprod 64 | else: 65 | raise ValueError 66 | 67 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device) 68 | sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 69 | return torch.flip(sigmas, (0,)) 70 | -------------------------------------------------------------------------------- /core/sgm/modules/diffusionmodules/denoiser.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ...util import append_dims, instantiate_from_config 7 | from .denoiser_scaling import DenoiserScaling 8 | from .discretizer import Discretization 9 | 10 | 11 | class Denoiser(nn.Module): 12 | def __init__(self, scaling_config: Dict): 13 | super().__init__() 14 | 15 | self.scaling: DenoiserScaling = instantiate_from_config(scaling_config) 16 | 17 | def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: 18 | return sigma 19 | 20 | def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: 21 | return c_noise 22 | 23 | def forward( 24 | self, 25 | network: nn.Module, 26 | input: torch.Tensor, 27 | sigma: torch.Tensor, 28 | cond: Dict, 29 | **additional_model_inputs, 30 | ) -> torch.Tensor: 31 | sigma = self.possibly_quantize_sigma(sigma) 32 | sigma_shape = sigma.shape 33 | sigma = append_dims(sigma, input.ndim) 34 | c_skip, c_out, c_in, c_noise = self.scaling(sigma) 35 | c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) 36 | return ( 37 | network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out 38 | + input * c_skip 39 | ) 40 | 41 | 42 | class DiscreteDenoiser(Denoiser): 43 | def __init__( 44 | self, 45 | scaling_config: Dict, 46 | num_idx: int, 47 | discretization_config: Dict, 48 | do_append_zero: bool = False, 49 | quantize_c_noise: bool = True, 50 | flip: bool = True, 51 | ): 52 | super().__init__(scaling_config) 53 | self.discretization: Discretization = instantiate_from_config( 54 | discretization_config 55 | ) 56 | sigmas = self.discretization(num_idx, do_append_zero=do_append_zero, flip=flip) 57 | self.register_buffer("sigmas", sigmas) 58 | self.quantize_c_noise = quantize_c_noise 59 | self.num_idx = num_idx 60 | 61 | def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor: 62 | dists = sigma - self.sigmas[:, None] 63 | return dists.abs().argmin(dim=0).view(sigma.shape) 64 | 65 | def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor: 66 | return self.sigmas[idx] 67 | 68 | def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: 69 | return self.idx_to_sigma(self.sigma_to_idx(sigma)) 70 | 71 | def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: 72 | if self.quantize_c_noise: 73 | return self.sigma_to_idx(c_noise) 74 | else: 75 | return c_noise 76 | -------------------------------------------------------------------------------- /core/sgm/data/mnist.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torchvision 3 | from torch.utils.data import DataLoader, Dataset 4 | from torchvision import transforms 5 | 6 | 7 | class MNISTDataDictWrapper(Dataset): 8 | def __init__(self, dset): 9 | super().__init__() 10 | self.dset = dset 11 | 12 | def __getitem__(self, i): 13 | x, y = self.dset[i] 14 | return {"jpg": x, "cls": y} 15 | 16 | def __len__(self): 17 | return len(self.dset) 18 | 19 | 20 | class MNISTLoader(pl.LightningDataModule): 21 | def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True): 22 | super().__init__() 23 | 24 | transform = transforms.Compose( 25 | [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] 26 | ) 27 | 28 | self.batch_size = batch_size 29 | self.num_workers = num_workers 30 | self.prefetch_factor = prefetch_factor if num_workers > 0 else 0 31 | self.shuffle = shuffle 32 | self.train_dataset = MNISTDataDictWrapper( 33 | torchvision.datasets.MNIST( 34 | root=".data/", train=True, download=True, transform=transform 35 | ) 36 | ) 37 | self.test_dataset = MNISTDataDictWrapper( 38 | torchvision.datasets.MNIST( 39 | root=".data/", train=False, download=True, transform=transform 40 | ) 41 | ) 42 | 43 | def prepare_data(self): 44 | pass 45 | 46 | def train_dataloader(self): 47 | return DataLoader( 48 | self.train_dataset, 49 | batch_size=self.batch_size, 50 | shuffle=self.shuffle, 51 | num_workers=self.num_workers, 52 | prefetch_factor=self.prefetch_factor, 53 | ) 54 | 55 | def test_dataloader(self): 56 | return DataLoader( 57 | self.test_dataset, 58 | batch_size=self.batch_size, 59 | shuffle=self.shuffle, 60 | num_workers=self.num_workers, 61 | prefetch_factor=self.prefetch_factor, 62 | ) 63 | 64 | def val_dataloader(self): 65 | return DataLoader( 66 | self.test_dataset, 67 | batch_size=self.batch_size, 68 | shuffle=self.shuffle, 69 | num_workers=self.num_workers, 70 | prefetch_factor=self.prefetch_factor, 71 | ) 72 | 73 | 74 | if __name__ == "__main__": 75 | dset = MNISTDataDictWrapper( 76 | torchvision.datasets.MNIST( 77 | root=".data/", 78 | train=False, 79 | download=True, 80 | transform=transforms.Compose( 81 | [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] 82 | ), 83 | ) 84 | ) 85 | ex = dset[0] 86 | -------------------------------------------------------------------------------- /core/models/turbulence/video_dit_B_baseline_1st_mask_svd_lr1e-4.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | root: null 3 | total_length: 15 4 | input_length: 4 5 | train: 6 | filename: train.npz.npy 7 | test: 8 | filename: test.npz.npy 9 | n_subsample: 20 10 | 11 | model: 12 | target: dydiff.dydiff_turbulence.DynamicalLDMForTurbulenceWithEncoderCondition 13 | params: 14 | image_size: 64 15 | linear_start: 0.00085 16 | linear_end: 0.0120 17 | beta_schedule: cosine 18 | linear_start_gamma: 0. 19 | linear_end_gamma: 0. 20 | gamma_schedule: "cosine-0.0" 21 | log_every_t: 100 22 | timesteps: 1000 23 | first_stage_key: "image" 24 | first_stage_key_prev: "prev" 25 | scale_factor: 0.18215 26 | shift_factor: 0. 27 | # scale_by_std: True 28 | channels: none # align with input_length 29 | monitor: val/loss_simple_ema 30 | use_ema: False 31 | num_timesteps_cond: 1 # add noise in eval for # steps 32 | parameterization: eps # v not supported 33 | x_channels: 2 34 | z_channels: 3 35 | new_prev_ema: True 36 | # frame_weighting: True 37 | use_x_ema: True 38 | 39 | # concat 40 | cond_stage_key: "cond" 41 | conditioning_key: "concat-video-mask-1st" 42 | cond_stage_config: 43 | target: torch.nn.Identity 44 | 45 | unconditional_guidance_scale: 1.0 46 | visualize_intermediates: True 47 | # rollout: 20 48 | 49 | first_stage_config: 50 | model: 51 | # base_learning_rate: 4.5e-6 52 | target: ldm.models.autoencoder.AutoencoderKL 53 | params: 54 | ckpt_path: null 55 | monitor: "val/rec_loss" 56 | embed_dim: 3 57 | lossconfig: 58 | target: torch.nn.Identity 59 | 60 | ddconfig: 61 | double_z: True 62 | z_channels: 3 63 | resolution: 64 64 | in_channels: 2 65 | out_ch: 2 66 | ch: 128 67 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 68 | num_res_blocks: 2 69 | attn_resolutions: [ ] 70 | dropout: 0.0 71 | 72 | # unet 73 | unet_config: 74 | target: dit.video_dit.DiT_B_4 75 | params: 76 | num_video_frames: 0 77 | input_size: 16 78 | in_channels: 0 79 | out_channels: 0 80 | 81 | validate_kwargs: 82 | ddim: True 83 | ddim_steps: 50 84 | ddim_eta: 0. 85 | 86 | training: 87 | max_iterations: 1000005 88 | model_attrs: 89 | learning_rate: 1e-4 90 | batch_size: 16 91 | num_workers: 16 92 | validation_freq: null 93 | accumulate_grad_batches: 1 94 | logger: 95 | save_dir: ../logs/turbulence 96 | logger_freq: 5000 97 | checkpoint_freq: 10000 98 | 99 | eval: 100 | num_vis: 5 -------------------------------------------------------------------------------- /core/models/turbulence/video_dit_B_ratio05_1st_mask_svd_lr1e-4.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | root: null 3 | total_length: 15 4 | input_length: 4 5 | train: 6 | filename: train.npz.npy 7 | test: 8 | filename: test.npz.npy 9 | n_subsample: 20 10 | 11 | model: 12 | target: dydiff.dydiff_turbulence.DynamicalLDMForTurbulenceWithEncoderCondition 13 | params: 14 | image_size: 64 15 | linear_start: 0.00085 16 | linear_end: 0.0120 17 | beta_schedule: cosine 18 | linear_start_gamma: 0. 19 | linear_end_gamma: 0. 20 | gamma_schedule: "cosine-0.5" 21 | log_every_t: 100 22 | timesteps: 1000 23 | first_stage_key: "image" 24 | first_stage_key_prev: "prev" 25 | scale_factor: 0.18215 26 | shift_factor: 0. 27 | # scale_by_std: True 28 | channels: none # align with input_length 29 | monitor: val/loss_simple_ema 30 | use_ema: False 31 | num_timesteps_cond: 1 # add noise in eval for # steps 32 | parameterization: eps # v not supported 33 | x_channels: 2 34 | z_channels: 3 35 | new_prev_ema: True 36 | # frame_weighting: True 37 | use_x_ema: True 38 | 39 | # concat 40 | cond_stage_key: "cond" 41 | conditioning_key: "concat-video-mask-1st" 42 | cond_stage_config: 43 | target: torch.nn.Identity 44 | 45 | unconditional_guidance_scale: 1.0 46 | visualize_intermediates: True 47 | # rollout: 20 48 | 49 | first_stage_config: 50 | model: 51 | # base_learning_rate: 4.5e-6 52 | target: ldm.models.autoencoder.AutoencoderKL 53 | params: 54 | ckpt_path: null 55 | monitor: "val/rec_loss" 56 | embed_dim: 3 57 | lossconfig: 58 | target: torch.nn.Identity 59 | 60 | ddconfig: 61 | double_z: True 62 | z_channels: 3 63 | resolution: 64 64 | in_channels: 2 65 | out_ch: 2 66 | ch: 128 67 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 68 | num_res_blocks: 2 69 | attn_resolutions: [ ] 70 | dropout: 0.0 71 | 72 | # unet 73 | unet_config: 74 | target: dit.video_dit.DiT_B_4 75 | params: 76 | num_video_frames: 0 77 | input_size: 16 78 | in_channels: 0 79 | out_channels: 0 80 | 81 | validate_kwargs: 82 | ddim: True 83 | ddim_steps: 50 84 | ddim_eta: 0. 85 | 86 | training: 87 | max_iterations: 1000005 88 | model_attrs: 89 | learning_rate: 1e-4 90 | batch_size: 16 91 | num_workers: 16 92 | validation_freq: null 93 | accumulate_grad_batches: 1 94 | logger: 95 | save_dir: ../logs/turbulence 96 | logger_freq: 5000 97 | checkpoint_freq: 10000 98 | 99 | eval: 100 | num_vis: 5 -------------------------------------------------------------------------------- /core/models/turbulence/video_dit_S_baseline_1st_mask_svd_lr1e-4.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | root: null 3 | total_length: 15 4 | input_length: 4 5 | train: 6 | filename: train.npz.npy 7 | test: 8 | filename: test.npz.npy 9 | n_subsample: 20 10 | 11 | model: 12 | target: dydiff.dydiff_turbulence.DynamicalLDMForTurbulenceWithEncoderCondition 13 | params: 14 | image_size: 64 15 | linear_start: 0.00085 16 | linear_end: 0.0120 17 | beta_schedule: cosine 18 | linear_start_gamma: 0. 19 | linear_end_gamma: 0. 20 | gamma_schedule: "cosine-0.0" 21 | log_every_t: 100 22 | timesteps: 1000 23 | first_stage_key: "image" 24 | first_stage_key_prev: "prev" 25 | scale_factor: 0.18215 26 | shift_factor: 0. 27 | # scale_by_std: True 28 | channels: none # align with input_length 29 | monitor: val/loss_simple_ema 30 | use_ema: False 31 | num_timesteps_cond: 1 # add noise in eval for # steps 32 | parameterization: eps # v not supported 33 | x_channels: 2 34 | z_channels: 3 35 | new_prev_ema: True 36 | # frame_weighting: True 37 | use_x_ema: True 38 | 39 | # concat 40 | cond_stage_key: "cond" 41 | conditioning_key: "concat-video-mask-1st" 42 | cond_stage_config: 43 | target: torch.nn.Identity 44 | 45 | unconditional_guidance_scale: 1.0 46 | visualize_intermediates: True 47 | # rollout: 20 48 | 49 | first_stage_config: 50 | model: 51 | # base_learning_rate: 4.5e-6 52 | target: ldm.models.autoencoder.AutoencoderKL 53 | params: 54 | ckpt_path: null 55 | monitor: "val/rec_loss" 56 | embed_dim: 3 57 | lossconfig: 58 | target: torch.nn.Identity 59 | 60 | ddconfig: 61 | double_z: True 62 | z_channels: 3 63 | resolution: 64 64 | in_channels: 2 65 | out_ch: 2 66 | ch: 128 67 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 68 | num_res_blocks: 2 69 | attn_resolutions: [ ] 70 | dropout: 0.0 71 | 72 | # unet 73 | unet_config: 74 | target: dit.video_dit.DiT_S_4 75 | params: 76 | num_video_frames: 0 77 | input_size: 16 78 | in_channels: 0 79 | out_channels: 0 80 | 81 | validate_kwargs: 82 | ddim: True 83 | ddim_steps: 50 84 | ddim_eta: 0. 85 | 86 | training: 87 | max_iterations: 1000005 88 | model_attrs: 89 | learning_rate: 1e-4 90 | batch_size: 16 91 | num_workers: 16 92 | validation_freq: null 93 | accumulate_grad_batches: 1 94 | logger: 95 | save_dir: ../logs/turbulence 96 | logger_freq: 5000 97 | checkpoint_freq: 10000 98 | 99 | eval: 100 | num_vis: 5 -------------------------------------------------------------------------------- /core/models/turbulence/video_dit_S_ratio05_1st_mask_svd_lr1e-4.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | root: null 3 | total_length: 15 4 | input_length: 4 5 | train: 6 | filename: train.npz.npy 7 | test: 8 | filename: test.npz.npy 9 | n_subsample: 20 10 | 11 | model: 12 | target: dydiff.dydiff_turbulence.DynamicalLDMForTurbulenceWithEncoderCondition 13 | params: 14 | image_size: 64 15 | linear_start: 0.00085 16 | linear_end: 0.0120 17 | beta_schedule: cosine 18 | linear_start_gamma: 0. 19 | linear_end_gamma: 0. 20 | gamma_schedule: "cosine-0.5" 21 | log_every_t: 100 22 | timesteps: 1000 23 | first_stage_key: "image" 24 | first_stage_key_prev: "prev" 25 | scale_factor: 0.18215 26 | shift_factor: 0. 27 | # scale_by_std: True 28 | channels: none # align with input_length 29 | monitor: val/loss_simple_ema 30 | use_ema: False 31 | num_timesteps_cond: 1 # add noise in eval for # steps 32 | parameterization: eps # v not supported 33 | x_channels: 2 34 | z_channels: 3 35 | new_prev_ema: True 36 | # frame_weighting: True 37 | use_x_ema: True 38 | 39 | # concat 40 | cond_stage_key: "cond" 41 | conditioning_key: "concat-video-mask-1st" 42 | cond_stage_config: 43 | target: torch.nn.Identity 44 | 45 | unconditional_guidance_scale: 1.0 46 | visualize_intermediates: True 47 | # rollout: 20 48 | 49 | first_stage_config: 50 | model: 51 | # base_learning_rate: 4.5e-6 52 | target: ldm.models.autoencoder.AutoencoderKL 53 | params: 54 | ckpt_path: null 55 | monitor: "val/rec_loss" 56 | embed_dim: 3 57 | lossconfig: 58 | target: torch.nn.Identity 59 | 60 | ddconfig: 61 | double_z: True 62 | z_channels: 3 63 | resolution: 64 64 | in_channels: 2 65 | out_ch: 2 66 | ch: 128 67 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 68 | num_res_blocks: 2 69 | attn_resolutions: [ ] 70 | dropout: 0.0 71 | 72 | # unet 73 | unet_config: 74 | target: dit.video_dit.DiT_S_4 75 | params: 76 | num_video_frames: 0 77 | input_size: 16 78 | in_channels: 0 79 | out_channels: 0 80 | 81 | validate_kwargs: 82 | ddim: True 83 | ddim_steps: 50 84 | ddim_eta: 0. 85 | 86 | training: 87 | max_iterations: 1000005 88 | model_attrs: 89 | learning_rate: 1e-4 90 | batch_size: 16 91 | num_workers: 16 92 | validation_freq: null 93 | accumulate_grad_batches: 1 94 | logger: 95 | save_dir: ../logs/turbulence 96 | logger_freq: 5000 97 | checkpoint_freq: 10000 98 | 99 | eval: 100 | num_vis: 5 -------------------------------------------------------------------------------- /core/taming/modules/discriminator/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | 5 | from taming.modules.util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find('Conv') != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find('BatchNorm') != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 22 | """Construct a PatchGAN discriminator 23 | Parameters: 24 | input_nc (int) -- the number of channels in input images 25 | ndf (int) -- the number of filters in the last conv layer 26 | n_layers (int) -- the number of conv layers in the discriminator 27 | norm_layer -- normalization layer 28 | """ 29 | super(NLayerDiscriminator, self).__init__() 30 | if not use_actnorm: 31 | norm_layer = nn.BatchNorm2d 32 | else: 33 | norm_layer = ActNorm 34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 35 | use_bias = norm_layer.func != nn.BatchNorm2d 36 | else: 37 | use_bias = norm_layer != nn.BatchNorm2d 38 | 39 | kw = 4 40 | padw = 1 41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 42 | nf_mult = 1 43 | nf_mult_prev = 1 44 | for n in range(1, n_layers): # gradually increase the number of filters 45 | nf_mult_prev = nf_mult 46 | nf_mult = min(2 ** n, 8) 47 | sequence += [ 48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 49 | norm_layer(ndf * nf_mult), 50 | nn.LeakyReLU(0.2, True) 51 | ] 52 | 53 | nf_mult_prev = nf_mult 54 | nf_mult = min(2 ** n_layers, 8) 55 | sequence += [ 56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 57 | norm_layer(ndf * nf_mult), 58 | nn.LeakyReLU(0.2, True) 59 | ] 60 | 61 | sequence += [ 62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 63 | self.main = nn.Sequential(*sequence) 64 | 65 | def forward(self, input): 66 | """Standard forward.""" 67 | return self.main(input) 68 | -------------------------------------------------------------------------------- /core/taming/data/base.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import numpy as np 3 | import albumentations 4 | from PIL import Image 5 | from torch.utils.data import Dataset, ConcatDataset 6 | 7 | 8 | class ConcatDatasetWithIndex(ConcatDataset): 9 | """Modified from original pytorch code to return dataset idx""" 10 | def __getitem__(self, idx): 11 | if idx < 0: 12 | if -idx > len(self): 13 | raise ValueError("absolute value of index should not exceed dataset length") 14 | idx = len(self) + idx 15 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 16 | if dataset_idx == 0: 17 | sample_idx = idx 18 | else: 19 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 20 | return self.datasets[dataset_idx][sample_idx], dataset_idx 21 | 22 | 23 | class ImagePaths(Dataset): 24 | def __init__(self, paths, size=None, random_crop=False, labels=None): 25 | self.size = size 26 | self.random_crop = random_crop 27 | 28 | self.labels = dict() if labels is None else labels 29 | self.labels["file_path_"] = paths 30 | self._length = len(paths) 31 | 32 | if self.size is not None and self.size > 0: 33 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) 34 | if not self.random_crop: 35 | self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) 36 | else: 37 | self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) 38 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) 39 | else: 40 | self.preprocessor = lambda **kwargs: kwargs 41 | 42 | def __len__(self): 43 | return self._length 44 | 45 | def preprocess_image(self, image_path): 46 | image = Image.open(image_path) 47 | if not image.mode == "RGB": 48 | image = image.convert("RGB") 49 | image = np.array(image).astype(np.uint8) 50 | image = self.preprocessor(image=image)["image"] 51 | image = (image/127.5 - 1.0).astype(np.float32) 52 | return image 53 | 54 | def __getitem__(self, i): 55 | example = dict() 56 | example["image"] = self.preprocess_image(self.labels["file_path_"][i]) 57 | for k in self.labels: 58 | example[k] = self.labels[k][i] 59 | return example 60 | 61 | 62 | class NumpyPaths(ImagePaths): 63 | def preprocess_image(self, image_path): 64 | image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 65 | image = np.transpose(image, (1,2,0)) 66 | image = Image.fromarray(image, mode="RGB") 67 | image = np.array(image).astype(np.uint8) 68 | image = self.preprocessor(image=image)["image"] 69 | image = (image/127.5 - 1.0).astype(np.float32) 70 | return image 71 | -------------------------------------------------------------------------------- /core/ldm/modules/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /core/sgm/modules/autoencoding/losses/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ....util import default, instantiate_from_config 5 | from ..lpips.loss.lpips import LPIPS 6 | 7 | 8 | class LatentLPIPS(nn.Module): 9 | def __init__( 10 | self, 11 | decoder_config, 12 | perceptual_weight=1.0, 13 | latent_weight=1.0, 14 | scale_input_to_tgt_size=False, 15 | scale_tgt_to_input_size=False, 16 | perceptual_weight_on_inputs=0.0, 17 | ): 18 | super().__init__() 19 | self.scale_input_to_tgt_size = scale_input_to_tgt_size 20 | self.scale_tgt_to_input_size = scale_tgt_to_input_size 21 | self.init_decoder(decoder_config) 22 | self.perceptual_loss = LPIPS().eval() 23 | self.perceptual_weight = perceptual_weight 24 | self.latent_weight = latent_weight 25 | self.perceptual_weight_on_inputs = perceptual_weight_on_inputs 26 | 27 | def init_decoder(self, config): 28 | self.decoder = instantiate_from_config(config) 29 | if hasattr(self.decoder, "encoder"): 30 | del self.decoder.encoder 31 | 32 | def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"): 33 | log = dict() 34 | loss = (latent_inputs - latent_predictions) ** 2 35 | log[f"{split}/latent_l2_loss"] = loss.mean().detach() 36 | image_reconstructions = None 37 | if self.perceptual_weight > 0.0: 38 | image_reconstructions = self.decoder.decode(latent_predictions) 39 | image_targets = self.decoder.decode(latent_inputs) 40 | perceptual_loss = self.perceptual_loss( 41 | image_targets.contiguous(), image_reconstructions.contiguous() 42 | ) 43 | loss = ( 44 | self.latent_weight * loss.mean() 45 | + self.perceptual_weight * perceptual_loss.mean() 46 | ) 47 | log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach() 48 | 49 | if self.perceptual_weight_on_inputs > 0.0: 50 | image_reconstructions = default( 51 | image_reconstructions, self.decoder.decode(latent_predictions) 52 | ) 53 | if self.scale_input_to_tgt_size: 54 | image_inputs = torch.nn.functional.interpolate( 55 | image_inputs, 56 | image_reconstructions.shape[2:], 57 | mode="bicubic", 58 | antialias=True, 59 | ) 60 | elif self.scale_tgt_to_input_size: 61 | image_reconstructions = torch.nn.functional.interpolate( 62 | image_reconstructions, 63 | image_inputs.shape[2:], 64 | mode="bicubic", 65 | antialias=True, 66 | ) 67 | 68 | perceptual_loss2 = self.perceptual_loss( 69 | image_inputs.contiguous(), image_reconstructions.contiguous() 70 | ) 71 | loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean() 72 | log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach() 73 | return loss, log 74 | -------------------------------------------------------------------------------- /core/sgm/modules/autoencoding/lpips/model/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch.nn as nn 4 | 5 | from ..util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find("BatchNorm") != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | 22 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 23 | """Construct a PatchGAN discriminator 24 | Parameters: 25 | input_nc (int) -- the number of channels in input images 26 | ndf (int) -- the number of filters in the last conv layer 27 | n_layers (int) -- the number of conv layers in the discriminator 28 | norm_layer -- normalization layer 29 | """ 30 | super(NLayerDiscriminator, self).__init__() 31 | if not use_actnorm: 32 | norm_layer = nn.BatchNorm2d 33 | else: 34 | norm_layer = ActNorm 35 | if ( 36 | type(norm_layer) == functools.partial 37 | ): # no need to use bias as BatchNorm2d has affine parameters 38 | use_bias = norm_layer.func != nn.BatchNorm2d 39 | else: 40 | use_bias = norm_layer != nn.BatchNorm2d 41 | 42 | kw = 4 43 | padw = 1 44 | sequence = [ 45 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 46 | nn.LeakyReLU(0.2, True), 47 | ] 48 | nf_mult = 1 49 | nf_mult_prev = 1 50 | for n in range(1, n_layers): # gradually increase the number of filters 51 | nf_mult_prev = nf_mult 52 | nf_mult = min(2**n, 8) 53 | sequence += [ 54 | nn.Conv2d( 55 | ndf * nf_mult_prev, 56 | ndf * nf_mult, 57 | kernel_size=kw, 58 | stride=2, 59 | padding=padw, 60 | bias=use_bias, 61 | ), 62 | norm_layer(ndf * nf_mult), 63 | nn.LeakyReLU(0.2, True), 64 | ] 65 | 66 | nf_mult_prev = nf_mult 67 | nf_mult = min(2**n_layers, 8) 68 | sequence += [ 69 | nn.Conv2d( 70 | ndf * nf_mult_prev, 71 | ndf * nf_mult, 72 | kernel_size=kw, 73 | stride=1, 74 | padding=padw, 75 | bias=use_bias, 76 | ), 77 | norm_layer(ndf * nf_mult), 78 | nn.LeakyReLU(0.2, True), 79 | ] 80 | 81 | sequence += [ 82 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) 83 | ] # output 1 channel prediction map 84 | self.main = nn.Sequential(*sequence) 85 | 86 | def forward(self, input): 87 | """Standard forward.""" 88 | return self.main(input) 89 | -------------------------------------------------------------------------------- /core/models/turbulence/baseline_1st_mask_svd_lr1e-4.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | root: null 3 | total_length: 15 4 | input_length: 4 5 | train: 6 | filename: train.npz.npy 7 | test: 8 | filename: test.npz.npy 9 | n_subsample: 20 10 | 11 | model: 12 | target: dydiff.dydiff_turbulence.DynamicalLDMForTurbulenceWithEncoderCondition 13 | params: 14 | image_size: 64 15 | linear_start: 0.00085 16 | linear_end: 0.0120 17 | beta_schedule: cosine 18 | linear_start_gamma: 0. 19 | linear_end_gamma: 0. 20 | gamma_schedule: "cosine-0.0" 21 | log_every_t: 100 22 | timesteps: 1000 23 | first_stage_key: "image" 24 | first_stage_key_prev: "prev" 25 | scale_factor: 0.18215 26 | shift_factor: 0. 27 | # scale_by_std: True 28 | channels: none # align with input_length 29 | monitor: val/loss_simple_ema 30 | use_ema: False 31 | num_timesteps_cond: 1 # add noise in eval for # steps 32 | parameterization: eps # v not supported 33 | x_channels: 2 34 | z_channels: 3 35 | new_prev_ema: True 36 | # frame_weighting: True 37 | use_x_ema: True 38 | 39 | # concat 40 | cond_stage_key: "cond" 41 | conditioning_key: "concat-video-mask-1st" 42 | cond_stage_config: 43 | target: torch.nn.Identity 44 | 45 | unconditional_guidance_scale: 1.0 46 | visualize_intermediates: True 47 | # rollout: 20 48 | 49 | first_stage_config: 50 | model: 51 | # base_learning_rate: 4.5e-6 52 | target: ldm.models.autoencoder.AutoencoderKL 53 | params: 54 | ckpt_path: null 55 | monitor: "val/rec_loss" 56 | embed_dim: 3 57 | lossconfig: 58 | target: torch.nn.Identity 59 | 60 | ddconfig: 61 | double_z: True 62 | z_channels: 3 63 | resolution: 64 64 | in_channels: 2 65 | out_ch: 2 66 | ch: 128 67 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 68 | num_res_blocks: 2 69 | attn_resolutions: [ ] 70 | dropout: 0.0 71 | 72 | # unet 73 | unet_config: 74 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 75 | params: 76 | num_video_frames: 0 77 | # use_checkpoint: True 78 | in_channels: 0 79 | out_channels: 0 80 | model_channels: 64 81 | attention_resolutions: [4, 2, 1] 82 | num_res_blocks: 2 83 | channel_mult: [1, 2, 4, 4] 84 | num_heads: 4 85 | transformer_depth: 1 86 | # spatial_transformer_attn_type: softmax-xformers 87 | spatial_transformer_attn_type: softmax 88 | extra_ff_mix_layer: True 89 | merge_strategy: learned 90 | video_kernel_size: [3, 1, 1] 91 | 92 | validate_kwargs: 93 | ddim: True 94 | ddim_steps: 50 95 | ddim_eta: 0. 96 | 97 | training: 98 | max_iterations: 1000005 99 | model_attrs: 100 | learning_rate: 1e-4 101 | batch_size: 16 102 | num_workers: 16 103 | validation_freq: null 104 | accumulate_grad_batches: 1 105 | logger: 106 | save_dir: ../logs/turbulence 107 | logger_freq: 5000 108 | checkpoint_freq: 10000 109 | 110 | eval: 111 | num_vis: 5 -------------------------------------------------------------------------------- /core/sgm/data/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torchdata.datapipes.iter 4 | import webdataset as wds 5 | from omegaconf import DictConfig 6 | from pytorch_lightning import LightningDataModule 7 | 8 | try: 9 | from sdata import create_dataset, create_dummy_dataset, create_loader 10 | except ImportError as e: 11 | print("#" * 100) 12 | print("Datasets not yet available") 13 | print("to enable, we need to add stable-datasets as a submodule") 14 | print("please use ``git submodule update --init --recursive``") 15 | print("and do ``pip install -e stable-datasets/`` from the root of this repo") 16 | print("#" * 100) 17 | exit(1) 18 | 19 | 20 | class StableDataModuleFromConfig(LightningDataModule): 21 | def __init__( 22 | self, 23 | train: DictConfig, 24 | validation: Optional[DictConfig] = None, 25 | test: Optional[DictConfig] = None, 26 | skip_val_loader: bool = False, 27 | dummy: bool = False, 28 | ): 29 | super().__init__() 30 | self.train_config = train 31 | assert ( 32 | "datapipeline" in self.train_config and "loader" in self.train_config 33 | ), "train config requires the fields `datapipeline` and `loader`" 34 | 35 | self.val_config = validation 36 | if not skip_val_loader: 37 | if self.val_config is not None: 38 | assert ( 39 | "datapipeline" in self.val_config and "loader" in self.val_config 40 | ), "validation config requires the fields `datapipeline` and `loader`" 41 | else: 42 | print( 43 | "Warning: No Validation datapipeline defined, using that one from training" 44 | ) 45 | self.val_config = train 46 | 47 | self.test_config = test 48 | if self.test_config is not None: 49 | assert ( 50 | "datapipeline" in self.test_config and "loader" in self.test_config 51 | ), "test config requires the fields `datapipeline` and `loader`" 52 | 53 | self.dummy = dummy 54 | if self.dummy: 55 | print("#" * 100) 56 | print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)") 57 | print("#" * 100) 58 | 59 | def setup(self, stage: str) -> None: 60 | print("Preparing datasets") 61 | if self.dummy: 62 | data_fn = create_dummy_dataset 63 | else: 64 | data_fn = create_dataset 65 | 66 | self.train_datapipeline = data_fn(**self.train_config.datapipeline) 67 | if self.val_config: 68 | self.val_datapipeline = data_fn(**self.val_config.datapipeline) 69 | if self.test_config: 70 | self.test_datapipeline = data_fn(**self.test_config.datapipeline) 71 | 72 | def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe: 73 | loader = create_loader(self.train_datapipeline, **self.train_config.loader) 74 | return loader 75 | 76 | def val_dataloader(self) -> wds.DataPipeline: 77 | return create_loader(self.val_datapipeline, **self.val_config.loader) 78 | 79 | def test_dataloader(self) -> wds.DataPipeline: 80 | return create_loader(self.test_datapipeline, **self.test_config.loader) 81 | -------------------------------------------------------------------------------- /core/models/turbulence/dfdiff_ema_cosine_ratio01_1st_mask_svd_lr1e-4.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | root: null 3 | total_length: 15 4 | input_length: 4 5 | train: 6 | filename: train.npz.npy 7 | test: 8 | filename: test.npz.npy 9 | n_subsample: 20 10 | 11 | model: 12 | target: dydiff.dydiff_turbulence.DynamicalLDMForTurbulenceWithEncoderCondition 13 | params: 14 | image_size: 64 15 | linear_start: 0.00085 16 | linear_end: 0.0120 17 | beta_schedule: cosine 18 | linear_start_gamma: 0. 19 | linear_end_gamma: 0. 20 | gamma_schedule: "cosine-0.1" 21 | log_every_t: 100 22 | timesteps: 1000 23 | first_stage_key: "image" 24 | first_stage_key_prev: "prev" 25 | scale_factor: 0.18215 26 | shift_factor: 0. 27 | # scale_by_std: True 28 | channels: none # align with input_length 29 | monitor: val/loss_simple_ema 30 | use_ema: False 31 | num_timesteps_cond: 1 # add noise in eval for # steps 32 | parameterization: eps # v not supported 33 | x_channels: 2 34 | z_channels: 3 35 | new_prev_ema: True 36 | # frame_weighting: True 37 | use_x_ema: True 38 | 39 | # concat 40 | cond_stage_key: "cond" 41 | conditioning_key: "concat-video-mask-1st" 42 | cond_stage_config: 43 | target: torch.nn.Identity 44 | 45 | unconditional_guidance_scale: 1.0 46 | visualize_intermediates: True 47 | # rollout: 20 48 | 49 | first_stage_config: 50 | model: 51 | # base_learning_rate: 4.5e-6 52 | target: ldm.models.autoencoder.AutoencoderKL 53 | params: 54 | ckpt_path: null 55 | monitor: "val/rec_loss" 56 | embed_dim: 3 57 | lossconfig: 58 | target: torch.nn.Identity 59 | 60 | ddconfig: 61 | double_z: True 62 | z_channels: 3 63 | resolution: 64 64 | in_channels: 2 65 | out_ch: 2 66 | ch: 128 67 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 68 | num_res_blocks: 2 69 | attn_resolutions: [ ] 70 | dropout: 0.0 71 | 72 | # unet 73 | unet_config: 74 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 75 | params: 76 | num_video_frames: 0 77 | # use_checkpoint: True 78 | in_channels: 0 79 | out_channels: 0 80 | model_channels: 64 81 | attention_resolutions: [4, 2, 1] 82 | num_res_blocks: 2 83 | channel_mult: [1, 2, 4, 4] 84 | num_heads: 4 85 | transformer_depth: 1 86 | # spatial_transformer_attn_type: softmax-xformers 87 | spatial_transformer_attn_type: softmax 88 | extra_ff_mix_layer: True 89 | merge_strategy: learned 90 | video_kernel_size: [3, 1, 1] 91 | 92 | validate_kwargs: 93 | ddim: True 94 | ddim_steps: 50 95 | ddim_eta: 0. 96 | 97 | training: 98 | max_iterations: 1000005 99 | model_attrs: 100 | learning_rate: 1e-4 101 | batch_size: 16 102 | num_workers: 16 103 | validation_freq: null 104 | accumulate_grad_batches: 1 105 | logger: 106 | save_dir: ../logs/turbulence 107 | logger_freq: 5000 108 | checkpoint_freq: 10000 109 | 110 | eval: 111 | num_vis: 5 -------------------------------------------------------------------------------- /core/models/turbulence/dfdiff_ema_cosine_ratio09_1st_mask_svd_lr1e-4.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | root: null 3 | total_length: 15 4 | input_length: 4 5 | train: 6 | filename: train.npz.npy 7 | test: 8 | filename: test.npz.npy 9 | n_subsample: 20 10 | 11 | model: 12 | target: dydiff.dydiff_turbulence.DynamicalLDMForTurbulenceWithEncoderCondition 13 | params: 14 | image_size: 64 15 | linear_start: 0.00085 16 | linear_end: 0.0120 17 | beta_schedule: cosine 18 | linear_start_gamma: 0. 19 | linear_end_gamma: 0. 20 | gamma_schedule: "cosine-0.9" 21 | log_every_t: 100 22 | timesteps: 1000 23 | first_stage_key: "image" 24 | first_stage_key_prev: "prev" 25 | scale_factor: 0.18215 26 | shift_factor: 0. 27 | # scale_by_std: True 28 | channels: none # align with input_length 29 | monitor: val/loss_simple_ema 30 | use_ema: False 31 | num_timesteps_cond: 1 # add noise in eval for # steps 32 | parameterization: eps # v not supported 33 | x_channels: 2 34 | z_channels: 3 35 | new_prev_ema: True 36 | # frame_weighting: True 37 | use_x_ema: True 38 | 39 | # concat 40 | cond_stage_key: "cond" 41 | conditioning_key: "concat-video-mask-1st" 42 | cond_stage_config: 43 | target: torch.nn.Identity 44 | 45 | unconditional_guidance_scale: 1.0 46 | visualize_intermediates: True 47 | # rollout: 20 48 | 49 | first_stage_config: 50 | model: 51 | # base_learning_rate: 4.5e-6 52 | target: ldm.models.autoencoder.AutoencoderKL 53 | params: 54 | ckpt_path: null 55 | monitor: "val/rec_loss" 56 | embed_dim: 3 57 | lossconfig: 58 | target: torch.nn.Identity 59 | 60 | ddconfig: 61 | double_z: True 62 | z_channels: 3 63 | resolution: 64 64 | in_channels: 2 65 | out_ch: 2 66 | ch: 128 67 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 68 | num_res_blocks: 2 69 | attn_resolutions: [ ] 70 | dropout: 0.0 71 | 72 | # unet 73 | unet_config: 74 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 75 | params: 76 | num_video_frames: 0 77 | # use_checkpoint: True 78 | in_channels: 0 79 | out_channels: 0 80 | model_channels: 64 81 | attention_resolutions: [4, 2, 1] 82 | num_res_blocks: 2 83 | channel_mult: [1, 2, 4, 4] 84 | num_heads: 4 85 | transformer_depth: 1 86 | # spatial_transformer_attn_type: softmax-xformers 87 | spatial_transformer_attn_type: softmax 88 | extra_ff_mix_layer: True 89 | merge_strategy: learned 90 | video_kernel_size: [3, 1, 1] 91 | 92 | validate_kwargs: 93 | ddim: True 94 | ddim_steps: 50 95 | ddim_eta: 0. 96 | 97 | training: 98 | max_iterations: 1000005 99 | model_attrs: 100 | learning_rate: 1e-4 101 | batch_size: 16 102 | num_workers: 16 103 | validation_freq: null 104 | accumulate_grad_batches: 1 105 | logger: 106 | save_dir: ../logs/turbulence 107 | logger_freq: 5000 108 | checkpoint_freq: 10000 109 | 110 | eval: 111 | num_vis: 5 -------------------------------------------------------------------------------- /core/models/turbulence/dydiff_ema_cosine_ratio05_1st_mask_svd_lr1e-4.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | root: null 3 | total_length: 15 4 | input_length: 4 5 | train: 6 | filename: train.npz.npy 7 | test: 8 | filename: test.npz.npy 9 | n_subsample: 20 10 | 11 | model: 12 | target: dydiff.dydiff_turbulence.DynamicalLDMForTurbulenceWithEncoderCondition 13 | params: 14 | image_size: 64 15 | linear_start: 0.00085 16 | linear_end: 0.0120 17 | beta_schedule: cosine 18 | linear_start_gamma: 0. 19 | linear_end_gamma: 0. 20 | gamma_schedule: "cosine-0.5" 21 | log_every_t: 100 22 | timesteps: 1000 23 | first_stage_key: "image" 24 | first_stage_key_prev: "prev" 25 | scale_factor: 0.18215 26 | shift_factor: 0. 27 | # scale_by_std: True 28 | channels: none # align with input_length 29 | monitor: val/loss_simple_ema 30 | use_ema: False 31 | num_timesteps_cond: 1 # add noise in eval for # steps 32 | parameterization: eps # v not supported 33 | x_channels: 2 34 | z_channels: 3 35 | new_prev_ema: True 36 | # frame_weighting: True 37 | use_x_ema: True 38 | 39 | # concat 40 | cond_stage_key: "cond" 41 | conditioning_key: "concat-video-mask-1st" 42 | cond_stage_config: 43 | target: torch.nn.Identity 44 | 45 | unconditional_guidance_scale: 1.0 46 | visualize_intermediates: True 47 | # rollout: 20 48 | 49 | first_stage_config: 50 | model: 51 | # base_learning_rate: 4.5e-6 52 | target: ldm.models.autoencoder.AutoencoderKL 53 | params: 54 | ckpt_path: null 55 | monitor: "val/rec_loss" 56 | embed_dim: 3 57 | lossconfig: 58 | target: torch.nn.Identity 59 | 60 | ddconfig: 61 | double_z: True 62 | z_channels: 3 63 | resolution: 64 64 | in_channels: 2 65 | out_ch: 2 66 | ch: 128 67 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 68 | num_res_blocks: 2 69 | attn_resolutions: [ ] 70 | dropout: 0.0 71 | 72 | # unet 73 | unet_config: 74 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 75 | params: 76 | num_video_frames: 0 77 | # use_checkpoint: True 78 | in_channels: 0 79 | out_channels: 0 80 | model_channels: 64 81 | attention_resolutions: [4, 2, 1] 82 | num_res_blocks: 2 83 | channel_mult: [1, 2, 4, 4] 84 | num_heads: 4 85 | transformer_depth: 1 86 | # spatial_transformer_attn_type: softmax-xformers 87 | spatial_transformer_attn_type: softmax 88 | extra_ff_mix_layer: True 89 | merge_strategy: learned 90 | video_kernel_size: [3, 1, 1] 91 | 92 | validate_kwargs: 93 | ddim: True 94 | ddim_steps: 50 95 | ddim_eta: 0. 96 | 97 | training: 98 | max_iterations: 1000005 99 | model_attrs: 100 | learning_rate: 1e-4 101 | batch_size: 16 102 | num_workers: 16 103 | validation_freq: null 104 | accumulate_grad_batches: 1 105 | logger: 106 | save_dir: ../logs/turbulence 107 | logger_freq: 5000 108 | checkpoint_freq: 10000 109 | 110 | eval: 111 | num_vis: 5 -------------------------------------------------------------------------------- /core/models/turbulence/dydiff_ema_cosine_ratio10_1st_mask_svd_lr1e-4.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | root: null 3 | total_length: 15 4 | input_length: 4 5 | train: 6 | filename: train.npz.npy 7 | test: 8 | filename: test.npz.npy 9 | n_subsample: 20 10 | 11 | model: 12 | target: dydiff.dydiff_turbulence.DynamicalLDMForTurbulenceWithEncoderCondition 13 | params: 14 | image_size: 64 15 | linear_start: 0.00085 16 | linear_end: 0.0120 17 | beta_schedule: cosine 18 | linear_start_gamma: 0. 19 | linear_end_gamma: 0. 20 | gamma_schedule: "cosine-1.0" 21 | log_every_t: 100 22 | timesteps: 1000 23 | first_stage_key: "image" 24 | first_stage_key_prev: "prev" 25 | scale_factor: 0.18215 26 | shift_factor: 0. 27 | # scale_by_std: True 28 | channels: none # align with input_length 29 | monitor: val/loss_simple_ema 30 | use_ema: False 31 | num_timesteps_cond: 1 # add noise in eval for # steps 32 | parameterization: eps # v not supported 33 | x_channels: 2 34 | z_channels: 3 35 | new_prev_ema: True 36 | # frame_weighting: True 37 | use_x_ema: True 38 | 39 | # concat 40 | cond_stage_key: "cond" 41 | conditioning_key: "concat-video-mask-1st" 42 | cond_stage_config: 43 | target: torch.nn.Identity 44 | 45 | unconditional_guidance_scale: 1.0 46 | visualize_intermediates: True 47 | # rollout: 20 48 | 49 | first_stage_config: 50 | model: 51 | # base_learning_rate: 4.5e-6 52 | target: ldm.models.autoencoder.AutoencoderKL 53 | params: 54 | ckpt_path: null 55 | monitor: "val/rec_loss" 56 | embed_dim: 3 57 | lossconfig: 58 | target: torch.nn.Identity 59 | 60 | ddconfig: 61 | double_z: True 62 | z_channels: 3 63 | resolution: 64 64 | in_channels: 2 65 | out_ch: 2 66 | ch: 128 67 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 68 | num_res_blocks: 2 69 | attn_resolutions: [ ] 70 | dropout: 0.0 71 | 72 | # unet 73 | unet_config: 74 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 75 | params: 76 | num_video_frames: 0 77 | # use_checkpoint: True 78 | in_channels: 0 79 | out_channels: 0 80 | model_channels: 64 81 | attention_resolutions: [4, 2, 1] 82 | num_res_blocks: 2 83 | channel_mult: [1, 2, 4, 4] 84 | num_heads: 4 85 | transformer_depth: 1 86 | # spatial_transformer_attn_type: softmax-xformers 87 | spatial_transformer_attn_type: softmax 88 | extra_ff_mix_layer: True 89 | merge_strategy: learned 90 | video_kernel_size: [3, 1, 1] 91 | 92 | validate_kwargs: 93 | ddim: True 94 | ddim_steps: 50 95 | ddim_eta: 0. 96 | 97 | training: 98 | max_iterations: 1000005 99 | model_attrs: 100 | learning_rate: 1e-4 101 | batch_size: 16 102 | num_workers: 16 103 | validation_freq: null 104 | accumulate_grad_batches: 1 105 | logger: 106 | save_dir: ../logs/turbulence 107 | logger_freq: 5000 108 | checkpoint_freq: 10000 109 | 110 | eval: 111 | num_vis: 5 -------------------------------------------------------------------------------- /core/ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | import torch 3 | 4 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 5 | 6 | 7 | MODEL_TYPES = { 8 | "eps": "noise", 9 | "v": "v" 10 | } 11 | 12 | 13 | class DPMSolverSampler(object): 14 | def __init__(self, model, **kwargs): 15 | super().__init__() 16 | self.model = model 17 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 18 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 19 | 20 | def register_buffer(self, name, attr): 21 | if type(attr) == torch.Tensor: 22 | if attr.device != torch.device("cuda"): 23 | attr = attr.to(torch.device("cuda")) 24 | setattr(self, name, attr) 25 | 26 | @torch.no_grad() 27 | def sample(self, 28 | S, 29 | batch_size, 30 | shape, 31 | conditioning=None, 32 | callback=None, 33 | normals_sequence=None, 34 | img_callback=None, 35 | quantize_x0=False, 36 | eta=0., 37 | mask=None, 38 | x0=None, 39 | temperature=1., 40 | noise_dropout=0., 41 | score_corrector=None, 42 | corrector_kwargs=None, 43 | verbose=True, 44 | x_T=None, 45 | log_every_t=100, 46 | unconditional_guidance_scale=1., 47 | unconditional_conditioning=None, 48 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 49 | **kwargs 50 | ): 51 | if conditioning is not None: 52 | if isinstance(conditioning, dict): 53 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 54 | if cbs != batch_size: 55 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 56 | else: 57 | if conditioning.shape[0] != batch_size: 58 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 59 | 60 | # sampling 61 | C, H, W = shape 62 | size = (batch_size, C, H, W) 63 | 64 | print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 65 | 66 | device = self.model.betas.device 67 | if x_T is None: 68 | img = torch.randn(size, device=device) 69 | else: 70 | img = x_T 71 | 72 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 73 | 74 | model_fn = model_wrapper( 75 | lambda x, t, c: self.model.apply_model(x, t, c), 76 | ns, 77 | model_type=MODEL_TYPES[self.model.parameterization], 78 | guidance_type="classifier-free", 79 | condition=conditioning, 80 | unconditional_condition=unconditional_conditioning, 81 | guidance_scale=unconditional_guidance_scale, 82 | ) 83 | 84 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 85 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) 86 | 87 | return x.to(device), None -------------------------------------------------------------------------------- /core/ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /core/dydiff/datasets/bair.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datasets.bair.torch_dataset import BairSimpleDataset 3 | from datasets.moving_mnist.data_module import MovingMNISTDataModule 4 | from datasets.normalizer import RGBNormalizer 5 | 6 | from einops import rearrange 7 | 8 | class BairForDyDiff(BairSimpleDataset): 9 | def __init__(self, base_path, total_length, input_length, no_oracle=False, n_subsample=None, train=True, last_only=False): 10 | super().__init__(base_path, train, total_length) 11 | self.total_length = total_length 12 | self.input_length = input_length 13 | self.no_oracle = no_oracle 14 | self.last_only = last_only 15 | self.num_samples = super().__len__() 16 | if n_subsample is not None: 17 | self.num_samples = min(self.num_samples, n_subsample) 18 | self.indices = torch.randperm(self.num_samples, generator=torch.Generator().manual_seed(42)).tolist() 19 | 20 | def __len__(self): 21 | return self.num_samples 22 | 23 | def __getitem__(self, index): 24 | imgs, actions, states = super().__getitem__(self.indices[index]) # (T, 1, H, W, C) 25 | imgs = imgs / 255.0 26 | imgs = torch.from_numpy(imgs) 27 | imgs = RGBNormalizer.normalize(imgs) 28 | imgs = rearrange(imgs, 'T H W C -> T C H W') 29 | 30 | # data = data.permute(1, 2, 0).to(memory_format=torch.contiguous_format).float() 31 | cond = imgs[:self.input_length] 32 | pred = imgs[self.input_length:] 33 | # if self.no_oracle: 34 | # prev = imgs[self.input_length-1:self.input_length].repeat(self.total_length-self.input_length, 1, 1, 1) 35 | # else: 36 | # prev = imgs[self.input_length-1:-1] 37 | if self.last_only: 38 | prev = imgs[self.input_length-1:self.input_length] 39 | else: 40 | prev = imgs[:self.input_length] 41 | cond = cond.reshape(-1, 64, 64).permute(1, 2, 0).to(memory_format=torch.contiguous_format).float() 42 | pred = pred.reshape(-1, 64, 64).permute(1, 2, 0).to(memory_format=torch.contiguous_format).float() 43 | prev = prev.reshape(-1, 64, 64).permute(1, 2, 0).to(memory_format=torch.contiguous_format).float() 44 | total = imgs.reshape(-1, 64, 64).permute(1, 2, 0).to(memory_format=torch.contiguous_format).float() 45 | # cond_all = torch.cat([cond, topography], dim=-1) 46 | T, C = actions.shape 47 | actions_pad = torch.zeros(T+1, C) 48 | actions_pad[1:] = torch.from_numpy(actions) 49 | return dict(image=pred, prev=prev, cond=cond, total=total, actions=actions_pad) 50 | 51 | 52 | class BairDataModuleForDyDiff(MovingMNISTDataModule): 53 | def setup(self, stage=None): 54 | dataset_config = self.hparams.dataset_config 55 | self.train_dataset = BairForDyDiff( 56 | base_path=dataset_config.base_path, 57 | total_length=dataset_config.total_length, 58 | input_length=dataset_config.input_length, 59 | train=True, 60 | **dataset_config.train 61 | ) 62 | self.test_dataset = BairForDyDiff( 63 | base_path=dataset_config.base_path, 64 | total_length=dataset_config.total_length, 65 | input_length=dataset_config.input_length, 66 | train=False, 67 | **dataset_config.test 68 | ) 69 | 70 | -------------------------------------------------------------------------------- /core/ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1, dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | # remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.', '') 20 | self.m_name2s_name.update({name: s_name}) 21 | self.register_buffer(s_name, p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def reset_num_updates(self): 26 | del self.num_updates 27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 47 | else: 48 | assert not key in self.m_name2s_name 49 | 50 | def copy_to(self, model): 51 | m_param = dict(model.named_parameters()) 52 | shadow_params = dict(self.named_buffers()) 53 | for key in m_param: 54 | if m_param[key].requires_grad: 55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 56 | else: 57 | assert not key in self.m_name2s_name 58 | 59 | def store(self, parameters): 60 | """ 61 | Save the current parameters for restoring later. 62 | Args: 63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 64 | temporarily stored. 65 | """ 66 | self.collected_params = [param.clone() for param in parameters] 67 | 68 | def restore(self, parameters): 69 | """ 70 | Restore the parameters stored with the `store` method. 71 | Useful to validate the model with EMA parameters without affecting the 72 | original optimization process. Store the parameters before the 73 | `copy_to` method. After validation (or model saving), use this to 74 | restore the former parameters. 75 | Args: 76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 77 | updated with the stored parameters. 78 | """ 79 | for c_param, param in zip(self.collected_params, parameters): 80 | param.data.copy_(c_param.data) 81 | -------------------------------------------------------------------------------- /core/taming/data/conditional_builder/objects_bbox.py: -------------------------------------------------------------------------------- 1 | from itertools import cycle 2 | from typing import List, Tuple, Callable, Optional 3 | 4 | from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont 5 | from more_itertools.recipes import grouper 6 | from taming.data.image_transforms import convert_pil_to_tensor 7 | from torch import LongTensor, Tensor 8 | 9 | from taming.data.helper_types import BoundingBox, Annotation 10 | from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder 11 | from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \ 12 | pad_list, get_plot_font_size, absolute_bbox 13 | 14 | 15 | class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder): 16 | @property 17 | def object_descriptor_length(self) -> int: 18 | return 3 19 | 20 | def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: 21 | object_triples = [ 22 | (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox)) 23 | for ann in annotations 24 | ] 25 | empty_triple = (self.none, self.none, self.none) 26 | object_triples = pad_list(object_triples, empty_triple, self.no_max_objects) 27 | return object_triples 28 | 29 | def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]: 30 | conditional_list = conditional.tolist() 31 | crop_coordinates = None 32 | if self.encode_crop: 33 | crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) 34 | conditional_list = conditional_list[:-2] 35 | object_triples = grouper(conditional_list, 3) 36 | assert conditional.shape[0] == self.embedding_dim 37 | return [ 38 | (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2])) 39 | for object_triple in object_triples if object_triple[0] != self.none 40 | ], crop_coordinates 41 | 42 | def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], 43 | line_width: int = 3, font_size: Optional[int] = None) -> Tensor: 44 | plot = pil_image.new('RGB', figure_size, WHITE) 45 | draw = pil_img_draw.Draw(plot) 46 | font = ImageFont.truetype( 47 | "/usr/share/fonts/truetype/lato/Lato-Regular.ttf", 48 | size=get_plot_font_size(font_size, figure_size) 49 | ) 50 | width, height = plot.size 51 | description, crop_coordinates = self.inverse_build(conditional) 52 | for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)): 53 | annotation = self.representation_to_annotation(representation) 54 | class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation) 55 | bbox = absolute_bbox(bbox, width, height) 56 | draw.rectangle(bbox, outline=color, width=line_width) 57 | draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font) 58 | if crop_coordinates is not None: 59 | draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) 60 | return convert_pil_to_tensor(plot) / 127.5 - 1. 61 | -------------------------------------------------------------------------------- /core/sgm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | torch.tensor(0, dtype=torch.int) 16 | if use_num_upates 17 | else torch.tensor(-1, dtype=torch.int), 18 | ) 19 | 20 | for name, p in model.named_parameters(): 21 | if p.requires_grad: 22 | # remove as '.'-character is not allowed in buffers 23 | s_name = name.replace(".", "") 24 | self.m_name2s_name.update({name: s_name}) 25 | self.register_buffer(s_name, p.clone().detach().data) 26 | 27 | self.collected_params = [] 28 | 29 | def reset_num_updates(self): 30 | del self.num_updates 31 | self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) 32 | 33 | def forward(self, model): 34 | decay = self.decay 35 | 36 | if self.num_updates >= 0: 37 | self.num_updates += 1 38 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 39 | 40 | one_minus_decay = 1.0 - decay 41 | 42 | with torch.no_grad(): 43 | m_param = dict(model.named_parameters()) 44 | shadow_params = dict(self.named_buffers()) 45 | 46 | for key in m_param: 47 | if m_param[key].requires_grad: 48 | sname = self.m_name2s_name[key] 49 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 50 | shadow_params[sname].sub_( 51 | one_minus_decay * (shadow_params[sname] - m_param[key]) 52 | ) 53 | else: 54 | assert not key in self.m_name2s_name 55 | 56 | def copy_to(self, model): 57 | m_param = dict(model.named_parameters()) 58 | shadow_params = dict(self.named_buffers()) 59 | for key in m_param: 60 | if m_param[key].requires_grad: 61 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 62 | else: 63 | assert not key in self.m_name2s_name 64 | 65 | def store(self, parameters): 66 | """ 67 | Save the current parameters for restoring later. 68 | Args: 69 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 70 | temporarily stored. 71 | """ 72 | self.collected_params = [param.clone() for param in parameters] 73 | 74 | def restore(self, parameters): 75 | """ 76 | Restore the parameters stored with the `store` method. 77 | Useful to validate the model with EMA parameters without affecting the 78 | original optimization process. Store the parameters before the 79 | `copy_to` method. After validation (or model saving), use this to 80 | restore the former parameters. 81 | Args: 82 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 83 | updated with the stored parameters. 84 | """ 85 | for c_param, param in zip(self.collected_params, parameters): 86 | param.data.copy_(c_param.data) 87 | -------------------------------------------------------------------------------- /core/sgm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.sum( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=[1, 2, 3]): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims, 69 | ) 70 | 71 | def mode(self): 72 | return self.mean 73 | 74 | 75 | def normal_kl(mean1, logvar1, mean2, logvar2): 76 | """ 77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 78 | Compute the KL divergence between two gaussians. 79 | Shapes are automatically broadcasted, so batches can be compared to 80 | scalars, among other use cases. 81 | """ 82 | tensor = None 83 | for obj in (mean1, logvar1, mean2, logvar2): 84 | if isinstance(obj, torch.Tensor): 85 | tensor = obj 86 | break 87 | assert tensor is not None, "at least one argument must be a Tensor" 88 | 89 | # Force variances to be Tensors. Broadcasting helps convert scalars to 90 | # Tensors, but it does not work for torch.exp(). 91 | logvar1, logvar2 = [ 92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 93 | for x in (logvar1, logvar2) 94 | ] 95 | 96 | return 0.5 * ( 97 | -1.0 98 | + logvar2 99 | - logvar1 100 | + torch.exp(logvar1 - logvar2) 101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 102 | ) 103 | -------------------------------------------------------------------------------- /core/ldm/modules/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /core/ldm/modules/diffusionmodules/upscaling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from functools import partial 5 | 6 | from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule 7 | from ldm.util import default 8 | 9 | 10 | class AbstractLowScaleModel(nn.Module): 11 | # for concatenating a downsampled image to the latent representation 12 | def __init__(self, noise_schedule_config=None): 13 | super(AbstractLowScaleModel, self).__init__() 14 | if noise_schedule_config is not None: 15 | self.register_schedule(**noise_schedule_config) 16 | 17 | def register_schedule(self, beta_schedule="linear", timesteps=1000, 18 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 19 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, 20 | cosine_s=cosine_s) 21 | alphas = 1. - betas 22 | alphas_cumprod = np.cumprod(alphas, axis=0) 23 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 24 | 25 | timesteps, = betas.shape 26 | self.num_timesteps = int(timesteps) 27 | self.linear_start = linear_start 28 | self.linear_end = linear_end 29 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 30 | 31 | to_torch = partial(torch.tensor, dtype=torch.float32) 32 | 33 | self.register_buffer('betas', to_torch(betas)) 34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 35 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 36 | 37 | # calculations for diffusion q(x_t | x_{t-1}) and others 38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 43 | 44 | def q_sample(self, x_start, t, noise=None): 45 | noise = default(noise, lambda: torch.randn_like(x_start)) 46 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 47 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) 48 | 49 | def forward(self, x): 50 | return x, None 51 | 52 | def decode(self, x): 53 | return x 54 | 55 | 56 | class SimpleImageConcat(AbstractLowScaleModel): 57 | # no noise level conditioning 58 | def __init__(self): 59 | super(SimpleImageConcat, self).__init__(noise_schedule_config=None) 60 | self.max_noise_level = 0 61 | 62 | def forward(self, x): 63 | # fix to constant noise level 64 | return x, torch.zeros(x.shape[0], device=x.device).long() 65 | 66 | 67 | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): 68 | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): 69 | super().__init__(noise_schedule_config=noise_schedule_config) 70 | self.max_noise_level = max_noise_level 71 | 72 | def forward(self, x, noise_level=None): 73 | if noise_level is None: 74 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() 75 | else: 76 | assert isinstance(noise_level, torch.Tensor) 77 | z = self.q_sample(x, noise_level) 78 | return z, noise_level 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /core/sgm/modules/autoencoding/lpips/model/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | 26 | --------------------------- LICENSE FOR pix2pix -------------------------------- 27 | BSD License 28 | 29 | For pix2pix software 30 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu 31 | All rights reserved. 32 | 33 | Redistribution and use in source and binary forms, with or without 34 | modification, are permitted provided that the following conditions are met: 35 | 36 | * Redistributions of source code must retain the above copyright notice, this 37 | list of conditions and the following disclaimer. 38 | 39 | * Redistributions in binary form must reproduce the above copyright notice, 40 | this list of conditions and the following disclaimer in the documentation 41 | and/or other materials provided with the distribution. 42 | 43 | ----------------------------- LICENSE FOR DCGAN -------------------------------- 44 | BSD License 45 | 46 | For dcgan.torch software 47 | 48 | Copyright (c) 2015, Facebook, Inc. All rights reserved. 49 | 50 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 51 | 52 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 53 | 54 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 55 | 56 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 57 | 58 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /core/dydiff/datasets/turbulence.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datasets.turbulence.torch_dataset import Turbulence 3 | from datasets.turbulence.data_module import TurbulenceDataModule 4 | from datasets.normalizer import NullNormalizer 5 | 6 | 7 | class TurbulenceForDyDiff(Turbulence): 8 | def __init__(self, root, filename, total_length=20, input_length=10, no_oracle=False, n_subsample=None): 9 | self.input_length = input_length 10 | self.no_oracle = no_oracle 11 | super().__init__(root, filename, total_length, n_subsample) 12 | 13 | def __getitem__(self, index): 14 | data = super().__getitem__(index) 15 | data = NullNormalizer.normalize(data) # (T, C, H, W) 16 | 17 | # data = data.permute(1, 2, 0).to(memory_format=torch.contiguous_format).float() 18 | cond = data[:self.input_length] 19 | pred = data[self.input_length:] 20 | # if self.no_oracle: 21 | # prev = data[self.input_length-1:self.input_length].repeat(self.total_length-self.input_length, 1, 1, 1) 22 | # else: 23 | # prev = data[self.input_length-1:-1] 24 | prev = data[:self.input_length] 25 | cond = cond.reshape(-1, 64, 64).permute(1, 2, 0).to(memory_format=torch.contiguous_format).float() 26 | pred = pred.reshape(-1, 64, 64).permute(1, 2, 0).to(memory_format=torch.contiguous_format).float() 27 | prev = prev.reshape(-1, 64, 64).permute(1, 2, 0).to(memory_format=torch.contiguous_format).float() 28 | # cond_all = torch.cat([cond, topography], dim=-1) 29 | return dict(image=pred, prev=prev, cond=cond) 30 | 31 | class TurbulenceForDyDiffTest(Turbulence): 32 | def __init__(self, root, filename, total_length=20, input_length=10, no_oracle=False, n_subsample=None): 33 | self.input_length = input_length 34 | self.no_oracle = no_oracle 35 | # print(self.force_last_frame) 36 | super().__init__(root, filename, total_length, n_subsample) 37 | 38 | def __getitem__(self, index): 39 | data = super().__getitem__(index) 40 | data = NullNormalizer.normalize(data) # (T, C, H, W) 41 | 42 | # data = data.permute(1, 2, 0).to(memory_format=torch.contiguous_format).float() 43 | cond = data[:self.input_length] 44 | pred = data[self.input_length:self.total_length] 45 | # if self.no_oracle: 46 | # prev = data[self.input_length-1:self.input_length].repeat(self.total_length-self.input_length, 1, 1, 1) 47 | # else: 48 | # prev = data[self.input_length-1:-1] 49 | prev = data[:self.input_length] 50 | cond = cond.reshape(-1, 64, 64).permute(1, 2, 0).to(memory_format=torch.contiguous_format).float() 51 | pred = pred.reshape(-1, 64, 64).permute(1, 2, 0).to(memory_format=torch.contiguous_format).float() 52 | prev = prev.reshape(-1, 64, 64).permute(1, 2, 0).to(memory_format=torch.contiguous_format).float() 53 | total = data.reshape(-1, 64, 64).permute(1, 2, 0).to(memory_format=torch.contiguous_format).float() 54 | # cond_all = torch.cat([cond, topography], dim=-1) 55 | return dict(image=pred, prev=prev, cond=cond, total=total) 56 | 57 | 58 | class TurbulenceDataModuleForDyDiff(TurbulenceDataModule): 59 | def setup(self, stage=None): 60 | dataset_config = self.hparams.dataset_config 61 | self.train_dataset = TurbulenceForDyDiff( 62 | root=dataset_config.root, 63 | total_length=dataset_config.total_length, 64 | input_length=dataset_config.input_length, 65 | **dataset_config.train 66 | ) 67 | self.test_dataset = TurbulenceForDyDiffTest( 68 | root=dataset_config.root, 69 | total_length=dataset_config.total_length, 70 | input_length=dataset_config.input_length, 71 | **dataset_config.test 72 | ) 73 | print(f"Train dataset: {len(self.train_dataset)} samples") 74 | print(f"Test dataset: {len(self.test_dataset)} samples") 75 | -------------------------------------------------------------------------------- /core/sgm/modules/diffusionmodules/loss.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ...modules.autoencoding.lpips.loss.lpips import LPIPS 7 | from ...modules.encoders.modules import GeneralConditioner 8 | from ...util import append_dims, instantiate_from_config 9 | from .denoiser import Denoiser 10 | 11 | 12 | class StandardDiffusionLoss(nn.Module): 13 | def __init__( 14 | self, 15 | sigma_sampler_config: dict, 16 | loss_weighting_config: dict, 17 | loss_type: str = "l2", 18 | offset_noise_level: float = 0.0, 19 | batch2model_keys: Optional[Union[str, List[str]]] = None, 20 | ): 21 | super().__init__() 22 | 23 | assert loss_type in ["l2", "l1", "lpips"] 24 | 25 | self.sigma_sampler = instantiate_from_config(sigma_sampler_config) 26 | self.loss_weighting = instantiate_from_config(loss_weighting_config) 27 | 28 | self.loss_type = loss_type 29 | self.offset_noise_level = offset_noise_level 30 | 31 | if loss_type == "lpips": 32 | self.lpips = LPIPS().eval() 33 | 34 | if not batch2model_keys: 35 | batch2model_keys = [] 36 | 37 | if isinstance(batch2model_keys, str): 38 | batch2model_keys = [batch2model_keys] 39 | 40 | self.batch2model_keys = set(batch2model_keys) 41 | 42 | def get_noised_input( 43 | self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor 44 | ) -> torch.Tensor: 45 | noised_input = input + noise * sigmas_bc 46 | return noised_input 47 | 48 | def forward( 49 | self, 50 | network: nn.Module, 51 | denoiser: Denoiser, 52 | conditioner: GeneralConditioner, 53 | input: torch.Tensor, 54 | batch: Dict, 55 | ) -> torch.Tensor: 56 | cond = conditioner(batch) 57 | return self._forward(network, denoiser, cond, input, batch) 58 | 59 | def _forward( 60 | self, 61 | network: nn.Module, 62 | denoiser: Denoiser, 63 | cond: Dict, 64 | input: torch.Tensor, 65 | batch: Dict, 66 | ) -> Tuple[torch.Tensor, Dict]: 67 | additional_model_inputs = { 68 | key: batch[key] for key in self.batch2model_keys.intersection(batch) 69 | } 70 | sigmas = self.sigma_sampler(input.shape[0]).to(input) 71 | 72 | noise = torch.randn_like(input) 73 | if self.offset_noise_level > 0.0: 74 | offset_shape = ( 75 | (input.shape[0], 1, input.shape[2]) 76 | if self.n_frames is not None 77 | else (input.shape[0], input.shape[1]) 78 | ) 79 | noise = noise + self.offset_noise_level * append_dims( 80 | torch.randn(offset_shape, device=input.device), 81 | input.ndim, 82 | ) 83 | sigmas_bc = append_dims(sigmas, input.ndim) 84 | noised_input = self.get_noised_input(sigmas_bc, noise, input) 85 | 86 | model_output = denoiser( 87 | network, noised_input, sigmas, cond, **additional_model_inputs 88 | ) 89 | w = append_dims(self.loss_weighting(sigmas), input.ndim) 90 | return self.get_loss(model_output, input, w) 91 | 92 | def get_loss(self, model_output, target, w): 93 | if self.loss_type == "l2": 94 | return torch.mean( 95 | (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1 96 | ) 97 | elif self.loss_type == "l1": 98 | return torch.mean( 99 | (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1 100 | ) 101 | elif self.loss_type == "lpips": 102 | loss = self.lpips(model_output, target).reshape(-1) 103 | return loss 104 | else: 105 | raise NotImplementedError(f"Unknown loss type {self.loss_type}") 106 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | -------------------------------------------------------------------------------- /core/taming/data/sflckr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import albumentations 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class SegmentationBase(Dataset): 10 | def __init__(self, 11 | data_csv, data_root, segmentation_root, 12 | size=None, random_crop=False, interpolation="bicubic", 13 | n_labels=182, shift_segmentation=False, 14 | ): 15 | self.n_labels = n_labels 16 | self.shift_segmentation = shift_segmentation 17 | self.data_csv = data_csv 18 | self.data_root = data_root 19 | self.segmentation_root = segmentation_root 20 | with open(self.data_csv, "r") as f: 21 | self.image_paths = f.read().splitlines() 22 | self._length = len(self.image_paths) 23 | self.labels = { 24 | "relative_file_path_": [l for l in self.image_paths], 25 | "file_path_": [os.path.join(self.data_root, l) 26 | for l in self.image_paths], 27 | "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png")) 28 | for l in self.image_paths] 29 | } 30 | 31 | size = None if size is not None and size<=0 else size 32 | self.size = size 33 | if self.size is not None: 34 | self.interpolation = interpolation 35 | self.interpolation = { 36 | "nearest": cv2.INTER_NEAREST, 37 | "bilinear": cv2.INTER_LINEAR, 38 | "bicubic": cv2.INTER_CUBIC, 39 | "area": cv2.INTER_AREA, 40 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] 41 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 42 | interpolation=self.interpolation) 43 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 44 | interpolation=cv2.INTER_NEAREST) 45 | self.center_crop = not random_crop 46 | if self.center_crop: 47 | self.cropper = albumentations.CenterCrop(height=self.size, width=self.size) 48 | else: 49 | self.cropper = albumentations.RandomCrop(height=self.size, width=self.size) 50 | self.preprocessor = self.cropper 51 | 52 | def __len__(self): 53 | return self._length 54 | 55 | def __getitem__(self, i): 56 | example = dict((k, self.labels[k][i]) for k in self.labels) 57 | image = Image.open(example["file_path_"]) 58 | if not image.mode == "RGB": 59 | image = image.convert("RGB") 60 | image = np.array(image).astype(np.uint8) 61 | if self.size is not None: 62 | image = self.image_rescaler(image=image)["image"] 63 | segmentation = Image.open(example["segmentation_path_"]) 64 | assert segmentation.mode == "L", segmentation.mode 65 | segmentation = np.array(segmentation).astype(np.uint8) 66 | if self.shift_segmentation: 67 | # used to support segmentations containing unlabeled==255 label 68 | segmentation = segmentation+1 69 | if self.size is not None: 70 | segmentation = self.segmentation_rescaler(image=segmentation)["image"] 71 | if self.size is not None: 72 | processed = self.preprocessor(image=image, 73 | mask=segmentation 74 | ) 75 | else: 76 | processed = {"image": image, 77 | "mask": segmentation 78 | } 79 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) 80 | segmentation = processed["mask"] 81 | onehot = np.eye(self.n_labels)[segmentation] 82 | example["segmentation"] = onehot 83 | return example 84 | 85 | 86 | class Examples(SegmentationBase): 87 | def __init__(self, size=None, random_crop=False, interpolation="bicubic"): 88 | super().__init__(data_csv="data/sflckr_examples.txt", 89 | data_root="data/sflckr_images", 90 | segmentation_root="data/sflckr_segmentations", 91 | size=size, random_crop=random_crop, interpolation=interpolation) 92 | -------------------------------------------------------------------------------- /core/taming/data/conditional_builder/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from typing import List, Any, Tuple, Optional 3 | 4 | from taming.data.helper_types import BoundingBox, Annotation 5 | 6 | # source: seaborn, color palette tab10 7 | COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188), 8 | (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)] 9 | BLACK = (0, 0, 0) 10 | GRAY_75 = (63, 63, 63) 11 | GRAY_50 = (127, 127, 127) 12 | GRAY_25 = (191, 191, 191) 13 | WHITE = (255, 255, 255) 14 | FULL_CROP = (0., 0., 1., 1.) 15 | 16 | 17 | def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float: 18 | """ 19 | Give intersection area of two rectangles. 20 | @param rectangle1: (x0, y0, w, h) of first rectangle 21 | @param rectangle2: (x0, y0, w, h) of second rectangle 22 | """ 23 | rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3] 24 | rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3] 25 | x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0])) 26 | y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1])) 27 | return x_overlap * y_overlap 28 | 29 | 30 | def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox: 31 | return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3] 32 | 33 | 34 | def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]: 35 | bbox = relative_bbox 36 | bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height 37 | return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) 38 | 39 | 40 | def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List: 41 | return list_ + [pad_element for _ in range(pad_to_length - len(list_))] 42 | 43 | 44 | def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \ 45 | List[Annotation]: 46 | def clamp(x: float): 47 | return max(min(x, 1.), 0.) 48 | 49 | def rescale_bbox(bbox: BoundingBox) -> BoundingBox: 50 | x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) 51 | y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) 52 | w = min(bbox[2] / crop_coordinates[2], 1 - x0) 53 | h = min(bbox[3] / crop_coordinates[3], 1 - y0) 54 | if flip: 55 | x0 = 1 - (x0 + w) 56 | return x0, y0, w, h 57 | 58 | return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations] 59 | 60 | 61 | def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List: 62 | return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0] 63 | 64 | 65 | def additional_parameters_string(annotation: Annotation, short: bool = True) -> str: 66 | sl = slice(1) if short else slice(None) 67 | string = '' 68 | if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside): 69 | return string 70 | if annotation.is_group_of: 71 | string += 'group'[sl] + ',' 72 | if annotation.is_occluded: 73 | string += 'occluded'[sl] + ',' 74 | if annotation.is_depiction: 75 | string += 'depiction'[sl] + ',' 76 | if annotation.is_inside: 77 | string += 'inside'[sl] 78 | return '(' + string.strip(",") + ')' 79 | 80 | 81 | def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int: 82 | if font_size is None: 83 | font_size = 10 84 | if max(figure_size) >= 256: 85 | font_size = 12 86 | if max(figure_size) >= 512: 87 | font_size = 15 88 | return font_size 89 | 90 | 91 | def get_circle_size(figure_size: Tuple[int, int]) -> int: 92 | circle_size = 2 93 | if max(figure_size) >= 256: 94 | circle_size = 3 95 | if max(figure_size) >= 512: 96 | circle_size = 4 97 | return circle_size 98 | 99 | 100 | def load_object_from_string(object_string: str) -> Any: 101 | """ 102 | Source: https://stackoverflow.com/a/10773699 103 | """ 104 | module_name, class_name = object_string.rsplit(".", 1) 105 | return getattr(importlib.import_module(module_name), class_name) 106 | -------------------------------------------------------------------------------- /core/taming/modules/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def count_params(model): 6 | total_params = sum(p.numel() for p in model.parameters()) 7 | return total_params 8 | 9 | 10 | class ActNorm(nn.Module): 11 | def __init__(self, num_features, logdet=False, affine=True, 12 | allow_reverse_init=False): 13 | assert affine 14 | super().__init__() 15 | self.logdet = logdet 16 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 17 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 18 | self.allow_reverse_init = allow_reverse_init 19 | 20 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 21 | 22 | def initialize(self, input): 23 | with torch.no_grad(): 24 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 25 | mean = ( 26 | flatten.mean(1) 27 | .unsqueeze(1) 28 | .unsqueeze(2) 29 | .unsqueeze(3) 30 | .permute(1, 0, 2, 3) 31 | ) 32 | std = ( 33 | flatten.std(1) 34 | .unsqueeze(1) 35 | .unsqueeze(2) 36 | .unsqueeze(3) 37 | .permute(1, 0, 2, 3) 38 | ) 39 | 40 | self.loc.data.copy_(-mean) 41 | self.scale.data.copy_(1 / (std + 1e-6)) 42 | 43 | def forward(self, input, reverse=False): 44 | if reverse: 45 | return self.reverse(input) 46 | if len(input.shape) == 2: 47 | input = input[:,:,None,None] 48 | squeeze = True 49 | else: 50 | squeeze = False 51 | 52 | _, _, height, width = input.shape 53 | 54 | if self.training and self.initialized.item() == 0: 55 | self.initialize(input) 56 | self.initialized.fill_(1) 57 | 58 | h = self.scale * (input + self.loc) 59 | 60 | if squeeze: 61 | h = h.squeeze(-1).squeeze(-1) 62 | 63 | if self.logdet: 64 | log_abs = torch.log(torch.abs(self.scale)) 65 | logdet = height*width*torch.sum(log_abs) 66 | logdet = logdet * torch.ones(input.shape[0]).to(input) 67 | return h, logdet 68 | 69 | return h 70 | 71 | def reverse(self, output): 72 | if self.training and self.initialized.item() == 0: 73 | if not self.allow_reverse_init: 74 | raise RuntimeError( 75 | "Initializing ActNorm in reverse direction is " 76 | "disabled by default. Use allow_reverse_init=True to enable." 77 | ) 78 | else: 79 | self.initialize(output) 80 | self.initialized.fill_(1) 81 | 82 | if len(output.shape) == 2: 83 | output = output[:,:,None,None] 84 | squeeze = True 85 | else: 86 | squeeze = False 87 | 88 | h = output / self.scale - self.loc 89 | 90 | if squeeze: 91 | h = h.squeeze(-1).squeeze(-1) 92 | return h 93 | 94 | 95 | class AbstractEncoder(nn.Module): 96 | def __init__(self): 97 | super().__init__() 98 | 99 | def encode(self, *args, **kwargs): 100 | raise NotImplementedError 101 | 102 | 103 | class Labelator(AbstractEncoder): 104 | """Net2Net Interface for Class-Conditional Model""" 105 | def __init__(self, n_classes, quantize_interface=True): 106 | super().__init__() 107 | self.n_classes = n_classes 108 | self.quantize_interface = quantize_interface 109 | 110 | def encode(self, c): 111 | c = c[:,None] 112 | if self.quantize_interface: 113 | return c, None, [None, None, c.long()] 114 | return c 115 | 116 | 117 | class SOSProvider(AbstractEncoder): 118 | # for unconditional training 119 | def __init__(self, sos_token, quantize_interface=True): 120 | super().__init__() 121 | self.sos_token = sos_token 122 | self.quantize_interface = quantize_interface 123 | 124 | def encode(self, x): 125 | # get batch size from data and replicate sos_token 126 | c = torch.ones(x.shape[0], 1)*self.sos_token 127 | c = c.long().to(x.device) 128 | if self.quantize_interface: 129 | return c, None, [None, None, c] 130 | return c 131 | -------------------------------------------------------------------------------- /core/sgm/modules/autoencoding/lpips/util.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | 4 | import requests 5 | import torch 6 | import torch.nn as nn 7 | from tqdm import tqdm 8 | 9 | URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} 10 | 11 | CKPT_MAP = {"vgg_lpips": "vgg.pth"} 12 | 13 | MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} 14 | 15 | 16 | def download(url, local_path, chunk_size=1024): 17 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 18 | with requests.get(url, stream=True) as r: 19 | total_size = int(r.headers.get("content-length", 0)) 20 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 21 | with open(local_path, "wb") as f: 22 | for data in r.iter_content(chunk_size=chunk_size): 23 | if data: 24 | f.write(data) 25 | pbar.update(chunk_size) 26 | 27 | 28 | def md5_hash(path): 29 | with open(path, "rb") as f: 30 | content = f.read() 31 | return hashlib.md5(content).hexdigest() 32 | 33 | 34 | def get_ckpt_path(name, root, check=False): 35 | assert name in URL_MAP 36 | path = os.path.join(root, CKPT_MAP[name]) 37 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 38 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 39 | download(URL_MAP[name], path) 40 | md5 = md5_hash(path) 41 | assert md5 == MD5_MAP[name], md5 42 | return path 43 | 44 | 45 | class ActNorm(nn.Module): 46 | def __init__( 47 | self, num_features, logdet=False, affine=True, allow_reverse_init=False 48 | ): 49 | assert affine 50 | super().__init__() 51 | self.logdet = logdet 52 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 53 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 54 | self.allow_reverse_init = allow_reverse_init 55 | 56 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) 57 | 58 | def initialize(self, input): 59 | with torch.no_grad(): 60 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 61 | mean = ( 62 | flatten.mean(1) 63 | .unsqueeze(1) 64 | .unsqueeze(2) 65 | .unsqueeze(3) 66 | .permute(1, 0, 2, 3) 67 | ) 68 | std = ( 69 | flatten.std(1) 70 | .unsqueeze(1) 71 | .unsqueeze(2) 72 | .unsqueeze(3) 73 | .permute(1, 0, 2, 3) 74 | ) 75 | 76 | self.loc.data.copy_(-mean) 77 | self.scale.data.copy_(1 / (std + 1e-6)) 78 | 79 | def forward(self, input, reverse=False): 80 | if reverse: 81 | return self.reverse(input) 82 | if len(input.shape) == 2: 83 | input = input[:, :, None, None] 84 | squeeze = True 85 | else: 86 | squeeze = False 87 | 88 | _, _, height, width = input.shape 89 | 90 | if self.training and self.initialized.item() == 0: 91 | self.initialize(input) 92 | self.initialized.fill_(1) 93 | 94 | h = self.scale * (input + self.loc) 95 | 96 | if squeeze: 97 | h = h.squeeze(-1).squeeze(-1) 98 | 99 | if self.logdet: 100 | log_abs = torch.log(torch.abs(self.scale)) 101 | logdet = height * width * torch.sum(log_abs) 102 | logdet = logdet * torch.ones(input.shape[0]).to(input) 103 | return h, logdet 104 | 105 | return h 106 | 107 | def reverse(self, output): 108 | if self.training and self.initialized.item() == 0: 109 | if not self.allow_reverse_init: 110 | raise RuntimeError( 111 | "Initializing ActNorm in reverse direction is " 112 | "disabled by default. Use allow_reverse_init=True to enable." 113 | ) 114 | else: 115 | self.initialize(output) 116 | self.initialized.fill_(1) 117 | 118 | if len(output.shape) == 2: 119 | output = output[:, :, None, None] 120 | squeeze = True 121 | else: 122 | squeeze = False 123 | 124 | h = output / self.scale - self.loc 125 | 126 | if squeeze: 127 | h = h.squeeze(-1).squeeze(-1) 128 | return h 129 | -------------------------------------------------------------------------------- /core/logger/logger_turbulence.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torchvision 6 | from PIL import Image 7 | from pytorch_lightning.callbacks import Callback 8 | from pytorch_lightning.utilities.distributed import rank_zero_only 9 | 10 | from datasets.normalizer import NullNormalizer 11 | from logger.visualization_turbulence import save_plots 12 | 13 | 14 | def save_img(visualization_data, path, filename_prefix): 15 | # visualization_data: (c, h, W) 16 | filenames = [f"{filename_prefix}_{i:0>2d}.png" for i in range(visualization_data.shape[0])] 17 | save_plots(filenames, visualization_data, path) 18 | 19 | 20 | class ImageLoggerWithKeyToConcat(Callback): 21 | def __init__(self, save_dir='results', keys_to_concat=None, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True, 22 | rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, 23 | log_images_kwargs=None): 24 | super().__init__() 25 | self.save_dir = save_dir 26 | self.keys_to_concat = keys_to_concat 27 | self.rescale = rescale 28 | self.batch_freq = batch_frequency 29 | self.max_images = max_images 30 | if not increase_log_steps: 31 | self.log_steps = [self.batch_freq] 32 | self.clamp = clamp 33 | self.disabled = disabled 34 | self.log_on_batch_idx = log_on_batch_idx 35 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} 36 | self.log_first_step = log_first_step 37 | 38 | @rank_zero_only 39 | def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): 40 | root = os.path.join(self.save_dir, "image_log", split, "step-{:06}".format(global_step)) 41 | if self.keys_to_concat is not None: 42 | data_to_plot = torch.cat([images[k] for k in self.keys_to_concat], dim=3) 43 | data_to_plot = NullNormalizer.denormalize(data_to_plot) 44 | for idx, datum_to_plot in enumerate(data_to_plot): 45 | _, h, w = datum_to_plot.shape 46 | datum_to_plot = datum_to_plot.reshape(-1, 2, h, w) 47 | sample_idx = batch_idx * data_to_plot.shape[0] + idx 48 | batch_root = os.path.join(root, str(sample_idx)) 49 | filename = "{:06}".format(sample_idx) 50 | save_img(datum_to_plot.cpu().numpy(), batch_root, filename) 51 | 52 | for k in images: 53 | if k != "diffusion_row": 54 | data_to_plot = images[k] 55 | data_to_plot = NullNormalizer.denormalize(data_to_plot) 56 | for idx, datum_to_plot in enumerate(data_to_plot): 57 | sample_idx = batch_idx * data_to_plot.shape[0] + idx 58 | batch_root = os.path.join(root, k, str(sample_idx)) 59 | filename = "{:06}".format(sample_idx) 60 | save_img(datum_to_plot.cpu().numpy(), batch_root, filename) 61 | 62 | def log_img(self, pl_module, batch, batch_idx, split="train"): 63 | check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step 64 | if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 65 | hasattr(pl_module, "log_images") and 66 | callable(pl_module.log_images) and 67 | self.max_images > 0): 68 | logger = type(pl_module.logger) 69 | 70 | is_train = pl_module.training 71 | if is_train: 72 | pl_module.eval() 73 | 74 | with torch.no_grad(): 75 | images = pl_module.log_images(batch, split=split, plot_progressive_rows=False, **self.log_images_kwargs) 76 | images = {k:v for k, v in images.items() if not k.startswith("z_")} 77 | 78 | for k in images: 79 | N = min(images[k].shape[0], self.max_images) 80 | images[k] = images[k][:N] 81 | if isinstance(images[k], torch.Tensor): 82 | images[k] = images[k].detach().cpu() 83 | if self.clamp: 84 | images[k] = torch.clamp(images[k], -1., 1.) 85 | 86 | self.log_local(self.save_dir, split, images, 87 | pl_module.global_step, pl_module.current_epoch, batch_idx) 88 | 89 | if is_train: 90 | pl_module.train() 91 | 92 | def check_frequency(self, check_idx): 93 | return check_idx % self.batch_freq == 0 94 | 95 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 96 | if not self.disabled: 97 | self.log_img(pl_module, batch, batch_idx, split="train") 98 | -------------------------------------------------------------------------------- /core/sgm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | 9 | def __init__( 10 | self, 11 | warm_up_steps, 12 | lr_min, 13 | lr_max, 14 | lr_start, 15 | max_decay_steps, 16 | verbosity_interval=0, 17 | ): 18 | self.lr_warm_up_steps = warm_up_steps 19 | self.lr_start = lr_start 20 | self.lr_min = lr_min 21 | self.lr_max = lr_max 22 | self.lr_max_decay_steps = max_decay_steps 23 | self.last_lr = 0.0 24 | self.verbosity_interval = verbosity_interval 25 | 26 | def schedule(self, n, **kwargs): 27 | if self.verbosity_interval > 0: 28 | if n % self.verbosity_interval == 0: 29 | print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 30 | if n < self.lr_warm_up_steps: 31 | lr = ( 32 | self.lr_max - self.lr_start 33 | ) / self.lr_warm_up_steps * n + self.lr_start 34 | self.last_lr = lr 35 | return lr 36 | else: 37 | t = (n - self.lr_warm_up_steps) / ( 38 | self.lr_max_decay_steps - self.lr_warm_up_steps 39 | ) 40 | t = min(t, 1.0) 41 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 42 | 1 + np.cos(t * np.pi) 43 | ) 44 | self.last_lr = lr 45 | return lr 46 | 47 | def __call__(self, n, **kwargs): 48 | return self.schedule(n, **kwargs) 49 | 50 | 51 | class LambdaWarmUpCosineScheduler2: 52 | """ 53 | supports repeated iterations, configurable via lists 54 | note: use with a base_lr of 1.0. 55 | """ 56 | 57 | def __init__( 58 | self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0 59 | ): 60 | assert ( 61 | len(warm_up_steps) 62 | == len(f_min) 63 | == len(f_max) 64 | == len(f_start) 65 | == len(cycle_lengths) 66 | ) 67 | self.lr_warm_up_steps = warm_up_steps 68 | self.f_start = f_start 69 | self.f_min = f_min 70 | self.f_max = f_max 71 | self.cycle_lengths = cycle_lengths 72 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 73 | self.last_f = 0.0 74 | self.verbosity_interval = verbosity_interval 75 | 76 | def find_in_interval(self, n): 77 | interval = 0 78 | for cl in self.cum_cycles[1:]: 79 | if n <= cl: 80 | return interval 81 | interval += 1 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: 88 | print( 89 | f"current step: {n}, recent lr-multiplier: {self.last_f}, " 90 | f"current cycle {cycle}" 91 | ) 92 | if n < self.lr_warm_up_steps[cycle]: 93 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ 94 | cycle 95 | ] * n + self.f_start[cycle] 96 | self.last_f = f 97 | return f 98 | else: 99 | t = (n - self.lr_warm_up_steps[cycle]) / ( 100 | self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] 101 | ) 102 | t = min(t, 1.0) 103 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 104 | 1 + np.cos(t * np.pi) 105 | ) 106 | self.last_f = f 107 | return f 108 | 109 | def __call__(self, n, **kwargs): 110 | return self.schedule(n, **kwargs) 111 | 112 | 113 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 114 | def schedule(self, n, **kwargs): 115 | cycle = self.find_in_interval(n) 116 | n = n - self.cum_cycles[cycle] 117 | if self.verbosity_interval > 0: 118 | if n % self.verbosity_interval == 0: 119 | print( 120 | f"current step: {n}, recent lr-multiplier: {self.last_f}, " 121 | f"current cycle {cycle}" 122 | ) 123 | 124 | if n < self.lr_warm_up_steps[cycle]: 125 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ 126 | cycle 127 | ] * n + self.f_start[cycle] 128 | self.last_f = f 129 | return f 130 | else: 131 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * ( 132 | self.cycle_lengths[cycle] - n 133 | ) / (self.cycle_lengths[cycle]) 134 | self.last_f = f 135 | return f 136 | -------------------------------------------------------------------------------- /core/ldm/modules/diffusionmodules/warping_unet.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Optional 3 | import math 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from ldm.modules.diffusionmodules.util import ( 11 | checkpoint, 12 | conv_nd, 13 | linear, 14 | avg_pool_nd, 15 | zero_module, 16 | normalization, 17 | timestep_embedding, 18 | ) 19 | from ldm.modules.attention import SpatialTransformer 20 | from ldm.modules.video_attention import SpatialVideoTransformer 21 | from ldm.util import exists 22 | 23 | from ldm.modules.diffusionmodules.openaimodel import UNetModel 24 | from einops import rearrange 25 | 26 | 27 | def make_grid(input): 28 | B, C, H, W = input.size() 29 | xx = th.arange(0, W, device=input.device).view(1, -1).repeat(H, 1) 30 | yy = th.arange(0, H, device=input.device).view(-1, 1).repeat(1, W) 31 | xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1) 32 | yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1) 33 | grid = th.cat((xx, yy), 1).float() 34 | return grid 35 | 36 | def warp(input, flow, grid, mode="bilinear", padding_mode="border"): 37 | B, C, H, W = input.size() 38 | vgrid = grid + flow 39 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0 40 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0 41 | vgrid = vgrid.permute(0, 2, 3, 1) 42 | # output = torch.nn.functional.grid_sample(input, vgrid) 43 | output = F.grid_sample(input, vgrid, padding_mode=padding_mode, mode=mode, align_corners=True) 44 | return output 45 | 46 | class WarpingUNetModel(UNetModel): 47 | def __init__( 48 | self, 49 | image_size, 50 | in_channels, 51 | model_channels, 52 | out_channels, 53 | num_res_blocks, 54 | attention_resolutions, 55 | dropout=0, 56 | channel_mult=(1, 2, 4, 8), 57 | conv_resample=True, 58 | dims=2, 59 | num_classes=None, # duplicated in ControlNet 60 | use_checkpoint=False, 61 | use_fp16=False, 62 | num_heads=-1, 63 | num_head_channels=-1, 64 | num_heads_upsample=-1, 65 | use_scale_shift_norm=False, 66 | resblock_updown=False, 67 | use_new_attention_order=False, 68 | use_spatial_transformer=False, # custom transformer support 69 | transformer_depth=1, # custom transformer support 70 | context_dim=None, # custom transformer support 71 | n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model 72 | legacy=True, 73 | disable_self_attentions=None, 74 | num_attention_blocks=None, 75 | disable_middle_self_attn=False, 76 | use_linear_in_transformer=False, 77 | x_channels=3, 78 | ): 79 | self.motion_channels = out_channels // x_channels * 2 80 | self.density_channels = out_channels 81 | self.x_channels = x_channels 82 | self.num_frames = out_channels // x_channels 83 | super().__init__( 84 | image_size, 85 | in_channels, 86 | model_channels, 87 | self.motion_channels + self.density_channels, 88 | num_res_blocks, 89 | attention_resolutions, 90 | dropout, 91 | channel_mult, 92 | conv_resample, 93 | dims, 94 | num_classes, 95 | use_checkpoint, 96 | use_fp16, 97 | num_heads, 98 | num_head_channels, 99 | num_heads_upsample, 100 | use_scale_shift_norm, 101 | resblock_updown, 102 | use_new_attention_order, 103 | use_spatial_transformer, 104 | transformer_depth, 105 | context_dim, 106 | n_embed, 107 | legacy, 108 | disable_self_attentions, 109 | num_attention_blocks, 110 | disable_middle_self_attn, 111 | use_linear_in_transformer, 112 | ) 113 | 114 | def forward(self, x, timesteps=None, context=None, y=None,**kwargs): 115 | out = super().forward(x, timesteps, context, y, **kwargs) 116 | # print(out.shape, self.motion_channels) 117 | motion, residual = out[:, :self.motion_channels], out[:, self.motion_channels:] 118 | last_frame = x[:, -self.x_channels:] 119 | grid = make_grid(last_frame) 120 | output_list = [] 121 | for i in range(self.num_frames): 122 | motion_i = motion[:, i * 2: (i + 1) * 2] 123 | residual_i = residual[:, i * self.x_channels: (i + 1) * self.x_channels] 124 | # print(last_frame.shape, motion_i.shape, residual_i.shape, grid.shape) 125 | last_frame = warp(last_frame, motion_i, grid) + residual_i 126 | output_list.append(last_frame) 127 | output = th.cat(output_list, 1) 128 | return output -------------------------------------------------------------------------------- /core/taming/data/image_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import warnings 3 | from typing import Union 4 | 5 | import torch 6 | from torch import Tensor 7 | from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor 8 | from torchvision.transforms.functional import _get_image_size as get_image_size 9 | 10 | from taming.data.helper_types import BoundingBox, Image 11 | 12 | pil_to_tensor = PILToTensor() 13 | 14 | 15 | def convert_pil_to_tensor(image: Image) -> Tensor: 16 | with warnings.catch_warnings(): 17 | # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194 18 | warnings.simplefilter("ignore") 19 | return pil_to_tensor(image) 20 | 21 | 22 | class RandomCrop1dReturnCoordinates(RandomCrop): 23 | def forward(self, img: Image) -> (BoundingBox, Image): 24 | """ 25 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 26 | Args: 27 | img (PIL Image or Tensor): Image to be cropped. 28 | 29 | Returns: 30 | Bounding box: x0, y0, w, h 31 | PIL Image or Tensor: Cropped image. 32 | 33 | Based on: 34 | torchvision.transforms.RandomCrop, torchvision 1.7.0 35 | """ 36 | if self.padding is not None: 37 | img = F.pad(img, self.padding, self.fill, self.padding_mode) 38 | 39 | width, height = get_image_size(img) 40 | # pad the width if needed 41 | if self.pad_if_needed and width < self.size[1]: 42 | padding = [self.size[1] - width, 0] 43 | img = F.pad(img, padding, self.fill, self.padding_mode) 44 | # pad the height if needed 45 | if self.pad_if_needed and height < self.size[0]: 46 | padding = [0, self.size[0] - height] 47 | img = F.pad(img, padding, self.fill, self.padding_mode) 48 | 49 | i, j, h, w = self.get_params(img, self.size) 50 | bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h 51 | return bbox, F.crop(img, i, j, h, w) 52 | 53 | 54 | class Random2dCropReturnCoordinates(torch.nn.Module): 55 | """ 56 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 57 | Args: 58 | img (PIL Image or Tensor): Image to be cropped. 59 | 60 | Returns: 61 | Bounding box: x0, y0, w, h 62 | PIL Image or Tensor: Cropped image. 63 | 64 | Based on: 65 | torchvision.transforms.RandomCrop, torchvision 1.7.0 66 | """ 67 | 68 | def __init__(self, min_size: int): 69 | super().__init__() 70 | self.min_size = min_size 71 | 72 | def forward(self, img: Image) -> (BoundingBox, Image): 73 | width, height = get_image_size(img) 74 | max_size = min(width, height) 75 | if max_size <= self.min_size: 76 | size = max_size 77 | else: 78 | size = random.randint(self.min_size, max_size) 79 | top = random.randint(0, height - size) 80 | left = random.randint(0, width - size) 81 | bbox = left / width, top / height, size / width, size / height 82 | return bbox, F.crop(img, top, left, size, size) 83 | 84 | 85 | class CenterCropReturnCoordinates(CenterCrop): 86 | @staticmethod 87 | def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox: 88 | if width > height: 89 | w = height / width 90 | h = 1.0 91 | x0 = 0.5 - w / 2 92 | y0 = 0. 93 | else: 94 | w = 1.0 95 | h = width / height 96 | x0 = 0. 97 | y0 = 0.5 - h / 2 98 | return x0, y0, w, h 99 | 100 | def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]): 101 | """ 102 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 103 | Args: 104 | img (PIL Image or Tensor): Image to be cropped. 105 | 106 | Returns: 107 | Bounding box: x0, y0, w, h 108 | PIL Image or Tensor: Cropped image. 109 | Based on: 110 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0) 111 | """ 112 | width, height = get_image_size(img) 113 | return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size) 114 | 115 | 116 | class RandomHorizontalFlipReturn(RandomHorizontalFlip): 117 | def forward(self, img: Image) -> (bool, Image): 118 | """ 119 | Additionally to flipping, returns a boolean whether it was flipped or not. 120 | Args: 121 | img (PIL Image or Tensor): Image to be flipped. 122 | 123 | Returns: 124 | flipped: whether the image was flipped or not 125 | PIL Image or Tensor: Randomly flipped image. 126 | 127 | Based on: 128 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0) 129 | """ 130 | if torch.rand(1) < self.p: 131 | return True, F.hflip(img) 132 | return False, img 133 | -------------------------------------------------------------------------------- /core/taming/data/faceshq.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import albumentations 4 | from torch.utils.data import Dataset 5 | 6 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex 7 | 8 | 9 | class FacesBase(Dataset): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__() 12 | self.data = None 13 | self.keys = None 14 | 15 | def __len__(self): 16 | return len(self.data) 17 | 18 | def __getitem__(self, i): 19 | example = self.data[i] 20 | ex = {} 21 | if self.keys is not None: 22 | for k in self.keys: 23 | ex[k] = example[k] 24 | else: 25 | ex = example 26 | return ex 27 | 28 | 29 | class CelebAHQTrain(FacesBase): 30 | def __init__(self, size, keys=None): 31 | super().__init__() 32 | root = "data/celebahq" 33 | with open("data/celebahqtrain.txt", "r") as f: 34 | relpaths = f.read().splitlines() 35 | paths = [os.path.join(root, relpath) for relpath in relpaths] 36 | self.data = NumpyPaths(paths=paths, size=size, random_crop=False) 37 | self.keys = keys 38 | 39 | 40 | class CelebAHQValidation(FacesBase): 41 | def __init__(self, size, keys=None): 42 | super().__init__() 43 | root = "data/celebahq" 44 | with open("data/celebahqvalidation.txt", "r") as f: 45 | relpaths = f.read().splitlines() 46 | paths = [os.path.join(root, relpath) for relpath in relpaths] 47 | self.data = NumpyPaths(paths=paths, size=size, random_crop=False) 48 | self.keys = keys 49 | 50 | 51 | class FFHQTrain(FacesBase): 52 | def __init__(self, size, keys=None): 53 | super().__init__() 54 | root = "data/ffhq" 55 | with open("data/ffhqtrain.txt", "r") as f: 56 | relpaths = f.read().splitlines() 57 | paths = [os.path.join(root, relpath) for relpath in relpaths] 58 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 59 | self.keys = keys 60 | 61 | 62 | class FFHQValidation(FacesBase): 63 | def __init__(self, size, keys=None): 64 | super().__init__() 65 | root = "data/ffhq" 66 | with open("data/ffhqvalidation.txt", "r") as f: 67 | relpaths = f.read().splitlines() 68 | paths = [os.path.join(root, relpath) for relpath in relpaths] 69 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 70 | self.keys = keys 71 | 72 | 73 | class FacesHQTrain(Dataset): 74 | # CelebAHQ [0] + FFHQ [1] 75 | def __init__(self, size, keys=None, crop_size=None, coord=False): 76 | d1 = CelebAHQTrain(size=size, keys=keys) 77 | d2 = FFHQTrain(size=size, keys=keys) 78 | self.data = ConcatDatasetWithIndex([d1, d2]) 79 | self.coord = coord 80 | if crop_size is not None: 81 | self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) 82 | if self.coord: 83 | self.cropper = albumentations.Compose([self.cropper], 84 | additional_targets={"coord": "image"}) 85 | 86 | def __len__(self): 87 | return len(self.data) 88 | 89 | def __getitem__(self, i): 90 | ex, y = self.data[i] 91 | if hasattr(self, "cropper"): 92 | if not self.coord: 93 | out = self.cropper(image=ex["image"]) 94 | ex["image"] = out["image"] 95 | else: 96 | h,w,_ = ex["image"].shape 97 | coord = np.arange(h*w).reshape(h,w,1)/(h*w) 98 | out = self.cropper(image=ex["image"], coord=coord) 99 | ex["image"] = out["image"] 100 | ex["coord"] = out["coord"] 101 | ex["class"] = y 102 | return ex 103 | 104 | 105 | class FacesHQValidation(Dataset): 106 | # CelebAHQ [0] + FFHQ [1] 107 | def __init__(self, size, keys=None, crop_size=None, coord=False): 108 | d1 = CelebAHQValidation(size=size, keys=keys) 109 | d2 = FFHQValidation(size=size, keys=keys) 110 | self.data = ConcatDatasetWithIndex([d1, d2]) 111 | self.coord = coord 112 | if crop_size is not None: 113 | self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) 114 | if self.coord: 115 | self.cropper = albumentations.Compose([self.cropper], 116 | additional_targets={"coord": "image"}) 117 | 118 | def __len__(self): 119 | return len(self.data) 120 | 121 | def __getitem__(self, i): 122 | ex, y = self.data[i] 123 | if hasattr(self, "cropper"): 124 | if not self.coord: 125 | out = self.cropper(image=ex["image"]) 126 | ex["image"] = out["image"] 127 | else: 128 | h,w,_ = ex["image"].shape 129 | coord = np.arange(h*w).reshape(h,w,1)/(h*w) 130 | out = self.cropper(image=ex["image"], coord=coord) 131 | ex["image"] = out["image"] 132 | ex["coord"] = out["coord"] 133 | ex["class"] = y 134 | return ex 135 | -------------------------------------------------------------------------------- /core/taming/modules/losses/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | 8 | from taming.util import get_ckpt_path 9 | 10 | 11 | class LPIPS(nn.Module): 12 | # Learned perceptual metric 13 | def __init__(self, use_dropout=True): 14 | super().__init__() 15 | self.scaling_layer = ScalingLayer() 16 | self.chns = [64, 128, 256, 512, 512] # vg16 features 17 | self.net = vgg16(pretrained=True, requires_grad=False) 18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 23 | self.load_from_pretrained() 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def load_from_pretrained(self, name="vgg_lpips"): 28 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") 29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 30 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 31 | 32 | @classmethod 33 | def from_pretrained(cls, name="vgg_lpips"): 34 | if name != "vgg_lpips": 35 | raise NotImplementedError 36 | model = cls() 37 | ckpt = get_ckpt_path(name) 38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 39 | return model 40 | 41 | def forward(self, input, target): 42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 43 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 44 | feats0, feats1, diffs = {}, {}, {} 45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 46 | for kk in range(len(self.chns)): 47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 49 | 50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 51 | val = res[0] 52 | for l in range(1, len(self.chns)): 53 | val += res[l] 54 | return val 55 | 56 | 57 | class ScalingLayer(nn.Module): 58 | def __init__(self): 59 | super(ScalingLayer, self).__init__() 60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 62 | 63 | def forward(self, inp): 64 | return (inp - self.shift) / self.scale 65 | 66 | 67 | class NetLinLayer(nn.Module): 68 | """ A single linear layer which does a 1x1 conv """ 69 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 70 | super(NetLinLayer, self).__init__() 71 | layers = [nn.Dropout(), ] if (use_dropout) else [] 72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 73 | self.model = nn.Sequential(*layers) 74 | 75 | 76 | class vgg16(torch.nn.Module): 77 | def __init__(self, requires_grad=False, pretrained=True): 78 | super(vgg16, self).__init__() 79 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 80 | self.slice1 = torch.nn.Sequential() 81 | self.slice2 = torch.nn.Sequential() 82 | self.slice3 = torch.nn.Sequential() 83 | self.slice4 = torch.nn.Sequential() 84 | self.slice5 = torch.nn.Sequential() 85 | self.N_slices = 5 86 | for x in range(4): 87 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 88 | for x in range(4, 9): 89 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 90 | for x in range(9, 16): 91 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 92 | for x in range(16, 23): 93 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 94 | for x in range(23, 30): 95 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 96 | if not requires_grad: 97 | for param in self.parameters(): 98 | param.requires_grad = False 99 | 100 | def forward(self, X): 101 | h = self.slice1(X) 102 | h_relu1_2 = h 103 | h = self.slice2(h) 104 | h_relu2_2 = h 105 | h = self.slice3(h) 106 | h_relu3_3 = h 107 | h = self.slice4(h) 108 | h_relu4_3 = h 109 | h = self.slice5(h) 110 | h_relu5_3 = h 111 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 112 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 113 | return out 114 | 115 | 116 | def normalize_tensor(x,eps=1e-10): 117 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 118 | return x/(norm_factor+eps) 119 | 120 | 121 | def spatial_average(x, keepdim=True): 122 | return x.mean([2,3],keepdim=keepdim) 123 | 124 | -------------------------------------------------------------------------------- /core/taming/util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 7 | } 8 | 9 | CKPT_MAP = { 10 | "vgg_lpips": "vgg.pth" 11 | } 12 | 13 | MD5_MAP = { 14 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 15 | } 16 | 17 | 18 | def download(url, local_path, chunk_size=1024): 19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 20 | with requests.get(url, stream=True) as r: 21 | total_size = int(r.headers.get("content-length", 0)) 22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 23 | with open(local_path, "wb") as f: 24 | for data in r.iter_content(chunk_size=chunk_size): 25 | if data: 26 | f.write(data) 27 | pbar.update(chunk_size) 28 | 29 | 30 | def md5_hash(path): 31 | with open(path, "rb") as f: 32 | content = f.read() 33 | return hashlib.md5(content).hexdigest() 34 | 35 | 36 | def get_ckpt_path(name, root, check=False): 37 | assert name in URL_MAP 38 | path = os.path.join(root, CKPT_MAP[name]) 39 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 40 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 41 | download(URL_MAP[name], path) 42 | md5 = md5_hash(path) 43 | assert md5 == MD5_MAP[name], md5 44 | return path 45 | 46 | 47 | class KeyNotFoundError(Exception): 48 | def __init__(self, cause, keys=None, visited=None): 49 | self.cause = cause 50 | self.keys = keys 51 | self.visited = visited 52 | messages = list() 53 | if keys is not None: 54 | messages.append("Key not found: {}".format(keys)) 55 | if visited is not None: 56 | messages.append("Visited: {}".format(visited)) 57 | messages.append("Cause:\n{}".format(cause)) 58 | message = "\n".join(messages) 59 | super().__init__(message) 60 | 61 | 62 | def retrieve( 63 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 64 | ): 65 | """Given a nested list or dict return the desired value at key expanding 66 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 67 | is done in-place. 68 | 69 | Parameters 70 | ---------- 71 | list_or_dict : list or dict 72 | Possibly nested list or dictionary. 73 | key : str 74 | key/to/value, path like string describing all keys necessary to 75 | consider to get to the desired value. List indices can also be 76 | passed here. 77 | splitval : str 78 | String that defines the delimiter between keys of the 79 | different depth levels in `key`. 80 | default : obj 81 | Value returned if :attr:`key` is not found. 82 | expand : bool 83 | Whether to expand callable nodes on the path or not. 84 | 85 | Returns 86 | ------- 87 | The desired value or if :attr:`default` is not ``None`` and the 88 | :attr:`key` is not found returns ``default``. 89 | 90 | Raises 91 | ------ 92 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 93 | ``None``. 94 | """ 95 | 96 | keys = key.split(splitval) 97 | 98 | success = True 99 | try: 100 | visited = [] 101 | parent = None 102 | last_key = None 103 | for key in keys: 104 | if callable(list_or_dict): 105 | if not expand: 106 | raise KeyNotFoundError( 107 | ValueError( 108 | "Trying to get past callable node with expand=False." 109 | ), 110 | keys=keys, 111 | visited=visited, 112 | ) 113 | list_or_dict = list_or_dict() 114 | parent[last_key] = list_or_dict 115 | 116 | last_key = key 117 | parent = list_or_dict 118 | 119 | try: 120 | if isinstance(list_or_dict, dict): 121 | list_or_dict = list_or_dict[key] 122 | else: 123 | list_or_dict = list_or_dict[int(key)] 124 | except (KeyError, IndexError, ValueError) as e: 125 | raise KeyNotFoundError(e, keys=keys, visited=visited) 126 | 127 | visited += [key] 128 | # final expansion of retrieved value 129 | if expand and callable(list_or_dict): 130 | list_or_dict = list_or_dict() 131 | parent[last_key] = list_or_dict 132 | except KeyNotFoundError as e: 133 | if default is None: 134 | raise e 135 | else: 136 | list_or_dict = default 137 | success = False 138 | 139 | if not pass_success: 140 | return list_or_dict 141 | else: 142 | return list_or_dict, success 143 | 144 | 145 | if __name__ == "__main__": 146 | config = {"keya": "a", 147 | "keyb": "b", 148 | "keyc": 149 | {"cc1": 1, 150 | "cc2": 2, 151 | } 152 | } 153 | from omegaconf import OmegaConf 154 | config = OmegaConf.create(config) 155 | print(config) 156 | retrieve(config, "keya") 157 | 158 | -------------------------------------------------------------------------------- /core/ldm/modules/midas/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth.""" 2 | import sys 3 | import re 4 | import numpy as np 5 | import cv2 6 | import torch 7 | 8 | 9 | def read_pfm(path): 10 | """Read pfm file. 11 | 12 | Args: 13 | path (str): path to file 14 | 15 | Returns: 16 | tuple: (data, scale) 17 | """ 18 | with open(path, "rb") as file: 19 | 20 | color = None 21 | width = None 22 | height = None 23 | scale = None 24 | endian = None 25 | 26 | header = file.readline().rstrip() 27 | if header.decode("ascii") == "PF": 28 | color = True 29 | elif header.decode("ascii") == "Pf": 30 | color = False 31 | else: 32 | raise Exception("Not a PFM file: " + path) 33 | 34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 35 | if dim_match: 36 | width, height = list(map(int, dim_match.groups())) 37 | else: 38 | raise Exception("Malformed PFM header.") 39 | 40 | scale = float(file.readline().decode("ascii").rstrip()) 41 | if scale < 0: 42 | # little-endian 43 | endian = "<" 44 | scale = -scale 45 | else: 46 | # big-endian 47 | endian = ">" 48 | 49 | data = np.fromfile(file, endian + "f") 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | 55 | return data, scale 56 | 57 | 58 | def write_pfm(path, image, scale=1): 59 | """Write pfm file. 60 | 61 | Args: 62 | path (str): pathto file 63 | image (array): data 64 | scale (int, optional): Scale. Defaults to 1. 65 | """ 66 | 67 | with open(path, "wb") as file: 68 | color = None 69 | 70 | if image.dtype.name != "float32": 71 | raise Exception("Image dtype must be float32.") 72 | 73 | image = np.flipud(image) 74 | 75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 76 | color = True 77 | elif ( 78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 79 | ): # greyscale 80 | color = False 81 | else: 82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 83 | 84 | file.write("PF\n" if color else "Pf\n".encode()) 85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 86 | 87 | endian = image.dtype.byteorder 88 | 89 | if endian == "<" or endian == "=" and sys.byteorder == "little": 90 | scale = -scale 91 | 92 | file.write("%f\n".encode() % scale) 93 | 94 | image.tofile(file) 95 | 96 | 97 | def read_image(path): 98 | """Read image and output RGB image (0-1). 99 | 100 | Args: 101 | path (str): path to file 102 | 103 | Returns: 104 | array: RGB image (0-1) 105 | """ 106 | img = cv2.imread(path) 107 | 108 | if img.ndim == 2: 109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 110 | 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 112 | 113 | return img 114 | 115 | 116 | def resize_image(img): 117 | """Resize image and make it fit for network. 118 | 119 | Args: 120 | img (array): image 121 | 122 | Returns: 123 | tensor: data ready for network 124 | """ 125 | height_orig = img.shape[0] 126 | width_orig = img.shape[1] 127 | 128 | if width_orig > height_orig: 129 | scale = width_orig / 384 130 | else: 131 | scale = height_orig / 384 132 | 133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 135 | 136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 137 | 138 | img_resized = ( 139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 140 | ) 141 | img_resized = img_resized.unsqueeze(0) 142 | 143 | return img_resized 144 | 145 | 146 | def resize_depth(depth, width, height): 147 | """Resize depth map and bring to CPU (numpy). 148 | 149 | Args: 150 | depth (tensor): depth 151 | width (int): image width 152 | height (int): image height 153 | 154 | Returns: 155 | array: processed depth 156 | """ 157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 158 | 159 | depth_resized = cv2.resize( 160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 161 | ) 162 | 163 | return depth_resized 164 | 165 | def write_depth(path, depth, bits=1): 166 | """Write depth map to pfm and png file. 167 | 168 | Args: 169 | path (str): filepath without extension 170 | depth (array): depth 171 | """ 172 | write_pfm(path + ".pfm", depth.astype(np.float32)) 173 | 174 | depth_min = depth.min() 175 | depth_max = depth.max() 176 | 177 | max_val = (2**(8*bits))-1 178 | 179 | if depth_max - depth_min > np.finfo("float").eps: 180 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 181 | else: 182 | out = np.zeros(depth.shape, dtype=depth.type) 183 | 184 | if bits == 1: 185 | cv2.imwrite(path + ".png", out.astype("uint8")) 186 | elif bits == 2: 187 | cv2.imwrite(path + ".png", out.astype("uint16")) 188 | 189 | return 190 | -------------------------------------------------------------------------------- /core/train_turbulence_dydiff.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytorch_lightning as pl 3 | from dydiff.datasets.turbulence import TurbulenceDataModuleForDyDiff 4 | 5 | from logger.logger_turbulence import ImageLoggerWithKeyToConcat 6 | 7 | import argparse 8 | from omegaconf import OmegaConf 9 | 10 | from ldm.util import instantiate_from_config 11 | 12 | from pytorch_lightning.callbacks import ModelCheckpoint 13 | 14 | 15 | def create_data_module(cfg): 16 | data_module = TurbulenceDataModuleForDyDiff(cfg.data, cfg.training.batch_size, cfg.training.num_workers) 17 | return data_module 18 | 19 | def create_model(cfg): 20 | cfg.model.params.total_length = cfg.data.total_length 21 | cfg.model.params.input_length = cfg.data.input_length 22 | cfg.model.params.num_vis = cfg.eval.num_vis 23 | cfg.model.params.validation_save_dir = cfg.training.logger.save_dir 24 | 25 | cfg.model.params.channels = cfg.data.total_length - cfg.data.input_length 26 | cfg.model.params.unet_config.params.in_channels += cfg.data.total_length 27 | cfg.model.params.unet_config.params.out_channels = cfg.data.total_length - cfg.data.input_length 28 | 29 | if "ensemble" in cfg.model.params.keys() and cfg.model.params.ensemble == True: 30 | # cfg.model.params.channels *= 2 31 | cfg.model.params.unet_config.params.in_channels += cfg.model.params.unet_config.params.out_channels 32 | 33 | cfg.model.params.channels *= 3 34 | cfg.model.params.unet_config.params.in_channels *= 3 35 | cfg.model.params.unet_config.params.out_channels *= 3 36 | if cfg.model.params.conditioning_key == "mcvd": 37 | cfg.model.params.unet_config.params.cond_channels *= 3 38 | 39 | if cfg.model.params.conditioning_key == "concat-video": 40 | cfg.model.params.unet_config.params.in_channels = (cfg.data.input_length + 1) * 3 41 | cfg.model.params.unet_config.params.out_channels = 3 42 | cfg.model.params.unet_config.params.num_video_frames = cfg.data.total_length - cfg.data.input_length 43 | elif cfg.model.params.conditioning_key.startswith("concat-video-mask"): 44 | cfg.model.params.unet_config.params.in_channels = 3 * 2 + 1 # 3 * 2 for video&concat, 1 for mask 45 | if "1st" in cfg.model.params.conditioning_key: 46 | cfg.model.params.unet_config.params.in_channels += 1 47 | cfg.model.params.unet_config.params.out_channels = 3 48 | cfg.model.params.unet_config.params.num_video_frames = cfg.data.total_length 49 | 50 | cfg.model.params.ckpt_path = cfg.ckpt_path 51 | 52 | model = instantiate_from_config(cfg.model) 53 | for k, v in cfg.training.model_attrs.items(): 54 | setattr(model, k, v) 55 | return model 56 | 57 | def create_loggers(cfg): 58 | image_logger = ImageLoggerWithKeyToConcat( 59 | batch_frequency=cfg.training.logger.logger_freq, 60 | save_dir=cfg.training.logger.save_dir, 61 | keys_to_concat=["inputs", "samples"], 62 | log_images_kwargs=dict(cfg.model.params.validate_kwargs) 63 | ) 64 | checkpoint_logger = ModelCheckpoint( 65 | dirpath=cfg.training.logger.save_dir, 66 | every_n_train_steps=cfg.training.logger.checkpoint_freq, 67 | save_top_k=-1 68 | ) 69 | return [image_logger, checkpoint_logger] 70 | 71 | 72 | if __name__=='__main__': 73 | # Parse arguments 74 | def str2bool(v): 75 | if isinstance(v, bool): 76 | return v 77 | if v.lower() in ("yes", "true", "t", "y", "1"): 78 | return True 79 | elif v.lower() in ("no", "false", "f", "n", "0"): 80 | return False 81 | else: 82 | raise argparse.ArgumentTypeError("Boolean value expected.") 83 | 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument('--config_file', type=str, default='') 86 | parser.add_argument('--n_gpu', type=int, default=1) 87 | parser.add_argument('--data_root', type=str) 88 | parser.add_argument('--model_root', type=str) 89 | parser.add_argument('--ckpt_path', type=str, default=None) 90 | parser.add_argument("--test", type=str2bool, nargs="?", const=True, default=False) 91 | parser.add_argument("--resume", type=str, default=None) 92 | args = parser.parse_args() 93 | 94 | cfg = OmegaConf.load(args.config_file) 95 | cfg = OmegaConf.merge(cfg, vars(args)) 96 | config_file_name = os.path.basename(args.config_file) 97 | config_file_name = config_file_name.split('.')[0] 98 | 99 | cfg.training.logger.save_dir = os.path.join(cfg.training.logger.save_dir, config_file_name) 100 | 101 | print(OmegaConf.to_yaml(cfg)) 102 | # model 103 | model = create_model(cfg) 104 | 105 | # data module 106 | data_module = create_data_module(cfg) 107 | 108 | # loggers 109 | loggers = create_loggers(cfg) 110 | 111 | # trainer 112 | trainer = pl.Trainer(gpus=args.n_gpu, precision=32, 113 | callbacks=loggers, 114 | default_root_dir=cfg.training.logger.save_dir, 115 | max_steps=cfg.training.max_iterations, 116 | accumulate_grad_batches=cfg.training.accumulate_grad_batches, 117 | val_check_interval=int(cfg.training.validation_freq) if cfg.training.validation_freq is not None else float(1.), 118 | strategy="ddp") 119 | # Train! 120 | if not args.test: 121 | trainer.fit(model, datamodule=data_module, ckpt_path=args.resume) 122 | else: 123 | trainer.test(model, datamodule=data_module, ckpt_path=args.resume) 124 | -------------------------------------------------------------------------------- /core/ldm/modules/midas/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /core/sgm/modules/autoencoding/lpips/loss/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | from collections import namedtuple 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from ..util import get_ckpt_path 10 | 11 | 12 | class LPIPS(nn.Module): 13 | # Learned perceptual metric 14 | def __init__(self, use_dropout=True): 15 | super().__init__() 16 | self.scaling_layer = ScalingLayer() 17 | self.chns = [64, 128, 256, 512, 512] # vg16 features 18 | self.net = vgg16(pretrained=True, requires_grad=False) 19 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 20 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 21 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 22 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 23 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 24 | self.load_from_pretrained() 25 | for param in self.parameters(): 26 | param.requires_grad = False 27 | 28 | def load_from_pretrained(self, name="vgg_lpips"): 29 | ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss") 30 | self.load_state_dict( 31 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 32 | ) 33 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 34 | 35 | @classmethod 36 | def from_pretrained(cls, name="vgg_lpips"): 37 | if name != "vgg_lpips": 38 | raise NotImplementedError 39 | model = cls() 40 | ckpt = get_ckpt_path(name) 41 | model.load_state_dict( 42 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 43 | ) 44 | return model 45 | 46 | def forward(self, input, target): 47 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 48 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 49 | feats0, feats1, diffs = {}, {}, {} 50 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 51 | for kk in range(len(self.chns)): 52 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( 53 | outs1[kk] 54 | ) 55 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 56 | 57 | res = [ 58 | spatial_average(lins[kk].model(diffs[kk]), keepdim=True) 59 | for kk in range(len(self.chns)) 60 | ] 61 | val = res[0] 62 | for l in range(1, len(self.chns)): 63 | val += res[l] 64 | return val 65 | 66 | 67 | class ScalingLayer(nn.Module): 68 | def __init__(self): 69 | super(ScalingLayer, self).__init__() 70 | self.register_buffer( 71 | "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] 72 | ) 73 | self.register_buffer( 74 | "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] 75 | ) 76 | 77 | def forward(self, inp): 78 | return (inp - self.shift) / self.scale 79 | 80 | 81 | class NetLinLayer(nn.Module): 82 | """A single linear layer which does a 1x1 conv""" 83 | 84 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 85 | super(NetLinLayer, self).__init__() 86 | layers = ( 87 | [ 88 | nn.Dropout(), 89 | ] 90 | if (use_dropout) 91 | else [] 92 | ) 93 | layers += [ 94 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), 95 | ] 96 | self.model = nn.Sequential(*layers) 97 | 98 | 99 | class vgg16(torch.nn.Module): 100 | def __init__(self, requires_grad=False, pretrained=True): 101 | super(vgg16, self).__init__() 102 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 103 | self.slice1 = torch.nn.Sequential() 104 | self.slice2 = torch.nn.Sequential() 105 | self.slice3 = torch.nn.Sequential() 106 | self.slice4 = torch.nn.Sequential() 107 | self.slice5 = torch.nn.Sequential() 108 | self.N_slices = 5 109 | for x in range(4): 110 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 111 | for x in range(4, 9): 112 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 113 | for x in range(9, 16): 114 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 115 | for x in range(16, 23): 116 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 117 | for x in range(23, 30): 118 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 119 | if not requires_grad: 120 | for param in self.parameters(): 121 | param.requires_grad = False 122 | 123 | def forward(self, X): 124 | h = self.slice1(X) 125 | h_relu1_2 = h 126 | h = self.slice2(h) 127 | h_relu2_2 = h 128 | h = self.slice3(h) 129 | h_relu3_3 = h 130 | h = self.slice4(h) 131 | h_relu4_3 = h 132 | h = self.slice5(h) 133 | h_relu5_3 = h 134 | vgg_outputs = namedtuple( 135 | "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] 136 | ) 137 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 138 | return out 139 | 140 | 141 | def normalize_tensor(x, eps=1e-10): 142 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 143 | return x / (norm_factor + eps) 144 | 145 | 146 | def spatial_average(x, keepdim=True): 147 | return x.mean([2, 3], keepdim=keepdim) 148 | -------------------------------------------------------------------------------- /core/taming/data/ade20k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import albumentations 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | from taming.data.sflckr import SegmentationBase # for examples included in repo 9 | 10 | 11 | class Examples(SegmentationBase): 12 | def __init__(self, size=256, random_crop=False, interpolation="bicubic"): 13 | super().__init__(data_csv="data/ade20k_examples.txt", 14 | data_root="data/ade20k_images", 15 | segmentation_root="data/ade20k_segmentations", 16 | size=size, random_crop=random_crop, 17 | interpolation=interpolation, 18 | n_labels=151, shift_segmentation=False) 19 | 20 | 21 | # With semantic map and scene label 22 | class ADE20kBase(Dataset): 23 | def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None): 24 | self.split = self.get_split() 25 | self.n_labels = 151 # unknown + 150 26 | self.data_csv = {"train": "data/ade20k_train.txt", 27 | "validation": "data/ade20k_test.txt"}[self.split] 28 | self.data_root = "data/ade20k_root" 29 | with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f: 30 | self.scene_categories = f.read().splitlines() 31 | self.scene_categories = dict(line.split() for line in self.scene_categories) 32 | with open(self.data_csv, "r") as f: 33 | self.image_paths = f.read().splitlines() 34 | self._length = len(self.image_paths) 35 | self.labels = { 36 | "relative_file_path_": [l for l in self.image_paths], 37 | "file_path_": [os.path.join(self.data_root, "images", l) 38 | for l in self.image_paths], 39 | "relative_segmentation_path_": [l.replace(".jpg", ".png") 40 | for l in self.image_paths], 41 | "segmentation_path_": [os.path.join(self.data_root, "annotations", 42 | l.replace(".jpg", ".png")) 43 | for l in self.image_paths], 44 | "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")] 45 | for l in self.image_paths], 46 | } 47 | 48 | size = None if size is not None and size<=0 else size 49 | self.size = size 50 | if crop_size is None: 51 | self.crop_size = size if size is not None else None 52 | else: 53 | self.crop_size = crop_size 54 | if self.size is not None: 55 | self.interpolation = interpolation 56 | self.interpolation = { 57 | "nearest": cv2.INTER_NEAREST, 58 | "bilinear": cv2.INTER_LINEAR, 59 | "bicubic": cv2.INTER_CUBIC, 60 | "area": cv2.INTER_AREA, 61 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] 62 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 63 | interpolation=self.interpolation) 64 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 65 | interpolation=cv2.INTER_NEAREST) 66 | 67 | if crop_size is not None: 68 | self.center_crop = not random_crop 69 | if self.center_crop: 70 | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) 71 | else: 72 | self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) 73 | self.preprocessor = self.cropper 74 | 75 | def __len__(self): 76 | return self._length 77 | 78 | def __getitem__(self, i): 79 | example = dict((k, self.labels[k][i]) for k in self.labels) 80 | image = Image.open(example["file_path_"]) 81 | if not image.mode == "RGB": 82 | image = image.convert("RGB") 83 | image = np.array(image).astype(np.uint8) 84 | if self.size is not None: 85 | image = self.image_rescaler(image=image)["image"] 86 | segmentation = Image.open(example["segmentation_path_"]) 87 | segmentation = np.array(segmentation).astype(np.uint8) 88 | if self.size is not None: 89 | segmentation = self.segmentation_rescaler(image=segmentation)["image"] 90 | if self.size is not None: 91 | processed = self.preprocessor(image=image, mask=segmentation) 92 | else: 93 | processed = {"image": image, "mask": segmentation} 94 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) 95 | segmentation = processed["mask"] 96 | onehot = np.eye(self.n_labels)[segmentation] 97 | example["segmentation"] = onehot 98 | return example 99 | 100 | 101 | class ADE20kTrain(ADE20kBase): 102 | # default to random_crop=True 103 | def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None): 104 | super().__init__(config=config, size=size, random_crop=random_crop, 105 | interpolation=interpolation, crop_size=crop_size) 106 | 107 | def get_split(self): 108 | return "train" 109 | 110 | 111 | class ADE20kValidation(ADE20kBase): 112 | def get_split(self): 113 | return "validation" 114 | 115 | 116 | if __name__ == "__main__": 117 | dset = ADE20kValidation() 118 | ex = dset[0] 119 | for k in ["image", "scene_category", "segmentation"]: 120 | print(type(ex[k])) 121 | try: 122 | print(ex[k].shape) 123 | except: 124 | print(ex[k]) 125 | --------------------------------------------------------------------------------