├── ldm ├── data │ ├── __init__.py │ └── util.py ├── models │ ├── diffusion │ │ ├── __init__.py │ │ ├── dpm_solver │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-39.pyc │ │ │ │ ├── sampler.cpython-312.pyc │ │ │ │ ├── sampler.cpython-39.pyc │ │ │ │ ├── __init__.cpython-312.pyc │ │ │ │ ├── dpm_solver.cpython-312.pyc │ │ │ │ └── dpm_solver.cpython-39.pyc │ │ │ └── sampler.py │ │ └── sampling_util.py │ └── autoencoder.py ├── modules │ ├── encoders │ │ └── __init__.py │ ├── karlo │ │ ├── __init__.py │ │ └── kakao │ │ │ ├── __init__.py │ │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── clip.cpython-39.pyc │ │ │ │ ├── __init__.cpython-39.pyc │ │ │ │ ├── sr_64_256.cpython-39.pyc │ │ │ │ ├── prior_model.cpython-39.pyc │ │ │ │ └── decoder_model.cpython-39.pyc │ │ │ ├── sr_256_1k.py │ │ │ ├── sr_64_256.py │ │ │ ├── prior_model.py │ │ │ ├── clip.py │ │ │ └── decoder_model.py │ │ │ ├── modules │ │ │ ├── diffusion │ │ │ │ ├── __pycache__ │ │ │ │ │ ├── respace.cpython-39.pyc │ │ │ │ │ └── gaussian_diffusion.cpython-39.pyc │ │ │ │ └── respace.py │ │ │ ├── __init__.py │ │ │ ├── resample.py │ │ │ ├── nn.py │ │ │ └── xf.py │ │ │ ├── template.py │ │ │ └── sampler.py │ ├── midas │ │ ├── __init__.py │ │ ├── midas │ │ │ ├── __init__.py │ │ │ ├── base_model.py │ │ │ ├── midas_net.py │ │ │ ├── dpt_depth.py │ │ │ ├── midas_net_custom.py │ │ │ ├── transforms.py │ │ │ └── blocks.py │ │ ├── utils.py │ │ └── api.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── upscaling.py │ │ └── util.py │ ├── image_degradation │ │ ├── utils │ │ │ └── test.png │ │ └── __init__.py │ └── ema.py └── util.py ├── stable_diffusion.egg-info ├── top_level.txt ├── dependency_links.txt ├── requires.txt ├── PKG-INFO └── SOURCES.txt ├── samples_paper ├── image.png ├── results.png ├── Architecture.PNG ├── firstPageImg.PNG ├── 17273391_55cfc7d3d4.jpg ├── 55473406_1d2271c1f2.jpg ├── 96985174_31d4c6f06d.jpg └── 431282339_0aa60dd78e.jpg ├── setup.py ├── requirements.txt ├── environment.yaml ├── LICENSE ├── configs └── stable-diffusion │ ├── v2-inference.yaml │ ├── v2-inference-v.yaml │ ├── intel │ ├── v2-inference-fp32.yaml │ ├── v2-inference-bf16.yaml │ ├── v2-inference-v-fp32.yaml │ └── v2-inference-v-bf16.yaml │ └── v1-inference.yaml ├── scripts ├── evaluate_metrics.py ├── metrics.py ├── dataset.py ├── qam.py ├── semantic_t2i.py └── semantic_i2i.py ├── .gitignore └── README.md /ldm/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/karlo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/midas/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stable_diffusion.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /stable_diffusion.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /stable_diffusion.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | tqdm 4 | -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /samples_paper/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/samples_paper/image.png -------------------------------------------------------------------------------- /samples_paper/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/samples_paper/results.png -------------------------------------------------------------------------------- /samples_paper/Architecture.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/samples_paper/Architecture.PNG -------------------------------------------------------------------------------- /samples_paper/firstPageImg.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/samples_paper/firstPageImg.PNG -------------------------------------------------------------------------------- /samples_paper/17273391_55cfc7d3d4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/samples_paper/17273391_55cfc7d3d4.jpg -------------------------------------------------------------------------------- /samples_paper/55473406_1d2271c1f2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/samples_paper/55473406_1d2271c1f2.jpg -------------------------------------------------------------------------------- /samples_paper/96985174_31d4c6f06d.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/samples_paper/96985174_31d4c6f06d.jpg -------------------------------------------------------------------------------- /samples_paper/431282339_0aa60dd78e.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/samples_paper/431282339_0aa60dd78e.jpg -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/models/__pycache__/clip.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/modules/karlo/kakao/models/__pycache__/clip.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-312.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/modules/karlo/kakao/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/models/__pycache__/sr_64_256.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/modules/karlo/kakao/models/__pycache__/sr_64_256.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-312.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/models/__pycache__/prior_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/modules/karlo/kakao/models/__pycache__/prior_model.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/models/__pycache__/decoder_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/modules/karlo/kakao/models/__pycache__/decoder_model.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/modules/diffusion/__pycache__/respace.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/modules/karlo/kakao/modules/diffusion/__pycache__/respace.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/modules/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/modules/karlo/kakao/modules/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc -------------------------------------------------------------------------------- /stable_diffusion.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: stable-diffusion 3 | Version: 0.0.1 4 | License-File: LICENSE 5 | License-File: LICENSE-MODEL 6 | Requires-Dist: torch 7 | Requires-Dist: numpy 8 | Requires-Dist: tqdm 9 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='stable-diffusion', 5 | version='0.0.1', 6 | description='', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) -------------------------------------------------------------------------------- /stable_diffusion.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE 2 | LICENSE-MODEL 3 | README.md 4 | setup.py 5 | stable_diffusion.egg-info/PKG-INFO 6 | stable_diffusion.egg-info/SOURCES.txt 7 | stable_diffusion.egg-info/dependency_links.txt 8 | stable_diffusion.egg-info/requires.txt 9 | stable_diffusion.egg-info/top_level.txt -------------------------------------------------------------------------------- /ldm/modules/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/models/sr_256_1k.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Karlo-v1.0.alpha 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | from ldm.modules.karlo.kakao.models.sr_64_256 import SupRes64to256Progressive 7 | 8 | 9 | class SupRes256to1kProgressive(SupRes64to256Progressive): 10 | pass # no difference currently 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==0.4.3 2 | opencv-python 3 | pudb==2019.2 4 | imageio==2.9.0 5 | imageio-ffmpeg==0.4.2 6 | pytorch-lightning==1.6.0 #pytorch-lightning==1.4.2 7 | torchmetrics==0.6 8 | omegaconf==2.1.1 9 | test-tube>=0.7.5 10 | streamlit>=0.73.1 11 | einops==0.3.0 12 | transformers 13 | webdataset==0.2.5 14 | open-clip-torch==2.7.0 15 | kornia==0.6 16 | invisible-watermark>=0.1.5 17 | streamlit-drawable-canvas==0.8.0 18 | 19 | diffusers 20 | bitstring 21 | lpips 22 | -q kaggle 23 | accelerate 24 | SSIM-PIL 25 | 26 | 27 | -e . 28 | -------------------------------------------------------------------------------- /ldm/data/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ldm.modules.midas.api import load_midas_transform 4 | 5 | 6 | class AddMiDaS(object): 7 | def __init__(self, model_type): 8 | super().__init__() 9 | self.transform = load_midas_transform(model_type) 10 | 11 | def pt2np(self, x): 12 | x = ((x + 1.0) * .5).detach().cpu().numpy() 13 | return x 14 | 15 | def np2pt(self, x): 16 | x = torch.from_numpy(x) * 2 - 1. 17 | return x 18 | 19 | def __call__(self, sample): 20 | # sample['jpg'] is tensor hwc in [-1, 1] at this point 21 | x = self.pt2np(sample['jpg']) 22 | x = self.transform({"image": x})["image"] 23 | sample['midas_in'] = x 24 | return sample -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: ldm 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.12.1 10 | - torchvision=0.13.1 11 | - numpy=1.23.1 12 | - pip: 13 | - albumentations==1.3.0 14 | - opencv-python==4.6.0.66 15 | - imageio==2.9.0 16 | - imageio-ffmpeg==0.4.2 17 | - pytorch-lightning==1.4.2 18 | - omegaconf==2.1.1 19 | - test-tube>=0.7.5 20 | - streamlit==1.12.1 21 | - einops==0.3.0 22 | - transformers==4.19.2 23 | - webdataset==0.2.5 24 | - kornia==0.6 25 | - open_clip_torch==2.0.2 26 | - invisible-watermark>=0.1.5 27 | - streamlit-drawable-canvas==0.8.0 28 | - torchmetrics==0.6.0 29 | - -e . 30 | -------------------------------------------------------------------------------- /ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def norm_thresholding(x0, value): 15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 16 | return x0 * (value / s) 17 | 18 | 19 | def spatial_norm_thresholding(x0, value): 20 | # b c h w 21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 22 | return x0 * (value / s) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Stability AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion) 3 | # ------------------------------------------------------------------------------------ 4 | 5 | 6 | from .diffusion import gaussian_diffusion as gd 7 | from .diffusion.respace import ( 8 | SpacedDiffusion, 9 | space_timesteps, 10 | ) 11 | 12 | 13 | def create_gaussian_diffusion( 14 | steps, 15 | learn_sigma, 16 | sigma_small, 17 | noise_schedule, 18 | use_kl, 19 | predict_xstart, 20 | rescale_learned_sigmas, 21 | timestep_respacing, 22 | ): 23 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 24 | if use_kl: 25 | loss_type = gd.LossType.RESCALED_KL 26 | elif rescale_learned_sigmas: 27 | loss_type = gd.LossType.RESCALED_MSE 28 | else: 29 | loss_type = gd.LossType.MSE 30 | if not timestep_respacing: 31 | timestep_respacing = [steps] 32 | 33 | return SpacedDiffusion( 34 | use_timesteps=space_timesteps(steps, timestep_respacing), 35 | betas=betas, 36 | model_mean_type=( 37 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 38 | ), 39 | model_var_type=( 40 | ( 41 | gd.ModelVarType.FIXED_LARGE 42 | if not sigma_small 43 | else gd.ModelVarType.FIXED_SMALL 44 | ) 45 | if not learn_sigma 46 | else gd.ModelVarType.LEARNED_RANGE 47 | ), 48 | loss_type=loss_type, 49 | ) 50 | -------------------------------------------------------------------------------- /configs/stable-diffusion/v2-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False # we set this to false because this is an inference only config 19 | 20 | unet_config: 21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 22 | params: 23 | use_checkpoint: True 24 | #use_fp16: True 25 | image_size: 32 # unused 26 | in_channels: 4 27 | out_channels: 4 28 | model_channels: 320 29 | attention_resolutions: [ 4, 2, 1 ] 30 | num_res_blocks: 2 31 | channel_mult: [ 1, 2, 4, 4 ] 32 | num_head_channels: 64 # need to fix for flash-attn 33 | use_spatial_transformer: True 34 | use_linear_in_transformer: True 35 | transformer_depth: 1 36 | context_dim: 1024 37 | legacy: False 38 | 39 | first_stage_config: 40 | target: ldm.models.autoencoder.AutoencoderKL 41 | params: 42 | embed_dim: 4 43 | monitor: val/rec_loss 44 | ddconfig: 45 | #attn_type: "vanilla-xformers" 46 | double_z: true 47 | z_channels: 4 48 | resolution: 256 49 | in_channels: 3 50 | out_ch: 3 51 | ch: 128 52 | ch_mult: 53 | - 1 54 | - 2 55 | - 4 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: [] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: 64 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 65 | params: 66 | freeze: True 67 | layer: "penultimate" 68 | -------------------------------------------------------------------------------- /configs/stable-diffusion/v2-inference-v.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | parameterization: "v" 6 | linear_start: 0.00085 7 | linear_end: 0.0120 8 | num_timesteps_cond: 1 9 | log_every_t: 200 10 | timesteps: 1000 11 | first_stage_key: "jpg" 12 | cond_stage_key: "txt" 13 | image_size: 64 14 | channels: 4 15 | cond_stage_trainable: false 16 | conditioning_key: crossattn 17 | monitor: val/loss_simple_ema 18 | scale_factor: 0.18215 19 | use_ema: False # we set this to false because this is an inference only config 20 | 21 | unet_config: 22 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 23 | params: 24 | use_checkpoint: True 25 | use_fp16: True 26 | image_size: 32 # unused 27 | in_channels: 4 28 | out_channels: 4 29 | model_channels: 320 30 | attention_resolutions: [ 4, 2, 1 ] 31 | num_res_blocks: 2 32 | channel_mult: [ 1, 2, 4, 4 ] 33 | num_head_channels: 64 # need to fix for flash-attn 34 | use_spatial_transformer: True 35 | use_linear_in_transformer: True 36 | transformer_depth: 1 37 | context_dim: 1024 38 | legacy: False 39 | 40 | first_stage_config: 41 | target: ldm.models.autoencoder.AutoencoderKL 42 | params: 43 | embed_dim: 4 44 | monitor: val/rec_loss 45 | ddconfig: 46 | #attn_type: "vanilla-xformers" 47 | double_z: true 48 | z_channels: 4 49 | resolution: 256 50 | in_channels: 3 51 | out_ch: 3 52 | ch: 128 53 | ch_mult: 54 | - 1 55 | - 2 56 | - 4 57 | - 4 58 | num_res_blocks: 2 59 | attn_resolutions: [] 60 | dropout: 0.0 61 | lossconfig: 62 | target: torch.nn.Identity 63 | 64 | cond_stage_config: 65 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 66 | params: 67 | freeze: True 68 | layer: "penultimate" 69 | -------------------------------------------------------------------------------- /configs/stable-diffusion/intel/v2-inference-fp32.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022 Intel Corporation 2 | # SPDX-License-Identifier: MIT 3 | 4 | model: 5 | base_learning_rate: 1.0e-4 6 | target: ldm.models.diffusion.ddpm.LatentDiffusion 7 | params: 8 | linear_start: 0.00085 9 | linear_end: 0.0120 10 | num_timesteps_cond: 1 11 | log_every_t: 200 12 | timesteps: 1000 13 | first_stage_key: "jpg" 14 | cond_stage_key: "txt" 15 | image_size: 64 16 | channels: 4 17 | cond_stage_trainable: false 18 | conditioning_key: crossattn 19 | monitor: val/loss_simple_ema 20 | scale_factor: 0.18215 21 | use_ema: False # we set this to false because this is an inference only config 22 | 23 | unet_config: 24 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 25 | params: 26 | use_checkpoint: False 27 | use_fp16: False 28 | image_size: 32 # unused 29 | in_channels: 4 30 | out_channels: 4 31 | model_channels: 320 32 | attention_resolutions: [ 4, 2, 1 ] 33 | num_res_blocks: 2 34 | channel_mult: [ 1, 2, 4, 4 ] 35 | num_head_channels: 64 # need to fix for flash-attn 36 | use_spatial_transformer: True 37 | use_linear_in_transformer: True 38 | transformer_depth: 1 39 | context_dim: 1024 40 | legacy: False 41 | 42 | first_stage_config: 43 | target: ldm.models.autoencoder.AutoencoderKL 44 | params: 45 | embed_dim: 4 46 | monitor: val/rec_loss 47 | ddconfig: 48 | #attn_type: "vanilla-xformers" 49 | double_z: true 50 | z_channels: 4 51 | resolution: 256 52 | in_channels: 3 53 | out_ch: 3 54 | ch: 128 55 | ch_mult: 56 | - 1 57 | - 2 58 | - 4 59 | - 4 60 | num_res_blocks: 2 61 | attn_resolutions: [] 62 | dropout: 0.0 63 | lossconfig: 64 | target: torch.nn.Identity 65 | 66 | cond_stage_config: 67 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 68 | params: 69 | freeze: True 70 | layer: "penultimate" 71 | -------------------------------------------------------------------------------- /configs/stable-diffusion/intel/v2-inference-bf16.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022 Intel Corporation 2 | # SPDX-License-Identifier: MIT 3 | 4 | model: 5 | base_learning_rate: 1.0e-4 6 | target: ldm.models.diffusion.ddpm.LatentDiffusion 7 | params: 8 | linear_start: 0.00085 9 | linear_end: 0.0120 10 | num_timesteps_cond: 1 11 | log_every_t: 200 12 | timesteps: 1000 13 | first_stage_key: "jpg" 14 | cond_stage_key: "txt" 15 | image_size: 64 16 | channels: 4 17 | cond_stage_trainable: false 18 | conditioning_key: crossattn 19 | monitor: val/loss_simple_ema 20 | scale_factor: 0.18215 21 | use_ema: False # we set this to false because this is an inference only config 22 | 23 | unet_config: 24 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 25 | params: 26 | use_checkpoint: False 27 | use_fp16: False 28 | use_bf16: True 29 | image_size: 32 # unused 30 | in_channels: 4 31 | out_channels: 4 32 | model_channels: 320 33 | attention_resolutions: [ 4, 2, 1 ] 34 | num_res_blocks: 2 35 | channel_mult: [ 1, 2, 4, 4 ] 36 | num_head_channels: 64 # need to fix for flash-attn 37 | use_spatial_transformer: True 38 | use_linear_in_transformer: True 39 | transformer_depth: 1 40 | context_dim: 1024 41 | legacy: False 42 | 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | embed_dim: 4 47 | monitor: val/rec_loss 48 | ddconfig: 49 | #attn_type: "vanilla-xformers" 50 | double_z: true 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: 57 | - 1 58 | - 2 59 | - 4 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: [] 63 | dropout: 0.0 64 | lossconfig: 65 | target: torch.nn.Identity 66 | 67 | cond_stage_config: 68 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 69 | params: 70 | freeze: True 71 | layer: "penultimate" 72 | -------------------------------------------------------------------------------- /configs/stable-diffusion/intel/v2-inference-v-fp32.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022 Intel Corporation 2 | # SPDX-License-Identifier: MIT 3 | 4 | model: 5 | base_learning_rate: 1.0e-4 6 | target: ldm.models.diffusion.ddpm.LatentDiffusion 7 | params: 8 | parameterization: "v" 9 | linear_start: 0.00085 10 | linear_end: 0.0120 11 | num_timesteps_cond: 1 12 | log_every_t: 200 13 | timesteps: 1000 14 | first_stage_key: "jpg" 15 | cond_stage_key: "txt" 16 | image_size: 64 17 | channels: 4 18 | cond_stage_trainable: false 19 | conditioning_key: crossattn 20 | monitor: val/loss_simple_ema 21 | scale_factor: 0.18215 22 | use_ema: False # we set this to false because this is an inference only config 23 | 24 | unet_config: 25 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 26 | params: 27 | use_checkpoint: False 28 | use_fp16: False 29 | image_size: 32 # unused 30 | in_channels: 4 31 | out_channels: 4 32 | model_channels: 320 33 | attention_resolutions: [ 4, 2, 1 ] 34 | num_res_blocks: 2 35 | channel_mult: [ 1, 2, 4, 4 ] 36 | num_head_channels: 64 # need to fix for flash-attn 37 | use_spatial_transformer: True 38 | use_linear_in_transformer: True 39 | transformer_depth: 1 40 | context_dim: 1024 41 | legacy: False 42 | 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | embed_dim: 4 47 | monitor: val/rec_loss 48 | ddconfig: 49 | #attn_type: "vanilla-xformers" 50 | double_z: true 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: 57 | - 1 58 | - 2 59 | - 4 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: [] 63 | dropout: 0.0 64 | lossconfig: 65 | target: torch.nn.Identity 66 | 67 | cond_stage_config: 68 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 69 | params: 70 | freeze: True 71 | layer: "penultimate" 72 | -------------------------------------------------------------------------------- /configs/stable-diffusion/v1-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder #ldm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder -------------------------------------------------------------------------------- /configs/stable-diffusion/intel/v2-inference-v-bf16.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022 Intel Corporation 2 | # SPDX-License-Identifier: MIT 3 | 4 | model: 5 | base_learning_rate: 1.0e-4 6 | target: ldm.models.diffusion.ddpm.LatentDiffusion 7 | params: 8 | parameterization: "v" 9 | linear_start: 0.00085 10 | linear_end: 0.0120 11 | num_timesteps_cond: 1 12 | log_every_t: 200 13 | timesteps: 1000 14 | first_stage_key: "jpg" 15 | cond_stage_key: "txt" 16 | image_size: 64 17 | channels: 4 18 | cond_stage_trainable: false 19 | conditioning_key: crossattn 20 | monitor: val/loss_simple_ema 21 | scale_factor: 0.18215 22 | use_ema: False # we set this to false because this is an inference only config 23 | 24 | unet_config: 25 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 26 | params: 27 | use_checkpoint: False 28 | use_fp16: False 29 | use_bf16: True 30 | image_size: 32 # unused 31 | in_channels: 4 32 | out_channels: 4 33 | model_channels: 320 34 | attention_resolutions: [ 4, 2, 1 ] 35 | num_res_blocks: 2 36 | channel_mult: [ 1, 2, 4, 4 ] 37 | num_head_channels: 64 # need to fix for flash-attn 38 | use_spatial_transformer: True 39 | use_linear_in_transformer: True 40 | transformer_depth: 1 41 | context_dim: 1024 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 4 48 | monitor: val/rec_loss 49 | ddconfig: 50 | #attn_type: "vanilla-xformers" 51 | double_z: true 52 | z_channels: 4 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: 58 | - 1 59 | - 2 60 | - 4 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | 68 | cond_stage_config: 69 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 70 | params: 71 | freeze: True 72 | layer: "penultimate" 73 | -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/modules/resample.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Modified from Guided-Diffusion (https://github.com/openai/guided-diffusion) 3 | # ------------------------------------------------------------------------------------ 4 | 5 | from abc import abstractmethod 6 | 7 | import torch as th 8 | 9 | 10 | def create_named_schedule_sampler(name, diffusion): 11 | """ 12 | Create a ScheduleSampler from a library of pre-defined samplers. 13 | 14 | :param name: the name of the sampler. 15 | :param diffusion: the diffusion object to sample for. 16 | """ 17 | if name == "uniform": 18 | return UniformSampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(th.nn.Module): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / th.sum(w) 54 | indices = p.multinomial(batch_size, replacement=True) 55 | weights = 1 / (len(p) * p[indices]) 56 | return indices, weights 57 | 58 | 59 | class UniformSampler(ScheduleSampler): 60 | def __init__(self, diffusion): 61 | super(UniformSampler, self).__init__() 62 | self.diffusion = diffusion 63 | self.register_buffer( 64 | "_weights", th.ones([diffusion.num_timesteps]), persistent=False 65 | ) 66 | 67 | def weights(self): 68 | return self._weights 69 | -------------------------------------------------------------------------------- /scripts/evaluate_metrics.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import scripts.metrics as Metrics 3 | from PIL import Image 4 | import numpy as np 5 | import glob 6 | import torch 7 | from tqdm import tqdm 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('-p_gt', '--path_gt', type=str, 13 | default='G:\Giordano\stablediffusion\outputs\img2img-samples\samples-orig-10') 14 | parser.add_argument('-p', '--path', type=str, 15 | default='G:\Giordano\stablediffusion\outputs\img2img-samples\samples-10') 16 | args = parser.parse_args() 17 | real_names = list(glob.glob('{}/*.png'.format(args.path_gt))) 18 | # real_names = list(glob.glob('{}/*.jpg'.format(args.path_gt))) 19 | print(real_names, args.path_gt) 20 | 21 | 22 | 23 | fake_names = list(glob.glob('{}/*.png'.format(args.path))) 24 | 25 | real_names.sort() 26 | fake_names.sort() 27 | 28 | avg_psnr = 0.0 29 | avg_ssim = 0.0 30 | avg_fid = 0.0 31 | fid_img_list_real, fid_img_list_fake = [],[] 32 | idx = 0 33 | for rname, fname in tqdm(zip(real_names, fake_names), total=len(real_names)): 34 | idx += 1 35 | 36 | 37 | 38 | hr_img = np.array(Image.open(rname)) 39 | sr_img = np.array(Image.open(fname)) 40 | psnr = Metrics.calculate_psnr(sr_img, hr_img) 41 | ssim = Metrics.calculate_ssim(sr_img, hr_img) 42 | fid_img_list_real.append(torch.from_numpy(hr_img).permute(2,0,1).unsqueeze(0)) 43 | fid_img_list_fake.append(torch.from_numpy(sr_img).permute(2,0,1).unsqueeze(0)) 44 | avg_psnr += psnr 45 | avg_ssim += ssim 46 | if idx % 10 == 0: 47 | # fid = Metrics.calculate_FID(torch.cat(fid_img_list_real,dim=0), torch.cat(fid_img_list_fake,dim=0)) 48 | # fid_img_list_real, fid_img_list_fake = [],[] 49 | # avg_fid += fid 50 | print('Image:{}, PSNR:{:.4f}, SSIM:{:.4f}'.format(idx, psnr, ssim)) 51 | 52 | 53 | #last FID 54 | fid = Metrics.calculate_FID(torch.cat(fid_img_list_real,dim=0), torch.cat(fid_img_list_fake,dim=0)) 55 | avg_fid += fid 56 | 57 | avg_psnr = avg_psnr / idx 58 | avg_ssim = avg_ssim / idx 59 | # avg_fid = avg_fid / idx 60 | 61 | # log 62 | print('# Validation # PSNR: {}'.format(avg_psnr)) 63 | print('# Validation # SSIM: {}'.format(avg_ssim)) 64 | print('# Validation # FID: {}'.format(avg_fid)) -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1, dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | # remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.', '') 20 | self.m_name2s_name.update({name: s_name}) 21 | self.register_buffer(s_name, p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def reset_num_updates(self): 26 | del self.num_updates 27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 47 | else: 48 | assert not key in self.m_name2s_name 49 | 50 | def copy_to(self, model): 51 | m_param = dict(model.named_parameters()) 52 | shadow_params = dict(self.named_buffers()) 53 | for key in m_param: 54 | if m_param[key].requires_grad: 55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 56 | else: 57 | assert not key in self.m_name2s_name 58 | 59 | def store(self, parameters): 60 | """ 61 | Save the current parameters for restoring later. 62 | Args: 63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 64 | temporarily stored. 65 | """ 66 | self.collected_params = [param.clone() for param in parameters] 67 | 68 | def restore(self, parameters): 69 | """ 70 | Restore the parameters stored with the `store` method. 71 | Useful to validate the model with EMA parameters without affecting the 72 | original optimization process. Store the parameters before the 73 | `copy_to` method. After validation (or model saving), use this to 74 | restore the former parameters. 75 | Args: 76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 77 | updated with the stored parameters. 78 | """ 79 | for c_param, param in zip(self.collected_params, parameters): 80 | param.data.copy_(c_param.data) 81 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/upscaling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from functools import partial 5 | 6 | from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule 7 | from ldm.util import default 8 | 9 | 10 | class AbstractLowScaleModel(nn.Module): 11 | # for concatenating a downsampled image to the latent representation 12 | def __init__(self, noise_schedule_config=None): 13 | super(AbstractLowScaleModel, self).__init__() 14 | if noise_schedule_config is not None: 15 | self.register_schedule(**noise_schedule_config) 16 | 17 | def register_schedule(self, beta_schedule="linear", timesteps=1000, 18 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 19 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, 20 | cosine_s=cosine_s) 21 | alphas = 1. - betas 22 | alphas_cumprod = np.cumprod(alphas, axis=0) 23 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 24 | 25 | timesteps, = betas.shape 26 | self.num_timesteps = int(timesteps) 27 | self.linear_start = linear_start 28 | self.linear_end = linear_end 29 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 30 | 31 | to_torch = partial(torch.tensor, dtype=torch.float32) 32 | 33 | self.register_buffer('betas', to_torch(betas)) 34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 35 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 36 | 37 | # calculations for diffusion q(x_t | x_{t-1}) and others 38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 43 | 44 | def q_sample(self, x_start, t, noise=None): 45 | noise = default(noise, lambda: torch.randn_like(x_start)) 46 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 47 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) 48 | 49 | def forward(self, x): 50 | return x, None 51 | 52 | def decode(self, x): 53 | return x 54 | 55 | 56 | class SimpleImageConcat(AbstractLowScaleModel): 57 | # no noise level conditioning 58 | def __init__(self): 59 | super(SimpleImageConcat, self).__init__(noise_schedule_config=None) 60 | self.max_noise_level = 0 61 | 62 | def forward(self, x): 63 | # fix to constant noise level 64 | return x, torch.zeros(x.shape[0], device=x.device).long() 65 | 66 | 67 | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): 68 | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): 69 | super().__init__(noise_schedule_config=noise_schedule_config) 70 | self.max_noise_level = max_noise_level 71 | 72 | def forward(self, x, noise_level=None): 73 | if noise_level is None: 74 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() 75 | else: 76 | assert isinstance(noise_level, torch.Tensor) 77 | z = self.q_sample(x, noise_level) 78 | return z, noise_level 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/modules/nn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion) 3 | # ------------------------------------------------------------------------------------ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class GroupNorm32(nn.GroupNorm): 13 | def __init__(self, num_groups, num_channels, swish, eps=1e-5): 14 | super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps) 15 | self.swish = swish 16 | 17 | def forward(self, x): 18 | y = super().forward(x.float()).to(x.dtype) 19 | if self.swish == 1.0: 20 | y = F.silu(y) 21 | elif self.swish: 22 | y = y * F.sigmoid(y * float(self.swish)) 23 | return y 24 | 25 | 26 | def conv_nd(dims, *args, **kwargs): 27 | """ 28 | Create a 1D, 2D, or 3D convolution module. 29 | """ 30 | if dims == 1: 31 | return nn.Conv1d(*args, **kwargs) 32 | elif dims == 2: 33 | return nn.Conv2d(*args, **kwargs) 34 | elif dims == 3: 35 | return nn.Conv3d(*args, **kwargs) 36 | raise ValueError(f"unsupported dimensions: {dims}") 37 | 38 | 39 | def linear(*args, **kwargs): 40 | """ 41 | Create a linear module. 42 | """ 43 | return nn.Linear(*args, **kwargs) 44 | 45 | 46 | def avg_pool_nd(dims, *args, **kwargs): 47 | """ 48 | Create a 1D, 2D, or 3D average pooling module. 49 | """ 50 | if dims == 1: 51 | return nn.AvgPool1d(*args, **kwargs) 52 | elif dims == 2: 53 | return nn.AvgPool2d(*args, **kwargs) 54 | elif dims == 3: 55 | return nn.AvgPool3d(*args, **kwargs) 56 | raise ValueError(f"unsupported dimensions: {dims}") 57 | 58 | 59 | def zero_module(module): 60 | """ 61 | Zero out the parameters of a module and return it. 62 | """ 63 | for p in module.parameters(): 64 | p.detach().zero_() 65 | return module 66 | 67 | 68 | def scale_module(module, scale): 69 | """ 70 | Scale the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().mul_(scale) 74 | return module 75 | 76 | 77 | def normalization(channels, swish=0.0): 78 | """ 79 | Make a standard normalization layer, with an optional swish activation. 80 | 81 | :param channels: number of input channels. 82 | :return: an nn.Module for normalization. 83 | """ 84 | return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) 85 | 86 | 87 | def timestep_embedding(timesteps, dim, max_period=10000): 88 | """ 89 | Create sinusoidal timestep embeddings. 90 | 91 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 92 | These may be fractional. 93 | :param dim: the dimension of the output. 94 | :param max_period: controls the minimum frequency of the embeddings. 95 | :return: an [N x dim] Tensor of positional embeddings. 96 | """ 97 | half = dim // 2 98 | freqs = th.exp( 99 | -math.log(max_period) 100 | * th.arange(start=0, end=half, dtype=th.float32, device=timesteps.device) 101 | / half 102 | ) 103 | args = timesteps[:, None].float() * freqs[None] 104 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 105 | if dim % 2: 106 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 107 | return embedding 108 | 109 | 110 | def mean_flat(tensor): 111 | """ 112 | Take the mean over all non-batch dimensions. 113 | """ 114 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 115 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by project 2 | outputs/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # General MacOS 13 | .DS_Store 14 | .AppleDouble 15 | .LSOverride 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | cover/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | .pybuilder/ 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | # For a library or package, you might want to ignore these files since the code is 95 | # intended to run in multiple environments; otherwise, check them in: 96 | # .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # poetry 106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 107 | # This is especially recommended for binary packages to ensure reproducibility, and is more 108 | # commonly ignored for libraries. 109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 110 | #poetry.lock 111 | 112 | # pdm 113 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 114 | #pdm.lock 115 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 116 | # in version control. 117 | # https://pdm.fming.dev/#use-with-ide 118 | .pdm.toml 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # IDEs 164 | .idea/ 165 | .vscode/ 166 | -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | import torch 3 | 4 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 5 | 6 | MODEL_TYPES = { 7 | "eps": "noise", 8 | "v": "v" 9 | } 10 | 11 | 12 | class DPMSolverSampler(object): 13 | def __init__(self, model, device=torch.device("cuda"), **kwargs): 14 | super().__init__() 15 | self.model = model 16 | self.device = device 17 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 18 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 19 | 20 | def register_buffer(self, name, attr): 21 | if type(attr) == torch.Tensor: 22 | if attr.device != self.device: 23 | attr = attr.to(self.device) 24 | setattr(self, name, attr) 25 | 26 | @torch.no_grad() 27 | def sample(self, 28 | S, 29 | batch_size, 30 | shape, 31 | conditioning=None, 32 | callback=None, 33 | normals_sequence=None, 34 | img_callback=None, 35 | quantize_x0=False, 36 | eta=0., 37 | mask=None, 38 | x0=None, 39 | temperature=1., 40 | noise_dropout=0., 41 | score_corrector=None, 42 | corrector_kwargs=None, 43 | verbose=True, 44 | x_T=None, 45 | log_every_t=100, 46 | unconditional_guidance_scale=1., 47 | unconditional_conditioning=None, 48 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 49 | **kwargs 50 | ): 51 | if conditioning is not None: 52 | if isinstance(conditioning, dict): 53 | ctmp = conditioning[list(conditioning.keys())[0]] 54 | while isinstance(ctmp, list): ctmp = ctmp[0] 55 | if isinstance(ctmp, torch.Tensor): 56 | cbs = ctmp.shape[0] 57 | if cbs != batch_size: 58 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 59 | elif isinstance(conditioning, list): 60 | for ctmp in conditioning: 61 | if ctmp.shape[0] != batch_size: 62 | print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}") 63 | else: 64 | if isinstance(conditioning, torch.Tensor): 65 | if conditioning.shape[0] != batch_size: 66 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 67 | 68 | # sampling 69 | C, H, W = shape 70 | size = (batch_size, C, H, W) 71 | 72 | print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 73 | 74 | device = self.model.betas.device 75 | if x_T is None: 76 | img = torch.randn(size, device=device) 77 | else: 78 | img = x_T 79 | 80 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 81 | 82 | model_fn = model_wrapper( 83 | lambda x, t, c: self.model.apply_model(x, t, c), 84 | ns, 85 | model_type=MODEL_TYPES[self.model.parameterization], 86 | guidance_type="classifier-free", 87 | condition=conditioning, 88 | unconditional_condition=unconditional_conditioning, 89 | guidance_scale=unconditional_guidance_scale, 90 | ) 91 | 92 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 93 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, 94 | lower_order_final=True) 95 | 96 | return x.to(device), None 97 | -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/models/sr_64_256.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Karlo-v1.0.alpha 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | import copy 7 | import torch 8 | 9 | from ldm.modules.karlo.kakao.modules.unet import SuperResUNetModel 10 | from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion 11 | 12 | 13 | class ImprovedSupRes64to256ProgressiveModel(torch.nn.Module): 14 | """ 15 | ImprovedSR model fine-tunes the pretrained DDPM-based SR model by using adversarial and perceptual losses. 16 | In specific, the low-resolution sample is iteratively recovered by 6 steps with the frozen pretrained SR model. 17 | In the following additional one step, a seperate fine-tuned model recovers high-frequency details. 18 | This approach greatly improves the fidelity of images of 256x256px, even with small number of reverse steps. 19 | """ 20 | 21 | def __init__(self, config): 22 | super().__init__() 23 | 24 | self._config = config 25 | self._diffusion_kwargs = dict( 26 | steps=config.diffusion.steps, 27 | learn_sigma=config.diffusion.learn_sigma, 28 | sigma_small=config.diffusion.sigma_small, 29 | noise_schedule=config.diffusion.noise_schedule, 30 | use_kl=config.diffusion.use_kl, 31 | predict_xstart=config.diffusion.predict_xstart, 32 | rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas, 33 | ) 34 | 35 | self.model_first_steps = SuperResUNetModel( 36 | in_channels=3, # auto-changed to 6 inside the model 37 | model_channels=config.model.hparams.channels, 38 | out_channels=3, 39 | num_res_blocks=config.model.hparams.depth, 40 | attention_resolutions=(), # no attention 41 | dropout=config.model.hparams.dropout, 42 | channel_mult=config.model.hparams.channels_multiple, 43 | resblock_updown=True, 44 | use_middle_attention=False, 45 | ) 46 | self.model_last_step = SuperResUNetModel( 47 | in_channels=3, # auto-changed to 6 inside the model 48 | model_channels=config.model.hparams.channels, 49 | out_channels=3, 50 | num_res_blocks=config.model.hparams.depth, 51 | attention_resolutions=(), # no attention 52 | dropout=config.model.hparams.dropout, 53 | channel_mult=config.model.hparams.channels_multiple, 54 | resblock_updown=True, 55 | use_middle_attention=False, 56 | ) 57 | 58 | @classmethod 59 | def load_from_checkpoint(cls, config, ckpt_path, strict: bool = True): 60 | ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"] 61 | 62 | model = cls(config) 63 | model.load_state_dict(ckpt, strict=strict) 64 | return model 65 | 66 | def get_sample_fn(self, timestep_respacing): 67 | diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs) 68 | diffusion_kwargs.update(timestep_respacing=timestep_respacing) 69 | diffusion = create_gaussian_diffusion(**diffusion_kwargs) 70 | return diffusion.p_sample_loop_progressive_for_improved_sr 71 | 72 | def forward(self, low_res, timestep_respacing="7", **kwargs): 73 | assert ( 74 | timestep_respacing == "7" 75 | ), "different respacing method may work, but no guaranteed" 76 | 77 | sample_fn = self.get_sample_fn(timestep_respacing) 78 | sample_outputs = sample_fn( 79 | self.model_first_steps, 80 | self.model_last_step, 81 | shape=low_res.shape, 82 | clip_denoised=True, 83 | model_kwargs=dict(low_res=low_res), 84 | **kwargs, 85 | ) 86 | for x in sample_outputs: 87 | sample = x["sample"] 88 | yield sample 89 | -------------------------------------------------------------------------------- /scripts/metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | import cv2 5 | from torchvision.utils import make_grid 6 | from torchmetrics.image.fid import FrechetInceptionDistance 7 | 8 | def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)): 9 | ''' 10 | Converts a torch Tensor into an image Numpy array 11 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 12 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 13 | ''' 14 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp 15 | tensor = (tensor - min_max[0]) / \ 16 | (min_max[1] - min_max[0]) # to range [0,1] 17 | n_dim = tensor.dim() 18 | if n_dim == 4: 19 | n_img = len(tensor) 20 | img_np = make_grid(tensor, nrow=int( 21 | math.sqrt(n_img)), normalize=False).numpy() 22 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB 23 | elif n_dim == 3: 24 | img_np = tensor.numpy() 25 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB 26 | elif n_dim == 2: 27 | img_np = tensor.numpy() 28 | else: 29 | raise TypeError( 30 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 31 | if out_type == np.uint8: 32 | img_np = (img_np * 255.0).round() 33 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 34 | return img_np.astype(out_type) 35 | 36 | 37 | def save_img(img, img_path, mode='RGB'): 38 | cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 39 | # cv2.imwrite(img_path, img) 40 | 41 | 42 | def calculate_psnr(img1, img2): 43 | # img1 and img2 have range [0, 255] 44 | img1 = img1.astype(np.float64) 45 | img2 = img2.astype(np.float64) 46 | mse = np.mean((img1 - img2)**2) 47 | if mse == 0: 48 | return float('inf') 49 | return 20 * math.log10(255.0 / math.sqrt(mse)) 50 | 51 | 52 | def ssim(img1, img2): 53 | C1 = (0.01 * 255)**2 54 | C2 = (0.03 * 255)**2 55 | 56 | img1 = img1.astype(np.float64) 57 | img2 = img2.astype(np.float64) 58 | kernel = cv2.getGaussianKernel(11, 1.5) 59 | window = np.outer(kernel, kernel.transpose()) 60 | 61 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 62 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 63 | mu1_sq = mu1**2 64 | mu2_sq = mu2**2 65 | mu1_mu2 = mu1 * mu2 66 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 67 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 68 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 69 | 70 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 71 | (sigma1_sq + sigma2_sq + C2)) 72 | return ssim_map.mean() 73 | 74 | 75 | def calculate_ssim(img1, img2): 76 | '''calculate SSIM 77 | the same outputs as MATLAB's 78 | img1, img2: [0, 255] 79 | ''' 80 | if not img1.shape == img2.shape: 81 | raise ValueError('Input images must have the same dimensions.') 82 | if img1.ndim == 2: 83 | return ssim(img1, img2) 84 | elif img1.ndim == 3: 85 | if img1.shape[2] == 3: 86 | ssims = [] 87 | for i in range(3): 88 | ssims.append(ssim(img1, img2)) 89 | return np.array(ssims).mean() 90 | elif img1.shape[2] == 1: 91 | return ssim(np.squeeze(img1), np.squeeze(img2)) 92 | else: 93 | raise ValueError('Wrong input image dimensions.') 94 | 95 | fid = FrechetInceptionDistance() 96 | def calculate_FID(imgs_dist1, imgs_dist2): 97 | # generate two slightly overlapping image intensity distributions 98 | fid.update(imgs_dist1, real=True) 99 | fid.update(imgs_dist2, real=False) 100 | # fid.compute() 101 | return fid.compute() -------------------------------------------------------------------------------- /scripts/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import optim 4 | from torch.utils.data import Dataset, DataLoader, random_split 5 | from PIL import Image 6 | from torchvision.transforms import Resize, ToTensor, Normalize, Compose 7 | from matplotlib import pyplot as plt 8 | 9 | from io import open 10 | import unicodedata 11 | import re 12 | 13 | 14 | 15 | 16 | class Flickr8kDataset(Dataset): 17 | def __init__(self,images_dir_path, capt_file_path): 18 | super().__init__() 19 | #read data 20 | data=open(capt_file_path).read().strip().split('\n') 21 | data=data[1:] 22 | 23 | img_filenames_list=[] 24 | captions_list=[] 25 | 26 | for s in data: 27 | templist=s.lower().split(",") 28 | img_path=templist[0] 29 | caption=",".join(s for s in templist[1:]) 30 | caption=self.normalizeString(caption) 31 | img_filenames_list.append(img_path) 32 | captions_list.append(caption) 33 | 34 | self.images_dir_path=images_dir_path 35 | self.img_filenames_list=img_filenames_list 36 | self.captions_list=captions_list 37 | self.length=len(self.captions_list) 38 | self.transform=Compose([Resize((224,224), antialias=True), ToTensor(), Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]) 39 | 40 | 41 | def __len__(self): 42 | return self.length 43 | 44 | #unicode 2 ascii, remove non-letter characters, trim 45 | def normalizeString(self,s): 46 | sres="" 47 | for ch in unicodedata.normalize('NFD', s): 48 | #Return the normal form form ('NFD') for the Unicode string s. 49 | if unicodedata.category(ch) != 'Mn': 50 | # The function in the first part returns the general 51 | # category assigned to the character ch as string. 52 | # "Mn' refers to Mark, Nonspacing 53 | sres+=ch 54 | #sres = re.sub(r"([.!?])", r" \1", sres) 55 | # inserts a space before any occurrence of ".", "!", or "?" in the string sres. 56 | sres = re.sub(r"[^a-zA-Z!?,]+", r" ", sres) 57 | # this line of code replaces any sequence of characters in sres 58 | # that are not letters (a-z or A-Z) or the punctuation marks 59 | # "!", "," or "?" with a single space character. 60 | return sres.strip() 61 | 62 | 63 | 64 | def __getitem__(self,idx): 65 | imgfname,caption=self.img_filenames_list[idx],self.captions_list[idx] 66 | 67 | imgfname=self.images_dir_path+imgfname 68 | 69 | return imgfname, caption 70 | 71 | import os 72 | class Only_images_Flickr8kDataset(Dataset): 73 | def __init__(self,images_dir_path): 74 | super().__init__() 75 | #read data 76 | img_filenames_list=[] 77 | 78 | for root, dirs, files in os.walk(images_dir_path): 79 | for file in files: 80 | if file.endswith('.jpg'): 81 | img_filenames_list.append(file) 82 | 83 | 84 | #print(img_filenames_list) 85 | self.images_dir_path=images_dir_path 86 | self.img_filenames_list=img_filenames_list 87 | self.length = len(img_filenames_list) 88 | 89 | 90 | def __len__(self): 91 | return self.length 92 | 93 | 94 | 95 | def __getitem__(self,idx): 96 | 97 | imgfname = self.img_filenames_list[idx] 98 | imgfname = self.images_dir_path+imgfname 99 | 100 | return imgfname 101 | 102 | 103 | 104 | if __name__ == "__main__": 105 | 106 | 107 | capt_file_path= "path/to/captions.txt" #"G:/Giordano/Flickr8kDataset/captions.txt" 108 | images_dir_path= "path/to/Images" #"G:/Giordano/Flickr8kDataset/Images/" 109 | 110 | dataset=Flickr8kDataset(images_dir_path, capt_file_path) 111 | 112 | 113 | batch_size=1 114 | train_dataloader=DataLoader(dataset=dataset,batch_size=batch_size, shuffle=True) 115 | 116 | 117 | for i in train_dataloader: 118 | print(i[0][0]) 119 | print(i[1]) 120 | break -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 | 5 |
6 | 7 |
8 | 9 | 10 | # Language-Oriented Semantic Latent Representation for Image Transmission 11 | 12 | > [**Language-Oriented Semantic Latent Representation for Image Transmission**](https://arxiv.org/abs/2405.09976) 13 | 14 | > [Giordano Cicchetti](https://www.linkedin.com/in/giordano-cicchetti-1ab73b258/), [Eleonora Grassucci](https://scholar.google.com/citations?user=Jcv0TgQAAAAJ&hl=it), 15 | > [Jihong Park](https://scholar.google.com/citations?user=I0CO72QAAAAJ&hl=en),[Jinho Choi](https://scholar.google.co.uk/citations?user=QzFia5YAAAAJ&hl=en) 16 | > [Sergio Barbarossa](https://scholar.google.com/citations?user=2woHFu8AAAAJ&hl=en),[Danilo Comminiello](https://scholar.google.com/citations?user=H3Y52cMAAAAJ&hl=en) 17 | 18 | 19 | This is the official implementation of the paper: [Language-Oriented Semantic Latent Representation for Image Transmission](https://arxiv.org/abs/2405.09976) 20 | 21 | 22 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2404.05669) 23 | ![Visitors](https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fgithub.com%2Fispamm%2FImg2Img-SC&label=VISITORS&countColor=%23263759) 24 | 25 |
26 | 27 | 28 | 29 | ## News 30 | 31 | 32 | **June 17, 2024** *Code Released* 33 | 34 | ________________________________ 35 | 36 | 37 | 38 | ## Requirements 39 | 40 | Create a dedicated conda environment: 41 | ``` 42 | conda env -n SemanticI2I python=3.9 43 | conda activate SemanticI2I 44 | 45 | ``` 46 | 47 | 48 | You can clone the repository by typing: 49 | ``` 50 | 51 | git clone https://github.com/ispamm/Img2Img-SC.git 52 | cd Img2Img-SC 53 | 54 | ``` 55 | 56 | You can update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running 57 | 58 | ``` 59 | conda install pytorch==1.12.1 torchvision==0.13.1 -c pytorch 60 | pip install transformers==4.19.2 diffusers invisible-watermark 61 | pip install -e . 62 | 63 | ``` 64 | 65 | After that you can install the remaining required packages by running: 66 | 67 | 68 | ``` 69 | 70 | pip install -r requirements.txt 71 | 72 | ``` 73 | 74 | ## Download pretraining models 75 | 76 | Download pretrained checkpoints and copy them into the "/checkpoints" folder. 77 | 78 | ## Img2Img 79 | 80 | The scripts are located in the "/scripts" folder. 81 | 82 | /scripts/semantic_i2i.py refers to the proposed I2I framework that uses latent embedding and image caption. 83 | /scripts/semantic_t2i.py refers to the I2I framework that uses only image caption. 84 | 85 | For testing the img2img framework, change the model and configuration paths inside script files and then use: 86 | 87 | ``` 88 | python /scripts/semantic_i2i.py 89 | 90 | #Or 91 | 92 | python /scripts/semantic_t2i.py 93 | ``` 94 | and adapt the checkpoint and config paths accordingly. 95 | 96 | ## Results 97 | 98 |
99 | 100 | 101 | 102 |
103 | 104 |
105 | 106 | 107 | 108 |
109 | 110 |
111 | 112 | 113 | 114 |
115 | 116 | 117 | ## License 118 | 119 | The code in this repository is released under the MIT License. 120 | 121 | ## Acknowledgment 122 | 123 | Most of the code contained in this repository is based on Stable diffusion repository https://github.com/Stability-AI/stablediffusion 124 | 125 | 126 | ## BibTeX 127 | 128 | ``` 129 | @misc{cicchetti2024languageoriented, 130 | title={Language-Oriented Semantic Latent Representation for Image Transmission}, 131 | author={Giordano Cicchetti and Eleonora Grassucci and Jihong Park and Jinho Choi and Sergio Barbarossa and Danilo Comminiello}, 132 | year={2024}, 133 | eprint={2405.09976}, 134 | archivePrefix={arXiv}, 135 | primaryClass={id='cs.CV' full_name='Computer Vision and Pattern Recognition' is_active=True alt_name=None in_archive='cs' is_general=False description='Covers image processing, computer vision, pattern recognition, and scene understanding. Roughly includes material in ACM Subject Classes I.2.10, I.4, and I.5.'} 136 | } 137 | ``` 138 | 139 | 140 | -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/template.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Karlo-v1.0.alpha 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | import os 7 | import logging 8 | import torch 9 | 10 | from omegaconf import OmegaConf 11 | 12 | from ldm.modules.karlo.kakao.models.clip import CustomizedCLIP, CustomizedTokenizer 13 | from ldm.modules.karlo.kakao.models.prior_model import PriorDiffusionModel 14 | from ldm.modules.karlo.kakao.models.decoder_model import Text2ImProgressiveModel 15 | from ldm.modules.karlo.kakao.models.sr_64_256 import ImprovedSupRes64to256ProgressiveModel 16 | 17 | 18 | SAMPLING_CONF = { 19 | "default": { 20 | "prior_sm": "25", 21 | "prior_n_samples": 1, 22 | "prior_cf_scale": 4.0, 23 | "decoder_sm": "50", 24 | "decoder_cf_scale": 8.0, 25 | "sr_sm": "7", 26 | }, 27 | "fast": { 28 | "prior_sm": "25", 29 | "prior_n_samples": 1, 30 | "prior_cf_scale": 4.0, 31 | "decoder_sm": "25", 32 | "decoder_cf_scale": 8.0, 33 | "sr_sm": "7", 34 | }, 35 | } 36 | 37 | CKPT_PATH = { 38 | "prior": "prior-ckpt-step=01000000-of-01000000.ckpt", 39 | "decoder": "decoder-ckpt-step=01000000-of-01000000.ckpt", 40 | "sr_256": "improved-sr-ckpt-step=1.2M.ckpt", 41 | } 42 | 43 | 44 | class BaseSampler: 45 | _PRIOR_CLASS = PriorDiffusionModel 46 | _DECODER_CLASS = Text2ImProgressiveModel 47 | _SR256_CLASS = ImprovedSupRes64to256ProgressiveModel 48 | 49 | def __init__( 50 | self, 51 | root_dir: str, 52 | sampling_type: str = "fast", 53 | ): 54 | self._root_dir = root_dir 55 | 56 | sampling_type = SAMPLING_CONF[sampling_type] 57 | self._prior_sm = sampling_type["prior_sm"] 58 | self._prior_n_samples = sampling_type["prior_n_samples"] 59 | self._prior_cf_scale = sampling_type["prior_cf_scale"] 60 | 61 | assert self._prior_n_samples == 1 62 | 63 | self._decoder_sm = sampling_type["decoder_sm"] 64 | self._decoder_cf_scale = sampling_type["decoder_cf_scale"] 65 | 66 | self._sr_sm = sampling_type["sr_sm"] 67 | 68 | def __repr__(self): 69 | line = "" 70 | line += f"Prior, sampling method: {self._prior_sm}, cf_scale: {self._prior_cf_scale}\n" 71 | line += f"Decoder, sampling method: {self._decoder_sm}, cf_scale: {self._decoder_cf_scale}\n" 72 | line += f"SR(64->256), sampling method: {self._sr_sm}" 73 | 74 | return line 75 | 76 | def load_clip(self, clip_path: str): 77 | clip = CustomizedCLIP.load_from_checkpoint( 78 | os.path.join(self._root_dir, clip_path) 79 | ) 80 | clip = torch.jit.script(clip) 81 | clip.cuda() 82 | clip.eval() 83 | 84 | self._clip = clip 85 | self._tokenizer = CustomizedTokenizer() 86 | 87 | def load_prior( 88 | self, 89 | ckpt_path: str, 90 | clip_stat_path: str, 91 | prior_config: str = "configs/prior_1B_vit_l.yaml" 92 | ): 93 | logging.info(f"Loading prior: {ckpt_path}") 94 | 95 | config = OmegaConf.load(prior_config) 96 | clip_mean, clip_std = torch.load( 97 | os.path.join(self._root_dir, clip_stat_path), map_location="cpu" 98 | ) 99 | 100 | prior = self._PRIOR_CLASS.load_from_checkpoint( 101 | config, 102 | self._tokenizer, 103 | clip_mean, 104 | clip_std, 105 | os.path.join(self._root_dir, ckpt_path), 106 | strict=True, 107 | ) 108 | prior.cuda() 109 | prior.eval() 110 | logging.info("done.") 111 | 112 | self._prior = prior 113 | 114 | def load_decoder(self, ckpt_path: str, decoder_config: str = "configs/decoder_900M_vit_l.yaml"): 115 | logging.info(f"Loading decoder: {ckpt_path}") 116 | 117 | config = OmegaConf.load(decoder_config) 118 | decoder = self._DECODER_CLASS.load_from_checkpoint( 119 | config, 120 | self._tokenizer, 121 | os.path.join(self._root_dir, ckpt_path), 122 | strict=True, 123 | ) 124 | decoder.cuda() 125 | decoder.eval() 126 | logging.info("done.") 127 | 128 | self._decoder = decoder 129 | 130 | def load_sr_64_256(self, ckpt_path: str, sr_config: str = "configs/improved_sr_64_256_1.4B.yaml"): 131 | logging.info(f"Loading SR(64->256): {ckpt_path}") 132 | 133 | config = OmegaConf.load(sr_config) 134 | sr = self._SR256_CLASS.load_from_checkpoint( 135 | config, os.path.join(self._root_dir, ckpt_path), strict=True 136 | ) 137 | sr.cuda() 138 | sr.eval() 139 | logging.info("done.") 140 | 141 | self._sr_64_256 = sr -------------------------------------------------------------------------------- /scripts/qam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.special import erfc 3 | import random 4 | import torch 5 | import math 6 | import bitstring 7 | import numpy as np 8 | 9 | ''' 10 | This python file is dedicated to the QAM Modulation and noisy channel estimation 11 | 12 | we used K-ary discerete memoryless channel (DMC) where the crossover probability is given by 13 | the bit error rate (BER) of the M-ary quadrature amplitude modulation (QAM). 14 | 15 | For DMC, you can find the channel model from (9) in https://ieeexplore.ieee.org/abstract/document/10437659. 16 | 17 | For the crossover probability, you assumed an AWGN channel where the BER is a Q-function 18 | of the SNR and M: https://www.etti.unibw.de/labalive/experiment/qam/. 19 | 20 | ''' 21 | 22 | # Modulate Tensor in 16QAM transmission and noisy channel conditions 23 | def qam16ModulationTensor(input_tensor,snr_db=10): 24 | 25 | message_shape = input_tensor.shape 26 | 27 | message = input_tensor 28 | 29 | #Convert tensor in bitstream 30 | bit_list = tensor2bin(message) 31 | 32 | #Introduce noise to the bitstream according to SNR 33 | bit_list_noisy = introduce_noise(bit_list,snr=snr_db) 34 | 35 | #Convert bitstream back to tensor 36 | back_to_tensor = bin2tensor(bit_list_noisy) 37 | 38 | return back_to_tensor.reshape(message_shape) 39 | 40 | 41 | # Modulate String in 16QAM transmission and noisy channel conditions 42 | def qam16ModulationString(input_tensor,snr_db=10): 43 | 44 | message = input_tensor 45 | 46 | #Convert string to bitstream 47 | bit_list = list2bin(message) 48 | 49 | #Introduce noise to the bitstream according to SNR 50 | bit_list_noisy = introduce_noise(bit_list,snr=snr_db) 51 | 52 | #Convert bitstream back to list of char 53 | back_to_tensor = bin2list(bit_list_noisy) 54 | 55 | return "".join(back_to_tensor) 56 | 57 | 58 | 59 | 60 | 61 | def introduce_noise(bit_list,snr=10,qam=16): 62 | 63 | # Compute ebno according to SNR 64 | ebno = 10 ** (snr/10) 65 | 66 | # Estimate probability of bit error according to https://www.etti.unibw.de/labalive/experiment/qam/ 67 | K = np.sqrt(qam) # 4 68 | M = 2 ** K 69 | Pm = (1 - 1 / np.sqrt(M)) * erfc(np.sqrt(3 / 2 / (M - 1) * K * ebno)) 70 | Ps_qam = 1 - (1 - Pm) ** 2 71 | Pb_qam = Ps_qam / K 72 | 73 | bit_flipped = 0 74 | bit_tot = 0 75 | new_list = [] 76 | for num in bit_list: 77 | num_new = [] 78 | for b in num: 79 | 80 | if random.random() < Pb_qam: 81 | num_new.append(str(1 - int(b))) # Flipping the bit 82 | bit_flipped+=1 83 | else: 84 | num_new.append(b) 85 | bit_tot+=1 86 | new_list.append(''.join(num_new)) 87 | 88 | #print(bit_flipped/bit_tot) 89 | return new_list 90 | 91 | 92 | 93 | 94 | 95 | def bin2float(b): 96 | ''' Convert binary string to a float. 97 | 98 | Attributes: 99 | :b: Binary string to transform. 100 | ''' 101 | 102 | num = bitstring.BitArray(bin=b).float 103 | 104 | #print(num) 105 | if math.isnan(num) or math.isinf(num): 106 | 107 | num = np.random.randn() 108 | 109 | 110 | if num > 10: 111 | 112 | num=np.random.randn() 113 | 114 | if num < -10: 115 | 116 | num=np.random.randn() 117 | 118 | if num < 1e-2 and num>-1e-2: 119 | 120 | num = np.random.randn() 121 | 122 | return num 123 | 124 | 125 | def float2bin(f): 126 | ''' Convert float to 64-bit binary string. 127 | 128 | Attributes: 129 | :f: Float number to transform. 130 | ''' 131 | 132 | f1 = bitstring.BitArray(float=f, length=64) 133 | return f1.bin 134 | 135 | 136 | def tensor2bin(tensor): 137 | 138 | tensor_flattened = tensor.view(-1).numpy() 139 | 140 | bit_list = [] 141 | for number in tensor_flattened: 142 | bit_list.append(float2bin(number)) 143 | 144 | 145 | return bit_list 146 | 147 | 148 | def bin2tensor(input_list): 149 | tensor_reconstructed = [bin2float(bin) for bin in input_list] 150 | return torch.FloatTensor(tensor_reconstructed) 151 | 152 | 153 | def string2int(char): 154 | return ord(char) 155 | 156 | 157 | def int2bin(int_num): 158 | return '{0:08b}'.format(int_num) 159 | 160 | def int2string(int_num): 161 | return chr(int_num) 162 | 163 | def bin2int(bin_num): 164 | return int(bin_num, 2) 165 | 166 | 167 | def list2bin(input_list): 168 | 169 | bit_list = [] 170 | for number in input_list: 171 | bit_list.append(int2bin(string2int(number))) 172 | 173 | return bit_list 174 | 175 | def bin2list(input_list): 176 | list_reconstructed = [int2string(bin2int(bin)) for bin in input_list] 177 | return list_reconstructed 178 | 179 | 180 | 181 | 182 | 183 | 184 | -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/modules/diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion) 3 | # ------------------------------------------------------------------------------------ 4 | 5 | 6 | import torch as th 7 | 8 | from .gaussian_diffusion import GaussianDiffusion 9 | 10 | 11 | def space_timesteps(num_timesteps, section_counts): 12 | """ 13 | Create a list of timesteps to use from an original diffusion process, 14 | given the number of timesteps we want to take from equally-sized portions 15 | of the original process. 16 | 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | 21 | :param num_timesteps: the number of diffusion steps in the original 22 | process to divide up. 23 | :param section_counts: either a list of numbers, or a string containing 24 | comma-separated numbers, indicating the step count 25 | per section. As a special case, use "ddimN" where N 26 | is a number of steps to use the striding from the 27 | DDIM paper. 28 | :return: a set of diffusion steps from the original process to use. 29 | """ 30 | if isinstance(section_counts, str): 31 | if section_counts.startswith("ddim"): 32 | desired_count = int(section_counts[len("ddim") :]) 33 | for i in range(1, num_timesteps): 34 | if len(range(0, num_timesteps, i)) == desired_count: 35 | return set(range(0, num_timesteps, i)) 36 | raise ValueError( 37 | f"cannot create exactly {num_timesteps} steps with an integer stride" 38 | ) 39 | elif section_counts == "fast27": 40 | steps = space_timesteps(num_timesteps, "10,10,3,2,2") 41 | # Help reduce DDIM artifacts from noisiest timesteps. 42 | steps.remove(num_timesteps - 1) 43 | steps.add(num_timesteps - 3) 44 | return steps 45 | section_counts = [int(x) for x in section_counts.split(",")] 46 | size_per = num_timesteps // len(section_counts) 47 | extra = num_timesteps % len(section_counts) 48 | start_idx = 0 49 | all_steps = [] 50 | for i, section_count in enumerate(section_counts): 51 | size = size_per + (1 if i < extra else 0) 52 | if size < section_count: 53 | raise ValueError( 54 | f"cannot divide section of {size} steps into {section_count}" 55 | ) 56 | if section_count <= 1: 57 | frac_stride = 1 58 | else: 59 | frac_stride = (size - 1) / (section_count - 1) 60 | cur_idx = 0.0 61 | taken_steps = [] 62 | for _ in range(section_count): 63 | taken_steps.append(start_idx + round(cur_idx)) 64 | cur_idx += frac_stride 65 | all_steps += taken_steps 66 | start_idx += size 67 | return set(all_steps) 68 | 69 | 70 | class SpacedDiffusion(GaussianDiffusion): 71 | """ 72 | A diffusion process which can skip steps in a base diffusion process. 73 | 74 | :param use_timesteps: a collection (sequence or set) of timesteps from the 75 | original diffusion process to retain. 76 | :param kwargs: the kwargs to create the base diffusion process. 77 | """ 78 | 79 | def __init__(self, use_timesteps, **kwargs): 80 | self.use_timesteps = set(use_timesteps) 81 | self.original_num_steps = len(kwargs["betas"]) 82 | 83 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 84 | last_alpha_cumprod = 1.0 85 | new_betas = [] 86 | timestep_map = [] 87 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 88 | if i in self.use_timesteps: 89 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 90 | last_alpha_cumprod = alpha_cumprod 91 | timestep_map.append(i) 92 | kwargs["betas"] = th.tensor(new_betas).numpy() 93 | super().__init__(**kwargs) 94 | self.register_buffer("timestep_map", th.tensor(timestep_map), persistent=False) 95 | 96 | def p_mean_variance(self, model, *args, **kwargs): 97 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 98 | 99 | def condition_mean(self, cond_fn, *args, **kwargs): 100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 101 | 102 | def condition_score(self, cond_fn, *args, **kwargs): 103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 104 | 105 | def _wrap_model(self, model): 106 | def wrapped(x, ts, **kwargs): 107 | ts_cpu = ts.detach().to("cpu") 108 | return model( 109 | x, self.timestep_map[ts_cpu].to(device=ts.device, dtype=ts.dtype), **kwargs 110 | ) 111 | 112 | return wrapped 113 | -------------------------------------------------------------------------------- /ldm/modules/midas/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth.""" 2 | import sys 3 | import re 4 | import numpy as np 5 | import cv2 6 | import torch 7 | 8 | 9 | def read_pfm(path): 10 | """Read pfm file. 11 | 12 | Args: 13 | path (str): path to file 14 | 15 | Returns: 16 | tuple: (data, scale) 17 | """ 18 | with open(path, "rb") as file: 19 | 20 | color = None 21 | width = None 22 | height = None 23 | scale = None 24 | endian = None 25 | 26 | header = file.readline().rstrip() 27 | if header.decode("ascii") == "PF": 28 | color = True 29 | elif header.decode("ascii") == "Pf": 30 | color = False 31 | else: 32 | raise Exception("Not a PFM file: " + path) 33 | 34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 35 | if dim_match: 36 | width, height = list(map(int, dim_match.groups())) 37 | else: 38 | raise Exception("Malformed PFM header.") 39 | 40 | scale = float(file.readline().decode("ascii").rstrip()) 41 | if scale < 0: 42 | # little-endian 43 | endian = "<" 44 | scale = -scale 45 | else: 46 | # big-endian 47 | endian = ">" 48 | 49 | data = np.fromfile(file, endian + "f") 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | 55 | return data, scale 56 | 57 | 58 | def write_pfm(path, image, scale=1): 59 | """Write pfm file. 60 | 61 | Args: 62 | path (str): pathto file 63 | image (array): data 64 | scale (int, optional): Scale. Defaults to 1. 65 | """ 66 | 67 | with open(path, "wb") as file: 68 | color = None 69 | 70 | if image.dtype.name != "float32": 71 | raise Exception("Image dtype must be float32.") 72 | 73 | image = np.flipud(image) 74 | 75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 76 | color = True 77 | elif ( 78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 79 | ): # greyscale 80 | color = False 81 | else: 82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 83 | 84 | file.write("PF\n" if color else "Pf\n".encode()) 85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 86 | 87 | endian = image.dtype.byteorder 88 | 89 | if endian == "<" or endian == "=" and sys.byteorder == "little": 90 | scale = -scale 91 | 92 | file.write("%f\n".encode() % scale) 93 | 94 | image.tofile(file) 95 | 96 | 97 | def read_image(path): 98 | """Read image and output RGB image (0-1). 99 | 100 | Args: 101 | path (str): path to file 102 | 103 | Returns: 104 | array: RGB image (0-1) 105 | """ 106 | img = cv2.imread(path) 107 | 108 | if img.ndim == 2: 109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 110 | 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 112 | 113 | return img 114 | 115 | 116 | def resize_image(img): 117 | """Resize image and make it fit for network. 118 | 119 | Args: 120 | img (array): image 121 | 122 | Returns: 123 | tensor: data ready for network 124 | """ 125 | height_orig = img.shape[0] 126 | width_orig = img.shape[1] 127 | 128 | if width_orig > height_orig: 129 | scale = width_orig / 384 130 | else: 131 | scale = height_orig / 384 132 | 133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 135 | 136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 137 | 138 | img_resized = ( 139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 140 | ) 141 | img_resized = img_resized.unsqueeze(0) 142 | 143 | return img_resized 144 | 145 | 146 | def resize_depth(depth, width, height): 147 | """Resize depth map and bring to CPU (numpy). 148 | 149 | Args: 150 | depth (tensor): depth 151 | width (int): image width 152 | height (int): image height 153 | 154 | Returns: 155 | array: processed depth 156 | """ 157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 158 | 159 | depth_resized = cv2.resize( 160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 161 | ) 162 | 163 | return depth_resized 164 | 165 | def write_depth(path, depth, bits=1): 166 | """Write depth map to pfm and png file. 167 | 168 | Args: 169 | path (str): filepath without extension 170 | depth (array): depth 171 | """ 172 | write_pfm(path + ".pfm", depth.astype(np.float32)) 173 | 174 | depth_min = depth.min() 175 | depth_max = depth.max() 176 | 177 | max_val = (2**(8*bits))-1 178 | 179 | if depth_max - depth_min > np.finfo("float").eps: 180 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 181 | else: 182 | out = np.zeros(depth.shape, dtype=depth.type) 183 | 184 | if bits == 1: 185 | cv2.imwrite(path + ".png", out.astype("uint8")) 186 | elif bits == 2: 187 | cv2.imwrite(path + ".png", out.astype("uint16")) 188 | 189 | return 190 | -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/models/prior_model.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Karlo-v1.0.alpha 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | import copy 7 | import torch 8 | 9 | from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion 10 | from ldm.modules.karlo.kakao.modules.xf import PriorTransformer 11 | 12 | 13 | class PriorDiffusionModel(torch.nn.Module): 14 | """ 15 | A prior that generates clip image feature based on the text prompt. 16 | 17 | :param config: yaml config to define the decoder. 18 | :param tokenizer: tokenizer used in clip. 19 | :param clip_mean: mean to normalize the clip image feature (zero-mean, unit variance). 20 | :param clip_std: std to noramlize the clip image feature (zero-mean, unit variance). 21 | """ 22 | 23 | def __init__(self, config, tokenizer, clip_mean, clip_std): 24 | super().__init__() 25 | 26 | self._conf = config 27 | self._model_conf = config.model.hparams 28 | self._diffusion_kwargs = dict( 29 | steps=config.diffusion.steps, 30 | learn_sigma=config.diffusion.learn_sigma, 31 | sigma_small=config.diffusion.sigma_small, 32 | noise_schedule=config.diffusion.noise_schedule, 33 | use_kl=config.diffusion.use_kl, 34 | predict_xstart=config.diffusion.predict_xstart, 35 | rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas, 36 | timestep_respacing=config.diffusion.timestep_respacing, 37 | ) 38 | self._tokenizer = tokenizer 39 | 40 | self.register_buffer("clip_mean", clip_mean[None, :], persistent=False) 41 | self.register_buffer("clip_std", clip_std[None, :], persistent=False) 42 | 43 | causal_mask = self.get_causal_mask() 44 | self.register_buffer("causal_mask", causal_mask, persistent=False) 45 | 46 | self.model = PriorTransformer( 47 | text_ctx=self._model_conf.text_ctx, 48 | xf_width=self._model_conf.xf_width, 49 | xf_layers=self._model_conf.xf_layers, 50 | xf_heads=self._model_conf.xf_heads, 51 | xf_final_ln=self._model_conf.xf_final_ln, 52 | clip_dim=self._model_conf.clip_dim, 53 | ) 54 | 55 | cf_token, cf_mask = self.set_cf_text_tensor() 56 | self.register_buffer("cf_token", cf_token, persistent=False) 57 | self.register_buffer("cf_mask", cf_mask, persistent=False) 58 | 59 | @classmethod 60 | def load_from_checkpoint( 61 | cls, config, tokenizer, clip_mean, clip_std, ckpt_path, strict: bool = True 62 | ): 63 | ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"] 64 | 65 | model = cls(config, tokenizer, clip_mean, clip_std) 66 | model.load_state_dict(ckpt, strict=strict) 67 | return model 68 | 69 | def set_cf_text_tensor(self): 70 | return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx) 71 | 72 | def get_sample_fn(self, timestep_respacing): 73 | use_ddim = timestep_respacing.startswith(("ddim", "fast")) 74 | 75 | diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs) 76 | diffusion_kwargs.update(timestep_respacing=timestep_respacing) 77 | diffusion = create_gaussian_diffusion(**diffusion_kwargs) 78 | sample_fn = diffusion.ddim_sample_loop if use_ddim else diffusion.p_sample_loop 79 | 80 | return sample_fn 81 | 82 | def get_causal_mask(self): 83 | seq_len = self._model_conf.text_ctx + 4 84 | mask = torch.empty(seq_len, seq_len) 85 | mask.fill_(float("-inf")) 86 | mask.triu_(1) 87 | mask = mask[None, ...] 88 | return mask 89 | 90 | def forward( 91 | self, 92 | txt_feat, 93 | txt_feat_seq, 94 | mask, 95 | cf_guidance_scales=None, 96 | timestep_respacing=None, 97 | denoised_fn=True, 98 | ): 99 | # cfg should be enabled in inference 100 | assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0) 101 | 102 | bsz_ = txt_feat.shape[0] 103 | bsz = bsz_ // 2 104 | 105 | def guided_model_fn(x_t, ts, **kwargs): 106 | half = x_t[: len(x_t) // 2] 107 | combined = torch.cat([half, half], dim=0) 108 | model_out = self.model(combined, ts, **kwargs) 109 | eps, rest = ( 110 | model_out[:, : int(x_t.shape[1])], 111 | model_out[:, int(x_t.shape[1]) :], 112 | ) 113 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 114 | half_eps = uncond_eps + cf_guidance_scales.view(-1, 1) * ( 115 | cond_eps - uncond_eps 116 | ) 117 | eps = torch.cat([half_eps, half_eps], dim=0) 118 | return torch.cat([eps, rest], dim=1) 119 | 120 | cond = { 121 | "text_emb": txt_feat, 122 | "text_enc": txt_feat_seq, 123 | "mask": mask, 124 | "causal_mask": self.causal_mask, 125 | } 126 | sample_fn = self.get_sample_fn(timestep_respacing) 127 | sample = sample_fn( 128 | guided_model_fn, 129 | (bsz_, self.model.clip_dim), 130 | noise=None, 131 | device=txt_feat.device, 132 | clip_denoised=False, 133 | denoised_fn=lambda x: torch.clamp(x, -10, 10), 134 | model_kwargs=cond, 135 | ) 136 | sample = (sample * self.clip_std) + self.clip_mean 137 | 138 | return sample[:bsz] 139 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /ldm/modules/midas/api.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/isl-org/MiDaS 2 | 3 | import cv2 4 | import torch 5 | import torch.nn as nn 6 | from torchvision.transforms import Compose 7 | 8 | from ldm.modules.midas.midas.dpt_depth import DPTDepthModel 9 | from ldm.modules.midas.midas.midas_net import MidasNet 10 | from ldm.modules.midas.midas.midas_net_custom import MidasNet_small 11 | from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet 12 | 13 | 14 | ISL_PATHS = { 15 | "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", 16 | "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", 17 | "midas_v21": "", 18 | "midas_v21_small": "", 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | def load_midas_transform(model_type): 29 | # https://github.com/isl-org/MiDaS/blob/master/run.py 30 | # load transform only 31 | if model_type == "dpt_large": # DPT-Large 32 | net_w, net_h = 384, 384 33 | resize_mode = "minimal" 34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 35 | 36 | elif model_type == "dpt_hybrid": # DPT-Hybrid 37 | net_w, net_h = 384, 384 38 | resize_mode = "minimal" 39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 40 | 41 | elif model_type == "midas_v21": 42 | net_w, net_h = 384, 384 43 | resize_mode = "upper_bound" 44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 45 | 46 | elif model_type == "midas_v21_small": 47 | net_w, net_h = 256, 256 48 | resize_mode = "upper_bound" 49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 50 | 51 | else: 52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large" 53 | 54 | transform = Compose( 55 | [ 56 | Resize( 57 | net_w, 58 | net_h, 59 | resize_target=None, 60 | keep_aspect_ratio=True, 61 | ensure_multiple_of=32, 62 | resize_method=resize_mode, 63 | image_interpolation_method=cv2.INTER_CUBIC, 64 | ), 65 | normalization, 66 | PrepareForNet(), 67 | ] 68 | ) 69 | 70 | return transform 71 | 72 | 73 | def load_model(model_type): 74 | # https://github.com/isl-org/MiDaS/blob/master/run.py 75 | # load network 76 | model_path = ISL_PATHS[model_type] 77 | if model_type == "dpt_large": # DPT-Large 78 | model = DPTDepthModel( 79 | path=model_path, 80 | backbone="vitl16_384", 81 | non_negative=True, 82 | ) 83 | net_w, net_h = 384, 384 84 | resize_mode = "minimal" 85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 86 | 87 | elif model_type == "dpt_hybrid": # DPT-Hybrid 88 | model = DPTDepthModel( 89 | path=model_path, 90 | backbone="vitb_rn50_384", 91 | non_negative=True, 92 | ) 93 | net_w, net_h = 384, 384 94 | resize_mode = "minimal" 95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 96 | 97 | elif model_type == "midas_v21": 98 | model = MidasNet(model_path, non_negative=True) 99 | net_w, net_h = 384, 384 100 | resize_mode = "upper_bound" 101 | normalization = NormalizeImage( 102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 103 | ) 104 | 105 | elif model_type == "midas_v21_small": 106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, 107 | non_negative=True, blocks={'expand': True}) 108 | net_w, net_h = 256, 256 109 | resize_mode = "upper_bound" 110 | normalization = NormalizeImage( 111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 112 | ) 113 | 114 | else: 115 | print(f"model_type '{model_type}' not implemented, use: --model_type large") 116 | assert False 117 | 118 | transform = Compose( 119 | [ 120 | Resize( 121 | net_w, 122 | net_h, 123 | resize_target=None, 124 | keep_aspect_ratio=True, 125 | ensure_multiple_of=32, 126 | resize_method=resize_mode, 127 | image_interpolation_method=cv2.INTER_CUBIC, 128 | ), 129 | normalization, 130 | PrepareForNet(), 131 | ] 132 | ) 133 | 134 | return model.eval(), transform 135 | 136 | 137 | class MiDaSInference(nn.Module): 138 | MODEL_TYPES_TORCH_HUB = [ 139 | "DPT_Large", 140 | "DPT_Hybrid", 141 | "MiDaS_small" 142 | ] 143 | MODEL_TYPES_ISL = [ 144 | "dpt_large", 145 | "dpt_hybrid", 146 | "midas_v21", 147 | "midas_v21_small", 148 | ] 149 | 150 | def __init__(self, model_type): 151 | super().__init__() 152 | assert (model_type in self.MODEL_TYPES_ISL) 153 | model, _ = load_model(model_type) 154 | self.model = model 155 | self.model.train = disabled_train 156 | 157 | def forward(self, x): 158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array 159 | # NOTE: we expect that the correct transform has been called during dataloading. 160 | with torch.no_grad(): 161 | prediction = self.model(x) 162 | prediction = torch.nn.functional.interpolate( 163 | prediction.unsqueeze(1), 164 | size=x.shape[2:], 165 | mode="bicubic", 166 | align_corners=False, 167 | ) 168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) 169 | return prediction 170 | 171 | -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/models/clip.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Karlo-v1.0.alpha 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | # ------------------------------------------------------------------------------------ 6 | # Adapted from OpenAI's CLIP (https://github.com/openai/CLIP/) 7 | # ------------------------------------------------------------------------------------ 8 | 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import clip 14 | 15 | from clip.model import CLIP, convert_weights 16 | from clip.simple_tokenizer import SimpleTokenizer, default_bpe 17 | 18 | 19 | """===== Monkey-Patching original CLIP for JIT compile =====""" 20 | 21 | 22 | class LayerNorm(nn.LayerNorm): 23 | """Subclass torch's LayerNorm to handle fp16.""" 24 | 25 | def forward(self, x: torch.Tensor): 26 | orig_type = x.dtype 27 | ret = F.layer_norm( 28 | x.type(torch.float32), 29 | self.normalized_shape, 30 | self.weight, 31 | self.bias, 32 | self.eps, 33 | ) 34 | return ret.type(orig_type) 35 | 36 | 37 | clip.model.LayerNorm = LayerNorm 38 | delattr(clip.model.CLIP, "forward") 39 | 40 | """===== End of Monkey-Patching =====""" 41 | 42 | 43 | class CustomizedCLIP(CLIP): 44 | def __init__(self, *args, **kwargs): 45 | super().__init__(*args, **kwargs) 46 | 47 | @torch.jit.export 48 | def encode_image(self, image): 49 | return self.visual(image) 50 | 51 | @torch.jit.export 52 | def encode_text(self, text): 53 | # re-define this function to return unpooled text features 54 | 55 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 56 | 57 | x = x + self.positional_embedding.type(self.dtype) 58 | x = x.permute(1, 0, 2) # NLD -> LND 59 | x = self.transformer(x) 60 | x = x.permute(1, 0, 2) # LND -> NLD 61 | x = self.ln_final(x).type(self.dtype) 62 | 63 | x_seq = x 64 | # x.shape = [batch_size, n_ctx, transformer.width] 65 | # take features from the eot embedding (eot_token is the highest number in each sequence) 66 | x_out = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 67 | 68 | return x_out, x_seq 69 | 70 | @torch.jit.ignore 71 | def forward(self, image, text): 72 | super().forward(image, text) 73 | 74 | @classmethod 75 | def load_from_checkpoint(cls, ckpt_path: str): 76 | state_dict = torch.load(ckpt_path, map_location="cpu").state_dict() 77 | 78 | vit = "visual.proj" in state_dict 79 | if vit: 80 | vision_width = state_dict["visual.conv1.weight"].shape[0] 81 | vision_layers = len( 82 | [ 83 | k 84 | for k in state_dict.keys() 85 | if k.startswith("visual.") and k.endswith(".attn.in_proj_weight") 86 | ] 87 | ) 88 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 89 | grid_size = round( 90 | (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5 91 | ) 92 | image_resolution = vision_patch_size * grid_size 93 | else: 94 | counts: list = [ 95 | len( 96 | set( 97 | k.split(".")[2] 98 | for k in state_dict 99 | if k.startswith(f"visual.layer{b}") 100 | ) 101 | ) 102 | for b in [1, 2, 3, 4] 103 | ] 104 | vision_layers = tuple(counts) 105 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 106 | output_width = round( 107 | (state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5 108 | ) 109 | vision_patch_size = None 110 | assert ( 111 | output_width**2 + 1 112 | == state_dict["visual.attnpool.positional_embedding"].shape[0] 113 | ) 114 | image_resolution = output_width * 32 115 | 116 | embed_dim = state_dict["text_projection"].shape[1] 117 | context_length = state_dict["positional_embedding"].shape[0] 118 | vocab_size = state_dict["token_embedding.weight"].shape[0] 119 | transformer_width = state_dict["ln_final.weight"].shape[0] 120 | transformer_heads = transformer_width // 64 121 | transformer_layers = len( 122 | set( 123 | k.split(".")[2] 124 | for k in state_dict 125 | if k.startswith("transformer.resblocks") 126 | ) 127 | ) 128 | 129 | model = cls( 130 | embed_dim, 131 | image_resolution, 132 | vision_layers, 133 | vision_width, 134 | vision_patch_size, 135 | context_length, 136 | vocab_size, 137 | transformer_width, 138 | transformer_heads, 139 | transformer_layers, 140 | ) 141 | 142 | for key in ["input_resolution", "context_length", "vocab_size"]: 143 | if key in state_dict: 144 | del state_dict[key] 145 | 146 | convert_weights(model) 147 | model.load_state_dict(state_dict) 148 | model.eval() 149 | model.float() 150 | return model 151 | 152 | 153 | class CustomizedTokenizer(SimpleTokenizer): 154 | def __init__(self): 155 | super().__init__(bpe_path=default_bpe()) 156 | 157 | self.sot_token = self.encoder["<|startoftext|>"] 158 | self.eot_token = self.encoder["<|endoftext|>"] 159 | 160 | def padded_tokens_and_mask(self, texts, text_ctx): 161 | assert isinstance(texts, list) and all( 162 | isinstance(elem, str) for elem in texts 163 | ), "texts should be a list of strings" 164 | 165 | all_tokens = [ 166 | [self.sot_token] + self.encode(text) + [self.eot_token] for text in texts 167 | ] 168 | 169 | mask = [ 170 | [True] * min(text_ctx, len(tokens)) 171 | + [False] * max(text_ctx - len(tokens), 0) 172 | for tokens in all_tokens 173 | ] 174 | mask = torch.tensor(mask, dtype=torch.bool) 175 | result = torch.zeros(len(all_tokens), text_ctx, dtype=torch.int) 176 | for i, tokens in enumerate(all_tokens): 177 | if len(tokens) > text_ctx: 178 | tokens = tokens[:text_ctx] 179 | tokens[-1] = self.eot_token 180 | result[i, : len(tokens)] = torch.tensor(tokens) 181 | 182 | return result, mask 183 | -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/models/decoder_model.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Karlo-v1.0.alpha 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | import copy 7 | import torch 8 | 9 | from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion 10 | from ldm.modules.karlo.kakao.modules.unet import PLMImUNet 11 | 12 | 13 | class Text2ImProgressiveModel(torch.nn.Module): 14 | """ 15 | A decoder that generates 64x64px images based on the text prompt. 16 | 17 | :param config: yaml config to define the decoder. 18 | :param tokenizer: tokenizer used in clip. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | config, 24 | tokenizer, 25 | ): 26 | super().__init__() 27 | 28 | self._conf = config 29 | self._model_conf = config.model.hparams 30 | self._diffusion_kwargs = dict( 31 | steps=config.diffusion.steps, 32 | learn_sigma=config.diffusion.learn_sigma, 33 | sigma_small=config.diffusion.sigma_small, 34 | noise_schedule=config.diffusion.noise_schedule, 35 | use_kl=config.diffusion.use_kl, 36 | predict_xstart=config.diffusion.predict_xstart, 37 | rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas, 38 | timestep_respacing=config.diffusion.timestep_respacing, 39 | ) 40 | self._tokenizer = tokenizer 41 | 42 | self.model = self.create_plm_dec_model() 43 | 44 | cf_token, cf_mask = self.set_cf_text_tensor() 45 | self.register_buffer("cf_token", cf_token, persistent=False) 46 | self.register_buffer("cf_mask", cf_mask, persistent=False) 47 | 48 | @classmethod 49 | def load_from_checkpoint(cls, config, tokenizer, ckpt_path, strict: bool = True): 50 | ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"] 51 | 52 | model = cls(config, tokenizer) 53 | model.load_state_dict(ckpt, strict=strict) 54 | return model 55 | 56 | def create_plm_dec_model(self): 57 | image_size = self._model_conf.image_size 58 | if self._model_conf.channel_mult == "": 59 | if image_size == 256: 60 | channel_mult = (1, 1, 2, 2, 4, 4) 61 | elif image_size == 128: 62 | channel_mult = (1, 1, 2, 3, 4) 63 | elif image_size == 64: 64 | channel_mult = (1, 2, 3, 4) 65 | else: 66 | raise ValueError(f"unsupported image size: {image_size}") 67 | else: 68 | channel_mult = tuple( 69 | int(ch_mult) for ch_mult in self._model_conf.channel_mult.split(",") 70 | ) 71 | assert 2 ** (len(channel_mult) + 2) == image_size 72 | 73 | attention_ds = [] 74 | for res in self._model_conf.attention_resolutions.split(","): 75 | attention_ds.append(image_size // int(res)) 76 | 77 | return PLMImUNet( 78 | text_ctx=self._model_conf.text_ctx, 79 | xf_width=self._model_conf.xf_width, 80 | in_channels=3, 81 | model_channels=self._model_conf.num_channels, 82 | out_channels=6 if self._model_conf.learn_sigma else 3, 83 | num_res_blocks=self._model_conf.num_res_blocks, 84 | attention_resolutions=tuple(attention_ds), 85 | dropout=self._model_conf.dropout, 86 | channel_mult=channel_mult, 87 | num_heads=self._model_conf.num_heads, 88 | num_head_channels=self._model_conf.num_head_channels, 89 | num_heads_upsample=self._model_conf.num_heads_upsample, 90 | use_scale_shift_norm=self._model_conf.use_scale_shift_norm, 91 | resblock_updown=self._model_conf.resblock_updown, 92 | clip_dim=self._model_conf.clip_dim, 93 | clip_emb_mult=self._model_conf.clip_emb_mult, 94 | clip_emb_type=self._model_conf.clip_emb_type, 95 | clip_emb_drop=self._model_conf.clip_emb_drop, 96 | ) 97 | 98 | def set_cf_text_tensor(self): 99 | return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx) 100 | 101 | def get_sample_fn(self, timestep_respacing): 102 | use_ddim = timestep_respacing.startswith(("ddim", "fast")) 103 | 104 | diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs) 105 | diffusion_kwargs.update(timestep_respacing=timestep_respacing) 106 | diffusion = create_gaussian_diffusion(**diffusion_kwargs) 107 | sample_fn = ( 108 | diffusion.ddim_sample_loop_progressive 109 | if use_ddim 110 | else diffusion.p_sample_loop_progressive 111 | ) 112 | 113 | return sample_fn 114 | 115 | def forward( 116 | self, 117 | txt_feat, 118 | txt_feat_seq, 119 | tok, 120 | mask, 121 | img_feat=None, 122 | cf_guidance_scales=None, 123 | timestep_respacing=None, 124 | ): 125 | # cfg should be enabled in inference 126 | assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0) 127 | assert img_feat is not None 128 | 129 | bsz = txt_feat.shape[0] 130 | img_sz = self._model_conf.image_size 131 | 132 | def guided_model_fn(x_t, ts, **kwargs): 133 | half = x_t[: len(x_t) // 2] 134 | combined = torch.cat([half, half], dim=0) 135 | model_out = self.model(combined, ts, **kwargs) 136 | eps, rest = model_out[:, :3], model_out[:, 3:] 137 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 138 | half_eps = uncond_eps + cf_guidance_scales.view(-1, 1, 1, 1) * ( 139 | cond_eps - uncond_eps 140 | ) 141 | eps = torch.cat([half_eps, half_eps], dim=0) 142 | return torch.cat([eps, rest], dim=1) 143 | 144 | cf_feat = self.model.cf_param.unsqueeze(0) 145 | cf_feat = cf_feat.expand(bsz // 2, -1) 146 | feat = torch.cat([img_feat, cf_feat.to(txt_feat.device)], dim=0) 147 | 148 | cond = { 149 | "y": feat, 150 | "txt_feat": txt_feat, 151 | "txt_feat_seq": txt_feat_seq, 152 | "mask": mask, 153 | } 154 | sample_fn = self.get_sample_fn(timestep_respacing) 155 | sample_outputs = sample_fn( 156 | guided_model_fn, 157 | (bsz, 3, img_sz, img_sz), 158 | noise=None, 159 | device=txt_feat.device, 160 | clip_denoised=True, 161 | model_kwargs=cond, 162 | ) 163 | 164 | for out in sample_outputs: 165 | sample = out["sample"] 166 | yield sample if cf_guidance_scales is None else sample[ 167 | : sample.shape[0] // 2 168 | ] 169 | 170 | 171 | class Text2ImModel(Text2ImProgressiveModel): 172 | def forward( 173 | self, 174 | txt_feat, 175 | txt_feat_seq, 176 | tok, 177 | mask, 178 | img_feat=None, 179 | cf_guidance_scales=None, 180 | timestep_respacing=None, 181 | ): 182 | last_out = None 183 | for out in super().forward( 184 | txt_feat, 185 | txt_feat_seq, 186 | tok, 187 | mask, 188 | img_feat, 189 | cf_guidance_scales, 190 | timestep_respacing, 191 | ): 192 | last_out = out 193 | return last_out 194 | -------------------------------------------------------------------------------- /scripts/semantic_t2i.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import optim 4 | from torch.utils.data import Dataset, DataLoader, random_split 5 | from scripts.dataset import Flickr8kDataset,Only_images_Flickr8kDataset 6 | from itertools import islice 7 | from ldm.util import instantiate_from_config 8 | from PIL import Image 9 | import PIL 10 | import torch 11 | import numpy as np 12 | import argparse, os 13 | from omegaconf import OmegaConf 14 | from pytorch_lightning import seed_everything 15 | from imwatermark import WatermarkEncoder 16 | from ldm.models.diffusion.ddim import DDIMSampler 17 | from tqdm import tqdm 18 | import lpips as lp 19 | from einops import rearrange, repeat 20 | from torch import autocast 21 | from tqdm import tqdm, trange 22 | from transformers import pipeline 23 | from scripts.qam import qam16ModulationTensor, qam16ModulationString 24 | import time 25 | from PIL import Image 26 | import torchvision.transforms as transforms 27 | from diffusers import StableDiffusionPipeline 28 | from transformers import pipeline 29 | from SSIM_PIL import compare_ssim 30 | from torchvision.transforms import Resize, ToTensor, Normalize, Compose 31 | 32 | 33 | ''' 34 | 35 | INIT DATASET AND DATALOADER 36 | 37 | ''' 38 | capt_file_path = "path/to/captions.txt" #"G:/Giordano/Flickr8kDataset/captions.txt" 39 | images_dir_path = "path/to/Images" #"G:/Giordano/Flickr8kDataset/Images/" 40 | batch_size = 1 41 | 42 | dataset = Only_images_Flickr8kDataset(images_dir_path) 43 | 44 | test_dataloader=DataLoader(dataset=dataset,batch_size=batch_size, shuffle=True) 45 | 46 | 47 | ''' 48 | MODEL CHECKPOINT 49 | 50 | ''' 51 | 52 | 53 | model_ckpt_path = "path/to/model-checkpoint" #"G:/Giordano/stablediffusion/checkpoints/v1-5-pruned.ckpt" #v2-1_512-ema-pruned.ckpt" 54 | config_path = "path/to/model-config" #"G:/Giordano/stablediffusion/configs/stable-diffusion/v1-inference.yaml" 55 | 56 | 57 | 58 | def load_model_from_config(config, ckpt, verbose=False): 59 | print(f"Loading model from {ckpt}") 60 | pl_sd = torch.load(ckpt, map_location="cpu") 61 | if "global_step" in pl_sd: 62 | print(f"Global Step: {pl_sd['global_step']}") 63 | sd = pl_sd["state_dict"] 64 | model = instantiate_from_config(config.model) 65 | m, u = model.load_state_dict(sd, strict=False) 66 | if len(m) > 0 and verbose: 67 | print("missing keys:") 68 | print(m) 69 | if len(u) > 0 and verbose: 70 | print("unexpected keys:") 71 | print(u) 72 | 73 | model.cuda() 74 | model.eval() 75 | return model 76 | 77 | 78 | def load_img(path): 79 | image = Image.open(path).convert("RGB") 80 | w, h = (512,512)#image.size 81 | #print(f"loaded input image of size ({w}, {h}) from {path}") 82 | w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 83 | image = image.resize((w, h), resample=PIL.Image.LANCZOS) 84 | image = np.array(image).astype(np.float32) / 255.0 85 | image = image[None].transpose(0, 3, 1, 2) 86 | image = torch.from_numpy(image) 87 | return 2. * image - 1. 88 | 89 | 90 | 91 | 92 | def test(dataloader, 93 | snr=10, 94 | num_images=100, 95 | sampling_steps = 50, 96 | outpath="outpath" 97 | ): 98 | 99 | blip = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large") 100 | 101 | model_id = "runwayml/stable-diffusion-v1-5" 102 | pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) 103 | pipe = pipe.to("cuda") 104 | 105 | transform = Compose([Resize((512,512), antialias=True), transforms.PILToTensor() ]) 106 | 107 | lpips = lp.LPIPS(net='alex') 108 | 109 | 110 | sample_path = os.path.join(outpath,f'Test-TEXTONLY-sample-{snr}-{sampling_steps}') 111 | 112 | 113 | os.makedirs(sample_path, exist_ok=True) 114 | 115 | sample_orig_path = os.path.join(outpath,f'Test-TEXTONLY-sample-orig-{snr}-{sampling_steps}') 116 | 117 | os.makedirs(sample_orig_path, exist_ok=True) 118 | 119 | text_path = os.path.join(outpath,f'Test-TEXTONLY-text-{snr}-{sampling_steps}') 120 | 121 | os.makedirs(text_path, exist_ok=True) 122 | 123 | 124 | lpips_values = [] 125 | time_values = [] 126 | ssim_values = [] 127 | 128 | i=0 129 | 130 | 131 | for batch in tqdm(dataloader,total=num_images): 132 | 133 | img_file_path = batch[0] 134 | 135 | #Open Image 136 | init_image = Image.open(img_file_path) 137 | 138 | #Automatically extract caption using BLIP model 139 | prompt_blip = blip(init_image)[0]["generated_text"] 140 | 141 | #Save Caption for Clip metric computation 142 | f = open(os.path.join(text_path, f"{i}.txt"),"a") 143 | f.write(prompt_blip) 144 | f.close() 145 | 146 | 147 | #Introduce noise in the text (aka. simulate noisy channel) 148 | prompt_corrupted = qam16ModulationString(prompt_blip,snr) 149 | 150 | #Compute time to reconstruct image 151 | time_start = time.time() 152 | #Reconstruct image using noisy text caption 153 | image_generated = pipe(prompt_corrupted,num_inference_steps=sampling_steps).images[0] 154 | time_finish = time.time() 155 | 156 | time_elapsed = time_finish - time_start 157 | time_values.append(time_elapsed) 158 | 159 | #Save images for subsequent FID and CLIP Score computation 160 | image_generated.save(os.path.join(sample_path,f'{i}.png')) 161 | init_image.save(os.path.join(sample_orig_path,f'{i}.png')) 162 | 163 | #Compute SSIM 164 | init_image_copy = init_image.resize((512, 512), resample=PIL.Image.LANCZOS) 165 | ssim_values.append(compare_ssim(init_image_copy, image_generated)) 166 | 167 | #Compute LPIPS 168 | image_generated = (transform(image_generated) / 255) *2 -1 169 | init_image = (transform(init_image) / 255 ) *2 - 1 170 | lp_score=lpips(init_image.cpu(),image_generated.cpu()).item() 171 | lpips_values.append(lp_score) 172 | 173 | i+=1 174 | if i==num_images: 175 | break 176 | 177 | print(f'mean lpips score: {sum(lpips_values)/len(lpips_values)}') 178 | 179 | print(f'mean ssim score: {sum(ssim_values)/len(ssim_values)}') 180 | 181 | print(f'mean time score: {sum(time_values)/len(time_values)}') 182 | 183 | if __name__ == "__main__": 184 | 185 | parser = argparse.ArgumentParser() 186 | 187 | parser.add_argument( 188 | "--outdir", 189 | type=str, 190 | nargs="?", 191 | help="dir to write results to", 192 | default="outputs/img2img-samples" 193 | ) 194 | 195 | parser.add_argument( 196 | "--seed", 197 | type=int, 198 | default=42, 199 | help="the seed (for reproducible sampling)", 200 | ) 201 | 202 | 203 | opt = parser.parse_args() 204 | seed_everything(opt.seed) 205 | 206 | 207 | os.makedirs(opt.outdir, exist_ok=True) 208 | outpath = opt.outdir 209 | 210 | 211 | #START TESTING 212 | 213 | test(test_dataloader,snr=10,num_images=100,outpath=outpath) 214 | 215 | test(test_dataloader,snr=8.75,num_images=100,outpath=outpath) 216 | 217 | test(test_dataloader,snr=7.50,num_images=100,outpath=outpath) 218 | 219 | test(test_dataloader,snr=6.25,num_images=100,outpath=outpath) 220 | 221 | test(test_dataloader,snr=5,num_images=100,outpath=outpath) 222 | 223 | -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/modules/xf.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Adapted from the repos below: 3 | # (a) Guided-Diffusion (https://github.com/openai/guided-diffusion) 4 | # (b) CLIP ViT (https://github.com/openai/CLIP/) 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import math 8 | 9 | import torch as th 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .nn import timestep_embedding 14 | 15 | 16 | def convert_module_to_f16(param): 17 | """ 18 | Convert primitive modules to float16. 19 | """ 20 | if isinstance(param, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): 21 | param.weight.data = param.weight.data.half() 22 | if param.bias is not None: 23 | param.bias.data = param.bias.data.half() 24 | 25 | 26 | class LayerNorm(nn.LayerNorm): 27 | """ 28 | Implementation that supports fp16 inputs but fp32 gains/biases. 29 | """ 30 | 31 | def forward(self, x: th.Tensor): 32 | return super().forward(x.float()).to(x.dtype) 33 | 34 | 35 | class MultiheadAttention(nn.Module): 36 | def __init__(self, n_ctx, width, heads): 37 | super().__init__() 38 | self.n_ctx = n_ctx 39 | self.width = width 40 | self.heads = heads 41 | self.c_qkv = nn.Linear(width, width * 3) 42 | self.c_proj = nn.Linear(width, width) 43 | self.attention = QKVMultiheadAttention(heads, n_ctx) 44 | 45 | def forward(self, x, mask=None): 46 | x = self.c_qkv(x) 47 | x = self.attention(x, mask=mask) 48 | x = self.c_proj(x) 49 | return x 50 | 51 | 52 | class MLP(nn.Module): 53 | def __init__(self, width): 54 | super().__init__() 55 | self.width = width 56 | self.c_fc = nn.Linear(width, width * 4) 57 | self.c_proj = nn.Linear(width * 4, width) 58 | self.gelu = nn.GELU() 59 | 60 | def forward(self, x): 61 | return self.c_proj(self.gelu(self.c_fc(x))) 62 | 63 | 64 | class QKVMultiheadAttention(nn.Module): 65 | def __init__(self, n_heads: int, n_ctx: int): 66 | super().__init__() 67 | self.n_heads = n_heads 68 | self.n_ctx = n_ctx 69 | 70 | def forward(self, qkv, mask=None): 71 | bs, n_ctx, width = qkv.shape 72 | attn_ch = width // self.n_heads // 3 73 | scale = 1 / math.sqrt(math.sqrt(attn_ch)) 74 | qkv = qkv.view(bs, n_ctx, self.n_heads, -1) 75 | q, k, v = th.split(qkv, attn_ch, dim=-1) 76 | weight = th.einsum("bthc,bshc->bhts", q * scale, k * scale) 77 | wdtype = weight.dtype 78 | if mask is not None: 79 | weight = weight + mask[:, None, ...] 80 | weight = th.softmax(weight, dim=-1).type(wdtype) 81 | return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 82 | 83 | 84 | class ResidualAttentionBlock(nn.Module): 85 | def __init__( 86 | self, 87 | n_ctx: int, 88 | width: int, 89 | heads: int, 90 | ): 91 | super().__init__() 92 | 93 | self.attn = MultiheadAttention( 94 | n_ctx, 95 | width, 96 | heads, 97 | ) 98 | self.ln_1 = LayerNorm(width) 99 | self.mlp = MLP(width) 100 | self.ln_2 = LayerNorm(width) 101 | 102 | def forward(self, x, mask=None): 103 | x = x + self.attn(self.ln_1(x), mask=mask) 104 | x = x + self.mlp(self.ln_2(x)) 105 | return x 106 | 107 | 108 | class Transformer(nn.Module): 109 | def __init__( 110 | self, 111 | n_ctx: int, 112 | width: int, 113 | layers: int, 114 | heads: int, 115 | ): 116 | super().__init__() 117 | self.n_ctx = n_ctx 118 | self.width = width 119 | self.layers = layers 120 | self.resblocks = nn.ModuleList( 121 | [ 122 | ResidualAttentionBlock( 123 | n_ctx, 124 | width, 125 | heads, 126 | ) 127 | for _ in range(layers) 128 | ] 129 | ) 130 | 131 | def forward(self, x, mask=None): 132 | for block in self.resblocks: 133 | x = block(x, mask=mask) 134 | return x 135 | 136 | 137 | class PriorTransformer(nn.Module): 138 | """ 139 | A Causal Transformer that conditions on CLIP text embedding, text. 140 | 141 | :param text_ctx: number of text tokens to expect. 142 | :param xf_width: width of the transformer. 143 | :param xf_layers: depth of the transformer. 144 | :param xf_heads: heads in the transformer. 145 | :param xf_final_ln: use a LayerNorm after the output layer. 146 | :param clip_dim: dimension of clip feature. 147 | """ 148 | 149 | def __init__( 150 | self, 151 | text_ctx, 152 | xf_width, 153 | xf_layers, 154 | xf_heads, 155 | xf_final_ln, 156 | clip_dim, 157 | ): 158 | super().__init__() 159 | 160 | self.text_ctx = text_ctx 161 | self.xf_width = xf_width 162 | self.xf_layers = xf_layers 163 | self.xf_heads = xf_heads 164 | self.clip_dim = clip_dim 165 | self.ext_len = 4 166 | 167 | self.time_embed = nn.Sequential( 168 | nn.Linear(xf_width, xf_width), 169 | nn.SiLU(), 170 | nn.Linear(xf_width, xf_width), 171 | ) 172 | self.text_enc_proj = nn.Linear(clip_dim, xf_width) 173 | self.text_emb_proj = nn.Linear(clip_dim, xf_width) 174 | self.clip_img_proj = nn.Linear(clip_dim, xf_width) 175 | self.out_proj = nn.Linear(xf_width, clip_dim) 176 | self.transformer = Transformer( 177 | text_ctx + self.ext_len, 178 | xf_width, 179 | xf_layers, 180 | xf_heads, 181 | ) 182 | if xf_final_ln: 183 | self.final_ln = LayerNorm(xf_width) 184 | else: 185 | self.final_ln = None 186 | 187 | self.positional_embedding = nn.Parameter( 188 | th.empty(1, text_ctx + self.ext_len, xf_width) 189 | ) 190 | self.prd_emb = nn.Parameter(th.randn((1, 1, xf_width))) 191 | 192 | nn.init.normal_(self.prd_emb, std=0.01) 193 | nn.init.normal_(self.positional_embedding, std=0.01) 194 | 195 | def forward( 196 | self, 197 | x, 198 | timesteps, 199 | text_emb=None, 200 | text_enc=None, 201 | mask=None, 202 | causal_mask=None, 203 | ): 204 | bsz = x.shape[0] 205 | mask = F.pad(mask, (0, self.ext_len), value=True) 206 | 207 | t_emb = self.time_embed(timestep_embedding(timesteps, self.xf_width)) 208 | text_enc = self.text_enc_proj(text_enc) 209 | text_emb = self.text_emb_proj(text_emb) 210 | x = self.clip_img_proj(x) 211 | 212 | input_seq = [ 213 | text_enc, 214 | text_emb[:, None, :], 215 | t_emb[:, None, :], 216 | x[:, None, :], 217 | self.prd_emb.to(x.dtype).expand(bsz, -1, -1), 218 | ] 219 | input = th.cat(input_seq, dim=1) 220 | input = input + self.positional_embedding.to(input.dtype) 221 | 222 | mask = th.where(mask, 0.0, float("-inf")) 223 | mask = (mask[:, None, :] + causal_mask).to(input.dtype) 224 | 225 | out = self.transformer(input, mask=mask) 226 | if self.final_ln is not None: 227 | out = self.final_ln(out) 228 | 229 | out = self.out_proj(out[:, -1]) 230 | 231 | return out 232 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | from torch import optim 5 | import numpy as np 6 | 7 | from inspect import isfunction 8 | from PIL import Image, ImageDraw, ImageFont 9 | 10 | 11 | def autocast(f): 12 | def do_autocast(*args, **kwargs): 13 | with torch.cuda.amp.autocast(enabled=True, 14 | dtype=torch.get_autocast_gpu_dtype(), 15 | cache_enabled=torch.is_autocast_cache_enabled()): 16 | return f(*args, **kwargs) 17 | 18 | return do_autocast 19 | 20 | 21 | def log_txt_as_img(wh, xc, size=10): 22 | # wh a tuple of (width, height) 23 | # xc a list of captions to plot 24 | b = len(xc) 25 | txts = list() 26 | for bi in range(b): 27 | txt = Image.new("RGB", wh, color="white") 28 | draw = ImageDraw.Draw(txt) 29 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 30 | nc = int(40 * (wh[0] / 256)) 31 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 32 | 33 | try: 34 | draw.text((0, 0), lines, fill="black", font=font) 35 | except UnicodeEncodeError: 36 | print("Cant encode string for logging. Skipping.") 37 | 38 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 39 | txts.append(txt) 40 | txts = np.stack(txts) 41 | txts = torch.tensor(txts) 42 | return txts 43 | 44 | 45 | def ismap(x): 46 | if not isinstance(x, torch.Tensor): 47 | return False 48 | return (len(x.shape) == 4) and (x.shape[1] > 3) 49 | 50 | 51 | def isimage(x): 52 | if not isinstance(x,torch.Tensor): 53 | return False 54 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 55 | 56 | 57 | def exists(x): 58 | return x is not None 59 | 60 | 61 | def default(val, d): 62 | if exists(val): 63 | return val 64 | return d() if isfunction(d) else d 65 | 66 | 67 | def mean_flat(tensor): 68 | """ 69 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 70 | Take the mean over all non-batch dimensions. 71 | """ 72 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 73 | 74 | 75 | def count_params(model, verbose=False): 76 | total_params = sum(p.numel() for p in model.parameters()) 77 | if verbose: 78 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 79 | return total_params 80 | 81 | 82 | def instantiate_from_config(config): 83 | if not "target" in config: 84 | if config == '__is_first_stage__': 85 | return None 86 | elif config == "__is_unconditional__": 87 | return None 88 | raise KeyError("Expected key `target` to instantiate.") 89 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 90 | 91 | 92 | def get_obj_from_str(string, reload=False): 93 | module, cls = string.rsplit(".", 1) 94 | if reload: 95 | module_imp = importlib.import_module(module) 96 | importlib.reload(module_imp) 97 | return getattr(importlib.import_module(module, package=None), cls) 98 | 99 | 100 | class AdamWwithEMAandWings(optim.Optimizer): 101 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 102 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using 103 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code 104 | ema_power=1., param_names=()): 105 | """AdamW that saves EMA versions of the parameters.""" 106 | if not 0.0 <= lr: 107 | raise ValueError("Invalid learning rate: {}".format(lr)) 108 | if not 0.0 <= eps: 109 | raise ValueError("Invalid epsilon value: {}".format(eps)) 110 | if not 0.0 <= betas[0] < 1.0: 111 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 112 | if not 0.0 <= betas[1] < 1.0: 113 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 114 | if not 0.0 <= weight_decay: 115 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 116 | if not 0.0 <= ema_decay <= 1.0: 117 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) 118 | defaults = dict(lr=lr, betas=betas, eps=eps, 119 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, 120 | ema_power=ema_power, param_names=param_names) 121 | super().__init__(params, defaults) 122 | 123 | def __setstate__(self, state): 124 | super().__setstate__(state) 125 | for group in self.param_groups: 126 | group.setdefault('amsgrad', False) 127 | 128 | @torch.no_grad() 129 | def step(self, closure=None): 130 | """Performs a single optimization step. 131 | Args: 132 | closure (callable, optional): A closure that reevaluates the model 133 | and returns the loss. 134 | """ 135 | loss = None 136 | if closure is not None: 137 | with torch.enable_grad(): 138 | loss = closure() 139 | 140 | for group in self.param_groups: 141 | params_with_grad = [] 142 | grads = [] 143 | exp_avgs = [] 144 | exp_avg_sqs = [] 145 | ema_params_with_grad = [] 146 | state_sums = [] 147 | max_exp_avg_sqs = [] 148 | state_steps = [] 149 | amsgrad = group['amsgrad'] 150 | beta1, beta2 = group['betas'] 151 | ema_decay = group['ema_decay'] 152 | ema_power = group['ema_power'] 153 | 154 | for p in group['params']: 155 | if p.grad is None: 156 | continue 157 | params_with_grad.append(p) 158 | if p.grad.is_sparse: 159 | raise RuntimeError('AdamW does not support sparse gradients') 160 | grads.append(p.grad) 161 | 162 | state = self.state[p] 163 | 164 | # State initialization 165 | if len(state) == 0: 166 | state['step'] = 0 167 | # Exponential moving average of gradient values 168 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 169 | # Exponential moving average of squared gradient values 170 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 171 | if amsgrad: 172 | # Maintains max of all exp. moving avg. of sq. grad. values 173 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 174 | # Exponential moving average of parameter values 175 | state['param_exp_avg'] = p.detach().float().clone() 176 | 177 | exp_avgs.append(state['exp_avg']) 178 | exp_avg_sqs.append(state['exp_avg_sq']) 179 | ema_params_with_grad.append(state['param_exp_avg']) 180 | 181 | if amsgrad: 182 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 183 | 184 | # update the steps for each param group update 185 | state['step'] += 1 186 | # record the step after step update 187 | state_steps.append(state['step']) 188 | 189 | optim._functional.adamw(params_with_grad, 190 | grads, 191 | exp_avgs, 192 | exp_avg_sqs, 193 | max_exp_avg_sqs, 194 | state_steps, 195 | amsgrad=amsgrad, 196 | beta1=beta1, 197 | beta2=beta2, 198 | lr=group['lr'], 199 | weight_decay=group['weight_decay'], 200 | eps=group['eps'], 201 | maximize=False) 202 | 203 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) 204 | for param, ema_param in zip(params_with_grad, ema_params_with_grad): 205 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) 206 | 207 | return loss -------------------------------------------------------------------------------- /ldm/modules/midas/midas/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.float32), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height). 50 | """ 51 | 52 | def __init__( 53 | self, 54 | width, 55 | height, 56 | resize_target=True, 57 | keep_aspect_ratio=False, 58 | ensure_multiple_of=1, 59 | resize_method="lower_bound", 60 | image_interpolation_method=cv2.INTER_AREA, 61 | ): 62 | """Init. 63 | 64 | Args: 65 | width (int): desired output width 66 | height (int): desired output height 67 | resize_target (bool, optional): 68 | True: Resize the full sample (image, mask, target). 69 | False: Resize image only. 70 | Defaults to True. 71 | keep_aspect_ratio (bool, optional): 72 | True: Keep the aspect ratio of the input sample. 73 | Output sample might not have the given width and height, and 74 | resize behaviour depends on the parameter 'resize_method'. 75 | Defaults to False. 76 | ensure_multiple_of (int, optional): 77 | Output width and height is constrained to be multiple of this parameter. 78 | Defaults to 1. 79 | resize_method (str, optional): 80 | "lower_bound": Output will be at least as large as the given size. 81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 83 | Defaults to "lower_bound". 84 | """ 85 | self.__width = width 86 | self.__height = height 87 | 88 | self.__resize_target = resize_target 89 | self.__keep_aspect_ratio = keep_aspect_ratio 90 | self.__multiple_of = ensure_multiple_of 91 | self.__resize_method = resize_method 92 | self.__image_interpolation_method = image_interpolation_method 93 | 94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 96 | 97 | if max_val is not None and y > max_val: 98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 99 | 100 | if y < min_val: 101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 102 | 103 | return y 104 | 105 | def get_size(self, width, height): 106 | # determine new height and width 107 | scale_height = self.__height / height 108 | scale_width = self.__width / width 109 | 110 | if self.__keep_aspect_ratio: 111 | if self.__resize_method == "lower_bound": 112 | # scale such that output size is lower bound 113 | if scale_width > scale_height: 114 | # fit width 115 | scale_height = scale_width 116 | else: 117 | # fit height 118 | scale_width = scale_height 119 | elif self.__resize_method == "upper_bound": 120 | # scale such that output size is upper bound 121 | if scale_width < scale_height: 122 | # fit width 123 | scale_height = scale_width 124 | else: 125 | # fit height 126 | scale_width = scale_height 127 | elif self.__resize_method == "minimal": 128 | # scale as least as possbile 129 | if abs(1 - scale_width) < abs(1 - scale_height): 130 | # fit width 131 | scale_height = scale_width 132 | else: 133 | # fit height 134 | scale_width = scale_height 135 | else: 136 | raise ValueError( 137 | f"resize_method {self.__resize_method} not implemented" 138 | ) 139 | 140 | if self.__resize_method == "lower_bound": 141 | new_height = self.constrain_to_multiple_of( 142 | scale_height * height, min_val=self.__height 143 | ) 144 | new_width = self.constrain_to_multiple_of( 145 | scale_width * width, min_val=self.__width 146 | ) 147 | elif self.__resize_method == "upper_bound": 148 | new_height = self.constrain_to_multiple_of( 149 | scale_height * height, max_val=self.__height 150 | ) 151 | new_width = self.constrain_to_multiple_of( 152 | scale_width * width, max_val=self.__width 153 | ) 154 | elif self.__resize_method == "minimal": 155 | new_height = self.constrain_to_multiple_of(scale_height * height) 156 | new_width = self.constrain_to_multiple_of(scale_width * width) 157 | else: 158 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 159 | 160 | return (new_width, new_height) 161 | 162 | def __call__(self, sample): 163 | width, height = self.get_size( 164 | sample["image"].shape[1], sample["image"].shape[0] 165 | ) 166 | 167 | # resize sample 168 | sample["image"] = cv2.resize( 169 | sample["image"], 170 | (width, height), 171 | interpolation=self.__image_interpolation_method, 172 | ) 173 | 174 | if self.__resize_target: 175 | if "disparity" in sample: 176 | sample["disparity"] = cv2.resize( 177 | sample["disparity"], 178 | (width, height), 179 | interpolation=cv2.INTER_NEAREST, 180 | ) 181 | 182 | if "depth" in sample: 183 | sample["depth"] = cv2.resize( 184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 185 | ) 186 | 187 | sample["mask"] = cv2.resize( 188 | sample["mask"].astype(np.float32), 189 | (width, height), 190 | interpolation=cv2.INTER_NEAREST, 191 | ) 192 | sample["mask"] = sample["mask"].astype(bool) 193 | 194 | return sample 195 | 196 | 197 | class NormalizeImage(object): 198 | """Normlize image by given mean and std. 199 | """ 200 | 201 | def __init__(self, mean, std): 202 | self.__mean = mean 203 | self.__std = std 204 | 205 | def __call__(self, sample): 206 | sample["image"] = (sample["image"] - self.__mean) / self.__std 207 | 208 | return sample 209 | 210 | 211 | class PrepareForNet(object): 212 | """Prepare sample for usage as network input. 213 | """ 214 | 215 | def __init__(self): 216 | pass 217 | 218 | def __call__(self, sample): 219 | image = np.transpose(sample["image"], (2, 0, 1)) 220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 221 | 222 | if "mask" in sample: 223 | sample["mask"] = sample["mask"].astype(np.float32) 224 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 225 | 226 | if "disparity" in sample: 227 | disparity = sample["disparity"].astype(np.float32) 228 | sample["disparity"] = np.ascontiguousarray(disparity) 229 | 230 | if "depth" in sample: 231 | depth = sample["depth"].astype(np.float32) 232 | sample["depth"] = np.ascontiguousarray(depth) 233 | 234 | return sample 235 | -------------------------------------------------------------------------------- /ldm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import torch.nn.functional as F 4 | from contextlib import contextmanager 5 | 6 | from ldm.modules.diffusionmodules.model import Encoder, Decoder 7 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution 8 | 9 | from ldm.util import instantiate_from_config 10 | from ldm.modules.ema import LitEma 11 | 12 | 13 | class AutoencoderKL(pl.LightningModule): 14 | def __init__(self, 15 | ddconfig, 16 | lossconfig, 17 | embed_dim, 18 | ckpt_path=None, 19 | ignore_keys=[], 20 | image_key="image", 21 | colorize_nlabels=None, 22 | monitor=None, 23 | ema_decay=None, 24 | learn_logvar=False 25 | ): 26 | super().__init__() 27 | self.learn_logvar = learn_logvar 28 | self.image_key = image_key 29 | self.encoder = Encoder(**ddconfig) 30 | self.decoder = Decoder(**ddconfig) 31 | self.loss = instantiate_from_config(lossconfig) 32 | assert ddconfig["double_z"] 33 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 34 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 35 | self.embed_dim = embed_dim 36 | if colorize_nlabels is not None: 37 | assert type(colorize_nlabels)==int 38 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 39 | if monitor is not None: 40 | self.monitor = monitor 41 | 42 | self.use_ema = ema_decay is not None 43 | if self.use_ema: 44 | self.ema_decay = ema_decay 45 | assert 0. < ema_decay < 1. 46 | self.model_ema = LitEma(self, decay=ema_decay) 47 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 48 | 49 | if ckpt_path is not None: 50 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 51 | 52 | def init_from_ckpt(self, path, ignore_keys=list()): 53 | sd = torch.load(path, map_location="cpu")["state_dict"] 54 | keys = list(sd.keys()) 55 | for k in keys: 56 | for ik in ignore_keys: 57 | if k.startswith(ik): 58 | print("Deleting key {} from state_dict.".format(k)) 59 | del sd[k] 60 | self.load_state_dict(sd, strict=False) 61 | print(f"Restored from {path}") 62 | 63 | @contextmanager 64 | def ema_scope(self, context=None): 65 | if self.use_ema: 66 | self.model_ema.store(self.parameters()) 67 | self.model_ema.copy_to(self) 68 | if context is not None: 69 | print(f"{context}: Switched to EMA weights") 70 | try: 71 | yield None 72 | finally: 73 | if self.use_ema: 74 | self.model_ema.restore(self.parameters()) 75 | if context is not None: 76 | print(f"{context}: Restored training weights") 77 | 78 | def on_train_batch_end(self, *args, **kwargs): 79 | if self.use_ema: 80 | self.model_ema(self) 81 | 82 | def encode(self, x): 83 | h = self.encoder(x) 84 | moments = self.quant_conv(h) 85 | posterior = DiagonalGaussianDistribution(moments) 86 | return posterior 87 | 88 | def decode(self, z): 89 | z = self.post_quant_conv(z) 90 | dec = self.decoder(z) 91 | return dec 92 | 93 | def forward(self, input, sample_posterior=True): 94 | posterior = self.encode(input) 95 | if sample_posterior: 96 | z = posterior.sample() 97 | else: 98 | z = posterior.mode() 99 | dec = self.decode(z) 100 | return dec, posterior 101 | 102 | def get_input(self, batch, k): 103 | x = batch[k] 104 | if len(x.shape) == 3: 105 | x = x[..., None] 106 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 107 | return x 108 | 109 | def training_step(self, batch, batch_idx, optimizer_idx): 110 | inputs = self.get_input(batch, self.image_key) 111 | reconstructions, posterior = self(inputs) 112 | 113 | if optimizer_idx == 0: 114 | # train encoder+decoder+logvar 115 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 116 | last_layer=self.get_last_layer(), split="train") 117 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 118 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 119 | return aeloss 120 | 121 | if optimizer_idx == 1: 122 | # train the discriminator 123 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 124 | last_layer=self.get_last_layer(), split="train") 125 | 126 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 127 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 128 | return discloss 129 | 130 | def validation_step(self, batch, batch_idx): 131 | log_dict = self._validation_step(batch, batch_idx) 132 | with self.ema_scope(): 133 | log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") 134 | return log_dict 135 | 136 | def _validation_step(self, batch, batch_idx, postfix=""): 137 | inputs = self.get_input(batch, self.image_key) 138 | reconstructions, posterior = self(inputs) 139 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 140 | last_layer=self.get_last_layer(), split="val"+postfix) 141 | 142 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 143 | last_layer=self.get_last_layer(), split="val"+postfix) 144 | 145 | self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) 146 | self.log_dict(log_dict_ae) 147 | self.log_dict(log_dict_disc) 148 | return self.log_dict 149 | 150 | def configure_optimizers(self): 151 | lr = self.learning_rate 152 | ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( 153 | self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) 154 | if self.learn_logvar: 155 | print(f"{self.__class__.__name__}: Learning logvar") 156 | ae_params_list.append(self.loss.logvar) 157 | opt_ae = torch.optim.Adam(ae_params_list, 158 | lr=lr, betas=(0.5, 0.9)) 159 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 160 | lr=lr, betas=(0.5, 0.9)) 161 | return [opt_ae, opt_disc], [] 162 | 163 | def get_last_layer(self): 164 | return self.decoder.conv_out.weight 165 | 166 | @torch.no_grad() 167 | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): 168 | log = dict() 169 | x = self.get_input(batch, self.image_key) 170 | x = x.to(self.device) 171 | if not only_inputs: 172 | xrec, posterior = self(x) 173 | if x.shape[1] > 3: 174 | # colorize with random projection 175 | assert xrec.shape[1] > 3 176 | x = self.to_rgb(x) 177 | xrec = self.to_rgb(xrec) 178 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 179 | log["reconstructions"] = xrec 180 | if log_ema or self.use_ema: 181 | with self.ema_scope(): 182 | xrec_ema, posterior_ema = self(x) 183 | if x.shape[1] > 3: 184 | # colorize with random projection 185 | assert xrec_ema.shape[1] > 3 186 | xrec_ema = self.to_rgb(xrec_ema) 187 | log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) 188 | log["reconstructions_ema"] = xrec_ema 189 | log["inputs"] = x 190 | return log 191 | 192 | def to_rgb(self, x): 193 | assert self.image_key == "segmentation" 194 | if not hasattr(self, "colorize"): 195 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 196 | x = F.conv2d(x, weight=self.colorize) 197 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 198 | return x 199 | 200 | 201 | class IdentityFirstStage(torch.nn.Module): 202 | def __init__(self, *args, vq_interface=False, **kwargs): 203 | self.vq_interface = vq_interface 204 | super().__init__() 205 | 206 | def encode(self, x, *args, **kwargs): 207 | return x 208 | 209 | def decode(self, x, *args, **kwargs): 210 | return x 211 | 212 | def quantize(self, x, *args, **kwargs): 213 | if self.vq_interface: 214 | return x, None, [None, None, None] 215 | return x 216 | 217 | def forward(self, x, *args, **kwargs): 218 | return x 219 | 220 | -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/sampler.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Karlo-v1.0.alpha 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | 5 | # source: https://github.com/kakaobrain/karlo/blob/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/karlo/sampler/t2i.py#L15 6 | # ------------------------------------------------------------------------------------ 7 | 8 | from typing import Iterator 9 | 10 | import torch 11 | import torchvision.transforms.functional as TVF 12 | from torchvision.transforms import InterpolationMode 13 | 14 | from .template import BaseSampler, CKPT_PATH 15 | 16 | 17 | class T2ISampler(BaseSampler): 18 | """ 19 | A sampler for text-to-image generation. 20 | :param root_dir: directory for model checkpoints. 21 | :param sampling_type: ["default", "fast"] 22 | """ 23 | 24 | def __init__( 25 | self, 26 | root_dir: str, 27 | sampling_type: str = "default", 28 | ): 29 | super().__init__(root_dir, sampling_type) 30 | 31 | @classmethod 32 | def from_pretrained( 33 | cls, 34 | root_dir: str, 35 | clip_model_path: str, 36 | clip_stat_path: str, 37 | sampling_type: str = "default", 38 | ): 39 | 40 | model = cls( 41 | root_dir=root_dir, 42 | sampling_type=sampling_type, 43 | ) 44 | model.load_clip(clip_model_path) 45 | model.load_prior( 46 | f"{CKPT_PATH['prior']}", 47 | clip_stat_path=clip_stat_path, 48 | prior_config="configs/karlo/prior_1B_vit_l.yaml" 49 | ) 50 | model.load_decoder(f"{CKPT_PATH['decoder']}", decoder_config="configs/karlo/decoder_900M_vit_l.yaml") 51 | model.load_sr_64_256(CKPT_PATH["sr_256"], sr_config="configs/karlo/improved_sr_64_256_1.4B.yaml") 52 | return model 53 | 54 | def preprocess( 55 | self, 56 | prompt: str, 57 | bsz: int, 58 | ): 59 | """Setup prompts & cfg scales""" 60 | prompts_batch = [prompt for _ in range(bsz)] 61 | 62 | prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch) 63 | prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda") 64 | 65 | decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch) 66 | decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda") 67 | 68 | """ Get CLIP text feature """ 69 | clip_model = self._clip 70 | tokenizer = self._tokenizer 71 | max_txt_length = self._prior.model.text_ctx 72 | 73 | tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length) 74 | cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length) 75 | if not (cf_token.shape == tok.shape): 76 | cf_token = cf_token.expand(tok.shape[0], -1) 77 | cf_mask = cf_mask.expand(tok.shape[0], -1) 78 | 79 | tok = torch.cat([tok, cf_token], dim=0) 80 | mask = torch.cat([mask, cf_mask], dim=0) 81 | 82 | tok, mask = tok.to(device="cuda"), mask.to(device="cuda") 83 | txt_feat, txt_feat_seq = clip_model.encode_text(tok) 84 | 85 | return ( 86 | prompts_batch, 87 | prior_cf_scales_batch, 88 | decoder_cf_scales_batch, 89 | txt_feat, 90 | txt_feat_seq, 91 | tok, 92 | mask, 93 | ) 94 | 95 | def __call__( 96 | self, 97 | prompt: str, 98 | bsz: int, 99 | progressive_mode=None, 100 | ) -> Iterator[torch.Tensor]: 101 | assert progressive_mode in ("loop", "stage", "final") 102 | with torch.no_grad(), torch.cuda.amp.autocast(): 103 | ( 104 | prompts_batch, 105 | prior_cf_scales_batch, 106 | decoder_cf_scales_batch, 107 | txt_feat, 108 | txt_feat_seq, 109 | tok, 110 | mask, 111 | ) = self.preprocess( 112 | prompt, 113 | bsz, 114 | ) 115 | 116 | """ Transform CLIP text feature into image feature """ 117 | img_feat = self._prior( 118 | txt_feat, 119 | txt_feat_seq, 120 | mask, 121 | prior_cf_scales_batch, 122 | timestep_respacing=self._prior_sm, 123 | ) 124 | 125 | """ Generate 64x64px images """ 126 | images_64_outputs = self._decoder( 127 | txt_feat, 128 | txt_feat_seq, 129 | tok, 130 | mask, 131 | img_feat, 132 | cf_guidance_scales=decoder_cf_scales_batch, 133 | timestep_respacing=self._decoder_sm, 134 | ) 135 | 136 | images_64 = None 137 | for k, out in enumerate(images_64_outputs): 138 | images_64 = out 139 | if progressive_mode == "loop": 140 | yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0) 141 | if progressive_mode == "stage": 142 | yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0) 143 | 144 | images_64 = torch.clamp(images_64, -1, 1) 145 | 146 | """ Upsample 64x64 to 256x256 """ 147 | images_256 = TVF.resize( 148 | images_64, 149 | [256, 256], 150 | interpolation=InterpolationMode.BICUBIC, 151 | antialias=True, 152 | ) 153 | images_256_outputs = self._sr_64_256( 154 | images_256, timestep_respacing=self._sr_sm 155 | ) 156 | 157 | for k, out in enumerate(images_256_outputs): 158 | images_256 = out 159 | if progressive_mode == "loop": 160 | yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0) 161 | if progressive_mode == "stage": 162 | yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0) 163 | 164 | yield torch.clamp(images_256 * 0.5 + 0.5, 0.0, 1.0) 165 | 166 | 167 | class PriorSampler(BaseSampler): 168 | """ 169 | A sampler for text-to-image generation, but only the prior. 170 | :param root_dir: directory for model checkpoints. 171 | :param sampling_type: ["default", "fast"] 172 | """ 173 | 174 | def __init__( 175 | self, 176 | root_dir: str, 177 | sampling_type: str = "default", 178 | ): 179 | super().__init__(root_dir, sampling_type) 180 | 181 | @classmethod 182 | def from_pretrained( 183 | cls, 184 | root_dir: str, 185 | clip_model_path: str, 186 | clip_stat_path: str, 187 | sampling_type: str = "default", 188 | ): 189 | model = cls( 190 | root_dir=root_dir, 191 | sampling_type=sampling_type, 192 | ) 193 | model.load_clip(clip_model_path) 194 | model.load_prior( 195 | f"{CKPT_PATH['prior']}", 196 | clip_stat_path=clip_stat_path, 197 | prior_config="configs/karlo/prior_1B_vit_l.yaml" 198 | ) 199 | return model 200 | 201 | def preprocess( 202 | self, 203 | prompt: str, 204 | bsz: int, 205 | ): 206 | """Setup prompts & cfg scales""" 207 | prompts_batch = [prompt for _ in range(bsz)] 208 | 209 | prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch) 210 | prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda") 211 | 212 | decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch) 213 | decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda") 214 | 215 | """ Get CLIP text feature """ 216 | clip_model = self._clip 217 | tokenizer = self._tokenizer 218 | max_txt_length = self._prior.model.text_ctx 219 | 220 | tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length) 221 | cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length) 222 | if not (cf_token.shape == tok.shape): 223 | cf_token = cf_token.expand(tok.shape[0], -1) 224 | cf_mask = cf_mask.expand(tok.shape[0], -1) 225 | 226 | tok = torch.cat([tok, cf_token], dim=0) 227 | mask = torch.cat([mask, cf_mask], dim=0) 228 | 229 | tok, mask = tok.to(device="cuda"), mask.to(device="cuda") 230 | txt_feat, txt_feat_seq = clip_model.encode_text(tok) 231 | 232 | return ( 233 | prompts_batch, 234 | prior_cf_scales_batch, 235 | decoder_cf_scales_batch, 236 | txt_feat, 237 | txt_feat_seq, 238 | tok, 239 | mask, 240 | ) 241 | 242 | def __call__( 243 | self, 244 | prompt: str, 245 | bsz: int, 246 | progressive_mode=None, 247 | ) -> Iterator[torch.Tensor]: 248 | assert progressive_mode in ("loop", "stage", "final") 249 | with torch.no_grad(), torch.cuda.amp.autocast(): 250 | ( 251 | prompts_batch, 252 | prior_cf_scales_batch, 253 | decoder_cf_scales_batch, 254 | txt_feat, 255 | txt_feat_seq, 256 | tok, 257 | mask, 258 | ) = self.preprocess( 259 | prompt, 260 | bsz, 261 | ) 262 | 263 | """ Transform CLIP text feature into image feature """ 264 | img_feat = self._prior( 265 | txt_feat, 266 | txt_feat_seq, 267 | mask, 268 | prior_cf_scales_batch, 269 | timestep_respacing=self._prior_sm, 270 | ) 271 | 272 | yield img_feat 273 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vit import ( 5 | _make_pretrained_vitb_rn50_384, 6 | _make_pretrained_vitl16_384, 7 | _make_pretrained_vitb16_384, 8 | forward_vit, 9 | ) 10 | 11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): 12 | if backbone == "vitl16_384": 13 | pretrained = _make_pretrained_vitl16_384( 14 | use_pretrained, hooks=hooks, use_readout=use_readout 15 | ) 16 | scratch = _make_scratch( 17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 18 | ) # ViT-L/16 - 85.0% Top1 (backbone) 19 | elif backbone == "vitb_rn50_384": 20 | pretrained = _make_pretrained_vitb_rn50_384( 21 | use_pretrained, 22 | hooks=hooks, 23 | use_vit_only=use_vit_only, 24 | use_readout=use_readout, 25 | ) 26 | scratch = _make_scratch( 27 | [256, 512, 768, 768], features, groups=groups, expand=expand 28 | ) # ViT-H/16 - 85.0% Top1 (backbone) 29 | elif backbone == "vitb16_384": 30 | pretrained = _make_pretrained_vitb16_384( 31 | use_pretrained, hooks=hooks, use_readout=use_readout 32 | ) 33 | scratch = _make_scratch( 34 | [96, 192, 384, 768], features, groups=groups, expand=expand 35 | ) # ViT-B/16 - 84.6% Top1 (backbone) 36 | elif backbone == "resnext101_wsl": 37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 39 | elif backbone == "efficientnet_lite3": 40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) 41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 42 | else: 43 | print(f"Backbone '{backbone}' not implemented") 44 | assert False 45 | 46 | return pretrained, scratch 47 | 48 | 49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 50 | scratch = nn.Module() 51 | 52 | out_shape1 = out_shape 53 | out_shape2 = out_shape 54 | out_shape3 = out_shape 55 | out_shape4 = out_shape 56 | if expand==True: 57 | out_shape1 = out_shape 58 | out_shape2 = out_shape*2 59 | out_shape3 = out_shape*4 60 | out_shape4 = out_shape*8 61 | 62 | scratch.layer1_rn = nn.Conv2d( 63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 64 | ) 65 | scratch.layer2_rn = nn.Conv2d( 66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 67 | ) 68 | scratch.layer3_rn = nn.Conv2d( 69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 70 | ) 71 | scratch.layer4_rn = nn.Conv2d( 72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 73 | ) 74 | 75 | return scratch 76 | 77 | 78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): 79 | efficientnet = torch.hub.load( 80 | "rwightman/gen-efficientnet-pytorch", 81 | "tf_efficientnet_lite3", 82 | pretrained=use_pretrained, 83 | exportable=exportable 84 | ) 85 | return _make_efficientnet_backbone(efficientnet) 86 | 87 | 88 | def _make_efficientnet_backbone(effnet): 89 | pretrained = nn.Module() 90 | 91 | pretrained.layer1 = nn.Sequential( 92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] 93 | ) 94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) 95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) 96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) 97 | 98 | return pretrained 99 | 100 | 101 | def _make_resnet_backbone(resnet): 102 | pretrained = nn.Module() 103 | pretrained.layer1 = nn.Sequential( 104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 105 | ) 106 | 107 | pretrained.layer2 = resnet.layer2 108 | pretrained.layer3 = resnet.layer3 109 | pretrained.layer4 = resnet.layer4 110 | 111 | return pretrained 112 | 113 | 114 | def _make_pretrained_resnext101_wsl(use_pretrained): 115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 116 | return _make_resnet_backbone(resnet) 117 | 118 | 119 | 120 | class Interpolate(nn.Module): 121 | """Interpolation module. 122 | """ 123 | 124 | def __init__(self, scale_factor, mode, align_corners=False): 125 | """Init. 126 | 127 | Args: 128 | scale_factor (float): scaling 129 | mode (str): interpolation mode 130 | """ 131 | super(Interpolate, self).__init__() 132 | 133 | self.interp = nn.functional.interpolate 134 | self.scale_factor = scale_factor 135 | self.mode = mode 136 | self.align_corners = align_corners 137 | 138 | def forward(self, x): 139 | """Forward pass. 140 | 141 | Args: 142 | x (tensor): input 143 | 144 | Returns: 145 | tensor: interpolated data 146 | """ 147 | 148 | x = self.interp( 149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners 150 | ) 151 | 152 | return x 153 | 154 | 155 | class ResidualConvUnit(nn.Module): 156 | """Residual convolution module. 157 | """ 158 | 159 | def __init__(self, features): 160 | """Init. 161 | 162 | Args: 163 | features (int): number of features 164 | """ 165 | super().__init__() 166 | 167 | self.conv1 = nn.Conv2d( 168 | features, features, kernel_size=3, stride=1, padding=1, bias=True 169 | ) 170 | 171 | self.conv2 = nn.Conv2d( 172 | features, features, kernel_size=3, stride=1, padding=1, bias=True 173 | ) 174 | 175 | self.relu = nn.ReLU(inplace=True) 176 | 177 | def forward(self, x): 178 | """Forward pass. 179 | 180 | Args: 181 | x (tensor): input 182 | 183 | Returns: 184 | tensor: output 185 | """ 186 | out = self.relu(x) 187 | out = self.conv1(out) 188 | out = self.relu(out) 189 | out = self.conv2(out) 190 | 191 | return out + x 192 | 193 | 194 | class FeatureFusionBlock(nn.Module): 195 | """Feature fusion block. 196 | """ 197 | 198 | def __init__(self, features): 199 | """Init. 200 | 201 | Args: 202 | features (int): number of features 203 | """ 204 | super(FeatureFusionBlock, self).__init__() 205 | 206 | self.resConfUnit1 = ResidualConvUnit(features) 207 | self.resConfUnit2 = ResidualConvUnit(features) 208 | 209 | def forward(self, *xs): 210 | """Forward pass. 211 | 212 | Returns: 213 | tensor: output 214 | """ 215 | output = xs[0] 216 | 217 | if len(xs) == 2: 218 | output += self.resConfUnit1(xs[1]) 219 | 220 | output = self.resConfUnit2(output) 221 | 222 | output = nn.functional.interpolate( 223 | output, scale_factor=2, mode="bilinear", align_corners=True 224 | ) 225 | 226 | return output 227 | 228 | 229 | 230 | 231 | class ResidualConvUnit_custom(nn.Module): 232 | """Residual convolution module. 233 | """ 234 | 235 | def __init__(self, features, activation, bn): 236 | """Init. 237 | 238 | Args: 239 | features (int): number of features 240 | """ 241 | super().__init__() 242 | 243 | self.bn = bn 244 | 245 | self.groups=1 246 | 247 | self.conv1 = nn.Conv2d( 248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 249 | ) 250 | 251 | self.conv2 = nn.Conv2d( 252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 253 | ) 254 | 255 | if self.bn==True: 256 | self.bn1 = nn.BatchNorm2d(features) 257 | self.bn2 = nn.BatchNorm2d(features) 258 | 259 | self.activation = activation 260 | 261 | self.skip_add = nn.quantized.FloatFunctional() 262 | 263 | def forward(self, x): 264 | """Forward pass. 265 | 266 | Args: 267 | x (tensor): input 268 | 269 | Returns: 270 | tensor: output 271 | """ 272 | 273 | out = self.activation(x) 274 | out = self.conv1(out) 275 | if self.bn==True: 276 | out = self.bn1(out) 277 | 278 | out = self.activation(out) 279 | out = self.conv2(out) 280 | if self.bn==True: 281 | out = self.bn2(out) 282 | 283 | if self.groups > 1: 284 | out = self.conv_merge(out) 285 | 286 | return self.skip_add.add(out, x) 287 | 288 | # return out + x 289 | 290 | 291 | class FeatureFusionBlock_custom(nn.Module): 292 | """Feature fusion block. 293 | """ 294 | 295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): 296 | """Init. 297 | 298 | Args: 299 | features (int): number of features 300 | """ 301 | super(FeatureFusionBlock_custom, self).__init__() 302 | 303 | self.deconv = deconv 304 | self.align_corners = align_corners 305 | 306 | self.groups=1 307 | 308 | self.expand = expand 309 | out_features = features 310 | if self.expand==True: 311 | out_features = features//2 312 | 313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 314 | 315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) 316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) 317 | 318 | self.skip_add = nn.quantized.FloatFunctional() 319 | 320 | def forward(self, *xs): 321 | """Forward pass. 322 | 323 | Returns: 324 | tensor: output 325 | """ 326 | output = xs[0] 327 | 328 | if len(xs) == 2: 329 | res = self.resConfUnit1(xs[1]) 330 | output = self.skip_add.add(output, res) 331 | # output += res 332 | 333 | output = self.resConfUnit2(output) 334 | 335 | output = nn.functional.interpolate( 336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 337 | ) 338 | 339 | output = self.out_conv(output) 340 | 341 | return output 342 | 343 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from ldm.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "squaredcos_cap_v2": # used for karlo prior 38 | # return early 39 | return betas_for_alpha_bar( 40 | n_timestep, 41 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 42 | ) 43 | 44 | elif schedule == "sqrt_linear": 45 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 46 | elif schedule == "sqrt": 47 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 48 | else: 49 | raise ValueError(f"schedule '{schedule}' unknown.") 50 | return betas.numpy() 51 | 52 | 53 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 54 | if ddim_discr_method == 'uniform': 55 | c = num_ddpm_timesteps // num_ddim_timesteps 56 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 57 | elif ddim_discr_method == 'quad': 58 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 59 | else: 60 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 61 | 62 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 63 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 64 | steps_out = ddim_timesteps + 1 65 | if verbose: 66 | print(f'Selected timesteps for ddim sampler: {steps_out}') 67 | return steps_out 68 | 69 | 70 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 71 | # select alphas for computing the variance schedule 72 | alphas = alphacums[ddim_timesteps] 73 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 74 | 75 | # according the the formula provided in https://arxiv.org/abs/2010.02502 76 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 77 | if verbose: 78 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 79 | print(f'For the chosen value of eta, which is {eta}, ' 80 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 81 | return sigmas, alphas, alphas_prev 82 | 83 | 84 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 85 | """ 86 | Create a beta schedule that discretizes the given alpha_t_bar function, 87 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 88 | :param num_diffusion_timesteps: the number of betas to produce. 89 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 90 | produces the cumulative product of (1-beta) up to that 91 | part of the diffusion process. 92 | :param max_beta: the maximum beta to use; use values lower than 1 to 93 | prevent singularities. 94 | """ 95 | betas = [] 96 | for i in range(num_diffusion_timesteps): 97 | t1 = i / num_diffusion_timesteps 98 | t2 = (i + 1) / num_diffusion_timesteps 99 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 100 | return np.array(betas) 101 | 102 | 103 | def extract_into_tensor(a, t, x_shape): 104 | b, *_ = t.shape 105 | out = a.gather(-1, t) 106 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 107 | 108 | 109 | def checkpoint(func, inputs, params, flag): 110 | """ 111 | Evaluate a function without caching intermediate activations, allowing for 112 | reduced memory at the expense of extra compute in the backward pass. 113 | :param func: the function to evaluate. 114 | :param inputs: the argument sequence to pass to `func`. 115 | :param params: a sequence of parameters `func` depends on but does not 116 | explicitly take as arguments. 117 | :param flag: if False, disable gradient checkpointing. 118 | """ 119 | if flag: 120 | args = tuple(inputs) + tuple(params) 121 | return CheckpointFunction.apply(func, len(inputs), *args) 122 | else: 123 | return func(*inputs) 124 | 125 | 126 | class CheckpointFunction(torch.autograd.Function): 127 | @staticmethod 128 | def forward(ctx, run_function, length, *args): 129 | ctx.run_function = run_function 130 | ctx.input_tensors = list(args[:length]) 131 | ctx.input_params = list(args[length:]) 132 | ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), 133 | "dtype": torch.get_autocast_gpu_dtype(), 134 | "cache_enabled": torch.is_autocast_cache_enabled()} 135 | with torch.no_grad(): 136 | output_tensors = ctx.run_function(*ctx.input_tensors) 137 | return output_tensors 138 | 139 | @staticmethod 140 | def backward(ctx, *output_grads): 141 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 142 | with torch.enable_grad(), \ 143 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): 144 | # Fixes a bug where the first op in run_function modifies the 145 | # Tensor storage in place, which is not allowed for detach()'d 146 | # Tensors. 147 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 148 | output_tensors = ctx.run_function(*shallow_copies) 149 | input_grads = torch.autograd.grad( 150 | output_tensors, 151 | ctx.input_tensors + ctx.input_params, 152 | output_grads, 153 | allow_unused=True, 154 | ) 155 | del ctx.input_tensors 156 | del ctx.input_params 157 | del output_tensors 158 | return (None, None) + input_grads 159 | 160 | 161 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 162 | """ 163 | Create sinusoidal timestep embeddings. 164 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 165 | These may be fractional. 166 | :param dim: the dimension of the output. 167 | :param max_period: controls the minimum frequency of the embeddings. 168 | :return: an [N x dim] Tensor of positional embeddings. 169 | """ 170 | if not repeat_only: 171 | half = dim // 2 172 | freqs = torch.exp( 173 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 174 | ).to(device=timesteps.device) 175 | args = timesteps[:, None].float() * freqs[None] 176 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 177 | if dim % 2: 178 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 179 | else: 180 | embedding = repeat(timesteps, 'b -> b d', d=dim) 181 | return embedding 182 | 183 | 184 | def zero_module(module): 185 | """ 186 | Zero out the parameters of a module and return it. 187 | """ 188 | for p in module.parameters(): 189 | p.detach().zero_() 190 | return module 191 | 192 | 193 | def scale_module(module, scale): 194 | """ 195 | Scale the parameters of a module and return it. 196 | """ 197 | for p in module.parameters(): 198 | p.detach().mul_(scale) 199 | return module 200 | 201 | 202 | def mean_flat(tensor): 203 | """ 204 | Take the mean over all non-batch dimensions. 205 | """ 206 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 207 | 208 | 209 | def normalization(channels): 210 | """ 211 | Make a standard normalization layer. 212 | :param channels: number of input channels. 213 | :return: an nn.Module for normalization. 214 | """ 215 | return GroupNorm32(32, channels) 216 | 217 | 218 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 219 | class SiLU(nn.Module): 220 | def forward(self, x): 221 | return x * torch.sigmoid(x) 222 | 223 | 224 | class GroupNorm32(nn.GroupNorm): 225 | def forward(self, x): 226 | return super().forward(x.float()).type(x.dtype) 227 | 228 | 229 | def conv_nd(dims, *args, **kwargs): 230 | """ 231 | Create a 1D, 2D, or 3D convolution module. 232 | """ 233 | if dims == 1: 234 | return nn.Conv1d(*args, **kwargs) 235 | elif dims == 2: 236 | return nn.Conv2d(*args, **kwargs) 237 | elif dims == 3: 238 | return nn.Conv3d(*args, **kwargs) 239 | raise ValueError(f"unsupported dimensions: {dims}") 240 | 241 | 242 | def linear(*args, **kwargs): 243 | """ 244 | Create a linear module. 245 | """ 246 | return nn.Linear(*args, **kwargs) 247 | 248 | 249 | def avg_pool_nd(dims, *args, **kwargs): 250 | """ 251 | Create a 1D, 2D, or 3D average pooling module. 252 | """ 253 | if dims == 1: 254 | return nn.AvgPool1d(*args, **kwargs) 255 | elif dims == 2: 256 | return nn.AvgPool2d(*args, **kwargs) 257 | elif dims == 3: 258 | return nn.AvgPool3d(*args, **kwargs) 259 | raise ValueError(f"unsupported dimensions: {dims}") 260 | 261 | 262 | class HybridConditioner(nn.Module): 263 | 264 | def __init__(self, c_concat_config, c_crossattn_config): 265 | super().__init__() 266 | self.concat_conditioner = instantiate_from_config(c_concat_config) 267 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 268 | 269 | def forward(self, c_concat, c_crossattn): 270 | c_concat = self.concat_conditioner(c_concat) 271 | c_crossattn = self.crossattn_conditioner(c_crossattn) 272 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 273 | 274 | 275 | def noise_like(shape, device, repeat=False): 276 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 277 | noise = lambda: torch.randn(shape, device=device) 278 | return repeat_noise() if repeat else noise() 279 | -------------------------------------------------------------------------------- /scripts/semantic_i2i.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import optim 4 | from torch.utils.data import Dataset, DataLoader, random_split 5 | from scripts.dataset import Flickr8kDataset,Only_images_Flickr8kDataset 6 | from itertools import islice 7 | from ldm.util import instantiate_from_config 8 | from PIL import Image 9 | import PIL 10 | import torch 11 | import numpy as np 12 | import argparse, os 13 | from omegaconf import OmegaConf 14 | from pytorch_lightning import seed_everything 15 | from imwatermark import WatermarkEncoder 16 | from ldm.models.diffusion.ddim import DDIMSampler 17 | from tqdm import tqdm 18 | import lpips as lp 19 | from einops import rearrange, repeat 20 | from torch import autocast 21 | from tqdm import tqdm, trange 22 | from transformers import pipeline 23 | from scripts.qam import qam16ModulationTensor, qam16ModulationString 24 | import time 25 | 26 | from SSIM_PIL import compare_ssim 27 | 28 | ''' 29 | 30 | INIT DATASET AND DATALOADER 31 | 32 | ''' 33 | capt_file_path = "path/to/captions.txt" #"G:/Giordano/Flickr8kDataset/captions.txt" 34 | images_dir_path = "path/to/Images" #"G:/Giordano/Flickr8kDataset/Images/" 35 | batch_size = 1 36 | 37 | dataset = Only_images_Flickr8kDataset(images_dir_path) 38 | 39 | test_dataloader=DataLoader(dataset=dataset,batch_size=batch_size, shuffle=True) 40 | 41 | 42 | ''' 43 | MODEL CHECKPOINT 44 | 45 | ''' 46 | 47 | 48 | model_ckpt_path = "path/to/model-checkpoint" #"G:/Giordano/stablediffusion/checkpoints/v1-5-pruned.ckpt" #v2-1_512-ema-pruned.ckpt" 49 | config_path = "path/to/model-config" #"G:/Giordano/stablediffusion/configs/stable-diffusion/v1-inference.yaml" 50 | 51 | 52 | 53 | 54 | def load_model_from_config(config, ckpt, verbose=False): 55 | print(f"Loading model from {ckpt}") 56 | pl_sd = torch.load(ckpt, map_location="cpu") 57 | if "global_step" in pl_sd: 58 | print(f"Global Step: {pl_sd['global_step']}") 59 | sd = pl_sd["state_dict"] 60 | model = instantiate_from_config(config.model) 61 | m, u = model.load_state_dict(sd, strict=False) 62 | if len(m) > 0 and verbose: 63 | print("missing keys:") 64 | print(m) 65 | if len(u) > 0 and verbose: 66 | print("unexpected keys:") 67 | print(u) 68 | 69 | model.cuda() 70 | model.eval() 71 | return model 72 | 73 | 74 | def load_img(path): 75 | image = Image.open(path).convert("RGB") 76 | w, h = (512,512)#image.size 77 | #print(f"loaded input image of size ({w}, {h}) from {path}") 78 | w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 79 | image = image.resize((w, h), resample=PIL.Image.LANCZOS) 80 | image = np.array(image).astype(np.float32) / 255.0 81 | image = image[None].transpose(0, 3, 1, 2) 82 | image = torch.from_numpy(image) 83 | return 2. * image - 1. 84 | 85 | 86 | 87 | 88 | 89 | def test(dataloader, 90 | snr=10, 91 | num_images=100, 92 | batch_size=1, 93 | num_images_per_sample=2, 94 | outpath='', 95 | model=None, 96 | device=None, 97 | sampler=None, 98 | strength=0.8, 99 | ddim_steps=50, 100 | scale=9.0): 101 | 102 | blip = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large") 103 | i=0 104 | 105 | sampling_steps = int(strength*50) 106 | print(sampling_steps) 107 | sampler.make_schedule(ddim_num_steps=50, ddim_eta=0.0, verbose=False) #attenzione ai parametri 108 | sample_path = os.path.join(outpath, f"Test-samples-{snr}-{sampling_steps}") 109 | os.makedirs(sample_path, exist_ok=True) 110 | 111 | text_path = os.path.join(outpath, f"Test-text-samples-{snr}-{sampling_steps}") 112 | os.makedirs(text_path, exist_ok=True) 113 | 114 | sample_orig_path = os.path.join(outpath, f"Test-samples-orig-{snr}-{sampling_steps}") 115 | os.makedirs(sample_orig_path, exist_ok=True) 116 | 117 | lpips = lp.LPIPS(net='alex') 118 | lpips_values = [] 119 | 120 | ssim_values = [] 121 | 122 | time_values = [] 123 | 124 | tq = tqdm(dataloader,total=num_images) 125 | for batch in tq: 126 | 127 | img_file_path = batch[0] 128 | 129 | #Open Image 130 | init_image = Image.open(img_file_path) 131 | 132 | #Automatically extract caption using BLIP model 133 | prompt = blip(init_image)[0]["generated_text"] 134 | prompt_original = prompt 135 | 136 | base_count = len(os.listdir(sample_path)) 137 | 138 | assert os.path.isfile(img_file_path) 139 | init_image = load_img(img_file_path).to(device) 140 | init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) 141 | init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space 142 | 143 | #print(init_latent.shape,init_latent.type()) 144 | 145 | ''' 146 | CHANNEL SIMULATION 147 | ''' 148 | 149 | init_latent = qam16ModulationTensor(init_latent.cpu(),snr_db=snr).to(device) 150 | 151 | prompt = qam16ModulationString(prompt,snr_db=snr) #NOISY BLIP PROMPT 152 | 153 | data = [batch_size * [prompt]] 154 | assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]' 155 | t_enc = int(strength * ddim_steps) 156 | 157 | precision_scope = autocast 158 | with torch.no_grad(): 159 | with precision_scope("cuda"): 160 | with model.ema_scope(): 161 | all_samples = list() 162 | for n in range(1): 163 | for prompts in data: 164 | start_time = time.time() 165 | uc = None 166 | if scale != 1.0: 167 | uc = model.get_learned_conditioning(batch_size * [""]) 168 | if isinstance(prompts, tuple): 169 | prompts = list(prompts) 170 | c = model.get_learned_conditioning(prompts) 171 | # encode (scaled latent) 172 | z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(device)) 173 | # z_enc = init_latent 174 | # decode it 175 | samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=scale, 176 | unconditional_conditioning=uc, ) 177 | x_samples = model.decode_first_stage(samples) 178 | 179 | x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) 180 | 181 | end_time = time.time() 182 | execution_time = end_time - start_time 183 | 184 | time_values.append(execution_time) 185 | 186 | for x_sample in x_samples: 187 | x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') 188 | img = Image.fromarray(x_sample.astype(np.uint8)) 189 | #SAVE IMAGE 190 | img.save(os.path.join(sample_path, f"{base_count:05}.png")) 191 | #SAVE TEXT 192 | f = open(os.path.join(text_path, f"{base_count:05}.txt"),"a") 193 | f.write(prompt_original) 194 | f.close() 195 | 196 | #SAVE ORIGINAL IMAGE 197 | init_image_copy = Image.open(img_file_path) 198 | init_image_copy = init_image_copy.resize((512, 512), resample=PIL.Image.LANCZOS) 199 | init_image_copy.save(os.path.join(sample_orig_path, f"{base_count:05}.png")) 200 | 201 | # Compute SSIM 202 | ssim_values.append(compare_ssim(init_image_copy, img)) 203 | base_count += 1 204 | all_samples.append(x_samples) 205 | 206 | #Compute LPIPS 207 | sample_out = (all_samples[0][0] * 2) - 1 208 | lp_score=lpips(init_image[0].cpu(),sample_out.cpu()).item() 209 | 210 | tq.set_postfix(lpips=lp_score) 211 | 212 | if not np.isnan(lp_score): 213 | lpips_values.append(lp_score) 214 | 215 | 216 | i+=1 217 | if i== num_images: 218 | break 219 | 220 | 221 | print(f'mean lpips score at snr={snr} : {sum(lpips_values)/len(lpips_values)}') 222 | print(f'mean ssim score at snr={snr} : {sum(ssim_values)/len(ssim_values)}') 223 | print(f'mean time with sampling iterations {sampling_steps} : {sum(time_values)/len(time_values)}') 224 | return 1 225 | 226 | 227 | 228 | if __name__ == "__main__": 229 | 230 | parser = argparse.ArgumentParser() 231 | 232 | parser.add_argument( 233 | "--outdir", 234 | type=str, 235 | nargs="?", 236 | help="dir to write results to", 237 | default="outputs/img2img-samples" 238 | ) 239 | 240 | parser.add_argument( 241 | "--seed", 242 | type=int, 243 | default=42, 244 | help="the seed (for reproducible sampling)", 245 | ) 246 | 247 | 248 | opt = parser.parse_args() 249 | seed_everything(opt.seed) 250 | 251 | config = OmegaConf.load(f"{config_path}") 252 | model = load_model_from_config(config, f"{model_ckpt_path}") 253 | 254 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 255 | model = model.to(device) 256 | 257 | sampler = DDIMSampler(model) 258 | 259 | os.makedirs(opt.outdir, exist_ok=True) 260 | outpath = opt.outdir 261 | 262 | #INIZIO TEST 263 | 264 | #Strength is used to modulate the number of sampling steps. Steps=50*strength 265 | test(test_dataloader,snr=10,num_images=100,batch_size=1,num_images_per_sample=1,outpath=outpath, 266 | model=model,device=device,sampler=sampler,strength=0.6,scale=9) 267 | 268 | test(test_dataloader,snr=8.75,num_images=100,batch_size=1,num_images_per_sample=1,outpath=outpath, 269 | model=model,device=device,sampler=sampler,strength=0.6,scale=9) 270 | 271 | test(test_dataloader,snr=7.50,num_images=100,batch_size=1,num_images_per_sample=1,outpath=outpath, 272 | model=model,device=device,sampler=sampler,strength=0.6,scale=9) 273 | 274 | test(test_dataloader,snr=6.25,num_images=100,batch_size=1,num_images_per_sample=1,outpath=outpath, 275 | model=model,device=device,sampler=sampler,strength=0.6,scale=9) 276 | 277 | test(test_dataloader,snr=5,num_images=100,batch_size=1,num_images_per_sample=1,outpath=outpath, 278 | model=model,device=device,sampler=sampler,strength=0.6,scale=9) 279 | 280 | --------------------------------------------------------------------------------