├── CtrlColor_environ.yaml
├── LICENSE
├── README.md
├── annotator
└── util.py
├── assets
├── hf_demo1.png
├── iterative.gif
├── region.gif
├── region_cond.gif
└── teaser_aligned.png
├── cldm
├── cldm.py
├── ddim_haced_sag_step.py
├── ddim_hacked_sag.py
├── hack.py
└── model.py
├── config.py
├── ldm
├── models
│ ├── autoencoder.py
│ ├── autoencoder_train.py
│ ├── diffusion
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── ddim.cpython-38.pyc
│ │ │ ├── ddpm.cpython-38.pyc
│ │ │ └── ddpm_nonoise.cpython-38.pyc
│ │ ├── ddim.py
│ │ ├── ddpm.py
│ │ ├── dpm_solver
│ │ │ ├── __init__.py
│ │ │ ├── dpm_solver.py
│ │ │ └── sampler.py
│ │ ├── plms.py
│ │ └── sampling_util.py
│ └── logger.py
├── modules
│ ├── attention.py
│ ├── attention_dcn_control.py
│ ├── diffusionmodules
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── model.cpython-38.pyc
│ │ │ ├── model_brefore_dcn.cpython-38.pyc
│ │ │ ├── openaimodel.cpython-38.pyc
│ │ │ └── util.cpython-38.pyc
│ │ ├── model.py
│ │ ├── model_brefore_dcn.py
│ │ ├── openaimodel.py
│ │ └── util.py
│ ├── distributions
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ └── distributions.cpython-38.pyc
│ │ └── distributions.py
│ ├── ema.py
│ ├── encoders
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ └── modules.cpython-38.pyc
│ │ └── modules.py
│ ├── image_degradation
│ │ ├── __init__.py
│ │ ├── bsrgan.py
│ │ ├── bsrgan_light.py
│ │ ├── utils
│ │ │ └── test.png
│ │ └── utils_image.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── contperceptual.cpython-38.pyc
│ │ │ └── vqperceptual.cpython-38.pyc
│ │ ├── contperceptual.py
│ │ └── vqperceptual.py
│ └── midas
│ │ ├── __init__.py
│ │ ├── api.py
│ │ ├── midas
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── blocks.py
│ │ ├── dpt_depth.py
│ │ ├── midas_net.py
│ │ ├── midas_net_custom.py
│ │ ├── transforms.py
│ │ └── vit.py
│ │ └── utils.py
└── util.py
├── models
├── cldm_v15_inpainting_infer.yaml
└── cldm_v15_inpainting_infer1.yaml
├── pretrained_models
└── put_ckpts_here.txt
├── share.py
├── taming
├── data
│ ├── ade20k.py
│ ├── annotated_objects_coco.py
│ ├── annotated_objects_dataset.py
│ ├── annotated_objects_open_images.py
│ ├── base.py
│ ├── coco.py
│ ├── conditional_builder
│ │ ├── objects_bbox.py
│ │ ├── objects_center_points.py
│ │ └── utils.py
│ ├── custom.py
│ ├── faceshq.py
│ ├── helper_types.py
│ ├── image_transforms.py
│ ├── imagenet.py
│ ├── open_images_helper.py
│ ├── sflckr.py
│ └── utils.py
├── lr_scheduler.py
├── models
│ ├── cond_transformer.py
│ ├── dummy_cond_stage.py
│ └── vqgan.py
├── modules
│ ├── __pycache__
│ │ └── util.cpython-38.pyc
│ ├── autoencoder
│ │ └── lpips
│ │ │ └── vgg.pth
│ ├── diffusionmodules
│ │ └── model.py
│ ├── discriminator
│ │ ├── __pycache__
│ │ │ └── model.cpython-38.pyc
│ │ └── model.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── lpips.cpython-38.pyc
│ │ │ └── vqperceptual.cpython-38.pyc
│ │ ├── lpips.py
│ │ ├── segmentation.py
│ │ └── vqperceptual.py
│ ├── misc
│ │ └── coord.py
│ ├── transformer
│ │ ├── mingpt.py
│ │ └── permuter.py
│ ├── util.py
│ └── vqvae
│ │ └── quantize.py
└── util.py
└── test.py
/CtrlColor_environ.yaml:
--------------------------------------------------------------------------------
1 | name: CtrlColor
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.5
7 | - pip=20.3
8 | - cudatoolkit=11.3
9 | - pytorch=1.12.1
10 | - torchvision=0.13.1
11 | - numpy=1.23.1
12 | - pip:
13 | - gradio==3.31.0
14 | - gradio-client==0.2.5
15 | - albumentations==1.3.0
16 | - opencv-python==4.9.0.80
17 | - opencv-python-headless==4.5.5.64
18 | - imageio==2.9.0
19 | - imageio-ffmpeg==0.4.2
20 | - pytorch-lightning==1.5.0
21 | - omegaconf==2.1.1
22 | - test-tube>=0.7.5
23 | - streamlit==1.12.1
24 | - webdataset==0.2.5
25 | - kornia==0.6
26 | - open_clip_torch==2.0.2
27 | - invisible-watermark>=0.1.5
28 | - streamlit-drawable-canvas==0.8.0
29 | - torchmetrics==0.6.0
30 | - addict==2.4.0
31 | - yapf==0.32.0
32 | - prettytable==3.6.0
33 | - basicsr==1.4.2
34 | - salesforce-lavis==1.0.2
35 | - grpcio==1.60
36 | - pydantic==1.10.5
37 | - spacy==3.5.1
38 | - typer==0.7.0
39 | - typing-extensions==4.4.0
40 | - fastapi==0.92.0
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | S-Lab License 1.0
2 |
3 | Copyright 2022 S-Lab
4 |
5 | Redistribution and use for non-commercial purpose in source and
6 | binary forms, with or without modification, are permitted provided
7 | that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright
10 | notice, this list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright
13 | notice, this list of conditions and the following disclaimer in
14 | the documentation and/or other materials provided with the
15 | distribution.
16 |
17 | 3. Neither the name of the copyright holder nor the names of its
18 | contributors may be used to endorse or promote products derived
19 | from this software without specific prior written permission.
20 |
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 |
33 | In the event that redistribution and/or use for commercial purpose in
34 | source or binary forms, with or without modification is required,
35 | please contact the contributor(s) of the work.
36 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
Control Color: Multimodal Diffusion-based Interactive Image Colorization
4 |
5 |
12 |
13 | S-Lab, Nanyang Technological University
14 |
15 |
16 |
19 |
27 |
28 |
42 |
43 |

44 |
45 |
Control Color (CtrlColor) achieves highly controllable multimodal image colorization based on stable diffusion model.
46 |
47 |
48 |
49 |
50 |
51 |  |
52 |  |
53 |
54 |
55 | Region colorization |
56 | Iterative editing |
57 |
58 |
59 |
60 |
61 | :open_book: For more visual results and applications of CtrlColor, go checkout our project page.
62 |
63 | ---
64 |
65 |
66 |
76 |
77 |
87 |
88 | ## :mega: Updates
89 | - **2024.12.16**: The test codes (gradio demo), colorization model checkpoint, and autoencoder checkpoint are now publicly available.
90 |
91 | ## :desktop_computer: Requirements
92 |
93 | - required packages in `CtrlColor_environ.yaml`
94 |
95 | ```
96 | # git clone this repository
97 | git clone https://github.com/ZhexinLiang/Control-Color.git
98 | cd Control_Color
99 |
100 | # create new anaconda env and install python dependencies
101 | conda env create -f CtrlColor_environ.yaml
102 | conda activate CtrlColor
103 | ```
104 |
105 | ## :running_woman: Inference
106 |
107 | ### Prepare models:
108 |
109 | Please download the checkpoints of both colorization model and vae from [[Google Drive](https://drive.google.com/drive/folders/1lgqstNwrMCzymowRsbGM-4hk0-7L-eOT?usp=sharing)] and put both checkpoints in `./pretrained_models` folder.
110 |
111 | ### Testing:
112 |
113 | You can use the following cmd to run gradio demo:
114 |
115 | ```
116 | python test.py
117 | ```
118 | Then you will get our interactive interface as below:
119 |
120 |
121 |
122 | ## :love_you_gesture: Citation
123 | If you find our work useful for your research, please consider citing the paper:
124 | ```
125 | @article{liang2024control,
126 | title={Control Color: Multimodal Diffusion-based Interactive Image Colorization},
127 | author={Liang, Zhexin and Li, Zhaochen and Zhou, Shangchen and Li, Chongyi and Loy, Chen Change},
128 | journal={arXiv preprint arXiv:2402.10855},
129 | year={2024}
130 | }
131 | ```
132 |
133 | ### Contact
134 | If you have any questions, please feel free to reach out at `zhexinliang@gmail.com`.
135 |
136 |
145 |
--------------------------------------------------------------------------------
/annotator/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import os
4 |
5 |
6 | annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
7 |
8 |
9 | def HWC3(x):
10 | assert x.dtype == np.uint8
11 | if x.ndim == 2:
12 | x = x[:, :, None]
13 | assert x.ndim == 3
14 | H, W, C = x.shape
15 | assert C == 1 or C == 3 or C == 4
16 | if C == 3:
17 | return x
18 | if C == 1:
19 | return np.concatenate([x, x, x], axis=2)
20 | if C == 4:
21 | color = x[:, :, 0:3].astype(np.float32)
22 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0
23 | y = color * alpha + 255.0 * (1.0 - alpha)
24 | y = y.clip(0, 255).astype(np.uint8)
25 | return y
26 |
27 |
28 | def resize_image(input_image, resolution):
29 | H, W, C = input_image.shape
30 | H = float(H)
31 | W = float(W)
32 | k = float(resolution) / min(H, W)#min(H,W)
33 | H *= k
34 | W *= k
35 | H_new = int(np.round(H / 64.0)) * 64
36 | W_new = int(np.round(W / 64.0)) * 64
37 | H = H_new if H_new<800 else int(np.round(800 / 64.0)) * 64#1024->896
38 | W=W_new if W_new<800 else int(np.round(800 / 64.0)) * 64
39 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
40 | return img
41 |
--------------------------------------------------------------------------------
/assets/hf_demo1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/assets/hf_demo1.png
--------------------------------------------------------------------------------
/assets/iterative.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/assets/iterative.gif
--------------------------------------------------------------------------------
/assets/region.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/assets/region.gif
--------------------------------------------------------------------------------
/assets/region_cond.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/assets/region_cond.gif
--------------------------------------------------------------------------------
/assets/teaser_aligned.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/assets/teaser_aligned.png
--------------------------------------------------------------------------------
/cldm/hack.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import einops
3 |
4 | import ldm.modules.encoders.modules
5 | import ldm.modules.attention
6 |
7 | from transformers import logging
8 | from ldm.modules.attention import default
9 |
10 |
11 | def disable_verbosity():
12 | logging.set_verbosity_error()
13 | print('logging improved.')
14 | return
15 |
16 |
17 | def enable_sliced_attention():
18 | ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
19 | print('Enabled sliced_attention.')
20 | return
21 |
22 |
23 | def hack_everything(clip_skip=0):
24 | disable_verbosity()
25 | ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
26 | ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
27 | print('Enabled clip hacks.')
28 | return
29 |
30 |
31 | # Written by Lvmin
32 | def _hacked_clip_forward(self, text):
33 | PAD = self.tokenizer.pad_token_id
34 | EOS = self.tokenizer.eos_token_id
35 | BOS = self.tokenizer.bos_token_id
36 |
37 | def tokenize(t):
38 | return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
39 |
40 | def transformer_encode(t):
41 | if self.clip_skip > 1:
42 | rt = self.transformer(input_ids=t, output_hidden_states=True)
43 | return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
44 | else:
45 | return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
46 |
47 | def split(x):
48 | return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
49 |
50 | def pad(x, p, i):
51 | return x[:i] if len(x) >= i else x + [p] * (i - len(x))
52 |
53 | raw_tokens_list = tokenize(text)
54 | tokens_list = []
55 |
56 | for raw_tokens in raw_tokens_list:
57 | raw_tokens_123 = split(raw_tokens)
58 | raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
59 | raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
60 | tokens_list.append(raw_tokens_123)
61 |
62 | tokens_list = torch.IntTensor(tokens_list).to(self.device)
63 |
64 | feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
65 | y = transformer_encode(feed)
66 | z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
67 |
68 | return z
69 |
70 |
71 | # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
72 | def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
73 | h = self.heads
74 |
75 | q = self.to_q(x)
76 | context = default(context, x)
77 | k = self.to_k(context)
78 | v = self.to_v(context)
79 | del context, x
80 |
81 | q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
82 |
83 | limit = k.shape[0]
84 | att_step = 1
85 | q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
86 | k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
87 | v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
88 |
89 | q_chunks.reverse()
90 | k_chunks.reverse()
91 | v_chunks.reverse()
92 | sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
93 | del k, q, v
94 | for i in range(0, limit, att_step):
95 | q_buffer = q_chunks.pop()
96 | k_buffer = k_chunks.pop()
97 | v_buffer = v_chunks.pop()
98 | sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
99 |
100 | del k_buffer, q_buffer
101 | # attention, what we cannot get enough of, by chunks
102 |
103 | sim_buffer = sim_buffer.softmax(dim=-1)
104 |
105 | sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
106 | del v_buffer
107 | sim[i:i + att_step, :, :] = sim_buffer
108 |
109 | del sim_buffer
110 | sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
111 | return self.to_out(sim)
112 |
--------------------------------------------------------------------------------
/cldm/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | from omegaconf import OmegaConf
5 | from ldm.util import instantiate_from_config
6 |
7 |
8 | def get_state_dict(d):
9 | return d.get('state_dict', d)
10 |
11 |
12 | def load_state_dict(ckpt_path, location='cpu'):
13 | _, extension = os.path.splitext(ckpt_path)
14 | if extension.lower() == ".safetensors":
15 | import safetensors.torch
16 | state_dict = safetensors.torch.load_file(ckpt_path, device=location)
17 | else:
18 | state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
19 | state_dict = get_state_dict(state_dict)
20 | print(f'Loaded state_dict from [{ckpt_path}]')
21 | return state_dict
22 |
23 |
24 | def create_model(config_path):
25 | config = OmegaConf.load(config_path)
26 | model = instantiate_from_config(config.model).cpu()
27 | print(f'Loaded model config from [{config_path}]')
28 | return model
29 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | save_memory = False
2 |
--------------------------------------------------------------------------------
/ldm/models/autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 | import torch.nn.functional as F
4 | from contextlib import contextmanager
5 |
6 | # from ldm.modules.diffusionmodules.model_window import Encoder, Decoder
7 | from ldm.modules.diffusionmodules.model_brefore_dcn import Encoder, Decoder
8 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
9 |
10 | from ldm.util import instantiate_from_config
11 | from ldm.modules.ema import LitEma
12 |
13 |
14 | class AutoencoderKL(pl.LightningModule):
15 | def __init__(self,
16 | ddconfig,
17 | lossconfig,
18 | embed_dim,
19 | ckpt_path=None,
20 | ignore_keys=[],
21 | image_key="image",
22 | colorize_nlabels=None,
23 | monitor=None,
24 | ema_decay=None,
25 | learn_logvar=False
26 | ):
27 | super().__init__()
28 | self.learn_logvar = learn_logvar
29 | self.image_key = image_key
30 | self.encoder = Encoder(**ddconfig)
31 | self.decoder = Decoder(**ddconfig)
32 | self.loss = instantiate_from_config(lossconfig)
33 | assert ddconfig["double_z"]
34 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
35 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
36 | self.embed_dim = embed_dim
37 | if colorize_nlabels is not None:
38 | assert type(colorize_nlabels)==int
39 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
40 | if monitor is not None:
41 | self.monitor = monitor
42 |
43 | self.use_ema = ema_decay is not None
44 | if self.use_ema:
45 | self.ema_decay = ema_decay
46 | assert 0. < ema_decay < 1.
47 | self.model_ema = LitEma(self, decay=ema_decay)
48 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
49 |
50 | if ckpt_path is not None:
51 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
52 |
53 | def init_from_ckpt(self, path, ignore_keys=list()):
54 | sd = torch.load(path, map_location="cpu")["state_dict"]
55 | keys = list(sd.keys())
56 | for k in keys:
57 | for ik in ignore_keys:
58 | if k.startswith(ik):
59 | print("Deleting key {} from state_dict.".format(k))
60 | del sd[k]
61 | self.load_state_dict(sd, strict=False)
62 | print(f"Restored from {path}")
63 |
64 | @contextmanager
65 | def ema_scope(self, context=None):
66 | if self.use_ema:
67 | self.model_ema.store(self.parameters())
68 | self.model_ema.copy_to(self)
69 | if context is not None:
70 | print(f"{context}: Switched to EMA weights")
71 | try:
72 | yield None
73 | finally:
74 | if self.use_ema:
75 | self.model_ema.restore(self.parameters())
76 | if context is not None:
77 | print(f"{context}: Restored training weights")
78 |
79 | def on_train_batch_end(self, *args, **kwargs):
80 | if self.use_ema:
81 | self.model_ema(self)
82 |
83 | def encode(self, x):
84 | h = self.encoder(x)
85 | moments = self.quant_conv(h)
86 | posterior = DiagonalGaussianDistribution(moments)
87 | return posterior
88 |
89 | def decode(self, z):
90 | z = self.post_quant_conv(z)
91 | dec = self.decoder(z)
92 | return dec
93 |
94 | def forward(self, input, sample_posterior=True):
95 | posterior = self.encode(input)
96 | if sample_posterior:
97 | z = posterior.sample()
98 | else:
99 | z = posterior.mode()
100 | dec = self.decode(z)
101 | return dec, posterior
102 |
103 | def get_input(self, batch, k):
104 | x = batch[k]
105 | if len(x.shape) == 3:
106 | x = x[..., None]
107 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
108 | return x
109 |
110 | def training_step(self, batch, batch_idx, optimizer_idx):
111 | inputs = self.get_input(batch, self.image_key)
112 | reconstructions, posterior = self(inputs)
113 |
114 | if optimizer_idx == 0:
115 | # train encoder+decoder+logvar
116 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
117 | last_layer=self.get_last_layer(), split="train")
118 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
119 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
120 | return aeloss
121 |
122 | if optimizer_idx == 1:
123 | # train the discriminator
124 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
125 | last_layer=self.get_last_layer(), split="train")
126 |
127 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
128 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
129 | return discloss
130 |
131 | def validation_step(self, batch, batch_idx):
132 | log_dict = self._validation_step(batch, batch_idx)
133 | with self.ema_scope():
134 | log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
135 | return log_dict
136 |
137 | def _validation_step(self, batch, batch_idx, postfix=""):
138 | inputs = self.get_input(batch, self.image_key)
139 | reconstructions, posterior = self(inputs)
140 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
141 | last_layer=self.get_last_layer(), split="val"+postfix)
142 |
143 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
144 | last_layer=self.get_last_layer(), split="val"+postfix)
145 |
146 | self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
147 | self.log_dict(log_dict_ae)
148 | self.log_dict(log_dict_disc)
149 | return self.log_dict
150 |
151 | def configure_optimizers(self):
152 | lr = self.learning_rate
153 | ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
154 | self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
155 | if self.learn_logvar:
156 | print(f"{self.__class__.__name__}: Learning logvar")
157 | ae_params_list.append(self.loss.logvar)
158 | opt_ae = torch.optim.Adam(ae_params_list,
159 | lr=lr, betas=(0.5, 0.9))
160 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
161 | lr=lr, betas=(0.5, 0.9))
162 | return [opt_ae, opt_disc], []
163 |
164 | def get_last_layer(self):
165 | return self.decoder.conv_out.weight
166 |
167 | @torch.no_grad()
168 | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
169 | log = dict()
170 | x = self.get_input(batch, self.image_key)
171 | x = x.to(self.device)
172 | if not only_inputs:
173 | xrec, posterior = self(x)
174 | if x.shape[1] > 3:
175 | # colorize with random projection
176 | assert xrec.shape[1] > 3
177 | x = self.to_rgb(x)
178 | xrec = self.to_rgb(xrec)
179 | log["samples"] = self.decode(torch.randn_like(posterior.sample()))
180 | log["reconstructions"] = xrec
181 | if log_ema or self.use_ema:
182 | with self.ema_scope():
183 | xrec_ema, posterior_ema = self(x)
184 | if x.shape[1] > 3:
185 | # colorize with random projection
186 | assert xrec_ema.shape[1] > 3
187 | xrec_ema = self.to_rgb(xrec_ema)
188 | log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
189 | log["reconstructions_ema"] = xrec_ema
190 | log["inputs"] = x
191 | return log
192 |
193 | def to_rgb(self, x):
194 | assert self.image_key == "segmentation"
195 | if not hasattr(self, "colorize"):
196 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
197 | x = F.conv2d(x, weight=self.colorize)
198 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
199 | return x
200 |
201 |
202 | class IdentityFirstStage(torch.nn.Module):
203 | def __init__(self, *args, vq_interface=False, **kwargs):
204 | self.vq_interface = vq_interface
205 | super().__init__()
206 |
207 | def encode(self, x, *args, **kwargs):
208 | return x
209 |
210 | def decode(self, x, *args, **kwargs):
211 | return x
212 |
213 | def quantize(self, x, *args, **kwargs):
214 | if self.vq_interface:
215 | return x, None, [None, None, None]
216 | return x
217 |
218 | def forward(self, x, *args, **kwargs):
219 | return x
220 |
221 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/ddpm_nonoise.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/models/diffusion/__pycache__/ddpm_nonoise.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/__init__.py:
--------------------------------------------------------------------------------
1 | from .sampler import DPMSolverSampler
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/sampler.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 | import torch
3 |
4 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
5 |
6 |
7 | MODEL_TYPES = {
8 | "eps": "noise",
9 | "v": "v"
10 | }
11 |
12 |
13 | class DPMSolverSampler(object):
14 | def __init__(self, model, **kwargs):
15 | super().__init__()
16 | self.model = model
17 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
18 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
19 |
20 | def register_buffer(self, name, attr):
21 | if type(attr) == torch.Tensor:
22 | if attr.device != torch.device("cuda"):
23 | attr = attr.to(torch.device("cuda"))
24 | setattr(self, name, attr)
25 |
26 | @torch.no_grad()
27 | def sample(self,
28 | S,
29 | batch_size,
30 | shape,
31 | conditioning=None,
32 | callback=None,
33 | normals_sequence=None,
34 | img_callback=None,
35 | quantize_x0=False,
36 | eta=0.,
37 | mask=None,
38 | x0=None,
39 | temperature=1.,
40 | noise_dropout=0.,
41 | score_corrector=None,
42 | corrector_kwargs=None,
43 | verbose=True,
44 | x_T=None,
45 | log_every_t=100,
46 | unconditional_guidance_scale=1.,
47 | unconditional_conditioning=None,
48 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
49 | **kwargs
50 | ):
51 | if conditioning is not None:
52 | if isinstance(conditioning, dict):
53 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
54 | if cbs != batch_size:
55 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
56 | else:
57 | if conditioning.shape[0] != batch_size:
58 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
59 |
60 | # sampling
61 | C, H, W = shape
62 | size = (batch_size, C, H, W)
63 |
64 | print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
65 |
66 | device = self.model.betas.device
67 | if x_T is None:
68 | img = torch.randn(size, device=device)
69 | else:
70 | img = x_T
71 |
72 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
73 |
74 | model_fn = model_wrapper(
75 | lambda x, t, c: self.model.apply_model(x, t, c),
76 | ns,
77 | model_type=MODEL_TYPES[self.model.parameterization],
78 | guidance_type="classifier-free",
79 | condition=conditioning,
80 | unconditional_condition=unconditional_conditioning,
81 | guidance_scale=unconditional_guidance_scale,
82 | )
83 |
84 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
85 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
86 |
87 | return x.to(device), None
--------------------------------------------------------------------------------
/ldm/models/diffusion/sampling_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def append_dims(x, target_dims):
6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.
7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
8 | dims_to_append = target_dims - x.ndim
9 | if dims_to_append < 0:
10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
11 | return x[(...,) + (None,) * dims_to_append]
12 |
13 |
14 | def norm_thresholding(x0, value):
15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
16 | return x0 * (value / s)
17 |
18 |
19 | def spatial_norm_thresholding(x0, value):
20 | # b c h w
21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
22 | return x0 * (value / s)
--------------------------------------------------------------------------------
/ldm/models/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | import torchvision
6 | from PIL import Image
7 | from pytorch_lightning.callbacks import Callback
8 | from pytorch_lightning.utilities.distributed import rank_zero_only
9 |
10 | # import pdb
11 |
12 | class ImageLogger(Callback):
13 | def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True,
14 | rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
15 | log_images_kwargs=None,ckpt_dir="./ckpt"):
16 | super().__init__()
17 | self.rescale = rescale
18 | self.batch_freq = batch_frequency
19 | self.max_images = max_images
20 | if not increase_log_steps:
21 | self.log_steps = [self.batch_freq]
22 | self.clamp = clamp
23 | self.disabled = disabled
24 | self.log_on_batch_idx = log_on_batch_idx
25 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
26 | self.log_first_step = log_first_step
27 | self.ckpt_dir=ckpt_dir
28 | self.global_save_num=-2000
29 | self.global_save_num1=-100
30 |
31 | @rank_zero_only
32 | def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
33 | root = os.path.join(save_dir, "image_log", split)
34 | # print(images)
35 | for k in images:
36 | grid = torchvision.utils.make_grid(images[k], nrow=4)
37 | if self.rescale:
38 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
39 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
40 | grid = grid.numpy()
41 | grid = (grid * 255).astype(np.uint8)
42 | filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
43 | path = os.path.join(root, filename)
44 | os.makedirs(os.path.split(path)[0], exist_ok=True)
45 | Image.fromarray(grid).save(path)
46 |
47 | def log_img(self, pl_module, batch, batch_idx, split="train"):
48 | check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step
49 | if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
50 | hasattr(pl_module, "log_images") and
51 | callable(pl_module.log_images) and
52 | self.max_images > 0):
53 | logger = type(pl_module.logger)
54 |
55 | is_train = pl_module.training
56 | if is_train:
57 | pl_module.eval()
58 |
59 | with torch.no_grad():
60 | images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
61 |
62 | for k in images:
63 | N = min(images[k].shape[0], self.max_images)
64 | images[k] = images[k][:N]
65 | if isinstance(images[k], torch.Tensor):
66 | images[k] = images[k].detach().cpu()
67 | if self.clamp:
68 | images[k] = torch.clamp(images[k], -1., 1.)
69 |
70 | self.log_local(pl_module.logger.save_dir, split, images,
71 | pl_module.global_step, pl_module.current_epoch, batch_idx)
72 |
73 | if is_train:
74 | pl_module.train()
75 |
76 | def check_frequency(self, check_idx):
77 | return check_idx % self.batch_freq == 0
78 |
79 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
80 | #if not self.disabled:
81 | #if pl_module.global_step%50 == 0:
82 | # if pl_module.current_epoch-self.global_save_num1 > 0:
83 | # print(batch_idx)
84 | if batch_idx % 500 == 0:
85 | # print("inside")
86 | # pdb.set_trace()
87 | # self.global_save_num1=pl_module.current_epoch
88 | self.log_img(pl_module, batch, batch_idx, split="train_"+"ckpt_inpainting_from5625_2+3750_exemplar_only_vae")
89 | #if pl_module.global_step%1200 == 0 and self.check_frequency(batch_idx):
90 | if batch_idx % 1000 == 0:
91 | # if pl_module.current_epoch-self.global_save_num>10 and self.check_frequency(batch_idx):
92 | # self.global_save_num=pl_module.current_epoch
93 | trainer.save_checkpoint(self.ckpt_dir+"/epoch"+str(pl_module.current_epoch)+"_global-step"+str(pl_module.global_step)+".ckpt")
94 |
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/model_brefore_dcn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/diffusionmodules/__pycache__/model_brefore_dcn.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def sample_addhint(self, generator):
40 | latents = torch.randn(self.mean.shape, generator=generator, device='cpu', dtype=self.parameters.dtype).cuda()
41 | x = self.mean + self.std * latents
42 | return x
43 |
44 | def kl(self, other=None):
45 | if self.deterministic:
46 | return torch.Tensor([0.])
47 | else:
48 | if other is None:
49 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
50 | + self.var - 1.0 - self.logvar,
51 | dim=[1, 2, 3])
52 | else:
53 | return 0.5 * torch.sum(
54 | torch.pow(self.mean - other.mean, 2) / other.var
55 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
56 | dim=[1, 2, 3])
57 |
58 | def nll(self, sample, dims=[1,2,3]):
59 | if self.deterministic:
60 | return torch.Tensor([0.])
61 | logtwopi = np.log(2.0 * np.pi)
62 | return 0.5 * torch.sum(
63 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
64 | dim=dims)
65 |
66 | def mode(self):
67 | return self.mean
68 |
69 |
70 | def normal_kl(mean1, logvar1, mean2, logvar2):
71 | """
72 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
73 | Compute the KL divergence between two gaussians.
74 | Shapes are automatically broadcasted, so batches can be compared to
75 | scalars, among other use cases.
76 | """
77 | tensor = None
78 | for obj in (mean1, logvar1, mean2, logvar2):
79 | if isinstance(obj, torch.Tensor):
80 | tensor = obj
81 | break
82 | assert tensor is not None, "at least one argument must be a Tensor"
83 |
84 | # Force variances to be Tensors. Broadcasting helps convert scalars to
85 | # Tensors, but it does not work for torch.exp().
86 | logvar1, logvar2 = [
87 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
88 | for x in (logvar1, logvar2)
89 | ]
90 |
91 | return 0.5 * (
92 | -1.0
93 | + logvar2
94 | - logvar1
95 | + torch.exp(logvar1 - logvar2)
96 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
97 | )
98 |
--------------------------------------------------------------------------------
/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1, dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | # remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.', '')
20 | self.m_name2s_name.update({name: s_name})
21 | self.register_buffer(s_name, p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def reset_num_updates(self):
26 | del self.num_updates
27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
28 |
29 | def forward(self, model):
30 | decay = self.decay
31 |
32 | if self.num_updates >= 0:
33 | self.num_updates += 1
34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
35 |
36 | one_minus_decay = 1.0 - decay
37 |
38 | with torch.no_grad():
39 | m_param = dict(model.named_parameters())
40 | shadow_params = dict(self.named_buffers())
41 |
42 | for key in m_param:
43 | if m_param[key].requires_grad:
44 | sname = self.m_name2s_name[key]
45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
47 | else:
48 | assert not key in self.m_name2s_name
49 |
50 | def copy_to(self, model):
51 | m_param = dict(model.named_parameters())
52 | shadow_params = dict(self.named_buffers())
53 | for key in m_param:
54 | if m_param[key].requires_grad:
55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
56 | else:
57 | assert not key in self.m_name2s_name
58 |
59 | def store(self, parameters):
60 | """
61 | Save the current parameters for restoring later.
62 | Args:
63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
64 | temporarily stored.
65 | """
66 | self.collected_params = [param.clone() for param in parameters]
67 |
68 | def restore(self, parameters):
69 | """
70 | Restore the parameters stored with the `store` method.
71 | Useful to validate the model with EMA parameters without affecting the
72 | original optimization process. Store the parameters before the
73 | `copy_to` method. After validation (or model saving), use this to
74 | restore the former parameters.
75 | Args:
76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
77 | updated with the stored parameters.
78 | """
79 | for c_param, param in zip(self.collected_params, parameters):
80 | param.data.copy_(c_param.data)
81 |
--------------------------------------------------------------------------------
/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/utils/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/image_degradation/utils/test.png
--------------------------------------------------------------------------------
/ldm/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
2 | from ldm.modules.losses.vqperceptual import VQLPIPSWithDiscriminator
--------------------------------------------------------------------------------
/ldm/modules/losses/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/losses/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/losses/__pycache__/contperceptual.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/losses/__pycache__/contperceptual.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/losses/__pycache__/vqperceptual.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/losses/__pycache__/vqperceptual.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 |
6 | #https://github.com/IceClear/StableSR/blob/main/ldm/modules/losses/contperceptual.py
7 |
8 | class LPIPSWithDiscriminator(nn.Module):
9 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
10 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
11 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
12 | disc_loss="hinge"):
13 |
14 | super().__init__()
15 | assert disc_loss in ["hinge", "vanilla"]
16 | self.kl_weight = kl_weight
17 | self.pixel_weight = pixelloss_weight
18 | self.perceptual_loss = LPIPS().eval()
19 | self.perceptual_weight = perceptual_weight
20 | # output log variance
21 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
22 |
23 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
24 | n_layers=disc_num_layers,
25 | use_actnorm=use_actnorm
26 | ).apply(weights_init)
27 | self.discriminator_iter_start = disc_start
28 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
29 | self.disc_factor = disc_factor
30 | self.discriminator_weight = disc_weight
31 | self.disc_conditional = disc_conditional
32 |
33 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
34 | if last_layer is not None:
35 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
36 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
37 | else:
38 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
39 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
40 |
41 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
42 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
43 | d_weight = d_weight * self.discriminator_weight
44 | return d_weight
45 |
46 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
47 | global_step, last_layer=None, cond=None, split="train",
48 | weights=None, return_dic=False):
49 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
50 | if self.perceptual_weight > 0:
51 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
52 | rec_loss = rec_loss + self.perceptual_weight * p_loss
53 |
54 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
55 | weighted_nll_loss = nll_loss
56 | if weights is not None:
57 | weighted_nll_loss = weights*nll_loss
58 | weighted_nll_loss = torch.mean(weighted_nll_loss) / weighted_nll_loss.shape[0]
59 | nll_loss = torch.mean(nll_loss) / nll_loss.shape[0]
60 | if self.kl_weight>0:
61 | kl_loss = posteriors.kl()
62 | kl_loss = torch.mean(kl_loss) / kl_loss.shape[0]
63 |
64 | # now the GAN part
65 | if optimizer_idx == 0:
66 | # generator update
67 | if cond is None:
68 | assert not self.disc_conditional
69 | logits_fake = self.discriminator(reconstructions.contiguous())
70 | else:
71 | assert self.disc_conditional
72 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
73 | g_loss = -torch.mean(logits_fake)
74 |
75 | if self.disc_factor > 0.0:
76 | try:
77 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
78 | except RuntimeError:
79 | # assert not self.training
80 | d_weight = torch.tensor(1.0) * self.discriminator_weight
81 | else:
82 | # d_weight = torch.tensor(0.0)
83 | d_weight = torch.tensor(0.0)
84 |
85 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
86 | if self.kl_weight>0:
87 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
88 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
89 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
90 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
91 | "{}/d_weight".format(split): d_weight.detach(),
92 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
93 | "{}/g_loss".format(split): g_loss.detach().mean(),
94 | }
95 | if return_dic:
96 | loss_dic = {}
97 | loss_dic['total_loss'] = loss.clone().detach().mean()
98 | loss_dic['logvar'] = self.logvar.detach()
99 | loss_dic['kl_loss'] = kl_loss.detach().mean()
100 | loss_dic['nll_loss'] = nll_loss.detach().mean()
101 | loss_dic['rec_loss'] = rec_loss.detach().mean()
102 | loss_dic['d_weight'] = d_weight.detach()
103 | loss_dic['disc_factor'] = torch.tensor(disc_factor)
104 | loss_dic['g_loss'] = g_loss.detach().mean()
105 | else:
106 | loss = weighted_nll_loss + d_weight * disc_factor * g_loss
107 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
108 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
109 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
110 | "{}/d_weight".format(split): d_weight.detach(),
111 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
112 | "{}/g_loss".format(split): g_loss.detach().mean(),
113 | }
114 | if return_dic:
115 | loss_dic = {}
116 | loss_dic["{}/total_loss".format(split)] = loss.clone().detach().mean()
117 | loss_dic["{}/logvar".format(split)] = self.logvar.detach()
118 | loss_dic['nll_loss'.format(split)] = nll_loss.detach().mean()
119 | loss_dic['rec_loss'.format(split)] = rec_loss.detach().mean()
120 | loss_dic['d_weight'.format(split)] = d_weight.detach()
121 | loss_dic['disc_factor'.format(split)] = torch.tensor(disc_factor)
122 | loss_dic['g_loss'.format(split)] = g_loss.detach().mean()
123 |
124 | if return_dic:
125 | return loss, log, loss_dic
126 | return loss, log
127 |
128 | if optimizer_idx == 1:
129 | # second pass for discriminator update
130 | if cond is None:
131 | logits_real = self.discriminator(inputs.contiguous().detach())
132 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
133 | else:
134 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
135 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
136 |
137 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
138 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
139 |
140 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
141 | "{}/logits_real".format(split): logits_real.detach().mean(),
142 | "{}/logits_fake".format(split): logits_fake.detach().mean()
143 | }
144 |
145 | if return_dic:
146 | loss_dic = {}
147 | loss_dic["{}/disc_loss".format(split)] = d_loss.clone().detach().mean()
148 | loss_dic["{}/logits_real".format(split)] = logits_real.detach().mean()
149 | loss_dic["{}/logits_fake".format(split)] = logits_fake.detach().mean()
150 | return d_loss, log, loss_dic
151 |
152 | return d_loss, log
--------------------------------------------------------------------------------
/ldm/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from taming.modules.losses.lpips import LPIPS
6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7 |
8 |
9 | class DummyLoss(nn.Module):
10 | def __init__(self):
11 | super().__init__()
12 |
13 |
14 | def adopt_weight(weight, global_step, threshold=0, value=0.):
15 | if global_step < threshold:
16 | weight = value
17 | return weight
18 |
19 |
20 | def hinge_d_loss(logits_real, logits_fake):
21 | loss_real = torch.mean(F.relu(1. - logits_real))
22 | loss_fake = torch.mean(F.relu(1. + logits_fake))
23 | d_loss = 0.5 * (loss_real + loss_fake)
24 | return d_loss
25 |
26 |
27 | def vanilla_d_loss(logits_real, logits_fake):
28 | d_loss = 0.5 * (
29 | torch.mean(torch.nn.functional.softplus(-logits_real)) +
30 | torch.mean(torch.nn.functional.softplus(logits_fake)))
31 | return d_loss
32 |
33 |
34 | class VQLPIPSWithDiscriminator(nn.Module):
35 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
36 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
37 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
38 | disc_ndf=64, disc_loss="hinge"):
39 | super().__init__()
40 | assert disc_loss in ["hinge", "vanilla"]
41 | self.codebook_weight = codebook_weight
42 | self.pixel_weight = pixelloss_weight
43 | self.perceptual_loss = LPIPS().eval()
44 | self.perceptual_weight = perceptual_weight
45 |
46 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
47 | n_layers=disc_num_layers,
48 | use_actnorm=use_actnorm,
49 | ndf=disc_ndf
50 | ).apply(weights_init)
51 | self.discriminator_iter_start = disc_start
52 | if disc_loss == "hinge":
53 | self.disc_loss = hinge_d_loss
54 | elif disc_loss == "vanilla":
55 | self.disc_loss = vanilla_d_loss
56 | else:
57 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
58 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
59 | self.disc_factor = disc_factor
60 | self.discriminator_weight = disc_weight
61 | self.disc_conditional = disc_conditional
62 |
63 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
64 | if last_layer is not None:
65 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
66 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
67 | else:
68 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
69 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
70 |
71 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
72 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
73 | d_weight = d_weight * self.discriminator_weight
74 | return d_weight
75 |
76 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
77 | global_step, last_layer=None, cond=None, split="train"):
78 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
79 | if self.perceptual_weight > 0:
80 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
81 | rec_loss = rec_loss + self.perceptual_weight * p_loss
82 | else:
83 | p_loss = torch.tensor([0.0])
84 |
85 | nll_loss = rec_loss
86 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
87 | nll_loss = torch.mean(nll_loss)
88 |
89 | # now the GAN part
90 | if optimizer_idx == 0:
91 | # generator update
92 | if cond is None:
93 | assert not self.disc_conditional
94 | logits_fake = self.discriminator(reconstructions.contiguous())
95 | else:
96 | assert self.disc_conditional
97 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
98 | g_loss = -torch.mean(logits_fake)
99 |
100 | try:
101 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
102 | except RuntimeError:
103 | assert not self.training
104 | d_weight = torch.tensor(0.0)
105 |
106 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
107 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
108 |
109 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
110 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
111 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
112 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
113 | "{}/p_loss".format(split): p_loss.detach().mean(),
114 | "{}/d_weight".format(split): d_weight.detach(),
115 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
116 | "{}/g_loss".format(split): g_loss.detach().mean(),
117 | }
118 | return loss, log
119 |
120 | if optimizer_idx == 1:
121 | # second pass for discriminator update
122 | if cond is None:
123 | logits_real = self.discriminator(inputs.contiguous().detach())
124 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
125 | else:
126 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
127 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
128 |
129 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
130 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
131 |
132 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
133 | "{}/logits_real".format(split): logits_real.detach().mean(),
134 | "{}/logits_fake".format(split): logits_fake.detach().mean()
135 | }
136 | return d_loss, log
--------------------------------------------------------------------------------
/ldm/modules/midas/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/midas/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/midas/api.py:
--------------------------------------------------------------------------------
1 | # based on https://github.com/isl-org/MiDaS
2 |
3 | import cv2
4 | import torch
5 | import torch.nn as nn
6 | from torchvision.transforms import Compose
7 |
8 | from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
9 | from ldm.modules.midas.midas.midas_net import MidasNet
10 | from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
11 | from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
12 |
13 |
14 | ISL_PATHS = {
15 | "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
16 | "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
17 | "midas_v21": "",
18 | "midas_v21_small": "",
19 | }
20 |
21 |
22 | def disabled_train(self, mode=True):
23 | """Overwrite model.train with this function to make sure train/eval mode
24 | does not change anymore."""
25 | return self
26 |
27 |
28 | def load_midas_transform(model_type):
29 | # https://github.com/isl-org/MiDaS/blob/master/run.py
30 | # load transform only
31 | if model_type == "dpt_large": # DPT-Large
32 | net_w, net_h = 384, 384
33 | resize_mode = "minimal"
34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
35 |
36 | elif model_type == "dpt_hybrid": # DPT-Hybrid
37 | net_w, net_h = 384, 384
38 | resize_mode = "minimal"
39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
40 |
41 | elif model_type == "midas_v21":
42 | net_w, net_h = 384, 384
43 | resize_mode = "upper_bound"
44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
45 |
46 | elif model_type == "midas_v21_small":
47 | net_w, net_h = 256, 256
48 | resize_mode = "upper_bound"
49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
50 |
51 | else:
52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
53 |
54 | transform = Compose(
55 | [
56 | Resize(
57 | net_w,
58 | net_h,
59 | resize_target=None,
60 | keep_aspect_ratio=True,
61 | ensure_multiple_of=32,
62 | resize_method=resize_mode,
63 | image_interpolation_method=cv2.INTER_CUBIC,
64 | ),
65 | normalization,
66 | PrepareForNet(),
67 | ]
68 | )
69 |
70 | return transform
71 |
72 |
73 | def load_model(model_type):
74 | # https://github.com/isl-org/MiDaS/blob/master/run.py
75 | # load network
76 | model_path = ISL_PATHS[model_type]
77 | if model_type == "dpt_large": # DPT-Large
78 | model = DPTDepthModel(
79 | path=model_path,
80 | backbone="vitl16_384",
81 | non_negative=True,
82 | )
83 | net_w, net_h = 384, 384
84 | resize_mode = "minimal"
85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
86 |
87 | elif model_type == "dpt_hybrid": # DPT-Hybrid
88 | model = DPTDepthModel(
89 | path=model_path,
90 | backbone="vitb_rn50_384",
91 | non_negative=True,
92 | )
93 | net_w, net_h = 384, 384
94 | resize_mode = "minimal"
95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
96 |
97 | elif model_type == "midas_v21":
98 | model = MidasNet(model_path, non_negative=True)
99 | net_w, net_h = 384, 384
100 | resize_mode = "upper_bound"
101 | normalization = NormalizeImage(
102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
103 | )
104 |
105 | elif model_type == "midas_v21_small":
106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
107 | non_negative=True, blocks={'expand': True})
108 | net_w, net_h = 256, 256
109 | resize_mode = "upper_bound"
110 | normalization = NormalizeImage(
111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
112 | )
113 |
114 | else:
115 | print(f"model_type '{model_type}' not implemented, use: --model_type large")
116 | assert False
117 |
118 | transform = Compose(
119 | [
120 | Resize(
121 | net_w,
122 | net_h,
123 | resize_target=None,
124 | keep_aspect_ratio=True,
125 | ensure_multiple_of=32,
126 | resize_method=resize_mode,
127 | image_interpolation_method=cv2.INTER_CUBIC,
128 | ),
129 | normalization,
130 | PrepareForNet(),
131 | ]
132 | )
133 |
134 | return model.eval(), transform
135 |
136 |
137 | class MiDaSInference(nn.Module):
138 | MODEL_TYPES_TORCH_HUB = [
139 | "DPT_Large",
140 | "DPT_Hybrid",
141 | "MiDaS_small"
142 | ]
143 | MODEL_TYPES_ISL = [
144 | "dpt_large",
145 | "dpt_hybrid",
146 | "midas_v21",
147 | "midas_v21_small",
148 | ]
149 |
150 | def __init__(self, model_type):
151 | super().__init__()
152 | assert (model_type in self.MODEL_TYPES_ISL)
153 | model, _ = load_model(model_type)
154 | self.model = model
155 | self.model.train = disabled_train
156 |
157 | def forward(self, x):
158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
159 | # NOTE: we expect that the correct transform has been called during dataloading.
160 | with torch.no_grad():
161 | prediction = self.model(x)
162 | prediction = torch.nn.functional.interpolate(
163 | prediction.unsqueeze(1),
164 | size=x.shape[2:],
165 | mode="bicubic",
166 | align_corners=False,
167 | )
168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
169 | return prediction
170 |
171 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/ldm/modules/midas/midas/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class BaseModel(torch.nn.Module):
5 | def load(self, path):
6 | """Load model from file.
7 |
8 | Args:
9 | path (str): file path
10 | """
11 | parameters = torch.load(path, map_location=torch.device('cpu'))
12 |
13 | if "optimizer" in parameters:
14 | parameters = parameters["model"]
15 |
16 | self.load_state_dict(parameters)
17 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/dpt_depth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .base_model import BaseModel
6 | from .blocks import (
7 | FeatureFusionBlock,
8 | FeatureFusionBlock_custom,
9 | Interpolate,
10 | _make_encoder,
11 | forward_vit,
12 | )
13 |
14 |
15 | def _make_fusion_block(features, use_bn):
16 | return FeatureFusionBlock_custom(
17 | features,
18 | nn.ReLU(False),
19 | deconv=False,
20 | bn=use_bn,
21 | expand=False,
22 | align_corners=True,
23 | )
24 |
25 |
26 | class DPT(BaseModel):
27 | def __init__(
28 | self,
29 | head,
30 | features=256,
31 | backbone="vitb_rn50_384",
32 | readout="project",
33 | channels_last=False,
34 | use_bn=False,
35 | ):
36 |
37 | super(DPT, self).__init__()
38 |
39 | self.channels_last = channels_last
40 |
41 | hooks = {
42 | "vitb_rn50_384": [0, 1, 8, 11],
43 | "vitb16_384": [2, 5, 8, 11],
44 | "vitl16_384": [5, 11, 17, 23],
45 | }
46 |
47 | # Instantiate backbone and reassemble blocks
48 | self.pretrained, self.scratch = _make_encoder(
49 | backbone,
50 | features,
51 | False, # Set to true of you want to train from scratch, uses ImageNet weights
52 | groups=1,
53 | expand=False,
54 | exportable=False,
55 | hooks=hooks[backbone],
56 | use_readout=readout,
57 | )
58 |
59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
63 |
64 | self.scratch.output_conv = head
65 |
66 |
67 | def forward(self, x):
68 | if self.channels_last == True:
69 | x.contiguous(memory_format=torch.channels_last)
70 |
71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
72 |
73 | layer_1_rn = self.scratch.layer1_rn(layer_1)
74 | layer_2_rn = self.scratch.layer2_rn(layer_2)
75 | layer_3_rn = self.scratch.layer3_rn(layer_3)
76 | layer_4_rn = self.scratch.layer4_rn(layer_4)
77 |
78 | path_4 = self.scratch.refinenet4(layer_4_rn)
79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
82 |
83 | out = self.scratch.output_conv(path_1)
84 |
85 | return out
86 |
87 |
88 | class DPTDepthModel(DPT):
89 | def __init__(self, path=None, non_negative=True, **kwargs):
90 | features = kwargs["features"] if "features" in kwargs else 256
91 |
92 | head = nn.Sequential(
93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
96 | nn.ReLU(True),
97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
98 | nn.ReLU(True) if non_negative else nn.Identity(),
99 | nn.Identity(),
100 | )
101 |
102 | super().__init__(head, **kwargs)
103 |
104 | if path is not None:
105 | self.load(path)
106 |
107 | def forward(self, x):
108 | return super().forward(x).squeeze(dim=1)
109 |
110 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/midas_net.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=256, non_negative=True):
17 | """Init.
18 |
19 | Args:
20 | path (str, optional): Path to saved model. Defaults to None.
21 | features (int, optional): Number of features. Defaults to 256.
22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23 | """
24 | print("Loading weights: ", path)
25 |
26 | super(MidasNet, self).__init__()
27 |
28 | use_pretrained = False if path is None else True
29 |
30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31 |
32 | self.scratch.refinenet4 = FeatureFusionBlock(features)
33 | self.scratch.refinenet3 = FeatureFusionBlock(features)
34 | self.scratch.refinenet2 = FeatureFusionBlock(features)
35 | self.scratch.refinenet1 = FeatureFusionBlock(features)
36 |
37 | self.scratch.output_conv = nn.Sequential(
38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39 | Interpolate(scale_factor=2, mode="bilinear"),
40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41 | nn.ReLU(True),
42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43 | nn.ReLU(True) if non_negative else nn.Identity(),
44 | )
45 |
46 | if path:
47 | self.load(path)
48 |
49 | def forward(self, x):
50 | """Forward pass.
51 |
52 | Args:
53 | x (tensor): input data (image)
54 |
55 | Returns:
56 | tensor: depth
57 | """
58 |
59 | layer_1 = self.pretrained.layer1(x)
60 | layer_2 = self.pretrained.layer2(layer_1)
61 | layer_3 = self.pretrained.layer3(layer_2)
62 | layer_4 = self.pretrained.layer4(layer_3)
63 |
64 | layer_1_rn = self.scratch.layer1_rn(layer_1)
65 | layer_2_rn = self.scratch.layer2_rn(layer_2)
66 | layer_3_rn = self.scratch.layer3_rn(layer_3)
67 | layer_4_rn = self.scratch.layer4_rn(layer_4)
68 |
69 | path_4 = self.scratch.refinenet4(layer_4_rn)
70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73 |
74 | out = self.scratch.output_conv(path_1)
75 |
76 | return torch.squeeze(out, dim=1)
77 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/midas_net_custom.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet_small(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17 | blocks={'expand': True}):
18 | """Init.
19 |
20 | Args:
21 | path (str, optional): Path to saved model. Defaults to None.
22 | features (int, optional): Number of features. Defaults to 256.
23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24 | """
25 | print("Loading weights: ", path)
26 |
27 | super(MidasNet_small, self).__init__()
28 |
29 | use_pretrained = False if path else True
30 |
31 | self.channels_last = channels_last
32 | self.blocks = blocks
33 | self.backbone = backbone
34 |
35 | self.groups = 1
36 |
37 | features1=features
38 | features2=features
39 | features3=features
40 | features4=features
41 | self.expand = False
42 | if "expand" in self.blocks and self.blocks['expand'] == True:
43 | self.expand = True
44 | features1=features
45 | features2=features*2
46 | features3=features*4
47 | features4=features*8
48 |
49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50 |
51 | self.scratch.activation = nn.ReLU(False)
52 |
53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57 |
58 |
59 | self.scratch.output_conv = nn.Sequential(
60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61 | Interpolate(scale_factor=2, mode="bilinear"),
62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63 | self.scratch.activation,
64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65 | nn.ReLU(True) if non_negative else nn.Identity(),
66 | nn.Identity(),
67 | )
68 |
69 | if path:
70 | self.load(path)
71 |
72 |
73 | def forward(self, x):
74 | """Forward pass.
75 |
76 | Args:
77 | x (tensor): input data (image)
78 |
79 | Returns:
80 | tensor: depth
81 | """
82 | if self.channels_last==True:
83 | print("self.channels_last = ", self.channels_last)
84 | x.contiguous(memory_format=torch.channels_last)
85 |
86 |
87 | layer_1 = self.pretrained.layer1(x)
88 | layer_2 = self.pretrained.layer2(layer_1)
89 | layer_3 = self.pretrained.layer3(layer_2)
90 | layer_4 = self.pretrained.layer4(layer_3)
91 |
92 | layer_1_rn = self.scratch.layer1_rn(layer_1)
93 | layer_2_rn = self.scratch.layer2_rn(layer_2)
94 | layer_3_rn = self.scratch.layer3_rn(layer_3)
95 | layer_4_rn = self.scratch.layer4_rn(layer_4)
96 |
97 |
98 | path_4 = self.scratch.refinenet4(layer_4_rn)
99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102 |
103 | out = self.scratch.output_conv(path_1)
104 |
105 | return torch.squeeze(out, dim=1)
106 |
107 |
108 |
109 | def fuse_model(m):
110 | prev_previous_type = nn.Identity()
111 | prev_previous_name = ''
112 | previous_type = nn.Identity()
113 | previous_name = ''
114 | for name, module in m.named_modules():
115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116 | # print("FUSED ", prev_previous_name, previous_name, name)
117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119 | # print("FUSED ", prev_previous_name, previous_name)
120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122 | # print("FUSED ", previous_name, name)
123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124 |
125 | prev_previous_type = previous_type
126 | prev_previous_name = previous_name
127 | previous_type = type(module)
128 | previous_name = name
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import math
4 |
5 |
6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7 | """Rezise the sample to ensure the given size. Keeps aspect ratio.
8 |
9 | Args:
10 | sample (dict): sample
11 | size (tuple): image size
12 |
13 | Returns:
14 | tuple: new size
15 | """
16 | shape = list(sample["disparity"].shape)
17 |
18 | if shape[0] >= size[0] and shape[1] >= size[1]:
19 | return sample
20 |
21 | scale = [0, 0]
22 | scale[0] = size[0] / shape[0]
23 | scale[1] = size[1] / shape[1]
24 |
25 | scale = max(scale)
26 |
27 | shape[0] = math.ceil(scale * shape[0])
28 | shape[1] = math.ceil(scale * shape[1])
29 |
30 | # resize
31 | sample["image"] = cv2.resize(
32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33 | )
34 |
35 | sample["disparity"] = cv2.resize(
36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37 | )
38 | sample["mask"] = cv2.resize(
39 | sample["mask"].astype(np.float32),
40 | tuple(shape[::-1]),
41 | interpolation=cv2.INTER_NEAREST,
42 | )
43 | sample["mask"] = sample["mask"].astype(bool)
44 |
45 | return tuple(shape)
46 |
47 |
48 | class Resize(object):
49 | """Resize sample to given size (width, height).
50 | """
51 |
52 | def __init__(
53 | self,
54 | width,
55 | height,
56 | resize_target=True,
57 | keep_aspect_ratio=False,
58 | ensure_multiple_of=1,
59 | resize_method="lower_bound",
60 | image_interpolation_method=cv2.INTER_AREA,
61 | ):
62 | """Init.
63 |
64 | Args:
65 | width (int): desired output width
66 | height (int): desired output height
67 | resize_target (bool, optional):
68 | True: Resize the full sample (image, mask, target).
69 | False: Resize image only.
70 | Defaults to True.
71 | keep_aspect_ratio (bool, optional):
72 | True: Keep the aspect ratio of the input sample.
73 | Output sample might not have the given width and height, and
74 | resize behaviour depends on the parameter 'resize_method'.
75 | Defaults to False.
76 | ensure_multiple_of (int, optional):
77 | Output width and height is constrained to be multiple of this parameter.
78 | Defaults to 1.
79 | resize_method (str, optional):
80 | "lower_bound": Output will be at least as large as the given size.
81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.)
83 | Defaults to "lower_bound".
84 | """
85 | self.__width = width
86 | self.__height = height
87 |
88 | self.__resize_target = resize_target
89 | self.__keep_aspect_ratio = keep_aspect_ratio
90 | self.__multiple_of = ensure_multiple_of
91 | self.__resize_method = resize_method
92 | self.__image_interpolation_method = image_interpolation_method
93 |
94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
96 |
97 | if max_val is not None and y > max_val:
98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
99 |
100 | if y < min_val:
101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
102 |
103 | return y
104 |
105 | def get_size(self, width, height):
106 | # determine new height and width
107 | scale_height = self.__height / height
108 | scale_width = self.__width / width
109 |
110 | if self.__keep_aspect_ratio:
111 | if self.__resize_method == "lower_bound":
112 | # scale such that output size is lower bound
113 | if scale_width > scale_height:
114 | # fit width
115 | scale_height = scale_width
116 | else:
117 | # fit height
118 | scale_width = scale_height
119 | elif self.__resize_method == "upper_bound":
120 | # scale such that output size is upper bound
121 | if scale_width < scale_height:
122 | # fit width
123 | scale_height = scale_width
124 | else:
125 | # fit height
126 | scale_width = scale_height
127 | elif self.__resize_method == "minimal":
128 | # scale as least as possbile
129 | if abs(1 - scale_width) < abs(1 - scale_height):
130 | # fit width
131 | scale_height = scale_width
132 | else:
133 | # fit height
134 | scale_width = scale_height
135 | else:
136 | raise ValueError(
137 | f"resize_method {self.__resize_method} not implemented"
138 | )
139 |
140 | if self.__resize_method == "lower_bound":
141 | new_height = self.constrain_to_multiple_of(
142 | scale_height * height, min_val=self.__height
143 | )
144 | new_width = self.constrain_to_multiple_of(
145 | scale_width * width, min_val=self.__width
146 | )
147 | elif self.__resize_method == "upper_bound":
148 | new_height = self.constrain_to_multiple_of(
149 | scale_height * height, max_val=self.__height
150 | )
151 | new_width = self.constrain_to_multiple_of(
152 | scale_width * width, max_val=self.__width
153 | )
154 | elif self.__resize_method == "minimal":
155 | new_height = self.constrain_to_multiple_of(scale_height * height)
156 | new_width = self.constrain_to_multiple_of(scale_width * width)
157 | else:
158 | raise ValueError(f"resize_method {self.__resize_method} not implemented")
159 |
160 | return (new_width, new_height)
161 |
162 | def __call__(self, sample):
163 | width, height = self.get_size(
164 | sample["image"].shape[1], sample["image"].shape[0]
165 | )
166 |
167 | # resize sample
168 | sample["image"] = cv2.resize(
169 | sample["image"],
170 | (width, height),
171 | interpolation=self.__image_interpolation_method,
172 | )
173 |
174 | if self.__resize_target:
175 | if "disparity" in sample:
176 | sample["disparity"] = cv2.resize(
177 | sample["disparity"],
178 | (width, height),
179 | interpolation=cv2.INTER_NEAREST,
180 | )
181 |
182 | if "depth" in sample:
183 | sample["depth"] = cv2.resize(
184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
185 | )
186 |
187 | sample["mask"] = cv2.resize(
188 | sample["mask"].astype(np.float32),
189 | (width, height),
190 | interpolation=cv2.INTER_NEAREST,
191 | )
192 | sample["mask"] = sample["mask"].astype(bool)
193 |
194 | return sample
195 |
196 |
197 | class NormalizeImage(object):
198 | """Normlize image by given mean and std.
199 | """
200 |
201 | def __init__(self, mean, std):
202 | self.__mean = mean
203 | self.__std = std
204 |
205 | def __call__(self, sample):
206 | sample["image"] = (sample["image"] - self.__mean) / self.__std
207 |
208 | return sample
209 |
210 |
211 | class PrepareForNet(object):
212 | """Prepare sample for usage as network input.
213 | """
214 |
215 | def __init__(self):
216 | pass
217 |
218 | def __call__(self, sample):
219 | image = np.transpose(sample["image"], (2, 0, 1))
220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32)
221 |
222 | if "mask" in sample:
223 | sample["mask"] = sample["mask"].astype(np.float32)
224 | sample["mask"] = np.ascontiguousarray(sample["mask"])
225 |
226 | if "disparity" in sample:
227 | disparity = sample["disparity"].astype(np.float32)
228 | sample["disparity"] = np.ascontiguousarray(disparity)
229 |
230 | if "depth" in sample:
231 | depth = sample["depth"].astype(np.float32)
232 | sample["depth"] = np.ascontiguousarray(depth)
233 |
234 | return sample
235 |
--------------------------------------------------------------------------------
/ldm/modules/midas/utils.py:
--------------------------------------------------------------------------------
1 | """Utils for monoDepth."""
2 | import sys
3 | import re
4 | import numpy as np
5 | import cv2
6 | import torch
7 |
8 |
9 | def read_pfm(path):
10 | """Read pfm file.
11 |
12 | Args:
13 | path (str): path to file
14 |
15 | Returns:
16 | tuple: (data, scale)
17 | """
18 | with open(path, "rb") as file:
19 |
20 | color = None
21 | width = None
22 | height = None
23 | scale = None
24 | endian = None
25 |
26 | header = file.readline().rstrip()
27 | if header.decode("ascii") == "PF":
28 | color = True
29 | elif header.decode("ascii") == "Pf":
30 | color = False
31 | else:
32 | raise Exception("Not a PFM file: " + path)
33 |
34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
35 | if dim_match:
36 | width, height = list(map(int, dim_match.groups()))
37 | else:
38 | raise Exception("Malformed PFM header.")
39 |
40 | scale = float(file.readline().decode("ascii").rstrip())
41 | if scale < 0:
42 | # little-endian
43 | endian = "<"
44 | scale = -scale
45 | else:
46 | # big-endian
47 | endian = ">"
48 |
49 | data = np.fromfile(file, endian + "f")
50 | shape = (height, width, 3) if color else (height, width)
51 |
52 | data = np.reshape(data, shape)
53 | data = np.flipud(data)
54 |
55 | return data, scale
56 |
57 |
58 | def write_pfm(path, image, scale=1):
59 | """Write pfm file.
60 |
61 | Args:
62 | path (str): pathto file
63 | image (array): data
64 | scale (int, optional): Scale. Defaults to 1.
65 | """
66 |
67 | with open(path, "wb") as file:
68 | color = None
69 |
70 | if image.dtype.name != "float32":
71 | raise Exception("Image dtype must be float32.")
72 |
73 | image = np.flipud(image)
74 |
75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image
76 | color = True
77 | elif (
78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
79 | ): # greyscale
80 | color = False
81 | else:
82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
83 |
84 | file.write("PF\n" if color else "Pf\n".encode())
85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
86 |
87 | endian = image.dtype.byteorder
88 |
89 | if endian == "<" or endian == "=" and sys.byteorder == "little":
90 | scale = -scale
91 |
92 | file.write("%f\n".encode() % scale)
93 |
94 | image.tofile(file)
95 |
96 |
97 | def read_image(path):
98 | """Read image and output RGB image (0-1).
99 |
100 | Args:
101 | path (str): path to file
102 |
103 | Returns:
104 | array: RGB image (0-1)
105 | """
106 | img = cv2.imread(path)
107 |
108 | if img.ndim == 2:
109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
110 |
111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
112 |
113 | return img
114 |
115 |
116 | def resize_image(img):
117 | """Resize image and make it fit for network.
118 |
119 | Args:
120 | img (array): image
121 |
122 | Returns:
123 | tensor: data ready for network
124 | """
125 | height_orig = img.shape[0]
126 | width_orig = img.shape[1]
127 |
128 | if width_orig > height_orig:
129 | scale = width_orig / 384
130 | else:
131 | scale = height_orig / 384
132 |
133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
135 |
136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
137 |
138 | img_resized = (
139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
140 | )
141 | img_resized = img_resized.unsqueeze(0)
142 |
143 | return img_resized
144 |
145 |
146 | def resize_depth(depth, width, height):
147 | """Resize depth map and bring to CPU (numpy).
148 |
149 | Args:
150 | depth (tensor): depth
151 | width (int): image width
152 | height (int): image height
153 |
154 | Returns:
155 | array: processed depth
156 | """
157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
158 |
159 | depth_resized = cv2.resize(
160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
161 | )
162 |
163 | return depth_resized
164 |
165 | def write_depth(path, depth, bits=1):
166 | """Write depth map to pfm and png file.
167 |
168 | Args:
169 | path (str): filepath without extension
170 | depth (array): depth
171 | """
172 | write_pfm(path + ".pfm", depth.astype(np.float32))
173 |
174 | depth_min = depth.min()
175 | depth_max = depth.max()
176 |
177 | max_val = (2**(8*bits))-1
178 |
179 | if depth_max - depth_min > np.finfo("float").eps:
180 | out = max_val * (depth - depth_min) / (depth_max - depth_min)
181 | else:
182 | out = np.zeros(depth.shape, dtype=depth.type)
183 |
184 | if bits == 1:
185 | cv2.imwrite(path + ".png", out.astype("uint8"))
186 | elif bits == 2:
187 | cv2.imwrite(path + ".png", out.astype("uint16"))
188 |
189 | return
190 |
--------------------------------------------------------------------------------
/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | from torch import optim
5 | import numpy as np
6 |
7 | from inspect import isfunction
8 | from PIL import Image, ImageDraw, ImageFont
9 | from torchvision import transforms
10 | import cv2
11 |
12 | # def get_hint_image(image_withmask):
13 | # image=(image_withmask.squeeze(0)[:3,:,:]+1.)/2.
14 | # image_gray=cv2.cvtColor(np.asarray(image.permute(1,2,0).cpu()),cv2.COLOR_RGB2LAB)[:,:,0]
15 | # image_gray = torch.from_numpy(cv2.merge([image_gray,image_gray,image_gray])).permute(2,0,1)
16 | # mask=(image_withmask.squeeze(0)[3,:,:]+1.)/2.
17 | # H,W=mask.shape
18 | # for i in range(H):
19 | # for j in range(W):
20 | # if mask[i,j]==0:
21 | # image[:,i,j]=image_gray[:,i,j] #torch.mean(image[:,i,j]) #image_gray[:,i,j]
22 | # return image
23 |
24 | def get_hint_image(image,image_gray,mask):
25 | # image=(image_withmask.squeeze(0)[:3,:,:]+1.)/2.
26 | # image_gray=cv2.cvtColor(np.asarray(image.permute(1,2,0).cpu()),cv2.COLOR_RGB2LAB)[:,:,0]
27 | # image_gray = torch.from_numpy(cv2.merge([image_gray,image_gray,image_gray])).permute(2,0,1)
28 | # mask=(image_withmask.squeeze(0)[3,:,:]+1.)/2.
29 | image=np.array(image.copy())
30 | image_gray=np.array(image_gray.copy())
31 | H,W=mask.shape
32 | for i in range(H):
33 | for j in range(W):
34 | if mask[i,j]==0:
35 | image[i,j]=image_gray[i,j] #torch.mean(image[:,i,j]) #image_gray[:,i,j]
36 | return Image.fromarray(image)
37 |
38 | def log_txt_as_img(wh,masked_image, xc, size=10):
39 | # wh a tuple of (width, height)
40 | # xc a list of captions to plot
41 | xc=xc
42 | b = len(xc)
43 | txts = list()
44 | for bi in range(b):
45 | txt = Image.new("RGB", wh, color="white")
46 | # image=(image_withmask.squeeze(0)[:3,:,:]+1.)/2.
47 | # mask=(image_withmask.squeeze(0)[3,:,:]+1.)/2.
48 | # image=(image_withmask+1.)/2.
49 | # # image = get_hint_image(image_withmask)
50 | # # print(image.shape)
51 | # image_target=transforms.ToPILImage()(image.squeeze(0)).convert("RGB")
52 | # # image_gray=transforms.ToPILImage()(image).convert("L")
53 | image=(masked_image.squeeze(0)+1.)/2.
54 | image_target=transforms.ToPILImage()(image.squeeze(0)).convert("RGB")
55 | txt = image_target#get_hint_image(image_target,image_gray,mask)
56 | draw = ImageDraw.Draw(txt)
57 | font = ImageFont.truetype('font/DejaVuSans.ttf', size=size)
58 | nc = int(40 * (wh[0] / 256))
59 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
60 |
61 | try:
62 | draw.text((0, 0), lines, fill="black", font=font)
63 | except UnicodeEncodeError:
64 | print("Cant encode string for logging. Skipping.")
65 |
66 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
67 | txts.append(txt)
68 | txts = np.stack(txts)
69 | txts = torch.tensor(txts)
70 | return txts
71 |
72 |
73 | def ismap(x):
74 | if not isinstance(x, torch.Tensor):
75 | return False
76 | return (len(x.shape) == 4) and (x.shape[1] > 3)
77 |
78 |
79 | def isimage(x):
80 | if not isinstance(x,torch.Tensor):
81 | return False
82 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
83 |
84 |
85 | def exists(x):
86 | return x is not None
87 |
88 |
89 | def default(val, d):
90 | if exists(val):
91 | return val
92 | return d() if isfunction(d) else d
93 |
94 |
95 | def mean_flat(tensor):
96 | """
97 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
98 | Take the mean over all non-batch dimensions.
99 | """
100 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
101 |
102 |
103 | def count_params(model, verbose=False):
104 | total_params = sum(p.numel() for p in model.parameters())
105 | if verbose:
106 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
107 | return total_params
108 |
109 |
110 | def instantiate_from_config(config):
111 | if not "target" in config:
112 | if not config == '__is_first_stage__':#changed for only training vae
113 | return None
114 | # elif config == "__is_unconditional__":#changed for only training vae
115 | # return None
116 | raise KeyError("Expected key `target` to instantiate.")
117 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
118 |
119 |
120 | def get_obj_from_str(string, reload=False):
121 | module, cls = string.rsplit(".", 1)
122 | if reload:
123 | module_imp = importlib.import_module(module)
124 | importlib.reload(module_imp)
125 | return getattr(importlib.import_module(module, package=None), cls)
126 |
127 |
128 | class AdamWwithEMAandWings(optim.Optimizer):
129 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
130 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
131 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
132 | ema_power=1., param_names=()):
133 | """AdamW that saves EMA versions of the parameters."""
134 | if not 0.0 <= lr:
135 | raise ValueError("Invalid learning rate: {}".format(lr))
136 | if not 0.0 <= eps:
137 | raise ValueError("Invalid epsilon value: {}".format(eps))
138 | if not 0.0 <= betas[0] < 1.0:
139 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
140 | if not 0.0 <= betas[1] < 1.0:
141 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
142 | if not 0.0 <= weight_decay:
143 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
144 | if not 0.0 <= ema_decay <= 1.0:
145 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
146 | defaults = dict(lr=lr, betas=betas, eps=eps,
147 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
148 | ema_power=ema_power, param_names=param_names)
149 | super().__init__(params, defaults)
150 |
151 | def __setstate__(self, state):
152 | super().__setstate__(state)
153 | for group in self.param_groups:
154 | group.setdefault('amsgrad', False)
155 |
156 | @torch.no_grad()
157 | def step(self, closure=None):
158 | """Performs a single optimization step.
159 | Args:
160 | closure (callable, optional): A closure that reevaluates the model
161 | and returns the loss.
162 | """
163 | loss = None
164 | if closure is not None:
165 | with torch.enable_grad():
166 | loss = closure()
167 |
168 | for group in self.param_groups:
169 | params_with_grad = []
170 | grads = []
171 | exp_avgs = []
172 | exp_avg_sqs = []
173 | ema_params_with_grad = []
174 | state_sums = []
175 | max_exp_avg_sqs = []
176 | state_steps = []
177 | amsgrad = group['amsgrad']
178 | beta1, beta2 = group['betas']
179 | ema_decay = group['ema_decay']
180 | ema_power = group['ema_power']
181 |
182 | for p in group['params']:
183 | if p.grad is None:
184 | continue
185 | params_with_grad.append(p)
186 | if p.grad.is_sparse:
187 | raise RuntimeError('AdamW does not support sparse gradients')
188 | grads.append(p.grad)
189 |
190 | state = self.state[p]
191 |
192 | # State initialization
193 | if len(state) == 0:
194 | state['step'] = 0
195 | # Exponential moving average of gradient values
196 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
197 | # Exponential moving average of squared gradient values
198 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
199 | if amsgrad:
200 | # Maintains max of all exp. moving avg. of sq. grad. values
201 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
202 | # Exponential moving average of parameter values
203 | state['param_exp_avg'] = p.detach().float().clone()
204 |
205 | exp_avgs.append(state['exp_avg'])
206 | exp_avg_sqs.append(state['exp_avg_sq'])
207 | ema_params_with_grad.append(state['param_exp_avg'])
208 |
209 | if amsgrad:
210 | max_exp_avg_sqs.append(state['max_exp_avg_sq'])
211 |
212 | # update the steps for each param group update
213 | state['step'] += 1
214 | # record the step after step update
215 | state_steps.append(state['step'])
216 |
217 | optim._functional.adamw(params_with_grad,
218 | grads,
219 | exp_avgs,
220 | exp_avg_sqs,
221 | max_exp_avg_sqs,
222 | state_steps,
223 | amsgrad=amsgrad,
224 | beta1=beta1,
225 | beta2=beta2,
226 | lr=group['lr'],
227 | weight_decay=group['weight_decay'],
228 | eps=group['eps'],
229 | maximize=False)
230 |
231 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
232 | for param, ema_param in zip(params_with_grad, ema_params_with_grad):
233 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
234 |
235 | return loss
--------------------------------------------------------------------------------
/models/cldm_v15_inpainting_infer.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: cldm.cldm.ControlLDM
3 | params:
4 | linear_start: 0.00085
5 | linear_end: 0.0120
6 | num_timesteps_cond: 1
7 | log_every_t: 200
8 | timesteps: 1000
9 | first_stage_key: "jpg"
10 | cond_stage_key: "txt"
11 | control_key: "hint"
12 | masked_image: "mask_img"
13 | mask: "mask"
14 | image_size: 64
15 | channels: 4
16 | cond_stage_trainable: false
17 | conditioning_key: crossattn
18 | monitor: val/loss_simple_ema
19 | scale_factor: 0.18215
20 | use_ema: False
21 | only_mid_control: False
22 | load_loss: False
23 |
24 | control_stage_config:
25 | target: cldm.cldm.ControlNet
26 | params:
27 | image_size: 32 # unused
28 | in_channels: 4
29 | hint_channels: 3
30 | model_channels: 320
31 | attention_resolutions: [ 4, 2, 1 ]
32 | num_res_blocks: 2
33 | channel_mult: [ 1, 2, 4, 4 ]
34 | num_heads: 8
35 | use_spatial_transformer: True
36 | transformer_depth: 1
37 | context_dim: 768
38 | use_checkpoint: True
39 | legacy: False
40 |
41 | unet_config:
42 | target: cldm.cldm.ControlledUnetModel
43 | params:
44 | image_size: 32 # unused
45 | in_channels: 9
46 | out_channels: 4
47 | model_channels: 320
48 | attention_resolutions: [ 4, 2, 1 ]
49 | num_res_blocks: 2
50 | channel_mult: [ 1, 2, 4, 4 ]
51 | num_heads: 8
52 | use_spatial_transformer: True
53 | transformer_depth: 1
54 | context_dim: 768
55 | use_checkpoint: True
56 | legacy: False
57 |
58 | first_stage_config:
59 | target: ldm.models.autoencoder.AutoencoderKL
60 | params:
61 | embed_dim: 4
62 | monitor: val/rec_loss
63 | ddconfig:
64 | double_z: true
65 | z_channels: 4
66 | resolution: 256
67 | in_channels: 3
68 | out_ch: 3
69 | ch: 128
70 | ch_mult:
71 | - 1
72 | - 2
73 | - 4
74 | - 4
75 | num_res_blocks: 2
76 | attn_resolutions: []
77 | dropout: 0.0
78 | lossconfig:
79 | target: torch.nn.Identity
80 |
81 | contextual_stage_config:
82 | target: models_deep_exp.NonlocalNet.VGG19_pytorch
83 |
84 | cond_stage_config:
85 | # target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
86 | target: ldm.modules.encoders.modules.FrozenCLIPDualEmbedder
87 | #ldm.modules.encoders.modules.FrozenCLIPDualEmbedder
88 |
--------------------------------------------------------------------------------
/models/cldm_v15_inpainting_infer1.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: cldm.cldm.ControlLDM
3 | params:
4 | linear_start: 0.00085
5 | linear_end: 0.0120
6 | num_timesteps_cond: 1
7 | log_every_t: 200
8 | timesteps: 1000
9 | first_stage_key: "jpg"
10 | cond_stage_key: "txt"
11 | control_key: "hint"
12 | masked_image: "mask_img"
13 | mask: "mask"
14 | image_size: 64
15 | channels: 4
16 | cond_stage_trainable: false
17 | conditioning_key: crossattn
18 | monitor: val/loss_simple_ema
19 | scale_factor: 0.18215
20 | use_ema: False
21 | only_mid_control: False
22 | load_loss: False
23 |
24 | control_stage_config:
25 | target: cldm.cldm.ControlNet
26 | params:
27 | image_size: 32 # unused
28 | in_channels: 4
29 | hint_channels: 3
30 | model_channels: 320
31 | attention_resolutions: [ 4, 2, 1 ]
32 | num_res_blocks: 2
33 | channel_mult: [ 1, 2, 4, 4 ]
34 | num_heads: 8
35 | use_spatial_transformer: True
36 | transformer_depth: 1
37 | context_dim: 768
38 | use_checkpoint: True
39 | legacy: False
40 |
41 | unet_config:
42 | target: cldm.cldm.ControlledUnetModel
43 | params:
44 | image_size: 32 # unused
45 | in_channels: 9
46 | out_channels: 4
47 | model_channels: 320
48 | attention_resolutions: [ 4, 2, 1 ]
49 | num_res_blocks: 2
50 | channel_mult: [ 1, 2, 4, 4 ]
51 | num_heads: 8
52 | use_spatial_transformer: True
53 | transformer_depth: 1
54 | context_dim: 768
55 | use_checkpoint: True
56 | legacy: False
57 |
58 | first_stage_config:
59 | target: ldm.models.autoencoder.AutoencoderKL
60 | params:
61 | embed_dim: 4
62 | monitor: val/rec_loss
63 | ddconfig:
64 | double_z: true
65 | z_channels: 4
66 | resolution: 256
67 | in_channels: 3
68 | out_ch: 3
69 | ch: 128
70 | ch_mult:
71 | - 1
72 | - 2
73 | - 4
74 | - 4
75 | num_res_blocks: 2
76 | attn_resolutions: []
77 | dropout: 0.0
78 | lossconfig:
79 | target: torch.nn.Identity
80 |
81 | contextual_stage_config:
82 | target: models_deep_exp.NonlocalNet.VGG19_pytorch
83 |
84 | cond_stage_config:
85 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
86 | # target: ldm.modules.encoders.modules.FrozenCLIPDualEmbedder
87 | #ldm.modules.encoders.modules.FrozenCLIPDualEmbedder
88 |
--------------------------------------------------------------------------------
/pretrained_models/put_ckpts_here.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/pretrained_models/put_ckpts_here.txt
--------------------------------------------------------------------------------
/share.py:
--------------------------------------------------------------------------------
1 | import config
2 | from cldm.hack import disable_verbosity, enable_sliced_attention
3 |
4 |
5 | disable_verbosity()
6 |
7 | if config.save_memory:
8 | enable_sliced_attention()
9 |
--------------------------------------------------------------------------------
/taming/data/ade20k.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | import albumentations
5 | from PIL import Image
6 | from torch.utils.data import Dataset
7 |
8 | from taming.data.sflckr import SegmentationBase # for examples included in repo
9 |
10 |
11 | class Examples(SegmentationBase):
12 | def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
13 | super().__init__(data_csv="data/ade20k_examples.txt",
14 | data_root="data/ade20k_images",
15 | segmentation_root="data/ade20k_segmentations",
16 | size=size, random_crop=random_crop,
17 | interpolation=interpolation,
18 | n_labels=151, shift_segmentation=False)
19 |
20 |
21 | # With semantic map and scene label
22 | class ADE20kBase(Dataset):
23 | def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None):
24 | self.split = self.get_split()
25 | self.n_labels = 151 # unknown + 150
26 | self.data_csv = {"train": "data/ade20k_train.txt",
27 | "validation": "data/ade20k_test.txt"}[self.split]
28 | self.data_root = "data/ade20k_root"
29 | with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f:
30 | self.scene_categories = f.read().splitlines()
31 | self.scene_categories = dict(line.split() for line in self.scene_categories)
32 | with open(self.data_csv, "r") as f:
33 | self.image_paths = f.read().splitlines()
34 | self._length = len(self.image_paths)
35 | self.labels = {
36 | "relative_file_path_": [l for l in self.image_paths],
37 | "file_path_": [os.path.join(self.data_root, "images", l)
38 | for l in self.image_paths],
39 | "relative_segmentation_path_": [l.replace(".jpg", ".png")
40 | for l in self.image_paths],
41 | "segmentation_path_": [os.path.join(self.data_root, "annotations",
42 | l.replace(".jpg", ".png"))
43 | for l in self.image_paths],
44 | "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")]
45 | for l in self.image_paths],
46 | }
47 |
48 | size = None if size is not None and size<=0 else size
49 | self.size = size
50 | if crop_size is None:
51 | self.crop_size = size if size is not None else None
52 | else:
53 | self.crop_size = crop_size
54 | if self.size is not None:
55 | self.interpolation = interpolation
56 | self.interpolation = {
57 | "nearest": cv2.INTER_NEAREST,
58 | "bilinear": cv2.INTER_LINEAR,
59 | "bicubic": cv2.INTER_CUBIC,
60 | "area": cv2.INTER_AREA,
61 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
62 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
63 | interpolation=self.interpolation)
64 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
65 | interpolation=cv2.INTER_NEAREST)
66 |
67 | if crop_size is not None:
68 | self.center_crop = not random_crop
69 | if self.center_crop:
70 | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
71 | else:
72 | self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
73 | self.preprocessor = self.cropper
74 |
75 | def __len__(self):
76 | return self._length
77 |
78 | def __getitem__(self, i):
79 | example = dict((k, self.labels[k][i]) for k in self.labels)
80 | image = Image.open(example["file_path_"])
81 | if not image.mode == "RGB":
82 | image = image.convert("RGB")
83 | image = np.array(image).astype(np.uint8)
84 | if self.size is not None:
85 | image = self.image_rescaler(image=image)["image"]
86 | segmentation = Image.open(example["segmentation_path_"])
87 | segmentation = np.array(segmentation).astype(np.uint8)
88 | if self.size is not None:
89 | segmentation = self.segmentation_rescaler(image=segmentation)["image"]
90 | if self.size is not None:
91 | processed = self.preprocessor(image=image, mask=segmentation)
92 | else:
93 | processed = {"image": image, "mask": segmentation}
94 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
95 | segmentation = processed["mask"]
96 | onehot = np.eye(self.n_labels)[segmentation]
97 | example["segmentation"] = onehot
98 | return example
99 |
100 |
101 | class ADE20kTrain(ADE20kBase):
102 | # default to random_crop=True
103 | def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None):
104 | super().__init__(config=config, size=size, random_crop=random_crop,
105 | interpolation=interpolation, crop_size=crop_size)
106 |
107 | def get_split(self):
108 | return "train"
109 |
110 |
111 | class ADE20kValidation(ADE20kBase):
112 | def get_split(self):
113 | return "validation"
114 |
115 |
116 | if __name__ == "__main__":
117 | dset = ADE20kValidation()
118 | ex = dset[0]
119 | for k in ["image", "scene_category", "segmentation"]:
120 | print(type(ex[k]))
121 | try:
122 | print(ex[k].shape)
123 | except:
124 | print(ex[k])
125 |
--------------------------------------------------------------------------------
/taming/data/annotated_objects_coco.py:
--------------------------------------------------------------------------------
1 | import json
2 | from itertools import chain
3 | from pathlib import Path
4 | from typing import Iterable, Dict, List, Callable, Any
5 | from collections import defaultdict
6 |
7 | from tqdm import tqdm
8 |
9 | from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
10 | from taming.data.helper_types import Annotation, ImageDescription, Category
11 |
12 | COCO_PATH_STRUCTURE = {
13 | 'train': {
14 | 'top_level': '',
15 | 'instances_annotations': 'annotations/instances_train2017.json',
16 | 'stuff_annotations': 'annotations/stuff_train2017.json',
17 | 'files': 'train2017'
18 | },
19 | 'validation': {
20 | 'top_level': '',
21 | 'instances_annotations': 'annotations/instances_val2017.json',
22 | 'stuff_annotations': 'annotations/stuff_val2017.json',
23 | 'files': 'val2017'
24 | }
25 | }
26 |
27 |
28 | def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]:
29 | return {
30 | str(img['id']): ImageDescription(
31 | id=img['id'],
32 | license=img.get('license'),
33 | file_name=img['file_name'],
34 | coco_url=img['coco_url'],
35 | original_size=(img['width'], img['height']),
36 | date_captured=img.get('date_captured'),
37 | flickr_url=img.get('flickr_url')
38 | )
39 | for img in description_json
40 | }
41 |
42 |
43 | def load_categories(category_json: Iterable) -> Dict[str, Category]:
44 | return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name'])
45 | for cat in category_json if cat['name'] != 'other'}
46 |
47 |
48 | def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription],
49 | category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]:
50 | annotations = defaultdict(list)
51 | total = sum(len(a) for a in annotations_json)
52 | for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total):
53 | image_id = str(ann['image_id'])
54 | if image_id not in image_descriptions:
55 | raise ValueError(f'image_id [{image_id}] has no image description.')
56 | category_id = ann['category_id']
57 | try:
58 | category_no = category_no_for_id(str(category_id))
59 | except KeyError:
60 | continue
61 |
62 | width, height = image_descriptions[image_id].original_size
63 | bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height)
64 |
65 | annotations[image_id].append(
66 | Annotation(
67 | id=ann['id'],
68 | area=bbox[2]*bbox[3], # use bbox area
69 | is_group_of=ann['iscrowd'],
70 | image_id=ann['image_id'],
71 | bbox=bbox,
72 | category_id=str(category_id),
73 | category_no=category_no
74 | )
75 | )
76 | return dict(annotations)
77 |
78 |
79 | class AnnotatedObjectsCoco(AnnotatedObjectsDataset):
80 | def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs):
81 | """
82 | @param data_path: is the path to the following folder structure:
83 | coco/
84 | ├── annotations
85 | │ ├── instances_train2017.json
86 | │ ├── instances_val2017.json
87 | │ ├── stuff_train2017.json
88 | │ └── stuff_val2017.json
89 | ├── train2017
90 | │ ├── 000000000009.jpg
91 | │ ├── 000000000025.jpg
92 | │ └── ...
93 | ├── val2017
94 | │ ├── 000000000139.jpg
95 | │ ├── 000000000285.jpg
96 | │ └── ...
97 | @param: split: one of 'train' or 'validation'
98 | @param: desired image size (give square images)
99 | """
100 | super().__init__(**kwargs)
101 | self.use_things = use_things
102 | self.use_stuff = use_stuff
103 |
104 | with open(self.paths['instances_annotations']) as f:
105 | inst_data_json = json.load(f)
106 | with open(self.paths['stuff_annotations']) as f:
107 | stuff_data_json = json.load(f)
108 |
109 | category_jsons = []
110 | annotation_jsons = []
111 | if self.use_things:
112 | category_jsons.append(inst_data_json['categories'])
113 | annotation_jsons.append(inst_data_json['annotations'])
114 | if self.use_stuff:
115 | category_jsons.append(stuff_data_json['categories'])
116 | annotation_jsons.append(stuff_data_json['annotations'])
117 |
118 | self.categories = load_categories(chain(*category_jsons))
119 | self.filter_categories()
120 | self.setup_category_id_and_number()
121 |
122 | self.image_descriptions = load_image_descriptions(inst_data_json['images'])
123 | annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split)
124 | self.annotations = self.filter_object_number(annotations, self.min_object_area,
125 | self.min_objects_per_image, self.max_objects_per_image)
126 | self.image_ids = list(self.annotations.keys())
127 | self.clean_up_annotations_and_image_descriptions()
128 |
129 | def get_path_structure(self) -> Dict[str, str]:
130 | if self.split not in COCO_PATH_STRUCTURE:
131 | raise ValueError(f'Split [{self.split} does not exist for COCO data.]')
132 | return COCO_PATH_STRUCTURE[self.split]
133 |
134 | def get_image_path(self, image_id: str) -> Path:
135 | return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name)
136 |
137 | def get_image_description(self, image_id: str) -> Dict[str, Any]:
138 | # noinspection PyProtectedMember
139 | return self.image_descriptions[image_id]._asdict()
140 |
--------------------------------------------------------------------------------
/taming/data/annotated_objects_open_images.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from csv import DictReader, reader as TupleReader
3 | from pathlib import Path
4 | from typing import Dict, List, Any
5 | import warnings
6 |
7 | from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
8 | from taming.data.helper_types import Annotation, Category
9 | from tqdm import tqdm
10 |
11 | OPEN_IMAGES_STRUCTURE = {
12 | 'train': {
13 | 'top_level': '',
14 | 'class_descriptions': 'class-descriptions-boxable.csv',
15 | 'annotations': 'oidv6-train-annotations-bbox.csv',
16 | 'file_list': 'train-images-boxable.csv',
17 | 'files': 'train'
18 | },
19 | 'validation': {
20 | 'top_level': '',
21 | 'class_descriptions': 'class-descriptions-boxable.csv',
22 | 'annotations': 'validation-annotations-bbox.csv',
23 | 'file_list': 'validation-images.csv',
24 | 'files': 'validation'
25 | },
26 | 'test': {
27 | 'top_level': '',
28 | 'class_descriptions': 'class-descriptions-boxable.csv',
29 | 'annotations': 'test-annotations-bbox.csv',
30 | 'file_list': 'test-images.csv',
31 | 'files': 'test'
32 | }
33 | }
34 |
35 |
36 | def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str],
37 | category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]:
38 | annotations: Dict[str, List[Annotation]] = defaultdict(list)
39 | with open(descriptor_path) as file:
40 | reader = DictReader(file)
41 | for i, row in tqdm(enumerate(reader), total=14620000, desc='Loading OpenImages annotations'):
42 | width = float(row['XMax']) - float(row['XMin'])
43 | height = float(row['YMax']) - float(row['YMin'])
44 | area = width * height
45 | category_id = row['LabelName']
46 | if category_id in category_mapping:
47 | category_id = category_mapping[category_id]
48 | if area >= min_object_area and category_id in category_no_for_id:
49 | annotations[row['ImageID']].append(
50 | Annotation(
51 | id=i,
52 | image_id=row['ImageID'],
53 | source=row['Source'],
54 | category_id=category_id,
55 | category_no=category_no_for_id[category_id],
56 | confidence=float(row['Confidence']),
57 | bbox=(float(row['XMin']), float(row['YMin']), width, height),
58 | area=area,
59 | is_occluded=bool(int(row['IsOccluded'])),
60 | is_truncated=bool(int(row['IsTruncated'])),
61 | is_group_of=bool(int(row['IsGroupOf'])),
62 | is_depiction=bool(int(row['IsDepiction'])),
63 | is_inside=bool(int(row['IsInside']))
64 | )
65 | )
66 | if 'train' in str(descriptor_path) and i < 14000000:
67 | warnings.warn(f'Running with subset of Open Images. Train dataset has length [{len(annotations)}].')
68 | return dict(annotations)
69 |
70 |
71 | def load_image_ids(csv_path: Path) -> List[str]:
72 | with open(csv_path) as file:
73 | reader = DictReader(file)
74 | return [row['image_name'] for row in reader]
75 |
76 |
77 | def load_categories(csv_path: Path) -> Dict[str, Category]:
78 | with open(csv_path) as file:
79 | reader = TupleReader(file)
80 | return {row[0]: Category(id=row[0], name=row[1], super_category=None) for row in reader}
81 |
82 |
83 | class AnnotatedObjectsOpenImages(AnnotatedObjectsDataset):
84 | def __init__(self, use_additional_parameters: bool, **kwargs):
85 | """
86 | @param data_path: is the path to the following folder structure:
87 | open_images/
88 | │ oidv6-train-annotations-bbox.csv
89 | ├── class-descriptions-boxable.csv
90 | ├── oidv6-train-annotations-bbox.csv
91 | ├── test
92 | │ ├── 000026e7ee790996.jpg
93 | │ ├── 000062a39995e348.jpg
94 | │ └── ...
95 | ├── test-annotations-bbox.csv
96 | ├── test-images.csv
97 | ├── train
98 | │ ├── 000002b66c9c498e.jpg
99 | │ ├── 000002b97e5471a0.jpg
100 | │ └── ...
101 | ├── train-images-boxable.csv
102 | ├── validation
103 | │ ├── 0001eeaf4aed83f9.jpg
104 | │ ├── 0004886b7d043cfd.jpg
105 | │ └── ...
106 | ├── validation-annotations-bbox.csv
107 | └── validation-images.csv
108 | @param: split: one of 'train', 'validation' or 'test'
109 | @param: desired image size (returns square images)
110 | """
111 |
112 | super().__init__(**kwargs)
113 | self.use_additional_parameters = use_additional_parameters
114 |
115 | self.categories = load_categories(self.paths['class_descriptions'])
116 | self.filter_categories()
117 | self.setup_category_id_and_number()
118 |
119 | self.image_descriptions = {}
120 | annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping,
121 | self.category_number)
122 | self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image,
123 | self.max_objects_per_image)
124 | self.image_ids = list(self.annotations.keys())
125 | self.clean_up_annotations_and_image_descriptions()
126 |
127 | def get_path_structure(self) -> Dict[str, str]:
128 | if self.split not in OPEN_IMAGES_STRUCTURE:
129 | raise ValueError(f'Split [{self.split} does not exist for Open Images data.]')
130 | return OPEN_IMAGES_STRUCTURE[self.split]
131 |
132 | def get_image_path(self, image_id: str) -> Path:
133 | return self.paths['files'].joinpath(f'{image_id:0>16}.jpg')
134 |
135 | def get_image_description(self, image_id: str) -> Dict[str, Any]:
136 | image_path = self.get_image_path(image_id)
137 | return {'file_path': str(image_path), 'file_name': image_path.name}
138 |
--------------------------------------------------------------------------------
/taming/data/base.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | import numpy as np
3 | import albumentations
4 | from PIL import Image
5 | from torch.utils.data import Dataset, ConcatDataset
6 |
7 |
8 | class ConcatDatasetWithIndex(ConcatDataset):
9 | """Modified from original pytorch code to return dataset idx"""
10 | def __getitem__(self, idx):
11 | if idx < 0:
12 | if -idx > len(self):
13 | raise ValueError("absolute value of index should not exceed dataset length")
14 | idx = len(self) + idx
15 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
16 | if dataset_idx == 0:
17 | sample_idx = idx
18 | else:
19 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
20 | return self.datasets[dataset_idx][sample_idx], dataset_idx
21 |
22 |
23 | class ImagePaths(Dataset):
24 | def __init__(self, paths, size=None, random_crop=False, labels=None):
25 | self.size = size
26 | self.random_crop = random_crop
27 |
28 | self.labels = dict() if labels is None else labels
29 | self.labels["file_path_"] = paths
30 | self._length = len(paths)
31 |
32 | if self.size is not None and self.size > 0:
33 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
34 | if not self.random_crop:
35 | self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
36 | else:
37 | self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
38 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
39 | else:
40 | self.preprocessor = lambda **kwargs: kwargs
41 |
42 | def __len__(self):
43 | return self._length
44 |
45 | def preprocess_image(self, image_path):
46 | image = Image.open(image_path)
47 | if not image.mode == "RGB":
48 | image = image.convert("RGB")
49 | image = np.array(image).astype(np.uint8)
50 | image = self.preprocessor(image=image)["image"]
51 | image = (image/127.5 - 1.0).astype(np.float32)
52 | return image
53 |
54 | def __getitem__(self, i):
55 | example = dict()
56 | example["image"] = self.preprocess_image(self.labels["file_path_"][i])
57 | for k in self.labels:
58 | example[k] = self.labels[k][i]
59 | return example
60 |
61 |
62 | class NumpyPaths(ImagePaths):
63 | def preprocess_image(self, image_path):
64 | image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024
65 | image = np.transpose(image, (1,2,0))
66 | image = Image.fromarray(image, mode="RGB")
67 | image = np.array(image).astype(np.uint8)
68 | image = self.preprocessor(image=image)["image"]
69 | image = (image/127.5 - 1.0).astype(np.float32)
70 | return image
71 |
--------------------------------------------------------------------------------
/taming/data/coco.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import albumentations
4 | import numpy as np
5 | from PIL import Image
6 | from tqdm import tqdm
7 | from torch.utils.data import Dataset
8 |
9 | from taming.data.sflckr import SegmentationBase # for examples included in repo
10 |
11 |
12 | class Examples(SegmentationBase):
13 | def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
14 | super().__init__(data_csv="data/coco_examples.txt",
15 | data_root="data/coco_images",
16 | segmentation_root="data/coco_segmentations",
17 | size=size, random_crop=random_crop,
18 | interpolation=interpolation,
19 | n_labels=183, shift_segmentation=True)
20 |
21 |
22 | class CocoBase(Dataset):
23 | """needed for (image, caption, segmentation) pairs"""
24 | def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False,
25 | crop_size=None, force_no_crop=False, given_files=None):
26 | self.split = self.get_split()
27 | self.size = size
28 | if crop_size is None:
29 | self.crop_size = size
30 | else:
31 | self.crop_size = crop_size
32 |
33 | self.onehot = onehot_segmentation # return segmentation as rgb or one hot
34 | self.stuffthing = use_stuffthing # include thing in segmentation
35 | if self.onehot and not self.stuffthing:
36 | raise NotImplemented("One hot mode is only supported for the "
37 | "stuffthings version because labels are stored "
38 | "a bit different.")
39 |
40 | data_json = datajson
41 | with open(data_json) as json_file:
42 | self.json_data = json.load(json_file)
43 | self.img_id_to_captions = dict()
44 | self.img_id_to_filepath = dict()
45 | self.img_id_to_segmentation_filepath = dict()
46 |
47 | assert data_json.split("/")[-1] in ["captions_train2017.json",
48 | "captions_val2017.json"]
49 | if self.stuffthing:
50 | self.segmentation_prefix = (
51 | "data/cocostuffthings/val2017" if
52 | data_json.endswith("captions_val2017.json") else
53 | "data/cocostuffthings/train2017")
54 | else:
55 | self.segmentation_prefix = (
56 | "data/coco/annotations/stuff_val2017_pixelmaps" if
57 | data_json.endswith("captions_val2017.json") else
58 | "data/coco/annotations/stuff_train2017_pixelmaps")
59 |
60 | imagedirs = self.json_data["images"]
61 | self.labels = {"image_ids": list()}
62 | for imgdir in tqdm(imagedirs, desc="ImgToPath"):
63 | self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"])
64 | self.img_id_to_captions[imgdir["id"]] = list()
65 | pngfilename = imgdir["file_name"].replace("jpg", "png")
66 | self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join(
67 | self.segmentation_prefix, pngfilename)
68 | if given_files is not None:
69 | if pngfilename in given_files:
70 | self.labels["image_ids"].append(imgdir["id"])
71 | else:
72 | self.labels["image_ids"].append(imgdir["id"])
73 |
74 | capdirs = self.json_data["annotations"]
75 | for capdir in tqdm(capdirs, desc="ImgToCaptions"):
76 | # there are in average 5 captions per image
77 | self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]]))
78 |
79 | self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
80 | if self.split=="validation":
81 | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
82 | else:
83 | self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
84 | self.preprocessor = albumentations.Compose(
85 | [self.rescaler, self.cropper],
86 | additional_targets={"segmentation": "image"})
87 | if force_no_crop:
88 | self.rescaler = albumentations.Resize(height=self.size, width=self.size)
89 | self.preprocessor = albumentations.Compose(
90 | [self.rescaler],
91 | additional_targets={"segmentation": "image"})
92 |
93 | def __len__(self):
94 | return len(self.labels["image_ids"])
95 |
96 | def preprocess_image(self, image_path, segmentation_path):
97 | image = Image.open(image_path)
98 | if not image.mode == "RGB":
99 | image = image.convert("RGB")
100 | image = np.array(image).astype(np.uint8)
101 |
102 | segmentation = Image.open(segmentation_path)
103 | if not self.onehot and not segmentation.mode == "RGB":
104 | segmentation = segmentation.convert("RGB")
105 | segmentation = np.array(segmentation).astype(np.uint8)
106 | if self.onehot:
107 | assert self.stuffthing
108 | # stored in caffe format: unlabeled==255. stuff and thing from
109 | # 0-181. to be compatible with the labels in
110 | # https://github.com/nightrome/cocostuff/blob/master/labels.txt
111 | # we shift stuffthing one to the right and put unlabeled in zero
112 | # as long as segmentation is uint8 shifting to right handles the
113 | # latter too
114 | assert segmentation.dtype == np.uint8
115 | segmentation = segmentation + 1
116 |
117 | processed = self.preprocessor(image=image, segmentation=segmentation)
118 | image, segmentation = processed["image"], processed["segmentation"]
119 | image = (image / 127.5 - 1.0).astype(np.float32)
120 |
121 | if self.onehot:
122 | assert segmentation.dtype == np.uint8
123 | # make it one hot
124 | n_labels = 183
125 | flatseg = np.ravel(segmentation)
126 | onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool)
127 | onehot[np.arange(flatseg.size), flatseg] = True
128 | onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int)
129 | segmentation = onehot
130 | else:
131 | segmentation = (segmentation / 127.5 - 1.0).astype(np.float32)
132 | return image, segmentation
133 |
134 | def __getitem__(self, i):
135 | img_path = self.img_id_to_filepath[self.labels["image_ids"][i]]
136 | seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]]
137 | image, segmentation = self.preprocess_image(img_path, seg_path)
138 | captions = self.img_id_to_captions[self.labels["image_ids"][i]]
139 | # randomly draw one of all available captions per image
140 | caption = captions[np.random.randint(0, len(captions))]
141 | example = {"image": image,
142 | "caption": [str(caption[0])],
143 | "segmentation": segmentation,
144 | "img_path": img_path,
145 | "seg_path": seg_path,
146 | "filename_": img_path.split(os.sep)[-1]
147 | }
148 | return example
149 |
150 |
151 | class CocoImagesAndCaptionsTrain(CocoBase):
152 | """returns a pair of (image, caption)"""
153 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False):
154 | super().__init__(size=size,
155 | dataroot="data/coco/train2017",
156 | datajson="data/coco/annotations/captions_train2017.json",
157 | onehot_segmentation=onehot_segmentation,
158 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop)
159 |
160 | def get_split(self):
161 | return "train"
162 |
163 |
164 | class CocoImagesAndCaptionsValidation(CocoBase):
165 | """returns a pair of (image, caption)"""
166 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
167 | given_files=None):
168 | super().__init__(size=size,
169 | dataroot="data/coco/val2017",
170 | datajson="data/coco/annotations/captions_val2017.json",
171 | onehot_segmentation=onehot_segmentation,
172 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
173 | given_files=given_files)
174 |
175 | def get_split(self):
176 | return "validation"
177 |
--------------------------------------------------------------------------------
/taming/data/conditional_builder/objects_bbox.py:
--------------------------------------------------------------------------------
1 | from itertools import cycle
2 | from typing import List, Tuple, Callable, Optional
3 |
4 | from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
5 | from more_itertools.recipes import grouper
6 | from taming.data.image_transforms import convert_pil_to_tensor
7 | from torch import LongTensor, Tensor
8 |
9 | from taming.data.helper_types import BoundingBox, Annotation
10 | from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
11 | from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \
12 | pad_list, get_plot_font_size, absolute_bbox
13 |
14 |
15 | class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder):
16 | @property
17 | def object_descriptor_length(self) -> int:
18 | return 3
19 |
20 | def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
21 | object_triples = [
22 | (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox))
23 | for ann in annotations
24 | ]
25 | empty_triple = (self.none, self.none, self.none)
26 | object_triples = pad_list(object_triples, empty_triple, self.no_max_objects)
27 | return object_triples
28 |
29 | def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]:
30 | conditional_list = conditional.tolist()
31 | crop_coordinates = None
32 | if self.encode_crop:
33 | crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
34 | conditional_list = conditional_list[:-2]
35 | object_triples = grouper(conditional_list, 3)
36 | assert conditional.shape[0] == self.embedding_dim
37 | return [
38 | (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2]))
39 | for object_triple in object_triples if object_triple[0] != self.none
40 | ], crop_coordinates
41 |
42 | def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
43 | line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
44 | plot = pil_image.new('RGB', figure_size, WHITE)
45 | draw = pil_img_draw.Draw(plot)
46 | font = ImageFont.truetype(
47 | "/usr/share/fonts/truetype/lato/Lato-Regular.ttf",
48 | size=get_plot_font_size(font_size, figure_size)
49 | )
50 | width, height = plot.size
51 | description, crop_coordinates = self.inverse_build(conditional)
52 | for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)):
53 | annotation = self.representation_to_annotation(representation)
54 | class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation)
55 | bbox = absolute_bbox(bbox, width, height)
56 | draw.rectangle(bbox, outline=color, width=line_width)
57 | draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font)
58 | if crop_coordinates is not None:
59 | draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
60 | return convert_pil_to_tensor(plot) / 127.5 - 1.
61 |
--------------------------------------------------------------------------------
/taming/data/conditional_builder/objects_center_points.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import warnings
4 | from itertools import cycle
5 | from typing import List, Optional, Tuple, Callable
6 |
7 | from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
8 | from more_itertools.recipes import grouper
9 | from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \
10 | additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \
11 | absolute_bbox, rescale_annotations
12 | from taming.data.helper_types import BoundingBox, Annotation
13 | from taming.data.image_transforms import convert_pil_to_tensor
14 | from torch import LongTensor, Tensor
15 |
16 |
17 | class ObjectsCenterPointsConditionalBuilder:
18 | def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, encode_crop: bool,
19 | use_group_parameter: bool, use_additional_parameters: bool):
20 | self.no_object_classes = no_object_classes
21 | self.no_max_objects = no_max_objects
22 | self.no_tokens = no_tokens
23 | self.encode_crop = encode_crop
24 | self.no_sections = int(math.sqrt(self.no_tokens))
25 | self.use_group_parameter = use_group_parameter
26 | self.use_additional_parameters = use_additional_parameters
27 |
28 | @property
29 | def none(self) -> int:
30 | return self.no_tokens - 1
31 |
32 | @property
33 | def object_descriptor_length(self) -> int:
34 | return 2
35 |
36 | @property
37 | def embedding_dim(self) -> int:
38 | extra_length = 2 if self.encode_crop else 0
39 | return self.no_max_objects * self.object_descriptor_length + extra_length
40 |
41 | def tokenize_coordinates(self, x: float, y: float) -> int:
42 | """
43 | Express 2d coordinates with one number.
44 | Example: assume self.no_tokens = 16, then no_sections = 4:
45 | 0 0 0 0
46 | 0 0 # 0
47 | 0 0 0 0
48 | 0 0 0 x
49 | Then the # position corresponds to token 6, the x position to token 15.
50 | @param x: float in [0, 1]
51 | @param y: float in [0, 1]
52 | @return: discrete tokenized coordinate
53 | """
54 | x_discrete = int(round(x * (self.no_sections - 1)))
55 | y_discrete = int(round(y * (self.no_sections - 1)))
56 | return y_discrete * self.no_sections + x_discrete
57 |
58 | def coordinates_from_token(self, token: int) -> (float, float):
59 | x = token % self.no_sections
60 | y = token // self.no_sections
61 | return x / (self.no_sections - 1), y / (self.no_sections - 1)
62 |
63 | def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox:
64 | x0, y0 = self.coordinates_from_token(token1)
65 | x1, y1 = self.coordinates_from_token(token2)
66 | return x0, y0, x1 - x0, y1 - y0
67 |
68 | def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple[int, int]:
69 | return self.tokenize_coordinates(bbox[0], bbox[1]), \
70 | self.tokenize_coordinates(bbox[0] + bbox[2], bbox[1] + bbox[3])
71 |
72 | def inverse_build(self, conditional: LongTensor) \
73 | -> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]:
74 | conditional_list = conditional.tolist()
75 | crop_coordinates = None
76 | if self.encode_crop:
77 | crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
78 | conditional_list = conditional_list[:-2]
79 | table_of_content = grouper(conditional_list, self.object_descriptor_length)
80 | assert conditional.shape[0] == self.embedding_dim
81 | return [
82 | (object_tuple[0], self.coordinates_from_token(object_tuple[1]))
83 | for object_tuple in table_of_content if object_tuple[0] != self.none
84 | ], crop_coordinates
85 |
86 | def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
87 | line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
88 | plot = pil_image.new('RGB', figure_size, WHITE)
89 | draw = pil_img_draw.Draw(plot)
90 | circle_size = get_circle_size(figure_size)
91 | font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf',
92 | size=get_plot_font_size(font_size, figure_size))
93 | width, height = plot.size
94 | description, crop_coordinates = self.inverse_build(conditional)
95 | for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)):
96 | x_abs, y_abs = x * width, y * height
97 | ann = self.representation_to_annotation(representation)
98 | label = label_for_category_no(ann.category_no) + ' ' + additional_parameters_string(ann)
99 | ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size]
100 | draw.ellipse(ellipse_bbox, fill=color, width=0)
101 | draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font)
102 | if crop_coordinates is not None:
103 | draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
104 | return convert_pil_to_tensor(plot) / 127.5 - 1.
105 |
106 | def object_representation(self, annotation: Annotation) -> int:
107 | modifier = 0
108 | if self.use_group_parameter:
109 | modifier |= 1 * (annotation.is_group_of is True)
110 | if self.use_additional_parameters:
111 | modifier |= 2 * (annotation.is_occluded is True)
112 | modifier |= 4 * (annotation.is_depiction is True)
113 | modifier |= 8 * (annotation.is_inside is True)
114 | return annotation.category_no + self.no_object_classes * modifier
115 |
116 | def representation_to_annotation(self, representation: int) -> Annotation:
117 | category_no = representation % self.no_object_classes
118 | modifier = representation // self.no_object_classes
119 | # noinspection PyTypeChecker
120 | return Annotation(
121 | area=None, image_id=None, bbox=None, category_id=None, id=None, source=None, confidence=None,
122 | category_no=category_no,
123 | is_group_of=bool((modifier & 1) * self.use_group_parameter),
124 | is_occluded=bool((modifier & 2) * self.use_additional_parameters),
125 | is_depiction=bool((modifier & 4) * self.use_additional_parameters),
126 | is_inside=bool((modifier & 8) * self.use_additional_parameters)
127 | )
128 |
129 | def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]:
130 | return list(self.token_pair_from_bbox(crop_coordinates))
131 |
132 | def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
133 | object_tuples = [
134 | (self.object_representation(a),
135 | self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2))
136 | for a in annotations
137 | ]
138 | empty_tuple = (self.none, self.none)
139 | object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects)
140 | return object_tuples
141 |
142 | def build(self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False) \
143 | -> LongTensor:
144 | if len(annotations) == 0:
145 | warnings.warn('Did not receive any annotations.')
146 | if len(annotations) > self.no_max_objects:
147 | warnings.warn('Received more annotations than allowed.')
148 | annotations = annotations[:self.no_max_objects]
149 |
150 | if not crop_coordinates:
151 | crop_coordinates = FULL_CROP
152 |
153 | random.shuffle(annotations)
154 | annotations = filter_annotations(annotations, crop_coordinates)
155 | if self.encode_crop:
156 | annotations = rescale_annotations(annotations, FULL_CROP, horizontal_flip)
157 | if horizontal_flip:
158 | crop_coordinates = horizontally_flip_bbox(crop_coordinates)
159 | extra = self._crop_encoder(crop_coordinates)
160 | else:
161 | annotations = rescale_annotations(annotations, crop_coordinates, horizontal_flip)
162 | extra = []
163 |
164 | object_tuples = self._make_object_descriptors(annotations)
165 | flattened = [token for tuple_ in object_tuples for token in tuple_] + extra
166 | assert len(flattened) == self.embedding_dim
167 | assert all(0 <= value < self.no_tokens for value in flattened)
168 | return LongTensor(flattened)
169 |
--------------------------------------------------------------------------------
/taming/data/conditional_builder/utils.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from typing import List, Any, Tuple, Optional
3 |
4 | from taming.data.helper_types import BoundingBox, Annotation
5 |
6 | # source: seaborn, color palette tab10
7 | COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188),
8 | (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)]
9 | BLACK = (0, 0, 0)
10 | GRAY_75 = (63, 63, 63)
11 | GRAY_50 = (127, 127, 127)
12 | GRAY_25 = (191, 191, 191)
13 | WHITE = (255, 255, 255)
14 | FULL_CROP = (0., 0., 1., 1.)
15 |
16 |
17 | def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float:
18 | """
19 | Give intersection area of two rectangles.
20 | @param rectangle1: (x0, y0, w, h) of first rectangle
21 | @param rectangle2: (x0, y0, w, h) of second rectangle
22 | """
23 | rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3]
24 | rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3]
25 | x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0]))
26 | y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1]))
27 | return x_overlap * y_overlap
28 |
29 |
30 | def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox:
31 | return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3]
32 |
33 |
34 | def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]:
35 | bbox = relative_bbox
36 | bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height
37 | return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
38 |
39 |
40 | def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List:
41 | return list_ + [pad_element for _ in range(pad_to_length - len(list_))]
42 |
43 |
44 | def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \
45 | List[Annotation]:
46 | def clamp(x: float):
47 | return max(min(x, 1.), 0.)
48 |
49 | def rescale_bbox(bbox: BoundingBox) -> BoundingBox:
50 | x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
51 | y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
52 | w = min(bbox[2] / crop_coordinates[2], 1 - x0)
53 | h = min(bbox[3] / crop_coordinates[3], 1 - y0)
54 | if flip:
55 | x0 = 1 - (x0 + w)
56 | return x0, y0, w, h
57 |
58 | return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations]
59 |
60 |
61 | def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List:
62 | return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0]
63 |
64 |
65 | def additional_parameters_string(annotation: Annotation, short: bool = True) -> str:
66 | sl = slice(1) if short else slice(None)
67 | string = ''
68 | if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside):
69 | return string
70 | if annotation.is_group_of:
71 | string += 'group'[sl] + ','
72 | if annotation.is_occluded:
73 | string += 'occluded'[sl] + ','
74 | if annotation.is_depiction:
75 | string += 'depiction'[sl] + ','
76 | if annotation.is_inside:
77 | string += 'inside'[sl]
78 | return '(' + string.strip(",") + ')'
79 |
80 |
81 | def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int:
82 | if font_size is None:
83 | font_size = 10
84 | if max(figure_size) >= 256:
85 | font_size = 12
86 | if max(figure_size) >= 512:
87 | font_size = 15
88 | return font_size
89 |
90 |
91 | def get_circle_size(figure_size: Tuple[int, int]) -> int:
92 | circle_size = 2
93 | if max(figure_size) >= 256:
94 | circle_size = 3
95 | if max(figure_size) >= 512:
96 | circle_size = 4
97 | return circle_size
98 |
99 |
100 | def load_object_from_string(object_string: str) -> Any:
101 | """
102 | Source: https://stackoverflow.com/a/10773699
103 | """
104 | module_name, class_name = object_string.rsplit(".", 1)
105 | return getattr(importlib.import_module(module_name), class_name)
106 |
--------------------------------------------------------------------------------
/taming/data/custom.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import albumentations
4 | from torch.utils.data import Dataset
5 |
6 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
7 |
8 |
9 | class CustomBase(Dataset):
10 | def __init__(self, *args, **kwargs):
11 | super().__init__()
12 | self.data = None
13 |
14 | def __len__(self):
15 | return len(self.data)
16 |
17 | def __getitem__(self, i):
18 | example = self.data[i]
19 | return example
20 |
21 |
22 |
23 | class CustomTrain(CustomBase):
24 | def __init__(self, size, training_images_list_file):
25 | super().__init__()
26 | with open(training_images_list_file, "r") as f:
27 | paths = f.read().splitlines()
28 | self.data = ImagePaths(paths=paths, size=size, random_crop=False)
29 |
30 |
31 | class CustomTest(CustomBase):
32 | def __init__(self, size, test_images_list_file):
33 | super().__init__()
34 | with open(test_images_list_file, "r") as f:
35 | paths = f.read().splitlines()
36 | self.data = ImagePaths(paths=paths, size=size, random_crop=False)
37 |
38 |
39 |
--------------------------------------------------------------------------------
/taming/data/faceshq.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import albumentations
4 | from torch.utils.data import Dataset
5 |
6 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
7 |
8 |
9 | class FacesBase(Dataset):
10 | def __init__(self, *args, **kwargs):
11 | super().__init__()
12 | self.data = None
13 | self.keys = None
14 |
15 | def __len__(self):
16 | return len(self.data)
17 |
18 | def __getitem__(self, i):
19 | example = self.data[i]
20 | ex = {}
21 | if self.keys is not None:
22 | for k in self.keys:
23 | ex[k] = example[k]
24 | else:
25 | ex = example
26 | return ex
27 |
28 |
29 | class CelebAHQTrain(FacesBase):
30 | def __init__(self, size, keys=None):
31 | super().__init__()
32 | root = "data/celebahq"
33 | with open("data/celebahqtrain.txt", "r") as f:
34 | relpaths = f.read().splitlines()
35 | paths = [os.path.join(root, relpath) for relpath in relpaths]
36 | self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
37 | self.keys = keys
38 |
39 |
40 | class CelebAHQValidation(FacesBase):
41 | def __init__(self, size, keys=None):
42 | super().__init__()
43 | root = "data/celebahq"
44 | with open("data/celebahqvalidation.txt", "r") as f:
45 | relpaths = f.read().splitlines()
46 | paths = [os.path.join(root, relpath) for relpath in relpaths]
47 | self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
48 | self.keys = keys
49 |
50 |
51 | class FFHQTrain(FacesBase):
52 | def __init__(self, size, keys=None):
53 | super().__init__()
54 | root = "data/ffhq"
55 | with open("data/ffhqtrain.txt", "r") as f:
56 | relpaths = f.read().splitlines()
57 | paths = [os.path.join(root, relpath) for relpath in relpaths]
58 | self.data = ImagePaths(paths=paths, size=size, random_crop=False)
59 | self.keys = keys
60 |
61 |
62 | class FFHQValidation(FacesBase):
63 | def __init__(self, size, keys=None):
64 | super().__init__()
65 | root = "data/ffhq"
66 | with open("data/ffhqvalidation.txt", "r") as f:
67 | relpaths = f.read().splitlines()
68 | paths = [os.path.join(root, relpath) for relpath in relpaths]
69 | self.data = ImagePaths(paths=paths, size=size, random_crop=False)
70 | self.keys = keys
71 |
72 |
73 | class FacesHQTrain(Dataset):
74 | # CelebAHQ [0] + FFHQ [1]
75 | def __init__(self, size, keys=None, crop_size=None, coord=False):
76 | d1 = CelebAHQTrain(size=size, keys=keys)
77 | d2 = FFHQTrain(size=size, keys=keys)
78 | self.data = ConcatDatasetWithIndex([d1, d2])
79 | self.coord = coord
80 | if crop_size is not None:
81 | self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
82 | if self.coord:
83 | self.cropper = albumentations.Compose([self.cropper],
84 | additional_targets={"coord": "image"})
85 |
86 | def __len__(self):
87 | return len(self.data)
88 |
89 | def __getitem__(self, i):
90 | ex, y = self.data[i]
91 | if hasattr(self, "cropper"):
92 | if not self.coord:
93 | out = self.cropper(image=ex["image"])
94 | ex["image"] = out["image"]
95 | else:
96 | h,w,_ = ex["image"].shape
97 | coord = np.arange(h*w).reshape(h,w,1)/(h*w)
98 | out = self.cropper(image=ex["image"], coord=coord)
99 | ex["image"] = out["image"]
100 | ex["coord"] = out["coord"]
101 | ex["class"] = y
102 | return ex
103 |
104 |
105 | class FacesHQValidation(Dataset):
106 | # CelebAHQ [0] + FFHQ [1]
107 | def __init__(self, size, keys=None, crop_size=None, coord=False):
108 | d1 = CelebAHQValidation(size=size, keys=keys)
109 | d2 = FFHQValidation(size=size, keys=keys)
110 | self.data = ConcatDatasetWithIndex([d1, d2])
111 | self.coord = coord
112 | if crop_size is not None:
113 | self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
114 | if self.coord:
115 | self.cropper = albumentations.Compose([self.cropper],
116 | additional_targets={"coord": "image"})
117 |
118 | def __len__(self):
119 | return len(self.data)
120 |
121 | def __getitem__(self, i):
122 | ex, y = self.data[i]
123 | if hasattr(self, "cropper"):
124 | if not self.coord:
125 | out = self.cropper(image=ex["image"])
126 | ex["image"] = out["image"]
127 | else:
128 | h,w,_ = ex["image"].shape
129 | coord = np.arange(h*w).reshape(h,w,1)/(h*w)
130 | out = self.cropper(image=ex["image"], coord=coord)
131 | ex["image"] = out["image"]
132 | ex["coord"] = out["coord"]
133 | ex["class"] = y
134 | return ex
135 |
--------------------------------------------------------------------------------
/taming/data/helper_types.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Tuple, Optional, NamedTuple, Union
2 | from PIL.Image import Image as pil_image
3 | from torch import Tensor
4 |
5 | try:
6 | from typing import Literal
7 | except ImportError:
8 | from typing_extensions import Literal
9 |
10 | Image = Union[Tensor, pil_image]
11 | BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h
12 | CropMethodType = Literal['none', 'random', 'center', 'random-2d']
13 | SplitType = Literal['train', 'validation', 'test']
14 |
15 |
16 | class ImageDescription(NamedTuple):
17 | id: int
18 | file_name: str
19 | original_size: Tuple[int, int] # w, h
20 | url: Optional[str] = None
21 | license: Optional[int] = None
22 | coco_url: Optional[str] = None
23 | date_captured: Optional[str] = None
24 | flickr_url: Optional[str] = None
25 | flickr_id: Optional[str] = None
26 | coco_id: Optional[str] = None
27 |
28 |
29 | class Category(NamedTuple):
30 | id: str
31 | super_category: Optional[str]
32 | name: str
33 |
34 |
35 | class Annotation(NamedTuple):
36 | area: float
37 | image_id: str
38 | bbox: BoundingBox
39 | category_no: int
40 | category_id: str
41 | id: Optional[int] = None
42 | source: Optional[str] = None
43 | confidence: Optional[float] = None
44 | is_group_of: Optional[bool] = None
45 | is_truncated: Optional[bool] = None
46 | is_occluded: Optional[bool] = None
47 | is_depiction: Optional[bool] = None
48 | is_inside: Optional[bool] = None
49 | segmentation: Optional[Dict] = None
50 |
--------------------------------------------------------------------------------
/taming/data/image_transforms.py:
--------------------------------------------------------------------------------
1 | import random
2 | import warnings
3 | from typing import Union
4 |
5 | import torch
6 | from torch import Tensor
7 | from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor
8 | from torchvision.transforms.functional import _get_image_size as get_image_size
9 |
10 | from taming.data.helper_types import BoundingBox, Image
11 |
12 | pil_to_tensor = PILToTensor()
13 |
14 |
15 | def convert_pil_to_tensor(image: Image) -> Tensor:
16 | with warnings.catch_warnings():
17 | # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194
18 | warnings.simplefilter("ignore")
19 | return pil_to_tensor(image)
20 |
21 |
22 | class RandomCrop1dReturnCoordinates(RandomCrop):
23 | def forward(self, img: Image) -> (BoundingBox, Image):
24 | """
25 | Additionally to cropping, returns the relative coordinates of the crop bounding box.
26 | Args:
27 | img (PIL Image or Tensor): Image to be cropped.
28 |
29 | Returns:
30 | Bounding box: x0, y0, w, h
31 | PIL Image or Tensor: Cropped image.
32 |
33 | Based on:
34 | torchvision.transforms.RandomCrop, torchvision 1.7.0
35 | """
36 | if self.padding is not None:
37 | img = F.pad(img, self.padding, self.fill, self.padding_mode)
38 |
39 | width, height = get_image_size(img)
40 | # pad the width if needed
41 | if self.pad_if_needed and width < self.size[1]:
42 | padding = [self.size[1] - width, 0]
43 | img = F.pad(img, padding, self.fill, self.padding_mode)
44 | # pad the height if needed
45 | if self.pad_if_needed and height < self.size[0]:
46 | padding = [0, self.size[0] - height]
47 | img = F.pad(img, padding, self.fill, self.padding_mode)
48 |
49 | i, j, h, w = self.get_params(img, self.size)
50 | bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h
51 | return bbox, F.crop(img, i, j, h, w)
52 |
53 |
54 | class Random2dCropReturnCoordinates(torch.nn.Module):
55 | """
56 | Additionally to cropping, returns the relative coordinates of the crop bounding box.
57 | Args:
58 | img (PIL Image or Tensor): Image to be cropped.
59 |
60 | Returns:
61 | Bounding box: x0, y0, w, h
62 | PIL Image or Tensor: Cropped image.
63 |
64 | Based on:
65 | torchvision.transforms.RandomCrop, torchvision 1.7.0
66 | """
67 |
68 | def __init__(self, min_size: int):
69 | super().__init__()
70 | self.min_size = min_size
71 |
72 | def forward(self, img: Image) -> (BoundingBox, Image):
73 | width, height = get_image_size(img)
74 | max_size = min(width, height)
75 | if max_size <= self.min_size:
76 | size = max_size
77 | else:
78 | size = random.randint(self.min_size, max_size)
79 | top = random.randint(0, height - size)
80 | left = random.randint(0, width - size)
81 | bbox = left / width, top / height, size / width, size / height
82 | return bbox, F.crop(img, top, left, size, size)
83 |
84 |
85 | class CenterCropReturnCoordinates(CenterCrop):
86 | @staticmethod
87 | def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox:
88 | if width > height:
89 | w = height / width
90 | h = 1.0
91 | x0 = 0.5 - w / 2
92 | y0 = 0.
93 | else:
94 | w = 1.0
95 | h = width / height
96 | x0 = 0.
97 | y0 = 0.5 - h / 2
98 | return x0, y0, w, h
99 |
100 | def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]):
101 | """
102 | Additionally to cropping, returns the relative coordinates of the crop bounding box.
103 | Args:
104 | img (PIL Image or Tensor): Image to be cropped.
105 |
106 | Returns:
107 | Bounding box: x0, y0, w, h
108 | PIL Image or Tensor: Cropped image.
109 | Based on:
110 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
111 | """
112 | width, height = get_image_size(img)
113 | return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size)
114 |
115 |
116 | class RandomHorizontalFlipReturn(RandomHorizontalFlip):
117 | def forward(self, img: Image) -> (bool, Image):
118 | """
119 | Additionally to flipping, returns a boolean whether it was flipped or not.
120 | Args:
121 | img (PIL Image or Tensor): Image to be flipped.
122 |
123 | Returns:
124 | flipped: whether the image was flipped or not
125 | PIL Image or Tensor: Randomly flipped image.
126 |
127 | Based on:
128 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
129 | """
130 | if torch.rand(1) < self.p:
131 | return True, F.hflip(img)
132 | return False, img
133 |
--------------------------------------------------------------------------------
/taming/data/open_images_helper.py:
--------------------------------------------------------------------------------
1 | open_images_unify_categories_for_coco = {
2 | '/m/03bt1vf': '/m/01g317',
3 | '/m/04yx4': '/m/01g317',
4 | '/m/05r655': '/m/01g317',
5 | '/m/01bl7v': '/m/01g317',
6 | '/m/0cnyhnx': '/m/01xq0k1',
7 | '/m/01226z': '/m/018xm',
8 | '/m/05ctyq': '/m/018xm',
9 | '/m/058qzx': '/m/04ctx',
10 | '/m/06pcq': '/m/0l515',
11 | '/m/03m3pdh': '/m/02crq1',
12 | '/m/046dlr': '/m/01x3z',
13 | '/m/0h8mzrc': '/m/01x3z',
14 | }
15 |
16 |
17 | top_300_classes_plus_coco_compatibility = [
18 | ('Man', 1060962),
19 | ('Clothing', 986610),
20 | ('Tree', 748162),
21 | ('Woman', 611896),
22 | ('Person', 610294),
23 | ('Human face', 442948),
24 | ('Girl', 175399),
25 | ('Building', 162147),
26 | ('Car', 159135),
27 | ('Plant', 155704),
28 | ('Human body', 137073),
29 | ('Flower', 133128),
30 | ('Window', 127485),
31 | ('Human arm', 118380),
32 | ('House', 114365),
33 | ('Wheel', 111684),
34 | ('Suit', 99054),
35 | ('Human hair', 98089),
36 | ('Human head', 92763),
37 | ('Chair', 88624),
38 | ('Boy', 79849),
39 | ('Table', 73699),
40 | ('Jeans', 57200),
41 | ('Tire', 55725),
42 | ('Skyscraper', 53321),
43 | ('Food', 52400),
44 | ('Footwear', 50335),
45 | ('Dress', 50236),
46 | ('Human leg', 47124),
47 | ('Toy', 46636),
48 | ('Tower', 45605),
49 | ('Boat', 43486),
50 | ('Land vehicle', 40541),
51 | ('Bicycle wheel', 34646),
52 | ('Palm tree', 33729),
53 | ('Fashion accessory', 32914),
54 | ('Glasses', 31940),
55 | ('Bicycle', 31409),
56 | ('Furniture', 30656),
57 | ('Sculpture', 29643),
58 | ('Bottle', 27558),
59 | ('Dog', 26980),
60 | ('Snack', 26796),
61 | ('Human hand', 26664),
62 | ('Bird', 25791),
63 | ('Book', 25415),
64 | ('Guitar', 24386),
65 | ('Jacket', 23998),
66 | ('Poster', 22192),
67 | ('Dessert', 21284),
68 | ('Baked goods', 20657),
69 | ('Drink', 19754),
70 | ('Flag', 18588),
71 | ('Houseplant', 18205),
72 | ('Tableware', 17613),
73 | ('Airplane', 17218),
74 | ('Door', 17195),
75 | ('Sports uniform', 17068),
76 | ('Shelf', 16865),
77 | ('Drum', 16612),
78 | ('Vehicle', 16542),
79 | ('Microphone', 15269),
80 | ('Street light', 14957),
81 | ('Cat', 14879),
82 | ('Fruit', 13684),
83 | ('Fast food', 13536),
84 | ('Animal', 12932),
85 | ('Vegetable', 12534),
86 | ('Train', 12358),
87 | ('Horse', 11948),
88 | ('Flowerpot', 11728),
89 | ('Motorcycle', 11621),
90 | ('Fish', 11517),
91 | ('Desk', 11405),
92 | ('Helmet', 10996),
93 | ('Truck', 10915),
94 | ('Bus', 10695),
95 | ('Hat', 10532),
96 | ('Auto part', 10488),
97 | ('Musical instrument', 10303),
98 | ('Sunglasses', 10207),
99 | ('Picture frame', 10096),
100 | ('Sports equipment', 10015),
101 | ('Shorts', 9999),
102 | ('Wine glass', 9632),
103 | ('Duck', 9242),
104 | ('Wine', 9032),
105 | ('Rose', 8781),
106 | ('Tie', 8693),
107 | ('Butterfly', 8436),
108 | ('Beer', 7978),
109 | ('Cabinetry', 7956),
110 | ('Laptop', 7907),
111 | ('Insect', 7497),
112 | ('Goggles', 7363),
113 | ('Shirt', 7098),
114 | ('Dairy Product', 7021),
115 | ('Marine invertebrates', 7014),
116 | ('Cattle', 7006),
117 | ('Trousers', 6903),
118 | ('Van', 6843),
119 | ('Billboard', 6777),
120 | ('Balloon', 6367),
121 | ('Human nose', 6103),
122 | ('Tent', 6073),
123 | ('Camera', 6014),
124 | ('Doll', 6002),
125 | ('Coat', 5951),
126 | ('Mobile phone', 5758),
127 | ('Swimwear', 5729),
128 | ('Strawberry', 5691),
129 | ('Stairs', 5643),
130 | ('Goose', 5599),
131 | ('Umbrella', 5536),
132 | ('Cake', 5508),
133 | ('Sun hat', 5475),
134 | ('Bench', 5310),
135 | ('Bookcase', 5163),
136 | ('Bee', 5140),
137 | ('Computer monitor', 5078),
138 | ('Hiking equipment', 4983),
139 | ('Office building', 4981),
140 | ('Coffee cup', 4748),
141 | ('Curtain', 4685),
142 | ('Plate', 4651),
143 | ('Box', 4621),
144 | ('Tomato', 4595),
145 | ('Coffee table', 4529),
146 | ('Office supplies', 4473),
147 | ('Maple', 4416),
148 | ('Muffin', 4365),
149 | ('Cocktail', 4234),
150 | ('Castle', 4197),
151 | ('Couch', 4134),
152 | ('Pumpkin', 3983),
153 | ('Computer keyboard', 3960),
154 | ('Human mouth', 3926),
155 | ('Christmas tree', 3893),
156 | ('Mushroom', 3883),
157 | ('Swimming pool', 3809),
158 | ('Pastry', 3799),
159 | ('Lavender (Plant)', 3769),
160 | ('Football helmet', 3732),
161 | ('Bread', 3648),
162 | ('Traffic sign', 3628),
163 | ('Common sunflower', 3597),
164 | ('Television', 3550),
165 | ('Bed', 3525),
166 | ('Cookie', 3485),
167 | ('Fountain', 3484),
168 | ('Paddle', 3447),
169 | ('Bicycle helmet', 3429),
170 | ('Porch', 3420),
171 | ('Deer', 3387),
172 | ('Fedora', 3339),
173 | ('Canoe', 3338),
174 | ('Carnivore', 3266),
175 | ('Bowl', 3202),
176 | ('Human eye', 3166),
177 | ('Ball', 3118),
178 | ('Pillow', 3077),
179 | ('Salad', 3061),
180 | ('Beetle', 3060),
181 | ('Orange', 3050),
182 | ('Drawer', 2958),
183 | ('Platter', 2937),
184 | ('Elephant', 2921),
185 | ('Seafood', 2921),
186 | ('Monkey', 2915),
187 | ('Countertop', 2879),
188 | ('Watercraft', 2831),
189 | ('Helicopter', 2805),
190 | ('Kitchen appliance', 2797),
191 | ('Personal flotation device', 2781),
192 | ('Swan', 2739),
193 | ('Lamp', 2711),
194 | ('Boot', 2695),
195 | ('Bronze sculpture', 2693),
196 | ('Chicken', 2677),
197 | ('Taxi', 2643),
198 | ('Juice', 2615),
199 | ('Cowboy hat', 2604),
200 | ('Apple', 2600),
201 | ('Tin can', 2590),
202 | ('Necklace', 2564),
203 | ('Ice cream', 2560),
204 | ('Human beard', 2539),
205 | ('Coin', 2536),
206 | ('Candle', 2515),
207 | ('Cart', 2512),
208 | ('High heels', 2441),
209 | ('Weapon', 2433),
210 | ('Handbag', 2406),
211 | ('Penguin', 2396),
212 | ('Rifle', 2352),
213 | ('Violin', 2336),
214 | ('Skull', 2304),
215 | ('Lantern', 2285),
216 | ('Scarf', 2269),
217 | ('Saucer', 2225),
218 | ('Sheep', 2215),
219 | ('Vase', 2189),
220 | ('Lily', 2180),
221 | ('Mug', 2154),
222 | ('Parrot', 2140),
223 | ('Human ear', 2137),
224 | ('Sandal', 2115),
225 | ('Lizard', 2100),
226 | ('Kitchen & dining room table', 2063),
227 | ('Spider', 1977),
228 | ('Coffee', 1974),
229 | ('Goat', 1926),
230 | ('Squirrel', 1922),
231 | ('Cello', 1913),
232 | ('Sushi', 1881),
233 | ('Tortoise', 1876),
234 | ('Pizza', 1870),
235 | ('Studio couch', 1864),
236 | ('Barrel', 1862),
237 | ('Cosmetics', 1841),
238 | ('Moths and butterflies', 1841),
239 | ('Convenience store', 1817),
240 | ('Watch', 1792),
241 | ('Home appliance', 1786),
242 | ('Harbor seal', 1780),
243 | ('Luggage and bags', 1756),
244 | ('Vehicle registration plate', 1754),
245 | ('Shrimp', 1751),
246 | ('Jellyfish', 1730),
247 | ('French fries', 1723),
248 | ('Egg (Food)', 1698),
249 | ('Football', 1697),
250 | ('Musical keyboard', 1683),
251 | ('Falcon', 1674),
252 | ('Candy', 1660),
253 | ('Medical equipment', 1654),
254 | ('Eagle', 1651),
255 | ('Dinosaur', 1634),
256 | ('Surfboard', 1630),
257 | ('Tank', 1628),
258 | ('Grape', 1624),
259 | ('Lion', 1624),
260 | ('Owl', 1622),
261 | ('Ski', 1613),
262 | ('Waste container', 1606),
263 | ('Frog', 1591),
264 | ('Sparrow', 1585),
265 | ('Rabbit', 1581),
266 | ('Pen', 1546),
267 | ('Sea lion', 1537),
268 | ('Spoon', 1521),
269 | ('Sink', 1512),
270 | ('Teddy bear', 1507),
271 | ('Bull', 1495),
272 | ('Sofa bed', 1490),
273 | ('Dragonfly', 1479),
274 | ('Brassiere', 1478),
275 | ('Chest of drawers', 1472),
276 | ('Aircraft', 1466),
277 | ('Human foot', 1463),
278 | ('Pig', 1455),
279 | ('Fork', 1454),
280 | ('Antelope', 1438),
281 | ('Tripod', 1427),
282 | ('Tool', 1424),
283 | ('Cheese', 1422),
284 | ('Lemon', 1397),
285 | ('Hamburger', 1393),
286 | ('Dolphin', 1390),
287 | ('Mirror', 1390),
288 | ('Marine mammal', 1387),
289 | ('Giraffe', 1385),
290 | ('Snake', 1368),
291 | ('Gondola', 1364),
292 | ('Wheelchair', 1360),
293 | ('Piano', 1358),
294 | ('Cupboard', 1348),
295 | ('Banana', 1345),
296 | ('Trumpet', 1335),
297 | ('Lighthouse', 1333),
298 | ('Invertebrate', 1317),
299 | ('Carrot', 1268),
300 | ('Sock', 1260),
301 | ('Tiger', 1241),
302 | ('Camel', 1224),
303 | ('Parachute', 1224),
304 | ('Bathroom accessory', 1223),
305 | ('Earrings', 1221),
306 | ('Headphones', 1218),
307 | ('Skirt', 1198),
308 | ('Skateboard', 1190),
309 | ('Sandwich', 1148),
310 | ('Saxophone', 1141),
311 | ('Goldfish', 1136),
312 | ('Stool', 1104),
313 | ('Traffic light', 1097),
314 | ('Shellfish', 1081),
315 | ('Backpack', 1079),
316 | ('Sea turtle', 1078),
317 | ('Cucumber', 1075),
318 | ('Tea', 1051),
319 | ('Toilet', 1047),
320 | ('Roller skates', 1040),
321 | ('Mule', 1039),
322 | ('Bust', 1031),
323 | ('Broccoli', 1030),
324 | ('Crab', 1020),
325 | ('Oyster', 1019),
326 | ('Cannon', 1012),
327 | ('Zebra', 1012),
328 | ('French horn', 1008),
329 | ('Grapefruit', 998),
330 | ('Whiteboard', 997),
331 | ('Zucchini', 997),
332 | ('Crocodile', 992),
333 |
334 | ('Clock', 960),
335 | ('Wall clock', 958),
336 |
337 | ('Doughnut', 869),
338 | ('Snail', 868),
339 |
340 | ('Baseball glove', 859),
341 |
342 | ('Panda', 830),
343 | ('Tennis racket', 830),
344 |
345 | ('Pear', 652),
346 |
347 | ('Bagel', 617),
348 | ('Oven', 616),
349 | ('Ladybug', 615),
350 | ('Shark', 615),
351 | ('Polar bear', 614),
352 | ('Ostrich', 609),
353 |
354 | ('Hot dog', 473),
355 | ('Microwave oven', 467),
356 | ('Fire hydrant', 20),
357 | ('Stop sign', 20),
358 | ('Parking meter', 20),
359 | ('Bear', 20),
360 | ('Flying disc', 20),
361 | ('Snowboard', 20),
362 | ('Tennis ball', 20),
363 | ('Kite', 20),
364 | ('Baseball bat', 20),
365 | ('Kitchen knife', 20),
366 | ('Knife', 20),
367 | ('Submarine sandwich', 20),
368 | ('Computer mouse', 20),
369 | ('Remote control', 20),
370 | ('Toaster', 20),
371 | ('Sink', 20),
372 | ('Refrigerator', 20),
373 | ('Alarm clock', 20),
374 | ('Wall clock', 20),
375 | ('Scissors', 20),
376 | ('Hair dryer', 20),
377 | ('Toothbrush', 20),
378 | ('Suitcase', 20)
379 | ]
380 |
--------------------------------------------------------------------------------
/taming/data/sflckr.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | import albumentations
5 | from PIL import Image
6 | from torch.utils.data import Dataset
7 |
8 |
9 | class SegmentationBase(Dataset):
10 | def __init__(self,
11 | data_csv, data_root, segmentation_root,
12 | size=None, random_crop=False, interpolation="bicubic",
13 | n_labels=182, shift_segmentation=False,
14 | ):
15 | self.n_labels = n_labels
16 | self.shift_segmentation = shift_segmentation
17 | self.data_csv = data_csv
18 | self.data_root = data_root
19 | self.segmentation_root = segmentation_root
20 | with open(self.data_csv, "r") as f:
21 | self.image_paths = f.read().splitlines()
22 | self._length = len(self.image_paths)
23 | self.labels = {
24 | "relative_file_path_": [l for l in self.image_paths],
25 | "file_path_": [os.path.join(self.data_root, l)
26 | for l in self.image_paths],
27 | "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png"))
28 | for l in self.image_paths]
29 | }
30 |
31 | size = None if size is not None and size<=0 else size
32 | self.size = size
33 | if self.size is not None:
34 | self.interpolation = interpolation
35 | self.interpolation = {
36 | "nearest": cv2.INTER_NEAREST,
37 | "bilinear": cv2.INTER_LINEAR,
38 | "bicubic": cv2.INTER_CUBIC,
39 | "area": cv2.INTER_AREA,
40 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
41 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
42 | interpolation=self.interpolation)
43 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
44 | interpolation=cv2.INTER_NEAREST)
45 | self.center_crop = not random_crop
46 | if self.center_crop:
47 | self.cropper = albumentations.CenterCrop(height=self.size, width=self.size)
48 | else:
49 | self.cropper = albumentations.RandomCrop(height=self.size, width=self.size)
50 | self.preprocessor = self.cropper
51 |
52 | def __len__(self):
53 | return self._length
54 |
55 | def __getitem__(self, i):
56 | example = dict((k, self.labels[k][i]) for k in self.labels)
57 | image = Image.open(example["file_path_"])
58 | if not image.mode == "RGB":
59 | image = image.convert("RGB")
60 | image = np.array(image).astype(np.uint8)
61 | if self.size is not None:
62 | image = self.image_rescaler(image=image)["image"]
63 | segmentation = Image.open(example["segmentation_path_"])
64 | assert segmentation.mode == "L", segmentation.mode
65 | segmentation = np.array(segmentation).astype(np.uint8)
66 | if self.shift_segmentation:
67 | # used to support segmentations containing unlabeled==255 label
68 | segmentation = segmentation+1
69 | if self.size is not None:
70 | segmentation = self.segmentation_rescaler(image=segmentation)["image"]
71 | if self.size is not None:
72 | processed = self.preprocessor(image=image,
73 | mask=segmentation
74 | )
75 | else:
76 | processed = {"image": image,
77 | "mask": segmentation
78 | }
79 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
80 | segmentation = processed["mask"]
81 | onehot = np.eye(self.n_labels)[segmentation]
82 | example["segmentation"] = onehot
83 | return example
84 |
85 |
86 | class Examples(SegmentationBase):
87 | def __init__(self, size=None, random_crop=False, interpolation="bicubic"):
88 | super().__init__(data_csv="data/sflckr_examples.txt",
89 | data_root="data/sflckr_images",
90 | segmentation_root="data/sflckr_segmentations",
91 | size=size, random_crop=random_crop, interpolation=interpolation)
92 |
--------------------------------------------------------------------------------
/taming/data/utils.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import os
3 | import tarfile
4 | import urllib
5 | import zipfile
6 | from pathlib import Path
7 |
8 | import numpy as np
9 | import torch
10 | from taming.data.helper_types import Annotation
11 | from torch._six import string_classes
12 | from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
13 | from tqdm import tqdm
14 |
15 |
16 | def unpack(path):
17 | if path.endswith("tar.gz"):
18 | with tarfile.open(path, "r:gz") as tar:
19 | tar.extractall(path=os.path.split(path)[0])
20 | elif path.endswith("tar"):
21 | with tarfile.open(path, "r:") as tar:
22 | tar.extractall(path=os.path.split(path)[0])
23 | elif path.endswith("zip"):
24 | with zipfile.ZipFile(path, "r") as f:
25 | f.extractall(path=os.path.split(path)[0])
26 | else:
27 | raise NotImplementedError(
28 | "Unknown file extension: {}".format(os.path.splitext(path)[1])
29 | )
30 |
31 |
32 | def reporthook(bar):
33 | """tqdm progress bar for downloads."""
34 |
35 | def hook(b=1, bsize=1, tsize=None):
36 | if tsize is not None:
37 | bar.total = tsize
38 | bar.update(b * bsize - bar.n)
39 |
40 | return hook
41 |
42 |
43 | def get_root(name):
44 | base = "data/"
45 | root = os.path.join(base, name)
46 | os.makedirs(root, exist_ok=True)
47 | return root
48 |
49 |
50 | def is_prepared(root):
51 | return Path(root).joinpath(".ready").exists()
52 |
53 |
54 | def mark_prepared(root):
55 | Path(root).joinpath(".ready").touch()
56 |
57 |
58 | def prompt_download(file_, source, target_dir, content_dir=None):
59 | targetpath = os.path.join(target_dir, file_)
60 | while not os.path.exists(targetpath):
61 | if content_dir is not None and os.path.exists(
62 | os.path.join(target_dir, content_dir)
63 | ):
64 | break
65 | print(
66 | "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath)
67 | )
68 | if content_dir is not None:
69 | print(
70 | "Or place its content into '{}'.".format(
71 | os.path.join(target_dir, content_dir)
72 | )
73 | )
74 | input("Press Enter when done...")
75 | return targetpath
76 |
77 |
78 | def download_url(file_, url, target_dir):
79 | targetpath = os.path.join(target_dir, file_)
80 | os.makedirs(target_dir, exist_ok=True)
81 | with tqdm(
82 | unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_
83 | ) as bar:
84 | urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar))
85 | return targetpath
86 |
87 |
88 | def download_urls(urls, target_dir):
89 | paths = dict()
90 | for fname, url in urls.items():
91 | outpath = download_url(fname, url, target_dir)
92 | paths[fname] = outpath
93 | return paths
94 |
95 |
96 | def quadratic_crop(x, bbox, alpha=1.0):
97 | """bbox is xmin, ymin, xmax, ymax"""
98 | im_h, im_w = x.shape[:2]
99 | bbox = np.array(bbox, dtype=np.float32)
100 | bbox = np.clip(bbox, 0, max(im_h, im_w))
101 | center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3])
102 | w = bbox[2] - bbox[0]
103 | h = bbox[3] - bbox[1]
104 | l = int(alpha * max(w, h))
105 | l = max(l, 2)
106 |
107 | required_padding = -1 * min(
108 | center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l)
109 | )
110 | required_padding = int(np.ceil(required_padding))
111 | if required_padding > 0:
112 | padding = [
113 | [required_padding, required_padding],
114 | [required_padding, required_padding],
115 | ]
116 | padding += [[0, 0]] * (len(x.shape) - 2)
117 | x = np.pad(x, padding, "reflect")
118 | center = center[0] + required_padding, center[1] + required_padding
119 | xmin = int(center[0] - l / 2)
120 | ymin = int(center[1] - l / 2)
121 | return np.array(x[ymin : ymin + l, xmin : xmin + l, ...])
122 |
123 |
124 | def custom_collate(batch):
125 | r"""source: pytorch 1.9.0, only one modification to original code """
126 |
127 | elem = batch[0]
128 | elem_type = type(elem)
129 | if isinstance(elem, torch.Tensor):
130 | out = None
131 | if torch.utils.data.get_worker_info() is not None:
132 | # If we're in a background process, concatenate directly into a
133 | # shared memory tensor to avoid an extra copy
134 | numel = sum([x.numel() for x in batch])
135 | storage = elem.storage()._new_shared(numel)
136 | out = elem.new(storage)
137 | return torch.stack(batch, 0, out=out)
138 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
139 | and elem_type.__name__ != 'string_':
140 | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
141 | # array of string classes and object
142 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
143 | raise TypeError(default_collate_err_msg_format.format(elem.dtype))
144 |
145 | return custom_collate([torch.as_tensor(b) for b in batch])
146 | elif elem.shape == (): # scalars
147 | return torch.as_tensor(batch)
148 | elif isinstance(elem, float):
149 | return torch.tensor(batch, dtype=torch.float64)
150 | elif isinstance(elem, int):
151 | return torch.tensor(batch)
152 | elif isinstance(elem, string_classes):
153 | return batch
154 | elif isinstance(elem, collections.abc.Mapping):
155 | return {key: custom_collate([d[key] for d in batch]) for key in elem}
156 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
157 | return elem_type(*(custom_collate(samples) for samples in zip(*batch)))
158 | if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added
159 | return batch # added
160 | elif isinstance(elem, collections.abc.Sequence):
161 | # check to make sure that the elements in batch have consistent size
162 | it = iter(batch)
163 | elem_size = len(next(it))
164 | if not all(len(elem) == elem_size for elem in it):
165 | raise RuntimeError('each element in list of batch should be of equal size')
166 | transposed = zip(*batch)
167 | return [custom_collate(samples) for samples in transposed]
168 |
169 | raise TypeError(default_collate_err_msg_format.format(elem_type))
170 |
--------------------------------------------------------------------------------
/taming/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n):
33 | return self.schedule(n)
34 |
35 |
--------------------------------------------------------------------------------
/taming/models/dummy_cond_stage.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor
2 |
3 |
4 | class DummyCondStage:
5 | def __init__(self, conditional_key):
6 | self.conditional_key = conditional_key
7 | self.train = None
8 |
9 | def eval(self):
10 | return self
11 |
12 | @staticmethod
13 | def encode(c: Tensor):
14 | return c, None, (None, None, c)
15 |
16 | @staticmethod
17 | def decode(c: Tensor):
18 | return c
19 |
20 | @staticmethod
21 | def to_rgb(c: Tensor):
22 | return c
23 |
--------------------------------------------------------------------------------
/taming/modules/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/taming/modules/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/taming/modules/autoencoder/lpips/vgg.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/taming/modules/autoencoder/lpips/vgg.pth
--------------------------------------------------------------------------------
/taming/modules/discriminator/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/taming/modules/discriminator/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/taming/modules/discriminator/model.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch.nn as nn
3 |
4 |
5 | from taming.modules.util import ActNorm
6 |
7 |
8 | def weights_init(m):
9 | classname = m.__class__.__name__
10 | if classname.find('Conv') != -1:
11 | nn.init.normal_(m.weight.data, 0.0, 0.02)
12 | elif classname.find('BatchNorm') != -1:
13 | nn.init.normal_(m.weight.data, 1.0, 0.02)
14 | nn.init.constant_(m.bias.data, 0)
15 |
16 |
17 | class NLayerDiscriminator(nn.Module):
18 | """Defines a PatchGAN discriminator as in Pix2Pix
19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20 | """
21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
22 | """Construct a PatchGAN discriminator
23 | Parameters:
24 | input_nc (int) -- the number of channels in input images
25 | ndf (int) -- the number of filters in the last conv layer
26 | n_layers (int) -- the number of conv layers in the discriminator
27 | norm_layer -- normalization layer
28 | """
29 | super(NLayerDiscriminator, self).__init__()
30 | if not use_actnorm:
31 | norm_layer = nn.BatchNorm2d
32 | else:
33 | norm_layer = ActNorm
34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
35 | use_bias = norm_layer.func != nn.BatchNorm2d
36 | else:
37 | use_bias = norm_layer != nn.BatchNorm2d
38 |
39 | kw = 4
40 | padw = 1
41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
42 | nf_mult = 1
43 | nf_mult_prev = 1
44 | for n in range(1, n_layers): # gradually increase the number of filters
45 | nf_mult_prev = nf_mult
46 | nf_mult = min(2 ** n, 8)
47 | sequence += [
48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
49 | norm_layer(ndf * nf_mult),
50 | nn.LeakyReLU(0.2, True)
51 | ]
52 |
53 | nf_mult_prev = nf_mult
54 | nf_mult = min(2 ** n_layers, 8)
55 | sequence += [
56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
57 | norm_layer(ndf * nf_mult),
58 | nn.LeakyReLU(0.2, True)
59 | ]
60 |
61 | sequence += [
62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
63 | self.main = nn.Sequential(*sequence)
64 |
65 | def forward(self, input):
66 | """Standard forward."""
67 | return self.main(input)
68 |
--------------------------------------------------------------------------------
/taming/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from taming.modules.losses.vqperceptual import DummyLoss
2 |
3 |
--------------------------------------------------------------------------------
/taming/modules/losses/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/taming/modules/losses/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/taming/modules/losses/__pycache__/lpips.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/taming/modules/losses/__pycache__/lpips.cpython-38.pyc
--------------------------------------------------------------------------------
/taming/modules/losses/__pycache__/vqperceptual.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhexinLiang/Control-Color/f21054af54f524591f7a3c0862fa90392f7d33c0/taming/modules/losses/__pycache__/vqperceptual.cpython-38.pyc
--------------------------------------------------------------------------------
/taming/modules/losses/lpips.py:
--------------------------------------------------------------------------------
1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torchvision import models
6 | from collections import namedtuple
7 |
8 | from taming.util import get_ckpt_path
9 |
10 |
11 | class LPIPS(nn.Module):
12 | # Learned perceptual metric
13 | def __init__(self, use_dropout=True):
14 | super().__init__()
15 | self.scaling_layer = ScalingLayer()
16 | self.chns = [64, 128, 256, 512, 512] # vg16 features
17 | self.net = vgg16(pretrained=True, requires_grad=False)
18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
23 | self.load_from_pretrained()
24 | for param in self.parameters():
25 | param.requires_grad = False
26 |
27 | def load_from_pretrained(self, name="vgg_lpips"):
28 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
30 | print("loaded pretrained LPIPS loss from {}".format(ckpt))
31 |
32 | @classmethod
33 | def from_pretrained(cls, name="vgg_lpips"):
34 | if name != "vgg_lpips":
35 | raise NotImplementedError
36 | model = cls()
37 | ckpt = get_ckpt_path(name)
38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
39 | return model
40 |
41 | def forward(self, input, target):
42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
43 | outs0, outs1 = self.net(in0_input), self.net(in1_input)
44 | feats0, feats1, diffs = {}, {}, {}
45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
46 | for kk in range(len(self.chns)):
47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
49 |
50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
51 | val = res[0]
52 | for l in range(1, len(self.chns)):
53 | val += res[l]
54 | return val
55 |
56 |
57 | class ScalingLayer(nn.Module):
58 | def __init__(self):
59 | super(ScalingLayer, self).__init__()
60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
62 |
63 | def forward(self, inp):
64 | return (inp - self.shift) / self.scale
65 |
66 |
67 | class NetLinLayer(nn.Module):
68 | """ A single linear layer which does a 1x1 conv """
69 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
70 | super(NetLinLayer, self).__init__()
71 | layers = [nn.Dropout(), ] if (use_dropout) else []
72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
73 | self.model = nn.Sequential(*layers)
74 |
75 |
76 | class vgg16(torch.nn.Module):
77 | def __init__(self, requires_grad=False, pretrained=True):
78 | super(vgg16, self).__init__()
79 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
80 | self.slice1 = torch.nn.Sequential()
81 | self.slice2 = torch.nn.Sequential()
82 | self.slice3 = torch.nn.Sequential()
83 | self.slice4 = torch.nn.Sequential()
84 | self.slice5 = torch.nn.Sequential()
85 | self.N_slices = 5
86 | for x in range(4):
87 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
88 | for x in range(4, 9):
89 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
90 | for x in range(9, 16):
91 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
92 | for x in range(16, 23):
93 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
94 | for x in range(23, 30):
95 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
96 | if not requires_grad:
97 | for param in self.parameters():
98 | param.requires_grad = False
99 |
100 | def forward(self, X):
101 | h = self.slice1(X)
102 | h_relu1_2 = h
103 | h = self.slice2(h)
104 | h_relu2_2 = h
105 | h = self.slice3(h)
106 | h_relu3_3 = h
107 | h = self.slice4(h)
108 | h_relu4_3 = h
109 | h = self.slice5(h)
110 | h_relu5_3 = h
111 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
112 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
113 | return out
114 |
115 |
116 | def normalize_tensor(x,eps=1e-10):
117 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
118 | return x/(norm_factor+eps)
119 |
120 |
121 | def spatial_average(x, keepdim=True):
122 | return x.mean([2,3],keepdim=keepdim)
123 |
124 |
--------------------------------------------------------------------------------
/taming/modules/losses/segmentation.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class BCELoss(nn.Module):
6 | def forward(self, prediction, target):
7 | loss = F.binary_cross_entropy_with_logits(prediction,target)
8 | return loss, {}
9 |
10 |
11 | class BCELossWithQuant(nn.Module):
12 | def __init__(self, codebook_weight=1.):
13 | super().__init__()
14 | self.codebook_weight = codebook_weight
15 |
16 | def forward(self, qloss, target, prediction, split):
17 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
18 | loss = bce_loss + self.codebook_weight*qloss
19 | return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
20 | "{}/bce_loss".format(split): bce_loss.detach().mean(),
21 | "{}/quant_loss".format(split): qloss.detach().mean()
22 | }
23 |
--------------------------------------------------------------------------------
/taming/modules/misc/coord.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class CoordStage(object):
4 | def __init__(self, n_embed, down_factor):
5 | self.n_embed = n_embed
6 | self.down_factor = down_factor
7 |
8 | def eval(self):
9 | return self
10 |
11 | def encode(self, c):
12 | """fake vqmodel interface"""
13 | assert 0.0 <= c.min() and c.max() <= 1.0
14 | b,ch,h,w = c.shape
15 | assert ch == 1
16 |
17 | c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
18 | mode="area")
19 | c = c.clamp(0.0, 1.0)
20 | c = self.n_embed*c
21 | c_quant = c.round()
22 | c_ind = c_quant.to(dtype=torch.long)
23 |
24 | info = None, None, c_ind
25 | return c_quant, None, info
26 |
27 | def decode(self, c):
28 | c = c/self.n_embed
29 | c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
30 | mode="nearest")
31 | return c
32 |
--------------------------------------------------------------------------------
/taming/modules/transformer/permuter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 |
6 | class AbstractPermuter(nn.Module):
7 | def __init__(self, *args, **kwargs):
8 | super().__init__()
9 | def forward(self, x, reverse=False):
10 | raise NotImplementedError
11 |
12 |
13 | class Identity(AbstractPermuter):
14 | def __init__(self):
15 | super().__init__()
16 |
17 | def forward(self, x, reverse=False):
18 | return x
19 |
20 |
21 | class Subsample(AbstractPermuter):
22 | def __init__(self, H, W):
23 | super().__init__()
24 | C = 1
25 | indices = np.arange(H*W).reshape(C,H,W)
26 | while min(H, W) > 1:
27 | indices = indices.reshape(C,H//2,2,W//2,2)
28 | indices = indices.transpose(0,2,4,1,3)
29 | indices = indices.reshape(C*4,H//2, W//2)
30 | H = H//2
31 | W = W//2
32 | C = C*4
33 | assert H == W == 1
34 | idx = torch.tensor(indices.ravel())
35 | self.register_buffer('forward_shuffle_idx',
36 | nn.Parameter(idx, requires_grad=False))
37 | self.register_buffer('backward_shuffle_idx',
38 | nn.Parameter(torch.argsort(idx), requires_grad=False))
39 |
40 | def forward(self, x, reverse=False):
41 | if not reverse:
42 | return x[:, self.forward_shuffle_idx]
43 | else:
44 | return x[:, self.backward_shuffle_idx]
45 |
46 |
47 | def mortonify(i, j):
48 | """(i,j) index to linear morton code"""
49 | i = np.uint64(i)
50 | j = np.uint64(j)
51 |
52 | z = np.uint(0)
53 |
54 | for pos in range(32):
55 | z = (z |
56 | ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) |
57 | ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1))
58 | )
59 | return z
60 |
61 |
62 | class ZCurve(AbstractPermuter):
63 | def __init__(self, H, W):
64 | super().__init__()
65 | reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)]
66 | idx = np.argsort(reverseidx)
67 | idx = torch.tensor(idx)
68 | reverseidx = torch.tensor(reverseidx)
69 | self.register_buffer('forward_shuffle_idx',
70 | idx)
71 | self.register_buffer('backward_shuffle_idx',
72 | reverseidx)
73 |
74 | def forward(self, x, reverse=False):
75 | if not reverse:
76 | return x[:, self.forward_shuffle_idx]
77 | else:
78 | return x[:, self.backward_shuffle_idx]
79 |
80 |
81 | class SpiralOut(AbstractPermuter):
82 | def __init__(self, H, W):
83 | super().__init__()
84 | assert H == W
85 | size = W
86 | indices = np.arange(size*size).reshape(size,size)
87 |
88 | i0 = size//2
89 | j0 = size//2-1
90 |
91 | i = i0
92 | j = j0
93 |
94 | idx = [indices[i0, j0]]
95 | step_mult = 0
96 | for c in range(1, size//2+1):
97 | step_mult += 1
98 | # steps left
99 | for k in range(step_mult):
100 | i = i - 1
101 | j = j
102 | idx.append(indices[i, j])
103 |
104 | # step down
105 | for k in range(step_mult):
106 | i = i
107 | j = j + 1
108 | idx.append(indices[i, j])
109 |
110 | step_mult += 1
111 | if c < size//2:
112 | # step right
113 | for k in range(step_mult):
114 | i = i + 1
115 | j = j
116 | idx.append(indices[i, j])
117 |
118 | # step up
119 | for k in range(step_mult):
120 | i = i
121 | j = j - 1
122 | idx.append(indices[i, j])
123 | else:
124 | # end reached
125 | for k in range(step_mult-1):
126 | i = i + 1
127 | idx.append(indices[i, j])
128 |
129 | assert len(idx) == size*size
130 | idx = torch.tensor(idx)
131 | self.register_buffer('forward_shuffle_idx', idx)
132 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
133 |
134 | def forward(self, x, reverse=False):
135 | if not reverse:
136 | return x[:, self.forward_shuffle_idx]
137 | else:
138 | return x[:, self.backward_shuffle_idx]
139 |
140 |
141 | class SpiralIn(AbstractPermuter):
142 | def __init__(self, H, W):
143 | super().__init__()
144 | assert H == W
145 | size = W
146 | indices = np.arange(size*size).reshape(size,size)
147 |
148 | i0 = size//2
149 | j0 = size//2-1
150 |
151 | i = i0
152 | j = j0
153 |
154 | idx = [indices[i0, j0]]
155 | step_mult = 0
156 | for c in range(1, size//2+1):
157 | step_mult += 1
158 | # steps left
159 | for k in range(step_mult):
160 | i = i - 1
161 | j = j
162 | idx.append(indices[i, j])
163 |
164 | # step down
165 | for k in range(step_mult):
166 | i = i
167 | j = j + 1
168 | idx.append(indices[i, j])
169 |
170 | step_mult += 1
171 | if c < size//2:
172 | # step right
173 | for k in range(step_mult):
174 | i = i + 1
175 | j = j
176 | idx.append(indices[i, j])
177 |
178 | # step up
179 | for k in range(step_mult):
180 | i = i
181 | j = j - 1
182 | idx.append(indices[i, j])
183 | else:
184 | # end reached
185 | for k in range(step_mult-1):
186 | i = i + 1
187 | idx.append(indices[i, j])
188 |
189 | assert len(idx) == size*size
190 | idx = idx[::-1]
191 | idx = torch.tensor(idx)
192 | self.register_buffer('forward_shuffle_idx', idx)
193 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
194 |
195 | def forward(self, x, reverse=False):
196 | if not reverse:
197 | return x[:, self.forward_shuffle_idx]
198 | else:
199 | return x[:, self.backward_shuffle_idx]
200 |
201 |
202 | class Random(nn.Module):
203 | def __init__(self, H, W):
204 | super().__init__()
205 | indices = np.random.RandomState(1).permutation(H*W)
206 | idx = torch.tensor(indices.ravel())
207 | self.register_buffer('forward_shuffle_idx', idx)
208 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
209 |
210 | def forward(self, x, reverse=False):
211 | if not reverse:
212 | return x[:, self.forward_shuffle_idx]
213 | else:
214 | return x[:, self.backward_shuffle_idx]
215 |
216 |
217 | class AlternateParsing(AbstractPermuter):
218 | def __init__(self, H, W):
219 | super().__init__()
220 | indices = np.arange(W*H).reshape(H,W)
221 | for i in range(1, H, 2):
222 | indices[i, :] = indices[i, ::-1]
223 | idx = indices.flatten()
224 | assert len(idx) == H*W
225 | idx = torch.tensor(idx)
226 | self.register_buffer('forward_shuffle_idx', idx)
227 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
228 |
229 | def forward(self, x, reverse=False):
230 | if not reverse:
231 | return x[:, self.forward_shuffle_idx]
232 | else:
233 | return x[:, self.backward_shuffle_idx]
234 |
235 |
236 | if __name__ == "__main__":
237 | p0 = AlternateParsing(16, 16)
238 | print(p0.forward_shuffle_idx)
239 | print(p0.backward_shuffle_idx)
240 |
241 | x = torch.randint(0, 768, size=(11, 256))
242 | y = p0(x)
243 | xre = p0(y, reverse=True)
244 | assert torch.equal(x, xre)
245 |
246 | p1 = SpiralOut(2, 2)
247 | print(p1.forward_shuffle_idx)
248 | print(p1.backward_shuffle_idx)
249 |
--------------------------------------------------------------------------------
/taming/modules/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def count_params(model):
6 | total_params = sum(p.numel() for p in model.parameters())
7 | return total_params
8 |
9 |
10 | class ActNorm(nn.Module):
11 | def __init__(self, num_features, logdet=False, affine=True,
12 | allow_reverse_init=False):
13 | assert affine
14 | super().__init__()
15 | self.logdet = logdet
16 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
17 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
18 | self.allow_reverse_init = allow_reverse_init
19 |
20 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
21 |
22 | def initialize(self, input):
23 | with torch.no_grad():
24 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
25 | mean = (
26 | flatten.mean(1)
27 | .unsqueeze(1)
28 | .unsqueeze(2)
29 | .unsqueeze(3)
30 | .permute(1, 0, 2, 3)
31 | )
32 | std = (
33 | flatten.std(1)
34 | .unsqueeze(1)
35 | .unsqueeze(2)
36 | .unsqueeze(3)
37 | .permute(1, 0, 2, 3)
38 | )
39 |
40 | self.loc.data.copy_(-mean)
41 | self.scale.data.copy_(1 / (std + 1e-6))
42 |
43 | def forward(self, input, reverse=False):
44 | if reverse:
45 | return self.reverse(input)
46 | if len(input.shape) == 2:
47 | input = input[:,:,None,None]
48 | squeeze = True
49 | else:
50 | squeeze = False
51 |
52 | _, _, height, width = input.shape
53 |
54 | if self.training and self.initialized.item() == 0:
55 | self.initialize(input)
56 | self.initialized.fill_(1)
57 |
58 | h = self.scale * (input + self.loc)
59 |
60 | if squeeze:
61 | h = h.squeeze(-1).squeeze(-1)
62 |
63 | if self.logdet:
64 | log_abs = torch.log(torch.abs(self.scale))
65 | logdet = height*width*torch.sum(log_abs)
66 | logdet = logdet * torch.ones(input.shape[0]).to(input)
67 | return h, logdet
68 |
69 | return h
70 |
71 | def reverse(self, output):
72 | if self.training and self.initialized.item() == 0:
73 | if not self.allow_reverse_init:
74 | raise RuntimeError(
75 | "Initializing ActNorm in reverse direction is "
76 | "disabled by default. Use allow_reverse_init=True to enable."
77 | )
78 | else:
79 | self.initialize(output)
80 | self.initialized.fill_(1)
81 |
82 | if len(output.shape) == 2:
83 | output = output[:,:,None,None]
84 | squeeze = True
85 | else:
86 | squeeze = False
87 |
88 | h = output / self.scale - self.loc
89 |
90 | if squeeze:
91 | h = h.squeeze(-1).squeeze(-1)
92 | return h
93 |
94 |
95 | class AbstractEncoder(nn.Module):
96 | def __init__(self):
97 | super().__init__()
98 |
99 | def encode(self, *args, **kwargs):
100 | raise NotImplementedError
101 |
102 |
103 | class Labelator(AbstractEncoder):
104 | """Net2Net Interface for Class-Conditional Model"""
105 | def __init__(self, n_classes, quantize_interface=True):
106 | super().__init__()
107 | self.n_classes = n_classes
108 | self.quantize_interface = quantize_interface
109 |
110 | def encode(self, c):
111 | c = c[:,None]
112 | if self.quantize_interface:
113 | return c, None, [None, None, c.long()]
114 | return c
115 |
116 |
117 | class SOSProvider(AbstractEncoder):
118 | # for unconditional training
119 | def __init__(self, sos_token, quantize_interface=True):
120 | super().__init__()
121 | self.sos_token = sos_token
122 | self.quantize_interface = quantize_interface
123 |
124 | def encode(self, x):
125 | # get batch size from data and replicate sos_token
126 | c = torch.ones(x.shape[0], 1)*self.sos_token
127 | c = c.long().to(x.device)
128 | if self.quantize_interface:
129 | return c, None, [None, None, c]
130 | return c
131 |
--------------------------------------------------------------------------------
/taming/util.py:
--------------------------------------------------------------------------------
1 | import os, hashlib
2 | import requests
3 | from tqdm import tqdm
4 |
5 | URL_MAP = {
6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
7 | }
8 |
9 | CKPT_MAP = {
10 | "vgg_lpips": "vgg.pth"
11 | }
12 |
13 | MD5_MAP = {
14 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
15 | }
16 |
17 |
18 | def download(url, local_path, chunk_size=1024):
19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
20 | with requests.get(url, stream=True) as r:
21 | total_size = int(r.headers.get("content-length", 0))
22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
23 | with open(local_path, "wb") as f:
24 | for data in r.iter_content(chunk_size=chunk_size):
25 | if data:
26 | f.write(data)
27 | pbar.update(chunk_size)
28 |
29 |
30 | def md5_hash(path):
31 | with open(path, "rb") as f:
32 | content = f.read()
33 | return hashlib.md5(content).hexdigest()
34 |
35 |
36 | def get_ckpt_path(name, root, check=False):
37 | assert name in URL_MAP
38 | path = os.path.join(root, CKPT_MAP[name])
39 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
40 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
41 | download(URL_MAP[name], path)
42 | md5 = md5_hash(path)
43 | assert md5 == MD5_MAP[name], md5
44 | return path
45 |
46 |
47 | class KeyNotFoundError(Exception):
48 | def __init__(self, cause, keys=None, visited=None):
49 | self.cause = cause
50 | self.keys = keys
51 | self.visited = visited
52 | messages = list()
53 | if keys is not None:
54 | messages.append("Key not found: {}".format(keys))
55 | if visited is not None:
56 | messages.append("Visited: {}".format(visited))
57 | messages.append("Cause:\n{}".format(cause))
58 | message = "\n".join(messages)
59 | super().__init__(message)
60 |
61 |
62 | def retrieve(
63 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
64 | ):
65 | """Given a nested list or dict return the desired value at key expanding
66 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion
67 | is done in-place.
68 |
69 | Parameters
70 | ----------
71 | list_or_dict : list or dict
72 | Possibly nested list or dictionary.
73 | key : str
74 | key/to/value, path like string describing all keys necessary to
75 | consider to get to the desired value. List indices can also be
76 | passed here.
77 | splitval : str
78 | String that defines the delimiter between keys of the
79 | different depth levels in `key`.
80 | default : obj
81 | Value returned if :attr:`key` is not found.
82 | expand : bool
83 | Whether to expand callable nodes on the path or not.
84 |
85 | Returns
86 | -------
87 | The desired value or if :attr:`default` is not ``None`` and the
88 | :attr:`key` is not found returns ``default``.
89 |
90 | Raises
91 | ------
92 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
93 | ``None``.
94 | """
95 |
96 | keys = key.split(splitval)
97 |
98 | success = True
99 | try:
100 | visited = []
101 | parent = None
102 | last_key = None
103 | for key in keys:
104 | if callable(list_or_dict):
105 | if not expand:
106 | raise KeyNotFoundError(
107 | ValueError(
108 | "Trying to get past callable node with expand=False."
109 | ),
110 | keys=keys,
111 | visited=visited,
112 | )
113 | list_or_dict = list_or_dict()
114 | parent[last_key] = list_or_dict
115 |
116 | last_key = key
117 | parent = list_or_dict
118 |
119 | try:
120 | if isinstance(list_or_dict, dict):
121 | list_or_dict = list_or_dict[key]
122 | else:
123 | list_or_dict = list_or_dict[int(key)]
124 | except (KeyError, IndexError, ValueError) as e:
125 | raise KeyNotFoundError(e, keys=keys, visited=visited)
126 |
127 | visited += [key]
128 | # final expansion of retrieved value
129 | if expand and callable(list_or_dict):
130 | list_or_dict = list_or_dict()
131 | parent[last_key] = list_or_dict
132 | except KeyNotFoundError as e:
133 | if default is None:
134 | raise e
135 | else:
136 | list_or_dict = default
137 | success = False
138 |
139 | if not pass_success:
140 | return list_or_dict
141 | else:
142 | return list_or_dict, success
143 |
144 |
145 | if __name__ == "__main__":
146 | config = {"keya": "a",
147 | "keyb": "b",
148 | "keyc":
149 | {"cc1": 1,
150 | "cc2": 2,
151 | }
152 | }
153 | from omegaconf import OmegaConf
154 | config = OmegaConf.create(config)
155 | print(config)
156 | retrieve(config, "keya")
157 |
158 |
--------------------------------------------------------------------------------