├── ldm ├── __init__.py ├── data │ ├── __init__.py │ ├── inpainting │ │ ├── __init__.py │ │ └── synthetic_mask.py │ ├── dummy.py │ ├── base.py │ ├── lsun.py │ ├── nerf_like.py │ ├── coco.py │ └── imagenet.py ├── models │ └── diffusion │ │ ├── __init__.py │ │ ├── sampling_util.py │ │ ├── classifier.py │ │ └── plms.py ├── modules │ ├── encoders │ │ └── __init__.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ └── util.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ ├── image_degradation │ │ ├── utils │ │ │ └── test.png │ │ └── __init__.py │ ├── ema.py │ ├── evaluate │ │ ├── ssim.py │ │ ├── frechet_video_distance.py │ │ └── torch_frechet_video_distance.py │ └── attention.py ├── README.md ├── thirdp │ └── psp │ │ ├── id_loss.py │ │ ├── model_irse.py │ │ └── helpers.py ├── extras.py ├── guidance.py ├── lr_scheduler.py └── util.py ├── .gitignore ├── images ├── 0_hydrant.png ├── Zero123-Simple.png └── image_preprocess.png ├── __init__.py ├── requirements.txt ├── custom-node-list.json ├── .github └── workflows │ └── publish.yml ├── util.py ├── model-list.json ├── config └── zero123.yaml ├── sample └── simple_workflow.json ├── README_CN.md ├── util_preprocess.py ├── zero123.py ├── zero123_nodes.py ├── README.md ├── pyproject.toml └── LICENSE /ldm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/data/inpainting/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/README.md: -------------------------------------------------------------------------------- 1 | Porting from https://github.com/CompVis/latent-diffusion -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .ipynb_checkpoints 3 | *.ckpt 4 | .env 5 | *.ipynb 6 | *.swp 7 | -------------------------------------------------------------------------------- /ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /images/0_hydrant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kealiu/ComfyUI-Zero123-Porting/HEAD/images/0_hydrant.png -------------------------------------------------------------------------------- /images/Zero123-Simple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kealiu/ComfyUI-Zero123-Porting/HEAD/images/Zero123-Simple.png -------------------------------------------------------------------------------- /images/image_preprocess.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kealiu/ComfyUI-Zero123-Porting/HEAD/images/image_preprocess.png -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kealiu/ComfyUI-Zero123-Porting/HEAD/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, os.path.dirname(__file__)) 5 | 6 | from zero123_nodes import Zero123, Zero123Preprocess 7 | 8 | sys.path.pop(0) 9 | 10 | NODE_CLASS_MAPPINGS = { 11 | "Zero123: Image Rotate in 3D" : Zero123, 12 | "Zero123: Image Preprocess" : Zero123Preprocess 13 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==1.4.4 2 | opencv-python==4.9.0.80 3 | pudb==2024.1 4 | imageio==2.34.1 5 | imageio-ffmpeg==0.4.9 6 | pytorch-lightning==2.2.3 7 | torchmetrics==1.3.2 8 | omegaconf==2.3.0 9 | test_tube==0.7.5 10 | streamlit==1.33.0 11 | einops==0.8.0 12 | torch-fidelity==0.3.0 13 | transformers==4.40.1 14 | -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 15 | -e git+https://github.com/openai/CLIP.git@main#egg=clip 16 | -------------------------------------------------------------------------------- /custom-node-list.json: -------------------------------------------------------------------------------- 1 | { 2 | "custom_nodes": [ 3 | { 4 | "author": "kealiu", 5 | "title": "ComfyUI-Zero123-Porting", 6 | "reference": "https://github.com/kealiu/ComfyUI-Zero123-Porting", 7 | "files": [ 8 | "https://github.com/kealiu/ComfyUI-Zero123-Porting" 9 | ], 10 | "install_type": "git-clone", 11 | "description": "Zero-1-to-3: Zero-shot One Image to 3D Object, unofficial porting of original [Zero123](https://github.com/cvlab-columbia/zero123)" 12 | } 13 | ] 14 | } -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | jobs: 11 | publish-node: 12 | name: Publish Custom Node to registry 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Check out code 16 | uses: actions/checkout@v4 17 | - name: Publish Custom Node 18 | uses: Comfy-Org/publish-node-action@main 19 | with: 20 | ## Add your own personal access token to your Github Repository secrets and reference it here. 21 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | # 2 | # extract from original zero123 code : https://github.com/cvlab-columbia/zero123/ 3 | # 4 | 5 | import torch 6 | import numpy as np 7 | import os 8 | import math 9 | import torch.nn as nn 10 | from einops import repeat 11 | 12 | def instantiate_from_config(config): 13 | if not "target" in config: 14 | if config == '__is_first_stage__': 15 | return None 16 | elif config == "__is_unconditional__": 17 | return None 18 | raise KeyError("Expected key `target` to instantiate.") 19 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 20 | 21 | def get_obj_from_str(string, reload=False): 22 | module, cls = string.rsplit(".", 1) 23 | if reload: 24 | module_imp = importlib.import_module(module) 25 | importlib.reload(module_imp) 26 | return getattr(importlib.import_module(module, package=None), cls) 27 | -------------------------------------------------------------------------------- /ldm/thirdp/psp/id_loss.py: -------------------------------------------------------------------------------- 1 | # https://github.com/eladrich/pixel2style2pixel 2 | import torch 3 | from torch import nn 4 | from ldm.thirdp.psp.model_irse import Backbone 5 | 6 | 7 | class IDFeatures(nn.Module): 8 | def __init__(self, model_path): 9 | super(IDFeatures, self).__init__() 10 | print('Loading ResNet ArcFace') 11 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') 12 | self.facenet.load_state_dict(torch.load(model_path, map_location="cpu")) 13 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 14 | self.facenet.eval() 15 | 16 | def forward(self, x, crop=False): 17 | # Not sure of the image range here 18 | if crop: 19 | x = torch.nn.functional.interpolate(x, (256, 256), mode="area") 20 | x = x[:, :, 35:223, 32:220] 21 | x = self.face_pool(x) 22 | x_feats = self.facenet(x) 23 | return x_feats 24 | -------------------------------------------------------------------------------- /ldm/data/dummy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import string 4 | from torch.utils.data import Dataset, Subset 5 | 6 | class DummyData(Dataset): 7 | def __init__(self, length, size): 8 | self.length = length 9 | self.size = size 10 | 11 | def __len__(self): 12 | return self.length 13 | 14 | def __getitem__(self, i): 15 | x = np.random.randn(*self.size) 16 | letters = string.ascii_lowercase 17 | y = ''.join(random.choice(string.ascii_lowercase) for i in range(10)) 18 | return {"jpg": x, "txt": y} 19 | 20 | 21 | class DummyDataWithEmbeddings(Dataset): 22 | def __init__(self, length, size, emb_size): 23 | self.length = length 24 | self.size = size 25 | self.emb_size = emb_size 26 | 27 | def __len__(self): 28 | return self.length 29 | 30 | def __getitem__(self, i): 31 | x = np.random.randn(*self.size) 32 | y = np.random.randn(*self.emb_size).astype(np.float32) 33 | return {"jpg": x, "txt": y} 34 | 35 | -------------------------------------------------------------------------------- /ldm/data/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from abc import abstractmethod 4 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 5 | 6 | 7 | class Txt2ImgIterableBaseDataset(IterableDataset): 8 | ''' 9 | Define an interface to make the IterableDatasets for text2img data chainable 10 | ''' 11 | def __init__(self, num_records=0, valid_ids=None, size=256): 12 | super().__init__() 13 | self.num_records = num_records 14 | self.valid_ids = valid_ids 15 | self.sample_ids = valid_ids 16 | self.size = size 17 | 18 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 19 | 20 | def __len__(self): 21 | return self.num_records 22 | 23 | @abstractmethod 24 | def __iter__(self): 25 | pass 26 | 27 | 28 | class PRNGMixin(object): 29 | """ 30 | Adds a prng property which is a numpy RandomState which gets 31 | reinitialized whenever the pid changes to avoid synchronized sampling 32 | behavior when used in conjunction with multiprocessing. 33 | """ 34 | @property 35 | def prng(self): 36 | currentpid = os.getpid() 37 | if getattr(self, "_initpid", None) != currentpid: 38 | self._initpid = currentpid 39 | self._prng = np.random.RandomState() 40 | return self._prng 41 | -------------------------------------------------------------------------------- /ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def renorm_thresholding(x0, value): 15 | # renorm 16 | pred_max = x0.max() 17 | pred_min = x0.min() 18 | pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1 19 | pred_x0 = 2 * pred_x0 - 1. # -1 ... 1 20 | 21 | s = torch.quantile( 22 | rearrange(pred_x0, 'b ... -> b (...)').abs(), 23 | value, 24 | dim=-1 25 | ) 26 | s.clamp_(min=1.0) 27 | s = s.view(-1, *((1,) * (pred_x0.ndim - 1))) 28 | 29 | # clip by threshold 30 | # pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max 31 | 32 | # temporary hack: numpy on cpu 33 | pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy() 34 | pred_x0 = torch.tensor(pred_x0).to(self.model.device) 35 | 36 | # re.renorm 37 | pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1 38 | pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range 39 | return pred_x0 40 | 41 | 42 | def norm_thresholding(x0, value): 43 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 44 | return x0 * (value / s) 45 | 46 | 47 | def spatial_norm_thresholding(x0, value): 48 | # b c h w 49 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 50 | return x0 * (value / s) -------------------------------------------------------------------------------- /model-list.json: -------------------------------------------------------------------------------- 1 | { 2 | "models": [ 3 | { 4 | "name": "Zero123 3D object Model", 5 | "type": "Zero123", 6 | "base": "Zero123", 7 | "save_path": "checkpoints/zero123", 8 | "description": "model that been trained on 10M+ 3D objects from Objaverse-XL, used for generated rotated CamView", 9 | "reference": "https://objaverse.allenai.org/docs/zero123-xl/", 10 | "filename": "zero123-xl.ckpt", 11 | "url": "https://huggingface.co/kealiu/zero123-xl/resolve/main/zero123-xl.ckpt" 12 | }, 13 | { 14 | "name": "Zero123 3D object Model", 15 | "type": "Zero123", 16 | "base": "Zero123", 17 | "save_path": "checkpoints/zero123", 18 | "description": "Stable Zero123 is a model for view-conditioned image generation based on [a/Zero123](https://github.com/cvlab-columbia/zero123).", 19 | "reference": "https://huggingface.co/stabilityai/stable-zero123", 20 | "filename": "stable_zero123.ckpt", 21 | "url": "https://huggingface.co/stabilityai/stable-zero123/resolve/main/stable_zero123.ckpt" 22 | }, 23 | { 24 | "name": "Zero123 3D object Model", 25 | "type": "Zero123", 26 | "base": "Zero123", 27 | "save_path": "checkpoints/zero123", 28 | "description": "Zero123 original checkpoints in 105000 steps.", 29 | "reference": "https://huggingface.co/cvlab/zero123-weights", 30 | "filename": "zero123-105000.ckpt", 31 | "url": "https://huggingface.co/cvlab/zero123-weights/resolve/main/105000.ckpt" 32 | }, 33 | { 34 | "name": "Zero123 3D object Model", 35 | "type": "Zero123", 36 | "base": "Zero123", 37 | "save_path": "checkpoints/zero123", 38 | "description": "Zero123 original checkpoints in 165000 steps.", 39 | "reference": "https://huggingface.co/cvlab/zero123-weights", 40 | "filename": "zero123-165000.ckpt", 41 | "url": "https://huggingface.co/cvlab/zero123-weights/resolve/main/165000.ckpt" 42 | } 43 | ] 44 | } -------------------------------------------------------------------------------- /config/zero123.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "image_target" 11 | cond_stage_key: "image_cond" 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: hybrid 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | 19 | scheduler_config: # 10000 warmup steps 20 | target: ldm.lr_scheduler.LambdaLinearScheduler 21 | params: 22 | warm_up_steps: [ 100 ] 23 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 24 | f_start: [ 1.e-6 ] 25 | f_max: [ 1. ] 26 | f_min: [ 1. ] 27 | 28 | unet_config: 29 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 30 | params: 31 | image_size: 32 # unused 32 | in_channels: 8 33 | out_channels: 4 34 | model_channels: 320 35 | attention_resolutions: [ 4, 2, 1 ] 36 | num_res_blocks: 2 37 | channel_mult: [ 1, 2, 4, 4 ] 38 | num_heads: 8 39 | use_spatial_transformer: True 40 | transformer_depth: 1 41 | context_dim: 768 42 | use_checkpoint: True 43 | legacy: False 44 | 45 | first_stage_config: 46 | target: ldm.models.autoencoder.AutoencoderKL 47 | params: 48 | embed_dim: 4 49 | monitor: val/rec_loss 50 | ddconfig: 51 | double_z: true 52 | z_channels: 4 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: 58 | - 1 59 | - 2 60 | - 4 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | 68 | cond_stage_config: 69 | target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder 70 | -------------------------------------------------------------------------------- /ldm/extras.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from omegaconf import OmegaConf 3 | import torch 4 | from ldm.util import instantiate_from_config 5 | import logging 6 | from contextlib import contextmanager 7 | 8 | from contextlib import contextmanager 9 | import logging 10 | 11 | @contextmanager 12 | def all_logging_disabled(highest_level=logging.CRITICAL): 13 | """ 14 | A context manager that will prevent any logging messages 15 | triggered during the body from being processed. 16 | 17 | :param highest_level: the maximum logging level in use. 18 | This would only need to be changed if a custom level greater than CRITICAL 19 | is defined. 20 | 21 | https://gist.github.com/simon-weber/7853144 22 | """ 23 | # two kind-of hacks here: 24 | # * can't get the highest logging level in effect => delegate to the user 25 | # * can't get the current module-level override => use an undocumented 26 | # (but non-private!) interface 27 | 28 | previous_level = logging.root.manager.disable 29 | 30 | logging.disable(highest_level) 31 | 32 | try: 33 | yield 34 | finally: 35 | logging.disable(previous_level) 36 | 37 | def load_training_dir(train_dir, device, epoch="last"): 38 | """Load a checkpoint and config from training directory""" 39 | train_dir = Path(train_dir) 40 | ckpt = list(train_dir.rglob(f"*{epoch}.ckpt")) 41 | assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files" 42 | config = list(train_dir.rglob(f"*-project.yaml")) 43 | assert len(ckpt) > 0, f"didn't find any config in {train_dir}" 44 | if len(config) > 1: 45 | print(f"found {len(config)} matching config files") 46 | config = sorted(config)[-1] 47 | print(f"selecting {config}") 48 | else: 49 | config = config[0] 50 | 51 | 52 | config = OmegaConf.load(config) 53 | return load_model_from_config(config, ckpt[0], device) 54 | 55 | def load_model_from_config(config, ckpt, device="cpu", verbose=False): 56 | """Loads a model from config and a ckpt 57 | if config is a path will use omegaconf to load 58 | """ 59 | if isinstance(config, (str, Path)): 60 | config = OmegaConf.load(config) 61 | 62 | with all_logging_disabled(): 63 | print(f"Loading model from {ckpt}") 64 | pl_sd = torch.load(ckpt, map_location="cpu") 65 | global_step = pl_sd["global_step"] 66 | sd = pl_sd["state_dict"] 67 | model = instantiate_from_config(config.model) 68 | m, u = model.load_state_dict(sd, strict=False) 69 | if len(m) > 0 and verbose: 70 | print("missing keys:") 71 | print(m) 72 | if len(u) > 0 and verbose: 73 | print("unexpected keys:") 74 | model.to(device) 75 | model.eval() 76 | model.cond_stage_model.device = device 77 | return model -------------------------------------------------------------------------------- /ldm/thirdp/psp/model_irse.py: -------------------------------------------------------------------------------- 1 | # https://github.com/eladrich/pixel2style2pixel 2 | 3 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 4 | from ldm.thirdp.psp.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 5 | 6 | """ 7 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 8 | """ 9 | 10 | 11 | class Backbone(Module): 12 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 13 | super(Backbone, self).__init__() 14 | assert input_size in [112, 224], "input_size should be 112 or 224" 15 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 16 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 17 | blocks = get_blocks(num_layers) 18 | if mode == 'ir': 19 | unit_module = bottleneck_IR 20 | elif mode == 'ir_se': 21 | unit_module = bottleneck_IR_SE 22 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 23 | BatchNorm2d(64), 24 | PReLU(64)) 25 | if input_size == 112: 26 | self.output_layer = Sequential(BatchNorm2d(512), 27 | Dropout(drop_ratio), 28 | Flatten(), 29 | Linear(512 * 7 * 7, 512), 30 | BatchNorm1d(512, affine=affine)) 31 | else: 32 | self.output_layer = Sequential(BatchNorm2d(512), 33 | Dropout(drop_ratio), 34 | Flatten(), 35 | Linear(512 * 14 * 14, 512), 36 | BatchNorm1d(512, affine=affine)) 37 | 38 | modules = [] 39 | for block in blocks: 40 | for bottleneck in block: 41 | modules.append(unit_module(bottleneck.in_channel, 42 | bottleneck.depth, 43 | bottleneck.stride)) 44 | self.body = Sequential(*modules) 45 | 46 | def forward(self, x): 47 | x = self.input_layer(x) 48 | x = self.body(x) 49 | x = self.output_layer(x) 50 | return l2_norm(x) 51 | 52 | 53 | def IR_50(input_size): 54 | """Constructs a ir-50 model.""" 55 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 56 | return model 57 | 58 | 59 | def IR_101(input_size): 60 | """Constructs a ir-101 model.""" 61 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 62 | return model 63 | 64 | 65 | def IR_152(input_size): 66 | """Constructs a ir-152 model.""" 67 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 68 | return model 69 | 70 | 71 | def IR_SE_50(input_size): 72 | """Constructs a ir_se-50 model.""" 73 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 74 | return model 75 | 76 | 77 | def IR_SE_101(input_size): 78 | """Constructs a ir_se-101 model.""" 79 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 80 | return model 81 | 82 | 83 | def IR_SE_152(input_size): 84 | """Constructs a ir_se-152 model.""" 85 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 86 | return model -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /sample/simple_workflow.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 42, 3 | "last_link_id": 49, 4 | "nodes": [ 5 | { 6 | "id": 37, 7 | "type": "LoadImage", 8 | "pos": [ 9 | 279, 10 | 200 11 | ], 12 | "size": { 13 | "0": 315, 14 | "1": 314 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "outputs": [ 20 | { 21 | "name": "IMAGE", 22 | "type": "IMAGE", 23 | "links": [ 24 | 45 25 | ], 26 | "shape": 3, 27 | "label": "IMAGE", 28 | "slot_index": 0 29 | }, 30 | { 31 | "name": "MASK", 32 | "type": "MASK", 33 | "links": null, 34 | "shape": 3, 35 | "label": "MASK" 36 | } 37 | ], 38 | "properties": { 39 | "Node name for S&R": "LoadImage" 40 | }, 41 | "widgets_values": [ 42 | "0_hydrant.png", 43 | "image" 44 | ] 45 | }, 46 | { 47 | "id": 42, 48 | "type": "PreviewImage", 49 | "pos": [ 50 | 1077, 51 | 201 52 | ], 53 | "size": [ 54 | 210, 55 | 246 56 | ], 57 | "flags": {}, 58 | "order": 2, 59 | "mode": 0, 60 | "inputs": [ 61 | { 62 | "name": "images", 63 | "type": "IMAGE", 64 | "link": 49, 65 | "label": "images" 66 | } 67 | ], 68 | "properties": { 69 | "Node name for S&R": "PreviewImage" 70 | } 71 | }, 72 | { 73 | "id": 39, 74 | "type": "Zero123: Image Rotate in 3D", 75 | "pos": [ 76 | 680, 77 | 200 78 | ], 79 | "size": { 80 | "0": 315, 81 | "1": 298 82 | }, 83 | "flags": {}, 84 | "order": 1, 85 | "mode": 0, 86 | "inputs": [ 87 | { 88 | "name": "image", 89 | "type": "IMAGE", 90 | "link": 45, 91 | "label": "image" 92 | } 93 | ], 94 | "outputs": [ 95 | { 96 | "name": "IMAGE", 97 | "type": "IMAGE", 98 | "links": [ 99 | 49 100 | ], 101 | "shape": 3, 102 | "label": "IMAGE", 103 | "slot_index": 0 104 | } 105 | ], 106 | "properties": { 107 | "Node name for S&R": "Zero123: Image Rotate in 3D" 108 | }, 109 | "widgets_values": [ 110 | -15, 111 | 90, 112 | 2.5, 113 | 75, 114 | 1, 115 | false, 116 | "zero123/zero123-xl.ckpt", 117 | "height=256", 118 | "width=256", 119 | "ddim", 120 | "ddim-uniform" 121 | ] 122 | } 123 | ], 124 | "links": [ 125 | [ 126 | 45, 127 | 37, 128 | 0, 129 | 39, 130 | 0, 131 | "IMAGE" 132 | ], 133 | [ 134 | 49, 135 | 39, 136 | 0, 137 | 42, 138 | 0, 139 | "IMAGE" 140 | ] 141 | ], 142 | "groups": [], 143 | "config": {}, 144 | "extra": { 145 | "workspace_info": { 146 | "id": "DiTa0nwCW2WLMjn9Ua3oZ" 147 | } 148 | }, 149 | "version": 0.4 150 | } -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/data/lsun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | 9 | class LSUNBase(Dataset): 10 | def __init__(self, 11 | txt_file, 12 | data_root, 13 | size=None, 14 | interpolation="bicubic", 15 | flip_p=0.5 16 | ): 17 | self.data_paths = txt_file 18 | self.data_root = data_root 19 | with open(self.data_paths, "r") as f: 20 | self.image_paths = f.read().splitlines() 21 | self._length = len(self.image_paths) 22 | self.labels = { 23 | "relative_file_path_": [l for l in self.image_paths], 24 | "file_path_": [os.path.join(self.data_root, l) 25 | for l in self.image_paths], 26 | } 27 | 28 | self.size = size 29 | self.interpolation = {"linear": PIL.Image.LINEAR, 30 | "bilinear": PIL.Image.BILINEAR, 31 | "bicubic": PIL.Image.BICUBIC, 32 | "lanczos": PIL.Image.LANCZOS, 33 | }[interpolation] 34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 35 | 36 | def __len__(self): 37 | return self._length 38 | 39 | def __getitem__(self, i): 40 | example = dict((k, self.labels[k][i]) for k in self.labels) 41 | image = Image.open(example["file_path_"]) 42 | if not image.mode == "RGB": 43 | image = image.convert("RGB") 44 | 45 | # default to score-sde preprocessing 46 | img = np.array(image).astype(np.uint8) 47 | crop = min(img.shape[0], img.shape[1]) 48 | h, w, = img.shape[0], img.shape[1] 49 | img = img[(h - crop) // 2:(h + crop) // 2, 50 | (w - crop) // 2:(w + crop) // 2] 51 | 52 | image = Image.fromarray(img) 53 | if self.size is not None: 54 | image = image.resize((self.size, self.size), resample=self.interpolation) 55 | 56 | image = self.flip(image) 57 | image = np.array(image).astype(np.uint8) 58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 59 | return example 60 | 61 | 62 | class LSUNChurchesTrain(LSUNBase): 63 | def __init__(self, **kwargs): 64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) 65 | 66 | 67 | class LSUNChurchesValidation(LSUNBase): 68 | def __init__(self, flip_p=0., **kwargs): 69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", 70 | flip_p=flip_p, **kwargs) 71 | 72 | 73 | class LSUNBedroomsTrain(LSUNBase): 74 | def __init__(self, **kwargs): 75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) 76 | 77 | 78 | class LSUNBedroomsValidation(LSUNBase): 79 | def __init__(self, flip_p=0.0, **kwargs): 80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", 81 | flip_p=flip_p, **kwargs) 82 | 83 | 84 | class LSUNCatsTrain(LSUNBase): 85 | def __init__(self, **kwargs): 86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) 87 | 88 | 89 | class LSUNCatsValidation(LSUNBase): 90 | def __init__(self, flip_p=0., **kwargs): 91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", 92 | flip_p=flip_p, **kwargs) 93 | -------------------------------------------------------------------------------- /ldm/guidance.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | from scipy import interpolate 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | from IPython.display import clear_output 7 | import abc 8 | 9 | 10 | class GuideModel(torch.nn.Module, abc.ABC): 11 | def __init__(self) -> None: 12 | super().__init__() 13 | 14 | @abc.abstractmethod 15 | def preprocess(self, x_img): 16 | pass 17 | 18 | @abc.abstractmethod 19 | def compute_loss(self, inp): 20 | pass 21 | 22 | 23 | class Guider(torch.nn.Module): 24 | def __init__(self, sampler, guide_model, scale=1.0, verbose=False): 25 | """Apply classifier guidance 26 | 27 | Specify a guidance scale as either a scalar 28 | Or a schedule as a list of tuples t = 0->1 and scale, e.g. 29 | [(0, 10), (0.5, 20), (1, 50)] 30 | """ 31 | super().__init__() 32 | self.sampler = sampler 33 | self.index = 0 34 | self.show = verbose 35 | self.guide_model = guide_model 36 | self.history = [] 37 | 38 | if isinstance(scale, (Tuple, List)): 39 | times = np.array([x[0] for x in scale]) 40 | values = np.array([x[1] for x in scale]) 41 | self.scale_schedule = {"times": times, "values": values} 42 | else: 43 | self.scale_schedule = float(scale) 44 | 45 | self.ddim_timesteps = sampler.ddim_timesteps 46 | self.ddpm_num_timesteps = sampler.ddpm_num_timesteps 47 | 48 | 49 | def get_scales(self): 50 | if isinstance(self.scale_schedule, float): 51 | return len(self.ddim_timesteps)*[self.scale_schedule] 52 | 53 | interpolater = interpolate.interp1d(self.scale_schedule["times"], self.scale_schedule["values"]) 54 | fractional_steps = np.array(self.ddim_timesteps)/self.ddpm_num_timesteps 55 | return interpolater(fractional_steps) 56 | 57 | def modify_score(self, model, e_t, x, t, c): 58 | 59 | # TODO look up index by t 60 | scale = self.get_scales()[self.index] 61 | 62 | if (scale == 0): 63 | return e_t 64 | 65 | sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device) 66 | with torch.enable_grad(): 67 | x_in = x.detach().requires_grad_(True) 68 | pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t) 69 | x_img = model.first_stage_model.decode((1/0.18215)*pred_x0) 70 | 71 | inp = self.guide_model.preprocess(x_img) 72 | loss = self.guide_model.compute_loss(inp) 73 | grads = torch.autograd.grad(loss.sum(), x_in)[0] 74 | correction = grads * scale 75 | 76 | if self.show: 77 | clear_output(wait=True) 78 | print(loss.item(), scale, correction.abs().max().item(), e_t.abs().max().item()) 79 | self.history.append([loss.item(), scale, correction.min().item(), correction.max().item()]) 80 | plt.imshow((inp[0].detach().permute(1,2,0).clamp(-1,1).cpu()+1)/2) 81 | plt.axis('off') 82 | plt.show() 83 | plt.imshow(correction[0][0].detach().cpu()) 84 | plt.axis('off') 85 | plt.show() 86 | 87 | 88 | e_t_mod = e_t - sqrt_1ma*correction 89 | if self.show: 90 | fig, axs = plt.subplots(1, 3) 91 | axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2) 92 | axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2) 93 | axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2) 94 | plt.show() 95 | self.index += 1 96 | return e_t_mod -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | # ComfyUI 自定义节点 Zero-1-to-3: Zero-shot 单张图片全角度3D重绘 2 | 3 | [英文](README.md) 4 | 5 | 这是一个非官方的 [Zero123 https://zero123.cs.columbia.edu/](https://zero123.cs.columbia.edu/) 移植 ComfyUI 自定义节点。实现使用单张 RGB 图像改变物体视角。 6 | 7 | 通过此移植,您可以在 ComfyUI 中生成 3D 旋转图像。 8 | 9 | ![Functions](https://github.com/cvlab-columbia/zero123/blob/main/teaser.png) 10 | 11 | # 简单上手 12 | 13 | 安装此节点后,下载 样例工作流 [sample workflow](sample/simple_workflow.json) 开始使用。 14 | 15 | 有任何问题或者建议,欢迎在[issue](https://github.com/kealiu/ComfyUI-Zero123-Porting/issues)中反馈。 16 | 17 | ## 节点和工作流 18 | 19 | ### 节点 `Zero123: Image Rotate in 3D` 20 | 21 | ![simple workflow](images/Zero123-Simple.png) 22 | 23 | ### 节点 `Zero123: Image Preprocess` 24 | 25 | ![simple image process](images/image_preprocess.png) 26 | 27 | ## 前提条件 28 | 29 | - 输入图像 `image` 必须是 `正方形` (宽=高),否则将强制自动转换 30 | - 输入图像 `image` 应该是一个具有 **`白色背景`** 的`物体`,可使用 `Zero123: Image Preprocess` 预处理图像。 31 | - 输出图像 `image` 目前仅支持 `256x256` (固定),可后期放大。 32 | 33 | # 说明 34 | 35 | ## 节点 `Zero123: Image Rotate in 3D` 参数输入与输出结果 36 | 37 | ### 输入 38 | 39 | - **_image_** : 输入图像,应为`正方形`图像,且为具有`白色背景`的`物体`。 40 | - **_polar_angle_** : `X` 轴的旋转角度,向上或向下转动 41 | - `<0.0` : 向上转动 42 | - `>0.0` : 向下转动 43 | - **_azimuth_angle_** : `Y` 轴的旋转角度,向左或向右转动 44 | - `<0.0` : 向左转动 45 | - `>0.0` : 向右转动 46 | - **_scale_** : `Z` 轴,`远`或`近` 47 | - `>1.0` : 更大、更近 48 | - `0<1<1.0`: 更小、更远 49 | - `1.0` : 不变 50 | - **_steps_** : 使用原始 `zero123` 代码库中的默认值 `75`,建议不小于 `75` 51 | - **_batch_size_** : 想要生成的图像数量 52 | - **_fp16_** : 是否以 `fp16` 加载模型。启用可以加速并节省 GPU 显存 53 | - **_checkpoint_** : 选择模型,`zero123-xl` 是当前最新的模型. `stable-zero123` 效果可能更好但商业需要许可。 54 | - **_height_** : 输出高度,固定为 256 不可变 55 | - **_width_** : 输出宽度,固定为 256 不可变 56 | - **_sampler_** : 固定不可变 57 | - **_scheduler_** : 固定不可变 58 | 59 | ### 输出 60 | 61 | - **_images_** : 输出图像 62 | 63 | ## 节点 `Zero123: Image Preprocess` 参数输入与输出结果 64 | 65 | ### 输入 66 | 67 | - **_image_** : 原始输入`图像`. 68 | - **_mask_** : 原始输入`图像`对应的`遮罩(Mask)`. 69 | - **_margin_** : 输出图像四周 `留白` 比例. 70 | 71 | ### 输出 72 | 73 | - **_image_** : 处理后的 `白底`、`方型` 、主体居中的 `图像`. 74 | 75 | ## 提示 76 | 77 | - 预处理图像时,识别主体,并移除所有背景。 78 | - 使用图像裁剪,来聚焦主体,并转为正方形图像 79 | - 尝试多张图像,选择最佳图像 80 | - 针对最终图像,进行放大处理 81 | 82 | # 安装 83 | 84 | ## 使用 ComfyUI Manager 85 | 86 | ### 自定义节点 87 | 88 | 搜索 `zero123` 选择本repo,进行安装。 89 | 90 | ### Models 91 | 92 | 搜索 `zero123`, 安装想要的模型。推荐 `zero123-xl.ckpt` 或 `stable-zero123` (商用需要许可) 93 | 94 | ## 手工安装 95 | 96 | ### 自定义节点 97 | 98 | ``` 99 | cd ComfyUI/custom_nodes 100 | git clone https://github.com/kealiu/ComfyUI-Zero123-Porting.git 101 | cd ComfyUI-Zero123-Porting 102 | pip install -r requirements.txt 103 | ``` 104 | 105 | 然后,重新启动 `ComfyUI`, 并刷新浏览器。 106 | 107 | ### 模型 108 | 109 | 打开 [`model-list.json`](model-list.json) ,获取模块下载 URL,并下载到 **`ComfyUI/models/checkpoints/zero123/`** 110 | 111 | # Zero123 related works 112 | 113 | - `zero123` : 原版 [zero123](https://zero123.cs.columbia.edu/), 也是本Repo来源版本。 114 | - `stable-zero123` : [StableAI](https://stability.ai/) 版本, 宣称由更多、更好的数据,以及更优的算法训练而来。[开源模型]((https://huggingface.co/stabilityai/stable-zero123))但商用有限制。 115 | - `zero123++` : [Sudo AI](https://sudo.ai) 版本。 [同样开源了模型](https://github.com/SUDO-AI-3D/zero123plus),但模型固定输出6张固定角度的图像。 116 | 117 | # Thanks to 118 | 119 | [Zero-1-to-3: Zero-shot One Image to 3D Object](https://github.com/cvlab-columbia/zero123),一款能够在diffusion模型中学习到相机视角控制机制的框架。 120 | 121 | ``` 122 | @misc{liu2023zero1to3, 123 | title={Zero-1-to-3: Zero-shot One Image to 3D Object}, 124 | author={Ruoshi Liu and Rundi Wu and Basile Van Hoorick and Pavel Tokmakov and Sergey Zakharov and Carl Vondrick}, 125 | year={2023}, 126 | eprint={2303.11328}, 127 | archivePrefix={arXiv}, 128 | primaryClass={cs.CV} 129 | } 130 | ``` 131 | -------------------------------------------------------------------------------- /ldm/modules/evaluate/ssim.py: -------------------------------------------------------------------------------- 1 | # MIT Licence 2 | 3 | # Methods to predict the SSIM, taken from 4 | # https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py 5 | 6 | from math import exp 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | 12 | def gaussian(window_size, sigma): 13 | gauss = torch.Tensor( 14 | [ 15 | exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2)) 16 | for x in range(window_size) 17 | ] 18 | ) 19 | return gauss / gauss.sum() 20 | 21 | 22 | def create_window(window_size, channel): 23 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 24 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 25 | window = Variable( 26 | _2D_window.expand(channel, 1, window_size, window_size).contiguous() 27 | ) 28 | return window 29 | 30 | 31 | def _ssim( 32 | img1, img2, window, window_size, channel, mask=None, size_average=True 33 | ): 34 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 35 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 36 | 37 | mu1_sq = mu1.pow(2) 38 | mu2_sq = mu2.pow(2) 39 | mu1_mu2 = mu1 * mu2 40 | 41 | sigma1_sq = ( 42 | F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) 43 | - mu1_sq 44 | ) 45 | sigma2_sq = ( 46 | F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) 47 | - mu2_sq 48 | ) 49 | sigma12 = ( 50 | F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) 51 | - mu1_mu2 52 | ) 53 | 54 | C1 = (0.01) ** 2 55 | C2 = (0.03) ** 2 56 | 57 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( 58 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) 59 | ) 60 | 61 | if not (mask is None): 62 | b = mask.size(0) 63 | ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask 64 | ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum( 65 | dim=1 66 | ).clamp(min=1) 67 | return ssim_map 68 | 69 | import pdb 70 | 71 | pdb.set_trace 72 | 73 | if size_average: 74 | return ssim_map.mean() 75 | else: 76 | return ssim_map.mean(1).mean(1).mean(1) 77 | 78 | 79 | class SSIM(torch.nn.Module): 80 | def __init__(self, window_size=11, size_average=True): 81 | super(SSIM, self).__init__() 82 | self.window_size = window_size 83 | self.size_average = size_average 84 | self.channel = 1 85 | self.window = create_window(window_size, self.channel) 86 | 87 | def forward(self, img1, img2, mask=None): 88 | (_, channel, _, _) = img1.size() 89 | 90 | if ( 91 | channel == self.channel 92 | and self.window.data.type() == img1.data.type() 93 | ): 94 | window = self.window 95 | else: 96 | window = create_window(self.window_size, channel) 97 | 98 | if img1.is_cuda: 99 | window = window.cuda(img1.get_device()) 100 | window = window.type_as(img1) 101 | 102 | self.window = window 103 | self.channel = channel 104 | 105 | return _ssim( 106 | img1, 107 | img2, 108 | window, 109 | self.window_size, 110 | channel, 111 | mask, 112 | self.size_average, 113 | ) 114 | 115 | 116 | def ssim(img1, img2, window_size=11, mask=None, size_average=True): 117 | (_, channel, _, _) = img1.size() 118 | window = create_window(window_size, channel) 119 | 120 | if img1.is_cuda: 121 | window = window.cuda(img1.get_device()) 122 | window = window.type_as(img1) 123 | 124 | return _ssim(img1, img2, window, window_size, channel, mask, size_average) 125 | -------------------------------------------------------------------------------- /ldm/thirdp/psp/helpers.py: -------------------------------------------------------------------------------- 1 | # https://github.com/eladrich/pixel2style2pixel 2 | 3 | from collections import namedtuple 4 | import torch 5 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 6 | 7 | """ 8 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 9 | """ 10 | 11 | 12 | class Flatten(Module): 13 | def forward(self, input): 14 | return input.view(input.size(0), -1) 15 | 16 | 17 | def l2_norm(input, axis=1): 18 | norm = torch.norm(input, 2, axis, True) 19 | output = torch.div(input, norm) 20 | return output 21 | 22 | 23 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 24 | """ A named tuple describing a ResNet block. """ 25 | 26 | 27 | def get_block(in_channel, depth, num_units, stride=2): 28 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 29 | 30 | 31 | def get_blocks(num_layers): 32 | if num_layers == 50: 33 | blocks = [ 34 | get_block(in_channel=64, depth=64, num_units=3), 35 | get_block(in_channel=64, depth=128, num_units=4), 36 | get_block(in_channel=128, depth=256, num_units=14), 37 | get_block(in_channel=256, depth=512, num_units=3) 38 | ] 39 | elif num_layers == 100: 40 | blocks = [ 41 | get_block(in_channel=64, depth=64, num_units=3), 42 | get_block(in_channel=64, depth=128, num_units=13), 43 | get_block(in_channel=128, depth=256, num_units=30), 44 | get_block(in_channel=256, depth=512, num_units=3) 45 | ] 46 | elif num_layers == 152: 47 | blocks = [ 48 | get_block(in_channel=64, depth=64, num_units=3), 49 | get_block(in_channel=64, depth=128, num_units=8), 50 | get_block(in_channel=128, depth=256, num_units=36), 51 | get_block(in_channel=256, depth=512, num_units=3) 52 | ] 53 | else: 54 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 55 | return blocks 56 | 57 | 58 | class SEModule(Module): 59 | def __init__(self, channels, reduction): 60 | super(SEModule, self).__init__() 61 | self.avg_pool = AdaptiveAvgPool2d(1) 62 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 63 | self.relu = ReLU(inplace=True) 64 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 65 | self.sigmoid = Sigmoid() 66 | 67 | def forward(self, x): 68 | module_input = x 69 | x = self.avg_pool(x) 70 | x = self.fc1(x) 71 | x = self.relu(x) 72 | x = self.fc2(x) 73 | x = self.sigmoid(x) 74 | return module_input * x 75 | 76 | 77 | class bottleneck_IR(Module): 78 | def __init__(self, in_channel, depth, stride): 79 | super(bottleneck_IR, self).__init__() 80 | if in_channel == depth: 81 | self.shortcut_layer = MaxPool2d(1, stride) 82 | else: 83 | self.shortcut_layer = Sequential( 84 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 85 | BatchNorm2d(depth) 86 | ) 87 | self.res_layer = Sequential( 88 | BatchNorm2d(in_channel), 89 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 90 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 91 | ) 92 | 93 | def forward(self, x): 94 | shortcut = self.shortcut_layer(x) 95 | res = self.res_layer(x) 96 | return res + shortcut 97 | 98 | 99 | class bottleneck_IR_SE(Module): 100 | def __init__(self, in_channel, depth, stride): 101 | super(bottleneck_IR_SE, self).__init__() 102 | if in_channel == depth: 103 | self.shortcut_layer = MaxPool2d(1, stride) 104 | else: 105 | self.shortcut_layer = Sequential( 106 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 107 | BatchNorm2d(depth) 108 | ) 109 | self.res_layer = Sequential( 110 | BatchNorm2d(in_channel), 111 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 112 | PReLU(depth), 113 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 114 | BatchNorm2d(depth), 115 | SEModule(depth, 16) 116 | ) 117 | 118 | def forward(self, x): 119 | shortcut = self.shortcut_layer(x) 120 | res = self.res_layer(x) 121 | return res + shortcut -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /util_preprocess.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import torch 4 | 5 | def image_to_mask(image, channel='red'): 6 | channels = ["red", "green", "blue", "alpha"] 7 | mask = image[:, :, channels.index(channel)] 8 | return (mask,) 9 | 10 | # mask handle 11 | def tensor2PIL(image: torch.Tensor) -> list[Image.Image]: 12 | batch_count = image.size(0) if len(image.shape) > 3 else 1 13 | if batch_count > 1: 14 | out = [] 15 | for i in range(batch_count): 16 | out.extend(tensor2PIL(image[i])) 17 | return out 18 | 19 | return [ 20 | Image.fromarray( 21 | np.clip(255.0 * image.cpu().numpy().squeeze(), 0, 255).astype( 22 | np.uint8 23 | ) 24 | ) 25 | ] 26 | 27 | def mask2bbox(mask: torch.Tensor): 28 | _mask = tensor2PIL(mask)[0] 29 | alpha_channel = np.array(_mask) 30 | 31 | non_zero_indices = np.nonzero(alpha_channel) 32 | 33 | try: 34 | min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1]) 35 | min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0]) 36 | except: 37 | return (-1, -1, -1, -1, -1) 38 | 39 | h = max_y - min_y 40 | w = max_x - min_x 41 | corpx = min_x 42 | corpy = min_y 43 | sidelen = h 44 | if (h > w): 45 | corpx = corpx - (h - w)//2 46 | elif (h < w ): 47 | sidelen = w 48 | corpy = corpy - (w - h)//2 49 | return (corpx, corpy, h, w, sidelen) 50 | 51 | # 52 | # code borrow from comfyui 53 | # 54 | def repeat_to_batch_size(tensor, batch_size): 55 | if tensor.shape[0] > batch_size: 56 | return tensor[:batch_size] 57 | elif tensor.shape[0] < batch_size: 58 | return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size] 59 | return tensor 60 | 61 | def composite_image_with_mask(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): 62 | source = source.to(destination.device) 63 | if resize_source: 64 | source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") 65 | 66 | source = repeat_to_batch_size(source, destination.shape[0]) 67 | 68 | x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier)) 69 | y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier)) 70 | 71 | left, top = (x // multiplier, y // multiplier) 72 | right, bottom = (left + source.shape[3], top + source.shape[2],) 73 | 74 | if mask is None: 75 | mask = torch.ones_like(source) 76 | else: 77 | mask = mask.to(destination.device, copy=True) 78 | mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") 79 | mask = repeat_to_batch_size(mask, source.shape[0]) 80 | 81 | # calculate the bounds of the source that will be overlapping the destination 82 | # this prevents the source trying to overwrite latent pixels that are out of bounds 83 | # of the destination 84 | visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),) 85 | 86 | mask = mask[:, :, :visible_height, :visible_width] 87 | inverse_mask = torch.ones_like(mask) - mask 88 | 89 | source_portion = mask * source[:, :, :visible_height, :visible_width] 90 | destination_portion = inverse_mask * destination[:, :, top:bottom, left:right] 91 | 92 | destination[:, :, top:bottom, left:right] = source_portion + destination_portion 93 | return destination 94 | 95 | def composite_new_image(destination, source, x, y, resize_source, mask = None): 96 | destination = destination.clone().movedim(-1, 1) 97 | output = composite_image_with_mask(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1) 98 | return (output,) 99 | 100 | def generate_pure_image(width, height, batch_size=1, color=0): 101 | r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF) 102 | g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF) 103 | b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF) 104 | a = torch.full([batch_size, height, width, 1], 1) 105 | return (torch.cat((r, g, b, a), dim=-1), ) 106 | -------------------------------------------------------------------------------- /zero123.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import torch 5 | from contextlib import nullcontext 6 | from einops import rearrange 7 | from ldm.util import instantiate_from_config 8 | from ldm.models.diffusion.ddim import DDIMSampler 9 | from omegaconf import OmegaConf 10 | from PIL import Image 11 | from rich import print 12 | from torch import autocast 13 | from torchvision import transforms 14 | 15 | def load_model_from_config(config, ckpt, device, verbose=False): 16 | print(f'Loading model from {ckpt}') 17 | pl_sd = torch.load(ckpt, map_location='cpu') 18 | if 'global_step' in pl_sd: 19 | print(f'Global Step: {pl_sd["global_step"]}') 20 | sd = pl_sd['state_dict'] 21 | 22 | sys.path.insert(0, os.path.dirname(__file__)) 23 | model = instantiate_from_config(config.model) 24 | sys.path.pop(0) 25 | 26 | m, u = model.load_state_dict(sd, strict=False) 27 | if len(m) > 0 and verbose: 28 | print('missing keys:') 29 | print(m) 30 | if len(u) > 0 and verbose: 31 | print('unexpected keys:') 32 | print(u) 33 | 34 | model.to(device) 35 | model.eval() 36 | return model 37 | 38 | 39 | def init_model(device, ckpt, half_precision=False): 40 | config = os.path.join(os.path.dirname(__file__), 'config/zero123.yaml') 41 | config = OmegaConf.load(config) 42 | 43 | # Instantiate all models beforehand for efficiency. 44 | if half_precision: 45 | model = torch.compile(load_model_from_config(config, ckpt, device=device)).half() 46 | else: 47 | model = torch.compile(load_model_from_config(config, ckpt, device=device)) 48 | 49 | return model 50 | 51 | @torch.no_grad() 52 | def sample_model_batch(model, sampler, input_im, xs, ys, n_samples=4, precision='autocast', ddim_eta=1.0, ddim_steps=75, scale=3.0, h=256, w=256): 53 | precision_scope = autocast if precision == 'autocast' else nullcontext 54 | with precision_scope("cuda"): 55 | with model.ema_scope(): 56 | c = model.get_learned_conditioning(input_im).tile(n_samples, 1, 1) 57 | T = [] 58 | for x, y in zip(xs, ys): 59 | T.append([np.radians(x), np.sin(np.radians(y)), np.cos(np.radians(y)), 0]) 60 | T = torch.tensor(np.array(T))[:, None, :].float().to(c.device) 61 | c = torch.cat([c, T], dim=-1) 62 | c = model.cc_projection(c) 63 | cond = {} 64 | cond['c_crossattn'] = [c] 65 | cond['c_concat'] = [model.encode_first_stage(input_im).mode().detach() 66 | .repeat(n_samples, 1, 1, 1)] 67 | if scale != 1.0: 68 | uc = {} 69 | uc['c_concat'] = [torch.zeros(n_samples, 4, h // 8, w // 8).to(c.device)] 70 | uc['c_crossattn'] = [torch.zeros_like(c).to(c.device)] 71 | else: 72 | uc = None 73 | 74 | shape = [4, h // 8, w // 8] 75 | samples_ddim, _ = sampler.sample(S=ddim_steps, 76 | conditioning=cond, 77 | batch_size=n_samples, 78 | shape=shape, 79 | verbose=False, 80 | unconditional_guidance_scale=scale, 81 | unconditional_conditioning=uc, 82 | eta=ddim_eta, 83 | x_T=None) 84 | x_samples_ddim = model.decode_first_stage(samples_ddim) 85 | ret_imgs = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu() 86 | del cond, c, x_samples_ddim, samples_ddim, uc, input_im 87 | torch.cuda.empty_cache() 88 | return ret_imgs 89 | 90 | @torch.no_grad() 91 | def predict_cam(model, imnp, xs, ys, device="cuda", n_samples=1, ddim_steps=75, scale=3.0): 92 | input_im = transforms.ToTensor()(imnp).unsqueeze(0).to(device) 93 | input_im = input_im * 2 - 1 94 | 95 | sampler = DDIMSampler(model) 96 | sampleimg = sample_model_batch(model, sampler, input_im, xs, ys, n_samples=n_samples, ddim_steps=ddim_steps, scale=scale) 97 | 98 | out_images = [] 99 | for sample in sampleimg: 100 | image = torch.from_numpy(rearrange(sample.numpy(), 'c h w -> h w c'))[None,] 101 | out_images.append(image) 102 | return out_images 103 | -------------------------------------------------------------------------------- /zero123_nodes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import folder_paths 4 | import numpy as np 5 | from PIL import Image 6 | from zero123 import init_model, predict_cam 7 | 8 | from comfy import model_management 9 | 10 | from util_preprocess import mask2bbox, composite_new_image, generate_pure_image 11 | 12 | g_model = None 13 | g_ckpt = None 14 | g_device = None 15 | g_hf = None 16 | def load_model(checkpoint, hf=True): 17 | global g_model 18 | global g_ckpt 19 | global g_device 20 | global g_hf 21 | if (g_ckpt == checkpoint) and (g_hf == hf) and g_model: 22 | return (g_model, g_device) 23 | # not init or ckpt changed 24 | if g_model: 25 | del g_model # may need reload model 26 | g_model = None 27 | torch.cuda.empty_cache() 28 | 29 | if not g_device: 30 | g_device = model_management.get_torch_device() 31 | if (not g_device) and torch.cuda.is_available(): 32 | gpu = torch.cuda.current_device() 33 | if gpu >= 0: 34 | g_device = f'cuda:{gpu}' 35 | else: 36 | g_device = 'cpu' 37 | g_model = init_model(g_device, checkpoint, half_precision=hf) 38 | g_ckpt = checkpoint 39 | g_hf = hf 40 | return (g_model, g_device) 41 | 42 | class Zero123Preprocess: 43 | @classmethod 44 | def INPUT_TYPES(s): 45 | return { 46 | "required": { 47 | "image": ("IMAGE",), 48 | "mask": ("MASK",), 49 | "margin": ("FLOAT", { "default": 0.05, "min": 0.01, "max": 1.0, "step": 0.01}) 50 | } 51 | } 52 | 53 | RETURN_TYPES = ("IMAGE",) 54 | OUTPUT_IS_LIST = (False, ) 55 | FUNCTION = "zero123_proprecess" 56 | CATEGORY = "image" 57 | 58 | def zero123_proprecess(self, image, mask, margin): 59 | # generate new image 60 | ox, oy, h, w, nl = mask2bbox(mask[0]) 61 | if nl <= 0: 62 | print("!!!ERROR: Empty Mask, no subject found! Please Check it") 63 | raise ValueError("!!!ERROR: Empty Mask, no subject found! Please Check it") 64 | return None 65 | 66 | bb_image = image[0][int(oy):int(oy+nl),int(ox):int(ox+nl), :].unsqueeze(0) 67 | bb_mask = mask[0][int(oy):int(oy+nl),int(ox):int(ox+nl)] 68 | if bb_image.shape[3] == 3: # RGB 69 | alpha = torch.ones(1, nl, nl, 1) 70 | bb_image = torch.cat((bb_image, alpha), 3) 71 | 72 | margin_nl = math.floor(nl*margin)+1 73 | pure_image = generate_pure_image(nl+margin_nl*2, nl+margin_nl*2, color=0xffffff)[0] 74 | return composite_new_image(pure_image, bb_image, margin_nl, margin_nl, False, mask = bb_mask) 75 | 76 | 77 | class Zero123: 78 | @classmethod 79 | def INPUT_TYPES(s): 80 | return { 81 | "required": { 82 | "image": ("IMAGE",), 83 | "polar_angle": ("INT", { "default": 0, "min": -180, "max": 180, "step": 1, "display": "number"}), 84 | "azimuth_angle": ("INT", { "default": 0, "min": -180, "max": 180, "step": 1, "display": "number"}), 85 | "scale": ("FLOAT", { "default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}), 86 | "steps": ("INT", { "default": 75, "min": 1, "step": 1, "display": "number"}), 87 | "batch_size": ("INT", { "default": 1, "min": 1, "step": 1, "display": "number"}), 88 | "fp16": ("BOOLEAN", { "default": True }), 89 | "checkpoint": (list(filter(lambda k: 'zero123' in k, folder_paths.get_filename_list("checkpoints"))), ), 90 | }, 91 | "optional": { 92 | "height": (["height=256"],), 93 | "width": (["width=256"],), 94 | "sampler": (["ddim"],), 95 | "scheduler": (["ddim-uniform"],), 96 | } 97 | } 98 | 99 | RETURN_TYPES = ("IMAGE",) 100 | OUTPUT_IS_LIST = (False, ) 101 | FUNCTION = "moveCam" 102 | CATEGORY = "image" 103 | 104 | # 105 | # height, width, sample, scheduler cannot changed currently, just for show information 106 | # 107 | def moveCam(self, image, polar_angle, azimuth_angle, scale, steps, batch_size, fp16, checkpoint, *args, **kwargs): 108 | xs = [polar_angle]*batch_size 109 | ys = [azimuth_angle]*batch_size 110 | 111 | model, device = load_model(folder_paths.get_full_path("checkpoints", checkpoint), hf=fp16) 112 | 113 | # just for simplify 114 | if image.shape[3] > 3: 115 | image = image[:, :, :, :3] 116 | 117 | input_im = Image.fromarray((255. * image[0]).numpy().astype(np.uint8)) 118 | w, h = input_im.size 119 | if (w != 256) or (h != 256): 120 | input_im = input_im.resize([256, 256], Image.Resampling.LANCZOS) 121 | input_im = np.asarray(input_im, dtype=np.float32) / 255.0 122 | 123 | # input_im = Image.open('input.png') 124 | # input_imnp = np.asarray(input_im, dtype=np.float32) / 255.0 125 | outputs = predict_cam(model, input_im, xs, ys, scale=scale, device=device, n_samples=batch_size, ddim_steps=steps) 126 | 127 | return outputs 128 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI Node for Zero-1-to-3: Zero-shot One Image to 3D Object 2 | 3 | [中文](README_CN.md) 4 | 5 | This is an unofficial porting of [Zero123 https://zero123.cs.columbia.edu/](https://zero123.cs.columbia.edu/) for ComfyUI, Zero123 is a framework for changing the camera viewpoint of an object given just a single RGB image. 6 | 7 | This porting enable you generate 3D rotated image in ComfyUI. 8 | 9 | ![Functions](https://github.com/cvlab-columbia/zero123/blob/main/teaser.png) 10 | 11 | # Quickly Start 12 | 13 | After install this node, download the [sample workflow](sample/simple_workflow.json) to start trying. 14 | 15 | If you have any questions or suggestions, please don't hesitate to leave them in the [issue tracker](https://github.com/kealiu/ComfyUI-Zero123-Porting/issues). 16 | 17 | ## Node and Workflow 18 | 19 | ### Node `Zero123: Image Rotate in 3D` 20 | 21 | ![simple workflow](images/Zero123-Simple.png) 22 | 23 | ### Node `Zero123: Image Preprocess` 24 | 25 | ![simple image process](images/image_preprocess.png) 26 | 27 | ## PREREQUISITES 28 | 29 | - INPUT `image` must `square` (width=height), otherwise, this node will automatically trans it forcely. 30 | - INPUT `image` should be an `object` with **white background**, which means you need preprocess of image (use `Zero123: Image Preprocess). 31 | - OUTPUT `image` only support `256x256` (fixed) currently, you can upscale it later. 32 | 33 | # Explains 34 | 35 | ## Node `Zero123: Image Rotate in 3D` Input and Output 36 | 37 | ### INPUT 38 | 39 | - **_image_** : input image, should be an `square` image, and an `object` with `white backgroup`. 40 | - **_polar_angle_** : angle of `x` axis, turn up or down 41 | - `<0.0`: turn up 42 | - `>0.0`: turn down 43 | - **_azimuth_angle_** : angle of `y` axis, turn left or right 44 | - `<0.0`: turn left 45 | - `>0.0`: turn right 46 | - **_scale_** : `z` axis, `far away` or `near`; 47 | - `>1.0` : means bigger, or `near`; 48 | - `0<1<1.0` : means smaller, or `far away` 49 | - `1.0` : mean no change 50 | - **_steps_** : `75` is the default value by original `zero123` repo, do not smaller then `75`. 51 | - **_batch_size_** : how many images you do like to generated. 52 | - **_fp16_** : whether to load model in `fp16`. enable it can speed up and save GPU mem. 53 | - **_checkpoint_** : the model you select, `zero123-xl` is the lates one, and `stable-zero123`claim to be the best, but licences required for commercially use. 54 | - **_height_** : output height, fix to 256, information only 55 | - **_width_** : output width, fix to 256, information only 56 | - **_sampler_** : cannot change, information only 57 | - **_scheduler_** : cannot change, information only 58 | 59 | ### OUTPUT 60 | 61 | - **_images_** : the output images 62 | 63 | ## Node `Zero123: Image Preprocess` Input and Output 64 | 65 | ### INPUT 66 | 67 | - **_image_** : the original input `image`. 68 | - **_mask_** : the `mask` of the `image`. 69 | - **_margin_** : the `percentage(%)` margin of output image. 70 | 71 | ### OUTPUT 72 | 73 | - **_image_** : the processed `white background`, and `square` version input `image` with subject in center. 74 | 75 | ## Tips 76 | 77 | - for proprecess image, segment out the subject, and remove all backgroup. 78 | - use image corp to focus the main subject, and make a squre image 79 | - try multi images and select the best one 80 | - upscale for final usage. 81 | 82 | # Installation 83 | 84 | ## By ComfyUI Manager 85 | 86 | ### Customer Nodes 87 | 88 | search `zero123` and select this repo, install it. 89 | 90 | ### Models 91 | 92 | search `zero123` and install the model you like, such as `zero123-xl.ckpt` and `stable-zero123` (licences required for commercially). 93 | 94 | ## Manually Installation 95 | 96 | ### Customer Nodes 97 | 98 | ``` 99 | cd ComfyUI/custom_nodes 100 | git clone https://github.com/kealiu/ComfyUI-Zero123-Porting.git 101 | cd ComfyUI-Zero123-Porting 102 | pip install -r requirements.txt 103 | ``` 104 | 105 | And then, restart `ComfyUI`, and refresh your browser. 106 | 107 | ### Models 108 | 109 | check out [`model-list.json`](model-list.json) for modules download URL, their should be place under **`ComfyUI/models/checkpoints/zero123/`** 110 | 111 | 112 | # Zero123 related works 113 | 114 | - `zero123` by [zero123](https://zero123.cs.columbia.edu/), the original one. This repo porting from this one. 115 | - `stable-zero123` by [StableAI](https://stability.ai/), which train [models](https://huggingface.co/stabilityai/stable-zero123) with more data and claim to have better output. 116 | - `zero123++` by [Sudo AI](https://sudo.ai), which [opensource a model](https://github.com/SUDO-AI-3D/zero123plus) that always gen image with fix angles. 117 | 118 | # Thanks to 119 | 120 | [Zero-1-to-3: Zero-shot One Image to 3D Object](https://github.com/cvlab-columbia/zero123), which be able to learn control mechanisms that manipulate the camera viewpoint in large-scale diffusion models 121 | 122 | ``` 123 | @misc{liu2023zero1to3, 124 | title={Zero-1-to-3: Zero-shot One Image to 3D Object}, 125 | author={Ruoshi Liu and Rundi Wu and Basile Van Hoorick and Pavel Tokmakov and Sergey Zakharov and Carl Vondrick}, 126 | year={2023}, 127 | eprint={2303.11328}, 128 | archivePrefix={arXiv}, 129 | primaryClass={cs.CV} 130 | } 131 | ``` 132 | -------------------------------------------------------------------------------- /ldm/data/inpainting/synthetic_mask.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw 2 | import numpy as np 3 | 4 | settings = { 5 | "256narrow": { 6 | "p_irr": 1, 7 | "min_n_irr": 4, 8 | "max_n_irr": 50, 9 | "max_l_irr": 40, 10 | "max_w_irr": 10, 11 | "min_n_box": None, 12 | "max_n_box": None, 13 | "min_s_box": None, 14 | "max_s_box": None, 15 | "marg": None, 16 | }, 17 | "256train": { 18 | "p_irr": 0.5, 19 | "min_n_irr": 1, 20 | "max_n_irr": 5, 21 | "max_l_irr": 200, 22 | "max_w_irr": 100, 23 | "min_n_box": 1, 24 | "max_n_box": 4, 25 | "min_s_box": 30, 26 | "max_s_box": 150, 27 | "marg": 10, 28 | }, 29 | "512train": { # TODO: experimental 30 | "p_irr": 0.5, 31 | "min_n_irr": 1, 32 | "max_n_irr": 5, 33 | "max_l_irr": 450, 34 | "max_w_irr": 250, 35 | "min_n_box": 1, 36 | "max_n_box": 4, 37 | "min_s_box": 30, 38 | "max_s_box": 300, 39 | "marg": 10, 40 | }, 41 | "512train-large": { # TODO: experimental 42 | "p_irr": 0.5, 43 | "min_n_irr": 1, 44 | "max_n_irr": 5, 45 | "max_l_irr": 450, 46 | "max_w_irr": 400, 47 | "min_n_box": 1, 48 | "max_n_box": 4, 49 | "min_s_box": 75, 50 | "max_s_box": 450, 51 | "marg": 10, 52 | }, 53 | } 54 | 55 | 56 | def gen_segment_mask(mask, start, end, brush_width): 57 | mask = mask > 0 58 | mask = (255 * mask).astype(np.uint8) 59 | mask = Image.fromarray(mask) 60 | draw = ImageDraw.Draw(mask) 61 | draw.line([start, end], fill=255, width=brush_width, joint="curve") 62 | mask = np.array(mask) / 255 63 | return mask 64 | 65 | 66 | def gen_box_mask(mask, masked): 67 | x_0, y_0, w, h = masked 68 | mask[y_0:y_0 + h, x_0:x_0 + w] = 1 69 | return mask 70 | 71 | 72 | def gen_round_mask(mask, masked, radius): 73 | x_0, y_0, w, h = masked 74 | xy = [(x_0, y_0), (x_0 + w, y_0 + w)] 75 | 76 | mask = mask > 0 77 | mask = (255 * mask).astype(np.uint8) 78 | mask = Image.fromarray(mask) 79 | draw = ImageDraw.Draw(mask) 80 | draw.rounded_rectangle(xy, radius=radius, fill=255) 81 | mask = np.array(mask) / 255 82 | return mask 83 | 84 | 85 | def gen_large_mask(prng, img_h, img_w, 86 | marg, p_irr, min_n_irr, max_n_irr, max_l_irr, max_w_irr, 87 | min_n_box, max_n_box, min_s_box, max_s_box): 88 | """ 89 | img_h: int, an image height 90 | img_w: int, an image width 91 | marg: int, a margin for a box starting coordinate 92 | p_irr: float, 0 <= p_irr <= 1, a probability of a polygonal chain mask 93 | 94 | min_n_irr: int, min number of segments 95 | max_n_irr: int, max number of segments 96 | max_l_irr: max length of a segment in polygonal chain 97 | max_w_irr: max width of a segment in polygonal chain 98 | 99 | min_n_box: int, min bound for the number of box primitives 100 | max_n_box: int, max bound for the number of box primitives 101 | min_s_box: int, min length of a box side 102 | max_s_box: int, max length of a box side 103 | """ 104 | 105 | mask = np.zeros((img_h, img_w)) 106 | uniform = prng.randint 107 | 108 | if np.random.uniform(0, 1) < p_irr: # generate polygonal chain 109 | n = uniform(min_n_irr, max_n_irr) # sample number of segments 110 | 111 | for _ in range(n): 112 | y = uniform(0, img_h) # sample a starting point 113 | x = uniform(0, img_w) 114 | 115 | a = uniform(0, 360) # sample angle 116 | l = uniform(10, max_l_irr) # sample segment length 117 | w = uniform(5, max_w_irr) # sample a segment width 118 | 119 | # draw segment starting from (x,y) to (x_,y_) using brush of width w 120 | x_ = x + l * np.sin(a) 121 | y_ = y + l * np.cos(a) 122 | 123 | mask = gen_segment_mask(mask, start=(x, y), end=(x_, y_), brush_width=w) 124 | x, y = x_, y_ 125 | else: # generate Box masks 126 | n = uniform(min_n_box, max_n_box) # sample number of rectangles 127 | 128 | for _ in range(n): 129 | h = uniform(min_s_box, max_s_box) # sample box shape 130 | w = uniform(min_s_box, max_s_box) 131 | 132 | x_0 = uniform(marg, img_w - marg - w) # sample upper-left coordinates of box 133 | y_0 = uniform(marg, img_h - marg - h) 134 | 135 | if np.random.uniform(0, 1) < 0.5: 136 | mask = gen_box_mask(mask, masked=(x_0, y_0, w, h)) 137 | else: 138 | r = uniform(0, 60) # sample radius 139 | mask = gen_round_mask(mask, masked=(x_0, y_0, w, h), radius=r) 140 | return mask 141 | 142 | 143 | make_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256train"]) 144 | make_narrow_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256narrow"]) 145 | make_512_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train"]) 146 | make_512_lama_mask_large = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train-large"]) 147 | 148 | 149 | MASK_MODES = { 150 | "256train": make_lama_mask, 151 | "256narrow": make_narrow_lama_mask, 152 | "512train": make_512_lama_mask, 153 | "512train-large": make_512_lama_mask_large 154 | } 155 | 156 | if __name__ == "__main__": 157 | import sys 158 | 159 | out = sys.argv[1] 160 | 161 | prng = np.random.RandomState(1) 162 | kwargs = settings["256train"] 163 | mask = gen_large_mask(prng, 256, 256, **kwargs) 164 | mask = (255 * mask).astype(np.uint8) 165 | mask = Image.fromarray(mask) 166 | mask.save(out) 167 | -------------------------------------------------------------------------------- /ldm/modules/evaluate/frechet_video_distance.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python2, python3 17 | """Minimal Reference implementation for the Frechet Video Distance (FVD). 18 | 19 | FVD is a metric for the quality of video generation models. It is inspired by 20 | the FID (Frechet Inception Distance) used for images, but uses a different 21 | embedding to be better suitable for videos. 22 | """ 23 | 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | from __future__ import print_function 27 | 28 | 29 | import six 30 | import tensorflow.compat.v1 as tf 31 | import tensorflow_gan as tfgan 32 | import tensorflow_hub as hub 33 | 34 | 35 | def preprocess(videos, target_resolution): 36 | """Runs some preprocessing on the videos for I3D model. 37 | 38 | Args: 39 | videos: [batch_size, num_frames, height, width, depth] The videos to be 40 | preprocessed. We don't care about the specific dtype of the videos, it can 41 | be anything that tf.image.resize_bilinear accepts. Values are expected to 42 | be in the range 0-255. 43 | target_resolution: (width, height): target video resolution 44 | 45 | Returns: 46 | videos: [batch_size, num_frames, height, width, depth] 47 | """ 48 | videos_shape = list(videos.shape) 49 | all_frames = tf.reshape(videos, [-1] + videos_shape[-3:]) 50 | resized_videos = tf.image.resize_bilinear(all_frames, size=target_resolution) 51 | target_shape = [videos_shape[0], -1] + list(target_resolution) + [3] 52 | output_videos = tf.reshape(resized_videos, target_shape) 53 | scaled_videos = 2. * tf.cast(output_videos, tf.float32) / 255. - 1 54 | return scaled_videos 55 | 56 | 57 | def _is_in_graph(tensor_name): 58 | """Checks whether a given tensor does exists in the graph.""" 59 | try: 60 | tf.get_default_graph().get_tensor_by_name(tensor_name) 61 | except KeyError: 62 | return False 63 | return True 64 | 65 | 66 | def create_id3_embedding(videos,warmup=False,batch_size=16): 67 | """Embeds the given videos using the Inflated 3D Convolution ne twork. 68 | 69 | Downloads the graph of the I3D from tf.hub and adds it to the graph on the 70 | first call. 71 | 72 | Args: 73 | videos: [batch_size, num_frames, height=224, width=224, depth=3]. 74 | Expected range is [-1, 1]. 75 | 76 | Returns: 77 | embedding: [batch_size, embedding_size]. embedding_size depends 78 | on the model used. 79 | 80 | Raises: 81 | ValueError: when a provided embedding_layer is not supported. 82 | """ 83 | 84 | # batch_size = 16 85 | module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1" 86 | 87 | 88 | # Making sure that we import the graph separately for 89 | # each different input video tensor. 90 | module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str( 91 | videos.name).replace(":", "_") 92 | 93 | 94 | 95 | assert_ops = [ 96 | tf.Assert( 97 | tf.reduce_max(videos) <= 1.001, 98 | ["max value in frame is > 1", videos]), 99 | tf.Assert( 100 | tf.reduce_min(videos) >= -1.001, 101 | ["min value in frame is < -1", videos]), 102 | tf.assert_equal( 103 | tf.shape(videos)[0], 104 | batch_size, ["invalid frame batch size: ", 105 | tf.shape(videos)], 106 | summarize=6), 107 | ] 108 | with tf.control_dependencies(assert_ops): 109 | videos = tf.identity(videos) 110 | 111 | module_scope = "%s_apply_default/" % module_name 112 | 113 | # To check whether the module has already been loaded into the graph, we look 114 | # for a given tensor name. If this tensor name exists, we assume the function 115 | # has been called before and the graph was imported. Otherwise we import it. 116 | # Note: in theory, the tensor could exist, but have wrong shapes. 117 | # This will happen if create_id3_embedding is called with a frames_placehoder 118 | # of wrong size/batch size, because even though that will throw a tf.Assert 119 | # on graph-execution time, it will insert the tensor (with wrong shape) into 120 | # the graph. This is why we need the following assert. 121 | if warmup: 122 | video_batch_size = int(videos.shape[0]) 123 | assert video_batch_size in [batch_size, -1, None], f"Invalid batch size {video_batch_size}" 124 | tensor_name = module_scope + "RGB/inception_i3d/Mean:0" 125 | if not _is_in_graph(tensor_name): 126 | i3d_model = hub.Module(module_spec, name=module_name) 127 | i3d_model(videos) 128 | 129 | # gets the kinetics-i3d-400-logits layer 130 | tensor_name = module_scope + "RGB/inception_i3d/Mean:0" 131 | tensor = tf.get_default_graph().get_tensor_by_name(tensor_name) 132 | return tensor 133 | 134 | 135 | def calculate_fvd(real_activations, 136 | generated_activations): 137 | """Returns a list of ops that compute metrics as funcs of activations. 138 | 139 | Args: 140 | real_activations: [num_samples, embedding_size] 141 | generated_activations: [num_samples, embedding_size] 142 | 143 | Returns: 144 | A scalar that contains the requested FVD. 145 | """ 146 | return tfgan.eval.frechet_classifier_distance_from_activations( 147 | real_activations, generated_activations) 148 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui-zero123-porting" 3 | description = "Zero-1-to-3: Zero-shot One Image to 3D Object, unofficial porting of original [Zero123](https://github.com/cvlab-columbia/zero123)" 4 | version = "1.0.0" 5 | license = "LICENSE" 6 | dependencies = ["absl-py==2.1.0", "aiofiles==23.2.1", "aiohttp==3.9.5", "aiosignal==1.3.1", "albumentations==1.4.4", "altair==5.3.0", "annotated-types==0.6.0", "antlr4-python3-runtime==4.9.3", "anyio==4.3.0", "argon2-cffi==23.1.0", "argon2-cffi-bindings==21.2.0", "arrow==1.3.0", "asttokens==2.4.1", "async-lru==2.0.4", "async-timeout==4.0.3", "attrs==23.2.0", "Babel==2.14.0", "beautifulsoup4==4.12.3", "bleach==6.1.0", "blinker==1.8.1", "bokeh==3.4.1", "braceexpand==0.1.7", "cachetools==5.3.3", "certifi==2024.2.2", "cffi==1.16.0", "charset-normalizer==3.3.2", "click==8.1.7", "clip==1.0", "colorama==0.4.6", "coloredlogs==15.0.1", "comm==0.2.2", "contourpy==1.2.1", "cycler==0.12.1", "datasets==2.19.0", "debugpy==1.8.1", "decorator==5.1.1", "defusedxml==0.7.1", "diffusers==0.27.2", "dill==0.3.8", "dl-ext==1.3.4", "einops==0.8.0", "exceptiongroup==1.2.1", "executing==2.0.1", "faiss-gpu==1.7.2", "fastapi==0.110.2", "fastjsonschema==2.19.1", "ffmpy==0.3.2", "filelock==3.13.4", "fire==0.6.0", "flatbuffers==24.3.25", "fonttools==4.51.0", "fqdn==1.5.1", "frozenlist==1.4.1", "fsspec==2024.3.1", "ftfy==6.2.0", "future==1.0.0", "gdown==5.1.0", "gitdb==4.0.11", "GitPython==3.1.43", "gradio==3.44.0", "gradio_client==0.5.0", "grpcio==1.62.2", "h11==0.14.0", "httpcore==1.0.5", "httpx==0.27.0", "huggingface-hub==0.22.2", "humanfriendly==10.0", "icecream==2.1.3", "idna==3.7", "imageio==2.34.1", "imageio-ffmpeg==0.4.9", "importlib_metadata==7.1.0", "importlib_resources==6.4.0", "ipykernel==6.29.4", "ipython==8.24.0", "ipywidgets==8.1.2", "isoduration==20.11.0", "jedi==0.19.1", "Jinja2==3.1.3", "joblib==1.4.0", "json5==0.9.25", "jsonpointer==2.4", "jsonschema==4.21.1", "jsonschema-specifications==2023.12.1", "jupyter==1.0.0", "jupyter-console==6.6.3", "jupyter-events==0.10.0", "jupyter-lsp==2.2.5", "jupyter_bokeh==4.0.4", "jupyter_client==8.6.1", "jupyter_core==5.7.2", "jupyter_server==2.14.0", "jupyter_server_terminals==0.5.3", "jupyterlab==4.1.8", "jupyterlab_pygments==0.3.0", "jupyterlab_server==2.27.1", "jupyterlab_widgets==3.0.10", "kiwisolver==1.4.5", "kornia==0.7.2", "kornia_rs==0.1.3", "lazy_loader==0.4", "lightning-utilities==0.11.2", "linkify-it-py==2.0.3", "llvmlite==0.42.0", "loguru==0.7.2", "Markdown==3.6", "markdown-it-py==3.0.0", "MarkupSafe==2.1.5", "matplotlib==3.8.4", "matplotlib-inline==0.1.7", "mdit-py-plugins==0.4.0", "mdurl==0.1.2", "mistune==3.0.2", "mpmath==1.3.0", "multidict==6.0.5", "multipledispatch==1.0.0", "multiprocess==0.70.16", "nbclient==0.10.0", "nbconvert==7.16.3", "nbformat==5.10.4", "nest-asyncio==1.6.0", "networkx==3.3", "ninja==1.11.1.1", "notebook==7.1.3", "notebook_shim==0.2.4", "numba==0.59.1", "numpy==1.26.4", "nvdiffrast==0.3.1", "nvidia-cublas-cu12==12.1.3.1", "nvidia-cuda-cupti-cu12==12.1.105", "nvidia-cuda-nvrtc-cu12==12.1.105", "nvidia-cuda-runtime-cu12==12.1.105", "nvidia-cudnn-cu12==8.9.2.26", "nvidia-cufft-cu12==11.0.2.54", "nvidia-curand-cu12==10.3.2.106", "nvidia-cusolver-cu12==11.4.5.107", "nvidia-cusparse-cu12==12.1.0.106", "nvidia-nccl-cu12==2.20.5", "nvidia-nvjitlink-cu12==12.4.127", "nvidia-nvtx-cu12==12.1.105", "omegaconf==2.3.0", "onnx==1.16.0", "onnxruntime==1.17.3", "opencv-python==4.9.0.80", "opencv-python-headless==4.9.0.80", "orjson==3.10.1", "overrides==7.7.0", "packaging==24.0", "pandas==2.2.2", "pandocfilters==1.5.1", "panel==1.4.2", "param==2.1.0", "parso==0.8.4", "pexpect==4.9.0", "pillow==10.3.0", "platformdirs==4.2.1", "plotly==5.21.0", "pooch==1.8.1", "prometheus_client==0.20.0", "prompt-toolkit==3.0.43", "protobuf==4.25.3", "psutil==5.9.8", "ptyprocess==0.7.0", "pudb==2024.1", "pure-eval==0.2.2", "pyarrow==16.0.0", "pyarrow-hotfix==0.6", "pycparser==2.22", "pydantic==2.7.1", "pydantic_core==2.18.2", "pydeck==0.9.0b1", "pydub==0.25.1", "Pygments==2.17.2", "pyhocon==0.3.60", "PyMatting==1.1.12", "PyMCubes==0.1.4", "pyparsing==3.1.2", "PySocks==1.7.1", "python-dateutil==2.9.0.post0", "python-json-logger==2.0.7", "python-multipart==0.0.9", "pytorch-lightning==2.2.3", "pytz==2024.1", "pyviz_comms==3.0.2", "PyYAML==6.0.1", "pyzmq==26.0.2", "qtconsole==5.5.1", "QtPy==2.4.1", "referencing==0.35.0", "regex==2024.4.28", "rembg==2.0.56", "requests==2.31.0", "rfc3339-validator==0.1.4", "rfc3986-validator==0.1.1", "rich==13.7.1", "rpds-py==0.18.0", "safetensors==0.4.3", "scikit-image==0.23.2", "scikit-learn==1.4.2", "scipy==1.13.0", "segment-anything==1.0", "semantic-version==2.10.0", "Send2Trash==1.8.3", "six==1.16.0", "smmap==5.0.1", "sniffio==1.3.1", "soupsieve==2.5", "stack-data==0.6.3", "starlette==0.37.2", "streamlit==1.33.0", "sympy==1.12", "-e git+https://github.com/CompVis/taming-transformers.git@3ba01b241669f5ade541ce990f7650a3b8f65318#egg=taming_transformers", "tenacity==8.2.3", "tensorboard==2.16.2", "tensorboard-data-server==0.7.2", "tensorboardX==2.6.2.2", "termcolor==2.4.0", "terminado==0.18.1", "test_tube==0.7.5", "threadpoolctl==3.4.0", "tifffile==2024.4.24", "tinycss2==1.3.0", "tokenizers==0.19.1", "toml==0.10.2", "tomli==2.0.1", "toolz==0.12.1", "torch==2.3.0", "torch-fidelity==0.3.0", "torchmetrics==1.3.2", "torchvision==0.18.0", "tornado==6.4", "tqdm==4.66.2", "traitlets==5.14.3", "transformers==4.40.1", "transforms3d==0.4.1", "trimesh==4.3.1", "triton==2.3.0", "types-python-dateutil==2.9.0.20240316", "typing_extensions==4.11.0", "tzdata==2024.1", "uc-micro-py==1.0.3", "uri-template==1.3.0", "urllib3==2.2.1", "urwid==2.6.11", "urwid-readline==0.14", "uvicorn==0.29.0", "watchdog==4.0.0", "wcwidth==0.2.13", "webcolors==1.13", "webdataset==0.2.86", "webencodings==0.5.1", "websocket-client==1.8.0", "websockets==11.0.3", "Werkzeug==3.0.2", "widgetsnbextension==4.0.10", "xxhash==3.4.1", "xyzservices==2024.4.0", "yacs==0.1.8", "yarl==1.9.4", "zipp==3.18.1"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/kealiu/ComfyUI-Zero123-Porting" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "" 14 | DisplayName = "ComfyUI-Zero123-Porting" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /ldm/data/nerf_like.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import os 3 | import json 4 | import numpy as np 5 | import torch 6 | import imageio 7 | import math 8 | import cv2 9 | from torchvision import transforms 10 | 11 | def cartesian_to_spherical(xyz): 12 | ptsnew = np.hstack((xyz, np.zeros(xyz.shape))) 13 | xy = xyz[:,0]**2 + xyz[:,1]**2 14 | z = np.sqrt(xy + xyz[:,2]**2) 15 | theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down 16 | #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up 17 | azimuth = np.arctan2(xyz[:,1], xyz[:,0]) 18 | return np.array([theta, azimuth, z]) 19 | 20 | 21 | def get_T(T_target, T_cond): 22 | theta_cond, azimuth_cond, z_cond = cartesian_to_spherical(T_cond[None, :]) 23 | theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :]) 24 | 25 | d_theta = theta_target - theta_cond 26 | d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi) 27 | d_z = z_target - z_cond 28 | 29 | d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()]) 30 | return d_T 31 | 32 | def get_spherical(T_target, T_cond): 33 | theta_cond, azimuth_cond, z_cond = cartesian_to_spherical(T_cond[None, :]) 34 | theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :]) 35 | 36 | d_theta = theta_target - theta_cond 37 | d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi) 38 | d_z = z_target - z_cond 39 | 40 | d_T = torch.tensor([math.degrees(d_theta.item()), math.degrees(d_azimuth.item()), d_z.item()]) 41 | return d_T 42 | 43 | class RTMV(Dataset): 44 | def __init__(self, root_dir='datasets/RTMV/google_scanned',\ 45 | first_K=64, resolution=256, load_target=False): 46 | self.root_dir = root_dir 47 | self.scene_list = sorted(next(os.walk(root_dir))[1]) 48 | self.resolution = resolution 49 | self.first_K = first_K 50 | self.load_target = load_target 51 | 52 | def __len__(self): 53 | return len(self.scene_list) 54 | 55 | def __getitem__(self, idx): 56 | scene_dir = os.path.join(self.root_dir, self.scene_list[idx]) 57 | with open(os.path.join(scene_dir, 'transforms.json'), "r") as f: 58 | meta = json.load(f) 59 | imgs = [] 60 | poses = [] 61 | for i_img in range(self.first_K): 62 | meta_img = meta['frames'][i_img] 63 | 64 | if i_img == 0 or self.load_target: 65 | img_path = os.path.join(scene_dir, meta_img['file_path']) 66 | img = imageio.imread(img_path) 67 | img = cv2.resize(img, (self.resolution, self.resolution), interpolation = cv2.INTER_LINEAR) 68 | imgs.append(img) 69 | 70 | c2w = meta_img['transform_matrix'] 71 | poses.append(c2w) 72 | 73 | imgs = (np.array(imgs) / 255.).astype(np.float32) # (RGBA) imgs 74 | imgs = torch.tensor(self.blend_rgba(imgs)).permute(0, 3, 1, 2) 75 | imgs = imgs * 2 - 1. # convert to stable diffusion range 76 | poses = torch.tensor(np.array(poses).astype(np.float32)) 77 | return imgs, poses 78 | 79 | def blend_rgba(self, img): 80 | img = img[..., :3] * img[..., -1:] + (1. - img[..., -1:]) # blend A to RGB 81 | return img 82 | 83 | 84 | class GSO(Dataset): 85 | def __init__(self, root_dir='datasets/GoogleScannedObjects',\ 86 | split='val', first_K=5, resolution=256, load_target=False, name='render_mvs'): 87 | self.root_dir = root_dir 88 | with open(os.path.join(root_dir, '%s.json' % split), "r") as f: 89 | self.scene_list = json.load(f) 90 | self.resolution = resolution 91 | self.first_K = first_K 92 | self.load_target = load_target 93 | self.name = name 94 | 95 | def __len__(self): 96 | return len(self.scene_list) 97 | 98 | def __getitem__(self, idx): 99 | scene_dir = os.path.join(self.root_dir, self.scene_list[idx]) 100 | with open(os.path.join(scene_dir, 'transforms_%s.json' % self.name), "r") as f: 101 | meta = json.load(f) 102 | imgs = [] 103 | poses = [] 104 | for i_img in range(self.first_K): 105 | meta_img = meta['frames'][i_img] 106 | 107 | if i_img == 0 or self.load_target: 108 | img_path = os.path.join(scene_dir, meta_img['file_path']) 109 | img = imageio.imread(img_path) 110 | img = cv2.resize(img, (self.resolution, self.resolution), interpolation = cv2.INTER_LINEAR) 111 | imgs.append(img) 112 | 113 | c2w = meta_img['transform_matrix'] 114 | poses.append(c2w) 115 | 116 | imgs = (np.array(imgs) / 255.).astype(np.float32) # (RGBA) imgs 117 | mask = imgs[:, :, :, -1] 118 | imgs = torch.tensor(self.blend_rgba(imgs)).permute(0, 3, 1, 2) 119 | imgs = imgs * 2 - 1. # convert to stable diffusion range 120 | poses = torch.tensor(np.array(poses).astype(np.float32)) 121 | return imgs, poses 122 | 123 | def blend_rgba(self, img): 124 | img = img[..., :3] * img[..., -1:] + (1. - img[..., -1:]) # blend A to RGB 125 | return img 126 | 127 | class WILD(Dataset): 128 | def __init__(self, root_dir='data/nerf_wild',\ 129 | first_K=33, resolution=256, load_target=False): 130 | self.root_dir = root_dir 131 | self.scene_list = sorted(next(os.walk(root_dir))[1]) 132 | self.resolution = resolution 133 | self.first_K = first_K 134 | self.load_target = load_target 135 | 136 | def __len__(self): 137 | return len(self.scene_list) 138 | 139 | def __getitem__(self, idx): 140 | scene_dir = os.path.join(self.root_dir, self.scene_list[idx]) 141 | with open(os.path.join(scene_dir, 'transforms_train.json'), "r") as f: 142 | meta = json.load(f) 143 | imgs = [] 144 | poses = [] 145 | for i_img in range(self.first_K): 146 | meta_img = meta['frames'][i_img] 147 | 148 | if i_img == 0 or self.load_target: 149 | img_path = os.path.join(scene_dir, meta_img['file_path']) 150 | img = imageio.imread(img_path + '.png') 151 | img = cv2.resize(img, (self.resolution, self.resolution), interpolation = cv2.INTER_LINEAR) 152 | imgs.append(img) 153 | 154 | c2w = meta_img['transform_matrix'] 155 | poses.append(c2w) 156 | 157 | imgs = (np.array(imgs) / 255.).astype(np.float32) # (RGBA) imgs 158 | imgs = torch.tensor(self.blend_rgba(imgs)).permute(0, 3, 1, 2) 159 | imgs = imgs * 2 - 1. # convert to stable diffusion range 160 | poses = torch.tensor(np.array(poses).astype(np.float32)) 161 | return imgs, poses 162 | 163 | def blend_rgba(self, img): 164 | img = img[..., :3] * img[..., -1:] + (1. - img[..., -1:]) # blend A to RGB 165 | return img -------------------------------------------------------------------------------- /ldm/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | 11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 15 | loss_real = (weights * loss_real).sum() / weights.sum() 16 | loss_fake = (weights * loss_fake).sum() / weights.sum() 17 | d_loss = 0.5 * (loss_real + loss_fake) 18 | return d_loss 19 | 20 | def adopt_weight(weight, global_step, threshold=0, value=0.): 21 | if global_step < threshold: 22 | weight = value 23 | return weight 24 | 25 | 26 | def measure_perplexity(predicted_indices, n_embed): 27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 30 | avg_probs = encodings.mean(0) 31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 32 | cluster_use = torch.sum(avg_probs > 0) 33 | return perplexity, cluster_use 34 | 35 | def l1(x, y): 36 | return torch.abs(x-y) 37 | 38 | 39 | def l2(x, y): 40 | return torch.pow((x-y), 2) 41 | 42 | 43 | class VQLPIPSWithDiscriminator(nn.Module): 44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 48 | pixel_loss="l1"): 49 | super().__init__() 50 | assert disc_loss in ["hinge", "vanilla"] 51 | assert perceptual_loss in ["lpips", "clips", "dists"] 52 | assert pixel_loss in ["l1", "l2"] 53 | self.codebook_weight = codebook_weight 54 | self.pixel_weight = pixelloss_weight 55 | if perceptual_loss == "lpips": 56 | print(f"{self.__class__.__name__}: Running with LPIPS.") 57 | self.perceptual_loss = LPIPS().eval() 58 | else: 59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 60 | self.perceptual_weight = perceptual_weight 61 | 62 | if pixel_loss == "l1": 63 | self.pixel_loss = l1 64 | else: 65 | self.pixel_loss = l2 66 | 67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 68 | n_layers=disc_num_layers, 69 | use_actnorm=use_actnorm, 70 | ndf=disc_ndf 71 | ).apply(weights_init) 72 | self.discriminator_iter_start = disc_start 73 | if disc_loss == "hinge": 74 | self.disc_loss = hinge_d_loss 75 | elif disc_loss == "vanilla": 76 | self.disc_loss = vanilla_d_loss 77 | else: 78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 80 | self.disc_factor = disc_factor 81 | self.discriminator_weight = disc_weight 82 | self.disc_conditional = disc_conditional 83 | self.n_classes = n_classes 84 | 85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 86 | if last_layer is not None: 87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 89 | else: 90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 92 | 93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 95 | d_weight = d_weight * self.discriminator_weight 96 | return d_weight 97 | 98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 100 | if not exists(codebook_loss): 101 | codebook_loss = torch.tensor([0.]).to(inputs.device) 102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 104 | if self.perceptual_weight > 0: 105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 106 | rec_loss = rec_loss + self.perceptual_weight * p_loss 107 | else: 108 | p_loss = torch.tensor([0.0]) 109 | 110 | nll_loss = rec_loss 111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 112 | nll_loss = torch.mean(nll_loss) 113 | 114 | # now the GAN part 115 | if optimizer_idx == 0: 116 | # generator update 117 | if cond is None: 118 | assert not self.disc_conditional 119 | logits_fake = self.discriminator(reconstructions.contiguous()) 120 | else: 121 | assert self.disc_conditional 122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 123 | g_loss = -torch.mean(logits_fake) 124 | 125 | try: 126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 127 | except RuntimeError: 128 | assert not self.training 129 | d_weight = torch.tensor(0.0) 130 | 131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 133 | 134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 136 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 137 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 138 | "{}/p_loss".format(split): p_loss.detach().mean(), 139 | "{}/d_weight".format(split): d_weight.detach(), 140 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 141 | "{}/g_loss".format(split): g_loss.detach().mean(), 142 | } 143 | if predicted_indices is not None: 144 | assert self.n_classes is not None 145 | with torch.no_grad(): 146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 147 | log[f"{split}/perplexity"] = perplexity 148 | log[f"{split}/cluster_usage"] = cluster_usage 149 | return loss, log 150 | 151 | if optimizer_idx == 1: 152 | # second pass for discriminator update 153 | if cond is None: 154 | logits_real = self.discriminator(inputs.contiguous().detach()) 155 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 156 | else: 157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 159 | 160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 162 | 163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 164 | "{}/logits_real".format(split): logits_real.detach().mean(), 165 | "{}/logits_fake".format(split): logits_fake.detach().mean() 166 | } 167 | return d_loss, log 168 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | from torch import optim 5 | import numpy as np 6 | 7 | from inspect import isfunction 8 | from PIL import Image, ImageDraw, ImageFont 9 | 10 | import os 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | from PIL import Image 14 | import torch 15 | import time 16 | import cv2 17 | import PIL 18 | 19 | def pil_rectangle_crop(im): 20 | width, height = im.size # Get dimensions 21 | 22 | if width <= height: 23 | left = 0 24 | right = width 25 | top = (height - width)/2 26 | bottom = (height + width)/2 27 | else: 28 | 29 | top = 0 30 | bottom = height 31 | left = (width - height) / 2 32 | bottom = (width + height) / 2 33 | 34 | # Crop the center of the image 35 | im = im.crop((left, top, right, bottom)) 36 | return im 37 | 38 | def add_margin(pil_img, color, size=256): 39 | width, height = pil_img.size 40 | result = Image.new(pil_img.mode, (size, size), color) 41 | result.paste(pil_img, ((size - width) // 2, (size - height) // 2)) 42 | return result 43 | 44 | def load_and_preprocess(interface, input_im): 45 | ''' 46 | :param input_im (PIL Image). 47 | :return image (H, W, 3) array in [0, 1]. 48 | ''' 49 | # See https://github.com/Ir1d/image-background-remove-tool 50 | image = input_im.convert('RGB') 51 | 52 | image_without_background = interface([image])[0] 53 | image_without_background = np.array(image_without_background) 54 | est_seg = image_without_background > 127 55 | image = np.array(image) 56 | foreground = est_seg[:, : , -1].astype(np.bool_) 57 | image[~foreground] = [255., 255., 255.] 58 | x, y, w, h = cv2.boundingRect(foreground.astype(np.uint8)) 59 | image = image[y:y+h, x:x+w, :] 60 | image = PIL.Image.fromarray(np.array(image)) 61 | 62 | # resize image such that long edge is 512 63 | image.thumbnail([200, 200], Image.Resampling.LANCZOS) 64 | image = add_margin(image, (255, 255, 255), size=256) 65 | image = np.array(image) 66 | 67 | return image 68 | 69 | 70 | def log_txt_as_img(wh, xc, size=10): 71 | # wh a tuple of (width, height) 72 | # xc a list of captions to plot 73 | b = len(xc) 74 | txts = list() 75 | for bi in range(b): 76 | txt = Image.new("RGB", wh, color="white") 77 | draw = ImageDraw.Draw(txt) 78 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 79 | nc = int(40 * (wh[0] / 256)) 80 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 81 | 82 | try: 83 | draw.text((0, 0), lines, fill="black", font=font) 84 | except UnicodeEncodeError: 85 | print("Cant encode string for logging. Skipping.") 86 | 87 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 88 | txts.append(txt) 89 | txts = np.stack(txts) 90 | txts = torch.tensor(txts) 91 | return txts 92 | 93 | 94 | def ismap(x): 95 | if not isinstance(x, torch.Tensor): 96 | return False 97 | return (len(x.shape) == 4) and (x.shape[1] > 3) 98 | 99 | 100 | def isimage(x): 101 | if not isinstance(x,torch.Tensor): 102 | return False 103 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 104 | 105 | 106 | def exists(x): 107 | return x is not None 108 | 109 | 110 | def default(val, d): 111 | if exists(val): 112 | return val 113 | return d() if isfunction(d) else d 114 | 115 | 116 | def mean_flat(tensor): 117 | """ 118 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 119 | Take the mean over all non-batch dimensions. 120 | """ 121 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 122 | 123 | 124 | def count_params(model, verbose=False): 125 | total_params = sum(p.numel() for p in model.parameters()) 126 | if verbose: 127 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 128 | return total_params 129 | 130 | 131 | def instantiate_from_config(config): 132 | if not "target" in config: 133 | if config == '__is_first_stage__': 134 | return None 135 | elif config == "__is_unconditional__": 136 | return None 137 | raise KeyError("Expected key `target` to instantiate.") 138 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 139 | 140 | 141 | def get_obj_from_str(string, reload=False): 142 | module, cls = string.rsplit(".", 1) 143 | if reload: 144 | module_imp = importlib.import_module(module) 145 | importlib.reload(module_imp) 146 | return getattr(importlib.import_module(module, package=None), cls) 147 | 148 | 149 | class AdamWwithEMAandWings(optim.Optimizer): 150 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 151 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using 152 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code 153 | ema_power=1., param_names=()): 154 | """AdamW that saves EMA versions of the parameters.""" 155 | if not 0.0 <= lr: 156 | raise ValueError("Invalid learning rate: {}".format(lr)) 157 | if not 0.0 <= eps: 158 | raise ValueError("Invalid epsilon value: {}".format(eps)) 159 | if not 0.0 <= betas[0] < 1.0: 160 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 161 | if not 0.0 <= betas[1] < 1.0: 162 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 163 | if not 0.0 <= weight_decay: 164 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 165 | if not 0.0 <= ema_decay <= 1.0: 166 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) 167 | defaults = dict(lr=lr, betas=betas, eps=eps, 168 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, 169 | ema_power=ema_power, param_names=param_names) 170 | super().__init__(params, defaults) 171 | 172 | def __setstate__(self, state): 173 | super().__setstate__(state) 174 | for group in self.param_groups: 175 | group.setdefault('amsgrad', False) 176 | 177 | @torch.no_grad() 178 | def step(self, closure=None): 179 | """Performs a single optimization step. 180 | Args: 181 | closure (callable, optional): A closure that reevaluates the model 182 | and returns the loss. 183 | """ 184 | loss = None 185 | if closure is not None: 186 | with torch.enable_grad(): 187 | loss = closure() 188 | 189 | for group in self.param_groups: 190 | params_with_grad = [] 191 | grads = [] 192 | exp_avgs = [] 193 | exp_avg_sqs = [] 194 | ema_params_with_grad = [] 195 | state_sums = [] 196 | max_exp_avg_sqs = [] 197 | state_steps = [] 198 | amsgrad = group['amsgrad'] 199 | beta1, beta2 = group['betas'] 200 | ema_decay = group['ema_decay'] 201 | ema_power = group['ema_power'] 202 | 203 | for p in group['params']: 204 | if p.grad is None: 205 | continue 206 | params_with_grad.append(p) 207 | if p.grad.is_sparse: 208 | raise RuntimeError('AdamW does not support sparse gradients') 209 | grads.append(p.grad) 210 | 211 | state = self.state[p] 212 | 213 | # State initialization 214 | if len(state) == 0: 215 | state['step'] = 0 216 | # Exponential moving average of gradient values 217 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 218 | # Exponential moving average of squared gradient values 219 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 220 | if amsgrad: 221 | # Maintains max of all exp. moving avg. of sq. grad. values 222 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 223 | # Exponential moving average of parameter values 224 | state['param_exp_avg'] = p.detach().float().clone() 225 | 226 | exp_avgs.append(state['exp_avg']) 227 | exp_avg_sqs.append(state['exp_avg_sq']) 228 | ema_params_with_grad.append(state['param_exp_avg']) 229 | 230 | if amsgrad: 231 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 232 | 233 | # update the steps for each param group update 234 | state['step'] += 1 235 | # record the step after step update 236 | state_steps.append(state['step']) 237 | 238 | optim._functional.adamw(params_with_grad, 239 | grads, 240 | exp_avgs, 241 | exp_avg_sqs, 242 | max_exp_avg_sqs, 243 | state_steps, 244 | amsgrad=amsgrad, 245 | beta1=beta1, 246 | beta2=beta2, 247 | lr=group['lr'], 248 | weight_decay=group['weight_decay'], 249 | eps=group['eps'], 250 | maximize=False) 251 | 252 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) 253 | for param, ema_param in zip(params_with_grad, ema_params_with_grad): 254 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) 255 | 256 | return loss -------------------------------------------------------------------------------- /ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from ldm.modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 154 | super().__init__() 155 | inner_dim = dim_head * heads 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 164 | 165 | self.to_out = nn.Sequential( 166 | nn.Linear(inner_dim, query_dim), 167 | nn.Dropout(dropout) 168 | ) 169 | 170 | def forward(self, x, context=None, mask=None): 171 | h = self.heads 172 | 173 | q = self.to_q(x) 174 | context = default(context, x) 175 | k = self.to_k(context) 176 | v = self.to_v(context) 177 | 178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 179 | 180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 181 | 182 | if exists(mask): 183 | mask = rearrange(mask, 'b ... -> b (...)') 184 | max_neg_value = -torch.finfo(sim.dtype).max 185 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 186 | sim.masked_fill_(~mask, max_neg_value) 187 | 188 | # attention, what we cannot get enough of 189 | attn = sim.softmax(dim=-1) 190 | 191 | out = einsum('b i j, b j d -> b i d', attn, v) 192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 193 | return self.to_out(out) 194 | 195 | 196 | class BasicTransformerBlock(nn.Module): 197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, 198 | disable_self_attn=False): 199 | super().__init__() 200 | self.disable_self_attn = disable_self_attn 201 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, 202 | context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn 203 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 204 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 205 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 206 | self.norm1 = nn.LayerNorm(dim) 207 | self.norm2 = nn.LayerNorm(dim) 208 | self.norm3 = nn.LayerNorm(dim) 209 | self.checkpoint = checkpoint 210 | 211 | def forward(self, x, context=None): 212 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 213 | 214 | def _forward(self, x, context=None): 215 | x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x 216 | x = self.attn2(self.norm2(x), context=context) + x 217 | x = self.ff(self.norm3(x)) + x 218 | return x 219 | 220 | 221 | class SpatialTransformer(nn.Module): 222 | """ 223 | Transformer block for image-like data. 224 | First, project the input (aka embedding) 225 | and reshape to b, t, d. 226 | Then apply standard transformer action. 227 | Finally, reshape to image 228 | """ 229 | def __init__(self, in_channels, n_heads, d_head, 230 | depth=1, dropout=0., context_dim=None, 231 | disable_self_attn=False): 232 | super().__init__() 233 | self.in_channels = in_channels 234 | inner_dim = n_heads * d_head 235 | self.norm = Normalize(in_channels) 236 | 237 | self.proj_in = nn.Conv2d(in_channels, 238 | inner_dim, 239 | kernel_size=1, 240 | stride=1, 241 | padding=0) 242 | 243 | self.transformer_blocks = nn.ModuleList( 244 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, 245 | disable_self_attn=disable_self_attn) 246 | for d in range(depth)] 247 | ) 248 | 249 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 250 | in_channels, 251 | kernel_size=1, 252 | stride=1, 253 | padding=0)) 254 | 255 | def forward(self, x, context=None): 256 | # note: if no context is given, cross-attention defaults to self-attention 257 | b, c, h, w = x.shape 258 | x_in = x 259 | x = self.norm(x) 260 | x = self.proj_in(x) 261 | x = rearrange(x, 'b c h w -> b (h w) c').contiguous() 262 | for block in self.transformer_blocks: 263 | x = block(x, context=context) 264 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() 265 | x = self.proj_out(x) 266 | return x + x_in 267 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from ldm.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def checkpoint(func, inputs, params, flag): 103 | """ 104 | Evaluate a function without caching intermediate activations, allowing for 105 | reduced memory at the expense of extra compute in the backward pass. 106 | :param func: the function to evaluate. 107 | :param inputs: the argument sequence to pass to `func`. 108 | :param params: a sequence of parameters `func` depends on but does not 109 | explicitly take as arguments. 110 | :param flag: if False, disable gradient checkpointing. 111 | """ 112 | if flag: 113 | args = tuple(inputs) + tuple(params) 114 | return CheckpointFunction.apply(func, len(inputs), *args) 115 | else: 116 | return func(*inputs) 117 | 118 | 119 | class CheckpointFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, run_function, length, *args): 122 | ctx.run_function = run_function 123 | ctx.input_tensors = list(args[:length]) 124 | ctx.input_params = list(args[length:]) 125 | 126 | with torch.no_grad(): 127 | output_tensors = ctx.run_function(*ctx.input_tensors) 128 | return output_tensors 129 | 130 | @staticmethod 131 | def backward(ctx, *output_grads): 132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 133 | with torch.enable_grad(): 134 | # Fixes a bug where the first op in run_function modifies the 135 | # Tensor storage in place, which is not allowed for detach()'d 136 | # Tensors. 137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 138 | output_tensors = ctx.run_function(*shallow_copies) 139 | input_grads = torch.autograd.grad( 140 | output_tensors, 141 | ctx.input_tensors + ctx.input_params, 142 | output_grads, 143 | allow_unused=True, 144 | ) 145 | del ctx.input_tensors 146 | del ctx.input_params 147 | del output_tensors 148 | return (None, None) + input_grads 149 | 150 | 151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 152 | """ 153 | Create sinusoidal timestep embeddings. 154 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 155 | These may be fractional. 156 | :param dim: the dimension of the output. 157 | :param max_period: controls the minimum frequency of the embeddings. 158 | :return: an [N x dim] Tensor of positional embeddings. 159 | """ 160 | if not repeat_only: 161 | half = dim // 2 162 | freqs = torch.exp( 163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 164 | ).to(device=timesteps.device) 165 | args = timesteps[:, None].float() * freqs[None] 166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 167 | if dim % 2: 168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 169 | else: 170 | embedding = repeat(timesteps, 'b -> b d', d=dim) 171 | return embedding 172 | 173 | 174 | def zero_module(module): 175 | """ 176 | Zero out the parameters of a module and return it. 177 | """ 178 | for p in module.parameters(): 179 | p.detach().zero_() 180 | return module 181 | 182 | 183 | def scale_module(module, scale): 184 | """ 185 | Scale the parameters of a module and return it. 186 | """ 187 | for p in module.parameters(): 188 | p.detach().mul_(scale) 189 | return module 190 | 191 | 192 | def mean_flat(tensor): 193 | """ 194 | Take the mean over all non-batch dimensions. 195 | """ 196 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 197 | 198 | 199 | def normalization(channels): 200 | """ 201 | Make a standard normalization layer. 202 | :param channels: number of input channels. 203 | :return: an nn.Module for normalization. 204 | """ 205 | return GroupNorm32(32, channels) 206 | 207 | 208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 209 | class SiLU(nn.Module): 210 | def forward(self, x): 211 | return x * torch.sigmoid(x) 212 | 213 | 214 | class GroupNorm32(nn.GroupNorm): 215 | def forward(self, x): 216 | return super().forward(x.float()).type(x.dtype) 217 | 218 | def conv_nd(dims, *args, **kwargs): 219 | """ 220 | Create a 1D, 2D, or 3D convolution module. 221 | """ 222 | if dims == 1: 223 | return nn.Conv1d(*args, **kwargs) 224 | elif dims == 2: 225 | return nn.Conv2d(*args, **kwargs) 226 | elif dims == 3: 227 | return nn.Conv3d(*args, **kwargs) 228 | raise ValueError(f"unsupported dimensions: {dims}") 229 | 230 | 231 | def linear(*args, **kwargs): 232 | """ 233 | Create a linear module. 234 | """ 235 | return nn.Linear(*args, **kwargs) 236 | 237 | 238 | def avg_pool_nd(dims, *args, **kwargs): 239 | """ 240 | Create a 1D, 2D, or 3D average pooling module. 241 | """ 242 | if dims == 1: 243 | return nn.AvgPool1d(*args, **kwargs) 244 | elif dims == 2: 245 | return nn.AvgPool2d(*args, **kwargs) 246 | elif dims == 3: 247 | return nn.AvgPool3d(*args, **kwargs) 248 | raise ValueError(f"unsupported dimensions: {dims}") 249 | 250 | 251 | class HybridConditioner(nn.Module): 252 | 253 | def __init__(self, c_concat_config, c_crossattn_config): 254 | super().__init__() 255 | self.concat_conditioner = instantiate_from_config(c_concat_config) 256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 257 | 258 | def forward(self, c_concat, c_crossattn): 259 | c_concat = self.concat_conditioner(c_concat) 260 | c_crossattn = self.crossattn_conditioner(c_crossattn) 261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 262 | 263 | 264 | def noise_like(shape, device, repeat=False): 265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 266 | noise = lambda: torch.randn(shape, device=device) 267 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /ldm/models/diffusion/classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pytorch_lightning as pl 4 | from omegaconf import OmegaConf 5 | from torch.nn import functional as F 6 | from torch.optim import AdamW 7 | from torch.optim.lr_scheduler import LambdaLR 8 | from copy import deepcopy 9 | from einops import rearrange 10 | from glob import glob 11 | from natsort import natsorted 12 | 13 | from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel 14 | from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config 15 | 16 | __models__ = { 17 | 'class_label': EncoderUNetModel, 18 | 'segmentation': UNetModel 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | class NoisyLatentImageClassifier(pl.LightningModule): 29 | 30 | def __init__(self, 31 | diffusion_path, 32 | num_classes, 33 | ckpt_path=None, 34 | pool='attention', 35 | label_key=None, 36 | diffusion_ckpt_path=None, 37 | scheduler_config=None, 38 | weight_decay=1.e-2, 39 | log_steps=10, 40 | monitor='val/loss', 41 | *args, 42 | **kwargs): 43 | super().__init__(*args, **kwargs) 44 | self.num_classes = num_classes 45 | # get latest config of diffusion model 46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] 47 | self.diffusion_config = OmegaConf.load(diffusion_config).model 48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path 49 | self.load_diffusion() 50 | 51 | self.monitor = monitor 52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps 54 | self.log_steps = log_steps 55 | 56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ 57 | else self.diffusion_model.cond_stage_key 58 | 59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' 60 | 61 | if self.label_key not in __models__: 62 | raise NotImplementedError() 63 | 64 | self.load_classifier(ckpt_path, pool) 65 | 66 | self.scheduler_config = scheduler_config 67 | self.use_scheduler = self.scheduler_config is not None 68 | self.weight_decay = weight_decay 69 | 70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): 71 | sd = torch.load(path, map_location="cpu") 72 | if "state_dict" in list(sd.keys()): 73 | sd = sd["state_dict"] 74 | keys = list(sd.keys()) 75 | for k in keys: 76 | for ik in ignore_keys: 77 | if k.startswith(ik): 78 | print("Deleting key {} from state_dict.".format(k)) 79 | del sd[k] 80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( 81 | sd, strict=False) 82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 83 | if len(missing) > 0: 84 | print(f"Missing Keys: {missing}") 85 | if len(unexpected) > 0: 86 | print(f"Unexpected Keys: {unexpected}") 87 | 88 | def load_diffusion(self): 89 | model = instantiate_from_config(self.diffusion_config) 90 | self.diffusion_model = model.eval() 91 | self.diffusion_model.train = disabled_train 92 | for param in self.diffusion_model.parameters(): 93 | param.requires_grad = False 94 | 95 | def load_classifier(self, ckpt_path, pool): 96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params) 97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels 98 | model_config.out_channels = self.num_classes 99 | if self.label_key == 'class_label': 100 | model_config.pool = pool 101 | 102 | self.model = __models__[self.label_key](**model_config) 103 | if ckpt_path is not None: 104 | print('#####################################################################') 105 | print(f'load from ckpt "{ckpt_path}"') 106 | print('#####################################################################') 107 | self.init_from_ckpt(ckpt_path) 108 | 109 | @torch.no_grad() 110 | def get_x_noisy(self, x, t, noise=None): 111 | noise = default(noise, lambda: torch.randn_like(x)) 112 | continuous_sqrt_alpha_cumprod = None 113 | if self.diffusion_model.use_continuous_noise: 114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) 115 | # todo: make sure t+1 is correct here 116 | 117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, 118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) 119 | 120 | def forward(self, x_noisy, t, *args, **kwargs): 121 | return self.model(x_noisy, t) 122 | 123 | @torch.no_grad() 124 | def get_input(self, batch, k): 125 | x = batch[k] 126 | if len(x.shape) == 3: 127 | x = x[..., None] 128 | x = rearrange(x, 'b h w c -> b c h w') 129 | x = x.to(memory_format=torch.contiguous_format).float() 130 | return x 131 | 132 | @torch.no_grad() 133 | def get_conditioning(self, batch, k=None): 134 | if k is None: 135 | k = self.label_key 136 | assert k is not None, 'Needs to provide label key' 137 | 138 | targets = batch[k].to(self.device) 139 | 140 | if self.label_key == 'segmentation': 141 | targets = rearrange(targets, 'b h w c -> b c h w') 142 | for down in range(self.numd): 143 | h, w = targets.shape[-2:] 144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') 145 | 146 | # targets = rearrange(targets,'b c h w -> b h w c') 147 | 148 | return targets 149 | 150 | def compute_top_k(self, logits, labels, k, reduction="mean"): 151 | _, top_ks = torch.topk(logits, k, dim=1) 152 | if reduction == "mean": 153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() 154 | elif reduction == "none": 155 | return (top_ks == labels[:, None]).float().sum(dim=-1) 156 | 157 | def on_train_epoch_start(self): 158 | # save some memory 159 | self.diffusion_model.model.to('cpu') 160 | 161 | @torch.no_grad() 162 | def write_logs(self, loss, logits, targets): 163 | log_prefix = 'train' if self.training else 'val' 164 | log = {} 165 | log[f"{log_prefix}/loss"] = loss.mean() 166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k( 167 | logits, targets, k=1, reduction="mean" 168 | ) 169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k( 170 | logits, targets, k=5, reduction="mean" 171 | ) 172 | 173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) 174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) 175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) 176 | lr = self.optimizers().param_groups[0]['lr'] 177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) 178 | 179 | def shared_step(self, batch, t=None): 180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) 181 | targets = self.get_conditioning(batch) 182 | if targets.dim() == 4: 183 | targets = targets.argmax(dim=1) 184 | if t is None: 185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() 186 | else: 187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() 188 | x_noisy = self.get_x_noisy(x, t) 189 | logits = self(x_noisy, t) 190 | 191 | loss = F.cross_entropy(logits, targets, reduction='none') 192 | 193 | self.write_logs(loss.detach(), logits.detach(), targets.detach()) 194 | 195 | loss = loss.mean() 196 | return loss, logits, x_noisy, targets 197 | 198 | def training_step(self, batch, batch_idx): 199 | loss, *_ = self.shared_step(batch) 200 | return loss 201 | 202 | def reset_noise_accs(self): 203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in 204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} 205 | 206 | def on_validation_start(self): 207 | self.reset_noise_accs() 208 | 209 | @torch.no_grad() 210 | def validation_step(self, batch, batch_idx): 211 | loss, *_ = self.shared_step(batch) 212 | 213 | for t in self.noisy_acc: 214 | _, logits, _, targets = self.shared_step(batch, t) 215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) 216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) 217 | 218 | return loss 219 | 220 | def configure_optimizers(self): 221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) 222 | 223 | if self.use_scheduler: 224 | scheduler = instantiate_from_config(self.scheduler_config) 225 | 226 | print("Setting up LambdaLR scheduler...") 227 | scheduler = [ 228 | { 229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), 230 | 'interval': 'step', 231 | 'frequency': 1 232 | }] 233 | return [optimizer], scheduler 234 | 235 | return optimizer 236 | 237 | @torch.no_grad() 238 | def log_images(self, batch, N=8, *args, **kwargs): 239 | log = dict() 240 | x = self.get_input(batch, self.diffusion_model.first_stage_key) 241 | log['inputs'] = x 242 | 243 | y = self.get_conditioning(batch) 244 | 245 | if self.label_key == 'class_label': 246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) 247 | log['labels'] = y 248 | 249 | if ismap(y): 250 | log['labels'] = self.diffusion_model.to_rgb(y) 251 | 252 | for step in range(self.log_steps): 253 | current_time = step * self.log_time_interval 254 | 255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) 256 | 257 | log[f'inputs@t{current_time}'] = x_noisy 258 | 259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) 260 | pred = rearrange(pred, 'b h w c -> b c h w') 261 | 262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) 263 | 264 | for key in log: 265 | log[key] = log[key][:N] 266 | 267 | return log 268 | -------------------------------------------------------------------------------- /ldm/modules/evaluate/torch_frechet_video_distance.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/universome/fvd-comparison/blob/master/compare_models.py; huge thanks! 2 | import os 3 | import numpy as np 4 | import io 5 | import re 6 | import requests 7 | import html 8 | import hashlib 9 | import urllib 10 | import urllib.request 11 | import scipy.linalg 12 | import multiprocessing as mp 13 | import glob 14 | 15 | 16 | from tqdm import tqdm 17 | from typing import Any, List, Tuple, Union, Dict, Callable 18 | 19 | from torchvision.io import read_video 20 | import torch; torch.set_grad_enabled(False) 21 | from einops import rearrange 22 | 23 | from nitro.util import isvideo 24 | 25 | def compute_frechet_distance(mu_sample,sigma_sample,mu_ref,sigma_ref) -> float: 26 | print('Calculate frechet distance...') 27 | m = np.square(mu_sample - mu_ref).sum() 28 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_sample, sigma_ref), disp=False) # pylint: disable=no-member 29 | fid = np.real(m + np.trace(sigma_sample + sigma_ref - s * 2)) 30 | 31 | return float(fid) 32 | 33 | 34 | def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 35 | mu = feats.mean(axis=0) # [d] 36 | sigma = np.cov(feats, rowvar=False) # [d, d] 37 | 38 | return mu, sigma 39 | 40 | 41 | def open_url(url: str, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False) -> Any: 42 | """Download the given URL and return a binary-mode file object to access the data.""" 43 | assert num_attempts >= 1 44 | 45 | # Doesn't look like an URL scheme so interpret it as a local filename. 46 | if not re.match('^[a-z]+://', url): 47 | return url if return_filename else open(url, "rb") 48 | 49 | # Handle file URLs. This code handles unusual file:// patterns that 50 | # arise on Windows: 51 | # 52 | # file:///c:/foo.txt 53 | # 54 | # which would translate to a local '/c:/foo.txt' filename that's 55 | # invalid. Drop the forward slash for such pathnames. 56 | # 57 | # If you touch this code path, you should test it on both Linux and 58 | # Windows. 59 | # 60 | # Some internet resources suggest using urllib.request.url2pathname() but 61 | # but that converts forward slashes to backslashes and this causes 62 | # its own set of problems. 63 | if url.startswith('file://'): 64 | filename = urllib.parse.urlparse(url).path 65 | if re.match(r'^/[a-zA-Z]:', filename): 66 | filename = filename[1:] 67 | return filename if return_filename else open(filename, "rb") 68 | 69 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() 70 | 71 | # Download. 72 | url_name = None 73 | url_data = None 74 | with requests.Session() as session: 75 | if verbose: 76 | print("Downloading %s ..." % url, end="", flush=True) 77 | for attempts_left in reversed(range(num_attempts)): 78 | try: 79 | with session.get(url) as res: 80 | res.raise_for_status() 81 | if len(res.content) == 0: 82 | raise IOError("No data received") 83 | 84 | if len(res.content) < 8192: 85 | content_str = res.content.decode("utf-8") 86 | if "download_warning" in res.headers.get("Set-Cookie", ""): 87 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] 88 | if len(links) == 1: 89 | url = requests.compat.urljoin(url, links[0]) 90 | raise IOError("Google Drive virus checker nag") 91 | if "Google Drive - Quota exceeded" in content_str: 92 | raise IOError("Google Drive download quota exceeded -- please try again later") 93 | 94 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) 95 | url_name = match[1] if match else url 96 | url_data = res.content 97 | if verbose: 98 | print(" done") 99 | break 100 | except KeyboardInterrupt: 101 | raise 102 | except: 103 | if not attempts_left: 104 | if verbose: 105 | print(" failed") 106 | raise 107 | if verbose: 108 | print(".", end="", flush=True) 109 | 110 | # Return data as file object. 111 | assert not return_filename 112 | return io.BytesIO(url_data) 113 | 114 | def load_video(ip): 115 | vid, *_ = read_video(ip) 116 | vid = rearrange(vid, 't h w c -> t c h w').to(torch.uint8) 117 | return vid 118 | 119 | def get_data_from_str(input_str,nprc = None): 120 | assert os.path.isdir(input_str), f'Specified input folder "{input_str}" is not a directory' 121 | vid_filelist = glob.glob(os.path.join(input_str,'*.mp4')) 122 | print(f'Found {len(vid_filelist)} videos in dir {input_str}') 123 | 124 | if nprc is None: 125 | try: 126 | nprc = mp.cpu_count() 127 | except NotImplementedError: 128 | print('WARNING: cpu_count() not avlailable, using only 1 cpu for video loading') 129 | nprc = 1 130 | 131 | pool = mp.Pool(processes=nprc) 132 | 133 | vids = [] 134 | for v in tqdm(pool.imap_unordered(load_video,vid_filelist),total=len(vid_filelist),desc='Loading videos...'): 135 | vids.append(v) 136 | 137 | 138 | vids = torch.stack(vids,dim=0).float() 139 | 140 | return vids 141 | 142 | def get_stats(stats): 143 | assert os.path.isfile(stats) and stats.endswith('.npz'), f'no stats found under {stats}' 144 | 145 | print(f'Using precomputed statistics under {stats}') 146 | stats = np.load(stats) 147 | stats = {key: stats[key] for key in stats.files} 148 | 149 | return stats 150 | 151 | 152 | 153 | 154 | @torch.no_grad() 155 | def compute_fvd(ref_input, sample_input, bs=32, 156 | ref_stats=None, 157 | sample_stats=None, 158 | nprc_load=None): 159 | 160 | 161 | 162 | calc_stats = ref_stats is None or sample_stats is None 163 | 164 | if calc_stats: 165 | 166 | only_ref = sample_stats is not None 167 | only_sample = ref_stats is not None 168 | 169 | 170 | if isinstance(ref_input,str) and not only_sample: 171 | ref_input = get_data_from_str(ref_input,nprc_load) 172 | 173 | if isinstance(sample_input, str) and not only_ref: 174 | sample_input = get_data_from_str(sample_input, nprc_load) 175 | 176 | stats = compute_statistics(sample_input,ref_input, 177 | device='cuda' if torch.cuda.is_available() else 'cpu', 178 | bs=bs, 179 | only_ref=only_ref, 180 | only_sample=only_sample) 181 | 182 | if only_ref: 183 | stats.update(get_stats(sample_stats)) 184 | elif only_sample: 185 | stats.update(get_stats(ref_stats)) 186 | 187 | 188 | 189 | else: 190 | stats = get_stats(sample_stats) 191 | stats.update(get_stats(ref_stats)) 192 | 193 | fvd = compute_frechet_distance(**stats) 194 | 195 | return {'FVD' : fvd,} 196 | 197 | 198 | @torch.no_grad() 199 | def compute_statistics(videos_fake, videos_real, device: str='cuda', bs=32, only_ref=False,only_sample=False) -> Dict: 200 | detector_url = 'https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1' 201 | detector_kwargs = dict(rescale=True, resize=True, return_features=True) # Return raw features before the softmax layer. 202 | 203 | with open_url(detector_url, verbose=False) as f: 204 | detector = torch.jit.load(f).eval().to(device) 205 | 206 | 207 | 208 | assert not (only_sample and only_ref), 'only_ref and only_sample arguments are mutually exclusive' 209 | 210 | ref_embed, sample_embed = [], [] 211 | 212 | info = f'Computing I3D activations for FVD score with batch size {bs}' 213 | 214 | if only_ref: 215 | 216 | if not isvideo(videos_real): 217 | # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255] 218 | videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float() 219 | print(videos_real.shape) 220 | 221 | if videos_real.shape[0] % bs == 0: 222 | n_secs = videos_real.shape[0] // bs 223 | else: 224 | n_secs = videos_real.shape[0] // bs + 1 225 | 226 | videos_real = torch.tensor_split(videos_real, n_secs, dim=0) 227 | 228 | for ref_v in tqdm(videos_real, total=len(videos_real),desc=info): 229 | 230 | feats_ref = detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() 231 | ref_embed.append(feats_ref) 232 | 233 | elif only_sample: 234 | 235 | if not isvideo(videos_fake): 236 | # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255] 237 | videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float() 238 | print(videos_fake.shape) 239 | 240 | if videos_fake.shape[0] % bs == 0: 241 | n_secs = videos_fake.shape[0] // bs 242 | else: 243 | n_secs = videos_fake.shape[0] // bs + 1 244 | 245 | videos_real = torch.tensor_split(videos_real, n_secs, dim=0) 246 | 247 | for sample_v in tqdm(videos_fake, total=len(videos_real),desc=info): 248 | feats_sample = detector(sample_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() 249 | sample_embed.append(feats_sample) 250 | 251 | 252 | else: 253 | 254 | if not isvideo(videos_real): 255 | # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255] 256 | videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float() 257 | 258 | if not isvideo(videos_fake): 259 | videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float() 260 | 261 | if videos_fake.shape[0] % bs == 0: 262 | n_secs = videos_fake.shape[0] // bs 263 | else: 264 | n_secs = videos_fake.shape[0] // bs + 1 265 | 266 | videos_real = torch.tensor_split(videos_real, n_secs, dim=0) 267 | videos_fake = torch.tensor_split(videos_fake, n_secs, dim=0) 268 | 269 | for ref_v, sample_v in tqdm(zip(videos_real,videos_fake),total=len(videos_fake),desc=info): 270 | # print(ref_v.shape) 271 | # ref_v = torch.nn.functional.interpolate(ref_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False) 272 | # sample_v = torch.nn.functional.interpolate(sample_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False) 273 | 274 | 275 | feats_sample = detector(sample_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() 276 | feats_ref = detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() 277 | sample_embed.append(feats_sample) 278 | ref_embed.append(feats_ref) 279 | 280 | out = dict() 281 | if len(sample_embed) > 0: 282 | sample_embed = np.concatenate(sample_embed,axis=0) 283 | mu_sample, sigma_sample = compute_stats(sample_embed) 284 | out.update({'mu_sample': mu_sample, 285 | 'sigma_sample': sigma_sample}) 286 | 287 | if len(ref_embed) > 0: 288 | ref_embed = np.concatenate(ref_embed,axis=0) 289 | mu_ref, sigma_ref = compute_stats(ref_embed) 290 | out.update({'mu_ref': mu_ref, 291 | 'sigma_ref': sigma_ref}) 292 | 293 | 294 | return out 295 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /ldm/data/coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import albumentations 4 | import numpy as np 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from torch.utils.data import Dataset 8 | from abc import abstractmethod 9 | 10 | 11 | class CocoBase(Dataset): 12 | """needed for (image, caption, segmentation) pairs""" 13 | def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False, 14 | crop_size=None, force_no_crop=False, given_files=None, use_segmentation=True,crop_type=None): 15 | self.split = self.get_split() 16 | self.size = size 17 | if crop_size is None: 18 | self.crop_size = size 19 | else: 20 | self.crop_size = crop_size 21 | 22 | assert crop_type in [None, 'random', 'center'] 23 | self.crop_type = crop_type 24 | self.use_segmenation = use_segmentation 25 | self.onehot = onehot_segmentation # return segmentation as rgb or one hot 26 | self.stuffthing = use_stuffthing # include thing in segmentation 27 | if self.onehot and not self.stuffthing: 28 | raise NotImplemented("One hot mode is only supported for the " 29 | "stuffthings version because labels are stored " 30 | "a bit different.") 31 | 32 | data_json = datajson 33 | with open(data_json) as json_file: 34 | self.json_data = json.load(json_file) 35 | self.img_id_to_captions = dict() 36 | self.img_id_to_filepath = dict() 37 | self.img_id_to_segmentation_filepath = dict() 38 | 39 | assert data_json.split("/")[-1] in [f"captions_train{self.year()}.json", 40 | f"captions_val{self.year()}.json"] 41 | # TODO currently hardcoded paths, would be better to follow logic in 42 | # cocstuff pixelmaps 43 | if self.use_segmenation: 44 | if self.stuffthing: 45 | self.segmentation_prefix = ( 46 | f"data/cocostuffthings/val{self.year()}" if 47 | data_json.endswith(f"captions_val{self.year()}.json") else 48 | f"data/cocostuffthings/train{self.year()}") 49 | else: 50 | self.segmentation_prefix = ( 51 | f"data/coco/annotations/stuff_val{self.year()}_pixelmaps" if 52 | data_json.endswith(f"captions_val{self.year()}.json") else 53 | f"data/coco/annotations/stuff_train{self.year()}_pixelmaps") 54 | 55 | imagedirs = self.json_data["images"] 56 | self.labels = {"image_ids": list()} 57 | for imgdir in tqdm(imagedirs, desc="ImgToPath"): 58 | self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"]) 59 | self.img_id_to_captions[imgdir["id"]] = list() 60 | pngfilename = imgdir["file_name"].replace("jpg", "png") 61 | if self.use_segmenation: 62 | self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join( 63 | self.segmentation_prefix, pngfilename) 64 | if given_files is not None: 65 | if pngfilename in given_files: 66 | self.labels["image_ids"].append(imgdir["id"]) 67 | else: 68 | self.labels["image_ids"].append(imgdir["id"]) 69 | 70 | capdirs = self.json_data["annotations"] 71 | for capdir in tqdm(capdirs, desc="ImgToCaptions"): 72 | # there are in average 5 captions per image 73 | #self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]])) 74 | self.img_id_to_captions[capdir["image_id"]].append(capdir["caption"]) 75 | 76 | self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) 77 | if self.split=="validation": 78 | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) 79 | else: 80 | # default option for train is random crop 81 | if self.crop_type in [None, 'random']: 82 | self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) 83 | else: 84 | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) 85 | self.preprocessor = albumentations.Compose( 86 | [self.rescaler, self.cropper], 87 | additional_targets={"segmentation": "image"}) 88 | if force_no_crop: 89 | self.rescaler = albumentations.Resize(height=self.size, width=self.size) 90 | self.preprocessor = albumentations.Compose( 91 | [self.rescaler], 92 | additional_targets={"segmentation": "image"}) 93 | 94 | @abstractmethod 95 | def year(self): 96 | raise NotImplementedError() 97 | 98 | def __len__(self): 99 | return len(self.labels["image_ids"]) 100 | 101 | def preprocess_image(self, image_path, segmentation_path=None): 102 | image = Image.open(image_path) 103 | if not image.mode == "RGB": 104 | image = image.convert("RGB") 105 | image = np.array(image).astype(np.uint8) 106 | if segmentation_path: 107 | segmentation = Image.open(segmentation_path) 108 | if not self.onehot and not segmentation.mode == "RGB": 109 | segmentation = segmentation.convert("RGB") 110 | segmentation = np.array(segmentation).astype(np.uint8) 111 | if self.onehot: 112 | assert self.stuffthing 113 | # stored in caffe format: unlabeled==255. stuff and thing from 114 | # 0-181. to be compatible with the labels in 115 | # https://github.com/nightrome/cocostuff/blob/master/labels.txt 116 | # we shift stuffthing one to the right and put unlabeled in zero 117 | # as long as segmentation is uint8 shifting to right handles the 118 | # latter too 119 | assert segmentation.dtype == np.uint8 120 | segmentation = segmentation + 1 121 | 122 | processed = self.preprocessor(image=image, segmentation=segmentation) 123 | 124 | image, segmentation = processed["image"], processed["segmentation"] 125 | else: 126 | image = self.preprocessor(image=image,)['image'] 127 | 128 | image = (image / 127.5 - 1.0).astype(np.float32) 129 | if segmentation_path: 130 | if self.onehot: 131 | assert segmentation.dtype == np.uint8 132 | # make it one hot 133 | n_labels = 183 134 | flatseg = np.ravel(segmentation) 135 | onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool) 136 | onehot[np.arange(flatseg.size), flatseg] = True 137 | onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int) 138 | segmentation = onehot 139 | else: 140 | segmentation = (segmentation / 127.5 - 1.0).astype(np.float32) 141 | return image, segmentation 142 | else: 143 | return image 144 | 145 | def __getitem__(self, i): 146 | img_path = self.img_id_to_filepath[self.labels["image_ids"][i]] 147 | if self.use_segmenation: 148 | seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]] 149 | image, segmentation = self.preprocess_image(img_path, seg_path) 150 | else: 151 | image = self.preprocess_image(img_path) 152 | captions = self.img_id_to_captions[self.labels["image_ids"][i]] 153 | # randomly draw one of all available captions per image 154 | caption = captions[np.random.randint(0, len(captions))] 155 | example = {"image": image, 156 | #"caption": [str(caption[0])], 157 | "caption": caption, 158 | "img_path": img_path, 159 | "filename_": img_path.split(os.sep)[-1] 160 | } 161 | if self.use_segmenation: 162 | example.update({"seg_path": seg_path, 'segmentation': segmentation}) 163 | return example 164 | 165 | 166 | class CocoImagesAndCaptionsTrain2017(CocoBase): 167 | """returns a pair of (image, caption)""" 168 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,): 169 | super().__init__(size=size, 170 | dataroot="data/coco/train2017", 171 | datajson="data/coco/annotations/captions_train2017.json", 172 | onehot_segmentation=onehot_segmentation, 173 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop) 174 | 175 | def get_split(self): 176 | return "train" 177 | 178 | def year(self): 179 | return '2017' 180 | 181 | 182 | class CocoImagesAndCaptionsValidation2017(CocoBase): 183 | """returns a pair of (image, caption)""" 184 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False, 185 | given_files=None): 186 | super().__init__(size=size, 187 | dataroot="data/coco/val2017", 188 | datajson="data/coco/annotations/captions_val2017.json", 189 | onehot_segmentation=onehot_segmentation, 190 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, 191 | given_files=given_files) 192 | 193 | def get_split(self): 194 | return "validation" 195 | 196 | def year(self): 197 | return '2017' 198 | 199 | 200 | 201 | class CocoImagesAndCaptionsTrain2014(CocoBase): 202 | """returns a pair of (image, caption)""" 203 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,crop_type='random'): 204 | super().__init__(size=size, 205 | dataroot="data/coco/train2014", 206 | datajson="data/coco/annotations2014/annotations/captions_train2014.json", 207 | onehot_segmentation=onehot_segmentation, 208 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, 209 | use_segmentation=False, 210 | crop_type=crop_type) 211 | 212 | def get_split(self): 213 | return "train" 214 | 215 | def year(self): 216 | return '2014' 217 | 218 | class CocoImagesAndCaptionsValidation2014(CocoBase): 219 | """returns a pair of (image, caption)""" 220 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False, 221 | given_files=None,crop_type='center',**kwargs): 222 | super().__init__(size=size, 223 | dataroot="data/coco/val2014", 224 | datajson="data/coco/annotations2014/annotations/captions_val2014.json", 225 | onehot_segmentation=onehot_segmentation, 226 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, 227 | given_files=given_files, 228 | use_segmentation=False, 229 | crop_type=crop_type) 230 | 231 | def get_split(self): 232 | return "validation" 233 | 234 | def year(self): 235 | return '2014' 236 | 237 | if __name__ == '__main__': 238 | with open("data/coco/annotations2014/annotations/captions_val2014.json", "r") as json_file: 239 | json_data = json.load(json_file) 240 | capdirs = json_data["annotations"] 241 | import pudb; pudb.set_trace() 242 | #d2 = CocoImagesAndCaptionsTrain2014(size=256) 243 | d2 = CocoImagesAndCaptionsValidation2014(size=256) 244 | print("constructed dataset.") 245 | print(f"length of {d2.__class__.__name__}: {len(d2)}") 246 | 247 | ex2 = d2[0] 248 | # ex3 = d3[0] 249 | # print(ex1["image"].shape) 250 | print(ex2["image"].shape) 251 | # print(ex3["image"].shape) 252 | # print(ex1["segmentation"].shape) 253 | print(ex2["caption"].__class__.__name__) 254 | -------------------------------------------------------------------------------- /ldm/models/diffusion/plms.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | 8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 9 | from ldm.models.diffusion.sampling_util import norm_thresholding 10 | 11 | 12 | class PLMSSampler(object): 13 | def __init__(self, model, schedule="linear", **kwargs): 14 | super().__init__() 15 | self.model = model 16 | self.ddpm_num_timesteps = model.num_timesteps 17 | self.schedule = schedule 18 | 19 | def register_buffer(self, name, attr): 20 | if type(attr) == torch.Tensor: 21 | if attr.device != torch.device("cuda"): 22 | attr = attr.to(torch.device("cuda")) 23 | setattr(self, name, attr) 24 | 25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 26 | if ddim_eta != 0: 27 | raise ValueError('ddim_eta must be 0 for PLMS') 28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 30 | alphas_cumprod = self.model.alphas_cumprod 31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 33 | 34 | self.register_buffer('betas', to_torch(self.model.betas)) 35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 37 | 38 | # calculations for diffusion q(x_t | x_{t-1}) and others 39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 44 | 45 | # ddim sampling parameters 46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 47 | ddim_timesteps=self.ddim_timesteps, 48 | eta=ddim_eta,verbose=verbose) 49 | self.register_buffer('ddim_sigmas', ddim_sigmas) 50 | self.register_buffer('ddim_alphas', ddim_alphas) 51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 57 | 58 | @torch.no_grad() 59 | def sample(self, 60 | S, 61 | batch_size, 62 | shape, 63 | conditioning=None, 64 | callback=None, 65 | normals_sequence=None, 66 | img_callback=None, 67 | quantize_x0=False, 68 | eta=0., 69 | mask=None, 70 | x0=None, 71 | temperature=1., 72 | noise_dropout=0., 73 | score_corrector=None, 74 | corrector_kwargs=None, 75 | verbose=True, 76 | x_T=None, 77 | log_every_t=100, 78 | unconditional_guidance_scale=1., 79 | unconditional_conditioning=None, 80 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 81 | dynamic_threshold=None, 82 | **kwargs 83 | ): 84 | if conditioning is not None: 85 | if isinstance(conditioning, dict): 86 | ctmp = conditioning[list(conditioning.keys())[0]] 87 | while isinstance(ctmp, list): ctmp = ctmp[0] 88 | cbs = ctmp.shape[0] 89 | if cbs != batch_size: 90 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 91 | else: 92 | if conditioning.shape[0] != batch_size: 93 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 94 | 95 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 96 | # sampling 97 | C, H, W = shape 98 | size = (batch_size, C, H, W) 99 | print(f'Data shape for PLMS sampling is {size}') 100 | 101 | samples, intermediates = self.plms_sampling(conditioning, size, 102 | callback=callback, 103 | img_callback=img_callback, 104 | quantize_denoised=quantize_x0, 105 | mask=mask, x0=x0, 106 | ddim_use_original_steps=False, 107 | noise_dropout=noise_dropout, 108 | temperature=temperature, 109 | score_corrector=score_corrector, 110 | corrector_kwargs=corrector_kwargs, 111 | x_T=x_T, 112 | log_every_t=log_every_t, 113 | unconditional_guidance_scale=unconditional_guidance_scale, 114 | unconditional_conditioning=unconditional_conditioning, 115 | dynamic_threshold=dynamic_threshold, 116 | ) 117 | return samples, intermediates 118 | 119 | @torch.no_grad() 120 | def plms_sampling(self, cond, shape, 121 | x_T=None, ddim_use_original_steps=False, 122 | callback=None, timesteps=None, quantize_denoised=False, 123 | mask=None, x0=None, img_callback=None, log_every_t=100, 124 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 125 | unconditional_guidance_scale=1., unconditional_conditioning=None, 126 | dynamic_threshold=None): 127 | device = self.model.betas.device 128 | b = shape[0] 129 | if x_T is None: 130 | img = torch.randn(shape, device=device) 131 | else: 132 | img = x_T 133 | 134 | if timesteps is None: 135 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 136 | elif timesteps is not None and not ddim_use_original_steps: 137 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 138 | timesteps = self.ddim_timesteps[:subset_end] 139 | 140 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 141 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) 142 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 143 | print(f"Running PLMS Sampling with {total_steps} timesteps") 144 | 145 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) 146 | old_eps = [] 147 | 148 | for i, step in enumerate(iterator): 149 | index = total_steps - i - 1 150 | ts = torch.full((b,), step, device=device, dtype=torch.long) 151 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) 152 | 153 | if mask is not None: 154 | assert x0 is not None 155 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 156 | img = img_orig * mask + (1. - mask) * img 157 | 158 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 159 | quantize_denoised=quantize_denoised, temperature=temperature, 160 | noise_dropout=noise_dropout, score_corrector=score_corrector, 161 | corrector_kwargs=corrector_kwargs, 162 | unconditional_guidance_scale=unconditional_guidance_scale, 163 | unconditional_conditioning=unconditional_conditioning, 164 | old_eps=old_eps, t_next=ts_next, 165 | dynamic_threshold=dynamic_threshold) 166 | img, pred_x0, e_t = outs 167 | old_eps.append(e_t) 168 | if len(old_eps) >= 4: 169 | old_eps.pop(0) 170 | if callback: callback(i) 171 | if img_callback: img_callback(pred_x0, i) 172 | 173 | if index % log_every_t == 0 or index == total_steps - 1: 174 | intermediates['x_inter'].append(img) 175 | intermediates['pred_x0'].append(pred_x0) 176 | 177 | return img, intermediates 178 | 179 | @torch.no_grad() 180 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 181 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 182 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, 183 | dynamic_threshold=None): 184 | b, *_, device = *x.shape, x.device 185 | 186 | def get_model_output(x, t): 187 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 188 | e_t = self.model.apply_model(x, t, c) 189 | else: 190 | x_in = torch.cat([x] * 2) 191 | t_in = torch.cat([t] * 2) 192 | if isinstance(c, dict): 193 | assert isinstance(unconditional_conditioning, dict) 194 | c_in = dict() 195 | for k in c: 196 | if isinstance(c[k], list): 197 | c_in[k] = [torch.cat([ 198 | unconditional_conditioning[k][i], 199 | c[k][i]]) for i in range(len(c[k]))] 200 | else: 201 | c_in[k] = torch.cat([ 202 | unconditional_conditioning[k], 203 | c[k]]) 204 | else: 205 | c_in = torch.cat([unconditional_conditioning, c]) 206 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 207 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 208 | 209 | if score_corrector is not None: 210 | assert self.model.parameterization == "eps" 211 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 212 | 213 | return e_t 214 | 215 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 216 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 217 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 218 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 219 | 220 | def get_x_prev_and_pred_x0(e_t, index): 221 | # select parameters corresponding to the currently considered timestep 222 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 223 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 224 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 225 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 226 | 227 | # current prediction for x_0 228 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 229 | if quantize_denoised: 230 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 231 | if dynamic_threshold is not None: 232 | pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) 233 | # direction pointing to x_t 234 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 235 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 236 | if noise_dropout > 0.: 237 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 238 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 239 | return x_prev, pred_x0 240 | 241 | e_t = get_model_output(x, t) 242 | if len(old_eps) == 0: 243 | # Pseudo Improved Euler (2nd order) 244 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) 245 | e_t_next = get_model_output(x_prev, t_next) 246 | e_t_prime = (e_t + e_t_next) / 2 247 | elif len(old_eps) == 1: 248 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth) 249 | e_t_prime = (3 * e_t - old_eps[-1]) / 2 250 | elif len(old_eps) == 2: 251 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth) 252 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 253 | elif len(old_eps) >= 3: 254 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth) 255 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 256 | 257 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) 258 | 259 | return x_prev, pred_x0, e_t 260 | -------------------------------------------------------------------------------- /ldm/data/imagenet.py: -------------------------------------------------------------------------------- 1 | import os, yaml, pickle, shutil, tarfile, glob 2 | import cv2 3 | import albumentations 4 | import PIL 5 | import numpy as np 6 | import torchvision.transforms.functional as TF 7 | from omegaconf import OmegaConf 8 | from functools import partial 9 | from PIL import Image 10 | from tqdm import tqdm 11 | from torch.utils.data import Dataset, Subset 12 | 13 | import taming.data.utils as tdu 14 | from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve 15 | from taming.data.imagenet import ImagePaths 16 | 17 | from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light 18 | 19 | 20 | def synset2idx(path_to_yaml="data/index_synset.yaml"): 21 | with open(path_to_yaml) as f: 22 | di2s = yaml.load(f) 23 | return dict((v,k) for k,v in di2s.items()) 24 | 25 | 26 | class ImageNetBase(Dataset): 27 | def __init__(self, config=None): 28 | self.config = config or OmegaConf.create() 29 | if not type(self.config)==dict: 30 | self.config = OmegaConf.to_container(self.config) 31 | self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) 32 | self.process_images = True # if False we skip loading & processing images and self.data contains filepaths 33 | self._prepare() 34 | self._prepare_synset_to_human() 35 | self._prepare_idx_to_synset() 36 | self._prepare_human_to_integer_label() 37 | self._load() 38 | 39 | def __len__(self): 40 | return len(self.data) 41 | 42 | def __getitem__(self, i): 43 | return self.data[i] 44 | 45 | def _prepare(self): 46 | raise NotImplementedError() 47 | 48 | def _filter_relpaths(self, relpaths): 49 | ignore = set([ 50 | "n06596364_9591.JPEG", 51 | ]) 52 | relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] 53 | if "sub_indices" in self.config: 54 | indices = str_to_indices(self.config["sub_indices"]) 55 | synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings 56 | self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) 57 | files = [] 58 | for rpath in relpaths: 59 | syn = rpath.split("/")[0] 60 | if syn in synsets: 61 | files.append(rpath) 62 | return files 63 | else: 64 | return relpaths 65 | 66 | def _prepare_synset_to_human(self): 67 | SIZE = 2655750 68 | URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" 69 | self.human_dict = os.path.join(self.root, "synset_human.txt") 70 | if (not os.path.exists(self.human_dict) or 71 | not os.path.getsize(self.human_dict)==SIZE): 72 | download(URL, self.human_dict) 73 | 74 | def _prepare_idx_to_synset(self): 75 | URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" 76 | self.idx2syn = os.path.join(self.root, "index_synset.yaml") 77 | if (not os.path.exists(self.idx2syn)): 78 | download(URL, self.idx2syn) 79 | 80 | def _prepare_human_to_integer_label(self): 81 | URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" 82 | self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt") 83 | if (not os.path.exists(self.human2integer)): 84 | download(URL, self.human2integer) 85 | with open(self.human2integer, "r") as f: 86 | lines = f.read().splitlines() 87 | assert len(lines) == 1000 88 | self.human2integer_dict = dict() 89 | for line in lines: 90 | value, key = line.split(":") 91 | self.human2integer_dict[key] = int(value) 92 | 93 | def _load(self): 94 | with open(self.txt_filelist, "r") as f: 95 | self.relpaths = f.read().splitlines() 96 | l1 = len(self.relpaths) 97 | self.relpaths = self._filter_relpaths(self.relpaths) 98 | print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) 99 | 100 | self.synsets = [p.split("/")[0] for p in self.relpaths] 101 | self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] 102 | 103 | unique_synsets = np.unique(self.synsets) 104 | class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) 105 | if not self.keep_orig_class_label: 106 | self.class_labels = [class_dict[s] for s in self.synsets] 107 | else: 108 | self.class_labels = [self.synset2idx[s] for s in self.synsets] 109 | 110 | with open(self.human_dict, "r") as f: 111 | human_dict = f.read().splitlines() 112 | human_dict = dict(line.split(maxsplit=1) for line in human_dict) 113 | 114 | self.human_labels = [human_dict[s] for s in self.synsets] 115 | 116 | labels = { 117 | "relpath": np.array(self.relpaths), 118 | "synsets": np.array(self.synsets), 119 | "class_label": np.array(self.class_labels), 120 | "human_label": np.array(self.human_labels), 121 | } 122 | 123 | if self.process_images: 124 | self.size = retrieve(self.config, "size", default=256) 125 | self.data = ImagePaths(self.abspaths, 126 | labels=labels, 127 | size=self.size, 128 | random_crop=self.random_crop, 129 | ) 130 | else: 131 | self.data = self.abspaths 132 | 133 | 134 | class ImageNetTrain(ImageNetBase): 135 | NAME = "ILSVRC2012_train" 136 | URL = "http://www.image-net.org/challenges/LSVRC/2012/" 137 | AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" 138 | FILES = [ 139 | "ILSVRC2012_img_train.tar", 140 | ] 141 | SIZES = [ 142 | 147897477120, 143 | ] 144 | 145 | def __init__(self, process_images=True, data_root=None, **kwargs): 146 | self.process_images = process_images 147 | self.data_root = data_root 148 | super().__init__(**kwargs) 149 | 150 | def _prepare(self): 151 | if self.data_root: 152 | self.root = os.path.join(self.data_root, self.NAME) 153 | else: 154 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) 155 | self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) 156 | 157 | self.datadir = os.path.join(self.root, "data") 158 | self.txt_filelist = os.path.join(self.root, "filelist.txt") 159 | self.expected_length = 1281167 160 | self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", 161 | default=True) 162 | if not tdu.is_prepared(self.root): 163 | # prep 164 | print("Preparing dataset {} in {}".format(self.NAME, self.root)) 165 | 166 | datadir = self.datadir 167 | if not os.path.exists(datadir): 168 | path = os.path.join(self.root, self.FILES[0]) 169 | if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: 170 | import academictorrents as at 171 | atpath = at.get(self.AT_HASH, datastore=self.root) 172 | assert atpath == path 173 | 174 | print("Extracting {} to {}".format(path, datadir)) 175 | os.makedirs(datadir, exist_ok=True) 176 | with tarfile.open(path, "r:") as tar: 177 | tar.extractall(path=datadir) 178 | 179 | print("Extracting sub-tars.") 180 | subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) 181 | for subpath in tqdm(subpaths): 182 | subdir = subpath[:-len(".tar")] 183 | os.makedirs(subdir, exist_ok=True) 184 | with tarfile.open(subpath, "r:") as tar: 185 | tar.extractall(path=subdir) 186 | 187 | filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) 188 | filelist = [os.path.relpath(p, start=datadir) for p in filelist] 189 | filelist = sorted(filelist) 190 | filelist = "\n".join(filelist)+"\n" 191 | with open(self.txt_filelist, "w") as f: 192 | f.write(filelist) 193 | 194 | tdu.mark_prepared(self.root) 195 | 196 | 197 | class ImageNetValidation(ImageNetBase): 198 | NAME = "ILSVRC2012_validation" 199 | URL = "http://www.image-net.org/challenges/LSVRC/2012/" 200 | AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" 201 | VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" 202 | FILES = [ 203 | "ILSVRC2012_img_val.tar", 204 | "validation_synset.txt", 205 | ] 206 | SIZES = [ 207 | 6744924160, 208 | 1950000, 209 | ] 210 | 211 | def __init__(self, process_images=True, data_root=None, **kwargs): 212 | self.data_root = data_root 213 | self.process_images = process_images 214 | super().__init__(**kwargs) 215 | 216 | def _prepare(self): 217 | if self.data_root: 218 | self.root = os.path.join(self.data_root, self.NAME) 219 | else: 220 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) 221 | self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) 222 | self.datadir = os.path.join(self.root, "data") 223 | self.txt_filelist = os.path.join(self.root, "filelist.txt") 224 | self.expected_length = 50000 225 | self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", 226 | default=False) 227 | if not tdu.is_prepared(self.root): 228 | # prep 229 | print("Preparing dataset {} in {}".format(self.NAME, self.root)) 230 | 231 | datadir = self.datadir 232 | if not os.path.exists(datadir): 233 | path = os.path.join(self.root, self.FILES[0]) 234 | if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: 235 | import academictorrents as at 236 | atpath = at.get(self.AT_HASH, datastore=self.root) 237 | assert atpath == path 238 | 239 | print("Extracting {} to {}".format(path, datadir)) 240 | os.makedirs(datadir, exist_ok=True) 241 | with tarfile.open(path, "r:") as tar: 242 | tar.extractall(path=datadir) 243 | 244 | vspath = os.path.join(self.root, self.FILES[1]) 245 | if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: 246 | download(self.VS_URL, vspath) 247 | 248 | with open(vspath, "r") as f: 249 | synset_dict = f.read().splitlines() 250 | synset_dict = dict(line.split() for line in synset_dict) 251 | 252 | print("Reorganizing into synset folders") 253 | synsets = np.unique(list(synset_dict.values())) 254 | for s in synsets: 255 | os.makedirs(os.path.join(datadir, s), exist_ok=True) 256 | for k, v in synset_dict.items(): 257 | src = os.path.join(datadir, k) 258 | dst = os.path.join(datadir, v) 259 | shutil.move(src, dst) 260 | 261 | filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) 262 | filelist = [os.path.relpath(p, start=datadir) for p in filelist] 263 | filelist = sorted(filelist) 264 | filelist = "\n".join(filelist)+"\n" 265 | with open(self.txt_filelist, "w") as f: 266 | f.write(filelist) 267 | 268 | tdu.mark_prepared(self.root) 269 | 270 | 271 | 272 | class ImageNetSR(Dataset): 273 | def __init__(self, size=None, 274 | degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., 275 | random_crop=True): 276 | """ 277 | Imagenet Superresolution Dataloader 278 | Performs following ops in order: 279 | 1. crops a crop of size s from image either as random or center crop 280 | 2. resizes crop to size with cv2.area_interpolation 281 | 3. degrades resized crop with degradation_fn 282 | 283 | :param size: resizing to size after cropping 284 | :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light 285 | :param downscale_f: Low Resolution Downsample factor 286 | :param min_crop_f: determines crop size s, 287 | where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f) 288 | :param max_crop_f: "" 289 | :param data_root: 290 | :param random_crop: 291 | """ 292 | self.base = self.get_base() 293 | assert size 294 | assert (size / downscale_f).is_integer() 295 | self.size = size 296 | self.LR_size = int(size / downscale_f) 297 | self.min_crop_f = min_crop_f 298 | self.max_crop_f = max_crop_f 299 | assert(max_crop_f <= 1.) 300 | self.center_crop = not random_crop 301 | 302 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) 303 | 304 | self.pil_interpolation = False # gets reset later if incase interp_op is from pillow 305 | 306 | if degradation == "bsrgan": 307 | self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) 308 | 309 | elif degradation == "bsrgan_light": 310 | self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) 311 | 312 | else: 313 | interpolation_fn = { 314 | "cv_nearest": cv2.INTER_NEAREST, 315 | "cv_bilinear": cv2.INTER_LINEAR, 316 | "cv_bicubic": cv2.INTER_CUBIC, 317 | "cv_area": cv2.INTER_AREA, 318 | "cv_lanczos": cv2.INTER_LANCZOS4, 319 | "pil_nearest": PIL.Image.NEAREST, 320 | "pil_bilinear": PIL.Image.BILINEAR, 321 | "pil_bicubic": PIL.Image.BICUBIC, 322 | "pil_box": PIL.Image.BOX, 323 | "pil_hamming": PIL.Image.HAMMING, 324 | "pil_lanczos": PIL.Image.LANCZOS, 325 | }[degradation] 326 | 327 | self.pil_interpolation = degradation.startswith("pil_") 328 | 329 | if self.pil_interpolation: 330 | self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) 331 | 332 | else: 333 | self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, 334 | interpolation=interpolation_fn) 335 | 336 | def __len__(self): 337 | return len(self.base) 338 | 339 | def __getitem__(self, i): 340 | example = self.base[i] 341 | image = Image.open(example["file_path_"]) 342 | 343 | if not image.mode == "RGB": 344 | image = image.convert("RGB") 345 | 346 | image = np.array(image).astype(np.uint8) 347 | 348 | min_side_len = min(image.shape[:2]) 349 | crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) 350 | crop_side_len = int(crop_side_len) 351 | 352 | if self.center_crop: 353 | self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) 354 | 355 | else: 356 | self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) 357 | 358 | image = self.cropper(image=image)["image"] 359 | image = self.image_rescaler(image=image)["image"] 360 | 361 | if self.pil_interpolation: 362 | image_pil = PIL.Image.fromarray(image) 363 | LR_image = self.degradation_process(image_pil) 364 | LR_image = np.array(LR_image).astype(np.uint8) 365 | 366 | else: 367 | LR_image = self.degradation_process(image=image)["image"] 368 | 369 | example["image"] = (image/127.5 - 1.0).astype(np.float32) 370 | example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) 371 | example["caption"] = example["human_label"] # dummy caption 372 | return example 373 | 374 | 375 | class ImageNetSRTrain(ImageNetSR): 376 | def __init__(self, **kwargs): 377 | super().__init__(**kwargs) 378 | 379 | def get_base(self): 380 | with open("data/imagenet_train_hr_indices.p", "rb") as f: 381 | indices = pickle.load(f) 382 | dset = ImageNetTrain(process_images=False,) 383 | return Subset(dset, indices) 384 | 385 | 386 | class ImageNetSRValidation(ImageNetSR): 387 | def __init__(self, **kwargs): 388 | super().__init__(**kwargs) 389 | 390 | def get_base(self): 391 | with open("data/imagenet_val_hr_indices.p", "rb") as f: 392 | indices = pickle.load(f) 393 | dset = ImageNetValidation(process_images=False,) 394 | return Subset(dset, indices) 395 | --------------------------------------------------------------------------------