├── .gitattributes ├── .gitignore ├── DDIM_ldm ├── DDIMSampler.py ├── DDIM_ldm.py ├── DDIM_ldm_celeb.py ├── DDIM_ldm_coco.py ├── PLMSSampler.py └── README.md ├── README.md ├── callbacks ├── README.md ├── __init__.py ├── celeb_mask │ ├── celeb_mask_loss_fn.py │ ├── celeb_mask_wandb.py │ └── sampling_save_fig.py ├── checkpoint.py ├── coco_layout │ ├── sampling_save_fig.py │ └── wandb.py ├── sampling_save_fig.py ├── schedule_sampler.py ├── utils.py └── wandb.py ├── configs ├── celeb_mask.json ├── cocostuff.json ├── cocostuff_SD1_5.json ├── cocostuff_SD1_5_merge_model.json ├── cocostuff_SD2_1.json ├── cocostuff_no_text.json └── vg.json ├── data ├── __init__.py ├── coco_w_stuff.py ├── face_parsing.py ├── instances_val2017.json ├── random_sampling.py ├── stuff_val2017.json ├── vg.py └── vg_splits.json ├── fid_eval.py ├── figures ├── LD_gradio_demo.gif ├── LD_interacitve_demo.gif └── teaser.png ├── image_editing.ipynb ├── interactive_plotting ├── app.py ├── static │ ├── css │ │ └── style.css │ ├── doc │ │ └── labels.txt │ └── js │ │ └── script.js └── templates │ └── index.html ├── main.py ├── model_utils.py ├── modules ├── bert │ ├── bert_embedder.py │ └── x_transformer.py ├── kl_autoencoder │ └── autoencoder.py ├── openai_unet │ ├── attention.py │ ├── instance_prompt_attention.py │ ├── openaimodel.py │ ├── openaimodel_layout_diffuse.py │ ├── openaimodel_partial_attn_with_text_branch.py │ ├── partial_attention.py │ └── util.py ├── openclip │ └── modules.py └── vqvae │ ├── autoencoder.py │ └── model.py ├── pretrained_models ├── LAION_text2img │ ├── split_model.py │ └── txt2img-1p4B-eval.yaml ├── SD1_5 │ └── split_model.py ├── SD2_1 │ └── split_model.py ├── anything4_5 │ └── split_model.py ├── celeba256 │ ├── config.yaml │ ├── split_model.py │ └── split_model_weights.py ├── counterfeitV25 │ └── split_model.py └── negative │ └── EasyNegative.safetensors ├── requirements.txt ├── run_gradio.py ├── run_gradio_merge.py ├── sampling.ipynb ├── sampling.py ├── sampling_in_background.py ├── scripts ├── convert_jpg.py ├── convert_npz_to_npy.py ├── download_celebMask.sh ├── download_coco.sh ├── download_pretrained_models.sh ├── download_vg.sh ├── eval_scripts │ ├── celeb_mask.sh │ ├── convert_npz_to_npy.sh │ └── fid_coco_layout_ablation.sh ├── preprocess_vg.py ├── remove_empty_file_in_vg.py ├── resize_images.py ├── sampling_scripts │ └── dist_sampling.sh └── train_scripts │ └── dist_train.sh ├── test_utils.py └── train_utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | experiments/cocostuff_LayoutDiffuse_SD2_1/latest.ckpt filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # myself 2 | experiments/* 3 | datasets/* 4 | !experiments/clean.sh 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *__pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | wandb/ 138 | lightning_logs/ 139 | pretrained_models/**/*.zip 140 | pretrained_models/**/*.ckpt 141 | src/ 142 | tmp/ 143 | sync_to_ec2.sh -------------------------------------------------------------------------------- /DDIM_ldm/DDIMSampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | from model_utils import default, make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, make_beta_schedule, extract_into_tensor 5 | 6 | class DDIMSampler(object): 7 | def __init__(self, model, beta_schedule_args={ 8 | "schedule": "linear", 9 | "n_timestep": 1000, 10 | "linear_start": 0.0015, 11 | "linear_end": 0.0195 12 | }, training_target='noise'): 13 | super().__init__() 14 | self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 15 | self.model = model 16 | self.ddpm_num_timesteps = beta_schedule_args['n_timestep'] 17 | self.make_full_schedule(**beta_schedule_args) 18 | self.training_target = training_target 19 | 20 | def register_buffer(self, name, attr): 21 | if type(attr) == torch.Tensor: 22 | if attr.device != torch.device("cuda") and self.device == torch.device("cuda"): 23 | attr = attr.to(torch.device("cuda")) 24 | setattr(self, name, attr) 25 | 26 | def make_full_schedule(self, **beta_schedule_args): 27 | betas = make_beta_schedule(**beta_schedule_args) 28 | alphas = 1. - betas 29 | alphas_cumprod = np.cumprod(alphas, axis=0) 30 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 32 | to_torch = lambda x: torch.tensor(x).to(torch.float32).to(self.device) 33 | 34 | self.register_buffer('betas', to_torch(betas)) 35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 36 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 37 | 38 | # calculations for diffusion q(x_t | x_{t-1}) and others 39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 44 | 45 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 46 | self.ddim_timesteps = make_ddim_timesteps( 47 | ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 48 | num_ddpm_timesteps=self.ddpm_num_timesteps, 49 | verbose=verbose 50 | ) 51 | 52 | # ddim sampling parameters 53 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( 54 | alphacums=self.alphas_cumprod.cpu(), 55 | ddim_timesteps=self.ddim_timesteps, 56 | eta=ddim_eta,verbose=verbose 57 | ) 58 | self.register_buffer('ddim_sigmas', ddim_sigmas) 59 | self.register_buffer('ddim_alphas', ddim_alphas) 60 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 61 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 62 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 63 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 64 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 65 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 66 | 67 | def q_sample(self, x_start, t, noise=None): 68 | noise = default(noise, lambda: torch.randn_like(x_start)) 69 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 70 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) 71 | 72 | def predict_start_from_z_and_v(self, y_t, t, v): 73 | return ( 74 | extract_into_tensor(self.sqrt_alphas_cumprod, t, y_t.shape) * y_t - 75 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, y_t.shape) * v 76 | ) 77 | 78 | def predict_eps_from_z_and_v(self, y_t, t, v): 79 | return ( 80 | extract_into_tensor(self.sqrt_alphas_cumprod, t, y_t.shape) * v + 81 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, y_t.shape) * y_t 82 | ) 83 | 84 | @torch.no_grad() 85 | def sample(self, 86 | S, 87 | batch_size, 88 | shape, 89 | eta=0., 90 | verbose=True, 91 | x_T=None, 92 | log_every_t=100, 93 | model_kwargs={}, 94 | uncondition_model_kwargs=None, 95 | guidance_scale=1. 96 | ): 97 | 98 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 99 | # sampling 100 | if shape is None: 101 | shape = x_T.shape[1:] 102 | C, H, W = shape 103 | size = (batch_size, C, H, W) 104 | 105 | samples, intermediates = self.ddim_sampling( 106 | size, 107 | ddim_use_original_steps=False, 108 | x_T=x_T, 109 | log_every_t=log_every_t, 110 | model_kwargs=model_kwargs, 111 | uncondition_model_kwargs=uncondition_model_kwargs, 112 | guidance_scale=guidance_scale 113 | ) 114 | return samples, intermediates 115 | 116 | @torch.no_grad() 117 | def ddim_sampling( 118 | self, 119 | shape, 120 | x_T=None, 121 | ddim_use_original_steps=False, 122 | timesteps=None, 123 | log_every_t=100, 124 | model_kwargs={}, 125 | uncondition_model_kwargs=None, 126 | guidance_scale=1. 127 | ): 128 | device = self.device 129 | b = shape[0] 130 | if x_T is None: 131 | img = torch.randn(shape, device=device) 132 | else: 133 | img = x_T 134 | 135 | if timesteps is None: 136 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 137 | elif timesteps is not None and not ddim_use_original_steps: 138 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 139 | timesteps = self.ddim_timesteps[:subset_end] 140 | 141 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 142 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 143 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 144 | print(f"Running DDIM Sampling with {total_steps} timesteps") 145 | 146 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 147 | 148 | # if guidance_scale > 1 and uncondition_model_kwargs is not None: 149 | # print(f'INFO: guidance scale {guidance_scale} during classifier free guidance with {uncondition_model_kwargs}') 150 | for i, step in enumerate(iterator): 151 | index = total_steps - i - 1 152 | ts = torch.full((b,), step, device=device, dtype=torch.long) 153 | 154 | outs = self.p_sample_ddim( 155 | img, ts, 156 | index=index, 157 | use_original_steps=ddim_use_original_steps, 158 | model_kwargs=model_kwargs, 159 | uncondition_model_kwargs=uncondition_model_kwargs, 160 | guidance_scale=guidance_scale 161 | ) 162 | img, pred_x0 = outs 163 | 164 | if index % log_every_t == 0 or index == total_steps - 1: 165 | intermediates['x_inter'].append(img) 166 | intermediates['pred_x0'].append(pred_x0) 167 | 168 | return img, intermediates 169 | 170 | @torch.no_grad() 171 | def p_sample_ddim( 172 | self, x, t, index, 173 | repeat_noise=False, 174 | use_original_steps=False, 175 | model_kwargs={}, 176 | uncondition_model_kwargs=None, 177 | guidance_scale=1. 178 | ): 179 | b, *_, device = *x.shape, x.device 180 | 181 | def get_model_output(x, t): 182 | model_output = self.model(x, t, **model_kwargs) 183 | if uncondition_model_kwargs is not None and guidance_scale > 1.: 184 | model_output_uncond = self.model(x, t, **uncondition_model_kwargs) 185 | model_output = model_output_uncond + guidance_scale * (model_output - model_output_uncond) 186 | 187 | if self.training_target == "v": 188 | e_t = self.predict_eps_from_z_and_v(x, t, model_output) 189 | else: 190 | e_t = model_output 191 | return e_t, model_output 192 | 193 | e_t, model_output = get_model_output(x, t) 194 | 195 | alphas = self.alphas_cumprod if use_original_steps else self.ddim_alphas 196 | alphas_prev = self.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 197 | sqrt_one_minus_alphas = self.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 198 | sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 199 | # select parameters corresponding to the currently considered timestep 200 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 201 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 202 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 203 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 204 | 205 | # current prediction for x_0 206 | if self.training_target != "v": 207 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 208 | else: 209 | pred_x0 = self.predict_start_from_z_and_v(x, t, model_output) 210 | # direction pointing to x_t 211 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 212 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) 213 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 214 | return x_prev, pred_x0 215 | -------------------------------------------------------------------------------- /DDIM_ldm/DDIM_ldm_celeb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from .DDIM_ldm import DDIM_LDM_VQVAETraining 4 | 5 | class DDIM_LDM_pretrained_celeb(DDIM_LDM_VQVAETraining): 6 | def process_batch(self, batch, mode='train'): 7 | return super().process_batch(batch['image'], mode) 8 | 9 | class DDIM_LDM_LayoutDiffuse_celeb_mask(DDIM_LDM_pretrained_celeb): 10 | def __init__(self, *args, freeze_pretrained_weights=True, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | self.freeze_pretrained_weights = freeze_pretrained_weights 13 | 14 | def process_batch(self, batch, mode='train'): 15 | y_t, target, t, x_0, model_kwargs = super().process_batch(batch, mode) 16 | model_kwargs.update({'context': { 17 | 'layout': batch['seg_mask'] 18 | }}) 19 | return y_t, target, t, x_0, model_kwargs 20 | 21 | def initialize_unet(self, unet_init_weights): 22 | if unet_init_weights is not None: 23 | if os.path.exists(unet_init_weights): 24 | print(f'INFO: initialize denoising UNet from {unet_init_weights}, NOTE: without partial attention layers') 25 | model_sd = torch.load(unet_init_weights) 26 | self_model_sd = self.denoise_fn.state_dict() 27 | self_model_params = list(self.denoise_fn.named_parameters()) 28 | self_model_k = list(map(lambda x: x[0], self_model_params)) 29 | self.params_not_pretrained = [] 30 | k_idx = 0 31 | for model_layer_idx, (model_k, model_v) in enumerate(model_sd.items()): 32 | while (self_model_params[k_idx][1].shape != model_v.shape) or (model_k.split('.')[0:2] != self_model_k[k_idx].split('.')[0:2]): 33 | self.params_not_pretrained.append(self_model_params[k_idx][1]) 34 | k_idx += 1 35 | self_model_sd[self_model_k[k_idx]] = model_v 36 | k_idx += 1 37 | self.denoise_fn.load_state_dict(self_model_sd) 38 | else: 39 | print(f'WARNING: cannot find pretrained weights {unet_init_weights}, initialize from scratch') 40 | 41 | def training_step(self, batch, batch_idx): 42 | self.clip_denoised = False # during training do not clip to -1 to 1 to prevent grad detached 43 | y_t, y_target, t, raw_image, model_kwargs = self.process_batch(batch, mode='train') 44 | pred = self.denoise_fn(y_t, t, **model_kwargs) 45 | loss, loss_simple, loss_vlb = self.get_loss(pred, y_target, t) 46 | with torch.no_grad(): 47 | if self.training_target == 'noise': 48 | y_0_hat = self.predict_start_from_noise( 49 | y_t, t=t, 50 | noise=pred.detach() 51 | ) 52 | else: 53 | y_0_hat = pred.detach() 54 | 55 | self.log(f'train_loss', loss) 56 | self.log(f'train_loss_simple', loss_simple) 57 | self.log(f'train_loss_vlb', loss_vlb) 58 | if self.learn_logvar: 59 | self.log(f'logvar', self.logvar.data.mean()) 60 | return { 61 | 'loss': loss, 62 | 'raw_image': raw_image, 63 | 'model_input': y_t, 64 | 'model_output': pred, 65 | 'y_0_hat': self.decode_latent_to_image(y_0_hat) 66 | } 67 | 68 | def validation_step(self, batch, batch_idx): 69 | y_t, _, _, y_0_image, model_kwargs = self.process_batch(batch, mode='val') 70 | restored = self.sampling(noise=y_t, model_kwargs=model_kwargs) 71 | sampled = self.sampling(noise=torch.randn_like(y_t), model_kwargs=model_kwargs) 72 | return { 73 | 'y_0_image': y_0_image, 74 | 'restore': restored, 75 | 'sampling': sampled 76 | } 77 | 78 | @torch.no_grad() 79 | def test_step(self, batch, batch_idx): 80 | y_t, _, _, y_0_image, model_kwargs = self.process_batch(batch, mode='val') 81 | sampled = self.sampling(noise=torch.randn_like(y_t), model_kwargs=model_kwargs) 82 | return { 83 | 'sampling': sampled 84 | } 85 | 86 | def configure_optimizers(self): 87 | if self.freeze_pretrained_weights: 88 | assert hasattr(self, 'params_not_pretrained') 89 | print('INFO: pretrained weights are not trainable') 90 | params = self.params_not_pretrained 91 | else: 92 | params = list(self.denoise_fn.parameters()) 93 | if self.learn_logvar: 94 | params = params + [self.logvar] 95 | optimizer = torch.optim.Adam(params, **self.optim_args) 96 | return optimizer -------------------------------------------------------------------------------- /DDIM_ldm/DDIM_ldm_coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from .DDIM_ldm import DDIM_LDM_VQVAETraining, DDIM_LDM_Text_VQVAETraining 4 | from train_utils import obtain_state_dict_key_mapping 5 | 6 | class DDIM_LDM_pretrained_COCO(DDIM_LDM_VQVAETraining): 7 | def initialize_unet(self, unet_init_weights): 8 | if unet_init_weights is not None: 9 | if os.path.exists(unet_init_weights): 10 | print(f'INFO: initialize denoising UNet from {unet_init_weights}, NOTE: without partial attention layers') 11 | model_sd = torch.load(unet_init_weights) 12 | self_model_sd = self.denoise_fn.state_dict() 13 | self_model_params = list(self.denoise_fn.named_parameters()) 14 | self_model_k = list(map(lambda x: x[0], self_model_params)) 15 | self.params_not_pretrained = [] 16 | self.params_pretrained = [] 17 | for k_idx in range(len(self_model_k)): 18 | this_k = self_model_k[k_idx] 19 | if this_k not in model_sd: 20 | key_in_foundational_model, key_only_in_layout_diffuse = obtain_state_dict_key_mapping(this_k) 21 | if key_only_in_layout_diffuse: 22 | self.params_not_pretrained.append(self_model_params[k_idx][1]) 23 | else: 24 | self_model_sd[this_k] = model_sd[key_in_foundational_model] 25 | self.params_pretrained.append(self_model_params[k_idx][1]) 26 | elif (self_model_sd[this_k].shape == model_sd[this_k].shape) or (self_model_sd[this_k].shape == model_sd[this_k].squeeze(-1).shape): 27 | self_model_sd[this_k] = model_sd[this_k] 28 | self.params_pretrained.append(self_model_params[k_idx][1]) 29 | else: 30 | self.params_not_pretrained.append(self_model_params[k_idx][1]) 31 | 32 | self.denoise_fn.load_state_dict(self_model_sd) 33 | else: 34 | print(f'WARNING: {unet_init_weights} does not exist, initialize denoising UNet randomly') 35 | 36 | def configure_optimizers(self): 37 | if self.freeze_pretrained_weights: 38 | assert hasattr(self, 'params_not_pretrained') 39 | print('INFO: pretrained weights are not trainable') 40 | params = self.params_not_pretrained 41 | else: 42 | params = list(self.denoise_fn.parameters()) 43 | if self.learn_logvar: 44 | params = params + [self.logvar] 45 | optimizer = torch.optim.Adam(params, **self.optim_args) 46 | return optimizer 47 | 48 | @torch.no_grad() 49 | def fast_sampling(self, noise, model_kwargs={}): 50 | y_0, y_t_hist = super().fast_sampling( 51 | noise, 52 | model_kwargs=model_kwargs, 53 | uncondition_model_kwargs={'context': torch.empty((1, 0, 5)).to(noise.device)} 54 | ) 55 | return y_0, y_t_hist 56 | 57 | @torch.no_grad() 58 | def test_step(self, batch, batch_idx): 59 | y_t, _, _, y_0_image, model_kwargs = self.process_batch(batch, mode='val') 60 | sampled = self.sampling(noise=torch.randn_like(y_t), model_kwargs=model_kwargs) 61 | return { 62 | 'sampling': sampled 63 | } 64 | 65 | class DDIM_LDM_pretrained_COCO_instance_prompt(DDIM_LDM_pretrained_COCO): 66 | def process_batch(self, batch, mode='train'): 67 | y_t, target, t, y_0, model_kwargs = super().process_batch(batch[0], mode) 68 | model_kwargs.update({'context': torch.tensor(batch[1])}) 69 | return y_t, target, t, y_0, model_kwargs 70 | 71 | class DDIM_LDM_LAION_pretrained_COCO(DDIM_LDM_pretrained_COCO): 72 | def training_step(self, batch, batch_idx): 73 | res_dict = super().training_step(batch, batch_idx) 74 | res_dict['model_input'] = res_dict['model_input'][:, :3] # the LAION pretrained model has 4 channels, for visualization with wandb, we only keep the first 3 channels 75 | return res_dict 76 | 77 | class DDIM_LDM_LAION_pretrained_COCO_instance_prompt(DDIM_LDM_LAION_pretrained_COCO): 78 | def process_batch(self, batch, mode='train'): 79 | y_t, target, t, y_0, model_kwargs = super().process_batch(batch[0], mode) 80 | model_kwargs.update({'context': { 81 | 'layout':torch.tensor(batch[1]) 82 | }}) 83 | return y_t, target, t, y_0, model_kwargs 84 | 85 | class DDIM_LDM_LAION_Text(DDIM_LDM_Text_VQVAETraining): 86 | def initialize_unet(self, unet_init_weights): 87 | if unet_init_weights is not None: 88 | if os.path.exists(unet_init_weights): 89 | print(f'INFO: initialize denoising UNet from {unet_init_weights}, NOTE: without partial attention layers') 90 | model_sd = torch.load(unet_init_weights) 91 | self_model_sd = self.denoise_fn.state_dict() 92 | self_model_params = list(self.denoise_fn.named_parameters()) 93 | self_model_k = list(map(lambda x: x[0], self_model_params)) 94 | self.params_not_pretrained = [] 95 | self.params_pretrained = [] 96 | for k_idx in range(len(self_model_k)): 97 | this_k = self_model_k[k_idx] 98 | if this_k not in model_sd: 99 | key_in_foundational_model, key_only_in_layout_diffuse = obtain_state_dict_key_mapping(this_k) 100 | if key_only_in_layout_diffuse: 101 | self.params_not_pretrained.append(self_model_params[k_idx][1]) 102 | else: 103 | self_model_sd[this_k] = model_sd[key_in_foundational_model] 104 | self.params_pretrained.append(self_model_params[k_idx][1]) 105 | elif (self_model_sd[this_k].shape == model_sd[this_k].shape) or (self_model_sd[this_k].shape == model_sd[this_k].squeeze(-1).shape): 106 | self_model_sd[this_k] = model_sd[this_k] 107 | self.params_pretrained.append(self_model_params[k_idx][1]) 108 | else: 109 | self.params_not_pretrained.append(self_model_params[k_idx][1]) 110 | 111 | self.denoise_fn.load_state_dict(self_model_sd) 112 | else: 113 | print(f'WARNING: cannot find {unet_init_weights}, skip initialization') 114 | 115 | def process_batch(self, batch, mode='train'): 116 | y_t, target, t, y_0, model_kwargs = super().process_batch(batch[0], mode) 117 | model_kwargs.update({'context': { 118 | 'layout': torch.tensor(batch[1]), 119 | 'text': self.encode_text(batch[2]) 120 | }}) 121 | return y_t, target, t, y_0, model_kwargs 122 | 123 | def training_step(self, batch, batch_idx): 124 | res_dict = super().training_step(batch, batch_idx) 125 | res_dict['model_input'] = res_dict['model_input'][:, :3] # the LAION pretrained model has 4 channels, for visualization with wandb, we only keep the first 3 channels 126 | return res_dict 127 | 128 | @torch.no_grad() 129 | def fast_sampling(self, noise, model_kwargs={}): 130 | from train_utils import NEGATIVE_PROMPTS, NEGATIVE_PROMPTS_EMBEDDINGS 131 | if NEGATIVE_PROMPTS_EMBEDDINGS is not None: 132 | negative = torch.stack([NEGATIVE_PROMPTS_EMBEDDINGS]*noise.shape[0], dim=0) 133 | else: 134 | negative = self.encode_text([NEGATIVE_PROMPTS]) 135 | y_0, y_t_hist = super().fast_sampling( 136 | noise, 137 | model_kwargs=model_kwargs, 138 | uncondition_model_kwargs={'context': { 139 | 'layout': torch.empty((1, 0, 5)).to(noise.device), 140 | 'text': negative.to(noise.device) 141 | } 142 | } 143 | ) 144 | return y_0, y_t_hist 145 | 146 | class DDIM_LDM_LAION_Text_CKPT_Merge(DDIM_LDM_LAION_Text): 147 | def merge(self, ckpt_path, alpha, interp="sigmoid"): 148 | if interp == "sigmoid": 149 | theta_func = DDIM_LDM_LAION_Text_CKPT_Merge.sigmoid 150 | elif interp == "inv_sigmoid": 151 | theta_func = DDIM_LDM_LAION_Text_CKPT_Merge.inv_sigmoid 152 | elif interp == "add_diff": 153 | theta_func = DDIM_LDM_LAION_Text_CKPT_Merge.add_difference 154 | else: 155 | theta_func = DDIM_LDM_LAION_Text_CKPT_Merge.weighted_sum 156 | HACK_LAYERS_IN_LAYOUT_DIFFUSE = [ 157 | 'output_blocks.5.3.conv.weight', 158 | 'output_blocks.5.3.conv.bias', 159 | 'output_blocks.8.3.conv.weight', 160 | 'output_blocks.8.3.conv.bias' 161 | ] # these layers need to find the corresponding layer in the foundational model with another name 162 | HACK_LAYERS_NEED_TO_BE_IGNORED = [ 163 | 'output_blocks.5.2.conv.weight', 164 | 'output_blocks.5.2.conv.bias', 165 | 'output_blocks.8.2.conv.weight', 166 | 'output_blocks.8.2.conv.bias' 167 | ] # these layers need to be ignored because they are trained with the layout diffuse layers 168 | self_state_dict = self.denoise_fn.state_dict() 169 | new_state_dict = torch.load(ckpt_path, map_location=self.device) 170 | for k in self_state_dict: 171 | if k in HACK_LAYERS_IN_LAYOUT_DIFFUSE: 172 | key_in_foundational_model, _ = obtain_state_dict_key_mapping(k) 173 | theta0 = self_state_dict[k] 174 | theta1 = new_state_dict[key_in_foundational_model] 175 | elif (k in new_state_dict) and (k not in HACK_LAYERS_NEED_TO_BE_IGNORED): 176 | theta0 = self_state_dict[k] 177 | theta1 = new_state_dict[k] 178 | else: 179 | # dummy, do nothing, the layout diffuse layers will go through here 180 | print('INFO: skip layer', k) 181 | continue 182 | self_state_dict[k] = theta_func(theta0, theta1, None, alpha) 183 | 184 | del new_state_dict 185 | 186 | @staticmethod 187 | def weighted_sum(theta0, theta1, theta2, alpha): 188 | return ((1 - alpha) * theta0) + (alpha * theta1) 189 | 190 | # Smoothstep (https://en.wikipedia.org/wiki/Smoothstep) 191 | @staticmethod 192 | def sigmoid(theta0, theta1, theta2, alpha): 193 | alpha = alpha * alpha * (3 - (2 * alpha)) 194 | return theta0 + ((theta1 - theta0) * alpha) 195 | 196 | # Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep) 197 | @staticmethod 198 | def inv_sigmoid(theta0, theta1, theta2, alpha): 199 | import math 200 | 201 | alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0) 202 | return theta0 + ((theta1 - theta0) * alpha) 203 | 204 | @staticmethod 205 | def add_difference(theta0, theta1, theta2, alpha): 206 | return theta0 + (theta1 - theta2) * (1.0 - alpha) -------------------------------------------------------------------------------- /DDIM_ldm/README.md: -------------------------------------------------------------------------------- 1 | ### Custom Training 2 | If you want to train on your dataset, you may need following knowledge 3 | #### 1 [main.py](main.py) 4 | The entrance of the program for training. It does following things: 5 | * Create denoising/vqvae/text models in the config json. The denoising/vqvae/text model is a regular `pytorch module`. 6 | * Create a DDIM training instance which is a `pytorch lightning module`. (e.g., the training class for COCO is `DDIM_LDM_LAION_Text`, you can find the class in json config file) 7 | * Prepare dataset and dataloader. 8 | * Create callbacks for checkpointing and visualization (see [callbacks README](callbacks/README.md) for details). 9 | * Create a `pytorch lightning` `Trainer` instance for training. 10 | 11 | #### 2 [Denoising model](modules) 12 | The denoising model is a [UNet model](modules/openai_unet/openaimodel_layout_diffuse.py) that takes layout information and (optional) text prompts. 13 | 14 | #### 3 [Latent diffusion model](DDIM_ldm) 15 | The folder contains the code for diffusion. 16 | Class [DDIM_LDM](DDIM_ldm/DDIM_ldm.py) contains the coefficients and functions for diffusion and denoising process. 17 | 18 | Class [DDIM_LDMTraining](DDIM_ldm/DDIM_ldm.py) contains the code for 19 | * Training (need to follow pl gramma) 20 | * Validation/testing (need to follow pl gramma) 21 | * Sampling 22 | * Initializing optimizer 23 | 24 | Class [DDIM_LDM_VQVAETraining](DDIM_ldm/DDIM_ldm.py) adds on VQVAE encoder and decoder. 25 | 26 | Class [DDIM_LDM_Text_VQVAETraining](DDIM_ldm/DDIM_ldm.py) adds on text model 27 | 28 | In most of the cases, you only need to overwrite the `DDIM_LDM_VQVAETraining` or `DDIM_LDM_Text_VQVAETraining` class for a customized training. 29 | 30 | You can see class `DDIM_LDM_LAION_Text` to understand how to derive these class for each dataset/task. 31 | 32 | Functions `trianing_step`, `validation_step` and `test_step` will return a dictionary. This dictonary will be the `outputs` arguments in the callback functions. You can use this dictionary for visualization e.t.c. 33 | 34 | 35 | #### 4 [Callbacks](callbacks) 36 | see callbacks' [readme](../callbacks/README.md) 37 | 38 | #### 5 [Data](data) 39 | The function of this folder is to return a training loader or validation loader. 40 | 41 | In most of the cases you can use the off-the-shelf datasets (e.g. official ones in `torchvison`). The only thing you need to modify is to overwrite the `process_batch()` funcation in `DDIM_LDM_VQVAETraining`. 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Code release for [LayoutDiffuse: Adapting Foundational Diffusion Models for Layout-to-Image Generation](https://github.com/cplusx/layout_diffuse) 2 | 3 | --- 4 | ![teaser](figures/teaser.png) 5 | ## 0. Installation 6 | ### 0.1 7 | Follow the official instruction from the [pytorch](https://pytorch.org/get-started/locally/). Install the correct pytorch according to your hardware. 8 | 9 | Then clone this repository and install the dependencies. 10 | ``` 11 | git clone https://github.com/cplusx/layout_diffuse.git 12 | pip install -r requirements.txt 13 | ``` 14 |
15 | Optional downloading 16 | 17 | #### 0.2 Download dataset (optional, only required for training) 18 | 19 | Download dataset. Run 20 | ``` 21 | bash scripts/download_coco.sh 22 | bash scripts/download_vg.sh 23 | bash scripts/download_celebMask.sh 24 | ``` 25 | 26 | This should create a folder in `~/disk2/data` and put all files in that folder. 27 | 28 | 29 | *Note for celebMask*: 30 | 1. You might have see problem when downloading celebMask dataset saying that the file has been downloaded too many times. You will need to go to their [website](https://github.com/switchablenorms/CelebAMask-HQ) to download. 31 | 2. You need to use this [script](https://github.com/switchablenorms/CelebAMask-HQ/blob/master/face_parsing/Data_preprocessing/g_mask.py) to convert celebMask part-ground-truth to mask-ground-truth. 32 | 33 | #### 0.3 (Optional, only required for training) 34 | Download foundational pretrained models, run 35 | ``` 36 | bash scripts/download_pretrained_models.sh {face|ldm|SD1_5|SD2_1|all} 37 | ``` 38 | 39 | #### 0.4 (Optional) log experiments with WandB 40 | The visualization depends on `wandb`, remember to set it on your server by `wandb login`. 41 | 42 |
43 | 44 | 45 | --- 46 | 47 | ## 1. Sampling with trained models 48 | 49 | Download model weights for [COCO backboned with SD2.1](https://huggingface.co/cplusx/LD/resolve/main/LD_SD2_1.ckpt), [COCO backboned with SD1.5 @ 20 epochs (still on going)](https://huggingface.co/cplusx/LD/resolve/main/LD_SD1_5.ckpt) 50 | 51 |
52 | **Benchmarking** results in the paper (COCO-LDM, VG-LDM, CelebMask) 53 | Download model weights [COCO](https://huggingface.co/cplusx/LD/resolve/main/cocostuff_ldm.ckpt), [VG](https://huggingface.co/cplusx/LD/resolve/main/vg_ldm.ckpt) or [celebMask](https://huggingface.co/cplusx/LD/resolve/main/celeb_mask.ckpt) and put weights under folder `experiments/{cocostuff/cocostuff_no_text/vg/celeb_mask}_LayoutDiffuse` 54 |
55 | 56 | --- 57 | 58 | ## There are three ways to sample from the model: 59 | 60 | 1. **Recommended**: using interactive webpage. This is the work around before Gradio supports bounding box input. You will need flask to run the server. To obtain better image quality, we use chatGPT to generate text prompts. You need to set up your OpenAI API key if you want to use. **NOTE**: If not providing openai api key, it will use default text prompt by concatenating the class labels (e.g. person, dog, car, etc.), the result may have semantic meaningless background. 61 | ![Interactive plotting](figures/LD_interacitve_demo.gif) 62 | ``` 63 | pip install flask 64 | python sampling_in_background.py -c configs/cocostuff_SD2_1.json --openai_api_key [OPENAI_API_KEY] --model_path [PATH_TO_MODEL, if not given, it will use the default path e.g., "experiments/cocostuff_LayoutDiffuse_SD2_1/latest.ckpt"] 65 | # open another terminal 66 | cd interactive_plotting 67 | export FLASK_APP=app.py 68 | flask run 69 | ``` 70 | 71 | 72 | 2. Use [Gradio](https://gradio.app/) to use LayoutDiffuse. Gradio has not supported bounding box input yet, so we current support to upload a reference image and generating an image with the same layout. The layout is detected by a YOLOv5 model. **NOTE**: If not providing openai api key, it will use default text prompt by concatenating the class labels (e.g. person, dog, car, etc.), the result may have semantic meaningless background 73 | ``` 74 | pip install gradio 75 | python run_gradio.py -c configs/cocostuff_SD2_1.json --openai_api_key [OPENAI_API_KEY] --model_path [PATH_TO_MODEL, if not given, it will use the default path e.g., "experiments/cocostuff_LayoutDiffuse_SD2_1/latest.ckpt"] 76 | ``` 77 | ![Gradio plotting](figures/LD_gradio_demo.gif) 78 | 79 | 3. Sampling many images (using COCO dataset) for benchmarking purpose. Replace `-c` with other config files to sample from other datasets. 80 | See [notebooks for single image sample](sampling.ipynb) or running sampling for the dataset 81 | ``` 82 | python sampling.py -c configs/cocostuff_SD2_1.json --model_path [PATH_TO_MODEL, if not given, it will use the default path e.g., "experiments/cocostuff_LayoutDiffuse_SD2_1/latest.ckpt"] 83 | ``` 84 | 85 | --- 86 | 87 | ### 2. Training 88 | ``` 89 | python main.py -c configs/cocostuff_SD1_5.json 90 | ``` 91 | You can change the config files to other dataset in `configs` 92 | 93 | --- 94 | 95 | ### 3. Training on custom data 96 | See [code structure](DDIM_ldm/README.md) for details. 97 | --- 98 | 99 | This code is developed using a variety of resources from [this repository](https://github.com/lucidrains/denoising-diffusion-pytorch) -------------------------------------------------------------------------------- /callbacks/README.md: -------------------------------------------------------------------------------- 1 | ### [Checkpoint callbacks](checkpoint.py) 2 | * get_epoch_checkpoint: save checkpoint every `n` epochs, name after `epoch={:04d}.ckpt` 3 | * get_latest_checkpoint: save the lastest checkpoint, name after `latest.ckpt` 4 | 5 | ### Image saving callbacks 6 | This includes `sampling_save_fig.py`, `coco_layoutsampling_save_fig.py` and `celeb_mask/sampling_save_fig.py`. These callbacks are used to save images during sampling (the output of `validation_step()` will be passed to these callbacks.) 7 | 8 | ### [WandB callbacks](wandb.py) 9 | Visualize the input and output images 10 | The `outputs` argument is a dictionary that contains return from `train_step()` -------------------------------------------------------------------------------- /callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .wandb import * 2 | from .checkpoint import * -------------------------------------------------------------------------------- /callbacks/celeb_mask/celeb_mask_loss_fn.py: -------------------------------------------------------------------------------- 1 | '''Change the loss weight of background as the training continues 2 | At the beginning of each epoch, update loss function 3 | ''' 4 | from pytorch_lightning.callbacks import Callback 5 | from functools import partial 6 | 7 | class CelebBackgroundLossWeightTuner(Callback): 8 | def __init__(self, max_epochs, loss_fn, num_classes): 9 | super().__init__() 10 | self.max_epochs = max_epochs 11 | self.loss_fn = loss_fn 12 | self.num_classes = num_classes 13 | 14 | def on_epoch_start(self, trainer, pl_module): 15 | current_epoch = pl_module.current_epoch 16 | this_bg_weight = current_epoch / self.max_epochs 17 | print(f'INFO: set background weight to {this_bg_weight}') 18 | this_loss_fn = partial( 19 | self.loss_fn, 20 | background_weight=this_bg_weight, 21 | num_classes=self.num_classes 22 | ) 23 | pl_module.loss_fn = this_loss_fn 24 | return super().on_epoch_start(trainer, pl_module) -------------------------------------------------------------------------------- /callbacks/celeb_mask/celeb_mask_wandb.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | from ..wandb import WandBImageLogger, clip_image, unnorm 3 | from ..utils import to_mask_if_dim_gt_3 4 | # colorize mask is included in the to_mask... func 5 | 6 | class CelebMaskWandBImageLogger(WandBImageLogger): 7 | def on_train_batch_end( 8 | self, trainer, pl_module, 9 | outputs, batch, batch_idx 10 | ): 11 | if batch_idx == 0: 12 | # if raw_image is one-hot mask, it needs to be map to 1 dim for visualization 13 | raw_image = self.tensor2image(clip_image(unnorm( 14 | outputs['raw_image'][:self.max_num_images] 15 | ))) 16 | raw_mask = self.tensor2image(to_mask_if_dim_gt_3(clip_image(unnorm( 17 | outputs['raw_mask'][:self.max_num_images] 18 | )))) 19 | 20 | model_input = clip_image(unnorm( 21 | outputs['model_input'][:self.max_num_images] 22 | )) 23 | model_input_image = self.tensor2image(model_input[:, :3]) 24 | model_input_mask = self.tensor2image(to_mask_if_dim_gt_3( 25 | model_input[:, 3:] 26 | )) 27 | 28 | model_output = clip_image(unnorm( 29 | outputs['model_output'][:self.max_num_images] 30 | )) 31 | model_output_image = self.tensor2image(model_output[:, :3]) 32 | model_output_mask = self.tensor2image(to_mask_if_dim_gt_3( 33 | model_output[:, 3:] 34 | )) 35 | 36 | y_0_hat = clip_image(unnorm( 37 | outputs['y_0_hat'][:self.max_num_images] 38 | )) 39 | y_0_hat_image = self.tensor2image(y_0_hat[:, :3]) 40 | y_0_hat_mask = self.tensor2image(to_mask_if_dim_gt_3( 41 | y_0_hat[:, 3:] 42 | )) 43 | 44 | self.wandb_logger.experiment.log({ 45 | 'train/raw_image': raw_image, 46 | 'train/raw_mask': raw_mask, 47 | 'train/model_input_image': model_input_image, 48 | 'train/model_input_mask': model_input_mask, 49 | 'train/model_output_image': model_output_image, 50 | 'train/model_output_mask': model_output_mask, 51 | 'train/y_0_hat_image': y_0_hat_image, 52 | 'train/y_0_hat_mask': y_0_hat_mask 53 | }) 54 | 55 | 56 | def on_validation_batch_end( 57 | self, trainer, pl_module, 58 | outputs, batch, batch_idx, 59 | dataloader_idx 60 | ): 61 | if batch_idx == 0: 62 | y_0_image = outputs.pop('y_0_image') 63 | y_0_mask = outputs.pop('y_0_mask') 64 | 65 | y_0_image = self.tensor2image(clip_image(unnorm( 66 | y_0_image[:self.max_num_images] 67 | ))) 68 | y_0_mask = self.tensor2image(to_mask_if_dim_gt_3(clip_image(unnorm( 69 | y_0_mask[:self.max_num_images] 70 | )))) 71 | self.wandb_logger.experiment.log({ 72 | 'validation/y_0_image': y_0_image, 73 | 'validation/y_0_mask': y_0_mask 74 | }) 75 | 76 | for output_type, this_outputs in outputs.items(): 77 | y_0_hat = clip_image(unnorm( 78 | this_outputs['model_output'][:self.max_num_images] 79 | )) 80 | y_0_hat_image = self.tensor2image(y_0_hat[:, :3]) 81 | y_0_hat_mask = self.tensor2image(to_mask_if_dim_gt_3(y_0_hat[:, 3:])) 82 | self.wandb_logger.experiment.log({ 83 | f'validation/{output_type}_image': y_0_hat_image, 84 | f'validation/{output_type}_mask': y_0_hat_mask 85 | }) 86 | 87 | 88 | y_t_hist = clip_image(unnorm( 89 | this_outputs['model_history_output'][:self.max_num_images] 90 | )) # bs, time step, im_dim, im_h, im_w 91 | y_t_hist_image = y_t_hist[:, :, :3] 92 | y_t_hist_mask = to_mask_if_dim_gt_3( 93 | y_t_hist[:, :, 3:], dim=2 94 | ) 95 | 96 | # '''log images with `WandbLogger.log_image` 97 | # must convert to a list''' 98 | # self.wandb_logger.log_image( 99 | # key=f'validation/{output_type}_image', 100 | # images=[i for i in self.tensor2numpy(y_0_hat_image)], 101 | # caption=[f'im_{i}' for i in range(len(y_0_hat_image))] 102 | # ) 103 | # self.wandb_logger.log_image( 104 | # key=f'validation/{output_type}_mask', 105 | # images=[i for i in self.tensor2numpy(y_0_hat_mask)], 106 | # caption=[f'im_{i}' for i in range(len(y_0_hat_mask))] 107 | # ) 108 | 109 | '''log predictions as a Table''' 110 | columns = [f't={t}' for t in reversed(range(y_t_hist_image.shape[1]))] 111 | data = [] 112 | for this_y_t_hist in y_t_hist_image: 113 | this_Images = [ 114 | wandb.Image(self.tensor2numpy(i)) for i in this_y_t_hist 115 | ] 116 | data.append(this_Images) 117 | self.wandb_logger.log_table( 118 | key=f'validation_table/{output_type}_image', 119 | columns=columns, data=data 120 | ) 121 | 122 | columns = [f't={t}' for t in reversed(range(y_t_hist_mask.shape[1]))] 123 | data = [] 124 | for this_y_t_hist in y_t_hist_mask: 125 | this_Images = [ 126 | wandb.Image(self.tensor2numpy(i)) for i in this_y_t_hist 127 | ] 128 | data.append(this_Images) 129 | self.wandb_logger.log_table( 130 | key=f'validation_table/{output_type}_mask', 131 | columns=columns, data=data 132 | ) 133 | 134 | class CelebMaskEmbeddingWandBImageLogger(WandBImageLogger): 135 | def on_train_batch_end( 136 | self, trainer, pl_module, 137 | outputs, batch, batch_idx 138 | ): 139 | if batch_idx == 0: 140 | # if raw_image is one-hot mask, it needs to be map to 1 dim for visualization 141 | raw_image = self.tensor2image(clip_image(unnorm( 142 | outputs['raw_image'][:self.max_num_images] 143 | ))) 144 | raw_mask = self.tensor2image(to_mask_if_dim_gt_3( 145 | outputs['raw_mask'][:self.max_num_images] 146 | )) 147 | 148 | model_input = outputs['model_input'][:self.max_num_images] 149 | model_input_image = self.tensor2image(clip_image(unnorm(model_input[:, :3]))) 150 | model_input_mask = self.tensor2image(to_mask_if_dim_gt_3( 151 | model_input[:, 3:] 152 | )) # b, 3, h, w 153 | 154 | model_output = outputs['model_output'][:self.max_num_images] 155 | model_output_image = self.tensor2image(clip_image(unnorm(model_output[:, :3]))) 156 | model_output_mask = self.tensor2image(to_mask_if_dim_gt_3( 157 | model_output[:, 3:] 158 | )) # b, 3, h, w 159 | 160 | y_0_hat = outputs['y_0_hat'][:self.max_num_images] 161 | y_0_hat_image = self.tensor2image(clip_image(unnorm(y_0_hat[:, :3]))) 162 | y_0_hat_mask = self.tensor2image(to_mask_if_dim_gt_3( 163 | y_0_hat[:, 3:] 164 | )) # b, num classes, h, w 165 | 166 | self.wandb_logger.experiment.log({ 167 | 'train/raw_image': raw_image, 168 | 'train/raw_mask': raw_mask, 169 | 'train/model_input_image': model_input_image, 170 | 'train/model_input_mask': model_input_mask, 171 | 'train/model_output_image': model_output_image, 172 | 'train/model_output_mask': model_output_mask, 173 | 'train/y_0_hat_image': y_0_hat_image, 174 | 'train/y_0_hat_mask': y_0_hat_mask 175 | }) 176 | 177 | 178 | def on_validation_batch_end( 179 | self, trainer, pl_module, 180 | outputs, batch, batch_idx, 181 | dataloader_idx 182 | ): 183 | if batch_idx == 0: 184 | y_0_image = outputs.pop('y_0_image') 185 | y_0_mask = outputs.pop('y_0_mask') 186 | 187 | y_0_image = self.tensor2image(clip_image(unnorm( 188 | y_0_image[:self.max_num_images] 189 | ))) 190 | y_0_mask = self.tensor2image(to_mask_if_dim_gt_3( 191 | y_0_mask[:self.max_num_images] 192 | )) 193 | self.wandb_logger.experiment.log({ 194 | 'validation/y_0_image': y_0_image, 195 | 'validation/y_0_mask': y_0_mask 196 | }) 197 | 198 | for output_type, this_outputs in outputs.items(): 199 | y_0_hat = this_outputs['model_output'][:self.max_num_images] 200 | y_0_hat_image = clip_image(unnorm(y_0_hat[:, :3])) 201 | y_0_hat_mask = to_mask_if_dim_gt_3(y_0_hat[:, 3:]) # b, num classes, h, w 202 | 203 | y_t_hist = this_outputs['model_history_output'][:self.max_num_images] 204 | y_t_hist_image = clip_image(unnorm(y_t_hist[:, :, :3])) 205 | y_t_hist_mask = to_mask_if_dim_gt_3( 206 | y_t_hist[:, :, 3:], dim=2 207 | ) # b, t, num classes, h, w 208 | 209 | '''log images with `WandbLogger.log_image` 210 | must convert to a list''' 211 | self.wandb_logger.log_image( 212 | key=f'validation/{output_type}_image', 213 | images=[i for i in self.tensor2numpy(y_0_hat_image)], 214 | caption=[f'im_{i}' for i in range(len(y_0_hat_image))] 215 | ) 216 | self.wandb_logger.log_image( 217 | key=f'validation/{output_type}_mask', 218 | images=[i for i in self.tensor2numpy(y_0_hat_mask)], 219 | caption=[f'im_{i}' for i in range(len(y_0_hat_mask))] 220 | ) 221 | 222 | '''log predictions as a Table''' 223 | columns = [f't={t}' for t in reversed(range(y_t_hist_image.shape[1]))] 224 | data = [] 225 | for this_y_t_hist in y_t_hist_image: 226 | this_Images = [ 227 | wandb.Image(self.tensor2numpy(i)) for i in this_y_t_hist 228 | ] 229 | data.append(this_Images) 230 | self.wandb_logger.log_table( 231 | key=f'validation_table/{output_type}_image', 232 | columns=columns, data=data 233 | ) 234 | 235 | columns = [f't={t}' for t in reversed(range(y_t_hist_mask.shape[1]))] 236 | data = [] 237 | for this_y_t_hist in y_t_hist_mask: 238 | this_Images = [ 239 | wandb.Image(self.tensor2numpy(i)) for i in this_y_t_hist 240 | ] 241 | data.append(this_Images) 242 | self.wandb_logger.log_table( 243 | key=f'validation_table/{output_type}_mask', 244 | columns=columns, data=data 245 | ) -------------------------------------------------------------------------------- /callbacks/celeb_mask/sampling_save_fig.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | import cv2 5 | import numpy as np 6 | from pytorch_lightning.callbacks import Callback 7 | from data.face_parsing import MaskMeshConverter, celebAMask_labels 8 | from ..sampling_save_fig import format_dtype_and_shape, save_figure, save_sampling_history 9 | 10 | def format_image(x): 11 | x = x.cpu() 12 | x = (x + 1) / 2 13 | x = x.clamp(0, 1) 14 | x = x.permute(1,2,0).detach().numpy() 15 | return x 16 | 17 | def save_mask_index(image, save_path): 18 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 19 | image = format_dtype_and_shape(image) 20 | image = image.astype(np.uint8) 21 | cv2.imwrite(save_path, image) 22 | 23 | def save_raw_image_tensor(x, save_path): 24 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 25 | np.savez(save_path, image=x) 26 | 27 | # TODO, change colorizer to the gt color of celeb mask 28 | class MaskColorizer(): 29 | def __init__(self): 30 | self.mask_cvt = MaskMeshConverter( 31 | labels = list(celebAMask_labels.keys()), 32 | mesh_dim=3 33 | ) 34 | def __call__(self, x): 35 | assert len(x.shape) == 3 or len(x.shape) == 4, f'mask should have shape 3 or 4, got {x.shape}' 36 | # input: 1, h, w or time, 1, h, w 37 | x = x.squeeze(-3) # (h, w) or (time, h, w) 38 | x = self.mask_cvt(x) # h, w, 3 or time, h, w, 3 39 | if len(x.shape) == 3: 40 | return x.permute(2, 0, 1) # 3, h, w 41 | elif len(x.shape) == 4: 42 | return x.permute(0, 3, 1, 2) # time, 3, h, w 43 | else: 44 | raise RuntimeError(f'Unknown dim, mask shape is {x.shape}') 45 | 46 | class CelebMaskImageSavingCallback(Callback): 47 | def __init__(self, expt_path, start_idx=0): 48 | self.expt_path = expt_path 49 | self.current_idx = start_idx 50 | self.mask_colorizer = MaskColorizer() 51 | self.repeat_idx = -1 52 | 53 | def save_y_0_hat(self, image, mask, prefix, rank, current_epoch, current_idx, num_gpus=1): 54 | save_figure( 55 | image, 56 | save_path=os.path.join( 57 | self.expt_path, 58 | f'epoch_{current_epoch:05d}', 59 | f'image', 60 | f'{rank+num_gpus*current_idx:04d}_{self.repeat_idx:02d}.png') 61 | ) 62 | if mask is not None: 63 | save_figure( 64 | self.mask_colorizer(mask[None]), 65 | save_path=os.path.join( 66 | self.expt_path, 67 | f'epoch_{current_epoch:05d}', 68 | f'mask', 69 | f'{rank+num_gpus*current_idx:04d}.png') 70 | ) 71 | save_mask_index( 72 | mask, 73 | save_path=os.path.join( 74 | self.expt_path, 75 | f'epoch_{current_epoch:05d}', 76 | f'mask_index', 77 | f'{rank+num_gpus*current_idx:04d}.png') 78 | ) 79 | 80 | save_raw_image_tensor( 81 | format_image(image), 82 | save_path=os.path.join( 83 | self.expt_path, 84 | f'epoch_{current_epoch:05d}', 85 | 'raw_tensor', 86 | f'{rank+num_gpus*current_idx:04d}_{self.repeat_idx:02d}') # will add .npz automatically 87 | ) 88 | 89 | def save_y_t_hist(self, y_t_hist, prefix, rank, current_epoch, current_idx): 90 | for hist_image in y_t_hist: 91 | save_sampling_history( 92 | hist_image, 93 | save_path=os.path.join( 94 | self.expt_path, 95 | f'{prefix}_at_{current_epoch:05d}_image', 96 | f'{rank:02d}_{current_idx:05d}.png') 97 | ) 98 | current_idx += 1 99 | 100 | class CelebMaskPartialAttnImageSavingCallback(CelebMaskImageSavingCallback): 101 | 102 | def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 103 | if batch_idx == 0: 104 | self.repeat_idx += 1 105 | self.current_idx = 0 106 | rank = pl_module.global_rank 107 | current_epoch = pl_module.current_epoch 108 | 109 | y_0_hat = outputs['sampling']['model_output'] 110 | y_t_hist = outputs['sampling']['model_history_output'] 111 | masks = batch['seg_mask'] 112 | for image, mask in zip(y_0_hat, masks): 113 | self.save_y_0_hat( 114 | image, mask, 115 | prefix='sampling', 116 | rank=rank, current_epoch=current_epoch, 117 | current_idx = self.current_idx, 118 | num_gpus=trainer.num_devices 119 | ) 120 | # self.save_y_t_hist( 121 | # y_t_hist, 122 | # prefix='sampling_hist', 123 | # rank=rank, current_epoch=current_epoch, 124 | # current_idx = self.current_idx 125 | # ) 126 | 127 | self.current_idx += 1 128 | 129 | class CelebMaskBaselineImageSavingCallback(CelebMaskImageSavingCallback): 130 | 131 | def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 132 | if batch_idx == 0: 133 | self.repeat_idx += 1 134 | self.current_idx = 0 135 | rank = pl_module.global_rank 136 | current_epoch = pl_module.current_epoch 137 | 138 | y_0_hat = outputs['sampling']['model_output'] 139 | y_t_hist = outputs['sampling']['model_history_output'] 140 | masks = batch['seg_mask'] 141 | for image, mask in zip(y_0_hat, masks): 142 | self.save_y_0_hat( 143 | image, None, 144 | prefix='sampling', 145 | rank=rank, current_epoch=current_epoch, 146 | current_idx = self.current_idx, 147 | num_gpus=trainer.num_devices 148 | ) 149 | # self.save_y_t_hist( 150 | # y_t_hist, 151 | # prefix='sampling_hist', 152 | # rank=rank, current_epoch=current_epoch, 153 | # current_idx = self.current_idx 154 | # ) 155 | 156 | self.current_idx += 1 -------------------------------------------------------------------------------- /callbacks/checkpoint.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.callbacks import ModelCheckpoint, Callback 2 | import os 3 | 4 | def get_epoch_checkpoint( 5 | expt_path, 6 | every_n_epochs=10, 7 | save_top_k=5 8 | ): 9 | epoch_checkpoint = ModelCheckpoint( 10 | every_n_epochs=every_n_epochs, 11 | save_top_k=save_top_k, 12 | monitor="epoch", 13 | mode="max", 14 | dirpath=expt_path, 15 | filename="{epoch:04d}", 16 | ) 17 | return epoch_checkpoint 18 | 19 | def get_latest_checkpoint( 20 | expt_path, 21 | every_n_epochs=1 22 | ): 23 | latest_checkpoint = ModelCheckpoint( 24 | every_n_epochs=every_n_epochs, 25 | save_top_k=1, 26 | monitor="epoch", 27 | mode="max", 28 | dirpath=expt_path, 29 | filename="latest", 30 | ) 31 | return latest_checkpoint 32 | 33 | class CheckpointEveryNSteps(Callback): 34 | """ 35 | from https://github.com/Lightning-AI/lightning/issues/2534 36 | Save a checkpoint every N steps, instead of Lightning's default that checkpoints 37 | based on validation loss. 38 | """ 39 | def __init__( 40 | self, 41 | expt_path, 42 | save_step_at, 43 | prefix="N-Step-Checkpoint", 44 | use_modelcheckpoint_filename=False, 45 | ): 46 | """ 47 | Args: 48 | save_step_frequency: how often to save in steps 49 | prefix: add a prefix to the name, only used if 50 | use_modelcheckpoint_filename=False 51 | use_modelcheckpoint_filename: just use the ModelCheckpoint callback's 52 | default filename, don't use ours. 53 | """ 54 | self.expt_path = expt_path 55 | self.save_step_at = save_step_at 56 | self.prefix = prefix 57 | self.use_modelcheckpoint_filename = use_modelcheckpoint_filename 58 | self.saved_steps = [] 59 | 60 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): 61 | """ Check if we should save a checkpoint after every train batch """ 62 | epoch = trainer.current_epoch 63 | global_step = trainer.global_step 64 | rank = pl_module.global_rank 65 | if rank == 0: 66 | print(global_step) 67 | if (global_step in self.save_step_at) and (global_step not in self.saved_steps) and (rank == 0): 68 | if self.use_modelcheckpoint_filename: 69 | filename = trainer.checkpoint_callback.filename 70 | else: 71 | filename = f"{self.prefix}_{epoch=}_{global_step=}.ckpt" 72 | ckpt_path = os.path.join(self.expt_path, filename) 73 | trainer.save_checkpoint(ckpt_path) 74 | self.saved_steps.append(global_step) 75 | 76 | if global_step == self.save_step_at[-1]: 77 | # training is done 78 | raise 79 | 80 | def get_iteration_checkpoint( 81 | expt_path, 82 | ): 83 | print("INFO: Add iteration callbacks") 84 | return CheckpointEveryNSteps( 85 | expt_path = expt_path, 86 | save_step_at=[100, 200, 500, 1000, 2000] 87 | ) -------------------------------------------------------------------------------- /callbacks/coco_layout/wandb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | from ..wandb import WandBImageLogger -------------------------------------------------------------------------------- /callbacks/sampling_save_fig.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | import cv2 5 | import numpy as np 6 | from pytorch_lightning.callbacks import Callback 7 | 8 | from .utils import unnorm, clip_image 9 | 10 | def format_dtype_and_shape(x): 11 | if isinstance(x, torch.Tensor): 12 | if len(x.shape) == 3 and x.shape[0] == 3: 13 | x = x.permute(1, 2, 0) 14 | if len(x.shape) == 4 and x.shape[1] == 3: 15 | x = x.permute(0, 2, 3, 1) 16 | x = x.detach().cpu().numpy() 17 | return x 18 | 19 | def save_figure(image, save_path): 20 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 21 | if image.min() < 0: 22 | image = clip_image(unnorm(image)) 23 | image = format_dtype_and_shape(image) 24 | image = (image * 255).astype(np.uint8) 25 | cv2.imwrite(save_path, image[..., ::-1]) 26 | 27 | def save_sampling_history(image, save_path): 28 | if image.min() < 0: 29 | image = clip_image(unnorm(image)) 30 | grid_img = torchvision.utils.make_grid(image, nrow=4) 31 | save_figure(grid_img, save_path) 32 | 33 | class BasicImageSavingCallback(Callback): 34 | def __init__(self, expt_path, start_idx=0): 35 | self.expt_path = expt_path 36 | self.current_idx = start_idx 37 | def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 38 | rank = pl_module.global_rank 39 | current_epoch = pl_module.current_epoch 40 | y_0_hat = outputs['sampling']['model_output'] 41 | y_t_hist = outputs['sampling']['model_history_output'] 42 | for image, hist_image in zip(y_0_hat, y_t_hist): 43 | save_figure( 44 | image, 45 | save_path=os.path.join( 46 | self.expt_path, 47 | f'sampling_at_{current_epoch:05d}', 48 | f'{rank:02d}_{self.current_idx:05d}.png') 49 | ) 50 | save_sampling_history( 51 | hist_image, 52 | save_path=os.path.join( 53 | self.expt_path, 54 | f'sampling_hist_at_{current_epoch:05d}', 55 | f'{rank:02d}_{self.current_idx:05d}.png') 56 | ) 57 | self.current_idx += 1 58 | -------------------------------------------------------------------------------- /callbacks/schedule_sampler.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.callbacks import Callback 2 | from DDIM.schedule_sampler import LossSecondMomentResampler 3 | 4 | '''May use in the future, now directly change schedule sampler in the DDIM.py''' 5 | class ScheduleSamplerCallback(Callback): 6 | '''change the sample weight of different timestamps''' 7 | def on_train_batch_end( 8 | self, trainer, pl_module, 9 | outputs, batch, batch_idx 10 | ): 11 | if isinstance(pl_module.schedule_sampler, LossSecondMomentResampler): 12 | t = outputs['t'] 13 | loss_flat = outputs['loss_flat'] 14 | pl_module.schedule_sampler.update_with_all_losses( 15 | t, loss_flat.detach() 16 | ) -------------------------------------------------------------------------------- /callbacks/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from data.face_parsing import MaskMeshConverter, celebAMask_labels 3 | 4 | def unnorm(x): 5 | '''convert from range [-1, 1] to [0, 1]''' 6 | return (x+1) / 2 7 | 8 | def clip_image(x, min=0., max=1.): 9 | return torch.clamp(x, min=min, max=max) 10 | 11 | def colorize_mask(x, input_is_float=True): 12 | assert len(x.shape) == 4 or len(x.shape) == 5, f'mask should have shape 4 or 5, got {x.shape}' 13 | # input: b, 1, h, w or b, time, 1, h, w 14 | if len(x.shape) >= 4: 15 | x = x.squeeze(-3) # 16 | if input_is_float: 17 | x = (256*x).cpu().to(torch.long) # why * 255 is not correct??? 18 | mask_cvt = MaskMeshConverter( 19 | labels = list(celebAMask_labels.keys()), 20 | mesh_dim=3 21 | ) 22 | x = mask_cvt(x) # b, h, w, 3 or b, time, h, w, 3 23 | if len(x.shape) == 4: 24 | return x.permute(0, 3, 1, 2) # b, 3, h, w 25 | elif len(x.shape) == 5: 26 | return x.permute(0, 1, 4, 2, 3) # b, time, 3, h, w 27 | else: 28 | raise RuntimeError(f'Unknown dim, mask shape is {x.shape}') 29 | 30 | def to_mask_if_dim_gt_3(x, dim=1, keepdim=True, colorize=True): 31 | # TODO, now function also accept 3D mask (b, h, w), give it a new name 32 | ''' 33 | valid x shape: (b, h, w), (b, 1, h, w), (b, 3, h, w), (b, num class, h, w) 34 | (b, t, 1, h, w), (b, t, 3, h, w), (b, t, num class, h, w) 35 | colorize_mask valid input shapes are (b, 1, h, w) and (b, t, 1, h, w) 36 | ''' 37 | if len(x.shape) == 3: 38 | '''handle (b, h, w)''' 39 | x = x.unsqueeze(1) 40 | if x.shape[dim] > 3: 41 | '''handle (b, num class, h, w) and (b, t, num class, h, w)''' 42 | x = torch.argmax(x, dim=dim, keepdim=keepdim) 43 | if colorize: 44 | if x.shape[-3] != 1: 45 | print(f'WARNING: mask has shape {x.shape}, will not apply colorization since the dim[-3] != 1') 46 | else: 47 | x = colorize_mask(x, input_is_float=False) 48 | return x 49 | -------------------------------------------------------------------------------- /callbacks/wandb.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import torch 3 | import torchvision 4 | from pytorch_lightning.loggers import WandbLogger 5 | from pytorch_lightning.callbacks import Callback 6 | from .utils import unnorm, clip_image 7 | 8 | class WandBImageLogger(Callback): 9 | def __init__( 10 | self, 11 | wandb_logger: WandbLogger=None, 12 | max_num_images: int=16, 13 | ) -> None: 14 | super().__init__() 15 | self.wandb_logger = wandb_logger 16 | self.max_num_images = max_num_images 17 | 18 | # TODO move following two functions to utils.py 19 | def tensor2numpy(self, x): 20 | x = x.float() # handle bf16 21 | '''convert 4D (b, dim, h, w) pytorch tensor to numpy (b, h, w, dim) 22 | or convert 3D (dim, h, w) pytorch tensor to numpy (h, w, dim)''' 23 | if len(x.shape) == 4: 24 | return x.permute(0, 2, 3, 1).detach().cpu().numpy() 25 | else: 26 | return x.permute(1, 2, 0).detach().cpu().numpy() 27 | 28 | def tensor2image(self, x): 29 | x = x.float() # handle bf16 30 | '''convert 4D (b, dim, h, w) pytorch tensor to wandb Image class''' 31 | grid_img = torchvision.utils.make_grid( 32 | x, nrow=4 33 | ).permute(1, 2, 0).detach().cpu().numpy() 34 | img = wandb.Image( 35 | grid_img 36 | ) 37 | return img 38 | 39 | def on_train_batch_end( 40 | self, trainer, pl_module, 41 | outputs, batch, batch_idx 42 | ): 43 | # record images in first batch 44 | if isinstance(outputs, list): 45 | print(outputs) 46 | raise 47 | if batch_idx == 0: 48 | raw_image = self.tensor2image(clip_image(unnorm( 49 | outputs['raw_image'][:self.max_num_images] 50 | ))) 51 | model_input = self.tensor2image(clip_image(unnorm( 52 | outputs['model_input'][:self.max_num_images] 53 | ))) 54 | model_output = self.tensor2image(clip_image(unnorm( 55 | outputs['model_output'][:self.max_num_images] 56 | ))) 57 | y_0_hat = self.tensor2image(clip_image(unnorm( 58 | outputs['y_0_hat'][:self.max_num_images] 59 | ))) 60 | self.wandb_logger.experiment.log({ 61 | 'train/raw_image': raw_image, 62 | 'train/model_input': model_input, 63 | 'train/model_output': model_output, 64 | 'train/y_0_hat': y_0_hat 65 | }) 66 | 67 | def on_validation_batch_end( 68 | self, trainer, pl_module, 69 | outputs, batch, batch_idx, dataloader_idx 70 | ): 71 | """Called when the validation batch ends.""" 72 | if batch_idx == 0: 73 | y_0 = self.tensor2image(clip_image(unnorm( 74 | outputs['y_0_image'][:self.max_num_images] 75 | ))) 76 | self.wandb_logger.experiment.log({ 77 | f'validation/raw_image': y_0, 78 | }) 79 | outputs.pop('y_0_image') 80 | '''outputs has result of restoration and sampling''' 81 | for output_type, this_outputs in outputs.items(): 82 | y_0_hat = self.tensor2image(clip_image(unnorm( 83 | this_outputs['model_output'][:self.max_num_images] 84 | ))) 85 | y_t_hist = unnorm( 86 | this_outputs['model_history_output'][:self.max_num_images] 87 | ) # bs, time step, im_dim, im_h, im_w 88 | 89 | self.wandb_logger.experiment.log({ 90 | f'validation/{output_type}': y_0_hat, 91 | }) 92 | 93 | '''log predictions as a Table''' 94 | columns = [f't={t}' for t in reversed(range(y_t_hist.shape[1]))] 95 | data = [] 96 | for this_y_t_hist in y_t_hist: 97 | this_Images = [ 98 | wandb.Image(self.tensor2numpy(i)) for i in this_y_t_hist 99 | ] 100 | data.append(this_Images) 101 | self.wandb_logger.log_table( 102 | key=f'validation_table/{output_type}', 103 | columns=columns, data=data 104 | ) 105 | 106 | class WandBVAEImageLogger(Callback): 107 | def __init__( 108 | self, 109 | wandb_logger: WandbLogger=None, 110 | max_num_images: int=16, 111 | ) -> None: 112 | super().__init__() 113 | self.wandb_logger = wandb_logger 114 | self.max_num_images = max_num_images 115 | 116 | def tensor2image(self, x): 117 | '''convert 4D (b, dim, h, w) pytorch tensor to wandb Image class''' 118 | grid_img = torchvision.utils.make_grid( 119 | x, nrow=4 120 | ).permute(1, 2, 0).detach().cpu().numpy() 121 | img = wandb.Image( 122 | grid_img 123 | ) 124 | return img 125 | 126 | def on_train_batch_end( 127 | self, trainer, pl_module, 128 | outputs, batch, batch_idx 129 | ): 130 | if batch_idx == 0: 131 | raw_image = self.tensor2image(clip_image(unnorm( 132 | outputs['raw_image'][:self.max_num_images] 133 | ))) 134 | model_output = self.tensor2image(clip_image(unnorm( 135 | outputs['model_output'][:self.max_num_images] 136 | ))) 137 | self.wandb_logger.experiment.log({ 138 | 'train/raw_image': raw_image, 139 | 'train/model_output': model_output, 140 | }) 141 | 142 | def on_validation_batch_end( 143 | self, trainer, pl_module, 144 | outputs, batch, batch_idx, dataloader_idx 145 | ): 146 | """Called when the validation batch ends.""" 147 | if batch_idx == 0: 148 | x = self.tensor2image(clip_image(unnorm( 149 | outputs['raw_image'][:self.max_num_images] 150 | ))) 151 | self.wandb_logger.experiment.log({ 152 | f'validation/raw_image': x, 153 | }) 154 | o = self.tensor2image(clip_image(unnorm( 155 | outputs['model_output'][:self.max_num_images] 156 | ))) 157 | self.wandb_logger.experiment.log({ 158 | f'validation/model_output': o, 159 | }) -------------------------------------------------------------------------------- /configs/celeb_mask.json: -------------------------------------------------------------------------------- 1 | { 2 | "expt_name": "celeb_mask_LayoutDiffuse", 3 | "expt_dir": "experiments", 4 | "trainer_args": { 5 | "max_epochs": 200, 6 | "accelerator": "gpu", 7 | "devices": [0,1,2,3], 8 | "limit_val_batches": 1, 9 | "strategy": "ddp", 10 | "accumulate_grad_batches": 32, 11 | "check_val_every_n_epoch": 1 12 | }, 13 | "diffusion": { 14 | "model": "DDIM_ldm.DDIM_ldm_celeb.DDIM_LDM_LayoutDiffuse_celeb_mask", 15 | "model_args": { 16 | "loss_fn": "mse", 17 | "training_target": "noise", 18 | "beta_schedule_args": { 19 | "schedule": "linear", 20 | "n_timestep": 1000, 21 | "linear_start": 0.0015, 22 | "linear_end": 0.0195 23 | }, 24 | "optim_args" :{ 25 | "lr": 5e-5, 26 | "weight_decay": 0 27 | }, 28 | "unet_init_weights": "pretrained_models/celeba256/unet.ckpt", 29 | "vqvae_init_weights": "pretrained_models/celeba256/vqvae.ckpt", 30 | "freeze_pretrained_weights": false, 31 | "use_fast_sampling": true, 32 | "fast_sampler": "plms", 33 | "fast_sampling_steps": 100, 34 | "clip_denoised": false 35 | } 36 | }, 37 | "denoising_model": { 38 | "model": "modules.openai_unet.openaimodel_layout_diffuse.UNetModel", 39 | "model_args": { 40 | "image_size": 64, 41 | "in_channels": 3, 42 | "model_channels": 224, 43 | "out_channels": 3, 44 | "num_res_blocks": 2, 45 | "attention_resolutions": [2, 4, 8], 46 | "channel_mult": [1, 2, 3, 4], 47 | "num_head_channels": 32, 48 | "use_checkpoint": false, 49 | "prompt_dim": 128, 50 | "num_prompt": 64, 51 | "instance_prompt_attn_type": "segmentation", 52 | "instance_attn_res": [2, 4], 53 | "instance_prompt_args": { 54 | "num_classes": 19, 55 | "embedding_dim": 128 56 | }, 57 | "verbose": true 58 | } 59 | }, 60 | "vqvae_model": { 61 | "model": "modules.vqvae.autoencoder.VQModelInterface", 62 | "model_args": { 63 | "embed_dim": 3, 64 | "n_embed": 8192, 65 | "ddconfig": { 66 | "double_z": false, 67 | "z_channels": 3, 68 | "resolution": 256, 69 | "in_channels": 3, 70 | "out_ch": 3, 71 | "ch": 128, 72 | "ch_mult": [1, 2, 4], 73 | "num_res_blocks": 2, 74 | "attn_resolutions": [], 75 | "dropout": 0.0 76 | }, 77 | "lossconfig": { 78 | "target": "torch.nn.Identity" 79 | } 80 | } 81 | }, 82 | "data": { 83 | "dataset": "celeb_mask", 84 | "root": "/home/ubuntu/disk2/data/face/CelebAMask-HQ", 85 | "image_size": 256, 86 | "down_resolutions": [1,2,4,8,16], 87 | "train_args": { 88 | "split": "train", 89 | "data_len": -1 90 | }, 91 | "val_args": { 92 | "split": "val", 93 | "data_len": 4 94 | }, 95 | "batch_size": 1, 96 | "val_batch_size": 1 97 | }, 98 | "save_model_config": { 99 | "every_n_epochs": 5, 100 | "save_top_k": 5 101 | }, 102 | "sampling_args": { 103 | "sampling_w_noise": false, 104 | "image_size": 64, 105 | "in_channel": 3, 106 | "num_samples": -1, 107 | "callbacks": [ 108 | "callbacks.celeb_mask.sampling_save_fig.CelebMaskPartialAttnImageSavingCallback" 109 | ] 110 | } 111 | } -------------------------------------------------------------------------------- /configs/cocostuff.json: -------------------------------------------------------------------------------- 1 | { 2 | "expt_name": "cocostuff_LayoutDiffuse", 3 | "expt_dir": "experiments", 4 | "trainer_args": { 5 | "max_epochs": 1000, 6 | "accelerator": "gpu", 7 | "devices": [0,1,2,3,4,5,6,7], 8 | "limit_val_batches": 1, 9 | "strategy": "ddp", 10 | "accumulate_grad_batches": 32, 11 | "check_val_every_n_epoch": 1 12 | }, 13 | "callbacks": [ 14 | "callbacks.WandBImageLogger" 15 | ], 16 | "diffusion": { 17 | "model": "DDIM_ldm.DDIM_ldm_coco.DDIM_LDM_LAION_Text", 18 | "model_args": { 19 | "loss_fn": "mse", 20 | "training_target": "noise", 21 | "beta_schedule_args": { 22 | "schedule": "linear", 23 | "n_timestep": 1000, 24 | "linear_start": 0.00085, 25 | "linear_end": 0.012 26 | }, 27 | "optim_args": { 28 | "lr": 3e-5, 29 | "weight_decay": 0 30 | }, 31 | "unet_init_weights": "pretrained_models/LAION_text2img/unet.ckpt", 32 | "vqvae_init_weights": "pretrained_models/LAION_text2img/vqvae.ckpt", 33 | "text_model_init_weights": "pretrained_models/LAION_text2img/bert.ckpt", 34 | "freeze_pretrained_weights": false, 35 | "use_fast_sampling": true, 36 | "fast_sampling_steps": 20, 37 | "fast_sampler": "plms", 38 | "guidance_scale": 5.0, 39 | "clip_denoised": false, 40 | "scale_factor": 0.18215 41 | } 42 | }, 43 | "denoising_model": { 44 | "model": "modules.openai_unet.openaimodel_layout_diffuse.UNetModel", 45 | "model_args": { 46 | "image_size": 32, 47 | "in_channels": 4, 48 | "model_channels": 320, 49 | "out_channels": 4, 50 | "num_res_blocks": 2, 51 | "attention_resolutions": [1, 2, 4], 52 | "channel_mult": [1, 2, 4, 4], 53 | "num_heads": 8, 54 | "use_spatial_transformer": true, 55 | "transformer_depth": 1, 56 | "use_checkpoint": true, 57 | "legacy": false, 58 | "prompt_dim": 128, 59 | "num_prompt": 64, 60 | "image_in_kv": true, 61 | "text_context_dim": 1280, 62 | "instance_prompt_attn_type": "layout_partial_v2", 63 | "instance_attn_res": [1, 2], 64 | "instance_prompt_args": { 65 | "num_classes": 181, 66 | "embedding_dim": 128 67 | }, 68 | "verbose": true 69 | } 70 | }, 71 | "vqvae_model": { 72 | "model": "modules.kl_autoencoder.autoencoder.AutoencoderKL", 73 | "model_args": { 74 | "embed_dim": 4, 75 | "ddconfig": { 76 | "double_z": true, 77 | "z_channels": 4, 78 | "resolution": 256, 79 | "in_channels": 3, 80 | "out_ch": 3, 81 | "ch": 128, 82 | "ch_mult": [1, 2, 4, 4], 83 | "num_res_blocks": 2, 84 | "attn_resolutions": [], 85 | "dropout": 0.0 86 | }, 87 | "lossconfig": { 88 | "target": "torch.nn.Identity" 89 | } 90 | } 91 | }, 92 | "text_model": { 93 | "model": "modules.bert.bert_embedder.BERTEmbedder", 94 | "model_args": { 95 | "n_embed": 1280, 96 | "n_layer": 32 97 | } 98 | }, 99 | "data": { 100 | "dataset": "coco_stuff_layout_caption_label", 101 | "root": "/home/ubuntu/disk2/data/COCO", 102 | "image_size": 256, 103 | "dataset_args": { 104 | "train_empty_string": 0, 105 | "val_empty_string": 0 106 | }, 107 | "train_args": { 108 | "split": "train", 109 | "data_len": -1 110 | }, 111 | "val_args": { 112 | "split": "val", 113 | "data_len": 1 114 | }, 115 | "batch_size": 1, 116 | "val_batch_size": 1 117 | }, 118 | "sampling_args": { 119 | "sampling_w_noise": false, 120 | "image_size": 32, 121 | "in_channel": 4, 122 | "num_samples": -1, 123 | "callbacks": [ 124 | "callbacks.coco_layout.sampling_save_fig.COCOLayoutImageSavingCallback" 125 | ] 126 | } 127 | } -------------------------------------------------------------------------------- /configs/cocostuff_SD1_5.json: -------------------------------------------------------------------------------- 1 | { 2 | "expt_name": "cocostuff_LayoutDiffuse_SD1_5", 3 | "expt_dir": "experiments", 4 | "trainer_args": { 5 | "max_epochs": 1000, 6 | "accelerator": "gpu", 7 | "devices": [0,1,2,3], 8 | "limit_val_batches": 1, 9 | "strategy": "ddp", 10 | "accumulate_grad_batches": 32, 11 | "check_val_every_n_epoch": 1 12 | }, 13 | "callbacks": [ 14 | "callbacks.WandBImageLogger" 15 | ], 16 | "diffusion": { 17 | "model": "DDIM_ldm.DDIM_ldm_coco.DDIM_LDM_LAION_Text", 18 | "model_args": { 19 | "loss_fn": "mse", 20 | "training_target": "noise", 21 | "beta_schedule_args": { 22 | "schedule": "linear", 23 | "n_timestep": 1000, 24 | "linear_start": 0.00085, 25 | "linear_end": 0.012 26 | }, 27 | "optim_args": { 28 | "lr": 1e-4, 29 | "weight_decay": 0 30 | }, 31 | "unet_init_weights": "pretrained_models/SD1_5/unet.ckpt", 32 | "vqvae_init_weights": "pretrained_models/SD1_5/vqvae.ckpt", 33 | "text_model_init_weights": "pretrained_models/SD1_5/clip.ckpt", 34 | "freeze_pretrained_weights": true, 35 | "use_fast_sampling": true, 36 | "fast_sampling_steps": 20, 37 | "fast_sampler": "plms", 38 | "guidance_scale": 7.5, 39 | "clip_denoised": false, 40 | "scale_factor": 0.18215 41 | } 42 | }, 43 | "denoising_model": { 44 | "model": "modules.openai_unet.openaimodel_layout_diffuse.UNetModel", 45 | "model_args": { 46 | "image_size": 32, 47 | "in_channels": 4, 48 | "model_channels": 320, 49 | "out_channels": 4, 50 | "num_res_blocks": 2, 51 | "attention_resolutions": [1, 2, 4], 52 | "channel_mult": [1, 2, 4, 4], 53 | "num_heads": 8, 54 | "use_spatial_transformer": true, 55 | "transformer_depth": 1, 56 | "use_checkpoint": true, 57 | "legacy": false, 58 | "prompt_dim": 128, 59 | "num_prompt": 0, 60 | "image_in_kv": false, 61 | "text_context_dim": 768, 62 | "instance_prompt_attn_type": "layout_partial_v2", 63 | "instance_attn_res": [2,4], 64 | "instance_prompt_args": { 65 | "num_classes": 181, 66 | "embedding_dim": 128 67 | }, 68 | "verbose": true 69 | } 70 | }, 71 | "vqvae_model": { 72 | "model": "modules.kl_autoencoder.autoencoder.AutoencoderKL", 73 | "model_args": { 74 | "embed_dim": 4, 75 | "ddconfig": { 76 | "double_z": true, 77 | "z_channels": 4, 78 | "resolution": 256, 79 | "in_channels": 3, 80 | "out_ch": 3, 81 | "ch": 128, 82 | "ch_mult": [1, 2, 4, 4], 83 | "num_res_blocks": 2, 84 | "attn_resolutions": [], 85 | "dropout": 0.0 86 | }, 87 | "lossconfig": { 88 | "target": "torch.nn.Identity" 89 | } 90 | } 91 | }, 92 | "text_model": { 93 | "model": "modules.openclip.modules.FrozenCLIPEmbedder", 94 | "model_args": { 95 | "freeze": true 96 | } 97 | }, 98 | "data": { 99 | "dataset": "coco_stuff_layout_caption_label", 100 | "root": "/home/ubuntu/disk2/data/COCO", 101 | "image_size": 512, 102 | "dataset_args": { 103 | "train_empty_string": 0, 104 | "val_empty_string": 0 105 | }, 106 | "train_args": { 107 | "split": "train", 108 | "data_len": -1 109 | }, 110 | "val_args": { 111 | "split": "val", 112 | "data_len": 1 113 | }, 114 | "batch_size": 1, 115 | "val_batch_size": 1 116 | }, 117 | "sampling_args": { 118 | "sampling_w_noise": false, 119 | "image_size": 64, 120 | "in_channel": 4, 121 | "num_samples": -1, 122 | "callbacks": [ 123 | "callbacks.coco_layout.sampling_save_fig.COCOLayoutImageSavingCallback" 124 | ] 125 | } 126 | } -------------------------------------------------------------------------------- /configs/cocostuff_SD1_5_merge_model.json: -------------------------------------------------------------------------------- 1 | { 2 | "expt_name": "cocostuff_LayoutDiffuse_SD1_5", 3 | "expt_dir": "experiments", 4 | "trainer_args": { 5 | "max_epochs": 1000, 6 | "accelerator": "gpu", 7 | "devices": [0,1,2,3], 8 | "limit_val_batches": 1, 9 | "strategy": "ddp", 10 | "accumulate_grad_batches": 32, 11 | "check_val_every_n_epoch": 1 12 | }, 13 | "callbacks": [ 14 | "callbacks.WandBImageLogger" 15 | ], 16 | "diffusion": { 17 | "model": "DDIM_ldm.DDIM_ldm_coco.DDIM_LDM_LAION_Text_CKPT_Merge", 18 | "model_args": { 19 | "loss_fn": "mse", 20 | "training_target": "noise", 21 | "beta_schedule_args": { 22 | "schedule": "linear", 23 | "n_timestep": 1000, 24 | "linear_start": 0.00085, 25 | "linear_end": 0.012 26 | }, 27 | "optim_args": { 28 | "lr": 1e-4, 29 | "weight_decay": 0 30 | }, 31 | "unet_init_weights": "pretrained_models/SD1_5/unet.ckpt", 32 | "vqvae_init_weights": "pretrained_models/SD1_5/vqvae.ckpt", 33 | "text_model_init_weights": "pretrained_models/SD1_5/clip.ckpt", 34 | "freeze_pretrained_weights": true, 35 | "use_fast_sampling": true, 36 | "fast_sampling_steps": 20, 37 | "fast_sampler": "plms", 38 | "guidance_scale": 7.5, 39 | "clip_denoised": false, 40 | "scale_factor": 0.18215 41 | } 42 | }, 43 | "denoising_model": { 44 | "model": "modules.openai_unet.openaimodel_layout_diffuse.UNetModel", 45 | "model_args": { 46 | "image_size": 32, 47 | "in_channels": 4, 48 | "model_channels": 320, 49 | "out_channels": 4, 50 | "num_res_blocks": 2, 51 | "attention_resolutions": [1, 2, 4], 52 | "channel_mult": [1, 2, 4, 4], 53 | "num_heads": 8, 54 | "use_spatial_transformer": true, 55 | "transformer_depth": 1, 56 | "use_checkpoint": true, 57 | "legacy": false, 58 | "prompt_dim": 128, 59 | "num_prompt": 0, 60 | "image_in_kv": false, 61 | "text_context_dim": 768, 62 | "instance_prompt_attn_type": "layout_partial_v2", 63 | "instance_attn_res": [2,4], 64 | "instance_prompt_args": { 65 | "num_classes": 181, 66 | "embedding_dim": 128 67 | }, 68 | "verbose": true 69 | } 70 | }, 71 | "vqvae_model": { 72 | "model": "modules.kl_autoencoder.autoencoder.AutoencoderKL", 73 | "model_args": { 74 | "embed_dim": 4, 75 | "ddconfig": { 76 | "double_z": true, 77 | "z_channels": 4, 78 | "resolution": 256, 79 | "in_channels": 3, 80 | "out_ch": 3, 81 | "ch": 128, 82 | "ch_mult": [1, 2, 4, 4], 83 | "num_res_blocks": 2, 84 | "attn_resolutions": [], 85 | "dropout": 0.0 86 | }, 87 | "lossconfig": { 88 | "target": "torch.nn.Identity" 89 | } 90 | } 91 | }, 92 | "text_model": { 93 | "model": "modules.openclip.modules.FrozenCLIPEmbedder", 94 | "model_args": { 95 | "freeze": true 96 | } 97 | }, 98 | "data": { 99 | "dataset": "coco_stuff_layout_caption_label", 100 | "root": "/home/ubuntu/disk2/data/COCO", 101 | "image_size": 512, 102 | "dataset_args": { 103 | "train_empty_string": 0, 104 | "val_empty_string": 0 105 | }, 106 | "train_args": { 107 | "split": "train", 108 | "data_len": -1 109 | }, 110 | "val_args": { 111 | "split": "val", 112 | "data_len": 1 113 | }, 114 | "batch_size": 1, 115 | "val_batch_size": 1 116 | }, 117 | "sampling_args": { 118 | "sampling_w_noise": false, 119 | "image_size": 64, 120 | "in_channel": 4, 121 | "num_samples": -1, 122 | "callbacks": [ 123 | "callbacks.coco_layout.sampling_save_fig.COCOLayoutImageSavingCallback" 124 | ] 125 | } 126 | } -------------------------------------------------------------------------------- /configs/cocostuff_SD2_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "expt_name": "cocostuff_LayoutDiffuse_SD2_1", 3 | "expt_dir": "experiments", 4 | "trainer_args": { 5 | "max_epochs": 1000, 6 | "accelerator": "gpu", 7 | "devices": [0,1,2,3], 8 | "limit_val_batches": 1, 9 | "strategy": "ddp", 10 | "accumulate_grad_batches": 32, 11 | "check_val_every_n_epoch": 1 12 | }, 13 | "callbacks": [ 14 | "callbacks.WandBImageLogger" 15 | ], 16 | "diffusion": { 17 | "model": "DDIM_ldm.DDIM_ldm_coco.DDIM_LDM_LAION_Text", 18 | "model_args": { 19 | "loss_fn": "mse", 20 | "training_target": "noise", 21 | "beta_schedule_args": { 22 | "schedule": "linear", 23 | "n_timestep": 1000, 24 | "linear_start": 0.00085, 25 | "linear_end": 0.012 26 | }, 27 | "optim_args": { 28 | "lr": 3e-5, 29 | "weight_decay": 0 30 | }, 31 | "unet_init_weights": "pretrained_models/SD2_1/unet.ckpt", 32 | "vqvae_init_weights": "pretrained_models/SD2_1/vqvae.ckpt", 33 | "text_model_init_weights": "pretrained_models/SD2_1/clip.ckpt", 34 | "freeze_pretrained_weights": true, 35 | "use_fast_sampling": true, 36 | "fast_sampling_steps": 20, 37 | "fast_sampler": "plms", 38 | "guidance_scale": 5, 39 | "clip_denoised": false, 40 | "scale_factor": 0.18215 41 | } 42 | }, 43 | "denoising_model": { 44 | "model": "modules.openai_unet.openaimodel_layout_diffuse.UNetModel", 45 | "model_args": { 46 | "image_size": 32, 47 | "in_channels": 4, 48 | "model_channels": 320, 49 | "out_channels": 4, 50 | "num_res_blocks": 2, 51 | "attention_resolutions": [1, 2, 4], 52 | "channel_mult": [1, 2, 4, 4], 53 | "num_head_channels": 64, 54 | "use_spatial_transformer": true, 55 | "transformer_depth": 1, 56 | "use_checkpoint": true, 57 | "legacy": false, 58 | "prompt_dim": 128, 59 | "num_prompt": 0, 60 | "image_in_kv": false, 61 | "text_context_dim": 1024, 62 | "instance_prompt_attn_type": "layout_partial_v2", 63 | "instance_attn_res": [2,4], 64 | "instance_prompt_args": { 65 | "num_classes": 181, 66 | "embedding_dim": 128 67 | }, 68 | "verbose": true 69 | } 70 | }, 71 | "vqvae_model": { 72 | "model": "modules.kl_autoencoder.autoencoder.AutoencoderKL", 73 | "model_args": { 74 | "embed_dim": 4, 75 | "ddconfig": { 76 | "double_z": true, 77 | "z_channels": 4, 78 | "resolution": 256, 79 | "in_channels": 3, 80 | "out_ch": 3, 81 | "ch": 128, 82 | "ch_mult": [1, 2, 4, 4], 83 | "num_res_blocks": 2, 84 | "attn_resolutions": [], 85 | "dropout": 0.0 86 | }, 87 | "lossconfig": { 88 | "target": "torch.nn.Identity" 89 | } 90 | } 91 | }, 92 | "text_model": { 93 | "model": "modules.openclip.modules.FrozenOpenCLIPEmbedder", 94 | "model_args": { 95 | "freeze": true, 96 | "layer": "penultimate" 97 | } 98 | }, 99 | "data": { 100 | "dataset": "coco_stuff_layout_caption_label", 101 | "root": "/home/ubuntu/disk2/data/COCO", 102 | "image_size": 512, 103 | "dataset_args": { 104 | "train_empty_string": 0, 105 | "val_empty_string": 0 106 | }, 107 | "train_args": { 108 | "split": "train", 109 | "data_len": -1 110 | }, 111 | "val_args": { 112 | "split": "val", 113 | "data_len": 1 114 | }, 115 | "batch_size": 1, 116 | "val_batch_size": 1 117 | }, 118 | "sampling_args": { 119 | "sampling_w_noise": false, 120 | "image_size": 64, 121 | "in_channel": 4, 122 | "num_samples": -1, 123 | "callbacks": [ 124 | "callbacks.coco_layout.sampling_save_fig.COCOLayoutImageSavingCallback" 125 | ] 126 | } 127 | } -------------------------------------------------------------------------------- /configs/cocostuff_no_text.json: -------------------------------------------------------------------------------- 1 | { 2 | "expt_name": "cocostuff_no_text_LayoutDiffuse", 3 | "expt_dir": "experiments", 4 | "trainer_args": { 5 | "max_epochs": 60, 6 | "accelerator": "gpu", 7 | "devices": [0,1,2,3,4,5,6,7], 8 | "limit_val_batches": 1, 9 | "strategy": "ddp", 10 | "accumulate_grad_batches": 32, 11 | "check_val_every_n_epoch": 1 12 | }, 13 | "callbacks": [ 14 | "callbacks.WandBImageLogger" 15 | ], 16 | "diffusion": { 17 | "model": "DDIM_ldm.DDIM_ldm_coco.DDIM_LDM_LAION_pretrained_COCO_instance_prompt", 18 | "model_args": { 19 | "loss_fn": "mse", 20 | "training_target": "noise", 21 | "beta_schedule_args": { 22 | "schedule": "linear", 23 | "n_timestep": 1000, 24 | "linear_start": 0.00085, 25 | "linear_end": 0.012 26 | }, 27 | "optim_args": { 28 | "lr": 3e-5, 29 | "weight_decay": 0 30 | }, 31 | "unet_init_weights": "pretrained_models/LAION_text2img/unet.ckpt", 32 | "vqvae_init_weights": "pretrained_models/LAION_text2img/vqvae.ckpt", 33 | "freeze_pretrained_weights": false, 34 | "use_fast_sampling": true, 35 | "fast_sampler": "plms", 36 | "fast_sampling_steps": 100, 37 | "clip_denoised": false, 38 | "scale_factor": 0.18215 39 | } 40 | }, 41 | "denoising_model": { 42 | "model": "modules.openai_unet.openaimodel_layout_diffuse.UNetModel", 43 | "model_args": { 44 | "image_size": 32, 45 | "in_channels": 4, 46 | "model_channels": 320, 47 | "out_channels": 4, 48 | "num_res_blocks": 2, 49 | "attention_resolutions": [1, 2, 4], 50 | "channel_mult": [1, 2, 4, 4], 51 | "num_heads": 8, 52 | "use_spatial_transformer": true, 53 | "transformer_depth": 1, 54 | "use_checkpoint": true, 55 | "legacy": false, 56 | "prompt_dim": 128, 57 | "num_prompt": 64, 58 | "instance_prompt_attn_type": "layout_partial_v2", 59 | "instance_attn_res": [1, 2], 60 | "instance_prompt_args": { 61 | "num_classes": 181, 62 | "embedding_dim": 128 63 | }, 64 | "verbose": true 65 | } 66 | }, 67 | "vqvae_model": { 68 | "model": "modules.kl_autoencoder.autoencoder.AutoencoderKL", 69 | "model_args": { 70 | "embed_dim": 4, 71 | "ddconfig": { 72 | "double_z": true, 73 | "z_channels": 4, 74 | "resolution": 256, 75 | "in_channels": 3, 76 | "out_ch": 3, 77 | "ch": 128, 78 | "ch_mult": [1, 2, 4, 4], 79 | "num_res_blocks": 2, 80 | "attn_resolutions": [], 81 | "dropout": 0.0 82 | }, 83 | "lossconfig": { 84 | "target": "torch.nn.Identity" 85 | } 86 | } 87 | }, 88 | "data": { 89 | "dataset": "coco_stuff_layout", 90 | "root": "/home/ubuntu/disk2/data/COCO", 91 | "image_size": 256, 92 | "train_args": { 93 | "split": "train", 94 | "data_len": -1 95 | }, 96 | "val_args": { 97 | "split": "val", 98 | "data_len": 1 99 | }, 100 | "batch_size": 1, 101 | "val_batch_size": 1 102 | }, 103 | "sampling_args": { 104 | "sampling_w_noise": false, 105 | "image_size": 32, 106 | "in_channel": 4, 107 | "num_samples": -1, 108 | "callbacks": [ 109 | "callbacks.coco_layout.sampling_save_fig.COCOLayoutImageSavingCallback" 110 | ] 111 | } 112 | } -------------------------------------------------------------------------------- /configs/vg.json: -------------------------------------------------------------------------------- 1 | { 2 | "expt_name": "vg_LayoutDiffuse", 3 | "expt_dir": "experiments", 4 | "trainer_args": { 5 | "max_epochs": 150, 6 | "accelerator": "gpu", 7 | "devices": [0,1,2,3], 8 | "limit_val_batches": 1, 9 | "strategy": "ddp", 10 | "accumulate_grad_batches": 32, 11 | "check_val_every_n_epoch": 1 12 | }, 13 | "callbacks": [ 14 | "callbacks.WandBImageLogger" 15 | ], 16 | "diffusion": { 17 | "model": "DDIM_ldm.DDIM_ldm_coco.DDIM_LDM_LAION_Text", 18 | "model_args": { 19 | "loss_fn": "mse", 20 | "training_target": "noise", 21 | "beta_schedule_args": { 22 | "schedule": "linear", 23 | "n_timestep": 1000, 24 | "linear_start": 0.00085, 25 | "linear_end": 0.012 26 | }, 27 | "optim_args": { 28 | "lr": 2e-6, 29 | "weight_decay": 0 30 | }, 31 | "unet_init_weights": "pretrained_models/LAION_text2img/unet.ckpt", 32 | "vqvae_init_weights": "pretrained_models/LAION_text2img/vqvae.ckpt", 33 | "text_model_init_weights": "pretrained_models/LAION_text2img/bert.ckpt", 34 | "freeze_pretrained_weights": false, 35 | "use_fast_sampling": true, 36 | "fast_sampling_steps": 100, 37 | "fast_sampler": "plms", 38 | "clip_denoised": false, 39 | "scale_factor": 0.18215 40 | } 41 | }, 42 | "denoising_model": { 43 | "model": "modules.openai_unet.openaimodel_layout_diffuse.UNetModel", 44 | "model_args": { 45 | "image_size": 32, 46 | "in_channels": 4, 47 | "model_channels": 320, 48 | "out_channels": 4, 49 | "num_res_blocks": 2, 50 | "attention_resolutions": [1, 2, 4], 51 | "channel_mult": [1, 2, 4, 4], 52 | "num_heads": 8, 53 | "use_spatial_transformer": true, 54 | "transformer_depth": 1, 55 | "use_checkpoint": true, 56 | "legacy": false, 57 | "prompt_dim": 128, 58 | "num_prompt": 64, 59 | "image_in_kv": true, 60 | "text_context_dim": 1280, 61 | "instance_prompt_attn_type": "layout_partial_v2", 62 | "instance_attn_res": [1, 2], 63 | "instance_prompt_args": { 64 | "num_classes": 181, 65 | "embedding_dim": 128 66 | }, 67 | "verbose": true 68 | } 69 | }, 70 | "vqvae_model": { 71 | "model": "modules.kl_autoencoder.autoencoder.AutoencoderKL", 72 | "model_args": { 73 | "embed_dim": 4, 74 | "ddconfig": { 75 | "double_z": true, 76 | "z_channels": 4, 77 | "resolution": 256, 78 | "in_channels": 3, 79 | "out_ch": 3, 80 | "ch": 128, 81 | "ch_mult": [1, 2, 4, 4], 82 | "num_res_blocks": 2, 83 | "attn_resolutions": [], 84 | "dropout": 0.0 85 | }, 86 | "lossconfig": { 87 | "target": "torch.nn.Identity" 88 | } 89 | } 90 | }, 91 | "text_model": { 92 | "model": "modules.bert.bert_embedder.BERTEmbedder", 93 | "model_args": { 94 | "n_embed": 1280, 95 | "n_layer": 32 96 | } 97 | }, 98 | "data": { 99 | "dataset": "vg_layout_label", 100 | "root": "/home/ubuntu/disk2/data/VG", 101 | "image_size": 256, 102 | "train_args": { 103 | "split": "train", 104 | "data_len": -1 105 | }, 106 | "val_args": { 107 | "split": "val", 108 | "data_len": 1 109 | }, 110 | "batch_size": 1, 111 | "val_batch_size": 1 112 | }, 113 | "sampling_args": { 114 | "sampling_w_noise": false, 115 | "image_size": 32, 116 | "in_channel": 4, 117 | "num_samples": -1, 118 | "callbacks": [ 119 | "callbacks.coco_layout.sampling_save_fig.VGLayoutImageSavingCallback" 120 | ] 121 | } 122 | } -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from model_utils import default 2 | 3 | def get_dataset(**kwargs): 4 | dataset = kwargs['dataset'] 5 | if dataset in ['celeb_mask']: 6 | from .face_parsing import CelebAMaskHQ, get_train_transform, get_test_transform 7 | root = kwargs['root'] 8 | image_size = kwargs['image_size'] 9 | train_set = CelebAMaskHQ( 10 | root, 11 | dual_transforms=get_train_transform(image_size), 12 | **kwargs['train_args'] 13 | ) 14 | val_set = CelebAMaskHQ( 15 | root, 16 | dual_transforms=get_test_transform(image_size), 17 | **kwargs['val_args'] 18 | ) 19 | elif dataset == 'coco_stuff_layout': 20 | from .coco_w_stuff import get_cocostuff_dataset 21 | root = kwargs['root'] 22 | image_size = kwargs['image_size'] 23 | train_set, val_set = get_cocostuff_dataset( 24 | root, image_size 25 | ) 26 | elif dataset == 'coco_stuff_layout_caption': 27 | from .coco_w_stuff import get_cocostuff_caption_dataset 28 | root = kwargs['root'] 29 | image_size = kwargs['image_size'] 30 | train_set, val_set = get_cocostuff_caption_dataset( 31 | root, image_size, **kwargs['dataset_args'] 32 | ) 33 | elif dataset == 'coco_stuff_layout_caption_label': 34 | from .coco_w_stuff import get_cocostuff_caption_label_dataset 35 | root = kwargs['root'] 36 | image_size = kwargs['image_size'] 37 | train_set, val_set = get_cocostuff_caption_label_dataset( 38 | root, image_size, **kwargs['dataset_args'] 39 | ) 40 | elif dataset == 'vg_layout_label': 41 | from .vg import get_vg_caption_dataset 42 | root = kwargs['root'] 43 | image_size = kwargs['image_size'] 44 | train_set, val_set = get_vg_caption_dataset( 45 | root, image_size 46 | ) 47 | else: 48 | raise NotImplementedError(f'got {dataset}') 49 | 50 | return train_set, val_set -------------------------------------------------------------------------------- /data/face_parsing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import torch 5 | import torch.nn.functional as F 6 | import albumentations as A 7 | import math 8 | from PIL import Image 9 | from torchvision.datasets.vision import VisionDataset 10 | 11 | celebAMask_label_list = ['background', 'skin', 'nose', 'eye_g', 'l_eye', 'r_eye', 'l_brow', 'r_brow', 'l_ear', 'r_ear', 'mouth', 'u_lip', 'l_lip', 'hair', 'hat', 'ear_r', 'neck_l', 'neck', 'cloth'] 12 | celebAMask_labels = {i: v for i, v in enumerate(celebAMask_label_list)} 13 | flip_mapping = torch.tensor([-1] * len(list(celebAMask_labels.keys())), dtype=torch.long) 14 | for i, x in enumerate(celebAMask_labels.keys()): 15 | flip_mapping[x] = i 16 | flip_mapping[4] = 5; flip_mapping[5] = 4 17 | flip_mapping[6] = 7; flip_mapping[7] = 6 18 | flip_mapping[8] = 9; flip_mapping[9] = 8 19 | 20 | class MaskMeshConverter(torch.nn.Module): 21 | ''' 22 | convert a segmentation mask to multiple channels using mesh grid 23 | ''' 24 | def __init__(self, labels, mesh_dim=3): 25 | super().__init__() 26 | self.labels = labels 27 | num_grid_each_dim = math.ceil(len(labels)**(1/mesh_dim)) 28 | mesh_d = torch.meshgrid( 29 | *[torch.linspace(0,1,num_grid_each_dim)]*mesh_dim 30 | ) 31 | mesh_d = [i.reshape(-1) for i in mesh_d] 32 | self.mesh = torch.stack(mesh_d, dim=-1) 33 | self.mesh_embedding = torch.nn.Embedding(len(self.mesh), mesh_dim) 34 | self.mesh_embedding.weight.data = self.mesh 35 | 36 | # maps index in mask to index in mesh 37 | assert torch.tensor(labels).min() >= 0 38 | index_map = torch.tensor([-1] * (torch.tensor(labels).max() + 1), dtype=torch.int) 39 | reverse_index_map = torch.tensor([-1] * (torch.tensor(labels).max() + 1), dtype=torch.int) 40 | for i, x in enumerate(labels): 41 | index_map[x] = i 42 | reverse_index_map[i] = x 43 | self.register_buffer('index_map', index_map) 44 | self.register_buffer('reverse_index_map', reverse_index_map) 45 | 46 | def index_mask_to_nd_mesh(self, mask): 47 | mesh_idx_mask = self.index_map[mask] 48 | embedding = self.mesh_embedding(mesh_idx_mask).detach() 49 | return embedding 50 | 51 | def nd_mesh_to_index_mask(self, mesh): 52 | mesh_size = mesh.size() 53 | mesh = mesh.view(mesh_size[0], -1, mesh_size[-1]) # bs, hxw, mesh_dim 54 | mesh_dist_to_embedding = torch.cdist( 55 | mesh, 56 | self.mesh_embedding.weight.data[None].expand( 57 | mesh_size[0], -1, -1 58 | ), 59 | p=1 60 | ) 61 | mesh_nn = torch.argmin( 62 | mesh_dist_to_embedding[:, :, :len(self.labels)], 63 | dim=-1, keepdim=True 64 | ).view(*mesh_size[:-1]) # bs, h, w 65 | return self.reverse_index_map[mesh_nn] 66 | 67 | 68 | def __call__(self, mask): 69 | return self.index_mask_to_nd_mesh(mask) 70 | 71 | class MaskOnehotConverter(torch.nn.Module): 72 | def __init__(self, labels): 73 | '''Two step mapping to handle incontinuous index case (e.g. 255 for ignore)''' 74 | super().__init__() 75 | self.num_classes = len(labels) 76 | # maps index in mask to index in one-hot 77 | assert torch.tensor(labels).min() >= 0 78 | index_map = torch.tensor([-1] * (torch.tensor(labels).max() + 1), dtype=torch.int) 79 | reverse_index_map = torch.tensor([-1] * (torch.tensor(labels).max() + 1), dtype=torch.int) 80 | for i, x in enumerate(labels): 81 | index_map[x] = i 82 | reverse_index_map[i] = x 83 | self.register_buffer('index_map', index_map) 84 | self.register_buffer('reverse_index_map', reverse_index_map) 85 | 86 | def index_mask_to_one_hot(self, mask): 87 | continous_idx_mask = self.index_map[mask].to(torch.long) 88 | one_hot_tensor = F.one_hot(continous_idx_mask, num_classes=self.num_classes) 89 | return one_hot_tensor.to(torch.float) 90 | 91 | def one_hot_to_index_mask(self, one_hot_tensor): 92 | # one_hot_tensor: b, dim, h, w 93 | continous_idx_mask = torch.argmax(one_hot_tensor, dim=1).to(torch.long) 94 | mask = self.reverse_index_map[continous_idx_mask] 95 | return mask 96 | 97 | def __call__(self, mask): 98 | return self.index_mask_to_one_hot(mask) 99 | 100 | def get_train_transform(image_size): 101 | train_transform = A.Compose([ 102 | A.PadIfNeeded(min_height=image_size, min_width=image_size), 103 | A.Resize(width=image_size, height=image_size, interpolation=cv2.INTER_AREA), 104 | # A.HorizontalFlip(p=0.5), # disable it since we need to modify left and right index for eyes, eyebrows and ears. Move this function to dataset. 105 | # A.RandomBrightnessContrast(p=0.2), # maybe not good for face generation? 106 | ]) 107 | return train_transform 108 | 109 | def get_test_transform(image_size): 110 | test_transform = A.Compose([ 111 | A.PadIfNeeded(min_height=image_size, min_width=image_size), 112 | A.Resize(width=image_size, height=image_size, interpolation=cv2.INTER_AREA), 113 | ]) 114 | return test_transform 115 | 116 | class CelebAMaskHQ(VisionDataset): 117 | def __init__( 118 | self, root, split='train', data_len=-1, 119 | transform=None, target_transform=None, 120 | dual_transforms=None, 121 | ): 122 | ''' 123 | root=/home/ubuntu/disk2/data/face/CelebAMask-HQ 124 | Remember to preprocess dataset with https://github.com/switchablenorms/CelebAMask-HQ/blob/master/face_parsing/Data_preprocessing/g_mask.py 125 | ''' 126 | super().__init__(root, transform=transform, target_transform=target_transform) 127 | assert split in ['train', 'val'], f'got {split}' 128 | self.split = split 129 | self.img_dir = os.path.join(root, 'CelebA-HQ-img') 130 | self.mask_dir = os.path.join(root, 'CelebAMaskHQ-mask') 131 | with open(os.path.join(root, f'{split}.txt')) as IN: 132 | self.keys = [i.strip() for i in IN] 133 | if data_len > 0: 134 | self.keys = self.keys[:data_len] 135 | self.dual_transforms = dual_transforms 136 | 137 | def _load_image(self, image_name): 138 | image_path = os.path.join( 139 | self.img_dir, 140 | f'{image_name}.jpg' 141 | ) 142 | image = cv2.imread(image_path)[...,::-1] 143 | return image 144 | 145 | def _load_mask(self, image_name): 146 | image_path = os.path.join( 147 | self.mask_dir, 148 | f'{image_name}.png' 149 | ) 150 | image = np.array(Image.open(image_path)) 151 | return image 152 | 153 | def _flip(self, image, mask): 154 | if self.split == 'train' and np.random.rand() < 0.5: 155 | image = torch.flip(image, dims=[1]) 156 | mask = torch.flip(mask, dims=[1]) 157 | mask = flip_mapping[mask.to(torch.long)] 158 | return image, mask 159 | 160 | def _process_mask(self, mask): 161 | return mask.to(torch.long) 162 | 163 | def __getitem__(self, index): 164 | this_key = self.keys[index] 165 | image = self._load_image(this_key) 166 | image = (image).astype(np.float32) / 255. 167 | mask = self._load_mask(this_key) 168 | 169 | transformed = self.dual_transforms(image=image, masks=[mask]) 170 | image = torch.from_numpy(transformed['image']) 171 | mask = torch.from_numpy(transformed['masks'][0]) 172 | 173 | # flip during training with correct mask index 174 | image, mask = self._flip(image, mask) 175 | 176 | mask = self._process_mask(mask) 177 | # h, w, dim -> dim, h, w 178 | image = image.permute(2, 0, 1) 179 | 180 | image = (image - 0.5) * 2 181 | 182 | ret = {} 183 | ret['image'] = image # return original image and mask for visualization 184 | ret['seg_mask'] = mask 185 | return ret 186 | 187 | def __len__(self): 188 | """Return the number of images.""" 189 | return len(self.keys) 190 | -------------------------------------------------------------------------------- /data/random_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | class RandomNoise(Dataset): 5 | def __init__(self, h, w, dim, length=500): 6 | self.h = h 7 | self.w = w 8 | self.dim = dim 9 | self.length = length 10 | 11 | def __len__(self): 12 | return self.length 13 | 14 | def __getitem__(self, index): 15 | return torch.randn(self.dim, self.h, self.w) -------------------------------------------------------------------------------- /fid_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from cleanfid import fid 4 | from PIL import ImageFile, PngImagePlugin 5 | ImageFile.LOAD_TRUNCATED_IMAGES = True 6 | PngImagePlugin.MAX_TEXT_CHUNK = 1048576 * 10 7 | 8 | def resize_a_folder(folder_path, size): 9 | os.system(f'python scripts/resize_images.py --indir {folder_path} --size {size}') 10 | return folder_path + f'-{size}' 11 | 12 | def convert_a_folder_to_jpg(folder_path): 13 | os.system(f'python scripts/convert_jpg.py --indir {folder_path}') 14 | return folder_path + f'-jpg' 15 | 16 | def evaluate(args): 17 | src_path = args.src 18 | if args.cvt_jpg_s: 19 | src_path = convert_a_folder_to_jpg(src_path) 20 | if args.resize_s: 21 | src_path = resize_a_folder(args.src, args.target_size) 22 | 23 | dst_path = args.dst 24 | if args.cvt_jpg_d: 25 | dst_path = convert_a_folder_to_jpg(dst_path) 26 | if args.resize_d: 27 | dst_path = resize_a_folder(args.dst, args.target_size) 28 | 29 | fid_score = fid.compute_fid(src_path, dst_path) 30 | print('FID of : {}'.format(fid_score)) 31 | 32 | if __name__ == '__main__': 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('-s', '--src', type=str, default='', help='Ground truth images directory') 35 | parser.add_argument('--resize_s', action='store_true') 36 | parser.add_argument('--cvt_jpg_s', action='store_true') 37 | parser.add_argument('-d', '--dst', type=str, help='Generate images directory') 38 | parser.add_argument('--resize_d', action='store_true') 39 | parser.add_argument('--cvt_jpg_d', action='store_true') 40 | parser.add_argument('--target_size', type=int, default=256) 41 | 42 | ''' parser configs ''' 43 | args = parser.parse_args() 44 | 45 | evaluate(args) -------------------------------------------------------------------------------- /figures/LD_gradio_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cplusx/layout_diffuse/9666cb867313aa693775f6134442dea3734565a5/figures/LD_gradio_demo.gif -------------------------------------------------------------------------------- /figures/LD_interacitve_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cplusx/layout_diffuse/9666cb867313aa693775f6134442dea3734565a5/figures/LD_interacitve_demo.gif -------------------------------------------------------------------------------- /figures/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cplusx/layout_diffuse/9666cb867313aa693775f6134442dea3734565a5/figures/teaser.png -------------------------------------------------------------------------------- /interactive_plotting/app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from flask import Flask, render_template, request, jsonify, make_response, send_file 4 | from PIL import Image 5 | 6 | app = Flask(__name__) 7 | 8 | rectangles = [] 9 | 10 | @app.route('/') 11 | def index(): 12 | return render_template('index.html') 13 | 14 | @app.route('/image/') 15 | def get_image(image_name): 16 | print('INFO: image_name =', image_name) 17 | image_path = os.path.join('tmp', image_name) 18 | resize_image(image_path, target_height=384) 19 | return send_file(image_path, mimetype='image/jpg') 20 | 21 | @app.route('/get_sd_images', methods=['GET', 'POST']) 22 | def get_sd_images(): 23 | if request.method == 'POST': 24 | data = request.get_json() 25 | rectangles = data.get('rectangles', []) 26 | # do something with the rectangles data 27 | this_hash = save_rectangles(rectangles) 28 | try: 29 | # wait for the wait_for_image coroutine to complete and return its result 30 | image_path = wait_for_image(this_hash, timeout=200) 31 | return jsonify({'image_path': image_path}) 32 | except TimeoutError: 33 | return make_response(jsonify({'error': 'Timeout waiting for image'}), 404) 34 | except Exception as e: 35 | print(e) 36 | return make_response(jsonify({'error': 'Error in wait_for_image coroutine'}), 500) 37 | 38 | return make_response(jsonify({'message': 'Method not allowed'}), 405) 39 | 40 | def wait_for_image(hash_value, timeout=60): 41 | # set up a file system watcher to monitor for the image file 42 | watched_folder = 'tmp' 43 | watched_file = hash_value + '.jpg' 44 | timeout = time.time() + timeout 45 | while time.time() < timeout: 46 | for file_name in os.listdir(watched_folder): 47 | if file_name == watched_file or file_name == 'hamburger_pic.jpeg': 48 | save_path = os.path.join(watched_folder, watched_file) 49 | return watched_file 50 | time.sleep(1) 51 | 52 | raise TimeoutError('Timeout waiting for image') 53 | 54 | def save_rectangles(rectangles, save_dir='tmp'): 55 | # save rectangles to a file, the input is a list of dictionary 56 | # e.g., [{'x1': 107, 'y1': 271, 'x2': 386, 'y2': 407, 'color': '#2605af', 'class': 'car'}] 57 | 58 | # make dir if not existing 59 | if not os.path.exists(save_dir): 60 | os.makedirs(save_dir) 61 | 62 | # convert pixel values to relative size 63 | for rectangle in rectangles: 64 | x1, y1, x2, y2 = rectangle['x1'], rectangle['y1'], rectangle['x2'], rectangle['y2'] 65 | w, h = 512, 512 # canvas size is 512x512 pixels 66 | rectangle['x'] = x1 / w 67 | rectangle['y'] = y1 / h 68 | rectangle['w'] = (x2 - x1) / w 69 | rectangle['h'] = (y2 - y1) / h 70 | rectangle.pop('x1', None) 71 | rectangle.pop('y1', None) 72 | rectangle.pop('x2', None) 73 | rectangle.pop('y2', None) 74 | 75 | # compute a hash for saving file name 76 | hash_name = str(hash(str(rectangles)))[1:11] 77 | save_name = hash_name + '.txt' 78 | save_path = os.path.join(save_dir, save_name) 79 | 80 | # save rectangles to file in specified format 81 | with open(save_path, 'w') as f: 82 | for rectangle in rectangles: 83 | x, y, w, h = rectangle['x'], rectangle['y'], rectangle['w'], rectangle['h'] 84 | class_id = rectangle['class'] 85 | f.write(f"{x},{y},{w},{h},{class_id}\n") 86 | 87 | return hash_name 88 | 89 | def resize_image(image_path, target_height=256): 90 | img = Image.open(image_path) 91 | height_percent = target_height / float(img.size[1]) 92 | width_size = int((float(img.size[0]) * float(height_percent))) 93 | img = img.resize((width_size, target_height), Image.ANTIALIAS) 94 | img.save(image_path) 95 | 96 | if __name__ == '__main__': 97 | app.run(debug=True, host='0.0.0.0') -------------------------------------------------------------------------------- /interactive_plotting/static/css/style.css: -------------------------------------------------------------------------------- 1 | /* Change the background color */ 2 | body { 3 | background-color: #F2F2F2; 4 | } 5 | 6 | /* Style the header */ 7 | h1 { 8 | font-size: 36px; 9 | margin-top: 20px; 10 | margin-bottom: 20px; 11 | text-align: center; 12 | } 13 | 14 | /* Style the buttons */ 15 | button { 16 | padding: 10px 20px; 17 | font-size: 20px; 18 | background-color: #4CAF50; 19 | color: white; 20 | border: none; 21 | border-radius: 4px; 22 | cursor: pointer; 23 | } 24 | 25 | button:hover { 26 | background-color: #3e8e41; 27 | } 28 | 29 | #clearbtn { 30 | margin-right: 20px; 31 | } 32 | 33 | #submit { 34 | margin-left: 20px; 35 | } 36 | 37 | /* Style the canvas and image container */ 38 | #canvas-container { 39 | display: flex; 40 | flex-direction: row; 41 | justify-content: space-between; 42 | align-items: flex-start; 43 | } 44 | 45 | #canvas { 46 | position: relative; 47 | width: 512px; 48 | height: 512px; 49 | border: 1px solid #ddd; 50 | background-color: #fff; 51 | order: 1; 52 | } 53 | 54 | #image-container { 55 | max-height: 512px; 56 | overflow-y: auto; 57 | max-width: 800px; 58 | margin-top: 20px; 59 | } 60 | 61 | /* Style the class selector */ 62 | #class-selector { 63 | width: 250px; 64 | padding: 20px; 65 | border: 1px solid #ddd; 66 | background-color: #fff; 67 | box-shadow: 0 4px 8px 0 rgba(0, 0, 0, 0.2); 68 | float: right; 69 | order: 2; 70 | font-size: 16px; 71 | } 72 | 73 | #class-selector h3 { 74 | font-size: 24px; 75 | margin-top: 0; 76 | margin-bottom: 20px; 77 | } 78 | 79 | #class-selector li { 80 | list-style: none; 81 | margin-bottom: 10px; 82 | } 83 | 84 | #class-selector input[type="radio"] { 85 | margin-right: 10px; 86 | cursor: pointer; 87 | } 88 | 89 | #class-selector label { 90 | font-size: 16px; 91 | cursor: pointer; 92 | } 93 | 94 | #class-selector label:hover { 95 | text-decoration: underline; 96 | } 97 | 98 | /* Responsive styles for small screens */ 99 | @media only screen and (max-width: 768px) { 100 | #canvas-container { 101 | flex-direction: column; 102 | align-items: center; 103 | } 104 | 105 | #class-selector { 106 | width: 100%; 107 | margin-top: 20px; 108 | } 109 | } 110 | 111 | #class-selector-list { 112 | max-height: 400px; 113 | overflow-y: auto; 114 | } -------------------------------------------------------------------------------- /interactive_plotting/static/doc/labels.txt: -------------------------------------------------------------------------------- 1 | 0: unlabeled 2 | 1: person 3 | 2: bicycle 4 | 3: car 5 | 4: motorcycle 6 | 5: airplane 7 | 6: bus 8 | 7: train 9 | 8: truck 10 | 9: boat 11 | 10: traffic light 12 | 11: fire hydrant 13 | 12: street sign 14 | 13: stop sign 15 | 14: parking meter 16 | 15: bench 17 | 16: bird 18 | 17: cat 19 | 18: dog 20 | 19: horse 21 | 20: sheep 22 | 21: cow 23 | 22: elephant 24 | 23: bear 25 | 24: zebra 26 | 25: giraffe 27 | 26: hat 28 | 27: backpack 29 | 28: umbrella 30 | 29: shoe 31 | 30: eye glasses 32 | 31: handbag 33 | 32: tie 34 | 33: suitcase 35 | 34: frisbee 36 | 35: skis 37 | 36: snowboard 38 | 37: sports ball 39 | 38: kite 40 | 39: baseball bat 41 | 40: baseball glove 42 | 41: skateboard 43 | 42: surfboard 44 | 43: tennis racket 45 | 44: bottle 46 | 45: plate 47 | 46: wine glass 48 | 47: cup 49 | 48: fork 50 | 49: knife 51 | 50: spoon 52 | 51: bowl 53 | 52: banana 54 | 53: apple 55 | 54: sandwich 56 | 55: orange 57 | 56: broccoli 58 | 57: carrot 59 | 58: hot dog 60 | 59: pizza 61 | 60: donut 62 | 61: cake 63 | 62: chair 64 | 63: couch 65 | 64: potted plant 66 | 65: bed 67 | 66: mirror 68 | 67: dining table 69 | 68: window 70 | 69: desk 71 | 70: toilet 72 | 71: door 73 | 72: tv 74 | 73: laptop 75 | 74: mouse 76 | 75: remote 77 | 76: keyboard 78 | 77: cell phone 79 | 78: microwave 80 | 79: oven 81 | 80: toaster 82 | 81: sink 83 | 82: refrigerator 84 | 83: blender 85 | 84: book 86 | 85: clock 87 | 86: vase 88 | 87: scissors 89 | 88: teddy bear 90 | 89: hair drier 91 | 90: toothbrush 92 | 91: hair brush 93 | 92: banner 94 | 93: blanket 95 | 94: branch 96 | 95: bridge 97 | 96: building-other 98 | 97: bush 99 | 98: cabinet 100 | 99: cage 101 | 100: cardboard 102 | 101: carpet 103 | 102: ceiling-other 104 | 103: ceiling-tile 105 | 104: cloth 106 | 105: clothes 107 | 106: clouds 108 | 107: counter 109 | 108: cupboard 110 | 109: curtain 111 | 110: desk-stuff 112 | 111: dirt 113 | 112: door-stuff 114 | 113: fence 115 | 114: floor-marble 116 | 115: floor-other 117 | 116: floor-stone 118 | 117: floor-tile 119 | 118: floor-wood 120 | 119: flower 121 | 120: fog 122 | 121: food-other 123 | 122: fruit 124 | 123: furniture-other 125 | 124: grass 126 | 125: gravel 127 | 126: ground-other 128 | 127: hill 129 | 128: house 130 | 129: leaves 131 | 130: light 132 | 131: mat 133 | 132: metal 134 | 133: mirror-stuff 135 | 134: moss 136 | 135: mountain 137 | 136: mud 138 | 137: napkin 139 | 138: net 140 | 139: paper 141 | 140: pavement 142 | 141: pillow 143 | 142: plant-other 144 | 143: plastic 145 | 144: platform 146 | 145: playingfield 147 | 146: railing 148 | 147: railroad 149 | 148: river 150 | 149: road 151 | 150: rock 152 | 151: roof 153 | 152: rug 154 | 153: salad 155 | 154: sand 156 | 155: sea 157 | 156: shelf 158 | 157: sky-other 159 | 158: skyscraper 160 | 159: snow 161 | 160: solid-other 162 | 161: stairs 163 | 162: stone 164 | 163: straw 165 | 164: structural-other 166 | 165: table 167 | 166: tent 168 | 167: textile-other 169 | 168: towel 170 | 169: tree 171 | 170: vegetable 172 | 171: wall-brick 173 | 172: wall-concrete 174 | 173: wall-other 175 | 174: wall-panel 176 | 175: wall-stone 177 | 176: wall-tile 178 | 177: wall-wood 179 | 178: water-other 180 | 179: waterdrops 181 | 180: window-blind 182 | 181: window-other 183 | 182: wood 184 | -------------------------------------------------------------------------------- /interactive_plotting/static/js/script.js: -------------------------------------------------------------------------------- 1 | var color = 'black'; 2 | var rectangles = []; 3 | 4 | function init() { 5 | generateClassSelector(); 6 | } 7 | 8 | function clearCanvas() { 9 | rectangles = []; 10 | // remove any existing image 11 | $('#image-container').remove(); 12 | redrawCanvas(); 13 | } 14 | 15 | function undo() { 16 | var lastRect = rectangles.pop(); 17 | if (lastRect) { 18 | $('#canvas').children().last().remove(); 19 | } 20 | } 21 | 22 | function getSDImages() { 23 | $.ajax({ 24 | type: 'POST', 25 | url: '/get_sd_images', 26 | contentType: 'application/json', 27 | data: JSON.stringify({rectangles: rectangles}), 28 | success: function(data) { 29 | // create a new div element 30 | var imageDiv = document.createElement('div'); 31 | imageDiv.id = 'image-container'; 32 | imageDiv.style.border = '1px solid black'; 33 | 34 | // create a new img element and set its src attribute 35 | var image = document.createElement('img'); 36 | console.log(data) 37 | url = '/image/' + data['image_path'] 38 | image.src = url; 39 | 40 | // remove any existing image 41 | $('#image-container').remove(); 42 | 43 | // append the img element to the new div element 44 | imageDiv.appendChild(image); 45 | 46 | // append the new div element to the body 47 | document.body.appendChild(imageDiv); 48 | } 49 | }); 50 | } 51 | 52 | function createRectangle(x1, y1, x2, y2, color, class_name=$("input[name='class']:checked").val()) { 53 | var rect = $('
').css({ 54 | position: 'absolute', 55 | left: Math.min(x1, x2) + 'px', 56 | top: Math.min(y1, y2) + 'px', 57 | width: Math.abs(x2 - x1) + 'px', 58 | height: Math.abs(y2 - y1) + 'px', 59 | border: '2px solid ' + color 60 | }); 61 | 62 | var colorbox = $('
').css({ 63 | position: 'absolute', 64 | left: '0px', 65 | top: '0px', 66 | width: '50px', 67 | height: '20px', 68 | background: 'none', 69 | color: color, 70 | 'text-align': 'center', 71 | 'font-size': '12px', 72 | 'line-height': '20px' 73 | }).text(class_name); 74 | 75 | rect.append(colorbox); 76 | var rect_config = { 77 | x1: Math.max(0, Math.min(x1, x2)), 78 | y1: Math.max(0, Math.min(y1, y2)), 79 | x2: Math.min(511, Math.max(x1, x2)), 80 | y2: Math.min(511, Math.max(y1, y2)), 81 | color: color, 82 | class: class_name 83 | }; 84 | 85 | return {rect, rect_config}; 86 | } 87 | 88 | function drawRectangle(rect) { 89 | $('#canvas').append(rect); 90 | } 91 | 92 | function updateRectangle(rect, x1, y1, x2, y2) { 93 | rect.css({ 94 | left: Math.min(x1, x2) + 'px', 95 | top: Math.min(y1, y2) + 'px', 96 | width: Math.abs(x2 - x1) + 'px', 97 | height: Math.abs(y2 - y1) + 'px' 98 | }); 99 | 100 | return rect; 101 | } 102 | 103 | function redrawCanvas() { 104 | $('#canvas').empty(); 105 | for (var i = 0; i < rectangles.length; i++) { 106 | var rect = rectangles[i]; 107 | res = createRectangle( 108 | rect.x1, rect.y1, rect.x2, rect.y2, rect.color, rect.class 109 | ); 110 | thisRect = res.rect 111 | drawRectangle(thisRect) 112 | } 113 | } 114 | 115 | $(function() { 116 | var isDrawing = false; 117 | 118 | $('#canvas').mousedown(function(event) { 119 | startX = event.offsetX; 120 | startY = event.offsetY; 121 | isDrawing = true; 122 | res = createRectangle(startX, startY, startX, startY, color); 123 | currentRect = res.rect; 124 | currentRectConfig = res.rect_config; 125 | }); 126 | 127 | $('#canvas').mousemove(function(event) { 128 | if (isDrawing) { 129 | // clear the canvas and redraw any existing rectangles 130 | redrawCanvas() 131 | // get the current end point and draw a new rectangle 132 | endX = event.offsetX; 133 | endY = event.offsetY; 134 | currentRect = updateRectangle(currentRect, startX, startY, endX, endY); 135 | drawRectangle(currentRect) 136 | } 137 | }); 138 | 139 | $('#canvas').mouseup(function(event) { 140 | if (isDrawing) { 141 | isDrawing = false; 142 | 143 | // get the final end point and create a new rectangle 144 | // endX = event.offsetX; 145 | // endY = event.offsetY; 146 | res = createRectangle(startX, startY, endX, endY, color); 147 | currentRect = res.rect; 148 | currentRectConfig = res.rect_config; 149 | rectangles.push(currentRectConfig); 150 | } 151 | }); 152 | 153 | $('#clearbtn').click(clearCanvas); 154 | // Add event listener to Plot button 155 | $('#submit').click(getSDImages); 156 | // redraw the canvas on page load to display any existing rectangles 157 | redrawCanvas(); 158 | }); 159 | 160 | function generateClassSelector() { 161 | // Make a GET request to read the labels file 162 | const xhr = new XMLHttpRequest(); 163 | xhr.open('GET', '/static/doc/labels.txt'); 164 | xhr.onreadystatechange = function() { 165 | if (xhr.readyState === 4 && xhr.status === 200) { 166 | // Split the labels into an array 167 | const labels = xhr.responseText.trim().split('\n'); 168 | const classes = {}; 169 | // Loop through each label to generate a unique color and create a radio button 170 | for (let i = 1; i < labels.length; i++) { 171 | const [classId, className] = labels[i].split(': '); 172 | classes[className] = `#${(Math.random()*0xFFFFFF<<0).toString(16).padStart(6, '0')}`; 173 | const input = document.createElement('input'); 174 | input.type = 'radio'; 175 | input.name = 'class'; 176 | input.value = className; 177 | const label = document.createElement('label'); 178 | label.htmlFor = className; 179 | label.innerText = className; 180 | label.style.color = classes[className]; 181 | const li = document.createElement('li'); 182 | li.appendChild(input); 183 | li.appendChild(label); 184 | document.querySelector('#class-selector-list').appendChild(li); 185 | } 186 | // Attach event listener to radio buttons to update label color 187 | $("input[name='class']").on('change', function() { 188 | const className = $("input[name='class']:checked").val(); 189 | color = classes[className]; 190 | }); 191 | } 192 | }; 193 | xhr.send(); 194 | } 195 | -------------------------------------------------------------------------------- /interactive_plotting/templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | LayoutDiffuse Interactive Plotter 5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 |

LayoutDiffuse Interactive Plotter

13 |
14 |
15 |
16 |
17 |

Select a class:

18 |
    19 |
    20 |
    21 |
    22 | 23 | 24 | 25 |
    26 |
    27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import json 5 | from torch.utils.data import DataLoader 6 | from pytorch_lightning import Trainer, seed_everything 7 | from data import get_dataset 8 | from train_utils import get_models, get_DDPM, get_logger_and_callbacks 9 | 10 | if __name__ == '__main__': 11 | seed_everything(42) 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument( 14 | '-c', '--config', type=str, 15 | default='config/train.json') 16 | parser.add_argument( 17 | '-r', '--resume', action="store_true" 18 | ) 19 | parser.add_argument( 20 | '-n', '--nnode', type=int, default=1 21 | ) 22 | 23 | ''' parser configs ''' 24 | args_raw = parser.parse_args() 25 | with open(args_raw.config, 'r') as IN: 26 | args = json.load(IN) 27 | args['resume'] = args_raw.resume 28 | args['nnode'] = args_raw.nnode 29 | expt_name = args['expt_name'] 30 | expt_dir = args['expt_dir'] 31 | expt_path = os.path.join(expt_dir, expt_name) 32 | os.makedirs(expt_path, exist_ok=True) 33 | 34 | '''1. create denoising model''' 35 | models = get_models(args) 36 | 37 | diffusion_configs = args['diffusion'] 38 | ddpm_model = get_DDPM( 39 | diffusion_configs=diffusion_configs, 40 | log_args=args, 41 | **models 42 | ) 43 | 44 | '''2. dataset and dataloader''' 45 | data_args = args['data'] 46 | train_set, val_set = get_dataset(**data_args) 47 | train_loader = DataLoader( 48 | train_set, batch_size=data_args['batch_size'], shuffle=True, 49 | num_workers=4*len(args['trainer_args']['devices']), pin_memory=True 50 | ) 51 | val_loader = DataLoader( 52 | val_set, batch_size=data_args['val_batch_size'], 53 | num_workers=len(args['trainer_args']['devices']), pin_memory=True 54 | ) 55 | '''3. create callbacks''' 56 | wandb_logger, callbacks = get_logger_and_callbacks(expt_name, expt_path, args) 57 | 58 | '''4. trainer''' 59 | trainer_args = { 60 | "max_epochs": 1000, 61 | "accelerator": "gpu", 62 | "devices": [0], 63 | "limit_val_batches": 1, 64 | "strategy": "ddp", 65 | "check_val_every_n_epoch": 1, 66 | "num_nodes": args['nnode'] 67 | # "benchmark" :True 68 | } 69 | config_trainer_args = args['trainer_args'] if args.get('trainer_args') is not None else {} 70 | trainer_args.update(config_trainer_args) 71 | print(f'Training args are {trainer_args}') 72 | trainer = Trainer( 73 | logger = wandb_logger, 74 | callbacks = callbacks, 75 | **trainer_args 76 | ) 77 | '''5. start training''' 78 | if args['resume']: 79 | print('INFO: Try to resume from checkpoint') 80 | ckpt_path = os.path.join(expt_path, 'latest.ckpt') 81 | if os.path.exists(ckpt_path): 82 | print(f'INFO: Found checkpoint {ckpt_path}') 83 | # ckpt = torch.load(ckpt_path, map_location='cpu')['state_dict'] 84 | # ddpm_model.load_state_dict(ckpt) 85 | else: 86 | ckpt_path = None 87 | else: 88 | ckpt_path = None 89 | trainer.fit( 90 | ddpm_model, train_loader, val_loader, 91 | ckpt_path=ckpt_path 92 | ) 93 | -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch 3 | import numpy as np 4 | from inspect import isfunction 5 | 6 | def instantiate_from_config(config): 7 | if not "target" in config: 8 | raise KeyError("Expected key `target` to instantiate.") 9 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 10 | 11 | 12 | def get_obj_from_str(string, reload=False): 13 | module, cls = string.rsplit(".", 1) 14 | if reload: 15 | module_imp = importlib.import_module(module) 16 | importlib.reload(module_imp) 17 | return getattr(importlib.import_module(module, package=None), cls) 18 | 19 | def exists(x): 20 | return x is not None 21 | 22 | def default(val, d): 23 | if exists(val): 24 | return val 25 | return d() if isfunction(d) else d 26 | 27 | def noise_like(shape, device, repeat=False): 28 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 29 | noise = lambda: torch.randn(shape, device=device) 30 | return repeat_noise() if repeat else noise() 31 | 32 | def extract_into_tensor(a, t, x_shape): 33 | b, *_ = t.shape 34 | out = a.gather(-1, t) 35 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 36 | 37 | def right_pad_dims_to(x, t): 38 | padding_dims = x.ndim - t.ndim 39 | if padding_dims <= 0: 40 | return t 41 | return t.view(*t.shape, *((1,) * padding_dims)) 42 | 43 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 44 | if schedule == "linear": 45 | betas = ( 46 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 47 | ) 48 | 49 | elif schedule == "cosine": 50 | timesteps = ( 51 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 52 | ) 53 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 54 | alphas = torch.cos(alphas).pow(2) 55 | alphas = alphas / alphas[0] 56 | betas = 1 - alphas[1:] / alphas[:-1] 57 | betas = np.clip(betas, a_min=0, a_max=0.999) 58 | 59 | elif schedule == "sqrt_linear": 60 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 61 | elif schedule == "sqrt": 62 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 63 | else: 64 | raise ValueError(f"schedule '{schedule}' unknown.") 65 | return betas.numpy() 66 | 67 | 68 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 69 | if ddim_discr_method == 'uniform': 70 | c = num_ddpm_timesteps // num_ddim_timesteps 71 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 72 | elif ddim_discr_method == 'quad': 73 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 74 | else: 75 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 76 | 77 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 78 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 79 | steps_out = ddim_timesteps + 1 80 | if verbose: 81 | print(f'Selected timesteps for ddim sampler: {steps_out}') 82 | return steps_out 83 | 84 | 85 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 86 | # select alphas for computing the variance schedule 87 | alphas = alphacums[ddim_timesteps] 88 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 89 | 90 | # according the the formula provided in https://arxiv.org/abs/2010.02502 91 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 92 | if verbose: 93 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 94 | print(f'For the chosen value of eta, which is {eta}, ' 95 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 96 | return sigmas, alphas, alphas_prev 97 | 98 | 99 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 100 | """ 101 | Create a beta schedule that discretizes the given alpha_t_bar function, 102 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 103 | :param num_diffusion_timesteps: the number of betas to produce. 104 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 105 | produces the cumulative product of (1-beta) up to that 106 | part of the diffusion process. 107 | :param max_beta: the maximum beta to use; use values lower than 1 to 108 | prevent singularities. 109 | """ 110 | betas = [] 111 | for i in range(num_diffusion_timesteps): 112 | t1 = i / num_diffusion_timesteps 113 | t2 = (i + 1) / num_diffusion_timesteps 114 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 115 | return np.array(betas) -------------------------------------------------------------------------------- /modules/bert/bert_embedder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .x_transformer import TransformerWrapper, Encoder 3 | 4 | class BERTTokenizer(torch.nn.Module): 5 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" 6 | def __init__(self, device="cuda", vq_interface=True, max_length=77): 7 | super().__init__() 8 | from transformers import BertTokenizerFast # TODO: add to reuquirements 9 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 10 | self.device = device 11 | self.vq_interface = vq_interface 12 | self.max_length = max_length 13 | 14 | def forward(self, text): 15 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 16 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 17 | tokens = batch_encoding["input_ids"].to(self.device) 18 | return tokens 19 | 20 | @torch.no_grad() 21 | def encode(self, text): 22 | tokens = self(text) 23 | if not self.vq_interface: 24 | return tokens 25 | return None, None, [None, None, tokens] 26 | 27 | def decode(self, text): 28 | return text 29 | 30 | 31 | class BERTEmbedder(torch.nn.Module): 32 | """Uses the BERT tokenizr model and add some transformer encoder layers""" 33 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, 34 | device="cuda",use_tokenizer=True, embedding_dropout=0.0): 35 | super().__init__() 36 | self.use_tknz_fn = use_tokenizer 37 | if self.use_tknz_fn: 38 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) 39 | self.device = device 40 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 41 | attn_layers=Encoder(dim=n_embed, depth=n_layer), 42 | emb_dropout=embedding_dropout) 43 | 44 | def forward(self, text): 45 | if self.use_tknz_fn: 46 | tokens = self.tknz_fn(text)#.to(self.device) 47 | else: 48 | tokens = text 49 | z = self.transformer(tokens, return_embeddings=True) 50 | return z 51 | 52 | def encode(self, text): 53 | # output of length 77 54 | return self(text) -------------------------------------------------------------------------------- /modules/kl_autoencoder/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from modules.vqvae.model import Encoder, Decoder 7 | 8 | from model_utils import instantiate_from_config 9 | 10 | class DiagonalGaussianDistribution(object): 11 | def __init__(self, parameters, deterministic=False): 12 | self.parameters = parameters 13 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 14 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 15 | self.deterministic = deterministic 16 | self.std = torch.exp(0.5 * self.logvar) 17 | self.var = torch.exp(self.logvar) 18 | if self.deterministic: 19 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 20 | 21 | def sample(self): 22 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 23 | return x 24 | 25 | def kl(self, other=None): 26 | if self.deterministic: 27 | return torch.Tensor([0.]) 28 | else: 29 | if other is None: 30 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 31 | + self.var - 1.0 - self.logvar, 32 | dim=[1, 2, 3]) 33 | else: 34 | return 0.5 * torch.sum( 35 | torch.pow(self.mean - other.mean, 2) / other.var 36 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 37 | dim=[1, 2, 3]) 38 | 39 | def nll(self, sample, dims=[1,2,3]): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | logtwopi = np.log(2.0 * np.pi) 43 | return 0.5 * torch.sum( 44 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 45 | dim=dims) 46 | 47 | def mode(self): 48 | return self.mean 49 | 50 | class AutoencoderKL(pl.LightningModule): 51 | def __init__(self, 52 | ddconfig, 53 | lossconfig, 54 | embed_dim, 55 | ckpt_path=None, 56 | ignore_keys=[], 57 | image_key="image", 58 | colorize_nlabels=None, 59 | monitor=None, 60 | ): 61 | super().__init__() 62 | self.image_key = image_key 63 | self.encoder = Encoder(**ddconfig) 64 | self.decoder = Decoder(**ddconfig) 65 | self.loss = instantiate_from_config(lossconfig) 66 | assert ddconfig["double_z"] 67 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 68 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 69 | self.embed_dim = embed_dim 70 | if colorize_nlabels is not None: 71 | assert type(colorize_nlabels)==int 72 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 73 | if monitor is not None: 74 | self.monitor = monitor 75 | if ckpt_path is not None: 76 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 77 | 78 | def init_from_ckpt(self, path, ignore_keys=list()): 79 | sd = torch.load(path, map_location="cpu")["state_dict"] 80 | keys = list(sd.keys()) 81 | for k in keys: 82 | for ik in ignore_keys: 83 | if k.startswith(ik): 84 | print("Deleting key {} from state_dict.".format(k)) 85 | del sd[k] 86 | self.load_state_dict(sd, strict=False) 87 | print(f"Restored from {path}") 88 | 89 | def encode(self, x): 90 | h = self.encoder(x) 91 | moments = self.quant_conv(h) 92 | posterior = DiagonalGaussianDistribution(moments) 93 | # TODO check if need to put sample into DDIM_ldm class 94 | enc = posterior.sample() 95 | return enc #posterior 96 | 97 | def decode(self, z): 98 | z = self.post_quant_conv(z) 99 | dec = self.decoder(z) 100 | return dec 101 | 102 | def forward(self, input, sample_posterior=True): 103 | posterior = self.encode(input) 104 | if sample_posterior: 105 | z = posterior.sample() 106 | else: 107 | z = posterior.mode() 108 | dec = self.decode(z) 109 | return dec, posterior 110 | 111 | def get_input(self, batch, k): 112 | x = batch[k] 113 | if len(x.shape) == 3: 114 | x = x[..., None] 115 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 116 | return x 117 | 118 | def training_step(self, batch, batch_idx, optimizer_idx): 119 | inputs = self.get_input(batch, self.image_key) 120 | reconstructions, posterior = self(inputs) 121 | 122 | if optimizer_idx == 0: 123 | # train encoder+decoder+logvar 124 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 125 | last_layer=self.get_last_layer(), split="train") 126 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 127 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 128 | return aeloss 129 | 130 | if optimizer_idx == 1: 131 | # train the discriminator 132 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 133 | last_layer=self.get_last_layer(), split="train") 134 | 135 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 136 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 137 | return discloss 138 | 139 | def validation_step(self, batch, batch_idx): 140 | inputs = self.get_input(batch, self.image_key) 141 | reconstructions, posterior = self(inputs) 142 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 143 | last_layer=self.get_last_layer(), split="val") 144 | 145 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 146 | last_layer=self.get_last_layer(), split="val") 147 | 148 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) 149 | self.log_dict(log_dict_ae) 150 | self.log_dict(log_dict_disc) 151 | return self.log_dict 152 | 153 | def configure_optimizers(self): 154 | lr = self.learning_rate 155 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 156 | list(self.decoder.parameters())+ 157 | list(self.quant_conv.parameters())+ 158 | list(self.post_quant_conv.parameters()), 159 | lr=lr, betas=(0.5, 0.9)) 160 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 161 | lr=lr, betas=(0.5, 0.9)) 162 | return [opt_ae, opt_disc], [] 163 | 164 | def get_last_layer(self): 165 | return self.decoder.conv_out.weight 166 | 167 | @torch.no_grad() 168 | def log_images(self, batch, only_inputs=False, **kwargs): 169 | log = dict() 170 | x = self.get_input(batch, self.image_key) 171 | x = x.to(self.device) 172 | if not only_inputs: 173 | xrec, posterior = self(x) 174 | if x.shape[1] > 3: 175 | # colorize with random projection 176 | assert xrec.shape[1] > 3 177 | x = self.to_rgb(x) 178 | xrec = self.to_rgb(xrec) 179 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 180 | log["reconstructions"] = xrec 181 | log["inputs"] = x 182 | return log 183 | 184 | def to_rgb(self, x): 185 | assert self.image_key == "segmentation" 186 | if not hasattr(self, "colorize"): 187 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 188 | x = F.conv2d(x, weight=self.colorize) 189 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 190 | return x -------------------------------------------------------------------------------- /modules/openai_unet/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from model_utils import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def checkpoint(func, inputs, params, flag): 103 | """ 104 | Evaluate a function without caching intermediate activations, allowing for 105 | reduced memory at the expense of extra compute in the backward pass. 106 | :param func: the function to evaluate. 107 | :param inputs: the argument sequence to pass to `func`. 108 | :param params: a sequence of parameters `func` depends on but does not 109 | explicitly take as arguments. 110 | :param flag: if False, disable gradient checkpointing. 111 | """ 112 | if flag: 113 | args = tuple(inputs) + tuple(params) 114 | return CheckpointFunction.apply(func, len(inputs), *args) 115 | else: 116 | return func(*inputs) 117 | 118 | 119 | class CheckpointFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, run_function, length, *args): 122 | ctx.run_function = run_function 123 | ctx.input_tensors = list(args[:length]) 124 | ctx.input_params = list(args[length:]) 125 | 126 | with torch.no_grad(): 127 | output_tensors = ctx.run_function(*ctx.input_tensors) 128 | return output_tensors 129 | 130 | @staticmethod 131 | def backward(ctx, *output_grads): 132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 133 | with torch.enable_grad(): 134 | # Fixes a bug where the first op in run_function modifies the 135 | # Tensor storage in place, which is not allowed for detach()'d 136 | # Tensors. 137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 138 | output_tensors = ctx.run_function(*shallow_copies) 139 | input_grads = torch.autograd.grad( 140 | output_tensors, 141 | ctx.input_tensors + ctx.input_params, 142 | output_grads, 143 | allow_unused=True, 144 | ) 145 | del ctx.input_tensors 146 | del ctx.input_params 147 | del output_tensors 148 | return (None, None) + input_grads 149 | 150 | 151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 152 | """ 153 | Create sinusoidal timestep embeddings. 154 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 155 | These may be fractional. 156 | :param dim: the dimension of the output. 157 | :param max_period: controls the minimum frequency of the embeddings. 158 | :return: an [N x dim] Tensor of positional embeddings. 159 | """ 160 | if not repeat_only: 161 | half = dim // 2 162 | freqs = torch.exp( 163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 164 | ).to(device=timesteps.device) 165 | args = timesteps[:, None].float() * freqs[None] 166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 167 | if dim % 2: 168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 169 | else: 170 | embedding = repeat(timesteps, 'b -> b d', d=dim) 171 | return embedding 172 | 173 | 174 | def zero_module(module): 175 | """ 176 | Zero out the parameters of a module and return it. 177 | """ 178 | for p in module.parameters(): 179 | p.detach().zero_() 180 | return module 181 | 182 | 183 | def scale_module(module, scale): 184 | """ 185 | Scale the parameters of a module and return it. 186 | """ 187 | for p in module.parameters(): 188 | p.detach().mul_(scale) 189 | return module 190 | 191 | 192 | def mean_flat(tensor): 193 | """ 194 | Take the mean over all non-batch dimensions. 195 | """ 196 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 197 | 198 | 199 | def normalization(channels): 200 | """ 201 | Make a standard normalization layer. 202 | :param channels: number of input channels. 203 | :return: an nn.Module for normalization. 204 | """ 205 | return GroupNorm32(32, channels) 206 | 207 | 208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 209 | class SiLU(nn.Module): 210 | def forward(self, x): 211 | return x * torch.sigmoid(x) 212 | 213 | 214 | class GroupNorm32(nn.GroupNorm): 215 | def forward(self, x): 216 | return super().forward(x.float()).type(x.dtype) 217 | 218 | def conv_nd(dims, *args, **kwargs): 219 | """ 220 | Create a 1D, 2D, or 3D convolution module. 221 | """ 222 | if dims == 1: 223 | return nn.Conv1d(*args, **kwargs) 224 | elif dims == 2: 225 | return nn.Conv2d(*args, **kwargs) 226 | elif dims == 3: 227 | return nn.Conv3d(*args, **kwargs) 228 | raise ValueError(f"unsupported dimensions: {dims}") 229 | 230 | 231 | def linear(*args, **kwargs): 232 | """ 233 | Create a linear module. 234 | """ 235 | return nn.Linear(*args, **kwargs) 236 | 237 | 238 | def avg_pool_nd(dims, *args, **kwargs): 239 | """ 240 | Create a 1D, 2D, or 3D average pooling module. 241 | """ 242 | if dims == 1: 243 | return nn.AvgPool1d(*args, **kwargs) 244 | elif dims == 2: 245 | return nn.AvgPool2d(*args, **kwargs) 246 | elif dims == 3: 247 | return nn.AvgPool3d(*args, **kwargs) 248 | raise ValueError(f"unsupported dimensions: {dims}") 249 | 250 | 251 | class HybridConditioner(nn.Module): 252 | 253 | def __init__(self, c_concat_config, c_crossattn_config): 254 | super().__init__() 255 | self.concat_conditioner = instantiate_from_config(c_concat_config) 256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 257 | 258 | def forward(self, c_concat, c_crossattn): 259 | c_concat = self.concat_conditioner(c_concat) 260 | c_crossattn = self.crossattn_conditioner(c_crossattn) 261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 262 | 263 | 264 | def noise_like(shape, device, repeat=False): 265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 266 | noise = lambda: torch.randn(shape, device=device) 267 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /modules/openclip/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.checkpoint import checkpoint 4 | 5 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel 6 | 7 | import open_clip 8 | 9 | 10 | class AbstractEncoder(nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def encode(self, *args, **kwargs): 15 | raise NotImplementedError 16 | 17 | 18 | class IdentityEncoder(AbstractEncoder): 19 | 20 | def encode(self, x): 21 | return x 22 | 23 | 24 | class ClassEmbedder(nn.Module): 25 | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): 26 | super().__init__() 27 | self.key = key 28 | self.embedding = nn.Embedding(n_classes, embed_dim) 29 | self.n_classes = n_classes 30 | self.ucg_rate = ucg_rate 31 | 32 | def forward(self, batch, key=None, disable_dropout=False): 33 | if key is None: 34 | key = self.key 35 | # this is for use in crossattn 36 | c = batch[key][:, None] 37 | if self.ucg_rate > 0. and not disable_dropout: 38 | mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) 39 | c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) 40 | c = c.long() 41 | c = self.embedding(c) 42 | return c 43 | 44 | def get_unconditional_conditioning(self, bs, device="cuda"): 45 | uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) 46 | uc = torch.ones((bs,), device=device) * uc_class 47 | uc = {self.key: uc} 48 | return uc 49 | 50 | 51 | def disabled_train(self, mode=True): 52 | """Overwrite model.train with this function to make sure train/eval mode 53 | does not change anymore.""" 54 | return self 55 | 56 | 57 | class FrozenT5Embedder(AbstractEncoder): 58 | """Uses the T5 transformer encoder for text""" 59 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl 60 | super().__init__() 61 | self.tokenizer = T5Tokenizer.from_pretrained(version) 62 | self.transformer = T5EncoderModel.from_pretrained(version) 63 | self.device = device 64 | self.max_length = max_length # TODO: typical value? 65 | if freeze: 66 | self.freeze() 67 | 68 | def freeze(self): 69 | self.transformer = self.transformer.eval() 70 | #self.train = disabled_train 71 | for param in self.parameters(): 72 | param.requires_grad = False 73 | 74 | def forward(self, text): 75 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 76 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 77 | tokens = batch_encoding["input_ids"].to(self.device) 78 | outputs = self.transformer(input_ids=tokens) 79 | 80 | z = outputs.last_hidden_state 81 | return z 82 | 83 | def encode(self, text): 84 | return self(text) 85 | 86 | 87 | class FrozenCLIPEmbedder(AbstractEncoder): 88 | """Uses the CLIP transformer encoder for text (from huggingface)""" 89 | LAYERS = [ 90 | "last", 91 | "pooled", 92 | "hidden" 93 | ] 94 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, 95 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 96 | super().__init__() 97 | assert layer in self.LAYERS 98 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 99 | self.transformer = CLIPTextModel.from_pretrained(version) 100 | self.device = device 101 | self.max_length = max_length 102 | if freeze: 103 | self.freeze() 104 | self.layer = layer 105 | self.layer_idx = layer_idx 106 | if layer == "hidden": 107 | assert layer_idx is not None 108 | assert 0 <= abs(layer_idx) <= 12 109 | 110 | def freeze(self): 111 | self.transformer = self.transformer.eval() 112 | #self.train = disabled_train 113 | for param in self.parameters(): 114 | param.requires_grad = False 115 | 116 | def forward(self, text): 117 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 118 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 119 | tokens = batch_encoding["input_ids"].to(self.device) 120 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") 121 | if self.layer == "last": 122 | z = outputs.last_hidden_state 123 | elif self.layer == "pooled": 124 | z = outputs.pooler_output[:, None, :] 125 | else: 126 | z = outputs.hidden_states[self.layer_idx] 127 | return z 128 | 129 | def encode(self, text): 130 | return self(text) 131 | 132 | 133 | class FrozenOpenCLIPEmbedder(AbstractEncoder): 134 | """ 135 | Uses the OpenCLIP transformer encoder for text 136 | """ 137 | LAYERS = [ 138 | #"pooled", 139 | "last", 140 | "penultimate" 141 | ] 142 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, 143 | freeze=True, layer="last"): 144 | super().__init__() 145 | assert layer in self.LAYERS 146 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) 147 | del model.visual 148 | self.model = model 149 | 150 | self.device = device 151 | self.max_length = max_length 152 | if freeze: 153 | self.freeze() 154 | self.layer = layer 155 | if self.layer == "last": 156 | self.layer_idx = 0 157 | elif self.layer == "penultimate": 158 | self.layer_idx = 1 159 | else: 160 | raise NotImplementedError() 161 | 162 | def freeze(self): 163 | self.model = self.model.eval() 164 | for param in self.parameters(): 165 | param.requires_grad = False 166 | 167 | def forward(self, text): 168 | tokens = open_clip.tokenize(text) 169 | z = self.encode_with_transformer(tokens.to(self.device)) 170 | return z 171 | 172 | def encode_with_transformer(self, text): 173 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 174 | x = x + self.model.positional_embedding 175 | x = x.permute(1, 0, 2) # NLD -> LND 176 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 177 | x = x.permute(1, 0, 2) # LND -> NLD 178 | x = self.model.ln_final(x) 179 | return x 180 | 181 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): 182 | for i, r in enumerate(self.model.transformer.resblocks): 183 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 184 | break 185 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): 186 | x = checkpoint(r, x, attn_mask) 187 | else: 188 | x = r(x, attn_mask=attn_mask) 189 | return x 190 | 191 | def encode(self, text): 192 | return self(text) 193 | 194 | 195 | class FrozenCLIPT5Encoder(AbstractEncoder): 196 | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", 197 | clip_max_length=77, t5_max_length=77): 198 | super().__init__() 199 | self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) 200 | self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) 201 | # print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " 202 | # f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") 203 | 204 | def encode(self, text): 205 | return self(text) 206 | 207 | def forward(self, text): 208 | clip_z = self.clip_encoder.encode(text) 209 | t5_z = self.t5_encoder.encode(text) 210 | return [clip_z, t5_z] 211 | 212 | 213 | -------------------------------------------------------------------------------- /pretrained_models/LAION_text2img/split_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | pl_sd = torch.load("model.ckpt") 4 | sd = pl_sd["state_dict"] 5 | unet_sd = {k[22:]: v for k, v in sd.items() if k[:21]=='model.diffusion_model'} 6 | vq_sd = {k[18:]: v for k, v in sd.items() if k[:17]=='first_stage_model'} 7 | cond_sd = {k[17:]: v for k, v in sd.items() if k[:16]=='cond_stage_model'} 8 | 9 | torch.save(unet_sd, 'unet.ckpt') 10 | torch.save(vq_sd, 'vqvae.ckpt') 11 | torch.save(cond_sd, 'bert.ckpt') -------------------------------------------------------------------------------- /pretrained_models/LAION_text2img/txt2img-1p4B-eval.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | unet_config: 21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 22 | params: 23 | image_size: 32 24 | in_channels: 4 25 | out_channels: 4 26 | model_channels: 320 27 | attention_resolutions: 28 | - 4 29 | - 2 30 | - 1 31 | num_res_blocks: 2 32 | channel_mult: 33 | - 1 34 | - 2 35 | - 4 36 | - 4 37 | num_heads: 8 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 1280 41 | use_checkpoint: true 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 4 48 | monitor: val/rec_loss 49 | ddconfig: 50 | double_z: true 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: 57 | - 1 58 | - 2 59 | - 4 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: [] 63 | dropout: 0.0 64 | lossconfig: 65 | target: torch.nn.Identity 66 | 67 | cond_stage_config: 68 | target: ldm.modules.encoders.modules.BERTEmbedder 69 | params: 70 | n_embed: 1280 71 | n_layer: 32 72 | -------------------------------------------------------------------------------- /pretrained_models/SD1_5/split_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | pl_sd = torch.load("model.ckpt") 4 | sd = pl_sd["state_dict"] 5 | unet_sd = {k[22:]: v for k, v in sd.items() if k[:21]=='model.diffusion_model'} 6 | vq_sd = {k[18:]: v for k, v in sd.items() if k[:17]=='first_stage_model'} 7 | cond_sd = {k[17:]: v for k, v in sd.items() if k[:16]=='cond_stage_model'} 8 | 9 | torch.save(unet_sd, 'unet.ckpt') 10 | torch.save(vq_sd, 'vqvae.ckpt') 11 | torch.save(cond_sd, 'clip.ckpt') -------------------------------------------------------------------------------- /pretrained_models/SD2_1/split_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | pl_sd = torch.load("model.ckpt") 4 | sd = pl_sd["state_dict"] 5 | unet_sd = {k[22:]: v for k, v in sd.items() if k[:21]=='model.diffusion_model'} 6 | vq_sd = {k[18:]: v for k, v in sd.items() if k[:17]=='first_stage_model'} 7 | cond_sd = {k[17:]: v for k, v in sd.items() if k[:16]=='cond_stage_model'} 8 | 9 | torch.save(unet_sd, 'unet.ckpt') 10 | torch.save(vq_sd, 'vqvae.ckpt') 11 | torch.save(cond_sd, 'clip.ckpt') -------------------------------------------------------------------------------- /pretrained_models/anything4_5/split_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | sd = torch.load("model.ckpt") 4 | unet_sd = {k[22:]: v for k, v in sd.items() if k[:21]=='model.diffusion_model'} 5 | vq_sd = {k[18:]: v for k, v in sd.items() if k[:17]=='first_stage_model'} 6 | cond_sd = {k[17:]: v for k, v in sd.items() if k[:16]=='cond_stage_model'} 7 | 8 | torch.save(unet_sd, 'unet.ckpt') 9 | torch.save(vq_sd, 'vqvae.ckpt') 10 | torch.save(cond_sd, 'clip.ckpt') 11 | -------------------------------------------------------------------------------- /pretrained_models/celeba256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 48 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.faceshq.CelebAHQTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.faceshq.CelebAHQValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /pretrained_models/celeba256/split_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | pl_sd = torch.load("model.ckpt") 4 | sd = pl_sd["state_dict"] 5 | unet_sd = {k[22:]: v for k, v in sd.items() if k[:21]=='model.diffusion_model'} 6 | vq_sd = {k[18:]: v for k, v in sd.items() if k[:17]=='first_stage_model'} 7 | 8 | torch.save(unet_sd, 'unet.ckpt') 9 | torch.save(vq_sd, 'vqvae.ckpt') -------------------------------------------------------------------------------- /pretrained_models/celeba256/split_model_weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | pl_sd = torch.load("model.ckpt") 4 | sd = pl_sd["state_dict"] 5 | unet_sd = {k[22:]: v for k, v in sd.items() if k[:21]=='model.diffusion_model'} 6 | vq_sd = {k[18:]: v for k, v in sd.items() if k[:17]=='first_stage_model'} 7 | 8 | torch.save(unet_sd, 'unet.ckpt') 9 | torch.save(vq_sd, 'vqvae.ckpt') -------------------------------------------------------------------------------- /pretrained_models/counterfeitV25/split_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from safetensors import safe_open 3 | 4 | def load_safetensors(file_path): 5 | tensors = {} 6 | with safe_open(file_path, framework="pt", device="cpu") as f: 7 | for key in f.keys(): 8 | tensors[key] = f.get_tensor(key) 9 | return tensors 10 | 11 | sd = load_safetensors("counterfeitV25Pruned.safetensors") 12 | unet_sd = {k[22:]: v for k, v in sd.items() if k[:21]=='model.diffusion_model'} 13 | vq_sd = {k[18:]: v for k, v in sd.items() if k[:17]=='first_stage_model'} 14 | cond_sd = {k[17:]: v for k, v in sd.items() if k[:16]=='cond_stage_model'} 15 | 16 | torch.save(unet_sd, 'unet.ckpt') 17 | torch.save(vq_sd, 'vqvae.ckpt') 18 | torch.save(cond_sd, 'clip.ckpt') 19 | -------------------------------------------------------------------------------- /pretrained_models/negative/EasyNegative.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cplusx/layout_diffuse/9666cb867313aa693775f6134442dea3734565a5/pretrained_models/negative/EasyNegative.safetensors -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==1.2.1 2 | h5py 3 | imageio 4 | lmdb 5 | matplotlib 6 | opencv-python 7 | pillow 8 | pytorch-lightning 9 | scikit-image 10 | scikit-learn 11 | scipy 12 | #torch==1.12.1+cu116 13 | #torchvision==0.13.0+cu116 14 | tqdm 15 | wandb 16 | clean-fid 17 | einops 18 | pycocotools 19 | perceiver-pytorch 20 | transformers 21 | gdown 22 | -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 23 | open-clip-torch 24 | openai -------------------------------------------------------------------------------- /run_gradio.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime 3 | import gradio as gr 4 | import os 5 | import torch 6 | import json 7 | from train_utils import get_models, get_DDPM 8 | import logging 9 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR) 10 | from data.coco_w_stuff import get_coco_id_mapping 11 | import numpy as np 12 | from test_utils import sample_one_image, parse_test_args, load_test_models, load_model_weights 13 | 14 | coco_id_to_name = get_coco_id_mapping() 15 | coco_name_to_id = {v: int(k) for k, v in coco_id_to_name.items()} 16 | 17 | args = parse_test_args() 18 | ddpm_model = load_test_models(args) 19 | load_model_weights(ddpm_model=ddpm_model, args=args) 20 | 21 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 22 | ddpm_model = ddpm_model.to(device) 23 | ddpm_model.text_fn = ddpm_model.text_fn.to(device) 24 | ddpm_model.text_fn.device = device 25 | ddpm_model.denoise_fn = ddpm_model.denoise_fn.to(device) 26 | ddpm_model.vqvae_fn = ddpm_model.vqvae_fn.to(device) 27 | 28 | yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True).to(device) 29 | 30 | def obtain_bbox_from_yolo(image): 31 | H, W = image.shape[:2] 32 | results = yolo_model(image) 33 | # convert results to [x, y, w, h, object_name] 34 | xyxy_conf_cls = results.xyxy[0].detach().cpu().numpy() 35 | bboxes = [] 36 | for x1, y1, x2, y2, conf, cls_idx in xyxy_conf_cls: 37 | x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) 38 | cls_name = yolo_model.names[int(cls_idx)] 39 | if conf >= 0.5: 40 | bboxes.append([x1 / W, y1 / H, (x2 - x1) / W, (y2 - y1) / H, cls_name]) 41 | return bboxes 42 | 43 | def save_bboxes(bboxes, save_dir): 44 | current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 45 | file_name = str(hash(str(current_time)))[1:10] 46 | os.makedirs(save_dir, exist_ok=True) 47 | save_path = os.path.join(save_dir, f'{file_name}.txt') 48 | with open(save_path, 'w') as OUT: 49 | for bbox in bboxes: 50 | OUT.write(','.join([str(x) for x in bbox])) 51 | OUT.write('\n') 52 | return save_path 53 | 54 | def sample_images(ref_image): 55 | bboxes = obtain_bbox_from_yolo(ref_image) 56 | bbox_path = save_bboxes(bboxes, 'tmp') 57 | image, image_with_bbox, canvas_with_bbox = sample_one_image( 58 | bbox_path, 59 | ddpm_model, 60 | device, 61 | coco_name_to_id, coco_id_to_name, 62 | api_key=args['openai_api_key'], 63 | image_size=ref_image.shape[:2], 64 | additional_caption=args['additional_caption'] 65 | ) 66 | os.remove(bbox_path) 67 | if image is None: 68 | # Return a placeholder image and a message 69 | placeholder = np.zeros((ref_image.shape[0], ref_image.shape[1], 3), dtype=np.uint8) 70 | message = "No object found in the image" 71 | return message, placeholder, placeholder, placeholder 72 | else: 73 | return "Success", image, image_with_bbox, canvas_with_bbox 74 | 75 | # Define the Gradio interface with a message component 76 | input_image = gr.inputs.Image() 77 | output_images = [gr.outputs.Image(type='numpy') for i in range(3)] 78 | message = gr.outputs.Textbox(label="Information", type="text") 79 | interface = gr.Interface( 80 | fn=sample_images, 81 | inputs=input_image, 82 | outputs=[message] + output_images, 83 | capture_session=True, 84 | title="LayoutDiffuse", 85 | description="Drop a reference image to generate a new image with the same layout", 86 | allow_flagging=False, 87 | live=False 88 | ) 89 | 90 | interface.launch(share=True) 91 | -------------------------------------------------------------------------------- /run_gradio_merge.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import gradio as gr 3 | import os 4 | import torch 5 | import logging 6 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR) 7 | from data.coco_w_stuff import get_coco_id_mapping 8 | import numpy as np 9 | from test_utils import sample_one_image, parse_test_args, load_test_models, load_model_weights 10 | 11 | coco_id_to_name = get_coco_id_mapping() 12 | coco_name_to_id = {v: int(k) for k, v in coco_id_to_name.items()} 13 | 14 | args = parse_test_args() 15 | ddpm_model = load_test_models(args) 16 | load_model_weights(ddpm_model=ddpm_model, args=args) 17 | 18 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 19 | ddpm_model = ddpm_model.to(device) 20 | ddpm_model.text_fn = ddpm_model.text_fn.to(device) 21 | ddpm_model.text_fn.device = device 22 | ddpm_model.denoise_fn = ddpm_model.denoise_fn.to(device) 23 | ddpm_model.vqvae_fn = ddpm_model.vqvae_fn.to(device) 24 | 25 | # ddpm_model.merge('pretrained_models/anything4_5/unet.ckpt', alpha=1.) 26 | # ddpm_model.merge('pretrained_models/counterfeitV25/unet.ckpt', alpha=1.) 27 | 28 | yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True).to(device) 29 | 30 | def obtain_bbox_from_yolo(image): 31 | H, W = image.shape[:2] 32 | results = yolo_model(image) 33 | # convert results to [x, y, w, h, object_name] 34 | xyxy_conf_cls = results.xyxy[0].detach().cpu().numpy() 35 | bboxes = [] 36 | for x1, y1, x2, y2, conf, cls_idx in xyxy_conf_cls: 37 | x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) 38 | cls_name = yolo_model.names[int(cls_idx)] 39 | if conf >= 0.5: 40 | bboxes.append([x1 / W, y1 / H, (x2 - x1) / W, (y2 - y1) / H, cls_name]) 41 | return bboxes 42 | 43 | def save_bboxes(bboxes, save_dir): 44 | current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 45 | file_name = str(hash(str(current_time)))[1:10] 46 | os.makedirs(save_dir, exist_ok=True) 47 | save_path = os.path.join(save_dir, f'{file_name}.txt') 48 | with open(save_path, 'w') as OUT: 49 | for bbox in bboxes: 50 | OUT.write(','.join([str(x) for x in bbox])) 51 | OUT.write('\n') 52 | return save_path 53 | 54 | def sample_images(ref_image, user_input): 55 | bboxes = obtain_bbox_from_yolo(ref_image) 56 | bbox_path = save_bboxes(bboxes, 'tmp') 57 | image, image_with_bbox, canvas_with_bbox = sample_one_image( 58 | bbox_path, 59 | ddpm_model, 60 | device, 61 | coco_name_to_id, coco_id_to_name, 62 | api_key=args['openai_api_key'], 63 | image_size=ref_image.shape[:2], 64 | additional_caption=args['additional_caption'] + user_input 65 | ) 66 | os.remove(bbox_path) 67 | if image is None: 68 | # Return a placeholder image and a message 69 | placeholder = np.zeros((ref_image.shape[0], ref_image.shape[1], 3), dtype=np.uint8) 70 | message = "No object found in the image" 71 | return message, placeholder, placeholder, placeholder 72 | else: 73 | return "Success", image, image_with_bbox, canvas_with_bbox 74 | 75 | # Define the Gradio interface with a message component 76 | input_image = gr.inputs.Image() 77 | input_text = gr.inputs.Textbox(type='text', label='Additional caption') 78 | output_images = [gr.outputs.Image(type='numpy') for i in range(3)] 79 | message = gr.outputs.Textbox(label="Information", type="text") 80 | interface = gr.Interface( 81 | fn=sample_images, 82 | inputs=[input_image, input_text], 83 | outputs=[message] + output_images, 84 | capture_session=True, 85 | title="LayoutDiffuse", 86 | description="Drop a reference image to generate a new image with the same layout", 87 | allow_flagging=False, 88 | live=False 89 | ) 90 | 91 | interface.launch(share=True) 92 | -------------------------------------------------------------------------------- /sampling.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import json 5 | from pytorch_lightning import Trainer 6 | from train_utils import get_models, get_DDPM 7 | from test_utils import load_model_weights 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument( 12 | '-c', '--config', type=str, 13 | default='config/train.json') 14 | parser.add_argument( 15 | '-n', '--num_repeat', type=int, 16 | default=1, help='the number of images for each condition') 17 | parser.add_argument( 18 | '-e', '--epoch', type=int, 19 | default=None, help='which epoch to evaluate, if None, will use the latest') 20 | parser.add_argument( 21 | '--nnode', type=int, default=1 22 | ) 23 | parser.add_argument( 24 | '--model_path', type=str, 25 | default=None, help='model path for generating layout diffuse, if not provided, will use the latest.ckpt') 26 | 27 | ''' parser configs ''' 28 | args_raw = parser.parse_args() 29 | with open(args_raw.config, 'r') as IN: 30 | args = json.load(IN) 31 | args.update(vars(args_raw)) 32 | # args['gpu_ids'] = [0] # DEBUG 33 | expt_name = args['expt_name'] 34 | expt_dir = args['expt_dir'] 35 | expt_path = os.path.join(expt_dir, expt_name) 36 | os.makedirs(expt_path, exist_ok=True) 37 | 38 | '''1. create denoising model''' 39 | denoise_args = args['denoising_model']['model_args'] 40 | models = get_models(args) 41 | 42 | diffusion_configs = args['diffusion'] 43 | ddpm_model = get_DDPM( 44 | diffusion_configs=diffusion_configs, 45 | log_args=args, 46 | **models 47 | ) 48 | 49 | '''2. create a dataloader which generates''' 50 | from test_utils import get_test_dataset, get_test_callbacks 51 | test_dataset, test_loader = get_test_dataset(args) 52 | 53 | '''3. callbacks''' 54 | callbacks = get_test_callbacks(args, expt_path) 55 | 56 | '''4. load checkpoint''' 57 | print('INFO: loading checkpoint') 58 | if args['model_path'] is not None: 59 | ckpt_path = args['model_path'] 60 | else: 61 | expt_path = os.path.join(args['expt_dir'], args['expt_name']) 62 | if args['epoch'] is None: 63 | ckpt_to_use = 'latest.ckpt' 64 | else: 65 | ckpt_to_use = f'epoch={args["epoch"]:04d}.ckpt' 66 | ckpt_path = os.path.join(expt_path, ckpt_to_use) 67 | print(ckpt_path) 68 | if os.path.exists(ckpt_path): 69 | print(f'INFO: Found checkpoint {ckpt_path}') 70 | # ckpt = torch.load(ckpt_path, map_location='cpu')['state_dict'] 71 | ''' DEBUG ''' 72 | # ckpt_denoise_fn = {k.replace('denoise_fn.', ''): v for k, v in ckpt.items() if 'denoise_fn' in k} 73 | # ddpm_model.denoise_fn.load_state_dict(ckpt_denoise_fn) 74 | # ddpm_model.load_state_dict(ckpt) 75 | else: 76 | ckpt_path = None 77 | raise RuntimeError('Cannot do inference without pretrained checkpoint') 78 | 79 | '''5. trianer''' 80 | trainer_args = { 81 | "max_epochs": 1000, 82 | "accelerator": "gpu", 83 | "devices": [0], 84 | "limit_val_batches": 1, 85 | "strategy": "ddp", 86 | "check_val_every_n_epoch": 1, 87 | "num_nodes": args['nnode'] 88 | # "benchmark" :True 89 | } 90 | config_trainer_args = args['trainer_args'] if args.get('trainer_args') is not None else {} 91 | trainer_args.update(config_trainer_args) 92 | print(f'Training args are {trainer_args}') 93 | trainer = Trainer( 94 | callbacks = callbacks, 95 | **trainer_args 96 | ) 97 | 98 | '''6. start sampling''' 99 | '''use trainer for sampling, you need a image saver callback to save images, useful for generate many images''' 100 | num_loop = args['num_repeat'] 101 | for _ in range(num_loop): 102 | # trainer.test(ddpm_model, test_loader) # DEBUG 103 | trainer.test(ddpm_model, test_loader, ckpt_path=ckpt_path) 104 | -------------------------------------------------------------------------------- /sampling_in_background.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from data.coco_w_stuff import get_coco_id_mapping 4 | import numpy as np 5 | import cv2 6 | import time 7 | from test_utils import sample_one_image, parse_test_args, load_test_models, load_model_weights 8 | coco_id_to_name = get_coco_id_mapping() 9 | coco_name_to_id = {v: int(k) for k, v in coco_id_to_name.items()} 10 | 11 | if __name__ == '__main__': 12 | args = parse_test_args() 13 | ddpm_model = load_test_models(args) 14 | load_model_weights(ddpm_model=ddpm_model, args=args) 15 | 16 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 17 | ddpm_model = ddpm_model.to(device) 18 | ddpm_model.text_fn = ddpm_model.text_fn.to(device) 19 | ddpm_model.text_fn.device = device 20 | ddpm_model.denoise_fn = ddpm_model.denoise_fn.to(device) 21 | ddpm_model.vqvae_fn = ddpm_model.vqvae_fn.to(device) 22 | 23 | while True: 24 | # read file in the folder. If there is a file, sample the image and save it to the folder "flask_images_sampled" and remove the file from the folder "flask_images_to_sample" 25 | 26 | from glob import glob 27 | files_to_sample = glob('interactive_plotting/tmp/*.txt') 28 | for f in files_to_sample: 29 | print('INFO: processing file', f) 30 | image, image_with_bbox, canvas_with_bbox = sample_one_image( 31 | f, ddpm_model, device, 32 | class_name_to_id=coco_name_to_id, 33 | class_id_to_name=coco_id_to_name, 34 | api_key=args['openai_api_key'], 35 | additional_caption=args['additional_caption'] 36 | ) 37 | # save the image 38 | cat_image = np.concatenate([image, image_with_bbox, canvas_with_bbox], axis=1) 39 | cv2.imwrite(f.replace('.txt', '.jpg'), (cat_image[..., ::-1] * 255).astype(np.uint8)) 40 | # remove the file 41 | os.remove(f) 42 | 43 | time.sleep(1) -------------------------------------------------------------------------------- /scripts/convert_jpg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import cv2 4 | from tqdm import tqdm 5 | 6 | def read_convert_and_save(img_path, save_path): 7 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 8 | image = cv2.imread(img_path) 9 | cv2.imwrite(save_path, image) 10 | 11 | if __name__ == '__main__': 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--indir', type=str) 14 | 15 | ''' parser configs ''' 16 | args = parser.parse_args() 17 | 18 | in_dir = args.indir 19 | out_dir = os.path.join( 20 | os.path.dirname(in_dir), 21 | os.path.basename(in_dir) + f'-jpg' 22 | ) 23 | 24 | image_names = os.listdir(in_dir) 25 | 26 | for image_name in tqdm(image_names, desc='convert image to jpg'): 27 | if not (image_name.endswith('.jpg') or image_name.endswith('.png')): 28 | continue 29 | img_path = os.path.join(in_dir, image_name) 30 | save_img_path = os.path.join(out_dir, image_name.replace('.png', '.jpg')) 31 | if os.path.exists(save_img_path): 32 | continue 33 | read_convert_and_save(img_path, save_img_path) -------------------------------------------------------------------------------- /scripts/convert_npz_to_npy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | import glob 5 | from tqdm import tqdm 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('-s', '--src', type=str, default='', help='source images directory') 9 | args = parser.parse_args() 10 | 11 | indir = args.src 12 | outdir = indir+'_npy' 13 | os.makedirs(outdir, exist_ok=True) 14 | 15 | npz_files = glob.glob(indir + '/*.npz') 16 | print(len(npz_files)) 17 | for npz_file in tqdm(npz_files): 18 | out_path = npz_file.replace(indir, outdir) 19 | out_path = out_path.replace('npz', 'npy') 20 | image = np.load(npz_file)['image'] 21 | 22 | with open(out_path, 'wb') as OUT: 23 | np.save(OUT, image*255) -------------------------------------------------------------------------------- /scripts/download_celebMask.sh: -------------------------------------------------------------------------------- 1 | DIR=~/disk2/data/CelebAMask-HQ 2 | mkdir -p $DIR 3 | 4 | cd $DIR 5 | gdown https://drive.google.com/uc?id=1badu11NqxGf6qM3PTTooQDJvQbejgbTv -------------------------------------------------------------------------------- /scripts/download_coco.sh: -------------------------------------------------------------------------------- 1 | mkdir -p ~/disk2/data/COCO 2 | cd ~/disk2/data/COCO 3 | wget http://images.cocodataset.org/zips/train2017.zip 4 | wget http://images.cocodataset.org/zips/val2017.zip 5 | wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip 6 | wget http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip 7 | ls *.zip | while read f; do 8 | unzip $f; 9 | done 10 | -------------------------------------------------------------------------------- /scripts/download_pretrained_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | download_face() { 4 | mkdir -p pretrained_models/celeba256 5 | wget -O pretrained_models/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip 6 | cd pretrained_models/celeba256 7 | unzip -o celeba-256.zip 8 | python split_model.py 9 | } 10 | 11 | download_ldm() { 12 | mkdir -p pretrained_models/LAION_text2img 13 | wget -O pretrained_models/LAION_text2img/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/txt2img-f8-large/model.ckpt 14 | cd pretrained_models/LAION_text2img 15 | python split_model.py 16 | } 17 | 18 | download_sd1_5() { 19 | mkdir -p pretrained_models/SD1_5 20 | wget -O pretrained_models/SD1_5/model.ckpt https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt 21 | cd pretrained_models/SD1_5 22 | python split_model.py 23 | } 24 | 25 | download_sd2_1() { 26 | mkdir -p pretrained_models/SD2_1 27 | wget -O pretrained_models/SD2_1/model.ckpt https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-nonema-pruned.ckpt 28 | cd pretrained_models/SD2_1 29 | python split_model.py 30 | } 31 | 32 | download_all() { 33 | download_face 34 | cd ../.. 35 | download_ldm 36 | cd ../.. 37 | download_sd1_5 38 | cd ../.. 39 | download_sd2_1 40 | } 41 | 42 | case $1 in 43 | "face") 44 | download_face 45 | ;; 46 | "ldm") 47 | download_ldm 48 | ;; 49 | "SD1_5") 50 | download_sd1_5 51 | ;; 52 | "SD2_1") 53 | download_sd2_1 54 | ;; 55 | "all") 56 | download_all 57 | ;; 58 | *) 59 | echo "Invalid argument. Usage: bash download.sh [face|ldm|SD1_5|SD2_1|all]" 60 | ;; 61 | esac 62 | -------------------------------------------------------------------------------- /scripts/download_vg.sh: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2018 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | VG_DIR=~/disk2/data/VG 16 | mkdir -p $VG_DIR 17 | 18 | wget https://visualgenome.org/static/data/dataset/objects.json.zip -O $VG_DIR/objects.json.zip 19 | wget https://visualgenome.org/static/data/dataset/attributes.json.zip -O $VG_DIR/attributes.json.zip 20 | wget https://visualgenome.org/static/data/dataset/relationships.json.zip -O $VG_DIR/relationships.json.zip 21 | wget https://visualgenome.org/static/data/dataset/object_alias.txt -O $VG_DIR/object_alias.txt 22 | wget https://visualgenome.org/static/data/dataset/relationship_alias.txt -O $VG_DIR/relationship_alias.txt 23 | wget https://visualgenome.org/static/data/dataset/image_data.json.zip -O $VG_DIR/image_data.json.zip 24 | wget https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip -O $VG_DIR/images.zip 25 | wget https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip -O $VG_DIR/images2.zip 26 | 27 | unzip $VG_DIR/objects.json.zip -d $VG_DIR 28 | unzip $VG_DIR/attributes.json.zip -d $VG_DIR 29 | unzip $VG_DIR/relationships.json.zip -d $VG_DIR 30 | unzip $VG_DIR/image_data.json.zip -d $VG_DIR 31 | unzip $VG_DIR/images.zip -d $VG_DIR/images 32 | unzip $VG_DIR/images2.zip -d $VG_DIR/images 33 | 34 | python scripts/preprocess_vg.py -------------------------------------------------------------------------------- /scripts/eval_scripts/celeb_mask.sh: -------------------------------------------------------------------------------- 1 | # python fid_eval.py \ 2 | # --dataset celeb_mask \ 3 | # -s /home/ubuntu/disk2/data/face/CelebAMask-HQ/CelebA-HQ-img-train-256x256 \ 4 | # -d experiments/celeb_mask_ldm_partial_attn/sampling_at_00279_image 5 | 6 | run_once () { 7 | res=$1 8 | epoch=$2 9 | python fid_eval.py \ 10 | -s /home/ubuntu/disk2/data/face/CelebAMask-HQ/CelebA-HQ-img \ 11 | --resize_s \ 12 | -d experiments/celeb_mask_ldm_${res}_samples/epoch_${epoch}/image > tmp/data_efficiency_res_${res}_epoch_${epoch}.txt 13 | } 14 | 15 | # CUDA_VISIBLE_DEVICES=1 run_once 128 "00099" && run_once 128 "00199" && run_once 128 "00499" && run_once 128 "00999" && run_once 128 "01999" & 16 | 17 | # CUDA_VISIBLE_DEVICES=2 run_once 256 "00049" && run_once 256 "00099" && run_once 256 "00249" && run_once 256 "00499" && run_once 256 "00999" & 18 | 19 | # CUDA_VISIBLE_DEVICES=3 run_once 512 "00024" && run_once 512 "00049" && run_once 512 "00124" && run_once 512 "00249" && run_once 512 "00499" & 20 | 21 | CUDA_VISIBLE_DEVICES=0 run_once 1024 "00012" && run_once 1024 "00024" && run_once 1024 "00062" && run_once 1024 "00124" && run_once 1024 "00249" & 22 | 23 | CUDA_VISIBLE_DEVICES=1 run_once 2048 "00006" && run_once 2048 "00012" && run_once 2048 "00031" && run_once 2048 "00062" && run_once 2048 "00124" & 24 | 25 | # seq 4 5 10| while read e; do 26 | # python fid_eval.py \ 27 | # -s /home/ubuntu/disk2/data/face/CelebAMask-HQ/CelebA-HQ-img \ 28 | # --resize_s \ 29 | # -d experiments/celeb_mask_ldm_v2/epoch_0000$e/image > tmp/celeb_v2_fid_e_$e.txt 30 | # done 31 | 32 | # seq 14 5 30 | while read e; do 33 | # python fid_eval.py \ 34 | # -s /home/ubuntu/disk2/data/face/CelebAMask-HQ/CelebA-HQ-img \ 35 | # --resize_s \ 36 | # -d experiments/celeb_mask_ldm_v2/epoch_000$e/image > tmp/celeb_v2_fid_e_$e.txt 37 | # done -------------------------------------------------------------------------------- /scripts/eval_scripts/convert_npz_to_npy.sh: -------------------------------------------------------------------------------- 1 | # python scripts/convert_npz_to_npy.py -s experiments/laion_ldm_cocostuff_layout_caption_ablation_no_instance_attn/epoch_00009_plms_100_5.0/raw_tensor 2 | # python scripts/convert_npz_to_npy.py -s experiments/laion_ldm_cocostuff_layout_caption_ablation_no_instance_attn/epoch_00029_plms_100_5.0/raw_tensor 3 | # python scripts/convert_npz_to_npy.py -s experiments/laion_ldm_cocostuff_layout_caption_ablation_no_instance_attn/epoch_00059_plms_100_5.0/raw_tensor 4 | 5 | # python scripts/convert_npz_to_npy.py -s experiments/laion_ldm_cocostuff_layout_caption_v9/epoch_00059_plms_200_5.0/raw_tensor 6 | seq 4 5 10 | while read e; do 7 | python scripts/convert_npz_to_npy.py -s experiments/celeb_mask_ldm_v2/epoch_0000$e/raw_tensor 8 | done 9 | 10 | seq 14 5 30 | while read e; do 11 | python scripts/convert_npz_to_npy.py -s experiments/celeb_mask_ldm_v2/epoch_000$e/raw_tensor 12 | done 13 | # python scripts/convert_npz_to_npy.py -s experiments/laion_ldm_cocostuff_layout_caption_v9/epoch_00009/raw_tensor 14 | # python scripts/convert_npz_to_npy.py -s experiments/laion_ldm_cocostuff_layout_caption_v9/epoch_00029/raw_tensor 15 | 16 | # python scripts/convert_npz_to_npy.py -s experiments/laion_ldm_cocostuff_layout_no_caption/epoch_00009_plms_100_5.0/raw_tensor 17 | # python scripts/convert_npz_to_npy.py -s experiments/laion_ldm_cocostuff_layout_no_caption/epoch_00029_plms_100_5.0/raw_tensor 18 | # python scripts/convert_npz_to_npy.py -s experiments/laion_ldm_cocostuff_layout_no_caption/epoch_00059_plms_100_5.0/raw_tensor 19 | -------------------------------------------------------------------------------- /scripts/eval_scripts/fid_coco_layout_ablation.sh: -------------------------------------------------------------------------------- 1 | run_once () { 2 | expt=$1 3 | appendix=$3 4 | echo "Score for ${expt}, epoch $2" 5 | python fid_eval.py \ 6 | -s /home/ubuntu/disk2/data/COCO/train2017 \ 7 | --resize_s \ 8 | -d experiments/${expt}/epoch_000$2$appendix/raw_tensor_npy 9 | # -d experiments/${expt}/epoch_000$2$appendix/sample_image 10 | } 11 | 12 | expt="laion_ldm_cocostuff_layout_no_caption" 13 | appendix="_plms_100_5.0" 14 | epoch="09" 15 | CUDA_VISIBLE_DEVICES=5 run_once $expt $epoch $appendix >> tmp/fid_${expt}_${epoch}.txt 2>&1 & 16 | 17 | expt="laion_ldm_cocostuff_layout_no_caption" 18 | appendix="_plms_100_5.0" 19 | epoch="29" 20 | CUDA_VISIBLE_DEVICES=6 run_once $expt $epoch $appendix >> tmp/fid_${expt}_${epoch}.txt 2>&1 & 21 | 22 | expt="laion_ldm_cocostuff_layout_no_caption" 23 | appendix="_plms_100_5.0" 24 | epoch="59" 25 | CUDA_VISIBLE_DEVICES=7 run_once $expt $epoch $appendix >> tmp/fid_${expt}_${epoch}.txt 2>&1 & 26 | 27 | # expt="laion_ldm_cocostuff_layout_caption_v9" 28 | # appendix="_plms_200_5.0" 29 | # epoch="59" 30 | # CUDA_VISIBLE_DEVICES=6 run_once $expt $epoch $appendix >> tmp/fid_${expt}_${epoch}.txt 2>&1 & 31 | 32 | # expt="laion_ldm_cocostuff_layout_caption_ablation_no_instance_attn" 33 | # appendix="_plms_100_5.0" 34 | # epoch="09" 35 | # CUDA_VISIBLE_DEVICES=6 run_once $expt $epoch $appendix >> tmp/fid_${expt}_${epoch}.txt 2>&1 & 36 | 37 | # expt="laion_ldm_cocostuff_layout_caption_ablation_no_instance_attn" 38 | # appendix="_plms_100_5.0" 39 | # epoch="29" 40 | # CUDA_VISIBLE_DEVICES=5 run_once $expt $epoch $appendix >> tmp/fid_${expt}_${epoch}.txt 2>&1 & 41 | 42 | # expt="laion_ldm_cocostuff_layout_caption_ablation_no_instance_attn" 43 | # appendix="_plms_100_5.0" 44 | # epoch="59" 45 | # CUDA_VISIBLE_DEVICES=7 run_once $expt $epoch $appendix >> tmp/fid_${expt}_${epoch}.txt 2>&1 & -------------------------------------------------------------------------------- /scripts/remove_empty_file_in_vg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from glob import glob 4 | 5 | VG_DIR = '/home/ubuntu/disk2/data/VG/images' 6 | # VG_DIR = 'experiments/laion_ldm_cocostuff_layout_caption_v9/epoch_00059_plms_200_5.0/sampled_256_cropped_224' 7 | image_paths = glob(VG_DIR+'/**/*.jpg') + glob(VG_DIR+'/**/*.png') 8 | 9 | for path in image_paths: 10 | try: 11 | Image.open(path) 12 | except: 13 | print(f'{path} failed, remove it') 14 | os.system(f'rm {path}') -------------------------------------------------------------------------------- /scripts/resize_images.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import glob 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from PIL import ImageFile 8 | ImageFile.LOAD_TRUNCATED_IMAGES = True 9 | 10 | def process_image(img_path, save_path, size, mode): 11 | print('save image to ', save_path) 12 | img = Image.open(img_path) 13 | img = img.resize((size, size), mode) 14 | img = img.save(save_path) 15 | 16 | def read_resize_and_save(img_path, save_path, size, mode=Image.BICUBIC): 17 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 18 | if img_path.endswith('.png') or img_path.endswith('.jpg'): 19 | process_image(img_path, save_path, size, mode) 20 | 21 | if __name__ == '__main__': 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--indir', type=str) 24 | parser.add_argument('--size', type=int) 25 | 26 | ''' parser configs ''' 27 | args = parser.parse_args() 28 | size = args.size 29 | 30 | in_dir = args.indir 31 | out_dir = os.path.join( 32 | os.path.dirname(in_dir), 33 | os.path.basename(in_dir) + f'-{size}' 34 | ) 35 | 36 | image_names = glob.glob(in_dir + '/*.jpg') + glob.glob(in_dir + '/*.png') + glob.glob(in_dir + '/**/*.jpg') + glob.glob(in_dir + '/**/*.png') 37 | 38 | for image_name in tqdm(image_names): 39 | save_img_path = image_name.replace(in_dir, out_dir) 40 | if image_name.endswith('.jpg'): 41 | save_img_path = save_img_path.replace('.jpg', '.png') 42 | if os.path.exists(save_img_path): 43 | continue 44 | try: 45 | read_resize_and_save(image_name, save_img_path, size, mode=Image.BICUBIC) 46 | except: 47 | print(image_name, 'is broken') -------------------------------------------------------------------------------- /scripts/sampling_scripts/dist_sampling.sh: -------------------------------------------------------------------------------- 1 | export OMP_NUM_THREADS=1 2 | export MKL_NUM_THREADS=1 3 | export NNODE=3 4 | torchrun \ 5 | --nnodes=$NNODE \ 6 | --nproc_per_node 8 \ 7 | --rdzv_id v9_dist_sample \ 8 | --rdzv_backend c10d \ 9 | --rdzv_endpoint $1:29500 \ 10 | sampling.py -c $2 --nnode $NNODE -e $3 -n $4 11 | 12 | # usage: bash scripts/sampling_scripts/dist_sampling.sh \ 13 | # 172.31.0.139 configs/laion_cocostuff_text_v9.json \ 14 | # 59 5 # this is machine 1 ip address 15 | -------------------------------------------------------------------------------- /scripts/train_scripts/dist_train.sh: -------------------------------------------------------------------------------- 1 | export OMP_NUM_THREADS=1 2 | export MKL_NUM_THREADS=1 3 | export NNODE=4 4 | torchrun \ 5 | --nnodes=$NNODE \ 6 | --nproc_per_node 4 \ 7 | --rdzv_id v9_dist \ 8 | --rdzv_backend c10d \ 9 | --rdzv_endpoint $1:29500 \ 10 | main.py -c $2 -n $NNODE -r 11 | 12 | # usage: bash scripts/train_scripts/dist_train.sh 172.31.42.68 -------------------------------------------------------------------------------- /test_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import DataLoader 5 | from data.random_sampling import RandomNoise 6 | from model_utils import default, get_obj_from_str 7 | from callbacks.coco_layout.sampling_save_fig import ColorMapping, plot_bbox_without_overlap, plot_bounding_box 8 | import cv2 9 | 10 | def get_test_dataset(args): 11 | sampling_args = args['sampling_args'] 12 | sampling_w_noise = default(sampling_args.get('sampling_w_noise'), False) 13 | if sampling_w_noise: 14 | test_dataset = RandomNoise( 15 | sampling_args['image_size'], 16 | sampling_args['image_size'], 17 | sampling_args['in_channel'], 18 | sampling_args['num_samples'] 19 | ) 20 | else: 21 | from data import get_dataset 22 | args['data']['val_args']['data_len'] = sampling_args['num_samples'] 23 | _, test_dataset = get_dataset(**args['data']) 24 | test_loader = DataLoader(test_dataset, batch_size=args['data']['batch_size'], num_workers=4, shuffle=False) 25 | return test_dataset, test_loader 26 | 27 | def get_test_callbacks(args, expt_path): 28 | sampling_args = args['sampling_args'] 29 | callbacks = [] 30 | callbacks_obj = sampling_args.get('callbacks') 31 | for target in callbacks_obj: 32 | callbacks.append( 33 | get_obj_from_str(target)(expt_path) 34 | ) 35 | return callbacks 36 | 37 | def postprocess_image(batched_x, batched_bbox, class_id_to_name, image_callback=lambda x: x): 38 | x = batched_x[0] 39 | bbox = batched_bbox[0] 40 | x = x.permute(1, 2, 0).detach().cpu().numpy().clip(-1, 1) 41 | x = (x + 1) / 2 42 | x = image_callback(x) 43 | image_with_bbox = overlap_image_with_bbox(x, bbox, class_id_to_name) 44 | canvas_with_bbox = overlap_image_with_bbox(np.ones_like(x), bbox, class_id_to_name) 45 | return x, image_with_bbox, canvas_with_bbox 46 | 47 | def overlap_image_with_bbox(image, bbox, class_id_to_name): 48 | label_color_mapper = ColorMapping(id_class_mapping=class_id_to_name) 49 | image_with_bbox = plot_bbox_without_overlap( 50 | image.copy(), 51 | bbox, 52 | label_color_mapper 53 | ) if len(bbox) <= 10 else None 54 | if image_with_bbox is not None: 55 | return image_with_bbox 56 | return plot_bounding_box( 57 | image.copy(), 58 | bbox, 59 | label_color_mapper 60 | ) 61 | 62 | def generate_completion(caption, api_key, additional_caption=''): 63 | import openai 64 | # check if api_key is valid 65 | def validate_api_key(api_key): 66 | import re 67 | regex = "^sk-[a-zA-Z0-9]{48}$" # regex pattern for OpenAI API key 68 | if not isinstance(api_key, str): 69 | return None 70 | if not re.match(regex, api_key): 71 | return None 72 | return api_key 73 | openai.api_key = validate_api_key(api_key) 74 | if openai.api_key is None: 75 | print('WARNING: invalid OpenAI API key, using default caption') 76 | return caption 77 | prompt = f'Describe a scene with following words: ' + caption + '. Use the above words to generate a prompt for drawing with a diffusion model. Use at least 30 words and at most 80 words and include all given words. The final image should looks nice and be related to the given words' 78 | 79 | response = openai.ChatCompletion.create( 80 | model="gpt-3.5-turbo", 81 | messages=[{ 82 | "role": "user", 83 | "content": prompt 84 | }] 85 | ) 86 | 87 | return response.choices[0].message.content.strip() + additional_caption 88 | 89 | def concatenate_class_labels_to_caption(objects, class_id_to_name, api_key=None, additional_caption=''): 90 | # if want to add additional description for styles, add it to additonal_caption 91 | caption = '' 92 | for i in objects: 93 | caption += class_id_to_name[i[4]+1] + ', ' 94 | caption = caption.rstrip(', ') 95 | if api_key is not None: 96 | caption = generate_completion(caption, api_key=api_key, additional_caption=additional_caption) 97 | print('INFO: using openai text completion and the generated caption is: \n', caption) 98 | else: 99 | caption = caption + additional_caption 100 | print('INFO: using default caption: \n', caption) 101 | return caption 102 | 103 | def sample_one_image(bbox_path, ddpm_model, device, class_name_to_id, class_id_to_name, api_key=None, image_size=(512, 512), additional_caption=''): 104 | # the format of text file is: x, y, w, h, class_id 105 | with open(bbox_path, 'r') as IN: 106 | raw_objects = [i.strip().split(',') for i in IN] 107 | objects = [] 108 | for i in raw_objects: 109 | i[0] = float(i[0]) 110 | i[1] = float(i[1]) 111 | i[2] = float(i[2]) 112 | i[3] = float(i[3]) 113 | class_name = i[4].strip() 114 | if class_name in class_name_to_id: 115 | # remove objects that are not in coco, these objects have class id but not appear in coco 116 | i[4] = int(class_name_to_id[class_name]) - 1 117 | objects.append(i) 118 | if len(objects) == 0: 119 | return None, None, None 120 | batch = [] 121 | image_resizer = ImageResizer() 122 | new_h, new_w = image_resizer.get_proper_size(image_size) 123 | batch.append(torch.randn(1, 3, new_h, new_w).to(device)) 124 | batch.append(torch.from_numpy(np.array(objects)).to(device).unsqueeze(0)) 125 | batch.append(( 126 | concatenate_class_labels_to_caption(objects, class_id_to_name, api_key, additional_caption), 127 | )) 128 | res = ddpm_model.test_step(batch, 0) # we pass a batch but only text and layout is used when sampling 129 | sampled_images = res['sampling']['model_output'] 130 | return postprocess_image(sampled_images, batch[1], class_id_to_name, image_callback=lambda x: image_resizer.to_original_size(x)) 131 | 132 | 133 | class ImageResizer: 134 | def __init__(self): 135 | self.original_size = None 136 | 137 | def to_proper_size(self, img): 138 | # Get the new height and width that can be divided by 64 139 | new_h, new_w = self.get_proper_size(img.shape[:2]) 140 | 141 | # Resize the image using OpenCV's resize function 142 | resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) 143 | 144 | return resized 145 | 146 | def to_original_size(self, img): 147 | # Resize the image to original size using OpenCV's resize function 148 | resized = cv2.resize(img, (self.original_size[1], self.original_size[0]), interpolation=cv2.INTER_AREA) 149 | 150 | return resized 151 | 152 | def get_proper_size(self, size): 153 | self.original_size = size 154 | # Calculate the new height and width that can be divided by 64 155 | if size[0] % 64 == 0: 156 | new_h = size[0] 157 | else: 158 | new_h = size[0] + (64 - size[0] % 64) 159 | 160 | if size[1] % 64 == 0: 161 | new_w = size[1] 162 | else: 163 | new_w = size[1] + (64 - size[1] % 64) 164 | 165 | return new_h, new_w 166 | 167 | def parse_test_args(): 168 | import argparse 169 | import json 170 | parser = argparse.ArgumentParser() 171 | parser.add_argument( 172 | '-c', '--config', type=str, 173 | default='config/train.json') 174 | parser.add_argument( 175 | '-e', '--epoch', type=int, 176 | default=None, help='which epoch to evaluate, if None, will use the latest') 177 | parser.add_argument( 178 | '--openai_api_key', type=str, 179 | default=None, help='openai api key for generating text prompt') 180 | parser.add_argument( 181 | '--model_path', type=str, 182 | default=None, help='model path for generating layout diffuse, if not provided, will use the latest.ckpt') 183 | parser.add_argument( 184 | '--additional_caption', type=str, 185 | default='', help='additional caption for the generated image') 186 | 187 | ''' parser configs ''' 188 | args_raw = parser.parse_args() 189 | with open(args_raw.config, 'r') as IN: 190 | args = json.load(IN) 191 | args.update(vars(args_raw)) 192 | return args 193 | 194 | def load_test_models(args): 195 | from train_utils import get_models, get_DDPM 196 | models = get_models(args) 197 | 198 | diffusion_configs = args['diffusion'] 199 | ddpm_model = get_DDPM( 200 | diffusion_configs=diffusion_configs, 201 | log_args=args, 202 | **models 203 | ) 204 | return ddpm_model 205 | 206 | def load_model_weights(ddpm_model, args): 207 | print('INFO: loading checkpoint') 208 | if args['model_path'] is not None: 209 | ckpt_path = args['model_path'] 210 | else: 211 | expt_path = os.path.join(args['expt_dir'], args['expt_name']) 212 | if args['epoch'] is None: 213 | ckpt_to_use = 'latest.ckpt' 214 | else: 215 | ckpt_to_use = f'epoch={args["epoch"]:04d}.ckpt' 216 | ckpt_path = os.path.join(expt_path, ckpt_to_use) 217 | print(ckpt_path) 218 | if os.path.exists(ckpt_path): 219 | print(f'INFO: Found checkpoint {ckpt_path}') 220 | ckpt = torch.load(ckpt_path, map_location='cpu')['state_dict'] 221 | ddpm_model.load_state_dict(ckpt) 222 | else: 223 | ckpt_path = None 224 | raise RuntimeError('Cannot do inference without pretrained checkpoint') -------------------------------------------------------------------------------- /train_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | # from modules.vae.vae import BetaVAE 3 | from pytorch_lightning.loggers import WandbLogger 4 | from callbacks import get_epoch_checkpoint, get_latest_checkpoint, get_iteration_checkpoint 5 | from model_utils import instantiate_from_config, get_obj_from_str 6 | 7 | def get_models(args): 8 | denoise_model = args['denoising_model']['model'] 9 | denoise_args = args['denoising_model']['model_args'] 10 | denoise_fn = instantiate_from_config({ 11 | 'target': denoise_model, 12 | 'params': denoise_args 13 | }) 14 | model_dict = { 15 | 'denoise_fn': denoise_fn, 16 | } 17 | 18 | if args.get('vqvae_model'): 19 | vq_model = args['vqvae_model']['model'] 20 | vq_args = args['vqvae_model']['model_args'] 21 | vqvae_fn = instantiate_from_config({ 22 | 'target': vq_model, 23 | 'params': vq_args 24 | }) 25 | 26 | model_dict['vqvae_fn'] = vqvae_fn 27 | 28 | if args.get('text_model'): 29 | text_model = args['text_model']['model'] 30 | text_args = args['text_model']['model_args'] 31 | text_fn = instantiate_from_config({ 32 | 'target': text_model, 33 | 'params': text_args 34 | }) 35 | 36 | model_dict['text_fn'] = text_fn 37 | 38 | return model_dict 39 | 40 | def get_DDPM(diffusion_configs, log_args={}, **models): 41 | diffusion_model_class = diffusion_configs['model'] 42 | diffusion_args = diffusion_configs['model_args'] 43 | DDPM_model = get_obj_from_str(diffusion_model_class) 44 | ddpm_model = DDPM_model( 45 | log_args=log_args, 46 | **models, 47 | **diffusion_args 48 | ) 49 | return ddpm_model 50 | 51 | 52 | def get_logger_and_callbacks(expt_name, expt_path, args): 53 | callbacks = [] 54 | # 3.1 checkpoint callbacks 55 | save_model_config = args.get('save_model_config', {}) 56 | epoch_checkpoint = get_epoch_checkpoint(expt_path, **save_model_config) 57 | latest_checkpoint = get_latest_checkpoint(expt_path) 58 | callbacks.append(epoch_checkpoint) 59 | callbacks.append(latest_checkpoint) 60 | 61 | # 3.2 wandb logger 62 | wandb_logger = WandbLogger( 63 | project=expt_name, 64 | ) 65 | iteration_callbacks = args.get('iteration_callbacks') 66 | if iteration_callbacks: 67 | callbacks.append(get_iteration_checkpoint(expt_path)) 68 | config_callbacks = args.get('callbacks') 69 | if config_callbacks is not None: 70 | for callback in config_callbacks: 71 | print(f'Initiate callback {callback}') 72 | callbacks.append( 73 | get_obj_from_str(callback)( 74 | wandb_logger=wandb_logger, 75 | max_num_images=8 76 | ) 77 | ) 78 | else: 79 | from callbacks import WandBImageLogger 80 | print(f'INFO: got {expt_name}, will use default image logger') 81 | wandb_callback = WandBImageLogger( 82 | wandb_logger=wandb_logger, 83 | max_num_images=8 84 | ) 85 | callbacks.append(wandb_callback) 86 | 87 | return wandb_logger, callbacks 88 | 89 | if os.path.exists('negative/EasyNegative.safetensors'): 90 | from safetensors import safe_open 91 | with safe_open('negative/EasyNegative.safetensors', framework="pt", device="cpu") as f: 92 | NEGATIVE_PROMPTS_EMBEDDINGS = f.get_tensor('emb_params') 93 | else: 94 | NEGATIVE_PROMPTS_EMBEDDINGS = None 95 | NEGATIVE_PROMPTS = "(((deformed))), blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar, multiple breasts, (mutated hands and fingers:1.5), (long body :1.3), (mutation, poorly drawn :1.2), black-white, bad anatomy, liquid body, liquidtongue, disfigured, malformed, mutated, anatomical nonsense, text font ui, error, malformed hands, long neck, blurred, lowers, low res, bad anatomy, bad proportions, bad shadow, uncoordinated body, unnatural body, fused breasts, bad breasts, huge breasts, poorly drawn breasts, extra breasts, liquid breasts, heavy breasts, missingbreasts, huge haunch, huge thighs, huge calf, bad hands, fused hand, missing hand, disappearing arms, disappearing thigh, disappearing calf, disappearing legs, fusedears, bad ears, poorly drawn ears, extra ears, liquid ears, heavy ears, missing ears, old photo, low res, black and white, black and white filter, colorless" 96 | 97 | def obtain_state_dict_key_mapping(key_in_layout_diffuse): 98 | key_only_in_layout_diffuse = False 99 | if key_in_layout_diffuse == 'output_blocks.5.3.conv.weight': 100 | key_in_foundational_model = 'output_blocks.5.2.conv.weight' 101 | elif key_in_layout_diffuse == 'output_blocks.5.3.conv.bias': 102 | key_in_foundational_model = 'output_blocks.5.2.conv.bias' 103 | elif key_in_layout_diffuse == 'output_blocks.8.3.conv.weight': 104 | key_in_foundational_model = 'output_blocks.8.2.conv.weight' 105 | elif key_in_layout_diffuse == 'output_blocks.8.3.conv.bias': 106 | key_in_foundational_model = 'output_blocks.8.2.conv.bias' 107 | else: 108 | key_in_foundational_model = key_in_layout_diffuse 109 | key_only_in_layout_diffuse = True 110 | return key_in_foundational_model, key_only_in_layout_diffuse --------------------------------------------------------------------------------