├── ldm
├── data
│ ├── __init__.py
│ ├── base.py
│ ├── lsun.py
│ └── imagenet.py
├── models
│ ├── diffusion
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── ddim.cpython-38.pyc
│ │ │ ├── ddpm.cpython-38.pyc
│ │ │ ├── plms.cpython-38.pyc
│ │ │ └── __init__.cpython-38.pyc
│ │ ├── classifier.py
│ │ ├── _ddim.py
│ │ └── ddim.py
│ ├── __pycache__
│ │ └── autoencoder.cpython-38.pyc
│ └── autoencoder.py
├── modules
│ ├── encoders
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ └── modules.cpython-38.pyc
│ │ └── modules.py
│ ├── distributions
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ └── distributions.cpython-38.pyc
│ │ └── distributions.py
│ ├── diffusionmodules
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── model.cpython-38.pyc
│ │ │ ├── util.cpython-38.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ └── openaimodel.cpython-38.pyc
│ │ └── util.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── contperceptual.py
│ │ └── vqperceptual.py
│ ├── __pycache__
│ │ ├── ema.cpython-38.pyc
│ │ ├── attention.cpython-38.pyc
│ │ └── x_transformer.cpython-38.pyc
│ ├── image_degradation
│ │ ├── utils
│ │ │ └── test.png
│ │ └── __init__.py
│ ├── ema.py
│ ├── attention.py
│ └── x_transformer.py
├── __pycache__
│ ├── casa.cpython-38.pyc
│ └── util.cpython-38.pyc
├── lr_scheduler.py
└── util.py
├── teaser.png
├── environment.yaml
├── README.md
├── configs
└── stable-diffusion
│ └── v1-inference.yaml
└── utils.py
/ldm/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Mao/Initial-Noise-Construction/HEAD/teaser.png
--------------------------------------------------------------------------------
/ldm/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
--------------------------------------------------------------------------------
/ldm/__pycache__/casa.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Mao/Initial-Noise-Construction/HEAD/ldm/__pycache__/casa.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Mao/Initial-Noise-Construction/HEAD/ldm/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/__pycache__/ema.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Mao/Initial-Noise-Construction/HEAD/ldm/modules/__pycache__/ema.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/utils/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Mao/Initial-Noise-Construction/HEAD/ldm/modules/image_degradation/utils/test.png
--------------------------------------------------------------------------------
/ldm/models/__pycache__/autoencoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Mao/Initial-Noise-Construction/HEAD/ldm/models/__pycache__/autoencoder.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/__pycache__/attention.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Mao/Initial-Noise-Construction/HEAD/ldm/modules/__pycache__/attention.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Mao/Initial-Noise-Construction/HEAD/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Mao/Initial-Noise-Construction/HEAD/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Mao/Initial-Noise-Construction/HEAD/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/__pycache__/x_transformer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Mao/Initial-Noise-Construction/HEAD/ldm/modules/__pycache__/x_transformer.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Mao/Initial-Noise-Construction/HEAD/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Mao/Initial-Noise-Construction/HEAD/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Mao/Initial-Noise-Construction/HEAD/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Mao/Initial-Noise-Construction/HEAD/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Mao/Initial-Noise-Construction/HEAD/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Mao/Initial-Noise-Construction/HEAD/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Mao/Initial-Noise-Construction/HEAD/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Mao/Initial-Noise-Construction/HEAD/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Mao/Initial-Noise-Construction/HEAD/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: ldm
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.5
7 | - pip=20.3
8 | - cudatoolkit=11.0
9 | - pytorch=1.7.0
10 | - torchvision=0.8.1
11 | - numpy=1.19.2
12 | - pip:
13 | - albumentations==0.4.3
14 | - opencv-python==4.1.2.30
15 | - pudb==2019.2
16 | - imageio==2.9.0
17 | - imageio-ffmpeg==0.4.2
18 | - pytorch-lightning==1.6.1
19 | - omegaconf==2.1.1
20 | - test-tube>=0.7.5
21 | - streamlit>=0.73.1
22 | - einops==0.3.0
23 | - torch-fidelity==0.3.0
24 | - transformers==4.19.2
25 | - kornia==0.7.2
26 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
27 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip
28 | - -e .
--------------------------------------------------------------------------------
/ldm/data/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
3 |
4 |
5 | class Txt2ImgIterableBaseDataset(IterableDataset):
6 | '''
7 | Define an interface to make the IterableDatasets for text2img data chainable
8 | '''
9 | def __init__(self, num_records=0, valid_ids=None, size=256):
10 | super().__init__()
11 | self.num_records = num_records
12 | self.valid_ids = valid_ids
13 | self.sample_ids = valid_ids
14 | self.size = size
15 |
16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
17 |
18 | def __len__(self):
19 | return self.num_records
20 |
21 | @abstractmethod
22 | def __iter__(self):
23 | pass
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # The Lottery Ticket Hypothesis in Denoising: Towards Semantic-Driven Initialization (ECCV 2024)
2 |
3 | ## [Project Page] [Paper]
4 |
5 | Refer to our previous work for more discussion about initial noise in diffusion!
6 | - Guided Image Synthesis via Initial Image Editing in Diffusion Model (ACM MM 2023)
7 |
8 | 
9 |
10 |
11 | ## Setup
12 |
13 | Our codebase is built on [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion)
14 | and has shared dependencies and model architecture.
15 |
16 | ### Creating a Conda Environment
17 |
18 | ```
19 | conda env create -f environment.yaml
20 | conda activate ldm
21 | ```
22 |
23 | ### Downloading StableDiffusion Weights
24 |
25 | Download the StableDiffusion weights from the [CompVis organization at Hugging Face](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original)
26 | (download the `sd-v1-4.ckpt` file), and link them:
27 | ```
28 | mkdir -p models/ldm/stable-diffusion-v1/
29 | ln -s models/ldm/stable-diffusion-v1/model.ckpt
30 | ```
31 | ## Hands on
32 |
33 | Play with [hands-on](./hands_on_ECCV.ipynb) to try our approach right away, refer to [utils.py](./utils.py) for the implementation.
34 |
35 | ## Citation
36 | ```
37 | @article{mao2024theLottery,
38 | title={The Lottery Ticket Hypothesis in Denoising: Towards Semantic-Driven Initialization},
39 | author={Mao, Jiafeng and Wang, Xueting and Aizawa, Kiyoharu},
40 | journal={ECCV},
41 | year={2024}
42 | }
43 | ```
44 |
--------------------------------------------------------------------------------
/configs/stable-diffusion/v1-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 10000 ]
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32 # unused
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions: [ 4, 2, 1 ]
37 | num_res_blocks: 2
38 | channel_mult: [ 1, 2, 4, 4 ]
39 | num_heads: 8
40 | use_spatial_transformer: True
41 | transformer_depth: 1
42 | context_dim: 768
43 | use_checkpoint: True
44 | legacy: False
45 | save_map: True
46 |
47 | first_stage_config:
48 | target: ldm.models.autoencoder.AutoencoderKL
49 | params:
50 | embed_dim: 4
51 | monitor: val/rec_loss
52 | ddconfig:
53 | double_z: true
54 | z_channels: 4
55 | resolution: 256
56 | in_channels: 3
57 | out_ch: 3
58 | ch: 128
59 | ch_mult:
60 | - 1
61 | - 2
62 | - 4
63 | - 4
64 | num_res_blocks: 2
65 | attn_resolutions: []
66 | dropout: 0.0
67 | lossconfig:
68 | target: torch.nn.Identity
69 |
70 | cond_stage_config:
71 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
72 |
--------------------------------------------------------------------------------
/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31 |
32 | one_minus_decay = 1.0 - decay
33 |
34 | with torch.no_grad():
35 | m_param = dict(model.named_parameters())
36 | shadow_params = dict(self.named_buffers())
37 |
38 | for key in m_param:
39 | if m_param[key].requires_grad:
40 | sname = self.m_name2s_name[key]
41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43 | else:
44 | assert not key in self.m_name2s_name
45 |
46 | def copy_to(self, model):
47 | m_param = dict(model.named_parameters())
48 | shadow_params = dict(self.named_buffers())
49 | for key in m_param:
50 | if m_param[key].requires_grad:
51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52 | else:
53 | assert not key in self.m_name2s_name
54 |
55 | def store(self, parameters):
56 | """
57 | Save the current parameters for restoring later.
58 | Args:
59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60 | temporarily stored.
61 | """
62 | self.collected_params = [param.clone() for param in parameters]
63 |
64 | def restore(self, parameters):
65 | """
66 | Restore the parameters stored with the `store` method.
67 | Useful to validate the model with EMA parameters without affecting the
68 | original optimization process. Store the parameters before the
69 | `copy_to` method. After validation (or model saving), use this to
70 | restore the former parameters.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | updated with the stored parameters.
74 | """
75 | for c_param, param in zip(self.collected_params, parameters):
76 | param.data.copy_(c_param.data)
77 |
--------------------------------------------------------------------------------
/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/ldm/data/lsun.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import PIL
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms
7 |
8 |
9 | class LSUNBase(Dataset):
10 | def __init__(self,
11 | txt_file,
12 | data_root,
13 | size=None,
14 | interpolation="bicubic",
15 | flip_p=0.5
16 | ):
17 | self.data_paths = txt_file
18 | self.data_root = data_root
19 | with open(self.data_paths, "r") as f:
20 | self.image_paths = f.read().splitlines()
21 | self._length = len(self.image_paths)
22 | self.labels = {
23 | "relative_file_path_": [l for l in self.image_paths],
24 | "file_path_": [os.path.join(self.data_root, l)
25 | for l in self.image_paths],
26 | }
27 |
28 | self.size = size
29 | self.interpolation = {"linear": PIL.Image.LINEAR,
30 | "bilinear": PIL.Image.BILINEAR,
31 | "bicubic": PIL.Image.BICUBIC,
32 | "lanczos": PIL.Image.LANCZOS,
33 | }[interpolation]
34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35 |
36 | def __len__(self):
37 | return self._length
38 |
39 | def __getitem__(self, i):
40 | example = dict((k, self.labels[k][i]) for k in self.labels)
41 | image = Image.open(example["file_path_"])
42 | if not image.mode == "RGB":
43 | image = image.convert("RGB")
44 |
45 | # default to score-sde preprocessing
46 | img = np.array(image).astype(np.uint8)
47 | crop = min(img.shape[0], img.shape[1])
48 | h, w, = img.shape[0], img.shape[1]
49 | img = img[(h - crop) // 2:(h + crop) // 2,
50 | (w - crop) // 2:(w + crop) // 2]
51 |
52 | image = Image.fromarray(img)
53 | if self.size is not None:
54 | image = image.resize((self.size, self.size), resample=self.interpolation)
55 |
56 | image = self.flip(image)
57 | image = np.array(image).astype(np.uint8)
58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59 | return example
60 |
61 |
62 | class LSUNChurchesTrain(LSUNBase):
63 | def __init__(self, **kwargs):
64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65 |
66 |
67 | class LSUNChurchesValidation(LSUNBase):
68 | def __init__(self, flip_p=0., **kwargs):
69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70 | flip_p=flip_p, **kwargs)
71 |
72 |
73 | class LSUNBedroomsTrain(LSUNBase):
74 | def __init__(self, **kwargs):
75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76 |
77 |
78 | class LSUNBedroomsValidation(LSUNBase):
79 | def __init__(self, flip_p=0.0, **kwargs):
80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81 | flip_p=flip_p, **kwargs)
82 |
83 |
84 | class LSUNCatsTrain(LSUNBase):
85 | def __init__(self, **kwargs):
86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87 |
88 |
89 | class LSUNCatsValidation(LSUNBase):
90 | def __init__(self, flip_p=0., **kwargs):
91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92 | flip_p=flip_p, **kwargs)
93 |
--------------------------------------------------------------------------------
/ldm/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n, **kwargs):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n, **kwargs):
33 | return self.schedule(n,**kwargs)
34 |
35 |
36 | class LambdaWarmUpCosineScheduler2:
37 | """
38 | supports repeated iterations, configurable via lists
39 | note: use with a base_lr of 1.0.
40 | """
41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43 | self.lr_warm_up_steps = warm_up_steps
44 | self.f_start = f_start
45 | self.f_min = f_min
46 | self.f_max = f_max
47 | self.cycle_lengths = cycle_lengths
48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49 | self.last_f = 0.
50 | self.verbosity_interval = verbosity_interval
51 |
52 | def find_in_interval(self, n):
53 | interval = 0
54 | for cl in self.cum_cycles[1:]:
55 | if n <= cl:
56 | return interval
57 | interval += 1
58 |
59 | def schedule(self, n, **kwargs):
60 | cycle = self.find_in_interval(n)
61 | n = n - self.cum_cycles[cycle]
62 | if self.verbosity_interval > 0:
63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64 | f"current cycle {cycle}")
65 | if n < self.lr_warm_up_steps[cycle]:
66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67 | self.last_f = f
68 | return f
69 | else:
70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71 | t = min(t, 1.0)
72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73 | 1 + np.cos(t * np.pi))
74 | self.last_f = f
75 | return f
76 |
77 | def __call__(self, n, **kwargs):
78 | return self.schedule(n, **kwargs)
79 |
80 |
81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88 | f"current cycle {cycle}")
89 |
90 | if n < self.lr_warm_up_steps[cycle]:
91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92 | self.last_f = f
93 | return f
94 | else:
95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96 | self.last_f = f
97 | return f
98 |
99 |
--------------------------------------------------------------------------------
/ldm/modules/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 |
6 |
7 | class LPIPSWithDiscriminator(nn.Module):
8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
11 | disc_loss="hinge"):
12 |
13 | super().__init__()
14 | assert disc_loss in ["hinge", "vanilla"]
15 | self.kl_weight = kl_weight
16 | self.pixel_weight = pixelloss_weight
17 | self.perceptual_loss = LPIPS().eval()
18 | self.perceptual_weight = perceptual_weight
19 | # output log variance
20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
21 |
22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
23 | n_layers=disc_num_layers,
24 | use_actnorm=use_actnorm
25 | ).apply(weights_init)
26 | self.discriminator_iter_start = disc_start
27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
28 | self.disc_factor = disc_factor
29 | self.discriminator_weight = disc_weight
30 | self.disc_conditional = disc_conditional
31 |
32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
33 | if last_layer is not None:
34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
36 | else:
37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
39 |
40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
42 | d_weight = d_weight * self.discriminator_weight
43 | return d_weight
44 |
45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
46 | global_step, last_layer=None, cond=None, split="train",
47 | weights=None):
48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
49 | if self.perceptual_weight > 0:
50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
51 | rec_loss = rec_loss + self.perceptual_weight * p_loss
52 |
53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
54 | weighted_nll_loss = nll_loss
55 | if weights is not None:
56 | weighted_nll_loss = weights*nll_loss
57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
59 | kl_loss = posteriors.kl()
60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
61 |
62 | # now the GAN part
63 | if optimizer_idx == 0:
64 | # generator update
65 | if cond is None:
66 | assert not self.disc_conditional
67 | logits_fake = self.discriminator(reconstructions.contiguous())
68 | else:
69 | assert self.disc_conditional
70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
71 | g_loss = -torch.mean(logits_fake)
72 |
73 | if self.disc_factor > 0.0:
74 | try:
75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
76 | except RuntimeError:
77 | assert not self.training
78 | d_weight = torch.tensor(0.0)
79 | else:
80 | d_weight = torch.tensor(0.0)
81 |
82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
84 |
85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
87 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
88 | "{}/d_weight".format(split): d_weight.detach(),
89 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
90 | "{}/g_loss".format(split): g_loss.detach().mean(),
91 | }
92 | return loss, log
93 |
94 | if optimizer_idx == 1:
95 | # second pass for discriminator update
96 | if cond is None:
97 | logits_real = self.discriminator(inputs.contiguous().detach())
98 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
99 | else:
100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
102 |
103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
105 |
106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
107 | "{}/logits_real".format(split): logits_real.detach().mean(),
108 | "{}/logits_fake".format(split): logits_fake.detach().mean()
109 | }
110 | return d_loss, log
111 |
112 |
--------------------------------------------------------------------------------
/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | import numpy as np
5 | from collections import abc
6 | from einops import rearrange
7 | from functools import partial
8 |
9 | import multiprocessing as mp
10 | from threading import Thread
11 | from queue import Queue
12 |
13 | from inspect import isfunction
14 | from PIL import Image, ImageDraw, ImageFont
15 | import torch.nn.functional as F
16 |
17 | def log_txt_as_img(wh, xc, size=10):
18 | # wh a tuple of (width, height)
19 | # xc a list of captions to plot
20 | b = len(xc)
21 | txts = list()
22 | for bi in range(b):
23 | txt = Image.new("RGB", wh, color="white")
24 | draw = ImageDraw.Draw(txt)
25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
26 | nc = int(40 * (wh[0] / 256))
27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
28 |
29 | try:
30 | draw.text((0, 0), lines, fill="black", font=font)
31 | except UnicodeEncodeError:
32 | print("Cant encode string for logging. Skipping.")
33 |
34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
35 | txts.append(txt)
36 | txts = np.stack(txts)
37 | txts = torch.tensor(txts)
38 | return txts
39 |
40 |
41 | def ismap(x):
42 | if not isinstance(x, torch.Tensor):
43 | return False
44 | return (len(x.shape) == 4) and (x.shape[1] > 3)
45 |
46 |
47 | def isimage(x):
48 | if not isinstance(x, torch.Tensor):
49 | return False
50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
51 |
52 |
53 | def exists(x):
54 | return x is not None
55 |
56 |
57 | def default(val, d):
58 | if exists(val):
59 | return val
60 | return d() if isfunction(d) else d
61 |
62 |
63 | def mean_flat(tensor):
64 | """
65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
66 | Take the mean over all non-batch dimensions.
67 | """
68 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
69 |
70 |
71 | def count_params(model, verbose=False):
72 | total_params = sum(p.numel() for p in model.parameters())
73 | if verbose:
74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
75 | return total_params
76 |
77 |
78 | def instantiate_from_config(config):
79 | if not "target" in config:
80 | if config == '__is_first_stage__':
81 | return None
82 | elif config == "__is_unconditional__":
83 | return None
84 | raise KeyError("Expected key `target` to instantiate.")
85 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
86 |
87 |
88 | def get_obj_from_str(string, reload=False):
89 | module, cls = string.rsplit(".", 1)
90 | if reload:
91 | module_imp = importlib.import_module(module)
92 | importlib.reload(module_imp)
93 | return getattr(importlib.import_module(module, package=None), cls)
94 |
95 |
96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
97 | # create dummy dataset instance
98 |
99 | # run prefetching
100 | if idx_to_fn:
101 | res = func(data, worker_id=idx)
102 | else:
103 | res = func(data)
104 | Q.put([idx, res])
105 | Q.put("Done")
106 |
107 |
108 | def parallel_data_prefetch(
109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
110 | ):
111 | # if target_data_type not in ["ndarray", "list"]:
112 | # raise ValueError(
113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
114 | # )
115 | if isinstance(data, np.ndarray) and target_data_type == "list":
116 | raise ValueError("list expected but function got ndarray.")
117 | elif isinstance(data, abc.Iterable):
118 | if isinstance(data, dict):
119 | print(
120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
121 | )
122 | data = list(data.values())
123 | if target_data_type == "ndarray":
124 | data = np.asarray(data)
125 | else:
126 | data = list(data)
127 | else:
128 | raise TypeError(
129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
130 | )
131 |
132 | if cpu_intensive:
133 | Q = mp.Queue(1000)
134 | proc = mp.Process
135 | else:
136 | Q = Queue(1000)
137 | proc = Thread
138 | # spawn processes
139 | if target_data_type == "ndarray":
140 | arguments = [
141 | [func, Q, part, i, use_worker_id]
142 | for i, part in enumerate(np.array_split(data, n_proc))
143 | ]
144 | else:
145 | step = (
146 | int(len(data) / n_proc + 1)
147 | if len(data) % n_proc != 0
148 | else int(len(data) / n_proc)
149 | )
150 | arguments = [
151 | [func, Q, part, i, use_worker_id]
152 | for i, part in enumerate(
153 | [data[i: i + step] for i in range(0, len(data), step)]
154 | )
155 | ]
156 | processes = []
157 | for i in range(n_proc):
158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
159 | processes += [p]
160 |
161 | # start processes
162 | print(f"Start prefetching...")
163 | import time
164 |
165 | start = time.time()
166 | gather_res = [[] for _ in range(n_proc)]
167 | try:
168 | for p in processes:
169 | p.start()
170 |
171 | k = 0
172 | while k < n_proc:
173 | # get result
174 | res = Q.get()
175 | if res == "Done":
176 | k += 1
177 | else:
178 | gather_res[res[0]] = res[1]
179 |
180 | except Exception as e:
181 | print("Exception: ", e)
182 | for p in processes:
183 | p.terminate()
184 |
185 | raise e
186 | finally:
187 | for p in processes:
188 | p.join()
189 | print(f"Prefetching complete. [{time.time() - start} sec.]")
190 |
191 | if target_data_type == 'ndarray':
192 | if not isinstance(gather_res[0], np.ndarray):
193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
194 |
195 | # order outputs
196 | return np.concatenate(gather_res, axis=0)
197 | elif target_data_type == 'list':
198 | out = []
199 | for r in gather_res:
200 | out.extend(r)
201 | return out
202 | else:
203 | return gather_res
204 |
205 | def preprocess_prompts(prompts):
206 | if isinstance(prompts, (list, tuple)):
207 | return [p.lower().strip().strip(".").strip() for p in prompts]
208 | elif isinstance(prompts, str):
209 | return prompts.lower().strip().strip(".").strip()
210 | else:
211 | raise NotImplementedError
212 |
213 | def block_single_pixel(value, scale_factor=4):
214 | e = torch.zeros(256)
215 | e[value] = 1
216 | e = rearrange(e, '(w h)-> w h', w=16)
217 | e_resized = F.interpolate(e.reshape(1,1,16,16), scale_factor=scale_factor)[0][0]
218 | e_resized = rearrange(e_resized, 'w h -> (w h)')
219 | return torch.where(e_resized==1)[0]
220 |
221 | def priority_for_2_class(pixel_avai, cls_mask):
222 | assert len(pixel_avai) == 2
223 | rank_1 = cls_mask[0][pixel_avai[0].sort(descending=True)[1]]
224 | rank_2 = cls_mask[1][pixel_avai[1].sort(descending=True)[1]]
225 |
226 | priority_for_1 = torch.cat((rank_1, rank_2.flip(dims=(0,))),0)
227 | priority_for_2 = torch.cat((rank_2, rank_1.flip(dims=(0,))),0)
228 | return [priority_for_1, priority_for_2]
229 |
230 | def priority_for_1_class(pixel_avai, cls_mask):
231 | assert len(pixel_avai) == 1
232 | rank_1 = cls_mask[0][pixel_avai[0].sort(descending=True)[1]]
233 | return [rank_1]
234 |
235 | def flatten(maps):
236 | return rearrange(maps, 'w h -> (w h)')
237 |
238 | def diff(t1, t2):
239 | combined = torch.cat((t1, t2))
240 | uniques, counts = combined.unique(return_counts=True)
241 | difference = uniques[counts == 1]
242 | return difference
--------------------------------------------------------------------------------
/ldm/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from einops import repeat
5 |
6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7 | from taming.modules.losses.lpips import LPIPS
8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9 |
10 |
11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
15 | loss_real = (weights * loss_real).sum() / weights.sum()
16 | loss_fake = (weights * loss_fake).sum() / weights.sum()
17 | d_loss = 0.5 * (loss_real + loss_fake)
18 | return d_loss
19 |
20 | def adopt_weight(weight, global_step, threshold=0, value=0.):
21 | if global_step < threshold:
22 | weight = value
23 | return weight
24 |
25 |
26 | def measure_perplexity(predicted_indices, n_embed):
27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
30 | avg_probs = encodings.mean(0)
31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
32 | cluster_use = torch.sum(avg_probs > 0)
33 | return perplexity, cluster_use
34 |
35 | def l1(x, y):
36 | return torch.abs(x-y)
37 |
38 |
39 | def l2(x, y):
40 | return torch.pow((x-y), 2)
41 |
42 |
43 | class VQLPIPSWithDiscriminator(nn.Module):
44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
48 | pixel_loss="l1"):
49 | super().__init__()
50 | assert disc_loss in ["hinge", "vanilla"]
51 | assert perceptual_loss in ["lpips", "clips", "dists"]
52 | assert pixel_loss in ["l1", "l2"]
53 | self.codebook_weight = codebook_weight
54 | self.pixel_weight = pixelloss_weight
55 | if perceptual_loss == "lpips":
56 | print(f"{self.__class__.__name__}: Running with LPIPS.")
57 | self.perceptual_loss = LPIPS().eval()
58 | else:
59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
60 | self.perceptual_weight = perceptual_weight
61 |
62 | if pixel_loss == "l1":
63 | self.pixel_loss = l1
64 | else:
65 | self.pixel_loss = l2
66 |
67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
68 | n_layers=disc_num_layers,
69 | use_actnorm=use_actnorm,
70 | ndf=disc_ndf
71 | ).apply(weights_init)
72 | self.discriminator_iter_start = disc_start
73 | if disc_loss == "hinge":
74 | self.disc_loss = hinge_d_loss
75 | elif disc_loss == "vanilla":
76 | self.disc_loss = vanilla_d_loss
77 | else:
78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
80 | self.disc_factor = disc_factor
81 | self.discriminator_weight = disc_weight
82 | self.disc_conditional = disc_conditional
83 | self.n_classes = n_classes
84 |
85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
86 | if last_layer is not None:
87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
89 | else:
90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
92 |
93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
95 | d_weight = d_weight * self.discriminator_weight
96 | return d_weight
97 |
98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
100 | if not exists(codebook_loss):
101 | codebook_loss = torch.tensor([0.]).to(inputs.device)
102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
104 | if self.perceptual_weight > 0:
105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
106 | rec_loss = rec_loss + self.perceptual_weight * p_loss
107 | else:
108 | p_loss = torch.tensor([0.0])
109 |
110 | nll_loss = rec_loss
111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
112 | nll_loss = torch.mean(nll_loss)
113 |
114 | # now the GAN part
115 | if optimizer_idx == 0:
116 | # generator update
117 | if cond is None:
118 | assert not self.disc_conditional
119 | logits_fake = self.discriminator(reconstructions.contiguous())
120 | else:
121 | assert self.disc_conditional
122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
123 | g_loss = -torch.mean(logits_fake)
124 |
125 | try:
126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
127 | except RuntimeError:
128 | assert not self.training
129 | d_weight = torch.tensor(0.0)
130 |
131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
133 |
134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
136 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
137 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
138 | "{}/p_loss".format(split): p_loss.detach().mean(),
139 | "{}/d_weight".format(split): d_weight.detach(),
140 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
141 | "{}/g_loss".format(split): g_loss.detach().mean(),
142 | }
143 | if predicted_indices is not None:
144 | assert self.n_classes is not None
145 | with torch.no_grad():
146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
147 | log[f"{split}/perplexity"] = perplexity
148 | log[f"{split}/cluster_usage"] = cluster_usage
149 | return loss, log
150 |
151 | if optimizer_idx == 1:
152 | # second pass for discriminator update
153 | if cond is None:
154 | logits_real = self.discriminator(inputs.contiguous().detach())
155 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
156 | else:
157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
159 |
160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
162 |
163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
164 | "{}/logits_real".format(split): logits_real.detach().mean(),
165 | "{}/logits_fake".format(split): logits_fake.detach().mean()
166 | }
167 | return d_loss, log
168 |
--------------------------------------------------------------------------------
/ldm/modules/encoders/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from functools import partial
4 | import clip
5 | from einops import rearrange, repeat
6 | from transformers import CLIPTokenizer, CLIPTextModel
7 | import kornia
8 |
9 | from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
10 |
11 |
12 | class AbstractEncoder(nn.Module):
13 | def __init__(self):
14 | super().__init__()
15 |
16 | def encode(self, *args, **kwargs):
17 | raise NotImplementedError
18 |
19 |
20 |
21 | class ClassEmbedder(nn.Module):
22 | def __init__(self, embed_dim, n_classes=1000, key='class'):
23 | super().__init__()
24 | self.key = key
25 | self.embedding = nn.Embedding(n_classes, embed_dim)
26 |
27 | def forward(self, batch, key=None):
28 | if key is None:
29 | key = self.key
30 | # this is for use in crossattn
31 | c = batch[key][:, None]
32 | c = self.embedding(c)
33 | return c
34 |
35 |
36 | class TransformerEmbedder(AbstractEncoder):
37 | """Some transformer encoder layers"""
38 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
39 | super().__init__()
40 | self.device = device
41 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
42 | attn_layers=Encoder(dim=n_embed, depth=n_layer))
43 |
44 | def forward(self, tokens):
45 | tokens = tokens.to(self.device) # meh
46 | z = self.transformer(tokens, return_embeddings=True)
47 | return z
48 |
49 | def encode(self, x):
50 | return self(x)
51 |
52 |
53 | class BERTTokenizer(AbstractEncoder):
54 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
55 | def __init__(self, device="cuda", vq_interface=True, max_length=77):
56 | super().__init__()
57 | from transformers import BertTokenizerFast # TODO: add to reuquirements
58 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
59 | self.device = device
60 | self.vq_interface = vq_interface
61 | self.max_length = max_length
62 |
63 | def forward(self, text):
64 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
65 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
66 | tokens = batch_encoding["input_ids"].to(self.device)
67 | return tokens
68 |
69 | @torch.no_grad()
70 | def encode(self, text):
71 | tokens = self(text)
72 | if not self.vq_interface:
73 | return tokens
74 | return None, None, [None, None, tokens]
75 |
76 | def decode(self, text):
77 | return text
78 |
79 |
80 | class BERTEmbedder(AbstractEncoder):
81 | """Uses the BERT tokenizr model and add some transformer encoder layers"""
82 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
83 | device="cuda",use_tokenizer=True, embedding_dropout=0.0):
84 | super().__init__()
85 | self.use_tknz_fn = use_tokenizer
86 | if self.use_tknz_fn:
87 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
88 | self.device = device
89 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
90 | attn_layers=Encoder(dim=n_embed, depth=n_layer),
91 | emb_dropout=embedding_dropout)
92 |
93 | def forward(self, text):
94 | if self.use_tknz_fn:
95 | tokens = self.tknz_fn(text)#.to(self.device)
96 | else:
97 | tokens = text
98 | z = self.transformer(tokens, return_embeddings=True)
99 | return z
100 |
101 | def encode(self, text):
102 | # output of length 77
103 | return self(text)
104 |
105 |
106 | class SpatialRescaler(nn.Module):
107 | def __init__(self,
108 | n_stages=1,
109 | method='bilinear',
110 | multiplier=0.5,
111 | in_channels=3,
112 | out_channels=None,
113 | bias=False):
114 | super().__init__()
115 | self.n_stages = n_stages
116 | assert self.n_stages >= 0
117 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
118 | self.multiplier = multiplier
119 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
120 | self.remap_output = out_channels is not None
121 | if self.remap_output:
122 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
123 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
124 |
125 | def forward(self,x):
126 | for stage in range(self.n_stages):
127 | x = self.interpolator(x, scale_factor=self.multiplier)
128 |
129 |
130 | if self.remap_output:
131 | x = self.channel_mapper(x)
132 | return x
133 |
134 | def encode(self, x):
135 | return self(x)
136 |
137 | class FrozenCLIPEmbedder(AbstractEncoder):
138 | """Uses the CLIP transformer encoder for text (from Hugging Face)"""
139 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
140 | super().__init__()
141 | self.tokenizer = CLIPTokenizer.from_pretrained(version)
142 | self.transformer = CLIPTextModel.from_pretrained(version)
143 | self.device = device
144 | self.max_length = max_length
145 | self.freeze()
146 |
147 | def freeze(self):
148 | self.transformer = self.transformer.eval()
149 | for param in self.parameters():
150 | param.requires_grad = False
151 |
152 | def forward(self, text):
153 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
154 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
155 | tokens = batch_encoding["input_ids"].to(self.device)
156 | outputs = self.transformer(input_ids=tokens)
157 |
158 | z = outputs.last_hidden_state
159 | return z
160 |
161 | def encode(self, text):
162 | return self(text)
163 |
164 |
165 | class FrozenCLIPTextEmbedder(nn.Module):
166 | """
167 | Uses the CLIP transformer encoder for text.
168 | """
169 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
170 | super().__init__()
171 | self.model, _ = clip.load(version, jit=False, device="cpu")
172 | self.device = device
173 | self.max_length = max_length
174 | self.n_repeat = n_repeat
175 | self.normalize = normalize
176 |
177 | def freeze(self):
178 | self.model = self.model.eval()
179 | for param in self.parameters():
180 | param.requires_grad = False
181 |
182 | def forward(self, text):
183 | tokens = clip.tokenize(text).to(self.device)
184 | z = self.model.encode_text(tokens)
185 | if self.normalize:
186 | z = z / torch.linalg.norm(z, dim=1, keepdim=True)
187 | return z
188 |
189 | def encode(self, text):
190 | z = self(text)
191 | if z.ndim==2:
192 | z = z[:, None, :]
193 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
194 | return z
195 |
196 |
197 | class FrozenClipImageEmbedder(nn.Module):
198 | """
199 | Uses the CLIP image encoder.
200 | """
201 | def __init__(
202 | self,
203 | model,
204 | jit=False,
205 | device='cuda' if torch.cuda.is_available() else 'cpu',
206 | antialias=False,
207 | ):
208 | super().__init__()
209 | self.model, _ = clip.load(name=model, device=device, jit=jit)
210 |
211 | self.antialias = antialias
212 |
213 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
214 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
215 |
216 | def preprocess(self, x):
217 | # normalize to [0,1]
218 | x = kornia.geometry.resize(x, (224, 224),
219 | interpolation='bicubic',align_corners=True,
220 | antialias=self.antialias)
221 | x = (x + 1.) / 2.
222 | # renormalize according to clip
223 | x = kornia.enhance.normalize(x, self.mean, self.std)
224 | return x
225 |
226 | def forward(self, x):
227 | # x is assumed to be in range [-1,1]
228 | return self.model.encode_image(self.preprocess(x))
229 |
230 |
231 | if __name__ == "__main__":
232 | from ldm.util import count_params
233 | model = FrozenCLIPEmbedder()
234 | count_params(model, verbose=True)
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/util.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 |
11 | import os
12 | import math
13 | import torch
14 | import torch.nn as nn
15 | import numpy as np
16 | from einops import repeat
17 |
18 | from ldm.util import instantiate_from_config
19 |
20 |
21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22 | if schedule == "linear":
23 | betas = (
24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25 | )
26 |
27 | elif schedule == "cosine":
28 | timesteps = (
29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30 | )
31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
32 | alphas = torch.cos(alphas).pow(2)
33 | alphas = alphas / alphas[0]
34 | betas = 1 - alphas[1:] / alphas[:-1]
35 | betas = np.clip(betas, a_min=0, a_max=0.999)
36 |
37 | elif schedule == "sqrt_linear":
38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39 | elif schedule == "sqrt":
40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41 | else:
42 | raise ValueError(f"schedule '{schedule}' unknown.")
43 | return betas.numpy()
44 |
45 |
46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47 | if ddim_discr_method == 'uniform':
48 | c = num_ddpm_timesteps // num_ddim_timesteps
49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50 | elif ddim_discr_method == 'quad':
51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52 | else:
53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54 |
55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
57 | steps_out = ddim_timesteps + 1
58 | if verbose:
59 | print(f'Selected timesteps for ddim sampler: {steps_out}')
60 | return steps_out
61 |
62 |
63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64 | # select alphas for computing the variance schedule
65 | alphas = alphacums[ddim_timesteps]
66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67 |
68 | # according the the formula provided in https://arxiv.org/abs/2010.02502
69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70 | if verbose:
71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72 | print(f'For the chosen value of eta, which is {eta}, '
73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74 | return sigmas, alphas, alphas_prev
75 |
76 |
77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78 | """
79 | Create a beta schedule that discretizes the given alpha_t_bar function,
80 | which defines the cumulative product of (1-beta) over time from t = [0,1].
81 | :param num_diffusion_timesteps: the number of betas to produce.
82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83 | produces the cumulative product of (1-beta) up to that
84 | part of the diffusion process.
85 | :param max_beta: the maximum beta to use; use values lower than 1 to
86 | prevent singularities.
87 | """
88 | betas = []
89 | for i in range(num_diffusion_timesteps):
90 | t1 = i / num_diffusion_timesteps
91 | t2 = (i + 1) / num_diffusion_timesteps
92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93 | return np.array(betas)
94 |
95 |
96 | def extract_into_tensor(a, t, x_shape):
97 | b, *_ = t.shape
98 | out = a.gather(-1, t)
99 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100 |
101 |
102 | def checkpoint(func, inputs, params, flag):
103 | """
104 | Evaluate a function without caching intermediate activations, allowing for
105 | reduced memory at the expense of extra compute in the backward pass.
106 | :param func: the function to evaluate.
107 | :param inputs: the argument sequence to pass to `func`.
108 | :param params: a sequence of parameters `func` depends on but does not
109 | explicitly take as arguments.
110 | :param flag: if False, disable gradient checkpointing.
111 | """
112 | if flag:
113 | args = tuple(inputs) + tuple(params)
114 | return CheckpointFunction.apply(func, len(inputs), *args)
115 | else:
116 | return func(*inputs)
117 |
118 |
119 | class CheckpointFunction(torch.autograd.Function):
120 | @staticmethod
121 | def forward(ctx, run_function, length, *args):
122 | ctx.run_function = run_function
123 | ctx.input_tensors = list(args[:length])
124 | ctx.input_params = list(args[length:])
125 |
126 | with torch.no_grad():
127 | output_tensors = ctx.run_function(*ctx.input_tensors)
128 | return output_tensors
129 |
130 | @staticmethod
131 | def backward(ctx, *output_grads):
132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
133 | with torch.enable_grad():
134 | # Fixes a bug where the first op in run_function modifies the
135 | # Tensor storage in place, which is not allowed for detach()'d
136 | # Tensors.
137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
138 | output_tensors = ctx.run_function(*shallow_copies)
139 | input_grads = torch.autograd.grad(
140 | output_tensors,
141 | ctx.input_tensors + ctx.input_params,
142 | output_grads,
143 | allow_unused=True,
144 | )
145 | del ctx.input_tensors
146 | del ctx.input_params
147 | del output_tensors
148 | return (None, None) + input_grads
149 |
150 |
151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
152 | """
153 | Create sinusoidal timestep embeddings.
154 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
155 | These may be fractional.
156 | :param dim: the dimension of the output.
157 | :param max_period: controls the minimum frequency of the embeddings.
158 | :return: an [N x dim] Tensor of positional embeddings.
159 | """
160 | if not repeat_only:
161 | half = dim // 2
162 | freqs = torch.exp(
163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
164 | ).to(device=timesteps.device)
165 | args = timesteps[:, None].float() * freqs[None]
166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
167 | if dim % 2:
168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
169 | else:
170 | embedding = repeat(timesteps, 'b -> b d', d=dim)
171 | return embedding
172 |
173 |
174 | def zero_module(module):
175 | """
176 | Zero out the parameters of a module and return it.
177 | """
178 | for p in module.parameters():
179 | p.detach().zero_()
180 | return module
181 |
182 |
183 | def scale_module(module, scale):
184 | """
185 | Scale the parameters of a module and return it.
186 | """
187 | for p in module.parameters():
188 | p.detach().mul_(scale)
189 | return module
190 |
191 |
192 | def mean_flat(tensor):
193 | """
194 | Take the mean over all non-batch dimensions.
195 | """
196 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
197 |
198 |
199 | def normalization(channels):
200 | """
201 | Make a standard normalization layer.
202 | :param channels: number of input channels.
203 | :return: an nn.Module for normalization.
204 | """
205 | return GroupNorm32(32, channels)
206 |
207 |
208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
209 | class SiLU(nn.Module):
210 | def forward(self, x):
211 | return x * torch.sigmoid(x)
212 |
213 |
214 | class GroupNorm32(nn.GroupNorm):
215 | def forward(self, x):
216 | return super().forward(x.float()).type(x.dtype)
217 |
218 | def conv_nd(dims, *args, **kwargs):
219 | """
220 | Create a 1D, 2D, or 3D convolution module.
221 | """
222 | if dims == 1:
223 | return nn.Conv1d(*args, **kwargs)
224 | elif dims == 2:
225 | return nn.Conv2d(*args, **kwargs)
226 | elif dims == 3:
227 | return nn.Conv3d(*args, **kwargs)
228 | raise ValueError(f"unsupported dimensions: {dims}")
229 |
230 |
231 | def linear(*args, **kwargs):
232 | """
233 | Create a linear module.
234 | """
235 | return nn.Linear(*args, **kwargs)
236 |
237 |
238 | def avg_pool_nd(dims, *args, **kwargs):
239 | """
240 | Create a 1D, 2D, or 3D average pooling module.
241 | """
242 | if dims == 1:
243 | return nn.AvgPool1d(*args, **kwargs)
244 | elif dims == 2:
245 | return nn.AvgPool2d(*args, **kwargs)
246 | elif dims == 3:
247 | return nn.AvgPool3d(*args, **kwargs)
248 | raise ValueError(f"unsupported dimensions: {dims}")
249 |
250 |
251 | class HybridConditioner(nn.Module):
252 |
253 | def __init__(self, c_concat_config, c_crossattn_config):
254 | super().__init__()
255 | self.concat_conditioner = instantiate_from_config(c_concat_config)
256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
257 |
258 | def forward(self, c_concat, c_crossattn):
259 | c_concat = self.concat_conditioner(c_concat)
260 | c_crossattn = self.crossattn_conditioner(c_crossattn)
261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
262 |
263 |
264 | def noise_like(shape, device, repeat=False):
265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
266 | noise = lambda: torch.randn(shape, device=device)
267 | return repeat_noise() if repeat else noise()
--------------------------------------------------------------------------------
/ldm/models/diffusion/classifier.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pytorch_lightning as pl
4 | from omegaconf import OmegaConf
5 | from torch.nn import functional as F
6 | from torch.optim import AdamW
7 | from torch.optim.lr_scheduler import LambdaLR
8 | from copy import deepcopy
9 | from einops import rearrange
10 | from glob import glob
11 | from natsort import natsorted
12 |
13 | from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14 | from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
15 |
16 | __models__ = {
17 | 'class_label': EncoderUNetModel,
18 | 'segmentation': UNetModel
19 | }
20 |
21 |
22 | def disabled_train(self, mode=True):
23 | """Overwrite model.train with this function to make sure train/eval mode
24 | does not change anymore."""
25 | return self
26 |
27 |
28 | class NoisyLatentImageClassifier(pl.LightningModule):
29 |
30 | def __init__(self,
31 | diffusion_path,
32 | num_classes,
33 | ckpt_path=None,
34 | pool='attention',
35 | label_key=None,
36 | diffusion_ckpt_path=None,
37 | scheduler_config=None,
38 | weight_decay=1.e-2,
39 | log_steps=10,
40 | monitor='val/loss',
41 | *args,
42 | **kwargs):
43 | super().__init__(*args, **kwargs)
44 | self.num_classes = num_classes
45 | # get latest config of diffusion model
46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47 | self.diffusion_config = OmegaConf.load(diffusion_config).model
48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49 | self.load_diffusion()
50 |
51 | self.monitor = monitor
52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54 | self.log_steps = log_steps
55 |
56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57 | else self.diffusion_model.cond_stage_key
58 |
59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
60 |
61 | if self.label_key not in __models__:
62 | raise NotImplementedError()
63 |
64 | self.load_classifier(ckpt_path, pool)
65 |
66 | self.scheduler_config = scheduler_config
67 | self.use_scheduler = self.scheduler_config is not None
68 | self.weight_decay = weight_decay
69 |
70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71 | sd = torch.load(path, map_location="cpu")
72 | if "state_dict" in list(sd.keys()):
73 | sd = sd["state_dict"]
74 | keys = list(sd.keys())
75 | for k in keys:
76 | for ik in ignore_keys:
77 | if k.startswith(ik):
78 | print("Deleting key {} from state_dict.".format(k))
79 | del sd[k]
80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81 | sd, strict=False)
82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83 | if len(missing) > 0:
84 | print(f"Missing Keys: {missing}")
85 | if len(unexpected) > 0:
86 | print(f"Unexpected Keys: {unexpected}")
87 |
88 | def load_diffusion(self):
89 | model = instantiate_from_config(self.diffusion_config)
90 | self.diffusion_model = model.eval()
91 | self.diffusion_model.train = disabled_train
92 | for param in self.diffusion_model.parameters():
93 | param.requires_grad = False
94 |
95 | def load_classifier(self, ckpt_path, pool):
96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98 | model_config.out_channels = self.num_classes
99 | if self.label_key == 'class_label':
100 | model_config.pool = pool
101 |
102 | self.model = __models__[self.label_key](**model_config)
103 | if ckpt_path is not None:
104 | print('#####################################################################')
105 | print(f'load from ckpt "{ckpt_path}"')
106 | print('#####################################################################')
107 | self.init_from_ckpt(ckpt_path)
108 |
109 | @torch.no_grad()
110 | def get_x_noisy(self, x, t, noise=None):
111 | noise = default(noise, lambda: torch.randn_like(x))
112 | continuous_sqrt_alpha_cumprod = None
113 | if self.diffusion_model.use_continuous_noise:
114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115 | # todo: make sure t+1 is correct here
116 |
117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119 |
120 | def forward(self, x_noisy, t, *args, **kwargs):
121 | return self.model(x_noisy, t)
122 |
123 | @torch.no_grad()
124 | def get_input(self, batch, k):
125 | x = batch[k]
126 | if len(x.shape) == 3:
127 | x = x[..., None]
128 | x = rearrange(x, 'b h w c -> b c h w')
129 | x = x.to(memory_format=torch.contiguous_format).float()
130 | return x
131 |
132 | @torch.no_grad()
133 | def get_conditioning(self, batch, k=None):
134 | if k is None:
135 | k = self.label_key
136 | assert k is not None, 'Needs to provide label key'
137 |
138 | targets = batch[k].to(self.device)
139 |
140 | if self.label_key == 'segmentation':
141 | targets = rearrange(targets, 'b h w c -> b c h w')
142 | for down in range(self.numd):
143 | h, w = targets.shape[-2:]
144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145 |
146 | # targets = rearrange(targets,'b c h w -> b h w c')
147 |
148 | return targets
149 |
150 | def compute_top_k(self, logits, labels, k, reduction="mean"):
151 | _, top_ks = torch.topk(logits, k, dim=1)
152 | if reduction == "mean":
153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154 | elif reduction == "none":
155 | return (top_ks == labels[:, None]).float().sum(dim=-1)
156 |
157 | def on_train_epoch_start(self):
158 | # save some memory
159 | self.diffusion_model.model.to('cpu')
160 |
161 | @torch.no_grad()
162 | def write_logs(self, loss, logits, targets):
163 | log_prefix = 'train' if self.training else 'val'
164 | log = {}
165 | log[f"{log_prefix}/loss"] = loss.mean()
166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167 | logits, targets, k=1, reduction="mean"
168 | )
169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170 | logits, targets, k=5, reduction="mean"
171 | )
172 |
173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176 | lr = self.optimizers().param_groups[0]['lr']
177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178 |
179 | def shared_step(self, batch, t=None):
180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181 | targets = self.get_conditioning(batch)
182 | if targets.dim() == 4:
183 | targets = targets.argmax(dim=1)
184 | if t is None:
185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186 | else:
187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188 | x_noisy = self.get_x_noisy(x, t)
189 | logits = self(x_noisy, t)
190 |
191 | loss = F.cross_entropy(logits, targets, reduction='none')
192 |
193 | self.write_logs(loss.detach(), logits.detach(), targets.detach())
194 |
195 | loss = loss.mean()
196 | return loss, logits, x_noisy, targets
197 |
198 | def training_step(self, batch, batch_idx):
199 | loss, *_ = self.shared_step(batch)
200 | return loss
201 |
202 | def reset_noise_accs(self):
203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205 |
206 | def on_validation_start(self):
207 | self.reset_noise_accs()
208 |
209 | @torch.no_grad()
210 | def validation_step(self, batch, batch_idx):
211 | loss, *_ = self.shared_step(batch)
212 |
213 | for t in self.noisy_acc:
214 | _, logits, _, targets = self.shared_step(batch, t)
215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217 |
218 | return loss
219 |
220 | def configure_optimizers(self):
221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222 |
223 | if self.use_scheduler:
224 | scheduler = instantiate_from_config(self.scheduler_config)
225 |
226 | print("Setting up LambdaLR scheduler...")
227 | scheduler = [
228 | {
229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230 | 'interval': 'step',
231 | 'frequency': 1
232 | }]
233 | return [optimizer], scheduler
234 |
235 | return optimizer
236 |
237 | @torch.no_grad()
238 | def log_images(self, batch, N=8, *args, **kwargs):
239 | log = dict()
240 | x = self.get_input(batch, self.diffusion_model.first_stage_key)
241 | log['inputs'] = x
242 |
243 | y = self.get_conditioning(batch)
244 |
245 | if self.label_key == 'class_label':
246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247 | log['labels'] = y
248 |
249 | if ismap(y):
250 | log['labels'] = self.diffusion_model.to_rgb(y)
251 |
252 | for step in range(self.log_steps):
253 | current_time = step * self.log_time_interval
254 |
255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256 |
257 | log[f'inputs@t{current_time}'] = x_noisy
258 |
259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260 | pred = rearrange(pred, 'b h w c -> b c h w')
261 |
262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263 |
264 | for key in log:
265 | log[key] = log[key][:N]
266 |
267 | return log
268 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/_ddim.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 | import numpy as np
5 | from tqdm import tqdm
6 | from functools import partial
7 |
8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
9 | extract_into_tensor
10 |
11 |
12 | class DDIMSampler(object):
13 | def __init__(self, model, schedule="linear", **kwargs):
14 | super().__init__()
15 | self.model = model
16 | self.ddpm_num_timesteps = model.num_timesteps
17 | self.schedule = schedule
18 |
19 | def register_buffer(self, name, attr):
20 | if type(attr) == torch.Tensor:
21 | if attr.device != torch.device("cuda"):
22 | attr = attr.to(torch.device("cuda"))
23 | setattr(self, name, attr)
24 |
25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
26 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
27 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
28 | alphas_cumprod = self.model.alphas_cumprod
29 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
30 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
31 |
32 | self.register_buffer('betas', to_torch(self.model.betas))
33 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
34 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
35 |
36 | # calculations for diffusion q(x_t | x_{t-1}) and others
37 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
38 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
39 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
40 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
41 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
42 |
43 | # ddim sampling parameters
44 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
45 | ddim_timesteps=self.ddim_timesteps,
46 | eta=ddim_eta,verbose=verbose)
47 | self.register_buffer('ddim_sigmas', ddim_sigmas)
48 | self.register_buffer('ddim_alphas', ddim_alphas)
49 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
50 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
51 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
52 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
53 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
54 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
55 |
56 | @torch.no_grad()
57 | def sample(self,
58 | S,
59 | batch_size,
60 | shape,
61 | conditioning=None,
62 | callback=None,
63 | normals_sequence=None,
64 | img_callback=None,
65 | quantize_x0=False,
66 | eta=0.,
67 | mask=None,
68 | x0=None,
69 | temperature=1.,
70 | noise_dropout=0.,
71 | score_corrector=None,
72 | corrector_kwargs=None,
73 | verbose=True,
74 | x_T=None,
75 | log_every_t=100,
76 | unconditional_guidance_scale=1.,
77 | unconditional_conditioning=None,
78 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
79 | **kwargs
80 | ):
81 | if conditioning is not None:
82 | if isinstance(conditioning, dict):
83 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
84 | if cbs != batch_size:
85 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
86 | else:
87 | if conditioning.shape[0] != batch_size:
88 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
89 |
90 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
91 | # sampling
92 | C, H, W = shape
93 | size = (batch_size, C, H, W)
94 | print(f'Data shape for DDIM sampling is {size}, eta {eta}')
95 |
96 | samples, intermediates = self.ddim_sampling(conditioning, size,
97 | callback=callback,
98 | img_callback=img_callback,
99 | quantize_denoised=quantize_x0,
100 | mask=mask, x0=x0,
101 | ddim_use_original_steps=False,
102 | noise_dropout=noise_dropout,
103 | temperature=temperature,
104 | score_corrector=score_corrector,
105 | corrector_kwargs=corrector_kwargs,
106 | x_T=x_T,
107 | log_every_t=log_every_t,
108 | unconditional_guidance_scale=unconditional_guidance_scale,
109 | unconditional_conditioning=unconditional_conditioning,
110 | )
111 | return samples, intermediates
112 |
113 | @torch.no_grad()
114 | def ddim_sampling(self, cond, shape,
115 | x_T=None, ddim_use_original_steps=False,
116 | callback=None, timesteps=None, quantize_denoised=False,
117 | mask=None, x0=None, img_callback=None, log_every_t=100,
118 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
119 | unconditional_guidance_scale=1., unconditional_conditioning=None,):
120 | device = self.model.betas.device
121 | b = shape[0]
122 | if x_T is None:
123 | img = torch.randn(shape, device=device)
124 | else:
125 | img = x_T
126 |
127 | if timesteps is None:
128 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
129 | elif timesteps is not None and not ddim_use_original_steps:
130 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
131 | timesteps = self.ddim_timesteps[:subset_end]
132 |
133 | intermediates = {'x_inter': [img], 'pred_x0': [img]}
134 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
135 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
136 | print(f"Running DDIM Sampling with {total_steps} timesteps")
137 |
138 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
139 |
140 | for i, step in enumerate(iterator):
141 | index = total_steps - i - 1
142 | ts = torch.full((b,), step, device=device, dtype=torch.long)
143 |
144 | if mask is not None:
145 | assert x0 is not None
146 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
147 | img = img_orig * mask + (1. - mask) * img
148 |
149 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
150 | quantize_denoised=quantize_denoised, temperature=temperature,
151 | noise_dropout=noise_dropout, score_corrector=score_corrector,
152 | corrector_kwargs=corrector_kwargs,
153 | unconditional_guidance_scale=unconditional_guidance_scale,
154 | unconditional_conditioning=unconditional_conditioning)
155 | img, pred_x0 = outs
156 | if callback: callback(i)
157 | if img_callback: img_callback(pred_x0, i)
158 |
159 | if index % log_every_t == 0 or index == total_steps - 1:
160 | intermediates['x_inter'].append(img)
161 | intermediates['pred_x0'].append(pred_x0)
162 |
163 | return img, intermediates
164 |
165 | @torch.no_grad()
166 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
167 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
168 | unconditional_guidance_scale=1., unconditional_conditioning=None):
169 | b, *_, device = *x.shape, x.device
170 |
171 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
172 | e_t = self.model.apply_model(x, t, c)
173 | else:
174 | x_in = torch.cat([x] * 2)
175 | t_in = torch.cat([t] * 2)
176 | c_in = torch.cat([unconditional_conditioning, c])
177 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
178 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
179 |
180 | if score_corrector is not None:
181 | assert self.model.parameterization == "eps"
182 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
183 |
184 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
185 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
186 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
187 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
188 | # select parameters corresponding to the currently considered timestep
189 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
190 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
191 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
192 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
193 |
194 | # current prediction for x_0
195 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
196 | if quantize_denoised:
197 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
198 | # direction pointing to x_t
199 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
200 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
201 | if noise_dropout > 0.:
202 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
203 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
204 | return x_prev, pred_x0
205 |
206 | @torch.no_grad()
207 | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
208 | # fast, but does not allow for exact reconstruction
209 | # t serves as an index to gather the correct alphas
210 | if use_original_steps:
211 | sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
212 | sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
213 | else:
214 | sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
215 | sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
216 |
217 | if noise is None:
218 | noise = torch.randn_like(x0)
219 | return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
220 | extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
221 |
222 | @torch.no_grad()
223 | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
224 | use_original_steps=False):
225 |
226 | timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
227 | timesteps = timesteps[:t_start]
228 |
229 | time_range = np.flip(timesteps)
230 | total_steps = timesteps.shape[0]
231 | print(f"Running DDIM Sampling with {total_steps} timesteps")
232 |
233 | iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
234 | x_dec = x_latent
235 | for i, step in enumerate(iterator):
236 | index = total_steps - i - 1
237 | ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
238 | x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
239 | unconditional_guidance_scale=unconditional_guidance_scale,
240 | unconditional_conditioning=unconditional_conditioning)
241 | return x_dec
--------------------------------------------------------------------------------
/ldm/data/imagenet.py:
--------------------------------------------------------------------------------
1 | import os, yaml, pickle, shutil, tarfile, glob
2 | import cv2
3 | import albumentations
4 | import PIL
5 | import numpy as np
6 | import torchvision.transforms.functional as TF
7 | from omegaconf import OmegaConf
8 | from functools import partial
9 | from PIL import Image
10 | from tqdm import tqdm
11 | from torch.utils.data import Dataset, Subset
12 |
13 | import taming.data.utils as tdu
14 | from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
15 | from taming.data.imagenet import ImagePaths
16 |
17 | from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
18 |
19 |
20 | def synset2idx(path_to_yaml="data/index_synset.yaml"):
21 | with open(path_to_yaml) as f:
22 | di2s = yaml.load(f)
23 | return dict((v,k) for k,v in di2s.items())
24 |
25 |
26 | class ImageNetBase(Dataset):
27 | def __init__(self, config=None):
28 | self.config = config or OmegaConf.create()
29 | if not type(self.config)==dict:
30 | self.config = OmegaConf.to_container(self.config)
31 | self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
32 | self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
33 | self._prepare()
34 | self._prepare_synset_to_human()
35 | self._prepare_idx_to_synset()
36 | self._prepare_human_to_integer_label()
37 | self._load()
38 |
39 | def __len__(self):
40 | return len(self.data)
41 |
42 | def __getitem__(self, i):
43 | return self.data[i]
44 |
45 | def _prepare(self):
46 | raise NotImplementedError()
47 |
48 | def _filter_relpaths(self, relpaths):
49 | ignore = set([
50 | "n06596364_9591.JPEG",
51 | ])
52 | relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
53 | if "sub_indices" in self.config:
54 | indices = str_to_indices(self.config["sub_indices"])
55 | synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
56 | self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
57 | files = []
58 | for rpath in relpaths:
59 | syn = rpath.split("/")[0]
60 | if syn in synsets:
61 | files.append(rpath)
62 | return files
63 | else:
64 | return relpaths
65 |
66 | def _prepare_synset_to_human(self):
67 | SIZE = 2655750
68 | URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
69 | self.human_dict = os.path.join(self.root, "synset_human.txt")
70 | if (not os.path.exists(self.human_dict) or
71 | not os.path.getsize(self.human_dict)==SIZE):
72 | download(URL, self.human_dict)
73 |
74 | def _prepare_idx_to_synset(self):
75 | URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
76 | self.idx2syn = os.path.join(self.root, "index_synset.yaml")
77 | if (not os.path.exists(self.idx2syn)):
78 | download(URL, self.idx2syn)
79 |
80 | def _prepare_human_to_integer_label(self):
81 | URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
82 | self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
83 | if (not os.path.exists(self.human2integer)):
84 | download(URL, self.human2integer)
85 | with open(self.human2integer, "r") as f:
86 | lines = f.read().splitlines()
87 | assert len(lines) == 1000
88 | self.human2integer_dict = dict()
89 | for line in lines:
90 | value, key = line.split(":")
91 | self.human2integer_dict[key] = int(value)
92 |
93 | def _load(self):
94 | with open(self.txt_filelist, "r") as f:
95 | self.relpaths = f.read().splitlines()
96 | l1 = len(self.relpaths)
97 | self.relpaths = self._filter_relpaths(self.relpaths)
98 | print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
99 |
100 | self.synsets = [p.split("/")[0] for p in self.relpaths]
101 | self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
102 |
103 | unique_synsets = np.unique(self.synsets)
104 | class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
105 | if not self.keep_orig_class_label:
106 | self.class_labels = [class_dict[s] for s in self.synsets]
107 | else:
108 | self.class_labels = [self.synset2idx[s] for s in self.synsets]
109 |
110 | with open(self.human_dict, "r") as f:
111 | human_dict = f.read().splitlines()
112 | human_dict = dict(line.split(maxsplit=1) for line in human_dict)
113 |
114 | self.human_labels = [human_dict[s] for s in self.synsets]
115 |
116 | labels = {
117 | "relpath": np.array(self.relpaths),
118 | "synsets": np.array(self.synsets),
119 | "class_label": np.array(self.class_labels),
120 | "human_label": np.array(self.human_labels),
121 | }
122 |
123 | if self.process_images:
124 | self.size = retrieve(self.config, "size", default=256)
125 | self.data = ImagePaths(self.abspaths,
126 | labels=labels,
127 | size=self.size,
128 | random_crop=self.random_crop,
129 | )
130 | else:
131 | self.data = self.abspaths
132 |
133 |
134 | class ImageNetTrain(ImageNetBase):
135 | NAME = "ILSVRC2012_train"
136 | URL = "http://www.image-net.org/challenges/LSVRC/2012/"
137 | AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
138 | FILES = [
139 | "ILSVRC2012_img_train.tar",
140 | ]
141 | SIZES = [
142 | 147897477120,
143 | ]
144 |
145 | def __init__(self, process_images=True, data_root=None, **kwargs):
146 | self.process_images = process_images
147 | self.data_root = data_root
148 | super().__init__(**kwargs)
149 |
150 | def _prepare(self):
151 | if self.data_root:
152 | self.root = os.path.join(self.data_root, self.NAME)
153 | else:
154 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
155 | self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
156 |
157 | self.datadir = os.path.join(self.root, "data")
158 | self.txt_filelist = os.path.join(self.root, "filelist.txt")
159 | self.expected_length = 1281167
160 | self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
161 | default=True)
162 | if not tdu.is_prepared(self.root):
163 | # prep
164 | print("Preparing dataset {} in {}".format(self.NAME, self.root))
165 |
166 | datadir = self.datadir
167 | if not os.path.exists(datadir):
168 | path = os.path.join(self.root, self.FILES[0])
169 | if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
170 | import academictorrents as at
171 | atpath = at.get(self.AT_HASH, datastore=self.root)
172 | assert atpath == path
173 |
174 | print("Extracting {} to {}".format(path, datadir))
175 | os.makedirs(datadir, exist_ok=True)
176 | with tarfile.open(path, "r:") as tar:
177 | tar.extractall(path=datadir)
178 |
179 | print("Extracting sub-tars.")
180 | subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
181 | for subpath in tqdm(subpaths):
182 | subdir = subpath[:-len(".tar")]
183 | os.makedirs(subdir, exist_ok=True)
184 | with tarfile.open(subpath, "r:") as tar:
185 | tar.extractall(path=subdir)
186 |
187 | filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
188 | filelist = [os.path.relpath(p, start=datadir) for p in filelist]
189 | filelist = sorted(filelist)
190 | filelist = "\n".join(filelist)+"\n"
191 | with open(self.txt_filelist, "w") as f:
192 | f.write(filelist)
193 |
194 | tdu.mark_prepared(self.root)
195 |
196 |
197 | class ImageNetValidation(ImageNetBase):
198 | NAME = "ILSVRC2012_validation"
199 | URL = "http://www.image-net.org/challenges/LSVRC/2012/"
200 | AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
201 | VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
202 | FILES = [
203 | "ILSVRC2012_img_val.tar",
204 | "validation_synset.txt",
205 | ]
206 | SIZES = [
207 | 6744924160,
208 | 1950000,
209 | ]
210 |
211 | def __init__(self, process_images=True, data_root=None, **kwargs):
212 | self.data_root = data_root
213 | self.process_images = process_images
214 | super().__init__(**kwargs)
215 |
216 | def _prepare(self):
217 | if self.data_root:
218 | self.root = os.path.join(self.data_root, self.NAME)
219 | else:
220 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
221 | self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
222 | self.datadir = os.path.join(self.root, "data")
223 | self.txt_filelist = os.path.join(self.root, "filelist.txt")
224 | self.expected_length = 50000
225 | self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
226 | default=False)
227 | if not tdu.is_prepared(self.root):
228 | # prep
229 | print("Preparing dataset {} in {}".format(self.NAME, self.root))
230 |
231 | datadir = self.datadir
232 | if not os.path.exists(datadir):
233 | path = os.path.join(self.root, self.FILES[0])
234 | if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
235 | import academictorrents as at
236 | atpath = at.get(self.AT_HASH, datastore=self.root)
237 | assert atpath == path
238 |
239 | print("Extracting {} to {}".format(path, datadir))
240 | os.makedirs(datadir, exist_ok=True)
241 | with tarfile.open(path, "r:") as tar:
242 | tar.extractall(path=datadir)
243 |
244 | vspath = os.path.join(self.root, self.FILES[1])
245 | if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
246 | download(self.VS_URL, vspath)
247 |
248 | with open(vspath, "r") as f:
249 | synset_dict = f.read().splitlines()
250 | synset_dict = dict(line.split() for line in synset_dict)
251 |
252 | print("Reorganizing into synset folders")
253 | synsets = np.unique(list(synset_dict.values()))
254 | for s in synsets:
255 | os.makedirs(os.path.join(datadir, s), exist_ok=True)
256 | for k, v in synset_dict.items():
257 | src = os.path.join(datadir, k)
258 | dst = os.path.join(datadir, v)
259 | shutil.move(src, dst)
260 |
261 | filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
262 | filelist = [os.path.relpath(p, start=datadir) for p in filelist]
263 | filelist = sorted(filelist)
264 | filelist = "\n".join(filelist)+"\n"
265 | with open(self.txt_filelist, "w") as f:
266 | f.write(filelist)
267 |
268 | tdu.mark_prepared(self.root)
269 |
270 |
271 |
272 | class ImageNetSR(Dataset):
273 | def __init__(self, size=None,
274 | degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
275 | random_crop=True):
276 | """
277 | Imagenet Superresolution Dataloader
278 | Performs following ops in order:
279 | 1. crops a crop of size s from image either as random or center crop
280 | 2. resizes crop to size with cv2.area_interpolation
281 | 3. degrades resized crop with degradation_fn
282 |
283 | :param size: resizing to size after cropping
284 | :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
285 | :param downscale_f: Low Resolution Downsample factor
286 | :param min_crop_f: determines crop size s,
287 | where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
288 | :param max_crop_f: ""
289 | :param data_root:
290 | :param random_crop:
291 | """
292 | self.base = self.get_base()
293 | assert size
294 | assert (size / downscale_f).is_integer()
295 | self.size = size
296 | self.LR_size = int(size / downscale_f)
297 | self.min_crop_f = min_crop_f
298 | self.max_crop_f = max_crop_f
299 | assert(max_crop_f <= 1.)
300 | self.center_crop = not random_crop
301 |
302 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
303 |
304 | self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
305 |
306 | if degradation == "bsrgan":
307 | self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
308 |
309 | elif degradation == "bsrgan_light":
310 | self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
311 |
312 | else:
313 | interpolation_fn = {
314 | "cv_nearest": cv2.INTER_NEAREST,
315 | "cv_bilinear": cv2.INTER_LINEAR,
316 | "cv_bicubic": cv2.INTER_CUBIC,
317 | "cv_area": cv2.INTER_AREA,
318 | "cv_lanczos": cv2.INTER_LANCZOS4,
319 | "pil_nearest": PIL.Image.NEAREST,
320 | "pil_bilinear": PIL.Image.BILINEAR,
321 | "pil_bicubic": PIL.Image.BICUBIC,
322 | "pil_box": PIL.Image.BOX,
323 | "pil_hamming": PIL.Image.HAMMING,
324 | "pil_lanczos": PIL.Image.LANCZOS,
325 | }[degradation]
326 |
327 | self.pil_interpolation = degradation.startswith("pil_")
328 |
329 | if self.pil_interpolation:
330 | self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
331 |
332 | else:
333 | self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
334 | interpolation=interpolation_fn)
335 |
336 | def __len__(self):
337 | return len(self.base)
338 |
339 | def __getitem__(self, i):
340 | example = self.base[i]
341 | image = Image.open(example["file_path_"])
342 |
343 | if not image.mode == "RGB":
344 | image = image.convert("RGB")
345 |
346 | image = np.array(image).astype(np.uint8)
347 |
348 | min_side_len = min(image.shape[:2])
349 | crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
350 | crop_side_len = int(crop_side_len)
351 |
352 | if self.center_crop:
353 | self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
354 |
355 | else:
356 | self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
357 |
358 | image = self.cropper(image=image)["image"]
359 | image = self.image_rescaler(image=image)["image"]
360 |
361 | if self.pil_interpolation:
362 | image_pil = PIL.Image.fromarray(image)
363 | LR_image = self.degradation_process(image_pil)
364 | LR_image = np.array(LR_image).astype(np.uint8)
365 |
366 | else:
367 | LR_image = self.degradation_process(image=image)["image"]
368 |
369 | example["image"] = (image/127.5 - 1.0).astype(np.float32)
370 | example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
371 |
372 | return example
373 |
374 |
375 | class ImageNetSRTrain(ImageNetSR):
376 | def __init__(self, **kwargs):
377 | super().__init__(**kwargs)
378 |
379 | def get_base(self):
380 | with open("data/imagenet_train_hr_indices.p", "rb") as f:
381 | indices = pickle.load(f)
382 | dset = ImageNetTrain(process_images=False,)
383 | return Subset(dset, indices)
384 |
385 |
386 | class ImageNetSRValidation(ImageNetSR):
387 | def __init__(self, **kwargs):
388 | super().__init__(**kwargs)
389 |
390 | def get_base(self):
391 | with open("data/imagenet_val_hr_indices.p", "rb") as f:
392 | indices = pickle.load(f)
393 | dset = ImageNetValidation(process_images=False,)
394 | return Subset(dset, indices)
395 |
--------------------------------------------------------------------------------
/ldm/modules/attention.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 | import math
3 | import pdb
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn, einsum
7 | from einops import rearrange, repeat
8 |
9 | from ldm.modules.diffusionmodules.util import checkpoint
10 | import numpy as np
11 | import scipy.stats as st
12 |
13 |
14 | def exists(val):
15 | return val is not None
16 |
17 |
18 | def uniq(arr):
19 | return{el: True for el in arr}.keys()
20 |
21 |
22 | def default(val, d):
23 | if exists(val):
24 | return val
25 | return d() if isfunction(d) else d
26 |
27 |
28 | def max_neg_value(t):
29 | return -torch.finfo(t.dtype).max
30 |
31 |
32 | def init_(tensor):
33 | dim = tensor.shape[-1]
34 | std = 1 / math.sqrt(dim)
35 | tensor.uniform_(-std, std)
36 | return tensor
37 |
38 |
39 | # feedforward
40 | class GEGLU(nn.Module):
41 | def __init__(self, dim_in, dim_out):
42 | super().__init__()
43 | self.proj = nn.Linear(dim_in, dim_out * 2)
44 |
45 | def forward(self, x):
46 | x, gate = self.proj(x).chunk(2, dim=-1)
47 | return x * F.gelu(gate)
48 |
49 |
50 | class FeedForward(nn.Module):
51 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
52 | super().__init__()
53 | inner_dim = int(dim * mult)
54 | dim_out = default(dim_out, dim)
55 | project_in = nn.Sequential(
56 | nn.Linear(dim, inner_dim),
57 | nn.GELU()
58 | ) if not glu else GEGLU(dim, inner_dim)
59 |
60 | self.net = nn.Sequential(
61 | project_in,
62 | nn.Dropout(dropout),
63 | nn.Linear(inner_dim, dim_out)
64 | )
65 |
66 | def forward(self, x):
67 | return self.net(x)
68 |
69 |
70 | def zero_module(module):
71 | """
72 | Zero out the parameters of a module and return it.
73 | """
74 | for p in module.parameters():
75 | p.detach().zero_()
76 | return module
77 |
78 |
79 | def Normalize(in_channels):
80 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
81 |
82 |
83 | class LinearAttention(nn.Module):
84 | def __init__(self, dim, heads=4, dim_head=32):
85 | super().__init__()
86 | self.heads = heads
87 | hidden_dim = dim_head * heads
88 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
89 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
90 |
91 | def forward(self, x):
92 | b, c, h, w = x.shape
93 | qkv = self.to_qkv(x)
94 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
95 | k = k.softmax(dim=-1)
96 | context = torch.einsum('bhdn,bhen->bhde', k, v)
97 | out = torch.einsum('bhde,bhdn->bhen', context, q)
98 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
99 | return self.to_out(out)
100 |
101 |
102 | class SpatialSelfAttention(nn.Module):
103 | def __init__(self, in_channels):
104 | super().__init__()
105 | self.in_channels = in_channels
106 |
107 | self.norm = Normalize(in_channels)
108 | self.q = torch.nn.Conv2d(in_channels,
109 | in_channels,
110 | kernel_size=1,
111 | stride=1,
112 | padding=0)
113 | self.k = torch.nn.Conv2d(in_channels,
114 | in_channels,
115 | kernel_size=1,
116 | stride=1,
117 | padding=0)
118 | self.v = torch.nn.Conv2d(in_channels,
119 | in_channels,
120 | kernel_size=1,
121 | stride=1,
122 | padding=0)
123 | self.proj_out = torch.nn.Conv2d(in_channels,
124 | in_channels,
125 | kernel_size=1,
126 | stride=1,
127 | padding=0)
128 |
129 | def forward(self, x):
130 | h_ = x
131 | h_ = self.norm(h_)
132 | q = self.q(h_)
133 | k = self.k(h_)
134 | v = self.v(h_)
135 |
136 | # compute attention
137 | b,c,h,w = q.shape
138 | q = rearrange(q, 'b c h w -> b (h w) c')
139 | k = rearrange(k, 'b c h w -> b c (h w)')
140 | w_ = torch.einsum('bij,bjk->bik', q, k)
141 |
142 | w_ = w_ * (int(c)**(-0.5))
143 | w_ = torch.nn.functional.softmax(w_, dim=2)
144 |
145 | # attend to values
146 | v = rearrange(v, 'b c h w -> b c (h w)')
147 | w_ = rearrange(w_, 'b i j -> b j i')
148 | h_ = torch.einsum('bij,bjk->bik', v, w_)
149 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
150 | h_ = self.proj_out(h_)
151 |
152 | return x+h_
153 |
154 |
155 | class CrossAttention(nn.Module):
156 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., struct_attn=False, save_map=False):
157 | super().__init__()
158 | inner_dim = dim_head * heads
159 | context_dim = default(context_dim, query_dim)
160 |
161 | self.scale = dim_head ** -0.5
162 | self.heads = heads
163 |
164 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
165 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
166 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
167 |
168 | self.to_out = nn.Sequential(
169 | nn.Linear(inner_dim, query_dim),
170 | nn.Dropout(dropout)
171 | )
172 |
173 | self.struct_attn = struct_attn
174 | self.save_map = save_map
175 |
176 | def schedule(self, total_step=50):
177 | # from 4.6 to 0
178 | seq = {}
179 | seq['linear'] = (torch.arange(total_step).flip(0) / total_step) * 4.6
180 | seq['log_smooth'] = torch.log(torch.arange(total_step).flip(0) * 100 + 1)
181 | seq['log_rapid'] = - torch.log(torch.arange(total_step) + 0.001)/1.5
182 | return seq
183 |
184 | def extra_mask_generation(self, down_scale, mask_cond):
185 | hw = int(512/down_scale)
186 | attn_extra = torch.ones([8, int(hw**2), 77]).cuda() * 0
187 | attn_extra = rearrange(attn_extra, 'b (h w) l -> b h w l', h = hw)
188 | obj_infs = mask_cond['object_infs']
189 | for category in obj_infs:
190 | for bbox in category['loc']:
191 | x1, y1, x2, y2 = [xy/down_scale for xy in bbox]
192 | x1 = int(max(x1 - 1, 0))
193 | y1 = int(max(y1 - 1, 0))
194 | x2 = int(min(x2 + 1, hw))
195 | y2 = int(min(y2 + 1, hw))
196 | if mask_cond['para']['soft']:
197 | w = x2 - x1
198 | h = y2 - y1
199 | if w > h:
200 | offset = int(w/2)
201 | w_step = 1
202 | h_step = w/h
203 | else:
204 | offset = int(h/2)
205 | h_step = 1
206 | w_step = h/w
207 | if w > 1 and h > 1:
208 | y, x = np.mgrid[-offset:(offset+0.2):h_step, -offset:(offset+0.2):w_step]
209 | pos = np.empty(y.shape + (2,))
210 | pos[:, :, 0] = (x/offset) * mask_cond['para']['L_soft']
211 | pos[:, :, 1] = (y/offset) * mask_cond['para']['L_soft']
212 | rv = st.multivariate_normal([0, 0], [[1, 0], [0, 1]])
213 | _, tw, th = attn_extra[:, y1:y2, x1:x2, 0].shape
214 | value = torch.from_numpy(rv.pdf(pos)/(rv.pdf(pos).max())).cuda()[:tw, :th]
215 | for p in category['prompt']:
216 | if mask_cond['para']['neg_out']:
217 | attn_extra[:, :, :, p] = -9999999999
218 | attn_extra[:, y1:y2, x1:x2, p] = value
219 | else:
220 | for p in category['prompt']:
221 | if mask_cond['para']['neg_out']:
222 | attn_extra[:, :, :, p] = -9999999999
223 | attn_extra[:, y1:y2, x1:x2, p] = 1
224 | else:
225 | for p in category['prompt']:
226 | if mask_cond['para']['neg_out']:
227 | attn_extra[:, :, :, p] = -9999999999
228 | attn_extra[:, y1:y2, x1:x2, p] = 1
229 | attn_extra = rearrange(attn_extra, 'b h w l -> b (h w) l')
230 | return attn_extra
231 |
232 | def forward(self, x, context=None, t=None, mask_cond=None, mask=None):
233 | h = self.heads
234 | q = self.to_q(x)
235 | down_scale = 512 / math.sqrt(x.shape[1])
236 |
237 | if t is not None:
238 | t = 50 - t - 1
239 | weights = self.schedule(total_step=50)
240 | weight = weights['linear'][t]
241 | if isinstance(context, list):
242 | if self.struct_attn:
243 | out = self.masked_qkv(q, context, weight, down_scale, mask_cond, mask)
244 | else:
245 | context = torch.cat([context[0], context[1]['k'][0]], dim=0) # use key tensor for context
246 | out = self.normal_qkv(q, context, mask)
247 | else:
248 | context = default(context, x)
249 | out = self.normal_qkv(q, context, mask)
250 |
251 | return self.to_out(out)
252 |
253 | def masked_qkv(self, q, context, weight, down_scale, mask_cond, mask):
254 | """
255 | context: list of [uc, list of conditional context]
256 | """
257 | uc_context = context[0]
258 | context_k, context_v = context[1]['k'], context[1]['v']
259 |
260 | if isinstance(context_k, list) and isinstance(context_v, list):
261 | out = self._masked_qkv(q, uc_context, context_k, context_v, mask_cond, weight, down_scale, mask)
262 | else:
263 | raise NotImplementedError
264 |
265 | return out
266 |
267 | def _masked_qkv(self, q, uc_context, context_k, context_v, mask_cond, weight, down_scale, mask):
268 | h = self.heads
269 |
270 | assert uc_context.size(0) == context_k[0].size(0) == context_v[0].size(0)
271 | true_bs = uc_context.size(0) * h
272 |
273 | k_uc, v_uc = self.get_kv(uc_context)
274 | k_c = [self.to_k(c_k) for c_k in context_k]
275 | v_c = [self.to_v(c_v) for c_v in context_v]
276 |
277 |
278 | q = rearrange(q, 'b n (h d) -> (b h) n d', h=h)
279 | k_uc = rearrange(k_uc, 'b n (h d) -> (b h) n d', h=h)
280 | v_uc = rearrange(v_uc, 'b n (h d) -> (b h) n d', h=h)
281 |
282 | k_c = [rearrange(k, 'b n (h d) -> (b h) n d', h=h) for k in k_c]
283 | v_c = [rearrange(v, 'b n (h d) -> (b h) n d', h=h) for v in v_c]
284 |
285 | sim_uc = einsum('b i d, b j d -> b i j', q[:true_bs], k_uc) * self.scale
286 | sim_c = [einsum('b i d, b j d -> b i j', q[true_bs:], k) * self.scale for k in k_c]
287 |
288 | if mask_cond['is_use']:
289 | attn_extra = self.extra_mask_generation(down_scale, mask_cond)
290 | w_dot = mask_cond['para']['w_dot']
291 | w = w_dot * weight * sim_c[0].max()
292 | sim_c[0] = sim_c[0] + w * attn_extra
293 |
294 | attn_uc = sim_uc.softmax(dim=-1)
295 |
296 | attn_c = [sim.softmax(dim=-1) for sim in sim_c]
297 | if self.save_map and sim_uc.size(1) != sim_uc.size(2):
298 | self.save_attn_maps(attn_c)
299 | if mask_cond['is_use']:
300 | self.save_extra_attn_maps([attn_extra])
301 |
302 | out_uc = einsum('b i j, b j d -> b i d', attn_uc, v_uc)
303 | n_keys, n_values = len(k_c), len(v_c)
304 | if n_keys == n_values:
305 | out_c = sum([einsum('b i j, b j d -> b i d', attn, v) for attn, v in zip(attn_c, v_c)]) / len(v_c)
306 | else:
307 | assert n_keys == 1 or n_values == 1
308 | out_c = sum([einsum('b i j, b j d -> b i d', attn, v) for attn in attn_c for v in v_c]) / (n_keys * n_values)
309 |
310 | out = torch.cat([out_uc, out_c], dim=0)
311 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
312 |
313 | return out
314 |
315 |
316 | def normal_qkv(self, q, context, mask):
317 | h = self.heads
318 |
319 | k = self.to_k(context)
320 | v = self.to_v(context)
321 |
322 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
323 |
324 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
325 |
326 | if exists(mask):
327 | mask = rearrange(mask, 'b ... -> b (...)')
328 | max_neg_value = -torch.finfo(sim.dtype).max
329 | mask = repeat(mask, 'b j -> (b h) () j', h=h)
330 | sim.masked_fill_(~mask, max_neg_value)
331 |
332 | attn = sim.softmax(dim=-1)
333 |
334 | if self.save_map and sim.size(1) != sim.size(2):
335 | self.save_attn_maps(attn.chunk(2)[1])
336 |
337 | out = einsum('b i j, b j d -> b i d', attn, v)
338 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
339 |
340 | return out
341 |
342 | def get_kv(self, context):
343 | return self.to_k(context), self.to_v(context)
344 |
345 | def save_attn_maps(self, attn):
346 | h = self.heads
347 | if isinstance(attn, list):
348 | height = width = int(math.sqrt(attn[0].size(1)))
349 | self.attn_maps = [rearrange(m.detach(), '(b x) (h w) l -> b x h w l', x=h, h=height, w=width)[...,:40].cpu() for m in attn]
350 | else:
351 | height = width = int(math.sqrt(attn.size(1)))
352 | self.attn_maps = rearrange(attn.detach(), '(b x) (h w) l -> b x h w l', x=h, h=height, w=width)[...,:40].cpu()
353 | def save_extra_attn_maps(self, attn):
354 | h = self.heads
355 | if isinstance(attn, list):
356 | height = width = int(math.sqrt(attn[0].size(1)))
357 | self.attn_extra = [rearrange(m.detach(), '(b x) (h w) l -> b x h w l', x=h, h=height, w=width)[...,:40].cpu() for m in attn]
358 | else:
359 | height = width = int(math.sqrt(attn.size(1)))
360 | self.attn_extra = rearrange(attn.detach(), '(b x) (h w) l -> b x h w l', x=h, h=height, w=width)[...,:40].cpu()
361 |
362 |
363 | class BasicTransformerBlock(nn.Module):
364 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, struct_attn=False, save_map=False):
365 | super().__init__()
366 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
367 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
368 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
369 | heads=n_heads, dim_head=d_head, dropout=dropout,
370 | struct_attn=struct_attn, save_map=save_map) # is self-attn if context is none
371 | self.norm1 = nn.LayerNorm(dim)
372 | self.norm2 = nn.LayerNorm(dim)
373 | self.norm3 = nn.LayerNorm(dim)
374 | self.checkpoint = checkpoint
375 |
376 | def forward(self, x, context=None, t=None, mask_cond=None):
377 | return checkpoint(self._forward, (x, context, t, mask_cond), self.parameters(), self.checkpoint)
378 |
379 | def _forward(self, x, context=None, t=None, mask_cond=None):
380 | x = self.attn1(self.norm1(x)) + x
381 | x = self.attn2(self.norm2(x), context=context, t=t, mask_cond=mask_cond) + x
382 | x = self.ff(self.norm3(x)) + x
383 | return x
384 |
385 |
386 | class SpatialTransformer(nn.Module):
387 | """
388 | Transformer block for image-like data.
389 | First, project the input (aka embedding)
390 | and reshape to b, t, d.
391 | Then apply standard transformer action.
392 | Finally, reshape to image
393 | """
394 | def __init__(self, in_channels, n_heads, d_head,
395 | depth=1, dropout=0., context_dim=None, struct_attn=False, save_map=False):
396 | super().__init__()
397 | self.in_channels = in_channels
398 | inner_dim = n_heads * d_head
399 | self.norm = Normalize(in_channels)
400 |
401 | self.proj_in = nn.Conv2d(in_channels,
402 | inner_dim,
403 | kernel_size=1,
404 | stride=1,
405 | padding=0)
406 |
407 | self.transformer_blocks = nn.ModuleList(
408 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, struct_attn=struct_attn, save_map=save_map)
409 | for d in range(depth)]
410 | )
411 |
412 | self.proj_out = zero_module(nn.Conv2d(inner_dim,
413 | in_channels,
414 | kernel_size=1,
415 | stride=1,
416 | padding=0))
417 | self.struct_attn = struct_attn
418 |
419 | def forward(self, x, context=None, t=None, mask_cond=None):
420 | # note: if no context is given, cross-attention defaults to self-attention
421 | b, c, h, w = x.shape
422 | x_in = x
423 | x = self.norm(x)
424 | x = self.proj_in(x)
425 | x = rearrange(x, 'b c h w -> b (h w) c')
426 | for block in self.transformer_blocks:
427 | x = block(x, context=context, t=t, mask_cond=mask_cond)
428 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
429 | x = self.proj_out(x)
430 | return x + x_in
--------------------------------------------------------------------------------
/ldm/models/diffusion/ddim.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | from collections import defaultdict
4 | import torch
5 | import numpy as np
6 | from tqdm import tqdm
7 | from functools import partial
8 |
9 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
10 |
11 |
12 | class DDIMSampler(object):
13 | def __init__(self, model, schedule="linear", **kwargs):
14 | super().__init__()
15 | self.model = model
16 | self.ddpm_num_timesteps = model.num_timesteps
17 | self.schedule = schedule
18 |
19 | def register_buffer(self, name, attr):
20 | if type(attr) == torch.Tensor:
21 | if attr.device != torch.device("cuda"):
22 | attr = attr.to(torch.device("cuda"))
23 | setattr(self, name, attr)
24 |
25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
26 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
27 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
28 | alphas_cumprod = self.model.alphas_cumprod
29 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
30 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
31 |
32 | self.register_buffer('betas', to_torch(self.model.betas))
33 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
34 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
35 |
36 | # calculations for diffusion q(x_t | x_{t-1}) and others
37 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
38 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
39 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
40 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
41 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
42 |
43 | # ddim sampling parameters
44 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
45 | ddim_timesteps=self.ddim_timesteps,
46 | eta=ddim_eta,verbose=verbose)
47 | self.register_buffer('ddim_sigmas', ddim_sigmas)
48 | self.register_buffer('ddim_alphas', ddim_alphas)
49 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
50 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
51 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
52 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
53 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
54 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
55 |
56 | @torch.no_grad()
57 | def sample(self,
58 | S,
59 | batch_size,
60 | shape,
61 | conditioning=None,
62 | callback=None,
63 | normals_sequence=None,
64 | img_callback=None,
65 | quantize_x0=False,
66 | eta=0.,
67 | mask=None,
68 | x0=None,
69 | temperature=1.,
70 | noise_dropout=0.,
71 | score_corrector=None,
72 | corrector_kwargs=None,
73 | verbose=True,
74 | x_T=None,
75 | log_every_t=100,
76 | unconditional_guidance_scale=1.,
77 | unconditional_conditioning=None,
78 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
79 | skip=False,
80 | quiet=False,
81 | mask_cond = None,
82 | **kwargs
83 | ):
84 | if conditioning is not None:
85 | assert isinstance(conditioning, dict)
86 |
87 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
88 | # sampling
89 | C, H, W = shape
90 | size = (batch_size, C, H, W)
91 |
92 | samples, intermediates = self.ddim_sampling(conditioning, size,
93 | callback=callback,
94 | img_callback=img_callback,
95 | quantize_denoised=quantize_x0,
96 | mask=mask, x0=x0,
97 | ddim_use_original_steps=False,
98 | noise_dropout=noise_dropout,
99 | temperature=temperature,
100 | score_corrector=score_corrector,
101 | corrector_kwargs=corrector_kwargs,
102 | x_T=x_T,
103 | log_every_t=log_every_t,
104 | unconditional_guidance_scale=unconditional_guidance_scale,
105 | unconditional_conditioning=unconditional_conditioning,
106 | skip=skip,
107 | mask_cond = mask_cond,
108 | quiet=quiet,
109 | **kwargs
110 | )
111 | return samples, intermediates
112 |
113 | @torch.no_grad()
114 | def ddim_sampling(self, cond, shape,
115 | x_T=None, ddim_use_original_steps=False,
116 | callback=None, timesteps=None, quantize_denoised=False,
117 | mask=None, x0=None, img_callback=None, log_every_t=100,
118 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, mask_cond=None,
119 | unconditional_guidance_scale=1., unconditional_conditioning=None, skip=False, quiet=False, **kwargs):
120 | device = self.model.betas.device
121 | b = shape[0]
122 | if x_T is None:
123 | img = torch.randn(shape, device=device)
124 | else:
125 | img = x_T
126 |
127 | if timesteps is None:
128 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
129 | elif timesteps is not None and not ddim_use_original_steps:
130 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
131 | timesteps = self.ddim_timesteps[:subset_end]
132 |
133 | intermediates = {'x_inter': [img], 'pred_x0': [img]}
134 | if skip:
135 | return img, intermediates
136 |
137 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
138 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
139 |
140 | if not quiet:
141 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
142 | else:
143 | iterator = time_range
144 | old_eps = []
145 | self.attn_maps = defaultdict(list)
146 |
147 | for i, step in enumerate(iterator):
148 | index = total_steps - i - 1
149 | ts = torch.full((b,), step, device=device, dtype=torch.long)
150 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
151 |
152 | if mask is not None:
153 | assert x0 is not None
154 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
155 | img = img_orig * mask + (1. - mask) * img
156 |
157 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
158 | quantize_denoised=quantize_denoised, temperature=temperature,
159 | noise_dropout=noise_dropout, score_corrector=score_corrector,
160 | corrector_kwargs=corrector_kwargs,
161 | unconditional_guidance_scale=unconditional_guidance_scale,
162 | unconditional_conditioning=unconditional_conditioning,
163 | mask_cond=mask_cond,
164 | old_eps=old_eps, t_next=ts_next)
165 | img, pred_x0 = outs
166 | if callback: callback(i)
167 | if img_callback: img_callback(pred_x0, i)
168 |
169 | if index % log_every_t == 0 or index == total_steps - 1:
170 | intermediates['x_inter'].append(img)
171 | intermediates['pred_x0'].append(pred_x0)
172 |
173 | if kwargs.get('save_attn_maps', False):
174 | for name, module in self.model.model.diffusion_model.named_modules():
175 | module_name = type(module).__name__
176 | if module_name == 'CrossAttention' and 'attn2' in name:
177 | self.attn_maps[name].append(module.attn_maps)
178 | if mask_cond['is_use']:
179 | self.attn_maps[name+'_extra'].append(module.attn_extra)
180 |
181 | return img, intermediates
182 |
183 | @torch.no_grad()
184 | def get_attention(self,
185 | S,
186 | batch_size,
187 | shape,
188 | conditioning=None,
189 | callback=None,
190 | normals_sequence=None,
191 | img_callback=None,
192 | quantize_x0=False,
193 | eta=0.,
194 | mask=None,
195 | x0=None,
196 | temperature=1.,
197 | noise_dropout=0.,
198 | score_corrector=None,
199 | corrector_kwargs=None,
200 | verbose=True,
201 | x_T=None,
202 | log_every_t=100,
203 | unconditional_guidance_scale=1.,
204 | unconditional_conditioning=None,
205 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
206 | skip=False,
207 | quiet=False,
208 | mask_cond = None,
209 | **kwargs
210 | ):
211 | if conditioning is not None:
212 | assert isinstance(conditioning, dict)
213 |
214 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
215 | # sampling
216 | C, H, W = shape
217 | size = (batch_size, C, H, W)
218 |
219 | self.get_attention_(conditioning, size,
220 | callback=callback,
221 | img_callback=img_callback,
222 | quantize_denoised=quantize_x0,
223 | mask=mask, x0=x0,
224 | ddim_use_original_steps=False,
225 | noise_dropout=noise_dropout,
226 | temperature=temperature,
227 | score_corrector=score_corrector,
228 | corrector_kwargs=corrector_kwargs,
229 | x_T=x_T,
230 | log_every_t=log_every_t,
231 | unconditional_guidance_scale=unconditional_guidance_scale,
232 | unconditional_conditioning=unconditional_conditioning,
233 | skip=skip,
234 | mask_cond = mask_cond,
235 | quiet=quiet,
236 | **kwargs
237 | )
238 |
239 | @torch.no_grad()
240 | def get_attention_(self, cond, shape,
241 | x_T=None, ddim_use_original_steps=False,
242 | callback=None, timesteps=None, quantize_denoised=False,
243 | mask=None, x0=None, img_callback=None, log_every_t=100,
244 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, mask_cond=None,
245 | unconditional_guidance_scale=1., unconditional_conditioning=None, skip=False, quiet=False, **kwargs):
246 | device = self.model.betas.device
247 | b = shape[0]
248 | if x_T is None:
249 | img = torch.randn(shape, device=device)
250 | else:
251 | img = x_T
252 |
253 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
254 |
255 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
256 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
257 |
258 | iterator = time_range
259 | old_eps = []
260 | self.attn_maps = defaultdict(list)
261 |
262 | for i, step in enumerate(iterator):
263 | index = total_steps - i - 1
264 | ts = torch.full((b,), step, device=device, dtype=torch.long)
265 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
266 |
267 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
268 | quantize_denoised=quantize_denoised, temperature=temperature,
269 | noise_dropout=noise_dropout, score_corrector=score_corrector,
270 | corrector_kwargs=corrector_kwargs,
271 | unconditional_guidance_scale=unconditional_guidance_scale,
272 | unconditional_conditioning=unconditional_conditioning,
273 | mask_cond=mask_cond,
274 | old_eps=old_eps, t_next=ts_next)
275 |
276 | if kwargs.get('save_attn_maps', False):
277 | for name, module in self.model.model.diffusion_model.named_modules():
278 | module_name = type(module).__name__
279 | if module_name == 'CrossAttention' and 'attn2' in name:
280 | self.attn_maps[name].append(module.attn_maps)
281 | break
282 |
283 | @torch.no_grad()
284 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
285 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,mask_cond=None,
286 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
287 | b, *_, device = *x.shape, x.device
288 |
289 | #def get_model_output(x, t):
290 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
291 | e_t = self.model.apply_model(x, t, c)
292 | else:
293 | x_in = torch.cat([x] * 2)
294 | t_in = torch.cat([t] * 2)
295 | if isinstance(c, (list, dict)):
296 | c_in = [unconditional_conditioning, c]
297 | else:
298 | c_in = torch.cat([unconditional_conditioning, c])
299 |
300 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, mask_cond=mask_cond).chunk(2)
301 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
302 |
303 | if score_corrector is not None:
304 | assert self.model.parameterization == "eps"
305 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
306 |
307 | # return e_t
308 |
309 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
310 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
311 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
312 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
313 |
314 | #def get_x_prev_and_pred_x0(e_t, index):
315 | # select parameters corresponding to the currently considered timestep
316 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
317 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
318 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
319 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
320 |
321 | # current prediction for x_0
322 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
323 | if quantize_denoised:
324 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
325 | # direction pointing to x_t
326 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
327 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
328 | if noise_dropout > 0.:
329 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
330 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
331 | return x_prev, pred_x0
332 |
333 | #e_t = get_model_output(x, t)
334 | #if len(old_eps) == 0:
335 | # # Pseudo Improved Euler (2nd order)
336 | # x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
337 | # e_t_next = get_model_output(x_prev, t_next)
338 | # e_t_prime = (e_t + e_t_next) / 2
339 | #elif len(old_eps) == 1:
340 | # # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
341 | # e_t_prime = (3 * e_t - old_eps[-1]) / 2
342 | #elif len(old_eps) == 2:
343 | # # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
344 | # e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
345 | #elif len(old_eps) >= 3:
346 | # # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
347 | # e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
348 |
349 | #x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
350 |
351 | #return x_prev, pred_x0, e_t
352 |
--------------------------------------------------------------------------------
/ldm/models/autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 | import torch.nn.functional as F
4 | from contextlib import contextmanager
5 |
6 | from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
7 |
8 | from ldm.modules.diffusionmodules.model import Encoder, Decoder
9 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
10 |
11 | from ldm.util import instantiate_from_config
12 |
13 |
14 | class VQModel(pl.LightningModule):
15 | def __init__(self,
16 | ddconfig,
17 | lossconfig,
18 | n_embed,
19 | embed_dim,
20 | ckpt_path=None,
21 | ignore_keys=[],
22 | image_key="image",
23 | colorize_nlabels=None,
24 | monitor=None,
25 | batch_resize_range=None,
26 | scheduler_config=None,
27 | lr_g_factor=1.0,
28 | remap=None,
29 | sane_index_shape=False, # tell vector quantizer to return indices as bhw
30 | use_ema=False
31 | ):
32 | super().__init__()
33 | self.embed_dim = embed_dim
34 | self.n_embed = n_embed
35 | self.image_key = image_key
36 | self.encoder = Encoder(**ddconfig)
37 | self.decoder = Decoder(**ddconfig)
38 | self.loss = instantiate_from_config(lossconfig)
39 | self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
40 | remap=remap,
41 | sane_index_shape=sane_index_shape)
42 | self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
43 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
44 | if colorize_nlabels is not None:
45 | assert type(colorize_nlabels)==int
46 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
47 | if monitor is not None:
48 | self.monitor = monitor
49 | self.batch_resize_range = batch_resize_range
50 | if self.batch_resize_range is not None:
51 | print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
52 |
53 | self.use_ema = use_ema
54 | if self.use_ema:
55 | self.model_ema = LitEma(self)
56 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
57 |
58 | if ckpt_path is not None:
59 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
60 | self.scheduler_config = scheduler_config
61 | self.lr_g_factor = lr_g_factor
62 |
63 | @contextmanager
64 | def ema_scope(self, context=None):
65 | if self.use_ema:
66 | self.model_ema.store(self.parameters())
67 | self.model_ema.copy_to(self)
68 | if context is not None:
69 | print(f"{context}: Switched to EMA weights")
70 | try:
71 | yield None
72 | finally:
73 | if self.use_ema:
74 | self.model_ema.restore(self.parameters())
75 | if context is not None:
76 | print(f"{context}: Restored training weights")
77 |
78 | def init_from_ckpt(self, path, ignore_keys=list()):
79 | sd = torch.load(path, map_location="cpu")["state_dict"]
80 | keys = list(sd.keys())
81 | for k in keys:
82 | for ik in ignore_keys:
83 | if k.startswith(ik):
84 | print("Deleting key {} from state_dict.".format(k))
85 | del sd[k]
86 | missing, unexpected = self.load_state_dict(sd, strict=False)
87 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
88 | if len(missing) > 0:
89 | print(f"Missing Keys: {missing}")
90 | print(f"Unexpected Keys: {unexpected}")
91 |
92 | def on_train_batch_end(self, *args, **kwargs):
93 | if self.use_ema:
94 | self.model_ema(self)
95 |
96 | def encode(self, x):
97 | h = self.encoder(x)
98 | h = self.quant_conv(h)
99 | quant, emb_loss, info = self.quantize(h)
100 | return quant, emb_loss, info
101 |
102 | def encode_to_prequant(self, x):
103 | h = self.encoder(x)
104 | h = self.quant_conv(h)
105 | return h
106 |
107 | def decode(self, quant):
108 | quant = self.post_quant_conv(quant)
109 | dec = self.decoder(quant)
110 | return dec
111 |
112 | def decode_code(self, code_b):
113 | quant_b = self.quantize.embed_code(code_b)
114 | dec = self.decode(quant_b)
115 | return dec
116 |
117 | def forward(self, input, return_pred_indices=False):
118 | quant, diff, (_,_,ind) = self.encode(input)
119 | dec = self.decode(quant)
120 | if return_pred_indices:
121 | return dec, diff, ind
122 | return dec, diff
123 |
124 | def get_input(self, batch, k):
125 | x = batch[k]
126 | if len(x.shape) == 3:
127 | x = x[..., None]
128 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
129 | if self.batch_resize_range is not None:
130 | lower_size = self.batch_resize_range[0]
131 | upper_size = self.batch_resize_range[1]
132 | if self.global_step <= 4:
133 | # do the first few batches with max size to avoid later oom
134 | new_resize = upper_size
135 | else:
136 | new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
137 | if new_resize != x.shape[2]:
138 | x = F.interpolate(x, size=new_resize, mode="bicubic")
139 | x = x.detach()
140 | return x
141 |
142 | def training_step(self, batch, batch_idx, optimizer_idx):
143 | # https://github.com/pytorch/pytorch/issues/37142
144 | # try not to fool the heuristics
145 | x = self.get_input(batch, self.image_key)
146 | xrec, qloss, ind = self(x, return_pred_indices=True)
147 |
148 | if optimizer_idx == 0:
149 | # autoencode
150 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
151 | last_layer=self.get_last_layer(), split="train",
152 | predicted_indices=ind)
153 |
154 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
155 | return aeloss
156 |
157 | if optimizer_idx == 1:
158 | # discriminator
159 | discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
160 | last_layer=self.get_last_layer(), split="train")
161 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
162 | return discloss
163 |
164 | def validation_step(self, batch, batch_idx):
165 | log_dict = self._validation_step(batch, batch_idx)
166 | with self.ema_scope():
167 | log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
168 | return log_dict
169 |
170 | def _validation_step(self, batch, batch_idx, suffix=""):
171 | x = self.get_input(batch, self.image_key)
172 | xrec, qloss, ind = self(x, return_pred_indices=True)
173 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
174 | self.global_step,
175 | last_layer=self.get_last_layer(),
176 | split="val"+suffix,
177 | predicted_indices=ind
178 | )
179 |
180 | discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
181 | self.global_step,
182 | last_layer=self.get_last_layer(),
183 | split="val"+suffix,
184 | predicted_indices=ind
185 | )
186 | rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
187 | self.log(f"val{suffix}/rec_loss", rec_loss,
188 | prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
189 | self.log(f"val{suffix}/aeloss", aeloss,
190 | prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
191 | if version.parse(pl.__version__) >= version.parse('1.4.0'):
192 | del log_dict_ae[f"val{suffix}/rec_loss"]
193 | self.log_dict(log_dict_ae)
194 | self.log_dict(log_dict_disc)
195 | return self.log_dict
196 |
197 | def configure_optimizers(self):
198 | lr_d = self.learning_rate
199 | lr_g = self.lr_g_factor*self.learning_rate
200 | print("lr_d", lr_d)
201 | print("lr_g", lr_g)
202 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
203 | list(self.decoder.parameters())+
204 | list(self.quantize.parameters())+
205 | list(self.quant_conv.parameters())+
206 | list(self.post_quant_conv.parameters()),
207 | lr=lr_g, betas=(0.5, 0.9))
208 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
209 | lr=lr_d, betas=(0.5, 0.9))
210 |
211 | if self.scheduler_config is not None:
212 | scheduler = instantiate_from_config(self.scheduler_config)
213 |
214 | print("Setting up LambdaLR scheduler...")
215 | scheduler = [
216 | {
217 | 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
218 | 'interval': 'step',
219 | 'frequency': 1
220 | },
221 | {
222 | 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
223 | 'interval': 'step',
224 | 'frequency': 1
225 | },
226 | ]
227 | return [opt_ae, opt_disc], scheduler
228 | return [opt_ae, opt_disc], []
229 |
230 | def get_last_layer(self):
231 | return self.decoder.conv_out.weight
232 |
233 | def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
234 | log = dict()
235 | x = self.get_input(batch, self.image_key)
236 | x = x.to(self.device)
237 | if only_inputs:
238 | log["inputs"] = x
239 | return log
240 | xrec, _ = self(x)
241 | if x.shape[1] > 3:
242 | # colorize with random projection
243 | assert xrec.shape[1] > 3
244 | x = self.to_rgb(x)
245 | xrec = self.to_rgb(xrec)
246 | log["inputs"] = x
247 | log["reconstructions"] = xrec
248 | if plot_ema:
249 | with self.ema_scope():
250 | xrec_ema, _ = self(x)
251 | if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
252 | log["reconstructions_ema"] = xrec_ema
253 | return log
254 |
255 | def to_rgb(self, x):
256 | assert self.image_key == "segmentation"
257 | if not hasattr(self, "colorize"):
258 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
259 | x = F.conv2d(x, weight=self.colorize)
260 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
261 | return x
262 |
263 |
264 | class VQModelInterface(VQModel):
265 | def __init__(self, embed_dim, *args, **kwargs):
266 | super().__init__(embed_dim=embed_dim, *args, **kwargs)
267 | self.embed_dim = embed_dim
268 |
269 | def encode(self, x):
270 | h = self.encoder(x)
271 | h = self.quant_conv(h)
272 | return h
273 |
274 | def decode(self, h, force_not_quantize=False):
275 | # also go through quantization layer
276 | if not force_not_quantize:
277 | quant, emb_loss, info = self.quantize(h)
278 | else:
279 | quant = h
280 | quant = self.post_quant_conv(quant)
281 | dec = self.decoder(quant)
282 | return dec
283 |
284 |
285 | class AutoencoderKL(pl.LightningModule):
286 | def __init__(self,
287 | ddconfig,
288 | lossconfig,
289 | embed_dim,
290 | ckpt_path=None,
291 | ignore_keys=[],
292 | image_key="image",
293 | colorize_nlabels=None,
294 | monitor=None,
295 | ):
296 | super().__init__()
297 | self.image_key = image_key
298 | self.encoder = Encoder(**ddconfig)
299 | self.decoder = Decoder(**ddconfig)
300 | self.loss = instantiate_from_config(lossconfig)
301 | assert ddconfig["double_z"]
302 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
303 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
304 | self.embed_dim = embed_dim
305 | if colorize_nlabels is not None:
306 | assert type(colorize_nlabels)==int
307 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
308 | if monitor is not None:
309 | self.monitor = monitor
310 | if ckpt_path is not None:
311 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
312 |
313 | def init_from_ckpt(self, path, ignore_keys=list()):
314 | sd = torch.load(path, map_location="cpu")["state_dict"]
315 | keys = list(sd.keys())
316 | for k in keys:
317 | for ik in ignore_keys:
318 | if k.startswith(ik):
319 | print("Deleting key {} from state_dict.".format(k))
320 | del sd[k]
321 | self.load_state_dict(sd, strict=False)
322 | print(f"Restored from {path}")
323 |
324 | def encode(self, x):
325 | h = self.encoder(x)
326 | moments = self.quant_conv(h)
327 | posterior = DiagonalGaussianDistribution(moments)
328 | return posterior
329 |
330 | def decode(self, z):
331 | z = self.post_quant_conv(z)
332 | dec = self.decoder(z)
333 | return dec
334 |
335 | def forward(self, input, sample_posterior=True):
336 | posterior = self.encode(input)
337 | if sample_posterior:
338 | z = posterior.sample()
339 | else:
340 | z = posterior.mode()
341 | dec = self.decode(z)
342 | return dec, posterior
343 |
344 | def get_input(self, batch, k):
345 | x = batch[k]
346 | if len(x.shape) == 3:
347 | x = x[..., None]
348 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
349 | return x
350 |
351 | def training_step(self, batch, batch_idx, optimizer_idx):
352 | inputs = self.get_input(batch, self.image_key)
353 | reconstructions, posterior = self(inputs)
354 |
355 | if optimizer_idx == 0:
356 | # train encoder+decoder+logvar
357 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
358 | last_layer=self.get_last_layer(), split="train")
359 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
360 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
361 | return aeloss
362 |
363 | if optimizer_idx == 1:
364 | # train the discriminator
365 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
366 | last_layer=self.get_last_layer(), split="train")
367 |
368 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
369 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
370 | return discloss
371 |
372 | def validation_step(self, batch, batch_idx):
373 | inputs = self.get_input(batch, self.image_key)
374 | reconstructions, posterior = self(inputs)
375 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
376 | last_layer=self.get_last_layer(), split="val")
377 |
378 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
379 | last_layer=self.get_last_layer(), split="val")
380 |
381 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
382 | self.log_dict(log_dict_ae)
383 | self.log_dict(log_dict_disc)
384 | return self.log_dict
385 |
386 | def configure_optimizers(self):
387 | lr = self.learning_rate
388 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
389 | list(self.decoder.parameters())+
390 | list(self.quant_conv.parameters())+
391 | list(self.post_quant_conv.parameters()),
392 | lr=lr, betas=(0.5, 0.9))
393 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
394 | lr=lr, betas=(0.5, 0.9))
395 | return [opt_ae, opt_disc], []
396 |
397 | def get_last_layer(self):
398 | return self.decoder.conv_out.weight
399 |
400 | @torch.no_grad()
401 | def log_images(self, batch, only_inputs=False, **kwargs):
402 | log = dict()
403 | x = self.get_input(batch, self.image_key)
404 | x = x.to(self.device)
405 | if not only_inputs:
406 | xrec, posterior = self(x)
407 | if x.shape[1] > 3:
408 | # colorize with random projection
409 | assert xrec.shape[1] > 3
410 | x = self.to_rgb(x)
411 | xrec = self.to_rgb(xrec)
412 | log["samples"] = self.decode(torch.randn_like(posterior.sample()))
413 | log["reconstructions"] = xrec
414 | log["inputs"] = x
415 | return log
416 |
417 | def to_rgb(self, x):
418 | assert self.image_key == "segmentation"
419 | if not hasattr(self, "colorize"):
420 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
421 | x = F.conv2d(x, weight=self.colorize)
422 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
423 | return x
424 |
425 |
426 | class IdentityFirstStage(torch.nn.Module):
427 | def __init__(self, *args, vq_interface=False, **kwargs):
428 | self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
429 | super().__init__()
430 |
431 | def encode(self, x, *args, **kwargs):
432 | return x
433 |
434 | def decode(self, x, *args, **kwargs):
435 | return x
436 |
437 | def quantize(self, x, *args, **kwargs):
438 | if self.vq_interface:
439 | return x, None, [None, None, None]
440 | return x
441 |
442 | def forward(self, x, *args, **kwargs):
443 | return x
444 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 |
4 | import matplotlib.pyplot as plt
5 | import torch.nn.functional as F
6 | import numpy as np
7 |
8 | from PIL import Image
9 | from tqdm import tqdm, trange
10 | from einops import rearrange
11 | from torchvision.utils import make_grid
12 | from torch import autocast
13 | from ldm.util import instantiate_from_config
14 | from itertools import combinations
15 |
16 |
17 | def load_model_from_config(config, ckpt, verbose=False):
18 | print(f"Loading model from {ckpt}")
19 | pl_sd = torch.load(ckpt, map_location="cpu")
20 | if "global_step" in pl_sd:
21 | print(f"Global Step: {pl_sd['global_step']}")
22 | sd = pl_sd["state_dict"]
23 | model = instantiate_from_config(config.model)
24 | m, u = model.load_state_dict(sd, strict=False)
25 | if len(m) > 0 and verbose:
26 | print("missing keys:")
27 | print(m)
28 | if len(u) > 0 and verbose:
29 | print("unexpected keys:")
30 | print(u)
31 |
32 | model.cuda()
33 | model.eval()
34 | return model
35 |
36 | def sampling(model, sampler, prompt, n_samples, scale=7.5, steps=50, conjunction=False, mask_cond=None, img=None):
37 | H = W = 512
38 | C = 4
39 | f = 8
40 | precision_scope = autocast
41 | with torch.no_grad():
42 | with precision_scope("cuda"):
43 | with model.ema_scope():
44 | all_samples = list()
45 | for n in range(n_samples):
46 | for bid, p in enumerate(prompt):
47 |
48 | uc = model.get_learned_conditioning([""])
49 | _c = model.get_learned_conditioning(p)
50 | c = {'k': [_c], 'v': [_c]}
51 | shape = [C, H // f, W // f]
52 |
53 | samples_ddim, _ = sampler.sample(S=steps,
54 | conditioning=c,
55 | batch_size=1,
56 | shape=shape,
57 | verbose=False,
58 | unconditional_guidance_scale=scale,
59 | unconditional_conditioning=uc,
60 | eta=0.0,
61 | x_T=img,
62 | quiet=True,
63 | mask_cond = mask_cond,
64 | save_attn_maps=True)
65 | x_samples_ddim = model.decode_first_stage(samples_ddim)
66 | x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
67 | x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
68 |
69 | x_checked_image = x_samples_ddim
70 | x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
71 | all_samples.append(x_checked_image_torch)
72 | return all_samples
73 |
74 | def diff(t1, t2):
75 | combined = torch.cat((t1, t2))
76 | uniques, counts = combined.unique(return_counts=True)
77 | difference = uniques[counts == 1]
78 | return difference
79 |
80 | def intersection(t1, t2):
81 | i = np.intersect1d(t1, t2)
82 | return torch.from_numpy(i)
83 |
84 | def block(value, scale_factor=4):
85 | vs = []
86 | for v in value:
87 | e = torch.zeros(256)
88 | e[v] = 1
89 | e = rearrange(e, '(w h)-> w h', w=16)
90 | e_resized = F.interpolate(e.reshape(1,1,16,16), scale_factor=scale_factor)[0][0]
91 | e_resized = rearrange(e_resized, 'w h -> (w h)')
92 | vs.append(torch.where(e_resized==1)[0])
93 | return vs
94 |
95 | def image_to_blocks(img):
96 | # input: [1,4,64,64] image
97 | # output: list of blocks, lenth is 256
98 | # block : [1, 4, 4, 4]
99 | blocks = []
100 | for i in range(16):
101 | for j in range(16):
102 | block = img[:, :, i * 4: (i + 1) * 4, j * 4: (j + 1) * 4]
103 | blocks.append(block)
104 | return blocks
105 |
106 | def generate(model, sampler, img_, prompt, ind=None):
107 | mask_cond = {
108 | 'is_use': False,
109 | }
110 | ddim_steps = 50
111 | n_samples = 1
112 | scale = 7.5
113 | all_samples = sampling(model, sampler, prompt,
114 | n_samples, scale,
115 | ddim_steps, mask_cond=mask_cond, conjunction=False, img=img_)
116 | grid = torch.stack(all_samples, 0)
117 | grid = rearrange(grid, 'n b c h w -> (n b) c h w')
118 | grid = make_grid(grid, nrow=int(np.sqrt(n_samples)))
119 | grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
120 | img = Image.fromarray(grid.astype(np.uint8))
121 | attn_maps = [item[0][0] for item in sampler.attn_maps['input_blocks.8.1.transformer_blocks.0.attn2']]
122 | maps = [torch.mean(item, axis=0) for item in attn_maps]
123 | maps = [rearrange(item, 'w h d -> d w h')[None,:] for item in maps]
124 | maps = rearrange(torch.cat(maps,dim=0), 't word w h -> word t w h')
125 | if ind is not None:
126 | plt.subplot(1, 5, 1)
127 | plt.imshow(img)
128 | plt.axis("off")
129 | plt.subplot(1, 5, 2)
130 | plt.imshow(maps[ind[0]][0],cmap = 'gray')
131 | plt.axis("off")
132 | plt.subplot(1, 5, 3)
133 | plt.imshow(maps[ind[1]][0],cmap = 'gray')
134 | plt.axis("off")
135 | plt.subplot(1, 5, 4)
136 | plt.imshow(maps[ind[0]][-1],cmap = 'gray')
137 | plt.axis("off")
138 | plt.subplot(1, 5, 5)
139 | plt.imshow(maps[ind[1]][-1],cmap = 'gray')
140 | plt.axis("off")
141 | plt.show()
142 | else:
143 | plt.imshow(img)
144 | plt.axis("off")
145 | plt.show()
146 |
147 | def preprocess_prompts(prompts):
148 | if isinstance(prompts, (list, tuple)):
149 | return [p.lower().strip().strip(".").strip() for p in prompts]
150 | elif isinstance(prompts, str):
151 | return prompts.lower().strip().strip(".").strip()
152 | else:
153 | raise NotImplementedError
154 |
155 | def intersection(t1, t2):
156 | i = np.intersect1d(t1, t2)
157 | return torch.from_numpy(i)
158 |
159 | def attention_to_score(attns):
160 | # input: [16, 16] attention maps
161 | # out put: list of score
162 | scores = []
163 | for attn in attns:
164 | scores.append(rearrange(attn, 'w h -> (w h)').tolist())
165 | return scores
166 |
167 | def score_normalize(scores):
168 | std = torch.std(scores, unbiased=False)
169 | mean = torch.mean(scores)
170 | scores = (scores - mean)/std
171 | scores = (scores - torch.min(scores)) / (torch.max(scores) - torch.min(scores))
172 | return scores
173 |
174 |
175 | class pixel_block_base:
176 | def __init__(self, model, sampler, labels):
177 | self.model = model
178 | self.sampler = sampler
179 | self.H = 512
180 | self.W = 512
181 | self.C = 4
182 | self.f = 8
183 | self.normalize = False
184 | self.shape = [1, self.C, self.H // self.f, self.W // self.f]
185 | self.cond = {'is_use': False}
186 |
187 | self.base = {}
188 | self.base['blocks'] = []
189 | for w in labels:
190 | self.base[w] = {}
191 |
192 | self.labels = labels
193 | self.combinations = list(combinations(labels, 2))
194 | for pair in self.combinations:
195 | self.base[pair[0]][pair[1]] = torch.tensor([])
196 | self.base[pair[1]][pair[0]] = torch.tensor([])
197 |
198 | self.prompt = []
199 | for pair in self.combinations:
200 | self.prompt.append('a ' + pair[0] + ' and a ' + pair[1] + '.')
201 |
202 | def _get_attention(self, prompt, img, scale=7.5, steps=50):
203 | precision_scope = autocast
204 | with torch.no_grad():
205 | with precision_scope("cuda"):
206 | with self.model.ema_scope():
207 | for bid, p in enumerate(prompt):
208 | p = preprocess_prompts(p)
209 | uc = self.model.get_learned_conditioning([""])
210 | kv = self.model.get_learned_conditioning(p[0])
211 | c = {'k':[kv], 'v': [kv]}
212 | shape = [self.C, self.H // self.f, self.W // self.f]
213 | self.sampler.get_attention(S=steps,
214 | conditioning=c,
215 | batch_size=1,
216 | shape=shape,
217 | verbose=False,
218 | unconditional_guidance_scale=scale,
219 | unconditional_conditioning=uc,
220 | eta=0.0,
221 | x_T=img,
222 | quiet=True,
223 | mask_cond=self.cond,
224 | save_attn_maps=True)
225 | all_attn_maps = [item[0][0] for item in self.sampler.attn_maps['input_blocks.8.1.transformer_blocks.0.attn2']]
226 | avg_maps = [torch.mean(item, axis=0) for item in all_attn_maps]
227 | avg_maps = [rearrange(item, 'w h d -> d w h')[None,:] for item in avg_maps]
228 | avg_maps = rearrange(torch.cat(avg_maps,dim=0), 't word w h -> word t w h')
229 | return avg_maps
230 |
231 |
232 | def make_base(self, n_img):
233 | for i in range(n_img):
234 | img_ = torch.randn(self.shape).cuda()
235 | _blocks = image_to_blocks(img_.clone().cpu())
236 | self.base['blocks'].append(torch.stack(_blocks, dim=0))
237 | for p in range(len(self.prompt)):
238 | caption = [[self.prompt[p]]]
239 | w = self.combinations[p]
240 |
241 | avg_maps = self._get_attention(caption, img_)
242 |
243 | m = [avg_maps[2][0]]
244 | _scores = attention_to_score(m)[0]
245 | self.base[w[0]][w[1]] = torch.cat([self.base[w[0]][w[1]], score_normalize(torch.tensor(_scores))], dim=0)
246 | m = [avg_maps[5][0]]
247 | _scores = attention_to_score(m)[0]
248 | self.base[w[1]][w[0]] = torch.cat([self.base[w[1]][w[0]], score_normalize(torch.tensor(_scores))], dim=0)
249 | self.base['blocks'] = torch.cat(self.base['blocks'], dim=0)
250 |
251 | # 对于每个类别,算一次总分, a 在与 b,c,d。。等类别同时出现时的得分相加
252 | for k in self.labels:
253 | avg = 0
254 | counter = 0
255 | for w in self.base[k].keys():
256 | counter = counter + 1
257 | avg = avg + self.base[k][w]
258 | self.base[k]['average'] = avg / counter
259 |
260 | def make_base_by_list(self, n_img, comb_list):
261 | # comb_list : ['a', 'b']
262 |
263 | for i in trange(n_img):
264 | img_ = torch.randn(self.shape).cuda()
265 | _blocks = image_to_blocks(img_.clone().cpu())
266 | self.base['blocks'].append(torch.stack(_blocks, dim=0))
267 | for p in trange(len(comb_list)):
268 | caption = [['A ' + comb_list[p][0] + ' and a ' + comb_list[p][1]+ '.'] ]
269 | w = comb_list[p]
270 | avg_maps = self._get_attention(caption, img_)
271 |
272 | m = [avg_maps[2][0]]
273 | _scores = attention_to_score(m)[0]
274 | self.base[w[0]][w[1]] = torch.cat([self.base[w[0]][w[1]], score_normalize(torch.tensor(_scores))], dim=0)
275 |
276 | m = [avg_maps[5][0]]
277 | _scores = attention_to_score(m)[0]
278 | self.base[w[1]][w[0]] = torch.cat([self.base[w[1]][w[0]], score_normalize(torch.tensor(_scores))], dim=0)
279 | self.base['blocks'] = torch.cat(self.base['blocks'], dim=0)
280 |
281 |
282 |
283 | def generate_region_mask(self, regions):
284 | region_mask = []
285 | for r in regions:
286 | z = torch.zeros(16,16)
287 | z[r[1]:r[3], r[0]:r[2]] = 1
288 | region_mask.append(z)
289 | region_mask = rearrange(torch.stack(region_mask, dim = 0), 'n w h -> n (w h)')
290 | if len(regions) == 2:
291 | mask_1 = torch.where(region_mask[0] == 1)[0]
292 | mask_2 = torch.where(region_mask[1] == 1)[0]
293 | if mask_1.shape[0] > mask_2.shape[0]:
294 | region_mask[0][intersection(mask_1, mask_2)] = 0
295 | else:
296 | region_mask[1][intersection(mask_1, mask_2)] = 0
297 | if len(regions) == 3:
298 | mask_1 = torch.where(region_mask[0] == 1)[0]
299 | mask_2 = torch.where(region_mask[1] == 1)[0]
300 | mask_3 = torch.where(region_mask[2] == 1)[0]
301 | if mask_1.shape[0] > mask_2.shape[0]:
302 | region_mask[0][intersection(mask_1, mask_2)] = 0
303 | else:
304 | region_mask[1][intersection(mask_1, mask_2)] = 0
305 |
306 | if mask_1.shape[0] > mask_3.shape[0]:
307 | region_mask[0][intersection(mask_1, mask_3)] = 0
308 | else:
309 | region_mask[2][intersection(mask_1, mask_3)] = 0
310 |
311 | if mask_2.shape[0] > mask_3.shape[0]:
312 | region_mask[1][intersection(mask_2, mask_3)] = 0
313 | else:
314 | region_mask[2][intersection(mask_2, mask_3)] = 0
315 |
316 | bg_mask = 1 - region_mask.sum(axis=0)
317 | return region_mask, bg_mask
318 |
319 | def make_img(self, region_mask, bg_mask, obj_blocks, bg_blocks, recalibration):
320 | # input : masks indicating locations and blocks selected for corresponding contents
321 | img = torch.zeros([1,4,64,64])
322 | for i in range(region_mask.shape[0]):
323 | num = torch.where(region_mask[i] != 0)[0].shape[0]
324 | r = rearrange(region_mask[i], '(w h) -> w h', w=16)
325 | total_num = obj_blocks[i].shape[0]
326 | if num > total_num:
327 | sampled_index = []
328 | while num > total_num:
329 | sampled_index = sampled_index + random.sample(range(0, total_num), total_num)
330 | num = num - total_num
331 | sampled_index = sampled_index + random.sample(range(0, total_num), num)
332 | else:
333 | sampled_index = random.sample(range(0, total_num), num)
334 | positions = (r == 1).nonzero(as_tuple=False)
335 | selected_blocks_obj = obj_blocks[i][sampled_index]
336 |
337 | # recalibration
338 | if recalibration:
339 | print(selected_blocks_obj.mean())
340 | print((selected_blocks_obj - selected_blocks_obj.mean()).pow(2).mean())
341 |
342 | for j in range(len(positions)):
343 | p, q = positions[j]
344 | img[:, :, 4 * p : 4 * p + 4, 4 * q : 4 * q + 4] = selected_blocks_obj[j]
345 |
346 | bg_num = torch.where(bg_mask != 0)[0].shape[0]
347 | r = rearrange(bg_mask, '(w h) -> w h', w=16)
348 | bg_total_num = bg_blocks.shape[0]
349 | if bg_num > bg_total_num:
350 | sampled_index = []
351 | while bg_num > bg_total_num:
352 | sampled_index = sampled_index + random.sample(range(0, bg_total_num), bg_total_num)
353 | bg_num = bg_num - bg_total_num
354 | sampled_index = sampled_index + random.sample(range(0, bg_total_num), bg_num)
355 | else:
356 | sampled_index = random.sample(range(0, bg_total_num), bg_num)
357 | bg_positions = (r == 1).nonzero(as_tuple=False)
358 | selected_blocks_bg = bg_blocks[sampled_index]
359 | for i in range(len(bg_positions)):
360 | p, q = bg_positions[i]
361 | img[:, :, 4 * p : 4 * p + 4, 4 * q : 4 * q + 4] = selected_blocks_bg[i]
362 | return img
363 |
364 | def product_image(self, words, regions, t_pos_1 = 0.5, t_bg_1 = 0.3, t_pos_2 = 0.3, t_neg_2 = 0.3, t_bg_2 = 0.1, recalibration = False):
365 | if len(words) == 1:
366 | # Fetch pre-collected pixel blocks and their scores
367 | word = words[0]
368 | scores = self.base[word]['average']
369 | blocks = self.base['blocks']
370 | # select blocks for obj
371 | # threshold [TO DO: sort?]
372 | blocks_index = torch.where(scores > t_pos_1)[0].numpy()
373 | # select blocks for bg
374 | blocks_index_bg = torch.where(scores < t_bg_1)[0].numpy()
375 | obj_blocks = [blocks[blocks_index]]
376 | bg_blocks = blocks[blocks_index_bg]
377 | region_mask, bg_mask = self.generate_region_mask(regions)
378 | img = self.make_img(region_mask, bg_mask, obj_blocks, bg_blocks, recalibration)
379 |
380 | elif len(words) == 2:
381 | score_1 = self.base[words[0]][words[1]]
382 | score_2 = self.base[words[1]][words[0]]
383 | blocks = self.base['blocks']
384 | # for class 1 :
385 | blocks_index_1 = intersection(torch.where(score_1 > t_pos_2)[0].numpy(), torch.where(score_2 < t_neg_2)[0].numpy())
386 | # for class 2 :
387 | blocks_index_2 = intersection(torch.where(score_2 > t_pos_2)[0].numpy(), torch.where(score_1 < t_neg_2)[0].numpy())
388 | # for background :
389 | blocks_index_bg = intersection(torch.where(score_1 < t_bg_2)[0].numpy(), torch.where(score_2 < t_bg_2)[0].numpy())
390 | obj_blocks = [blocks[blocks_index_1], blocks[blocks_index_2]]
391 | bg_blocks = blocks[blocks_index_bg]
392 | region_mask, bg_mask = self.generate_region_mask(regions)
393 | img = self.make_img(region_mask, bg_mask, obj_blocks, bg_blocks, recalibration)
394 |
395 | elif len(words) == 3:
396 | score_1_2 = self.base[words[0]][words[1]]
397 | score_1_3 = self.base[words[0]][words[2]]
398 | # for class 1 :
399 | blocks_index_1 = intersection(torch.where(score_1_2 > t_pos_2)[0].numpy(), torch.where(score_1_3 > t_pos_2)[0].numpy())
400 |
401 | score_2_1 = self.base[words[1]][words[0]]
402 | score_2_3 = self.base[words[1]][words[2]]
403 | # for class 2 :
404 | blocks_index_2 = intersection(torch.where(score_2_1 > t_pos_2)[0].numpy(), torch.where(score_2_3 > t_pos_2)[0].numpy())
405 |
406 | score_3_1 = self.base[words[2]][words[0]]
407 | score_3_2 = self.base[words[2]][words[1]]
408 | # for class 2 :
409 | blocks_index_3 = intersection(torch.where(score_3_1 > t_pos_2)[0].numpy(), torch.where(score_3_2 > t_pos_2)[0].numpy())
410 |
411 | blocks = self.base['blocks']
412 | score_1 = self.base[words[0]]['average']
413 | score_2 = self.base[words[1]]['average']
414 | score_3 = self.base[words[2]]['average']
415 |
416 | # for background :
417 | blocks_index_bg = intersection(torch.where(score_1 < t_bg_2)[0].numpy(), torch.where(score_2 < t_bg_2)[0].numpy())
418 | blocks_index_bg = intersection(blocks_index_bg, torch.where(score_3 < t_bg_2)[0].numpy())
419 |
420 | obj_blocks = [blocks[blocks_index_1], blocks[blocks_index_2], blocks[blocks_index_3]]
421 | bg_blocks = blocks[blocks_index_bg]
422 | region_mask, bg_mask = self.generate_region_mask(regions)
423 | img = self.make_img(region_mask, bg_mask, obj_blocks, bg_blocks, recalibration)
424 | return img, region_mask
--------------------------------------------------------------------------------
/ldm/modules/x_transformer.py:
--------------------------------------------------------------------------------
1 | """shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
2 | import torch
3 | from torch import nn, einsum
4 | import torch.nn.functional as F
5 | from functools import partial
6 | from inspect import isfunction
7 | from collections import namedtuple
8 | from einops import rearrange, repeat, reduce
9 |
10 | # constants
11 |
12 | DEFAULT_DIM_HEAD = 64
13 |
14 | Intermediates = namedtuple('Intermediates', [
15 | 'pre_softmax_attn',
16 | 'post_softmax_attn'
17 | ])
18 |
19 | LayerIntermediates = namedtuple('Intermediates', [
20 | 'hiddens',
21 | 'attn_intermediates'
22 | ])
23 |
24 |
25 | class AbsolutePositionalEmbedding(nn.Module):
26 | def __init__(self, dim, max_seq_len):
27 | super().__init__()
28 | self.emb = nn.Embedding(max_seq_len, dim)
29 | self.init_()
30 |
31 | def init_(self):
32 | nn.init.normal_(self.emb.weight, std=0.02)
33 |
34 | def forward(self, x):
35 | n = torch.arange(x.shape[1], device=x.device)
36 | return self.emb(n)[None, :, :]
37 |
38 |
39 | class FixedPositionalEmbedding(nn.Module):
40 | def __init__(self, dim):
41 | super().__init__()
42 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
43 | self.register_buffer('inv_freq', inv_freq)
44 |
45 | def forward(self, x, seq_dim=1, offset=0):
46 | t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
47 | sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
48 | emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
49 | return emb[None, :, :]
50 |
51 |
52 | # helpers
53 |
54 | def exists(val):
55 | return val is not None
56 |
57 |
58 | def default(val, d):
59 | if exists(val):
60 | return val
61 | return d() if isfunction(d) else d
62 |
63 |
64 | def always(val):
65 | def inner(*args, **kwargs):
66 | return val
67 | return inner
68 |
69 |
70 | def not_equals(val):
71 | def inner(x):
72 | return x != val
73 | return inner
74 |
75 |
76 | def equals(val):
77 | def inner(x):
78 | return x == val
79 | return inner
80 |
81 |
82 | def max_neg_value(tensor):
83 | return -torch.finfo(tensor.dtype).max
84 |
85 |
86 | # keyword argument helpers
87 |
88 | def pick_and_pop(keys, d):
89 | values = list(map(lambda key: d.pop(key), keys))
90 | return dict(zip(keys, values))
91 |
92 |
93 | def group_dict_by_key(cond, d):
94 | return_val = [dict(), dict()]
95 | for key in d.keys():
96 | match = bool(cond(key))
97 | ind = int(not match)
98 | return_val[ind][key] = d[key]
99 | return (*return_val,)
100 |
101 |
102 | def string_begins_with(prefix, str):
103 | return str.startswith(prefix)
104 |
105 |
106 | def group_by_key_prefix(prefix, d):
107 | return group_dict_by_key(partial(string_begins_with, prefix), d)
108 |
109 |
110 | def groupby_prefix_and_trim(prefix, d):
111 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
112 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
113 | return kwargs_without_prefix, kwargs
114 |
115 |
116 | # classes
117 | class Scale(nn.Module):
118 | def __init__(self, value, fn):
119 | super().__init__()
120 | self.value = value
121 | self.fn = fn
122 |
123 | def forward(self, x, **kwargs):
124 | x, *rest = self.fn(x, **kwargs)
125 | return (x * self.value, *rest)
126 |
127 |
128 | class Rezero(nn.Module):
129 | def __init__(self, fn):
130 | super().__init__()
131 | self.fn = fn
132 | self.g = nn.Parameter(torch.zeros(1))
133 |
134 | def forward(self, x, **kwargs):
135 | x, *rest = self.fn(x, **kwargs)
136 | return (x * self.g, *rest)
137 |
138 |
139 | class ScaleNorm(nn.Module):
140 | def __init__(self, dim, eps=1e-5):
141 | super().__init__()
142 | self.scale = dim ** -0.5
143 | self.eps = eps
144 | self.g = nn.Parameter(torch.ones(1))
145 |
146 | def forward(self, x):
147 | norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
148 | return x / norm.clamp(min=self.eps) * self.g
149 |
150 |
151 | class RMSNorm(nn.Module):
152 | def __init__(self, dim, eps=1e-8):
153 | super().__init__()
154 | self.scale = dim ** -0.5
155 | self.eps = eps
156 | self.g = nn.Parameter(torch.ones(dim))
157 |
158 | def forward(self, x):
159 | norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
160 | return x / norm.clamp(min=self.eps) * self.g
161 |
162 |
163 | class Residual(nn.Module):
164 | def forward(self, x, residual):
165 | return x + residual
166 |
167 |
168 | class GRUGating(nn.Module):
169 | def __init__(self, dim):
170 | super().__init__()
171 | self.gru = nn.GRUCell(dim, dim)
172 |
173 | def forward(self, x, residual):
174 | gated_output = self.gru(
175 | rearrange(x, 'b n d -> (b n) d'),
176 | rearrange(residual, 'b n d -> (b n) d')
177 | )
178 |
179 | return gated_output.reshape_as(x)
180 |
181 |
182 | # feedforward
183 |
184 | class GEGLU(nn.Module):
185 | def __init__(self, dim_in, dim_out):
186 | super().__init__()
187 | self.proj = nn.Linear(dim_in, dim_out * 2)
188 |
189 | def forward(self, x):
190 | x, gate = self.proj(x).chunk(2, dim=-1)
191 | return x * F.gelu(gate)
192 |
193 |
194 | class FeedForward(nn.Module):
195 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
196 | super().__init__()
197 | inner_dim = int(dim * mult)
198 | dim_out = default(dim_out, dim)
199 | project_in = nn.Sequential(
200 | nn.Linear(dim, inner_dim),
201 | nn.GELU()
202 | ) if not glu else GEGLU(dim, inner_dim)
203 |
204 | self.net = nn.Sequential(
205 | project_in,
206 | nn.Dropout(dropout),
207 | nn.Linear(inner_dim, dim_out)
208 | )
209 |
210 | def forward(self, x):
211 | return self.net(x)
212 |
213 |
214 | # attention.
215 | class Attention(nn.Module):
216 | def __init__(
217 | self,
218 | dim,
219 | dim_head=DEFAULT_DIM_HEAD,
220 | heads=8,
221 | causal=False,
222 | mask=None,
223 | talking_heads=False,
224 | sparse_topk=None,
225 | use_entmax15=False,
226 | num_mem_kv=0,
227 | dropout=0.,
228 | on_attn=False
229 | ):
230 | super().__init__()
231 | if use_entmax15:
232 | raise NotImplementedError("Check out entmax activation instead of softmax activation!")
233 | self.scale = dim_head ** -0.5
234 | self.heads = heads
235 | self.causal = causal
236 | self.mask = mask
237 |
238 | inner_dim = dim_head * heads
239 |
240 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
241 | self.to_k = nn.Linear(dim, inner_dim, bias=False)
242 | self.to_v = nn.Linear(dim, inner_dim, bias=False)
243 | self.dropout = nn.Dropout(dropout)
244 |
245 | # talking heads
246 | self.talking_heads = talking_heads
247 | if talking_heads:
248 | self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
249 | self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
250 |
251 | # explicit topk sparse attention
252 | self.sparse_topk = sparse_topk
253 |
254 | # entmax
255 | #self.attn_fn = entmax15 if use_entmax15 else F.softmax
256 | self.attn_fn = F.softmax
257 |
258 | # add memory key / values
259 | self.num_mem_kv = num_mem_kv
260 | if num_mem_kv > 0:
261 | self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
262 | self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
263 |
264 | # attention on attention
265 | self.attn_on_attn = on_attn
266 | self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
267 |
268 | def forward(
269 | self,
270 | x,
271 | context=None,
272 | mask=None,
273 | context_mask=None,
274 | rel_pos=None,
275 | sinusoidal_emb=None,
276 | prev_attn=None,
277 | mem=None
278 | ):
279 | b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
280 | kv_input = default(context, x)
281 |
282 | q_input = x
283 | k_input = kv_input
284 | v_input = kv_input
285 |
286 | if exists(mem):
287 | k_input = torch.cat((mem, k_input), dim=-2)
288 | v_input = torch.cat((mem, v_input), dim=-2)
289 |
290 | if exists(sinusoidal_emb):
291 | # in shortformer, the query would start at a position offset depending on the past cached memory
292 | offset = k_input.shape[-2] - q_input.shape[-2]
293 | q_input = q_input + sinusoidal_emb(q_input, offset=offset)
294 | k_input = k_input + sinusoidal_emb(k_input)
295 |
296 | q = self.to_q(q_input)
297 | k = self.to_k(k_input)
298 | v = self.to_v(v_input)
299 |
300 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
301 |
302 | input_mask = None
303 | if any(map(exists, (mask, context_mask))):
304 | q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
305 | k_mask = q_mask if not exists(context) else context_mask
306 | k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
307 | q_mask = rearrange(q_mask, 'b i -> b () i ()')
308 | k_mask = rearrange(k_mask, 'b j -> b () () j')
309 | input_mask = q_mask * k_mask
310 |
311 | if self.num_mem_kv > 0:
312 | mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
313 | k = torch.cat((mem_k, k), dim=-2)
314 | v = torch.cat((mem_v, v), dim=-2)
315 | if exists(input_mask):
316 | input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
317 |
318 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
319 | mask_value = max_neg_value(dots)
320 |
321 | if exists(prev_attn):
322 | dots = dots + prev_attn
323 |
324 | pre_softmax_attn = dots
325 |
326 | if talking_heads:
327 | dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
328 |
329 | if exists(rel_pos):
330 | dots = rel_pos(dots)
331 |
332 | if exists(input_mask):
333 | dots.masked_fill_(~input_mask, mask_value)
334 | del input_mask
335 |
336 | if self.causal:
337 | i, j = dots.shape[-2:]
338 | r = torch.arange(i, device=device)
339 | mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
340 | mask = F.pad(mask, (j - i, 0), value=False)
341 | dots.masked_fill_(mask, mask_value)
342 | del mask
343 |
344 | if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
345 | top, _ = dots.topk(self.sparse_topk, dim=-1)
346 | vk = top[..., -1].unsqueeze(-1).expand_as(dots)
347 | mask = dots < vk
348 | dots.masked_fill_(mask, mask_value)
349 | del mask
350 |
351 | attn = self.attn_fn(dots, dim=-1)
352 | post_softmax_attn = attn
353 |
354 | attn = self.dropout(attn)
355 |
356 | if talking_heads:
357 | attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
358 |
359 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
360 | out = rearrange(out, 'b h n d -> b n (h d)')
361 |
362 | intermediates = Intermediates(
363 | pre_softmax_attn=pre_softmax_attn,
364 | post_softmax_attn=post_softmax_attn
365 | )
366 |
367 | return self.to_out(out), intermediates
368 |
369 |
370 | class AttentionLayers(nn.Module):
371 | def __init__(
372 | self,
373 | dim,
374 | depth,
375 | heads=8,
376 | causal=False,
377 | cross_attend=False,
378 | only_cross=False,
379 | use_scalenorm=False,
380 | use_rmsnorm=False,
381 | use_rezero=False,
382 | rel_pos_num_buckets=32,
383 | rel_pos_max_distance=128,
384 | position_infused_attn=False,
385 | custom_layers=None,
386 | sandwich_coef=None,
387 | par_ratio=None,
388 | residual_attn=False,
389 | cross_residual_attn=False,
390 | macaron=False,
391 | pre_norm=True,
392 | gate_residual=False,
393 | **kwargs
394 | ):
395 | super().__init__()
396 | ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
397 | attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
398 |
399 | dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
400 |
401 | self.dim = dim
402 | self.depth = depth
403 | self.layers = nn.ModuleList([])
404 |
405 | self.has_pos_emb = position_infused_attn
406 | self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
407 | self.rotary_pos_emb = always(None)
408 |
409 | assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
410 | self.rel_pos = None
411 |
412 | self.pre_norm = pre_norm
413 |
414 | self.residual_attn = residual_attn
415 | self.cross_residual_attn = cross_residual_attn
416 |
417 | norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
418 | norm_class = RMSNorm if use_rmsnorm else norm_class
419 | norm_fn = partial(norm_class, dim)
420 |
421 | norm_fn = nn.Identity if use_rezero else norm_fn
422 | branch_fn = Rezero if use_rezero else None
423 |
424 | if cross_attend and not only_cross:
425 | default_block = ('a', 'c', 'f')
426 | elif cross_attend and only_cross:
427 | default_block = ('c', 'f')
428 | else:
429 | default_block = ('a', 'f')
430 |
431 | if macaron:
432 | default_block = ('f',) + default_block
433 |
434 | if exists(custom_layers):
435 | layer_types = custom_layers
436 | elif exists(par_ratio):
437 | par_depth = depth * len(default_block)
438 | assert 1 < par_ratio <= par_depth, 'par ratio out of range'
439 | default_block = tuple(filter(not_equals('f'), default_block))
440 | par_attn = par_depth // par_ratio
441 | depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
442 | par_width = (depth_cut + depth_cut // par_attn) // par_attn
443 | assert len(default_block) <= par_width, 'default block is too large for par_ratio'
444 | par_block = default_block + ('f',) * (par_width - len(default_block))
445 | par_head = par_block * par_attn
446 | layer_types = par_head + ('f',) * (par_depth - len(par_head))
447 | elif exists(sandwich_coef):
448 | assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
449 | layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
450 | else:
451 | layer_types = default_block * depth
452 |
453 | self.layer_types = layer_types
454 | self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
455 |
456 | for layer_type in self.layer_types:
457 | if layer_type == 'a':
458 | layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
459 | elif layer_type == 'c':
460 | layer = Attention(dim, heads=heads, **attn_kwargs)
461 | elif layer_type == 'f':
462 | layer = FeedForward(dim, **ff_kwargs)
463 | layer = layer if not macaron else Scale(0.5, layer)
464 | else:
465 | raise Exception(f'invalid layer type {layer_type}')
466 |
467 | if isinstance(layer, Attention) and exists(branch_fn):
468 | layer = branch_fn(layer)
469 |
470 | if gate_residual:
471 | residual_fn = GRUGating(dim)
472 | else:
473 | residual_fn = Residual()
474 |
475 | self.layers.append(nn.ModuleList([
476 | norm_fn(),
477 | layer,
478 | residual_fn
479 | ]))
480 |
481 | def forward(
482 | self,
483 | x,
484 | context=None,
485 | mask=None,
486 | context_mask=None,
487 | mems=None,
488 | return_hiddens=False
489 | ):
490 | hiddens = []
491 | intermediates = []
492 | prev_attn = None
493 | prev_cross_attn = None
494 |
495 | mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
496 |
497 | for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
498 | is_last = ind == (len(self.layers) - 1)
499 |
500 | if layer_type == 'a':
501 | hiddens.append(x)
502 | layer_mem = mems.pop(0)
503 |
504 | residual = x
505 |
506 | if self.pre_norm:
507 | x = norm(x)
508 |
509 | if layer_type == 'a':
510 | out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
511 | prev_attn=prev_attn, mem=layer_mem)
512 | elif layer_type == 'c':
513 | out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
514 | elif layer_type == 'f':
515 | out = block(x)
516 |
517 | x = residual_fn(out, residual)
518 |
519 | if layer_type in ('a', 'c'):
520 | intermediates.append(inter)
521 |
522 | if layer_type == 'a' and self.residual_attn:
523 | prev_attn = inter.pre_softmax_attn
524 | elif layer_type == 'c' and self.cross_residual_attn:
525 | prev_cross_attn = inter.pre_softmax_attn
526 |
527 | if not self.pre_norm and not is_last:
528 | x = norm(x)
529 |
530 | if return_hiddens:
531 | intermediates = LayerIntermediates(
532 | hiddens=hiddens,
533 | attn_intermediates=intermediates
534 | )
535 |
536 | return x, intermediates
537 |
538 | return x
539 |
540 |
541 | class Encoder(AttentionLayers):
542 | def __init__(self, **kwargs):
543 | assert 'causal' not in kwargs, 'cannot set causality on encoder'
544 | super().__init__(causal=False, **kwargs)
545 |
546 |
547 |
548 | class TransformerWrapper(nn.Module):
549 | def __init__(
550 | self,
551 | *,
552 | num_tokens,
553 | max_seq_len,
554 | attn_layers,
555 | emb_dim=None,
556 | max_mem_len=0.,
557 | emb_dropout=0.,
558 | num_memory_tokens=None,
559 | tie_embedding=False,
560 | use_pos_emb=True
561 | ):
562 | super().__init__()
563 | assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
564 |
565 | dim = attn_layers.dim
566 | emb_dim = default(emb_dim, dim)
567 |
568 | self.max_seq_len = max_seq_len
569 | self.max_mem_len = max_mem_len
570 | self.num_tokens = num_tokens
571 |
572 | self.token_emb = nn.Embedding(num_tokens, emb_dim)
573 | self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
574 | use_pos_emb and not attn_layers.has_pos_emb) else always(0)
575 | self.emb_dropout = nn.Dropout(emb_dropout)
576 |
577 | self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
578 | self.attn_layers = attn_layers
579 | self.norm = nn.LayerNorm(dim)
580 |
581 | self.init_()
582 |
583 | self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
584 |
585 | # memory tokens (like [cls]) from Memory Transformers paper
586 | num_memory_tokens = default(num_memory_tokens, 0)
587 | self.num_memory_tokens = num_memory_tokens
588 | if num_memory_tokens > 0:
589 | self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
590 |
591 | # let funnel encoder know number of memory tokens, if specified
592 | if hasattr(attn_layers, 'num_memory_tokens'):
593 | attn_layers.num_memory_tokens = num_memory_tokens
594 |
595 | def init_(self):
596 | nn.init.normal_(self.token_emb.weight, std=0.02)
597 |
598 | def forward(
599 | self,
600 | x,
601 | return_embeddings=False,
602 | mask=None,
603 | return_mems=False,
604 | return_attn=False,
605 | mems=None,
606 | **kwargs
607 | ):
608 | b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
609 | x = self.token_emb(x)
610 | x += self.pos_emb(x)
611 | x = self.emb_dropout(x)
612 |
613 | x = self.project_emb(x)
614 |
615 | if num_mem > 0:
616 | mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
617 | x = torch.cat((mem, x), dim=1)
618 |
619 | # auto-handle masking after appending memory tokens
620 | if exists(mask):
621 | mask = F.pad(mask, (num_mem, 0), value=True)
622 |
623 | x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
624 | x = self.norm(x)
625 |
626 | mem, x = x[:, :num_mem], x[:, num_mem:]
627 |
628 | out = self.to_logits(x) if not return_embeddings else x
629 |
630 | if return_mems:
631 | hiddens = intermediates.hiddens
632 | new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
633 | new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
634 | return out, new_mems
635 |
636 | if return_attn:
637 | attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
638 | return out, attn_maps
639 |
640 | return out
641 |
642 |
--------------------------------------------------------------------------------