├── .gitignore ├── LICENSE ├── README.md ├── configs ├── ControlNet │ └── prompt_config.yml └── text_guided │ ├── napoleon.yaml │ ├── nascar.yaml │ └── spiderman_example.yaml ├── demos └── TEXTure_demo.ipynb ├── environment.yaml ├── requirements.txt ├── scripts └── run_texture.py ├── shapes ├── env_sphere.obj ├── napoleon.obj ├── nascar.obj └── spiderman_example.obj ├── src ├── __init__.py ├── annotator │ └── util.py ├── cldm │ ├── cldm.py │ ├── ddim_hacked.py │ ├── hack.py │ ├── logger.py │ └── model.py ├── configs │ ├── __init__.py │ └── train_config.py ├── controlnet_depth.py ├── ldm │ ├── data │ │ ├── __init__.py │ │ └── util.py │ ├── models │ │ ├── autoencoder.py │ │ └── diffusion │ │ │ ├── ddim.py │ │ │ ├── ddpm.py │ │ │ ├── dpm_solver │ │ │ ├── __init__.py │ │ │ ├── dpm_solver.py │ │ │ └── sampler.py │ │ │ ├── plms.py │ │ │ └── sampling_util.py │ ├── modules │ │ ├── attention.py │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ ├── openaimodel.py │ │ │ ├── upscaling.py │ │ │ └── util.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ └── modules.py │ │ ├── image_degradation │ │ │ ├── __init__.py │ │ │ ├── bsrgan.py │ │ │ ├── bsrgan_light.py │ │ │ ├── utils │ │ │ │ └── test.png │ │ │ └── utils_image.py │ │ └── midas │ │ │ ├── __init__.py │ │ │ ├── api.py │ │ │ ├── midas │ │ │ ├── __init__.py │ │ │ ├── base_model.py │ │ │ ├── blocks.py │ │ │ ├── dpt_depth.py │ │ │ ├── midas_net.py │ │ │ ├── midas_net_custom.py │ │ │ ├── transforms.py │ │ │ └── vit.py │ │ │ └── utils.py │ └── util.py ├── models │ ├── __init__.py │ ├── cldm_v15.yaml │ ├── mesh.py │ ├── render.py │ └── textured_mesh.py ├── stable_diffusion_depth.py ├── training │ ├── __init__.py │ ├── trainer.py │ └── views_dataset.py └── utils.py └── textures └── brick_wall.png /.gitignore: -------------------------------------------------------------------------------- 1 | **/.ipynb_checkpoints 2 | **/__pycache__ 3 | *.mp4 4 | *.npy 5 | *.npz 6 | *.dae 7 | data/* 8 | logs/* 9 | .idea 10 | TOKEN 11 | control_sd15_depth.pth -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adding ControlNet to TEXTure 2 | 3 | In this repository we added [ControlNet](https://github.com/lllyasviel/ControlNet) to [TEXTure](https://github.com/TEXTurePaper/TEXTurePaper) to conduct some experiments using 3D models of the Avengers generated through [ECON](https://github.com/YuliangXiu/ECON). 4 | 5 | ## Content 6 | 7 | This repository is forked from the [TEXTure](https://github.com/TEXTurePaper/TEXTurePaper) repository so it has all the code present there. We also added the files from the [ControlNet](https://github.com/lllyasviel/ControlNet) repository to be able to use it. Then we added the files `controlnet_depth.py` and `prompt_config.yml`, and some changes around the TEXTure code to be able to use ControlNet in TEXTure. 8 | 9 | ## Installation 10 | 11 | To install TEXTure with ControlNet first create a virtual environment with the requirements for ControlNet: 12 | 13 | ```bash 14 | conda env create -f environment.yaml 15 | conda activate texture 16 | ``` 17 | 18 | Then, install the requirements for TEXTure: 19 | 20 | ```bash 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | and Kaolin: 25 | 26 | ```bash 27 | pip install kaolin==0.11.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/{TORCH_VER}_{CUDA_VER}.html 28 | ``` 29 | 30 | Note that you also need a token for StableDiffusion. 31 | First accept conditions for the model you want to use, default one is [`stabilityai/stable-diffusion-2-depth`]( https://huggingface.co/stabilityai/stable-diffusion-2-depth). Then, add a TOKEN file [access token](https://huggingface.co/settings/tokens) to the root folder of this project, or use the `huggingface-cli login` command. 32 | 33 | To be able to run ControlNet, you will also need to add the [ControlNet depth model](https://huggingface.co/lllyasviel/ControlNet/blob/main/models/control_sd15_depth.pth), and put it inside the `src/models`. 34 | 35 | 36 | ## Running 37 | 38 | We have an example that uses a 3D model avatar of Spider-Man generated with [ECON](https://github.com/YuliangXiu/ECON). To run it, move to the root folder of this repository and execute the following command: 39 | 40 | ```bash 41 | python -m scripts.run_texture --config_path=configs/text_guided/spiderman_example.yaml 42 | ``` 43 | 44 | In the `experiments` folder you should find the results. 45 | 46 | If you want to execute the basic TEXTure without ControlNet using this code you can add the `control_net: False` attribute to the `config file` you want to execute. For example: 47 | 48 | ```yaml 49 | log: 50 | exp_name: spiderman_example 51 | guide: 52 | text: "Amazing Spiderman, hyper realistic, {} view" 53 | append_direction: True 54 | shape_path: shapes/spiderman_example.obj 55 | control_net: False 56 | optim: 57 | seed: 3 58 | ``` 59 | -------------------------------------------------------------------------------- /configs/ControlNet/prompt_config.yml: -------------------------------------------------------------------------------- 1 | negative_prompt: 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' 2 | additional_prompt: 'best quality, extremely detailed' -------------------------------------------------------------------------------- /configs/text_guided/napoleon.yaml: -------------------------------------------------------------------------------- 1 | log: 2 | exp_name: napoleon 3 | guide: 4 | text: "A photo of Napoleon Bonaparte, {} view" 5 | append_direction: True 6 | shape_path: shapes/napoleon.obj 7 | optim: 8 | seed: 3 9 | -------------------------------------------------------------------------------- /configs/text_guided/nascar.yaml: -------------------------------------------------------------------------------- 1 | log: 2 | exp_name: nascar 3 | save_interval: 18 4 | guide: 5 | text: "A next gen nascar, {} view" 6 | diffusion_name: stabilityai/stable-diffusion-2-depth 7 | shape_scale: 0.6 8 | append_direction: True 9 | shape_path: shapes/nascar.obj 10 | texture_resolution: 1024 11 | guidance_scale: 10 12 | texture_interpolation_mode: 'bilinear' 13 | optim: 14 | seed: 2 15 | render: 16 | front_offset: 0 17 | -------------------------------------------------------------------------------- /configs/text_guided/spiderman_example.yaml: -------------------------------------------------------------------------------- 1 | log: 2 | exp_name: spiderman_example 3 | guide: 4 | text: "Amazing Spiderman, hyper realistic, {} view" 5 | append_direction: True 6 | shape_path: shapes/spiderman_example.obj 7 | optim: 8 | seed: 3 -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: texture 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.12.1 10 | - torchvision=0.13.1 11 | - numpy=1.23.1 12 | - pip: 13 | - gradio==3.16.2 14 | - albumentations==1.3.0 15 | - opencv-contrib-python==4.3.0.36 16 | - imageio==2.9.0 17 | - imageio-ffmpeg==0.4.2 18 | - pytorch-lightning==1.5.0 19 | - omegaconf==2.1.1 20 | - test-tube>=0.7.5 21 | - streamlit==1.12.1 22 | - einops==0.3.0 23 | - transformers==4.19.2 24 | - webdataset==0.2.5 25 | - kornia==0.6 26 | - open_clip_torch==2.0.2 27 | - invisible-watermark>=0.1.5 28 | - streamlit-drawable-canvas==0.8.0 29 | - torchmetrics==0.6.0 30 | - timm==0.6.12 31 | - addict==2.4.0 32 | - yapf==0.32.0 33 | - prettytable==3.6.0 34 | - safetensors==0.2.7 35 | - basicsr==1.4.2 36 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.12.1 2 | torchvision==0.13.1 3 | transformers==4.22.1 4 | diffusers==0.12.1 5 | accelerate 6 | huggingface-hub 7 | ninja 8 | xatlas 9 | imageio 10 | matplotlib 11 | pyrallis 12 | loguru 13 | tqdm 14 | einops 15 | opencv-python 16 | -------------------------------------------------------------------------------- /scripts/run_texture.py: -------------------------------------------------------------------------------- 1 | import pyrallis 2 | 3 | from src.configs.train_config import TrainConfig 4 | from src.training.trainer import TEXTure 5 | 6 | 7 | @pyrallis.wrap() 8 | def main(cfg: TrainConfig): 9 | trainer = TEXTure(cfg) 10 | if cfg.log.eval_only: 11 | trainer.full_eval() 12 | else: 13 | trainer.paint() 14 | 15 | 16 | if __name__ == '__main__': 17 | main() 18 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EliM0/TEXTureControlNet/f5f5fbe018e4eca59a76809bfe3d4d75e85bac86/src/__init__.py -------------------------------------------------------------------------------- /src/annotator/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | 5 | 6 | annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') 7 | 8 | 9 | def HWC3(x): 10 | assert x.dtype == np.uint8 11 | if x.ndim == 2: 12 | x = x[:, :, None] 13 | assert x.ndim == 3 14 | H, W, C = x.shape 15 | assert C == 1 or C == 3 or C == 4 16 | if C == 3: 17 | return x 18 | if C == 1: 19 | return np.concatenate([x, x, x], axis=2) 20 | if C == 4: 21 | color = x[:, :, 0:3].astype(np.float32) 22 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 23 | y = color * alpha + 255.0 * (1.0 - alpha) 24 | y = y.clip(0, 255).astype(np.uint8) 25 | return y 26 | 27 | 28 | def resize_image(input_image, resolution): 29 | H, W, C = input_image.shape 30 | H = float(H) 31 | W = float(W) 32 | k = float(resolution) / min(H, W) 33 | H *= k 34 | W *= k 35 | H = int(np.round(H / 64.0)) * 64 36 | W = int(np.round(W / 64.0)) * 64 37 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) 38 | return img 39 | -------------------------------------------------------------------------------- /src/cldm/ddim_hacked.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | from src.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor 8 | 9 | 10 | class DDIMSampler(object): 11 | def __init__(self, model, schedule="linear", **kwargs): 12 | super().__init__() 13 | self.model = model 14 | self.ddpm_num_timesteps = model.num_timesteps 15 | self.schedule = schedule 16 | 17 | def register_buffer(self, name, attr): 18 | if type(attr) == torch.Tensor: 19 | if attr.device != torch.device("cuda"): 20 | attr = attr.to(torch.device("cuda")) 21 | setattr(self, name, attr) 22 | 23 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 24 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 25 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 26 | alphas_cumprod = self.model.alphas_cumprod 27 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 28 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 29 | 30 | self.register_buffer('betas', to_torch(self.model.betas)) 31 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 32 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 33 | 34 | # calculations for diffusion q(x_t | x_{t-1}) and others 35 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 36 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 37 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 38 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 39 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 40 | 41 | # ddim sampling parameters 42 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 43 | ddim_timesteps=self.ddim_timesteps, 44 | eta=ddim_eta,verbose=verbose) 45 | self.register_buffer('ddim_sigmas', ddim_sigmas) 46 | self.register_buffer('ddim_alphas', ddim_alphas) 47 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 48 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 49 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 50 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 51 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 52 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 53 | 54 | @torch.no_grad() 55 | def sample(self, 56 | S, 57 | batch_size, 58 | shape, 59 | conditioning=None, 60 | callback=None, 61 | normals_sequence=None, 62 | img_callback=None, 63 | quantize_x0=False, 64 | eta=0., 65 | mask=None, 66 | x0=None, 67 | temperature=1., 68 | noise_dropout=0., 69 | score_corrector=None, 70 | corrector_kwargs=None, 71 | verbose=True, 72 | x_T=None, 73 | log_every_t=100, 74 | unconditional_guidance_scale=1., 75 | unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 76 | dynamic_threshold=None, 77 | ucg_schedule=None, 78 | **kwargs 79 | ): 80 | if conditioning is not None: 81 | if isinstance(conditioning, dict): 82 | ctmp = conditioning[list(conditioning.keys())[0]] 83 | while isinstance(ctmp, list): ctmp = ctmp[0] 84 | cbs = ctmp.shape[0] 85 | if cbs != batch_size: 86 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 87 | 88 | elif isinstance(conditioning, list): 89 | for ctmp in conditioning: 90 | if ctmp.shape[0] != batch_size: 91 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 92 | 93 | else: 94 | if conditioning.shape[0] != batch_size: 95 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 96 | 97 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 98 | # sampling 99 | C, H, W = shape 100 | size = (batch_size, C, H, W) 101 | print(f'Data shape for DDIM sampling is {size}, eta {eta}') 102 | 103 | samples, intermediates = self.ddim_sampling(conditioning, size, 104 | callback=callback, 105 | img_callback=img_callback, 106 | quantize_denoised=quantize_x0, 107 | mask=mask, x0=x0, 108 | ddim_use_original_steps=False, 109 | noise_dropout=noise_dropout, 110 | temperature=temperature, 111 | score_corrector=score_corrector, 112 | corrector_kwargs=corrector_kwargs, 113 | x_T=x_T, 114 | log_every_t=log_every_t, 115 | unconditional_guidance_scale=unconditional_guidance_scale, 116 | unconditional_conditioning=unconditional_conditioning, 117 | dynamic_threshold=dynamic_threshold, 118 | ucg_schedule=ucg_schedule 119 | ) 120 | return samples, intermediates 121 | 122 | @torch.no_grad() 123 | def ddim_sampling(self, cond, shape, 124 | x_T=None, ddim_use_original_steps=False, 125 | callback=None, timesteps=None, quantize_denoised=False, 126 | mask=None, x0=None, img_callback=None, log_every_t=100, 127 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 128 | unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, 129 | ucg_schedule=None): 130 | device = self.model.betas.device 131 | b = shape[0] 132 | if x_T is None: 133 | img = torch.randn(shape, device=device) 134 | else: 135 | img = x_T 136 | 137 | if timesteps is None: 138 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 139 | elif timesteps is not None and not ddim_use_original_steps: 140 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 141 | timesteps = self.ddim_timesteps[:subset_end] 142 | 143 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 144 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 145 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 146 | print(f"Running DDIM Sampling with {total_steps} timesteps") 147 | 148 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 149 | 150 | for i, step in enumerate(iterator): 151 | index = total_steps - i - 1 152 | ts = torch.full((b,), step, device=device, dtype=torch.long) 153 | 154 | if mask is not None: 155 | assert x0 is not None 156 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 157 | img = img_orig * mask + (1. - mask) * img 158 | 159 | if ucg_schedule is not None: 160 | assert len(ucg_schedule) == len(time_range) 161 | unconditional_guidance_scale = ucg_schedule[i] 162 | 163 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 164 | quantize_denoised=quantize_denoised, temperature=temperature, 165 | noise_dropout=noise_dropout, score_corrector=score_corrector, 166 | corrector_kwargs=corrector_kwargs, 167 | unconditional_guidance_scale=unconditional_guidance_scale, 168 | unconditional_conditioning=unconditional_conditioning, 169 | dynamic_threshold=dynamic_threshold) 170 | img, pred_x0 = outs 171 | if callback: callback(i) 172 | if img_callback: img_callback(pred_x0, i) 173 | 174 | if index % log_every_t == 0 or index == total_steps - 1: 175 | intermediates['x_inter'].append(img) 176 | intermediates['pred_x0'].append(pred_x0) 177 | 178 | return img, intermediates 179 | 180 | @torch.no_grad() 181 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 182 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 183 | unconditional_guidance_scale=1., unconditional_conditioning=None, 184 | dynamic_threshold=None): 185 | b, *_, device = *x.shape, x.device 186 | 187 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 188 | model_output = self.model.apply_model(x, t, c) 189 | else: 190 | model_t = self.model.apply_model(x, t, c) 191 | model_uncond = self.model.apply_model(x, t, unconditional_conditioning) 192 | model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond) 193 | 194 | if self.model.parameterization == "v": 195 | e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) 196 | else: 197 | e_t = model_output 198 | 199 | if score_corrector is not None: 200 | assert self.model.parameterization == "eps", 'not implemented' 201 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 202 | 203 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 204 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 205 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 206 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 207 | # select parameters corresponding to the currently considered timestep 208 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 209 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 210 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 211 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 212 | 213 | # current prediction for x_0 214 | if self.model.parameterization != "v": 215 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 216 | else: 217 | pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) 218 | 219 | if quantize_denoised: 220 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 221 | 222 | if dynamic_threshold is not None: 223 | raise NotImplementedError() 224 | 225 | # direction pointing to x_t 226 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 227 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 228 | if noise_dropout > 0.: 229 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 230 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 231 | return x_prev, pred_x0 232 | 233 | @torch.no_grad() 234 | def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, 235 | unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): 236 | timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps 237 | num_reference_steps = timesteps.shape[0] 238 | 239 | assert t_enc <= num_reference_steps 240 | num_steps = t_enc 241 | 242 | if use_original_steps: 243 | alphas_next = self.alphas_cumprod[:num_steps] 244 | alphas = self.alphas_cumprod_prev[:num_steps] 245 | else: 246 | alphas_next = self.ddim_alphas[:num_steps] 247 | alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) 248 | 249 | x_next = x0 250 | intermediates = [] 251 | inter_steps = [] 252 | for i in tqdm(range(num_steps), desc='Encoding Image'): 253 | t = torch.full((x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long) 254 | if unconditional_guidance_scale == 1.: 255 | noise_pred = self.model.apply_model(x_next, t, c) 256 | else: 257 | assert unconditional_conditioning is not None 258 | e_t_uncond, noise_pred = torch.chunk( 259 | self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), 260 | torch.cat((unconditional_conditioning, c))), 2) 261 | noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) 262 | 263 | xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next 264 | weighted_noise_pred = alphas_next[i].sqrt() * ( 265 | (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred 266 | x_next = xt_weighted + weighted_noise_pred 267 | if return_intermediates and i % ( 268 | num_steps // return_intermediates) == 0 and i < num_steps - 1: 269 | intermediates.append(x_next) 270 | inter_steps.append(i) 271 | elif return_intermediates and i >= num_steps - 2: 272 | intermediates.append(x_next) 273 | inter_steps.append(i) 274 | if callback: callback(i) 275 | 276 | out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} 277 | if return_intermediates: 278 | out.update({'intermediates': intermediates}) 279 | return x_next, out 280 | 281 | @torch.no_grad() 282 | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): 283 | # fast, but does not allow for exact reconstruction 284 | # t serves as an index to gather the correct alphas 285 | if use_original_steps: 286 | sqrt_alphas_cumprod = self.sqrt_alphas_cumprod 287 | sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod 288 | else: 289 | sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) 290 | sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas 291 | 292 | if noise is None: 293 | noise = torch.randn_like(x0) 294 | return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + 295 | extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) 296 | 297 | @torch.no_grad() 298 | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, 299 | use_original_steps=False, callback=None): 300 | 301 | timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps 302 | timesteps = timesteps[:t_start] 303 | 304 | time_range = np.flip(timesteps) 305 | total_steps = timesteps.shape[0] 306 | print(f"Running DDIM Sampling with {total_steps} timesteps") 307 | 308 | iterator = tqdm(time_range, desc='Decoding image', total=total_steps) 309 | x_dec = x_latent 310 | for i, step in enumerate(iterator): 311 | index = total_steps - i - 1 312 | ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) 313 | x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, 314 | unconditional_guidance_scale=unconditional_guidance_scale, 315 | unconditional_conditioning=unconditional_conditioning) 316 | if callback: callback(i) 317 | return x_dec 318 | -------------------------------------------------------------------------------- /src/cldm/hack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import einops 3 | 4 | import src.ldm.modules.encoders.modules 5 | import src.ldm.modules.attention 6 | 7 | from transformers import logging 8 | from src.ldm.modules.attention import default 9 | 10 | 11 | def disable_verbosity(): 12 | logging.set_verbosity_error() 13 | print('logging improved.') 14 | return 15 | 16 | 17 | def enable_sliced_attention(): 18 | ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward 19 | print('Enabled sliced_attention.') 20 | return 21 | 22 | 23 | def hack_everything(clip_skip=0): 24 | disable_verbosity() 25 | ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward 26 | ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip 27 | print('Enabled clip hacks.') 28 | return 29 | 30 | 31 | # Written by Lvmin 32 | def _hacked_clip_forward(self, text): 33 | PAD = self.tokenizer.pad_token_id 34 | EOS = self.tokenizer.eos_token_id 35 | BOS = self.tokenizer.bos_token_id 36 | 37 | def tokenize(t): 38 | return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"] 39 | 40 | def transformer_encode(t): 41 | if self.clip_skip > 1: 42 | rt = self.transformer(input_ids=t, output_hidden_states=True) 43 | return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip]) 44 | else: 45 | return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state 46 | 47 | def split(x): 48 | return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3] 49 | 50 | def pad(x, p, i): 51 | return x[:i] if len(x) >= i else x + [p] * (i - len(x)) 52 | 53 | raw_tokens_list = tokenize(text) 54 | tokens_list = [] 55 | 56 | for raw_tokens in raw_tokens_list: 57 | raw_tokens_123 = split(raw_tokens) 58 | raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123] 59 | raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123] 60 | tokens_list.append(raw_tokens_123) 61 | 62 | tokens_list = torch.IntTensor(tokens_list).to(self.device) 63 | 64 | feed = einops.rearrange(tokens_list, 'b f i -> (b f) i') 65 | y = transformer_encode(feed) 66 | z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3) 67 | 68 | return z 69 | 70 | 71 | # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py 72 | def _hacked_sliced_attentin_forward(self, x, context=None, mask=None): 73 | h = self.heads 74 | 75 | q = self.to_q(x) 76 | context = default(context, x) 77 | k = self.to_k(context) 78 | v = self.to_v(context) 79 | del context, x 80 | 81 | q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 82 | 83 | limit = k.shape[0] 84 | att_step = 1 85 | q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0)) 86 | k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0)) 87 | v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0)) 88 | 89 | q_chunks.reverse() 90 | k_chunks.reverse() 91 | v_chunks.reverse() 92 | sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) 93 | del k, q, v 94 | for i in range(0, limit, att_step): 95 | q_buffer = q_chunks.pop() 96 | k_buffer = k_chunks.pop() 97 | v_buffer = v_chunks.pop() 98 | sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale 99 | 100 | del k_buffer, q_buffer 101 | # attention, what we cannot get enough of, by chunks 102 | 103 | sim_buffer = sim_buffer.softmax(dim=-1) 104 | 105 | sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer) 106 | del v_buffer 107 | sim[i:i + att_step, :, :] = sim_buffer 108 | 109 | del sim_buffer 110 | sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h) 111 | return self.to_out(sim) 112 | -------------------------------------------------------------------------------- /src/cldm/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torchvision 6 | from PIL import Image 7 | from pytorch_lightning.callbacks import Callback 8 | from pytorch_lightning.utilities.distributed import rank_zero_only 9 | 10 | 11 | class ImageLogger(Callback): 12 | def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True, 13 | rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, 14 | log_images_kwargs=None): 15 | super().__init__() 16 | self.rescale = rescale 17 | self.batch_freq = batch_frequency 18 | self.max_images = max_images 19 | if not increase_log_steps: 20 | self.log_steps = [self.batch_freq] 21 | self.clamp = clamp 22 | self.disabled = disabled 23 | self.log_on_batch_idx = log_on_batch_idx 24 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} 25 | self.log_first_step = log_first_step 26 | 27 | @rank_zero_only 28 | def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): 29 | root = os.path.join(save_dir, "image_log", split) 30 | for k in images: 31 | grid = torchvision.utils.make_grid(images[k], nrow=4) 32 | if self.rescale: 33 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w 34 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) 35 | grid = grid.numpy() 36 | grid = (grid * 255).astype(np.uint8) 37 | filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx) 38 | path = os.path.join(root, filename) 39 | os.makedirs(os.path.split(path)[0], exist_ok=True) 40 | Image.fromarray(grid).save(path) 41 | 42 | def log_img(self, pl_module, batch, batch_idx, split="train"): 43 | check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step 44 | if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 45 | hasattr(pl_module, "log_images") and 46 | callable(pl_module.log_images) and 47 | self.max_images > 0): 48 | logger = type(pl_module.logger) 49 | 50 | is_train = pl_module.training 51 | if is_train: 52 | pl_module.eval() 53 | 54 | with torch.no_grad(): 55 | images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) 56 | 57 | for k in images: 58 | N = min(images[k].shape[0], self.max_images) 59 | images[k] = images[k][:N] 60 | if isinstance(images[k], torch.Tensor): 61 | images[k] = images[k].detach().cpu() 62 | if self.clamp: 63 | images[k] = torch.clamp(images[k], -1., 1.) 64 | 65 | self.log_local(pl_module.logger.save_dir, split, images, 66 | pl_module.global_step, pl_module.current_epoch, batch_idx) 67 | 68 | if is_train: 69 | pl_module.train() 70 | 71 | def check_frequency(self, check_idx): 72 | return check_idx % self.batch_freq == 0 73 | 74 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 75 | if not self.disabled: 76 | self.log_img(pl_module, batch, batch_idx, split="train") 77 | -------------------------------------------------------------------------------- /src/cldm/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from omegaconf import OmegaConf 5 | from src.ldm.util import instantiate_from_config 6 | 7 | 8 | def get_state_dict(d): 9 | return d.get('state_dict', d) 10 | 11 | 12 | def load_state_dict(ckpt_path, location='cpu'): 13 | _, extension = os.path.splitext(ckpt_path) 14 | if extension.lower() == ".safetensors": 15 | import safetensors.torch 16 | state_dict = safetensors.torch.load_file(ckpt_path, device=location) 17 | else: 18 | state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location))) 19 | state_dict = get_state_dict(state_dict) 20 | print(f'Loaded state_dict from [{ckpt_path}]') 21 | return state_dict 22 | 23 | 24 | def create_model(config_path): 25 | config = OmegaConf.load(config_path) 26 | model = instantiate_from_config(config.model).cpu() 27 | print(f'Loaded model config from [{config_path}]') 28 | return model 29 | -------------------------------------------------------------------------------- /src/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EliM0/TEXTureControlNet/f5f5fbe018e4eca59a76809bfe3d4d75e85bac86/src/configs/__init__.py -------------------------------------------------------------------------------- /src/configs/train_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import Optional, Tuple, List 4 | from loguru import logger 5 | 6 | 7 | @dataclass 8 | class RenderConfig: 9 | """ Parameters for the Mesh Renderer """ 10 | # Grid size for rendering during painting 11 | train_grid_size: int = 1200 12 | # Grid size of evaluation 13 | eval_grid_size: int = 1024 14 | # training camera radius range 15 | radius: float = 1.5 16 | # Set [0,overhead_range] as the overhead region 17 | overhead_range: float = 40 18 | # Define the front angle region 19 | front_range: float = 70 20 | # The front offset, use to rotate shape from code 21 | front_offset:float = 0.0 22 | # Number of views to use 23 | n_views: int = 8 24 | # Theta value for rendering during training 25 | base_theta:float = 60 26 | # Additional views to use before rotating around shape 27 | views_before: List[Tuple[float,float]] = field(default_factory=list) 28 | # Additional views to use after rotating around shape 29 | views_after: List[Tuple[float, float]] = field(default_factory=[[180,30],[180,150]].copy) 30 | # Whether to alternate between the rotating views from the different sides 31 | alternate_views: bool = True 32 | 33 | @dataclass 34 | class GuideConfig: 35 | """ Parameters defining the guidance """ 36 | # Guiding text prompt 37 | text: str 38 | # The mesh to paint 39 | shape_path: str = 'shapes/spot_triangulated.obj' 40 | # Append direction to text prompts 41 | append_direction: bool = True 42 | # A Textual-Inversion concept to use 43 | concept_name: Optional[str] = None 44 | # Path to the TI embedding 45 | concept_path: Optional[Path] = None 46 | # A huggingface diffusion model to use 47 | diffusion_name: str = 'stabilityai/stable-diffusion-2-depth' 48 | # Whether to use ControlNET as main model or not (if true, the model specified in diffusion_name is used only for inpainting) 49 | control_net: bool = True 50 | # Scale of mesh in 1x1x1 cube 51 | shape_scale: float = 0.6 52 | # height of mesh 53 | dy: float = 0.25 54 | # texture image resolution 55 | texture_resolution: int = 1024 56 | # texture mapping interpolation mode from texture image, options: 'nearest', 'bilinear', 'bicubic' 57 | texture_interpolation_mode: str= 'bilinear' 58 | # Guidance scale for score distillation 59 | guidance_scale: float = 7.5 60 | # Use inpainting in relevant iterations 61 | use_inpainting: bool = True 62 | # The texture before editing 63 | reference_texture: Optional[Path] = None 64 | # The edited texture 65 | initial_texture: Optional[Path] = None 66 | # Whether to use background color or image 67 | use_background_color: bool = False 68 | # Background image to use 69 | background_img: str = 'textures/brick_wall.png' 70 | # Threshold for defining refine regions 71 | z_update_thr: float = 0.2 72 | # Some more strict masking for projecting back 73 | strict_projection: bool = True 74 | 75 | 76 | @dataclass 77 | class OptimConfig: 78 | """ Parameters for the optimization process """ 79 | # Seed for experiment 80 | seed: int = 0 81 | # Learning rate for projection 82 | lr: float = 1e-2 83 | # For Diffusion model 84 | min_timestep: float = 0.02 85 | # For Diffusion model 86 | max_timestep: float = 0.98 87 | # For Diffusion model 88 | no_noise: bool = False 89 | 90 | 91 | @dataclass 92 | class LogConfig: 93 | """ Parameters for logging and saving """ 94 | # Experiment name 95 | exp_name: str 96 | # Experiment output dir 97 | exp_root: Path = Path('experiments/') 98 | # How many steps between save step 99 | save_interval: int = 100 100 | # Run only test 101 | eval_only: bool = False 102 | # Number of angles to sample for eval during training 103 | eval_size: int = 10 104 | # Number of angles to sample for eval after training 105 | full_eval_size: int = 100 106 | # Export a mesh 107 | save_mesh: bool = True 108 | # Whether to show intermediate diffusion visualizations 109 | vis_diffusion_steps: bool = False 110 | # Whether to log intermediate images 111 | log_images: bool = True 112 | 113 | @property 114 | def exp_dir(self) -> Path: 115 | return self.exp_root / self.exp_name 116 | 117 | 118 | @dataclass 119 | class TrainConfig: 120 | """ The main configuration for the coach trainer """ 121 | log: LogConfig = field(default_factory=LogConfig) 122 | render: RenderConfig = field(default_factory=RenderConfig) 123 | optim: OptimConfig = field(default_factory=OptimConfig) 124 | guide: GuideConfig = field(default_factory=GuideConfig) 125 | -------------------------------------------------------------------------------- /src/controlnet_depth.py: -------------------------------------------------------------------------------- 1 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler 2 | from huggingface_hub import hf_hub_download 3 | from transformers import CLIPTextModel, CLIPTokenizer, logging 4 | 5 | # suppress partial model loading warning 6 | from src import utils 7 | from src.utils import seed_everything 8 | from src.annotator.util import resize_image, HWC3 9 | from src.cldm.model import create_model, load_state_dict 10 | 11 | logging.set_verbosity_error() 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from loguru import logger 16 | from tqdm.auto import tqdm 17 | import cv2 18 | import numpy as np 19 | from PIL import Image 20 | import einops 21 | 22 | 23 | class ControlNet(nn.Module): 24 | def __init__(self, device, model_name='CompVis/stable-diffusion-v1-4', concept_name=None, concept_path=None, 25 | latent_mode=True, min_timestep=0.02, max_timestep=0.98, no_noise=False, 26 | use_inpaint=False): 27 | super().__init__() 28 | 29 | try: 30 | with open('./TOKEN', 'r') as f: 31 | self.token = f.read().replace('\n', '') # remove the last \n! 32 | logger.info(f'loaded hugging face access token from ./TOKEN!') 33 | except FileNotFoundError as e: 34 | self.token = True 35 | logger.warning( 36 | f'try to load hugging face access token from the default place, make sure you have run `huggingface-cli login`.') 37 | 38 | self.device = device 39 | self.latent_mode = latent_mode 40 | self.no_noise = no_noise 41 | self.num_train_timesteps = 1000 42 | self.min_step = int(self.num_train_timesteps * min_timestep) 43 | self.max_step = int(self.num_train_timesteps * max_timestep) 44 | self.use_inpaint = use_inpaint 45 | 46 | logger.info(f'loading stable diffusion with {model_name}...') 47 | 48 | # 1. Load the autoencoder model which will be used to decode the latents into image space. 49 | self.vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae", use_auth_token=self.token).to(self.device) 50 | 51 | # 2. Load the tokenizer and text encoder to tokenize and encode the text. 52 | self.tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer', use_auth_token=self.token) 53 | self.text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder='text_encoder', 54 | use_auth_token=self.token).to(self.device) 55 | self.image_encoder = None 56 | self.image_processor = None 57 | 58 | # 3. The UNet model for generating the latents. 59 | # CHANGED TO REPLACE UNET WITH CONTROLNET 60 | logger.info(f'loading control_net from ./src/models/cldm_v15.yaml...') 61 | self.unet = create_model('./src/models/cldm_v15.yaml').cpu() 62 | self.unet.load_state_dict(load_state_dict('./src/models/control_sd15_depth.pth', location=self.device)) 63 | self.unet.to(self.device) 64 | 65 | if self.use_inpaint: 66 | self.inpaint_unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-inpainting", 67 | subfolder="unet", use_auth_token=self.token).to( 68 | self.device) 69 | 70 | 71 | # 4. Create a scheduler for inference 72 | self.scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", 73 | num_train_timesteps=self.num_train_timesteps, steps_offset=1, 74 | skip_prk_steps=True) 75 | self.alphas = self.scheduler.alphas_cumprod.to(self.device) 76 | 77 | if concept_name is not None: 78 | self.load_concept(concept_name, concept_path) 79 | logger.info(f'\t successfully loaded stable diffusion!') 80 | 81 | def load_concept(self, concept_name, concept_path=None): 82 | if concept_path is None: 83 | repo_id_embeds = f"sd-concepts-library/{concept_name}" 84 | learned_embeds_path = hf_hub_download(repo_id=repo_id_embeds, filename="learned_embeds.bin") 85 | else: 86 | learned_embeds_path = concept_path 87 | 88 | loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") 89 | 90 | # separate token and the embeds 91 | for trained_token in loaded_learned_embeds: 92 | # trained_token = list(loaded_learned_embeds.keys())[0] 93 | print(f'Loading token for {trained_token}') 94 | embeds = loaded_learned_embeds[trained_token] 95 | 96 | # cast to dtype of text_encoder 97 | dtype = self.text_encoder.get_input_embeddings().weight.dtype 98 | embeds.to(dtype) 99 | 100 | # add the token in tokenizer 101 | token = trained_token 102 | num_added_tokens = self.tokenizer.add_tokens(token) 103 | if num_added_tokens == 0: 104 | raise ValueError( 105 | f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer.") 106 | 107 | # resize the token embeddings 108 | self.text_encoder.resize_token_embeddings(len(self.tokenizer)) 109 | 110 | # get the id for the token and assign the embeds 111 | token_id = self.tokenizer.convert_tokens_to_ids(token) 112 | self.text_encoder.get_input_embeddings().weight.data[token_id] = embeds 113 | 114 | def get_text_embeds(self, prompt, negative_prompt=None): 115 | # Tokenize text and get embeddings 116 | text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, 117 | truncation=True, return_tensors='pt') 118 | logger.info(prompt) 119 | logger.info(text_input.input_ids) 120 | 121 | with torch.no_grad(): 122 | text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] 123 | 124 | # Do the same for unconditional embeddings 125 | if negative_prompt is None: 126 | negative_prompt = [''] * len(prompt) 127 | uncond_input = self.tokenizer(negative_prompt, padding='max_length', 128 | max_length=self.tokenizer.model_max_length, return_tensors='pt') 129 | 130 | with torch.no_grad(): 131 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] 132 | 133 | # Cat for final embeddings 134 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 135 | return text_embeddings 136 | 137 | 138 | def get_control_net_inputs(self, depth_mask, prompt, a_prompt, n_prompt, image_resolution, strength): 139 | num_samples = 1 140 | guess_mode = False 141 | input_image = np.uint8(np.array(depth_mask[0,0,:,:].cpu())*255) 142 | input_image = HWC3(input_image) 143 | img = resize_image(input_image, image_resolution) 144 | H, W, C = img.shape 145 | 146 | input_image = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LINEAR) 147 | 148 | control = torch.from_numpy(input_image.copy()).float().cuda() / 255.0 149 | control = torch.stack([control for _ in range(num_samples)], dim=0) 150 | control = einops.rearrange(control, 'b h w c -> b c h w').clone() 151 | 152 | cond = {"c_concat": [control], "c_crossattn": [self.unet.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]} 153 | un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [self.unet.get_learned_conditioning([n_prompt] * num_samples)]} 154 | shape = (4, H // 8, W // 8) 155 | 156 | self.unet.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01 157 | 158 | return cond, un_cond 159 | 160 | 161 | 162 | def img2img_step(self, text_embeddings, inputs, depth_mask, control_net_cond, control_net_uncond, 163 | guidance_scale=100, strength=0.5, num_inference_steps=50, update_mask=None, 164 | latent_mode=False, check_mask=None, fixed_seed=None, check_mask_iters=0.5, intermediate_vis=False): 165 | # input is 1 3 512 512 166 | # depth_mask is 1 1 512 512 167 | # text_embeddings is 2 512 168 | intermediate_results = [] 169 | 170 | def sample(latents, depth_mask, strength, num_inference_steps, update_mask=None, check_mask=None, 171 | masked_latents=None): 172 | self.scheduler.set_timesteps(num_inference_steps) 173 | noise = None 174 | if latents is None: 175 | # Last chanel is reserved for depth 176 | latents = torch.randn( 177 | ( 178 | text_embeddings.shape[0] // 2, self.unet.control_model.in_channels, depth_mask.shape[2], 179 | depth_mask.shape[3]), 180 | device=self.device) 181 | timesteps = self.scheduler.timesteps 182 | else: 183 | # Strength has meaning only when latents are given 184 | timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) 185 | latent_timestep = timesteps[:1] 186 | if fixed_seed is not None: 187 | seed_everything(fixed_seed) 188 | noise = torch.randn_like(latents) 189 | if update_mask is not None: 190 | gt_latents = latents 191 | latents = torch.randn( 192 | (text_embeddings.shape[0] // 2, self.unet.control_model.in_channels, depth_mask.shape[2], 193 | depth_mask.shape[3]), 194 | device=self.device) 195 | else: 196 | latents = self.scheduler.add_noise(latents, noise, latent_timestep) 197 | 198 | depth_mask = torch.cat([depth_mask] * 2) 199 | 200 | with torch.autocast('cuda'): 201 | for i, t in tqdm(enumerate(timesteps)): 202 | is_inpaint_range = self.use_inpaint and (10 < i < 20) 203 | mask_constraints_iters = True # i < 20 204 | is_inpaint_iter = is_inpaint_range # and i %2 == 1 205 | 206 | if not is_inpaint_range and mask_constraints_iters: 207 | if update_mask is not None: 208 | noised_truth = self.scheduler.add_noise(gt_latents, noise, t) 209 | if check_mask is not None and i < int(len(timesteps) * check_mask_iters): 210 | curr_mask = check_mask 211 | else: 212 | curr_mask = update_mask 213 | latents = latents * curr_mask + noised_truth * (1 - curr_mask) 214 | 215 | # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. 216 | latent_model_input = torch.cat([latents] * 2) 217 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, 218 | t) # NOTE: This does nothing 219 | 220 | if is_inpaint_iter: 221 | latent_mask = torch.cat([update_mask] * 2) 222 | latent_image = torch.cat([masked_latents] * 2) 223 | latent_model_input_inpaint = torch.cat([latent_model_input, latent_mask, latent_image], dim=1) 224 | with torch.no_grad(): 225 | noise_pred_inpaint = \ 226 | self.inpaint_unet(latent_model_input_inpaint, t, encoder_hidden_states=text_embeddings)[ 227 | 'sample'] 228 | noise_pred = noise_pred_inpaint 229 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 230 | else: 231 | # CHANGED TO USE CONTROLNET 232 | # predict the noise residual 233 | with torch.no_grad(): 234 | ts = torch.full((1,), t, device=self.device, dtype=torch.long) 235 | noise_pred_text = self.unet.apply_model(latents, ts, control_net_cond) 236 | noise_pred_uncond = self.unet.apply_model(latents, ts, control_net_uncond) 237 | 238 | # perform guidance 239 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 240 | 241 | # compute the previous noisy sample x_t -> x_t-1 242 | 243 | if intermediate_vis: 244 | vis_alpha_t = torch.sqrt(self.scheduler.alphas_cumprod) 245 | vis_sigma_t = torch.sqrt(1 - self.scheduler.alphas_cumprod) 246 | a_t, s_t = vis_alpha_t[t], vis_sigma_t[t] 247 | vis_latents = (latents - s_t * noise) / a_t 248 | vis_latents = 1 / 0.18215 * vis_latents 249 | image = self.vae.decode(vis_latents).sample 250 | image = (image / 2 + 0.5).clamp(0, 1) 251 | image = image.cpu().permute(0, 2, 3, 1).numpy() 252 | image = Image.fromarray((image[0] * 255).round().astype("uint8")) 253 | intermediate_results.append(image) 254 | latents = self.scheduler.step(noise_pred, t, latents)['prev_sample'] 255 | 256 | return latents 257 | 258 | depth_mask = F.interpolate(depth_mask, size=(64, 64), mode='bicubic', 259 | align_corners=False) 260 | masked_latents = None 261 | if inputs is None: 262 | latents = None 263 | elif latent_mode: 264 | latents = inputs 265 | else: 266 | pred_rgb_512 = F.interpolate(inputs, (512, 512), mode='bilinear', 267 | align_corners=False) 268 | latents = self.encode_imgs(pred_rgb_512) 269 | if self.use_inpaint: 270 | update_mask_512 = F.interpolate(update_mask, (512, 512)) 271 | masked_inputs = pred_rgb_512 * (update_mask_512 < 0.5) + 0.5 * (update_mask_512 >= 0.5) 272 | masked_latents = self.encode_imgs(masked_inputs) 273 | 274 | if update_mask is not None: 275 | update_mask = F.interpolate(update_mask, (64, 64), mode='nearest') 276 | if check_mask is not None: 277 | check_mask = F.interpolate(check_mask, (64, 64), mode='nearest') 278 | 279 | depth_mask = 2.0 * (depth_mask - depth_mask.min()) / (depth_mask.max() - depth_mask.min()) - 1.0 280 | 281 | # timestep ~ U(0.02, 0.98) to avoid very high/low noise level 282 | t = (self.min_step + self.max_step) // 2 283 | 284 | with torch.no_grad(): 285 | target_latents = sample(latents, depth_mask, strength=strength, num_inference_steps=num_inference_steps, 286 | update_mask=update_mask, check_mask=check_mask, masked_latents=masked_latents) 287 | target_rgb = self.decode_latents(target_latents) 288 | 289 | if latent_mode: 290 | return target_rgb, target_latents 291 | else: 292 | return target_rgb, intermediate_results 293 | 294 | def decode_latents(self, latents): 295 | latents = 1 / 0.18215 * latents 296 | 297 | with torch.no_grad(): 298 | imgs = self.vae.decode(latents).sample 299 | 300 | imgs = (imgs / 2 + 0.5).clamp(0, 1) 301 | 302 | return imgs 303 | 304 | def encode_imgs(self, imgs): 305 | # imgs: [B, 3, H, W] 306 | 307 | imgs = 2 * imgs - 1 308 | 309 | posterior = self.vae.encode(imgs).latent_dist 310 | latents = posterior.sample() * 0.18215 311 | 312 | return latents 313 | 314 | def get_timesteps(self, num_inference_steps, strength): 315 | # get the original timestep using init_timestep 316 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) 317 | 318 | t_start = max(num_inference_steps - init_timestep, 0) 319 | timesteps = self.scheduler.timesteps[t_start:] 320 | 321 | return timesteps, num_inference_steps - t_start 322 | 323 | -------------------------------------------------------------------------------- /src/ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EliM0/TEXTureControlNet/f5f5fbe018e4eca59a76809bfe3d4d75e85bac86/src/ldm/data/__init__.py -------------------------------------------------------------------------------- /src/ldm/data/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.ldm.modules.midas.api import load_midas_transform 4 | 5 | 6 | class AddMiDaS(object): 7 | def __init__(self, model_type): 8 | super().__init__() 9 | self.transform = load_midas_transform(model_type) 10 | 11 | def pt2np(self, x): 12 | x = ((x + 1.0) * .5).detach().cpu().numpy() 13 | return x 14 | 15 | def np2pt(self, x): 16 | x = torch.from_numpy(x) * 2 - 1. 17 | return x 18 | 19 | def __call__(self, sample): 20 | # sample['jpg'] is tensor hwc in [-1, 1] at this point 21 | x = self.pt2np(sample['jpg']) 22 | x = self.transform({"image": x})["image"] 23 | sample['midas_in'] = x 24 | return sample -------------------------------------------------------------------------------- /src/ldm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import torch.nn.functional as F 4 | from contextlib import contextmanager 5 | 6 | from src.ldm.modules.diffusionmodules.model import Encoder, Decoder 7 | from src.ldm.modules.distributions.distributions import DiagonalGaussianDistribution 8 | 9 | from src.ldm.util import instantiate_from_config 10 | from src.ldm.modules.ema import LitEma 11 | 12 | 13 | class AutoencoderKL(pl.LightningModule): 14 | def __init__(self, 15 | ddconfig, 16 | lossconfig, 17 | embed_dim, 18 | ckpt_path=None, 19 | ignore_keys=[], 20 | image_key="image", 21 | colorize_nlabels=None, 22 | monitor=None, 23 | ema_decay=None, 24 | learn_logvar=False 25 | ): 26 | super().__init__() 27 | self.learn_logvar = learn_logvar 28 | self.image_key = image_key 29 | self.encoder = Encoder(**ddconfig) 30 | self.decoder = Decoder(**ddconfig) 31 | self.loss = instantiate_from_config(lossconfig) 32 | assert ddconfig["double_z"] 33 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 34 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 35 | self.embed_dim = embed_dim 36 | if colorize_nlabels is not None: 37 | assert type(colorize_nlabels)==int 38 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 39 | if monitor is not None: 40 | self.monitor = monitor 41 | 42 | self.use_ema = ema_decay is not None 43 | if self.use_ema: 44 | self.ema_decay = ema_decay 45 | assert 0. < ema_decay < 1. 46 | self.model_ema = LitEma(self, decay=ema_decay) 47 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 48 | 49 | if ckpt_path is not None: 50 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 51 | 52 | def init_from_ckpt(self, path, ignore_keys=list()): 53 | sd = torch.load(path, map_location="cpu")["state_dict"] 54 | keys = list(sd.keys()) 55 | for k in keys: 56 | for ik in ignore_keys: 57 | if k.startswith(ik): 58 | print("Deleting key {} from state_dict.".format(k)) 59 | del sd[k] 60 | self.load_state_dict(sd, strict=False) 61 | print(f"Restored from {path}") 62 | 63 | @contextmanager 64 | def ema_scope(self, context=None): 65 | if self.use_ema: 66 | self.model_ema.store(self.parameters()) 67 | self.model_ema.copy_to(self) 68 | if context is not None: 69 | print(f"{context}: Switched to EMA weights") 70 | try: 71 | yield None 72 | finally: 73 | if self.use_ema: 74 | self.model_ema.restore(self.parameters()) 75 | if context is not None: 76 | print(f"{context}: Restored training weights") 77 | 78 | def on_train_batch_end(self, *args, **kwargs): 79 | if self.use_ema: 80 | self.model_ema(self) 81 | 82 | def encode(self, x): 83 | h = self.encoder(x) 84 | moments = self.quant_conv(h) 85 | posterior = DiagonalGaussianDistribution(moments) 86 | return posterior 87 | 88 | def decode(self, z): 89 | z = self.post_quant_conv(z) 90 | dec = self.decoder(z) 91 | return dec 92 | 93 | def forward(self, input, sample_posterior=True): 94 | posterior = self.encode(input) 95 | if sample_posterior: 96 | z = posterior.sample() 97 | else: 98 | z = posterior.mode() 99 | dec = self.decode(z) 100 | return dec, posterior 101 | 102 | def get_input(self, batch, k): 103 | x = batch[k] 104 | if len(x.shape) == 3: 105 | x = x[..., None] 106 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 107 | return x 108 | 109 | def training_step(self, batch, batch_idx, optimizer_idx): 110 | inputs = self.get_input(batch, self.image_key) 111 | reconstructions, posterior = self(inputs) 112 | 113 | if optimizer_idx == 0: 114 | # train encoder+decoder+logvar 115 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 116 | last_layer=self.get_last_layer(), split="train") 117 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 118 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 119 | return aeloss 120 | 121 | if optimizer_idx == 1: 122 | # train the discriminator 123 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 124 | last_layer=self.get_last_layer(), split="train") 125 | 126 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 127 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 128 | return discloss 129 | 130 | def validation_step(self, batch, batch_idx): 131 | log_dict = self._validation_step(batch, batch_idx) 132 | with self.ema_scope(): 133 | log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") 134 | return log_dict 135 | 136 | def _validation_step(self, batch, batch_idx, postfix=""): 137 | inputs = self.get_input(batch, self.image_key) 138 | reconstructions, posterior = self(inputs) 139 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 140 | last_layer=self.get_last_layer(), split="val"+postfix) 141 | 142 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 143 | last_layer=self.get_last_layer(), split="val"+postfix) 144 | 145 | self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) 146 | self.log_dict(log_dict_ae) 147 | self.log_dict(log_dict_disc) 148 | return self.log_dict 149 | 150 | def configure_optimizers(self): 151 | lr = self.learning_rate 152 | ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( 153 | self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) 154 | if self.learn_logvar: 155 | print(f"{self.__class__.__name__}: Learning logvar") 156 | ae_params_list.append(self.loss.logvar) 157 | opt_ae = torch.optim.Adam(ae_params_list, 158 | lr=lr, betas=(0.5, 0.9)) 159 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 160 | lr=lr, betas=(0.5, 0.9)) 161 | return [opt_ae, opt_disc], [] 162 | 163 | def get_last_layer(self): 164 | return self.decoder.conv_out.weight 165 | 166 | @torch.no_grad() 167 | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): 168 | log = dict() 169 | x = self.get_input(batch, self.image_key) 170 | x = x.to(self.device) 171 | if not only_inputs: 172 | xrec, posterior = self(x) 173 | if x.shape[1] > 3: 174 | # colorize with random projection 175 | assert xrec.shape[1] > 3 176 | x = self.to_rgb(x) 177 | xrec = self.to_rgb(xrec) 178 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 179 | log["reconstructions"] = xrec 180 | if log_ema or self.use_ema: 181 | with self.ema_scope(): 182 | xrec_ema, posterior_ema = self(x) 183 | if x.shape[1] > 3: 184 | # colorize with random projection 185 | assert xrec_ema.shape[1] > 3 186 | xrec_ema = self.to_rgb(xrec_ema) 187 | log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) 188 | log["reconstructions_ema"] = xrec_ema 189 | log["inputs"] = x 190 | return log 191 | 192 | def to_rgb(self, x): 193 | assert self.image_key == "segmentation" 194 | if not hasattr(self, "colorize"): 195 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 196 | x = F.conv2d(x, weight=self.colorize) 197 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 198 | return x 199 | 200 | 201 | class IdentityFirstStage(torch.nn.Module): 202 | def __init__(self, *args, vq_interface=False, **kwargs): 203 | self.vq_interface = vq_interface 204 | super().__init__() 205 | 206 | def encode(self, x, *args, **kwargs): 207 | return x 208 | 209 | def decode(self, x, *args, **kwargs): 210 | return x 211 | 212 | def quantize(self, x, *args, **kwargs): 213 | if self.vq_interface: 214 | return x, None, [None, None, None] 215 | return x 216 | 217 | def forward(self, x, *args, **kwargs): 218 | return x 219 | 220 | -------------------------------------------------------------------------------- /src/ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /src/ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | import torch 3 | 4 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 5 | 6 | 7 | MODEL_TYPES = { 8 | "eps": "noise", 9 | "v": "v" 10 | } 11 | 12 | 13 | class DPMSolverSampler(object): 14 | def __init__(self, model, **kwargs): 15 | super().__init__() 16 | self.model = model 17 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 18 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 19 | 20 | def register_buffer(self, name, attr): 21 | if type(attr) == torch.Tensor: 22 | if attr.device != torch.device("cuda"): 23 | attr = attr.to(torch.device("cuda")) 24 | setattr(self, name, attr) 25 | 26 | @torch.no_grad() 27 | def sample(self, 28 | S, 29 | batch_size, 30 | shape, 31 | conditioning=None, 32 | callback=None, 33 | normals_sequence=None, 34 | img_callback=None, 35 | quantize_x0=False, 36 | eta=0., 37 | mask=None, 38 | x0=None, 39 | temperature=1., 40 | noise_dropout=0., 41 | score_corrector=None, 42 | corrector_kwargs=None, 43 | verbose=True, 44 | x_T=None, 45 | log_every_t=100, 46 | unconditional_guidance_scale=1., 47 | unconditional_conditioning=None, 48 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 49 | **kwargs 50 | ): 51 | if conditioning is not None: 52 | if isinstance(conditioning, dict): 53 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 54 | if cbs != batch_size: 55 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 56 | else: 57 | if conditioning.shape[0] != batch_size: 58 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 59 | 60 | # sampling 61 | C, H, W = shape 62 | size = (batch_size, C, H, W) 63 | 64 | print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 65 | 66 | device = self.model.betas.device 67 | if x_T is None: 68 | img = torch.randn(size, device=device) 69 | else: 70 | img = x_T 71 | 72 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 73 | 74 | model_fn = model_wrapper( 75 | lambda x, t, c: self.model.apply_model(x, t, c), 76 | ns, 77 | model_type=MODEL_TYPES[self.model.parameterization], 78 | guidance_type="classifier-free", 79 | condition=conditioning, 80 | unconditional_condition=unconditional_conditioning, 81 | guidance_scale=unconditional_guidance_scale, 82 | ) 83 | 84 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 85 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) 86 | 87 | return x.to(device), None -------------------------------------------------------------------------------- /src/ldm/models/diffusion/plms.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | 8 | from src.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 9 | from src.ldm.models.diffusion.sampling_util import norm_thresholding 10 | 11 | 12 | class PLMSSampler(object): 13 | def __init__(self, model, schedule="linear", **kwargs): 14 | super().__init__() 15 | self.model = model 16 | self.ddpm_num_timesteps = model.num_timesteps 17 | self.schedule = schedule 18 | 19 | def register_buffer(self, name, attr): 20 | if type(attr) == torch.Tensor: 21 | if attr.device != torch.device("cuda"): 22 | attr = attr.to(torch.device("cuda")) 23 | setattr(self, name, attr) 24 | 25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 26 | if ddim_eta != 0: 27 | raise ValueError('ddim_eta must be 0 for PLMS') 28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 30 | alphas_cumprod = self.model.alphas_cumprod 31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 33 | 34 | self.register_buffer('betas', to_torch(self.model.betas)) 35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 37 | 38 | # calculations for diffusion q(x_t | x_{t-1}) and others 39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 44 | 45 | # ddim sampling parameters 46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 47 | ddim_timesteps=self.ddim_timesteps, 48 | eta=ddim_eta,verbose=verbose) 49 | self.register_buffer('ddim_sigmas', ddim_sigmas) 50 | self.register_buffer('ddim_alphas', ddim_alphas) 51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 57 | 58 | @torch.no_grad() 59 | def sample(self, 60 | S, 61 | batch_size, 62 | shape, 63 | conditioning=None, 64 | callback=None, 65 | normals_sequence=None, 66 | img_callback=None, 67 | quantize_x0=False, 68 | eta=0., 69 | mask=None, 70 | x0=None, 71 | temperature=1., 72 | noise_dropout=0., 73 | score_corrector=None, 74 | corrector_kwargs=None, 75 | verbose=True, 76 | x_T=None, 77 | log_every_t=100, 78 | unconditional_guidance_scale=1., 79 | unconditional_conditioning=None, 80 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 81 | dynamic_threshold=None, 82 | **kwargs 83 | ): 84 | if conditioning is not None: 85 | if isinstance(conditioning, dict): 86 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 87 | if cbs != batch_size: 88 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 89 | else: 90 | if conditioning.shape[0] != batch_size: 91 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 92 | 93 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 94 | # sampling 95 | C, H, W = shape 96 | size = (batch_size, C, H, W) 97 | print(f'Data shape for PLMS sampling is {size}') 98 | 99 | samples, intermediates = self.plms_sampling(conditioning, size, 100 | callback=callback, 101 | img_callback=img_callback, 102 | quantize_denoised=quantize_x0, 103 | mask=mask, x0=x0, 104 | ddim_use_original_steps=False, 105 | noise_dropout=noise_dropout, 106 | temperature=temperature, 107 | score_corrector=score_corrector, 108 | corrector_kwargs=corrector_kwargs, 109 | x_T=x_T, 110 | log_every_t=log_every_t, 111 | unconditional_guidance_scale=unconditional_guidance_scale, 112 | unconditional_conditioning=unconditional_conditioning, 113 | dynamic_threshold=dynamic_threshold, 114 | ) 115 | return samples, intermediates 116 | 117 | @torch.no_grad() 118 | def plms_sampling(self, cond, shape, 119 | x_T=None, ddim_use_original_steps=False, 120 | callback=None, timesteps=None, quantize_denoised=False, 121 | mask=None, x0=None, img_callback=None, log_every_t=100, 122 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 123 | unconditional_guidance_scale=1., unconditional_conditioning=None, 124 | dynamic_threshold=None): 125 | device = self.model.betas.device 126 | b = shape[0] 127 | if x_T is None: 128 | img = torch.randn(shape, device=device) 129 | else: 130 | img = x_T 131 | 132 | if timesteps is None: 133 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 134 | elif timesteps is not None and not ddim_use_original_steps: 135 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 136 | timesteps = self.ddim_timesteps[:subset_end] 137 | 138 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 139 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) 140 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 141 | print(f"Running PLMS Sampling with {total_steps} timesteps") 142 | 143 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) 144 | old_eps = [] 145 | 146 | for i, step in enumerate(iterator): 147 | index = total_steps - i - 1 148 | ts = torch.full((b,), step, device=device, dtype=torch.long) 149 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) 150 | 151 | if mask is not None: 152 | assert x0 is not None 153 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 154 | img = img_orig * mask + (1. - mask) * img 155 | 156 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 157 | quantize_denoised=quantize_denoised, temperature=temperature, 158 | noise_dropout=noise_dropout, score_corrector=score_corrector, 159 | corrector_kwargs=corrector_kwargs, 160 | unconditional_guidance_scale=unconditional_guidance_scale, 161 | unconditional_conditioning=unconditional_conditioning, 162 | old_eps=old_eps, t_next=ts_next, 163 | dynamic_threshold=dynamic_threshold) 164 | img, pred_x0, e_t = outs 165 | old_eps.append(e_t) 166 | if len(old_eps) >= 4: 167 | old_eps.pop(0) 168 | if callback: callback(i) 169 | if img_callback: img_callback(pred_x0, i) 170 | 171 | if index % log_every_t == 0 or index == total_steps - 1: 172 | intermediates['x_inter'].append(img) 173 | intermediates['pred_x0'].append(pred_x0) 174 | 175 | return img, intermediates 176 | 177 | @torch.no_grad() 178 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 179 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 180 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, 181 | dynamic_threshold=None): 182 | b, *_, device = *x.shape, x.device 183 | 184 | def get_model_output(x, t): 185 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 186 | e_t = self.model.apply_model(x, t, c) 187 | else: 188 | x_in = torch.cat([x] * 2) 189 | t_in = torch.cat([t] * 2) 190 | c_in = torch.cat([unconditional_conditioning, c]) 191 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 192 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 193 | 194 | if score_corrector is not None: 195 | assert self.model.parameterization == "eps" 196 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 197 | 198 | return e_t 199 | 200 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 201 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 202 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 203 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 204 | 205 | def get_x_prev_and_pred_x0(e_t, index): 206 | # select parameters corresponding to the currently considered timestep 207 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 208 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 209 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 210 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 211 | 212 | # current prediction for x_0 213 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 214 | if quantize_denoised: 215 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 216 | if dynamic_threshold is not None: 217 | pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) 218 | # direction pointing to x_t 219 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 220 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 221 | if noise_dropout > 0.: 222 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 223 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 224 | return x_prev, pred_x0 225 | 226 | e_t = get_model_output(x, t) 227 | if len(old_eps) == 0: 228 | # Pseudo Improved Euler (2nd order) 229 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) 230 | e_t_next = get_model_output(x_prev, t_next) 231 | e_t_prime = (e_t + e_t_next) / 2 232 | elif len(old_eps) == 1: 233 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth) 234 | e_t_prime = (3 * e_t - old_eps[-1]) / 2 235 | elif len(old_eps) == 2: 236 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth) 237 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 238 | elif len(old_eps) >= 3: 239 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth) 240 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 241 | 242 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) 243 | 244 | return x_prev, pred_x0, e_t 245 | -------------------------------------------------------------------------------- /src/ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def norm_thresholding(x0, value): 15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 16 | return x0 * (value / s) 17 | 18 | 19 | def spatial_norm_thresholding(x0, value): 20 | # b c h w 21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 22 | return x0 * (value / s) -------------------------------------------------------------------------------- /src/ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | from typing import Optional, Any 8 | 9 | from src.ldm.modules.diffusionmodules.util import checkpoint 10 | 11 | 12 | try: 13 | import xformers 14 | import xformers.ops 15 | XFORMERS_IS_AVAILBLE = True 16 | except: 17 | XFORMERS_IS_AVAILBLE = False 18 | 19 | # CrossAttn precision handling 20 | import os 21 | _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") 22 | 23 | def exists(val): 24 | return val is not None 25 | 26 | 27 | def uniq(arr): 28 | return{el: True for el in arr}.keys() 29 | 30 | 31 | def default(val, d): 32 | if exists(val): 33 | return val 34 | return d() if isfunction(d) else d 35 | 36 | 37 | def max_neg_value(t): 38 | return -torch.finfo(t.dtype).max 39 | 40 | 41 | def init_(tensor): 42 | dim = tensor.shape[-1] 43 | std = 1 / math.sqrt(dim) 44 | tensor.uniform_(-std, std) 45 | return tensor 46 | 47 | 48 | # feedforward 49 | class GEGLU(nn.Module): 50 | def __init__(self, dim_in, dim_out): 51 | super().__init__() 52 | self.proj = nn.Linear(dim_in, dim_out * 2) 53 | 54 | def forward(self, x): 55 | x, gate = self.proj(x).chunk(2, dim=-1) 56 | return x * F.gelu(gate) 57 | 58 | 59 | class FeedForward(nn.Module): 60 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 61 | super().__init__() 62 | inner_dim = int(dim * mult) 63 | dim_out = default(dim_out, dim) 64 | project_in = nn.Sequential( 65 | nn.Linear(dim, inner_dim), 66 | nn.GELU() 67 | ) if not glu else GEGLU(dim, inner_dim) 68 | 69 | self.net = nn.Sequential( 70 | project_in, 71 | nn.Dropout(dropout), 72 | nn.Linear(inner_dim, dim_out) 73 | ) 74 | 75 | def forward(self, x): 76 | return self.net(x) 77 | 78 | 79 | def zero_module(module): 80 | """ 81 | Zero out the parameters of a module and return it. 82 | """ 83 | for p in module.parameters(): 84 | p.detach().zero_() 85 | return module 86 | 87 | 88 | def Normalize(in_channels): 89 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 90 | 91 | 92 | class SpatialSelfAttention(nn.Module): 93 | def __init__(self, in_channels): 94 | super().__init__() 95 | self.in_channels = in_channels 96 | 97 | self.norm = Normalize(in_channels) 98 | self.q = torch.nn.Conv2d(in_channels, 99 | in_channels, 100 | kernel_size=1, 101 | stride=1, 102 | padding=0) 103 | self.k = torch.nn.Conv2d(in_channels, 104 | in_channels, 105 | kernel_size=1, 106 | stride=1, 107 | padding=0) 108 | self.v = torch.nn.Conv2d(in_channels, 109 | in_channels, 110 | kernel_size=1, 111 | stride=1, 112 | padding=0) 113 | self.proj_out = torch.nn.Conv2d(in_channels, 114 | in_channels, 115 | kernel_size=1, 116 | stride=1, 117 | padding=0) 118 | 119 | def forward(self, x): 120 | h_ = x 121 | h_ = self.norm(h_) 122 | q = self.q(h_) 123 | k = self.k(h_) 124 | v = self.v(h_) 125 | 126 | # compute attention 127 | b,c,h,w = q.shape 128 | q = rearrange(q, 'b c h w -> b (h w) c') 129 | k = rearrange(k, 'b c h w -> b c (h w)') 130 | w_ = torch.einsum('bij,bjk->bik', q, k) 131 | 132 | w_ = w_ * (int(c)**(-0.5)) 133 | w_ = torch.nn.functional.softmax(w_, dim=2) 134 | 135 | # attend to values 136 | v = rearrange(v, 'b c h w -> b c (h w)') 137 | w_ = rearrange(w_, 'b i j -> b j i') 138 | h_ = torch.einsum('bij,bjk->bik', v, w_) 139 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 140 | h_ = self.proj_out(h_) 141 | 142 | return x+h_ 143 | 144 | 145 | class CrossAttention(nn.Module): 146 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 147 | super().__init__() 148 | inner_dim = dim_head * heads 149 | context_dim = default(context_dim, query_dim) 150 | 151 | self.scale = dim_head ** -0.5 152 | self.heads = heads 153 | 154 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 155 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 156 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 157 | 158 | self.to_out = nn.Sequential( 159 | nn.Linear(inner_dim, query_dim), 160 | nn.Dropout(dropout) 161 | ) 162 | 163 | def forward(self, x, context=None, mask=None): 164 | h = self.heads 165 | 166 | q = self.to_q(x) 167 | context = default(context, x) 168 | k = self.to_k(context) 169 | v = self.to_v(context) 170 | 171 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 172 | 173 | # force cast to fp32 to avoid overflowing 174 | if _ATTN_PRECISION =="fp32": 175 | with torch.autocast(enabled=False, device_type = 'cuda'): 176 | q, k = q.float(), k.float() 177 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 178 | else: 179 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 180 | 181 | del q, k 182 | 183 | if exists(mask): 184 | mask = rearrange(mask, 'b ... -> b (...)') 185 | max_neg_value = -torch.finfo(sim.dtype).max 186 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 187 | sim.masked_fill_(~mask, max_neg_value) 188 | 189 | # attention, what we cannot get enough of 190 | sim = sim.softmax(dim=-1) 191 | 192 | out = einsum('b i j, b j d -> b i d', sim, v) 193 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 194 | return self.to_out(out) 195 | 196 | 197 | class MemoryEfficientCrossAttention(nn.Module): 198 | # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 199 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): 200 | super().__init__() 201 | print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " 202 | f"{heads} heads.") 203 | inner_dim = dim_head * heads 204 | context_dim = default(context_dim, query_dim) 205 | 206 | self.heads = heads 207 | self.dim_head = dim_head 208 | 209 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 210 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 211 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 212 | 213 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) 214 | self.attention_op: Optional[Any] = None 215 | 216 | def forward(self, x, context=None, mask=None): 217 | q = self.to_q(x) 218 | context = default(context, x) 219 | k = self.to_k(context) 220 | v = self.to_v(context) 221 | 222 | b, _, _ = q.shape 223 | q, k, v = map( 224 | lambda t: t.unsqueeze(3) 225 | .reshape(b, t.shape[1], self.heads, self.dim_head) 226 | .permute(0, 2, 1, 3) 227 | .reshape(b * self.heads, t.shape[1], self.dim_head) 228 | .contiguous(), 229 | (q, k, v), 230 | ) 231 | 232 | # actually compute the attention, what we cannot get enough of 233 | out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) 234 | 235 | if exists(mask): 236 | raise NotImplementedError 237 | out = ( 238 | out.unsqueeze(0) 239 | .reshape(b, self.heads, out.shape[1], self.dim_head) 240 | .permute(0, 2, 1, 3) 241 | .reshape(b, out.shape[1], self.heads * self.dim_head) 242 | ) 243 | return self.to_out(out) 244 | 245 | 246 | class BasicTransformerBlock(nn.Module): 247 | ATTENTION_MODES = { 248 | "softmax": CrossAttention, # vanilla attention 249 | "softmax-xformers": MemoryEfficientCrossAttention 250 | } 251 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, 252 | disable_self_attn=False): 253 | super().__init__() 254 | attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" 255 | assert attn_mode in self.ATTENTION_MODES 256 | attn_cls = self.ATTENTION_MODES[attn_mode] 257 | self.disable_self_attn = disable_self_attn 258 | self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, 259 | context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn 260 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 261 | self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, 262 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 263 | self.norm1 = nn.LayerNorm(dim) 264 | self.norm2 = nn.LayerNorm(dim) 265 | self.norm3 = nn.LayerNorm(dim) 266 | self.checkpoint = checkpoint 267 | 268 | def forward(self, x, context=None): 269 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 270 | 271 | def _forward(self, x, context=None): 272 | x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x 273 | x = self.attn2(self.norm2(x), context=context) + x 274 | x = self.ff(self.norm3(x)) + x 275 | return x 276 | 277 | 278 | class SpatialTransformer(nn.Module): 279 | """ 280 | Transformer block for image-like data. 281 | First, project the input (aka embedding) 282 | and reshape to b, t, d. 283 | Then apply standard transformer action. 284 | Finally, reshape to image 285 | NEW: use_linear for more efficiency instead of the 1x1 convs 286 | """ 287 | def __init__(self, in_channels, n_heads, d_head, 288 | depth=1, dropout=0., context_dim=None, 289 | disable_self_attn=False, use_linear=False, 290 | use_checkpoint=True): 291 | super().__init__() 292 | if exists(context_dim) and not isinstance(context_dim, list): 293 | context_dim = [context_dim] 294 | self.in_channels = in_channels 295 | inner_dim = n_heads * d_head 296 | self.norm = Normalize(in_channels) 297 | if not use_linear: 298 | self.proj_in = nn.Conv2d(in_channels, 299 | inner_dim, 300 | kernel_size=1, 301 | stride=1, 302 | padding=0) 303 | else: 304 | self.proj_in = nn.Linear(in_channels, inner_dim) 305 | 306 | self.transformer_blocks = nn.ModuleList( 307 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], 308 | disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) 309 | for d in range(depth)] 310 | ) 311 | if not use_linear: 312 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 313 | in_channels, 314 | kernel_size=1, 315 | stride=1, 316 | padding=0)) 317 | else: 318 | self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) 319 | self.use_linear = use_linear 320 | 321 | def forward(self, x, context=None): 322 | # note: if no context is given, cross-attention defaults to self-attention 323 | if not isinstance(context, list): 324 | context = [context] 325 | b, c, h, w = x.shape 326 | x_in = x 327 | x = self.norm(x) 328 | if not self.use_linear: 329 | x = self.proj_in(x) 330 | x = rearrange(x, 'b c h w -> b (h w) c').contiguous() 331 | if self.use_linear: 332 | x = self.proj_in(x) 333 | for i, block in enumerate(self.transformer_blocks): 334 | x = block(x, context=context[i]) 335 | if self.use_linear: 336 | x = self.proj_out(x) 337 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() 338 | if not self.use_linear: 339 | x = self.proj_out(x) 340 | return x + x_in 341 | 342 | -------------------------------------------------------------------------------- /src/ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EliM0/TEXTureControlNet/f5f5fbe018e4eca59a76809bfe3d4d75e85bac86/src/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /src/ldm/modules/diffusionmodules/upscaling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from functools import partial 5 | 6 | from src.ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule 7 | from src.ldm.util import default 8 | 9 | 10 | class AbstractLowScaleModel(nn.Module): 11 | # for concatenating a downsampled image to the latent representation 12 | def __init__(self, noise_schedule_config=None): 13 | super(AbstractLowScaleModel, self).__init__() 14 | if noise_schedule_config is not None: 15 | self.register_schedule(**noise_schedule_config) 16 | 17 | def register_schedule(self, beta_schedule="linear", timesteps=1000, 18 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 19 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, 20 | cosine_s=cosine_s) 21 | alphas = 1. - betas 22 | alphas_cumprod = np.cumprod(alphas, axis=0) 23 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 24 | 25 | timesteps, = betas.shape 26 | self.num_timesteps = int(timesteps) 27 | self.linear_start = linear_start 28 | self.linear_end = linear_end 29 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 30 | 31 | to_torch = partial(torch.tensor, dtype=torch.float32) 32 | 33 | self.register_buffer('betas', to_torch(betas)) 34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 35 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 36 | 37 | # calculations for diffusion q(x_t | x_{t-1}) and others 38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 43 | 44 | def q_sample(self, x_start, t, noise=None): 45 | noise = default(noise, lambda: torch.randn_like(x_start)) 46 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 47 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) 48 | 49 | def forward(self, x): 50 | return x, None 51 | 52 | def decode(self, x): 53 | return x 54 | 55 | 56 | class SimpleImageConcat(AbstractLowScaleModel): 57 | # no noise level conditioning 58 | def __init__(self): 59 | super(SimpleImageConcat, self).__init__(noise_schedule_config=None) 60 | self.max_noise_level = 0 61 | 62 | def forward(self, x): 63 | # fix to constant noise level 64 | return x, torch.zeros(x.shape[0], device=x.device).long() 65 | 66 | 67 | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): 68 | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): 69 | super().__init__(noise_schedule_config=noise_schedule_config) 70 | self.max_noise_level = max_noise_level 71 | 72 | def forward(self, x, noise_level=None): 73 | if noise_level is None: 74 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() 75 | else: 76 | assert isinstance(noise_level, torch.Tensor) 77 | z = self.q_sample(x, noise_level) 78 | return z, noise_level 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /src/ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from src.ldm.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def checkpoint(func, inputs, params, flag): 103 | """ 104 | Evaluate a function without caching intermediate activations, allowing for 105 | reduced memory at the expense of extra compute in the backward pass. 106 | :param func: the function to evaluate. 107 | :param inputs: the argument sequence to pass to `func`. 108 | :param params: a sequence of parameters `func` depends on but does not 109 | explicitly take as arguments. 110 | :param flag: if False, disable gradient checkpointing. 111 | """ 112 | if flag: 113 | args = tuple(inputs) + tuple(params) 114 | return CheckpointFunction.apply(func, len(inputs), *args) 115 | else: 116 | return func(*inputs) 117 | 118 | 119 | class CheckpointFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, run_function, length, *args): 122 | ctx.run_function = run_function 123 | ctx.input_tensors = list(args[:length]) 124 | ctx.input_params = list(args[length:]) 125 | ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), 126 | "dtype": torch.get_autocast_gpu_dtype(), 127 | "cache_enabled": torch.is_autocast_cache_enabled()} 128 | with torch.no_grad(): 129 | output_tensors = ctx.run_function(*ctx.input_tensors) 130 | return output_tensors 131 | 132 | @staticmethod 133 | def backward(ctx, *output_grads): 134 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 135 | with torch.enable_grad(), \ 136 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): 137 | # Fixes a bug where the first op in run_function modifies the 138 | # Tensor storage in place, which is not allowed for detach()'d 139 | # Tensors. 140 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 141 | output_tensors = ctx.run_function(*shallow_copies) 142 | input_grads = torch.autograd.grad( 143 | output_tensors, 144 | ctx.input_tensors + ctx.input_params, 145 | output_grads, 146 | allow_unused=True, 147 | ) 148 | del ctx.input_tensors 149 | del ctx.input_params 150 | del output_tensors 151 | return (None, None) + input_grads 152 | 153 | 154 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 155 | """ 156 | Create sinusoidal timestep embeddings. 157 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 158 | These may be fractional. 159 | :param dim: the dimension of the output. 160 | :param max_period: controls the minimum frequency of the embeddings. 161 | :return: an [N x dim] Tensor of positional embeddings. 162 | """ 163 | if not repeat_only: 164 | half = dim // 2 165 | freqs = torch.exp( 166 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 167 | ).to(device=timesteps.device) 168 | args = timesteps[:, None].float() * freqs[None] 169 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 170 | if dim % 2: 171 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 172 | else: 173 | embedding = repeat(timesteps, 'b -> b d', d=dim) 174 | return embedding 175 | 176 | 177 | def zero_module(module): 178 | """ 179 | Zero out the parameters of a module and return it. 180 | """ 181 | for p in module.parameters(): 182 | p.detach().zero_() 183 | return module 184 | 185 | 186 | def scale_module(module, scale): 187 | """ 188 | Scale the parameters of a module and return it. 189 | """ 190 | for p in module.parameters(): 191 | p.detach().mul_(scale) 192 | return module 193 | 194 | 195 | def mean_flat(tensor): 196 | """ 197 | Take the mean over all non-batch dimensions. 198 | """ 199 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 200 | 201 | 202 | def normalization(channels): 203 | """ 204 | Make a standard normalization layer. 205 | :param channels: number of input channels. 206 | :return: an nn.Module for normalization. 207 | """ 208 | return GroupNorm32(32, channels) 209 | 210 | 211 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 212 | class SiLU(nn.Module): 213 | def forward(self, x): 214 | return x * torch.sigmoid(x) 215 | 216 | 217 | class GroupNorm32(nn.GroupNorm): 218 | def forward(self, x): 219 | return super().forward(x.float()).type(x.dtype) 220 | 221 | def conv_nd(dims, *args, **kwargs): 222 | """ 223 | Create a 1D, 2D, or 3D convolution module. 224 | """ 225 | if dims == 1: 226 | return nn.Conv1d(*args, **kwargs) 227 | elif dims == 2: 228 | return nn.Conv2d(*args, **kwargs) 229 | elif dims == 3: 230 | return nn.Conv3d(*args, **kwargs) 231 | raise ValueError(f"unsupported dimensions: {dims}") 232 | 233 | 234 | def linear(*args, **kwargs): 235 | """ 236 | Create a linear module. 237 | """ 238 | return nn.Linear(*args, **kwargs) 239 | 240 | 241 | def avg_pool_nd(dims, *args, **kwargs): 242 | """ 243 | Create a 1D, 2D, or 3D average pooling module. 244 | """ 245 | if dims == 1: 246 | return nn.AvgPool1d(*args, **kwargs) 247 | elif dims == 2: 248 | return nn.AvgPool2d(*args, **kwargs) 249 | elif dims == 3: 250 | return nn.AvgPool3d(*args, **kwargs) 251 | raise ValueError(f"unsupported dimensions: {dims}") 252 | 253 | 254 | class HybridConditioner(nn.Module): 255 | 256 | def __init__(self, c_concat_config, c_crossattn_config): 257 | super().__init__() 258 | self.concat_conditioner = instantiate_from_config(c_concat_config) 259 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 260 | 261 | def forward(self, c_concat, c_crossattn): 262 | c_concat = self.concat_conditioner(c_concat) 263 | c_crossattn = self.crossattn_conditioner(c_crossattn) 264 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 265 | 266 | 267 | def noise_like(shape, device, repeat=False): 268 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 269 | noise = lambda: torch.randn(shape, device=device) 270 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /src/ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EliM0/TEXTureControlNet/f5f5fbe018e4eca59a76809bfe3d4d75e85bac86/src/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /src/ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /src/ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1, dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | # remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.', '') 20 | self.m_name2s_name.update({name: s_name}) 21 | self.register_buffer(s_name, p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def reset_num_updates(self): 26 | del self.num_updates 27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 47 | else: 48 | assert not key in self.m_name2s_name 49 | 50 | def copy_to(self, model): 51 | m_param = dict(model.named_parameters()) 52 | shadow_params = dict(self.named_buffers()) 53 | for key in m_param: 54 | if m_param[key].requires_grad: 55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 56 | else: 57 | assert not key in self.m_name2s_name 58 | 59 | def store(self, parameters): 60 | """ 61 | Save the current parameters for restoring later. 62 | Args: 63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 64 | temporarily stored. 65 | """ 66 | self.collected_params = [param.clone() for param in parameters] 67 | 68 | def restore(self, parameters): 69 | """ 70 | Restore the parameters stored with the `store` method. 71 | Useful to validate the model with EMA parameters without affecting the 72 | original optimization process. Store the parameters before the 73 | `copy_to` method. After validation (or model saving), use this to 74 | restore the former parameters. 75 | Args: 76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 77 | updated with the stored parameters. 78 | """ 79 | for c_param, param in zip(self.collected_params, parameters): 80 | param.data.copy_(c_param.data) 81 | -------------------------------------------------------------------------------- /src/ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EliM0/TEXTureControlNet/f5f5fbe018e4eca59a76809bfe3d4d75e85bac86/src/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /src/ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.checkpoint import checkpoint 4 | 5 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel 6 | 7 | import open_clip 8 | from src.ldm.util import default, count_params 9 | 10 | 11 | class AbstractEncoder(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def encode(self, *args, **kwargs): 16 | raise NotImplementedError 17 | 18 | 19 | class IdentityEncoder(AbstractEncoder): 20 | 21 | def encode(self, x): 22 | return x 23 | 24 | 25 | class ClassEmbedder(nn.Module): 26 | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): 27 | super().__init__() 28 | self.key = key 29 | self.embedding = nn.Embedding(n_classes, embed_dim) 30 | self.n_classes = n_classes 31 | self.ucg_rate = ucg_rate 32 | 33 | def forward(self, batch, key=None, disable_dropout=False): 34 | if key is None: 35 | key = self.key 36 | # this is for use in crossattn 37 | c = batch[key][:, None] 38 | if self.ucg_rate > 0. and not disable_dropout: 39 | mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) 40 | c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) 41 | c = c.long() 42 | c = self.embedding(c) 43 | return c 44 | 45 | def get_unconditional_conditioning(self, bs, device="cuda"): 46 | uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) 47 | uc = torch.ones((bs,), device=device) * uc_class 48 | uc = {self.key: uc} 49 | return uc 50 | 51 | 52 | def disabled_train(self, mode=True): 53 | """Overwrite model.train with this function to make sure train/eval mode 54 | does not change anymore.""" 55 | return self 56 | 57 | 58 | class FrozenT5Embedder(AbstractEncoder): 59 | """Uses the T5 transformer encoder for text""" 60 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl 61 | super().__init__() 62 | self.tokenizer = T5Tokenizer.from_pretrained(version) 63 | self.transformer = T5EncoderModel.from_pretrained(version) 64 | self.device = device 65 | self.max_length = max_length # TODO: typical value? 66 | if freeze: 67 | self.freeze() 68 | 69 | def freeze(self): 70 | self.transformer = self.transformer.eval() 71 | #self.train = disabled_train 72 | for param in self.parameters(): 73 | param.requires_grad = False 74 | 75 | def forward(self, text): 76 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 77 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 78 | tokens = batch_encoding["input_ids"].to(self.device) 79 | outputs = self.transformer(input_ids=tokens) 80 | 81 | z = outputs.last_hidden_state 82 | return z 83 | 84 | def encode(self, text): 85 | return self(text) 86 | 87 | 88 | class FrozenCLIPEmbedder(AbstractEncoder): 89 | """Uses the CLIP transformer encoder for text (from huggingface)""" 90 | LAYERS = [ 91 | "last", 92 | "pooled", 93 | "hidden" 94 | ] 95 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, 96 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 97 | super().__init__() 98 | assert layer in self.LAYERS 99 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 100 | self.transformer = CLIPTextModel.from_pretrained(version) 101 | self.device = device 102 | self.max_length = max_length 103 | if freeze: 104 | self.freeze() 105 | self.layer = layer 106 | self.layer_idx = layer_idx 107 | if layer == "hidden": 108 | assert layer_idx is not None 109 | assert 0 <= abs(layer_idx) <= 12 110 | 111 | def freeze(self): 112 | self.transformer = self.transformer.eval() 113 | #self.train = disabled_train 114 | for param in self.parameters(): 115 | param.requires_grad = False 116 | 117 | def forward(self, text): 118 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 119 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 120 | tokens = batch_encoding["input_ids"].to(self.device) 121 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") 122 | if self.layer == "last": 123 | z = outputs.last_hidden_state 124 | elif self.layer == "pooled": 125 | z = outputs.pooler_output[:, None, :] 126 | else: 127 | z = outputs.hidden_states[self.layer_idx] 128 | return z 129 | 130 | def encode(self, text): 131 | return self(text) 132 | 133 | 134 | class FrozenOpenCLIPEmbedder(AbstractEncoder): 135 | """ 136 | Uses the OpenCLIP transformer encoder for text 137 | """ 138 | LAYERS = [ 139 | #"pooled", 140 | "last", 141 | "penultimate" 142 | ] 143 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, 144 | freeze=True, layer="last"): 145 | super().__init__() 146 | assert layer in self.LAYERS 147 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) 148 | del model.visual 149 | self.model = model 150 | 151 | self.device = device 152 | self.max_length = max_length 153 | if freeze: 154 | self.freeze() 155 | self.layer = layer 156 | if self.layer == "last": 157 | self.layer_idx = 0 158 | elif self.layer == "penultimate": 159 | self.layer_idx = 1 160 | else: 161 | raise NotImplementedError() 162 | 163 | def freeze(self): 164 | self.model = self.model.eval() 165 | for param in self.parameters(): 166 | param.requires_grad = False 167 | 168 | def forward(self, text): 169 | tokens = open_clip.tokenize(text) 170 | z = self.encode_with_transformer(tokens.to(self.device)) 171 | return z 172 | 173 | def encode_with_transformer(self, text): 174 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 175 | x = x + self.model.positional_embedding 176 | x = x.permute(1, 0, 2) # NLD -> LND 177 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 178 | x = x.permute(1, 0, 2) # LND -> NLD 179 | x = self.model.ln_final(x) 180 | return x 181 | 182 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): 183 | for i, r in enumerate(self.model.transformer.resblocks): 184 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 185 | break 186 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): 187 | x = checkpoint(r, x, attn_mask) 188 | else: 189 | x = r(x, attn_mask=attn_mask) 190 | return x 191 | 192 | def encode(self, text): 193 | return self(text) 194 | 195 | 196 | class FrozenCLIPT5Encoder(AbstractEncoder): 197 | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", 198 | clip_max_length=77, t5_max_length=77): 199 | super().__init__() 200 | self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) 201 | self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) 202 | print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " 203 | f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") 204 | 205 | def encode(self, text): 206 | return self(text) 207 | 208 | def forward(self, text): 209 | clip_z = self.clip_encoder.encode(text) 210 | t5_z = self.t5_encoder.encode(text) 211 | return [clip_z, t5_z] 212 | 213 | 214 | -------------------------------------------------------------------------------- /src/ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /src/ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EliM0/TEXTureControlNet/f5f5fbe018e4eca59a76809bfe3d4d75e85bac86/src/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /src/ldm/modules/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EliM0/TEXTureControlNet/f5f5fbe018e4eca59a76809bfe3d4d75e85bac86/src/ldm/modules/midas/__init__.py -------------------------------------------------------------------------------- /src/ldm/modules/midas/api.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/isl-org/MiDaS 2 | 3 | import cv2 4 | import torch 5 | import torch.nn as nn 6 | from torchvision.transforms import Compose 7 | 8 | from src.ldm.modules.midas.midas.dpt_depth import DPTDepthModel 9 | from src.ldm.modules.midas.midas.midas_net import MidasNet 10 | from src.ldm.modules.midas.midas.midas_net_custom import MidasNet_small 11 | from src.ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet 12 | 13 | 14 | ISL_PATHS = { 15 | "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", 16 | "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", 17 | "midas_v21": "", 18 | "midas_v21_small": "", 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | def load_midas_transform(model_type): 29 | # https://github.com/isl-org/MiDaS/blob/master/run.py 30 | # load transform only 31 | if model_type == "dpt_large": # DPT-Large 32 | net_w, net_h = 384, 384 33 | resize_mode = "minimal" 34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 35 | 36 | elif model_type == "dpt_hybrid": # DPT-Hybrid 37 | net_w, net_h = 384, 384 38 | resize_mode = "minimal" 39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 40 | 41 | elif model_type == "midas_v21": 42 | net_w, net_h = 384, 384 43 | resize_mode = "upper_bound" 44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 45 | 46 | elif model_type == "midas_v21_small": 47 | net_w, net_h = 256, 256 48 | resize_mode = "upper_bound" 49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 50 | 51 | else: 52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large" 53 | 54 | transform = Compose( 55 | [ 56 | Resize( 57 | net_w, 58 | net_h, 59 | resize_target=None, 60 | keep_aspect_ratio=True, 61 | ensure_multiple_of=32, 62 | resize_method=resize_mode, 63 | image_interpolation_method=cv2.INTER_CUBIC, 64 | ), 65 | normalization, 66 | PrepareForNet(), 67 | ] 68 | ) 69 | 70 | return transform 71 | 72 | 73 | def load_model(model_type): 74 | # https://github.com/isl-org/MiDaS/blob/master/run.py 75 | # load network 76 | model_path = ISL_PATHS[model_type] 77 | if model_type == "dpt_large": # DPT-Large 78 | model = DPTDepthModel( 79 | path=model_path, 80 | backbone="vitl16_384", 81 | non_negative=True, 82 | ) 83 | net_w, net_h = 384, 384 84 | resize_mode = "minimal" 85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 86 | 87 | elif model_type == "dpt_hybrid": # DPT-Hybrid 88 | model = DPTDepthModel( 89 | path=model_path, 90 | backbone="vitb_rn50_384", 91 | non_negative=True, 92 | ) 93 | net_w, net_h = 384, 384 94 | resize_mode = "minimal" 95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 96 | 97 | elif model_type == "midas_v21": 98 | model = MidasNet(model_path, non_negative=True) 99 | net_w, net_h = 384, 384 100 | resize_mode = "upper_bound" 101 | normalization = NormalizeImage( 102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 103 | ) 104 | 105 | elif model_type == "midas_v21_small": 106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, 107 | non_negative=True, blocks={'expand': True}) 108 | net_w, net_h = 256, 256 109 | resize_mode = "upper_bound" 110 | normalization = NormalizeImage( 111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 112 | ) 113 | 114 | else: 115 | print(f"model_type '{model_type}' not implemented, use: --model_type large") 116 | assert False 117 | 118 | transform = Compose( 119 | [ 120 | Resize( 121 | net_w, 122 | net_h, 123 | resize_target=None, 124 | keep_aspect_ratio=True, 125 | ensure_multiple_of=32, 126 | resize_method=resize_mode, 127 | image_interpolation_method=cv2.INTER_CUBIC, 128 | ), 129 | normalization, 130 | PrepareForNet(), 131 | ] 132 | ) 133 | 134 | return model.eval(), transform 135 | 136 | 137 | class MiDaSInference(nn.Module): 138 | MODEL_TYPES_TORCH_HUB = [ 139 | "DPT_Large", 140 | "DPT_Hybrid", 141 | "MiDaS_small" 142 | ] 143 | MODEL_TYPES_ISL = [ 144 | "dpt_large", 145 | "dpt_hybrid", 146 | "midas_v21", 147 | "midas_v21_small", 148 | ] 149 | 150 | def __init__(self, model_type): 151 | super().__init__() 152 | assert (model_type in self.MODEL_TYPES_ISL) 153 | model, _ = load_model(model_type) 154 | self.model = model 155 | self.model.train = disabled_train 156 | 157 | def forward(self, x): 158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array 159 | # NOTE: we expect that the correct transform has been called during dataloading. 160 | with torch.no_grad(): 161 | prediction = self.model(x) 162 | prediction = torch.nn.functional.interpolate( 163 | prediction.unsqueeze(1), 164 | size=x.shape[2:], 165 | mode="bicubic", 166 | align_corners=False, 167 | ) 168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) 169 | return prediction 170 | 171 | -------------------------------------------------------------------------------- /src/ldm/modules/midas/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EliM0/TEXTureControlNet/f5f5fbe018e4eca59a76809bfe3d4d75e85bac86/src/ldm/modules/midas/midas/__init__.py -------------------------------------------------------------------------------- /src/ldm/modules/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /src/ldm/modules/midas/midas/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vit import ( 5 | _make_pretrained_vitb_rn50_384, 6 | _make_pretrained_vitl16_384, 7 | _make_pretrained_vitb16_384, 8 | forward_vit, 9 | ) 10 | 11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): 12 | if backbone == "vitl16_384": 13 | pretrained = _make_pretrained_vitl16_384( 14 | use_pretrained, hooks=hooks, use_readout=use_readout 15 | ) 16 | scratch = _make_scratch( 17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 18 | ) # ViT-L/16 - 85.0% Top1 (backbone) 19 | elif backbone == "vitb_rn50_384": 20 | pretrained = _make_pretrained_vitb_rn50_384( 21 | use_pretrained, 22 | hooks=hooks, 23 | use_vit_only=use_vit_only, 24 | use_readout=use_readout, 25 | ) 26 | scratch = _make_scratch( 27 | [256, 512, 768, 768], features, groups=groups, expand=expand 28 | ) # ViT-H/16 - 85.0% Top1 (backbone) 29 | elif backbone == "vitb16_384": 30 | pretrained = _make_pretrained_vitb16_384( 31 | use_pretrained, hooks=hooks, use_readout=use_readout 32 | ) 33 | scratch = _make_scratch( 34 | [96, 192, 384, 768], features, groups=groups, expand=expand 35 | ) # ViT-B/16 - 84.6% Top1 (backbone) 36 | elif backbone == "resnext101_wsl": 37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 39 | elif backbone == "efficientnet_lite3": 40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) 41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 42 | else: 43 | print(f"Backbone '{backbone}' not implemented") 44 | assert False 45 | 46 | return pretrained, scratch 47 | 48 | 49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 50 | scratch = nn.Module() 51 | 52 | out_shape1 = out_shape 53 | out_shape2 = out_shape 54 | out_shape3 = out_shape 55 | out_shape4 = out_shape 56 | if expand==True: 57 | out_shape1 = out_shape 58 | out_shape2 = out_shape*2 59 | out_shape3 = out_shape*4 60 | out_shape4 = out_shape*8 61 | 62 | scratch.layer1_rn = nn.Conv2d( 63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 64 | ) 65 | scratch.layer2_rn = nn.Conv2d( 66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 67 | ) 68 | scratch.layer3_rn = nn.Conv2d( 69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 70 | ) 71 | scratch.layer4_rn = nn.Conv2d( 72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 73 | ) 74 | 75 | return scratch 76 | 77 | 78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): 79 | efficientnet = torch.hub.load( 80 | "rwightman/gen-efficientnet-pytorch", 81 | "tf_efficientnet_lite3", 82 | pretrained=use_pretrained, 83 | exportable=exportable 84 | ) 85 | return _make_efficientnet_backbone(efficientnet) 86 | 87 | 88 | def _make_efficientnet_backbone(effnet): 89 | pretrained = nn.Module() 90 | 91 | pretrained.layer1 = nn.Sequential( 92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] 93 | ) 94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) 95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) 96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) 97 | 98 | return pretrained 99 | 100 | 101 | def _make_resnet_backbone(resnet): 102 | pretrained = nn.Module() 103 | pretrained.layer1 = nn.Sequential( 104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 105 | ) 106 | 107 | pretrained.layer2 = resnet.layer2 108 | pretrained.layer3 = resnet.layer3 109 | pretrained.layer4 = resnet.layer4 110 | 111 | return pretrained 112 | 113 | 114 | def _make_pretrained_resnext101_wsl(use_pretrained): 115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 116 | return _make_resnet_backbone(resnet) 117 | 118 | 119 | 120 | class Interpolate(nn.Module): 121 | """Interpolation module. 122 | """ 123 | 124 | def __init__(self, scale_factor, mode, align_corners=False): 125 | """Init. 126 | 127 | Args: 128 | scale_factor (float): scaling 129 | mode (str): interpolation mode 130 | """ 131 | super(Interpolate, self).__init__() 132 | 133 | self.interp = nn.functional.interpolate 134 | self.scale_factor = scale_factor 135 | self.mode = mode 136 | self.align_corners = align_corners 137 | 138 | def forward(self, x): 139 | """Forward pass. 140 | 141 | Args: 142 | x (tensor): input 143 | 144 | Returns: 145 | tensor: interpolated data 146 | """ 147 | 148 | x = self.interp( 149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners 150 | ) 151 | 152 | return x 153 | 154 | 155 | class ResidualConvUnit(nn.Module): 156 | """Residual convolution module. 157 | """ 158 | 159 | def __init__(self, features): 160 | """Init. 161 | 162 | Args: 163 | features (int): number of features 164 | """ 165 | super().__init__() 166 | 167 | self.conv1 = nn.Conv2d( 168 | features, features, kernel_size=3, stride=1, padding=1, bias=True 169 | ) 170 | 171 | self.conv2 = nn.Conv2d( 172 | features, features, kernel_size=3, stride=1, padding=1, bias=True 173 | ) 174 | 175 | self.relu = nn.ReLU(inplace=True) 176 | 177 | def forward(self, x): 178 | """Forward pass. 179 | 180 | Args: 181 | x (tensor): input 182 | 183 | Returns: 184 | tensor: output 185 | """ 186 | out = self.relu(x) 187 | out = self.conv1(out) 188 | out = self.relu(out) 189 | out = self.conv2(out) 190 | 191 | return out + x 192 | 193 | 194 | class FeatureFusionBlock(nn.Module): 195 | """Feature fusion block. 196 | """ 197 | 198 | def __init__(self, features): 199 | """Init. 200 | 201 | Args: 202 | features (int): number of features 203 | """ 204 | super(FeatureFusionBlock, self).__init__() 205 | 206 | self.resConfUnit1 = ResidualConvUnit(features) 207 | self.resConfUnit2 = ResidualConvUnit(features) 208 | 209 | def forward(self, *xs): 210 | """Forward pass. 211 | 212 | Returns: 213 | tensor: output 214 | """ 215 | output = xs[0] 216 | 217 | if len(xs) == 2: 218 | output += self.resConfUnit1(xs[1]) 219 | 220 | output = self.resConfUnit2(output) 221 | 222 | output = nn.functional.interpolate( 223 | output, scale_factor=2, mode="bilinear", align_corners=True 224 | ) 225 | 226 | return output 227 | 228 | 229 | 230 | 231 | class ResidualConvUnit_custom(nn.Module): 232 | """Residual convolution module. 233 | """ 234 | 235 | def __init__(self, features, activation, bn): 236 | """Init. 237 | 238 | Args: 239 | features (int): number of features 240 | """ 241 | super().__init__() 242 | 243 | self.bn = bn 244 | 245 | self.groups=1 246 | 247 | self.conv1 = nn.Conv2d( 248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 249 | ) 250 | 251 | self.conv2 = nn.Conv2d( 252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 253 | ) 254 | 255 | if self.bn==True: 256 | self.bn1 = nn.BatchNorm2d(features) 257 | self.bn2 = nn.BatchNorm2d(features) 258 | 259 | self.activation = activation 260 | 261 | self.skip_add = nn.quantized.FloatFunctional() 262 | 263 | def forward(self, x): 264 | """Forward pass. 265 | 266 | Args: 267 | x (tensor): input 268 | 269 | Returns: 270 | tensor: output 271 | """ 272 | 273 | out = self.activation(x) 274 | out = self.conv1(out) 275 | if self.bn==True: 276 | out = self.bn1(out) 277 | 278 | out = self.activation(out) 279 | out = self.conv2(out) 280 | if self.bn==True: 281 | out = self.bn2(out) 282 | 283 | if self.groups > 1: 284 | out = self.conv_merge(out) 285 | 286 | return self.skip_add.add(out, x) 287 | 288 | # return out + x 289 | 290 | 291 | class FeatureFusionBlock_custom(nn.Module): 292 | """Feature fusion block. 293 | """ 294 | 295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): 296 | """Init. 297 | 298 | Args: 299 | features (int): number of features 300 | """ 301 | super(FeatureFusionBlock_custom, self).__init__() 302 | 303 | self.deconv = deconv 304 | self.align_corners = align_corners 305 | 306 | self.groups=1 307 | 308 | self.expand = expand 309 | out_features = features 310 | if self.expand==True: 311 | out_features = features//2 312 | 313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 314 | 315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) 316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) 317 | 318 | self.skip_add = nn.quantized.FloatFunctional() 319 | 320 | def forward(self, *xs): 321 | """Forward pass. 322 | 323 | Returns: 324 | tensor: output 325 | """ 326 | output = xs[0] 327 | 328 | if len(xs) == 2: 329 | res = self.resConfUnit1(xs[1]) 330 | output = self.skip_add.add(output, res) 331 | # output += res 332 | 333 | output = self.resConfUnit2(output) 334 | 335 | output = nn.functional.interpolate( 336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 337 | ) 338 | 339 | output = self.out_conv(output) 340 | 341 | return output 342 | 343 | -------------------------------------------------------------------------------- /src/ldm/modules/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /src/ldm/modules/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /src/ldm/modules/midas/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /src/ldm/modules/midas/midas/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.float32), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height). 50 | """ 51 | 52 | def __init__( 53 | self, 54 | width, 55 | height, 56 | resize_target=True, 57 | keep_aspect_ratio=False, 58 | ensure_multiple_of=1, 59 | resize_method="lower_bound", 60 | image_interpolation_method=cv2.INTER_AREA, 61 | ): 62 | """Init. 63 | 64 | Args: 65 | width (int): desired output width 66 | height (int): desired output height 67 | resize_target (bool, optional): 68 | True: Resize the full sample (image, mask, target). 69 | False: Resize image only. 70 | Defaults to True. 71 | keep_aspect_ratio (bool, optional): 72 | True: Keep the aspect ratio of the input sample. 73 | Output sample might not have the given width and height, and 74 | resize behaviour depends on the parameter 'resize_method'. 75 | Defaults to False. 76 | ensure_multiple_of (int, optional): 77 | Output width and height is constrained to be multiple of this parameter. 78 | Defaults to 1. 79 | resize_method (str, optional): 80 | "lower_bound": Output will be at least as large as the given size. 81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 83 | Defaults to "lower_bound". 84 | """ 85 | self.__width = width 86 | self.__height = height 87 | 88 | self.__resize_target = resize_target 89 | self.__keep_aspect_ratio = keep_aspect_ratio 90 | self.__multiple_of = ensure_multiple_of 91 | self.__resize_method = resize_method 92 | self.__image_interpolation_method = image_interpolation_method 93 | 94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 96 | 97 | if max_val is not None and y > max_val: 98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 99 | 100 | if y < min_val: 101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 102 | 103 | return y 104 | 105 | def get_size(self, width, height): 106 | # determine new height and width 107 | scale_height = self.__height / height 108 | scale_width = self.__width / width 109 | 110 | if self.__keep_aspect_ratio: 111 | if self.__resize_method == "lower_bound": 112 | # scale such that output size is lower bound 113 | if scale_width > scale_height: 114 | # fit width 115 | scale_height = scale_width 116 | else: 117 | # fit height 118 | scale_width = scale_height 119 | elif self.__resize_method == "upper_bound": 120 | # scale such that output size is upper bound 121 | if scale_width < scale_height: 122 | # fit width 123 | scale_height = scale_width 124 | else: 125 | # fit height 126 | scale_width = scale_height 127 | elif self.__resize_method == "minimal": 128 | # scale as least as possbile 129 | if abs(1 - scale_width) < abs(1 - scale_height): 130 | # fit width 131 | scale_height = scale_width 132 | else: 133 | # fit height 134 | scale_width = scale_height 135 | else: 136 | raise ValueError( 137 | f"resize_method {self.__resize_method} not implemented" 138 | ) 139 | 140 | if self.__resize_method == "lower_bound": 141 | new_height = self.constrain_to_multiple_of( 142 | scale_height * height, min_val=self.__height 143 | ) 144 | new_width = self.constrain_to_multiple_of( 145 | scale_width * width, min_val=self.__width 146 | ) 147 | elif self.__resize_method == "upper_bound": 148 | new_height = self.constrain_to_multiple_of( 149 | scale_height * height, max_val=self.__height 150 | ) 151 | new_width = self.constrain_to_multiple_of( 152 | scale_width * width, max_val=self.__width 153 | ) 154 | elif self.__resize_method == "minimal": 155 | new_height = self.constrain_to_multiple_of(scale_height * height) 156 | new_width = self.constrain_to_multiple_of(scale_width * width) 157 | else: 158 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 159 | 160 | return (new_width, new_height) 161 | 162 | def __call__(self, sample): 163 | width, height = self.get_size( 164 | sample["image"].shape[1], sample["image"].shape[0] 165 | ) 166 | 167 | # resize sample 168 | sample["image"] = cv2.resize( 169 | sample["image"], 170 | (width, height), 171 | interpolation=self.__image_interpolation_method, 172 | ) 173 | 174 | if self.__resize_target: 175 | if "disparity" in sample: 176 | sample["disparity"] = cv2.resize( 177 | sample["disparity"], 178 | (width, height), 179 | interpolation=cv2.INTER_NEAREST, 180 | ) 181 | 182 | if "depth" in sample: 183 | sample["depth"] = cv2.resize( 184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 185 | ) 186 | 187 | sample["mask"] = cv2.resize( 188 | sample["mask"].astype(np.float32), 189 | (width, height), 190 | interpolation=cv2.INTER_NEAREST, 191 | ) 192 | sample["mask"] = sample["mask"].astype(bool) 193 | 194 | return sample 195 | 196 | 197 | class NormalizeImage(object): 198 | """Normlize image by given mean and std. 199 | """ 200 | 201 | def __init__(self, mean, std): 202 | self.__mean = mean 203 | self.__std = std 204 | 205 | def __call__(self, sample): 206 | sample["image"] = (sample["image"] - self.__mean) / self.__std 207 | 208 | return sample 209 | 210 | 211 | class PrepareForNet(object): 212 | """Prepare sample for usage as network input. 213 | """ 214 | 215 | def __init__(self): 216 | pass 217 | 218 | def __call__(self, sample): 219 | image = np.transpose(sample["image"], (2, 0, 1)) 220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 221 | 222 | if "mask" in sample: 223 | sample["mask"] = sample["mask"].astype(np.float32) 224 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 225 | 226 | if "disparity" in sample: 227 | disparity = sample["disparity"].astype(np.float32) 228 | sample["disparity"] = np.ascontiguousarray(disparity) 229 | 230 | if "depth" in sample: 231 | depth = sample["depth"].astype(np.float32) 232 | sample["depth"] = np.ascontiguousarray(depth) 233 | 234 | return sample 235 | -------------------------------------------------------------------------------- /src/ldm/modules/midas/midas/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import timm 4 | import types 5 | import math 6 | import torch.nn.functional as F 7 | 8 | 9 | class Slice(nn.Module): 10 | def __init__(self, start_index=1): 11 | super(Slice, self).__init__() 12 | self.start_index = start_index 13 | 14 | def forward(self, x): 15 | return x[:, self.start_index :] 16 | 17 | 18 | class AddReadout(nn.Module): 19 | def __init__(self, start_index=1): 20 | super(AddReadout, self).__init__() 21 | self.start_index = start_index 22 | 23 | def forward(self, x): 24 | if self.start_index == 2: 25 | readout = (x[:, 0] + x[:, 1]) / 2 26 | else: 27 | readout = x[:, 0] 28 | return x[:, self.start_index :] + readout.unsqueeze(1) 29 | 30 | 31 | class ProjectReadout(nn.Module): 32 | def __init__(self, in_features, start_index=1): 33 | super(ProjectReadout, self).__init__() 34 | self.start_index = start_index 35 | 36 | self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) 37 | 38 | def forward(self, x): 39 | readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) 40 | features = torch.cat((x[:, self.start_index :], readout), -1) 41 | 42 | return self.project(features) 43 | 44 | 45 | class Transpose(nn.Module): 46 | def __init__(self, dim0, dim1): 47 | super(Transpose, self).__init__() 48 | self.dim0 = dim0 49 | self.dim1 = dim1 50 | 51 | def forward(self, x): 52 | x = x.transpose(self.dim0, self.dim1) 53 | return x 54 | 55 | 56 | def forward_vit(pretrained, x): 57 | b, c, h, w = x.shape 58 | 59 | glob = pretrained.model.forward_flex(x) 60 | 61 | layer_1 = pretrained.activations["1"] 62 | layer_2 = pretrained.activations["2"] 63 | layer_3 = pretrained.activations["3"] 64 | layer_4 = pretrained.activations["4"] 65 | 66 | layer_1 = pretrained.act_postprocess1[0:2](layer_1) 67 | layer_2 = pretrained.act_postprocess2[0:2](layer_2) 68 | layer_3 = pretrained.act_postprocess3[0:2](layer_3) 69 | layer_4 = pretrained.act_postprocess4[0:2](layer_4) 70 | 71 | unflatten = nn.Sequential( 72 | nn.Unflatten( 73 | 2, 74 | torch.Size( 75 | [ 76 | h // pretrained.model.patch_size[1], 77 | w // pretrained.model.patch_size[0], 78 | ] 79 | ), 80 | ) 81 | ) 82 | 83 | if layer_1.ndim == 3: 84 | layer_1 = unflatten(layer_1) 85 | if layer_2.ndim == 3: 86 | layer_2 = unflatten(layer_2) 87 | if layer_3.ndim == 3: 88 | layer_3 = unflatten(layer_3) 89 | if layer_4.ndim == 3: 90 | layer_4 = unflatten(layer_4) 91 | 92 | layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) 93 | layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) 94 | layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) 95 | layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) 96 | 97 | return layer_1, layer_2, layer_3, layer_4 98 | 99 | 100 | def _resize_pos_embed(self, posemb, gs_h, gs_w): 101 | posemb_tok, posemb_grid = ( 102 | posemb[:, : self.start_index], 103 | posemb[0, self.start_index :], 104 | ) 105 | 106 | gs_old = int(math.sqrt(len(posemb_grid))) 107 | 108 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 109 | posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") 110 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) 111 | 112 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 113 | 114 | return posemb 115 | 116 | 117 | def forward_flex(self, x): 118 | b, c, h, w = x.shape 119 | 120 | pos_embed = self._resize_pos_embed( 121 | self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] 122 | ) 123 | 124 | B = x.shape[0] 125 | 126 | if hasattr(self.patch_embed, "backbone"): 127 | x = self.patch_embed.backbone(x) 128 | if isinstance(x, (list, tuple)): 129 | x = x[-1] # last feature if backbone outputs list/tuple of features 130 | 131 | x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) 132 | 133 | if getattr(self, "dist_token", None) is not None: 134 | cls_tokens = self.cls_token.expand( 135 | B, -1, -1 136 | ) # stole cls_tokens impl from Phil Wang, thanks 137 | dist_token = self.dist_token.expand(B, -1, -1) 138 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 139 | else: 140 | cls_tokens = self.cls_token.expand( 141 | B, -1, -1 142 | ) # stole cls_tokens impl from Phil Wang, thanks 143 | x = torch.cat((cls_tokens, x), dim=1) 144 | 145 | x = x + pos_embed 146 | x = self.pos_drop(x) 147 | 148 | for blk in self.blocks: 149 | x = blk(x) 150 | 151 | x = self.norm(x) 152 | 153 | return x 154 | 155 | 156 | activations = {} 157 | 158 | 159 | def get_activation(name): 160 | def hook(model, input, output): 161 | activations[name] = output 162 | 163 | return hook 164 | 165 | 166 | def get_readout_oper(vit_features, features, use_readout, start_index=1): 167 | if use_readout == "ignore": 168 | readout_oper = [Slice(start_index)] * len(features) 169 | elif use_readout == "add": 170 | readout_oper = [AddReadout(start_index)] * len(features) 171 | elif use_readout == "project": 172 | readout_oper = [ 173 | ProjectReadout(vit_features, start_index) for out_feat in features 174 | ] 175 | else: 176 | assert ( 177 | False 178 | ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" 179 | 180 | return readout_oper 181 | 182 | 183 | def _make_vit_b16_backbone( 184 | model, 185 | features=[96, 192, 384, 768], 186 | size=[384, 384], 187 | hooks=[2, 5, 8, 11], 188 | vit_features=768, 189 | use_readout="ignore", 190 | start_index=1, 191 | ): 192 | pretrained = nn.Module() 193 | 194 | pretrained.model = model 195 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) 196 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) 197 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) 198 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) 199 | 200 | pretrained.activations = activations 201 | 202 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) 203 | 204 | # 32, 48, 136, 384 205 | pretrained.act_postprocess1 = nn.Sequential( 206 | readout_oper[0], 207 | Transpose(1, 2), 208 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 209 | nn.Conv2d( 210 | in_channels=vit_features, 211 | out_channels=features[0], 212 | kernel_size=1, 213 | stride=1, 214 | padding=0, 215 | ), 216 | nn.ConvTranspose2d( 217 | in_channels=features[0], 218 | out_channels=features[0], 219 | kernel_size=4, 220 | stride=4, 221 | padding=0, 222 | bias=True, 223 | dilation=1, 224 | groups=1, 225 | ), 226 | ) 227 | 228 | pretrained.act_postprocess2 = nn.Sequential( 229 | readout_oper[1], 230 | Transpose(1, 2), 231 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 232 | nn.Conv2d( 233 | in_channels=vit_features, 234 | out_channels=features[1], 235 | kernel_size=1, 236 | stride=1, 237 | padding=0, 238 | ), 239 | nn.ConvTranspose2d( 240 | in_channels=features[1], 241 | out_channels=features[1], 242 | kernel_size=2, 243 | stride=2, 244 | padding=0, 245 | bias=True, 246 | dilation=1, 247 | groups=1, 248 | ), 249 | ) 250 | 251 | pretrained.act_postprocess3 = nn.Sequential( 252 | readout_oper[2], 253 | Transpose(1, 2), 254 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 255 | nn.Conv2d( 256 | in_channels=vit_features, 257 | out_channels=features[2], 258 | kernel_size=1, 259 | stride=1, 260 | padding=0, 261 | ), 262 | ) 263 | 264 | pretrained.act_postprocess4 = nn.Sequential( 265 | readout_oper[3], 266 | Transpose(1, 2), 267 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 268 | nn.Conv2d( 269 | in_channels=vit_features, 270 | out_channels=features[3], 271 | kernel_size=1, 272 | stride=1, 273 | padding=0, 274 | ), 275 | nn.Conv2d( 276 | in_channels=features[3], 277 | out_channels=features[3], 278 | kernel_size=3, 279 | stride=2, 280 | padding=1, 281 | ), 282 | ) 283 | 284 | pretrained.model.start_index = start_index 285 | pretrained.model.patch_size = [16, 16] 286 | 287 | # We inject this function into the VisionTransformer instances so that 288 | # we can use it with interpolated position embeddings without modifying the library source. 289 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) 290 | pretrained.model._resize_pos_embed = types.MethodType( 291 | _resize_pos_embed, pretrained.model 292 | ) 293 | 294 | return pretrained 295 | 296 | 297 | def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): 298 | model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) 299 | 300 | hooks = [5, 11, 17, 23] if hooks == None else hooks 301 | return _make_vit_b16_backbone( 302 | model, 303 | features=[256, 512, 1024, 1024], 304 | hooks=hooks, 305 | vit_features=1024, 306 | use_readout=use_readout, 307 | ) 308 | 309 | 310 | def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): 311 | model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) 312 | 313 | hooks = [2, 5, 8, 11] if hooks == None else hooks 314 | return _make_vit_b16_backbone( 315 | model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout 316 | ) 317 | 318 | 319 | def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): 320 | model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) 321 | 322 | hooks = [2, 5, 8, 11] if hooks == None else hooks 323 | return _make_vit_b16_backbone( 324 | model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout 325 | ) 326 | 327 | 328 | def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): 329 | model = timm.create_model( 330 | "vit_deit_base_distilled_patch16_384", pretrained=pretrained 331 | ) 332 | 333 | hooks = [2, 5, 8, 11] if hooks == None else hooks 334 | return _make_vit_b16_backbone( 335 | model, 336 | features=[96, 192, 384, 768], 337 | hooks=hooks, 338 | use_readout=use_readout, 339 | start_index=2, 340 | ) 341 | 342 | 343 | def _make_vit_b_rn50_backbone( 344 | model, 345 | features=[256, 512, 768, 768], 346 | size=[384, 384], 347 | hooks=[0, 1, 8, 11], 348 | vit_features=768, 349 | use_vit_only=False, 350 | use_readout="ignore", 351 | start_index=1, 352 | ): 353 | pretrained = nn.Module() 354 | 355 | pretrained.model = model 356 | 357 | if use_vit_only == True: 358 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) 359 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) 360 | else: 361 | pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( 362 | get_activation("1") 363 | ) 364 | pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( 365 | get_activation("2") 366 | ) 367 | 368 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) 369 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) 370 | 371 | pretrained.activations = activations 372 | 373 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) 374 | 375 | if use_vit_only == True: 376 | pretrained.act_postprocess1 = nn.Sequential( 377 | readout_oper[0], 378 | Transpose(1, 2), 379 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 380 | nn.Conv2d( 381 | in_channels=vit_features, 382 | out_channels=features[0], 383 | kernel_size=1, 384 | stride=1, 385 | padding=0, 386 | ), 387 | nn.ConvTranspose2d( 388 | in_channels=features[0], 389 | out_channels=features[0], 390 | kernel_size=4, 391 | stride=4, 392 | padding=0, 393 | bias=True, 394 | dilation=1, 395 | groups=1, 396 | ), 397 | ) 398 | 399 | pretrained.act_postprocess2 = nn.Sequential( 400 | readout_oper[1], 401 | Transpose(1, 2), 402 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 403 | nn.Conv2d( 404 | in_channels=vit_features, 405 | out_channels=features[1], 406 | kernel_size=1, 407 | stride=1, 408 | padding=0, 409 | ), 410 | nn.ConvTranspose2d( 411 | in_channels=features[1], 412 | out_channels=features[1], 413 | kernel_size=2, 414 | stride=2, 415 | padding=0, 416 | bias=True, 417 | dilation=1, 418 | groups=1, 419 | ), 420 | ) 421 | else: 422 | pretrained.act_postprocess1 = nn.Sequential( 423 | nn.Identity(), nn.Identity(), nn.Identity() 424 | ) 425 | pretrained.act_postprocess2 = nn.Sequential( 426 | nn.Identity(), nn.Identity(), nn.Identity() 427 | ) 428 | 429 | pretrained.act_postprocess3 = nn.Sequential( 430 | readout_oper[2], 431 | Transpose(1, 2), 432 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 433 | nn.Conv2d( 434 | in_channels=vit_features, 435 | out_channels=features[2], 436 | kernel_size=1, 437 | stride=1, 438 | padding=0, 439 | ), 440 | ) 441 | 442 | pretrained.act_postprocess4 = nn.Sequential( 443 | readout_oper[3], 444 | Transpose(1, 2), 445 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 446 | nn.Conv2d( 447 | in_channels=vit_features, 448 | out_channels=features[3], 449 | kernel_size=1, 450 | stride=1, 451 | padding=0, 452 | ), 453 | nn.Conv2d( 454 | in_channels=features[3], 455 | out_channels=features[3], 456 | kernel_size=3, 457 | stride=2, 458 | padding=1, 459 | ), 460 | ) 461 | 462 | pretrained.model.start_index = start_index 463 | pretrained.model.patch_size = [16, 16] 464 | 465 | # We inject this function into the VisionTransformer instances so that 466 | # we can use it with interpolated position embeddings without modifying the library source. 467 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) 468 | 469 | # We inject this function into the VisionTransformer instances so that 470 | # we can use it with interpolated position embeddings without modifying the library source. 471 | pretrained.model._resize_pos_embed = types.MethodType( 472 | _resize_pos_embed, pretrained.model 473 | ) 474 | 475 | return pretrained 476 | 477 | 478 | def _make_pretrained_vitb_rn50_384( 479 | pretrained, use_readout="ignore", hooks=None, use_vit_only=False 480 | ): 481 | model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) 482 | 483 | hooks = [0, 1, 8, 11] if hooks == None else hooks 484 | return _make_vit_b_rn50_backbone( 485 | model, 486 | features=[256, 512, 768, 768], 487 | size=[384, 384], 488 | hooks=hooks, 489 | use_vit_only=use_vit_only, 490 | use_readout=use_readout, 491 | ) 492 | -------------------------------------------------------------------------------- /src/ldm/modules/midas/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth.""" 2 | import sys 3 | import re 4 | import numpy as np 5 | import cv2 6 | import torch 7 | 8 | 9 | def read_pfm(path): 10 | """Read pfm file. 11 | 12 | Args: 13 | path (str): path to file 14 | 15 | Returns: 16 | tuple: (data, scale) 17 | """ 18 | with open(path, "rb") as file: 19 | 20 | color = None 21 | width = None 22 | height = None 23 | scale = None 24 | endian = None 25 | 26 | header = file.readline().rstrip() 27 | if header.decode("ascii") == "PF": 28 | color = True 29 | elif header.decode("ascii") == "Pf": 30 | color = False 31 | else: 32 | raise Exception("Not a PFM file: " + path) 33 | 34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 35 | if dim_match: 36 | width, height = list(map(int, dim_match.groups())) 37 | else: 38 | raise Exception("Malformed PFM header.") 39 | 40 | scale = float(file.readline().decode("ascii").rstrip()) 41 | if scale < 0: 42 | # little-endian 43 | endian = "<" 44 | scale = -scale 45 | else: 46 | # big-endian 47 | endian = ">" 48 | 49 | data = np.fromfile(file, endian + "f") 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | 55 | return data, scale 56 | 57 | 58 | def write_pfm(path, image, scale=1): 59 | """Write pfm file. 60 | 61 | Args: 62 | path (str): pathto file 63 | image (array): data 64 | scale (int, optional): Scale. Defaults to 1. 65 | """ 66 | 67 | with open(path, "wb") as file: 68 | color = None 69 | 70 | if image.dtype.name != "float32": 71 | raise Exception("Image dtype must be float32.") 72 | 73 | image = np.flipud(image) 74 | 75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 76 | color = True 77 | elif ( 78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 79 | ): # greyscale 80 | color = False 81 | else: 82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 83 | 84 | file.write("PF\n" if color else "Pf\n".encode()) 85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 86 | 87 | endian = image.dtype.byteorder 88 | 89 | if endian == "<" or endian == "=" and sys.byteorder == "little": 90 | scale = -scale 91 | 92 | file.write("%f\n".encode() % scale) 93 | 94 | image.tofile(file) 95 | 96 | 97 | def read_image(path): 98 | """Read image and output RGB image (0-1). 99 | 100 | Args: 101 | path (str): path to file 102 | 103 | Returns: 104 | array: RGB image (0-1) 105 | """ 106 | img = cv2.imread(path) 107 | 108 | if img.ndim == 2: 109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 110 | 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 112 | 113 | return img 114 | 115 | 116 | def resize_image(img): 117 | """Resize image and make it fit for network. 118 | 119 | Args: 120 | img (array): image 121 | 122 | Returns: 123 | tensor: data ready for network 124 | """ 125 | height_orig = img.shape[0] 126 | width_orig = img.shape[1] 127 | 128 | if width_orig > height_orig: 129 | scale = width_orig / 384 130 | else: 131 | scale = height_orig / 384 132 | 133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 135 | 136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 137 | 138 | img_resized = ( 139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 140 | ) 141 | img_resized = img_resized.unsqueeze(0) 142 | 143 | return img_resized 144 | 145 | 146 | def resize_depth(depth, width, height): 147 | """Resize depth map and bring to CPU (numpy). 148 | 149 | Args: 150 | depth (tensor): depth 151 | width (int): image width 152 | height (int): image height 153 | 154 | Returns: 155 | array: processed depth 156 | """ 157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 158 | 159 | depth_resized = cv2.resize( 160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 161 | ) 162 | 163 | return depth_resized 164 | 165 | def write_depth(path, depth, bits=1): 166 | """Write depth map to pfm and png file. 167 | 168 | Args: 169 | path (str): filepath without extension 170 | depth (array): depth 171 | """ 172 | write_pfm(path + ".pfm", depth.astype(np.float32)) 173 | 174 | depth_min = depth.min() 175 | depth_max = depth.max() 176 | 177 | max_val = (2**(8*bits))-1 178 | 179 | if depth_max - depth_min > np.finfo("float").eps: 180 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 181 | else: 182 | out = np.zeros(depth.shape, dtype=depth.type) 183 | 184 | if bits == 1: 185 | cv2.imwrite(path + ".png", out.astype("uint8")) 186 | elif bits == 2: 187 | cv2.imwrite(path + ".png", out.astype("uint16")) 188 | 189 | return 190 | -------------------------------------------------------------------------------- /src/ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | from torch import optim 5 | import numpy as np 6 | 7 | from inspect import isfunction 8 | from PIL import Image, ImageDraw, ImageFont 9 | 10 | 11 | def log_txt_as_img(wh, xc, size=10): 12 | # wh a tuple of (width, height) 13 | # xc a list of captions to plot 14 | b = len(xc) 15 | txts = list() 16 | for bi in range(b): 17 | txt = Image.new("RGB", wh, color="white") 18 | draw = ImageDraw.Draw(txt) 19 | font = ImageFont.truetype('font/DejaVuSans.ttf', size=size) 20 | nc = int(40 * (wh[0] / 256)) 21 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 22 | 23 | try: 24 | draw.text((0, 0), lines, fill="black", font=font) 25 | except UnicodeEncodeError: 26 | print("Cant encode string for logging. Skipping.") 27 | 28 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 29 | txts.append(txt) 30 | txts = np.stack(txts) 31 | txts = torch.tensor(txts) 32 | return txts 33 | 34 | 35 | def ismap(x): 36 | if not isinstance(x, torch.Tensor): 37 | return False 38 | return (len(x.shape) == 4) and (x.shape[1] > 3) 39 | 40 | 41 | def isimage(x): 42 | if not isinstance(x,torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 45 | 46 | 47 | def exists(x): 48 | return x is not None 49 | 50 | 51 | def default(val, d): 52 | if exists(val): 53 | return val 54 | return d() if isfunction(d) else d 55 | 56 | 57 | def mean_flat(tensor): 58 | """ 59 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 60 | Take the mean over all non-batch dimensions. 61 | """ 62 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 63 | 64 | 65 | def count_params(model, verbose=False): 66 | total_params = sum(p.numel() for p in model.parameters()) 67 | if verbose: 68 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 69 | return total_params 70 | 71 | 72 | def instantiate_from_config(config): 73 | if not "target" in config: 74 | if config == '__is_first_stage__': 75 | return None 76 | elif config == "__is_unconditional__": 77 | return None 78 | raise KeyError("Expected key `target` to instantiate.") 79 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 80 | 81 | 82 | def get_obj_from_str(string, reload=False): 83 | module, cls = string.rsplit(".", 1) 84 | 85 | if reload: 86 | module_imp = importlib.import_module(module) 87 | importlib.reload(module_imp) 88 | return getattr(importlib.import_module(module, package=None), cls) 89 | 90 | 91 | class AdamWwithEMAandWings(optim.Optimizer): 92 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 93 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using 94 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code 95 | ema_power=1., param_names=()): 96 | """AdamW that saves EMA versions of the parameters.""" 97 | if not 0.0 <= lr: 98 | raise ValueError("Invalid learning rate: {}".format(lr)) 99 | if not 0.0 <= eps: 100 | raise ValueError("Invalid epsilon value: {}".format(eps)) 101 | if not 0.0 <= betas[0] < 1.0: 102 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 103 | if not 0.0 <= betas[1] < 1.0: 104 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 105 | if not 0.0 <= weight_decay: 106 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 107 | if not 0.0 <= ema_decay <= 1.0: 108 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) 109 | defaults = dict(lr=lr, betas=betas, eps=eps, 110 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, 111 | ema_power=ema_power, param_names=param_names) 112 | super().__init__(params, defaults) 113 | 114 | def __setstate__(self, state): 115 | super().__setstate__(state) 116 | for group in self.param_groups: 117 | group.setdefault('amsgrad', False) 118 | 119 | @torch.no_grad() 120 | def step(self, closure=None): 121 | """Performs a single optimization step. 122 | Args: 123 | closure (callable, optional): A closure that reevaluates the model 124 | and returns the loss. 125 | """ 126 | loss = None 127 | if closure is not None: 128 | with torch.enable_grad(): 129 | loss = closure() 130 | 131 | for group in self.param_groups: 132 | params_with_grad = [] 133 | grads = [] 134 | exp_avgs = [] 135 | exp_avg_sqs = [] 136 | ema_params_with_grad = [] 137 | state_sums = [] 138 | max_exp_avg_sqs = [] 139 | state_steps = [] 140 | amsgrad = group['amsgrad'] 141 | beta1, beta2 = group['betas'] 142 | ema_decay = group['ema_decay'] 143 | ema_power = group['ema_power'] 144 | 145 | for p in group['params']: 146 | if p.grad is None: 147 | continue 148 | params_with_grad.append(p) 149 | if p.grad.is_sparse: 150 | raise RuntimeError('AdamW does not support sparse gradients') 151 | grads.append(p.grad) 152 | 153 | state = self.state[p] 154 | 155 | # State initialization 156 | if len(state) == 0: 157 | state['step'] = 0 158 | # Exponential moving average of gradient values 159 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 160 | # Exponential moving average of squared gradient values 161 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 162 | if amsgrad: 163 | # Maintains max of all exp. moving avg. of sq. grad. values 164 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 165 | # Exponential moving average of parameter values 166 | state['param_exp_avg'] = p.detach().float().clone() 167 | 168 | exp_avgs.append(state['exp_avg']) 169 | exp_avg_sqs.append(state['exp_avg_sq']) 170 | ema_params_with_grad.append(state['param_exp_avg']) 171 | 172 | if amsgrad: 173 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 174 | 175 | # update the steps for each param group update 176 | state['step'] += 1 177 | # record the step after step update 178 | state_steps.append(state['step']) 179 | 180 | optim._functional.adamw(params_with_grad, 181 | grads, 182 | exp_avgs, 183 | exp_avg_sqs, 184 | max_exp_avg_sqs, 185 | state_steps, 186 | amsgrad=amsgrad, 187 | beta1=beta1, 188 | beta2=beta2, 189 | lr=group['lr'], 190 | weight_decay=group['weight_decay'], 191 | eps=group['eps'], 192 | maximize=False) 193 | 194 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) 195 | for param, ema_param in zip(params_with_grad, ema_params_with_grad): 196 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) 197 | 198 | return loss -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EliM0/TEXTureControlNet/f5f5fbe018e4eca59a76809bfe3d4d75e85bac86/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/cldm_v15.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: src.cldm.cldm.ControlLDM 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | num_timesteps_cond: 1 7 | log_every_t: 200 8 | timesteps: 1000 9 | first_stage_key: "jpg" 10 | cond_stage_key: "txt" 11 | control_key: "hint" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | only_mid_control: False 20 | 21 | control_stage_config: 22 | target: src.cldm.cldm.ControlNet 23 | params: 24 | image_size: 32 # unused 25 | in_channels: 4 26 | hint_channels: 3 27 | model_channels: 320 28 | attention_resolutions: [ 4, 2, 1 ] 29 | num_res_blocks: 2 30 | channel_mult: [ 1, 2, 4, 4 ] 31 | num_heads: 8 32 | use_spatial_transformer: True 33 | transformer_depth: 1 34 | context_dim: 768 35 | use_checkpoint: True 36 | legacy: False 37 | 38 | unet_config: 39 | target: src.cldm.cldm.ControlledUnetModel 40 | params: 41 | image_size: 32 # unused 42 | in_channels: 4 43 | out_channels: 4 44 | model_channels: 320 45 | attention_resolutions: [ 4, 2, 1 ] 46 | num_res_blocks: 2 47 | channel_mult: [ 1, 2, 4, 4 ] 48 | num_heads: 8 49 | use_spatial_transformer: True 50 | transformer_depth: 1 51 | context_dim: 768 52 | use_checkpoint: True 53 | legacy: False 54 | 55 | first_stage_config: 56 | target: src.ldm.models.autoencoder.AutoencoderKL 57 | params: 58 | embed_dim: 4 59 | monitor: val/rec_loss 60 | ddconfig: 61 | double_z: true 62 | z_channels: 4 63 | resolution: 256 64 | in_channels: 3 65 | out_ch: 3 66 | ch: 128 67 | ch_mult: 68 | - 1 69 | - 2 70 | - 4 71 | - 4 72 | num_res_blocks: 2 73 | attn_resolutions: [] 74 | dropout: 0.0 75 | lossconfig: 76 | target: torch.nn.Identity 77 | 78 | cond_stage_config: 79 | target: src.ldm.modules.encoders.modules.FrozenCLIPEmbedder 80 | -------------------------------------------------------------------------------- /src/models/mesh.py: -------------------------------------------------------------------------------- 1 | import kaolin as kal 2 | import torch 3 | 4 | import copy 5 | 6 | class Mesh: 7 | def __init__(self,obj_path, device): 8 | # from https://github.com/threedle/text2mesh 9 | 10 | if ".obj" in obj_path: 11 | try: 12 | mesh = kal.io.obj.import_mesh(obj_path, with_normals=True, with_materials=True) 13 | except: 14 | mesh = kal.io.obj.import_mesh(obj_path, with_normals=True, with_materials=False) 15 | 16 | elif ".off" in obj_path: 17 | mesh = kal.io.off.import_mesh(obj_path) 18 | else: 19 | raise ValueError(f"{obj_path} extension not implemented in mesh reader.") 20 | 21 | self.vertices = mesh.vertices.to(device) 22 | self.faces = mesh.faces.to(device) 23 | self.normals, self.face_area = self.calculate_face_normals(self.vertices, self.faces) 24 | self.ft = mesh.face_uvs_idx 25 | self.vt = mesh.uvs 26 | 27 | @staticmethod 28 | def calculate_face_normals(vertices: torch.Tensor, faces: torch.Tensor): 29 | """ 30 | calculate per face normals from vertices and faces 31 | """ 32 | v0 = vertices[faces[:, 0]] 33 | v1 = vertices[faces[:, 1]] 34 | v2 = vertices[faces[:, 2]] 35 | e0 = v1 - v0 36 | e1 = v2 - v0 37 | n = torch.cross(e0, e1, dim=-1) 38 | twice_area = torch.norm(n, dim=-1) 39 | n = n / twice_area[:, None] 40 | return n, twice_area / 2 41 | 42 | def standardize_mesh(self,inplace=False): 43 | mesh = self if inplace else copy.deepcopy(self) 44 | 45 | verts = mesh.vertices 46 | center = verts.mean(dim=0) 47 | verts -= center 48 | scale = torch.std(torch.norm(verts, p=2, dim=1)) 49 | verts /= scale 50 | mesh.vertices = verts 51 | return mesh 52 | 53 | def normalize_mesh(self,inplace=False, target_scale=1, dy=0): 54 | mesh = self if inplace else copy.deepcopy(self) 55 | 56 | verts = mesh.vertices 57 | center = verts.mean(dim=0) 58 | verts = verts - center 59 | scale = torch.max(torch.norm(verts, p=2, dim=1)) 60 | verts = verts / scale 61 | verts *= target_scale 62 | verts[:, 1] += dy 63 | mesh.vertices = verts 64 | return mesh 65 | 66 | -------------------------------------------------------------------------------- /src/models/render.py: -------------------------------------------------------------------------------- 1 | import kaolin as kal 2 | import torch 3 | import numpy as np 4 | from loguru import logger 5 | class Renderer: 6 | # from https://github.com/threedle/text2mesh 7 | 8 | def __init__(self, device, dim=(224, 224), interpolation_mode='nearest'): 9 | assert interpolation_mode in ['nearest', 'bilinear', 'bicubic'], f'no interpolation mode {interpolation_mode}' 10 | 11 | camera = kal.render.camera.generate_perspective_projection(np.pi / 3).to(device) 12 | 13 | self.device = device 14 | self.interpolation_mode = interpolation_mode 15 | self.camera_projection = camera 16 | self.dim = dim 17 | self.background = torch.ones(dim).to(device).float() 18 | 19 | @staticmethod 20 | def get_camera_from_view(elev, azim, r=3.0, look_at_height=0.0): 21 | x = r * torch.sin(elev) * torch.sin(azim) 22 | y = r * torch.cos(elev) 23 | z = r * torch.sin(elev) * torch.cos(azim) 24 | 25 | pos = torch.tensor([x, y, z]).unsqueeze(0) 26 | look_at = torch.zeros_like(pos) 27 | look_at[:, 1] = look_at_height 28 | direction = torch.tensor([0.0, 1.0, 0.0]).unsqueeze(0) 29 | 30 | camera_proj = kal.render.camera.generate_transformation_matrix(pos, look_at, direction) 31 | return camera_proj 32 | 33 | 34 | def normalize_depth(self, depth_map): 35 | assert depth_map.max() <= 0.0, 'depth map should be negative' 36 | object_mask = depth_map != 0 37 | # depth_map[object_mask] = (depth_map[object_mask] - depth_map[object_mask].min()) / ( 38 | # depth_map[object_mask].max() - depth_map[object_mask].min()) 39 | # depth_map = depth_map ** 4 40 | min_val = 0.5 41 | depth_map[object_mask] = ((1 - min_val) * (depth_map[object_mask] - depth_map[object_mask].min()) / ( 42 | depth_map[object_mask].max() - depth_map[object_mask].min())) + min_val 43 | # depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) 44 | # depth_map[depth_map == 1] = 0 # Background gets largest value, set to 0 45 | 46 | return depth_map 47 | 48 | def render_single_view(self, mesh, face_attributes, elev=0, azim=0, radius=2, look_at_height=0.0,calc_depth=True,dims=None, background_type='none'): 49 | dims = self.dim if dims is None else dims 50 | 51 | camera_transform = self.get_camera_from_view(torch.tensor(elev), torch.tensor(azim), r=radius, 52 | look_at_height=look_at_height).to(self.device) 53 | face_vertices_camera, face_vertices_image, face_normals = kal.render.mesh.prepare_vertices( 54 | mesh.vertices.to(self.device), mesh.faces.to(self.device), self.camera_projection, camera_transform=camera_transform) 55 | 56 | if calc_depth: 57 | depth_map, _ = kal.render.mesh.rasterize(dims[1], dims[0], face_vertices_camera[:, :, :, -1], 58 | face_vertices_image, face_vertices_camera[:, :, :, -1:]) 59 | depth_map = self.normalize_depth(depth_map) 60 | else: 61 | depth_map = torch.zeros(1,64,64,1) 62 | 63 | image_features, face_idx = kal.render.mesh.rasterize(dims[1], dims[0], face_vertices_camera[:, :, :, -1], 64 | face_vertices_image, face_attributes) 65 | 66 | mask = (face_idx > -1).float()[..., None] 67 | if background_type == 'white': 68 | image_features += 1 * (1 - mask) 69 | if background_type == 'random': 70 | image_features += torch.rand((1,1,1,3)).to(self.device) * (1 - mask) 71 | 72 | return image_features.permute(0, 3, 1, 2), mask.permute(0, 3, 1, 2), depth_map.permute(0, 3, 1, 2) 73 | 74 | 75 | def render_single_view_texture(self, verts, faces, uv_face_attr, texture_map, elev=0, azim=0, radius=2, 76 | look_at_height=0.0, dims=None, background_type='none', render_cache=None): 77 | dims = self.dim if dims is None else dims 78 | 79 | if render_cache is None: 80 | 81 | camera_transform = self.get_camera_from_view(torch.tensor(elev), torch.tensor(azim), r=radius, 82 | look_at_height=look_at_height).to(self.device) 83 | face_vertices_camera, face_vertices_image, face_normals = kal.render.mesh.prepare_vertices( 84 | verts.to(self.device), faces.to(self.device), self.camera_projection, camera_transform=camera_transform) 85 | 86 | depth_map, _ = kal.render.mesh.rasterize(dims[1], dims[0], face_vertices_camera[:, :, :, -1], 87 | face_vertices_image, face_vertices_camera[:, :, :, -1:]) 88 | depth_map = self.normalize_depth(depth_map) 89 | 90 | uv_features, face_idx = kal.render.mesh.rasterize(dims[1], dims[0], face_vertices_camera[:, :, :, -1], 91 | face_vertices_image, uv_face_attr) 92 | uv_features = uv_features.detach() 93 | 94 | else: 95 | # logger.info('Using render cache') 96 | face_normals, uv_features, face_idx, depth_map = render_cache['face_normals'], render_cache['uv_features'], render_cache['face_idx'], render_cache['depth_map'] 97 | mask = (face_idx > -1).float()[..., None] 98 | 99 | image_features = kal.render.mesh.texture_mapping(uv_features, texture_map, mode=self.interpolation_mode) 100 | image_features = image_features * mask 101 | if background_type == 'white': 102 | image_features += 1 * (1 - mask) 103 | elif background_type == 'random': 104 | image_features += torch.rand((1,1,1,3)).to(self.device) * (1 - mask) 105 | 106 | normals_image = face_normals[0][face_idx, :] 107 | 108 | render_cache = {'uv_features':uv_features, 'face_normals':face_normals,'face_idx':face_idx, 'depth_map':depth_map} 109 | 110 | return image_features.permute(0, 3, 1, 2), mask.permute(0, 3, 1, 2),\ 111 | depth_map.permute(0, 3, 1, 2), normals_image.permute(0, 3, 1, 2), render_cache 112 | 113 | def project_uv_single_view(self, verts, faces, uv_face_attr, elev=0, azim=0, radius=2, 114 | look_at_height=0.0, dims=None, background_type='none'): 115 | # project the vertices and interpolate the uv coordinates 116 | 117 | dims = self.dim if dims is None else dims 118 | 119 | camera_transform = self.get_camera_from_view(torch.tensor(elev), torch.tensor(azim), r=radius, 120 | look_at_height=look_at_height).to(self.device) 121 | face_vertices_camera, face_vertices_image, face_normals = kal.render.mesh.prepare_vertices( 122 | verts.to(self.device), faces.to(self.device), self.camera_projection, camera_transform=camera_transform) 123 | 124 | uv_features, face_idx = kal.render.mesh.rasterize(dims[1], dims[0], face_vertices_camera[:, :, :, -1], 125 | face_vertices_image, uv_face_attr) 126 | return face_vertices_image, face_vertices_camera, uv_features, face_idx 127 | 128 | def project_single_view(self, verts, faces, elev=0, azim=0, radius=2, 129 | look_at_height=0.0): 130 | # only project the vertices 131 | camera_transform = self.get_camera_from_view(torch.tensor(elev), torch.tensor(azim), r=radius, 132 | look_at_height=look_at_height).to(self.device) 133 | face_vertices_camera, face_vertices_image, face_normals = kal.render.mesh.prepare_vertices( 134 | verts.to(self.device), faces.to(self.device), self.camera_projection, camera_transform=camera_transform) 135 | 136 | return face_vertices_image 137 | -------------------------------------------------------------------------------- /src/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EliM0/TEXTureControlNet/f5f5fbe018e4eca59a76809bfe3d4d75e85bac86/src/training/__init__.py -------------------------------------------------------------------------------- /src/training/views_dataset.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import DataLoader 6 | 7 | from src.configs.train_config import RenderConfig 8 | from src.utils import get_view_direction 9 | from loguru import logger 10 | 11 | 12 | def rand_poses(size, device, radius_range=(1.0, 1.5), theta_range=(0.0, 150.0), phi_range=(0.0, 360.0), 13 | angle_overhead=30.0, angle_front=60.0, biased_angles=True): 14 | if theta_range != (0.0, 180.0): 15 | warnings.warn("theta_range is not (0.0, 180.0) in rand_poses\n Will use (0.0, 180.0) instead") 16 | 17 | angle_overhead = np.deg2rad(angle_overhead) 18 | angle_front = np.deg2rad(angle_front) 19 | 20 | radius = torch.rand(size, device=device) * (radius_range[1] - radius_range[0]) + radius_range[0] 21 | 22 | theta_range = np.deg2rad(theta_range) 23 | phi_range = np.deg2rad(phi_range) 24 | 25 | if biased_angles: 26 | top_flag = np.random.rand() > 0.3 # 70% of the time, the camera is at the top 27 | if top_flag: 28 | x = 1 - torch.rand(size, device=device) 29 | thetas = torch.acos(x) 30 | phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0] 31 | else: 32 | x = 1 - (torch.rand(size, device=device) + 1) 33 | thetas = torch.acos(x) 34 | phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0] 35 | else: 36 | # logger.warning('Using old theta calc') 37 | thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0] 38 | # thetas = torch.acos(1-2*torch.rand(size, device=device)) 39 | 40 | phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0] 41 | 42 | dirs = get_view_direction(thetas, phis, angle_overhead, angle_front) 43 | 44 | return dirs, thetas.item(), phis.item(), radius.item() 45 | 46 | 47 | def rand_modal_poses(size, device, radius_range=(1.4, 1.6), theta_range=(45.0, 90.0), phi_range=(0.0, 360.0), 48 | angle_overhead=30.0, theta_range_overhead=(0.0, 20.0), angle_front=60.0): 49 | theta_range = np.deg2rad(theta_range) 50 | theta_range_overhead = np.deg2rad(theta_range_overhead) 51 | phi_range = np.deg2rad(phi_range) 52 | angle_overhead = np.deg2rad(angle_overhead) 53 | angle_front = np.deg2rad(angle_front) 54 | 55 | radius = torch.rand(size, device=device) * (radius_range[1] - radius_range[0]) + radius_range[0] 56 | 57 | overhead_flag = torch.rand(1, device=device) > 0.85 58 | if overhead_flag: 59 | phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0] 60 | thetas = torch.rand(size, device=device) * (theta_range_overhead[1] - theta_range_overhead[0]) + \ 61 | theta_range_overhead[0] 62 | else: 63 | phi_mods = np.deg2rad([0, 90, 180, 270]) 64 | pertube_magnitude = np.deg2rad(15) 65 | rand_pertubations = torch.rand(size, device=device) * pertube_magnitude 66 | phis = rand_pertubations + torch.from_numpy(phi_mods[np.random.randint(0, 4, size)]).to(device) 67 | thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0] 68 | 69 | dirs = get_view_direction(thetas, phis, angle_overhead, angle_front) 70 | 71 | return dirs, thetas.item(), phis.item(), radius.item() 72 | 73 | 74 | def circle_poses(device, radius=1.25, theta=60.0, phi=0.0, angle_overhead=30.0, angle_front=60.0): 75 | theta = np.deg2rad(theta) 76 | phi = np.deg2rad(phi) 77 | angle_overhead = np.deg2rad(angle_overhead) 78 | angle_front = np.deg2rad(angle_front) 79 | 80 | thetas = torch.FloatTensor([theta]).to(device) 81 | phis = torch.FloatTensor([phi]).to(device) 82 | dirs = get_view_direction(thetas, phis, angle_overhead, angle_front) 83 | 84 | return dirs, thetas.item(), phis.item(), radius 85 | 86 | 87 | class MultiviewDataset: 88 | def __init__(self, cfg: RenderConfig, device): 89 | super().__init__() 90 | 91 | self.cfg = cfg 92 | self.device = device 93 | self.type = type # train, val, tests 94 | size = self.cfg.n_views 95 | 96 | self.phis = [(index / size) * 360 for index in range(size)] 97 | self.thetas = [self.cfg.base_theta for _ in range(size)] 98 | 99 | # Alternate lists 100 | alternate_lists = lambda l: [l[0]] + [i for j in zip(l[1:size // 2], l[-1:size // 2:-1]) for i in j] + [ 101 | l[size // 2]] 102 | if self.cfg.alternate_views: 103 | self.phis = alternate_lists(self.phis) 104 | self.thetas = alternate_lists(self.thetas) 105 | logger.info(f'phis: {self.phis}') 106 | # self.phis = self.phis[1:2] 107 | # self.thetas = self.thetas[1:2] 108 | # if append_upper: 109 | # # self.phis = [0,180, 0, 180]+self.phis 110 | # # self.thetas =[30, 30, 150, 150]+self.thetas 111 | # self.phis =[180,180]+self.phis 112 | # self.thetas = [30,150]+self.thetas 113 | 114 | for phi, theta in self.cfg.views_before: 115 | self.phis = [phi] + self.phis 116 | self.thetas = [theta] + self.thetas 117 | for phi, theta in self.cfg.views_after: 118 | self.phis = self.phis + [phi] 119 | self.thetas = self.thetas + [theta] 120 | # self.phis = [0, 0] + self.phis 121 | # self.thetas = [20, 160] + self.thetas 122 | 123 | self.size = len(self.phis) 124 | 125 | def collate(self, index): 126 | 127 | # B = len(index) # always 1 128 | 129 | # phi = (index[0] / self.size) * 360 130 | phi = self.phis[index[0]] 131 | theta = self.thetas[index[0]] 132 | radius = self.cfg.radius 133 | dirs, thetas, phis, radius = circle_poses(self.device, radius=radius, theta=theta, 134 | phi=phi, 135 | angle_overhead=self.cfg.overhead_range, 136 | angle_front=self.cfg.front_range) 137 | 138 | data = { 139 | 'dir': dirs, 140 | 'theta': thetas, 141 | 'phi': phis, 142 | 'radius': radius 143 | } 144 | 145 | return data 146 | 147 | def dataloader(self): 148 | loader = DataLoader(list(range(self.size)), batch_size=1, collate_fn=self.collate, shuffle=False, 149 | num_workers=0) 150 | loader._data = self # an ugly fix... we need to access dataset in trainer. 151 | return loader 152 | 153 | 154 | class ViewsDataset: 155 | def __init__(self, cfg: RenderConfig, device, size=100): 156 | super().__init__() 157 | 158 | self.cfg = cfg 159 | self.device = device 160 | self.type = type # train, val, test 161 | self.size = size 162 | 163 | def collate(self, index): 164 | # circle pose 165 | phi = (index[0] / self.size) * 360 166 | dirs, thetas, phis, radius = circle_poses(self.device, radius=self.cfg.radius * 1.2, theta=self.cfg.base_theta, 167 | phi=phi, 168 | angle_overhead=self.cfg.overhead_range, 169 | angle_front=self.cfg.front_range) 170 | 171 | data = { 172 | 'dir': dirs, 173 | 'theta': thetas, 174 | 'phi': phis, 175 | 'radius': radius 176 | } 177 | 178 | return data 179 | 180 | def dataloader(self): 181 | loader = DataLoader(list(range(self.size)), batch_size=1, collate_fn=self.collate, shuffle=False, 182 | num_workers=0) 183 | loader._data = self # an ugly fix... we need to access dataset in trainer. 184 | return loader 185 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | from pathlib import Path 4 | from typing import List 5 | 6 | import numpy as np 7 | import torch 8 | import torchvision.transforms as T 9 | from PIL import Image 10 | import einops 11 | from matplotlib import cm 12 | import torch.nn.functional as F 13 | 14 | 15 | def get_view_direction(thetas, phis, overhead, front): 16 | # phis [B,]; thetas: [B,] 17 | # front = 0 [0, front) 18 | # side (left) = 1 [front, 180) 19 | # back = 2 [180, 180+front) 20 | # side (right) = 3 [180+front, 360) 21 | # top = 4 [0, overhead] 22 | # bottom = 5 [180-overhead, 180] 23 | res = torch.zeros(thetas.shape[0], dtype=torch.long) 24 | # first determine by phis 25 | 26 | # res[(phis < front)] = 0 27 | res[(phis >= (2 * np.pi - front / 2)) & (phis < front / 2)] = 0 28 | 29 | # res[(phis >= front) & (phis < np.pi)] = 1 30 | res[(phis >= front / 2) & (phis < (np.pi - front / 2))] = 1 31 | 32 | # res[(phis >= np.pi) & (phis < (np.pi + front))] = 2 33 | res[(phis >= (np.pi - front / 2)) & (phis < (np.pi + front / 2))] = 2 34 | 35 | # res[(phis >= (np.pi + front))] = 3 36 | res[(phis >= (np.pi + front / 2)) & (phis < (2 * np.pi - front / 2))] = 3 37 | # override by thetas 38 | res[thetas <= overhead] = 4 39 | res[thetas >= (np.pi - overhead)] = 5 40 | return res 41 | 42 | 43 | def tensor2numpy(tensor: torch.Tensor) -> np.ndarray: 44 | tensor = tensor.detach().cpu().numpy() 45 | tensor = (tensor * 255).astype(np.uint8) 46 | return tensor 47 | 48 | 49 | def make_path(path: Path) -> Path: 50 | path.mkdir(exist_ok=True, parents=True) 51 | return path 52 | 53 | 54 | 55 | def save_colormap(tensor: torch.Tensor, path: Path): 56 | Image.fromarray((cm.seismic(tensor.cpu().numpy())[:, :, :3] * 255).astype(np.uint8)).save(path) 57 | 58 | 59 | 60 | def seed_everything(seed): 61 | random.seed(seed) 62 | os.environ['PYTHONHASHSEED'] = str(seed) 63 | np.random.seed(seed) 64 | torch.manual_seed(seed) 65 | torch.cuda.manual_seed(seed) 66 | # torch.backends.cudnn.deterministic = True 67 | # torch.backends.cudnn.benchmark = True 68 | 69 | 70 | def smooth_image(self, img: torch.Tensor, sigma: float) -> torch.Tensor: 71 | """apply gaussian blur to an image tensor with shape [C, H, W]""" 72 | img = T.GaussianBlur(kernel_size=(51, 51), sigma=(sigma, sigma))(img) 73 | return img 74 | 75 | 76 | def get_nonzero_region(mask:torch.Tensor): 77 | # Get the indices of the non-zero elements 78 | nz_indices = mask.nonzero() 79 | # Get the minimum and maximum indices along each dimension 80 | min_h, max_h = nz_indices[:, 0].min(), nz_indices[:, 0].max() 81 | min_w, max_w = nz_indices[:, 1].min(), nz_indices[:, 1].max() 82 | 83 | # Calculate the size of the square region 84 | size = max(max_h - min_h + 1, max_w - min_w + 1) * 1.1 85 | # Calculate the upper left corner of the square region 86 | h_start = min(min_h, max_h) - (size - (max_h - min_h + 1)) / 2 87 | w_start = min(min_w, max_w) - (size - (max_w - min_w + 1)) / 2 88 | 89 | min_h = int(h_start) 90 | min_w = int(w_start) 91 | max_h = int(min_h + size) 92 | max_w = int(min_w + size) 93 | 94 | return min_h, min_w, max_h, max_w 95 | 96 | 97 | def gaussian_fn(M, std): 98 | n = torch.arange(0, M) - (M - 1.0) / 2.0 99 | sig2 = 2 * std * std 100 | w = torch.exp(-n ** 2 / sig2) 101 | return w 102 | 103 | 104 | def gkern(kernlen=256, std=128): 105 | """Returns a 2D Gaussian kernel array.""" 106 | gkern1d = gaussian_fn(kernlen, std=std) 107 | gkern2d = torch.outer(gkern1d, gkern1d) 108 | return gkern2d 109 | 110 | def gaussian_blur(image:torch.Tensor, kernel_size:int, std:int) -> torch.Tensor: 111 | gaussian_filter = gkern(kernel_size, std=std) 112 | gaussian_filter /= gaussian_filter.sum() 113 | 114 | image = F.conv2d(image, 115 | gaussian_filter.unsqueeze(0).unsqueeze(0).cuda(), padding=kernel_size // 2) 116 | return image 117 | 118 | def color_with_shade(color: List[float],z_normals:torch.Tensor,light_coef=0.7): 119 | normals_with_light = (light_coef + (1 - light_coef) * z_normals.detach()) 120 | shaded_color = torch.tensor(color).view(1, 3, 1, 1).to( 121 | z_normals.device) * normals_with_light 122 | return shaded_color 123 | -------------------------------------------------------------------------------- /textures/brick_wall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EliM0/TEXTureControlNet/f5f5fbe018e4eca59a76809bfe3d4d75e85bac86/textures/brick_wall.png --------------------------------------------------------------------------------