├── ldm
├── data
│ ├── __init__.py
│ └── util.py
├── models
│ ├── diffusion
│ │ ├── __init__.py
│ │ ├── dpm_solver
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-39.pyc
│ │ │ │ ├── sampler.cpython-312.pyc
│ │ │ │ ├── sampler.cpython-39.pyc
│ │ │ │ ├── __init__.cpython-312.pyc
│ │ │ │ ├── dpm_solver.cpython-312.pyc
│ │ │ │ └── dpm_solver.cpython-39.pyc
│ │ │ └── sampler.py
│ │ └── sampling_util.py
│ └── autoencoder.py
├── modules
│ ├── encoders
│ │ └── __init__.py
│ ├── karlo
│ │ ├── __init__.py
│ │ └── kakao
│ │ │ ├── __init__.py
│ │ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── clip.cpython-39.pyc
│ │ │ │ ├── __init__.cpython-39.pyc
│ │ │ │ ├── sr_64_256.cpython-39.pyc
│ │ │ │ ├── prior_model.cpython-39.pyc
│ │ │ │ └── decoder_model.cpython-39.pyc
│ │ │ ├── sr_256_1k.py
│ │ │ ├── sr_64_256.py
│ │ │ ├── prior_model.py
│ │ │ ├── clip.py
│ │ │ └── decoder_model.py
│ │ │ ├── modules
│ │ │ ├── diffusion
│ │ │ │ ├── __pycache__
│ │ │ │ │ ├── respace.cpython-39.pyc
│ │ │ │ │ └── gaussian_diffusion.cpython-39.pyc
│ │ │ │ └── respace.py
│ │ │ ├── __init__.py
│ │ │ ├── resample.py
│ │ │ ├── nn.py
│ │ │ └── xf.py
│ │ │ ├── template.py
│ │ │ └── sampler.py
│ ├── midas
│ │ ├── __init__.py
│ │ ├── midas
│ │ │ ├── __init__.py
│ │ │ ├── base_model.py
│ │ │ ├── midas_net.py
│ │ │ ├── dpt_depth.py
│ │ │ ├── midas_net_custom.py
│ │ │ ├── transforms.py
│ │ │ └── blocks.py
│ │ ├── utils.py
│ │ └── api.py
│ ├── distributions
│ │ ├── __init__.py
│ │ └── distributions.py
│ ├── diffusionmodules
│ │ ├── __init__.py
│ │ ├── upscaling.py
│ │ └── util.py
│ ├── image_degradation
│ │ ├── utils
│ │ │ └── test.png
│ │ └── __init__.py
│ └── ema.py
└── util.py
├── stable_diffusion.egg-info
├── top_level.txt
├── dependency_links.txt
├── requires.txt
├── PKG-INFO
└── SOURCES.txt
├── samples_paper
├── image.png
├── results.png
├── Architecture.PNG
├── firstPageImg.PNG
├── 17273391_55cfc7d3d4.jpg
├── 55473406_1d2271c1f2.jpg
├── 96985174_31d4c6f06d.jpg
└── 431282339_0aa60dd78e.jpg
├── setup.py
├── requirements.txt
├── environment.yaml
├── LICENSE
├── configs
└── stable-diffusion
│ ├── v2-inference.yaml
│ ├── v2-inference-v.yaml
│ ├── intel
│ ├── v2-inference-fp32.yaml
│ ├── v2-inference-bf16.yaml
│ ├── v2-inference-v-fp32.yaml
│ └── v2-inference-v-bf16.yaml
│ └── v1-inference.yaml
├── scripts
├── evaluate_metrics.py
├── metrics.py
├── dataset.py
├── qam.py
├── semantic_t2i.py
└── semantic_i2i.py
├── .gitignore
└── README.md
/ldm/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ldm/modules/karlo/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ldm/modules/midas/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ldm/modules/karlo/kakao/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ldm/modules/karlo/kakao/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/stable_diffusion.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/stable_diffusion.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/stable_diffusion.egg-info/requires.txt:
--------------------------------------------------------------------------------
1 | torch
2 | numpy
3 | tqdm
4 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/__init__.py:
--------------------------------------------------------------------------------
1 | from .sampler import DPMSolverSampler
--------------------------------------------------------------------------------
/samples_paper/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/samples_paper/image.png
--------------------------------------------------------------------------------
/samples_paper/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/samples_paper/results.png
--------------------------------------------------------------------------------
/samples_paper/Architecture.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/samples_paper/Architecture.PNG
--------------------------------------------------------------------------------
/samples_paper/firstPageImg.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/samples_paper/firstPageImg.PNG
--------------------------------------------------------------------------------
/samples_paper/17273391_55cfc7d3d4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/samples_paper/17273391_55cfc7d3d4.jpg
--------------------------------------------------------------------------------
/samples_paper/55473406_1d2271c1f2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/samples_paper/55473406_1d2271c1f2.jpg
--------------------------------------------------------------------------------
/samples_paper/96985174_31d4c6f06d.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/samples_paper/96985174_31d4c6f06d.jpg
--------------------------------------------------------------------------------
/samples_paper/431282339_0aa60dd78e.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/samples_paper/431282339_0aa60dd78e.jpg
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/utils/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/modules/image_degradation/utils/test.png
--------------------------------------------------------------------------------
/ldm/modules/karlo/kakao/models/__pycache__/clip.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/modules/karlo/kakao/models/__pycache__/clip.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-312.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/modules/karlo/kakao/models/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/modules/karlo/kakao/models/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/modules/karlo/kakao/models/__pycache__/sr_64_256.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/modules/karlo/kakao/models/__pycache__/sr_64_256.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-312.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-312.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/modules/karlo/kakao/models/__pycache__/prior_model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/modules/karlo/kakao/models/__pycache__/prior_model.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/modules/karlo/kakao/models/__pycache__/decoder_model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/modules/karlo/kakao/models/__pycache__/decoder_model.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/modules/karlo/kakao/modules/diffusion/__pycache__/respace.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/modules/karlo/kakao/modules/diffusion/__pycache__/respace.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/modules/karlo/kakao/modules/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/Img2Img-SC/HEAD/ldm/modules/karlo/kakao/modules/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc
--------------------------------------------------------------------------------
/stable_diffusion.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.1
2 | Name: stable-diffusion
3 | Version: 0.0.1
4 | License-File: LICENSE
5 | License-File: LICENSE-MODEL
6 | Requires-Dist: torch
7 | Requires-Dist: numpy
8 | Requires-Dist: tqdm
9 |
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='stable-diffusion',
5 | version='0.0.1',
6 | description='',
7 | packages=find_packages(),
8 | install_requires=[
9 | 'torch',
10 | 'numpy',
11 | 'tqdm',
12 | ],
13 | )
--------------------------------------------------------------------------------
/stable_diffusion.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | LICENSE
2 | LICENSE-MODEL
3 | README.md
4 | setup.py
5 | stable_diffusion.egg-info/PKG-INFO
6 | stable_diffusion.egg-info/SOURCES.txt
7 | stable_diffusion.egg-info/dependency_links.txt
8 | stable_diffusion.egg-info/requires.txt
9 | stable_diffusion.egg-info/top_level.txt
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class BaseModel(torch.nn.Module):
5 | def load(self, path):
6 | """Load model from file.
7 |
8 | Args:
9 | path (str): file path
10 | """
11 | parameters = torch.load(path, map_location=torch.device('cpu'))
12 |
13 | if "optimizer" in parameters:
14 | parameters = parameters["model"]
15 |
16 | self.load_state_dict(parameters)
17 |
--------------------------------------------------------------------------------
/ldm/modules/karlo/kakao/models/sr_256_1k.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Karlo-v1.0.alpha
3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 |
6 | from ldm.modules.karlo.kakao.models.sr_64_256 import SupRes64to256Progressive
7 |
8 |
9 | class SupRes256to1kProgressive(SupRes64to256Progressive):
10 | pass # no difference currently
11 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | albumentations==0.4.3
2 | opencv-python
3 | pudb==2019.2
4 | imageio==2.9.0
5 | imageio-ffmpeg==0.4.2
6 | pytorch-lightning==1.6.0 #pytorch-lightning==1.4.2
7 | torchmetrics==0.6
8 | omegaconf==2.1.1
9 | test-tube>=0.7.5
10 | streamlit>=0.73.1
11 | einops==0.3.0
12 | transformers
13 | webdataset==0.2.5
14 | open-clip-torch==2.7.0
15 | kornia==0.6
16 | invisible-watermark>=0.1.5
17 | streamlit-drawable-canvas==0.8.0
18 |
19 | diffusers
20 | bitstring
21 | lpips
22 | -q kaggle
23 | accelerate
24 | SSIM-PIL
25 |
26 |
27 | -e .
28 |
--------------------------------------------------------------------------------
/ldm/data/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ldm.modules.midas.api import load_midas_transform
4 |
5 |
6 | class AddMiDaS(object):
7 | def __init__(self, model_type):
8 | super().__init__()
9 | self.transform = load_midas_transform(model_type)
10 |
11 | def pt2np(self, x):
12 | x = ((x + 1.0) * .5).detach().cpu().numpy()
13 | return x
14 |
15 | def np2pt(self, x):
16 | x = torch.from_numpy(x) * 2 - 1.
17 | return x
18 |
19 | def __call__(self, sample):
20 | # sample['jpg'] is tensor hwc in [-1, 1] at this point
21 | x = self.pt2np(sample['jpg'])
22 | x = self.transform({"image": x})["image"]
23 | sample['midas_in'] = x
24 | return sample
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: ldm
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.5
7 | - pip=20.3
8 | - cudatoolkit=11.3
9 | - pytorch=1.12.1
10 | - torchvision=0.13.1
11 | - numpy=1.23.1
12 | - pip:
13 | - albumentations==1.3.0
14 | - opencv-python==4.6.0.66
15 | - imageio==2.9.0
16 | - imageio-ffmpeg==0.4.2
17 | - pytorch-lightning==1.4.2
18 | - omegaconf==2.1.1
19 | - test-tube>=0.7.5
20 | - streamlit==1.12.1
21 | - einops==0.3.0
22 | - transformers==4.19.2
23 | - webdataset==0.2.5
24 | - kornia==0.6
25 | - open_clip_torch==2.0.2
26 | - invisible-watermark>=0.1.5
27 | - streamlit-drawable-canvas==0.8.0
28 | - torchmetrics==0.6.0
29 | - -e .
30 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/sampling_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def append_dims(x, target_dims):
6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.
7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
8 | dims_to_append = target_dims - x.ndim
9 | if dims_to_append < 0:
10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
11 | return x[(...,) + (None,) * dims_to_append]
12 |
13 |
14 | def norm_thresholding(x0, value):
15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
16 | return x0 * (value / s)
17 |
18 |
19 | def spatial_norm_thresholding(x0, value):
20 | # b c h w
21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
22 | return x0 * (value / s)
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Stability AI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/ldm/modules/karlo/kakao/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
3 | # ------------------------------------------------------------------------------------
4 |
5 |
6 | from .diffusion import gaussian_diffusion as gd
7 | from .diffusion.respace import (
8 | SpacedDiffusion,
9 | space_timesteps,
10 | )
11 |
12 |
13 | def create_gaussian_diffusion(
14 | steps,
15 | learn_sigma,
16 | sigma_small,
17 | noise_schedule,
18 | use_kl,
19 | predict_xstart,
20 | rescale_learned_sigmas,
21 | timestep_respacing,
22 | ):
23 | betas = gd.get_named_beta_schedule(noise_schedule, steps)
24 | if use_kl:
25 | loss_type = gd.LossType.RESCALED_KL
26 | elif rescale_learned_sigmas:
27 | loss_type = gd.LossType.RESCALED_MSE
28 | else:
29 | loss_type = gd.LossType.MSE
30 | if not timestep_respacing:
31 | timestep_respacing = [steps]
32 |
33 | return SpacedDiffusion(
34 | use_timesteps=space_timesteps(steps, timestep_respacing),
35 | betas=betas,
36 | model_mean_type=(
37 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
38 | ),
39 | model_var_type=(
40 | (
41 | gd.ModelVarType.FIXED_LARGE
42 | if not sigma_small
43 | else gd.ModelVarType.FIXED_SMALL
44 | )
45 | if not learn_sigma
46 | else gd.ModelVarType.LEARNED_RANGE
47 | ),
48 | loss_type=loss_type,
49 | )
50 |
--------------------------------------------------------------------------------
/configs/stable-diffusion/v2-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-4
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False # we set this to false because this is an inference only config
19 |
20 | unet_config:
21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22 | params:
23 | use_checkpoint: True
24 | #use_fp16: True
25 | image_size: 32 # unused
26 | in_channels: 4
27 | out_channels: 4
28 | model_channels: 320
29 | attention_resolutions: [ 4, 2, 1 ]
30 | num_res_blocks: 2
31 | channel_mult: [ 1, 2, 4, 4 ]
32 | num_head_channels: 64 # need to fix for flash-attn
33 | use_spatial_transformer: True
34 | use_linear_in_transformer: True
35 | transformer_depth: 1
36 | context_dim: 1024
37 | legacy: False
38 |
39 | first_stage_config:
40 | target: ldm.models.autoencoder.AutoencoderKL
41 | params:
42 | embed_dim: 4
43 | monitor: val/rec_loss
44 | ddconfig:
45 | #attn_type: "vanilla-xformers"
46 | double_z: true
47 | z_channels: 4
48 | resolution: 256
49 | in_channels: 3
50 | out_ch: 3
51 | ch: 128
52 | ch_mult:
53 | - 1
54 | - 2
55 | - 4
56 | - 4
57 | num_res_blocks: 2
58 | attn_resolutions: []
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config:
64 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
65 | params:
66 | freeze: True
67 | layer: "penultimate"
68 |
--------------------------------------------------------------------------------
/configs/stable-diffusion/v2-inference-v.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-4
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | parameterization: "v"
6 | linear_start: 0.00085
7 | linear_end: 0.0120
8 | num_timesteps_cond: 1
9 | log_every_t: 200
10 | timesteps: 1000
11 | first_stage_key: "jpg"
12 | cond_stage_key: "txt"
13 | image_size: 64
14 | channels: 4
15 | cond_stage_trainable: false
16 | conditioning_key: crossattn
17 | monitor: val/loss_simple_ema
18 | scale_factor: 0.18215
19 | use_ema: False # we set this to false because this is an inference only config
20 |
21 | unet_config:
22 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
23 | params:
24 | use_checkpoint: True
25 | use_fp16: True
26 | image_size: 32 # unused
27 | in_channels: 4
28 | out_channels: 4
29 | model_channels: 320
30 | attention_resolutions: [ 4, 2, 1 ]
31 | num_res_blocks: 2
32 | channel_mult: [ 1, 2, 4, 4 ]
33 | num_head_channels: 64 # need to fix for flash-attn
34 | use_spatial_transformer: True
35 | use_linear_in_transformer: True
36 | transformer_depth: 1
37 | context_dim: 1024
38 | legacy: False
39 |
40 | first_stage_config:
41 | target: ldm.models.autoencoder.AutoencoderKL
42 | params:
43 | embed_dim: 4
44 | monitor: val/rec_loss
45 | ddconfig:
46 | #attn_type: "vanilla-xformers"
47 | double_z: true
48 | z_channels: 4
49 | resolution: 256
50 | in_channels: 3
51 | out_ch: 3
52 | ch: 128
53 | ch_mult:
54 | - 1
55 | - 2
56 | - 4
57 | - 4
58 | num_res_blocks: 2
59 | attn_resolutions: []
60 | dropout: 0.0
61 | lossconfig:
62 | target: torch.nn.Identity
63 |
64 | cond_stage_config:
65 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
66 | params:
67 | freeze: True
68 | layer: "penultimate"
69 |
--------------------------------------------------------------------------------
/configs/stable-diffusion/intel/v2-inference-fp32.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022 Intel Corporation
2 | # SPDX-License-Identifier: MIT
3 |
4 | model:
5 | base_learning_rate: 1.0e-4
6 | target: ldm.models.diffusion.ddpm.LatentDiffusion
7 | params:
8 | linear_start: 0.00085
9 | linear_end: 0.0120
10 | num_timesteps_cond: 1
11 | log_every_t: 200
12 | timesteps: 1000
13 | first_stage_key: "jpg"
14 | cond_stage_key: "txt"
15 | image_size: 64
16 | channels: 4
17 | cond_stage_trainable: false
18 | conditioning_key: crossattn
19 | monitor: val/loss_simple_ema
20 | scale_factor: 0.18215
21 | use_ema: False # we set this to false because this is an inference only config
22 |
23 | unet_config:
24 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
25 | params:
26 | use_checkpoint: False
27 | use_fp16: False
28 | image_size: 32 # unused
29 | in_channels: 4
30 | out_channels: 4
31 | model_channels: 320
32 | attention_resolutions: [ 4, 2, 1 ]
33 | num_res_blocks: 2
34 | channel_mult: [ 1, 2, 4, 4 ]
35 | num_head_channels: 64 # need to fix for flash-attn
36 | use_spatial_transformer: True
37 | use_linear_in_transformer: True
38 | transformer_depth: 1
39 | context_dim: 1024
40 | legacy: False
41 |
42 | first_stage_config:
43 | target: ldm.models.autoencoder.AutoencoderKL
44 | params:
45 | embed_dim: 4
46 | monitor: val/rec_loss
47 | ddconfig:
48 | #attn_type: "vanilla-xformers"
49 | double_z: true
50 | z_channels: 4
51 | resolution: 256
52 | in_channels: 3
53 | out_ch: 3
54 | ch: 128
55 | ch_mult:
56 | - 1
57 | - 2
58 | - 4
59 | - 4
60 | num_res_blocks: 2
61 | attn_resolutions: []
62 | dropout: 0.0
63 | lossconfig:
64 | target: torch.nn.Identity
65 |
66 | cond_stage_config:
67 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
68 | params:
69 | freeze: True
70 | layer: "penultimate"
71 |
--------------------------------------------------------------------------------
/configs/stable-diffusion/intel/v2-inference-bf16.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022 Intel Corporation
2 | # SPDX-License-Identifier: MIT
3 |
4 | model:
5 | base_learning_rate: 1.0e-4
6 | target: ldm.models.diffusion.ddpm.LatentDiffusion
7 | params:
8 | linear_start: 0.00085
9 | linear_end: 0.0120
10 | num_timesteps_cond: 1
11 | log_every_t: 200
12 | timesteps: 1000
13 | first_stage_key: "jpg"
14 | cond_stage_key: "txt"
15 | image_size: 64
16 | channels: 4
17 | cond_stage_trainable: false
18 | conditioning_key: crossattn
19 | monitor: val/loss_simple_ema
20 | scale_factor: 0.18215
21 | use_ema: False # we set this to false because this is an inference only config
22 |
23 | unet_config:
24 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
25 | params:
26 | use_checkpoint: False
27 | use_fp16: False
28 | use_bf16: True
29 | image_size: 32 # unused
30 | in_channels: 4
31 | out_channels: 4
32 | model_channels: 320
33 | attention_resolutions: [ 4, 2, 1 ]
34 | num_res_blocks: 2
35 | channel_mult: [ 1, 2, 4, 4 ]
36 | num_head_channels: 64 # need to fix for flash-attn
37 | use_spatial_transformer: True
38 | use_linear_in_transformer: True
39 | transformer_depth: 1
40 | context_dim: 1024
41 | legacy: False
42 |
43 | first_stage_config:
44 | target: ldm.models.autoencoder.AutoencoderKL
45 | params:
46 | embed_dim: 4
47 | monitor: val/rec_loss
48 | ddconfig:
49 | #attn_type: "vanilla-xformers"
50 | double_z: true
51 | z_channels: 4
52 | resolution: 256
53 | in_channels: 3
54 | out_ch: 3
55 | ch: 128
56 | ch_mult:
57 | - 1
58 | - 2
59 | - 4
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions: []
63 | dropout: 0.0
64 | lossconfig:
65 | target: torch.nn.Identity
66 |
67 | cond_stage_config:
68 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
69 | params:
70 | freeze: True
71 | layer: "penultimate"
72 |
--------------------------------------------------------------------------------
/configs/stable-diffusion/intel/v2-inference-v-fp32.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022 Intel Corporation
2 | # SPDX-License-Identifier: MIT
3 |
4 | model:
5 | base_learning_rate: 1.0e-4
6 | target: ldm.models.diffusion.ddpm.LatentDiffusion
7 | params:
8 | parameterization: "v"
9 | linear_start: 0.00085
10 | linear_end: 0.0120
11 | num_timesteps_cond: 1
12 | log_every_t: 200
13 | timesteps: 1000
14 | first_stage_key: "jpg"
15 | cond_stage_key: "txt"
16 | image_size: 64
17 | channels: 4
18 | cond_stage_trainable: false
19 | conditioning_key: crossattn
20 | monitor: val/loss_simple_ema
21 | scale_factor: 0.18215
22 | use_ema: False # we set this to false because this is an inference only config
23 |
24 | unet_config:
25 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
26 | params:
27 | use_checkpoint: False
28 | use_fp16: False
29 | image_size: 32 # unused
30 | in_channels: 4
31 | out_channels: 4
32 | model_channels: 320
33 | attention_resolutions: [ 4, 2, 1 ]
34 | num_res_blocks: 2
35 | channel_mult: [ 1, 2, 4, 4 ]
36 | num_head_channels: 64 # need to fix for flash-attn
37 | use_spatial_transformer: True
38 | use_linear_in_transformer: True
39 | transformer_depth: 1
40 | context_dim: 1024
41 | legacy: False
42 |
43 | first_stage_config:
44 | target: ldm.models.autoencoder.AutoencoderKL
45 | params:
46 | embed_dim: 4
47 | monitor: val/rec_loss
48 | ddconfig:
49 | #attn_type: "vanilla-xformers"
50 | double_z: true
51 | z_channels: 4
52 | resolution: 256
53 | in_channels: 3
54 | out_ch: 3
55 | ch: 128
56 | ch_mult:
57 | - 1
58 | - 2
59 | - 4
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions: []
63 | dropout: 0.0
64 | lossconfig:
65 | target: torch.nn.Identity
66 |
67 | cond_stage_config:
68 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
69 | params:
70 | freeze: True
71 | layer: "penultimate"
72 |
--------------------------------------------------------------------------------
/configs/stable-diffusion/v1-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 10000 ]
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32 # unused
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions: [ 4, 2, 1 ]
37 | num_res_blocks: 2
38 | channel_mult: [ 1, 2, 4, 4 ]
39 | num_heads: 8
40 | use_spatial_transformer: True
41 | transformer_depth: 1
42 | context_dim: 768
43 | use_checkpoint: True
44 | legacy: False
45 |
46 | first_stage_config:
47 | target: ldm.models.autoencoder.AutoencoderKL
48 | params:
49 | embed_dim: 4
50 | monitor: val/rec_loss
51 | ddconfig:
52 | double_z: true
53 | z_channels: 4
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | - 4
63 | num_res_blocks: 2
64 | attn_resolutions: []
65 | dropout: 0.0
66 | lossconfig:
67 | target: torch.nn.Identity
68 |
69 | cond_stage_config:
70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder #ldm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
--------------------------------------------------------------------------------
/configs/stable-diffusion/intel/v2-inference-v-bf16.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022 Intel Corporation
2 | # SPDX-License-Identifier: MIT
3 |
4 | model:
5 | base_learning_rate: 1.0e-4
6 | target: ldm.models.diffusion.ddpm.LatentDiffusion
7 | params:
8 | parameterization: "v"
9 | linear_start: 0.00085
10 | linear_end: 0.0120
11 | num_timesteps_cond: 1
12 | log_every_t: 200
13 | timesteps: 1000
14 | first_stage_key: "jpg"
15 | cond_stage_key: "txt"
16 | image_size: 64
17 | channels: 4
18 | cond_stage_trainable: false
19 | conditioning_key: crossattn
20 | monitor: val/loss_simple_ema
21 | scale_factor: 0.18215
22 | use_ema: False # we set this to false because this is an inference only config
23 |
24 | unet_config:
25 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
26 | params:
27 | use_checkpoint: False
28 | use_fp16: False
29 | use_bf16: True
30 | image_size: 32 # unused
31 | in_channels: 4
32 | out_channels: 4
33 | model_channels: 320
34 | attention_resolutions: [ 4, 2, 1 ]
35 | num_res_blocks: 2
36 | channel_mult: [ 1, 2, 4, 4 ]
37 | num_head_channels: 64 # need to fix for flash-attn
38 | use_spatial_transformer: True
39 | use_linear_in_transformer: True
40 | transformer_depth: 1
41 | context_dim: 1024
42 | legacy: False
43 |
44 | first_stage_config:
45 | target: ldm.models.autoencoder.AutoencoderKL
46 | params:
47 | embed_dim: 4
48 | monitor: val/rec_loss
49 | ddconfig:
50 | #attn_type: "vanilla-xformers"
51 | double_z: true
52 | z_channels: 4
53 | resolution: 256
54 | in_channels: 3
55 | out_ch: 3
56 | ch: 128
57 | ch_mult:
58 | - 1
59 | - 2
60 | - 4
61 | - 4
62 | num_res_blocks: 2
63 | attn_resolutions: []
64 | dropout: 0.0
65 | lossconfig:
66 | target: torch.nn.Identity
67 |
68 | cond_stage_config:
69 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
70 | params:
71 | freeze: True
72 | layer: "penultimate"
73 |
--------------------------------------------------------------------------------
/ldm/modules/karlo/kakao/modules/resample.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Modified from Guided-Diffusion (https://github.com/openai/guided-diffusion)
3 | # ------------------------------------------------------------------------------------
4 |
5 | from abc import abstractmethod
6 |
7 | import torch as th
8 |
9 |
10 | def create_named_schedule_sampler(name, diffusion):
11 | """
12 | Create a ScheduleSampler from a library of pre-defined samplers.
13 |
14 | :param name: the name of the sampler.
15 | :param diffusion: the diffusion object to sample for.
16 | """
17 | if name == "uniform":
18 | return UniformSampler(diffusion)
19 | else:
20 | raise NotImplementedError(f"unknown schedule sampler: {name}")
21 |
22 |
23 | class ScheduleSampler(th.nn.Module):
24 | """
25 | A distribution over timesteps in the diffusion process, intended to reduce
26 | variance of the objective.
27 |
28 | By default, samplers perform unbiased importance sampling, in which the
29 | objective's mean is unchanged.
30 | However, subclasses may override sample() to change how the resampled
31 | terms are reweighted, allowing for actual changes in the objective.
32 | """
33 |
34 | @abstractmethod
35 | def weights(self):
36 | """
37 | Get a numpy array of weights, one per diffusion step.
38 |
39 | The weights needn't be normalized, but must be positive.
40 | """
41 |
42 | def sample(self, batch_size, device):
43 | """
44 | Importance-sample timesteps for a batch.
45 |
46 | :param batch_size: the number of timesteps.
47 | :param device: the torch device to save to.
48 | :return: a tuple (timesteps, weights):
49 | - timesteps: a tensor of timestep indices.
50 | - weights: a tensor of weights to scale the resulting losses.
51 | """
52 | w = self.weights()
53 | p = w / th.sum(w)
54 | indices = p.multinomial(batch_size, replacement=True)
55 | weights = 1 / (len(p) * p[indices])
56 | return indices, weights
57 |
58 |
59 | class UniformSampler(ScheduleSampler):
60 | def __init__(self, diffusion):
61 | super(UniformSampler, self).__init__()
62 | self.diffusion = diffusion
63 | self.register_buffer(
64 | "_weights", th.ones([diffusion.num_timesteps]), persistent=False
65 | )
66 |
67 | def weights(self):
68 | return self._weights
69 |
--------------------------------------------------------------------------------
/scripts/evaluate_metrics.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import scripts.metrics as Metrics
3 | from PIL import Image
4 | import numpy as np
5 | import glob
6 | import torch
7 | from tqdm import tqdm
8 |
9 |
10 | if __name__ == "__main__":
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('-p_gt', '--path_gt', type=str,
13 | default='G:\Giordano\stablediffusion\outputs\img2img-samples\samples-orig-10')
14 | parser.add_argument('-p', '--path', type=str,
15 | default='G:\Giordano\stablediffusion\outputs\img2img-samples\samples-10')
16 | args = parser.parse_args()
17 | real_names = list(glob.glob('{}/*.png'.format(args.path_gt)))
18 | # real_names = list(glob.glob('{}/*.jpg'.format(args.path_gt)))
19 | print(real_names, args.path_gt)
20 |
21 |
22 |
23 | fake_names = list(glob.glob('{}/*.png'.format(args.path)))
24 |
25 | real_names.sort()
26 | fake_names.sort()
27 |
28 | avg_psnr = 0.0
29 | avg_ssim = 0.0
30 | avg_fid = 0.0
31 | fid_img_list_real, fid_img_list_fake = [],[]
32 | idx = 0
33 | for rname, fname in tqdm(zip(real_names, fake_names), total=len(real_names)):
34 | idx += 1
35 |
36 |
37 |
38 | hr_img = np.array(Image.open(rname))
39 | sr_img = np.array(Image.open(fname))
40 | psnr = Metrics.calculate_psnr(sr_img, hr_img)
41 | ssim = Metrics.calculate_ssim(sr_img, hr_img)
42 | fid_img_list_real.append(torch.from_numpy(hr_img).permute(2,0,1).unsqueeze(0))
43 | fid_img_list_fake.append(torch.from_numpy(sr_img).permute(2,0,1).unsqueeze(0))
44 | avg_psnr += psnr
45 | avg_ssim += ssim
46 | if idx % 10 == 0:
47 | # fid = Metrics.calculate_FID(torch.cat(fid_img_list_real,dim=0), torch.cat(fid_img_list_fake,dim=0))
48 | # fid_img_list_real, fid_img_list_fake = [],[]
49 | # avg_fid += fid
50 | print('Image:{}, PSNR:{:.4f}, SSIM:{:.4f}'.format(idx, psnr, ssim))
51 |
52 |
53 | #last FID
54 | fid = Metrics.calculate_FID(torch.cat(fid_img_list_real,dim=0), torch.cat(fid_img_list_fake,dim=0))
55 | avg_fid += fid
56 |
57 | avg_psnr = avg_psnr / idx
58 | avg_ssim = avg_ssim / idx
59 | # avg_fid = avg_fid / idx
60 |
61 | # log
62 | print('# Validation # PSNR: {}'.format(avg_psnr))
63 | print('# Validation # SSIM: {}'.format(avg_ssim))
64 | print('# Validation # FID: {}'.format(avg_fid))
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/midas_net.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=256, non_negative=True):
17 | """Init.
18 |
19 | Args:
20 | path (str, optional): Path to saved model. Defaults to None.
21 | features (int, optional): Number of features. Defaults to 256.
22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23 | """
24 | print("Loading weights: ", path)
25 |
26 | super(MidasNet, self).__init__()
27 |
28 | use_pretrained = False if path is None else True
29 |
30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31 |
32 | self.scratch.refinenet4 = FeatureFusionBlock(features)
33 | self.scratch.refinenet3 = FeatureFusionBlock(features)
34 | self.scratch.refinenet2 = FeatureFusionBlock(features)
35 | self.scratch.refinenet1 = FeatureFusionBlock(features)
36 |
37 | self.scratch.output_conv = nn.Sequential(
38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39 | Interpolate(scale_factor=2, mode="bilinear"),
40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41 | nn.ReLU(True),
42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43 | nn.ReLU(True) if non_negative else nn.Identity(),
44 | )
45 |
46 | if path:
47 | self.load(path)
48 |
49 | def forward(self, x):
50 | """Forward pass.
51 |
52 | Args:
53 | x (tensor): input data (image)
54 |
55 | Returns:
56 | tensor: depth
57 | """
58 |
59 | layer_1 = self.pretrained.layer1(x)
60 | layer_2 = self.pretrained.layer2(layer_1)
61 | layer_3 = self.pretrained.layer3(layer_2)
62 | layer_4 = self.pretrained.layer4(layer_3)
63 |
64 | layer_1_rn = self.scratch.layer1_rn(layer_1)
65 | layer_2_rn = self.scratch.layer2_rn(layer_2)
66 | layer_3_rn = self.scratch.layer3_rn(layer_3)
67 | layer_4_rn = self.scratch.layer4_rn(layer_4)
68 |
69 | path_4 = self.scratch.refinenet4(layer_4_rn)
70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73 |
74 | out = self.scratch.output_conv(path_1)
75 |
76 | return torch.squeeze(out, dim=1)
77 |
--------------------------------------------------------------------------------
/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1, dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | # remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.', '')
20 | self.m_name2s_name.update({name: s_name})
21 | self.register_buffer(s_name, p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def reset_num_updates(self):
26 | del self.num_updates
27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
28 |
29 | def forward(self, model):
30 | decay = self.decay
31 |
32 | if self.num_updates >= 0:
33 | self.num_updates += 1
34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
35 |
36 | one_minus_decay = 1.0 - decay
37 |
38 | with torch.no_grad():
39 | m_param = dict(model.named_parameters())
40 | shadow_params = dict(self.named_buffers())
41 |
42 | for key in m_param:
43 | if m_param[key].requires_grad:
44 | sname = self.m_name2s_name[key]
45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
47 | else:
48 | assert not key in self.m_name2s_name
49 |
50 | def copy_to(self, model):
51 | m_param = dict(model.named_parameters())
52 | shadow_params = dict(self.named_buffers())
53 | for key in m_param:
54 | if m_param[key].requires_grad:
55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
56 | else:
57 | assert not key in self.m_name2s_name
58 |
59 | def store(self, parameters):
60 | """
61 | Save the current parameters for restoring later.
62 | Args:
63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
64 | temporarily stored.
65 | """
66 | self.collected_params = [param.clone() for param in parameters]
67 |
68 | def restore(self, parameters):
69 | """
70 | Restore the parameters stored with the `store` method.
71 | Useful to validate the model with EMA parameters without affecting the
72 | original optimization process. Store the parameters before the
73 | `copy_to` method. After validation (or model saving), use this to
74 | restore the former parameters.
75 | Args:
76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
77 | updated with the stored parameters.
78 | """
79 | for c_param, param in zip(self.collected_params, parameters):
80 | param.data.copy_(c_param.data)
81 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/dpt_depth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .base_model import BaseModel
6 | from .blocks import (
7 | FeatureFusionBlock,
8 | FeatureFusionBlock_custom,
9 | Interpolate,
10 | _make_encoder,
11 | forward_vit,
12 | )
13 |
14 |
15 | def _make_fusion_block(features, use_bn):
16 | return FeatureFusionBlock_custom(
17 | features,
18 | nn.ReLU(False),
19 | deconv=False,
20 | bn=use_bn,
21 | expand=False,
22 | align_corners=True,
23 | )
24 |
25 |
26 | class DPT(BaseModel):
27 | def __init__(
28 | self,
29 | head,
30 | features=256,
31 | backbone="vitb_rn50_384",
32 | readout="project",
33 | channels_last=False,
34 | use_bn=False,
35 | ):
36 |
37 | super(DPT, self).__init__()
38 |
39 | self.channels_last = channels_last
40 |
41 | hooks = {
42 | "vitb_rn50_384": [0, 1, 8, 11],
43 | "vitb16_384": [2, 5, 8, 11],
44 | "vitl16_384": [5, 11, 17, 23],
45 | }
46 |
47 | # Instantiate backbone and reassemble blocks
48 | self.pretrained, self.scratch = _make_encoder(
49 | backbone,
50 | features,
51 | False, # Set to true of you want to train from scratch, uses ImageNet weights
52 | groups=1,
53 | expand=False,
54 | exportable=False,
55 | hooks=hooks[backbone],
56 | use_readout=readout,
57 | )
58 |
59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
63 |
64 | self.scratch.output_conv = head
65 |
66 |
67 | def forward(self, x):
68 | if self.channels_last == True:
69 | x.contiguous(memory_format=torch.channels_last)
70 |
71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
72 |
73 | layer_1_rn = self.scratch.layer1_rn(layer_1)
74 | layer_2_rn = self.scratch.layer2_rn(layer_2)
75 | layer_3_rn = self.scratch.layer3_rn(layer_3)
76 | layer_4_rn = self.scratch.layer4_rn(layer_4)
77 |
78 | path_4 = self.scratch.refinenet4(layer_4_rn)
79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
82 |
83 | out = self.scratch.output_conv(path_1)
84 |
85 | return out
86 |
87 |
88 | class DPTDepthModel(DPT):
89 | def __init__(self, path=None, non_negative=True, **kwargs):
90 | features = kwargs["features"] if "features" in kwargs else 256
91 |
92 | head = nn.Sequential(
93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
96 | nn.ReLU(True),
97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
98 | nn.ReLU(True) if non_negative else nn.Identity(),
99 | nn.Identity(),
100 | )
101 |
102 | super().__init__(head, **kwargs)
103 |
104 | if path is not None:
105 | self.load(path)
106 |
107 | def forward(self, x):
108 | return super().forward(x).squeeze(dim=1)
109 |
110 |
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/upscaling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from functools import partial
5 |
6 | from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
7 | from ldm.util import default
8 |
9 |
10 | class AbstractLowScaleModel(nn.Module):
11 | # for concatenating a downsampled image to the latent representation
12 | def __init__(self, noise_schedule_config=None):
13 | super(AbstractLowScaleModel, self).__init__()
14 | if noise_schedule_config is not None:
15 | self.register_schedule(**noise_schedule_config)
16 |
17 | def register_schedule(self, beta_schedule="linear", timesteps=1000,
18 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
19 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
20 | cosine_s=cosine_s)
21 | alphas = 1. - betas
22 | alphas_cumprod = np.cumprod(alphas, axis=0)
23 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
24 |
25 | timesteps, = betas.shape
26 | self.num_timesteps = int(timesteps)
27 | self.linear_start = linear_start
28 | self.linear_end = linear_end
29 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
30 |
31 | to_torch = partial(torch.tensor, dtype=torch.float32)
32 |
33 | self.register_buffer('betas', to_torch(betas))
34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
36 |
37 | # calculations for diffusion q(x_t | x_{t-1}) and others
38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
43 |
44 | def q_sample(self, x_start, t, noise=None):
45 | noise = default(noise, lambda: torch.randn_like(x_start))
46 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
47 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
48 |
49 | def forward(self, x):
50 | return x, None
51 |
52 | def decode(self, x):
53 | return x
54 |
55 |
56 | class SimpleImageConcat(AbstractLowScaleModel):
57 | # no noise level conditioning
58 | def __init__(self):
59 | super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
60 | self.max_noise_level = 0
61 |
62 | def forward(self, x):
63 | # fix to constant noise level
64 | return x, torch.zeros(x.shape[0], device=x.device).long()
65 |
66 |
67 | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
68 | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
69 | super().__init__(noise_schedule_config=noise_schedule_config)
70 | self.max_noise_level = max_noise_level
71 |
72 | def forward(self, x, noise_level=None):
73 | if noise_level is None:
74 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
75 | else:
76 | assert isinstance(noise_level, torch.Tensor)
77 | z = self.q_sample(x, noise_level)
78 | return z, noise_level
79 |
80 |
81 |
82 |
--------------------------------------------------------------------------------
/ldm/modules/karlo/kakao/modules/nn.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
3 | # ------------------------------------------------------------------------------------
4 |
5 | import math
6 |
7 | import torch as th
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class GroupNorm32(nn.GroupNorm):
13 | def __init__(self, num_groups, num_channels, swish, eps=1e-5):
14 | super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
15 | self.swish = swish
16 |
17 | def forward(self, x):
18 | y = super().forward(x.float()).to(x.dtype)
19 | if self.swish == 1.0:
20 | y = F.silu(y)
21 | elif self.swish:
22 | y = y * F.sigmoid(y * float(self.swish))
23 | return y
24 |
25 |
26 | def conv_nd(dims, *args, **kwargs):
27 | """
28 | Create a 1D, 2D, or 3D convolution module.
29 | """
30 | if dims == 1:
31 | return nn.Conv1d(*args, **kwargs)
32 | elif dims == 2:
33 | return nn.Conv2d(*args, **kwargs)
34 | elif dims == 3:
35 | return nn.Conv3d(*args, **kwargs)
36 | raise ValueError(f"unsupported dimensions: {dims}")
37 |
38 |
39 | def linear(*args, **kwargs):
40 | """
41 | Create a linear module.
42 | """
43 | return nn.Linear(*args, **kwargs)
44 |
45 |
46 | def avg_pool_nd(dims, *args, **kwargs):
47 | """
48 | Create a 1D, 2D, or 3D average pooling module.
49 | """
50 | if dims == 1:
51 | return nn.AvgPool1d(*args, **kwargs)
52 | elif dims == 2:
53 | return nn.AvgPool2d(*args, **kwargs)
54 | elif dims == 3:
55 | return nn.AvgPool3d(*args, **kwargs)
56 | raise ValueError(f"unsupported dimensions: {dims}")
57 |
58 |
59 | def zero_module(module):
60 | """
61 | Zero out the parameters of a module and return it.
62 | """
63 | for p in module.parameters():
64 | p.detach().zero_()
65 | return module
66 |
67 |
68 | def scale_module(module, scale):
69 | """
70 | Scale the parameters of a module and return it.
71 | """
72 | for p in module.parameters():
73 | p.detach().mul_(scale)
74 | return module
75 |
76 |
77 | def normalization(channels, swish=0.0):
78 | """
79 | Make a standard normalization layer, with an optional swish activation.
80 |
81 | :param channels: number of input channels.
82 | :return: an nn.Module for normalization.
83 | """
84 | return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
85 |
86 |
87 | def timestep_embedding(timesteps, dim, max_period=10000):
88 | """
89 | Create sinusoidal timestep embeddings.
90 |
91 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
92 | These may be fractional.
93 | :param dim: the dimension of the output.
94 | :param max_period: controls the minimum frequency of the embeddings.
95 | :return: an [N x dim] Tensor of positional embeddings.
96 | """
97 | half = dim // 2
98 | freqs = th.exp(
99 | -math.log(max_period)
100 | * th.arange(start=0, end=half, dtype=th.float32, device=timesteps.device)
101 | / half
102 | )
103 | args = timesteps[:, None].float() * freqs[None]
104 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
105 | if dim % 2:
106 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
107 | return embedding
108 |
109 |
110 | def mean_flat(tensor):
111 | """
112 | Take the mean over all non-batch dimensions.
113 | """
114 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
115 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Generated by project
2 | outputs/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # General MacOS
13 | .DS_Store
14 | .AppleDouble
15 | .LSOverride
16 |
17 | # Distribution / packaging
18 | .Python
19 | build/
20 | develop-eggs/
21 | dist/
22 | downloads/
23 | eggs/
24 | .eggs/
25 | lib/
26 | lib64/
27 | parts/
28 | sdist/
29 | var/
30 | wheels/
31 | share/python-wheels/
32 | *.egg-info/
33 | .installed.cfg
34 | *.egg
35 | MANIFEST
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .nox/
51 | .coverage
52 | .coverage.*
53 | .cache
54 | nosetests.xml
55 | coverage.xml
56 | *.cover
57 | *.py,cover
58 | .hypothesis/
59 | .pytest_cache/
60 | cover/
61 |
62 | # Translations
63 | *.mo
64 | *.pot
65 |
66 | # Django stuff:
67 | *.log
68 | local_settings.py
69 | db.sqlite3
70 | db.sqlite3-journal
71 |
72 | # Flask stuff:
73 | instance/
74 | .webassets-cache
75 |
76 | # Scrapy stuff:
77 | .scrapy
78 |
79 | # Sphinx documentation
80 | docs/_build/
81 |
82 | # PyBuilder
83 | .pybuilder/
84 | target/
85 |
86 | # Jupyter Notebook
87 | .ipynb_checkpoints
88 |
89 | # IPython
90 | profile_default/
91 | ipython_config.py
92 |
93 | # pyenv
94 | # For a library or package, you might want to ignore these files since the code is
95 | # intended to run in multiple environments; otherwise, check them in:
96 | # .python-version
97 |
98 | # pipenv
99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
102 | # install all needed dependencies.
103 | #Pipfile.lock
104 |
105 | # poetry
106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
107 | # This is especially recommended for binary packages to ensure reproducibility, and is more
108 | # commonly ignored for libraries.
109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
110 | #poetry.lock
111 |
112 | # pdm
113 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
114 | #pdm.lock
115 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
116 | # in version control.
117 | # https://pdm.fming.dev/#use-with-ide
118 | .pdm.toml
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # IDEs
164 | .idea/
165 | .vscode/
166 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/sampler.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 | import torch
3 |
4 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
5 |
6 | MODEL_TYPES = {
7 | "eps": "noise",
8 | "v": "v"
9 | }
10 |
11 |
12 | class DPMSolverSampler(object):
13 | def __init__(self, model, device=torch.device("cuda"), **kwargs):
14 | super().__init__()
15 | self.model = model
16 | self.device = device
17 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
18 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
19 |
20 | def register_buffer(self, name, attr):
21 | if type(attr) == torch.Tensor:
22 | if attr.device != self.device:
23 | attr = attr.to(self.device)
24 | setattr(self, name, attr)
25 |
26 | @torch.no_grad()
27 | def sample(self,
28 | S,
29 | batch_size,
30 | shape,
31 | conditioning=None,
32 | callback=None,
33 | normals_sequence=None,
34 | img_callback=None,
35 | quantize_x0=False,
36 | eta=0.,
37 | mask=None,
38 | x0=None,
39 | temperature=1.,
40 | noise_dropout=0.,
41 | score_corrector=None,
42 | corrector_kwargs=None,
43 | verbose=True,
44 | x_T=None,
45 | log_every_t=100,
46 | unconditional_guidance_scale=1.,
47 | unconditional_conditioning=None,
48 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
49 | **kwargs
50 | ):
51 | if conditioning is not None:
52 | if isinstance(conditioning, dict):
53 | ctmp = conditioning[list(conditioning.keys())[0]]
54 | while isinstance(ctmp, list): ctmp = ctmp[0]
55 | if isinstance(ctmp, torch.Tensor):
56 | cbs = ctmp.shape[0]
57 | if cbs != batch_size:
58 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
59 | elif isinstance(conditioning, list):
60 | for ctmp in conditioning:
61 | if ctmp.shape[0] != batch_size:
62 | print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}")
63 | else:
64 | if isinstance(conditioning, torch.Tensor):
65 | if conditioning.shape[0] != batch_size:
66 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
67 |
68 | # sampling
69 | C, H, W = shape
70 | size = (batch_size, C, H, W)
71 |
72 | print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
73 |
74 | device = self.model.betas.device
75 | if x_T is None:
76 | img = torch.randn(size, device=device)
77 | else:
78 | img = x_T
79 |
80 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
81 |
82 | model_fn = model_wrapper(
83 | lambda x, t, c: self.model.apply_model(x, t, c),
84 | ns,
85 | model_type=MODEL_TYPES[self.model.parameterization],
86 | guidance_type="classifier-free",
87 | condition=conditioning,
88 | unconditional_condition=unconditional_conditioning,
89 | guidance_scale=unconditional_guidance_scale,
90 | )
91 |
92 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
93 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2,
94 | lower_order_final=True)
95 |
96 | return x.to(device), None
97 |
--------------------------------------------------------------------------------
/ldm/modules/karlo/kakao/models/sr_64_256.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Karlo-v1.0.alpha
3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 |
6 | import copy
7 | import torch
8 |
9 | from ldm.modules.karlo.kakao.modules.unet import SuperResUNetModel
10 | from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion
11 |
12 |
13 | class ImprovedSupRes64to256ProgressiveModel(torch.nn.Module):
14 | """
15 | ImprovedSR model fine-tunes the pretrained DDPM-based SR model by using adversarial and perceptual losses.
16 | In specific, the low-resolution sample is iteratively recovered by 6 steps with the frozen pretrained SR model.
17 | In the following additional one step, a seperate fine-tuned model recovers high-frequency details.
18 | This approach greatly improves the fidelity of images of 256x256px, even with small number of reverse steps.
19 | """
20 |
21 | def __init__(self, config):
22 | super().__init__()
23 |
24 | self._config = config
25 | self._diffusion_kwargs = dict(
26 | steps=config.diffusion.steps,
27 | learn_sigma=config.diffusion.learn_sigma,
28 | sigma_small=config.diffusion.sigma_small,
29 | noise_schedule=config.diffusion.noise_schedule,
30 | use_kl=config.diffusion.use_kl,
31 | predict_xstart=config.diffusion.predict_xstart,
32 | rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas,
33 | )
34 |
35 | self.model_first_steps = SuperResUNetModel(
36 | in_channels=3, # auto-changed to 6 inside the model
37 | model_channels=config.model.hparams.channels,
38 | out_channels=3,
39 | num_res_blocks=config.model.hparams.depth,
40 | attention_resolutions=(), # no attention
41 | dropout=config.model.hparams.dropout,
42 | channel_mult=config.model.hparams.channels_multiple,
43 | resblock_updown=True,
44 | use_middle_attention=False,
45 | )
46 | self.model_last_step = SuperResUNetModel(
47 | in_channels=3, # auto-changed to 6 inside the model
48 | model_channels=config.model.hparams.channels,
49 | out_channels=3,
50 | num_res_blocks=config.model.hparams.depth,
51 | attention_resolutions=(), # no attention
52 | dropout=config.model.hparams.dropout,
53 | channel_mult=config.model.hparams.channels_multiple,
54 | resblock_updown=True,
55 | use_middle_attention=False,
56 | )
57 |
58 | @classmethod
59 | def load_from_checkpoint(cls, config, ckpt_path, strict: bool = True):
60 | ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
61 |
62 | model = cls(config)
63 | model.load_state_dict(ckpt, strict=strict)
64 | return model
65 |
66 | def get_sample_fn(self, timestep_respacing):
67 | diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs)
68 | diffusion_kwargs.update(timestep_respacing=timestep_respacing)
69 | diffusion = create_gaussian_diffusion(**diffusion_kwargs)
70 | return diffusion.p_sample_loop_progressive_for_improved_sr
71 |
72 | def forward(self, low_res, timestep_respacing="7", **kwargs):
73 | assert (
74 | timestep_respacing == "7"
75 | ), "different respacing method may work, but no guaranteed"
76 |
77 | sample_fn = self.get_sample_fn(timestep_respacing)
78 | sample_outputs = sample_fn(
79 | self.model_first_steps,
80 | self.model_last_step,
81 | shape=low_res.shape,
82 | clip_denoised=True,
83 | model_kwargs=dict(low_res=low_res),
84 | **kwargs,
85 | )
86 | for x in sample_outputs:
87 | sample = x["sample"]
88 | yield sample
89 |
--------------------------------------------------------------------------------
/scripts/metrics.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import numpy as np
4 | import cv2
5 | from torchvision.utils import make_grid
6 | from torchmetrics.image.fid import FrechetInceptionDistance
7 |
8 | def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)):
9 | '''
10 | Converts a torch Tensor into an image Numpy array
11 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
12 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
13 | '''
14 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp
15 | tensor = (tensor - min_max[0]) / \
16 | (min_max[1] - min_max[0]) # to range [0,1]
17 | n_dim = tensor.dim()
18 | if n_dim == 4:
19 | n_img = len(tensor)
20 | img_np = make_grid(tensor, nrow=int(
21 | math.sqrt(n_img)), normalize=False).numpy()
22 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB
23 | elif n_dim == 3:
24 | img_np = tensor.numpy()
25 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB
26 | elif n_dim == 2:
27 | img_np = tensor.numpy()
28 | else:
29 | raise TypeError(
30 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
31 | if out_type == np.uint8:
32 | img_np = (img_np * 255.0).round()
33 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
34 | return img_np.astype(out_type)
35 |
36 |
37 | def save_img(img, img_path, mode='RGB'):
38 | cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
39 | # cv2.imwrite(img_path, img)
40 |
41 |
42 | def calculate_psnr(img1, img2):
43 | # img1 and img2 have range [0, 255]
44 | img1 = img1.astype(np.float64)
45 | img2 = img2.astype(np.float64)
46 | mse = np.mean((img1 - img2)**2)
47 | if mse == 0:
48 | return float('inf')
49 | return 20 * math.log10(255.0 / math.sqrt(mse))
50 |
51 |
52 | def ssim(img1, img2):
53 | C1 = (0.01 * 255)**2
54 | C2 = (0.03 * 255)**2
55 |
56 | img1 = img1.astype(np.float64)
57 | img2 = img2.astype(np.float64)
58 | kernel = cv2.getGaussianKernel(11, 1.5)
59 | window = np.outer(kernel, kernel.transpose())
60 |
61 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
62 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
63 | mu1_sq = mu1**2
64 | mu2_sq = mu2**2
65 | mu1_mu2 = mu1 * mu2
66 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
67 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
68 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
69 |
70 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
71 | (sigma1_sq + sigma2_sq + C2))
72 | return ssim_map.mean()
73 |
74 |
75 | def calculate_ssim(img1, img2):
76 | '''calculate SSIM
77 | the same outputs as MATLAB's
78 | img1, img2: [0, 255]
79 | '''
80 | if not img1.shape == img2.shape:
81 | raise ValueError('Input images must have the same dimensions.')
82 | if img1.ndim == 2:
83 | return ssim(img1, img2)
84 | elif img1.ndim == 3:
85 | if img1.shape[2] == 3:
86 | ssims = []
87 | for i in range(3):
88 | ssims.append(ssim(img1, img2))
89 | return np.array(ssims).mean()
90 | elif img1.shape[2] == 1:
91 | return ssim(np.squeeze(img1), np.squeeze(img2))
92 | else:
93 | raise ValueError('Wrong input image dimensions.')
94 |
95 | fid = FrechetInceptionDistance()
96 | def calculate_FID(imgs_dist1, imgs_dist2):
97 | # generate two slightly overlapping image intensity distributions
98 | fid.update(imgs_dist1, real=True)
99 | fid.update(imgs_dist2, real=False)
100 | # fid.compute()
101 | return fid.compute()
--------------------------------------------------------------------------------
/scripts/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch import optim
4 | from torch.utils.data import Dataset, DataLoader, random_split
5 | from PIL import Image
6 | from torchvision.transforms import Resize, ToTensor, Normalize, Compose
7 | from matplotlib import pyplot as plt
8 |
9 | from io import open
10 | import unicodedata
11 | import re
12 |
13 |
14 |
15 |
16 | class Flickr8kDataset(Dataset):
17 | def __init__(self,images_dir_path, capt_file_path):
18 | super().__init__()
19 | #read data
20 | data=open(capt_file_path).read().strip().split('\n')
21 | data=data[1:]
22 |
23 | img_filenames_list=[]
24 | captions_list=[]
25 |
26 | for s in data:
27 | templist=s.lower().split(",")
28 | img_path=templist[0]
29 | caption=",".join(s for s in templist[1:])
30 | caption=self.normalizeString(caption)
31 | img_filenames_list.append(img_path)
32 | captions_list.append(caption)
33 |
34 | self.images_dir_path=images_dir_path
35 | self.img_filenames_list=img_filenames_list
36 | self.captions_list=captions_list
37 | self.length=len(self.captions_list)
38 | self.transform=Compose([Resize((224,224), antialias=True), ToTensor(), Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
39 |
40 |
41 | def __len__(self):
42 | return self.length
43 |
44 | #unicode 2 ascii, remove non-letter characters, trim
45 | def normalizeString(self,s):
46 | sres=""
47 | for ch in unicodedata.normalize('NFD', s):
48 | #Return the normal form form ('NFD') for the Unicode string s.
49 | if unicodedata.category(ch) != 'Mn':
50 | # The function in the first part returns the general
51 | # category assigned to the character ch as string.
52 | # "Mn' refers to Mark, Nonspacing
53 | sres+=ch
54 | #sres = re.sub(r"([.!?])", r" \1", sres)
55 | # inserts a space before any occurrence of ".", "!", or "?" in the string sres.
56 | sres = re.sub(r"[^a-zA-Z!?,]+", r" ", sres)
57 | # this line of code replaces any sequence of characters in sres
58 | # that are not letters (a-z or A-Z) or the punctuation marks
59 | # "!", "," or "?" with a single space character.
60 | return sres.strip()
61 |
62 |
63 |
64 | def __getitem__(self,idx):
65 | imgfname,caption=self.img_filenames_list[idx],self.captions_list[idx]
66 |
67 | imgfname=self.images_dir_path+imgfname
68 |
69 | return imgfname, caption
70 |
71 | import os
72 | class Only_images_Flickr8kDataset(Dataset):
73 | def __init__(self,images_dir_path):
74 | super().__init__()
75 | #read data
76 | img_filenames_list=[]
77 |
78 | for root, dirs, files in os.walk(images_dir_path):
79 | for file in files:
80 | if file.endswith('.jpg'):
81 | img_filenames_list.append(file)
82 |
83 |
84 | #print(img_filenames_list)
85 | self.images_dir_path=images_dir_path
86 | self.img_filenames_list=img_filenames_list
87 | self.length = len(img_filenames_list)
88 |
89 |
90 | def __len__(self):
91 | return self.length
92 |
93 |
94 |
95 | def __getitem__(self,idx):
96 |
97 | imgfname = self.img_filenames_list[idx]
98 | imgfname = self.images_dir_path+imgfname
99 |
100 | return imgfname
101 |
102 |
103 |
104 | if __name__ == "__main__":
105 |
106 |
107 | capt_file_path= "path/to/captions.txt" #"G:/Giordano/Flickr8kDataset/captions.txt"
108 | images_dir_path= "path/to/Images" #"G:/Giordano/Flickr8kDataset/Images/"
109 |
110 | dataset=Flickr8kDataset(images_dir_path, capt_file_path)
111 |
112 |
113 | batch_size=1
114 | train_dataloader=DataLoader(dataset=dataset,batch_size=batch_size, shuffle=True)
115 |
116 |
117 | for i in train_dataloader:
118 | print(i[0][0])
119 | print(i[1])
120 | break
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |

4 |
5 |
6 |
7 |
8 |
9 |
10 | # Language-Oriented Semantic Latent Representation for Image Transmission
11 |
12 | > [**Language-Oriented Semantic Latent Representation for Image Transmission**](https://arxiv.org/abs/2405.09976)
13 |
14 | > [Giordano Cicchetti](https://www.linkedin.com/in/giordano-cicchetti-1ab73b258/), [Eleonora Grassucci](https://scholar.google.com/citations?user=Jcv0TgQAAAAJ&hl=it),
15 | > [Jihong Park](https://scholar.google.com/citations?user=I0CO72QAAAAJ&hl=en),[Jinho Choi](https://scholar.google.co.uk/citations?user=QzFia5YAAAAJ&hl=en)
16 | > [Sergio Barbarossa](https://scholar.google.com/citations?user=2woHFu8AAAAJ&hl=en),[Danilo Comminiello](https://scholar.google.com/citations?user=H3Y52cMAAAAJ&hl=en)
17 |
18 |
19 | This is the official implementation of the paper: [Language-Oriented Semantic Latent Representation for Image Transmission](https://arxiv.org/abs/2405.09976)
20 |
21 |
22 | [](https://arxiv.org/abs/2404.05669)
23 | 
24 |
25 |
26 |
27 |
28 |
29 | ## News
30 |
31 |
32 | **June 17, 2024** *Code Released*
33 |
34 | ________________________________
35 |
36 |
37 |
38 | ## Requirements
39 |
40 | Create a dedicated conda environment:
41 | ```
42 | conda env -n SemanticI2I python=3.9
43 | conda activate SemanticI2I
44 |
45 | ```
46 |
47 |
48 | You can clone the repository by typing:
49 | ```
50 |
51 | git clone https://github.com/ispamm/Img2Img-SC.git
52 | cd Img2Img-SC
53 |
54 | ```
55 |
56 | You can update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running
57 |
58 | ```
59 | conda install pytorch==1.12.1 torchvision==0.13.1 -c pytorch
60 | pip install transformers==4.19.2 diffusers invisible-watermark
61 | pip install -e .
62 |
63 | ```
64 |
65 | After that you can install the remaining required packages by running:
66 |
67 |
68 | ```
69 |
70 | pip install -r requirements.txt
71 |
72 | ```
73 |
74 | ## Download pretraining models
75 |
76 | Download pretrained checkpoints and copy them into the "/checkpoints" folder.
77 |
78 | ## Img2Img
79 |
80 | The scripts are located in the "/scripts" folder.
81 |
82 | /scripts/semantic_i2i.py refers to the proposed I2I framework that uses latent embedding and image caption.
83 | /scripts/semantic_t2i.py refers to the I2I framework that uses only image caption.
84 |
85 | For testing the img2img framework, change the model and configuration paths inside script files and then use:
86 |
87 | ```
88 | python /scripts/semantic_i2i.py
89 |
90 | #Or
91 |
92 | python /scripts/semantic_t2i.py
93 | ```
94 | and adapt the checkpoint and config paths accordingly.
95 |
96 | ## Results
97 |
98 |
99 |
100 |

101 |
102 |
103 |
104 |
105 |
106 |

107 |
108 |
109 |
110 |
111 |
112 |

113 |
114 |
115 |
116 |
117 | ## License
118 |
119 | The code in this repository is released under the MIT License.
120 |
121 | ## Acknowledgment
122 |
123 | Most of the code contained in this repository is based on Stable diffusion repository https://github.com/Stability-AI/stablediffusion
124 |
125 |
126 | ## BibTeX
127 |
128 | ```
129 | @misc{cicchetti2024languageoriented,
130 | title={Language-Oriented Semantic Latent Representation for Image Transmission},
131 | author={Giordano Cicchetti and Eleonora Grassucci and Jihong Park and Jinho Choi and Sergio Barbarossa and Danilo Comminiello},
132 | year={2024},
133 | eprint={2405.09976},
134 | archivePrefix={arXiv},
135 | primaryClass={id='cs.CV' full_name='Computer Vision and Pattern Recognition' is_active=True alt_name=None in_archive='cs' is_general=False description='Covers image processing, computer vision, pattern recognition, and scene understanding. Roughly includes material in ACM Subject Classes I.2.10, I.4, and I.5.'}
136 | }
137 | ```
138 |
139 |
140 |
--------------------------------------------------------------------------------
/ldm/modules/karlo/kakao/template.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Karlo-v1.0.alpha
3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 |
6 | import os
7 | import logging
8 | import torch
9 |
10 | from omegaconf import OmegaConf
11 |
12 | from ldm.modules.karlo.kakao.models.clip import CustomizedCLIP, CustomizedTokenizer
13 | from ldm.modules.karlo.kakao.models.prior_model import PriorDiffusionModel
14 | from ldm.modules.karlo.kakao.models.decoder_model import Text2ImProgressiveModel
15 | from ldm.modules.karlo.kakao.models.sr_64_256 import ImprovedSupRes64to256ProgressiveModel
16 |
17 |
18 | SAMPLING_CONF = {
19 | "default": {
20 | "prior_sm": "25",
21 | "prior_n_samples": 1,
22 | "prior_cf_scale": 4.0,
23 | "decoder_sm": "50",
24 | "decoder_cf_scale": 8.0,
25 | "sr_sm": "7",
26 | },
27 | "fast": {
28 | "prior_sm": "25",
29 | "prior_n_samples": 1,
30 | "prior_cf_scale": 4.0,
31 | "decoder_sm": "25",
32 | "decoder_cf_scale": 8.0,
33 | "sr_sm": "7",
34 | },
35 | }
36 |
37 | CKPT_PATH = {
38 | "prior": "prior-ckpt-step=01000000-of-01000000.ckpt",
39 | "decoder": "decoder-ckpt-step=01000000-of-01000000.ckpt",
40 | "sr_256": "improved-sr-ckpt-step=1.2M.ckpt",
41 | }
42 |
43 |
44 | class BaseSampler:
45 | _PRIOR_CLASS = PriorDiffusionModel
46 | _DECODER_CLASS = Text2ImProgressiveModel
47 | _SR256_CLASS = ImprovedSupRes64to256ProgressiveModel
48 |
49 | def __init__(
50 | self,
51 | root_dir: str,
52 | sampling_type: str = "fast",
53 | ):
54 | self._root_dir = root_dir
55 |
56 | sampling_type = SAMPLING_CONF[sampling_type]
57 | self._prior_sm = sampling_type["prior_sm"]
58 | self._prior_n_samples = sampling_type["prior_n_samples"]
59 | self._prior_cf_scale = sampling_type["prior_cf_scale"]
60 |
61 | assert self._prior_n_samples == 1
62 |
63 | self._decoder_sm = sampling_type["decoder_sm"]
64 | self._decoder_cf_scale = sampling_type["decoder_cf_scale"]
65 |
66 | self._sr_sm = sampling_type["sr_sm"]
67 |
68 | def __repr__(self):
69 | line = ""
70 | line += f"Prior, sampling method: {self._prior_sm}, cf_scale: {self._prior_cf_scale}\n"
71 | line += f"Decoder, sampling method: {self._decoder_sm}, cf_scale: {self._decoder_cf_scale}\n"
72 | line += f"SR(64->256), sampling method: {self._sr_sm}"
73 |
74 | return line
75 |
76 | def load_clip(self, clip_path: str):
77 | clip = CustomizedCLIP.load_from_checkpoint(
78 | os.path.join(self._root_dir, clip_path)
79 | )
80 | clip = torch.jit.script(clip)
81 | clip.cuda()
82 | clip.eval()
83 |
84 | self._clip = clip
85 | self._tokenizer = CustomizedTokenizer()
86 |
87 | def load_prior(
88 | self,
89 | ckpt_path: str,
90 | clip_stat_path: str,
91 | prior_config: str = "configs/prior_1B_vit_l.yaml"
92 | ):
93 | logging.info(f"Loading prior: {ckpt_path}")
94 |
95 | config = OmegaConf.load(prior_config)
96 | clip_mean, clip_std = torch.load(
97 | os.path.join(self._root_dir, clip_stat_path), map_location="cpu"
98 | )
99 |
100 | prior = self._PRIOR_CLASS.load_from_checkpoint(
101 | config,
102 | self._tokenizer,
103 | clip_mean,
104 | clip_std,
105 | os.path.join(self._root_dir, ckpt_path),
106 | strict=True,
107 | )
108 | prior.cuda()
109 | prior.eval()
110 | logging.info("done.")
111 |
112 | self._prior = prior
113 |
114 | def load_decoder(self, ckpt_path: str, decoder_config: str = "configs/decoder_900M_vit_l.yaml"):
115 | logging.info(f"Loading decoder: {ckpt_path}")
116 |
117 | config = OmegaConf.load(decoder_config)
118 | decoder = self._DECODER_CLASS.load_from_checkpoint(
119 | config,
120 | self._tokenizer,
121 | os.path.join(self._root_dir, ckpt_path),
122 | strict=True,
123 | )
124 | decoder.cuda()
125 | decoder.eval()
126 | logging.info("done.")
127 |
128 | self._decoder = decoder
129 |
130 | def load_sr_64_256(self, ckpt_path: str, sr_config: str = "configs/improved_sr_64_256_1.4B.yaml"):
131 | logging.info(f"Loading SR(64->256): {ckpt_path}")
132 |
133 | config = OmegaConf.load(sr_config)
134 | sr = self._SR256_CLASS.load_from_checkpoint(
135 | config, os.path.join(self._root_dir, ckpt_path), strict=True
136 | )
137 | sr.cuda()
138 | sr.eval()
139 | logging.info("done.")
140 |
141 | self._sr_64_256 = sr
--------------------------------------------------------------------------------
/scripts/qam.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.special import erfc
3 | import random
4 | import torch
5 | import math
6 | import bitstring
7 | import numpy as np
8 |
9 | '''
10 | This python file is dedicated to the QAM Modulation and noisy channel estimation
11 |
12 | we used K-ary discerete memoryless channel (DMC) where the crossover probability is given by
13 | the bit error rate (BER) of the M-ary quadrature amplitude modulation (QAM).
14 |
15 | For DMC, you can find the channel model from (9) in https://ieeexplore.ieee.org/abstract/document/10437659.
16 |
17 | For the crossover probability, you assumed an AWGN channel where the BER is a Q-function
18 | of the SNR and M: https://www.etti.unibw.de/labalive/experiment/qam/.
19 |
20 | '''
21 |
22 | # Modulate Tensor in 16QAM transmission and noisy channel conditions
23 | def qam16ModulationTensor(input_tensor,snr_db=10):
24 |
25 | message_shape = input_tensor.shape
26 |
27 | message = input_tensor
28 |
29 | #Convert tensor in bitstream
30 | bit_list = tensor2bin(message)
31 |
32 | #Introduce noise to the bitstream according to SNR
33 | bit_list_noisy = introduce_noise(bit_list,snr=snr_db)
34 |
35 | #Convert bitstream back to tensor
36 | back_to_tensor = bin2tensor(bit_list_noisy)
37 |
38 | return back_to_tensor.reshape(message_shape)
39 |
40 |
41 | # Modulate String in 16QAM transmission and noisy channel conditions
42 | def qam16ModulationString(input_tensor,snr_db=10):
43 |
44 | message = input_tensor
45 |
46 | #Convert string to bitstream
47 | bit_list = list2bin(message)
48 |
49 | #Introduce noise to the bitstream according to SNR
50 | bit_list_noisy = introduce_noise(bit_list,snr=snr_db)
51 |
52 | #Convert bitstream back to list of char
53 | back_to_tensor = bin2list(bit_list_noisy)
54 |
55 | return "".join(back_to_tensor)
56 |
57 |
58 |
59 |
60 |
61 | def introduce_noise(bit_list,snr=10,qam=16):
62 |
63 | # Compute ebno according to SNR
64 | ebno = 10 ** (snr/10)
65 |
66 | # Estimate probability of bit error according to https://www.etti.unibw.de/labalive/experiment/qam/
67 | K = np.sqrt(qam) # 4
68 | M = 2 ** K
69 | Pm = (1 - 1 / np.sqrt(M)) * erfc(np.sqrt(3 / 2 / (M - 1) * K * ebno))
70 | Ps_qam = 1 - (1 - Pm) ** 2
71 | Pb_qam = Ps_qam / K
72 |
73 | bit_flipped = 0
74 | bit_tot = 0
75 | new_list = []
76 | for num in bit_list:
77 | num_new = []
78 | for b in num:
79 |
80 | if random.random() < Pb_qam:
81 | num_new.append(str(1 - int(b))) # Flipping the bit
82 | bit_flipped+=1
83 | else:
84 | num_new.append(b)
85 | bit_tot+=1
86 | new_list.append(''.join(num_new))
87 |
88 | #print(bit_flipped/bit_tot)
89 | return new_list
90 |
91 |
92 |
93 |
94 |
95 | def bin2float(b):
96 | ''' Convert binary string to a float.
97 |
98 | Attributes:
99 | :b: Binary string to transform.
100 | '''
101 |
102 | num = bitstring.BitArray(bin=b).float
103 |
104 | #print(num)
105 | if math.isnan(num) or math.isinf(num):
106 |
107 | num = np.random.randn()
108 |
109 |
110 | if num > 10:
111 |
112 | num=np.random.randn()
113 |
114 | if num < -10:
115 |
116 | num=np.random.randn()
117 |
118 | if num < 1e-2 and num>-1e-2:
119 |
120 | num = np.random.randn()
121 |
122 | return num
123 |
124 |
125 | def float2bin(f):
126 | ''' Convert float to 64-bit binary string.
127 |
128 | Attributes:
129 | :f: Float number to transform.
130 | '''
131 |
132 | f1 = bitstring.BitArray(float=f, length=64)
133 | return f1.bin
134 |
135 |
136 | def tensor2bin(tensor):
137 |
138 | tensor_flattened = tensor.view(-1).numpy()
139 |
140 | bit_list = []
141 | for number in tensor_flattened:
142 | bit_list.append(float2bin(number))
143 |
144 |
145 | return bit_list
146 |
147 |
148 | def bin2tensor(input_list):
149 | tensor_reconstructed = [bin2float(bin) for bin in input_list]
150 | return torch.FloatTensor(tensor_reconstructed)
151 |
152 |
153 | def string2int(char):
154 | return ord(char)
155 |
156 |
157 | def int2bin(int_num):
158 | return '{0:08b}'.format(int_num)
159 |
160 | def int2string(int_num):
161 | return chr(int_num)
162 |
163 | def bin2int(bin_num):
164 | return int(bin_num, 2)
165 |
166 |
167 | def list2bin(input_list):
168 |
169 | bit_list = []
170 | for number in input_list:
171 | bit_list.append(int2bin(string2int(number)))
172 |
173 | return bit_list
174 |
175 | def bin2list(input_list):
176 | list_reconstructed = [int2string(bin2int(bin)) for bin in input_list]
177 | return list_reconstructed
178 |
179 |
180 |
181 |
182 |
183 |
184 |
--------------------------------------------------------------------------------
/ldm/modules/karlo/kakao/modules/diffusion/respace.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
3 | # ------------------------------------------------------------------------------------
4 |
5 |
6 | import torch as th
7 |
8 | from .gaussian_diffusion import GaussianDiffusion
9 |
10 |
11 | def space_timesteps(num_timesteps, section_counts):
12 | """
13 | Create a list of timesteps to use from an original diffusion process,
14 | given the number of timesteps we want to take from equally-sized portions
15 | of the original process.
16 |
17 | For example, if there's 300 timesteps and the section counts are [10,15,20]
18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
19 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
20 |
21 | :param num_timesteps: the number of diffusion steps in the original
22 | process to divide up.
23 | :param section_counts: either a list of numbers, or a string containing
24 | comma-separated numbers, indicating the step count
25 | per section. As a special case, use "ddimN" where N
26 | is a number of steps to use the striding from the
27 | DDIM paper.
28 | :return: a set of diffusion steps from the original process to use.
29 | """
30 | if isinstance(section_counts, str):
31 | if section_counts.startswith("ddim"):
32 | desired_count = int(section_counts[len("ddim") :])
33 | for i in range(1, num_timesteps):
34 | if len(range(0, num_timesteps, i)) == desired_count:
35 | return set(range(0, num_timesteps, i))
36 | raise ValueError(
37 | f"cannot create exactly {num_timesteps} steps with an integer stride"
38 | )
39 | elif section_counts == "fast27":
40 | steps = space_timesteps(num_timesteps, "10,10,3,2,2")
41 | # Help reduce DDIM artifacts from noisiest timesteps.
42 | steps.remove(num_timesteps - 1)
43 | steps.add(num_timesteps - 3)
44 | return steps
45 | section_counts = [int(x) for x in section_counts.split(",")]
46 | size_per = num_timesteps // len(section_counts)
47 | extra = num_timesteps % len(section_counts)
48 | start_idx = 0
49 | all_steps = []
50 | for i, section_count in enumerate(section_counts):
51 | size = size_per + (1 if i < extra else 0)
52 | if size < section_count:
53 | raise ValueError(
54 | f"cannot divide section of {size} steps into {section_count}"
55 | )
56 | if section_count <= 1:
57 | frac_stride = 1
58 | else:
59 | frac_stride = (size - 1) / (section_count - 1)
60 | cur_idx = 0.0
61 | taken_steps = []
62 | for _ in range(section_count):
63 | taken_steps.append(start_idx + round(cur_idx))
64 | cur_idx += frac_stride
65 | all_steps += taken_steps
66 | start_idx += size
67 | return set(all_steps)
68 |
69 |
70 | class SpacedDiffusion(GaussianDiffusion):
71 | """
72 | A diffusion process which can skip steps in a base diffusion process.
73 |
74 | :param use_timesteps: a collection (sequence or set) of timesteps from the
75 | original diffusion process to retain.
76 | :param kwargs: the kwargs to create the base diffusion process.
77 | """
78 |
79 | def __init__(self, use_timesteps, **kwargs):
80 | self.use_timesteps = set(use_timesteps)
81 | self.original_num_steps = len(kwargs["betas"])
82 |
83 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
84 | last_alpha_cumprod = 1.0
85 | new_betas = []
86 | timestep_map = []
87 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
88 | if i in self.use_timesteps:
89 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
90 | last_alpha_cumprod = alpha_cumprod
91 | timestep_map.append(i)
92 | kwargs["betas"] = th.tensor(new_betas).numpy()
93 | super().__init__(**kwargs)
94 | self.register_buffer("timestep_map", th.tensor(timestep_map), persistent=False)
95 |
96 | def p_mean_variance(self, model, *args, **kwargs):
97 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
98 |
99 | def condition_mean(self, cond_fn, *args, **kwargs):
100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101 |
102 | def condition_score(self, cond_fn, *args, **kwargs):
103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104 |
105 | def _wrap_model(self, model):
106 | def wrapped(x, ts, **kwargs):
107 | ts_cpu = ts.detach().to("cpu")
108 | return model(
109 | x, self.timestep_map[ts_cpu].to(device=ts.device, dtype=ts.dtype), **kwargs
110 | )
111 |
112 | return wrapped
113 |
--------------------------------------------------------------------------------
/ldm/modules/midas/utils.py:
--------------------------------------------------------------------------------
1 | """Utils for monoDepth."""
2 | import sys
3 | import re
4 | import numpy as np
5 | import cv2
6 | import torch
7 |
8 |
9 | def read_pfm(path):
10 | """Read pfm file.
11 |
12 | Args:
13 | path (str): path to file
14 |
15 | Returns:
16 | tuple: (data, scale)
17 | """
18 | with open(path, "rb") as file:
19 |
20 | color = None
21 | width = None
22 | height = None
23 | scale = None
24 | endian = None
25 |
26 | header = file.readline().rstrip()
27 | if header.decode("ascii") == "PF":
28 | color = True
29 | elif header.decode("ascii") == "Pf":
30 | color = False
31 | else:
32 | raise Exception("Not a PFM file: " + path)
33 |
34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
35 | if dim_match:
36 | width, height = list(map(int, dim_match.groups()))
37 | else:
38 | raise Exception("Malformed PFM header.")
39 |
40 | scale = float(file.readline().decode("ascii").rstrip())
41 | if scale < 0:
42 | # little-endian
43 | endian = "<"
44 | scale = -scale
45 | else:
46 | # big-endian
47 | endian = ">"
48 |
49 | data = np.fromfile(file, endian + "f")
50 | shape = (height, width, 3) if color else (height, width)
51 |
52 | data = np.reshape(data, shape)
53 | data = np.flipud(data)
54 |
55 | return data, scale
56 |
57 |
58 | def write_pfm(path, image, scale=1):
59 | """Write pfm file.
60 |
61 | Args:
62 | path (str): pathto file
63 | image (array): data
64 | scale (int, optional): Scale. Defaults to 1.
65 | """
66 |
67 | with open(path, "wb") as file:
68 | color = None
69 |
70 | if image.dtype.name != "float32":
71 | raise Exception("Image dtype must be float32.")
72 |
73 | image = np.flipud(image)
74 |
75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image
76 | color = True
77 | elif (
78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
79 | ): # greyscale
80 | color = False
81 | else:
82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
83 |
84 | file.write("PF\n" if color else "Pf\n".encode())
85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
86 |
87 | endian = image.dtype.byteorder
88 |
89 | if endian == "<" or endian == "=" and sys.byteorder == "little":
90 | scale = -scale
91 |
92 | file.write("%f\n".encode() % scale)
93 |
94 | image.tofile(file)
95 |
96 |
97 | def read_image(path):
98 | """Read image and output RGB image (0-1).
99 |
100 | Args:
101 | path (str): path to file
102 |
103 | Returns:
104 | array: RGB image (0-1)
105 | """
106 | img = cv2.imread(path)
107 |
108 | if img.ndim == 2:
109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
110 |
111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
112 |
113 | return img
114 |
115 |
116 | def resize_image(img):
117 | """Resize image and make it fit for network.
118 |
119 | Args:
120 | img (array): image
121 |
122 | Returns:
123 | tensor: data ready for network
124 | """
125 | height_orig = img.shape[0]
126 | width_orig = img.shape[1]
127 |
128 | if width_orig > height_orig:
129 | scale = width_orig / 384
130 | else:
131 | scale = height_orig / 384
132 |
133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
135 |
136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
137 |
138 | img_resized = (
139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
140 | )
141 | img_resized = img_resized.unsqueeze(0)
142 |
143 | return img_resized
144 |
145 |
146 | def resize_depth(depth, width, height):
147 | """Resize depth map and bring to CPU (numpy).
148 |
149 | Args:
150 | depth (tensor): depth
151 | width (int): image width
152 | height (int): image height
153 |
154 | Returns:
155 | array: processed depth
156 | """
157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
158 |
159 | depth_resized = cv2.resize(
160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
161 | )
162 |
163 | return depth_resized
164 |
165 | def write_depth(path, depth, bits=1):
166 | """Write depth map to pfm and png file.
167 |
168 | Args:
169 | path (str): filepath without extension
170 | depth (array): depth
171 | """
172 | write_pfm(path + ".pfm", depth.astype(np.float32))
173 |
174 | depth_min = depth.min()
175 | depth_max = depth.max()
176 |
177 | max_val = (2**(8*bits))-1
178 |
179 | if depth_max - depth_min > np.finfo("float").eps:
180 | out = max_val * (depth - depth_min) / (depth_max - depth_min)
181 | else:
182 | out = np.zeros(depth.shape, dtype=depth.type)
183 |
184 | if bits == 1:
185 | cv2.imwrite(path + ".png", out.astype("uint8"))
186 | elif bits == 2:
187 | cv2.imwrite(path + ".png", out.astype("uint16"))
188 |
189 | return
190 |
--------------------------------------------------------------------------------
/ldm/modules/karlo/kakao/models/prior_model.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Karlo-v1.0.alpha
3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 |
6 | import copy
7 | import torch
8 |
9 | from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion
10 | from ldm.modules.karlo.kakao.modules.xf import PriorTransformer
11 |
12 |
13 | class PriorDiffusionModel(torch.nn.Module):
14 | """
15 | A prior that generates clip image feature based on the text prompt.
16 |
17 | :param config: yaml config to define the decoder.
18 | :param tokenizer: tokenizer used in clip.
19 | :param clip_mean: mean to normalize the clip image feature (zero-mean, unit variance).
20 | :param clip_std: std to noramlize the clip image feature (zero-mean, unit variance).
21 | """
22 |
23 | def __init__(self, config, tokenizer, clip_mean, clip_std):
24 | super().__init__()
25 |
26 | self._conf = config
27 | self._model_conf = config.model.hparams
28 | self._diffusion_kwargs = dict(
29 | steps=config.diffusion.steps,
30 | learn_sigma=config.diffusion.learn_sigma,
31 | sigma_small=config.diffusion.sigma_small,
32 | noise_schedule=config.diffusion.noise_schedule,
33 | use_kl=config.diffusion.use_kl,
34 | predict_xstart=config.diffusion.predict_xstart,
35 | rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas,
36 | timestep_respacing=config.diffusion.timestep_respacing,
37 | )
38 | self._tokenizer = tokenizer
39 |
40 | self.register_buffer("clip_mean", clip_mean[None, :], persistent=False)
41 | self.register_buffer("clip_std", clip_std[None, :], persistent=False)
42 |
43 | causal_mask = self.get_causal_mask()
44 | self.register_buffer("causal_mask", causal_mask, persistent=False)
45 |
46 | self.model = PriorTransformer(
47 | text_ctx=self._model_conf.text_ctx,
48 | xf_width=self._model_conf.xf_width,
49 | xf_layers=self._model_conf.xf_layers,
50 | xf_heads=self._model_conf.xf_heads,
51 | xf_final_ln=self._model_conf.xf_final_ln,
52 | clip_dim=self._model_conf.clip_dim,
53 | )
54 |
55 | cf_token, cf_mask = self.set_cf_text_tensor()
56 | self.register_buffer("cf_token", cf_token, persistent=False)
57 | self.register_buffer("cf_mask", cf_mask, persistent=False)
58 |
59 | @classmethod
60 | def load_from_checkpoint(
61 | cls, config, tokenizer, clip_mean, clip_std, ckpt_path, strict: bool = True
62 | ):
63 | ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
64 |
65 | model = cls(config, tokenizer, clip_mean, clip_std)
66 | model.load_state_dict(ckpt, strict=strict)
67 | return model
68 |
69 | def set_cf_text_tensor(self):
70 | return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx)
71 |
72 | def get_sample_fn(self, timestep_respacing):
73 | use_ddim = timestep_respacing.startswith(("ddim", "fast"))
74 |
75 | diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs)
76 | diffusion_kwargs.update(timestep_respacing=timestep_respacing)
77 | diffusion = create_gaussian_diffusion(**diffusion_kwargs)
78 | sample_fn = diffusion.ddim_sample_loop if use_ddim else diffusion.p_sample_loop
79 |
80 | return sample_fn
81 |
82 | def get_causal_mask(self):
83 | seq_len = self._model_conf.text_ctx + 4
84 | mask = torch.empty(seq_len, seq_len)
85 | mask.fill_(float("-inf"))
86 | mask.triu_(1)
87 | mask = mask[None, ...]
88 | return mask
89 |
90 | def forward(
91 | self,
92 | txt_feat,
93 | txt_feat_seq,
94 | mask,
95 | cf_guidance_scales=None,
96 | timestep_respacing=None,
97 | denoised_fn=True,
98 | ):
99 | # cfg should be enabled in inference
100 | assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0)
101 |
102 | bsz_ = txt_feat.shape[0]
103 | bsz = bsz_ // 2
104 |
105 | def guided_model_fn(x_t, ts, **kwargs):
106 | half = x_t[: len(x_t) // 2]
107 | combined = torch.cat([half, half], dim=0)
108 | model_out = self.model(combined, ts, **kwargs)
109 | eps, rest = (
110 | model_out[:, : int(x_t.shape[1])],
111 | model_out[:, int(x_t.shape[1]) :],
112 | )
113 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
114 | half_eps = uncond_eps + cf_guidance_scales.view(-1, 1) * (
115 | cond_eps - uncond_eps
116 | )
117 | eps = torch.cat([half_eps, half_eps], dim=0)
118 | return torch.cat([eps, rest], dim=1)
119 |
120 | cond = {
121 | "text_emb": txt_feat,
122 | "text_enc": txt_feat_seq,
123 | "mask": mask,
124 | "causal_mask": self.causal_mask,
125 | }
126 | sample_fn = self.get_sample_fn(timestep_respacing)
127 | sample = sample_fn(
128 | guided_model_fn,
129 | (bsz_, self.model.clip_dim),
130 | noise=None,
131 | device=txt_feat.device,
132 | clip_denoised=False,
133 | denoised_fn=lambda x: torch.clamp(x, -10, 10),
134 | model_kwargs=cond,
135 | )
136 | sample = (sample * self.clip_std) + self.clip_mean
137 |
138 | return sample[:bsz]
139 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/midas_net_custom.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet_small(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17 | blocks={'expand': True}):
18 | """Init.
19 |
20 | Args:
21 | path (str, optional): Path to saved model. Defaults to None.
22 | features (int, optional): Number of features. Defaults to 256.
23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24 | """
25 | print("Loading weights: ", path)
26 |
27 | super(MidasNet_small, self).__init__()
28 |
29 | use_pretrained = False if path else True
30 |
31 | self.channels_last = channels_last
32 | self.blocks = blocks
33 | self.backbone = backbone
34 |
35 | self.groups = 1
36 |
37 | features1=features
38 | features2=features
39 | features3=features
40 | features4=features
41 | self.expand = False
42 | if "expand" in self.blocks and self.blocks['expand'] == True:
43 | self.expand = True
44 | features1=features
45 | features2=features*2
46 | features3=features*4
47 | features4=features*8
48 |
49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50 |
51 | self.scratch.activation = nn.ReLU(False)
52 |
53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57 |
58 |
59 | self.scratch.output_conv = nn.Sequential(
60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61 | Interpolate(scale_factor=2, mode="bilinear"),
62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63 | self.scratch.activation,
64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65 | nn.ReLU(True) if non_negative else nn.Identity(),
66 | nn.Identity(),
67 | )
68 |
69 | if path:
70 | self.load(path)
71 |
72 |
73 | def forward(self, x):
74 | """Forward pass.
75 |
76 | Args:
77 | x (tensor): input data (image)
78 |
79 | Returns:
80 | tensor: depth
81 | """
82 | if self.channels_last==True:
83 | print("self.channels_last = ", self.channels_last)
84 | x.contiguous(memory_format=torch.channels_last)
85 |
86 |
87 | layer_1 = self.pretrained.layer1(x)
88 | layer_2 = self.pretrained.layer2(layer_1)
89 | layer_3 = self.pretrained.layer3(layer_2)
90 | layer_4 = self.pretrained.layer4(layer_3)
91 |
92 | layer_1_rn = self.scratch.layer1_rn(layer_1)
93 | layer_2_rn = self.scratch.layer2_rn(layer_2)
94 | layer_3_rn = self.scratch.layer3_rn(layer_3)
95 | layer_4_rn = self.scratch.layer4_rn(layer_4)
96 |
97 |
98 | path_4 = self.scratch.refinenet4(layer_4_rn)
99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102 |
103 | out = self.scratch.output_conv(path_1)
104 |
105 | return torch.squeeze(out, dim=1)
106 |
107 |
108 |
109 | def fuse_model(m):
110 | prev_previous_type = nn.Identity()
111 | prev_previous_name = ''
112 | previous_type = nn.Identity()
113 | previous_name = ''
114 | for name, module in m.named_modules():
115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116 | # print("FUSED ", prev_previous_name, previous_name, name)
117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119 | # print("FUSED ", prev_previous_name, previous_name)
120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122 | # print("FUSED ", previous_name, name)
123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124 |
125 | prev_previous_type = previous_type
126 | prev_previous_name = previous_name
127 | previous_type = type(module)
128 | previous_name = name
--------------------------------------------------------------------------------
/ldm/modules/midas/api.py:
--------------------------------------------------------------------------------
1 | # based on https://github.com/isl-org/MiDaS
2 |
3 | import cv2
4 | import torch
5 | import torch.nn as nn
6 | from torchvision.transforms import Compose
7 |
8 | from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
9 | from ldm.modules.midas.midas.midas_net import MidasNet
10 | from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
11 | from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
12 |
13 |
14 | ISL_PATHS = {
15 | "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
16 | "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
17 | "midas_v21": "",
18 | "midas_v21_small": "",
19 | }
20 |
21 |
22 | def disabled_train(self, mode=True):
23 | """Overwrite model.train with this function to make sure train/eval mode
24 | does not change anymore."""
25 | return self
26 |
27 |
28 | def load_midas_transform(model_type):
29 | # https://github.com/isl-org/MiDaS/blob/master/run.py
30 | # load transform only
31 | if model_type == "dpt_large": # DPT-Large
32 | net_w, net_h = 384, 384
33 | resize_mode = "minimal"
34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
35 |
36 | elif model_type == "dpt_hybrid": # DPT-Hybrid
37 | net_w, net_h = 384, 384
38 | resize_mode = "minimal"
39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
40 |
41 | elif model_type == "midas_v21":
42 | net_w, net_h = 384, 384
43 | resize_mode = "upper_bound"
44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
45 |
46 | elif model_type == "midas_v21_small":
47 | net_w, net_h = 256, 256
48 | resize_mode = "upper_bound"
49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
50 |
51 | else:
52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
53 |
54 | transform = Compose(
55 | [
56 | Resize(
57 | net_w,
58 | net_h,
59 | resize_target=None,
60 | keep_aspect_ratio=True,
61 | ensure_multiple_of=32,
62 | resize_method=resize_mode,
63 | image_interpolation_method=cv2.INTER_CUBIC,
64 | ),
65 | normalization,
66 | PrepareForNet(),
67 | ]
68 | )
69 |
70 | return transform
71 |
72 |
73 | def load_model(model_type):
74 | # https://github.com/isl-org/MiDaS/blob/master/run.py
75 | # load network
76 | model_path = ISL_PATHS[model_type]
77 | if model_type == "dpt_large": # DPT-Large
78 | model = DPTDepthModel(
79 | path=model_path,
80 | backbone="vitl16_384",
81 | non_negative=True,
82 | )
83 | net_w, net_h = 384, 384
84 | resize_mode = "minimal"
85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
86 |
87 | elif model_type == "dpt_hybrid": # DPT-Hybrid
88 | model = DPTDepthModel(
89 | path=model_path,
90 | backbone="vitb_rn50_384",
91 | non_negative=True,
92 | )
93 | net_w, net_h = 384, 384
94 | resize_mode = "minimal"
95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
96 |
97 | elif model_type == "midas_v21":
98 | model = MidasNet(model_path, non_negative=True)
99 | net_w, net_h = 384, 384
100 | resize_mode = "upper_bound"
101 | normalization = NormalizeImage(
102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
103 | )
104 |
105 | elif model_type == "midas_v21_small":
106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
107 | non_negative=True, blocks={'expand': True})
108 | net_w, net_h = 256, 256
109 | resize_mode = "upper_bound"
110 | normalization = NormalizeImage(
111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
112 | )
113 |
114 | else:
115 | print(f"model_type '{model_type}' not implemented, use: --model_type large")
116 | assert False
117 |
118 | transform = Compose(
119 | [
120 | Resize(
121 | net_w,
122 | net_h,
123 | resize_target=None,
124 | keep_aspect_ratio=True,
125 | ensure_multiple_of=32,
126 | resize_method=resize_mode,
127 | image_interpolation_method=cv2.INTER_CUBIC,
128 | ),
129 | normalization,
130 | PrepareForNet(),
131 | ]
132 | )
133 |
134 | return model.eval(), transform
135 |
136 |
137 | class MiDaSInference(nn.Module):
138 | MODEL_TYPES_TORCH_HUB = [
139 | "DPT_Large",
140 | "DPT_Hybrid",
141 | "MiDaS_small"
142 | ]
143 | MODEL_TYPES_ISL = [
144 | "dpt_large",
145 | "dpt_hybrid",
146 | "midas_v21",
147 | "midas_v21_small",
148 | ]
149 |
150 | def __init__(self, model_type):
151 | super().__init__()
152 | assert (model_type in self.MODEL_TYPES_ISL)
153 | model, _ = load_model(model_type)
154 | self.model = model
155 | self.model.train = disabled_train
156 |
157 | def forward(self, x):
158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
159 | # NOTE: we expect that the correct transform has been called during dataloading.
160 | with torch.no_grad():
161 | prediction = self.model(x)
162 | prediction = torch.nn.functional.interpolate(
163 | prediction.unsqueeze(1),
164 | size=x.shape[2:],
165 | mode="bicubic",
166 | align_corners=False,
167 | )
168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
169 | return prediction
170 |
171 |
--------------------------------------------------------------------------------
/ldm/modules/karlo/kakao/models/clip.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Karlo-v1.0.alpha
3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 | # ------------------------------------------------------------------------------------
6 | # Adapted from OpenAI's CLIP (https://github.com/openai/CLIP/)
7 | # ------------------------------------------------------------------------------------
8 |
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import clip
14 |
15 | from clip.model import CLIP, convert_weights
16 | from clip.simple_tokenizer import SimpleTokenizer, default_bpe
17 |
18 |
19 | """===== Monkey-Patching original CLIP for JIT compile ====="""
20 |
21 |
22 | class LayerNorm(nn.LayerNorm):
23 | """Subclass torch's LayerNorm to handle fp16."""
24 |
25 | def forward(self, x: torch.Tensor):
26 | orig_type = x.dtype
27 | ret = F.layer_norm(
28 | x.type(torch.float32),
29 | self.normalized_shape,
30 | self.weight,
31 | self.bias,
32 | self.eps,
33 | )
34 | return ret.type(orig_type)
35 |
36 |
37 | clip.model.LayerNorm = LayerNorm
38 | delattr(clip.model.CLIP, "forward")
39 |
40 | """===== End of Monkey-Patching ====="""
41 |
42 |
43 | class CustomizedCLIP(CLIP):
44 | def __init__(self, *args, **kwargs):
45 | super().__init__(*args, **kwargs)
46 |
47 | @torch.jit.export
48 | def encode_image(self, image):
49 | return self.visual(image)
50 |
51 | @torch.jit.export
52 | def encode_text(self, text):
53 | # re-define this function to return unpooled text features
54 |
55 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
56 |
57 | x = x + self.positional_embedding.type(self.dtype)
58 | x = x.permute(1, 0, 2) # NLD -> LND
59 | x = self.transformer(x)
60 | x = x.permute(1, 0, 2) # LND -> NLD
61 | x = self.ln_final(x).type(self.dtype)
62 |
63 | x_seq = x
64 | # x.shape = [batch_size, n_ctx, transformer.width]
65 | # take features from the eot embedding (eot_token is the highest number in each sequence)
66 | x_out = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
67 |
68 | return x_out, x_seq
69 |
70 | @torch.jit.ignore
71 | def forward(self, image, text):
72 | super().forward(image, text)
73 |
74 | @classmethod
75 | def load_from_checkpoint(cls, ckpt_path: str):
76 | state_dict = torch.load(ckpt_path, map_location="cpu").state_dict()
77 |
78 | vit = "visual.proj" in state_dict
79 | if vit:
80 | vision_width = state_dict["visual.conv1.weight"].shape[0]
81 | vision_layers = len(
82 | [
83 | k
84 | for k in state_dict.keys()
85 | if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
86 | ]
87 | )
88 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
89 | grid_size = round(
90 | (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5
91 | )
92 | image_resolution = vision_patch_size * grid_size
93 | else:
94 | counts: list = [
95 | len(
96 | set(
97 | k.split(".")[2]
98 | for k in state_dict
99 | if k.startswith(f"visual.layer{b}")
100 | )
101 | )
102 | for b in [1, 2, 3, 4]
103 | ]
104 | vision_layers = tuple(counts)
105 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
106 | output_width = round(
107 | (state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5
108 | )
109 | vision_patch_size = None
110 | assert (
111 | output_width**2 + 1
112 | == state_dict["visual.attnpool.positional_embedding"].shape[0]
113 | )
114 | image_resolution = output_width * 32
115 |
116 | embed_dim = state_dict["text_projection"].shape[1]
117 | context_length = state_dict["positional_embedding"].shape[0]
118 | vocab_size = state_dict["token_embedding.weight"].shape[0]
119 | transformer_width = state_dict["ln_final.weight"].shape[0]
120 | transformer_heads = transformer_width // 64
121 | transformer_layers = len(
122 | set(
123 | k.split(".")[2]
124 | for k in state_dict
125 | if k.startswith("transformer.resblocks")
126 | )
127 | )
128 |
129 | model = cls(
130 | embed_dim,
131 | image_resolution,
132 | vision_layers,
133 | vision_width,
134 | vision_patch_size,
135 | context_length,
136 | vocab_size,
137 | transformer_width,
138 | transformer_heads,
139 | transformer_layers,
140 | )
141 |
142 | for key in ["input_resolution", "context_length", "vocab_size"]:
143 | if key in state_dict:
144 | del state_dict[key]
145 |
146 | convert_weights(model)
147 | model.load_state_dict(state_dict)
148 | model.eval()
149 | model.float()
150 | return model
151 |
152 |
153 | class CustomizedTokenizer(SimpleTokenizer):
154 | def __init__(self):
155 | super().__init__(bpe_path=default_bpe())
156 |
157 | self.sot_token = self.encoder["<|startoftext|>"]
158 | self.eot_token = self.encoder["<|endoftext|>"]
159 |
160 | def padded_tokens_and_mask(self, texts, text_ctx):
161 | assert isinstance(texts, list) and all(
162 | isinstance(elem, str) for elem in texts
163 | ), "texts should be a list of strings"
164 |
165 | all_tokens = [
166 | [self.sot_token] + self.encode(text) + [self.eot_token] for text in texts
167 | ]
168 |
169 | mask = [
170 | [True] * min(text_ctx, len(tokens))
171 | + [False] * max(text_ctx - len(tokens), 0)
172 | for tokens in all_tokens
173 | ]
174 | mask = torch.tensor(mask, dtype=torch.bool)
175 | result = torch.zeros(len(all_tokens), text_ctx, dtype=torch.int)
176 | for i, tokens in enumerate(all_tokens):
177 | if len(tokens) > text_ctx:
178 | tokens = tokens[:text_ctx]
179 | tokens[-1] = self.eot_token
180 | result[i, : len(tokens)] = torch.tensor(tokens)
181 |
182 | return result, mask
183 |
--------------------------------------------------------------------------------
/ldm/modules/karlo/kakao/models/decoder_model.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Karlo-v1.0.alpha
3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 |
6 | import copy
7 | import torch
8 |
9 | from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion
10 | from ldm.modules.karlo.kakao.modules.unet import PLMImUNet
11 |
12 |
13 | class Text2ImProgressiveModel(torch.nn.Module):
14 | """
15 | A decoder that generates 64x64px images based on the text prompt.
16 |
17 | :param config: yaml config to define the decoder.
18 | :param tokenizer: tokenizer used in clip.
19 | """
20 |
21 | def __init__(
22 | self,
23 | config,
24 | tokenizer,
25 | ):
26 | super().__init__()
27 |
28 | self._conf = config
29 | self._model_conf = config.model.hparams
30 | self._diffusion_kwargs = dict(
31 | steps=config.diffusion.steps,
32 | learn_sigma=config.diffusion.learn_sigma,
33 | sigma_small=config.diffusion.sigma_small,
34 | noise_schedule=config.diffusion.noise_schedule,
35 | use_kl=config.diffusion.use_kl,
36 | predict_xstart=config.diffusion.predict_xstart,
37 | rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas,
38 | timestep_respacing=config.diffusion.timestep_respacing,
39 | )
40 | self._tokenizer = tokenizer
41 |
42 | self.model = self.create_plm_dec_model()
43 |
44 | cf_token, cf_mask = self.set_cf_text_tensor()
45 | self.register_buffer("cf_token", cf_token, persistent=False)
46 | self.register_buffer("cf_mask", cf_mask, persistent=False)
47 |
48 | @classmethod
49 | def load_from_checkpoint(cls, config, tokenizer, ckpt_path, strict: bool = True):
50 | ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
51 |
52 | model = cls(config, tokenizer)
53 | model.load_state_dict(ckpt, strict=strict)
54 | return model
55 |
56 | def create_plm_dec_model(self):
57 | image_size = self._model_conf.image_size
58 | if self._model_conf.channel_mult == "":
59 | if image_size == 256:
60 | channel_mult = (1, 1, 2, 2, 4, 4)
61 | elif image_size == 128:
62 | channel_mult = (1, 1, 2, 3, 4)
63 | elif image_size == 64:
64 | channel_mult = (1, 2, 3, 4)
65 | else:
66 | raise ValueError(f"unsupported image size: {image_size}")
67 | else:
68 | channel_mult = tuple(
69 | int(ch_mult) for ch_mult in self._model_conf.channel_mult.split(",")
70 | )
71 | assert 2 ** (len(channel_mult) + 2) == image_size
72 |
73 | attention_ds = []
74 | for res in self._model_conf.attention_resolutions.split(","):
75 | attention_ds.append(image_size // int(res))
76 |
77 | return PLMImUNet(
78 | text_ctx=self._model_conf.text_ctx,
79 | xf_width=self._model_conf.xf_width,
80 | in_channels=3,
81 | model_channels=self._model_conf.num_channels,
82 | out_channels=6 if self._model_conf.learn_sigma else 3,
83 | num_res_blocks=self._model_conf.num_res_blocks,
84 | attention_resolutions=tuple(attention_ds),
85 | dropout=self._model_conf.dropout,
86 | channel_mult=channel_mult,
87 | num_heads=self._model_conf.num_heads,
88 | num_head_channels=self._model_conf.num_head_channels,
89 | num_heads_upsample=self._model_conf.num_heads_upsample,
90 | use_scale_shift_norm=self._model_conf.use_scale_shift_norm,
91 | resblock_updown=self._model_conf.resblock_updown,
92 | clip_dim=self._model_conf.clip_dim,
93 | clip_emb_mult=self._model_conf.clip_emb_mult,
94 | clip_emb_type=self._model_conf.clip_emb_type,
95 | clip_emb_drop=self._model_conf.clip_emb_drop,
96 | )
97 |
98 | def set_cf_text_tensor(self):
99 | return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx)
100 |
101 | def get_sample_fn(self, timestep_respacing):
102 | use_ddim = timestep_respacing.startswith(("ddim", "fast"))
103 |
104 | diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs)
105 | diffusion_kwargs.update(timestep_respacing=timestep_respacing)
106 | diffusion = create_gaussian_diffusion(**diffusion_kwargs)
107 | sample_fn = (
108 | diffusion.ddim_sample_loop_progressive
109 | if use_ddim
110 | else diffusion.p_sample_loop_progressive
111 | )
112 |
113 | return sample_fn
114 |
115 | def forward(
116 | self,
117 | txt_feat,
118 | txt_feat_seq,
119 | tok,
120 | mask,
121 | img_feat=None,
122 | cf_guidance_scales=None,
123 | timestep_respacing=None,
124 | ):
125 | # cfg should be enabled in inference
126 | assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0)
127 | assert img_feat is not None
128 |
129 | bsz = txt_feat.shape[0]
130 | img_sz = self._model_conf.image_size
131 |
132 | def guided_model_fn(x_t, ts, **kwargs):
133 | half = x_t[: len(x_t) // 2]
134 | combined = torch.cat([half, half], dim=0)
135 | model_out = self.model(combined, ts, **kwargs)
136 | eps, rest = model_out[:, :3], model_out[:, 3:]
137 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
138 | half_eps = uncond_eps + cf_guidance_scales.view(-1, 1, 1, 1) * (
139 | cond_eps - uncond_eps
140 | )
141 | eps = torch.cat([half_eps, half_eps], dim=0)
142 | return torch.cat([eps, rest], dim=1)
143 |
144 | cf_feat = self.model.cf_param.unsqueeze(0)
145 | cf_feat = cf_feat.expand(bsz // 2, -1)
146 | feat = torch.cat([img_feat, cf_feat.to(txt_feat.device)], dim=0)
147 |
148 | cond = {
149 | "y": feat,
150 | "txt_feat": txt_feat,
151 | "txt_feat_seq": txt_feat_seq,
152 | "mask": mask,
153 | }
154 | sample_fn = self.get_sample_fn(timestep_respacing)
155 | sample_outputs = sample_fn(
156 | guided_model_fn,
157 | (bsz, 3, img_sz, img_sz),
158 | noise=None,
159 | device=txt_feat.device,
160 | clip_denoised=True,
161 | model_kwargs=cond,
162 | )
163 |
164 | for out in sample_outputs:
165 | sample = out["sample"]
166 | yield sample if cf_guidance_scales is None else sample[
167 | : sample.shape[0] // 2
168 | ]
169 |
170 |
171 | class Text2ImModel(Text2ImProgressiveModel):
172 | def forward(
173 | self,
174 | txt_feat,
175 | txt_feat_seq,
176 | tok,
177 | mask,
178 | img_feat=None,
179 | cf_guidance_scales=None,
180 | timestep_respacing=None,
181 | ):
182 | last_out = None
183 | for out in super().forward(
184 | txt_feat,
185 | txt_feat_seq,
186 | tok,
187 | mask,
188 | img_feat,
189 | cf_guidance_scales,
190 | timestep_respacing,
191 | ):
192 | last_out = out
193 | return last_out
194 |
--------------------------------------------------------------------------------
/scripts/semantic_t2i.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch import optim
4 | from torch.utils.data import Dataset, DataLoader, random_split
5 | from scripts.dataset import Flickr8kDataset,Only_images_Flickr8kDataset
6 | from itertools import islice
7 | from ldm.util import instantiate_from_config
8 | from PIL import Image
9 | import PIL
10 | import torch
11 | import numpy as np
12 | import argparse, os
13 | from omegaconf import OmegaConf
14 | from pytorch_lightning import seed_everything
15 | from imwatermark import WatermarkEncoder
16 | from ldm.models.diffusion.ddim import DDIMSampler
17 | from tqdm import tqdm
18 | import lpips as lp
19 | from einops import rearrange, repeat
20 | from torch import autocast
21 | from tqdm import tqdm, trange
22 | from transformers import pipeline
23 | from scripts.qam import qam16ModulationTensor, qam16ModulationString
24 | import time
25 | from PIL import Image
26 | import torchvision.transforms as transforms
27 | from diffusers import StableDiffusionPipeline
28 | from transformers import pipeline
29 | from SSIM_PIL import compare_ssim
30 | from torchvision.transforms import Resize, ToTensor, Normalize, Compose
31 |
32 |
33 | '''
34 |
35 | INIT DATASET AND DATALOADER
36 |
37 | '''
38 | capt_file_path = "path/to/captions.txt" #"G:/Giordano/Flickr8kDataset/captions.txt"
39 | images_dir_path = "path/to/Images" #"G:/Giordano/Flickr8kDataset/Images/"
40 | batch_size = 1
41 |
42 | dataset = Only_images_Flickr8kDataset(images_dir_path)
43 |
44 | test_dataloader=DataLoader(dataset=dataset,batch_size=batch_size, shuffle=True)
45 |
46 |
47 | '''
48 | MODEL CHECKPOINT
49 |
50 | '''
51 |
52 |
53 | model_ckpt_path = "path/to/model-checkpoint" #"G:/Giordano/stablediffusion/checkpoints/v1-5-pruned.ckpt" #v2-1_512-ema-pruned.ckpt"
54 | config_path = "path/to/model-config" #"G:/Giordano/stablediffusion/configs/stable-diffusion/v1-inference.yaml"
55 |
56 |
57 |
58 | def load_model_from_config(config, ckpt, verbose=False):
59 | print(f"Loading model from {ckpt}")
60 | pl_sd = torch.load(ckpt, map_location="cpu")
61 | if "global_step" in pl_sd:
62 | print(f"Global Step: {pl_sd['global_step']}")
63 | sd = pl_sd["state_dict"]
64 | model = instantiate_from_config(config.model)
65 | m, u = model.load_state_dict(sd, strict=False)
66 | if len(m) > 0 and verbose:
67 | print("missing keys:")
68 | print(m)
69 | if len(u) > 0 and verbose:
70 | print("unexpected keys:")
71 | print(u)
72 |
73 | model.cuda()
74 | model.eval()
75 | return model
76 |
77 |
78 | def load_img(path):
79 | image = Image.open(path).convert("RGB")
80 | w, h = (512,512)#image.size
81 | #print(f"loaded input image of size ({w}, {h}) from {path}")
82 | w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
83 | image = image.resize((w, h), resample=PIL.Image.LANCZOS)
84 | image = np.array(image).astype(np.float32) / 255.0
85 | image = image[None].transpose(0, 3, 1, 2)
86 | image = torch.from_numpy(image)
87 | return 2. * image - 1.
88 |
89 |
90 |
91 |
92 | def test(dataloader,
93 | snr=10,
94 | num_images=100,
95 | sampling_steps = 50,
96 | outpath="outpath"
97 | ):
98 |
99 | blip = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
100 |
101 | model_id = "runwayml/stable-diffusion-v1-5"
102 | pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
103 | pipe = pipe.to("cuda")
104 |
105 | transform = Compose([Resize((512,512), antialias=True), transforms.PILToTensor() ])
106 |
107 | lpips = lp.LPIPS(net='alex')
108 |
109 |
110 | sample_path = os.path.join(outpath,f'Test-TEXTONLY-sample-{snr}-{sampling_steps}')
111 |
112 |
113 | os.makedirs(sample_path, exist_ok=True)
114 |
115 | sample_orig_path = os.path.join(outpath,f'Test-TEXTONLY-sample-orig-{snr}-{sampling_steps}')
116 |
117 | os.makedirs(sample_orig_path, exist_ok=True)
118 |
119 | text_path = os.path.join(outpath,f'Test-TEXTONLY-text-{snr}-{sampling_steps}')
120 |
121 | os.makedirs(text_path, exist_ok=True)
122 |
123 |
124 | lpips_values = []
125 | time_values = []
126 | ssim_values = []
127 |
128 | i=0
129 |
130 |
131 | for batch in tqdm(dataloader,total=num_images):
132 |
133 | img_file_path = batch[0]
134 |
135 | #Open Image
136 | init_image = Image.open(img_file_path)
137 |
138 | #Automatically extract caption using BLIP model
139 | prompt_blip = blip(init_image)[0]["generated_text"]
140 |
141 | #Save Caption for Clip metric computation
142 | f = open(os.path.join(text_path, f"{i}.txt"),"a")
143 | f.write(prompt_blip)
144 | f.close()
145 |
146 |
147 | #Introduce noise in the text (aka. simulate noisy channel)
148 | prompt_corrupted = qam16ModulationString(prompt_blip,snr)
149 |
150 | #Compute time to reconstruct image
151 | time_start = time.time()
152 | #Reconstruct image using noisy text caption
153 | image_generated = pipe(prompt_corrupted,num_inference_steps=sampling_steps).images[0]
154 | time_finish = time.time()
155 |
156 | time_elapsed = time_finish - time_start
157 | time_values.append(time_elapsed)
158 |
159 | #Save images for subsequent FID and CLIP Score computation
160 | image_generated.save(os.path.join(sample_path,f'{i}.png'))
161 | init_image.save(os.path.join(sample_orig_path,f'{i}.png'))
162 |
163 | #Compute SSIM
164 | init_image_copy = init_image.resize((512, 512), resample=PIL.Image.LANCZOS)
165 | ssim_values.append(compare_ssim(init_image_copy, image_generated))
166 |
167 | #Compute LPIPS
168 | image_generated = (transform(image_generated) / 255) *2 -1
169 | init_image = (transform(init_image) / 255 ) *2 - 1
170 | lp_score=lpips(init_image.cpu(),image_generated.cpu()).item()
171 | lpips_values.append(lp_score)
172 |
173 | i+=1
174 | if i==num_images:
175 | break
176 |
177 | print(f'mean lpips score: {sum(lpips_values)/len(lpips_values)}')
178 |
179 | print(f'mean ssim score: {sum(ssim_values)/len(ssim_values)}')
180 |
181 | print(f'mean time score: {sum(time_values)/len(time_values)}')
182 |
183 | if __name__ == "__main__":
184 |
185 | parser = argparse.ArgumentParser()
186 |
187 | parser.add_argument(
188 | "--outdir",
189 | type=str,
190 | nargs="?",
191 | help="dir to write results to",
192 | default="outputs/img2img-samples"
193 | )
194 |
195 | parser.add_argument(
196 | "--seed",
197 | type=int,
198 | default=42,
199 | help="the seed (for reproducible sampling)",
200 | )
201 |
202 |
203 | opt = parser.parse_args()
204 | seed_everything(opt.seed)
205 |
206 |
207 | os.makedirs(opt.outdir, exist_ok=True)
208 | outpath = opt.outdir
209 |
210 |
211 | #START TESTING
212 |
213 | test(test_dataloader,snr=10,num_images=100,outpath=outpath)
214 |
215 | test(test_dataloader,snr=8.75,num_images=100,outpath=outpath)
216 |
217 | test(test_dataloader,snr=7.50,num_images=100,outpath=outpath)
218 |
219 | test(test_dataloader,snr=6.25,num_images=100,outpath=outpath)
220 |
221 | test(test_dataloader,snr=5,num_images=100,outpath=outpath)
222 |
223 |
--------------------------------------------------------------------------------
/ldm/modules/karlo/kakao/modules/xf.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Adapted from the repos below:
3 | # (a) Guided-Diffusion (https://github.com/openai/guided-diffusion)
4 | # (b) CLIP ViT (https://github.com/openai/CLIP/)
5 | # ------------------------------------------------------------------------------------
6 |
7 | import math
8 |
9 | import torch as th
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from .nn import timestep_embedding
14 |
15 |
16 | def convert_module_to_f16(param):
17 | """
18 | Convert primitive modules to float16.
19 | """
20 | if isinstance(param, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
21 | param.weight.data = param.weight.data.half()
22 | if param.bias is not None:
23 | param.bias.data = param.bias.data.half()
24 |
25 |
26 | class LayerNorm(nn.LayerNorm):
27 | """
28 | Implementation that supports fp16 inputs but fp32 gains/biases.
29 | """
30 |
31 | def forward(self, x: th.Tensor):
32 | return super().forward(x.float()).to(x.dtype)
33 |
34 |
35 | class MultiheadAttention(nn.Module):
36 | def __init__(self, n_ctx, width, heads):
37 | super().__init__()
38 | self.n_ctx = n_ctx
39 | self.width = width
40 | self.heads = heads
41 | self.c_qkv = nn.Linear(width, width * 3)
42 | self.c_proj = nn.Linear(width, width)
43 | self.attention = QKVMultiheadAttention(heads, n_ctx)
44 |
45 | def forward(self, x, mask=None):
46 | x = self.c_qkv(x)
47 | x = self.attention(x, mask=mask)
48 | x = self.c_proj(x)
49 | return x
50 |
51 |
52 | class MLP(nn.Module):
53 | def __init__(self, width):
54 | super().__init__()
55 | self.width = width
56 | self.c_fc = nn.Linear(width, width * 4)
57 | self.c_proj = nn.Linear(width * 4, width)
58 | self.gelu = nn.GELU()
59 |
60 | def forward(self, x):
61 | return self.c_proj(self.gelu(self.c_fc(x)))
62 |
63 |
64 | class QKVMultiheadAttention(nn.Module):
65 | def __init__(self, n_heads: int, n_ctx: int):
66 | super().__init__()
67 | self.n_heads = n_heads
68 | self.n_ctx = n_ctx
69 |
70 | def forward(self, qkv, mask=None):
71 | bs, n_ctx, width = qkv.shape
72 | attn_ch = width // self.n_heads // 3
73 | scale = 1 / math.sqrt(math.sqrt(attn_ch))
74 | qkv = qkv.view(bs, n_ctx, self.n_heads, -1)
75 | q, k, v = th.split(qkv, attn_ch, dim=-1)
76 | weight = th.einsum("bthc,bshc->bhts", q * scale, k * scale)
77 | wdtype = weight.dtype
78 | if mask is not None:
79 | weight = weight + mask[:, None, ...]
80 | weight = th.softmax(weight, dim=-1).type(wdtype)
81 | return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
82 |
83 |
84 | class ResidualAttentionBlock(nn.Module):
85 | def __init__(
86 | self,
87 | n_ctx: int,
88 | width: int,
89 | heads: int,
90 | ):
91 | super().__init__()
92 |
93 | self.attn = MultiheadAttention(
94 | n_ctx,
95 | width,
96 | heads,
97 | )
98 | self.ln_1 = LayerNorm(width)
99 | self.mlp = MLP(width)
100 | self.ln_2 = LayerNorm(width)
101 |
102 | def forward(self, x, mask=None):
103 | x = x + self.attn(self.ln_1(x), mask=mask)
104 | x = x + self.mlp(self.ln_2(x))
105 | return x
106 |
107 |
108 | class Transformer(nn.Module):
109 | def __init__(
110 | self,
111 | n_ctx: int,
112 | width: int,
113 | layers: int,
114 | heads: int,
115 | ):
116 | super().__init__()
117 | self.n_ctx = n_ctx
118 | self.width = width
119 | self.layers = layers
120 | self.resblocks = nn.ModuleList(
121 | [
122 | ResidualAttentionBlock(
123 | n_ctx,
124 | width,
125 | heads,
126 | )
127 | for _ in range(layers)
128 | ]
129 | )
130 |
131 | def forward(self, x, mask=None):
132 | for block in self.resblocks:
133 | x = block(x, mask=mask)
134 | return x
135 |
136 |
137 | class PriorTransformer(nn.Module):
138 | """
139 | A Causal Transformer that conditions on CLIP text embedding, text.
140 |
141 | :param text_ctx: number of text tokens to expect.
142 | :param xf_width: width of the transformer.
143 | :param xf_layers: depth of the transformer.
144 | :param xf_heads: heads in the transformer.
145 | :param xf_final_ln: use a LayerNorm after the output layer.
146 | :param clip_dim: dimension of clip feature.
147 | """
148 |
149 | def __init__(
150 | self,
151 | text_ctx,
152 | xf_width,
153 | xf_layers,
154 | xf_heads,
155 | xf_final_ln,
156 | clip_dim,
157 | ):
158 | super().__init__()
159 |
160 | self.text_ctx = text_ctx
161 | self.xf_width = xf_width
162 | self.xf_layers = xf_layers
163 | self.xf_heads = xf_heads
164 | self.clip_dim = clip_dim
165 | self.ext_len = 4
166 |
167 | self.time_embed = nn.Sequential(
168 | nn.Linear(xf_width, xf_width),
169 | nn.SiLU(),
170 | nn.Linear(xf_width, xf_width),
171 | )
172 | self.text_enc_proj = nn.Linear(clip_dim, xf_width)
173 | self.text_emb_proj = nn.Linear(clip_dim, xf_width)
174 | self.clip_img_proj = nn.Linear(clip_dim, xf_width)
175 | self.out_proj = nn.Linear(xf_width, clip_dim)
176 | self.transformer = Transformer(
177 | text_ctx + self.ext_len,
178 | xf_width,
179 | xf_layers,
180 | xf_heads,
181 | )
182 | if xf_final_ln:
183 | self.final_ln = LayerNorm(xf_width)
184 | else:
185 | self.final_ln = None
186 |
187 | self.positional_embedding = nn.Parameter(
188 | th.empty(1, text_ctx + self.ext_len, xf_width)
189 | )
190 | self.prd_emb = nn.Parameter(th.randn((1, 1, xf_width)))
191 |
192 | nn.init.normal_(self.prd_emb, std=0.01)
193 | nn.init.normal_(self.positional_embedding, std=0.01)
194 |
195 | def forward(
196 | self,
197 | x,
198 | timesteps,
199 | text_emb=None,
200 | text_enc=None,
201 | mask=None,
202 | causal_mask=None,
203 | ):
204 | bsz = x.shape[0]
205 | mask = F.pad(mask, (0, self.ext_len), value=True)
206 |
207 | t_emb = self.time_embed(timestep_embedding(timesteps, self.xf_width))
208 | text_enc = self.text_enc_proj(text_enc)
209 | text_emb = self.text_emb_proj(text_emb)
210 | x = self.clip_img_proj(x)
211 |
212 | input_seq = [
213 | text_enc,
214 | text_emb[:, None, :],
215 | t_emb[:, None, :],
216 | x[:, None, :],
217 | self.prd_emb.to(x.dtype).expand(bsz, -1, -1),
218 | ]
219 | input = th.cat(input_seq, dim=1)
220 | input = input + self.positional_embedding.to(input.dtype)
221 |
222 | mask = th.where(mask, 0.0, float("-inf"))
223 | mask = (mask[:, None, :] + causal_mask).to(input.dtype)
224 |
225 | out = self.transformer(input, mask=mask)
226 | if self.final_ln is not None:
227 | out = self.final_ln(out)
228 |
229 | out = self.out_proj(out[:, -1])
230 |
231 | return out
232 |
--------------------------------------------------------------------------------
/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | from torch import optim
5 | import numpy as np
6 |
7 | from inspect import isfunction
8 | from PIL import Image, ImageDraw, ImageFont
9 |
10 |
11 | def autocast(f):
12 | def do_autocast(*args, **kwargs):
13 | with torch.cuda.amp.autocast(enabled=True,
14 | dtype=torch.get_autocast_gpu_dtype(),
15 | cache_enabled=torch.is_autocast_cache_enabled()):
16 | return f(*args, **kwargs)
17 |
18 | return do_autocast
19 |
20 |
21 | def log_txt_as_img(wh, xc, size=10):
22 | # wh a tuple of (width, height)
23 | # xc a list of captions to plot
24 | b = len(xc)
25 | txts = list()
26 | for bi in range(b):
27 | txt = Image.new("RGB", wh, color="white")
28 | draw = ImageDraw.Draw(txt)
29 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
30 | nc = int(40 * (wh[0] / 256))
31 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
32 |
33 | try:
34 | draw.text((0, 0), lines, fill="black", font=font)
35 | except UnicodeEncodeError:
36 | print("Cant encode string for logging. Skipping.")
37 |
38 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
39 | txts.append(txt)
40 | txts = np.stack(txts)
41 | txts = torch.tensor(txts)
42 | return txts
43 |
44 |
45 | def ismap(x):
46 | if not isinstance(x, torch.Tensor):
47 | return False
48 | return (len(x.shape) == 4) and (x.shape[1] > 3)
49 |
50 |
51 | def isimage(x):
52 | if not isinstance(x,torch.Tensor):
53 | return False
54 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
55 |
56 |
57 | def exists(x):
58 | return x is not None
59 |
60 |
61 | def default(val, d):
62 | if exists(val):
63 | return val
64 | return d() if isfunction(d) else d
65 |
66 |
67 | def mean_flat(tensor):
68 | """
69 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
70 | Take the mean over all non-batch dimensions.
71 | """
72 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
73 |
74 |
75 | def count_params(model, verbose=False):
76 | total_params = sum(p.numel() for p in model.parameters())
77 | if verbose:
78 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
79 | return total_params
80 |
81 |
82 | def instantiate_from_config(config):
83 | if not "target" in config:
84 | if config == '__is_first_stage__':
85 | return None
86 | elif config == "__is_unconditional__":
87 | return None
88 | raise KeyError("Expected key `target` to instantiate.")
89 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
90 |
91 |
92 | def get_obj_from_str(string, reload=False):
93 | module, cls = string.rsplit(".", 1)
94 | if reload:
95 | module_imp = importlib.import_module(module)
96 | importlib.reload(module_imp)
97 | return getattr(importlib.import_module(module, package=None), cls)
98 |
99 |
100 | class AdamWwithEMAandWings(optim.Optimizer):
101 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
102 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
103 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
104 | ema_power=1., param_names=()):
105 | """AdamW that saves EMA versions of the parameters."""
106 | if not 0.0 <= lr:
107 | raise ValueError("Invalid learning rate: {}".format(lr))
108 | if not 0.0 <= eps:
109 | raise ValueError("Invalid epsilon value: {}".format(eps))
110 | if not 0.0 <= betas[0] < 1.0:
111 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
112 | if not 0.0 <= betas[1] < 1.0:
113 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
114 | if not 0.0 <= weight_decay:
115 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
116 | if not 0.0 <= ema_decay <= 1.0:
117 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
118 | defaults = dict(lr=lr, betas=betas, eps=eps,
119 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
120 | ema_power=ema_power, param_names=param_names)
121 | super().__init__(params, defaults)
122 |
123 | def __setstate__(self, state):
124 | super().__setstate__(state)
125 | for group in self.param_groups:
126 | group.setdefault('amsgrad', False)
127 |
128 | @torch.no_grad()
129 | def step(self, closure=None):
130 | """Performs a single optimization step.
131 | Args:
132 | closure (callable, optional): A closure that reevaluates the model
133 | and returns the loss.
134 | """
135 | loss = None
136 | if closure is not None:
137 | with torch.enable_grad():
138 | loss = closure()
139 |
140 | for group in self.param_groups:
141 | params_with_grad = []
142 | grads = []
143 | exp_avgs = []
144 | exp_avg_sqs = []
145 | ema_params_with_grad = []
146 | state_sums = []
147 | max_exp_avg_sqs = []
148 | state_steps = []
149 | amsgrad = group['amsgrad']
150 | beta1, beta2 = group['betas']
151 | ema_decay = group['ema_decay']
152 | ema_power = group['ema_power']
153 |
154 | for p in group['params']:
155 | if p.grad is None:
156 | continue
157 | params_with_grad.append(p)
158 | if p.grad.is_sparse:
159 | raise RuntimeError('AdamW does not support sparse gradients')
160 | grads.append(p.grad)
161 |
162 | state = self.state[p]
163 |
164 | # State initialization
165 | if len(state) == 0:
166 | state['step'] = 0
167 | # Exponential moving average of gradient values
168 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
169 | # Exponential moving average of squared gradient values
170 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
171 | if amsgrad:
172 | # Maintains max of all exp. moving avg. of sq. grad. values
173 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
174 | # Exponential moving average of parameter values
175 | state['param_exp_avg'] = p.detach().float().clone()
176 |
177 | exp_avgs.append(state['exp_avg'])
178 | exp_avg_sqs.append(state['exp_avg_sq'])
179 | ema_params_with_grad.append(state['param_exp_avg'])
180 |
181 | if amsgrad:
182 | max_exp_avg_sqs.append(state['max_exp_avg_sq'])
183 |
184 | # update the steps for each param group update
185 | state['step'] += 1
186 | # record the step after step update
187 | state_steps.append(state['step'])
188 |
189 | optim._functional.adamw(params_with_grad,
190 | grads,
191 | exp_avgs,
192 | exp_avg_sqs,
193 | max_exp_avg_sqs,
194 | state_steps,
195 | amsgrad=amsgrad,
196 | beta1=beta1,
197 | beta2=beta2,
198 | lr=group['lr'],
199 | weight_decay=group['weight_decay'],
200 | eps=group['eps'],
201 | maximize=False)
202 |
203 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
204 | for param, ema_param in zip(params_with_grad, ema_params_with_grad):
205 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
206 |
207 | return loss
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import math
4 |
5 |
6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7 | """Rezise the sample to ensure the given size. Keeps aspect ratio.
8 |
9 | Args:
10 | sample (dict): sample
11 | size (tuple): image size
12 |
13 | Returns:
14 | tuple: new size
15 | """
16 | shape = list(sample["disparity"].shape)
17 |
18 | if shape[0] >= size[0] and shape[1] >= size[1]:
19 | return sample
20 |
21 | scale = [0, 0]
22 | scale[0] = size[0] / shape[0]
23 | scale[1] = size[1] / shape[1]
24 |
25 | scale = max(scale)
26 |
27 | shape[0] = math.ceil(scale * shape[0])
28 | shape[1] = math.ceil(scale * shape[1])
29 |
30 | # resize
31 | sample["image"] = cv2.resize(
32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33 | )
34 |
35 | sample["disparity"] = cv2.resize(
36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37 | )
38 | sample["mask"] = cv2.resize(
39 | sample["mask"].astype(np.float32),
40 | tuple(shape[::-1]),
41 | interpolation=cv2.INTER_NEAREST,
42 | )
43 | sample["mask"] = sample["mask"].astype(bool)
44 |
45 | return tuple(shape)
46 |
47 |
48 | class Resize(object):
49 | """Resize sample to given size (width, height).
50 | """
51 |
52 | def __init__(
53 | self,
54 | width,
55 | height,
56 | resize_target=True,
57 | keep_aspect_ratio=False,
58 | ensure_multiple_of=1,
59 | resize_method="lower_bound",
60 | image_interpolation_method=cv2.INTER_AREA,
61 | ):
62 | """Init.
63 |
64 | Args:
65 | width (int): desired output width
66 | height (int): desired output height
67 | resize_target (bool, optional):
68 | True: Resize the full sample (image, mask, target).
69 | False: Resize image only.
70 | Defaults to True.
71 | keep_aspect_ratio (bool, optional):
72 | True: Keep the aspect ratio of the input sample.
73 | Output sample might not have the given width and height, and
74 | resize behaviour depends on the parameter 'resize_method'.
75 | Defaults to False.
76 | ensure_multiple_of (int, optional):
77 | Output width and height is constrained to be multiple of this parameter.
78 | Defaults to 1.
79 | resize_method (str, optional):
80 | "lower_bound": Output will be at least as large as the given size.
81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.)
83 | Defaults to "lower_bound".
84 | """
85 | self.__width = width
86 | self.__height = height
87 |
88 | self.__resize_target = resize_target
89 | self.__keep_aspect_ratio = keep_aspect_ratio
90 | self.__multiple_of = ensure_multiple_of
91 | self.__resize_method = resize_method
92 | self.__image_interpolation_method = image_interpolation_method
93 |
94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
96 |
97 | if max_val is not None and y > max_val:
98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
99 |
100 | if y < min_val:
101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
102 |
103 | return y
104 |
105 | def get_size(self, width, height):
106 | # determine new height and width
107 | scale_height = self.__height / height
108 | scale_width = self.__width / width
109 |
110 | if self.__keep_aspect_ratio:
111 | if self.__resize_method == "lower_bound":
112 | # scale such that output size is lower bound
113 | if scale_width > scale_height:
114 | # fit width
115 | scale_height = scale_width
116 | else:
117 | # fit height
118 | scale_width = scale_height
119 | elif self.__resize_method == "upper_bound":
120 | # scale such that output size is upper bound
121 | if scale_width < scale_height:
122 | # fit width
123 | scale_height = scale_width
124 | else:
125 | # fit height
126 | scale_width = scale_height
127 | elif self.__resize_method == "minimal":
128 | # scale as least as possbile
129 | if abs(1 - scale_width) < abs(1 - scale_height):
130 | # fit width
131 | scale_height = scale_width
132 | else:
133 | # fit height
134 | scale_width = scale_height
135 | else:
136 | raise ValueError(
137 | f"resize_method {self.__resize_method} not implemented"
138 | )
139 |
140 | if self.__resize_method == "lower_bound":
141 | new_height = self.constrain_to_multiple_of(
142 | scale_height * height, min_val=self.__height
143 | )
144 | new_width = self.constrain_to_multiple_of(
145 | scale_width * width, min_val=self.__width
146 | )
147 | elif self.__resize_method == "upper_bound":
148 | new_height = self.constrain_to_multiple_of(
149 | scale_height * height, max_val=self.__height
150 | )
151 | new_width = self.constrain_to_multiple_of(
152 | scale_width * width, max_val=self.__width
153 | )
154 | elif self.__resize_method == "minimal":
155 | new_height = self.constrain_to_multiple_of(scale_height * height)
156 | new_width = self.constrain_to_multiple_of(scale_width * width)
157 | else:
158 | raise ValueError(f"resize_method {self.__resize_method} not implemented")
159 |
160 | return (new_width, new_height)
161 |
162 | def __call__(self, sample):
163 | width, height = self.get_size(
164 | sample["image"].shape[1], sample["image"].shape[0]
165 | )
166 |
167 | # resize sample
168 | sample["image"] = cv2.resize(
169 | sample["image"],
170 | (width, height),
171 | interpolation=self.__image_interpolation_method,
172 | )
173 |
174 | if self.__resize_target:
175 | if "disparity" in sample:
176 | sample["disparity"] = cv2.resize(
177 | sample["disparity"],
178 | (width, height),
179 | interpolation=cv2.INTER_NEAREST,
180 | )
181 |
182 | if "depth" in sample:
183 | sample["depth"] = cv2.resize(
184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
185 | )
186 |
187 | sample["mask"] = cv2.resize(
188 | sample["mask"].astype(np.float32),
189 | (width, height),
190 | interpolation=cv2.INTER_NEAREST,
191 | )
192 | sample["mask"] = sample["mask"].astype(bool)
193 |
194 | return sample
195 |
196 |
197 | class NormalizeImage(object):
198 | """Normlize image by given mean and std.
199 | """
200 |
201 | def __init__(self, mean, std):
202 | self.__mean = mean
203 | self.__std = std
204 |
205 | def __call__(self, sample):
206 | sample["image"] = (sample["image"] - self.__mean) / self.__std
207 |
208 | return sample
209 |
210 |
211 | class PrepareForNet(object):
212 | """Prepare sample for usage as network input.
213 | """
214 |
215 | def __init__(self):
216 | pass
217 |
218 | def __call__(self, sample):
219 | image = np.transpose(sample["image"], (2, 0, 1))
220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32)
221 |
222 | if "mask" in sample:
223 | sample["mask"] = sample["mask"].astype(np.float32)
224 | sample["mask"] = np.ascontiguousarray(sample["mask"])
225 |
226 | if "disparity" in sample:
227 | disparity = sample["disparity"].astype(np.float32)
228 | sample["disparity"] = np.ascontiguousarray(disparity)
229 |
230 | if "depth" in sample:
231 | depth = sample["depth"].astype(np.float32)
232 | sample["depth"] = np.ascontiguousarray(depth)
233 |
234 | return sample
235 |
--------------------------------------------------------------------------------
/ldm/models/autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 | import torch.nn.functional as F
4 | from contextlib import contextmanager
5 |
6 | from ldm.modules.diffusionmodules.model import Encoder, Decoder
7 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
8 |
9 | from ldm.util import instantiate_from_config
10 | from ldm.modules.ema import LitEma
11 |
12 |
13 | class AutoencoderKL(pl.LightningModule):
14 | def __init__(self,
15 | ddconfig,
16 | lossconfig,
17 | embed_dim,
18 | ckpt_path=None,
19 | ignore_keys=[],
20 | image_key="image",
21 | colorize_nlabels=None,
22 | monitor=None,
23 | ema_decay=None,
24 | learn_logvar=False
25 | ):
26 | super().__init__()
27 | self.learn_logvar = learn_logvar
28 | self.image_key = image_key
29 | self.encoder = Encoder(**ddconfig)
30 | self.decoder = Decoder(**ddconfig)
31 | self.loss = instantiate_from_config(lossconfig)
32 | assert ddconfig["double_z"]
33 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
34 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
35 | self.embed_dim = embed_dim
36 | if colorize_nlabels is not None:
37 | assert type(colorize_nlabels)==int
38 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
39 | if monitor is not None:
40 | self.monitor = monitor
41 |
42 | self.use_ema = ema_decay is not None
43 | if self.use_ema:
44 | self.ema_decay = ema_decay
45 | assert 0. < ema_decay < 1.
46 | self.model_ema = LitEma(self, decay=ema_decay)
47 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
48 |
49 | if ckpt_path is not None:
50 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
51 |
52 | def init_from_ckpt(self, path, ignore_keys=list()):
53 | sd = torch.load(path, map_location="cpu")["state_dict"]
54 | keys = list(sd.keys())
55 | for k in keys:
56 | for ik in ignore_keys:
57 | if k.startswith(ik):
58 | print("Deleting key {} from state_dict.".format(k))
59 | del sd[k]
60 | self.load_state_dict(sd, strict=False)
61 | print(f"Restored from {path}")
62 |
63 | @contextmanager
64 | def ema_scope(self, context=None):
65 | if self.use_ema:
66 | self.model_ema.store(self.parameters())
67 | self.model_ema.copy_to(self)
68 | if context is not None:
69 | print(f"{context}: Switched to EMA weights")
70 | try:
71 | yield None
72 | finally:
73 | if self.use_ema:
74 | self.model_ema.restore(self.parameters())
75 | if context is not None:
76 | print(f"{context}: Restored training weights")
77 |
78 | def on_train_batch_end(self, *args, **kwargs):
79 | if self.use_ema:
80 | self.model_ema(self)
81 |
82 | def encode(self, x):
83 | h = self.encoder(x)
84 | moments = self.quant_conv(h)
85 | posterior = DiagonalGaussianDistribution(moments)
86 | return posterior
87 |
88 | def decode(self, z):
89 | z = self.post_quant_conv(z)
90 | dec = self.decoder(z)
91 | return dec
92 |
93 | def forward(self, input, sample_posterior=True):
94 | posterior = self.encode(input)
95 | if sample_posterior:
96 | z = posterior.sample()
97 | else:
98 | z = posterior.mode()
99 | dec = self.decode(z)
100 | return dec, posterior
101 |
102 | def get_input(self, batch, k):
103 | x = batch[k]
104 | if len(x.shape) == 3:
105 | x = x[..., None]
106 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
107 | return x
108 |
109 | def training_step(self, batch, batch_idx, optimizer_idx):
110 | inputs = self.get_input(batch, self.image_key)
111 | reconstructions, posterior = self(inputs)
112 |
113 | if optimizer_idx == 0:
114 | # train encoder+decoder+logvar
115 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
116 | last_layer=self.get_last_layer(), split="train")
117 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
118 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
119 | return aeloss
120 |
121 | if optimizer_idx == 1:
122 | # train the discriminator
123 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
124 | last_layer=self.get_last_layer(), split="train")
125 |
126 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
127 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
128 | return discloss
129 |
130 | def validation_step(self, batch, batch_idx):
131 | log_dict = self._validation_step(batch, batch_idx)
132 | with self.ema_scope():
133 | log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
134 | return log_dict
135 |
136 | def _validation_step(self, batch, batch_idx, postfix=""):
137 | inputs = self.get_input(batch, self.image_key)
138 | reconstructions, posterior = self(inputs)
139 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
140 | last_layer=self.get_last_layer(), split="val"+postfix)
141 |
142 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
143 | last_layer=self.get_last_layer(), split="val"+postfix)
144 |
145 | self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
146 | self.log_dict(log_dict_ae)
147 | self.log_dict(log_dict_disc)
148 | return self.log_dict
149 |
150 | def configure_optimizers(self):
151 | lr = self.learning_rate
152 | ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
153 | self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
154 | if self.learn_logvar:
155 | print(f"{self.__class__.__name__}: Learning logvar")
156 | ae_params_list.append(self.loss.logvar)
157 | opt_ae = torch.optim.Adam(ae_params_list,
158 | lr=lr, betas=(0.5, 0.9))
159 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
160 | lr=lr, betas=(0.5, 0.9))
161 | return [opt_ae, opt_disc], []
162 |
163 | def get_last_layer(self):
164 | return self.decoder.conv_out.weight
165 |
166 | @torch.no_grad()
167 | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
168 | log = dict()
169 | x = self.get_input(batch, self.image_key)
170 | x = x.to(self.device)
171 | if not only_inputs:
172 | xrec, posterior = self(x)
173 | if x.shape[1] > 3:
174 | # colorize with random projection
175 | assert xrec.shape[1] > 3
176 | x = self.to_rgb(x)
177 | xrec = self.to_rgb(xrec)
178 | log["samples"] = self.decode(torch.randn_like(posterior.sample()))
179 | log["reconstructions"] = xrec
180 | if log_ema or self.use_ema:
181 | with self.ema_scope():
182 | xrec_ema, posterior_ema = self(x)
183 | if x.shape[1] > 3:
184 | # colorize with random projection
185 | assert xrec_ema.shape[1] > 3
186 | xrec_ema = self.to_rgb(xrec_ema)
187 | log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
188 | log["reconstructions_ema"] = xrec_ema
189 | log["inputs"] = x
190 | return log
191 |
192 | def to_rgb(self, x):
193 | assert self.image_key == "segmentation"
194 | if not hasattr(self, "colorize"):
195 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
196 | x = F.conv2d(x, weight=self.colorize)
197 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
198 | return x
199 |
200 |
201 | class IdentityFirstStage(torch.nn.Module):
202 | def __init__(self, *args, vq_interface=False, **kwargs):
203 | self.vq_interface = vq_interface
204 | super().__init__()
205 |
206 | def encode(self, x, *args, **kwargs):
207 | return x
208 |
209 | def decode(self, x, *args, **kwargs):
210 | return x
211 |
212 | def quantize(self, x, *args, **kwargs):
213 | if self.vq_interface:
214 | return x, None, [None, None, None]
215 | return x
216 |
217 | def forward(self, x, *args, **kwargs):
218 | return x
219 |
220 |
--------------------------------------------------------------------------------
/ldm/modules/karlo/kakao/sampler.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Karlo-v1.0.alpha
3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved.
4 |
5 | # source: https://github.com/kakaobrain/karlo/blob/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/karlo/sampler/t2i.py#L15
6 | # ------------------------------------------------------------------------------------
7 |
8 | from typing import Iterator
9 |
10 | import torch
11 | import torchvision.transforms.functional as TVF
12 | from torchvision.transforms import InterpolationMode
13 |
14 | from .template import BaseSampler, CKPT_PATH
15 |
16 |
17 | class T2ISampler(BaseSampler):
18 | """
19 | A sampler for text-to-image generation.
20 | :param root_dir: directory for model checkpoints.
21 | :param sampling_type: ["default", "fast"]
22 | """
23 |
24 | def __init__(
25 | self,
26 | root_dir: str,
27 | sampling_type: str = "default",
28 | ):
29 | super().__init__(root_dir, sampling_type)
30 |
31 | @classmethod
32 | def from_pretrained(
33 | cls,
34 | root_dir: str,
35 | clip_model_path: str,
36 | clip_stat_path: str,
37 | sampling_type: str = "default",
38 | ):
39 |
40 | model = cls(
41 | root_dir=root_dir,
42 | sampling_type=sampling_type,
43 | )
44 | model.load_clip(clip_model_path)
45 | model.load_prior(
46 | f"{CKPT_PATH['prior']}",
47 | clip_stat_path=clip_stat_path,
48 | prior_config="configs/karlo/prior_1B_vit_l.yaml"
49 | )
50 | model.load_decoder(f"{CKPT_PATH['decoder']}", decoder_config="configs/karlo/decoder_900M_vit_l.yaml")
51 | model.load_sr_64_256(CKPT_PATH["sr_256"], sr_config="configs/karlo/improved_sr_64_256_1.4B.yaml")
52 | return model
53 |
54 | def preprocess(
55 | self,
56 | prompt: str,
57 | bsz: int,
58 | ):
59 | """Setup prompts & cfg scales"""
60 | prompts_batch = [prompt for _ in range(bsz)]
61 |
62 | prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch)
63 | prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda")
64 |
65 | decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
66 | decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda")
67 |
68 | """ Get CLIP text feature """
69 | clip_model = self._clip
70 | tokenizer = self._tokenizer
71 | max_txt_length = self._prior.model.text_ctx
72 |
73 | tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length)
74 | cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length)
75 | if not (cf_token.shape == tok.shape):
76 | cf_token = cf_token.expand(tok.shape[0], -1)
77 | cf_mask = cf_mask.expand(tok.shape[0], -1)
78 |
79 | tok = torch.cat([tok, cf_token], dim=0)
80 | mask = torch.cat([mask, cf_mask], dim=0)
81 |
82 | tok, mask = tok.to(device="cuda"), mask.to(device="cuda")
83 | txt_feat, txt_feat_seq = clip_model.encode_text(tok)
84 |
85 | return (
86 | prompts_batch,
87 | prior_cf_scales_batch,
88 | decoder_cf_scales_batch,
89 | txt_feat,
90 | txt_feat_seq,
91 | tok,
92 | mask,
93 | )
94 |
95 | def __call__(
96 | self,
97 | prompt: str,
98 | bsz: int,
99 | progressive_mode=None,
100 | ) -> Iterator[torch.Tensor]:
101 | assert progressive_mode in ("loop", "stage", "final")
102 | with torch.no_grad(), torch.cuda.amp.autocast():
103 | (
104 | prompts_batch,
105 | prior_cf_scales_batch,
106 | decoder_cf_scales_batch,
107 | txt_feat,
108 | txt_feat_seq,
109 | tok,
110 | mask,
111 | ) = self.preprocess(
112 | prompt,
113 | bsz,
114 | )
115 |
116 | """ Transform CLIP text feature into image feature """
117 | img_feat = self._prior(
118 | txt_feat,
119 | txt_feat_seq,
120 | mask,
121 | prior_cf_scales_batch,
122 | timestep_respacing=self._prior_sm,
123 | )
124 |
125 | """ Generate 64x64px images """
126 | images_64_outputs = self._decoder(
127 | txt_feat,
128 | txt_feat_seq,
129 | tok,
130 | mask,
131 | img_feat,
132 | cf_guidance_scales=decoder_cf_scales_batch,
133 | timestep_respacing=self._decoder_sm,
134 | )
135 |
136 | images_64 = None
137 | for k, out in enumerate(images_64_outputs):
138 | images_64 = out
139 | if progressive_mode == "loop":
140 | yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
141 | if progressive_mode == "stage":
142 | yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
143 |
144 | images_64 = torch.clamp(images_64, -1, 1)
145 |
146 | """ Upsample 64x64 to 256x256 """
147 | images_256 = TVF.resize(
148 | images_64,
149 | [256, 256],
150 | interpolation=InterpolationMode.BICUBIC,
151 | antialias=True,
152 | )
153 | images_256_outputs = self._sr_64_256(
154 | images_256, timestep_respacing=self._sr_sm
155 | )
156 |
157 | for k, out in enumerate(images_256_outputs):
158 | images_256 = out
159 | if progressive_mode == "loop":
160 | yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
161 | if progressive_mode == "stage":
162 | yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
163 |
164 | yield torch.clamp(images_256 * 0.5 + 0.5, 0.0, 1.0)
165 |
166 |
167 | class PriorSampler(BaseSampler):
168 | """
169 | A sampler for text-to-image generation, but only the prior.
170 | :param root_dir: directory for model checkpoints.
171 | :param sampling_type: ["default", "fast"]
172 | """
173 |
174 | def __init__(
175 | self,
176 | root_dir: str,
177 | sampling_type: str = "default",
178 | ):
179 | super().__init__(root_dir, sampling_type)
180 |
181 | @classmethod
182 | def from_pretrained(
183 | cls,
184 | root_dir: str,
185 | clip_model_path: str,
186 | clip_stat_path: str,
187 | sampling_type: str = "default",
188 | ):
189 | model = cls(
190 | root_dir=root_dir,
191 | sampling_type=sampling_type,
192 | )
193 | model.load_clip(clip_model_path)
194 | model.load_prior(
195 | f"{CKPT_PATH['prior']}",
196 | clip_stat_path=clip_stat_path,
197 | prior_config="configs/karlo/prior_1B_vit_l.yaml"
198 | )
199 | return model
200 |
201 | def preprocess(
202 | self,
203 | prompt: str,
204 | bsz: int,
205 | ):
206 | """Setup prompts & cfg scales"""
207 | prompts_batch = [prompt for _ in range(bsz)]
208 |
209 | prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch)
210 | prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda")
211 |
212 | decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
213 | decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda")
214 |
215 | """ Get CLIP text feature """
216 | clip_model = self._clip
217 | tokenizer = self._tokenizer
218 | max_txt_length = self._prior.model.text_ctx
219 |
220 | tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length)
221 | cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length)
222 | if not (cf_token.shape == tok.shape):
223 | cf_token = cf_token.expand(tok.shape[0], -1)
224 | cf_mask = cf_mask.expand(tok.shape[0], -1)
225 |
226 | tok = torch.cat([tok, cf_token], dim=0)
227 | mask = torch.cat([mask, cf_mask], dim=0)
228 |
229 | tok, mask = tok.to(device="cuda"), mask.to(device="cuda")
230 | txt_feat, txt_feat_seq = clip_model.encode_text(tok)
231 |
232 | return (
233 | prompts_batch,
234 | prior_cf_scales_batch,
235 | decoder_cf_scales_batch,
236 | txt_feat,
237 | txt_feat_seq,
238 | tok,
239 | mask,
240 | )
241 |
242 | def __call__(
243 | self,
244 | prompt: str,
245 | bsz: int,
246 | progressive_mode=None,
247 | ) -> Iterator[torch.Tensor]:
248 | assert progressive_mode in ("loop", "stage", "final")
249 | with torch.no_grad(), torch.cuda.amp.autocast():
250 | (
251 | prompts_batch,
252 | prior_cf_scales_batch,
253 | decoder_cf_scales_batch,
254 | txt_feat,
255 | txt_feat_seq,
256 | tok,
257 | mask,
258 | ) = self.preprocess(
259 | prompt,
260 | bsz,
261 | )
262 |
263 | """ Transform CLIP text feature into image feature """
264 | img_feat = self._prior(
265 | txt_feat,
266 | txt_feat_seq,
267 | mask,
268 | prior_cf_scales_batch,
269 | timestep_respacing=self._prior_sm,
270 | )
271 |
272 | yield img_feat
273 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .vit import (
5 | _make_pretrained_vitb_rn50_384,
6 | _make_pretrained_vitl16_384,
7 | _make_pretrained_vitb16_384,
8 | forward_vit,
9 | )
10 |
11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
12 | if backbone == "vitl16_384":
13 | pretrained = _make_pretrained_vitl16_384(
14 | use_pretrained, hooks=hooks, use_readout=use_readout
15 | )
16 | scratch = _make_scratch(
17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand
18 | ) # ViT-L/16 - 85.0% Top1 (backbone)
19 | elif backbone == "vitb_rn50_384":
20 | pretrained = _make_pretrained_vitb_rn50_384(
21 | use_pretrained,
22 | hooks=hooks,
23 | use_vit_only=use_vit_only,
24 | use_readout=use_readout,
25 | )
26 | scratch = _make_scratch(
27 | [256, 512, 768, 768], features, groups=groups, expand=expand
28 | ) # ViT-H/16 - 85.0% Top1 (backbone)
29 | elif backbone == "vitb16_384":
30 | pretrained = _make_pretrained_vitb16_384(
31 | use_pretrained, hooks=hooks, use_readout=use_readout
32 | )
33 | scratch = _make_scratch(
34 | [96, 192, 384, 768], features, groups=groups, expand=expand
35 | ) # ViT-B/16 - 84.6% Top1 (backbone)
36 | elif backbone == "resnext101_wsl":
37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
39 | elif backbone == "efficientnet_lite3":
40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
42 | else:
43 | print(f"Backbone '{backbone}' not implemented")
44 | assert False
45 |
46 | return pretrained, scratch
47 |
48 |
49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
50 | scratch = nn.Module()
51 |
52 | out_shape1 = out_shape
53 | out_shape2 = out_shape
54 | out_shape3 = out_shape
55 | out_shape4 = out_shape
56 | if expand==True:
57 | out_shape1 = out_shape
58 | out_shape2 = out_shape*2
59 | out_shape3 = out_shape*4
60 | out_shape4 = out_shape*8
61 |
62 | scratch.layer1_rn = nn.Conv2d(
63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
64 | )
65 | scratch.layer2_rn = nn.Conv2d(
66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
67 | )
68 | scratch.layer3_rn = nn.Conv2d(
69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
70 | )
71 | scratch.layer4_rn = nn.Conv2d(
72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
73 | )
74 |
75 | return scratch
76 |
77 |
78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
79 | efficientnet = torch.hub.load(
80 | "rwightman/gen-efficientnet-pytorch",
81 | "tf_efficientnet_lite3",
82 | pretrained=use_pretrained,
83 | exportable=exportable
84 | )
85 | return _make_efficientnet_backbone(efficientnet)
86 |
87 |
88 | def _make_efficientnet_backbone(effnet):
89 | pretrained = nn.Module()
90 |
91 | pretrained.layer1 = nn.Sequential(
92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
93 | )
94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
97 |
98 | return pretrained
99 |
100 |
101 | def _make_resnet_backbone(resnet):
102 | pretrained = nn.Module()
103 | pretrained.layer1 = nn.Sequential(
104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
105 | )
106 |
107 | pretrained.layer2 = resnet.layer2
108 | pretrained.layer3 = resnet.layer3
109 | pretrained.layer4 = resnet.layer4
110 |
111 | return pretrained
112 |
113 |
114 | def _make_pretrained_resnext101_wsl(use_pretrained):
115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
116 | return _make_resnet_backbone(resnet)
117 |
118 |
119 |
120 | class Interpolate(nn.Module):
121 | """Interpolation module.
122 | """
123 |
124 | def __init__(self, scale_factor, mode, align_corners=False):
125 | """Init.
126 |
127 | Args:
128 | scale_factor (float): scaling
129 | mode (str): interpolation mode
130 | """
131 | super(Interpolate, self).__init__()
132 |
133 | self.interp = nn.functional.interpolate
134 | self.scale_factor = scale_factor
135 | self.mode = mode
136 | self.align_corners = align_corners
137 |
138 | def forward(self, x):
139 | """Forward pass.
140 |
141 | Args:
142 | x (tensor): input
143 |
144 | Returns:
145 | tensor: interpolated data
146 | """
147 |
148 | x = self.interp(
149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
150 | )
151 |
152 | return x
153 |
154 |
155 | class ResidualConvUnit(nn.Module):
156 | """Residual convolution module.
157 | """
158 |
159 | def __init__(self, features):
160 | """Init.
161 |
162 | Args:
163 | features (int): number of features
164 | """
165 | super().__init__()
166 |
167 | self.conv1 = nn.Conv2d(
168 | features, features, kernel_size=3, stride=1, padding=1, bias=True
169 | )
170 |
171 | self.conv2 = nn.Conv2d(
172 | features, features, kernel_size=3, stride=1, padding=1, bias=True
173 | )
174 |
175 | self.relu = nn.ReLU(inplace=True)
176 |
177 | def forward(self, x):
178 | """Forward pass.
179 |
180 | Args:
181 | x (tensor): input
182 |
183 | Returns:
184 | tensor: output
185 | """
186 | out = self.relu(x)
187 | out = self.conv1(out)
188 | out = self.relu(out)
189 | out = self.conv2(out)
190 |
191 | return out + x
192 |
193 |
194 | class FeatureFusionBlock(nn.Module):
195 | """Feature fusion block.
196 | """
197 |
198 | def __init__(self, features):
199 | """Init.
200 |
201 | Args:
202 | features (int): number of features
203 | """
204 | super(FeatureFusionBlock, self).__init__()
205 |
206 | self.resConfUnit1 = ResidualConvUnit(features)
207 | self.resConfUnit2 = ResidualConvUnit(features)
208 |
209 | def forward(self, *xs):
210 | """Forward pass.
211 |
212 | Returns:
213 | tensor: output
214 | """
215 | output = xs[0]
216 |
217 | if len(xs) == 2:
218 | output += self.resConfUnit1(xs[1])
219 |
220 | output = self.resConfUnit2(output)
221 |
222 | output = nn.functional.interpolate(
223 | output, scale_factor=2, mode="bilinear", align_corners=True
224 | )
225 |
226 | return output
227 |
228 |
229 |
230 |
231 | class ResidualConvUnit_custom(nn.Module):
232 | """Residual convolution module.
233 | """
234 |
235 | def __init__(self, features, activation, bn):
236 | """Init.
237 |
238 | Args:
239 | features (int): number of features
240 | """
241 | super().__init__()
242 |
243 | self.bn = bn
244 |
245 | self.groups=1
246 |
247 | self.conv1 = nn.Conv2d(
248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
249 | )
250 |
251 | self.conv2 = nn.Conv2d(
252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
253 | )
254 |
255 | if self.bn==True:
256 | self.bn1 = nn.BatchNorm2d(features)
257 | self.bn2 = nn.BatchNorm2d(features)
258 |
259 | self.activation = activation
260 |
261 | self.skip_add = nn.quantized.FloatFunctional()
262 |
263 | def forward(self, x):
264 | """Forward pass.
265 |
266 | Args:
267 | x (tensor): input
268 |
269 | Returns:
270 | tensor: output
271 | """
272 |
273 | out = self.activation(x)
274 | out = self.conv1(out)
275 | if self.bn==True:
276 | out = self.bn1(out)
277 |
278 | out = self.activation(out)
279 | out = self.conv2(out)
280 | if self.bn==True:
281 | out = self.bn2(out)
282 |
283 | if self.groups > 1:
284 | out = self.conv_merge(out)
285 |
286 | return self.skip_add.add(out, x)
287 |
288 | # return out + x
289 |
290 |
291 | class FeatureFusionBlock_custom(nn.Module):
292 | """Feature fusion block.
293 | """
294 |
295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
296 | """Init.
297 |
298 | Args:
299 | features (int): number of features
300 | """
301 | super(FeatureFusionBlock_custom, self).__init__()
302 |
303 | self.deconv = deconv
304 | self.align_corners = align_corners
305 |
306 | self.groups=1
307 |
308 | self.expand = expand
309 | out_features = features
310 | if self.expand==True:
311 | out_features = features//2
312 |
313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
314 |
315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
317 |
318 | self.skip_add = nn.quantized.FloatFunctional()
319 |
320 | def forward(self, *xs):
321 | """Forward pass.
322 |
323 | Returns:
324 | tensor: output
325 | """
326 | output = xs[0]
327 |
328 | if len(xs) == 2:
329 | res = self.resConfUnit1(xs[1])
330 | output = self.skip_add.add(output, res)
331 | # output += res
332 |
333 | output = self.resConfUnit2(output)
334 |
335 | output = nn.functional.interpolate(
336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
337 | )
338 |
339 | output = self.out_conv(output)
340 |
341 | return output
342 |
343 |
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/util.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 |
11 | import os
12 | import math
13 | import torch
14 | import torch.nn as nn
15 | import numpy as np
16 | from einops import repeat
17 |
18 | from ldm.util import instantiate_from_config
19 |
20 |
21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22 | if schedule == "linear":
23 | betas = (
24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25 | )
26 |
27 | elif schedule == "cosine":
28 | timesteps = (
29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30 | )
31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
32 | alphas = torch.cos(alphas).pow(2)
33 | alphas = alphas / alphas[0]
34 | betas = 1 - alphas[1:] / alphas[:-1]
35 | betas = np.clip(betas, a_min=0, a_max=0.999)
36 |
37 | elif schedule == "squaredcos_cap_v2": # used for karlo prior
38 | # return early
39 | return betas_for_alpha_bar(
40 | n_timestep,
41 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
42 | )
43 |
44 | elif schedule == "sqrt_linear":
45 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
46 | elif schedule == "sqrt":
47 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
48 | else:
49 | raise ValueError(f"schedule '{schedule}' unknown.")
50 | return betas.numpy()
51 |
52 |
53 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
54 | if ddim_discr_method == 'uniform':
55 | c = num_ddpm_timesteps // num_ddim_timesteps
56 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
57 | elif ddim_discr_method == 'quad':
58 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
59 | else:
60 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
61 |
62 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
63 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
64 | steps_out = ddim_timesteps + 1
65 | if verbose:
66 | print(f'Selected timesteps for ddim sampler: {steps_out}')
67 | return steps_out
68 |
69 |
70 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
71 | # select alphas for computing the variance schedule
72 | alphas = alphacums[ddim_timesteps]
73 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
74 |
75 | # according the the formula provided in https://arxiv.org/abs/2010.02502
76 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
77 | if verbose:
78 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
79 | print(f'For the chosen value of eta, which is {eta}, '
80 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
81 | return sigmas, alphas, alphas_prev
82 |
83 |
84 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
85 | """
86 | Create a beta schedule that discretizes the given alpha_t_bar function,
87 | which defines the cumulative product of (1-beta) over time from t = [0,1].
88 | :param num_diffusion_timesteps: the number of betas to produce.
89 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
90 | produces the cumulative product of (1-beta) up to that
91 | part of the diffusion process.
92 | :param max_beta: the maximum beta to use; use values lower than 1 to
93 | prevent singularities.
94 | """
95 | betas = []
96 | for i in range(num_diffusion_timesteps):
97 | t1 = i / num_diffusion_timesteps
98 | t2 = (i + 1) / num_diffusion_timesteps
99 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
100 | return np.array(betas)
101 |
102 |
103 | def extract_into_tensor(a, t, x_shape):
104 | b, *_ = t.shape
105 | out = a.gather(-1, t)
106 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
107 |
108 |
109 | def checkpoint(func, inputs, params, flag):
110 | """
111 | Evaluate a function without caching intermediate activations, allowing for
112 | reduced memory at the expense of extra compute in the backward pass.
113 | :param func: the function to evaluate.
114 | :param inputs: the argument sequence to pass to `func`.
115 | :param params: a sequence of parameters `func` depends on but does not
116 | explicitly take as arguments.
117 | :param flag: if False, disable gradient checkpointing.
118 | """
119 | if flag:
120 | args = tuple(inputs) + tuple(params)
121 | return CheckpointFunction.apply(func, len(inputs), *args)
122 | else:
123 | return func(*inputs)
124 |
125 |
126 | class CheckpointFunction(torch.autograd.Function):
127 | @staticmethod
128 | def forward(ctx, run_function, length, *args):
129 | ctx.run_function = run_function
130 | ctx.input_tensors = list(args[:length])
131 | ctx.input_params = list(args[length:])
132 | ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
133 | "dtype": torch.get_autocast_gpu_dtype(),
134 | "cache_enabled": torch.is_autocast_cache_enabled()}
135 | with torch.no_grad():
136 | output_tensors = ctx.run_function(*ctx.input_tensors)
137 | return output_tensors
138 |
139 | @staticmethod
140 | def backward(ctx, *output_grads):
141 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
142 | with torch.enable_grad(), \
143 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
144 | # Fixes a bug where the first op in run_function modifies the
145 | # Tensor storage in place, which is not allowed for detach()'d
146 | # Tensors.
147 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
148 | output_tensors = ctx.run_function(*shallow_copies)
149 | input_grads = torch.autograd.grad(
150 | output_tensors,
151 | ctx.input_tensors + ctx.input_params,
152 | output_grads,
153 | allow_unused=True,
154 | )
155 | del ctx.input_tensors
156 | del ctx.input_params
157 | del output_tensors
158 | return (None, None) + input_grads
159 |
160 |
161 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
162 | """
163 | Create sinusoidal timestep embeddings.
164 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
165 | These may be fractional.
166 | :param dim: the dimension of the output.
167 | :param max_period: controls the minimum frequency of the embeddings.
168 | :return: an [N x dim] Tensor of positional embeddings.
169 | """
170 | if not repeat_only:
171 | half = dim // 2
172 | freqs = torch.exp(
173 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
174 | ).to(device=timesteps.device)
175 | args = timesteps[:, None].float() * freqs[None]
176 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
177 | if dim % 2:
178 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
179 | else:
180 | embedding = repeat(timesteps, 'b -> b d', d=dim)
181 | return embedding
182 |
183 |
184 | def zero_module(module):
185 | """
186 | Zero out the parameters of a module and return it.
187 | """
188 | for p in module.parameters():
189 | p.detach().zero_()
190 | return module
191 |
192 |
193 | def scale_module(module, scale):
194 | """
195 | Scale the parameters of a module and return it.
196 | """
197 | for p in module.parameters():
198 | p.detach().mul_(scale)
199 | return module
200 |
201 |
202 | def mean_flat(tensor):
203 | """
204 | Take the mean over all non-batch dimensions.
205 | """
206 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
207 |
208 |
209 | def normalization(channels):
210 | """
211 | Make a standard normalization layer.
212 | :param channels: number of input channels.
213 | :return: an nn.Module for normalization.
214 | """
215 | return GroupNorm32(32, channels)
216 |
217 |
218 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
219 | class SiLU(nn.Module):
220 | def forward(self, x):
221 | return x * torch.sigmoid(x)
222 |
223 |
224 | class GroupNorm32(nn.GroupNorm):
225 | def forward(self, x):
226 | return super().forward(x.float()).type(x.dtype)
227 |
228 |
229 | def conv_nd(dims, *args, **kwargs):
230 | """
231 | Create a 1D, 2D, or 3D convolution module.
232 | """
233 | if dims == 1:
234 | return nn.Conv1d(*args, **kwargs)
235 | elif dims == 2:
236 | return nn.Conv2d(*args, **kwargs)
237 | elif dims == 3:
238 | return nn.Conv3d(*args, **kwargs)
239 | raise ValueError(f"unsupported dimensions: {dims}")
240 |
241 |
242 | def linear(*args, **kwargs):
243 | """
244 | Create a linear module.
245 | """
246 | return nn.Linear(*args, **kwargs)
247 |
248 |
249 | def avg_pool_nd(dims, *args, **kwargs):
250 | """
251 | Create a 1D, 2D, or 3D average pooling module.
252 | """
253 | if dims == 1:
254 | return nn.AvgPool1d(*args, **kwargs)
255 | elif dims == 2:
256 | return nn.AvgPool2d(*args, **kwargs)
257 | elif dims == 3:
258 | return nn.AvgPool3d(*args, **kwargs)
259 | raise ValueError(f"unsupported dimensions: {dims}")
260 |
261 |
262 | class HybridConditioner(nn.Module):
263 |
264 | def __init__(self, c_concat_config, c_crossattn_config):
265 | super().__init__()
266 | self.concat_conditioner = instantiate_from_config(c_concat_config)
267 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
268 |
269 | def forward(self, c_concat, c_crossattn):
270 | c_concat = self.concat_conditioner(c_concat)
271 | c_crossattn = self.crossattn_conditioner(c_crossattn)
272 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
273 |
274 |
275 | def noise_like(shape, device, repeat=False):
276 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
277 | noise = lambda: torch.randn(shape, device=device)
278 | return repeat_noise() if repeat else noise()
279 |
--------------------------------------------------------------------------------
/scripts/semantic_i2i.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch import optim
4 | from torch.utils.data import Dataset, DataLoader, random_split
5 | from scripts.dataset import Flickr8kDataset,Only_images_Flickr8kDataset
6 | from itertools import islice
7 | from ldm.util import instantiate_from_config
8 | from PIL import Image
9 | import PIL
10 | import torch
11 | import numpy as np
12 | import argparse, os
13 | from omegaconf import OmegaConf
14 | from pytorch_lightning import seed_everything
15 | from imwatermark import WatermarkEncoder
16 | from ldm.models.diffusion.ddim import DDIMSampler
17 | from tqdm import tqdm
18 | import lpips as lp
19 | from einops import rearrange, repeat
20 | from torch import autocast
21 | from tqdm import tqdm, trange
22 | from transformers import pipeline
23 | from scripts.qam import qam16ModulationTensor, qam16ModulationString
24 | import time
25 |
26 | from SSIM_PIL import compare_ssim
27 |
28 | '''
29 |
30 | INIT DATASET AND DATALOADER
31 |
32 | '''
33 | capt_file_path = "path/to/captions.txt" #"G:/Giordano/Flickr8kDataset/captions.txt"
34 | images_dir_path = "path/to/Images" #"G:/Giordano/Flickr8kDataset/Images/"
35 | batch_size = 1
36 |
37 | dataset = Only_images_Flickr8kDataset(images_dir_path)
38 |
39 | test_dataloader=DataLoader(dataset=dataset,batch_size=batch_size, shuffle=True)
40 |
41 |
42 | '''
43 | MODEL CHECKPOINT
44 |
45 | '''
46 |
47 |
48 | model_ckpt_path = "path/to/model-checkpoint" #"G:/Giordano/stablediffusion/checkpoints/v1-5-pruned.ckpt" #v2-1_512-ema-pruned.ckpt"
49 | config_path = "path/to/model-config" #"G:/Giordano/stablediffusion/configs/stable-diffusion/v1-inference.yaml"
50 |
51 |
52 |
53 |
54 | def load_model_from_config(config, ckpt, verbose=False):
55 | print(f"Loading model from {ckpt}")
56 | pl_sd = torch.load(ckpt, map_location="cpu")
57 | if "global_step" in pl_sd:
58 | print(f"Global Step: {pl_sd['global_step']}")
59 | sd = pl_sd["state_dict"]
60 | model = instantiate_from_config(config.model)
61 | m, u = model.load_state_dict(sd, strict=False)
62 | if len(m) > 0 and verbose:
63 | print("missing keys:")
64 | print(m)
65 | if len(u) > 0 and verbose:
66 | print("unexpected keys:")
67 | print(u)
68 |
69 | model.cuda()
70 | model.eval()
71 | return model
72 |
73 |
74 | def load_img(path):
75 | image = Image.open(path).convert("RGB")
76 | w, h = (512,512)#image.size
77 | #print(f"loaded input image of size ({w}, {h}) from {path}")
78 | w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
79 | image = image.resize((w, h), resample=PIL.Image.LANCZOS)
80 | image = np.array(image).astype(np.float32) / 255.0
81 | image = image[None].transpose(0, 3, 1, 2)
82 | image = torch.from_numpy(image)
83 | return 2. * image - 1.
84 |
85 |
86 |
87 |
88 |
89 | def test(dataloader,
90 | snr=10,
91 | num_images=100,
92 | batch_size=1,
93 | num_images_per_sample=2,
94 | outpath='',
95 | model=None,
96 | device=None,
97 | sampler=None,
98 | strength=0.8,
99 | ddim_steps=50,
100 | scale=9.0):
101 |
102 | blip = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
103 | i=0
104 |
105 | sampling_steps = int(strength*50)
106 | print(sampling_steps)
107 | sampler.make_schedule(ddim_num_steps=50, ddim_eta=0.0, verbose=False) #attenzione ai parametri
108 | sample_path = os.path.join(outpath, f"Test-samples-{snr}-{sampling_steps}")
109 | os.makedirs(sample_path, exist_ok=True)
110 |
111 | text_path = os.path.join(outpath, f"Test-text-samples-{snr}-{sampling_steps}")
112 | os.makedirs(text_path, exist_ok=True)
113 |
114 | sample_orig_path = os.path.join(outpath, f"Test-samples-orig-{snr}-{sampling_steps}")
115 | os.makedirs(sample_orig_path, exist_ok=True)
116 |
117 | lpips = lp.LPIPS(net='alex')
118 | lpips_values = []
119 |
120 | ssim_values = []
121 |
122 | time_values = []
123 |
124 | tq = tqdm(dataloader,total=num_images)
125 | for batch in tq:
126 |
127 | img_file_path = batch[0]
128 |
129 | #Open Image
130 | init_image = Image.open(img_file_path)
131 |
132 | #Automatically extract caption using BLIP model
133 | prompt = blip(init_image)[0]["generated_text"]
134 | prompt_original = prompt
135 |
136 | base_count = len(os.listdir(sample_path))
137 |
138 | assert os.path.isfile(img_file_path)
139 | init_image = load_img(img_file_path).to(device)
140 | init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
141 | init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
142 |
143 | #print(init_latent.shape,init_latent.type())
144 |
145 | '''
146 | CHANNEL SIMULATION
147 | '''
148 |
149 | init_latent = qam16ModulationTensor(init_latent.cpu(),snr_db=snr).to(device)
150 |
151 | prompt = qam16ModulationString(prompt,snr_db=snr) #NOISY BLIP PROMPT
152 |
153 | data = [batch_size * [prompt]]
154 | assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
155 | t_enc = int(strength * ddim_steps)
156 |
157 | precision_scope = autocast
158 | with torch.no_grad():
159 | with precision_scope("cuda"):
160 | with model.ema_scope():
161 | all_samples = list()
162 | for n in range(1):
163 | for prompts in data:
164 | start_time = time.time()
165 | uc = None
166 | if scale != 1.0:
167 | uc = model.get_learned_conditioning(batch_size * [""])
168 | if isinstance(prompts, tuple):
169 | prompts = list(prompts)
170 | c = model.get_learned_conditioning(prompts)
171 | # encode (scaled latent)
172 | z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(device))
173 | # z_enc = init_latent
174 | # decode it
175 | samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=scale,
176 | unconditional_conditioning=uc, )
177 | x_samples = model.decode_first_stage(samples)
178 |
179 | x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
180 |
181 | end_time = time.time()
182 | execution_time = end_time - start_time
183 |
184 | time_values.append(execution_time)
185 |
186 | for x_sample in x_samples:
187 | x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
188 | img = Image.fromarray(x_sample.astype(np.uint8))
189 | #SAVE IMAGE
190 | img.save(os.path.join(sample_path, f"{base_count:05}.png"))
191 | #SAVE TEXT
192 | f = open(os.path.join(text_path, f"{base_count:05}.txt"),"a")
193 | f.write(prompt_original)
194 | f.close()
195 |
196 | #SAVE ORIGINAL IMAGE
197 | init_image_copy = Image.open(img_file_path)
198 | init_image_copy = init_image_copy.resize((512, 512), resample=PIL.Image.LANCZOS)
199 | init_image_copy.save(os.path.join(sample_orig_path, f"{base_count:05}.png"))
200 |
201 | # Compute SSIM
202 | ssim_values.append(compare_ssim(init_image_copy, img))
203 | base_count += 1
204 | all_samples.append(x_samples)
205 |
206 | #Compute LPIPS
207 | sample_out = (all_samples[0][0] * 2) - 1
208 | lp_score=lpips(init_image[0].cpu(),sample_out.cpu()).item()
209 |
210 | tq.set_postfix(lpips=lp_score)
211 |
212 | if not np.isnan(lp_score):
213 | lpips_values.append(lp_score)
214 |
215 |
216 | i+=1
217 | if i== num_images:
218 | break
219 |
220 |
221 | print(f'mean lpips score at snr={snr} : {sum(lpips_values)/len(lpips_values)}')
222 | print(f'mean ssim score at snr={snr} : {sum(ssim_values)/len(ssim_values)}')
223 | print(f'mean time with sampling iterations {sampling_steps} : {sum(time_values)/len(time_values)}')
224 | return 1
225 |
226 |
227 |
228 | if __name__ == "__main__":
229 |
230 | parser = argparse.ArgumentParser()
231 |
232 | parser.add_argument(
233 | "--outdir",
234 | type=str,
235 | nargs="?",
236 | help="dir to write results to",
237 | default="outputs/img2img-samples"
238 | )
239 |
240 | parser.add_argument(
241 | "--seed",
242 | type=int,
243 | default=42,
244 | help="the seed (for reproducible sampling)",
245 | )
246 |
247 |
248 | opt = parser.parse_args()
249 | seed_everything(opt.seed)
250 |
251 | config = OmegaConf.load(f"{config_path}")
252 | model = load_model_from_config(config, f"{model_ckpt_path}")
253 |
254 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
255 | model = model.to(device)
256 |
257 | sampler = DDIMSampler(model)
258 |
259 | os.makedirs(opt.outdir, exist_ok=True)
260 | outpath = opt.outdir
261 |
262 | #INIZIO TEST
263 |
264 | #Strength is used to modulate the number of sampling steps. Steps=50*strength
265 | test(test_dataloader,snr=10,num_images=100,batch_size=1,num_images_per_sample=1,outpath=outpath,
266 | model=model,device=device,sampler=sampler,strength=0.6,scale=9)
267 |
268 | test(test_dataloader,snr=8.75,num_images=100,batch_size=1,num_images_per_sample=1,outpath=outpath,
269 | model=model,device=device,sampler=sampler,strength=0.6,scale=9)
270 |
271 | test(test_dataloader,snr=7.50,num_images=100,batch_size=1,num_images_per_sample=1,outpath=outpath,
272 | model=model,device=device,sampler=sampler,strength=0.6,scale=9)
273 |
274 | test(test_dataloader,snr=6.25,num_images=100,batch_size=1,num_images_per_sample=1,outpath=outpath,
275 | model=model,device=device,sampler=sampler,strength=0.6,scale=9)
276 |
277 | test(test_dataloader,snr=5,num_images=100,batch_size=1,num_images_per_sample=1,outpath=outpath,
278 | model=model,device=device,sampler=sampler,strength=0.6,scale=9)
279 |
280 |
--------------------------------------------------------------------------------