├── CtrlColor_environ.yaml ├── LICENSE ├── README.md ├── annotator └── util.py ├── assets ├── hf_demo1.png ├── iterative.gif ├── region.gif ├── region_cond.gif └── teaser_aligned.png ├── cldm ├── cldm.py ├── ddim_haced_sag_step.py ├── ddim_hacked_sag.py ├── hack.py └── model.py ├── config.py ├── ldm ├── models │ ├── autoencoder.py │ ├── autoencoder_train.py │ ├── diffusion │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── ddim.cpython-38.pyc │ │ │ ├── ddpm.cpython-38.pyc │ │ │ └── ddpm_nonoise.cpython-38.pyc │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── dpm_solver │ │ │ ├── __init__.py │ │ │ ├── dpm_solver.py │ │ │ └── sampler.py │ │ ├── plms.py │ │ └── sampling_util.py │ └── logger.py ├── modules │ ├── attention.py │ ├── attention_dcn_control.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── model.cpython-38.pyc │ │ │ ├── model_brefore_dcn.cpython-38.pyc │ │ │ ├── openaimodel.cpython-38.pyc │ │ │ └── util.cpython-38.pyc │ │ ├── model.py │ │ ├── model_brefore_dcn.py │ │ ├── openaimodel.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── distributions.cpython-38.pyc │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── modules.cpython-38.pyc │ │ └── modules.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ ├── losses │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── contperceptual.cpython-38.pyc │ │ │ └── vqperceptual.cpython-38.pyc │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ └── midas │ │ ├── __init__.py │ │ ├── api.py │ │ ├── midas │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── blocks.py │ │ ├── dpt_depth.py │ │ ├── midas_net.py │ │ ├── midas_net_custom.py │ │ ├── transforms.py │ │ └── vit.py │ │ └── utils.py └── util.py ├── models ├── cldm_v15_inpainting_infer.yaml └── cldm_v15_inpainting_infer1.yaml ├── pretrained_models └── put_ckpts_here.txt ├── share.py ├── taming ├── data │ ├── ade20k.py │ ├── annotated_objects_coco.py │ ├── annotated_objects_dataset.py │ ├── annotated_objects_open_images.py │ ├── base.py │ ├── coco.py │ ├── conditional_builder │ │ ├── objects_bbox.py │ │ ├── objects_center_points.py │ │ └── utils.py │ ├── custom.py │ ├── faceshq.py │ ├── helper_types.py │ ├── image_transforms.py │ ├── imagenet.py │ ├── open_images_helper.py │ ├── sflckr.py │ └── utils.py ├── lr_scheduler.py ├── models │ ├── cond_transformer.py │ ├── dummy_cond_stage.py │ └── vqgan.py ├── modules │ ├── __pycache__ │ │ └── util.cpython-38.pyc │ ├── autoencoder │ │ └── lpips │ │ │ └── vgg.pth │ ├── diffusionmodules │ │ └── model.py │ ├── discriminator │ │ ├── __pycache__ │ │ │ └── model.cpython-38.pyc │ │ └── model.py │ ├── losses │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── lpips.cpython-38.pyc │ │ │ └── vqperceptual.cpython-38.pyc │ │ ├── lpips.py │ │ ├── segmentation.py │ │ └── vqperceptual.py │ ├── misc │ │ └── coord.py │ ├── transformer │ │ ├── mingpt.py │ │ └── permuter.py │ ├── util.py │ └── vqvae │ │ └── quantize.py └── util.py └── test.py /CtrlColor_environ.yaml: -------------------------------------------------------------------------------- 1 | name: CtrlColor 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.12.1 10 | - torchvision=0.13.1 11 | - numpy=1.23.1 12 | - pip: 13 | - gradio==3.31.0 14 | - gradio-client==0.2.5 15 | - albumentations==1.3.0 16 | - opencv-python==4.9.0.80 17 | - opencv-python-headless==4.5.5.64 18 | - imageio==2.9.0 19 | - imageio-ffmpeg==0.4.2 20 | - pytorch-lightning==1.5.0 21 | - omegaconf==2.1.1 22 | - test-tube>=0.7.5 23 | - streamlit==1.12.1 24 | - webdataset==0.2.5 25 | - kornia==0.6 26 | - open_clip_torch==2.0.2 27 | - invisible-watermark>=0.1.5 28 | - streamlit-drawable-canvas==0.8.0 29 | - torchmetrics==0.6.0 30 | - addict==2.4.0 31 | - yapf==0.32.0 32 | - prettytable==3.6.0 33 | - basicsr==1.4.2 34 | - salesforce-lavis==1.0.2 35 | - grpcio==1.60 36 | - pydantic==1.10.5 37 | - spacy==3.5.1 38 | - typer==0.7.0 39 | - typing-extensions==4.4.0 40 | - fastapi==0.92.0 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | S-Lab License 1.0 2 | 3 | Copyright 2022 S-Lab 4 | 5 | Redistribution and use for non-commercial purpose in source and 6 | binary forms, with or without modification, are permitted provided 7 | that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in 14 | the documentation and/or other materials provided with the 15 | distribution. 16 | 17 | 3. Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived 19 | from this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 25 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | 33 | In the event that redistribution and/or use for commercial purpose in 34 | source or binary forms, with or without modification is required, 35 | please contact the contributor(s) of the work. 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 
2 | 3 |

Control Color: Multimodal Diffusion-based Interactive Image Colorization

4 | 5 |
6 | Zhexin Liang  7 | Zhaochen Li  8 | Shangchen Zhou  9 | Chongyi Li  10 | Chen Change Loy 11 |
12 |
13 | S-Lab, Nanyang Technological University  14 |
15 | 16 | 19 | 27 | 28 |
29 |

30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 |

41 |
42 | 43 | 44 | 45 | Control Color (CtrlColor) achieves highly controllable multimodal image colorization based on stable diffusion model. 46 | 47 |
48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 |
Region colorizationIterative editing
60 | 61 | :open_book: For more visual results and applications of CtrlColor, go checkout our project page. 62 | 63 | --- 64 | 65 | 66 | 76 | 77 | 87 | 88 | ## :mega: Updates 89 | - **2024.12.16**: The test codes (gradio demo), colorization model checkpoint, and autoencoder checkpoint are now publicly available. 90 | 91 | ## :desktop_computer: Requirements 92 | 93 | - required packages in `CtrlColor_environ.yaml` 94 | 95 | ``` 96 | # git clone this repository 97 | git clone https://github.com/ZhexinLiang/Control-Color.git 98 | cd Control_Color 99 | 100 | # create new anaconda env and install python dependencies 101 | conda env create -f CtrlColor_environ.yaml 102 | conda activate CtrlColor 103 | ``` 104 | 105 | ## :running_woman: Inference 106 | 107 | ### Prepare models: 108 | 109 | Please download the checkpoints of both colorization model and vae from [[Google Drive](https://drive.google.com/drive/folders/1lgqstNwrMCzymowRsbGM-4hk0-7L-eOT?usp=sharing)] and put both checkpoints in `./pretrained_models` folder. 110 | 111 | ### Testing: 112 | 113 | You can use the following cmd to run gradio demo: 114 | 115 | ``` 116 | python test.py 117 | ``` 118 | Then you will get our interactive interface as below: 119 | 120 | 121 | 122 | ## :love_you_gesture: Citation 123 | If you find our work useful for your research, please consider citing the paper: 124 | ``` 125 | @article{liang2024control, 126 | title={Control Color: Multimodal Diffusion-based Interactive Image Colorization}, 127 | author={Liang, Zhexin and Li, Zhaochen and Zhou, Shangchen and Li, Chongyi and Loy, Chen Change}, 128 | journal={arXiv preprint arXiv:2402.10855}, 129 | year={2024} 130 | } 131 | ``` 132 | 133 | ### Contact 134 | If you have any questions, please feel free to reach out at `zhexinliang@gmail.com`. 135 | 136 | 145 | -------------------------------------------------------------------------------- /annotator/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | 5 | 6 | annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') 7 | 8 | 9 | def HWC3(x): 10 | assert x.dtype == np.uint8 11 | if x.ndim == 2: 12 | x = x[:, :, None] 13 | assert x.ndim == 3 14 | H, W, C = x.shape 15 | assert C == 1 or C == 3 or C == 4 16 | if C == 3: 17 | return x 18 | if C == 1: 19 | return np.concatenate([x, x, x], axis=2) 20 | if C == 4: 21 | color = x[:, :, 0:3].astype(np.float32) 22 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 23 | y = color * alpha + 255.0 * (1.0 - alpha) 24 | y = y.clip(0, 255).astype(np.uint8) 25 | return y 26 | 27 | 28 | def resize_image(input_image, resolution): 29 | H, W, C = input_image.shape 30 | H = float(H) 31 | W = float(W) 32 | k = float(resolution) / min(H, W)#min(H,W) 33 | H *= k 34 | W *= k 35 | H_new = int(np.round(H / 64.0)) * 64 36 | W_new = int(np.round(W / 64.0)) * 64 37 | H = H_new if H_new<800 else int(np.round(800 / 64.0)) * 64#1024->896 38 | W=W_new if W_new<800 else int(np.round(800 / 64.0)) * 64 39 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) 40 | return img 41 | -------------------------------------------------------------------------------- /assets/hf_demo1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/assets/hf_demo1.png -------------------------------------------------------------------------------- /assets/iterative.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/assets/iterative.gif -------------------------------------------------------------------------------- /assets/region.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/assets/region.gif -------------------------------------------------------------------------------- /assets/region_cond.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/assets/region_cond.gif -------------------------------------------------------------------------------- /assets/teaser_aligned.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/assets/teaser_aligned.png -------------------------------------------------------------------------------- /cldm/hack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import einops 3 | 4 | import ldm.modules.encoders.modules 5 | import ldm.modules.attention 6 | 7 | from transformers import logging 8 | from ldm.modules.attention import default 9 | 10 | 11 | def disable_verbosity(): 12 | logging.set_verbosity_error() 13 | print('logging improved.') 14 | return 15 | 16 | 17 | def enable_sliced_attention(): 18 | ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward 19 | print('Enabled sliced_attention.') 20 | return 21 | 22 | 23 | def hack_everything(clip_skip=0): 24 | disable_verbosity() 25 | ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward 26 | ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip 27 | print('Enabled clip hacks.') 28 | return 29 | 30 | 31 | # Written by Lvmin 32 | def _hacked_clip_forward(self, text): 33 | PAD = self.tokenizer.pad_token_id 34 | EOS = self.tokenizer.eos_token_id 35 | BOS = self.tokenizer.bos_token_id 36 | 37 | def tokenize(t): 38 | return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"] 39 | 40 | def transformer_encode(t): 41 | if self.clip_skip > 1: 42 | rt = self.transformer(input_ids=t, output_hidden_states=True) 43 | return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip]) 44 | else: 45 | return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state 46 | 47 | def split(x): 48 | return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3] 49 | 50 | def pad(x, p, i): 51 | return x[:i] if len(x) >= i else x + [p] * (i - len(x)) 52 | 53 | raw_tokens_list = tokenize(text) 54 | tokens_list = [] 55 | 56 | for raw_tokens in raw_tokens_list: 57 | raw_tokens_123 = split(raw_tokens) 58 | raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123] 59 | raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123] 60 | tokens_list.append(raw_tokens_123) 61 | 62 | tokens_list = torch.IntTensor(tokens_list).to(self.device) 63 | 64 | feed = einops.rearrange(tokens_list, 'b f i -> (b f) i') 65 | y = transformer_encode(feed) 66 | z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3) 67 | 68 | return z 69 | 70 | 71 | # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py 72 | def _hacked_sliced_attentin_forward(self, x, context=None, mask=None): 73 | h = self.heads 74 | 75 | q = self.to_q(x) 76 | context = default(context, x) 77 | k = self.to_k(context) 78 | v = self.to_v(context) 79 | del context, x 80 | 81 | q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 82 | 83 | limit = k.shape[0] 84 | att_step = 1 85 | q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0)) 86 | k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0)) 87 | v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0)) 88 | 89 | q_chunks.reverse() 90 | k_chunks.reverse() 91 | v_chunks.reverse() 92 | sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) 93 | del k, q, v 94 | for i in range(0, limit, att_step): 95 | q_buffer = q_chunks.pop() 96 | k_buffer = k_chunks.pop() 97 | v_buffer = v_chunks.pop() 98 | sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale 99 | 100 | del k_buffer, q_buffer 101 | # attention, what we cannot get enough of, by chunks 102 | 103 | sim_buffer = sim_buffer.softmax(dim=-1) 104 | 105 | sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer) 106 | del v_buffer 107 | sim[i:i + att_step, :, :] = sim_buffer 108 | 109 | del sim_buffer 110 | sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h) 111 | return self.to_out(sim) 112 | -------------------------------------------------------------------------------- /cldm/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from omegaconf import OmegaConf 5 | from ldm.util import instantiate_from_config 6 | 7 | 8 | def get_state_dict(d): 9 | return d.get('state_dict', d) 10 | 11 | 12 | def load_state_dict(ckpt_path, location='cpu'): 13 | _, extension = os.path.splitext(ckpt_path) 14 | if extension.lower() == ".safetensors": 15 | import safetensors.torch 16 | state_dict = safetensors.torch.load_file(ckpt_path, device=location) 17 | else: 18 | state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location))) 19 | state_dict = get_state_dict(state_dict) 20 | print(f'Loaded state_dict from [{ckpt_path}]') 21 | return state_dict 22 | 23 | 24 | def create_model(config_path): 25 | config = OmegaConf.load(config_path) 26 | model = instantiate_from_config(config.model).cpu() 27 | print(f'Loaded model config from [{config_path}]') 28 | return model 29 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | save_memory = False 2 | -------------------------------------------------------------------------------- /ldm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import torch.nn.functional as F 4 | from contextlib import contextmanager 5 | 6 | # from ldm.modules.diffusionmodules.model_window import Encoder, Decoder 7 | from ldm.modules.diffusionmodules.model_brefore_dcn import Encoder, Decoder 8 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution 9 | 10 | from ldm.util import instantiate_from_config 11 | from ldm.modules.ema import LitEma 12 | 13 | 14 | class AutoencoderKL(pl.LightningModule): 15 | def __init__(self, 16 | ddconfig, 17 | lossconfig, 18 | embed_dim, 19 | ckpt_path=None, 20 | ignore_keys=[], 21 | image_key="image", 22 | colorize_nlabels=None, 23 | monitor=None, 24 | ema_decay=None, 25 | learn_logvar=False 26 | ): 27 | super().__init__() 28 | self.learn_logvar = learn_logvar 29 | self.image_key = image_key 30 | self.encoder = Encoder(**ddconfig) 31 | self.decoder = Decoder(**ddconfig) 32 | self.loss = instantiate_from_config(lossconfig) 33 | assert ddconfig["double_z"] 34 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 35 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 36 | self.embed_dim = embed_dim 37 | if colorize_nlabels is not None: 38 | assert type(colorize_nlabels)==int 39 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 40 | if monitor is not None: 41 | self.monitor = monitor 42 | 43 | self.use_ema = ema_decay is not None 44 | if self.use_ema: 45 | self.ema_decay = ema_decay 46 | assert 0. < ema_decay < 1. 47 | self.model_ema = LitEma(self, decay=ema_decay) 48 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 49 | 50 | if ckpt_path is not None: 51 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 52 | 53 | def init_from_ckpt(self, path, ignore_keys=list()): 54 | sd = torch.load(path, map_location="cpu")["state_dict"] 55 | keys = list(sd.keys()) 56 | for k in keys: 57 | for ik in ignore_keys: 58 | if k.startswith(ik): 59 | print("Deleting key {} from state_dict.".format(k)) 60 | del sd[k] 61 | self.load_state_dict(sd, strict=False) 62 | print(f"Restored from {path}") 63 | 64 | @contextmanager 65 | def ema_scope(self, context=None): 66 | if self.use_ema: 67 | self.model_ema.store(self.parameters()) 68 | self.model_ema.copy_to(self) 69 | if context is not None: 70 | print(f"{context}: Switched to EMA weights") 71 | try: 72 | yield None 73 | finally: 74 | if self.use_ema: 75 | self.model_ema.restore(self.parameters()) 76 | if context is not None: 77 | print(f"{context}: Restored training weights") 78 | 79 | def on_train_batch_end(self, *args, **kwargs): 80 | if self.use_ema: 81 | self.model_ema(self) 82 | 83 | def encode(self, x): 84 | h = self.encoder(x) 85 | moments = self.quant_conv(h) 86 | posterior = DiagonalGaussianDistribution(moments) 87 | return posterior 88 | 89 | def decode(self, z): 90 | z = self.post_quant_conv(z) 91 | dec = self.decoder(z) 92 | return dec 93 | 94 | def forward(self, input, sample_posterior=True): 95 | posterior = self.encode(input) 96 | if sample_posterior: 97 | z = posterior.sample() 98 | else: 99 | z = posterior.mode() 100 | dec = self.decode(z) 101 | return dec, posterior 102 | 103 | def get_input(self, batch, k): 104 | x = batch[k] 105 | if len(x.shape) == 3: 106 | x = x[..., None] 107 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 108 | return x 109 | 110 | def training_step(self, batch, batch_idx, optimizer_idx): 111 | inputs = self.get_input(batch, self.image_key) 112 | reconstructions, posterior = self(inputs) 113 | 114 | if optimizer_idx == 0: 115 | # train encoder+decoder+logvar 116 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 117 | last_layer=self.get_last_layer(), split="train") 118 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 119 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 120 | return aeloss 121 | 122 | if optimizer_idx == 1: 123 | # train the discriminator 124 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 125 | last_layer=self.get_last_layer(), split="train") 126 | 127 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 128 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 129 | return discloss 130 | 131 | def validation_step(self, batch, batch_idx): 132 | log_dict = self._validation_step(batch, batch_idx) 133 | with self.ema_scope(): 134 | log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") 135 | return log_dict 136 | 137 | def _validation_step(self, batch, batch_idx, postfix=""): 138 | inputs = self.get_input(batch, self.image_key) 139 | reconstructions, posterior = self(inputs) 140 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 141 | last_layer=self.get_last_layer(), split="val"+postfix) 142 | 143 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 144 | last_layer=self.get_last_layer(), split="val"+postfix) 145 | 146 | self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) 147 | self.log_dict(log_dict_ae) 148 | self.log_dict(log_dict_disc) 149 | return self.log_dict 150 | 151 | def configure_optimizers(self): 152 | lr = self.learning_rate 153 | ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( 154 | self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) 155 | if self.learn_logvar: 156 | print(f"{self.__class__.__name__}: Learning logvar") 157 | ae_params_list.append(self.loss.logvar) 158 | opt_ae = torch.optim.Adam(ae_params_list, 159 | lr=lr, betas=(0.5, 0.9)) 160 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 161 | lr=lr, betas=(0.5, 0.9)) 162 | return [opt_ae, opt_disc], [] 163 | 164 | def get_last_layer(self): 165 | return self.decoder.conv_out.weight 166 | 167 | @torch.no_grad() 168 | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): 169 | log = dict() 170 | x = self.get_input(batch, self.image_key) 171 | x = x.to(self.device) 172 | if not only_inputs: 173 | xrec, posterior = self(x) 174 | if x.shape[1] > 3: 175 | # colorize with random projection 176 | assert xrec.shape[1] > 3 177 | x = self.to_rgb(x) 178 | xrec = self.to_rgb(xrec) 179 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 180 | log["reconstructions"] = xrec 181 | if log_ema or self.use_ema: 182 | with self.ema_scope(): 183 | xrec_ema, posterior_ema = self(x) 184 | if x.shape[1] > 3: 185 | # colorize with random projection 186 | assert xrec_ema.shape[1] > 3 187 | xrec_ema = self.to_rgb(xrec_ema) 188 | log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) 189 | log["reconstructions_ema"] = xrec_ema 190 | log["inputs"] = x 191 | return log 192 | 193 | def to_rgb(self, x): 194 | assert self.image_key == "segmentation" 195 | if not hasattr(self, "colorize"): 196 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 197 | x = F.conv2d(x, weight=self.colorize) 198 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 199 | return x 200 | 201 | 202 | class IdentityFirstStage(torch.nn.Module): 203 | def __init__(self, *args, vq_interface=False, **kwargs): 204 | self.vq_interface = vq_interface 205 | super().__init__() 206 | 207 | def encode(self, x, *args, **kwargs): 208 | return x 209 | 210 | def decode(self, x, *args, **kwargs): 211 | return x 212 | 213 | def quantize(self, x, *args, **kwargs): 214 | if self.vq_interface: 215 | return x, None, [None, None, None] 216 | return x 217 | 218 | def forward(self, x, *args, **kwargs): 219 | return x 220 | 221 | -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddpm_nonoise.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/models/diffusion/__pycache__/ddpm_nonoise.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | import torch 3 | 4 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 5 | 6 | 7 | MODEL_TYPES = { 8 | "eps": "noise", 9 | "v": "v" 10 | } 11 | 12 | 13 | class DPMSolverSampler(object): 14 | def __init__(self, model, **kwargs): 15 | super().__init__() 16 | self.model = model 17 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 18 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 19 | 20 | def register_buffer(self, name, attr): 21 | if type(attr) == torch.Tensor: 22 | if attr.device != torch.device("cuda"): 23 | attr = attr.to(torch.device("cuda")) 24 | setattr(self, name, attr) 25 | 26 | @torch.no_grad() 27 | def sample(self, 28 | S, 29 | batch_size, 30 | shape, 31 | conditioning=None, 32 | callback=None, 33 | normals_sequence=None, 34 | img_callback=None, 35 | quantize_x0=False, 36 | eta=0., 37 | mask=None, 38 | x0=None, 39 | temperature=1., 40 | noise_dropout=0., 41 | score_corrector=None, 42 | corrector_kwargs=None, 43 | verbose=True, 44 | x_T=None, 45 | log_every_t=100, 46 | unconditional_guidance_scale=1., 47 | unconditional_conditioning=None, 48 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 49 | **kwargs 50 | ): 51 | if conditioning is not None: 52 | if isinstance(conditioning, dict): 53 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 54 | if cbs != batch_size: 55 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 56 | else: 57 | if conditioning.shape[0] != batch_size: 58 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 59 | 60 | # sampling 61 | C, H, W = shape 62 | size = (batch_size, C, H, W) 63 | 64 | print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 65 | 66 | device = self.model.betas.device 67 | if x_T is None: 68 | img = torch.randn(size, device=device) 69 | else: 70 | img = x_T 71 | 72 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 73 | 74 | model_fn = model_wrapper( 75 | lambda x, t, c: self.model.apply_model(x, t, c), 76 | ns, 77 | model_type=MODEL_TYPES[self.model.parameterization], 78 | guidance_type="classifier-free", 79 | condition=conditioning, 80 | unconditional_condition=unconditional_conditioning, 81 | guidance_scale=unconditional_guidance_scale, 82 | ) 83 | 84 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 85 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) 86 | 87 | return x.to(device), None -------------------------------------------------------------------------------- /ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def norm_thresholding(x0, value): 15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 16 | return x0 * (value / s) 17 | 18 | 19 | def spatial_norm_thresholding(x0, value): 20 | # b c h w 21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 22 | return x0 * (value / s) -------------------------------------------------------------------------------- /ldm/models/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torchvision 6 | from PIL import Image 7 | from pytorch_lightning.callbacks import Callback 8 | from pytorch_lightning.utilities.distributed import rank_zero_only 9 | 10 | # import pdb 11 | 12 | class ImageLogger(Callback): 13 | def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True, 14 | rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, 15 | log_images_kwargs=None,ckpt_dir="./ckpt"): 16 | super().__init__() 17 | self.rescale = rescale 18 | self.batch_freq = batch_frequency 19 | self.max_images = max_images 20 | if not increase_log_steps: 21 | self.log_steps = [self.batch_freq] 22 | self.clamp = clamp 23 | self.disabled = disabled 24 | self.log_on_batch_idx = log_on_batch_idx 25 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} 26 | self.log_first_step = log_first_step 27 | self.ckpt_dir=ckpt_dir 28 | self.global_save_num=-2000 29 | self.global_save_num1=-100 30 | 31 | @rank_zero_only 32 | def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): 33 | root = os.path.join(save_dir, "image_log", split) 34 | # print(images) 35 | for k in images: 36 | grid = torchvision.utils.make_grid(images[k], nrow=4) 37 | if self.rescale: 38 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w 39 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) 40 | grid = grid.numpy() 41 | grid = (grid * 255).astype(np.uint8) 42 | filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx) 43 | path = os.path.join(root, filename) 44 | os.makedirs(os.path.split(path)[0], exist_ok=True) 45 | Image.fromarray(grid).save(path) 46 | 47 | def log_img(self, pl_module, batch, batch_idx, split="train"): 48 | check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step 49 | if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 50 | hasattr(pl_module, "log_images") and 51 | callable(pl_module.log_images) and 52 | self.max_images > 0): 53 | logger = type(pl_module.logger) 54 | 55 | is_train = pl_module.training 56 | if is_train: 57 | pl_module.eval() 58 | 59 | with torch.no_grad(): 60 | images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) 61 | 62 | for k in images: 63 | N = min(images[k].shape[0], self.max_images) 64 | images[k] = images[k][:N] 65 | if isinstance(images[k], torch.Tensor): 66 | images[k] = images[k].detach().cpu() 67 | if self.clamp: 68 | images[k] = torch.clamp(images[k], -1., 1.) 69 | 70 | self.log_local(pl_module.logger.save_dir, split, images, 71 | pl_module.global_step, pl_module.current_epoch, batch_idx) 72 | 73 | if is_train: 74 | pl_module.train() 75 | 76 | def check_frequency(self, check_idx): 77 | return check_idx % self.batch_freq == 0 78 | 79 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 80 | #if not self.disabled: 81 | #if pl_module.global_step%50 == 0: 82 | # if pl_module.current_epoch-self.global_save_num1 > 0: 83 | # print(batch_idx) 84 | if batch_idx % 500 == 0: 85 | # print("inside") 86 | # pdb.set_trace() 87 | # self.global_save_num1=pl_module.current_epoch 88 | self.log_img(pl_module, batch, batch_idx, split="train_"+"ckpt_inpainting_from5625_2+3750_exemplar_only_vae") 89 | #if pl_module.global_step%1200 == 0 and self.check_frequency(batch_idx): 90 | if batch_idx % 1000 == 0: 91 | # if pl_module.current_epoch-self.global_save_num>10 and self.check_frequency(batch_idx): 92 | # self.global_save_num=pl_module.current_epoch 93 | trainer.save_checkpoint(self.ckpt_dir+"/epoch"+str(pl_module.current_epoch)+"_global-step"+str(pl_module.global_step)+".ckpt") 94 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/model_brefore_dcn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/diffusionmodules/__pycache__/model_brefore_dcn.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def sample_addhint(self, generator): 40 | latents = torch.randn(self.mean.shape, generator=generator, device='cpu', dtype=self.parameters.dtype).cuda() 41 | x = self.mean + self.std * latents 42 | return x 43 | 44 | def kl(self, other=None): 45 | if self.deterministic: 46 | return torch.Tensor([0.]) 47 | else: 48 | if other is None: 49 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 50 | + self.var - 1.0 - self.logvar, 51 | dim=[1, 2, 3]) 52 | else: 53 | return 0.5 * torch.sum( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 56 | dim=[1, 2, 3]) 57 | 58 | def nll(self, sample, dims=[1,2,3]): 59 | if self.deterministic: 60 | return torch.Tensor([0.]) 61 | logtwopi = np.log(2.0 * np.pi) 62 | return 0.5 * torch.sum( 63 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 64 | dim=dims) 65 | 66 | def mode(self): 67 | return self.mean 68 | 69 | 70 | def normal_kl(mean1, logvar1, mean2, logvar2): 71 | """ 72 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 73 | Compute the KL divergence between two gaussians. 74 | Shapes are automatically broadcasted, so batches can be compared to 75 | scalars, among other use cases. 76 | """ 77 | tensor = None 78 | for obj in (mean1, logvar1, mean2, logvar2): 79 | if isinstance(obj, torch.Tensor): 80 | tensor = obj 81 | break 82 | assert tensor is not None, "at least one argument must be a Tensor" 83 | 84 | # Force variances to be Tensors. Broadcasting helps convert scalars to 85 | # Tensors, but it does not work for torch.exp(). 86 | logvar1, logvar2 = [ 87 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 88 | for x in (logvar1, logvar2) 89 | ] 90 | 91 | return 0.5 * ( 92 | -1.0 93 | + logvar2 94 | - logvar1 95 | + torch.exp(logvar1 - logvar2) 96 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 97 | ) 98 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1, dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | # remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.', '') 20 | self.m_name2s_name.update({name: s_name}) 21 | self.register_buffer(s_name, p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def reset_num_updates(self): 26 | del self.num_updates 27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 47 | else: 48 | assert not key in self.m_name2s_name 49 | 50 | def copy_to(self, model): 51 | m_param = dict(model.named_parameters()) 52 | shadow_params = dict(self.named_buffers()) 53 | for key in m_param: 54 | if m_param[key].requires_grad: 55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 56 | else: 57 | assert not key in self.m_name2s_name 58 | 59 | def store(self, parameters): 60 | """ 61 | Save the current parameters for restoring later. 62 | Args: 63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 64 | temporarily stored. 65 | """ 66 | self.collected_params = [param.clone() for param in parameters] 67 | 68 | def restore(self, parameters): 69 | """ 70 | Restore the parameters stored with the `store` method. 71 | Useful to validate the model with EMA parameters without affecting the 72 | original optimization process. Store the parameters before the 73 | `copy_to` method. After validation (or model saving), use this to 74 | restore the former parameters. 75 | Args: 76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 77 | updated with the stored parameters. 78 | """ 79 | for c_param, param in zip(self.collected_params, parameters): 80 | param.data.copy_(c_param.data) 81 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator 2 | from ldm.modules.losses.vqperceptual import VQLPIPSWithDiscriminator -------------------------------------------------------------------------------- /ldm/modules/losses/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/losses/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/losses/__pycache__/contperceptual.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/losses/__pycache__/contperceptual.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/losses/__pycache__/vqperceptual.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/losses/__pycache__/vqperceptual.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | #https://github.com/IceClear/StableSR/blob/main/ldm/modules/losses/contperceptual.py 7 | 8 | class LPIPSWithDiscriminator(nn.Module): 9 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 10 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 11 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 12 | disc_loss="hinge"): 13 | 14 | super().__init__() 15 | assert disc_loss in ["hinge", "vanilla"] 16 | self.kl_weight = kl_weight 17 | self.pixel_weight = pixelloss_weight 18 | self.perceptual_loss = LPIPS().eval() 19 | self.perceptual_weight = perceptual_weight 20 | # output log variance 21 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 22 | 23 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 24 | n_layers=disc_num_layers, 25 | use_actnorm=use_actnorm 26 | ).apply(weights_init) 27 | self.discriminator_iter_start = disc_start 28 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 29 | self.disc_factor = disc_factor 30 | self.discriminator_weight = disc_weight 31 | self.disc_conditional = disc_conditional 32 | 33 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 34 | if last_layer is not None: 35 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 36 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 37 | else: 38 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 39 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 40 | 41 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 42 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 43 | d_weight = d_weight * self.discriminator_weight 44 | return d_weight 45 | 46 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 47 | global_step, last_layer=None, cond=None, split="train", 48 | weights=None, return_dic=False): 49 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 50 | if self.perceptual_weight > 0: 51 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 52 | rec_loss = rec_loss + self.perceptual_weight * p_loss 53 | 54 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 55 | weighted_nll_loss = nll_loss 56 | if weights is not None: 57 | weighted_nll_loss = weights*nll_loss 58 | weighted_nll_loss = torch.mean(weighted_nll_loss) / weighted_nll_loss.shape[0] 59 | nll_loss = torch.mean(nll_loss) / nll_loss.shape[0] 60 | if self.kl_weight>0: 61 | kl_loss = posteriors.kl() 62 | kl_loss = torch.mean(kl_loss) / kl_loss.shape[0] 63 | 64 | # now the GAN part 65 | if optimizer_idx == 0: 66 | # generator update 67 | if cond is None: 68 | assert not self.disc_conditional 69 | logits_fake = self.discriminator(reconstructions.contiguous()) 70 | else: 71 | assert self.disc_conditional 72 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 73 | g_loss = -torch.mean(logits_fake) 74 | 75 | if self.disc_factor > 0.0: 76 | try: 77 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 78 | except RuntimeError: 79 | # assert not self.training 80 | d_weight = torch.tensor(1.0) * self.discriminator_weight 81 | else: 82 | # d_weight = torch.tensor(0.0) 83 | d_weight = torch.tensor(0.0) 84 | 85 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 86 | if self.kl_weight>0: 87 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 88 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 89 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 90 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 91 | "{}/d_weight".format(split): d_weight.detach(), 92 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 93 | "{}/g_loss".format(split): g_loss.detach().mean(), 94 | } 95 | if return_dic: 96 | loss_dic = {} 97 | loss_dic['total_loss'] = loss.clone().detach().mean() 98 | loss_dic['logvar'] = self.logvar.detach() 99 | loss_dic['kl_loss'] = kl_loss.detach().mean() 100 | loss_dic['nll_loss'] = nll_loss.detach().mean() 101 | loss_dic['rec_loss'] = rec_loss.detach().mean() 102 | loss_dic['d_weight'] = d_weight.detach() 103 | loss_dic['disc_factor'] = torch.tensor(disc_factor) 104 | loss_dic['g_loss'] = g_loss.detach().mean() 105 | else: 106 | loss = weighted_nll_loss + d_weight * disc_factor * g_loss 107 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 108 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 109 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 110 | "{}/d_weight".format(split): d_weight.detach(), 111 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 112 | "{}/g_loss".format(split): g_loss.detach().mean(), 113 | } 114 | if return_dic: 115 | loss_dic = {} 116 | loss_dic["{}/total_loss".format(split)] = loss.clone().detach().mean() 117 | loss_dic["{}/logvar".format(split)] = self.logvar.detach() 118 | loss_dic['nll_loss'.format(split)] = nll_loss.detach().mean() 119 | loss_dic['rec_loss'.format(split)] = rec_loss.detach().mean() 120 | loss_dic['d_weight'.format(split)] = d_weight.detach() 121 | loss_dic['disc_factor'.format(split)] = torch.tensor(disc_factor) 122 | loss_dic['g_loss'.format(split)] = g_loss.detach().mean() 123 | 124 | if return_dic: 125 | return loss, log, loss_dic 126 | return loss, log 127 | 128 | if optimizer_idx == 1: 129 | # second pass for discriminator update 130 | if cond is None: 131 | logits_real = self.discriminator(inputs.contiguous().detach()) 132 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 133 | else: 134 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 135 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 136 | 137 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 138 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 139 | 140 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 141 | "{}/logits_real".format(split): logits_real.detach().mean(), 142 | "{}/logits_fake".format(split): logits_fake.detach().mean() 143 | } 144 | 145 | if return_dic: 146 | loss_dic = {} 147 | loss_dic["{}/disc_loss".format(split)] = d_loss.clone().detach().mean() 148 | loss_dic["{}/logits_real".format(split)] = logits_real.detach().mean() 149 | loss_dic["{}/logits_fake".format(split)] = logits_fake.detach().mean() 150 | return d_loss, log, loss_dic 151 | 152 | return d_loss, log -------------------------------------------------------------------------------- /ldm/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from taming.modules.losses.lpips import LPIPS 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | 8 | 9 | class DummyLoss(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | 14 | def adopt_weight(weight, global_step, threshold=0, value=0.): 15 | if global_step < threshold: 16 | weight = value 17 | return weight 18 | 19 | 20 | def hinge_d_loss(logits_real, logits_fake): 21 | loss_real = torch.mean(F.relu(1. - logits_real)) 22 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 23 | d_loss = 0.5 * (loss_real + loss_fake) 24 | return d_loss 25 | 26 | 27 | def vanilla_d_loss(logits_real, logits_fake): 28 | d_loss = 0.5 * ( 29 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 30 | torch.mean(torch.nn.functional.softplus(logits_fake))) 31 | return d_loss 32 | 33 | 34 | class VQLPIPSWithDiscriminator(nn.Module): 35 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 36 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 37 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 38 | disc_ndf=64, disc_loss="hinge"): 39 | super().__init__() 40 | assert disc_loss in ["hinge", "vanilla"] 41 | self.codebook_weight = codebook_weight 42 | self.pixel_weight = pixelloss_weight 43 | self.perceptual_loss = LPIPS().eval() 44 | self.perceptual_weight = perceptual_weight 45 | 46 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 47 | n_layers=disc_num_layers, 48 | use_actnorm=use_actnorm, 49 | ndf=disc_ndf 50 | ).apply(weights_init) 51 | self.discriminator_iter_start = disc_start 52 | if disc_loss == "hinge": 53 | self.disc_loss = hinge_d_loss 54 | elif disc_loss == "vanilla": 55 | self.disc_loss = vanilla_d_loss 56 | else: 57 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 58 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 59 | self.disc_factor = disc_factor 60 | self.discriminator_weight = disc_weight 61 | self.disc_conditional = disc_conditional 62 | 63 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 64 | if last_layer is not None: 65 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 66 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 67 | else: 68 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 69 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 70 | 71 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 72 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 73 | d_weight = d_weight * self.discriminator_weight 74 | return d_weight 75 | 76 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 77 | global_step, last_layer=None, cond=None, split="train"): 78 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 79 | if self.perceptual_weight > 0: 80 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 81 | rec_loss = rec_loss + self.perceptual_weight * p_loss 82 | else: 83 | p_loss = torch.tensor([0.0]) 84 | 85 | nll_loss = rec_loss 86 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 87 | nll_loss = torch.mean(nll_loss) 88 | 89 | # now the GAN part 90 | if optimizer_idx == 0: 91 | # generator update 92 | if cond is None: 93 | assert not self.disc_conditional 94 | logits_fake = self.discriminator(reconstructions.contiguous()) 95 | else: 96 | assert self.disc_conditional 97 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 98 | g_loss = -torch.mean(logits_fake) 99 | 100 | try: 101 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 102 | except RuntimeError: 103 | assert not self.training 104 | d_weight = torch.tensor(0.0) 105 | 106 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 107 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 108 | 109 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 110 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 111 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 112 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 113 | "{}/p_loss".format(split): p_loss.detach().mean(), 114 | "{}/d_weight".format(split): d_weight.detach(), 115 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 116 | "{}/g_loss".format(split): g_loss.detach().mean(), 117 | } 118 | return loss, log 119 | 120 | if optimizer_idx == 1: 121 | # second pass for discriminator update 122 | if cond is None: 123 | logits_real = self.discriminator(inputs.contiguous().detach()) 124 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 125 | else: 126 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 127 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 128 | 129 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 130 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 131 | 132 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 133 | "{}/logits_real".format(split): logits_real.detach().mean(), 134 | "{}/logits_fake".format(split): logits_fake.detach().mean() 135 | } 136 | return d_loss, log -------------------------------------------------------------------------------- /ldm/modules/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/api.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/isl-org/MiDaS 2 | 3 | import cv2 4 | import torch 5 | import torch.nn as nn 6 | from torchvision.transforms import Compose 7 | 8 | from ldm.modules.midas.midas.dpt_depth import DPTDepthModel 9 | from ldm.modules.midas.midas.midas_net import MidasNet 10 | from ldm.modules.midas.midas.midas_net_custom import MidasNet_small 11 | from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet 12 | 13 | 14 | ISL_PATHS = { 15 | "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", 16 | "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", 17 | "midas_v21": "", 18 | "midas_v21_small": "", 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | def load_midas_transform(model_type): 29 | # https://github.com/isl-org/MiDaS/blob/master/run.py 30 | # load transform only 31 | if model_type == "dpt_large": # DPT-Large 32 | net_w, net_h = 384, 384 33 | resize_mode = "minimal" 34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 35 | 36 | elif model_type == "dpt_hybrid": # DPT-Hybrid 37 | net_w, net_h = 384, 384 38 | resize_mode = "minimal" 39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 40 | 41 | elif model_type == "midas_v21": 42 | net_w, net_h = 384, 384 43 | resize_mode = "upper_bound" 44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 45 | 46 | elif model_type == "midas_v21_small": 47 | net_w, net_h = 256, 256 48 | resize_mode = "upper_bound" 49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 50 | 51 | else: 52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large" 53 | 54 | transform = Compose( 55 | [ 56 | Resize( 57 | net_w, 58 | net_h, 59 | resize_target=None, 60 | keep_aspect_ratio=True, 61 | ensure_multiple_of=32, 62 | resize_method=resize_mode, 63 | image_interpolation_method=cv2.INTER_CUBIC, 64 | ), 65 | normalization, 66 | PrepareForNet(), 67 | ] 68 | ) 69 | 70 | return transform 71 | 72 | 73 | def load_model(model_type): 74 | # https://github.com/isl-org/MiDaS/blob/master/run.py 75 | # load network 76 | model_path = ISL_PATHS[model_type] 77 | if model_type == "dpt_large": # DPT-Large 78 | model = DPTDepthModel( 79 | path=model_path, 80 | backbone="vitl16_384", 81 | non_negative=True, 82 | ) 83 | net_w, net_h = 384, 384 84 | resize_mode = "minimal" 85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 86 | 87 | elif model_type == "dpt_hybrid": # DPT-Hybrid 88 | model = DPTDepthModel( 89 | path=model_path, 90 | backbone="vitb_rn50_384", 91 | non_negative=True, 92 | ) 93 | net_w, net_h = 384, 384 94 | resize_mode = "minimal" 95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 96 | 97 | elif model_type == "midas_v21": 98 | model = MidasNet(model_path, non_negative=True) 99 | net_w, net_h = 384, 384 100 | resize_mode = "upper_bound" 101 | normalization = NormalizeImage( 102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 103 | ) 104 | 105 | elif model_type == "midas_v21_small": 106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, 107 | non_negative=True, blocks={'expand': True}) 108 | net_w, net_h = 256, 256 109 | resize_mode = "upper_bound" 110 | normalization = NormalizeImage( 111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 112 | ) 113 | 114 | else: 115 | print(f"model_type '{model_type}' not implemented, use: --model_type large") 116 | assert False 117 | 118 | transform = Compose( 119 | [ 120 | Resize( 121 | net_w, 122 | net_h, 123 | resize_target=None, 124 | keep_aspect_ratio=True, 125 | ensure_multiple_of=32, 126 | resize_method=resize_mode, 127 | image_interpolation_method=cv2.INTER_CUBIC, 128 | ), 129 | normalization, 130 | PrepareForNet(), 131 | ] 132 | ) 133 | 134 | return model.eval(), transform 135 | 136 | 137 | class MiDaSInference(nn.Module): 138 | MODEL_TYPES_TORCH_HUB = [ 139 | "DPT_Large", 140 | "DPT_Hybrid", 141 | "MiDaS_small" 142 | ] 143 | MODEL_TYPES_ISL = [ 144 | "dpt_large", 145 | "dpt_hybrid", 146 | "midas_v21", 147 | "midas_v21_small", 148 | ] 149 | 150 | def __init__(self, model_type): 151 | super().__init__() 152 | assert (model_type in self.MODEL_TYPES_ISL) 153 | model, _ = load_model(model_type) 154 | self.model = model 155 | self.model.train = disabled_train 156 | 157 | def forward(self, x): 158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array 159 | # NOTE: we expect that the correct transform has been called during dataloading. 160 | with torch.no_grad(): 161 | prediction = self.model(x) 162 | prediction = torch.nn.functional.interpolate( 163 | prediction.unsqueeze(1), 164 | size=x.shape[2:], 165 | mode="bicubic", 166 | align_corners=False, 167 | ) 168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) 169 | return prediction 170 | 171 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/midas/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /ldm/modules/midas/midas/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.float32), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height). 50 | """ 51 | 52 | def __init__( 53 | self, 54 | width, 55 | height, 56 | resize_target=True, 57 | keep_aspect_ratio=False, 58 | ensure_multiple_of=1, 59 | resize_method="lower_bound", 60 | image_interpolation_method=cv2.INTER_AREA, 61 | ): 62 | """Init. 63 | 64 | Args: 65 | width (int): desired output width 66 | height (int): desired output height 67 | resize_target (bool, optional): 68 | True: Resize the full sample (image, mask, target). 69 | False: Resize image only. 70 | Defaults to True. 71 | keep_aspect_ratio (bool, optional): 72 | True: Keep the aspect ratio of the input sample. 73 | Output sample might not have the given width and height, and 74 | resize behaviour depends on the parameter 'resize_method'. 75 | Defaults to False. 76 | ensure_multiple_of (int, optional): 77 | Output width and height is constrained to be multiple of this parameter. 78 | Defaults to 1. 79 | resize_method (str, optional): 80 | "lower_bound": Output will be at least as large as the given size. 81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 83 | Defaults to "lower_bound". 84 | """ 85 | self.__width = width 86 | self.__height = height 87 | 88 | self.__resize_target = resize_target 89 | self.__keep_aspect_ratio = keep_aspect_ratio 90 | self.__multiple_of = ensure_multiple_of 91 | self.__resize_method = resize_method 92 | self.__image_interpolation_method = image_interpolation_method 93 | 94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 96 | 97 | if max_val is not None and y > max_val: 98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 99 | 100 | if y < min_val: 101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 102 | 103 | return y 104 | 105 | def get_size(self, width, height): 106 | # determine new height and width 107 | scale_height = self.__height / height 108 | scale_width = self.__width / width 109 | 110 | if self.__keep_aspect_ratio: 111 | if self.__resize_method == "lower_bound": 112 | # scale such that output size is lower bound 113 | if scale_width > scale_height: 114 | # fit width 115 | scale_height = scale_width 116 | else: 117 | # fit height 118 | scale_width = scale_height 119 | elif self.__resize_method == "upper_bound": 120 | # scale such that output size is upper bound 121 | if scale_width < scale_height: 122 | # fit width 123 | scale_height = scale_width 124 | else: 125 | # fit height 126 | scale_width = scale_height 127 | elif self.__resize_method == "minimal": 128 | # scale as least as possbile 129 | if abs(1 - scale_width) < abs(1 - scale_height): 130 | # fit width 131 | scale_height = scale_width 132 | else: 133 | # fit height 134 | scale_width = scale_height 135 | else: 136 | raise ValueError( 137 | f"resize_method {self.__resize_method} not implemented" 138 | ) 139 | 140 | if self.__resize_method == "lower_bound": 141 | new_height = self.constrain_to_multiple_of( 142 | scale_height * height, min_val=self.__height 143 | ) 144 | new_width = self.constrain_to_multiple_of( 145 | scale_width * width, min_val=self.__width 146 | ) 147 | elif self.__resize_method == "upper_bound": 148 | new_height = self.constrain_to_multiple_of( 149 | scale_height * height, max_val=self.__height 150 | ) 151 | new_width = self.constrain_to_multiple_of( 152 | scale_width * width, max_val=self.__width 153 | ) 154 | elif self.__resize_method == "minimal": 155 | new_height = self.constrain_to_multiple_of(scale_height * height) 156 | new_width = self.constrain_to_multiple_of(scale_width * width) 157 | else: 158 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 159 | 160 | return (new_width, new_height) 161 | 162 | def __call__(self, sample): 163 | width, height = self.get_size( 164 | sample["image"].shape[1], sample["image"].shape[0] 165 | ) 166 | 167 | # resize sample 168 | sample["image"] = cv2.resize( 169 | sample["image"], 170 | (width, height), 171 | interpolation=self.__image_interpolation_method, 172 | ) 173 | 174 | if self.__resize_target: 175 | if "disparity" in sample: 176 | sample["disparity"] = cv2.resize( 177 | sample["disparity"], 178 | (width, height), 179 | interpolation=cv2.INTER_NEAREST, 180 | ) 181 | 182 | if "depth" in sample: 183 | sample["depth"] = cv2.resize( 184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 185 | ) 186 | 187 | sample["mask"] = cv2.resize( 188 | sample["mask"].astype(np.float32), 189 | (width, height), 190 | interpolation=cv2.INTER_NEAREST, 191 | ) 192 | sample["mask"] = sample["mask"].astype(bool) 193 | 194 | return sample 195 | 196 | 197 | class NormalizeImage(object): 198 | """Normlize image by given mean and std. 199 | """ 200 | 201 | def __init__(self, mean, std): 202 | self.__mean = mean 203 | self.__std = std 204 | 205 | def __call__(self, sample): 206 | sample["image"] = (sample["image"] - self.__mean) / self.__std 207 | 208 | return sample 209 | 210 | 211 | class PrepareForNet(object): 212 | """Prepare sample for usage as network input. 213 | """ 214 | 215 | def __init__(self): 216 | pass 217 | 218 | def __call__(self, sample): 219 | image = np.transpose(sample["image"], (2, 0, 1)) 220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 221 | 222 | if "mask" in sample: 223 | sample["mask"] = sample["mask"].astype(np.float32) 224 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 225 | 226 | if "disparity" in sample: 227 | disparity = sample["disparity"].astype(np.float32) 228 | sample["disparity"] = np.ascontiguousarray(disparity) 229 | 230 | if "depth" in sample: 231 | depth = sample["depth"].astype(np.float32) 232 | sample["depth"] = np.ascontiguousarray(depth) 233 | 234 | return sample 235 | -------------------------------------------------------------------------------- /ldm/modules/midas/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth.""" 2 | import sys 3 | import re 4 | import numpy as np 5 | import cv2 6 | import torch 7 | 8 | 9 | def read_pfm(path): 10 | """Read pfm file. 11 | 12 | Args: 13 | path (str): path to file 14 | 15 | Returns: 16 | tuple: (data, scale) 17 | """ 18 | with open(path, "rb") as file: 19 | 20 | color = None 21 | width = None 22 | height = None 23 | scale = None 24 | endian = None 25 | 26 | header = file.readline().rstrip() 27 | if header.decode("ascii") == "PF": 28 | color = True 29 | elif header.decode("ascii") == "Pf": 30 | color = False 31 | else: 32 | raise Exception("Not a PFM file: " + path) 33 | 34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 35 | if dim_match: 36 | width, height = list(map(int, dim_match.groups())) 37 | else: 38 | raise Exception("Malformed PFM header.") 39 | 40 | scale = float(file.readline().decode("ascii").rstrip()) 41 | if scale < 0: 42 | # little-endian 43 | endian = "<" 44 | scale = -scale 45 | else: 46 | # big-endian 47 | endian = ">" 48 | 49 | data = np.fromfile(file, endian + "f") 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | 55 | return data, scale 56 | 57 | 58 | def write_pfm(path, image, scale=1): 59 | """Write pfm file. 60 | 61 | Args: 62 | path (str): pathto file 63 | image (array): data 64 | scale (int, optional): Scale. Defaults to 1. 65 | """ 66 | 67 | with open(path, "wb") as file: 68 | color = None 69 | 70 | if image.dtype.name != "float32": 71 | raise Exception("Image dtype must be float32.") 72 | 73 | image = np.flipud(image) 74 | 75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 76 | color = True 77 | elif ( 78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 79 | ): # greyscale 80 | color = False 81 | else: 82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 83 | 84 | file.write("PF\n" if color else "Pf\n".encode()) 85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 86 | 87 | endian = image.dtype.byteorder 88 | 89 | if endian == "<" or endian == "=" and sys.byteorder == "little": 90 | scale = -scale 91 | 92 | file.write("%f\n".encode() % scale) 93 | 94 | image.tofile(file) 95 | 96 | 97 | def read_image(path): 98 | """Read image and output RGB image (0-1). 99 | 100 | Args: 101 | path (str): path to file 102 | 103 | Returns: 104 | array: RGB image (0-1) 105 | """ 106 | img = cv2.imread(path) 107 | 108 | if img.ndim == 2: 109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 110 | 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 112 | 113 | return img 114 | 115 | 116 | def resize_image(img): 117 | """Resize image and make it fit for network. 118 | 119 | Args: 120 | img (array): image 121 | 122 | Returns: 123 | tensor: data ready for network 124 | """ 125 | height_orig = img.shape[0] 126 | width_orig = img.shape[1] 127 | 128 | if width_orig > height_orig: 129 | scale = width_orig / 384 130 | else: 131 | scale = height_orig / 384 132 | 133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 135 | 136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 137 | 138 | img_resized = ( 139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 140 | ) 141 | img_resized = img_resized.unsqueeze(0) 142 | 143 | return img_resized 144 | 145 | 146 | def resize_depth(depth, width, height): 147 | """Resize depth map and bring to CPU (numpy). 148 | 149 | Args: 150 | depth (tensor): depth 151 | width (int): image width 152 | height (int): image height 153 | 154 | Returns: 155 | array: processed depth 156 | """ 157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 158 | 159 | depth_resized = cv2.resize( 160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 161 | ) 162 | 163 | return depth_resized 164 | 165 | def write_depth(path, depth, bits=1): 166 | """Write depth map to pfm and png file. 167 | 168 | Args: 169 | path (str): filepath without extension 170 | depth (array): depth 171 | """ 172 | write_pfm(path + ".pfm", depth.astype(np.float32)) 173 | 174 | depth_min = depth.min() 175 | depth_max = depth.max() 176 | 177 | max_val = (2**(8*bits))-1 178 | 179 | if depth_max - depth_min > np.finfo("float").eps: 180 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 181 | else: 182 | out = np.zeros(depth.shape, dtype=depth.type) 183 | 184 | if bits == 1: 185 | cv2.imwrite(path + ".png", out.astype("uint8")) 186 | elif bits == 2: 187 | cv2.imwrite(path + ".png", out.astype("uint16")) 188 | 189 | return 190 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | from torch import optim 5 | import numpy as np 6 | 7 | from inspect import isfunction 8 | from PIL import Image, ImageDraw, ImageFont 9 | from torchvision import transforms 10 | import cv2 11 | 12 | # def get_hint_image(image_withmask): 13 | # image=(image_withmask.squeeze(0)[:3,:,:]+1.)/2. 14 | # image_gray=cv2.cvtColor(np.asarray(image.permute(1,2,0).cpu()),cv2.COLOR_RGB2LAB)[:,:,0] 15 | # image_gray = torch.from_numpy(cv2.merge([image_gray,image_gray,image_gray])).permute(2,0,1) 16 | # mask=(image_withmask.squeeze(0)[3,:,:]+1.)/2. 17 | # H,W=mask.shape 18 | # for i in range(H): 19 | # for j in range(W): 20 | # if mask[i,j]==0: 21 | # image[:,i,j]=image_gray[:,i,j] #torch.mean(image[:,i,j]) #image_gray[:,i,j] 22 | # return image 23 | 24 | def get_hint_image(image,image_gray,mask): 25 | # image=(image_withmask.squeeze(0)[:3,:,:]+1.)/2. 26 | # image_gray=cv2.cvtColor(np.asarray(image.permute(1,2,0).cpu()),cv2.COLOR_RGB2LAB)[:,:,0] 27 | # image_gray = torch.from_numpy(cv2.merge([image_gray,image_gray,image_gray])).permute(2,0,1) 28 | # mask=(image_withmask.squeeze(0)[3,:,:]+1.)/2. 29 | image=np.array(image.copy()) 30 | image_gray=np.array(image_gray.copy()) 31 | H,W=mask.shape 32 | for i in range(H): 33 | for j in range(W): 34 | if mask[i,j]==0: 35 | image[i,j]=image_gray[i,j] #torch.mean(image[:,i,j]) #image_gray[:,i,j] 36 | return Image.fromarray(image) 37 | 38 | def log_txt_as_img(wh,masked_image, xc, size=10): 39 | # wh a tuple of (width, height) 40 | # xc a list of captions to plot 41 | xc=xc 42 | b = len(xc) 43 | txts = list() 44 | for bi in range(b): 45 | txt = Image.new("RGB", wh, color="white") 46 | # image=(image_withmask.squeeze(0)[:3,:,:]+1.)/2. 47 | # mask=(image_withmask.squeeze(0)[3,:,:]+1.)/2. 48 | # image=(image_withmask+1.)/2. 49 | # # image = get_hint_image(image_withmask) 50 | # # print(image.shape) 51 | # image_target=transforms.ToPILImage()(image.squeeze(0)).convert("RGB") 52 | # # image_gray=transforms.ToPILImage()(image).convert("L") 53 | image=(masked_image.squeeze(0)+1.)/2. 54 | image_target=transforms.ToPILImage()(image.squeeze(0)).convert("RGB") 55 | txt = image_target#get_hint_image(image_target,image_gray,mask) 56 | draw = ImageDraw.Draw(txt) 57 | font = ImageFont.truetype('font/DejaVuSans.ttf', size=size) 58 | nc = int(40 * (wh[0] / 256)) 59 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 60 | 61 | try: 62 | draw.text((0, 0), lines, fill="black", font=font) 63 | except UnicodeEncodeError: 64 | print("Cant encode string for logging. Skipping.") 65 | 66 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 67 | txts.append(txt) 68 | txts = np.stack(txts) 69 | txts = torch.tensor(txts) 70 | return txts 71 | 72 | 73 | def ismap(x): 74 | if not isinstance(x, torch.Tensor): 75 | return False 76 | return (len(x.shape) == 4) and (x.shape[1] > 3) 77 | 78 | 79 | def isimage(x): 80 | if not isinstance(x,torch.Tensor): 81 | return False 82 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 83 | 84 | 85 | def exists(x): 86 | return x is not None 87 | 88 | 89 | def default(val, d): 90 | if exists(val): 91 | return val 92 | return d() if isfunction(d) else d 93 | 94 | 95 | def mean_flat(tensor): 96 | """ 97 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 98 | Take the mean over all non-batch dimensions. 99 | """ 100 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 101 | 102 | 103 | def count_params(model, verbose=False): 104 | total_params = sum(p.numel() for p in model.parameters()) 105 | if verbose: 106 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 107 | return total_params 108 | 109 | 110 | def instantiate_from_config(config): 111 | if not "target" in config: 112 | if not config == '__is_first_stage__':#changed for only training vae 113 | return None 114 | # elif config == "__is_unconditional__":#changed for only training vae 115 | # return None 116 | raise KeyError("Expected key `target` to instantiate.") 117 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 118 | 119 | 120 | def get_obj_from_str(string, reload=False): 121 | module, cls = string.rsplit(".", 1) 122 | if reload: 123 | module_imp = importlib.import_module(module) 124 | importlib.reload(module_imp) 125 | return getattr(importlib.import_module(module, package=None), cls) 126 | 127 | 128 | class AdamWwithEMAandWings(optim.Optimizer): 129 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 130 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using 131 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code 132 | ema_power=1., param_names=()): 133 | """AdamW that saves EMA versions of the parameters.""" 134 | if not 0.0 <= lr: 135 | raise ValueError("Invalid learning rate: {}".format(lr)) 136 | if not 0.0 <= eps: 137 | raise ValueError("Invalid epsilon value: {}".format(eps)) 138 | if not 0.0 <= betas[0] < 1.0: 139 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 140 | if not 0.0 <= betas[1] < 1.0: 141 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 142 | if not 0.0 <= weight_decay: 143 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 144 | if not 0.0 <= ema_decay <= 1.0: 145 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) 146 | defaults = dict(lr=lr, betas=betas, eps=eps, 147 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, 148 | ema_power=ema_power, param_names=param_names) 149 | super().__init__(params, defaults) 150 | 151 | def __setstate__(self, state): 152 | super().__setstate__(state) 153 | for group in self.param_groups: 154 | group.setdefault('amsgrad', False) 155 | 156 | @torch.no_grad() 157 | def step(self, closure=None): 158 | """Performs a single optimization step. 159 | Args: 160 | closure (callable, optional): A closure that reevaluates the model 161 | and returns the loss. 162 | """ 163 | loss = None 164 | if closure is not None: 165 | with torch.enable_grad(): 166 | loss = closure() 167 | 168 | for group in self.param_groups: 169 | params_with_grad = [] 170 | grads = [] 171 | exp_avgs = [] 172 | exp_avg_sqs = [] 173 | ema_params_with_grad = [] 174 | state_sums = [] 175 | max_exp_avg_sqs = [] 176 | state_steps = [] 177 | amsgrad = group['amsgrad'] 178 | beta1, beta2 = group['betas'] 179 | ema_decay = group['ema_decay'] 180 | ema_power = group['ema_power'] 181 | 182 | for p in group['params']: 183 | if p.grad is None: 184 | continue 185 | params_with_grad.append(p) 186 | if p.grad.is_sparse: 187 | raise RuntimeError('AdamW does not support sparse gradients') 188 | grads.append(p.grad) 189 | 190 | state = self.state[p] 191 | 192 | # State initialization 193 | if len(state) == 0: 194 | state['step'] = 0 195 | # Exponential moving average of gradient values 196 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 197 | # Exponential moving average of squared gradient values 198 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 199 | if amsgrad: 200 | # Maintains max of all exp. moving avg. of sq. grad. values 201 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 202 | # Exponential moving average of parameter values 203 | state['param_exp_avg'] = p.detach().float().clone() 204 | 205 | exp_avgs.append(state['exp_avg']) 206 | exp_avg_sqs.append(state['exp_avg_sq']) 207 | ema_params_with_grad.append(state['param_exp_avg']) 208 | 209 | if amsgrad: 210 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 211 | 212 | # update the steps for each param group update 213 | state['step'] += 1 214 | # record the step after step update 215 | state_steps.append(state['step']) 216 | 217 | optim._functional.adamw(params_with_grad, 218 | grads, 219 | exp_avgs, 220 | exp_avg_sqs, 221 | max_exp_avg_sqs, 222 | state_steps, 223 | amsgrad=amsgrad, 224 | beta1=beta1, 225 | beta2=beta2, 226 | lr=group['lr'], 227 | weight_decay=group['weight_decay'], 228 | eps=group['eps'], 229 | maximize=False) 230 | 231 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) 232 | for param, ema_param in zip(params_with_grad, ema_params_with_grad): 233 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) 234 | 235 | return loss -------------------------------------------------------------------------------- /models/cldm_v15_inpainting_infer.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: cldm.cldm.ControlLDM 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | num_timesteps_cond: 1 7 | log_every_t: 200 8 | timesteps: 1000 9 | first_stage_key: "jpg" 10 | cond_stage_key: "txt" 11 | control_key: "hint" 12 | masked_image: "mask_img" 13 | mask: "mask" 14 | image_size: 64 15 | channels: 4 16 | cond_stage_trainable: false 17 | conditioning_key: crossattn 18 | monitor: val/loss_simple_ema 19 | scale_factor: 0.18215 20 | use_ema: False 21 | only_mid_control: False 22 | load_loss: False 23 | 24 | control_stage_config: 25 | target: cldm.cldm.ControlNet 26 | params: 27 | image_size: 32 # unused 28 | in_channels: 4 29 | hint_channels: 3 30 | model_channels: 320 31 | attention_resolutions: [ 4, 2, 1 ] 32 | num_res_blocks: 2 33 | channel_mult: [ 1, 2, 4, 4 ] 34 | num_heads: 8 35 | use_spatial_transformer: True 36 | transformer_depth: 1 37 | context_dim: 768 38 | use_checkpoint: True 39 | legacy: False 40 | 41 | unet_config: 42 | target: cldm.cldm.ControlledUnetModel 43 | params: 44 | image_size: 32 # unused 45 | in_channels: 9 46 | out_channels: 4 47 | model_channels: 320 48 | attention_resolutions: [ 4, 2, 1 ] 49 | num_res_blocks: 2 50 | channel_mult: [ 1, 2, 4, 4 ] 51 | num_heads: 8 52 | use_spatial_transformer: True 53 | transformer_depth: 1 54 | context_dim: 768 55 | use_checkpoint: True 56 | legacy: False 57 | 58 | first_stage_config: 59 | target: ldm.models.autoencoder.AutoencoderKL 60 | params: 61 | embed_dim: 4 62 | monitor: val/rec_loss 63 | ddconfig: 64 | double_z: true 65 | z_channels: 4 66 | resolution: 256 67 | in_channels: 3 68 | out_ch: 3 69 | ch: 128 70 | ch_mult: 71 | - 1 72 | - 2 73 | - 4 74 | - 4 75 | num_res_blocks: 2 76 | attn_resolutions: [] 77 | dropout: 0.0 78 | lossconfig: 79 | target: torch.nn.Identity 80 | 81 | contextual_stage_config: 82 | target: models_deep_exp.NonlocalNet.VGG19_pytorch 83 | 84 | cond_stage_config: 85 | # target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 86 | target: ldm.modules.encoders.modules.FrozenCLIPDualEmbedder 87 | #ldm.modules.encoders.modules.FrozenCLIPDualEmbedder 88 | -------------------------------------------------------------------------------- /models/cldm_v15_inpainting_infer1.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: cldm.cldm.ControlLDM 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | num_timesteps_cond: 1 7 | log_every_t: 200 8 | timesteps: 1000 9 | first_stage_key: "jpg" 10 | cond_stage_key: "txt" 11 | control_key: "hint" 12 | masked_image: "mask_img" 13 | mask: "mask" 14 | image_size: 64 15 | channels: 4 16 | cond_stage_trainable: false 17 | conditioning_key: crossattn 18 | monitor: val/loss_simple_ema 19 | scale_factor: 0.18215 20 | use_ema: False 21 | only_mid_control: False 22 | load_loss: False 23 | 24 | control_stage_config: 25 | target: cldm.cldm.ControlNet 26 | params: 27 | image_size: 32 # unused 28 | in_channels: 4 29 | hint_channels: 3 30 | model_channels: 320 31 | attention_resolutions: [ 4, 2, 1 ] 32 | num_res_blocks: 2 33 | channel_mult: [ 1, 2, 4, 4 ] 34 | num_heads: 8 35 | use_spatial_transformer: True 36 | transformer_depth: 1 37 | context_dim: 768 38 | use_checkpoint: True 39 | legacy: False 40 | 41 | unet_config: 42 | target: cldm.cldm.ControlledUnetModel 43 | params: 44 | image_size: 32 # unused 45 | in_channels: 9 46 | out_channels: 4 47 | model_channels: 320 48 | attention_resolutions: [ 4, 2, 1 ] 49 | num_res_blocks: 2 50 | channel_mult: [ 1, 2, 4, 4 ] 51 | num_heads: 8 52 | use_spatial_transformer: True 53 | transformer_depth: 1 54 | context_dim: 768 55 | use_checkpoint: True 56 | legacy: False 57 | 58 | first_stage_config: 59 | target: ldm.models.autoencoder.AutoencoderKL 60 | params: 61 | embed_dim: 4 62 | monitor: val/rec_loss 63 | ddconfig: 64 | double_z: true 65 | z_channels: 4 66 | resolution: 256 67 | in_channels: 3 68 | out_ch: 3 69 | ch: 128 70 | ch_mult: 71 | - 1 72 | - 2 73 | - 4 74 | - 4 75 | num_res_blocks: 2 76 | attn_resolutions: [] 77 | dropout: 0.0 78 | lossconfig: 79 | target: torch.nn.Identity 80 | 81 | contextual_stage_config: 82 | target: models_deep_exp.NonlocalNet.VGG19_pytorch 83 | 84 | cond_stage_config: 85 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 86 | # target: ldm.modules.encoders.modules.FrozenCLIPDualEmbedder 87 | #ldm.modules.encoders.modules.FrozenCLIPDualEmbedder 88 | -------------------------------------------------------------------------------- /pretrained_models/put_ckpts_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/pretrained_models/put_ckpts_here.txt -------------------------------------------------------------------------------- /share.py: -------------------------------------------------------------------------------- 1 | import config 2 | from cldm.hack import disable_verbosity, enable_sliced_attention 3 | 4 | 5 | disable_verbosity() 6 | 7 | if config.save_memory: 8 | enable_sliced_attention() 9 | -------------------------------------------------------------------------------- /taming/data/ade20k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import albumentations 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | from taming.data.sflckr import SegmentationBase # for examples included in repo 9 | 10 | 11 | class Examples(SegmentationBase): 12 | def __init__(self, size=256, random_crop=False, interpolation="bicubic"): 13 | super().__init__(data_csv="data/ade20k_examples.txt", 14 | data_root="data/ade20k_images", 15 | segmentation_root="data/ade20k_segmentations", 16 | size=size, random_crop=random_crop, 17 | interpolation=interpolation, 18 | n_labels=151, shift_segmentation=False) 19 | 20 | 21 | # With semantic map and scene label 22 | class ADE20kBase(Dataset): 23 | def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None): 24 | self.split = self.get_split() 25 | self.n_labels = 151 # unknown + 150 26 | self.data_csv = {"train": "data/ade20k_train.txt", 27 | "validation": "data/ade20k_test.txt"}[self.split] 28 | self.data_root = "data/ade20k_root" 29 | with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f: 30 | self.scene_categories = f.read().splitlines() 31 | self.scene_categories = dict(line.split() for line in self.scene_categories) 32 | with open(self.data_csv, "r") as f: 33 | self.image_paths = f.read().splitlines() 34 | self._length = len(self.image_paths) 35 | self.labels = { 36 | "relative_file_path_": [l for l in self.image_paths], 37 | "file_path_": [os.path.join(self.data_root, "images", l) 38 | for l in self.image_paths], 39 | "relative_segmentation_path_": [l.replace(".jpg", ".png") 40 | for l in self.image_paths], 41 | "segmentation_path_": [os.path.join(self.data_root, "annotations", 42 | l.replace(".jpg", ".png")) 43 | for l in self.image_paths], 44 | "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")] 45 | for l in self.image_paths], 46 | } 47 | 48 | size = None if size is not None and size<=0 else size 49 | self.size = size 50 | if crop_size is None: 51 | self.crop_size = size if size is not None else None 52 | else: 53 | self.crop_size = crop_size 54 | if self.size is not None: 55 | self.interpolation = interpolation 56 | self.interpolation = { 57 | "nearest": cv2.INTER_NEAREST, 58 | "bilinear": cv2.INTER_LINEAR, 59 | "bicubic": cv2.INTER_CUBIC, 60 | "area": cv2.INTER_AREA, 61 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] 62 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 63 | interpolation=self.interpolation) 64 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 65 | interpolation=cv2.INTER_NEAREST) 66 | 67 | if crop_size is not None: 68 | self.center_crop = not random_crop 69 | if self.center_crop: 70 | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) 71 | else: 72 | self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) 73 | self.preprocessor = self.cropper 74 | 75 | def __len__(self): 76 | return self._length 77 | 78 | def __getitem__(self, i): 79 | example = dict((k, self.labels[k][i]) for k in self.labels) 80 | image = Image.open(example["file_path_"]) 81 | if not image.mode == "RGB": 82 | image = image.convert("RGB") 83 | image = np.array(image).astype(np.uint8) 84 | if self.size is not None: 85 | image = self.image_rescaler(image=image)["image"] 86 | segmentation = Image.open(example["segmentation_path_"]) 87 | segmentation = np.array(segmentation).astype(np.uint8) 88 | if self.size is not None: 89 | segmentation = self.segmentation_rescaler(image=segmentation)["image"] 90 | if self.size is not None: 91 | processed = self.preprocessor(image=image, mask=segmentation) 92 | else: 93 | processed = {"image": image, "mask": segmentation} 94 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) 95 | segmentation = processed["mask"] 96 | onehot = np.eye(self.n_labels)[segmentation] 97 | example["segmentation"] = onehot 98 | return example 99 | 100 | 101 | class ADE20kTrain(ADE20kBase): 102 | # default to random_crop=True 103 | def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None): 104 | super().__init__(config=config, size=size, random_crop=random_crop, 105 | interpolation=interpolation, crop_size=crop_size) 106 | 107 | def get_split(self): 108 | return "train" 109 | 110 | 111 | class ADE20kValidation(ADE20kBase): 112 | def get_split(self): 113 | return "validation" 114 | 115 | 116 | if __name__ == "__main__": 117 | dset = ADE20kValidation() 118 | ex = dset[0] 119 | for k in ["image", "scene_category", "segmentation"]: 120 | print(type(ex[k])) 121 | try: 122 | print(ex[k].shape) 123 | except: 124 | print(ex[k]) 125 | -------------------------------------------------------------------------------- /taming/data/annotated_objects_coco.py: -------------------------------------------------------------------------------- 1 | import json 2 | from itertools import chain 3 | from pathlib import Path 4 | from typing import Iterable, Dict, List, Callable, Any 5 | from collections import defaultdict 6 | 7 | from tqdm import tqdm 8 | 9 | from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset 10 | from taming.data.helper_types import Annotation, ImageDescription, Category 11 | 12 | COCO_PATH_STRUCTURE = { 13 | 'train': { 14 | 'top_level': '', 15 | 'instances_annotations': 'annotations/instances_train2017.json', 16 | 'stuff_annotations': 'annotations/stuff_train2017.json', 17 | 'files': 'train2017' 18 | }, 19 | 'validation': { 20 | 'top_level': '', 21 | 'instances_annotations': 'annotations/instances_val2017.json', 22 | 'stuff_annotations': 'annotations/stuff_val2017.json', 23 | 'files': 'val2017' 24 | } 25 | } 26 | 27 | 28 | def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]: 29 | return { 30 | str(img['id']): ImageDescription( 31 | id=img['id'], 32 | license=img.get('license'), 33 | file_name=img['file_name'], 34 | coco_url=img['coco_url'], 35 | original_size=(img['width'], img['height']), 36 | date_captured=img.get('date_captured'), 37 | flickr_url=img.get('flickr_url') 38 | ) 39 | for img in description_json 40 | } 41 | 42 | 43 | def load_categories(category_json: Iterable) -> Dict[str, Category]: 44 | return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name']) 45 | for cat in category_json if cat['name'] != 'other'} 46 | 47 | 48 | def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription], 49 | category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]: 50 | annotations = defaultdict(list) 51 | total = sum(len(a) for a in annotations_json) 52 | for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total): 53 | image_id = str(ann['image_id']) 54 | if image_id not in image_descriptions: 55 | raise ValueError(f'image_id [{image_id}] has no image description.') 56 | category_id = ann['category_id'] 57 | try: 58 | category_no = category_no_for_id(str(category_id)) 59 | except KeyError: 60 | continue 61 | 62 | width, height = image_descriptions[image_id].original_size 63 | bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height) 64 | 65 | annotations[image_id].append( 66 | Annotation( 67 | id=ann['id'], 68 | area=bbox[2]*bbox[3], # use bbox area 69 | is_group_of=ann['iscrowd'], 70 | image_id=ann['image_id'], 71 | bbox=bbox, 72 | category_id=str(category_id), 73 | category_no=category_no 74 | ) 75 | ) 76 | return dict(annotations) 77 | 78 | 79 | class AnnotatedObjectsCoco(AnnotatedObjectsDataset): 80 | def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs): 81 | """ 82 | @param data_path: is the path to the following folder structure: 83 | coco/ 84 | ├── annotations 85 | │ ├── instances_train2017.json 86 | │ ├── instances_val2017.json 87 | │ ├── stuff_train2017.json 88 | │ └── stuff_val2017.json 89 | ├── train2017 90 | │ ├── 000000000009.jpg 91 | │ ├── 000000000025.jpg 92 | │ └── ... 93 | ├── val2017 94 | │ ├── 000000000139.jpg 95 | │ ├── 000000000285.jpg 96 | │ └── ... 97 | @param: split: one of 'train' or 'validation' 98 | @param: desired image size (give square images) 99 | """ 100 | super().__init__(**kwargs) 101 | self.use_things = use_things 102 | self.use_stuff = use_stuff 103 | 104 | with open(self.paths['instances_annotations']) as f: 105 | inst_data_json = json.load(f) 106 | with open(self.paths['stuff_annotations']) as f: 107 | stuff_data_json = json.load(f) 108 | 109 | category_jsons = [] 110 | annotation_jsons = [] 111 | if self.use_things: 112 | category_jsons.append(inst_data_json['categories']) 113 | annotation_jsons.append(inst_data_json['annotations']) 114 | if self.use_stuff: 115 | category_jsons.append(stuff_data_json['categories']) 116 | annotation_jsons.append(stuff_data_json['annotations']) 117 | 118 | self.categories = load_categories(chain(*category_jsons)) 119 | self.filter_categories() 120 | self.setup_category_id_and_number() 121 | 122 | self.image_descriptions = load_image_descriptions(inst_data_json['images']) 123 | annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split) 124 | self.annotations = self.filter_object_number(annotations, self.min_object_area, 125 | self.min_objects_per_image, self.max_objects_per_image) 126 | self.image_ids = list(self.annotations.keys()) 127 | self.clean_up_annotations_and_image_descriptions() 128 | 129 | def get_path_structure(self) -> Dict[str, str]: 130 | if self.split not in COCO_PATH_STRUCTURE: 131 | raise ValueError(f'Split [{self.split} does not exist for COCO data.]') 132 | return COCO_PATH_STRUCTURE[self.split] 133 | 134 | def get_image_path(self, image_id: str) -> Path: 135 | return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name) 136 | 137 | def get_image_description(self, image_id: str) -> Dict[str, Any]: 138 | # noinspection PyProtectedMember 139 | return self.image_descriptions[image_id]._asdict() 140 | -------------------------------------------------------------------------------- /taming/data/annotated_objects_open_images.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from csv import DictReader, reader as TupleReader 3 | from pathlib import Path 4 | from typing import Dict, List, Any 5 | import warnings 6 | 7 | from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset 8 | from taming.data.helper_types import Annotation, Category 9 | from tqdm import tqdm 10 | 11 | OPEN_IMAGES_STRUCTURE = { 12 | 'train': { 13 | 'top_level': '', 14 | 'class_descriptions': 'class-descriptions-boxable.csv', 15 | 'annotations': 'oidv6-train-annotations-bbox.csv', 16 | 'file_list': 'train-images-boxable.csv', 17 | 'files': 'train' 18 | }, 19 | 'validation': { 20 | 'top_level': '', 21 | 'class_descriptions': 'class-descriptions-boxable.csv', 22 | 'annotations': 'validation-annotations-bbox.csv', 23 | 'file_list': 'validation-images.csv', 24 | 'files': 'validation' 25 | }, 26 | 'test': { 27 | 'top_level': '', 28 | 'class_descriptions': 'class-descriptions-boxable.csv', 29 | 'annotations': 'test-annotations-bbox.csv', 30 | 'file_list': 'test-images.csv', 31 | 'files': 'test' 32 | } 33 | } 34 | 35 | 36 | def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str], 37 | category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]: 38 | annotations: Dict[str, List[Annotation]] = defaultdict(list) 39 | with open(descriptor_path) as file: 40 | reader = DictReader(file) 41 | for i, row in tqdm(enumerate(reader), total=14620000, desc='Loading OpenImages annotations'): 42 | width = float(row['XMax']) - float(row['XMin']) 43 | height = float(row['YMax']) - float(row['YMin']) 44 | area = width * height 45 | category_id = row['LabelName'] 46 | if category_id in category_mapping: 47 | category_id = category_mapping[category_id] 48 | if area >= min_object_area and category_id in category_no_for_id: 49 | annotations[row['ImageID']].append( 50 | Annotation( 51 | id=i, 52 | image_id=row['ImageID'], 53 | source=row['Source'], 54 | category_id=category_id, 55 | category_no=category_no_for_id[category_id], 56 | confidence=float(row['Confidence']), 57 | bbox=(float(row['XMin']), float(row['YMin']), width, height), 58 | area=area, 59 | is_occluded=bool(int(row['IsOccluded'])), 60 | is_truncated=bool(int(row['IsTruncated'])), 61 | is_group_of=bool(int(row['IsGroupOf'])), 62 | is_depiction=bool(int(row['IsDepiction'])), 63 | is_inside=bool(int(row['IsInside'])) 64 | ) 65 | ) 66 | if 'train' in str(descriptor_path) and i < 14000000: 67 | warnings.warn(f'Running with subset of Open Images. Train dataset has length [{len(annotations)}].') 68 | return dict(annotations) 69 | 70 | 71 | def load_image_ids(csv_path: Path) -> List[str]: 72 | with open(csv_path) as file: 73 | reader = DictReader(file) 74 | return [row['image_name'] for row in reader] 75 | 76 | 77 | def load_categories(csv_path: Path) -> Dict[str, Category]: 78 | with open(csv_path) as file: 79 | reader = TupleReader(file) 80 | return {row[0]: Category(id=row[0], name=row[1], super_category=None) for row in reader} 81 | 82 | 83 | class AnnotatedObjectsOpenImages(AnnotatedObjectsDataset): 84 | def __init__(self, use_additional_parameters: bool, **kwargs): 85 | """ 86 | @param data_path: is the path to the following folder structure: 87 | open_images/ 88 | │ oidv6-train-annotations-bbox.csv 89 | ├── class-descriptions-boxable.csv 90 | ├── oidv6-train-annotations-bbox.csv 91 | ├── test 92 | │ ├── 000026e7ee790996.jpg 93 | │ ├── 000062a39995e348.jpg 94 | │ └── ... 95 | ├── test-annotations-bbox.csv 96 | ├── test-images.csv 97 | ├── train 98 | │ ├── 000002b66c9c498e.jpg 99 | │ ├── 000002b97e5471a0.jpg 100 | │ └── ... 101 | ├── train-images-boxable.csv 102 | ├── validation 103 | │ ├── 0001eeaf4aed83f9.jpg 104 | │ ├── 0004886b7d043cfd.jpg 105 | │ └── ... 106 | ├── validation-annotations-bbox.csv 107 | └── validation-images.csv 108 | @param: split: one of 'train', 'validation' or 'test' 109 | @param: desired image size (returns square images) 110 | """ 111 | 112 | super().__init__(**kwargs) 113 | self.use_additional_parameters = use_additional_parameters 114 | 115 | self.categories = load_categories(self.paths['class_descriptions']) 116 | self.filter_categories() 117 | self.setup_category_id_and_number() 118 | 119 | self.image_descriptions = {} 120 | annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping, 121 | self.category_number) 122 | self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image, 123 | self.max_objects_per_image) 124 | self.image_ids = list(self.annotations.keys()) 125 | self.clean_up_annotations_and_image_descriptions() 126 | 127 | def get_path_structure(self) -> Dict[str, str]: 128 | if self.split not in OPEN_IMAGES_STRUCTURE: 129 | raise ValueError(f'Split [{self.split} does not exist for Open Images data.]') 130 | return OPEN_IMAGES_STRUCTURE[self.split] 131 | 132 | def get_image_path(self, image_id: str) -> Path: 133 | return self.paths['files'].joinpath(f'{image_id:0>16}.jpg') 134 | 135 | def get_image_description(self, image_id: str) -> Dict[str, Any]: 136 | image_path = self.get_image_path(image_id) 137 | return {'file_path': str(image_path), 'file_name': image_path.name} 138 | -------------------------------------------------------------------------------- /taming/data/base.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import numpy as np 3 | import albumentations 4 | from PIL import Image 5 | from torch.utils.data import Dataset, ConcatDataset 6 | 7 | 8 | class ConcatDatasetWithIndex(ConcatDataset): 9 | """Modified from original pytorch code to return dataset idx""" 10 | def __getitem__(self, idx): 11 | if idx < 0: 12 | if -idx > len(self): 13 | raise ValueError("absolute value of index should not exceed dataset length") 14 | idx = len(self) + idx 15 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 16 | if dataset_idx == 0: 17 | sample_idx = idx 18 | else: 19 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 20 | return self.datasets[dataset_idx][sample_idx], dataset_idx 21 | 22 | 23 | class ImagePaths(Dataset): 24 | def __init__(self, paths, size=None, random_crop=False, labels=None): 25 | self.size = size 26 | self.random_crop = random_crop 27 | 28 | self.labels = dict() if labels is None else labels 29 | self.labels["file_path_"] = paths 30 | self._length = len(paths) 31 | 32 | if self.size is not None and self.size > 0: 33 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) 34 | if not self.random_crop: 35 | self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) 36 | else: 37 | self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) 38 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) 39 | else: 40 | self.preprocessor = lambda **kwargs: kwargs 41 | 42 | def __len__(self): 43 | return self._length 44 | 45 | def preprocess_image(self, image_path): 46 | image = Image.open(image_path) 47 | if not image.mode == "RGB": 48 | image = image.convert("RGB") 49 | image = np.array(image).astype(np.uint8) 50 | image = self.preprocessor(image=image)["image"] 51 | image = (image/127.5 - 1.0).astype(np.float32) 52 | return image 53 | 54 | def __getitem__(self, i): 55 | example = dict() 56 | example["image"] = self.preprocess_image(self.labels["file_path_"][i]) 57 | for k in self.labels: 58 | example[k] = self.labels[k][i] 59 | return example 60 | 61 | 62 | class NumpyPaths(ImagePaths): 63 | def preprocess_image(self, image_path): 64 | image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 65 | image = np.transpose(image, (1,2,0)) 66 | image = Image.fromarray(image, mode="RGB") 67 | image = np.array(image).astype(np.uint8) 68 | image = self.preprocessor(image=image)["image"] 69 | image = (image/127.5 - 1.0).astype(np.float32) 70 | return image 71 | -------------------------------------------------------------------------------- /taming/data/coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import albumentations 4 | import numpy as np 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from torch.utils.data import Dataset 8 | 9 | from taming.data.sflckr import SegmentationBase # for examples included in repo 10 | 11 | 12 | class Examples(SegmentationBase): 13 | def __init__(self, size=256, random_crop=False, interpolation="bicubic"): 14 | super().__init__(data_csv="data/coco_examples.txt", 15 | data_root="data/coco_images", 16 | segmentation_root="data/coco_segmentations", 17 | size=size, random_crop=random_crop, 18 | interpolation=interpolation, 19 | n_labels=183, shift_segmentation=True) 20 | 21 | 22 | class CocoBase(Dataset): 23 | """needed for (image, caption, segmentation) pairs""" 24 | def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False, 25 | crop_size=None, force_no_crop=False, given_files=None): 26 | self.split = self.get_split() 27 | self.size = size 28 | if crop_size is None: 29 | self.crop_size = size 30 | else: 31 | self.crop_size = crop_size 32 | 33 | self.onehot = onehot_segmentation # return segmentation as rgb or one hot 34 | self.stuffthing = use_stuffthing # include thing in segmentation 35 | if self.onehot and not self.stuffthing: 36 | raise NotImplemented("One hot mode is only supported for the " 37 | "stuffthings version because labels are stored " 38 | "a bit different.") 39 | 40 | data_json = datajson 41 | with open(data_json) as json_file: 42 | self.json_data = json.load(json_file) 43 | self.img_id_to_captions = dict() 44 | self.img_id_to_filepath = dict() 45 | self.img_id_to_segmentation_filepath = dict() 46 | 47 | assert data_json.split("/")[-1] in ["captions_train2017.json", 48 | "captions_val2017.json"] 49 | if self.stuffthing: 50 | self.segmentation_prefix = ( 51 | "data/cocostuffthings/val2017" if 52 | data_json.endswith("captions_val2017.json") else 53 | "data/cocostuffthings/train2017") 54 | else: 55 | self.segmentation_prefix = ( 56 | "data/coco/annotations/stuff_val2017_pixelmaps" if 57 | data_json.endswith("captions_val2017.json") else 58 | "data/coco/annotations/stuff_train2017_pixelmaps") 59 | 60 | imagedirs = self.json_data["images"] 61 | self.labels = {"image_ids": list()} 62 | for imgdir in tqdm(imagedirs, desc="ImgToPath"): 63 | self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"]) 64 | self.img_id_to_captions[imgdir["id"]] = list() 65 | pngfilename = imgdir["file_name"].replace("jpg", "png") 66 | self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join( 67 | self.segmentation_prefix, pngfilename) 68 | if given_files is not None: 69 | if pngfilename in given_files: 70 | self.labels["image_ids"].append(imgdir["id"]) 71 | else: 72 | self.labels["image_ids"].append(imgdir["id"]) 73 | 74 | capdirs = self.json_data["annotations"] 75 | for capdir in tqdm(capdirs, desc="ImgToCaptions"): 76 | # there are in average 5 captions per image 77 | self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]])) 78 | 79 | self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) 80 | if self.split=="validation": 81 | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) 82 | else: 83 | self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) 84 | self.preprocessor = albumentations.Compose( 85 | [self.rescaler, self.cropper], 86 | additional_targets={"segmentation": "image"}) 87 | if force_no_crop: 88 | self.rescaler = albumentations.Resize(height=self.size, width=self.size) 89 | self.preprocessor = albumentations.Compose( 90 | [self.rescaler], 91 | additional_targets={"segmentation": "image"}) 92 | 93 | def __len__(self): 94 | return len(self.labels["image_ids"]) 95 | 96 | def preprocess_image(self, image_path, segmentation_path): 97 | image = Image.open(image_path) 98 | if not image.mode == "RGB": 99 | image = image.convert("RGB") 100 | image = np.array(image).astype(np.uint8) 101 | 102 | segmentation = Image.open(segmentation_path) 103 | if not self.onehot and not segmentation.mode == "RGB": 104 | segmentation = segmentation.convert("RGB") 105 | segmentation = np.array(segmentation).astype(np.uint8) 106 | if self.onehot: 107 | assert self.stuffthing 108 | # stored in caffe format: unlabeled==255. stuff and thing from 109 | # 0-181. to be compatible with the labels in 110 | # https://github.com/nightrome/cocostuff/blob/master/labels.txt 111 | # we shift stuffthing one to the right and put unlabeled in zero 112 | # as long as segmentation is uint8 shifting to right handles the 113 | # latter too 114 | assert segmentation.dtype == np.uint8 115 | segmentation = segmentation + 1 116 | 117 | processed = self.preprocessor(image=image, segmentation=segmentation) 118 | image, segmentation = processed["image"], processed["segmentation"] 119 | image = (image / 127.5 - 1.0).astype(np.float32) 120 | 121 | if self.onehot: 122 | assert segmentation.dtype == np.uint8 123 | # make it one hot 124 | n_labels = 183 125 | flatseg = np.ravel(segmentation) 126 | onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool) 127 | onehot[np.arange(flatseg.size), flatseg] = True 128 | onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int) 129 | segmentation = onehot 130 | else: 131 | segmentation = (segmentation / 127.5 - 1.0).astype(np.float32) 132 | return image, segmentation 133 | 134 | def __getitem__(self, i): 135 | img_path = self.img_id_to_filepath[self.labels["image_ids"][i]] 136 | seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]] 137 | image, segmentation = self.preprocess_image(img_path, seg_path) 138 | captions = self.img_id_to_captions[self.labels["image_ids"][i]] 139 | # randomly draw one of all available captions per image 140 | caption = captions[np.random.randint(0, len(captions))] 141 | example = {"image": image, 142 | "caption": [str(caption[0])], 143 | "segmentation": segmentation, 144 | "img_path": img_path, 145 | "seg_path": seg_path, 146 | "filename_": img_path.split(os.sep)[-1] 147 | } 148 | return example 149 | 150 | 151 | class CocoImagesAndCaptionsTrain(CocoBase): 152 | """returns a pair of (image, caption)""" 153 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False): 154 | super().__init__(size=size, 155 | dataroot="data/coco/train2017", 156 | datajson="data/coco/annotations/captions_train2017.json", 157 | onehot_segmentation=onehot_segmentation, 158 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop) 159 | 160 | def get_split(self): 161 | return "train" 162 | 163 | 164 | class CocoImagesAndCaptionsValidation(CocoBase): 165 | """returns a pair of (image, caption)""" 166 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False, 167 | given_files=None): 168 | super().__init__(size=size, 169 | dataroot="data/coco/val2017", 170 | datajson="data/coco/annotations/captions_val2017.json", 171 | onehot_segmentation=onehot_segmentation, 172 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, 173 | given_files=given_files) 174 | 175 | def get_split(self): 176 | return "validation" 177 | -------------------------------------------------------------------------------- /taming/data/conditional_builder/objects_bbox.py: -------------------------------------------------------------------------------- 1 | from itertools import cycle 2 | from typing import List, Tuple, Callable, Optional 3 | 4 | from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont 5 | from more_itertools.recipes import grouper 6 | from taming.data.image_transforms import convert_pil_to_tensor 7 | from torch import LongTensor, Tensor 8 | 9 | from taming.data.helper_types import BoundingBox, Annotation 10 | from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder 11 | from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \ 12 | pad_list, get_plot_font_size, absolute_bbox 13 | 14 | 15 | class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder): 16 | @property 17 | def object_descriptor_length(self) -> int: 18 | return 3 19 | 20 | def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: 21 | object_triples = [ 22 | (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox)) 23 | for ann in annotations 24 | ] 25 | empty_triple = (self.none, self.none, self.none) 26 | object_triples = pad_list(object_triples, empty_triple, self.no_max_objects) 27 | return object_triples 28 | 29 | def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]: 30 | conditional_list = conditional.tolist() 31 | crop_coordinates = None 32 | if self.encode_crop: 33 | crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) 34 | conditional_list = conditional_list[:-2] 35 | object_triples = grouper(conditional_list, 3) 36 | assert conditional.shape[0] == self.embedding_dim 37 | return [ 38 | (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2])) 39 | for object_triple in object_triples if object_triple[0] != self.none 40 | ], crop_coordinates 41 | 42 | def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], 43 | line_width: int = 3, font_size: Optional[int] = None) -> Tensor: 44 | plot = pil_image.new('RGB', figure_size, WHITE) 45 | draw = pil_img_draw.Draw(plot) 46 | font = ImageFont.truetype( 47 | "/usr/share/fonts/truetype/lato/Lato-Regular.ttf", 48 | size=get_plot_font_size(font_size, figure_size) 49 | ) 50 | width, height = plot.size 51 | description, crop_coordinates = self.inverse_build(conditional) 52 | for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)): 53 | annotation = self.representation_to_annotation(representation) 54 | class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation) 55 | bbox = absolute_bbox(bbox, width, height) 56 | draw.rectangle(bbox, outline=color, width=line_width) 57 | draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font) 58 | if crop_coordinates is not None: 59 | draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) 60 | return convert_pil_to_tensor(plot) / 127.5 - 1. 61 | -------------------------------------------------------------------------------- /taming/data/conditional_builder/objects_center_points.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import warnings 4 | from itertools import cycle 5 | from typing import List, Optional, Tuple, Callable 6 | 7 | from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont 8 | from more_itertools.recipes import grouper 9 | from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \ 10 | additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \ 11 | absolute_bbox, rescale_annotations 12 | from taming.data.helper_types import BoundingBox, Annotation 13 | from taming.data.image_transforms import convert_pil_to_tensor 14 | from torch import LongTensor, Tensor 15 | 16 | 17 | class ObjectsCenterPointsConditionalBuilder: 18 | def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, encode_crop: bool, 19 | use_group_parameter: bool, use_additional_parameters: bool): 20 | self.no_object_classes = no_object_classes 21 | self.no_max_objects = no_max_objects 22 | self.no_tokens = no_tokens 23 | self.encode_crop = encode_crop 24 | self.no_sections = int(math.sqrt(self.no_tokens)) 25 | self.use_group_parameter = use_group_parameter 26 | self.use_additional_parameters = use_additional_parameters 27 | 28 | @property 29 | def none(self) -> int: 30 | return self.no_tokens - 1 31 | 32 | @property 33 | def object_descriptor_length(self) -> int: 34 | return 2 35 | 36 | @property 37 | def embedding_dim(self) -> int: 38 | extra_length = 2 if self.encode_crop else 0 39 | return self.no_max_objects * self.object_descriptor_length + extra_length 40 | 41 | def tokenize_coordinates(self, x: float, y: float) -> int: 42 | """ 43 | Express 2d coordinates with one number. 44 | Example: assume self.no_tokens = 16, then no_sections = 4: 45 | 0 0 0 0 46 | 0 0 # 0 47 | 0 0 0 0 48 | 0 0 0 x 49 | Then the # position corresponds to token 6, the x position to token 15. 50 | @param x: float in [0, 1] 51 | @param y: float in [0, 1] 52 | @return: discrete tokenized coordinate 53 | """ 54 | x_discrete = int(round(x * (self.no_sections - 1))) 55 | y_discrete = int(round(y * (self.no_sections - 1))) 56 | return y_discrete * self.no_sections + x_discrete 57 | 58 | def coordinates_from_token(self, token: int) -> (float, float): 59 | x = token % self.no_sections 60 | y = token // self.no_sections 61 | return x / (self.no_sections - 1), y / (self.no_sections - 1) 62 | 63 | def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox: 64 | x0, y0 = self.coordinates_from_token(token1) 65 | x1, y1 = self.coordinates_from_token(token2) 66 | return x0, y0, x1 - x0, y1 - y0 67 | 68 | def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple[int, int]: 69 | return self.tokenize_coordinates(bbox[0], bbox[1]), \ 70 | self.tokenize_coordinates(bbox[0] + bbox[2], bbox[1] + bbox[3]) 71 | 72 | def inverse_build(self, conditional: LongTensor) \ 73 | -> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]: 74 | conditional_list = conditional.tolist() 75 | crop_coordinates = None 76 | if self.encode_crop: 77 | crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) 78 | conditional_list = conditional_list[:-2] 79 | table_of_content = grouper(conditional_list, self.object_descriptor_length) 80 | assert conditional.shape[0] == self.embedding_dim 81 | return [ 82 | (object_tuple[0], self.coordinates_from_token(object_tuple[1])) 83 | for object_tuple in table_of_content if object_tuple[0] != self.none 84 | ], crop_coordinates 85 | 86 | def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], 87 | line_width: int = 3, font_size: Optional[int] = None) -> Tensor: 88 | plot = pil_image.new('RGB', figure_size, WHITE) 89 | draw = pil_img_draw.Draw(plot) 90 | circle_size = get_circle_size(figure_size) 91 | font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf', 92 | size=get_plot_font_size(font_size, figure_size)) 93 | width, height = plot.size 94 | description, crop_coordinates = self.inverse_build(conditional) 95 | for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)): 96 | x_abs, y_abs = x * width, y * height 97 | ann = self.representation_to_annotation(representation) 98 | label = label_for_category_no(ann.category_no) + ' ' + additional_parameters_string(ann) 99 | ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size] 100 | draw.ellipse(ellipse_bbox, fill=color, width=0) 101 | draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font) 102 | if crop_coordinates is not None: 103 | draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) 104 | return convert_pil_to_tensor(plot) / 127.5 - 1. 105 | 106 | def object_representation(self, annotation: Annotation) -> int: 107 | modifier = 0 108 | if self.use_group_parameter: 109 | modifier |= 1 * (annotation.is_group_of is True) 110 | if self.use_additional_parameters: 111 | modifier |= 2 * (annotation.is_occluded is True) 112 | modifier |= 4 * (annotation.is_depiction is True) 113 | modifier |= 8 * (annotation.is_inside is True) 114 | return annotation.category_no + self.no_object_classes * modifier 115 | 116 | def representation_to_annotation(self, representation: int) -> Annotation: 117 | category_no = representation % self.no_object_classes 118 | modifier = representation // self.no_object_classes 119 | # noinspection PyTypeChecker 120 | return Annotation( 121 | area=None, image_id=None, bbox=None, category_id=None, id=None, source=None, confidence=None, 122 | category_no=category_no, 123 | is_group_of=bool((modifier & 1) * self.use_group_parameter), 124 | is_occluded=bool((modifier & 2) * self.use_additional_parameters), 125 | is_depiction=bool((modifier & 4) * self.use_additional_parameters), 126 | is_inside=bool((modifier & 8) * self.use_additional_parameters) 127 | ) 128 | 129 | def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]: 130 | return list(self.token_pair_from_bbox(crop_coordinates)) 131 | 132 | def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: 133 | object_tuples = [ 134 | (self.object_representation(a), 135 | self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2)) 136 | for a in annotations 137 | ] 138 | empty_tuple = (self.none, self.none) 139 | object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects) 140 | return object_tuples 141 | 142 | def build(self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False) \ 143 | -> LongTensor: 144 | if len(annotations) == 0: 145 | warnings.warn('Did not receive any annotations.') 146 | if len(annotations) > self.no_max_objects: 147 | warnings.warn('Received more annotations than allowed.') 148 | annotations = annotations[:self.no_max_objects] 149 | 150 | if not crop_coordinates: 151 | crop_coordinates = FULL_CROP 152 | 153 | random.shuffle(annotations) 154 | annotations = filter_annotations(annotations, crop_coordinates) 155 | if self.encode_crop: 156 | annotations = rescale_annotations(annotations, FULL_CROP, horizontal_flip) 157 | if horizontal_flip: 158 | crop_coordinates = horizontally_flip_bbox(crop_coordinates) 159 | extra = self._crop_encoder(crop_coordinates) 160 | else: 161 | annotations = rescale_annotations(annotations, crop_coordinates, horizontal_flip) 162 | extra = [] 163 | 164 | object_tuples = self._make_object_descriptors(annotations) 165 | flattened = [token for tuple_ in object_tuples for token in tuple_] + extra 166 | assert len(flattened) == self.embedding_dim 167 | assert all(0 <= value < self.no_tokens for value in flattened) 168 | return LongTensor(flattened) 169 | -------------------------------------------------------------------------------- /taming/data/conditional_builder/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from typing import List, Any, Tuple, Optional 3 | 4 | from taming.data.helper_types import BoundingBox, Annotation 5 | 6 | # source: seaborn, color palette tab10 7 | COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188), 8 | (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)] 9 | BLACK = (0, 0, 0) 10 | GRAY_75 = (63, 63, 63) 11 | GRAY_50 = (127, 127, 127) 12 | GRAY_25 = (191, 191, 191) 13 | WHITE = (255, 255, 255) 14 | FULL_CROP = (0., 0., 1., 1.) 15 | 16 | 17 | def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float: 18 | """ 19 | Give intersection area of two rectangles. 20 | @param rectangle1: (x0, y0, w, h) of first rectangle 21 | @param rectangle2: (x0, y0, w, h) of second rectangle 22 | """ 23 | rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3] 24 | rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3] 25 | x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0])) 26 | y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1])) 27 | return x_overlap * y_overlap 28 | 29 | 30 | def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox: 31 | return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3] 32 | 33 | 34 | def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]: 35 | bbox = relative_bbox 36 | bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height 37 | return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) 38 | 39 | 40 | def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List: 41 | return list_ + [pad_element for _ in range(pad_to_length - len(list_))] 42 | 43 | 44 | def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \ 45 | List[Annotation]: 46 | def clamp(x: float): 47 | return max(min(x, 1.), 0.) 48 | 49 | def rescale_bbox(bbox: BoundingBox) -> BoundingBox: 50 | x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) 51 | y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) 52 | w = min(bbox[2] / crop_coordinates[2], 1 - x0) 53 | h = min(bbox[3] / crop_coordinates[3], 1 - y0) 54 | if flip: 55 | x0 = 1 - (x0 + w) 56 | return x0, y0, w, h 57 | 58 | return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations] 59 | 60 | 61 | def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List: 62 | return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0] 63 | 64 | 65 | def additional_parameters_string(annotation: Annotation, short: bool = True) -> str: 66 | sl = slice(1) if short else slice(None) 67 | string = '' 68 | if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside): 69 | return string 70 | if annotation.is_group_of: 71 | string += 'group'[sl] + ',' 72 | if annotation.is_occluded: 73 | string += 'occluded'[sl] + ',' 74 | if annotation.is_depiction: 75 | string += 'depiction'[sl] + ',' 76 | if annotation.is_inside: 77 | string += 'inside'[sl] 78 | return '(' + string.strip(",") + ')' 79 | 80 | 81 | def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int: 82 | if font_size is None: 83 | font_size = 10 84 | if max(figure_size) >= 256: 85 | font_size = 12 86 | if max(figure_size) >= 512: 87 | font_size = 15 88 | return font_size 89 | 90 | 91 | def get_circle_size(figure_size: Tuple[int, int]) -> int: 92 | circle_size = 2 93 | if max(figure_size) >= 256: 94 | circle_size = 3 95 | if max(figure_size) >= 512: 96 | circle_size = 4 97 | return circle_size 98 | 99 | 100 | def load_object_from_string(object_string: str) -> Any: 101 | """ 102 | Source: https://stackoverflow.com/a/10773699 103 | """ 104 | module_name, class_name = object_string.rsplit(".", 1) 105 | return getattr(importlib.import_module(module_name), class_name) 106 | -------------------------------------------------------------------------------- /taming/data/custom.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import albumentations 4 | from torch.utils.data import Dataset 5 | 6 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex 7 | 8 | 9 | class CustomBase(Dataset): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__() 12 | self.data = None 13 | 14 | def __len__(self): 15 | return len(self.data) 16 | 17 | def __getitem__(self, i): 18 | example = self.data[i] 19 | return example 20 | 21 | 22 | 23 | class CustomTrain(CustomBase): 24 | def __init__(self, size, training_images_list_file): 25 | super().__init__() 26 | with open(training_images_list_file, "r") as f: 27 | paths = f.read().splitlines() 28 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 29 | 30 | 31 | class CustomTest(CustomBase): 32 | def __init__(self, size, test_images_list_file): 33 | super().__init__() 34 | with open(test_images_list_file, "r") as f: 35 | paths = f.read().splitlines() 36 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 37 | 38 | 39 | -------------------------------------------------------------------------------- /taming/data/faceshq.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import albumentations 4 | from torch.utils.data import Dataset 5 | 6 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex 7 | 8 | 9 | class FacesBase(Dataset): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__() 12 | self.data = None 13 | self.keys = None 14 | 15 | def __len__(self): 16 | return len(self.data) 17 | 18 | def __getitem__(self, i): 19 | example = self.data[i] 20 | ex = {} 21 | if self.keys is not None: 22 | for k in self.keys: 23 | ex[k] = example[k] 24 | else: 25 | ex = example 26 | return ex 27 | 28 | 29 | class CelebAHQTrain(FacesBase): 30 | def __init__(self, size, keys=None): 31 | super().__init__() 32 | root = "data/celebahq" 33 | with open("data/celebahqtrain.txt", "r") as f: 34 | relpaths = f.read().splitlines() 35 | paths = [os.path.join(root, relpath) for relpath in relpaths] 36 | self.data = NumpyPaths(paths=paths, size=size, random_crop=False) 37 | self.keys = keys 38 | 39 | 40 | class CelebAHQValidation(FacesBase): 41 | def __init__(self, size, keys=None): 42 | super().__init__() 43 | root = "data/celebahq" 44 | with open("data/celebahqvalidation.txt", "r") as f: 45 | relpaths = f.read().splitlines() 46 | paths = [os.path.join(root, relpath) for relpath in relpaths] 47 | self.data = NumpyPaths(paths=paths, size=size, random_crop=False) 48 | self.keys = keys 49 | 50 | 51 | class FFHQTrain(FacesBase): 52 | def __init__(self, size, keys=None): 53 | super().__init__() 54 | root = "data/ffhq" 55 | with open("data/ffhqtrain.txt", "r") as f: 56 | relpaths = f.read().splitlines() 57 | paths = [os.path.join(root, relpath) for relpath in relpaths] 58 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 59 | self.keys = keys 60 | 61 | 62 | class FFHQValidation(FacesBase): 63 | def __init__(self, size, keys=None): 64 | super().__init__() 65 | root = "data/ffhq" 66 | with open("data/ffhqvalidation.txt", "r") as f: 67 | relpaths = f.read().splitlines() 68 | paths = [os.path.join(root, relpath) for relpath in relpaths] 69 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 70 | self.keys = keys 71 | 72 | 73 | class FacesHQTrain(Dataset): 74 | # CelebAHQ [0] + FFHQ [1] 75 | def __init__(self, size, keys=None, crop_size=None, coord=False): 76 | d1 = CelebAHQTrain(size=size, keys=keys) 77 | d2 = FFHQTrain(size=size, keys=keys) 78 | self.data = ConcatDatasetWithIndex([d1, d2]) 79 | self.coord = coord 80 | if crop_size is not None: 81 | self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) 82 | if self.coord: 83 | self.cropper = albumentations.Compose([self.cropper], 84 | additional_targets={"coord": "image"}) 85 | 86 | def __len__(self): 87 | return len(self.data) 88 | 89 | def __getitem__(self, i): 90 | ex, y = self.data[i] 91 | if hasattr(self, "cropper"): 92 | if not self.coord: 93 | out = self.cropper(image=ex["image"]) 94 | ex["image"] = out["image"] 95 | else: 96 | h,w,_ = ex["image"].shape 97 | coord = np.arange(h*w).reshape(h,w,1)/(h*w) 98 | out = self.cropper(image=ex["image"], coord=coord) 99 | ex["image"] = out["image"] 100 | ex["coord"] = out["coord"] 101 | ex["class"] = y 102 | return ex 103 | 104 | 105 | class FacesHQValidation(Dataset): 106 | # CelebAHQ [0] + FFHQ [1] 107 | def __init__(self, size, keys=None, crop_size=None, coord=False): 108 | d1 = CelebAHQValidation(size=size, keys=keys) 109 | d2 = FFHQValidation(size=size, keys=keys) 110 | self.data = ConcatDatasetWithIndex([d1, d2]) 111 | self.coord = coord 112 | if crop_size is not None: 113 | self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) 114 | if self.coord: 115 | self.cropper = albumentations.Compose([self.cropper], 116 | additional_targets={"coord": "image"}) 117 | 118 | def __len__(self): 119 | return len(self.data) 120 | 121 | def __getitem__(self, i): 122 | ex, y = self.data[i] 123 | if hasattr(self, "cropper"): 124 | if not self.coord: 125 | out = self.cropper(image=ex["image"]) 126 | ex["image"] = out["image"] 127 | else: 128 | h,w,_ = ex["image"].shape 129 | coord = np.arange(h*w).reshape(h,w,1)/(h*w) 130 | out = self.cropper(image=ex["image"], coord=coord) 131 | ex["image"] = out["image"] 132 | ex["coord"] = out["coord"] 133 | ex["class"] = y 134 | return ex 135 | -------------------------------------------------------------------------------- /taming/data/helper_types.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple, Optional, NamedTuple, Union 2 | from PIL.Image import Image as pil_image 3 | from torch import Tensor 4 | 5 | try: 6 | from typing import Literal 7 | except ImportError: 8 | from typing_extensions import Literal 9 | 10 | Image = Union[Tensor, pil_image] 11 | BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h 12 | CropMethodType = Literal['none', 'random', 'center', 'random-2d'] 13 | SplitType = Literal['train', 'validation', 'test'] 14 | 15 | 16 | class ImageDescription(NamedTuple): 17 | id: int 18 | file_name: str 19 | original_size: Tuple[int, int] # w, h 20 | url: Optional[str] = None 21 | license: Optional[int] = None 22 | coco_url: Optional[str] = None 23 | date_captured: Optional[str] = None 24 | flickr_url: Optional[str] = None 25 | flickr_id: Optional[str] = None 26 | coco_id: Optional[str] = None 27 | 28 | 29 | class Category(NamedTuple): 30 | id: str 31 | super_category: Optional[str] 32 | name: str 33 | 34 | 35 | class Annotation(NamedTuple): 36 | area: float 37 | image_id: str 38 | bbox: BoundingBox 39 | category_no: int 40 | category_id: str 41 | id: Optional[int] = None 42 | source: Optional[str] = None 43 | confidence: Optional[float] = None 44 | is_group_of: Optional[bool] = None 45 | is_truncated: Optional[bool] = None 46 | is_occluded: Optional[bool] = None 47 | is_depiction: Optional[bool] = None 48 | is_inside: Optional[bool] = None 49 | segmentation: Optional[Dict] = None 50 | -------------------------------------------------------------------------------- /taming/data/image_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import warnings 3 | from typing import Union 4 | 5 | import torch 6 | from torch import Tensor 7 | from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor 8 | from torchvision.transforms.functional import _get_image_size as get_image_size 9 | 10 | from taming.data.helper_types import BoundingBox, Image 11 | 12 | pil_to_tensor = PILToTensor() 13 | 14 | 15 | def convert_pil_to_tensor(image: Image) -> Tensor: 16 | with warnings.catch_warnings(): 17 | # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194 18 | warnings.simplefilter("ignore") 19 | return pil_to_tensor(image) 20 | 21 | 22 | class RandomCrop1dReturnCoordinates(RandomCrop): 23 | def forward(self, img: Image) -> (BoundingBox, Image): 24 | """ 25 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 26 | Args: 27 | img (PIL Image or Tensor): Image to be cropped. 28 | 29 | Returns: 30 | Bounding box: x0, y0, w, h 31 | PIL Image or Tensor: Cropped image. 32 | 33 | Based on: 34 | torchvision.transforms.RandomCrop, torchvision 1.7.0 35 | """ 36 | if self.padding is not None: 37 | img = F.pad(img, self.padding, self.fill, self.padding_mode) 38 | 39 | width, height = get_image_size(img) 40 | # pad the width if needed 41 | if self.pad_if_needed and width < self.size[1]: 42 | padding = [self.size[1] - width, 0] 43 | img = F.pad(img, padding, self.fill, self.padding_mode) 44 | # pad the height if needed 45 | if self.pad_if_needed and height < self.size[0]: 46 | padding = [0, self.size[0] - height] 47 | img = F.pad(img, padding, self.fill, self.padding_mode) 48 | 49 | i, j, h, w = self.get_params(img, self.size) 50 | bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h 51 | return bbox, F.crop(img, i, j, h, w) 52 | 53 | 54 | class Random2dCropReturnCoordinates(torch.nn.Module): 55 | """ 56 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 57 | Args: 58 | img (PIL Image or Tensor): Image to be cropped. 59 | 60 | Returns: 61 | Bounding box: x0, y0, w, h 62 | PIL Image or Tensor: Cropped image. 63 | 64 | Based on: 65 | torchvision.transforms.RandomCrop, torchvision 1.7.0 66 | """ 67 | 68 | def __init__(self, min_size: int): 69 | super().__init__() 70 | self.min_size = min_size 71 | 72 | def forward(self, img: Image) -> (BoundingBox, Image): 73 | width, height = get_image_size(img) 74 | max_size = min(width, height) 75 | if max_size <= self.min_size: 76 | size = max_size 77 | else: 78 | size = random.randint(self.min_size, max_size) 79 | top = random.randint(0, height - size) 80 | left = random.randint(0, width - size) 81 | bbox = left / width, top / height, size / width, size / height 82 | return bbox, F.crop(img, top, left, size, size) 83 | 84 | 85 | class CenterCropReturnCoordinates(CenterCrop): 86 | @staticmethod 87 | def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox: 88 | if width > height: 89 | w = height / width 90 | h = 1.0 91 | x0 = 0.5 - w / 2 92 | y0 = 0. 93 | else: 94 | w = 1.0 95 | h = width / height 96 | x0 = 0. 97 | y0 = 0.5 - h / 2 98 | return x0, y0, w, h 99 | 100 | def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]): 101 | """ 102 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 103 | Args: 104 | img (PIL Image or Tensor): Image to be cropped. 105 | 106 | Returns: 107 | Bounding box: x0, y0, w, h 108 | PIL Image or Tensor: Cropped image. 109 | Based on: 110 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0) 111 | """ 112 | width, height = get_image_size(img) 113 | return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size) 114 | 115 | 116 | class RandomHorizontalFlipReturn(RandomHorizontalFlip): 117 | def forward(self, img: Image) -> (bool, Image): 118 | """ 119 | Additionally to flipping, returns a boolean whether it was flipped or not. 120 | Args: 121 | img (PIL Image or Tensor): Image to be flipped. 122 | 123 | Returns: 124 | flipped: whether the image was flipped or not 125 | PIL Image or Tensor: Randomly flipped image. 126 | 127 | Based on: 128 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0) 129 | """ 130 | if torch.rand(1) < self.p: 131 | return True, F.hflip(img) 132 | return False, img 133 | -------------------------------------------------------------------------------- /taming/data/open_images_helper.py: -------------------------------------------------------------------------------- 1 | open_images_unify_categories_for_coco = { 2 | '/m/03bt1vf': '/m/01g317', 3 | '/m/04yx4': '/m/01g317', 4 | '/m/05r655': '/m/01g317', 5 | '/m/01bl7v': '/m/01g317', 6 | '/m/0cnyhnx': '/m/01xq0k1', 7 | '/m/01226z': '/m/018xm', 8 | '/m/05ctyq': '/m/018xm', 9 | '/m/058qzx': '/m/04ctx', 10 | '/m/06pcq': '/m/0l515', 11 | '/m/03m3pdh': '/m/02crq1', 12 | '/m/046dlr': '/m/01x3z', 13 | '/m/0h8mzrc': '/m/01x3z', 14 | } 15 | 16 | 17 | top_300_classes_plus_coco_compatibility = [ 18 | ('Man', 1060962), 19 | ('Clothing', 986610), 20 | ('Tree', 748162), 21 | ('Woman', 611896), 22 | ('Person', 610294), 23 | ('Human face', 442948), 24 | ('Girl', 175399), 25 | ('Building', 162147), 26 | ('Car', 159135), 27 | ('Plant', 155704), 28 | ('Human body', 137073), 29 | ('Flower', 133128), 30 | ('Window', 127485), 31 | ('Human arm', 118380), 32 | ('House', 114365), 33 | ('Wheel', 111684), 34 | ('Suit', 99054), 35 | ('Human hair', 98089), 36 | ('Human head', 92763), 37 | ('Chair', 88624), 38 | ('Boy', 79849), 39 | ('Table', 73699), 40 | ('Jeans', 57200), 41 | ('Tire', 55725), 42 | ('Skyscraper', 53321), 43 | ('Food', 52400), 44 | ('Footwear', 50335), 45 | ('Dress', 50236), 46 | ('Human leg', 47124), 47 | ('Toy', 46636), 48 | ('Tower', 45605), 49 | ('Boat', 43486), 50 | ('Land vehicle', 40541), 51 | ('Bicycle wheel', 34646), 52 | ('Palm tree', 33729), 53 | ('Fashion accessory', 32914), 54 | ('Glasses', 31940), 55 | ('Bicycle', 31409), 56 | ('Furniture', 30656), 57 | ('Sculpture', 29643), 58 | ('Bottle', 27558), 59 | ('Dog', 26980), 60 | ('Snack', 26796), 61 | ('Human hand', 26664), 62 | ('Bird', 25791), 63 | ('Book', 25415), 64 | ('Guitar', 24386), 65 | ('Jacket', 23998), 66 | ('Poster', 22192), 67 | ('Dessert', 21284), 68 | ('Baked goods', 20657), 69 | ('Drink', 19754), 70 | ('Flag', 18588), 71 | ('Houseplant', 18205), 72 | ('Tableware', 17613), 73 | ('Airplane', 17218), 74 | ('Door', 17195), 75 | ('Sports uniform', 17068), 76 | ('Shelf', 16865), 77 | ('Drum', 16612), 78 | ('Vehicle', 16542), 79 | ('Microphone', 15269), 80 | ('Street light', 14957), 81 | ('Cat', 14879), 82 | ('Fruit', 13684), 83 | ('Fast food', 13536), 84 | ('Animal', 12932), 85 | ('Vegetable', 12534), 86 | ('Train', 12358), 87 | ('Horse', 11948), 88 | ('Flowerpot', 11728), 89 | ('Motorcycle', 11621), 90 | ('Fish', 11517), 91 | ('Desk', 11405), 92 | ('Helmet', 10996), 93 | ('Truck', 10915), 94 | ('Bus', 10695), 95 | ('Hat', 10532), 96 | ('Auto part', 10488), 97 | ('Musical instrument', 10303), 98 | ('Sunglasses', 10207), 99 | ('Picture frame', 10096), 100 | ('Sports equipment', 10015), 101 | ('Shorts', 9999), 102 | ('Wine glass', 9632), 103 | ('Duck', 9242), 104 | ('Wine', 9032), 105 | ('Rose', 8781), 106 | ('Tie', 8693), 107 | ('Butterfly', 8436), 108 | ('Beer', 7978), 109 | ('Cabinetry', 7956), 110 | ('Laptop', 7907), 111 | ('Insect', 7497), 112 | ('Goggles', 7363), 113 | ('Shirt', 7098), 114 | ('Dairy Product', 7021), 115 | ('Marine invertebrates', 7014), 116 | ('Cattle', 7006), 117 | ('Trousers', 6903), 118 | ('Van', 6843), 119 | ('Billboard', 6777), 120 | ('Balloon', 6367), 121 | ('Human nose', 6103), 122 | ('Tent', 6073), 123 | ('Camera', 6014), 124 | ('Doll', 6002), 125 | ('Coat', 5951), 126 | ('Mobile phone', 5758), 127 | ('Swimwear', 5729), 128 | ('Strawberry', 5691), 129 | ('Stairs', 5643), 130 | ('Goose', 5599), 131 | ('Umbrella', 5536), 132 | ('Cake', 5508), 133 | ('Sun hat', 5475), 134 | ('Bench', 5310), 135 | ('Bookcase', 5163), 136 | ('Bee', 5140), 137 | ('Computer monitor', 5078), 138 | ('Hiking equipment', 4983), 139 | ('Office building', 4981), 140 | ('Coffee cup', 4748), 141 | ('Curtain', 4685), 142 | ('Plate', 4651), 143 | ('Box', 4621), 144 | ('Tomato', 4595), 145 | ('Coffee table', 4529), 146 | ('Office supplies', 4473), 147 | ('Maple', 4416), 148 | ('Muffin', 4365), 149 | ('Cocktail', 4234), 150 | ('Castle', 4197), 151 | ('Couch', 4134), 152 | ('Pumpkin', 3983), 153 | ('Computer keyboard', 3960), 154 | ('Human mouth', 3926), 155 | ('Christmas tree', 3893), 156 | ('Mushroom', 3883), 157 | ('Swimming pool', 3809), 158 | ('Pastry', 3799), 159 | ('Lavender (Plant)', 3769), 160 | ('Football helmet', 3732), 161 | ('Bread', 3648), 162 | ('Traffic sign', 3628), 163 | ('Common sunflower', 3597), 164 | ('Television', 3550), 165 | ('Bed', 3525), 166 | ('Cookie', 3485), 167 | ('Fountain', 3484), 168 | ('Paddle', 3447), 169 | ('Bicycle helmet', 3429), 170 | ('Porch', 3420), 171 | ('Deer', 3387), 172 | ('Fedora', 3339), 173 | ('Canoe', 3338), 174 | ('Carnivore', 3266), 175 | ('Bowl', 3202), 176 | ('Human eye', 3166), 177 | ('Ball', 3118), 178 | ('Pillow', 3077), 179 | ('Salad', 3061), 180 | ('Beetle', 3060), 181 | ('Orange', 3050), 182 | ('Drawer', 2958), 183 | ('Platter', 2937), 184 | ('Elephant', 2921), 185 | ('Seafood', 2921), 186 | ('Monkey', 2915), 187 | ('Countertop', 2879), 188 | ('Watercraft', 2831), 189 | ('Helicopter', 2805), 190 | ('Kitchen appliance', 2797), 191 | ('Personal flotation device', 2781), 192 | ('Swan', 2739), 193 | ('Lamp', 2711), 194 | ('Boot', 2695), 195 | ('Bronze sculpture', 2693), 196 | ('Chicken', 2677), 197 | ('Taxi', 2643), 198 | ('Juice', 2615), 199 | ('Cowboy hat', 2604), 200 | ('Apple', 2600), 201 | ('Tin can', 2590), 202 | ('Necklace', 2564), 203 | ('Ice cream', 2560), 204 | ('Human beard', 2539), 205 | ('Coin', 2536), 206 | ('Candle', 2515), 207 | ('Cart', 2512), 208 | ('High heels', 2441), 209 | ('Weapon', 2433), 210 | ('Handbag', 2406), 211 | ('Penguin', 2396), 212 | ('Rifle', 2352), 213 | ('Violin', 2336), 214 | ('Skull', 2304), 215 | ('Lantern', 2285), 216 | ('Scarf', 2269), 217 | ('Saucer', 2225), 218 | ('Sheep', 2215), 219 | ('Vase', 2189), 220 | ('Lily', 2180), 221 | ('Mug', 2154), 222 | ('Parrot', 2140), 223 | ('Human ear', 2137), 224 | ('Sandal', 2115), 225 | ('Lizard', 2100), 226 | ('Kitchen & dining room table', 2063), 227 | ('Spider', 1977), 228 | ('Coffee', 1974), 229 | ('Goat', 1926), 230 | ('Squirrel', 1922), 231 | ('Cello', 1913), 232 | ('Sushi', 1881), 233 | ('Tortoise', 1876), 234 | ('Pizza', 1870), 235 | ('Studio couch', 1864), 236 | ('Barrel', 1862), 237 | ('Cosmetics', 1841), 238 | ('Moths and butterflies', 1841), 239 | ('Convenience store', 1817), 240 | ('Watch', 1792), 241 | ('Home appliance', 1786), 242 | ('Harbor seal', 1780), 243 | ('Luggage and bags', 1756), 244 | ('Vehicle registration plate', 1754), 245 | ('Shrimp', 1751), 246 | ('Jellyfish', 1730), 247 | ('French fries', 1723), 248 | ('Egg (Food)', 1698), 249 | ('Football', 1697), 250 | ('Musical keyboard', 1683), 251 | ('Falcon', 1674), 252 | ('Candy', 1660), 253 | ('Medical equipment', 1654), 254 | ('Eagle', 1651), 255 | ('Dinosaur', 1634), 256 | ('Surfboard', 1630), 257 | ('Tank', 1628), 258 | ('Grape', 1624), 259 | ('Lion', 1624), 260 | ('Owl', 1622), 261 | ('Ski', 1613), 262 | ('Waste container', 1606), 263 | ('Frog', 1591), 264 | ('Sparrow', 1585), 265 | ('Rabbit', 1581), 266 | ('Pen', 1546), 267 | ('Sea lion', 1537), 268 | ('Spoon', 1521), 269 | ('Sink', 1512), 270 | ('Teddy bear', 1507), 271 | ('Bull', 1495), 272 | ('Sofa bed', 1490), 273 | ('Dragonfly', 1479), 274 | ('Brassiere', 1478), 275 | ('Chest of drawers', 1472), 276 | ('Aircraft', 1466), 277 | ('Human foot', 1463), 278 | ('Pig', 1455), 279 | ('Fork', 1454), 280 | ('Antelope', 1438), 281 | ('Tripod', 1427), 282 | ('Tool', 1424), 283 | ('Cheese', 1422), 284 | ('Lemon', 1397), 285 | ('Hamburger', 1393), 286 | ('Dolphin', 1390), 287 | ('Mirror', 1390), 288 | ('Marine mammal', 1387), 289 | ('Giraffe', 1385), 290 | ('Snake', 1368), 291 | ('Gondola', 1364), 292 | ('Wheelchair', 1360), 293 | ('Piano', 1358), 294 | ('Cupboard', 1348), 295 | ('Banana', 1345), 296 | ('Trumpet', 1335), 297 | ('Lighthouse', 1333), 298 | ('Invertebrate', 1317), 299 | ('Carrot', 1268), 300 | ('Sock', 1260), 301 | ('Tiger', 1241), 302 | ('Camel', 1224), 303 | ('Parachute', 1224), 304 | ('Bathroom accessory', 1223), 305 | ('Earrings', 1221), 306 | ('Headphones', 1218), 307 | ('Skirt', 1198), 308 | ('Skateboard', 1190), 309 | ('Sandwich', 1148), 310 | ('Saxophone', 1141), 311 | ('Goldfish', 1136), 312 | ('Stool', 1104), 313 | ('Traffic light', 1097), 314 | ('Shellfish', 1081), 315 | ('Backpack', 1079), 316 | ('Sea turtle', 1078), 317 | ('Cucumber', 1075), 318 | ('Tea', 1051), 319 | ('Toilet', 1047), 320 | ('Roller skates', 1040), 321 | ('Mule', 1039), 322 | ('Bust', 1031), 323 | ('Broccoli', 1030), 324 | ('Crab', 1020), 325 | ('Oyster', 1019), 326 | ('Cannon', 1012), 327 | ('Zebra', 1012), 328 | ('French horn', 1008), 329 | ('Grapefruit', 998), 330 | ('Whiteboard', 997), 331 | ('Zucchini', 997), 332 | ('Crocodile', 992), 333 | 334 | ('Clock', 960), 335 | ('Wall clock', 958), 336 | 337 | ('Doughnut', 869), 338 | ('Snail', 868), 339 | 340 | ('Baseball glove', 859), 341 | 342 | ('Panda', 830), 343 | ('Tennis racket', 830), 344 | 345 | ('Pear', 652), 346 | 347 | ('Bagel', 617), 348 | ('Oven', 616), 349 | ('Ladybug', 615), 350 | ('Shark', 615), 351 | ('Polar bear', 614), 352 | ('Ostrich', 609), 353 | 354 | ('Hot dog', 473), 355 | ('Microwave oven', 467), 356 | ('Fire hydrant', 20), 357 | ('Stop sign', 20), 358 | ('Parking meter', 20), 359 | ('Bear', 20), 360 | ('Flying disc', 20), 361 | ('Snowboard', 20), 362 | ('Tennis ball', 20), 363 | ('Kite', 20), 364 | ('Baseball bat', 20), 365 | ('Kitchen knife', 20), 366 | ('Knife', 20), 367 | ('Submarine sandwich', 20), 368 | ('Computer mouse', 20), 369 | ('Remote control', 20), 370 | ('Toaster', 20), 371 | ('Sink', 20), 372 | ('Refrigerator', 20), 373 | ('Alarm clock', 20), 374 | ('Wall clock', 20), 375 | ('Scissors', 20), 376 | ('Hair dryer', 20), 377 | ('Toothbrush', 20), 378 | ('Suitcase', 20) 379 | ] 380 | -------------------------------------------------------------------------------- /taming/data/sflckr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import albumentations 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class SegmentationBase(Dataset): 10 | def __init__(self, 11 | data_csv, data_root, segmentation_root, 12 | size=None, random_crop=False, interpolation="bicubic", 13 | n_labels=182, shift_segmentation=False, 14 | ): 15 | self.n_labels = n_labels 16 | self.shift_segmentation = shift_segmentation 17 | self.data_csv = data_csv 18 | self.data_root = data_root 19 | self.segmentation_root = segmentation_root 20 | with open(self.data_csv, "r") as f: 21 | self.image_paths = f.read().splitlines() 22 | self._length = len(self.image_paths) 23 | self.labels = { 24 | "relative_file_path_": [l for l in self.image_paths], 25 | "file_path_": [os.path.join(self.data_root, l) 26 | for l in self.image_paths], 27 | "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png")) 28 | for l in self.image_paths] 29 | } 30 | 31 | size = None if size is not None and size<=0 else size 32 | self.size = size 33 | if self.size is not None: 34 | self.interpolation = interpolation 35 | self.interpolation = { 36 | "nearest": cv2.INTER_NEAREST, 37 | "bilinear": cv2.INTER_LINEAR, 38 | "bicubic": cv2.INTER_CUBIC, 39 | "area": cv2.INTER_AREA, 40 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] 41 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 42 | interpolation=self.interpolation) 43 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 44 | interpolation=cv2.INTER_NEAREST) 45 | self.center_crop = not random_crop 46 | if self.center_crop: 47 | self.cropper = albumentations.CenterCrop(height=self.size, width=self.size) 48 | else: 49 | self.cropper = albumentations.RandomCrop(height=self.size, width=self.size) 50 | self.preprocessor = self.cropper 51 | 52 | def __len__(self): 53 | return self._length 54 | 55 | def __getitem__(self, i): 56 | example = dict((k, self.labels[k][i]) for k in self.labels) 57 | image = Image.open(example["file_path_"]) 58 | if not image.mode == "RGB": 59 | image = image.convert("RGB") 60 | image = np.array(image).astype(np.uint8) 61 | if self.size is not None: 62 | image = self.image_rescaler(image=image)["image"] 63 | segmentation = Image.open(example["segmentation_path_"]) 64 | assert segmentation.mode == "L", segmentation.mode 65 | segmentation = np.array(segmentation).astype(np.uint8) 66 | if self.shift_segmentation: 67 | # used to support segmentations containing unlabeled==255 label 68 | segmentation = segmentation+1 69 | if self.size is not None: 70 | segmentation = self.segmentation_rescaler(image=segmentation)["image"] 71 | if self.size is not None: 72 | processed = self.preprocessor(image=image, 73 | mask=segmentation 74 | ) 75 | else: 76 | processed = {"image": image, 77 | "mask": segmentation 78 | } 79 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) 80 | segmentation = processed["mask"] 81 | onehot = np.eye(self.n_labels)[segmentation] 82 | example["segmentation"] = onehot 83 | return example 84 | 85 | 86 | class Examples(SegmentationBase): 87 | def __init__(self, size=None, random_crop=False, interpolation="bicubic"): 88 | super().__init__(data_csv="data/sflckr_examples.txt", 89 | data_root="data/sflckr_images", 90 | segmentation_root="data/sflckr_segmentations", 91 | size=size, random_crop=random_crop, interpolation=interpolation) 92 | -------------------------------------------------------------------------------- /taming/data/utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import tarfile 4 | import urllib 5 | import zipfile 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import torch 10 | from taming.data.helper_types import Annotation 11 | from torch._six import string_classes 12 | from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format 13 | from tqdm import tqdm 14 | 15 | 16 | def unpack(path): 17 | if path.endswith("tar.gz"): 18 | with tarfile.open(path, "r:gz") as tar: 19 | tar.extractall(path=os.path.split(path)[0]) 20 | elif path.endswith("tar"): 21 | with tarfile.open(path, "r:") as tar: 22 | tar.extractall(path=os.path.split(path)[0]) 23 | elif path.endswith("zip"): 24 | with zipfile.ZipFile(path, "r") as f: 25 | f.extractall(path=os.path.split(path)[0]) 26 | else: 27 | raise NotImplementedError( 28 | "Unknown file extension: {}".format(os.path.splitext(path)[1]) 29 | ) 30 | 31 | 32 | def reporthook(bar): 33 | """tqdm progress bar for downloads.""" 34 | 35 | def hook(b=1, bsize=1, tsize=None): 36 | if tsize is not None: 37 | bar.total = tsize 38 | bar.update(b * bsize - bar.n) 39 | 40 | return hook 41 | 42 | 43 | def get_root(name): 44 | base = "data/" 45 | root = os.path.join(base, name) 46 | os.makedirs(root, exist_ok=True) 47 | return root 48 | 49 | 50 | def is_prepared(root): 51 | return Path(root).joinpath(".ready").exists() 52 | 53 | 54 | def mark_prepared(root): 55 | Path(root).joinpath(".ready").touch() 56 | 57 | 58 | def prompt_download(file_, source, target_dir, content_dir=None): 59 | targetpath = os.path.join(target_dir, file_) 60 | while not os.path.exists(targetpath): 61 | if content_dir is not None and os.path.exists( 62 | os.path.join(target_dir, content_dir) 63 | ): 64 | break 65 | print( 66 | "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath) 67 | ) 68 | if content_dir is not None: 69 | print( 70 | "Or place its content into '{}'.".format( 71 | os.path.join(target_dir, content_dir) 72 | ) 73 | ) 74 | input("Press Enter when done...") 75 | return targetpath 76 | 77 | 78 | def download_url(file_, url, target_dir): 79 | targetpath = os.path.join(target_dir, file_) 80 | os.makedirs(target_dir, exist_ok=True) 81 | with tqdm( 82 | unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_ 83 | ) as bar: 84 | urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar)) 85 | return targetpath 86 | 87 | 88 | def download_urls(urls, target_dir): 89 | paths = dict() 90 | for fname, url in urls.items(): 91 | outpath = download_url(fname, url, target_dir) 92 | paths[fname] = outpath 93 | return paths 94 | 95 | 96 | def quadratic_crop(x, bbox, alpha=1.0): 97 | """bbox is xmin, ymin, xmax, ymax""" 98 | im_h, im_w = x.shape[:2] 99 | bbox = np.array(bbox, dtype=np.float32) 100 | bbox = np.clip(bbox, 0, max(im_h, im_w)) 101 | center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3]) 102 | w = bbox[2] - bbox[0] 103 | h = bbox[3] - bbox[1] 104 | l = int(alpha * max(w, h)) 105 | l = max(l, 2) 106 | 107 | required_padding = -1 * min( 108 | center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l) 109 | ) 110 | required_padding = int(np.ceil(required_padding)) 111 | if required_padding > 0: 112 | padding = [ 113 | [required_padding, required_padding], 114 | [required_padding, required_padding], 115 | ] 116 | padding += [[0, 0]] * (len(x.shape) - 2) 117 | x = np.pad(x, padding, "reflect") 118 | center = center[0] + required_padding, center[1] + required_padding 119 | xmin = int(center[0] - l / 2) 120 | ymin = int(center[1] - l / 2) 121 | return np.array(x[ymin : ymin + l, xmin : xmin + l, ...]) 122 | 123 | 124 | def custom_collate(batch): 125 | r"""source: pytorch 1.9.0, only one modification to original code """ 126 | 127 | elem = batch[0] 128 | elem_type = type(elem) 129 | if isinstance(elem, torch.Tensor): 130 | out = None 131 | if torch.utils.data.get_worker_info() is not None: 132 | # If we're in a background process, concatenate directly into a 133 | # shared memory tensor to avoid an extra copy 134 | numel = sum([x.numel() for x in batch]) 135 | storage = elem.storage()._new_shared(numel) 136 | out = elem.new(storage) 137 | return torch.stack(batch, 0, out=out) 138 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 139 | and elem_type.__name__ != 'string_': 140 | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': 141 | # array of string classes and object 142 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None: 143 | raise TypeError(default_collate_err_msg_format.format(elem.dtype)) 144 | 145 | return custom_collate([torch.as_tensor(b) for b in batch]) 146 | elif elem.shape == (): # scalars 147 | return torch.as_tensor(batch) 148 | elif isinstance(elem, float): 149 | return torch.tensor(batch, dtype=torch.float64) 150 | elif isinstance(elem, int): 151 | return torch.tensor(batch) 152 | elif isinstance(elem, string_classes): 153 | return batch 154 | elif isinstance(elem, collections.abc.Mapping): 155 | return {key: custom_collate([d[key] for d in batch]) for key in elem} 156 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple 157 | return elem_type(*(custom_collate(samples) for samples in zip(*batch))) 158 | if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added 159 | return batch # added 160 | elif isinstance(elem, collections.abc.Sequence): 161 | # check to make sure that the elements in batch have consistent size 162 | it = iter(batch) 163 | elem_size = len(next(it)) 164 | if not all(len(elem) == elem_size for elem in it): 165 | raise RuntimeError('each element in list of batch should be of equal size') 166 | transposed = zip(*batch) 167 | return [custom_collate(samples) for samples in transposed] 168 | 169 | raise TypeError(default_collate_err_msg_format.format(elem_type)) 170 | -------------------------------------------------------------------------------- /taming/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n): 33 | return self.schedule(n) 34 | 35 | -------------------------------------------------------------------------------- /taming/models/dummy_cond_stage.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | 4 | class DummyCondStage: 5 | def __init__(self, conditional_key): 6 | self.conditional_key = conditional_key 7 | self.train = None 8 | 9 | def eval(self): 10 | return self 11 | 12 | @staticmethod 13 | def encode(c: Tensor): 14 | return c, None, (None, None, c) 15 | 16 | @staticmethod 17 | def decode(c: Tensor): 18 | return c 19 | 20 | @staticmethod 21 | def to_rgb(c: Tensor): 22 | return c 23 | -------------------------------------------------------------------------------- /taming/modules/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/taming/modules/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /taming/modules/autoencoder/lpips/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/taming/modules/autoencoder/lpips/vgg.pth -------------------------------------------------------------------------------- /taming/modules/discriminator/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/taming/modules/discriminator/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /taming/modules/discriminator/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | 5 | from taming.modules.util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find('Conv') != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find('BatchNorm') != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 22 | """Construct a PatchGAN discriminator 23 | Parameters: 24 | input_nc (int) -- the number of channels in input images 25 | ndf (int) -- the number of filters in the last conv layer 26 | n_layers (int) -- the number of conv layers in the discriminator 27 | norm_layer -- normalization layer 28 | """ 29 | super(NLayerDiscriminator, self).__init__() 30 | if not use_actnorm: 31 | norm_layer = nn.BatchNorm2d 32 | else: 33 | norm_layer = ActNorm 34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 35 | use_bias = norm_layer.func != nn.BatchNorm2d 36 | else: 37 | use_bias = norm_layer != nn.BatchNorm2d 38 | 39 | kw = 4 40 | padw = 1 41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 42 | nf_mult = 1 43 | nf_mult_prev = 1 44 | for n in range(1, n_layers): # gradually increase the number of filters 45 | nf_mult_prev = nf_mult 46 | nf_mult = min(2 ** n, 8) 47 | sequence += [ 48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 49 | norm_layer(ndf * nf_mult), 50 | nn.LeakyReLU(0.2, True) 51 | ] 52 | 53 | nf_mult_prev = nf_mult 54 | nf_mult = min(2 ** n_layers, 8) 55 | sequence += [ 56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 57 | norm_layer(ndf * nf_mult), 58 | nn.LeakyReLU(0.2, True) 59 | ] 60 | 61 | sequence += [ 62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 63 | self.main = nn.Sequential(*sequence) 64 | 65 | def forward(self, input): 66 | """Standard forward.""" 67 | return self.main(input) 68 | -------------------------------------------------------------------------------- /taming/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from taming.modules.losses.vqperceptual import DummyLoss 2 | 3 | -------------------------------------------------------------------------------- /taming/modules/losses/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/taming/modules/losses/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /taming/modules/losses/__pycache__/lpips.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/taming/modules/losses/__pycache__/lpips.cpython-38.pyc -------------------------------------------------------------------------------- /taming/modules/losses/__pycache__/vqperceptual.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/taming/modules/losses/__pycache__/vqperceptual.cpython-38.pyc -------------------------------------------------------------------------------- /taming/modules/losses/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | 8 | from taming.util import get_ckpt_path 9 | 10 | 11 | class LPIPS(nn.Module): 12 | # Learned perceptual metric 13 | def __init__(self, use_dropout=True): 14 | super().__init__() 15 | self.scaling_layer = ScalingLayer() 16 | self.chns = [64, 128, 256, 512, 512] # vg16 features 17 | self.net = vgg16(pretrained=True, requires_grad=False) 18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 23 | self.load_from_pretrained() 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def load_from_pretrained(self, name="vgg_lpips"): 28 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") 29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 30 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 31 | 32 | @classmethod 33 | def from_pretrained(cls, name="vgg_lpips"): 34 | if name != "vgg_lpips": 35 | raise NotImplementedError 36 | model = cls() 37 | ckpt = get_ckpt_path(name) 38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 39 | return model 40 | 41 | def forward(self, input, target): 42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 43 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 44 | feats0, feats1, diffs = {}, {}, {} 45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 46 | for kk in range(len(self.chns)): 47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 49 | 50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 51 | val = res[0] 52 | for l in range(1, len(self.chns)): 53 | val += res[l] 54 | return val 55 | 56 | 57 | class ScalingLayer(nn.Module): 58 | def __init__(self): 59 | super(ScalingLayer, self).__init__() 60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 62 | 63 | def forward(self, inp): 64 | return (inp - self.shift) / self.scale 65 | 66 | 67 | class NetLinLayer(nn.Module): 68 | """ A single linear layer which does a 1x1 conv """ 69 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 70 | super(NetLinLayer, self).__init__() 71 | layers = [nn.Dropout(), ] if (use_dropout) else [] 72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 73 | self.model = nn.Sequential(*layers) 74 | 75 | 76 | class vgg16(torch.nn.Module): 77 | def __init__(self, requires_grad=False, pretrained=True): 78 | super(vgg16, self).__init__() 79 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 80 | self.slice1 = torch.nn.Sequential() 81 | self.slice2 = torch.nn.Sequential() 82 | self.slice3 = torch.nn.Sequential() 83 | self.slice4 = torch.nn.Sequential() 84 | self.slice5 = torch.nn.Sequential() 85 | self.N_slices = 5 86 | for x in range(4): 87 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 88 | for x in range(4, 9): 89 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 90 | for x in range(9, 16): 91 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 92 | for x in range(16, 23): 93 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 94 | for x in range(23, 30): 95 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 96 | if not requires_grad: 97 | for param in self.parameters(): 98 | param.requires_grad = False 99 | 100 | def forward(self, X): 101 | h = self.slice1(X) 102 | h_relu1_2 = h 103 | h = self.slice2(h) 104 | h_relu2_2 = h 105 | h = self.slice3(h) 106 | h_relu3_3 = h 107 | h = self.slice4(h) 108 | h_relu4_3 = h 109 | h = self.slice5(h) 110 | h_relu5_3 = h 111 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 112 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 113 | return out 114 | 115 | 116 | def normalize_tensor(x,eps=1e-10): 117 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 118 | return x/(norm_factor+eps) 119 | 120 | 121 | def spatial_average(x, keepdim=True): 122 | return x.mean([2,3],keepdim=keepdim) 123 | 124 | -------------------------------------------------------------------------------- /taming/modules/losses/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class BCELoss(nn.Module): 6 | def forward(self, prediction, target): 7 | loss = F.binary_cross_entropy_with_logits(prediction,target) 8 | return loss, {} 9 | 10 | 11 | class BCELossWithQuant(nn.Module): 12 | def __init__(self, codebook_weight=1.): 13 | super().__init__() 14 | self.codebook_weight = codebook_weight 15 | 16 | def forward(self, qloss, target, prediction, split): 17 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target) 18 | loss = bce_loss + self.codebook_weight*qloss 19 | return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(), 20 | "{}/bce_loss".format(split): bce_loss.detach().mean(), 21 | "{}/quant_loss".format(split): qloss.detach().mean() 22 | } 23 | -------------------------------------------------------------------------------- /taming/modules/misc/coord.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class CoordStage(object): 4 | def __init__(self, n_embed, down_factor): 5 | self.n_embed = n_embed 6 | self.down_factor = down_factor 7 | 8 | def eval(self): 9 | return self 10 | 11 | def encode(self, c): 12 | """fake vqmodel interface""" 13 | assert 0.0 <= c.min() and c.max() <= 1.0 14 | b,ch,h,w = c.shape 15 | assert ch == 1 16 | 17 | c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor, 18 | mode="area") 19 | c = c.clamp(0.0, 1.0) 20 | c = self.n_embed*c 21 | c_quant = c.round() 22 | c_ind = c_quant.to(dtype=torch.long) 23 | 24 | info = None, None, c_ind 25 | return c_quant, None, info 26 | 27 | def decode(self, c): 28 | c = c/self.n_embed 29 | c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor, 30 | mode="nearest") 31 | return c 32 | -------------------------------------------------------------------------------- /taming/modules/transformer/permuter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class AbstractPermuter(nn.Module): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__() 9 | def forward(self, x, reverse=False): 10 | raise NotImplementedError 11 | 12 | 13 | class Identity(AbstractPermuter): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def forward(self, x, reverse=False): 18 | return x 19 | 20 | 21 | class Subsample(AbstractPermuter): 22 | def __init__(self, H, W): 23 | super().__init__() 24 | C = 1 25 | indices = np.arange(H*W).reshape(C,H,W) 26 | while min(H, W) > 1: 27 | indices = indices.reshape(C,H//2,2,W//2,2) 28 | indices = indices.transpose(0,2,4,1,3) 29 | indices = indices.reshape(C*4,H//2, W//2) 30 | H = H//2 31 | W = W//2 32 | C = C*4 33 | assert H == W == 1 34 | idx = torch.tensor(indices.ravel()) 35 | self.register_buffer('forward_shuffle_idx', 36 | nn.Parameter(idx, requires_grad=False)) 37 | self.register_buffer('backward_shuffle_idx', 38 | nn.Parameter(torch.argsort(idx), requires_grad=False)) 39 | 40 | def forward(self, x, reverse=False): 41 | if not reverse: 42 | return x[:, self.forward_shuffle_idx] 43 | else: 44 | return x[:, self.backward_shuffle_idx] 45 | 46 | 47 | def mortonify(i, j): 48 | """(i,j) index to linear morton code""" 49 | i = np.uint64(i) 50 | j = np.uint64(j) 51 | 52 | z = np.uint(0) 53 | 54 | for pos in range(32): 55 | z = (z | 56 | ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) | 57 | ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1)) 58 | ) 59 | return z 60 | 61 | 62 | class ZCurve(AbstractPermuter): 63 | def __init__(self, H, W): 64 | super().__init__() 65 | reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)] 66 | idx = np.argsort(reverseidx) 67 | idx = torch.tensor(idx) 68 | reverseidx = torch.tensor(reverseidx) 69 | self.register_buffer('forward_shuffle_idx', 70 | idx) 71 | self.register_buffer('backward_shuffle_idx', 72 | reverseidx) 73 | 74 | def forward(self, x, reverse=False): 75 | if not reverse: 76 | return x[:, self.forward_shuffle_idx] 77 | else: 78 | return x[:, self.backward_shuffle_idx] 79 | 80 | 81 | class SpiralOut(AbstractPermuter): 82 | def __init__(self, H, W): 83 | super().__init__() 84 | assert H == W 85 | size = W 86 | indices = np.arange(size*size).reshape(size,size) 87 | 88 | i0 = size//2 89 | j0 = size//2-1 90 | 91 | i = i0 92 | j = j0 93 | 94 | idx = [indices[i0, j0]] 95 | step_mult = 0 96 | for c in range(1, size//2+1): 97 | step_mult += 1 98 | # steps left 99 | for k in range(step_mult): 100 | i = i - 1 101 | j = j 102 | idx.append(indices[i, j]) 103 | 104 | # step down 105 | for k in range(step_mult): 106 | i = i 107 | j = j + 1 108 | idx.append(indices[i, j]) 109 | 110 | step_mult += 1 111 | if c < size//2: 112 | # step right 113 | for k in range(step_mult): 114 | i = i + 1 115 | j = j 116 | idx.append(indices[i, j]) 117 | 118 | # step up 119 | for k in range(step_mult): 120 | i = i 121 | j = j - 1 122 | idx.append(indices[i, j]) 123 | else: 124 | # end reached 125 | for k in range(step_mult-1): 126 | i = i + 1 127 | idx.append(indices[i, j]) 128 | 129 | assert len(idx) == size*size 130 | idx = torch.tensor(idx) 131 | self.register_buffer('forward_shuffle_idx', idx) 132 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) 133 | 134 | def forward(self, x, reverse=False): 135 | if not reverse: 136 | return x[:, self.forward_shuffle_idx] 137 | else: 138 | return x[:, self.backward_shuffle_idx] 139 | 140 | 141 | class SpiralIn(AbstractPermuter): 142 | def __init__(self, H, W): 143 | super().__init__() 144 | assert H == W 145 | size = W 146 | indices = np.arange(size*size).reshape(size,size) 147 | 148 | i0 = size//2 149 | j0 = size//2-1 150 | 151 | i = i0 152 | j = j0 153 | 154 | idx = [indices[i0, j0]] 155 | step_mult = 0 156 | for c in range(1, size//2+1): 157 | step_mult += 1 158 | # steps left 159 | for k in range(step_mult): 160 | i = i - 1 161 | j = j 162 | idx.append(indices[i, j]) 163 | 164 | # step down 165 | for k in range(step_mult): 166 | i = i 167 | j = j + 1 168 | idx.append(indices[i, j]) 169 | 170 | step_mult += 1 171 | if c < size//2: 172 | # step right 173 | for k in range(step_mult): 174 | i = i + 1 175 | j = j 176 | idx.append(indices[i, j]) 177 | 178 | # step up 179 | for k in range(step_mult): 180 | i = i 181 | j = j - 1 182 | idx.append(indices[i, j]) 183 | else: 184 | # end reached 185 | for k in range(step_mult-1): 186 | i = i + 1 187 | idx.append(indices[i, j]) 188 | 189 | assert len(idx) == size*size 190 | idx = idx[::-1] 191 | idx = torch.tensor(idx) 192 | self.register_buffer('forward_shuffle_idx', idx) 193 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) 194 | 195 | def forward(self, x, reverse=False): 196 | if not reverse: 197 | return x[:, self.forward_shuffle_idx] 198 | else: 199 | return x[:, self.backward_shuffle_idx] 200 | 201 | 202 | class Random(nn.Module): 203 | def __init__(self, H, W): 204 | super().__init__() 205 | indices = np.random.RandomState(1).permutation(H*W) 206 | idx = torch.tensor(indices.ravel()) 207 | self.register_buffer('forward_shuffle_idx', idx) 208 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) 209 | 210 | def forward(self, x, reverse=False): 211 | if not reverse: 212 | return x[:, self.forward_shuffle_idx] 213 | else: 214 | return x[:, self.backward_shuffle_idx] 215 | 216 | 217 | class AlternateParsing(AbstractPermuter): 218 | def __init__(self, H, W): 219 | super().__init__() 220 | indices = np.arange(W*H).reshape(H,W) 221 | for i in range(1, H, 2): 222 | indices[i, :] = indices[i, ::-1] 223 | idx = indices.flatten() 224 | assert len(idx) == H*W 225 | idx = torch.tensor(idx) 226 | self.register_buffer('forward_shuffle_idx', idx) 227 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) 228 | 229 | def forward(self, x, reverse=False): 230 | if not reverse: 231 | return x[:, self.forward_shuffle_idx] 232 | else: 233 | return x[:, self.backward_shuffle_idx] 234 | 235 | 236 | if __name__ == "__main__": 237 | p0 = AlternateParsing(16, 16) 238 | print(p0.forward_shuffle_idx) 239 | print(p0.backward_shuffle_idx) 240 | 241 | x = torch.randint(0, 768, size=(11, 256)) 242 | y = p0(x) 243 | xre = p0(y, reverse=True) 244 | assert torch.equal(x, xre) 245 | 246 | p1 = SpiralOut(2, 2) 247 | print(p1.forward_shuffle_idx) 248 | print(p1.backward_shuffle_idx) 249 | -------------------------------------------------------------------------------- /taming/modules/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def count_params(model): 6 | total_params = sum(p.numel() for p in model.parameters()) 7 | return total_params 8 | 9 | 10 | class ActNorm(nn.Module): 11 | def __init__(self, num_features, logdet=False, affine=True, 12 | allow_reverse_init=False): 13 | assert affine 14 | super().__init__() 15 | self.logdet = logdet 16 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 17 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 18 | self.allow_reverse_init = allow_reverse_init 19 | 20 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 21 | 22 | def initialize(self, input): 23 | with torch.no_grad(): 24 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 25 | mean = ( 26 | flatten.mean(1) 27 | .unsqueeze(1) 28 | .unsqueeze(2) 29 | .unsqueeze(3) 30 | .permute(1, 0, 2, 3) 31 | ) 32 | std = ( 33 | flatten.std(1) 34 | .unsqueeze(1) 35 | .unsqueeze(2) 36 | .unsqueeze(3) 37 | .permute(1, 0, 2, 3) 38 | ) 39 | 40 | self.loc.data.copy_(-mean) 41 | self.scale.data.copy_(1 / (std + 1e-6)) 42 | 43 | def forward(self, input, reverse=False): 44 | if reverse: 45 | return self.reverse(input) 46 | if len(input.shape) == 2: 47 | input = input[:,:,None,None] 48 | squeeze = True 49 | else: 50 | squeeze = False 51 | 52 | _, _, height, width = input.shape 53 | 54 | if self.training and self.initialized.item() == 0: 55 | self.initialize(input) 56 | self.initialized.fill_(1) 57 | 58 | h = self.scale * (input + self.loc) 59 | 60 | if squeeze: 61 | h = h.squeeze(-1).squeeze(-1) 62 | 63 | if self.logdet: 64 | log_abs = torch.log(torch.abs(self.scale)) 65 | logdet = height*width*torch.sum(log_abs) 66 | logdet = logdet * torch.ones(input.shape[0]).to(input) 67 | return h, logdet 68 | 69 | return h 70 | 71 | def reverse(self, output): 72 | if self.training and self.initialized.item() == 0: 73 | if not self.allow_reverse_init: 74 | raise RuntimeError( 75 | "Initializing ActNorm in reverse direction is " 76 | "disabled by default. Use allow_reverse_init=True to enable." 77 | ) 78 | else: 79 | self.initialize(output) 80 | self.initialized.fill_(1) 81 | 82 | if len(output.shape) == 2: 83 | output = output[:,:,None,None] 84 | squeeze = True 85 | else: 86 | squeeze = False 87 | 88 | h = output / self.scale - self.loc 89 | 90 | if squeeze: 91 | h = h.squeeze(-1).squeeze(-1) 92 | return h 93 | 94 | 95 | class AbstractEncoder(nn.Module): 96 | def __init__(self): 97 | super().__init__() 98 | 99 | def encode(self, *args, **kwargs): 100 | raise NotImplementedError 101 | 102 | 103 | class Labelator(AbstractEncoder): 104 | """Net2Net Interface for Class-Conditional Model""" 105 | def __init__(self, n_classes, quantize_interface=True): 106 | super().__init__() 107 | self.n_classes = n_classes 108 | self.quantize_interface = quantize_interface 109 | 110 | def encode(self, c): 111 | c = c[:,None] 112 | if self.quantize_interface: 113 | return c, None, [None, None, c.long()] 114 | return c 115 | 116 | 117 | class SOSProvider(AbstractEncoder): 118 | # for unconditional training 119 | def __init__(self, sos_token, quantize_interface=True): 120 | super().__init__() 121 | self.sos_token = sos_token 122 | self.quantize_interface = quantize_interface 123 | 124 | def encode(self, x): 125 | # get batch size from data and replicate sos_token 126 | c = torch.ones(x.shape[0], 1)*self.sos_token 127 | c = c.long().to(x.device) 128 | if self.quantize_interface: 129 | return c, None, [None, None, c] 130 | return c 131 | -------------------------------------------------------------------------------- /taming/util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 7 | } 8 | 9 | CKPT_MAP = { 10 | "vgg_lpips": "vgg.pth" 11 | } 12 | 13 | MD5_MAP = { 14 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 15 | } 16 | 17 | 18 | def download(url, local_path, chunk_size=1024): 19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 20 | with requests.get(url, stream=True) as r: 21 | total_size = int(r.headers.get("content-length", 0)) 22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 23 | with open(local_path, "wb") as f: 24 | for data in r.iter_content(chunk_size=chunk_size): 25 | if data: 26 | f.write(data) 27 | pbar.update(chunk_size) 28 | 29 | 30 | def md5_hash(path): 31 | with open(path, "rb") as f: 32 | content = f.read() 33 | return hashlib.md5(content).hexdigest() 34 | 35 | 36 | def get_ckpt_path(name, root, check=False): 37 | assert name in URL_MAP 38 | path = os.path.join(root, CKPT_MAP[name]) 39 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 40 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 41 | download(URL_MAP[name], path) 42 | md5 = md5_hash(path) 43 | assert md5 == MD5_MAP[name], md5 44 | return path 45 | 46 | 47 | class KeyNotFoundError(Exception): 48 | def __init__(self, cause, keys=None, visited=None): 49 | self.cause = cause 50 | self.keys = keys 51 | self.visited = visited 52 | messages = list() 53 | if keys is not None: 54 | messages.append("Key not found: {}".format(keys)) 55 | if visited is not None: 56 | messages.append("Visited: {}".format(visited)) 57 | messages.append("Cause:\n{}".format(cause)) 58 | message = "\n".join(messages) 59 | super().__init__(message) 60 | 61 | 62 | def retrieve( 63 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 64 | ): 65 | """Given a nested list or dict return the desired value at key expanding 66 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 67 | is done in-place. 68 | 69 | Parameters 70 | ---------- 71 | list_or_dict : list or dict 72 | Possibly nested list or dictionary. 73 | key : str 74 | key/to/value, path like string describing all keys necessary to 75 | consider to get to the desired value. List indices can also be 76 | passed here. 77 | splitval : str 78 | String that defines the delimiter between keys of the 79 | different depth levels in `key`. 80 | default : obj 81 | Value returned if :attr:`key` is not found. 82 | expand : bool 83 | Whether to expand callable nodes on the path or not. 84 | 85 | Returns 86 | ------- 87 | The desired value or if :attr:`default` is not ``None`` and the 88 | :attr:`key` is not found returns ``default``. 89 | 90 | Raises 91 | ------ 92 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 93 | ``None``. 94 | """ 95 | 96 | keys = key.split(splitval) 97 | 98 | success = True 99 | try: 100 | visited = [] 101 | parent = None 102 | last_key = None 103 | for key in keys: 104 | if callable(list_or_dict): 105 | if not expand: 106 | raise KeyNotFoundError( 107 | ValueError( 108 | "Trying to get past callable node with expand=False." 109 | ), 110 | keys=keys, 111 | visited=visited, 112 | ) 113 | list_or_dict = list_or_dict() 114 | parent[last_key] = list_or_dict 115 | 116 | last_key = key 117 | parent = list_or_dict 118 | 119 | try: 120 | if isinstance(list_or_dict, dict): 121 | list_or_dict = list_or_dict[key] 122 | else: 123 | list_or_dict = list_or_dict[int(key)] 124 | except (KeyError, IndexError, ValueError) as e: 125 | raise KeyNotFoundError(e, keys=keys, visited=visited) 126 | 127 | visited += [key] 128 | # final expansion of retrieved value 129 | if expand and callable(list_or_dict): 130 | list_or_dict = list_or_dict() 131 | parent[last_key] = list_or_dict 132 | except KeyNotFoundError as e: 133 | if default is None: 134 | raise e 135 | else: 136 | list_or_dict = default 137 | success = False 138 | 139 | if not pass_success: 140 | return list_or_dict 141 | else: 142 | return list_or_dict, success 143 | 144 | 145 | if __name__ == "__main__": 146 | config = {"keya": "a", 147 | "keyb": "b", 148 | "keyc": 149 | {"cc1": 1, 150 | "cc2": 2, 151 | } 152 | } 153 | from omegaconf import OmegaConf 154 | config = OmegaConf.create(config) 155 | print(config) 156 | retrieve(config, "keya") 157 | 158 | --------------------------------------------------------------------------------