').css({
54 | position: 'absolute',
55 | left: Math.min(x1, x2) + 'px',
56 | top: Math.min(y1, y2) + 'px',
57 | width: Math.abs(x2 - x1) + 'px',
58 | height: Math.abs(y2 - y1) + 'px',
59 | border: '2px solid ' + color
60 | });
61 |
62 | var colorbox = $('
').css({
63 | position: 'absolute',
64 | left: '0px',
65 | top: '0px',
66 | width: '50px',
67 | height: '20px',
68 | background: 'none',
69 | color: color,
70 | 'text-align': 'center',
71 | 'font-size': '12px',
72 | 'line-height': '20px'
73 | }).text(class_name);
74 |
75 | rect.append(colorbox);
76 | var rect_config = {
77 | x1: Math.max(0, Math.min(x1, x2)),
78 | y1: Math.max(0, Math.min(y1, y2)),
79 | x2: Math.min(511, Math.max(x1, x2)),
80 | y2: Math.min(511, Math.max(y1, y2)),
81 | color: color,
82 | class: class_name
83 | };
84 |
85 | return {rect, rect_config};
86 | }
87 |
88 | function drawRectangle(rect) {
89 | $('#canvas').append(rect);
90 | }
91 |
92 | function updateRectangle(rect, x1, y1, x2, y2) {
93 | rect.css({
94 | left: Math.min(x1, x2) + 'px',
95 | top: Math.min(y1, y2) + 'px',
96 | width: Math.abs(x2 - x1) + 'px',
97 | height: Math.abs(y2 - y1) + 'px'
98 | });
99 |
100 | return rect;
101 | }
102 |
103 | function redrawCanvas() {
104 | $('#canvas').empty();
105 | for (var i = 0; i < rectangles.length; i++) {
106 | var rect = rectangles[i];
107 | res = createRectangle(
108 | rect.x1, rect.y1, rect.x2, rect.y2, rect.color, rect.class
109 | );
110 | thisRect = res.rect
111 | drawRectangle(thisRect)
112 | }
113 | }
114 |
115 | $(function() {
116 | var isDrawing = false;
117 |
118 | $('#canvas').mousedown(function(event) {
119 | startX = event.offsetX;
120 | startY = event.offsetY;
121 | isDrawing = true;
122 | res = createRectangle(startX, startY, startX, startY, color);
123 | currentRect = res.rect;
124 | currentRectConfig = res.rect_config;
125 | });
126 |
127 | $('#canvas').mousemove(function(event) {
128 | if (isDrawing) {
129 | // clear the canvas and redraw any existing rectangles
130 | redrawCanvas()
131 | // get the current end point and draw a new rectangle
132 | endX = event.offsetX;
133 | endY = event.offsetY;
134 | currentRect = updateRectangle(currentRect, startX, startY, endX, endY);
135 | drawRectangle(currentRect)
136 | }
137 | });
138 |
139 | $('#canvas').mouseup(function(event) {
140 | if (isDrawing) {
141 | isDrawing = false;
142 |
143 | // get the final end point and create a new rectangle
144 | // endX = event.offsetX;
145 | // endY = event.offsetY;
146 | res = createRectangle(startX, startY, endX, endY, color);
147 | currentRect = res.rect;
148 | currentRectConfig = res.rect_config;
149 | rectangles.push(currentRectConfig);
150 | }
151 | });
152 |
153 | $('#clearbtn').click(clearCanvas);
154 | // Add event listener to Plot button
155 | $('#submit').click(getSDImages);
156 | // redraw the canvas on page load to display any existing rectangles
157 | redrawCanvas();
158 | });
159 |
160 | function generateClassSelector() {
161 | // Make a GET request to read the labels file
162 | const xhr = new XMLHttpRequest();
163 | xhr.open('GET', '/static/doc/labels.txt');
164 | xhr.onreadystatechange = function() {
165 | if (xhr.readyState === 4 && xhr.status === 200) {
166 | // Split the labels into an array
167 | const labels = xhr.responseText.trim().split('\n');
168 | const classes = {};
169 | // Loop through each label to generate a unique color and create a radio button
170 | for (let i = 1; i < labels.length; i++) {
171 | const [classId, className] = labels[i].split(': ');
172 | classes[className] = `#${(Math.random()*0xFFFFFF<<0).toString(16).padStart(6, '0')}`;
173 | const input = document.createElement('input');
174 | input.type = 'radio';
175 | input.name = 'class';
176 | input.value = className;
177 | const label = document.createElement('label');
178 | label.htmlFor = className;
179 | label.innerText = className;
180 | label.style.color = classes[className];
181 | const li = document.createElement('li');
182 | li.appendChild(input);
183 | li.appendChild(label);
184 | document.querySelector('#class-selector-list').appendChild(li);
185 | }
186 | // Attach event listener to radio buttons to update label color
187 | $("input[name='class']").on('change', function() {
188 | const className = $("input[name='class']:checked").val();
189 | color = classes[className];
190 | });
191 | }
192 | };
193 | xhr.send();
194 | }
195 |
--------------------------------------------------------------------------------
/interactive_plotting/templates/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
LayoutDiffuse Interactive Plotter
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 | LayoutDiffuse Interactive Plotter
13 |
14 |
15 |
16 |
17 |
Select a class:
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import json
5 | from torch.utils.data import DataLoader
6 | from pytorch_lightning import Trainer, seed_everything
7 | from data import get_dataset
8 | from train_utils import get_models, get_DDPM, get_logger_and_callbacks
9 |
10 | if __name__ == '__main__':
11 | seed_everything(42)
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument(
14 | '-c', '--config', type=str,
15 | default='config/train.json')
16 | parser.add_argument(
17 | '-r', '--resume', action="store_true"
18 | )
19 | parser.add_argument(
20 | '-n', '--nnode', type=int, default=1
21 | )
22 |
23 | ''' parser configs '''
24 | args_raw = parser.parse_args()
25 | with open(args_raw.config, 'r') as IN:
26 | args = json.load(IN)
27 | args['resume'] = args_raw.resume
28 | args['nnode'] = args_raw.nnode
29 | expt_name = args['expt_name']
30 | expt_dir = args['expt_dir']
31 | expt_path = os.path.join(expt_dir, expt_name)
32 | os.makedirs(expt_path, exist_ok=True)
33 |
34 | '''1. create denoising model'''
35 | models = get_models(args)
36 |
37 | diffusion_configs = args['diffusion']
38 | ddpm_model = get_DDPM(
39 | diffusion_configs=diffusion_configs,
40 | log_args=args,
41 | **models
42 | )
43 |
44 | '''2. dataset and dataloader'''
45 | data_args = args['data']
46 | train_set, val_set = get_dataset(**data_args)
47 | train_loader = DataLoader(
48 | train_set, batch_size=data_args['batch_size'], shuffle=True,
49 | num_workers=4*len(args['trainer_args']['devices']), pin_memory=True
50 | )
51 | val_loader = DataLoader(
52 | val_set, batch_size=data_args['val_batch_size'],
53 | num_workers=len(args['trainer_args']['devices']), pin_memory=True
54 | )
55 | '''3. create callbacks'''
56 | wandb_logger, callbacks = get_logger_and_callbacks(expt_name, expt_path, args)
57 |
58 | '''4. trainer'''
59 | trainer_args = {
60 | "max_epochs": 1000,
61 | "accelerator": "gpu",
62 | "devices": [0],
63 | "limit_val_batches": 1,
64 | "strategy": "ddp",
65 | "check_val_every_n_epoch": 1,
66 | "num_nodes": args['nnode']
67 | # "benchmark" :True
68 | }
69 | config_trainer_args = args['trainer_args'] if args.get('trainer_args') is not None else {}
70 | trainer_args.update(config_trainer_args)
71 | print(f'Training args are {trainer_args}')
72 | trainer = Trainer(
73 | logger = wandb_logger,
74 | callbacks = callbacks,
75 | **trainer_args
76 | )
77 | '''5. start training'''
78 | if args['resume']:
79 | print('INFO: Try to resume from checkpoint')
80 | ckpt_path = os.path.join(expt_path, 'latest.ckpt')
81 | if os.path.exists(ckpt_path):
82 | print(f'INFO: Found checkpoint {ckpt_path}')
83 | # ckpt = torch.load(ckpt_path, map_location='cpu')['state_dict']
84 | # ddpm_model.load_state_dict(ckpt)
85 | else:
86 | ckpt_path = None
87 | else:
88 | ckpt_path = None
89 | trainer.fit(
90 | ddpm_model, train_loader, val_loader,
91 | ckpt_path=ckpt_path
92 | )
93 |
--------------------------------------------------------------------------------
/model_utils.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import torch
3 | import numpy as np
4 | from inspect import isfunction
5 |
6 | def instantiate_from_config(config):
7 | if not "target" in config:
8 | raise KeyError("Expected key `target` to instantiate.")
9 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
10 |
11 |
12 | def get_obj_from_str(string, reload=False):
13 | module, cls = string.rsplit(".", 1)
14 | if reload:
15 | module_imp = importlib.import_module(module)
16 | importlib.reload(module_imp)
17 | return getattr(importlib.import_module(module, package=None), cls)
18 |
19 | def exists(x):
20 | return x is not None
21 |
22 | def default(val, d):
23 | if exists(val):
24 | return val
25 | return d() if isfunction(d) else d
26 |
27 | def noise_like(shape, device, repeat=False):
28 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
29 | noise = lambda: torch.randn(shape, device=device)
30 | return repeat_noise() if repeat else noise()
31 |
32 | def extract_into_tensor(a, t, x_shape):
33 | b, *_ = t.shape
34 | out = a.gather(-1, t)
35 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
36 |
37 | def right_pad_dims_to(x, t):
38 | padding_dims = x.ndim - t.ndim
39 | if padding_dims <= 0:
40 | return t
41 | return t.view(*t.shape, *((1,) * padding_dims))
42 |
43 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
44 | if schedule == "linear":
45 | betas = (
46 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
47 | )
48 |
49 | elif schedule == "cosine":
50 | timesteps = (
51 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
52 | )
53 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
54 | alphas = torch.cos(alphas).pow(2)
55 | alphas = alphas / alphas[0]
56 | betas = 1 - alphas[1:] / alphas[:-1]
57 | betas = np.clip(betas, a_min=0, a_max=0.999)
58 |
59 | elif schedule == "sqrt_linear":
60 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
61 | elif schedule == "sqrt":
62 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
63 | else:
64 | raise ValueError(f"schedule '{schedule}' unknown.")
65 | return betas.numpy()
66 |
67 |
68 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
69 | if ddim_discr_method == 'uniform':
70 | c = num_ddpm_timesteps // num_ddim_timesteps
71 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
72 | elif ddim_discr_method == 'quad':
73 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
74 | else:
75 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
76 |
77 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
78 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
79 | steps_out = ddim_timesteps + 1
80 | if verbose:
81 | print(f'Selected timesteps for ddim sampler: {steps_out}')
82 | return steps_out
83 |
84 |
85 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
86 | # select alphas for computing the variance schedule
87 | alphas = alphacums[ddim_timesteps]
88 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
89 |
90 | # according the the formula provided in https://arxiv.org/abs/2010.02502
91 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
92 | if verbose:
93 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
94 | print(f'For the chosen value of eta, which is {eta}, '
95 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
96 | return sigmas, alphas, alphas_prev
97 |
98 |
99 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
100 | """
101 | Create a beta schedule that discretizes the given alpha_t_bar function,
102 | which defines the cumulative product of (1-beta) over time from t = [0,1].
103 | :param num_diffusion_timesteps: the number of betas to produce.
104 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
105 | produces the cumulative product of (1-beta) up to that
106 | part of the diffusion process.
107 | :param max_beta: the maximum beta to use; use values lower than 1 to
108 | prevent singularities.
109 | """
110 | betas = []
111 | for i in range(num_diffusion_timesteps):
112 | t1 = i / num_diffusion_timesteps
113 | t2 = (i + 1) / num_diffusion_timesteps
114 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
115 | return np.array(betas)
--------------------------------------------------------------------------------
/modules/bert/bert_embedder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .x_transformer import TransformerWrapper, Encoder
3 |
4 | class BERTTokenizer(torch.nn.Module):
5 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
6 | def __init__(self, device="cuda", vq_interface=True, max_length=77):
7 | super().__init__()
8 | from transformers import BertTokenizerFast # TODO: add to reuquirements
9 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
10 | self.device = device
11 | self.vq_interface = vq_interface
12 | self.max_length = max_length
13 |
14 | def forward(self, text):
15 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
16 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
17 | tokens = batch_encoding["input_ids"].to(self.device)
18 | return tokens
19 |
20 | @torch.no_grad()
21 | def encode(self, text):
22 | tokens = self(text)
23 | if not self.vq_interface:
24 | return tokens
25 | return None, None, [None, None, tokens]
26 |
27 | def decode(self, text):
28 | return text
29 |
30 |
31 | class BERTEmbedder(torch.nn.Module):
32 | """Uses the BERT tokenizr model and add some transformer encoder layers"""
33 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
34 | device="cuda",use_tokenizer=True, embedding_dropout=0.0):
35 | super().__init__()
36 | self.use_tknz_fn = use_tokenizer
37 | if self.use_tknz_fn:
38 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
39 | self.device = device
40 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
41 | attn_layers=Encoder(dim=n_embed, depth=n_layer),
42 | emb_dropout=embedding_dropout)
43 |
44 | def forward(self, text):
45 | if self.use_tknz_fn:
46 | tokens = self.tknz_fn(text)#.to(self.device)
47 | else:
48 | tokens = text
49 | z = self.transformer(tokens, return_embeddings=True)
50 | return z
51 |
52 | def encode(self, text):
53 | # output of length 77
54 | return self(text)
--------------------------------------------------------------------------------
/modules/kl_autoencoder/autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | from modules.vqvae.model import Encoder, Decoder
7 |
8 | from model_utils import instantiate_from_config
9 |
10 | class DiagonalGaussianDistribution(object):
11 | def __init__(self, parameters, deterministic=False):
12 | self.parameters = parameters
13 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
14 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
15 | self.deterministic = deterministic
16 | self.std = torch.exp(0.5 * self.logvar)
17 | self.var = torch.exp(self.logvar)
18 | if self.deterministic:
19 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
20 |
21 | def sample(self):
22 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
23 | return x
24 |
25 | def kl(self, other=None):
26 | if self.deterministic:
27 | return torch.Tensor([0.])
28 | else:
29 | if other is None:
30 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
31 | + self.var - 1.0 - self.logvar,
32 | dim=[1, 2, 3])
33 | else:
34 | return 0.5 * torch.sum(
35 | torch.pow(self.mean - other.mean, 2) / other.var
36 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
37 | dim=[1, 2, 3])
38 |
39 | def nll(self, sample, dims=[1,2,3]):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | logtwopi = np.log(2.0 * np.pi)
43 | return 0.5 * torch.sum(
44 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
45 | dim=dims)
46 |
47 | def mode(self):
48 | return self.mean
49 |
50 | class AutoencoderKL(pl.LightningModule):
51 | def __init__(self,
52 | ddconfig,
53 | lossconfig,
54 | embed_dim,
55 | ckpt_path=None,
56 | ignore_keys=[],
57 | image_key="image",
58 | colorize_nlabels=None,
59 | monitor=None,
60 | ):
61 | super().__init__()
62 | self.image_key = image_key
63 | self.encoder = Encoder(**ddconfig)
64 | self.decoder = Decoder(**ddconfig)
65 | self.loss = instantiate_from_config(lossconfig)
66 | assert ddconfig["double_z"]
67 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
68 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
69 | self.embed_dim = embed_dim
70 | if colorize_nlabels is not None:
71 | assert type(colorize_nlabels)==int
72 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
73 | if monitor is not None:
74 | self.monitor = monitor
75 | if ckpt_path is not None:
76 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
77 |
78 | def init_from_ckpt(self, path, ignore_keys=list()):
79 | sd = torch.load(path, map_location="cpu")["state_dict"]
80 | keys = list(sd.keys())
81 | for k in keys:
82 | for ik in ignore_keys:
83 | if k.startswith(ik):
84 | print("Deleting key {} from state_dict.".format(k))
85 | del sd[k]
86 | self.load_state_dict(sd, strict=False)
87 | print(f"Restored from {path}")
88 |
89 | def encode(self, x):
90 | h = self.encoder(x)
91 | moments = self.quant_conv(h)
92 | posterior = DiagonalGaussianDistribution(moments)
93 | # TODO check if need to put sample into DDIM_ldm class
94 | enc = posterior.sample()
95 | return enc #posterior
96 |
97 | def decode(self, z):
98 | z = self.post_quant_conv(z)
99 | dec = self.decoder(z)
100 | return dec
101 |
102 | def forward(self, input, sample_posterior=True):
103 | posterior = self.encode(input)
104 | if sample_posterior:
105 | z = posterior.sample()
106 | else:
107 | z = posterior.mode()
108 | dec = self.decode(z)
109 | return dec, posterior
110 |
111 | def get_input(self, batch, k):
112 | x = batch[k]
113 | if len(x.shape) == 3:
114 | x = x[..., None]
115 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
116 | return x
117 |
118 | def training_step(self, batch, batch_idx, optimizer_idx):
119 | inputs = self.get_input(batch, self.image_key)
120 | reconstructions, posterior = self(inputs)
121 |
122 | if optimizer_idx == 0:
123 | # train encoder+decoder+logvar
124 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
125 | last_layer=self.get_last_layer(), split="train")
126 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
127 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
128 | return aeloss
129 |
130 | if optimizer_idx == 1:
131 | # train the discriminator
132 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
133 | last_layer=self.get_last_layer(), split="train")
134 |
135 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
136 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
137 | return discloss
138 |
139 | def validation_step(self, batch, batch_idx):
140 | inputs = self.get_input(batch, self.image_key)
141 | reconstructions, posterior = self(inputs)
142 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
143 | last_layer=self.get_last_layer(), split="val")
144 |
145 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
146 | last_layer=self.get_last_layer(), split="val")
147 |
148 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
149 | self.log_dict(log_dict_ae)
150 | self.log_dict(log_dict_disc)
151 | return self.log_dict
152 |
153 | def configure_optimizers(self):
154 | lr = self.learning_rate
155 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
156 | list(self.decoder.parameters())+
157 | list(self.quant_conv.parameters())+
158 | list(self.post_quant_conv.parameters()),
159 | lr=lr, betas=(0.5, 0.9))
160 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
161 | lr=lr, betas=(0.5, 0.9))
162 | return [opt_ae, opt_disc], []
163 |
164 | def get_last_layer(self):
165 | return self.decoder.conv_out.weight
166 |
167 | @torch.no_grad()
168 | def log_images(self, batch, only_inputs=False, **kwargs):
169 | log = dict()
170 | x = self.get_input(batch, self.image_key)
171 | x = x.to(self.device)
172 | if not only_inputs:
173 | xrec, posterior = self(x)
174 | if x.shape[1] > 3:
175 | # colorize with random projection
176 | assert xrec.shape[1] > 3
177 | x = self.to_rgb(x)
178 | xrec = self.to_rgb(xrec)
179 | log["samples"] = self.decode(torch.randn_like(posterior.sample()))
180 | log["reconstructions"] = xrec
181 | log["inputs"] = x
182 | return log
183 |
184 | def to_rgb(self, x):
185 | assert self.image_key == "segmentation"
186 | if not hasattr(self, "colorize"):
187 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
188 | x = F.conv2d(x, weight=self.colorize)
189 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
190 | return x
--------------------------------------------------------------------------------
/modules/openai_unet/util.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 |
11 | import os
12 | import math
13 | import torch
14 | import torch.nn as nn
15 | import numpy as np
16 | from einops import repeat
17 |
18 | from model_utils import instantiate_from_config
19 |
20 |
21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22 | if schedule == "linear":
23 | betas = (
24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25 | )
26 |
27 | elif schedule == "cosine":
28 | timesteps = (
29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30 | )
31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
32 | alphas = torch.cos(alphas).pow(2)
33 | alphas = alphas / alphas[0]
34 | betas = 1 - alphas[1:] / alphas[:-1]
35 | betas = np.clip(betas, a_min=0, a_max=0.999)
36 |
37 | elif schedule == "sqrt_linear":
38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39 | elif schedule == "sqrt":
40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41 | else:
42 | raise ValueError(f"schedule '{schedule}' unknown.")
43 | return betas.numpy()
44 |
45 |
46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47 | if ddim_discr_method == 'uniform':
48 | c = num_ddpm_timesteps // num_ddim_timesteps
49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50 | elif ddim_discr_method == 'quad':
51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52 | else:
53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54 |
55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
57 | steps_out = ddim_timesteps + 1
58 | if verbose:
59 | print(f'Selected timesteps for ddim sampler: {steps_out}')
60 | return steps_out
61 |
62 |
63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64 | # select alphas for computing the variance schedule
65 | alphas = alphacums[ddim_timesteps]
66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67 |
68 | # according the the formula provided in https://arxiv.org/abs/2010.02502
69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70 | if verbose:
71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72 | print(f'For the chosen value of eta, which is {eta}, '
73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74 | return sigmas, alphas, alphas_prev
75 |
76 |
77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78 | """
79 | Create a beta schedule that discretizes the given alpha_t_bar function,
80 | which defines the cumulative product of (1-beta) over time from t = [0,1].
81 | :param num_diffusion_timesteps: the number of betas to produce.
82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83 | produces the cumulative product of (1-beta) up to that
84 | part of the diffusion process.
85 | :param max_beta: the maximum beta to use; use values lower than 1 to
86 | prevent singularities.
87 | """
88 | betas = []
89 | for i in range(num_diffusion_timesteps):
90 | t1 = i / num_diffusion_timesteps
91 | t2 = (i + 1) / num_diffusion_timesteps
92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93 | return np.array(betas)
94 |
95 |
96 | def extract_into_tensor(a, t, x_shape):
97 | b, *_ = t.shape
98 | out = a.gather(-1, t)
99 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100 |
101 |
102 | def checkpoint(func, inputs, params, flag):
103 | """
104 | Evaluate a function without caching intermediate activations, allowing for
105 | reduced memory at the expense of extra compute in the backward pass.
106 | :param func: the function to evaluate.
107 | :param inputs: the argument sequence to pass to `func`.
108 | :param params: a sequence of parameters `func` depends on but does not
109 | explicitly take as arguments.
110 | :param flag: if False, disable gradient checkpointing.
111 | """
112 | if flag:
113 | args = tuple(inputs) + tuple(params)
114 | return CheckpointFunction.apply(func, len(inputs), *args)
115 | else:
116 | return func(*inputs)
117 |
118 |
119 | class CheckpointFunction(torch.autograd.Function):
120 | @staticmethod
121 | def forward(ctx, run_function, length, *args):
122 | ctx.run_function = run_function
123 | ctx.input_tensors = list(args[:length])
124 | ctx.input_params = list(args[length:])
125 |
126 | with torch.no_grad():
127 | output_tensors = ctx.run_function(*ctx.input_tensors)
128 | return output_tensors
129 |
130 | @staticmethod
131 | def backward(ctx, *output_grads):
132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
133 | with torch.enable_grad():
134 | # Fixes a bug where the first op in run_function modifies the
135 | # Tensor storage in place, which is not allowed for detach()'d
136 | # Tensors.
137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
138 | output_tensors = ctx.run_function(*shallow_copies)
139 | input_grads = torch.autograd.grad(
140 | output_tensors,
141 | ctx.input_tensors + ctx.input_params,
142 | output_grads,
143 | allow_unused=True,
144 | )
145 | del ctx.input_tensors
146 | del ctx.input_params
147 | del output_tensors
148 | return (None, None) + input_grads
149 |
150 |
151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
152 | """
153 | Create sinusoidal timestep embeddings.
154 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
155 | These may be fractional.
156 | :param dim: the dimension of the output.
157 | :param max_period: controls the minimum frequency of the embeddings.
158 | :return: an [N x dim] Tensor of positional embeddings.
159 | """
160 | if not repeat_only:
161 | half = dim // 2
162 | freqs = torch.exp(
163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
164 | ).to(device=timesteps.device)
165 | args = timesteps[:, None].float() * freqs[None]
166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
167 | if dim % 2:
168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
169 | else:
170 | embedding = repeat(timesteps, 'b -> b d', d=dim)
171 | return embedding
172 |
173 |
174 | def zero_module(module):
175 | """
176 | Zero out the parameters of a module and return it.
177 | """
178 | for p in module.parameters():
179 | p.detach().zero_()
180 | return module
181 |
182 |
183 | def scale_module(module, scale):
184 | """
185 | Scale the parameters of a module and return it.
186 | """
187 | for p in module.parameters():
188 | p.detach().mul_(scale)
189 | return module
190 |
191 |
192 | def mean_flat(tensor):
193 | """
194 | Take the mean over all non-batch dimensions.
195 | """
196 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
197 |
198 |
199 | def normalization(channels):
200 | """
201 | Make a standard normalization layer.
202 | :param channels: number of input channels.
203 | :return: an nn.Module for normalization.
204 | """
205 | return GroupNorm32(32, channels)
206 |
207 |
208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
209 | class SiLU(nn.Module):
210 | def forward(self, x):
211 | return x * torch.sigmoid(x)
212 |
213 |
214 | class GroupNorm32(nn.GroupNorm):
215 | def forward(self, x):
216 | return super().forward(x.float()).type(x.dtype)
217 |
218 | def conv_nd(dims, *args, **kwargs):
219 | """
220 | Create a 1D, 2D, or 3D convolution module.
221 | """
222 | if dims == 1:
223 | return nn.Conv1d(*args, **kwargs)
224 | elif dims == 2:
225 | return nn.Conv2d(*args, **kwargs)
226 | elif dims == 3:
227 | return nn.Conv3d(*args, **kwargs)
228 | raise ValueError(f"unsupported dimensions: {dims}")
229 |
230 |
231 | def linear(*args, **kwargs):
232 | """
233 | Create a linear module.
234 | """
235 | return nn.Linear(*args, **kwargs)
236 |
237 |
238 | def avg_pool_nd(dims, *args, **kwargs):
239 | """
240 | Create a 1D, 2D, or 3D average pooling module.
241 | """
242 | if dims == 1:
243 | return nn.AvgPool1d(*args, **kwargs)
244 | elif dims == 2:
245 | return nn.AvgPool2d(*args, **kwargs)
246 | elif dims == 3:
247 | return nn.AvgPool3d(*args, **kwargs)
248 | raise ValueError(f"unsupported dimensions: {dims}")
249 |
250 |
251 | class HybridConditioner(nn.Module):
252 |
253 | def __init__(self, c_concat_config, c_crossattn_config):
254 | super().__init__()
255 | self.concat_conditioner = instantiate_from_config(c_concat_config)
256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
257 |
258 | def forward(self, c_concat, c_crossattn):
259 | c_concat = self.concat_conditioner(c_concat)
260 | c_crossattn = self.crossattn_conditioner(c_crossattn)
261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
262 |
263 |
264 | def noise_like(shape, device, repeat=False):
265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
266 | noise = lambda: torch.randn(shape, device=device)
267 | return repeat_noise() if repeat else noise()
--------------------------------------------------------------------------------
/modules/openclip/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.checkpoint import checkpoint
4 |
5 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
6 |
7 | import open_clip
8 |
9 |
10 | class AbstractEncoder(nn.Module):
11 | def __init__(self):
12 | super().__init__()
13 |
14 | def encode(self, *args, **kwargs):
15 | raise NotImplementedError
16 |
17 |
18 | class IdentityEncoder(AbstractEncoder):
19 |
20 | def encode(self, x):
21 | return x
22 |
23 |
24 | class ClassEmbedder(nn.Module):
25 | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
26 | super().__init__()
27 | self.key = key
28 | self.embedding = nn.Embedding(n_classes, embed_dim)
29 | self.n_classes = n_classes
30 | self.ucg_rate = ucg_rate
31 |
32 | def forward(self, batch, key=None, disable_dropout=False):
33 | if key is None:
34 | key = self.key
35 | # this is for use in crossattn
36 | c = batch[key][:, None]
37 | if self.ucg_rate > 0. and not disable_dropout:
38 | mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
39 | c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
40 | c = c.long()
41 | c = self.embedding(c)
42 | return c
43 |
44 | def get_unconditional_conditioning(self, bs, device="cuda"):
45 | uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
46 | uc = torch.ones((bs,), device=device) * uc_class
47 | uc = {self.key: uc}
48 | return uc
49 |
50 |
51 | def disabled_train(self, mode=True):
52 | """Overwrite model.train with this function to make sure train/eval mode
53 | does not change anymore."""
54 | return self
55 |
56 |
57 | class FrozenT5Embedder(AbstractEncoder):
58 | """Uses the T5 transformer encoder for text"""
59 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
60 | super().__init__()
61 | self.tokenizer = T5Tokenizer.from_pretrained(version)
62 | self.transformer = T5EncoderModel.from_pretrained(version)
63 | self.device = device
64 | self.max_length = max_length # TODO: typical value?
65 | if freeze:
66 | self.freeze()
67 |
68 | def freeze(self):
69 | self.transformer = self.transformer.eval()
70 | #self.train = disabled_train
71 | for param in self.parameters():
72 | param.requires_grad = False
73 |
74 | def forward(self, text):
75 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
76 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
77 | tokens = batch_encoding["input_ids"].to(self.device)
78 | outputs = self.transformer(input_ids=tokens)
79 |
80 | z = outputs.last_hidden_state
81 | return z
82 |
83 | def encode(self, text):
84 | return self(text)
85 |
86 |
87 | class FrozenCLIPEmbedder(AbstractEncoder):
88 | """Uses the CLIP transformer encoder for text (from huggingface)"""
89 | LAYERS = [
90 | "last",
91 | "pooled",
92 | "hidden"
93 | ]
94 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
95 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
96 | super().__init__()
97 | assert layer in self.LAYERS
98 | self.tokenizer = CLIPTokenizer.from_pretrained(version)
99 | self.transformer = CLIPTextModel.from_pretrained(version)
100 | self.device = device
101 | self.max_length = max_length
102 | if freeze:
103 | self.freeze()
104 | self.layer = layer
105 | self.layer_idx = layer_idx
106 | if layer == "hidden":
107 | assert layer_idx is not None
108 | assert 0 <= abs(layer_idx) <= 12
109 |
110 | def freeze(self):
111 | self.transformer = self.transformer.eval()
112 | #self.train = disabled_train
113 | for param in self.parameters():
114 | param.requires_grad = False
115 |
116 | def forward(self, text):
117 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
118 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
119 | tokens = batch_encoding["input_ids"].to(self.device)
120 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
121 | if self.layer == "last":
122 | z = outputs.last_hidden_state
123 | elif self.layer == "pooled":
124 | z = outputs.pooler_output[:, None, :]
125 | else:
126 | z = outputs.hidden_states[self.layer_idx]
127 | return z
128 |
129 | def encode(self, text):
130 | return self(text)
131 |
132 |
133 | class FrozenOpenCLIPEmbedder(AbstractEncoder):
134 | """
135 | Uses the OpenCLIP transformer encoder for text
136 | """
137 | LAYERS = [
138 | #"pooled",
139 | "last",
140 | "penultimate"
141 | ]
142 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
143 | freeze=True, layer="last"):
144 | super().__init__()
145 | assert layer in self.LAYERS
146 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
147 | del model.visual
148 | self.model = model
149 |
150 | self.device = device
151 | self.max_length = max_length
152 | if freeze:
153 | self.freeze()
154 | self.layer = layer
155 | if self.layer == "last":
156 | self.layer_idx = 0
157 | elif self.layer == "penultimate":
158 | self.layer_idx = 1
159 | else:
160 | raise NotImplementedError()
161 |
162 | def freeze(self):
163 | self.model = self.model.eval()
164 | for param in self.parameters():
165 | param.requires_grad = False
166 |
167 | def forward(self, text):
168 | tokens = open_clip.tokenize(text)
169 | z = self.encode_with_transformer(tokens.to(self.device))
170 | return z
171 |
172 | def encode_with_transformer(self, text):
173 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
174 | x = x + self.model.positional_embedding
175 | x = x.permute(1, 0, 2) # NLD -> LND
176 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
177 | x = x.permute(1, 0, 2) # LND -> NLD
178 | x = self.model.ln_final(x)
179 | return x
180 |
181 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
182 | for i, r in enumerate(self.model.transformer.resblocks):
183 | if i == len(self.model.transformer.resblocks) - self.layer_idx:
184 | break
185 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
186 | x = checkpoint(r, x, attn_mask)
187 | else:
188 | x = r(x, attn_mask=attn_mask)
189 | return x
190 |
191 | def encode(self, text):
192 | return self(text)
193 |
194 |
195 | class FrozenCLIPT5Encoder(AbstractEncoder):
196 | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
197 | clip_max_length=77, t5_max_length=77):
198 | super().__init__()
199 | self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
200 | self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
201 | # print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
202 | # f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
203 |
204 | def encode(self, text):
205 | return self(text)
206 |
207 | def forward(self, text):
208 | clip_z = self.clip_encoder.encode(text)
209 | t5_z = self.t5_encoder.encode(text)
210 | return [clip_z, t5_z]
211 |
212 |
213 |
--------------------------------------------------------------------------------
/pretrained_models/LAION_text2img/split_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | pl_sd = torch.load("model.ckpt")
4 | sd = pl_sd["state_dict"]
5 | unet_sd = {k[22:]: v for k, v in sd.items() if k[:21]=='model.diffusion_model'}
6 | vq_sd = {k[18:]: v for k, v in sd.items() if k[:17]=='first_stage_model'}
7 | cond_sd = {k[17:]: v for k, v in sd.items() if k[:16]=='cond_stage_model'}
8 |
9 | torch.save(unet_sd, 'unet.ckpt')
10 | torch.save(vq_sd, 'vqvae.ckpt')
11 | torch.save(cond_sd, 'bert.ckpt')
--------------------------------------------------------------------------------
/pretrained_models/LAION_text2img/txt2img-1p4B-eval.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-05
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.012
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | unet_config:
21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22 | params:
23 | image_size: 32
24 | in_channels: 4
25 | out_channels: 4
26 | model_channels: 320
27 | attention_resolutions:
28 | - 4
29 | - 2
30 | - 1
31 | num_res_blocks: 2
32 | channel_mult:
33 | - 1
34 | - 2
35 | - 4
36 | - 4
37 | num_heads: 8
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 1280
41 | use_checkpoint: true
42 | legacy: False
43 |
44 | first_stage_config:
45 | target: ldm.models.autoencoder.AutoencoderKL
46 | params:
47 | embed_dim: 4
48 | monitor: val/rec_loss
49 | ddconfig:
50 | double_z: true
51 | z_channels: 4
52 | resolution: 256
53 | in_channels: 3
54 | out_ch: 3
55 | ch: 128
56 | ch_mult:
57 | - 1
58 | - 2
59 | - 4
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions: []
63 | dropout: 0.0
64 | lossconfig:
65 | target: torch.nn.Identity
66 |
67 | cond_stage_config:
68 | target: ldm.modules.encoders.modules.BERTEmbedder
69 | params:
70 | n_embed: 1280
71 | n_layer: 32
72 |
--------------------------------------------------------------------------------
/pretrained_models/SD1_5/split_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | pl_sd = torch.load("model.ckpt")
4 | sd = pl_sd["state_dict"]
5 | unet_sd = {k[22:]: v for k, v in sd.items() if k[:21]=='model.diffusion_model'}
6 | vq_sd = {k[18:]: v for k, v in sd.items() if k[:17]=='first_stage_model'}
7 | cond_sd = {k[17:]: v for k, v in sd.items() if k[:16]=='cond_stage_model'}
8 |
9 | torch.save(unet_sd, 'unet.ckpt')
10 | torch.save(vq_sd, 'vqvae.ckpt')
11 | torch.save(cond_sd, 'clip.ckpt')
--------------------------------------------------------------------------------
/pretrained_models/SD2_1/split_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | pl_sd = torch.load("model.ckpt")
4 | sd = pl_sd["state_dict"]
5 | unet_sd = {k[22:]: v for k, v in sd.items() if k[:21]=='model.diffusion_model'}
6 | vq_sd = {k[18:]: v for k, v in sd.items() if k[:17]=='first_stage_model'}
7 | cond_sd = {k[17:]: v for k, v in sd.items() if k[:16]=='cond_stage_model'}
8 |
9 | torch.save(unet_sd, 'unet.ckpt')
10 | torch.save(vq_sd, 'vqvae.ckpt')
11 | torch.save(cond_sd, 'clip.ckpt')
--------------------------------------------------------------------------------
/pretrained_models/anything4_5/split_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | sd = torch.load("model.ckpt")
4 | unet_sd = {k[22:]: v for k, v in sd.items() if k[:21]=='model.diffusion_model'}
5 | vq_sd = {k[18:]: v for k, v in sd.items() if k[:17]=='first_stage_model'}
6 | cond_sd = {k[17:]: v for k, v in sd.items() if k[:16]=='cond_stage_model'}
7 |
8 | torch.save(unet_sd, 'unet.ckpt')
9 | torch.save(vq_sd, 'vqvae.ckpt')
10 | torch.save(cond_sd, 'clip.ckpt')
11 |
--------------------------------------------------------------------------------
/pretrained_models/celeba256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: false
15 | concat_mode: false
16 | monitor: val/loss
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 224
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 4
34 | num_head_channels: 32
35 | first_stage_config:
36 | target: ldm.models.autoencoder.VQModelInterface
37 | params:
38 | embed_dim: 3
39 | n_embed: 8192
40 | ddconfig:
41 | double_z: false
42 | z_channels: 3
43 | resolution: 256
44 | in_channels: 3
45 | out_ch: 3
46 | ch: 128
47 | ch_mult:
48 | - 1
49 | - 2
50 | - 4
51 | num_res_blocks: 2
52 | attn_resolutions: []
53 | dropout: 0.0
54 | lossconfig:
55 | target: torch.nn.Identity
56 | cond_stage_config: __is_unconditional__
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 48
61 | num_workers: 5
62 | wrap: false
63 | train:
64 | target: ldm.data.faceshq.CelebAHQTrain
65 | params:
66 | size: 256
67 | validation:
68 | target: ldm.data.faceshq.CelebAHQValidation
69 | params:
70 | size: 256
71 |
--------------------------------------------------------------------------------
/pretrained_models/celeba256/split_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | pl_sd = torch.load("model.ckpt")
4 | sd = pl_sd["state_dict"]
5 | unet_sd = {k[22:]: v for k, v in sd.items() if k[:21]=='model.diffusion_model'}
6 | vq_sd = {k[18:]: v for k, v in sd.items() if k[:17]=='first_stage_model'}
7 |
8 | torch.save(unet_sd, 'unet.ckpt')
9 | torch.save(vq_sd, 'vqvae.ckpt')
--------------------------------------------------------------------------------
/pretrained_models/celeba256/split_model_weights.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | pl_sd = torch.load("model.ckpt")
4 | sd = pl_sd["state_dict"]
5 | unet_sd = {k[22:]: v for k, v in sd.items() if k[:21]=='model.diffusion_model'}
6 | vq_sd = {k[18:]: v for k, v in sd.items() if k[:17]=='first_stage_model'}
7 |
8 | torch.save(unet_sd, 'unet.ckpt')
9 | torch.save(vq_sd, 'vqvae.ckpt')
--------------------------------------------------------------------------------
/pretrained_models/counterfeitV25/split_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from safetensors import safe_open
3 |
4 | def load_safetensors(file_path):
5 | tensors = {}
6 | with safe_open(file_path, framework="pt", device="cpu") as f:
7 | for key in f.keys():
8 | tensors[key] = f.get_tensor(key)
9 | return tensors
10 |
11 | sd = load_safetensors("counterfeitV25Pruned.safetensors")
12 | unet_sd = {k[22:]: v for k, v in sd.items() if k[:21]=='model.diffusion_model'}
13 | vq_sd = {k[18:]: v for k, v in sd.items() if k[:17]=='first_stage_model'}
14 | cond_sd = {k[17:]: v for k, v in sd.items() if k[:16]=='cond_stage_model'}
15 |
16 | torch.save(unet_sd, 'unet.ckpt')
17 | torch.save(vq_sd, 'vqvae.ckpt')
18 | torch.save(cond_sd, 'clip.ckpt')
19 |
--------------------------------------------------------------------------------
/pretrained_models/negative/EasyNegative.safetensors:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cplusx/layout_diffuse/9666cb867313aa693775f6134442dea3734565a5/pretrained_models/negative/EasyNegative.safetensors
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | albumentations==1.2.1
2 | h5py
3 | imageio
4 | lmdb
5 | matplotlib
6 | opencv-python
7 | pillow
8 | pytorch-lightning
9 | scikit-image
10 | scikit-learn
11 | scipy
12 | #torch==1.12.1+cu116
13 | #torchvision==0.13.0+cu116
14 | tqdm
15 | wandb
16 | clean-fid
17 | einops
18 | pycocotools
19 | perceiver-pytorch
20 | transformers
21 | gdown
22 | -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
23 | open-clip-torch
24 | openai
--------------------------------------------------------------------------------
/run_gradio.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from datetime import datetime
3 | import gradio as gr
4 | import os
5 | import torch
6 | import json
7 | from train_utils import get_models, get_DDPM
8 | import logging
9 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
10 | from data.coco_w_stuff import get_coco_id_mapping
11 | import numpy as np
12 | from test_utils import sample_one_image, parse_test_args, load_test_models, load_model_weights
13 |
14 | coco_id_to_name = get_coco_id_mapping()
15 | coco_name_to_id = {v: int(k) for k, v in coco_id_to_name.items()}
16 |
17 | args = parse_test_args()
18 | ddpm_model = load_test_models(args)
19 | load_model_weights(ddpm_model=ddpm_model, args=args)
20 |
21 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
22 | ddpm_model = ddpm_model.to(device)
23 | ddpm_model.text_fn = ddpm_model.text_fn.to(device)
24 | ddpm_model.text_fn.device = device
25 | ddpm_model.denoise_fn = ddpm_model.denoise_fn.to(device)
26 | ddpm_model.vqvae_fn = ddpm_model.vqvae_fn.to(device)
27 |
28 | yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True).to(device)
29 |
30 | def obtain_bbox_from_yolo(image):
31 | H, W = image.shape[:2]
32 | results = yolo_model(image)
33 | # convert results to [x, y, w, h, object_name]
34 | xyxy_conf_cls = results.xyxy[0].detach().cpu().numpy()
35 | bboxes = []
36 | for x1, y1, x2, y2, conf, cls_idx in xyxy_conf_cls:
37 | x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
38 | cls_name = yolo_model.names[int(cls_idx)]
39 | if conf >= 0.5:
40 | bboxes.append([x1 / W, y1 / H, (x2 - x1) / W, (y2 - y1) / H, cls_name])
41 | return bboxes
42 |
43 | def save_bboxes(bboxes, save_dir):
44 | current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
45 | file_name = str(hash(str(current_time)))[1:10]
46 | os.makedirs(save_dir, exist_ok=True)
47 | save_path = os.path.join(save_dir, f'{file_name}.txt')
48 | with open(save_path, 'w') as OUT:
49 | for bbox in bboxes:
50 | OUT.write(','.join([str(x) for x in bbox]))
51 | OUT.write('\n')
52 | return save_path
53 |
54 | def sample_images(ref_image):
55 | bboxes = obtain_bbox_from_yolo(ref_image)
56 | bbox_path = save_bboxes(bboxes, 'tmp')
57 | image, image_with_bbox, canvas_with_bbox = sample_one_image(
58 | bbox_path,
59 | ddpm_model,
60 | device,
61 | coco_name_to_id, coco_id_to_name,
62 | api_key=args['openai_api_key'],
63 | image_size=ref_image.shape[:2],
64 | additional_caption=args['additional_caption']
65 | )
66 | os.remove(bbox_path)
67 | if image is None:
68 | # Return a placeholder image and a message
69 | placeholder = np.zeros((ref_image.shape[0], ref_image.shape[1], 3), dtype=np.uint8)
70 | message = "No object found in the image"
71 | return message, placeholder, placeholder, placeholder
72 | else:
73 | return "Success", image, image_with_bbox, canvas_with_bbox
74 |
75 | # Define the Gradio interface with a message component
76 | input_image = gr.inputs.Image()
77 | output_images = [gr.outputs.Image(type='numpy') for i in range(3)]
78 | message = gr.outputs.Textbox(label="Information", type="text")
79 | interface = gr.Interface(
80 | fn=sample_images,
81 | inputs=input_image,
82 | outputs=[message] + output_images,
83 | capture_session=True,
84 | title="LayoutDiffuse",
85 | description="Drop a reference image to generate a new image with the same layout",
86 | allow_flagging=False,
87 | live=False
88 | )
89 |
90 | interface.launch(share=True)
91 |
--------------------------------------------------------------------------------
/run_gradio_merge.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | import gradio as gr
3 | import os
4 | import torch
5 | import logging
6 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
7 | from data.coco_w_stuff import get_coco_id_mapping
8 | import numpy as np
9 | from test_utils import sample_one_image, parse_test_args, load_test_models, load_model_weights
10 |
11 | coco_id_to_name = get_coco_id_mapping()
12 | coco_name_to_id = {v: int(k) for k, v in coco_id_to_name.items()}
13 |
14 | args = parse_test_args()
15 | ddpm_model = load_test_models(args)
16 | load_model_weights(ddpm_model=ddpm_model, args=args)
17 |
18 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
19 | ddpm_model = ddpm_model.to(device)
20 | ddpm_model.text_fn = ddpm_model.text_fn.to(device)
21 | ddpm_model.text_fn.device = device
22 | ddpm_model.denoise_fn = ddpm_model.denoise_fn.to(device)
23 | ddpm_model.vqvae_fn = ddpm_model.vqvae_fn.to(device)
24 |
25 | # ddpm_model.merge('pretrained_models/anything4_5/unet.ckpt', alpha=1.)
26 | # ddpm_model.merge('pretrained_models/counterfeitV25/unet.ckpt', alpha=1.)
27 |
28 | yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True).to(device)
29 |
30 | def obtain_bbox_from_yolo(image):
31 | H, W = image.shape[:2]
32 | results = yolo_model(image)
33 | # convert results to [x, y, w, h, object_name]
34 | xyxy_conf_cls = results.xyxy[0].detach().cpu().numpy()
35 | bboxes = []
36 | for x1, y1, x2, y2, conf, cls_idx in xyxy_conf_cls:
37 | x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
38 | cls_name = yolo_model.names[int(cls_idx)]
39 | if conf >= 0.5:
40 | bboxes.append([x1 / W, y1 / H, (x2 - x1) / W, (y2 - y1) / H, cls_name])
41 | return bboxes
42 |
43 | def save_bboxes(bboxes, save_dir):
44 | current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
45 | file_name = str(hash(str(current_time)))[1:10]
46 | os.makedirs(save_dir, exist_ok=True)
47 | save_path = os.path.join(save_dir, f'{file_name}.txt')
48 | with open(save_path, 'w') as OUT:
49 | for bbox in bboxes:
50 | OUT.write(','.join([str(x) for x in bbox]))
51 | OUT.write('\n')
52 | return save_path
53 |
54 | def sample_images(ref_image, user_input):
55 | bboxes = obtain_bbox_from_yolo(ref_image)
56 | bbox_path = save_bboxes(bboxes, 'tmp')
57 | image, image_with_bbox, canvas_with_bbox = sample_one_image(
58 | bbox_path,
59 | ddpm_model,
60 | device,
61 | coco_name_to_id, coco_id_to_name,
62 | api_key=args['openai_api_key'],
63 | image_size=ref_image.shape[:2],
64 | additional_caption=args['additional_caption'] + user_input
65 | )
66 | os.remove(bbox_path)
67 | if image is None:
68 | # Return a placeholder image and a message
69 | placeholder = np.zeros((ref_image.shape[0], ref_image.shape[1], 3), dtype=np.uint8)
70 | message = "No object found in the image"
71 | return message, placeholder, placeholder, placeholder
72 | else:
73 | return "Success", image, image_with_bbox, canvas_with_bbox
74 |
75 | # Define the Gradio interface with a message component
76 | input_image = gr.inputs.Image()
77 | input_text = gr.inputs.Textbox(type='text', label='Additional caption')
78 | output_images = [gr.outputs.Image(type='numpy') for i in range(3)]
79 | message = gr.outputs.Textbox(label="Information", type="text")
80 | interface = gr.Interface(
81 | fn=sample_images,
82 | inputs=[input_image, input_text],
83 | outputs=[message] + output_images,
84 | capture_session=True,
85 | title="LayoutDiffuse",
86 | description="Drop a reference image to generate a new image with the same layout",
87 | allow_flagging=False,
88 | live=False
89 | )
90 |
91 | interface.launch(share=True)
92 |
--------------------------------------------------------------------------------
/sampling.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import json
5 | from pytorch_lightning import Trainer
6 | from train_utils import get_models, get_DDPM
7 | from test_utils import load_model_weights
8 |
9 | if __name__ == '__main__':
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument(
12 | '-c', '--config', type=str,
13 | default='config/train.json')
14 | parser.add_argument(
15 | '-n', '--num_repeat', type=int,
16 | default=1, help='the number of images for each condition')
17 | parser.add_argument(
18 | '-e', '--epoch', type=int,
19 | default=None, help='which epoch to evaluate, if None, will use the latest')
20 | parser.add_argument(
21 | '--nnode', type=int, default=1
22 | )
23 | parser.add_argument(
24 | '--model_path', type=str,
25 | default=None, help='model path for generating layout diffuse, if not provided, will use the latest.ckpt')
26 |
27 | ''' parser configs '''
28 | args_raw = parser.parse_args()
29 | with open(args_raw.config, 'r') as IN:
30 | args = json.load(IN)
31 | args.update(vars(args_raw))
32 | # args['gpu_ids'] = [0] # DEBUG
33 | expt_name = args['expt_name']
34 | expt_dir = args['expt_dir']
35 | expt_path = os.path.join(expt_dir, expt_name)
36 | os.makedirs(expt_path, exist_ok=True)
37 |
38 | '''1. create denoising model'''
39 | denoise_args = args['denoising_model']['model_args']
40 | models = get_models(args)
41 |
42 | diffusion_configs = args['diffusion']
43 | ddpm_model = get_DDPM(
44 | diffusion_configs=diffusion_configs,
45 | log_args=args,
46 | **models
47 | )
48 |
49 | '''2. create a dataloader which generates'''
50 | from test_utils import get_test_dataset, get_test_callbacks
51 | test_dataset, test_loader = get_test_dataset(args)
52 |
53 | '''3. callbacks'''
54 | callbacks = get_test_callbacks(args, expt_path)
55 |
56 | '''4. load checkpoint'''
57 | print('INFO: loading checkpoint')
58 | if args['model_path'] is not None:
59 | ckpt_path = args['model_path']
60 | else:
61 | expt_path = os.path.join(args['expt_dir'], args['expt_name'])
62 | if args['epoch'] is None:
63 | ckpt_to_use = 'latest.ckpt'
64 | else:
65 | ckpt_to_use = f'epoch={args["epoch"]:04d}.ckpt'
66 | ckpt_path = os.path.join(expt_path, ckpt_to_use)
67 | print(ckpt_path)
68 | if os.path.exists(ckpt_path):
69 | print(f'INFO: Found checkpoint {ckpt_path}')
70 | # ckpt = torch.load(ckpt_path, map_location='cpu')['state_dict']
71 | ''' DEBUG '''
72 | # ckpt_denoise_fn = {k.replace('denoise_fn.', ''): v for k, v in ckpt.items() if 'denoise_fn' in k}
73 | # ddpm_model.denoise_fn.load_state_dict(ckpt_denoise_fn)
74 | # ddpm_model.load_state_dict(ckpt)
75 | else:
76 | ckpt_path = None
77 | raise RuntimeError('Cannot do inference without pretrained checkpoint')
78 |
79 | '''5. trianer'''
80 | trainer_args = {
81 | "max_epochs": 1000,
82 | "accelerator": "gpu",
83 | "devices": [0],
84 | "limit_val_batches": 1,
85 | "strategy": "ddp",
86 | "check_val_every_n_epoch": 1,
87 | "num_nodes": args['nnode']
88 | # "benchmark" :True
89 | }
90 | config_trainer_args = args['trainer_args'] if args.get('trainer_args') is not None else {}
91 | trainer_args.update(config_trainer_args)
92 | print(f'Training args are {trainer_args}')
93 | trainer = Trainer(
94 | callbacks = callbacks,
95 | **trainer_args
96 | )
97 |
98 | '''6. start sampling'''
99 | '''use trainer for sampling, you need a image saver callback to save images, useful for generate many images'''
100 | num_loop = args['num_repeat']
101 | for _ in range(num_loop):
102 | # trainer.test(ddpm_model, test_loader) # DEBUG
103 | trainer.test(ddpm_model, test_loader, ckpt_path=ckpt_path)
104 |
--------------------------------------------------------------------------------
/sampling_in_background.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from data.coco_w_stuff import get_coco_id_mapping
4 | import numpy as np
5 | import cv2
6 | import time
7 | from test_utils import sample_one_image, parse_test_args, load_test_models, load_model_weights
8 | coco_id_to_name = get_coco_id_mapping()
9 | coco_name_to_id = {v: int(k) for k, v in coco_id_to_name.items()}
10 |
11 | if __name__ == '__main__':
12 | args = parse_test_args()
13 | ddpm_model = load_test_models(args)
14 | load_model_weights(ddpm_model=ddpm_model, args=args)
15 |
16 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
17 | ddpm_model = ddpm_model.to(device)
18 | ddpm_model.text_fn = ddpm_model.text_fn.to(device)
19 | ddpm_model.text_fn.device = device
20 | ddpm_model.denoise_fn = ddpm_model.denoise_fn.to(device)
21 | ddpm_model.vqvae_fn = ddpm_model.vqvae_fn.to(device)
22 |
23 | while True:
24 | # read file in the folder. If there is a file, sample the image and save it to the folder "flask_images_sampled" and remove the file from the folder "flask_images_to_sample"
25 |
26 | from glob import glob
27 | files_to_sample = glob('interactive_plotting/tmp/*.txt')
28 | for f in files_to_sample:
29 | print('INFO: processing file', f)
30 | image, image_with_bbox, canvas_with_bbox = sample_one_image(
31 | f, ddpm_model, device,
32 | class_name_to_id=coco_name_to_id,
33 | class_id_to_name=coco_id_to_name,
34 | api_key=args['openai_api_key'],
35 | additional_caption=args['additional_caption']
36 | )
37 | # save the image
38 | cat_image = np.concatenate([image, image_with_bbox, canvas_with_bbox], axis=1)
39 | cv2.imwrite(f.replace('.txt', '.jpg'), (cat_image[..., ::-1] * 255).astype(np.uint8))
40 | # remove the file
41 | os.remove(f)
42 |
43 | time.sleep(1)
--------------------------------------------------------------------------------
/scripts/convert_jpg.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import cv2
4 | from tqdm import tqdm
5 |
6 | def read_convert_and_save(img_path, save_path):
7 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
8 | image = cv2.imread(img_path)
9 | cv2.imwrite(save_path, image)
10 |
11 | if __name__ == '__main__':
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--indir', type=str)
14 |
15 | ''' parser configs '''
16 | args = parser.parse_args()
17 |
18 | in_dir = args.indir
19 | out_dir = os.path.join(
20 | os.path.dirname(in_dir),
21 | os.path.basename(in_dir) + f'-jpg'
22 | )
23 |
24 | image_names = os.listdir(in_dir)
25 |
26 | for image_name in tqdm(image_names, desc='convert image to jpg'):
27 | if not (image_name.endswith('.jpg') or image_name.endswith('.png')):
28 | continue
29 | img_path = os.path.join(in_dir, image_name)
30 | save_img_path = os.path.join(out_dir, image_name.replace('.png', '.jpg'))
31 | if os.path.exists(save_img_path):
32 | continue
33 | read_convert_and_save(img_path, save_img_path)
--------------------------------------------------------------------------------
/scripts/convert_npz_to_npy.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import argparse
4 | import glob
5 | from tqdm import tqdm
6 |
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('-s', '--src', type=str, default='', help='source images directory')
9 | args = parser.parse_args()
10 |
11 | indir = args.src
12 | outdir = indir+'_npy'
13 | os.makedirs(outdir, exist_ok=True)
14 |
15 | npz_files = glob.glob(indir + '/*.npz')
16 | print(len(npz_files))
17 | for npz_file in tqdm(npz_files):
18 | out_path = npz_file.replace(indir, outdir)
19 | out_path = out_path.replace('npz', 'npy')
20 | image = np.load(npz_file)['image']
21 |
22 | with open(out_path, 'wb') as OUT:
23 | np.save(OUT, image*255)
--------------------------------------------------------------------------------
/scripts/download_celebMask.sh:
--------------------------------------------------------------------------------
1 | DIR=~/disk2/data/CelebAMask-HQ
2 | mkdir -p $DIR
3 |
4 | cd $DIR
5 | gdown https://drive.google.com/uc?id=1badu11NqxGf6qM3PTTooQDJvQbejgbTv
--------------------------------------------------------------------------------
/scripts/download_coco.sh:
--------------------------------------------------------------------------------
1 | mkdir -p ~/disk2/data/COCO
2 | cd ~/disk2/data/COCO
3 | wget http://images.cocodataset.org/zips/train2017.zip
4 | wget http://images.cocodataset.org/zips/val2017.zip
5 | wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
6 | wget http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip
7 | ls *.zip | while read f; do
8 | unzip $f;
9 | done
10 |
--------------------------------------------------------------------------------
/scripts/download_pretrained_models.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | download_face() {
4 | mkdir -p pretrained_models/celeba256
5 | wget -O pretrained_models/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip
6 | cd pretrained_models/celeba256
7 | unzip -o celeba-256.zip
8 | python split_model.py
9 | }
10 |
11 | download_ldm() {
12 | mkdir -p pretrained_models/LAION_text2img
13 | wget -O pretrained_models/LAION_text2img/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/txt2img-f8-large/model.ckpt
14 | cd pretrained_models/LAION_text2img
15 | python split_model.py
16 | }
17 |
18 | download_sd1_5() {
19 | mkdir -p pretrained_models/SD1_5
20 | wget -O pretrained_models/SD1_5/model.ckpt https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt
21 | cd pretrained_models/SD1_5
22 | python split_model.py
23 | }
24 |
25 | download_sd2_1() {
26 | mkdir -p pretrained_models/SD2_1
27 | wget -O pretrained_models/SD2_1/model.ckpt https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-nonema-pruned.ckpt
28 | cd pretrained_models/SD2_1
29 | python split_model.py
30 | }
31 |
32 | download_all() {
33 | download_face
34 | cd ../..
35 | download_ldm
36 | cd ../..
37 | download_sd1_5
38 | cd ../..
39 | download_sd2_1
40 | }
41 |
42 | case $1 in
43 | "face")
44 | download_face
45 | ;;
46 | "ldm")
47 | download_ldm
48 | ;;
49 | "SD1_5")
50 | download_sd1_5
51 | ;;
52 | "SD2_1")
53 | download_sd2_1
54 | ;;
55 | "all")
56 | download_all
57 | ;;
58 | *)
59 | echo "Invalid argument. Usage: bash download.sh [face|ldm|SD1_5|SD2_1|all]"
60 | ;;
61 | esac
62 |
--------------------------------------------------------------------------------
/scripts/download_vg.sh:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2018 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | VG_DIR=~/disk2/data/VG
16 | mkdir -p $VG_DIR
17 |
18 | wget https://visualgenome.org/static/data/dataset/objects.json.zip -O $VG_DIR/objects.json.zip
19 | wget https://visualgenome.org/static/data/dataset/attributes.json.zip -O $VG_DIR/attributes.json.zip
20 | wget https://visualgenome.org/static/data/dataset/relationships.json.zip -O $VG_DIR/relationships.json.zip
21 | wget https://visualgenome.org/static/data/dataset/object_alias.txt -O $VG_DIR/object_alias.txt
22 | wget https://visualgenome.org/static/data/dataset/relationship_alias.txt -O $VG_DIR/relationship_alias.txt
23 | wget https://visualgenome.org/static/data/dataset/image_data.json.zip -O $VG_DIR/image_data.json.zip
24 | wget https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip -O $VG_DIR/images.zip
25 | wget https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip -O $VG_DIR/images2.zip
26 |
27 | unzip $VG_DIR/objects.json.zip -d $VG_DIR
28 | unzip $VG_DIR/attributes.json.zip -d $VG_DIR
29 | unzip $VG_DIR/relationships.json.zip -d $VG_DIR
30 | unzip $VG_DIR/image_data.json.zip -d $VG_DIR
31 | unzip $VG_DIR/images.zip -d $VG_DIR/images
32 | unzip $VG_DIR/images2.zip -d $VG_DIR/images
33 |
34 | python scripts/preprocess_vg.py
--------------------------------------------------------------------------------
/scripts/eval_scripts/celeb_mask.sh:
--------------------------------------------------------------------------------
1 | # python fid_eval.py \
2 | # --dataset celeb_mask \
3 | # -s /home/ubuntu/disk2/data/face/CelebAMask-HQ/CelebA-HQ-img-train-256x256 \
4 | # -d experiments/celeb_mask_ldm_partial_attn/sampling_at_00279_image
5 |
6 | run_once () {
7 | res=$1
8 | epoch=$2
9 | python fid_eval.py \
10 | -s /home/ubuntu/disk2/data/face/CelebAMask-HQ/CelebA-HQ-img \
11 | --resize_s \
12 | -d experiments/celeb_mask_ldm_${res}_samples/epoch_${epoch}/image > tmp/data_efficiency_res_${res}_epoch_${epoch}.txt
13 | }
14 |
15 | # CUDA_VISIBLE_DEVICES=1 run_once 128 "00099" && run_once 128 "00199" && run_once 128 "00499" && run_once 128 "00999" && run_once 128 "01999" &
16 |
17 | # CUDA_VISIBLE_DEVICES=2 run_once 256 "00049" && run_once 256 "00099" && run_once 256 "00249" && run_once 256 "00499" && run_once 256 "00999" &
18 |
19 | # CUDA_VISIBLE_DEVICES=3 run_once 512 "00024" && run_once 512 "00049" && run_once 512 "00124" && run_once 512 "00249" && run_once 512 "00499" &
20 |
21 | CUDA_VISIBLE_DEVICES=0 run_once 1024 "00012" && run_once 1024 "00024" && run_once 1024 "00062" && run_once 1024 "00124" && run_once 1024 "00249" &
22 |
23 | CUDA_VISIBLE_DEVICES=1 run_once 2048 "00006" && run_once 2048 "00012" && run_once 2048 "00031" && run_once 2048 "00062" && run_once 2048 "00124" &
24 |
25 | # seq 4 5 10| while read e; do
26 | # python fid_eval.py \
27 | # -s /home/ubuntu/disk2/data/face/CelebAMask-HQ/CelebA-HQ-img \
28 | # --resize_s \
29 | # -d experiments/celeb_mask_ldm_v2/epoch_0000$e/image > tmp/celeb_v2_fid_e_$e.txt
30 | # done
31 |
32 | # seq 14 5 30 | while read e; do
33 | # python fid_eval.py \
34 | # -s /home/ubuntu/disk2/data/face/CelebAMask-HQ/CelebA-HQ-img \
35 | # --resize_s \
36 | # -d experiments/celeb_mask_ldm_v2/epoch_000$e/image > tmp/celeb_v2_fid_e_$e.txt
37 | # done
--------------------------------------------------------------------------------
/scripts/eval_scripts/convert_npz_to_npy.sh:
--------------------------------------------------------------------------------
1 | # python scripts/convert_npz_to_npy.py -s experiments/laion_ldm_cocostuff_layout_caption_ablation_no_instance_attn/epoch_00009_plms_100_5.0/raw_tensor
2 | # python scripts/convert_npz_to_npy.py -s experiments/laion_ldm_cocostuff_layout_caption_ablation_no_instance_attn/epoch_00029_plms_100_5.0/raw_tensor
3 | # python scripts/convert_npz_to_npy.py -s experiments/laion_ldm_cocostuff_layout_caption_ablation_no_instance_attn/epoch_00059_plms_100_5.0/raw_tensor
4 |
5 | # python scripts/convert_npz_to_npy.py -s experiments/laion_ldm_cocostuff_layout_caption_v9/epoch_00059_plms_200_5.0/raw_tensor
6 | seq 4 5 10 | while read e; do
7 | python scripts/convert_npz_to_npy.py -s experiments/celeb_mask_ldm_v2/epoch_0000$e/raw_tensor
8 | done
9 |
10 | seq 14 5 30 | while read e; do
11 | python scripts/convert_npz_to_npy.py -s experiments/celeb_mask_ldm_v2/epoch_000$e/raw_tensor
12 | done
13 | # python scripts/convert_npz_to_npy.py -s experiments/laion_ldm_cocostuff_layout_caption_v9/epoch_00009/raw_tensor
14 | # python scripts/convert_npz_to_npy.py -s experiments/laion_ldm_cocostuff_layout_caption_v9/epoch_00029/raw_tensor
15 |
16 | # python scripts/convert_npz_to_npy.py -s experiments/laion_ldm_cocostuff_layout_no_caption/epoch_00009_plms_100_5.0/raw_tensor
17 | # python scripts/convert_npz_to_npy.py -s experiments/laion_ldm_cocostuff_layout_no_caption/epoch_00029_plms_100_5.0/raw_tensor
18 | # python scripts/convert_npz_to_npy.py -s experiments/laion_ldm_cocostuff_layout_no_caption/epoch_00059_plms_100_5.0/raw_tensor
19 |
--------------------------------------------------------------------------------
/scripts/eval_scripts/fid_coco_layout_ablation.sh:
--------------------------------------------------------------------------------
1 | run_once () {
2 | expt=$1
3 | appendix=$3
4 | echo "Score for ${expt}, epoch $2"
5 | python fid_eval.py \
6 | -s /home/ubuntu/disk2/data/COCO/train2017 \
7 | --resize_s \
8 | -d experiments/${expt}/epoch_000$2$appendix/raw_tensor_npy
9 | # -d experiments/${expt}/epoch_000$2$appendix/sample_image
10 | }
11 |
12 | expt="laion_ldm_cocostuff_layout_no_caption"
13 | appendix="_plms_100_5.0"
14 | epoch="09"
15 | CUDA_VISIBLE_DEVICES=5 run_once $expt $epoch $appendix >> tmp/fid_${expt}_${epoch}.txt 2>&1 &
16 |
17 | expt="laion_ldm_cocostuff_layout_no_caption"
18 | appendix="_plms_100_5.0"
19 | epoch="29"
20 | CUDA_VISIBLE_DEVICES=6 run_once $expt $epoch $appendix >> tmp/fid_${expt}_${epoch}.txt 2>&1 &
21 |
22 | expt="laion_ldm_cocostuff_layout_no_caption"
23 | appendix="_plms_100_5.0"
24 | epoch="59"
25 | CUDA_VISIBLE_DEVICES=7 run_once $expt $epoch $appendix >> tmp/fid_${expt}_${epoch}.txt 2>&1 &
26 |
27 | # expt="laion_ldm_cocostuff_layout_caption_v9"
28 | # appendix="_plms_200_5.0"
29 | # epoch="59"
30 | # CUDA_VISIBLE_DEVICES=6 run_once $expt $epoch $appendix >> tmp/fid_${expt}_${epoch}.txt 2>&1 &
31 |
32 | # expt="laion_ldm_cocostuff_layout_caption_ablation_no_instance_attn"
33 | # appendix="_plms_100_5.0"
34 | # epoch="09"
35 | # CUDA_VISIBLE_DEVICES=6 run_once $expt $epoch $appendix >> tmp/fid_${expt}_${epoch}.txt 2>&1 &
36 |
37 | # expt="laion_ldm_cocostuff_layout_caption_ablation_no_instance_attn"
38 | # appendix="_plms_100_5.0"
39 | # epoch="29"
40 | # CUDA_VISIBLE_DEVICES=5 run_once $expt $epoch $appendix >> tmp/fid_${expt}_${epoch}.txt 2>&1 &
41 |
42 | # expt="laion_ldm_cocostuff_layout_caption_ablation_no_instance_attn"
43 | # appendix="_plms_100_5.0"
44 | # epoch="59"
45 | # CUDA_VISIBLE_DEVICES=7 run_once $expt $epoch $appendix >> tmp/fid_${expt}_${epoch}.txt 2>&1 &
--------------------------------------------------------------------------------
/scripts/remove_empty_file_in_vg.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from glob import glob
4 |
5 | VG_DIR = '/home/ubuntu/disk2/data/VG/images'
6 | # VG_DIR = 'experiments/laion_ldm_cocostuff_layout_caption_v9/epoch_00059_plms_200_5.0/sampled_256_cropped_224'
7 | image_paths = glob(VG_DIR+'/**/*.jpg') + glob(VG_DIR+'/**/*.png')
8 |
9 | for path in image_paths:
10 | try:
11 | Image.open(path)
12 | except:
13 | print(f'{path} failed, remove it')
14 | os.system(f'rm {path}')
--------------------------------------------------------------------------------
/scripts/resize_images.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | import glob
5 | from PIL import Image
6 | from tqdm import tqdm
7 | from PIL import ImageFile
8 | ImageFile.LOAD_TRUNCATED_IMAGES = True
9 |
10 | def process_image(img_path, save_path, size, mode):
11 | print('save image to ', save_path)
12 | img = Image.open(img_path)
13 | img = img.resize((size, size), mode)
14 | img = img.save(save_path)
15 |
16 | def read_resize_and_save(img_path, save_path, size, mode=Image.BICUBIC):
17 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
18 | if img_path.endswith('.png') or img_path.endswith('.jpg'):
19 | process_image(img_path, save_path, size, mode)
20 |
21 | if __name__ == '__main__':
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument('--indir', type=str)
24 | parser.add_argument('--size', type=int)
25 |
26 | ''' parser configs '''
27 | args = parser.parse_args()
28 | size = args.size
29 |
30 | in_dir = args.indir
31 | out_dir = os.path.join(
32 | os.path.dirname(in_dir),
33 | os.path.basename(in_dir) + f'-{size}'
34 | )
35 |
36 | image_names = glob.glob(in_dir + '/*.jpg') + glob.glob(in_dir + '/*.png') + glob.glob(in_dir + '/**/*.jpg') + glob.glob(in_dir + '/**/*.png')
37 |
38 | for image_name in tqdm(image_names):
39 | save_img_path = image_name.replace(in_dir, out_dir)
40 | if image_name.endswith('.jpg'):
41 | save_img_path = save_img_path.replace('.jpg', '.png')
42 | if os.path.exists(save_img_path):
43 | continue
44 | try:
45 | read_resize_and_save(image_name, save_img_path, size, mode=Image.BICUBIC)
46 | except:
47 | print(image_name, 'is broken')
--------------------------------------------------------------------------------
/scripts/sampling_scripts/dist_sampling.sh:
--------------------------------------------------------------------------------
1 | export OMP_NUM_THREADS=1
2 | export MKL_NUM_THREADS=1
3 | export NNODE=3
4 | torchrun \
5 | --nnodes=$NNODE \
6 | --nproc_per_node 8 \
7 | --rdzv_id v9_dist_sample \
8 | --rdzv_backend c10d \
9 | --rdzv_endpoint $1:29500 \
10 | sampling.py -c $2 --nnode $NNODE -e $3 -n $4
11 |
12 | # usage: bash scripts/sampling_scripts/dist_sampling.sh \
13 | # 172.31.0.139 configs/laion_cocostuff_text_v9.json \
14 | # 59 5 # this is machine 1 ip address
15 |
--------------------------------------------------------------------------------
/scripts/train_scripts/dist_train.sh:
--------------------------------------------------------------------------------
1 | export OMP_NUM_THREADS=1
2 | export MKL_NUM_THREADS=1
3 | export NNODE=4
4 | torchrun \
5 | --nnodes=$NNODE \
6 | --nproc_per_node 4 \
7 | --rdzv_id v9_dist \
8 | --rdzv_backend c10d \
9 | --rdzv_endpoint $1:29500 \
10 | main.py -c $2 -n $NNODE -r
11 |
12 | # usage: bash scripts/train_scripts/dist_train.sh 172.31.42.68
--------------------------------------------------------------------------------
/test_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from torch.utils.data import DataLoader
5 | from data.random_sampling import RandomNoise
6 | from model_utils import default, get_obj_from_str
7 | from callbacks.coco_layout.sampling_save_fig import ColorMapping, plot_bbox_without_overlap, plot_bounding_box
8 | import cv2
9 |
10 | def get_test_dataset(args):
11 | sampling_args = args['sampling_args']
12 | sampling_w_noise = default(sampling_args.get('sampling_w_noise'), False)
13 | if sampling_w_noise:
14 | test_dataset = RandomNoise(
15 | sampling_args['image_size'],
16 | sampling_args['image_size'],
17 | sampling_args['in_channel'],
18 | sampling_args['num_samples']
19 | )
20 | else:
21 | from data import get_dataset
22 | args['data']['val_args']['data_len'] = sampling_args['num_samples']
23 | _, test_dataset = get_dataset(**args['data'])
24 | test_loader = DataLoader(test_dataset, batch_size=args['data']['batch_size'], num_workers=4, shuffle=False)
25 | return test_dataset, test_loader
26 |
27 | def get_test_callbacks(args, expt_path):
28 | sampling_args = args['sampling_args']
29 | callbacks = []
30 | callbacks_obj = sampling_args.get('callbacks')
31 | for target in callbacks_obj:
32 | callbacks.append(
33 | get_obj_from_str(target)(expt_path)
34 | )
35 | return callbacks
36 |
37 | def postprocess_image(batched_x, batched_bbox, class_id_to_name, image_callback=lambda x: x):
38 | x = batched_x[0]
39 | bbox = batched_bbox[0]
40 | x = x.permute(1, 2, 0).detach().cpu().numpy().clip(-1, 1)
41 | x = (x + 1) / 2
42 | x = image_callback(x)
43 | image_with_bbox = overlap_image_with_bbox(x, bbox, class_id_to_name)
44 | canvas_with_bbox = overlap_image_with_bbox(np.ones_like(x), bbox, class_id_to_name)
45 | return x, image_with_bbox, canvas_with_bbox
46 |
47 | def overlap_image_with_bbox(image, bbox, class_id_to_name):
48 | label_color_mapper = ColorMapping(id_class_mapping=class_id_to_name)
49 | image_with_bbox = plot_bbox_without_overlap(
50 | image.copy(),
51 | bbox,
52 | label_color_mapper
53 | ) if len(bbox) <= 10 else None
54 | if image_with_bbox is not None:
55 | return image_with_bbox
56 | return plot_bounding_box(
57 | image.copy(),
58 | bbox,
59 | label_color_mapper
60 | )
61 |
62 | def generate_completion(caption, api_key, additional_caption=''):
63 | import openai
64 | # check if api_key is valid
65 | def validate_api_key(api_key):
66 | import re
67 | regex = "^sk-[a-zA-Z0-9]{48}$" # regex pattern for OpenAI API key
68 | if not isinstance(api_key, str):
69 | return None
70 | if not re.match(regex, api_key):
71 | return None
72 | return api_key
73 | openai.api_key = validate_api_key(api_key)
74 | if openai.api_key is None:
75 | print('WARNING: invalid OpenAI API key, using default caption')
76 | return caption
77 | prompt = f'Describe a scene with following words: ' + caption + '. Use the above words to generate a prompt for drawing with a diffusion model. Use at least 30 words and at most 80 words and include all given words. The final image should looks nice and be related to the given words'
78 |
79 | response = openai.ChatCompletion.create(
80 | model="gpt-3.5-turbo",
81 | messages=[{
82 | "role": "user",
83 | "content": prompt
84 | }]
85 | )
86 |
87 | return response.choices[0].message.content.strip() + additional_caption
88 |
89 | def concatenate_class_labels_to_caption(objects, class_id_to_name, api_key=None, additional_caption=''):
90 | # if want to add additional description for styles, add it to additonal_caption
91 | caption = ''
92 | for i in objects:
93 | caption += class_id_to_name[i[4]+1] + ', '
94 | caption = caption.rstrip(', ')
95 | if api_key is not None:
96 | caption = generate_completion(caption, api_key=api_key, additional_caption=additional_caption)
97 | print('INFO: using openai text completion and the generated caption is: \n', caption)
98 | else:
99 | caption = caption + additional_caption
100 | print('INFO: using default caption: \n', caption)
101 | return caption
102 |
103 | def sample_one_image(bbox_path, ddpm_model, device, class_name_to_id, class_id_to_name, api_key=None, image_size=(512, 512), additional_caption=''):
104 | # the format of text file is: x, y, w, h, class_id
105 | with open(bbox_path, 'r') as IN:
106 | raw_objects = [i.strip().split(',') for i in IN]
107 | objects = []
108 | for i in raw_objects:
109 | i[0] = float(i[0])
110 | i[1] = float(i[1])
111 | i[2] = float(i[2])
112 | i[3] = float(i[3])
113 | class_name = i[4].strip()
114 | if class_name in class_name_to_id:
115 | # remove objects that are not in coco, these objects have class id but not appear in coco
116 | i[4] = int(class_name_to_id[class_name]) - 1
117 | objects.append(i)
118 | if len(objects) == 0:
119 | return None, None, None
120 | batch = []
121 | image_resizer = ImageResizer()
122 | new_h, new_w = image_resizer.get_proper_size(image_size)
123 | batch.append(torch.randn(1, 3, new_h, new_w).to(device))
124 | batch.append(torch.from_numpy(np.array(objects)).to(device).unsqueeze(0))
125 | batch.append((
126 | concatenate_class_labels_to_caption(objects, class_id_to_name, api_key, additional_caption),
127 | ))
128 | res = ddpm_model.test_step(batch, 0) # we pass a batch but only text and layout is used when sampling
129 | sampled_images = res['sampling']['model_output']
130 | return postprocess_image(sampled_images, batch[1], class_id_to_name, image_callback=lambda x: image_resizer.to_original_size(x))
131 |
132 |
133 | class ImageResizer:
134 | def __init__(self):
135 | self.original_size = None
136 |
137 | def to_proper_size(self, img):
138 | # Get the new height and width that can be divided by 64
139 | new_h, new_w = self.get_proper_size(img.shape[:2])
140 |
141 | # Resize the image using OpenCV's resize function
142 | resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
143 |
144 | return resized
145 |
146 | def to_original_size(self, img):
147 | # Resize the image to original size using OpenCV's resize function
148 | resized = cv2.resize(img, (self.original_size[1], self.original_size[0]), interpolation=cv2.INTER_AREA)
149 |
150 | return resized
151 |
152 | def get_proper_size(self, size):
153 | self.original_size = size
154 | # Calculate the new height and width that can be divided by 64
155 | if size[0] % 64 == 0:
156 | new_h = size[0]
157 | else:
158 | new_h = size[0] + (64 - size[0] % 64)
159 |
160 | if size[1] % 64 == 0:
161 | new_w = size[1]
162 | else:
163 | new_w = size[1] + (64 - size[1] % 64)
164 |
165 | return new_h, new_w
166 |
167 | def parse_test_args():
168 | import argparse
169 | import json
170 | parser = argparse.ArgumentParser()
171 | parser.add_argument(
172 | '-c', '--config', type=str,
173 | default='config/train.json')
174 | parser.add_argument(
175 | '-e', '--epoch', type=int,
176 | default=None, help='which epoch to evaluate, if None, will use the latest')
177 | parser.add_argument(
178 | '--openai_api_key', type=str,
179 | default=None, help='openai api key for generating text prompt')
180 | parser.add_argument(
181 | '--model_path', type=str,
182 | default=None, help='model path for generating layout diffuse, if not provided, will use the latest.ckpt')
183 | parser.add_argument(
184 | '--additional_caption', type=str,
185 | default='', help='additional caption for the generated image')
186 |
187 | ''' parser configs '''
188 | args_raw = parser.parse_args()
189 | with open(args_raw.config, 'r') as IN:
190 | args = json.load(IN)
191 | args.update(vars(args_raw))
192 | return args
193 |
194 | def load_test_models(args):
195 | from train_utils import get_models, get_DDPM
196 | models = get_models(args)
197 |
198 | diffusion_configs = args['diffusion']
199 | ddpm_model = get_DDPM(
200 | diffusion_configs=diffusion_configs,
201 | log_args=args,
202 | **models
203 | )
204 | return ddpm_model
205 |
206 | def load_model_weights(ddpm_model, args):
207 | print('INFO: loading checkpoint')
208 | if args['model_path'] is not None:
209 | ckpt_path = args['model_path']
210 | else:
211 | expt_path = os.path.join(args['expt_dir'], args['expt_name'])
212 | if args['epoch'] is None:
213 | ckpt_to_use = 'latest.ckpt'
214 | else:
215 | ckpt_to_use = f'epoch={args["epoch"]:04d}.ckpt'
216 | ckpt_path = os.path.join(expt_path, ckpt_to_use)
217 | print(ckpt_path)
218 | if os.path.exists(ckpt_path):
219 | print(f'INFO: Found checkpoint {ckpt_path}')
220 | ckpt = torch.load(ckpt_path, map_location='cpu')['state_dict']
221 | ddpm_model.load_state_dict(ckpt)
222 | else:
223 | ckpt_path = None
224 | raise RuntimeError('Cannot do inference without pretrained checkpoint')
--------------------------------------------------------------------------------
/train_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | # from modules.vae.vae import BetaVAE
3 | from pytorch_lightning.loggers import WandbLogger
4 | from callbacks import get_epoch_checkpoint, get_latest_checkpoint, get_iteration_checkpoint
5 | from model_utils import instantiate_from_config, get_obj_from_str
6 |
7 | def get_models(args):
8 | denoise_model = args['denoising_model']['model']
9 | denoise_args = args['denoising_model']['model_args']
10 | denoise_fn = instantiate_from_config({
11 | 'target': denoise_model,
12 | 'params': denoise_args
13 | })
14 | model_dict = {
15 | 'denoise_fn': denoise_fn,
16 | }
17 |
18 | if args.get('vqvae_model'):
19 | vq_model = args['vqvae_model']['model']
20 | vq_args = args['vqvae_model']['model_args']
21 | vqvae_fn = instantiate_from_config({
22 | 'target': vq_model,
23 | 'params': vq_args
24 | })
25 |
26 | model_dict['vqvae_fn'] = vqvae_fn
27 |
28 | if args.get('text_model'):
29 | text_model = args['text_model']['model']
30 | text_args = args['text_model']['model_args']
31 | text_fn = instantiate_from_config({
32 | 'target': text_model,
33 | 'params': text_args
34 | })
35 |
36 | model_dict['text_fn'] = text_fn
37 |
38 | return model_dict
39 |
40 | def get_DDPM(diffusion_configs, log_args={}, **models):
41 | diffusion_model_class = diffusion_configs['model']
42 | diffusion_args = diffusion_configs['model_args']
43 | DDPM_model = get_obj_from_str(diffusion_model_class)
44 | ddpm_model = DDPM_model(
45 | log_args=log_args,
46 | **models,
47 | **diffusion_args
48 | )
49 | return ddpm_model
50 |
51 |
52 | def get_logger_and_callbacks(expt_name, expt_path, args):
53 | callbacks = []
54 | # 3.1 checkpoint callbacks
55 | save_model_config = args.get('save_model_config', {})
56 | epoch_checkpoint = get_epoch_checkpoint(expt_path, **save_model_config)
57 | latest_checkpoint = get_latest_checkpoint(expt_path)
58 | callbacks.append(epoch_checkpoint)
59 | callbacks.append(latest_checkpoint)
60 |
61 | # 3.2 wandb logger
62 | wandb_logger = WandbLogger(
63 | project=expt_name,
64 | )
65 | iteration_callbacks = args.get('iteration_callbacks')
66 | if iteration_callbacks:
67 | callbacks.append(get_iteration_checkpoint(expt_path))
68 | config_callbacks = args.get('callbacks')
69 | if config_callbacks is not None:
70 | for callback in config_callbacks:
71 | print(f'Initiate callback {callback}')
72 | callbacks.append(
73 | get_obj_from_str(callback)(
74 | wandb_logger=wandb_logger,
75 | max_num_images=8
76 | )
77 | )
78 | else:
79 | from callbacks import WandBImageLogger
80 | print(f'INFO: got {expt_name}, will use default image logger')
81 | wandb_callback = WandBImageLogger(
82 | wandb_logger=wandb_logger,
83 | max_num_images=8
84 | )
85 | callbacks.append(wandb_callback)
86 |
87 | return wandb_logger, callbacks
88 |
89 | if os.path.exists('negative/EasyNegative.safetensors'):
90 | from safetensors import safe_open
91 | with safe_open('negative/EasyNegative.safetensors', framework="pt", device="cpu") as f:
92 | NEGATIVE_PROMPTS_EMBEDDINGS = f.get_tensor('emb_params')
93 | else:
94 | NEGATIVE_PROMPTS_EMBEDDINGS = None
95 | NEGATIVE_PROMPTS = "(((deformed))), blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar, multiple breasts, (mutated hands and fingers:1.5), (long body :1.3), (mutation, poorly drawn :1.2), black-white, bad anatomy, liquid body, liquidtongue, disfigured, malformed, mutated, anatomical nonsense, text font ui, error, malformed hands, long neck, blurred, lowers, low res, bad anatomy, bad proportions, bad shadow, uncoordinated body, unnatural body, fused breasts, bad breasts, huge breasts, poorly drawn breasts, extra breasts, liquid breasts, heavy breasts, missingbreasts, huge haunch, huge thighs, huge calf, bad hands, fused hand, missing hand, disappearing arms, disappearing thigh, disappearing calf, disappearing legs, fusedears, bad ears, poorly drawn ears, extra ears, liquid ears, heavy ears, missing ears, old photo, low res, black and white, black and white filter, colorless"
96 |
97 | def obtain_state_dict_key_mapping(key_in_layout_diffuse):
98 | key_only_in_layout_diffuse = False
99 | if key_in_layout_diffuse == 'output_blocks.5.3.conv.weight':
100 | key_in_foundational_model = 'output_blocks.5.2.conv.weight'
101 | elif key_in_layout_diffuse == 'output_blocks.5.3.conv.bias':
102 | key_in_foundational_model = 'output_blocks.5.2.conv.bias'
103 | elif key_in_layout_diffuse == 'output_blocks.8.3.conv.weight':
104 | key_in_foundational_model = 'output_blocks.8.2.conv.weight'
105 | elif key_in_layout_diffuse == 'output_blocks.8.3.conv.bias':
106 | key_in_foundational_model = 'output_blocks.8.2.conv.bias'
107 | else:
108 | key_in_foundational_model = key_in_layout_diffuse
109 | key_only_in_layout_diffuse = True
110 | return key_in_foundational_model, key_only_in_layout_diffuse
--------------------------------------------------------------------------------