├── LICENSE ├── README.md ├── git_demo.png ├── model_diffusion.py ├── plot_train.py ├── requirements.txt ├── test.py ├── train.py ├── train_fid_model.py └── util ├── backbone.py ├── constraint.py ├── data_util.py ├── datasets ├── __init__.py ├── base.py ├── crello.py ├── dataset.py ├── load_data.py ├── magazine.py ├── publaynet.py └── rico.py ├── diffusion_utils.py ├── ema.py ├── metric.py ├── seq_util.py ├── util.py └── visualization.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Jian Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Towards Aligned Layout Generation via Diffusion Model with Aesthetic Constraints 2 | source code of the layout generation model, [LACE](https://arxiv.org/abs/2402.04754). 3 | ![image](git_demo.png) 4 | ## 1. Installation 5 | ### 1.1 Prepare environment 6 | Install package for python 3.9 or later version: 7 | ``` 8 | conda create --name LACE python=3.9 9 | conda activate LACE 10 | python -m pip install -r requirements.txt 11 | ``` 12 | 13 | ### 1.2 Checkpoints 14 | Download the **trained checkpoints** for diffusion model and FID model at [Hugging Face](https://huggingface.co/datasets/puar-playground/LACE/tree/main) or through command line: 15 | ``` 16 | wget https://huggingface.co/datasets/puar-playground/LACE/resolve/main/model.tar.gz 17 | wget https://huggingface.co/datasets/puar-playground/LACE/resolve/main/fid.tar.gz 18 | tar -xvzf model.tar.gz 19 | tar -xvzf fid.tar.gz 20 | ``` 21 | Model hyper-parameter:
22 | for Publaynet: `--dim_transformer 1024 --nhead 16 --nlayer 4 --feature_dim 2048`
23 | for Rico13 and Rico25: `--dim_transformer 512 --nhead 16 --nlayer 4 --feature_dim 2048`
24 | 25 | ### 1.3 Datasets 26 | The datasets are also available at: 27 | ``` 28 | wget https://huggingface.co/datasets/puar-playground/LACE/resolve/main/datasets.tar.gz 29 | tar -xvzf datasets.tar.gz 30 | ``` 31 | Alternatively, you can download from the source and prepare each dataset as following: 32 | * [PubLayNet](https://developer.ibm.com/exchanges/data/all/publaynet/): Download the `labels.tar.gz` and decompress to `./dataset/publaynet-max25/raw` folder.
33 | * [Rico](https://www.kaggle.com/datasets/onurgunes1993/rico-dataset): Download the `rico_dataset_v0.1_semantic_annotations.zip` and decompress to `./dataset/rico25-max25/raw` folder.
34 | 35 | When the dataset is initialized for the first time, a new folder callled `processed` will be created at e.g., `./dataset/magazine-max25/processed` containing the formatted dataset for future uses. Training split of smaller dataset: Rico and Magazine will be duplicated to reach a reasonable epoch size. 36 | 37 | 38 | 39 | ## 2. Testing 40 | Run python script `test.py` to test. Please run `python test.py -h` to see detailed explaination.
41 | For PubLayNet: 42 | ``` 43 | python test.py --dataset publaynet --experiment all --device cuda:0 --dim_transformer 1024 --nhead 16 --batch_size 2048 --beautify 44 | ``` 45 | For Rico: 46 | ``` 47 | python test.py --dataset rico25 --experiment all --device cuda:0 --dim_transformer 512 --nhead 16 --batch_size 2048 --beautify 48 | ``` 49 | 50 | 51 | ## 3. Training 52 | Run python script `train_diffusion.py` to train.
53 | The script takes several command line arguments. Please run `python train_diffusion.py -h` to see detailed explaination.
54 | Example command for training:
55 | ``` 56 | python train.py --device cuda:1 --dataset rico25 --no-load_pre --lr 1e-6 --n_save_epoch 10 57 | ``` 58 | 59 | ## Reference 60 | ``` 61 | @inproceedings{ 62 | chen2024towards, 63 | title={Towards Aligned Layout Generation via Diffusion Model with Aesthetic Constraints}, 64 | author={Jian Chen and Ruiyi Zhang and Yufan Zhou and Changyou Chen}, 65 | booktitle={The Twelfth International Conference on Learning Representations}, 66 | year={2024}, 67 | url={https://openreview.net/forum?id=kJ0qp9Xdsh} 68 | } 69 | ``` 70 | -------------------------------------------------------------------------------- /git_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/puar-playground/LACE/3df36879a1e80cce58affa4aadeeb768f676c7f1/git_demo.png -------------------------------------------------------------------------------- /model_diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from util.diffusion_utils import * 4 | import torch.nn.functional as F 5 | from typing import Dict, List, Optional, Tuple, Union 6 | from util.backbone import TransformerEncoder 7 | from util.visualization import save_image 8 | import matplotlib.pyplot as plt 9 | import matplotlib 10 | # matplotlib.use('TkAgg') 11 | torch.manual_seed(1) 12 | 13 | 14 | class Diffusion(nn.Module): 15 | def __init__(self, num_timesteps=1000, nhead=8, feature_dim=2048, dim_transformer=512, seq_dim=10, num_layers=4, device='cuda', 16 | beta_schedule='cosine', ddim_num_steps=50, condition='None'): 17 | super().__init__() 18 | self.device = device 19 | self.num_timesteps = num_timesteps 20 | betas = make_beta_schedule(schedule=beta_schedule, num_timesteps=self.num_timesteps, start=0.0001, end=0.02) 21 | betas = self.betas = betas.float().to(self.device) 22 | self.betas_sqrt = torch.sqrt(betas) 23 | alphas = 1.0 - betas 24 | self.alphas = alphas 25 | self.one_minus_betas_sqrt = torch.sqrt(alphas) 26 | self.alphas_cumprod = alphas.cumprod(dim=0) 27 | self.alphas_bar_sqrt = torch.sqrt(self.alphas_cumprod) 28 | self.one_minus_alphas_bar_sqrt = torch.sqrt(1 - self.alphas_cumprod) 29 | alphas_cumprod_prev = torch.cat([torch.ones(1).to(self.device), self.alphas_cumprod[:-1]], dim=0) 30 | self.alphas_cumprod_prev = alphas_cumprod_prev 31 | self.posterior_mean_coeff_1 = (betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)) 32 | self.posterior_mean_coeff_2 = (torch.sqrt(alphas) * (1 - alphas_cumprod_prev) / (1 - self.alphas_cumprod)) 33 | posterior_variance = (betas * (1.0 - alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)) 34 | self.posterior_variance = posterior_variance 35 | self.logvar = betas.log() 36 | 37 | self.condition = condition 38 | self.seq_dim = seq_dim 39 | self.num_class = seq_dim - 4 40 | 41 | self.model = TransformerEncoder(num_layers=num_layers, dim_seq=seq_dim, dim_transformer=dim_transformer, nhead=nhead, 42 | dim_feedforward=feature_dim, diffusion_step=num_timesteps, device=device) 43 | 44 | self.ddim_num_steps = ddim_num_steps 45 | self.make_ddim_schedule(ddim_num_steps) 46 | 47 | def make_ddim_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.): 48 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 49 | num_ddpm_timesteps=self.num_timesteps) 50 | 51 | assert self.alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 52 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device) 53 | 54 | self.register_buffer('sqrt_alphas_cumprod', to_torch(torch.sqrt(self.alphas_cumprod))) 55 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(torch.sqrt(1. - self.alphas_cumprod))) 56 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(torch.log(1. - self.alphas_cumprod))) 57 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(torch.sqrt(1. / self.alphas_cumprod))) 58 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(torch.sqrt(1. / self.alphas_cumprod - 1))) 59 | 60 | # ddim sampling parameters 61 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=self.alphas_cumprod, 62 | ddim_timesteps=self.ddim_timesteps, 63 | eta=ddim_eta) 64 | self.register_buffer('ddim_sigmas', ddim_sigmas) 65 | self.register_buffer('ddim_alphas', ddim_alphas) 66 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 67 | self.register_buffer('ddim_sqrt_one_minus_alphas', torch.sqrt(1. - ddim_alphas)) 68 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 69 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 70 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 71 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 72 | 73 | def load_diffusion_net(self, net_state_dict): 74 | # new_states = dict() 75 | # for k in net_state_dict.keys(): 76 | # if 'layer_out' not in k and 'layer_in' not in k: 77 | # new_states[k] = net_state_dict[k] 78 | self.model.load_state_dict(net_state_dict, strict=True) 79 | 80 | def sample_t(self, size=(1,), t_max=None): 81 | """Samples batches of time steps to use.""" 82 | if t_max is None: 83 | t_max = int(self.num_timesteps) - 1 84 | 85 | t = torch.randint(low=0, high=t_max, size=size, device=self.device) 86 | 87 | return t.to(self.device) 88 | 89 | def forward_t(self, l_0_batch, t, real_mask, reparam=False): 90 | 91 | batch_size = l_0_batch.shape[0] 92 | e = torch.randn_like(l_0_batch).to(l_0_batch.device) 93 | 94 | l_t_noise = q_sample(l_0_batch, self.alphas_bar_sqrt, 95 | self.one_minus_alphas_bar_sqrt, t, noise=e) 96 | 97 | # cond c 98 | l_t_input_c = l_0_batch.clone() 99 | l_t_input_c[:, :, self.num_class:] = l_t_noise[:, :, self.num_class:] 100 | 101 | # cond cwh 102 | l_t_input_cwh = l_0_batch.clone() 103 | l_t_input_cwh[:, :, self.num_class:self.num_class+2] = l_t_noise[:, :, self.num_class:self.num_class+2] 104 | 105 | # cond complete 106 | fix_mask = rand_fix(batch_size, real_mask, ratio=0.2) 107 | l_t_input_complete = l_t_noise.clone() 108 | l_t_input_complete[fix_mask] = l_0_batch[fix_mask] 109 | 110 | l_t_input_all = torch.cat([l_t_noise, l_t_input_c, l_t_input_cwh, l_t_input_complete], dim=0) 111 | e_all = torch.cat([e, e, e, e], dim=0) 112 | t_all = torch.cat([t, t, t, t], dim=0) 113 | 114 | eps_theta = self.model(l_t_input_all, timestep=t_all) 115 | 116 | if reparam: 117 | sqrt_one_minus_alpha_bar_t = extract(self.one_minus_alphas_bar_sqrt, t_all, l_t_input_all) 118 | sqrt_alpha_bar_t = (1 - sqrt_one_minus_alpha_bar_t.square()).sqrt() 119 | l_0_generate_reparam = 1 / sqrt_alpha_bar_t * (l_t_input_all - eps_theta * sqrt_one_minus_alpha_bar_t).to(self.device) 120 | 121 | return eps_theta, e_all, l_0_generate_reparam 122 | else: 123 | return eps_theta, e_all, None 124 | 125 | def reverse(self, batch_size, only_last_sample=True, stochastic=True): 126 | 127 | self.model.eval() 128 | layout_t_0 = p_sample_loop(self.model, batch_size, 129 | self.num_timesteps, self.alphas, 130 | self.one_minus_alphas_bar_sqrt, 131 | only_last_sample=only_last_sample, stochastic=stochastic) 132 | 133 | bbox, label, mask = self.finalize(layout_t_0) 134 | 135 | return bbox, label, mask 136 | 137 | def reverse_ddim(self, batch_size=4, stochastic=True, save_inter=False, max_len=25): 138 | 139 | self.model.eval() 140 | layout_t_0, intermediates = ddim_sample_loop(self.model, batch_size, self.ddim_timesteps, self.ddim_alphas, 141 | self.ddim_alphas_prev, self.ddim_sigmas, stochastic=stochastic, 142 | seq_len=max_len, seq_dim=self.seq_dim) 143 | 144 | bbox, label, mask = self.finalize(layout_t_0, self.num_class) 145 | 146 | if not save_inter: 147 | return bbox, label, mask 148 | 149 | else: 150 | for i, layout_t in enumerate(intermediates['y_inter']): 151 | bbox, label, mask = self.finalize(layout_t, self.num_class) 152 | a = save_image(bbox, label, mask, draw_label=True) 153 | plt.figure(figsize=[15, 20]) 154 | plt.imshow(a) 155 | plt.tight_layout() 156 | plt.savefig(f'./plot/inter_{i}.png') 157 | plt.close() 158 | 159 | return bbox, label, mask 160 | 161 | 162 | @staticmethod 163 | def finalize(layout, num_class): 164 | layout[:, :, num_class:] = torch.clamp(layout[:, :, num_class:], min=-1, max=1) / 2 + 0.5 165 | bbox = layout[:, :, num_class:] 166 | label = torch.argmax(layout[:, :, :num_class], dim=2) 167 | mask = (label != num_class-1).clone().detach() 168 | 169 | return bbox, label, mask 170 | 171 | def conditional_reverse_ddim(self, real_layout, cond='c', ratio=0.2, stochastic=True): 172 | 173 | self.model.eval() 174 | layout_t_0, intermediates = \ 175 | ddim_cond_sample_loop(self.model, real_layout, self.ddim_timesteps, self.ddim_alphas, 176 | self.ddim_alphas_prev, self.ddim_sigmas, stochastic=stochastic, cond=cond, 177 | ratio=ratio) 178 | 179 | bbox, label, mask = self.finalize(layout_t_0, self.num_class) 180 | 181 | return bbox, label, mask 182 | 183 | def refinement_reverse_ddim(self, noisy_layout): 184 | self.model.eval() 185 | layout_t_0, intermediates = \ 186 | ddim_refine_sample_loop(self.model, noisy_layout, self.ddim_timesteps, self.ddim_alphas, 187 | self.ddim_alphas_prev, self.ddim_sigmas) 188 | 189 | bbox, label, mask = self.finalize(layout_t_0, self.num_class) 190 | 191 | return bbox, label, mask 192 | 193 | 194 | 195 | if __name__ == "__main__": 196 | 197 | model = Diffusion(num_timesteps=1000, nhead=8, dim_transformer=1024, 198 | feature_dim=2048, seq_dim=10, num_layers=4, 199 | device='cpu', ddim_num_steps=200, embed_type='pos') 200 | 201 | print(pow(model.one_minus_alphas_bar_sqrt[201], 2)) 202 | 203 | 204 | print(sum(model.ddim_timesteps <= 201)) 205 | timesteps = model.ddim_timesteps 206 | total_steps = sum(model.ddim_timesteps <= 201) 207 | time_range = np.flip(timesteps[:total_steps]) 208 | print(time_range) 209 | -------------------------------------------------------------------------------- /plot_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tqdm import tqdm 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import matplotlib.pyplot as plt 8 | import matplotlib 9 | matplotlib.use('TkAgg') 10 | from util.datasets.load_data import init_dataset 11 | from util.visualization import save_image 12 | from util.seq_util import sparse_to_dense, pad_until 13 | import argparse 14 | 15 | if __name__ == "__main__": 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--nepoch", default=500, help="number of training epochs", type=int) 19 | parser.add_argument("--batch_size", default=4, help="batch_size", type=int) 20 | parser.add_argument("--device", default='cuda', help="which GPU to use", type=str) 21 | parser.add_argument("--num_workers", default=1, help="num_workers", type=int) 22 | parser.add_argument("--num_plot", default=50, help="num_workers", type=int) 23 | parser.add_argument("--dataset", default='magazine', 24 | help="choose from [publaynet, rico13, rico25]", type=str) 25 | parser.add_argument("--data_dir", default='./datasets', help="dir of datasets", type=str) 26 | args = parser.parse_args() 27 | 28 | # prepare data 29 | train_dataset, train_loader = init_dataset(args.dataset, args.data_dir, batch_size=args.batch_size, 30 | split='train', shuffle=False) 31 | if not os.path.exists(f'./plot/{args.dataset}_train/'): 32 | os.mkdir(f'./plot/{args.dataset}_train/') 33 | 34 | with tqdm(enumerate(train_loader, 1), total=args.num_plot, desc=f'load data', 35 | ncols=120) as pbar: 36 | 37 | for i, data in pbar: 38 | 39 | if i > args.num_plot: 40 | break 41 | 42 | bbox, label, _, mask = sparse_to_dense(data) 43 | label, bbox, mask = pad_until(label, bbox, mask, max_seq_length=25) 44 | 45 | a = save_image(bbox, label, mask, draw_label=True, dataset=f'{args.dataset}') 46 | plt.figure(figsize=[15, 20]) 47 | plt.imshow(a) 48 | plt.tight_layout() 49 | plt.savefig(f'./plot/{args.dataset}_train/{i}.png') 50 | plt.close() 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.6.1 2 | fsspec==2023.4.0 3 | imageio==2.28.0 4 | matplotlib==3.7.1 5 | numpy==1.24.3 6 | omegaconf==2.3.0 7 | Pillow==10.0.0 8 | prdc==0.2 9 | pycocotools==2.0.6 10 | pytorch_fid==0.3.0 11 | scipy==1.10.1 12 | seaborn==0.12.2 13 | torch==2.0.0 14 | torch_geometric==2.3.1 15 | torchvision==0.15.1 16 | tqdm==4.65.0 17 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from fid.model import load_fidnet_v3 3 | from util.metric import compute_generative_model_scores, compute_maximum_iou, compute_overlap, compute_alignment 4 | import pickle as pk 5 | from tqdm import tqdm 6 | from util.datasets.load_data import init_dataset 7 | from util.visualization import save_image 8 | from util.constraint import * 9 | import matplotlib.pyplot as plt 10 | import matplotlib 11 | # matplotlib.use('TkAgg') 12 | from util.seq_util import sparse_to_dense, loader_to_list, pad_until 13 | import argparse 14 | from model_diffusion import Diffusion 15 | 16 | 17 | def test_fid_feat(dataset_name, device='cuda', batch_size=20): 18 | 19 | if os.path.exists(f'./fid/feature/fid_feat_test_{dataset_name}.pk'): 20 | feats_test = pk.load(open(f'./fid/feature/fid_feat_test_{dataset_name}.pk', 'rb')) 21 | return feats_test 22 | 23 | # prepare dataset 24 | main_dataset, main_dataloader = init_dataset(dataset_name, './datasets', batch_size=batch_size, 25 | split='test', shuffle=False, transform=None) 26 | 27 | fid_model = load_fidnet_v3(main_dataset, './fid/FIDNetV3', device=device) 28 | feats_test = [] 29 | 30 | with tqdm(enumerate(main_dataloader), total=len(main_dataloader), desc=f'Get feature for FID', 31 | ncols=200) as pbar: 32 | 33 | for i, data in pbar: 34 | 35 | bbox, label, _, mask = sparse_to_dense(data) 36 | label, bbox, mask = label.to(device), bbox.to(device), mask.to(device) 37 | padding_mask = ~mask 38 | 39 | with torch.set_grad_enabled(False): 40 | feat = fid_model.extract_features(bbox, label, padding_mask) 41 | feats_test.append(feat.detach().cpu()) 42 | 43 | pk.dump(feats_test, open(f'./fid/feature/fid_feat_test_{dataset_name}.pk', 'wb')) 44 | 45 | return feats_test 46 | 47 | 48 | def test_layout_uncond(model, batch_size=128, dataset_name='publaynet', test_plot=False, 49 | save_dir='./plot/test', beautify=False): 50 | 51 | model.eval() 52 | device = model.device 53 | n_batch_dict = {'publaynet': int(44 * 256 / batch_size), 'rico13': int(17 * 256 / batch_size), 54 | 'rico25': int(17 * 256 / batch_size), 'magazine': int(512 / batch_size), 55 | 'crello': int(2560 / batch_size)} 56 | n_batch = n_batch_dict[dataset_name] 57 | 58 | # prepare dataset 59 | main_dataset, _ = init_dataset(dataset_name, './datasets', batch_size=batch_size, split='test') 60 | 61 | fid_model = load_fidnet_v3(main_dataset, './fid/FIDNetV3', device=device) 62 | feats_test = test_fid_feat(dataset_name, device=device, batch_size=20) 63 | feats_generate = [] 64 | 65 | align_sum = 0 66 | overlap_sum = 0 67 | with torch.no_grad(): 68 | for i in tqdm(range(n_batch), desc='uncond testing', ncols=200, total=n_batch): 69 | bbox_generated, label, mask = model.reverse_ddim(batch_size=batch_size, stochastic=True, save_inter=False) 70 | if beautify and dataset_name=='publaynet': 71 | bbox_generated, mask = post_process(bbox_generated, mask, w_o=1) 72 | elif beautify and (dataset_name=='rico25' or dataset_name=='rico13'): 73 | bbox_generated, mask = post_process(bbox_generated, mask, w_o=0) 74 | 75 | padding_mask = ~mask 76 | 77 | label[mask == False] = 0 78 | 79 | if torch.isnan(bbox_generated[0, 0, 0]): 80 | print('not a number error') 81 | return None 82 | 83 | # accumulate align and overlap 84 | align_norm = compute_alignment(bbox_generated, mask) 85 | align_sum += torch.mean(align_norm) 86 | overlap_score = compute_overlap(bbox_generated, mask) 87 | overlap_sum += torch.mean(overlap_score) 88 | 89 | 90 | with torch.set_grad_enabled(False): 91 | feat = fid_model.extract_features(bbox_generated, label, padding_mask) 92 | feats_generate.append(feat.cpu()) 93 | 94 | if test_plot and i <= 10: 95 | img = save_image(bbox_generated[:9], label[:9], mask[:9], draw_label=False, dataset=dataset_name) 96 | plt.figure(figsize=[12, 12]) 97 | plt.imshow(img) 98 | plt.tight_layout() 99 | plt.savefig(os.path.join(save_dir, f'{dataset_name}_{i}.png')) 100 | # plt.close() 101 | 102 | result = compute_generative_model_scores(feats_test, feats_generate) 103 | fid = result['fid'] 104 | 105 | align_final = 100 * align_sum / n_batch 106 | overlap_final = 100 * overlap_sum / n_batch 107 | 108 | print(f'uncond, align: {align_final}, fid: {fid}, overlap: {overlap_final}') 109 | 110 | return align_final, fid, overlap_final 111 | 112 | 113 | def test_layout_cond(model, batch_size=256, cond='c', dataset_name='publaynet', seq_dim=10, 114 | test_plot=False, save_dir='./plot/test', beautify=False): 115 | 116 | assert cond in {'c', 'cwh', 'complete'} 117 | model.eval() 118 | device = model.device 119 | 120 | # prepare dataset 121 | main_dataset, main_dataloader = init_dataset(dataset_name, './datasets', batch_size=batch_size, 122 | split='test', shuffle=False, transform=None) 123 | 124 | layouts_main = loader_to_list(main_dataloader) 125 | layout_generated = [] 126 | 127 | fid_model = load_fidnet_v3(main_dataset, './fid/FIDNetV3', device=device) 128 | feats_test = test_fid_feat(dataset_name, device=device, batch_size=20) 129 | feats_generate = [] 130 | 131 | align_sum = 0 132 | overlap_sum = 0 133 | with torch.no_grad(): 134 | 135 | with tqdm(enumerate(main_dataloader), total=len(main_dataloader), desc=f'cond: {cond} generation', 136 | ncols=200) as pbar: 137 | 138 | for i, data in pbar: 139 | 140 | bbox, label, _, mask = sparse_to_dense(data) 141 | label, bbox, mask = pad_until(label, bbox, mask, max_seq_length=25) 142 | 143 | label, bbox, mask = label.to(device), bbox.to(device), mask.to(device) 144 | 145 | # shift to center 146 | bbox_in = 2 * (bbox - 0.5).to(device) 147 | 148 | # set mask to label 5 149 | label[mask == False] = seq_dim - 5 150 | 151 | label_oh = torch.nn.functional.one_hot(label, num_classes=seq_dim - 4).to(device) 152 | real_layout = torch.cat((label_oh, bbox_in), dim=2).to(device) 153 | 154 | bbox_generated, label_generated, mask_generated = model.conditional_reverse_ddim(real_layout, cond=cond) 155 | 156 | if beautify and dataset_name == 'publaynet': 157 | bbox_generated, mask_generated = post_process(bbox_generated, mask_generated, w_o=1) 158 | elif beautify and (dataset_name == 'rico25' or dataset_name == 'rico13'): 159 | bbox_generated, mask_generated = post_process(bbox_generated, mask_generated, w_o=0) 160 | 161 | padding_mask = ~mask_generated 162 | 163 | # test for errors 164 | if torch.isnan(bbox[0, 0, 0]): 165 | print('not a number error') 166 | return None 167 | 168 | # accumulate align and overlap 169 | align_norm = compute_alignment(bbox_generated, mask) 170 | align_sum += torch.mean(align_norm) 171 | overlap_score = compute_overlap(bbox_generated, mask) 172 | overlap_sum += torch.mean(overlap_score) 173 | 174 | # record for max_iou 175 | label_generated[label_generated == seq_dim - 5] = 0 176 | for j in range(bbox.shape[0]): 177 | mask_single = mask_generated[j, :] 178 | bbox_single = bbox_generated[j, mask_single, :] 179 | label_single = label_generated[j, mask_single] 180 | 181 | layout_generated.append((bbox_single.to('cpu').numpy(), label_single.to('cpu').numpy())) 182 | 183 | # record for FID 184 | with torch.set_grad_enabled(False): 185 | feat = fid_model.extract_features(bbox_generated, label_generated, padding_mask) 186 | feats_generate.append(feat.cpu()) 187 | 188 | if test_plot and i <= 10: 189 | img = save_image(bbox_generated[:9], label_generated[:9], mask_generated[:9], 190 | draw_label=False, dataset=dataset_name) 191 | plt.figure(figsize=[12, 12]) 192 | plt.imshow(img) 193 | plt.tight_layout() 194 | plt.savefig(f'./plot/test/cond_{cond}_{dataset_name}_{i}.png') 195 | plt.close() 196 | 197 | img = save_image(bbox[:9], label[:9], mask[:9], draw_label=False, dataset=dataset_name) 198 | plt.figure(figsize=[12, 12]) 199 | plt.imshow(img) 200 | plt.tight_layout() 201 | plt.savefig(os.path.join(save_dir, f'{dataset_name}_real.png')) 202 | plt.close() 203 | 204 | maxiou = compute_maximum_iou(layouts_main, layout_generated) 205 | result = compute_generative_model_scores(feats_test, feats_generate) 206 | fid = result['fid'] 207 | 208 | align_final = 100 * align_sum / len(main_dataloader) 209 | overlap_final = 100 * overlap_sum / len(main_dataloader) 210 | 211 | print(f'cond {cond}, align: {align_final}, fid: {fid}, maxiou: {maxiou}, overlap: {overlap_final}') 212 | 213 | return align_final, fid, maxiou, overlap_final 214 | 215 | 216 | def test_layout_refine(model, batch_size=256, dataset_name='publaynet', seq_dim=10, 217 | test_plot=False, save_dir='./plot/test', beautify=False): 218 | 219 | model.eval() 220 | device = model.device 221 | n_batch_dict = {'publaynet': 44, 'rico13': 17, 'rico25': 17, 'magazine': 2, 'crello': 10} 222 | n_batch = n_batch_dict[dataset_name] 223 | 224 | # prepare dataset 225 | main_dataset, main_dataloader = init_dataset(dataset_name, './datasets', batch_size=batch_size, 226 | split='test', shuffle=False, transform=None) 227 | 228 | layouts_main = loader_to_list(main_dataloader) 229 | layout_generated = [] 230 | 231 | fid_model = load_fidnet_v3(main_dataset, './fid/FIDNetV3', device=device) 232 | feats_test = test_fid_feat(dataset_name, device=device, batch_size=20) 233 | feats_generate = [] 234 | 235 | align_sum = 0 236 | overlap_sum = 0 237 | with torch.no_grad(): 238 | 239 | with tqdm(enumerate(main_dataloader), total=min(n_batch, len(main_dataloader)), desc=f'refine generation', 240 | ncols=200) as pbar: 241 | 242 | for i, data in pbar: 243 | if i == min(n_batch, len(main_dataloader)): 244 | break 245 | 246 | bbox, label, _, mask = sparse_to_dense(data) 247 | label, bbox, mask = pad_until(label, bbox, mask, max_seq_length=25) 248 | 249 | label, bbox, mask = label.to(device), bbox.to(device), mask.to(device) 250 | 251 | # shift to center 252 | bbox_noisy = torch.clamp(bbox + 0.1 * torch.randn_like(bbox), min=0, max=1) 253 | bbox_in_noisy = 2 * (bbox_noisy - 0.5).to(device) 254 | # 255 | # set mask to label 5 256 | label[mask == False] = seq_dim - 5 257 | 258 | label_oh = torch.nn.functional.one_hot(label, num_classes=seq_dim - 4).to(device) 259 | noisy_layout = torch.cat((label_oh, bbox_in_noisy), dim=2).to(device) 260 | 261 | bbox_refined, _, _ = model.refinement_reverse_ddim(noisy_layout) 262 | 263 | if beautify and dataset_name == 'publaynet': 264 | bbox_refined, mask = post_process(bbox_refined, mask, w_o=1) 265 | elif beautify and (dataset_name == 'rico25' or dataset_name == 'rico13'): 266 | bbox_refined, mask = post_process(bbox_refined, mask, w_o=0) 267 | padding_mask = ~mask 268 | 269 | # accumulate align and overlap 270 | align_norm = compute_alignment(bbox_refined, mask) 271 | align_sum += torch.mean(align_norm) 272 | overlap_score = compute_overlap(bbox_refined, mask) 273 | overlap_sum += torch.mean(overlap_score) 274 | 275 | # record for max_iou 276 | label[label == seq_dim - 5] = 0 277 | 278 | for j in range(bbox_refined.shape[0]): 279 | mask_single = mask[j, :] 280 | bbox_single = bbox_refined[j, mask_single, :] 281 | label_single = label[j, mask_single] 282 | 283 | layout_generated.append((bbox_single.to('cpu').numpy(), label_single.to('cpu').numpy())) 284 | 285 | # record for FID 286 | with torch.set_grad_enabled(False): 287 | feat = fid_model.extract_features(bbox_refined, label, padding_mask) 288 | feats_generate.append(feat.cpu()) 289 | 290 | 291 | if test_plot and i <= 10: 292 | img = save_image(bbox_refined[:9], label[:9], mask[:9], 293 | draw_label=False, dataset=dataset_name) 294 | plt.figure(figsize=[12, 12]) 295 | plt.imshow(img) 296 | plt.tight_layout() 297 | plt.savefig(f'./plot/test/refine_{dataset_name}_{i}.png') 298 | plt.close() 299 | 300 | img = save_image(bbox[:9], label[:9], mask[:9], draw_label=False, dataset=dataset_name) 301 | plt.figure(figsize=[12, 12]) 302 | plt.imshow(img) 303 | plt.tight_layout() 304 | plt.savefig(os.path.join(save_dir, f'{dataset_name}_real.png')) 305 | plt.close() 306 | 307 | maxiou = compute_maximum_iou(layouts_main, layout_generated) 308 | result = compute_generative_model_scores(feats_test, feats_generate) 309 | fid = result['fid'] 310 | 311 | align_final = 100 * align_sum / len(main_dataloader) 312 | overlap_final = 100 * overlap_sum / len(main_dataloader) 313 | 314 | print(f'refine, align: {align_final}, fid: {fid}, maxiou: {maxiou}, overlap: {overlap_final}') 315 | return align_final, fid, maxiou, overlap_final 316 | 317 | 318 | def test_all(model, dataset_name='publaynet', seq_dim=10, test_plot=False, save_dir='./plot/test', batch_size=256, 319 | beautify=False): 320 | 321 | align_uncond, fid_uncond, overlap_uncond = test_layout_uncond(model, batch_size=batch_size, dataset_name=dataset_name, 322 | test_plot=test_plot, save_dir=save_dir, beautify=beautify) 323 | align_c, fid_c, maxiou_c, overlap_c = test_layout_cond(model, batch_size=batch_size, cond='c', 324 | dataset_name=dataset_name, seq_dim=seq_dim, 325 | test_plot=test_plot, save_dir=save_dir, beautify=beautify) 326 | align_cwh, fid_cwh, maxiou_cwh, overlap_cwh = test_layout_cond(model, batch_size=batch_size, cond='cwh', 327 | dataset_name=dataset_name, seq_dim=seq_dim, 328 | test_plot=test_plot, save_dir=save_dir, beautify=beautify) 329 | align_complete, fid_complete, maxiou_complete, overlap_complete = test_layout_cond(model, batch_size=batch_size, 330 | cond='complete', dataset_name=dataset_name, 331 | seq_dim=seq_dim, test_plot=test_plot, 332 | save_dir=save_dir, beautify=beautify) 333 | align_r, fid_r, maxiou_r, overlap_r = test_layout_refine(model, batch_size=batch_size, 334 | dataset_name=dataset_name, seq_dim=seq_dim, 335 | test_plot=test_plot, save_dir=save_dir, beautify=beautify) 336 | 337 | # fid_total = fid_uncond + fid_c + fid_cwh + fid_complete 338 | # print(f'total fid: {fid_total}') 339 | # return fid_total 340 | 341 | 342 | if __name__ == "__main__": 343 | 344 | parser = argparse.ArgumentParser() 345 | parser.add_argument("--batch_size", default=256, help="batch_size", type=int) 346 | parser.add_argument("--device", default='cpu', help="which GPU to use", type=str) 347 | parser.add_argument("--dataset", default='publaynet', 348 | help="choose from [publaynet, rico13, rico25]", type=str) 349 | parser.add_argument("--data_dir", default='./datasets', help="dir of datasets", type=str) 350 | parser.add_argument("--num_workers", default=4, help="num_workers", type=int) 351 | parser.add_argument("--feature_dim", default=2048, help="feature_dim", type=int) 352 | parser.add_argument("--dim_transformer", default=1024, help="dim_transformer", type=int) 353 | parser.add_argument("--nhead", default=16, help="nhead attention", type=int) 354 | parser.add_argument("--nlayer", default=4, help="nlayer", type=int) 355 | parser.add_argument("--experiment", default='c', help="experiment setting [uncond, c, cwh, complete, all]", type=str) 356 | parser.add_argument('--plot', default=False, action=argparse.BooleanOptionalAction) 357 | parser.add_argument('--beautify', default=False, action=argparse.BooleanOptionalAction) 358 | parser.add_argument("--plot_save_dir", default='./plot/test', help="dir to save generated plot of layouts", type=str) 359 | args = parser.parse_args() 360 | 361 | # prepare data 362 | train_dataset, train_loader = init_dataset(args.dataset, args.data_dir, batch_size=args.batch_size, 363 | split='train', shuffle=True) 364 | num_class = train_dataset.num_classes + 1 365 | 366 | # set up model 367 | model_ddpm = Diffusion(num_timesteps=1000, nhead=args.nhead, dim_transformer=args.dim_transformer, 368 | feature_dim=args.feature_dim, seq_dim=num_class + 4, num_layers=args.nlayer, 369 | device=args.device, ddim_num_steps=100) 370 | 371 | state_dict = torch.load(f'./model/{args.dataset}_best.pt', map_location='cpu') 372 | model_ddpm.load_diffusion_net(state_dict) 373 | 374 | if args.experiment == 'uncond': 375 | test_layout_uncond(model_ddpm, batch_size=args.batch_size, 376 | dataset_name=args.dataset, test_plot=args.plot, 377 | save_dir=args.plot_save_dir, beautify=args.beautify) 378 | elif args.experiment in ['c', 'cwh', 'complete']: 379 | test_layout_cond(model_ddpm, batch_size=args.batch_size, cond=args.experiment, 380 | dataset_name=args.dataset, seq_dim=num_class + 4, 381 | test_plot=args.plot, save_dir=args.plot_save_dir, beautify=args.beautify) 382 | elif args.experiment == 'refine': 383 | test_layout_refine(model_ddpm, batch_size=args.batch_size, 384 | dataset_name=args.dataset, seq_dim=num_class + 4, 385 | test_plot=args.plot, save_dir=args.plot_save_dir, beautify=args.beautify) 386 | elif args.experiment == 'all': 387 | test_all(model_ddpm, dataset_name=args.dataset, seq_dim=num_class + 4, test_plot=args.plot, 388 | save_dir=args.plot_save_dir, batch_size=args.batch_size, beautify=args.beautify) 389 | else: 390 | raise Exception('experiment setting undefined') 391 | 392 | 393 | 394 | 395 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | from util.datasets.load_data import init_dataset 4 | from util.visualization import save_image 5 | from util.seq_util import sparse_to_dense, pad_until 6 | from model_diffusion import Diffusion 7 | from util.ema import EMA 8 | import argparse 9 | import pickle as pk 10 | import torch.optim as optim 11 | from util.constraint import * 12 | import math 13 | import os 14 | from test import test_all 15 | 16 | 17 | if __name__ == "__main__": 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--nepoch", default=None, help="number of training epochs", type=int) 21 | parser.add_argument("--start_epoch", default=0, help="start epoch", type=int) 22 | parser.add_argument("--batch_size", default=256, help="batch_size", type=int) 23 | parser.add_argument("--lr", default=1e-5, help="learning rate", type=float) 24 | parser.add_argument("--sample_t_max", default=999, help="maximum t in training", type=int) 25 | parser.add_argument("--dataset", default='publaynet', 26 | help="choose from [publaynet, rico13, rico25, magazine, crello]", type=str) 27 | parser.add_argument("--data_dir", default='./datasets', help="dir of datasets", type=str) 28 | parser.add_argument("--num_workers", default=4, help="num_workers", type=int) 29 | parser.add_argument("--n_save_epoch", default=50, help="number of epochs to do test and save model", type=int) 30 | parser.add_argument("--feature_dim", default=2048, help="feature_dim", type=int) 31 | parser.add_argument("--dim_transformer", default=1024, help="dim_transformer", type=int) 32 | parser.add_argument("--embed_type", default='pos', help="embed type for transformer, pos or time", type=str) 33 | parser.add_argument("--nhead", default=16, help="nhead attention", type=int) 34 | parser.add_argument("--nlayer", default=4, help="nlayer", type=int) 35 | parser.add_argument("--align_weight", default=1, help="the weight of alignment constraint", type=float) 36 | parser.add_argument("--align_type", default='local', help="local or global alignment constraint", type=str) 37 | parser.add_argument("--overlap_weight", default=1, help="the weight of overlap constraint", type=float) 38 | parser.add_argument('--load_pre', default=False, action=argparse.BooleanOptionalAction) 39 | parser.add_argument('--beautify', default=False, action=argparse.BooleanOptionalAction) 40 | parser.add_argument('--enable_test', default=True, action=argparse.BooleanOptionalAction) 41 | parser.add_argument("--gpu_devices", default=[0, 2, 3], type=int, nargs='+', help="") 42 | parser.add_argument("--device", default=None, help="which cuda to use", type=str) 43 | args = parser.parse_args() 44 | 45 | if args.device is None: 46 | gpu_devices = ','.join([str(id) for id in args.gpu_devices]) 47 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu_devices 48 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 49 | else: 50 | device = args.device 51 | 52 | print(f'load_pre: {args.load_pre}, enable_test: {args.enable_test}, embed: {args.embed_type}') 53 | print(f'dim_transformer: {args.dim_transformer}, n_layers: {args.nlayer}, nhead: {args.nhead}') 54 | print(f'align_type: {args.align_type}, align_weight: {args.align_weight}, overlap_weight: {args.overlap_weight}') 55 | print(f'device: {args.device}') 56 | 57 | # prepare data 58 | if args.embed_type == 'pos': 59 | train_dataset, train_loader = init_dataset(args.dataset, args.data_dir, batch_size=args.batch_size, 60 | split='train', shuffle=True, transform=None) 61 | else: 62 | train_dataset, train_loader = init_dataset(args.dataset, args.data_dir, batch_size=args.batch_size, 63 | split='train', shuffle=True) 64 | 65 | num_class = train_dataset.num_classes + 1 66 | 67 | # set up model 68 | model_ddpm = Diffusion(num_timesteps=1000, nhead=args.nhead, dim_transformer=args.dim_transformer, 69 | feature_dim=args.feature_dim, seq_dim=num_class + 4, num_layers=args.nlayer, 70 | device=device, ddim_num_steps=200) 71 | 72 | if args.load_pre: 73 | # state_dict = torch.load(f'./model/{args.embed_type}_{args.dataset}_1024_recent.pt', map_location='cpu') 74 | state_dict = torch.load(f'./model/publaynet_best.pt', map_location='cpu') 75 | model_ddpm.load_diffusion_net(state_dict) 76 | 77 | if args.device is None: 78 | print('using DataParallel') 79 | model_ddpm.model = nn.DataParallel(model_ddpm.model).to(device) 80 | else: 81 | print('using single gpu') 82 | model_ddpm.to(device) 83 | 84 | if args.load_pre: 85 | fid_best = test_all(model_ddpm, dataset_name=args.dataset, seq_dim=num_class + 4, batch_size=args.batch_size, 86 | beautify=args.beautify) 87 | # fid_best = 1e10 88 | else: 89 | fid_best = 1e10 90 | 91 | # optimizer 92 | optimizer = optim.Adam(model_ddpm.model.parameters(), lr=args.lr, weight_decay=0.0, betas=(0.9, 0.999), amsgrad=False, eps=1e-08) 93 | mse_loss = nn.MSELoss() 94 | 95 | ema_helper = EMA(mu=0.9999) 96 | ema_helper.register(model_ddpm.model) 97 | 98 | 99 | for epoch in range(args.start_epoch, args.nepoch): 100 | model_ddpm.model.train() 101 | 102 | if (epoch) % args.n_save_epoch == 0 and epoch != 0: 103 | 104 | # model_path = f'./model/{args.embed_type}_{args.dataset}_1024_recent.pt' 105 | # states = model_ddpm.model.module.state_dict() 106 | # torch.save(states, model_path) 107 | 108 | if args.enable_test: 109 | fid_total = test_all(model_ddpm, dataset_name=args.dataset, seq_dim=num_class + 4, batch_size=args.batch_size, beautify=False) 110 | # print(f'previous best fid: {fid_best}') 111 | # if fid_total < fid_best: 112 | # # model_path = f'./model/{args.embed_type}_{args.dataset}_1024_lowest.pt' 113 | # # torch.save(states, model_path) 114 | # fid_best = fid_total 115 | # print('New lowest fid model, saved') 116 | 117 | with tqdm(enumerate(train_loader), total=len(train_loader), desc=f'train diffusion epoch {epoch}', ncols=200) as pbar: 118 | 119 | for i, data in pbar: 120 | bbox, label, _, mask = sparse_to_dense(data) 121 | label, bbox, mask = pad_until(label, bbox, mask, max_seq_length=25) 122 | 123 | label, bbox, mask = label.to(device), bbox.to(device), mask.to(device) 124 | 125 | # shift to center 126 | bbox_in = 2 * (bbox - 0.5).to(args.device) 127 | 128 | # set mask to label 5 129 | label[mask==False] = num_class - 1 130 | 131 | label_oh = torch.nn.functional.one_hot(label, num_classes=num_class).to(args.device) 132 | 133 | # concat label with bbox and get a 10 dim 134 | layout_input = torch.cat((label_oh, bbox_in), dim=2).to(args.device) 135 | 136 | t = model_ddpm.sample_t([bbox.shape[0]], t_max=args.sample_t_max) 137 | t_all = torch.cat([t, t, t, t], dim=0) 138 | 139 | eps_theta, e, b_0_reparam = model_ddpm.forward_t(layout_input, t=t, real_mask=mask, reparam=True) 140 | 141 | # compute b_0 reparameterization 142 | bbox_rep = torch.clamp(b_0_reparam[:, :, num_class:], min=-1, max=1) / 2 + 0.5 143 | mask_4 = torch.cat([mask, mask, mask, mask], dim=0) 144 | bbox_4 = torch.cat([bbox, bbox, bbox, bbox], dim=0) 145 | 146 | # compute alignment loss 147 | if args.align_type == 'global': 148 | # global alignment 149 | align_loss = mean_alignment_error(bbox_rep, bbox_4, mask_4) 150 | else: 151 | # local alignment 152 | _, align_loss = layout_alignment(bbox_rep, mask_4, xy_only=False) 153 | align_loss = 20 * align_loss 154 | 155 | # compute piou and pdist 156 | piou = PIoU_xywh(bbox_rep, mask=mask_4.to(torch.float32), xy_only=False) 157 | pdist = Pdist(bbox_rep) 158 | 159 | # compute piou loss with temporal weight 160 | overlap_loss = torch.mean(piou, dim=[1, 2]) + torch.mean(piou.ne(0) * torch.exp(-pdist), dim=[1, 2]) 161 | # overlap_loss = torch.mean(piou, dim=[1, 2]) 162 | 163 | # reconstruction loss 164 | layout_input_all = torch.cat([layout_input, layout_input, layout_input, layout_input], dim=0) 165 | reconstruct_loss = mse_loss(layout_input_all[:, :, num_class:], b_0_reparam[:, :, num_class:]) 166 | # _, giou = GIoU_xywh(b_0_reparam[:, :, num_class:], layout_input_all[:, :, num_class:]) 167 | # reconstruct_loss = (1 - 1 * torch.mean(giou)) 168 | 169 | # combine constraints with temporal weight 170 | weight = constraint_temporal_weight(t_all, schedule='const') 171 | constraint_loss = torch.mean((args.align_weight * align_loss + args.overlap_weight * overlap_loss) 172 | * weight) 173 | 174 | # compute diffusion loss 175 | diffusion_loss = mse_loss(e, eps_theta) 176 | 177 | # total loss 178 | loss = diffusion_loss + constraint_loss + reconstruct_loss 179 | 180 | pbar.set_postfix({'diffusion': diffusion_loss.item(), 'align': torch.mean(align_loss).item(), 181 | 'overlap': torch.mean(overlap_loss).item(), 'reconstruct': reconstruct_loss.item()}) 182 | 183 | # optimize 184 | optimizer.zero_grad() 185 | loss.backward() 186 | torch.nn.utils.clip_grad_norm_(model_ddpm.model.parameters(), 1.0) 187 | optimizer.step() 188 | ema_helper.update(model_ddpm.model) 189 | 190 | -------------------------------------------------------------------------------- /train_fid_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | from pathlib import Path 5 | from tqdm import tqdm 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torchvision.transforms as T 10 | from torch_geometric.utils import to_dense_batch 11 | from util.data_util import AddNoiseToBBox, LexicographicOrder 12 | from util.fid.model import FIDNetV3 13 | from util.datasets.load_data import init_dataset 14 | 15 | 16 | def save_checkpoint(state, is_best, out_dir): 17 | out_path = Path(out_dir) / "checkpoint.pth.tar" 18 | torch.save(state, out_path) 19 | 20 | if is_best: 21 | best_path = Path(out_dir) / "model_best.pth.tar" 22 | shutil.copyfile(out_path, best_path) 23 | 24 | 25 | def main(): 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--batch_size", type=int, default=64, help="input batch size") 28 | parser.add_argument("--dataset", default='crello', help="choose from [magazine, rico13, rico25, publaynet, crello]", 29 | type=str) 30 | parser.add_argument("--data_dir", default='./datasets', help="dir of datasets", type=str) 31 | parser.add_argument("--device", default='cpu', help="which GPU to use", type=str) 32 | parser.add_argument("--out_dir", type=str, default="./fid/FIDNetV3/") 33 | parser.add_argument( 34 | "--iteration", 35 | type=int, 36 | default=int(2e5), 37 | help="number of iterations to train for", 38 | ) 39 | parser.add_argument( 40 | "--lr", type=float, default=3e-4, help="learning rate, default=3e-4" 41 | ) 42 | parser.add_argument("--seed", type=int, help="manual seed") 43 | args = parser.parse_args() 44 | print(args) 45 | 46 | prefix = "FIDNetV3" 47 | out_dir = Path(os.path.join(args.out_dir, args.dataset + '-max25')) 48 | out_dir.mkdir(parents=True, exist_ok=True) 49 | 50 | transform = T.Compose( 51 | [ 52 | T.RandomApply([AddNoiseToBBox()], 0.5), 53 | LexicographicOrder(), 54 | ] 55 | ) 56 | 57 | train_dataset, train_dataloader = init_dataset(args.dataset, args.data_dir, batch_size=64, split='train', 58 | transform=transform, shuffle=True) 59 | val_dataset, val_dataloader = init_dataset(args.dataset, args.data_dir, batch_size=64, split='test', 60 | transform=transform, shuffle=False) 61 | print('num_classes', train_dataset.num_classes) 62 | 63 | device = torch.device(args.device if torch.cuda.is_available() else "cpu") 64 | model = FIDNetV3(num_label=train_dataset.num_classes, max_bbox=25).to(device) 65 | 66 | # setup optimizer 67 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 68 | 69 | criterion_bce = nn.BCEWithLogitsLoss(reduction="none") 70 | criterion_label = nn.CrossEntropyLoss(reduction="none") 71 | criterion_bbox = nn.MSELoss(reduction="none") 72 | 73 | def proc_batch(batch): 74 | batch = batch.to(device) 75 | bbox, _ = to_dense_batch(batch.x, batch.batch) 76 | label, mask = to_dense_batch(batch.y, batch.batch) 77 | padding_mask = ~mask 78 | 79 | is_real = batch.attr["NoiseAdded"].float() 80 | return bbox, label, padding_mask, mask, is_real 81 | 82 | iteration = 0 83 | best_loss = 1e8 84 | max_epoch = args.iteration * args.batch_size / len(train_dataset) 85 | max_epoch = torch.ceil(torch.tensor(max_epoch)).int().item() 86 | for epoch in range(max_epoch): 87 | model.train() 88 | train_loss = { 89 | "Loss_BCE": 0, 90 | "Loss_Label": 0, 91 | "Loss_BBox": 0, 92 | } 93 | 94 | for i, batch in enumerate(train_dataloader): 95 | 96 | bbox, label, padding_mask, mask, is_real = proc_batch(batch) 97 | model.zero_grad() 98 | 99 | logit, logit_cls, bbox_pred = model(bbox, label, padding_mask) 100 | 101 | loss_bce = criterion_bce(logit, is_real) 102 | loss_label = criterion_label(logit_cls[mask], label[mask]) 103 | loss_bbox = criterion_bbox(bbox_pred[mask], bbox[mask]).sum(-1) 104 | loss = loss_bce.mean() + loss_label.mean() + 10 * loss_bbox.mean() 105 | loss.backward() 106 | 107 | optimizer.step() 108 | 109 | loss_bce_mean = loss_bce.mean().item() 110 | train_loss["Loss_BCE"] += loss_bce.sum().item() 111 | loss_label_mean = loss_label.mean().item() 112 | train_loss["Loss_Label"] += loss_label.sum().item() 113 | loss_bbox_mean = loss_bbox.mean().item() 114 | train_loss["Loss_BBox"] += loss_bbox.sum().item() 115 | 116 | if i % 100 == 0: 117 | log_prefix = f"[{epoch}/{max_epoch}][{i}/{len(train_dataset) // args.batch_size}]" 118 | log = f"Loss: {loss.item():E}\tBCE: {loss_bce_mean:E}\tLabel: {loss_label_mean:E}\tBBox: {loss_bbox_mean:E}" 119 | print(f"{log_prefix}\t{log}") 120 | 121 | iteration += 1 122 | 123 | for key in train_loss.keys(): 124 | train_loss[key] /= len(train_dataset) 125 | 126 | model.eval() 127 | with torch.no_grad(): 128 | val_loss = { 129 | "Loss_BCE": 0, 130 | "Loss_Label": 0, 131 | "Loss_BBox": 0, 132 | } 133 | 134 | for i, batch in enumerate(val_dataloader): 135 | bbox, label, padding_mask, mask, is_real = proc_batch(batch) 136 | 137 | logit, logit_cls, bbox_pred = model(bbox, label, padding_mask) 138 | 139 | loss_bce = criterion_bce(logit, is_real) 140 | loss_label = criterion_label(logit_cls[mask], label[mask]) 141 | loss_bbox = criterion_bbox(bbox_pred[mask], bbox[mask]).sum(-1) 142 | 143 | val_loss["Loss_BCE"] += loss_bce.sum().item() 144 | val_loss["Loss_Label"] += loss_label.sum().item() 145 | val_loss["Loss_BBox"] += loss_bbox.sum().item() 146 | 147 | for key in val_loss.keys(): 148 | val_loss[key] /= len(val_dataset) 149 | 150 | tag_scalar_dict = { 151 | "train": sum(train_loss.values()), 152 | "val": sum(val_loss.values()), 153 | } 154 | for key in train_loss.keys(): 155 | tag_scalar_dict = {"train": train_loss[key], "val": val_loss[key]} 156 | 157 | # do checkpointing 158 | val_loss = sum(val_loss.values()) 159 | is_best = val_loss < best_loss 160 | best_loss = min(val_loss, best_loss) 161 | 162 | save_checkpoint( 163 | { 164 | "epoch": epoch + 1, 165 | "state_dict": model.state_dict(), 166 | "best_loss": best_loss, 167 | "optimizer": optimizer.state_dict(), 168 | }, 169 | is_best, 170 | out_dir, 171 | ) 172 | 173 | 174 | if __name__ == "__main__": 175 | main() 176 | -------------------------------------------------------------------------------- /util/backbone.py: -------------------------------------------------------------------------------- 1 | # Implement TransformerEncoder that can consider timesteps as optional args for Diffusion. 2 | 3 | import copy 4 | import math 5 | from typing import Callable, Optional, Union 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from einops.layers.torch import Rearrange 10 | from torch import Tensor, nn 11 | 12 | 13 | def _get_clones(module, N): 14 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 15 | 16 | 17 | def _gelu2(x): 18 | return x * F.sigmoid(1.702 * x) 19 | 20 | 21 | def _get_activation_fn(activation): 22 | if activation == "relu": 23 | return F.relu 24 | elif activation == "gelu": 25 | return F.gelu 26 | elif activation == "gelu2": 27 | return _gelu2 28 | else: 29 | raise RuntimeError( 30 | "activation should be relu/gelu/gelu2, not {}".format(activation) 31 | ) 32 | 33 | 34 | class SinusoidalPosEmb(nn.Module): 35 | def __init__(self, num_steps: int, dim: int, rescale_steps: int = 4000): 36 | super().__init__() 37 | self.dim = dim 38 | self.num_steps = float(num_steps) 39 | self.rescale_steps = float(rescale_steps) 40 | 41 | def forward(self, x: Tensor): 42 | x = x / self.num_steps * self.rescale_steps 43 | device = x.device 44 | half_dim = self.dim // 2 45 | emb = math.log(10000) / (half_dim - 1) 46 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 47 | emb = x[:, None] * emb[None, :] 48 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 49 | return emb 50 | 51 | 52 | class _AdaNorm(nn.Module): 53 | def __init__( 54 | self, n_embd: int, max_timestep: int, emb_type: str = "adalayernorm_abs" 55 | ): 56 | super().__init__() 57 | if "abs" in emb_type: 58 | self.emb = SinusoidalPosEmb(max_timestep, n_embd) 59 | elif "mlp" in emb_type: 60 | self.emb = nn.Sequential( 61 | Rearrange("b -> b 1"), 62 | nn.Linear(1, n_embd // 2), 63 | nn.ReLU(), 64 | nn.Linear(n_embd // 2, n_embd), 65 | ) 66 | else: 67 | self.emb = nn.Embedding(max_timestep, n_embd) 68 | self.silu = nn.SiLU() 69 | self.linear = nn.Linear(n_embd, n_embd * 2) 70 | 71 | 72 | class AdaLayerNorm(_AdaNorm): 73 | def __init__( 74 | self, n_embd: int, max_timestep: int, emb_type: str = "adalayernorm_abs" 75 | ): 76 | super().__init__(n_embd, max_timestep, emb_type) 77 | self.layernorm = nn.LayerNorm(n_embd, elementwise_affine=False) 78 | 79 | def forward(self, x: Tensor, timestep: int): 80 | 81 | emb = self.linear(self.silu(self.emb(timestep))).unsqueeze(1) 82 | scale, shift = torch.chunk(emb, 2, dim=2) 83 | x = self.layernorm(x) * (1 + scale) + shift 84 | return x 85 | 86 | 87 | class AdaInsNorm(_AdaNorm): 88 | def __init__( 89 | self, n_embd: int, max_timestep: int, emb_type: str = "adalayernorm_abs" 90 | ): 91 | super().__init__(n_embd, max_timestep, emb_type) 92 | self.instancenorm = nn.InstanceNorm1d(n_embd) 93 | 94 | def forward(self, x, timestep): 95 | emb = self.linear(self.silu(self.emb(timestep))).unsqueeze(1) 96 | scale, shift = torch.chunk(emb, 2, dim=2) 97 | x = ( 98 | self.instancenorm(x.transpose(-1, -2)).transpose(-1, -2) * (1 + scale) 99 | + shift 100 | ) 101 | return x 102 | 103 | 104 | class Block(nn.Module): 105 | """an unassuming Transformer block""" 106 | 107 | def __init__( 108 | self, 109 | d_model=512, 110 | nhead=8, 111 | dim_feedforward: int = 2048, 112 | dropout: float = 0.0, 113 | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, 114 | batch_first: bool = True, 115 | norm_first: bool = True, 116 | device=None, 117 | dtype=None, 118 | # extension for diffusion 119 | diffusion_step: int = 100, 120 | timestep_type: str = 'adalayernorm', 121 | ) -> None: 122 | super().__init__() 123 | 124 | assert norm_first # minGPT-based implementations are designed for prenorm only 125 | assert timestep_type in [ 126 | None, 127 | "adalayernorm", 128 | "adainnorm", 129 | "adalayernorm_abs", 130 | "adainnorm_abs", 131 | "adalayernorm_mlp", 132 | "adainnorm_mlp", 133 | ] 134 | layer_norm_eps = 1e-5 # fixed 135 | 136 | self.norm_first = norm_first 137 | self.diffusion_step = diffusion_step 138 | self.timestep_type = timestep_type 139 | 140 | factory_kwargs = {"device": device, "dtype": dtype} 141 | self.self_attn = torch.nn.MultiheadAttention( 142 | d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs 143 | ) 144 | 145 | # Implementation of Feedforward model 146 | self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) 147 | self.dropout = nn.Dropout(dropout) 148 | self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) 149 | 150 | if timestep_type is not None: 151 | if "adalayernorm" in timestep_type: 152 | self.norm1 = AdaLayerNorm(d_model, diffusion_step, timestep_type) 153 | elif "adainnorm" in timestep_type: 154 | self.norm1 = AdaInsNorm(d_model, diffusion_step, timestep_type) 155 | else: 156 | self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) 157 | self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) 158 | self.dropout1 = nn.Dropout(dropout) 159 | self.dropout2 = nn.Dropout(dropout) 160 | 161 | if isinstance(activation, str): 162 | self.activation = _get_activation_fn(activation) 163 | else: 164 | self.activation = activation 165 | 166 | def forward( 167 | self, 168 | src: Tensor, 169 | src_mask: Optional[Tensor] = None, 170 | src_key_padding_mask: Optional[Tensor] = None, 171 | timestep: Tensor = None, 172 | ) -> Tensor: 173 | x = src 174 | if self.norm_first: 175 | if self.timestep_type is not None: 176 | x = self.norm1(x, timestep) 177 | else: 178 | x = self.norm1(x) 179 | x = x + self._sa_block(x, src_mask, src_key_padding_mask) 180 | x = x + self._ff_block(self.norm2(x)) 181 | else: 182 | x = x + self._sa_block(x, src_mask, src_key_padding_mask) 183 | if self.timestep_type is not None: 184 | x = self.norm1(x, timestep) 185 | else: 186 | x = self.norm1(x) 187 | x = self.norm2(x + self._ff_block(x)) 188 | 189 | return x 190 | 191 | # self-attention block 192 | def _sa_block( 193 | self, 194 | x: Tensor, 195 | attn_mask: Optional[Tensor], 196 | key_padding_mask: Optional[Tensor], 197 | ) -> Tensor: 198 | x = self.self_attn( 199 | x, 200 | x, 201 | x, 202 | attn_mask=attn_mask, 203 | key_padding_mask=key_padding_mask, 204 | need_weights=False, 205 | )[0] 206 | return self.dropout1(x) 207 | 208 | # feed forward block 209 | def _ff_block(self, x: Tensor) -> Tensor: 210 | x = self.linear2(self.dropout(self.activation(self.linear1(x)))) 211 | return self.dropout2(x) 212 | 213 | 214 | class TransformerEncoder(nn.Module): 215 | """ 216 | Close to torch.nn.TransformerEncoder, but with timestep support for diffusion 217 | """ 218 | 219 | __constants__ = ["norm"] 220 | 221 | def __init__(self, num_layers=4, dim_seq=10, dim_transformer=512, nhead=8, dim_feedforward=2048, 222 | diffusion_step=100, device='cuda'): 223 | super(TransformerEncoder, self).__init__() 224 | 225 | self.pos_encoder = SinusoidalPosEmb(num_steps=25, dim=dim_transformer).to(device) 226 | pos_i = torch.tensor([i for i in range(25)]).to(device) 227 | self.pos_embed = self.pos_encoder(pos_i) 228 | 229 | self.layer_in = nn.Linear(in_features=dim_seq, out_features=dim_transformer).to(device) 230 | encoder_layer = Block(d_model=dim_transformer, nhead=nhead, dim_feedforward=dim_feedforward, diffusion_step=diffusion_step) 231 | self.layers = _get_clones(encoder_layer, num_layers).to(device) 232 | self.num_layers = num_layers 233 | self.layer_out = nn.Linear(in_features=dim_transformer, out_features=dim_seq).to(device) 234 | 235 | def forward( 236 | self, 237 | src: Tensor, 238 | mask: Optional[Tensor] = None, 239 | src_key_padding_mask: Optional[Tensor] = None, 240 | timestep: Tensor = None, 241 | ) -> Tensor: 242 | output = src 243 | 244 | output = self.layer_in(output) 245 | output = F.softplus(output) 246 | output = output + self.pos_embed 247 | 248 | for i, mod in enumerate(self.layers): 249 | output = mod( 250 | output, 251 | src_mask=mask, 252 | src_key_padding_mask=src_key_padding_mask, 253 | timestep=timestep, 254 | ) 255 | 256 | if i < self.num_layers - 1: 257 | output = F.softplus(output) 258 | 259 | output = self.layer_out(output) 260 | 261 | return output 262 | 263 | -------------------------------------------------------------------------------- /util/constraint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import numpy as np 6 | import einops 7 | import imageio 8 | torch.manual_seed(0) 9 | import math 10 | 11 | 12 | def xywh_2_ltrb(bbox_xywh): 13 | 14 | bbox_ltrb = torch.zeros(bbox_xywh.shape).to(bbox_xywh.device) 15 | bbox_xy = torch.abs(bbox_xywh[:, :, :2]) 16 | bbox_wh = torch.abs(bbox_xywh[:, :, 2:]) 17 | bbox_ltrb[:, :, :2] = bbox_xy - 0.5 * bbox_wh 18 | bbox_ltrb[:, :, 2:] = bbox_xy + 0.5 * bbox_wh 19 | return bbox_ltrb 20 | 21 | 22 | def ltrb_2_xywh(bbox_ltrb): 23 | bbox_xywh = torch.zeros(bbox_ltrb.shape) 24 | bbox_wh = torch.abs(bbox_ltrb[:, :, 2:] - bbox_ltrb[:, :, :2]) 25 | bbox_xy = bbox_ltrb[:, :, :2] + 0.5 * bbox_wh 26 | bbox_xywh[:, :, :2] = bbox_xy 27 | bbox_xywh[:, :, 2:] = bbox_wh 28 | return bbox_xywh 29 | 30 | 31 | def xywh_to_ltrb_split(bbox): 32 | xc, yc, w, h = bbox 33 | x1 = xc - w / 2 34 | y1 = yc - h / 2 35 | x2 = xc + w / 2 36 | y2 = yc + h / 2 37 | return [x1, y1, x2, y2] 38 | 39 | 40 | def rand_bbox_ltrb(batch_shape): 41 | 42 | bbox_lt = torch.rand(batch_shape + [2]) 43 | bbox_wh_max = 1 - bbox_lt 44 | bbox_wh_weight = torch.rand(batch_shape).unsqueeze(-1).repeat([1 for _ in range(len(batch_shape))] + [2]) 45 | 46 | bbox_wh = 1 * bbox_wh_weight * bbox_wh_max 47 | bbox_rb = bbox_lt + bbox_wh 48 | 49 | bbox = torch.cat([bbox_lt, bbox_rb], dim=-1) 50 | return bbox 51 | 52 | 53 | def rand_bbox_xywh(batch_shape): 54 | 55 | bbox_ltrb = rand_bbox_ltrb(batch_shape) 56 | bbox_xywh = ltrb_2_xywh(bbox_ltrb) 57 | return bbox_xywh 58 | 59 | 60 | def GIoU_ltrb(bbox_1, bbox_2): 61 | 62 | # step 1 calculate area of bbox_1 and bbox_2 63 | a_1 = (bbox_1[:, :, 2] - bbox_1[:, :, 0]) * (bbox_1[:, :, 3] - bbox_1[:, :, 1]) 64 | a_2 = (bbox_2[:, :, 2] - bbox_2[:, :, 0]) * (bbox_2[:, :, 3] - bbox_2[:, :, 1]) 65 | 66 | # step 2.1 compute intersection I bbox 67 | bbox = torch.cat([bbox_1.unsqueeze(-1), bbox_2.unsqueeze(-1)], dim=-1) 68 | bbox_I_lt = torch.max(bbox, dim=-1)[0][:, :, :2] 69 | bbox_I_rb = torch.min(bbox, dim=-1)[0][:, :, 2:] 70 | 71 | # step 2.2 compute area of I 72 | a_I = F.relu((bbox_I_rb[:, :, 0] - bbox_I_lt[:, :, 0])) * F.relu((bbox_I_rb[:, :, 1] - bbox_I_lt[:, :, 1])) 73 | 74 | # step 3.1 compute smallest enclosing box C 75 | bbox_C_lt = torch.min(bbox, dim=-1)[0][:, :, :2] 76 | bbox_C_rb = torch.max(bbox, dim=-1)[0][:, :, 2:] 77 | 78 | # step 3.2 compute area of C 79 | a_C = (bbox_C_rb[:, :, 0] - bbox_C_lt[:, :, 0]) * (bbox_C_rb[:, :, 1] - bbox_C_lt[:, :, 1]) 80 | 81 | # step 4 compute IoU 82 | a_U = (a_1 + a_2 - a_I) 83 | iou = a_I / (a_U + 1e-10) 84 | 85 | # step 5 copute giou 86 | giou = iou - (a_C - a_U) / (a_C + 1e-10) 87 | 88 | return iou, giou 89 | 90 | 91 | def GIoU_xywh(bbox_pred, bbox_true, xy_only=False): 92 | 93 | if xy_only: 94 | wh = torch.abs(bbox_pred[:, :, 2:].clone().detach()) 95 | bbox = torch.cat([bbox_pred[:, :, :2], wh], dim=2) 96 | else: 97 | bbox = bbox_pred 98 | 99 | bbox_pred_ltrb = xywh_2_ltrb(torch.abs(bbox)) 100 | bbox_true_ltrb = xywh_2_ltrb(torch.abs(bbox_true)) 101 | return GIoU_ltrb(bbox_pred_ltrb, bbox_true_ltrb) 102 | 103 | 104 | def PIoU_ltrb(bbox_ltrb, mask=None): 105 | 106 | n_box = bbox_ltrb.shape[1] 107 | device = bbox_ltrb.device 108 | 109 | # compute area of bboxes 110 | area_bbox = (bbox_ltrb[:, :, 2] - bbox_ltrb[:, :, 0]) * (bbox_ltrb[:, :, 3] - bbox_ltrb[:, :, 1]) 111 | area_bbox_psum = area_bbox.unsqueeze(-1) + area_bbox.unsqueeze(-2) 112 | 113 | # compute pairwise intersection 114 | x1y1 = bbox_ltrb[:, :, [0, 1]] 115 | x1y1 = torch.swapaxes(x1y1, 1, 2) 116 | x1y1_I = torch.max(x1y1.unsqueeze(-1), x1y1.unsqueeze(-2)) 117 | 118 | x2y2 = bbox_ltrb[:, :, [2, 3]] 119 | x2y2 = torch.swapaxes(x2y2, 1, 2) 120 | x2y2_I = torch.min(x2y2.unsqueeze(-1), x2y2.unsqueeze(-2)) 121 | # compute area of Is 122 | wh_I = F.relu(x2y2_I - x1y1_I) 123 | area_I = wh_I[:, 0, :, :] * wh_I[:, 1, :, :] 124 | 125 | # compute pairwise IoU 126 | piou = area_I / (area_bbox_psum - area_I + 1e-10) 127 | 128 | piou.masked_fill_(torch.eye(n_box, n_box).to(torch.bool).to(device), 0) 129 | 130 | if mask is not None: 131 | mask = mask.unsqueeze(2) 132 | select_mask = torch.matmul(mask, torch.transpose(mask, dim0=1, dim1=2)) 133 | piou = piou * select_mask.to(device) 134 | 135 | return piou 136 | 137 | 138 | def PIoU_xywh(bbox_xywh, mask=None, xy_only=True): 139 | 140 | if xy_only: 141 | wh = torch.abs(bbox_xywh[:, :, 2:].clone().detach()) 142 | bbox = torch.cat([bbox_xywh[:, :, :2], wh], dim=2) 143 | bbox_ltrb = xywh_2_ltrb(bbox) 144 | else: 145 | bbox_ltrb = xywh_2_ltrb(bbox_xywh) 146 | 147 | return PIoU_ltrb(bbox_ltrb, mask) 148 | 149 | 150 | def Pdist(bbox): 151 | xy = bbox[:, :, :2] 152 | pdist_m = torch.cdist(xy, xy, p=2) 153 | 154 | return pdist_m 155 | 156 | def layout_alignment(bbox, mask, xy_only=False, mode='all'): 157 | """ 158 | alignment metrics in Attribute-conditioned Layout GAN for Automatic Graphic Design (TVCG2020) 159 | https://arxiv.org/abs/2009.05284 160 | """ 161 | 162 | if xy_only: 163 | wh = torch.abs(bbox[:, :, 2:].clone().detach()) 164 | bbox = torch.cat([bbox[:, :, :2], wh], dim=2) 165 | 166 | bbox = bbox.permute(2, 0, 1) 167 | xl, yt, xr, yb = xywh_to_ltrb_split(bbox) 168 | xc, yc = bbox[0], bbox[1] 169 | if mode == 'all': 170 | X = torch.stack([xl, xc, xr, yt, yc, yb], dim=1) 171 | elif mode == 'partial': 172 | X = torch.stack([xl, xc, yt, yb], dim=1) 173 | else: 174 | raise Exception('mode must be all or partial') 175 | 176 | X = X.unsqueeze(-1) - X.unsqueeze(-2) 177 | idx = torch.arange(X.size(2), device=X.device) 178 | X[:, :, idx, idx] = 1.0 179 | X = X.abs().permute(0, 2, 1, 3) 180 | X[~mask] = 1.0 181 | 182 | X = X.min(-1).values.min(-1).values 183 | X.masked_fill_(X.eq(1.0), 0.0) 184 | X = -torch.log(1 - X) 185 | 186 | score = einops.reduce(X, "b s -> b", reduction="sum") 187 | score_normalized = score / einops.reduce(mask, "b s -> b", reduction="sum") 188 | score_normalized[torch.isnan(score_normalized)] = 0.0 189 | 190 | return score, score_normalized 191 | 192 | 193 | def layout_alignment_matrix(bbox, mask): 194 | bbox = bbox.permute(2, 0, 1) 195 | xl, yt, xr, yb = xywh_to_ltrb_split(bbox) 196 | xc, yc = bbox[0], bbox[1] 197 | X = torch.stack([xl, xc, xr, yt, yc, yb], dim=1) 198 | X = X.unsqueeze(-1) - X.unsqueeze(-2) 199 | idx = torch.arange(X.size(2), device=X.device) 200 | X[:, :, idx, idx] = 1.0 201 | X = X.abs().permute(0, 2, 1, 3) 202 | X[~mask] = 1.0 203 | return X 204 | 205 | 206 | def mean_alignment_error(bbox_target, bbox_true, mask_true, th=1e-5, xy_only=False): 207 | """ 208 | mean coordinate difference error for aligned positions, a function for a batch 209 | tau_t: misalignment tolerance threshold 210 | th: threshold for alignment error in real-data 211 | mask_true: indices where the coordinate difference is smaller than th 212 | """ 213 | 214 | if xy_only: 215 | wh = torch.abs(bbox_target[:, :, 2:].clone().detach()) 216 | bbox = torch.cat([bbox_target[:, :, :2], wh], dim=2) 217 | else: 218 | bbox = bbox_target 219 | 220 | align_score_target = layout_alignment_matrix(bbox, mask_true) 221 | 222 | align_score_true = layout_alignment_matrix(bbox_true, mask_true) 223 | align_mask = (align_score_true < th).clone().detach() 224 | 225 | selected_difference = align_score_target * align_mask 226 | 227 | mae = einops.reduce(selected_difference, "n a b c -> n", reduction="sum") 228 | 229 | return mae 230 | 231 | 232 | def constraint_temporal_weight(t, schedule="linear", num_timesteps=1000, start=1e-5, end=1e-2): 233 | if schedule == "linear": 234 | w = 1 - torch.linspace(start, end, num_timesteps) 235 | elif schedule == "const": 236 | w = 1 - 4 * end * torch.ones(num_timesteps) 237 | elif schedule == "quad": 238 | w = 1 - torch.linspace(start ** 0.5, end ** 0.5, num_timesteps) ** 2 239 | elif schedule == "jsd": 240 | w = 1 - 1.0 / torch.linspace(num_timesteps, 1, num_timesteps) 241 | elif schedule == "sigmoid": 242 | betas = torch.linspace(-6, 6, num_timesteps) 243 | w = 1 - torch.sigmoid(betas) * (end - start) + start 244 | elif schedule == "cosine" or schedule == "cosine_reverse": 245 | max_beta = 0.999 246 | cosine_s = 0.008 247 | w = 1 - torch.tensor( 248 | [min(1 - (math.cos(((i + 1) / num_timesteps + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2) / ( 249 | math.cos((i / num_timesteps + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2), max_beta) for i in 250 | range(num_timesteps)]) 251 | elif schedule == "cosine_anneal": 252 | w = 1 - torch.tensor( 253 | [start + 0.5 * (end - start) * (1 - math.cos(t / (num_timesteps - 1) * math.pi)) for t in 254 | range(num_timesteps)]) 255 | 256 | weight = w.cumprod(dim=0).to(t.device) 257 | 258 | return weight[t] 259 | 260 | 261 | def post_process(bbox, mask_generated, xy_only=False, w_o=1, w_a=1): 262 | 263 | # print('beautify') 264 | if torch.sum(mask_generated) == 1: 265 | return bbox, mask_generated 266 | 267 | if xy_only: 268 | wh = torch.abs(bbox[:, :, 2:].clone().detach()) 269 | bbox_in = torch.cat([bbox[:, :, :2], wh], dim=2) 270 | else: 271 | bbox_in = bbox 272 | 273 | bbox_in[:, :, [0, 2]] *= 10 / 4 274 | bbox_in[:, :, [1, 3]] *= 10 / 6 275 | 276 | bbox_initial = bbox_in.clone().detach() 277 | mse_loss = nn.MSELoss() 278 | 279 | bbox_p = nn.Parameter(bbox_in) 280 | optimizer = optim.Adam([bbox_p], lr=1e-4, weight_decay=0.0, betas=(0.9, 0.999), amsgrad=False, eps=1e-08) 281 | with torch.enable_grad(): 282 | for i in range(1000): 283 | bbox_1 = torch.relu(bbox_p) 284 | align_score_target = layout_alignment_matrix(bbox_1, mask_generated) 285 | align_mask = (align_score_target < 1/64).clone().detach() 286 | align_loss = torch.mean(align_score_target * align_mask) 287 | 288 | piou_m = PIoU_xywh(bbox_1, mask=mask_generated.to(torch.float32), xy_only=True) 289 | piou = torch.mean(piou_m) 290 | 291 | mse = mse_loss(bbox_1, bbox_initial) 292 | loss = 1 * mse + w_a * align_loss + w_o * piou 293 | optimizer.zero_grad() 294 | loss.backward() 295 | torch.nn.utils.clip_grad_norm_([bbox_p], 1.0) 296 | optimizer.step() 297 | 298 | a, _ = torch.min(bbox_1[:, :, [2, 3]], dim=2) 299 | mask_generated = mask_generated * (a > 0.01) 300 | 301 | bbox_out = torch.relu(bbox_p) 302 | bbox_out[:, :, [0, 2]] *= 4 / 10 303 | bbox_out[:, :, [1, 3]] *= 6 / 10 304 | 305 | return bbox_out, mask_generated 306 | 307 | 308 | if __name__ == "__main__": 309 | 310 | w = constraint_temporal_weight(torch.tensor([225]), schedule="const") 311 | print(w) -------------------------------------------------------------------------------- /util/data_util.py: -------------------------------------------------------------------------------- 1 | import random 2 | from enum import IntEnum 3 | from itertools import combinations, product 4 | from typing import List, Tuple 5 | 6 | import numpy as np 7 | import torch 8 | import torchvision.transforms as T 9 | from torch import BoolTensor, FloatTensor, LongTensor 10 | from torch_geometric.utils import to_dense_batch 11 | from .constraint import xywh_to_ltrb_split 12 | 13 | 14 | class RelSize(IntEnum): 15 | UNKNOWN = 0 16 | SMALLER = 1 17 | EQUAL = 2 18 | LARGER = 3 19 | 20 | 21 | class RelLoc(IntEnum): 22 | UNKNOWN = 4 23 | LEFT = 5 24 | TOP = 6 25 | RIGHT = 7 26 | BOTTOM = 8 27 | CENTER = 9 28 | 29 | 30 | REL_SIZE_ALPHA = 0.1 31 | 32 | 33 | def detect_size_relation(b1, b2): 34 | a1 = b1[2] * b1[3] 35 | a2 = b2[2] * b2[3] 36 | alpha = REL_SIZE_ALPHA 37 | if (1 - alpha) * a1 < a2 < (1 + alpha) * a1: 38 | return RelSize.EQUAL 39 | elif a1 < a2: 40 | return RelSize.LARGER 41 | else: 42 | return RelSize.SMALLER 43 | 44 | 45 | def detect_loc_relation(b1, b2, is_canvas=False): 46 | if is_canvas: 47 | yc = b2[1] 48 | if yc < 1.0 / 3: 49 | return RelLoc.TOP 50 | elif yc < 2.0 / 3: 51 | return RelLoc.CENTER 52 | else: 53 | return RelLoc.BOTTOM 54 | 55 | else: 56 | l1, t1, r1, b1 = xywh_to_ltrb_split(b1) 57 | l2, t2, r2, b2 = xywh_to_ltrb_split(b2) 58 | 59 | if b2 <= t1: 60 | return RelLoc.TOP 61 | elif b1 <= t2: 62 | return RelLoc.BOTTOM 63 | elif r2 <= l1: 64 | return RelLoc.LEFT 65 | elif r1 <= l2: 66 | return RelLoc.RIGHT 67 | else: 68 | # might not be necessary 69 | return RelLoc.CENTER 70 | 71 | 72 | def get_rel_text(rel, canvas=False): 73 | if type(rel) == RelSize: 74 | index = rel - RelSize.UNKNOWN - 1 75 | if canvas: 76 | return [ 77 | "within canvas", 78 | "spread over canvas", 79 | "out of canvas", 80 | ][index] 81 | 82 | else: 83 | return [ 84 | "larger than", 85 | "equal to", 86 | "smaller than", 87 | ][index] 88 | 89 | else: 90 | index = rel - RelLoc.UNKNOWN - 1 91 | if canvas: 92 | return [ 93 | "", 94 | "at top", 95 | "", 96 | "at bottom", 97 | "at middle", 98 | ][index] 99 | 100 | else: 101 | return [ 102 | "right to", 103 | "below", 104 | "left to", 105 | "above", 106 | "around", 107 | ][index] 108 | 109 | 110 | # transform 111 | class AddCanvasElement: 112 | x = torch.tensor([[0.5, 0.5, 1.0, 1.0]], dtype=torch.float) 113 | y = torch.tensor([0], dtype=torch.long) 114 | 115 | def __call__(self, data): 116 | flag = data.attr["has_canvas_element"].any().item() 117 | assert not flag 118 | if not flag: 119 | # device = data.x.device 120 | # x, y = self.x.to(device), self.y.to(device) 121 | data.x = torch.cat([self.x, data.x], dim=0) 122 | data.y = torch.cat([self.y, data.y + 1], dim=0) 123 | data.attr = data.attr.copy() 124 | data.attr["has_canvas_element"] = True 125 | return data 126 | 127 | 128 | class AddRelationConstraints: 129 | def __init__(self, seed=None, edge_ratio=0.1, use_v1=False): 130 | self.edge_ratio = edge_ratio 131 | self.use_v1 = use_v1 132 | self.generator = random.Random() 133 | if seed is not None: 134 | self.generator.seed(seed) 135 | 136 | def __call__(self, data): 137 | N = data.x.size(0) 138 | has_canvas = data.attr["has_canvas_element"] 139 | 140 | rel_all = list(product(range(2), combinations(range(N), 2))) 141 | size = int(len(rel_all) * self.edge_ratio) 142 | rel_sample = set(self.generator.sample(rel_all, size)) 143 | 144 | edge_index, edge_attr = [], [] 145 | rel_unk = 1 << RelSize.UNKNOWN | 1 << RelLoc.UNKNOWN 146 | for i, j in combinations(range(N), 2): 147 | bi, bj = data.x[i], data.x[j] 148 | canvas = data.y[i] == 0 and has_canvas 149 | 150 | if self.use_v1: 151 | if (0, (i, j)) in rel_sample: 152 | rel_size = 1 << detect_size_relation(bi, bj) 153 | rel_loc = 1 << detect_loc_relation(bi, bj, canvas) 154 | else: 155 | rel_size = 1 << RelSize.UNKNOWN 156 | rel_loc = 1 << RelLoc.UNKNOWN 157 | else: 158 | if (0, (i, j)) in rel_sample: 159 | rel_size = 1 << detect_size_relation(bi, bj) 160 | else: 161 | rel_size = 1 << RelSize.UNKNOWN 162 | 163 | if (1, (i, j)) in rel_sample: 164 | rel_loc = 1 << detect_loc_relation(bi, bj, canvas) 165 | else: 166 | rel_loc = 1 << RelLoc.UNKNOWN 167 | 168 | rel = rel_size | rel_loc 169 | if rel != rel_unk: 170 | edge_index.append((i, j)) 171 | edge_attr.append(rel) 172 | 173 | data.edge_index = torch.as_tensor(edge_index).long() 174 | data.edge_index = data.edge_index.t().contiguous() 175 | data.edge_attr = torch.as_tensor(edge_attr).long() 176 | 177 | return data 178 | 179 | 180 | class RandomOrder: 181 | def __call__(self, data): 182 | assert not data.attr["has_canvas_element"] 183 | device = data.x.device 184 | N = data.x.size(0) 185 | idx = torch.randperm(N, device=device) 186 | data.x, data.y = data.x[idx], data.y[idx] 187 | return data 188 | 189 | 190 | class SortByLabel: 191 | def __call__(self, data): 192 | assert not data.attr["has_canvas_element"] 193 | idx = data.y.sort().indices 194 | data.x, data.y = data.x[idx], data.y[idx] 195 | return data 196 | 197 | 198 | class LexicographicOrder: 199 | def __call__(self, data): 200 | assert not data.attr["has_canvas_element"] 201 | x, y, _, _ = xywh_to_ltrb_split(data.x.t()) 202 | _zip = zip(*sorted(enumerate(zip(y, x)), key=lambda c: c[1:])) 203 | idx = list(list(_zip)[0]) 204 | 205 | data.x_orig, data.y_orig = data.x, data.y 206 | data.x, data.y = data.x[idx], data.y[idx] 207 | return data 208 | 209 | 210 | class AddNoiseToBBox: 211 | def __init__(self, std: float = 0.05): 212 | self.std = float(std) 213 | 214 | def __call__(self, data): 215 | noise = torch.normal(0, self.std, size=data.x.size(), device=data.x.device) 216 | data.x_orig = data.x.clone() 217 | data.x = data.x + noise 218 | data.attr = data.attr.copy() 219 | data.attr["NoiseAdded"][0] = True 220 | return data 221 | 222 | 223 | class HorizontalFlip: 224 | def __call__(self, data): 225 | data.x = data.x.clone() 226 | data.x[:, 0] = 1 - data.x[:, 0] 227 | return data 228 | 229 | 230 | # def compose_transform(transforms): 231 | # module = sys.modules[__name__] 232 | # transform_list = [] 233 | # for t in transforms: 234 | # # parse args 235 | # if "(" in t and ")" in t: 236 | # args = t[t.index("(") + 1 : t.index(")")] 237 | # t = t[: t.index("(")] 238 | # regex = re.compile(r"\b(\w+)=(.*?)(?=\s\w+=\s*|$)") 239 | # args = dict(regex.findall(args)) 240 | # for k in args: 241 | # try: 242 | # args[k] = float(args[k]) 243 | # except: 244 | # pass 245 | # else: 246 | # args = {} 247 | # if isinstance(t, str): 248 | # if hasattr(module, t): 249 | # transform_list.append(getattr(module, t)(**args)) 250 | # else: 251 | # raise NotImplementedError 252 | # else: 253 | # raise NotImplementedError 254 | # return T.Compose(transform_list) 255 | 256 | 257 | def compose_transform(transforms: List[str]) -> T.Compose: 258 | """ 259 | Compose transforms, optionally with args (e.g., AddRelationConstraints(edge_ratio=0.1)) 260 | """ 261 | transform_list = [] 262 | for t in transforms: 263 | if "(" in t and ")" in t: 264 | pass 265 | else: 266 | t += "()" 267 | transform_list.append(eval(t)) 268 | return T.Compose(transform_list) 269 | 270 | 271 | def sparse_to_dense( 272 | batch, 273 | device: torch.device = torch.device("cpu"), 274 | remove_canvas: bool = False, 275 | ) -> Tuple[FloatTensor, LongTensor, BoolTensor, BoolTensor]: 276 | batch = batch.to(device) 277 | bbox, _ = to_dense_batch(batch.x, batch.batch) 278 | label, mask = to_dense_batch(batch.y, batch.batch) 279 | 280 | if remove_canvas: 281 | bbox = bbox[:, 1:].contiguous() 282 | label = label[:, 1:].contiguous() - 1 # cancel +1 effect in transform 283 | label = label.clamp(min=0) 284 | mask = mask[:, 1:].contiguous() 285 | 286 | padding_mask = ~mask 287 | return bbox, label, padding_mask, mask 288 | 289 | 290 | def loader_to_list( 291 | loader: torch.utils.data.dataloader.DataLoader, 292 | ) -> List[Tuple[np.ndarray, np.ndarray]]: 293 | layouts = [] 294 | for batch in loader: 295 | bbox, label, _, mask = sparse_to_dense(batch) 296 | for i in range(len(label)): 297 | valid = mask[i].numpy() 298 | layouts.append((bbox[i].numpy()[valid], label[i].numpy()[valid])) 299 | return layouts 300 | 301 | 302 | def split_num_samples(N: int, batch_size: int) -> List[int]: 303 | quontinent = N // batch_size 304 | remainder = N % batch_size 305 | dataloader = quontinent * [batch_size] 306 | if remainder > 0: 307 | dataloader.append(remainder) 308 | return dataloader 309 | -------------------------------------------------------------------------------- /util/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .publaynet import PubLayNetDataset 2 | from .rico import Rico25Dataset 3 | 4 | _DATASETS = [ 5 | Rico25Dataset, 6 | PubLayNetDataset, 7 | ] 8 | DATASETS = {d.name: d for d in _DATASETS} 9 | -------------------------------------------------------------------------------- /util/datasets/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import fsspec 4 | import seaborn as sns 5 | import torch 6 | from fsspec.core import url_to_fs 7 | 8 | from .dataset import InMemoryDataset 9 | 10 | 11 | class BaseDataset(InMemoryDataset): 12 | name = None 13 | labels = [] 14 | _label2index = None 15 | _index2label = None 16 | _colors = None 17 | 18 | def __init__(self, dir: str, split: str, max_seq_length: int, transform=None): 19 | assert split in ["train", "val", "test"] 20 | name = f"{self.name}-max{max_seq_length}" 21 | self.max_seq_length = max_seq_length 22 | super().__init__(os.path.join(dir, name), transform) 23 | idx = self.processed_file_names.index("{}.pt".format(split)) 24 | 25 | with fsspec.open(self.processed_paths[idx], "rb") as file_obj: 26 | self.data, self.slices = torch.load(file_obj) 27 | 28 | @property 29 | def label2index(self): 30 | if self._label2index is None: 31 | self._label2index = dict() 32 | for idx, label in enumerate(self.labels): 33 | self._label2index[label] = idx 34 | return self._label2index 35 | 36 | @property 37 | def index2label(self): 38 | if self._index2label is None: 39 | self._index2label = dict() 40 | for idx, label in enumerate(self.labels): 41 | self._index2label[idx] = label 42 | return self._index2label 43 | 44 | @property 45 | def colors(self): 46 | if self._colors is None: 47 | n_colors = self.num_classes 48 | colors = sns.color_palette("husl", n_colors=n_colors) 49 | self._colors = [tuple(map(lambda x: int(x * 255), c)) for c in colors] 50 | return self._colors 51 | 52 | @property 53 | def raw_file_names(self): 54 | fs, _ = url_to_fs(self.raw_dir) 55 | if not fs.exists(self.raw_dir): 56 | return [] 57 | file_names = [f.split("/")[-1] for f in fs.ls(self.raw_dir)] 58 | return file_names 59 | 60 | @property 61 | def processed_file_names(self): 62 | return ["train.pt", "val.pt", "test.pt"] 63 | 64 | def download(self): 65 | raise FileNotFoundError("See dataset/README.md") 66 | 67 | def process(self): 68 | raise NotImplementedError 69 | 70 | def get_original_images(self): 71 | raise NotImplementedError 72 | -------------------------------------------------------------------------------- /util/datasets/crello.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import random 4 | random.seed(1) 5 | from fsspec.core import url_to_fs 6 | import numpy as np 7 | import torch 8 | from torch_geometric.data import Data 9 | from .base import BaseDataset 10 | import datasets 11 | 12 | def ltrb_2_xywh(bbox_ltrb): 13 | bbox_xywh = torch.zeros(bbox_ltrb.shape) 14 | bbox_wh = torch.abs(bbox_ltrb[:, 2:] - bbox_ltrb[:, :2]) 15 | bbox_xy = bbox_ltrb[:, :2] + 0.5 * bbox_wh 16 | bbox_xywh[:, :2] = bbox_xy 17 | bbox_xywh[:, 2:] = bbox_wh 18 | return bbox_xywh 19 | 20 | class CrelloDataset(BaseDataset): 21 | name = "crello" 22 | label_names = ['coloredBackground', 'imageElement', 'maskElement', 'svgElement', 'textElement'] 23 | label_dict = {k: i for i, k in enumerate(label_names)} 24 | 25 | def __init__(self, dir: str, split: str, max_seq_length: int, transform=None): 26 | super().__init__(dir, split, max_seq_length, transform) 27 | 28 | def download(self): 29 | _ = datasets.load_dataset("cyberagent/crello") 30 | 31 | def process(self): 32 | fs, _ = url_to_fs(self.raw_dir) 33 | for split_i, split in enumerate(['train', 'validation', 'test']): 34 | 35 | dataset = datasets.load_dataset("cyberagent/crello", split=split) 36 | # label_names = dataset.features['type'].feature.names 37 | width_list = [int(x) for x in dataset.features['canvas_width'].names] 38 | height_list = [int(x) for x in dataset.features['canvas_height'].names] 39 | 40 | data_list = [] 41 | for i, elements in tqdm(enumerate(dataset), total=len(dataset), desc=f'split: {split}', ncols=150): 42 | labels = torch.tensor(elements['type']).to(torch.long) 43 | wi = elements['canvas_width'] 44 | hi = elements['canvas_height'] 45 | W = width_list[wi] 46 | H = height_list[hi] 47 | left = torch.tensor(elements['left']) 48 | top = torch.tensor(elements['top']) 49 | width = torch.tensor(elements['width']) 50 | height = torch.tensor(elements['height']) 51 | 52 | def is_valid(x1, y1, width, height): 53 | 54 | if torch.min(width) <= 0 or torch.min(height) <= 0: 55 | return None, False 56 | 57 | x2, y2 = x1 + width, y1 + height 58 | bboxes = torch.stack([x1, y1, x2, y2], dim=1) 59 | if torch.max(bboxes) > 1 or torch.min(bboxes) < 0: 60 | bboxes_ltbr = torch.clamp(bboxes, min=0, max=1) 61 | bboxes = ltrb_2_xywh(bboxes_ltbr) 62 | return bboxes, False 63 | 64 | return bboxes, True 65 | 66 | bboxes, filtered = is_valid(left, top, width, height) 67 | if bboxes is None: 68 | continue 69 | 70 | if bboxes.shape[0] == 0 or bboxes.shape[0] > self.max_seq_length: 71 | continue 72 | 73 | data = Data(x=bboxes, y=labels) 74 | data.attr = { 75 | "name": elements['id'], 76 | "width": W, 77 | "height": H, 78 | "filtered": filtered, 79 | "has_canvas_element": False, 80 | "NoiseAdded": False, 81 | } 82 | data_list.append(data) 83 | 84 | with fs.open(self.processed_paths[split_i], "wb") as file_obj: 85 | if split_i == 0: 86 | print('duplicate training split (x10) for more batches per epoch') 87 | torch.save(self.collate(data_list * 10), file_obj) 88 | else: 89 | torch.save(self.collate(data_list), file_obj) 90 | 91 | 92 | 93 | if __name__ == '__main__': 94 | 95 | data_dir = '../../download/datasets' 96 | # get_train_labels(data_dir) 97 | # get_val_test_labels(data_dir) 98 | 99 | # train_dataset = MagazineDataset(dir=data_dir, split='train') 100 | # print(len(train_dataset)) 101 | # train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=False, num_workers=4) 102 | # for bbox, label, mask in tqdm(train_loader): 103 | # print(bbox[0]) 104 | # print(label[0]) 105 | # print(mask[0]) 106 | # break 107 | 108 | -------------------------------------------------------------------------------- /util/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | # slightly modified to cope with remote folders/files starting with gs:// 2 | import copy 3 | import os.path as osp 4 | import re 5 | import sys 6 | import warnings 7 | from collections.abc import Mapping, Sequence 8 | from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union 9 | import numpy as np 10 | import torch 11 | import torch.utils.data 12 | from fsspec.core import url_to_fs 13 | from torch import Tensor 14 | from torch_geometric.data import Data 15 | from torch_geometric.data.collate import collate 16 | from torch_geometric.data.dataset import Dataset 17 | from torch_geometric.data.makedirs import makedirs 18 | from torch_geometric.data.separate import separate 19 | 20 | IndexType = Union[slice, Tensor, np.ndarray, Sequence] 21 | 22 | 23 | class Dataset(torch.utils.data.Dataset): 24 | r"""Dataset base class for creating graph datasets. 25 | See `here `__ for the accompanying tutorial. 27 | 28 | Args: 29 | root (string, optional): Root directory where the dataset should be 30 | saved. (optional: :obj:`None`) 31 | transform (callable, optional): A function/transform that takes in an 32 | :obj:`torch_geometric.data.Data` object and returns a transformed 33 | version. The data object will be transformed before every access. 34 | (default: :obj:`None`) 35 | pre_transform (callable, optional): A function/transform that takes in 36 | an :obj:`torch_geometric.data.Data` object and returns a 37 | transformed version. The data object will be transformed before 38 | being saved to disk. (default: :obj:`None`) 39 | pre_filter (callable, optional): A function that takes in an 40 | :obj:`torch_geometric.data.Data` object and returns a boolean 41 | value, indicating whether the data object should be included in the 42 | final dataset. (default: :obj:`None`) 43 | """ 44 | 45 | @property 46 | def raw_file_names(self) -> Union[str, List[str], Tuple]: 47 | r"""The name of the files in the :obj:`self.raw_dir` folder that must 48 | be present in order to skip downloading.""" 49 | raise NotImplementedError 50 | 51 | @property 52 | def processed_file_names(self) -> Union[str, List[str], Tuple]: 53 | r"""The name of the files in the :obj:`self.processed_dir` folder that 54 | must be present in order to skip processing.""" 55 | raise NotImplementedError 56 | 57 | def download(self): 58 | r"""Downloads the dataset to the :obj:`self.raw_dir` folder.""" 59 | raise NotImplementedError 60 | 61 | def process(self): 62 | r"""Processes the dataset to the :obj:`self.processed_dir` folder.""" 63 | raise NotImplementedError 64 | 65 | def len(self) -> int: 66 | r"""Returns the number of graphs stored in the dataset.""" 67 | raise NotImplementedError 68 | 69 | def get(self, idx: int) -> Data: 70 | r"""Gets the data object at index :obj:`idx`.""" 71 | raise NotImplementedError 72 | 73 | def __init__( 74 | self, 75 | root: Optional[str] = None, 76 | transform: Optional[Callable] = None, 77 | pre_transform: Optional[Callable] = None, 78 | pre_filter: Optional[Callable] = None, 79 | ): 80 | super().__init__() 81 | 82 | if isinstance(root, str): 83 | if not root.startswith("gs://"): 84 | root = osp.expanduser(osp.normpath(root)) 85 | 86 | self.root = root 87 | self.transform = transform 88 | self.pre_transform = pre_transform 89 | self.pre_filter = pre_filter 90 | self._indices: Optional[Sequence] = None 91 | 92 | if self.download.__qualname__.split(".")[0] != "Dataset": 93 | self._download() 94 | 95 | if self.process.__qualname__.split(".")[0] != "Dataset": 96 | fs, _ = url_to_fs(self.processed_dir) 97 | if not all(fs.exists(p) for p in self.processed_paths): 98 | self._process() 99 | 100 | def indices(self) -> Sequence: 101 | return range(self.len()) if self._indices is None else self._indices 102 | 103 | @property 104 | def raw_dir(self) -> str: 105 | return osp.join(self.root, "raw") 106 | 107 | @property 108 | def processed_dir(self) -> str: 109 | return osp.join(self.root, "processed") 110 | 111 | @property 112 | def num_node_features(self) -> int: 113 | r"""Returns the number of features per node in the dataset.""" 114 | data = self[0] 115 | data = data[0] if isinstance(data, tuple) else data 116 | if hasattr(data, "num_node_features"): 117 | return data.num_node_features 118 | raise AttributeError( 119 | f"'{data.__class__.__name__}' object has no " 120 | f"attribute 'num_node_features'" 121 | ) 122 | 123 | @property 124 | def num_features(self) -> int: 125 | r"""Returns the number of features per node in the dataset. 126 | Alias for :py:attr:`~num_node_features`.""" 127 | return self.num_node_features 128 | 129 | @property 130 | def num_edge_features(self) -> int: 131 | r"""Returns the number of features per edge in the dataset.""" 132 | data = self[0] 133 | data = data[0] if isinstance(data, tuple) else data 134 | if hasattr(data, "num_edge_features"): 135 | return data.num_edge_features 136 | raise AttributeError( 137 | f"'{data.__class__.__name__}' object has no " 138 | f"attribute 'num_edge_features'" 139 | ) 140 | 141 | @property 142 | def raw_paths(self) -> List[str]: 143 | r"""The absolute filepaths that must be present in order to skip 144 | downloading.""" 145 | files = to_list(self.raw_file_names) 146 | return [osp.join(self.raw_dir, f) for f in files] 147 | 148 | @property 149 | def processed_paths(self) -> List[str]: 150 | r"""The absolute filepaths that must be present in order to skip 151 | processing.""" 152 | files = to_list(self.processed_file_names) 153 | return [osp.join(self.processed_dir, f) for f in files] 154 | 155 | def _download(self): 156 | if files_exist(self.raw_paths): # pragma: no cover 157 | return 158 | 159 | makedirs(self.raw_dir) 160 | self.download() 161 | 162 | def _process(self): 163 | f = osp.join(self.processed_dir, "pre_transform.pt") 164 | fs, _ = url_to_fs(self.processed_dir) 165 | if fs.exists(f): 166 | with fs.open(f, "rb") as file_obj: 167 | x = torch.load(file_obj) 168 | if x != _repr(self.pre_transform): 169 | warnings.warn( 170 | f"The `pre_transform` argument differs from the one used in " 171 | f"the pre-processed version of this dataset. If you want to " 172 | f"make use of another pre-processing technique, make sure to " 173 | f"delete '{self.processed_dir}' first" 174 | ) 175 | 176 | f = osp.join(self.processed_dir, "pre_filter.pt") 177 | if fs.exists(f): 178 | with fs.open(f, "rb") as file_obj: 179 | x = torch.load(file_obj) 180 | if x != _repr(self.pre_filter): 181 | warnings.warn( 182 | "The `pre_filter` argument differs from the one used in " 183 | "the pre-processed version of this dataset. If you want to " 184 | "make use of another pre-fitering technique, make sure to " 185 | "delete '{self.processed_dir}' first" 186 | ) 187 | 188 | if files_exist(self.processed_paths): # pragma: no cover 189 | return 190 | 191 | print("Processing...", file=sys.stderr) 192 | 193 | makedirs(self.processed_dir) 194 | self.process() 195 | 196 | path = osp.join(self.processed_dir, "pre_transform.pt") 197 | with fs.open(path, "wb") as file_obj: 198 | torch.save(_repr(self.pre_transform), file_obj) 199 | path = osp.join(self.processed_dir, "pre_filter.pt") 200 | with fs.open(path, "wb") as file_obj: 201 | torch.save(_repr(self.pre_filter), file_obj) 202 | 203 | print("Done!", file=sys.stderr) 204 | 205 | def __len__(self) -> int: 206 | r"""The number of examples in the dataset.""" 207 | return len(self.indices()) 208 | 209 | def __getitem__( 210 | self, 211 | idx: Union[int, np.integer, IndexType], 212 | ) -> Union["Dataset", Data]: 213 | r"""In case :obj:`idx` is of type integer, will return the data object 214 | at index :obj:`idx` (and transforms it in case :obj:`transform` is 215 | present). 216 | In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a 217 | tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type long or 218 | bool, will return a subset of the dataset at the specified indices.""" 219 | if ( 220 | isinstance(idx, (int, np.integer)) 221 | or (isinstance(idx, Tensor) and idx.dim() == 0) 222 | or (isinstance(idx, np.ndarray) and np.isscalar(idx)) 223 | ): 224 | data = self.get(self.indices()[idx]) 225 | data = data if self.transform is None else self.transform(data) 226 | return data 227 | 228 | else: 229 | return self.index_select(idx) 230 | 231 | def index_select(self, idx: IndexType) -> "Dataset": 232 | r"""Creates a subset of the dataset from specified indices :obj:`idx`. 233 | Indices :obj:`idx` can be a slicing object, *e.g.*, :obj:`[2:5]`, a 234 | list, a tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type 235 | long or bool.""" 236 | indices = self.indices() 237 | 238 | if isinstance(idx, slice): 239 | indices = indices[idx] 240 | 241 | elif isinstance(idx, Tensor) and idx.dtype == torch.long: 242 | return self.index_select(idx.flatten().tolist()) 243 | 244 | elif isinstance(idx, Tensor) and idx.dtype == torch.bool: 245 | idx = idx.flatten().nonzero(as_tuple=False) 246 | return self.index_select(idx.flatten().tolist()) 247 | 248 | elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: 249 | return self.index_select(idx.flatten().tolist()) 250 | 251 | elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: 252 | idx = idx.flatten().nonzero()[0] 253 | return self.index_select(idx.flatten().tolist()) 254 | 255 | elif isinstance(idx, Sequence) and not isinstance(idx, str): 256 | indices = [indices[i] for i in idx] 257 | 258 | else: 259 | raise IndexError( 260 | f"Only slices (':'), list, tuples, torch.tensor and " 261 | f"np.ndarray of dtype long or bool are valid indices (got " 262 | f"'{type(idx).__name__}')" 263 | ) 264 | 265 | dataset = copy.copy(self) 266 | dataset._indices = indices 267 | return dataset 268 | 269 | def shuffle( 270 | self, 271 | return_perm: bool = False, 272 | ) -> Union["Dataset", Tuple["Dataset", Tensor]]: 273 | r"""Randomly shuffles the examples in the dataset. 274 | 275 | Args: 276 | return_perm (bool, optional): If set to :obj:`True`, will also 277 | return the random permutation used to shuffle the dataset. 278 | (default: :obj:`False`) 279 | """ 280 | perm = torch.randperm(len(self)) 281 | dataset = self.index_select(perm) 282 | return (dataset, perm) if return_perm is True else dataset 283 | 284 | def __repr__(self) -> str: 285 | arg_repr = str(len(self)) if len(self) > 1 else "" 286 | return f"{self.__class__.__name__}({arg_repr})" 287 | 288 | 289 | def to_list(value: Any) -> Sequence: 290 | if isinstance(value, Sequence) and not isinstance(value, str): 291 | return value 292 | else: 293 | return [value] 294 | 295 | 296 | def files_exist(files: List[str]) -> bool: 297 | # NOTE: We return `False` in case `files` is empty, leading to a 298 | # re-processing of files on every instantiation. 299 | if len(files) == 0: 300 | return False 301 | else: 302 | fs, _ = url_to_fs(files[0]) 303 | return all([fs.exists(f) for f in files]) 304 | 305 | 306 | def _repr(obj: Any) -> str: 307 | if obj is None: 308 | return "None" 309 | return re.sub("(<.*?)\\s.*(>)", r"\1\2", obj.__repr__()) 310 | 311 | 312 | class InMemoryDataset(Dataset): 313 | r"""Dataset base class for creating graph datasets which easily fit 314 | into CPU memory. 315 | Inherits from :class:`torch_geometric.data.Dataset`. 316 | See `here `__ for the accompanying 318 | tutorial. 319 | 320 | Args: 321 | root (string, optional): Root directory where the dataset should be 322 | saved. (default: :obj:`None`) 323 | transform (callable, optional): A function/transform that takes in an 324 | :obj:`torch_geometric.data.Data` object and returns a transformed 325 | version. The data object will be transformed before every access. 326 | (default: :obj:`None`) 327 | pre_transform (callable, optional): A function/transform that takes in 328 | an :obj:`torch_geometric.data.Data` object and returns a 329 | transformed version. The data object will be transformed before 330 | being saved to disk. (default: :obj:`None`) 331 | pre_filter (callable, optional): A function that takes in an 332 | :obj:`torch_geometric.data.Data` object and returns a boolean 333 | value, indicating whether the data object should be included in the 334 | final dataset. (default: :obj:`None`) 335 | """ 336 | 337 | @property 338 | def raw_file_names(self) -> Union[str, List[str], Tuple]: 339 | raise NotImplementedError 340 | 341 | @property 342 | def processed_file_names(self) -> Union[str, List[str], Tuple]: 343 | raise NotImplementedError 344 | 345 | def __init__( 346 | self, 347 | root: Optional[str] = None, 348 | transform: Optional[Callable] = None, 349 | pre_transform: Optional[Callable] = None, 350 | pre_filter: Optional[Callable] = None, 351 | ): 352 | super().__init__(root, transform, pre_transform, pre_filter) 353 | self.data = None 354 | self.slices = None 355 | self._data_list: Optional[List[Data]] = None 356 | 357 | @property 358 | def num_classes(self) -> int: 359 | r"""Returns the number of classes in the dataset.""" 360 | y = self.data.y 361 | if y is None: 362 | return 0 363 | elif y.numel() == y.size(0) and not torch.is_floating_point(y): 364 | return int(self.data.y.max()) + 1 365 | elif y.numel() == y.size(0) and torch.is_floating_point(y): 366 | return torch.unique(y).numel() 367 | else: 368 | return self.data.y.size(-1) 369 | 370 | def len(self) -> int: 371 | if self.slices is None: 372 | return 1 373 | for _, value in nested_iter(self.slices): 374 | return len(value) - 1 375 | return 0 376 | 377 | def get(self, idx: int) -> Data: 378 | if self.len() == 1: 379 | return copy.copy(self.data) 380 | 381 | if not hasattr(self, "_data_list") or self._data_list is None: 382 | self._data_list = self.len() * [None] 383 | elif self._data_list[idx] is not None: 384 | return copy.copy(self._data_list[idx]) 385 | 386 | data = separate( 387 | cls=self.data.__class__, 388 | batch=self.data, 389 | idx=idx, 390 | slice_dict=self.slices, 391 | decrement=False, 392 | ) 393 | 394 | self._data_list[idx] = copy.copy(data) 395 | 396 | return data 397 | 398 | @staticmethod 399 | def collate(data_list: List[Data]) -> Tuple[Data, Optional[Dict[str, Tensor]]]: 400 | r"""Collates a Python list of :obj:`torch_geometric.data.Data` objects 401 | to the internal storage format of 402 | :class:`~torch_geometric.data.InMemoryDataset`.""" 403 | if len(data_list) == 1: 404 | return data_list[0], None 405 | 406 | data, slices, _ = collate( 407 | data_list[0].__class__, 408 | data_list=data_list, 409 | increment=False, 410 | add_batch=False, 411 | ) 412 | 413 | return data, slices 414 | 415 | def copy(self, idx: Optional[IndexType] = None) -> "InMemoryDataset": 416 | r"""Performs a deep-copy of the dataset. If :obj:`idx` is not given, 417 | will clone the full dataset. Otherwise, will only clone a subset of the 418 | dataset from indices :obj:`idx`. 419 | Indices can be slices, lists, tuples, and a :obj:`torch.Tensor` or 420 | :obj:`np.ndarray` of type long or bool. 421 | """ 422 | if idx is None: 423 | data_list = [self.get(i) for i in self.indices()] 424 | else: 425 | data_list = [self.get(i) for i in self.index_select(idx).indices()] 426 | 427 | dataset = copy.copy(self) 428 | dataset._indices = None 429 | dataset._data_list = None 430 | dataset.data, dataset.slices = self.collate(data_list) 431 | return dataset 432 | 433 | 434 | def nested_iter(mapping: Mapping) -> Iterable: 435 | for key, value in mapping.items(): 436 | if isinstance(value, Mapping): 437 | for inner_key, inner_value in nested_iter(value): 438 | yield inner_key, inner_value 439 | else: 440 | yield key, value 441 | -------------------------------------------------------------------------------- /util/datasets/load_data.py: -------------------------------------------------------------------------------- 1 | from util.datasets.publaynet import PubLayNetDataset 2 | from util.datasets.magazine import MagazineDataset 3 | from util.datasets.rico import Rico25Dataset, Rico5Dataset, Rico13Dataset 4 | from util.datasets.crello import CrelloDataset 5 | import torch 6 | import torch_geometric 7 | from util.seq_util import sparse_to_dense, pad_until 8 | 9 | dataset_dict = {'publaynet': PubLayNetDataset, 'rico13': Rico13Dataset, 10 | 'rico25': Rico25Dataset, 'magazine': MagazineDataset, 'crello': CrelloDataset} 11 | 12 | 13 | def init_dataset(dataset_name, data_dir, batch_size=128, split='train', transform=None, shuffle=False): 14 | 15 | main_dataset = dataset_dict[dataset_name](dir=data_dir, split=split, max_seq_length=25, transform=transform) 16 | main_loader = torch_geometric.loader.DataLoader(main_dataset, shuffle=shuffle, batch_size=batch_size, 17 | num_workers=4, pin_memory=True) 18 | return main_dataset, main_loader 19 | 20 | 21 | if __name__=="__main__": 22 | 23 | device = "cuda" if torch.cuda.is_available() else "cpu" 24 | data_dir = '/Users/chenjian/Documents/Projects/Layout/LayoutDiff/datasets' 25 | 26 | train_dataset, train_loader = init_dataset('crello', data_dir, batch_size=120, split='train') 27 | print(train_dataset.num_classes) 28 | 29 | for data in train_loader: 30 | bbox, label, _, mask = sparse_to_dense(data) 31 | label, bbox, mask = pad_until(label, bbox, mask, max_seq_length=25) 32 | bbox = bbox.to(device) 33 | 34 | print(bbox[0]) 35 | break 36 | 37 | # 38 | 39 | -------------------------------------------------------------------------------- /util/datasets/magazine.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import random 4 | random.seed(1) 5 | from fsspec.core import url_to_fs 6 | import numpy as np 7 | import torch 8 | from torch_geometric.data import Data 9 | import torch_geometric 10 | from .base import BaseDataset 11 | import torchvision.transforms as transforms 12 | import xml.etree.ElementTree as ET 13 | 14 | magazine_cmap = { 15 | "text": (254, 231, 44), 16 | "image": (27, 187, 146), 17 | "headline": (255, 0, 0), 18 | "text-over-image": (0, 102, 255), 19 | "headline-over-image": (204, 0, 255), 20 | "background": (200, 200, 200), 21 | } 22 | 23 | label_dict = {"text": 0, "image": 1, "headline": 2, "text-over-image": 3, "headline-over-image": 4, "background": 5} 24 | 25 | 26 | class MagazineDataset(BaseDataset): 27 | name = "magazine" 28 | labels = ["text", "image", "headline", "text-over-image", "headline-over-image", "background"] 29 | 30 | def __init__(self, dir: str, split: str, max_seq_length: int, transform=None): 31 | super().__init__(dir, split, max_seq_length, transform) 32 | 33 | def download(self): 34 | # super().download() 35 | pass 36 | 37 | def process(self): 38 | 39 | fs, _ = url_to_fs(self.raw_dir) 40 | # self.data_root 41 | self.colormap = magazine_cmap 42 | self.labels = ("text", "image", "headline", "text-over-image", "headline-over-image", "background") 43 | 44 | file_folder = os.path.join(self.raw_dir, 'annotations') 45 | 46 | file_cls_wise = {'fashion': [], 'food': [], 'science': [], 'news': [], 'travel': [], 'wedding': []} 47 | for file_name in sorted(os.listdir(file_folder)): 48 | cls = file_name.split('_')[0] 49 | file_cls_wise[cls].append(file_name) 50 | 51 | file_lists = dict() 52 | file_lists['train'] = sum([file_cls_wise[k][:int(0.9 * len(file_cls_wise[k]))] for k in file_cls_wise.keys()], []) 53 | file_lists['test'] = sum([file_cls_wise[k][int(0.9 * len(file_cls_wise[k])):] for k in file_cls_wise.keys()], []) 54 | 55 | for split in ['train', 'test']: 56 | file_list = file_lists[split] 57 | data_list = [] 58 | 59 | for file_name in file_list: 60 | boxes = [] 61 | labels = [] 62 | 63 | tree = ET.parse(os.path.join(file_folder, file_name)) 64 | root = tree.getroot() 65 | try: 66 | for layout in root.findall('layout'): 67 | if len(layout.findall('element')) > self.max_seq_length: 68 | continue 69 | for i, element in enumerate(layout.findall('element')): 70 | label = element.get('label') 71 | c = label_dict[label] 72 | px = [int(i) for i in element.get('polygon_x').split(" ")] 73 | py = [int(i) for i in element.get('polygon_y').split(" ")] 74 | # get center coordinate (x,y), width and height (w, h) 75 | x = (px[2] + px[0]) / 2 / 225 76 | y = (py[2] + py[0]) / 2 / 300 77 | w = (px[2] - px[0]) / 225 78 | h = (py[2] - py[0]) / 300 79 | boxes.append([x, y, w, h]) 80 | labels.append(c) 81 | 82 | except: 83 | continue 84 | 85 | if len(labels) == 0: 86 | continue 87 | 88 | boxes = torch.tensor(boxes, dtype=torch.float) 89 | labels = torch.tensor(labels, dtype=torch.long) 90 | 91 | data = Data(x=boxes, y=labels) 92 | data.attr = { 93 | "name": file_name, 94 | "width": 225, 95 | "height": 300, 96 | "filtered": True, 97 | "has_canvas_element": False, 98 | "NoiseAdded": False, 99 | } 100 | data_list.append(data) 101 | 102 | if split == "train": 103 | train_list = data_list 104 | else: 105 | test_list = data_list 106 | 107 | # train 90% / test 10% 108 | # self.processed_paths: [train. val, test] 109 | with fs.open(self.processed_paths[0], "wb") as file_obj: 110 | print('duplicate training split (x100) for more batches per epoch') 111 | torch.save(self.collate(train_list * 100), file_obj) 112 | with fs.open(self.processed_paths[1], "wb") as file_obj: 113 | torch.save([], file_obj) 114 | with fs.open(self.processed_paths[2], "wb") as file_obj: 115 | torch.save(self.collate(test_list), file_obj) 116 | -------------------------------------------------------------------------------- /util/datasets/publaynet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from fsspec.core import url_to_fs 4 | from torch_geometric.data import Data 5 | from tqdm import tqdm 6 | from .base import BaseDataset 7 | from typing import List, Tuple, Union 8 | import numpy as np 9 | from torch import Tensor 10 | from collections.abc import Mapping, Sequence 11 | IndexType = Union[slice, Tensor, np.ndarray, Sequence] 12 | from torch import BoolTensor, FloatTensor, LongTensor 13 | from torch_geometric.utils import to_dense_batch 14 | 15 | def sparse_to_dense( 16 | batch, 17 | device: torch.device = torch.device("cpu"), 18 | remove_canvas: bool = False, 19 | ) -> Tuple[FloatTensor, LongTensor, BoolTensor, BoolTensor]: 20 | batch = batch.to(device) 21 | bbox, _ = to_dense_batch(batch.x, batch.batch) 22 | label, mask = to_dense_batch(batch.y, batch.batch) 23 | 24 | if remove_canvas: 25 | bbox = bbox[:, 1:].contiguous() 26 | label = label[:, 1:].contiguous() - 1 # cancel +1 effect in transform 27 | label = label.clamp(min=0) 28 | mask = mask[:, 1:].contiguous() 29 | 30 | padding_mask = ~mask 31 | return bbox, label, padding_mask, mask 32 | 33 | 34 | class PubLayNetDataset(BaseDataset): 35 | name = "publaynet" 36 | labels = [ 37 | "text", 38 | "title", 39 | "list", 40 | "table", 41 | "figure", 42 | ] 43 | 44 | def __init__(self, dir: str, split: str, max_seq_length: int, transform=None): 45 | super().__init__(dir, split, max_seq_length, transform) 46 | 47 | def download(self): 48 | # super().download() 49 | pass 50 | 51 | def process(self): 52 | from pycocotools.coco import COCO 53 | 54 | fs, _ = url_to_fs(self.raw_dir) 55 | 56 | # if self.raw_dir.startswith("gs://"): 57 | # raise NotImplementedError 58 | 59 | for split_publaynet in ["train", "val"]: 60 | data_list = [] 61 | coco = COCO( 62 | os.path.join(self.raw_dir, "publaynet", f"{split_publaynet}.json") 63 | ) 64 | for img_id in tqdm(sorted(coco.getImgIds())): 65 | ann_img = coco.loadImgs(img_id) 66 | W = float(ann_img[0]["width"]) 67 | H = float(ann_img[0]["height"]) 68 | name = ann_img[0]["file_name"] 69 | if H < W: 70 | continue 71 | 72 | def is_valid(element): 73 | x1, y1, width, height = element["bbox"] 74 | x2, y2 = x1 + width, y1 + height 75 | if x1 < 0 or y1 < 0 or W < x2 or H < y2: 76 | return False 77 | 78 | if x2 <= x1 or y2 <= y1: 79 | return False 80 | 81 | return True 82 | 83 | elements = coco.loadAnns(coco.getAnnIds(imgIds=[img_id])) 84 | _elements = list(filter(is_valid, elements)) 85 | filtered = len(elements) != len(_elements) 86 | elements = _elements 87 | 88 | N = len(elements) 89 | if N == 0 or self.max_seq_length < N: 90 | continue 91 | 92 | boxes = [] 93 | labels = [] 94 | 95 | for element in elements: 96 | # bbox 97 | x1, y1, width, height = element["bbox"] 98 | xc = x1 + width / 2.0 99 | yc = y1 + height / 2.0 100 | b = [xc / W, yc / H, width / W, height / H] 101 | boxes.append(b) 102 | 103 | # label 104 | l = coco.cats[element["category_id"]]["name"] 105 | labels.append(self.label2index[l]) 106 | 107 | boxes = torch.tensor(boxes, dtype=torch.float) 108 | labels = torch.tensor(labels, dtype=torch.long) 109 | 110 | data = Data(x=boxes, y=labels) 111 | data.attr = { 112 | "name": name, 113 | "width": W, 114 | "height": H, 115 | "filtered": filtered, 116 | "has_canvas_element": False, 117 | "NoiseAdded": False, 118 | } 119 | data_list.append(data) 120 | 121 | if split_publaynet == "train": 122 | train_list = data_list 123 | else: 124 | val_list = data_list 125 | 126 | # shuffle train with seed 127 | generator = torch.Generator().manual_seed(0) 128 | indices = torch.randperm(len(train_list), generator=generator) 129 | train_list = [train_list[i] for i in indices] 130 | 131 | # train_list -> train 95% / val 5% 132 | # val_list -> test 100% 133 | s = int(len(train_list) * 0.95) 134 | with fs.open(self.processed_paths[0], "wb") as file_obj: 135 | torch.save(self.collate(train_list[:s]), file_obj) 136 | with fs.open(self.processed_paths[1], "wb") as file_obj: 137 | torch.save(self.collate(train_list[s:]), file_obj) 138 | with fs.open(self.processed_paths[2], "wb") as file_obj: 139 | torch.save(self.collate(val_list), file_obj) 140 | 141 | 142 | if __name__ == "__main__": 143 | 144 | dataset = PubLayNetDataset(dir='./download/datasets', max_seq_length=25, split="test", transform=None) 145 | 146 | -------------------------------------------------------------------------------- /util/datasets/rico.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import os 4 | from pathlib import Path 5 | from zipfile import ZipFile 6 | 7 | import numpy as np 8 | import torch 9 | from fsspec.core import url_to_fs 10 | from PIL import Image, ImageDraw 11 | from torch_geometric.data import Data 12 | from tqdm import tqdm 13 | from util.seq_util import sparse_to_dense 14 | from util.constraint import xywh_to_ltrb_split 15 | from .base import BaseDataset 16 | 17 | _rico5_labels = [ 18 | "Text", 19 | "Text Button", 20 | "Toolbar", 21 | "Image", 22 | "Icon", 23 | ] 24 | 25 | _rico13_labels = [ 26 | "Toolbar", 27 | "Image", 28 | "Text", 29 | "Icon", 30 | "Text Button", 31 | "Input", 32 | "List Item", 33 | "Advertisement", 34 | "Pager Indicator", 35 | "Web View", 36 | "Background Image", 37 | "Drawer", 38 | "Modal", 39 | ] 40 | 41 | _rico25_labels = [ 42 | "Text", 43 | "Image", 44 | "Icon", 45 | "Text Button", 46 | "List Item", 47 | "Input", 48 | "Background Image", 49 | "Card", 50 | "Web View", 51 | "Radio Button", 52 | "Drawer", 53 | "Checkbox", 54 | "Advertisement", 55 | "Modal", 56 | "Pager Indicator", 57 | "Slider", 58 | "On/Off Switch", 59 | "Button Bar", 60 | "Toolbar", 61 | "Number Stepper", 62 | "Multi-Tab", 63 | "Date Picker", 64 | "Map View", 65 | "Video", 66 | "Bottom Navigation", 67 | ] 68 | 69 | 70 | def append_child(element, elements): 71 | if "children" in element.keys(): 72 | for child in element["children"]: 73 | elements.append(child) 74 | elements = append_child(child, elements) 75 | return elements 76 | 77 | 78 | class _RicoDataset(BaseDataset): 79 | def __init__(self, dir: str, split: str, max_seq_length: int, transform=None): 80 | super().__init__(dir, split, max_seq_length, transform) 81 | 82 | def process(self): 83 | data_list = [] 84 | raw_file = os.path.join( 85 | self.raw_dir, "rico_dataset_v0.1_semantic_annotations.zip" 86 | ) 87 | 88 | fs, _ = url_to_fs(self.raw_dir) 89 | with fs.open(raw_file, "rb") as f, ZipFile(f) as z: 90 | names = sorted([n for n in z.namelist() if n.endswith(".json")]) 91 | for name in tqdm(names, total=len(names), ncols=150, desc='prepare Rico'): 92 | ann = json.loads(z.open(name).read()) 93 | 94 | B = ann["bounds"] 95 | W, H = float(B[2]), float(B[3]) 96 | 97 | if B[0] != 0 or B[1] != 0 or H < W: 98 | continue 99 | 100 | def is_valid(element): 101 | if element["componentLabel"] not in set(self.labels): 102 | return False 103 | 104 | x1, y1, x2, y2 = element["bounds"] 105 | if x1 < 0 or y1 < 0 or W < x2 or H < y2: 106 | return False 107 | 108 | if x2 <= x1 or y2 <= y1: 109 | return False 110 | 111 | return True 112 | 113 | elements = append_child(ann, []) 114 | _elements = list(filter(is_valid, elements)) 115 | filtered = len(elements) != len(_elements) 116 | elements = _elements 117 | N = len(elements) 118 | if N == 0 or self.max_seq_length < N: 119 | continue 120 | 121 | # only for debugging slice-based preprocessing 122 | # elements = append_child(ann, []) 123 | # filtered = False 124 | # if len(elements) == 0: 125 | # continue 126 | # elements = elements[:self.max_seq_length] 127 | 128 | boxes = [] 129 | labels = [] 130 | 131 | for element in elements: 132 | # bbox 133 | x1, y1, x2, y2 = element["bounds"] 134 | xc = (x1 + x2) / 2.0 135 | yc = (y1 + y2) / 2.0 136 | width = x2 - x1 137 | height = y2 - y1 138 | b = [xc / W, yc / H, width / W, height / H] 139 | boxes.append(b) 140 | 141 | # label 142 | l = element["componentLabel"] 143 | labels.append(self.label2index[l]) 144 | 145 | boxes = torch.tensor(boxes, dtype=torch.float) 146 | labels = torch.tensor(labels, dtype=torch.long) 147 | 148 | data = Data(x=boxes, y=labels) 149 | data.attr = { 150 | "name": name, 151 | "width": W, 152 | "height": H, 153 | "filtered": filtered, 154 | "has_canvas_element": False, 155 | "NoiseAdded": False, 156 | } 157 | data_list.append(data) 158 | 159 | # shuffle with seed 160 | generator = torch.Generator().manual_seed(0) 161 | indices = torch.randperm(len(data_list), generator=generator) 162 | data_list = [data_list[i] for i in indices] 163 | 164 | # train 85% / val 5% / test 10% 165 | N = len(data_list) 166 | s = [int(N * 0.85), int(N * 0.90)] 167 | 168 | with fs.open(self.processed_paths[0], "wb") as file_obj: 169 | print('duplicate training split (x10) for more batches per epoch') 170 | torch.save(self.collate(data_list[: s[0]] * 10), file_obj) 171 | with fs.open(self.processed_paths[1], "wb") as file_obj: 172 | torch.save(self.collate(data_list[s[0]:s[1]]), file_obj) 173 | with fs.open(self.processed_paths[2], "wb") as file_obj: 174 | torch.save(self.collate(data_list[s[1]:]), file_obj) 175 | 176 | def download(self): 177 | pass 178 | 179 | def get_original_resource(self, batch) -> Image: 180 | assert not self.raw_dir.startswith("gs://") 181 | bbox, _, _, _ = sparse_to_dense(batch) 182 | 183 | img_bg, img_original, cropped_patches = [], [], [] 184 | names = batch.attr["name"] 185 | if isinstance(names, str): 186 | names = [names] 187 | 188 | for i, name in enumerate(names): 189 | name = Path(name).name.replace(".json", ".jpg") 190 | img = Image.open(Path(self.raw_dir) / "combined" / name) 191 | img_original.append(copy.deepcopy(img)) 192 | 193 | W, H = img.size 194 | ltrb = xywh_to_ltrb_split(bbox[i].T.numpy()) 195 | left, right = (ltrb[0] * W).astype(np.uint32), (ltrb[2] * W).astype( 196 | np.uint32 197 | ) 198 | top, bottom = (ltrb[1] * H).astype(np.uint32), (ltrb[3] * H).astype( 199 | np.uint32 200 | ) 201 | draw = ImageDraw.Draw(img) 202 | patches = [] 203 | for (l, r, t, b) in zip(left, right, top, bottom): 204 | patches.append(img.crop((l, t, r, b))) 205 | # draw.rectangle([(l, t), (r, b)], fill=(255, 0, 0)) 206 | draw.rectangle([(l, t), (r, b)], fill=(255, 255, 255)) 207 | img_bg.append(img) 208 | cropped_patches.append(patches) 209 | # if len(patches) < S: 210 | # for i in range(S - len(patches)): 211 | # patches.append(Image.new("RGB", (0, 0))) 212 | 213 | return { 214 | "img_bg": img_bg, 215 | "img_original": img_original, 216 | "cropped_patches": cropped_patches, 217 | } 218 | 219 | # read from uncompressed data (the last line takes infinite time, so not used now..) 220 | # raw_file = os.path.join(self.raw_dir, "unique_uis.tar.gz") 221 | # with tarfile.open(raw_file) as f: 222 | # # return gzip.GzipFile(fileobj=f.extractfile(f"combined/{name}")).read() 223 | # return gzip.GzipFile(fileobj=f.extractfile(f"combined/hoge")).read() 224 | 225 | 226 | class Rico5Dataset(_RicoDataset): 227 | name = "rico5" 228 | labels = _rico5_labels 229 | 230 | def __init__(self, dir: str, split: str, max_seq_length: int, transform=None): 231 | super().__init__(dir, split, max_seq_length, transform) 232 | 233 | 234 | # Constrained Graphic Layout Generation via Latent Optimization (ACMMM2021) 235 | class Rico13Dataset(_RicoDataset): 236 | name = "rico13" 237 | labels = _rico13_labels 238 | 239 | def __init__(self, dir: str, split: str, max_seq_length: int, transform=None): 240 | super().__init__(dir, split, max_seq_length, transform) 241 | 242 | 243 | class Rico25Dataset(_RicoDataset): 244 | name = "rico25" 245 | labels = _rico25_labels 246 | 247 | def __init__(self, dir: str, split: str, max_seq_length: int, transform=None): 248 | super().__init__(dir, split, max_seq_length, transform) 249 | -------------------------------------------------------------------------------- /util/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import einops 4 | import torch 5 | import numpy as np 6 | from tqdm import tqdm 7 | import torch.nn.functional as F 8 | torch.manual_seed(3) 9 | import sys 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import matplotlib 13 | matplotlib.use('Agg') 14 | import matplotlib.pyplot as plt 15 | 16 | def softmax_prefix(input_tensor, n): 17 | 18 | # Extract the first 6 entries on the second dimension 19 | sub_tensor = input_tensor[:, :, :n] 20 | 21 | # Apply softmax along the second dimension 22 | softmax_sub_tensor = F.softmax(sub_tensor, dim=2) 23 | 24 | # Replace the first 6 entries with the softmax results 25 | output_tensor = input_tensor.clone() 26 | output_tensor[:, :, :n] = softmax_sub_tensor 27 | 28 | return output_tensor 29 | 30 | 31 | def make_beta_schedule(schedule="linear", num_timesteps=1000, start=1e-5, end=1e-2): 32 | if schedule == "linear": 33 | betas = torch.linspace(start, end, num_timesteps) 34 | elif schedule == "const": 35 | betas = end * torch.ones(num_timesteps) 36 | elif schedule == "quad": 37 | betas = torch.linspace(start ** 0.5, end ** 0.5, num_timesteps) ** 2 38 | elif schedule == "jsd": 39 | betas = 1.0 / torch.linspace(num_timesteps, 1, num_timesteps) 40 | elif schedule == "sigmoid": 41 | betas = torch.linspace(-6, 6, num_timesteps) 42 | betas = torch.sigmoid(betas) * (end - start) + start 43 | elif schedule == "cosine" or schedule == "cosine_reverse": 44 | max_beta = 0.999 45 | cosine_s = 0.008 46 | betas = torch.tensor( 47 | [min(1 - (math.cos(((i + 1) / num_timesteps + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2) / ( 48 | math.cos((i / num_timesteps + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2), max_beta) for i in 49 | range(num_timesteps)]) 50 | elif schedule == "cosine_anneal": 51 | betas = torch.tensor( 52 | [start + 0.5 * (end - start) * (1 - math.cos(t / (num_timesteps - 1) * math.pi)) for t in 53 | range(num_timesteps)]) 54 | return betas 55 | 56 | 57 | def extract(input, t, x): 58 | shape = x.shape 59 | out = torch.gather(input, 0, t.to(input.device)) 60 | reshape = [t.shape[0]] + [1] * (len(shape) - 1) 61 | return out.reshape(*reshape) 62 | 63 | 64 | # Forward functions 65 | def q_sample(y, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, t, noise=None): 66 | 67 | if noise is None: 68 | noise = torch.randn_like(y).to(y.device) 69 | sqrt_alpha_bar_t = extract(alphas_bar_sqrt, t, y) 70 | 71 | sqrt_one_minus_alpha_bar_t = extract(one_minus_alphas_bar_sqrt, t, y) 72 | # q(y_t | y_0, x) 73 | y_t = sqrt_alpha_bar_t * y + sqrt_one_minus_alpha_bar_t * noise 74 | 75 | return y_t 76 | 77 | 78 | # Reverse function -- sample y_{t-1} given y_t 79 | def p_sample(model, y_t, t, alphas, one_minus_alphas_bar_sqrt, stochastic=True): 80 | """ 81 | Reverse diffusion process sampling -- one time step. 82 | y: sampled y at time step t, y_t. 83 | """ 84 | device = next(model.parameters()).device 85 | z = stochastic * torch.randn_like(y_t) 86 | t = torch.tensor([t]).to(device) 87 | alpha_t = extract(alphas, t, y_t) 88 | sqrt_one_minus_alpha_bar_t = extract(one_minus_alphas_bar_sqrt, t, y_t) 89 | sqrt_one_minus_alpha_bar_t_m_1 = extract(one_minus_alphas_bar_sqrt, t - 1, y_t) 90 | sqrt_alpha_bar_t = (1 - sqrt_one_minus_alpha_bar_t.square()).sqrt() 91 | sqrt_alpha_bar_t_m_1 = (1 - sqrt_one_minus_alpha_bar_t_m_1.square()).sqrt() 92 | # y_t_m_1 posterior mean component coefficients 93 | gamma_0 = (1 - alpha_t) * sqrt_alpha_bar_t_m_1 / (sqrt_one_minus_alpha_bar_t.square()) 94 | gamma_1 = (sqrt_one_minus_alpha_bar_t_m_1.square()) * (alpha_t.sqrt()) / (sqrt_one_minus_alpha_bar_t.square()) 95 | 96 | eps_theta = model(y_t, timestep=t).to(device).detach() 97 | 98 | # y_0 reparameterization 99 | y_0_reparam = 1 / sqrt_alpha_bar_t * (y_t - eps_theta * sqrt_one_minus_alpha_bar_t).to(device) 100 | 101 | # posterior mean 102 | y_t_m_1_hat = gamma_0 * y_0_reparam + gamma_1 * y_t 103 | 104 | # posterior variance 105 | beta_t_hat = (sqrt_one_minus_alpha_bar_t_m_1.square()) / (sqrt_one_minus_alpha_bar_t.square()) * (1 - alpha_t) 106 | y_t_m_1 = y_t_m_1_hat.to(device) + beta_t_hat.sqrt().to(device) * z.to(device) 107 | 108 | 109 | return y_t_m_1 110 | 111 | 112 | # Reverse function -- sample y_0 given y_1 113 | def p_sample_t_1to0(model, y_t, one_minus_alphas_bar_sqrt): 114 | device = next(model.parameters()).device 115 | t = torch.tensor([0]).to(device) # corresponding to timestep 1 (i.e., t=1 in diffusion models) 116 | sqrt_one_minus_alpha_bar_t = extract(one_minus_alphas_bar_sqrt, t, y_t) 117 | sqrt_alpha_bar_t = (1 - sqrt_one_minus_alpha_bar_t.square()).sqrt() 118 | eps_theta = model(y_t, timestep=t).to(device).detach() 119 | 120 | # y_0 reparameterization 121 | y_0_reparam = 1 / sqrt_alpha_bar_t * (y_t - eps_theta * sqrt_one_minus_alpha_bar_t).to(device) 122 | 123 | y_t_m_1 = y_0_reparam.to(device) 124 | 125 | return y_t_m_1 126 | 127 | 128 | def p_sample_loop(model, batch_size, n_steps, alphas, one_minus_alphas_bar_sqrt, 129 | only_last_sample=True, stochastic=True): 130 | num_t, l_p_seq = None, None 131 | 132 | device = next(model.parameters()).device 133 | 134 | l_t = stochastic * torch.randn_like(torch.zeros([batch_size, 25, 10])).to(device) 135 | if only_last_sample: 136 | num_t = 1 137 | else: 138 | # y_p_seq = [y_t] 139 | l_p_seq = torch.zeros([batch_size, 25, 10, n_steps+1]).to(device) 140 | l_p_seq[:, :, :, n_steps] = l_t 141 | 142 | for t in reversed(range(1, n_steps-1)): 143 | 144 | l_t = p_sample(model, l_t, t, alphas, one_minus_alphas_bar_sqrt, stochastic=stochastic) # y_{t-1} 145 | 146 | if only_last_sample: 147 | num_t += 1 148 | else: 149 | # y_p_seq.append(y_t) 150 | l_p_seq[:, :, :, t] = l_t 151 | 152 | 153 | if only_last_sample: 154 | l_0 = p_sample_t_1to0(model, l_t, one_minus_alphas_bar_sqrt) 155 | return l_0 156 | else: 157 | # assert len(y_p_seq) == n_steps 158 | l_0 = p_sample_t_1to0(model, l_p_seq[:, :, :, 1], one_minus_alphas_bar_sqrt) 159 | # y_p_seq.append(y_0) 160 | l_p_seq[:, :, :, 0] = l_0 161 | 162 | return l_0, l_p_seq 163 | 164 | 165 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps): 166 | if ddim_discr_method == 'uniform': 167 | c = num_ddpm_timesteps // num_ddim_timesteps 168 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 169 | elif ddim_discr_method == 'quad': 170 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 171 | elif ddim_discr_method == 'new': 172 | c = (num_ddpm_timesteps - 50) // (num_ddim_timesteps - 50) 173 | ddim_timesteps = np.asarray(list(range(0, 50)) + list(range(50, num_ddpm_timesteps - 50, c))) 174 | else: 175 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 176 | 177 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 178 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 179 | steps_out = ddim_timesteps + 1 180 | # print(f'Selected timesteps for ddim sampler: {steps_out}') 181 | 182 | return steps_out 183 | 184 | 185 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta): 186 | # select alphas for computing the variance schedule 187 | device = alphacums.device 188 | alphas = alphacums[ddim_timesteps] 189 | alphas_prev = torch.tensor([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()).to(device) 190 | 191 | sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 192 | # print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 193 | # print(f'For the chosen value of eta, which is {eta}, ' 194 | # f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 195 | return sigmas, alphas, alphas_prev 196 | 197 | 198 | def ddim_sample_loop(model, batch_size, timesteps, ddim_alphas, ddim_alphas_prev, ddim_sigmas, stochastic=True, 199 | seq_len=25, seq_dim=10): 200 | device = next(model.parameters()).device 201 | 202 | b_t = 1 * stochastic * torch.randn_like(torch.zeros([batch_size, seq_len, seq_dim])).to(device) 203 | 204 | intermediates = {'y_inter': [b_t], 'pred_y0': [b_t]} 205 | time_range = np.flip(timesteps) 206 | total_steps = timesteps.shape[0] 207 | # print(f"Running DDIM Sampling with {total_steps} timesteps") 208 | 209 | for i, step in enumerate(time_range): 210 | index = total_steps - i - 1 211 | t = torch.full((batch_size,), step, device=device, dtype=torch.long) 212 | 213 | b_t, pred_y0 = ddim_sample_step(model, b_t, t, index, ddim_alphas, 214 | ddim_alphas_prev, ddim_sigmas) 215 | 216 | intermediates['y_inter'].append(b_t) 217 | intermediates['pred_y0'].append(pred_y0) 218 | 219 | return b_t, intermediates 220 | 221 | 222 | def rand_fix(batch_size, mask, ratio=0.2, n_elements=25, stochastic=True): 223 | 224 | if stochastic: 225 | indices = (torch.rand([batch_size, n_elements]) <= torch.rand([1]).item() * ratio).to(mask.device) * mask.to(torch.bool) 226 | else: 227 | a = torch.tensor([False, False, True, False, False, True, True, False, False, False, 228 | False, False, False, False, False, False, False, False, False, False, 229 | False, False, False, False, False]) 230 | indices = einops.repeat(a, "l -> n l", n=batch_size) * mask.to(torch.bool) 231 | 232 | return indices 233 | 234 | 235 | def ddim_cond_sample_loop(model, real_layout, timesteps, ddim_alphas, ddim_alphas_prev, ddim_sigmas, stochastic=True, cond='c', ratio=0.2): 236 | 237 | device = next(model.parameters()).device 238 | batch_size, seq_len, seq_dim = real_layout.shape 239 | num_class = seq_dim - 4 240 | 241 | 242 | real_label = torch.argmax(real_layout[:, :, :num_class], dim=2) 243 | real_mask = (real_label != num_class-1).clone().detach() 244 | 245 | # cond mask 246 | if cond == 'complete': 247 | fix_mask = rand_fix(batch_size, real_mask, ratio=ratio, stochastic=True) 248 | elif cond == 'c': 249 | fix_mask = torch.zeros([batch_size, seq_len, seq_dim]).to(torch.bool) 250 | fix_mask[:, :, :num_class] = True 251 | elif cond == 'cwh': 252 | fix_mask = torch.zeros([batch_size, seq_len, seq_dim]).to(torch.bool) 253 | fix_ind = [x for x in range(num_class)] + [num_class + 2, num_class + 3] 254 | fix_mask[:, :, fix_ind] = True 255 | else: 256 | raise Exception('cond must be c, cwh, or complete') 257 | 258 | l_t = 1 * stochastic * torch.randn_like(torch.zeros([batch_size, seq_len, seq_dim])).to(device) 259 | 260 | intermediates = {'y_inter': [l_t], 'pred_y0': [l_t]} 261 | time_range = np.flip(timesteps) 262 | total_steps = timesteps.shape[0] 263 | # noise = 1 * torch.randn_like(real_layout).to(device) 264 | 265 | for i, step in enumerate(time_range): 266 | index = total_steps - i - 1 267 | t = torch.full((batch_size,), step, device=device, dtype=torch.long) 268 | 269 | l_t[fix_mask] = real_layout[fix_mask] 270 | 271 | # # plot inter 272 | # print('hi here') 273 | # l_t_label = torch.argmax(l_t[:, :, :6], dim=2) 274 | # l_t_mask = (l_t_label != 5).clone().detach() 275 | # l_bbox = torch.clamp(l_t[:1, :, 6:], min=-1, max=1) / 2 + 0.5 276 | # img = save_image(l_bbox[:4], l_t_label[:4], l_t_mask[:4], draw_label=False) 277 | # plt.figure(figsize=[12, 12]) 278 | # plt.imshow(img) 279 | # plt.tight_layout() 280 | # plt.savefig(f'./plot/test/conditional_test_t_{t[0]}.png') 281 | # plt.close() 282 | 283 | l_t, pred_y0 = ddim_sample_step(model, l_t, t, index, ddim_alphas, 284 | ddim_alphas_prev, ddim_sigmas) 285 | 286 | l_t[fix_mask] = real_layout[fix_mask] 287 | 288 | intermediates['y_inter'].append(l_t) 289 | intermediates['pred_y0'].append(pred_y0) 290 | 291 | return l_t, intermediates 292 | 293 | 294 | def ddim_refine_sample_loop(model, noisy_layout, timesteps, ddim_alphas, ddim_alphas_prev, ddim_sigmas): 295 | 296 | device = next(model.parameters()).device 297 | batch_size, seq_len, seq_dim = noisy_layout.shape 298 | l_t = noisy_layout 299 | 300 | intermediates = {'y_inter': [l_t], 'pred_y0': [l_t]} 301 | total_steps = sum(timesteps <= 201) 302 | time_range = np.flip(timesteps[:total_steps]) 303 | 304 | # noise = 1 * torch.randn_like(real_layout).to(device) 305 | for i, step in enumerate(time_range): 306 | index = total_steps - i - 1 307 | t = torch.full((batch_size,), step, device=device, dtype=torch.long) 308 | 309 | l_t, pred_y0 = ddim_sample_step(model, l_t, t, index, ddim_alphas, 310 | ddim_alphas_prev, ddim_sigmas) 311 | 312 | intermediates['y_inter'].append(l_t) 313 | intermediates['pred_y0'].append(pred_y0) 314 | 315 | return l_t, intermediates 316 | 317 | def ddim_sample_step(model, l_t, t, index, ddim_alphas, ddim_alphas_prev, ddim_sigmas): 318 | 319 | device = next(model.parameters()).device 320 | e_t = model(l_t, timestep=t).to(device).detach() 321 | 322 | sqrt_one_minus_alphas = torch.sqrt(1. - ddim_alphas) 323 | # select parameters corresponding to the currently considered timestep 324 | a_t = torch.full(e_t.shape, ddim_alphas[index], device=device) 325 | a_t_m_1 = torch.full(e_t.shape, ddim_alphas_prev[index], device=device) 326 | sigma_t = torch.full(e_t.shape, ddim_sigmas[index], device=device) 327 | sqrt_one_minus_at = torch.full(e_t.shape, sqrt_one_minus_alphas[index], device=device) 328 | 329 | # direction pointing to x_t 330 | dir_b_t = (1. - a_t_m_1 - sigma_t ** 2).sqrt() * e_t 331 | noise = sigma_t * torch.randn_like(l_t).to(device) 332 | 333 | # reparameterize x_0 334 | b_0_reparam = (l_t - sqrt_one_minus_at * e_t) / a_t.sqrt() 335 | 336 | # compute b_t_m_1 337 | b_t_m_1 = a_t_m_1.sqrt() * b_0_reparam + 1 * dir_b_t + noise 338 | 339 | return b_t_m_1, b_0_reparam 340 | 341 | 342 | # def compute_piou_grad(layout_in): 343 | # 344 | # layout = layout_in.clone().detach() 345 | # 346 | # # convert centralized layout bbox representation [-1, 1] to wh representation [0, 1] 347 | # layout[:, :, 6:] = torch.clamp(layout[:, :, 6:], min=-1, max=1) / 2 + 0.5 348 | # # print('layout generated:', layout_t_0[0, :10, :]) 349 | # bbox = layout[:, :, 6:] 350 | # label = torch.argmax(layout[:, :, :6], dim=2) 351 | # mask = (label != 5).clone().detach() 352 | # 353 | # # compute grad iou 354 | # piou_gradient, mpiou = guidance_grad(bbox, mask=mask.to(torch.float), xy_only=True) 355 | # 356 | # piou_gradient = 2 * (piou_gradient - 0.5) 357 | # 358 | # return piou_gradient, mpiou 359 | 360 | 361 | 362 | if __name__ == "__main__": 363 | 364 | # t = make_ddim_timesteps('new', 200, 1000) 365 | # print(t) 366 | 367 | m = rand_fix(batch_size=2, mask=torch.tensor([1 for _ in range(25)]), n_elements=25) 368 | print(m) 369 | -------------------------------------------------------------------------------- /util/ema.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class EMA(object): 4 | def __init__(self, mu=0.999): 5 | self.mu = mu 6 | self.shadow = {} 7 | 8 | def register(self, module): 9 | for name, param in module.named_parameters(): 10 | if param.requires_grad: 11 | self.shadow[name] = param.data.clone() 12 | 13 | def update(self, module): 14 | for name, param in module.named_parameters(): 15 | if param.requires_grad: 16 | self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data 17 | 18 | def ema(self, module): 19 | for name, param in module.named_parameters(): 20 | if param.requires_grad: 21 | param.data.copy_(self.shadow[name].data) 22 | 23 | def ema_copy(self, module): 24 | module_copy = type(module)(module.config).to(module.config.device) 25 | module_copy.load_state_dict(module.state_dict()) 26 | self.ema(module_copy) 27 | return module_copy 28 | 29 | def state_dict(self): 30 | return self.shadow 31 | 32 | def load_state_dict(self, state_dict): 33 | self.shadow = state_dict 34 | -------------------------------------------------------------------------------- /util/metric.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | from functools import partial 3 | from itertools import chain 4 | from typing import Dict, List, Optional, Tuple, Union 5 | 6 | import numpy as np 7 | import torch 8 | import torch.distributions as tdist 9 | from einops import rearrange, reduce, repeat 10 | from prdc import compute_prdc 11 | from pytorch_fid.fid_score import calculate_frechet_distance 12 | from scipy.optimize import linear_sum_assignment 13 | from scipy.stats import wasserstein_distance 14 | from torch import BoolTensor, FloatTensor 15 | from torch_geometric.utils import to_dense_adj 16 | from .data_util import RelLoc, RelSize, detect_loc_relation, detect_size_relation 17 | from .constraint import xywh_to_ltrb_split 18 | 19 | Feats = Union[FloatTensor, List[FloatTensor]] 20 | Layout = Tuple[np.ndarray, np.ndarray] 21 | 22 | # set True to disable parallel computing by multiprocessing (typically for debug) 23 | # DISABLED = False 24 | DISABLED = True 25 | 26 | 27 | def __to_numpy_array(feats: Feats) -> np.ndarray: 28 | if isinstance(feats, list): 29 | # flatten list of batch-processed features 30 | if isinstance(feats[0], FloatTensor): 31 | feats = [x.detach().cpu().numpy() for x in feats] 32 | else: 33 | feats = feats.detach().cpu().numpy() 34 | return np.concatenate(feats) 35 | 36 | 37 | def compute_generative_model_scores( 38 | feats_real: Feats, 39 | feats_fake: Feats, 40 | ) -> Dict[str, float]: 41 | """ 42 | Compute precision, recall, density, coverage, and FID. 43 | """ 44 | feats_real = __to_numpy_array(feats_real) 45 | feats_fake = __to_numpy_array(feats_fake) 46 | 47 | mu_real = np.mean(feats_real, axis=0) 48 | sigma_real = np.cov(feats_real, rowvar=False) 49 | mu_fake = np.mean(feats_fake, axis=0) 50 | sigma_fake = np.cov(feats_fake, rowvar=False) 51 | 52 | results = compute_prdc( 53 | real_features=feats_real, fake_features=feats_fake, nearest_k=5 54 | ) 55 | results["fid"] = calculate_frechet_distance( 56 | mu_real, sigma_real, mu_fake, sigma_fake 57 | ) 58 | 59 | return results 60 | 61 | 62 | def compute_violation(bbox_flatten, data): 63 | """ 64 | Compute relation violation accuracy as in LayoutGAN++ [Kikuchi+, ACMMM'21]. 65 | """ 66 | device = data.x.device 67 | failures, valid = [], [] 68 | 69 | _zip = zip(data.edge_attr, data.edge_index.t()) 70 | for gt, (i, j) in _zip: 71 | failure, _valid = 0, 0 72 | b1, b2 = bbox_flatten[i], bbox_flatten[j] 73 | 74 | # size relation 75 | if ~gt & 1 << RelSize.UNKNOWN: 76 | pred = detect_size_relation(b1, b2) 77 | failure += (gt & 1 << pred).eq(0).long() 78 | _valid += 1 79 | 80 | # loc relation 81 | if ~gt & 1 << RelLoc.UNKNOWN: 82 | canvas = data.y[i].eq(0) 83 | pred = detect_loc_relation(b1, b2, canvas) 84 | failure += (gt & 1 << pred).eq(0).long() 85 | _valid += 1 86 | 87 | failures.append(failure) 88 | valid.append(_valid) 89 | 90 | failures = torch.as_tensor(failures).to(device) 91 | failures = to_dense_adj(data.edge_index, data.batch, failures) 92 | valid = torch.as_tensor(valid).to(device) 93 | valid = to_dense_adj(data.edge_index, data.batch, valid) 94 | 95 | return failures.sum((1, 2)) / valid.sum((1, 2)) 96 | 97 | 98 | def compute_alignment(bbox: FloatTensor, mask: BoolTensor) -> FloatTensor: 99 | """ 100 | Computes some alignment metrics that are different to each other in previous works. 101 | Attribute-conditioned Layout GAN for Automatic Graphic Design (TVCG2020) 102 | https://arxiv.org/abs/2009.05284 103 | """ 104 | S = bbox.size(1) 105 | 106 | bbox = bbox.permute(2, 0, 1) 107 | xl, yt, xr, yb = xywh_to_ltrb_split(bbox) 108 | xc, yc = bbox[0], bbox[1] 109 | X = torch.stack([xl, xc, xr, yt, yc, yb], dim=1) 110 | X = X.unsqueeze(-1) - X.unsqueeze(-2) 111 | idx = torch.arange(X.size(2), device=X.device) 112 | X[:, :, idx, idx] = 1.0 113 | X = X.abs().permute(0, 2, 1, 3) 114 | X[~mask] = 1.0 115 | X = X.min(-1).values.min(-1).values 116 | X.masked_fill_(X.eq(1.0), 0.0) 117 | X = -torch.log(1 - X) 118 | 119 | # original 120 | # return X.sum(-1) / mask.float().sum(-1) 121 | 122 | score = reduce(X, "b s -> b", reduction="sum") 123 | score_normalized = score / reduce(mask, "b s -> b", reduction="sum") 124 | score_normalized[torch.isnan(score_normalized)] = 0.0 125 | 126 | Y = torch.stack([xl, xc, xr], dim=1) 127 | Y = rearrange(Y, "b x s -> b x 1 s") - rearrange(Y, "b x s -> b x s 1") 128 | 129 | batch_mask = rearrange(~mask, "b s -> b 1 s") | rearrange(~mask, "b s -> b s 1") 130 | idx = torch.arange(S, device=Y.device) 131 | batch_mask[:, idx, idx] = True 132 | batch_mask = repeat(batch_mask, "b s1 s2 -> b x s1 s2", x=3) 133 | Y[batch_mask] = 1.0 134 | 135 | # Y = rearrange(Y.abs(), "b x s1 s2 -> b s1 x s2") 136 | # Y = reduce(Y, "b x s1 s2 -> b x", "min") 137 | # Y = rearrange(Y.abs(), " -> b s1 x s2") 138 | Y = reduce(Y.abs(), "b x s1 s2 -> b s1", "min") 139 | Y[Y == 1.0] = 0.0 140 | score_Y = reduce(Y, "b s -> b", "sum") 141 | 142 | results = { 143 | "alignment-ACLayoutGAN": score, 144 | "alignment-LayoutGAN++": score_normalized, 145 | "alignment-NDN": score_Y, 146 | } 147 | return results["alignment-LayoutGAN++"] 148 | 149 | 150 | def compute_overlap(bbox: FloatTensor, mask: BoolTensor): 151 | """ 152 | Based on 153 | (i) Attribute-conditioned Layout GAN for Automatic Graphic Design (TVCG2020) 154 | https://arxiv.org/abs/2009.05284 155 | (ii) LAYOUTGAN: GENERATING GRAPHIC LAYOUTS WITH WIREFRAME DISCRIMINATORS (ICLR2019) 156 | https://arxiv.org/abs/1901.06767 157 | "percentage of total overlapping area among any two bounding boxes inside the whole page." 158 | At least BLT authors seems to sum. (in the MSCOCO case, it surpasses 1.0) 159 | """ 160 | B, S = mask.size() 161 | bbox = bbox.masked_fill(~mask.unsqueeze(-1), 0) 162 | bbox = bbox.permute(2, 0, 1) 163 | 164 | l1, t1, r1, b1 = xywh_to_ltrb_split(bbox.unsqueeze(-1)) 165 | l2, t2, r2, b2 = xywh_to_ltrb_split(bbox.unsqueeze(-2)) 166 | a1 = (r1 - l1) * (b1 - t1) 167 | 168 | # intersection 169 | l_max = torch.maximum(l1, l2) 170 | r_min = torch.minimum(r1, r2) 171 | t_max = torch.maximum(t1, t2) 172 | b_min = torch.minimum(b1, b2) 173 | cond = (l_max < r_min) & (t_max < b_min) 174 | ai = torch.where(cond, (r_min - l_max) * (b_min - t_max), torch.zeros_like(a1[0])) 175 | 176 | # diag_mask = torch.eye(a1.size(1), dtype=torch.bool, device=a1.device) 177 | # ai = ai.masked_fill(diag_mask, 0) 178 | batch_mask = rearrange(~mask, "b s -> b 1 s") | rearrange(~mask, "b s -> b s 1") 179 | idx = torch.arange(S, device=ai.device) 180 | batch_mask[:, idx, idx] = True 181 | ai = ai.masked_fill(batch_mask, 0) 182 | 183 | ar = torch.nan_to_num(ai / a1) # (B, S, S) 184 | 185 | # original 186 | # return ar.sum(dim=(1, 2)) / mask.float().sum(-1) 187 | 188 | # fixed to avoid the case with single bbox 189 | score = reduce(ar, "b s1 s2 -> b", reduction="sum") 190 | score_normalized = score / reduce(mask, "b s -> b", reduction="sum") 191 | score_normalized[torch.isnan(score_normalized)] = 0.0 192 | 193 | ids = torch.arange(S) 194 | ii, jj = torch.meshgrid(ids, ids, indexing="ij") 195 | ai[repeat(ii >= jj, "s1 s2 -> b s1 s2", b=B)] = 0.0 196 | overlap = reduce(ai, "b s1 s2 -> b", reduction="sum") 197 | 198 | results = { 199 | "overlap-ACLayoutGAN": score, 200 | "overlap-LayoutGAN++": score_normalized, 201 | "overlap-LayoutGAN": overlap, 202 | } 203 | return results["overlap-LayoutGAN++"] 204 | 205 | 206 | def compute_iou( 207 | box_1: Union[np.ndarray, FloatTensor], 208 | box_2: Union[np.ndarray, FloatTensor], 209 | generalized: bool = False, 210 | ) -> Union[np.ndarray, FloatTensor]: 211 | # box_1: [N, 4] box_2: [N, 4] 212 | 213 | if isinstance(box_1, np.ndarray): 214 | lib = np 215 | elif isinstance(box_1, FloatTensor): 216 | lib = torch 217 | else: 218 | raise NotImplementedError(type(box_1)) 219 | 220 | l1, t1, r1, b1 = xywh_to_ltrb_split(box_1.T) 221 | l2, t2, r2, b2 = xywh_to_ltrb_split(box_2.T) 222 | a1, a2 = (r1 - l1) * (b1 - t1), (r2 - l2) * (b2 - t2) 223 | 224 | # intersection 225 | l_max = lib.maximum(l1, l2) 226 | r_min = lib.minimum(r1, r2) 227 | t_max = lib.maximum(t1, t2) 228 | b_min = lib.minimum(b1, b2) 229 | cond = (l_max < r_min) & (t_max < b_min) 230 | ai = lib.where(cond, (r_min - l_max) * (b_min - t_max), lib.zeros_like(a1[0])) 231 | 232 | au = a1 + a2 - ai 233 | iou = ai / au 234 | 235 | if not generalized: 236 | return iou 237 | 238 | # outer region 239 | l_min = lib.minimum(l1, l2) 240 | r_max = lib.maximum(r1, r2) 241 | t_min = lib.minimum(t1, t2) 242 | b_max = lib.maximum(b1, b2) 243 | ac = (r_max - l_min) * (b_max - t_min) 244 | 245 | giou = iou - (ac - au) / ac 246 | 247 | return giou 248 | 249 | 250 | def compute_perceptual_iou( 251 | box_1: Union[np.ndarray, FloatTensor], 252 | box_2: Union[np.ndarray, FloatTensor], 253 | ) -> Union[np.ndarray, FloatTensor]: 254 | """ 255 | Computes 'Perceptual' IoU [Kong+, BLT'22] 256 | """ 257 | # box_1: [N, 4] box_2: [N, 4] 258 | 259 | if isinstance(box_1, np.ndarray): 260 | lib = np 261 | elif isinstance(box_1, FloatTensor): 262 | lib = torch 263 | else: 264 | raise NotImplementedError(type(box_1)) 265 | 266 | l1, t1, r1, b1 = xywh_to_ltrb_split(box_1.T) 267 | l2, t2, r2, b2 = xywh_to_ltrb_split(box_2.T) 268 | a1, a2 = (r1 - l1) * (b1 - t1), (r2 - l2) * (b2 - t2) 269 | 270 | # intersection 271 | l_max = lib.maximum(l1, l2) 272 | r_min = lib.minimum(r1, r2) 273 | t_max = lib.maximum(t1, t2) 274 | b_min = lib.minimum(b1, b2) 275 | cond = (l_max < r_min) & (t_max < b_min) 276 | ai = lib.where(cond, (r_min - l_max) * (b_min - t_max), lib.zeros_like(a1[0])) 277 | 278 | # numpy-only procedure in this part 279 | if isinstance(box_1, FloatTensor): 280 | unique_box_1 = np.unique(box_1.numpy(), axis=0) 281 | else: 282 | unique_box_1 = np.unique(box_1, axis=0) 283 | N = 32 284 | l1, t1, r1, b1 = [ 285 | (x * N).round().astype(np.int32).clip(0, N) 286 | for x in xywh_to_ltrb_split(unique_box_1.T) 287 | ] 288 | canvas = np.zeros((N, N)) 289 | for (l, t, r, b) in zip(l1, t1, r1, b1): 290 | canvas[t:b, l:r] = 1 291 | global_area_union = canvas.sum() / (N**2) 292 | 293 | if global_area_union > 0.0: 294 | iou = ai / global_area_union 295 | return iou 296 | else: 297 | return lib.zeros((1,)) 298 | 299 | 300 | def __compute_maximum_iou_for_layout(layout_1: Layout, layout_2: Layout) -> float: 301 | score = 0.0 302 | (bi, li), (bj, lj) = layout_1, layout_2 303 | N = len(bi) 304 | for l in list(set(li.tolist())): 305 | _bi = bi[np.where(li == l)] 306 | _bj = bj[np.where(lj == l)] 307 | n = len(_bi) 308 | ii, jj = np.meshgrid(range(n), range(n)) 309 | ii, jj = ii.flatten(), jj.flatten() 310 | iou = compute_iou(_bi[ii], _bj[jj]).reshape(n, n) 311 | # note: maximize is supported only when scipy >= 1.4 312 | ii, jj = linear_sum_assignment(iou, maximize=True) 313 | score += iou[ii, jj].sum().item() 314 | return score / N 315 | 316 | 317 | def __compute_maximum_iou(layouts_1_and_2: Tuple[List[Layout]]) -> np.ndarray: 318 | layouts_1, layouts_2 = layouts_1_and_2 319 | N, M = len(layouts_1), len(layouts_2) 320 | ii, jj = np.meshgrid(range(N), range(M)) 321 | ii, jj = ii.flatten(), jj.flatten() 322 | scores = np.asarray( 323 | [ 324 | __compute_maximum_iou_for_layout(layouts_1[i], layouts_2[j]) 325 | for i, j in zip(ii, jj) 326 | ] 327 | ).reshape(N, M) 328 | ii, jj = linear_sum_assignment(scores, maximize=True) 329 | return scores[ii, jj] 330 | 331 | 332 | def __get_cond2layouts(layout_list: List[Layout]) -> Dict[str, List[Layout]]: 333 | out = dict() 334 | for bs, ls in layout_list: 335 | cond_key = str(sorted(ls.tolist())) 336 | if cond_key not in out.keys(): 337 | out[cond_key] = [(bs, ls)] 338 | else: 339 | out[cond_key].append((bs, ls)) 340 | return out 341 | 342 | 343 | def compute_maximum_iou( 344 | layouts_1: List[Layout], 345 | layouts_2: List[Layout], 346 | disable_parallel: bool = DISABLED, 347 | n_jobs: Optional[int] = None, 348 | ): 349 | """ 350 | Computes Maximum IoU [Kikuchi+, ACMMM'21] 351 | """ 352 | c2bl_1 = __get_cond2layouts(layouts_1) 353 | keys_1 = set(c2bl_1.keys()) 354 | c2bl_2 = __get_cond2layouts(layouts_2) 355 | keys_2 = set(c2bl_2.keys()) 356 | keys = list(keys_1.intersection(keys_2)) 357 | args = [(c2bl_1[key], c2bl_2[key]) for key in keys] 358 | # to check actual number of layouts for evaluation 359 | # ans = 0 360 | # for x in args: 361 | # ans += len(x[0]) 362 | if disable_parallel: 363 | scores = [__compute_maximum_iou(a) for a in args] 364 | else: 365 | with multiprocessing.Pool(n_jobs) as p: 366 | scores = p.map(__compute_maximum_iou, args) 367 | scores = np.asarray(list(chain.from_iterable(scores))) 368 | if len(scores) == 0: 369 | return 0.0 370 | else: 371 | return scores.mean().item() 372 | 373 | 374 | def __compute_average_iou(layout: Layout, perceptual: bool = False) -> float: 375 | bbox, _ = layout 376 | N = bbox.shape[0] 377 | if N in [0, 1]: 378 | return 0.0 # no overlap in principle 379 | 380 | ii, jj = np.meshgrid(range(N), range(N)) 381 | ii, jj = ii.flatten(), jj.flatten() 382 | is_non_diag = ii != jj # IoU for diag is always 1.0 383 | ii, jj = ii[is_non_diag], jj[is_non_diag] 384 | 385 | if perceptual: 386 | iou = compute_perceptual_iou(bbox[ii], bbox[jj]) 387 | else: 388 | iou = compute_iou(bbox[ii], bbox[jj]) 389 | 390 | # pick all pairs of overlapped objects 391 | cond = iou > np.finfo(np.float32).eps # to avoid very-small nonzero 392 | # return iou.mean().item() 393 | if len(iou[cond]) > 0: 394 | return iou[cond].mean().item() 395 | else: 396 | return 0.0 397 | 398 | 399 | def compute_average_iou( 400 | layouts: List[Layout], 401 | disable_parallel: bool = DISABLED, 402 | n_jobs: Optional[int] = None, 403 | ) -> Dict[str, float]: 404 | """ 405 | Compute IoU between overlapping objects for each layout. 406 | Note that the lower is better unlike popular IoU. 407 | 408 | Reference: 409 | Variational Transformer Networks for Layout Generation (CVPR2021) 410 | https://arxiv.org/abs/2104.02416 411 | Reference: (perceptual version) 412 | BLT: Bidirectional Layout Transformer for Controllable Layout Generation (ECCV2022) 413 | https://arxiv.org/abs/2112.05112 414 | """ 415 | func1 = partial(__compute_average_iou, perceptual=True) 416 | func2 = partial(__compute_average_iou, perceptual=False) 417 | 418 | # single-thread process for debugging 419 | if disable_parallel: 420 | scores1 = [func1(l) for l in layouts] 421 | scores2 = [func2(l) for l in layouts] 422 | else: 423 | with multiprocessing.Pool(n_jobs) as p1: 424 | scores1 = p1.map(func1, layouts) 425 | with multiprocessing.Pool(n_jobs) as p2: 426 | scores2 = p2.map(func2, layouts) 427 | results = { 428 | "average_iou-BLT": np.array(scores1).mean().item(), 429 | "average_iou-VTN": np.array(scores2).mean().item(), 430 | } 431 | return results 432 | 433 | 434 | def __compute_bbox_sim( 435 | bboxes_1: np.ndarray, 436 | category_1: np.int64, 437 | bboxes_2: np.ndarray, 438 | category_2: np.int64, 439 | C_S: float = 2.0, 440 | C: float = 0.5, 441 | ) -> float: 442 | # bboxes from diffrent categories never match 443 | if category_1 != category_2: 444 | return 0.0 445 | 446 | cx1, cy1, w1, h1 = bboxes_1 447 | cx2, cy2, w2, h2 = bboxes_2 448 | 449 | delta_c = np.sqrt(np.power(cx1 - cx2, 2) + np.power(cy1 - cy2, 2)) 450 | delta_s = np.abs(w1 - w2) + np.abs(h1 - h2) 451 | area = np.minimum(w1 * h1, w2 * h2) 452 | alpha = np.power(np.clip(area, 0.0, None), C) 453 | 454 | weight = alpha * np.power(2.0, -1.0 * delta_c - C_S * delta_s) 455 | return weight 456 | 457 | 458 | def __compute_docsim_between_two_layouts( 459 | layouts_1_layouts_2: Tuple[List[Layout]], 460 | max_diff_thresh: int = 3, 461 | ) -> float: 462 | layouts_1, layouts_2 = layouts_1_layouts_2 463 | bboxes_1, categories_1 = layouts_1 464 | bboxes_2, categories_2 = layouts_2 465 | 466 | N, M = len(bboxes_1), len(bboxes_2) 467 | if N >= M + max_diff_thresh or N <= M - max_diff_thresh: 468 | return 0.0 469 | 470 | ii, jj = np.meshgrid(range(N), range(M)) 471 | ii, jj = ii.flatten(), jj.flatten() 472 | scores = np.asarray( 473 | [ 474 | __compute_bbox_sim( 475 | bboxes_1[i], categories_1[i], bboxes_2[j], categories_2[j] 476 | ) 477 | for i, j in zip(ii, jj) 478 | ] 479 | ).reshape(N, M) 480 | ii, jj = linear_sum_assignment(scores, maximize=True) 481 | 482 | if len(scores[ii, jj]) == 0: 483 | # sometimes, predicted bboxes are somehow filtered. 484 | return 0.0 485 | else: 486 | return scores[ii, jj].mean() 487 | 488 | 489 | def compute_docsim( 490 | layouts_gt: List[Layout], 491 | layouts_generated: List[Layout], 492 | disable_parallel: bool = DISABLED, 493 | n_jobs: Optional[int] = None, 494 | ) -> float: 495 | """ 496 | Compute layout-to-layout similarity and average over layout pairs. 497 | Note that this is different from layouts-to-layouts similarity. 498 | """ 499 | args = list(zip(layouts_gt, layouts_generated)) 500 | if disable_parallel: 501 | scores = [] 502 | for arg in args: 503 | scores.append(__compute_docsim_between_two_layouts(arg)) 504 | else: 505 | with multiprocessing.Pool(n_jobs) as p: 506 | scores = p.map(__compute_docsim_between_two_layouts, args) 507 | return np.array(scores).mean() 508 | 509 | 510 | def _compute_wasserstein_distance_class( 511 | layouts_1: List[Layout], 512 | layouts_2: List[Layout], 513 | n_categories: int = 25, 514 | ) -> float: 515 | categories_1 = np.concatenate([l[1] for l in layouts_1]) 516 | counts = np.array( 517 | [np.count_nonzero(categories_1 == i) for i in range(n_categories)] 518 | ) 519 | prob_1 = counts / np.sum(counts) 520 | 521 | categories_2 = np.concatenate([l[1] for l in layouts_2]) 522 | counts = np.array( 523 | [np.count_nonzero(categories_2 == i) for i in range(n_categories)] 524 | ) 525 | prob_2 = counts / np.sum(counts) 526 | return np.absolute(prob_1 - prob_2).sum() 527 | 528 | 529 | def _compute_wasserstein_distance_bbox( 530 | layouts_1: List[Layout], 531 | layouts_2: List[Layout], 532 | ) -> float: 533 | bboxes_1 = np.concatenate([l[0] for l in layouts_1]).T 534 | bboxes_2 = np.concatenate([l[0] for l in layouts_2]).T 535 | 536 | # simple 1-dimensional wasserstein for (cx, cy, w, h) independently 537 | N = 4 538 | ans = 0.0 539 | for i in range(N): 540 | ans += wasserstein_distance(bboxes_1[i], bboxes_2[i]) 541 | ans /= N 542 | 543 | return ans 544 | 545 | 546 | def compute_wasserstein_distance( 547 | layouts_1: List[Layout], 548 | layouts_2: List[Layout], 549 | n_classes: int = 25, 550 | ) -> Dict[str, float]: 551 | w_class = _compute_wasserstein_distance_class(layouts_1, layouts_2, n_classes) 552 | w_bbox = _compute_wasserstein_distance_bbox(layouts_1, layouts_2) 553 | return { 554 | "wdist_class": w_class, 555 | "wdist_bbox": w_bbox, 556 | } 557 | 558 | 559 | if __name__ == "__main__": 560 | layouts = [ 561 | ( 562 | np.array( 563 | [ 564 | [0.2, 0.2, 0.4, 0.4], 565 | ] 566 | ), 567 | np.zeros((1,)), 568 | ) 569 | ] 570 | print(compute_average_iou(layouts)) 571 | -------------------------------------------------------------------------------- /util/seq_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import List, Tuple 4 | from torch import BoolTensor, FloatTensor, LongTensor 5 | from torch_geometric.utils import to_dense_batch 6 | 7 | def sparse_to_dense( 8 | batch, 9 | device: torch.device = torch.device("cpu"), 10 | remove_canvas: bool = False, 11 | ) -> Tuple[FloatTensor, LongTensor, BoolTensor, BoolTensor]: 12 | batch = batch.to(device) 13 | bbox, _ = to_dense_batch(batch.x, batch.batch) 14 | label, mask = to_dense_batch(batch.y, batch.batch) 15 | 16 | if remove_canvas: 17 | bbox = bbox[:, 1:].contiguous() 18 | label = label[:, 1:].contiguous() - 1 # cancel +1 effect in transform 19 | label = label.clamp(min=0) 20 | mask = mask[:, 1:].contiguous() 21 | 22 | padding_mask = ~mask 23 | return bbox, label, padding_mask, mask 24 | 25 | 26 | def pad_sequence(seq: LongTensor, max_seq_length: int, value) -> LongTensor: 27 | S = seq.shape[1] 28 | new_shape = list(seq.shape) 29 | s = max_seq_length - S 30 | if s > 0: 31 | new_shape[1] = s 32 | pad = torch.full(new_shape, value, dtype=seq.dtype) 33 | new_seq = torch.cat([seq, pad], dim=1) 34 | else: 35 | new_seq = seq 36 | 37 | return new_seq 38 | 39 | 40 | def pad_until(label: LongTensor, bbox: FloatTensor, mask: BoolTensor, max_seq_length: int 41 | ) -> Tuple[LongTensor, FloatTensor, BoolTensor]: 42 | label = pad_sequence(label, max_seq_length, 0) 43 | bbox = pad_sequence(bbox, max_seq_length, 0) 44 | mask = pad_sequence(mask, max_seq_length, False) 45 | return label, bbox, mask 46 | 47 | 48 | def _to(inputs, device): 49 | """ 50 | recursively send tensor to the specified device 51 | """ 52 | outputs = {} 53 | for k, v in inputs.items(): 54 | if isinstance(v, dict): 55 | outputs[k] = _to(v, device) 56 | elif isinstance(v, torch.Tensor): 57 | outputs[k] = v.to(device) 58 | return outputs 59 | 60 | def loader_to_list( 61 | loader: torch.utils.data.dataloader.DataLoader, 62 | ) -> List[Tuple[np.ndarray, np.ndarray]]: 63 | layouts = [] 64 | for batch in loader: 65 | bbox, label, _, mask = sparse_to_dense(batch) 66 | for i in range(len(label)): 67 | valid = mask[i].numpy() 68 | layouts.append((bbox[i].numpy()[valid], label[i].numpy()[valid])) 69 | return layouts -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Optional, Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | from einops import rearrange 7 | from torch import BoolTensor, FloatTensor, LongTensor 8 | 9 | 10 | def set_seed(seed: int): 11 | random.seed(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | 15 | 16 | def convert_xywh_to_ltrb(bbox: Union[np.ndarray, FloatTensor]): 17 | xc, yc, w, h = bbox 18 | x1 = xc - w / 2 19 | y1 = yc - h / 2 20 | x2 = xc + w / 2 21 | y2 = yc + h / 2 22 | return [x1, y1, x2, y2] 23 | 24 | 25 | def batch_topk_mask( 26 | scores: FloatTensor, 27 | topk: LongTensor, 28 | mask: Optional[BoolTensor] = None, 29 | ) -> Tuple[BoolTensor, FloatTensor]: 30 | assert scores.ndim == 2 and topk.ndim == 1 and scores.size(0) == topk.size(0) 31 | if mask is not None: 32 | assert mask.size() == scores.size() 33 | assert (scores.size(1) >= topk).all() 34 | 35 | # ignore scores where mask = False by setting extreme values 36 | if mask is not None: 37 | const = -1.0 * float("Inf") 38 | const = torch.full_like(scores, fill_value=const) 39 | scores = torch.where(mask, scores, const) 40 | 41 | sorted_values, _ = torch.sort(scores, dim=-1, descending=True) 42 | topk = rearrange(topk, "b -> b 1") 43 | 44 | k_th_scores = torch.gather(sorted_values, dim=1, index=topk) 45 | 46 | topk_mask = scores > k_th_scores 47 | return topk_mask, k_th_scores 48 | 49 | 50 | def batch_shuffle_index( 51 | batch_size: int, 52 | feature_length: int, 53 | mask: Optional[BoolTensor] = None, 54 | ) -> LongTensor: 55 | """ 56 | Note: masked part may be shuffled because of unpredictable behaviour of sorting [inf, ..., inf] 57 | """ 58 | if mask: 59 | assert mask.size() == [batch_size, feature_length] 60 | scores = torch.rand((batch_size, feature_length)) 61 | if mask: 62 | scores[~mask] = float("Inf") 63 | _, indices = torch.sort(scores, dim=1) 64 | return indices 65 | 66 | 67 | if __name__ == "__main__": 68 | scores = torch.arange(6).view(2, 3).float() 69 | # topk = torch.arange(2) + 1 70 | topk = torch.full((2,), 3) 71 | mask = torch.full((2, 3), False) 72 | # mask[1, 2] = False 73 | print(batch_topk_mask(scores, topk, mask=mask)) 74 | -------------------------------------------------------------------------------- /util/visualization.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from collections import Counter 3 | from typing import Dict, List, Optional, Tuple, Union 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision.transforms as T 8 | import torchvision.utils as vutils 9 | from einops import rearrange 10 | from numpy.typing import NDArray 11 | from PIL import Image, ImageDraw, ImageFont 12 | from torch import BoolTensor, FloatTensor, LongTensor 13 | 14 | from .util import convert_xywh_to_ltrb 15 | 16 | label_names = dict() 17 | label_names['rico25'] = ("Text", "Image", "Icon", "Text Button", "List Item", "Input", "Background Image", "Card", 18 | "Web View", "Radio Button", "Drawer", "Checkbox", "Advertisement", "Modal", "Pager Indicator", 19 | "Slider", "On/Off Switch", "Button Bar", "Toolbar", "Number Stepper", "Multi-Tab", "Date Picker", 20 | "Map View", "Video", "Bottom Navigation") 21 | 22 | label_names['rico13'] = ("Toolbar", "Image", "Text", "Icon", "Text Button", "Input", "List Item", "Advertisement", 23 | "Pager Indicator", "Web View", "Background Image", "Drawer", "Modal") 24 | 25 | label_names['magazine'] = ('text', 'image', 'headline', 'text-over-image', 'headline-over-image', 'background') 26 | label_names['publaynet'] = ('text', 'title', 'list', 'table', 'figure') 27 | label_names['crello'] = ['coloredBackground', 'imageElement', 'maskElement', 'svgElement', 'textElement'] 28 | color_6 = ((246, 112, 136), (173, 156, 49), (51, 176, 122), (56, 168, 197), (204, 121, 244), (204, 50, 144)) 29 | 30 | def color_extend(n_color): 31 | 32 | colors = ((246, 112, 136), (173, 156, 49), (51, 176, 122), (56, 168, 197), (204, 121, 244)) 33 | batch_size = np.ceil(n_color / 4).astype(int) 34 | 35 | color_batch_1 = np.linspace(colors[0], colors[1], batch_size)[:-1].astype(int) 36 | color_batch_2 = np.linspace(colors[1], colors[2], batch_size)[:-1].astype(int) 37 | color_batch_3 = np.linspace(colors[2], colors[3], batch_size)[:-1].astype(int) 38 | color_batch_4 = np.linspace(colors[3], colors[4], batch_size).astype(int) 39 | colors_long = np.concatenate([color_batch_1, color_batch_2, color_batch_3, color_batch_4]) 40 | colors_long = tuple(map(tuple, colors_long)) 41 | return colors_long 42 | 43 | def convert_layout_to_image( 44 | boxes: FloatTensor, 45 | labels: LongTensor, 46 | colors: List[Tuple[int]], 47 | canvas_size: Optional[Tuple[int]] = (60, 40), 48 | resources: Optional[Dict] = None, 49 | names: Optional[Tuple[str]] = None, 50 | index: bool = False, 51 | **kwargs, 52 | ): 53 | H, W = canvas_size 54 | if names or index: 55 | # font = ImageFont.truetype("LiberationSerif-Regular", W // 10) 56 | font = ImageFont.load_default() 57 | 58 | if resources: 59 | img = resources["img_bg"].resize((W, H)) 60 | else: 61 | img = Image.new("RGB", (int(W), int(H)), color=(255, 255, 255)) 62 | draw = ImageDraw.Draw(img, "RGBA") 63 | 64 | # draw from larger boxes 65 | a = [b[2] * b[3] for b in boxes] 66 | indices = sorted(range(len(a)), key=lambda i: a[i], reverse=True) 67 | 68 | for i in indices: 69 | bbox, label = boxes[i], labels[i] 70 | if isinstance(label, LongTensor): 71 | label = label.item() 72 | 73 | c_fill = colors[label] + (100,) 74 | x1, y1, x2, y2 = convert_xywh_to_ltrb(bbox) 75 | x1, x2 = x1 * (W - 1), x2 * (W - 1) 76 | y1, y2 = y1 * (H - 1), y2 * (H - 1) 77 | 78 | if resources: 79 | patch = resources["cropped_patches"][i] 80 | # round coordinates for exact size match for rendering images 81 | x1, x2 = int(x1), int(x2) 82 | y1, y2 = int(y1), int(y2) 83 | w, h = x2 - x1, y2 - y1 84 | patch = patch.resize((w, h)) 85 | img.paste(patch, (x1, y1)) 86 | else: 87 | draw.rectangle([x1, y1, x2, y2], outline=colors[label], fill=c_fill) 88 | if names: 89 | # draw.text((x1, y1), names[label], colors[label], font=font) 90 | draw.text((x1, y1), names[label], "black", font=font) 91 | elif index: 92 | draw.text((x1, y1), str(int(i % (len(labels)/2))), "black", font=font) 93 | 94 | return img 95 | 96 | 97 | def save_image( 98 | batch_boxes: FloatTensor, 99 | batch_labels: LongTensor, 100 | batch_mask: BoolTensor, 101 | out_path: Optional[Union[pathlib.PosixPath, str]] = None, 102 | canvas_size: Optional[Tuple[int]] = (360, 240), 103 | nrow: Optional[int] = None, 104 | batch_resources: Optional[Dict] = None, 105 | use_grid: bool = True, 106 | draw_label: bool = False, 107 | draw_index: bool = False, 108 | dataset: str = 'publaynet' 109 | ): 110 | # batch_boxes: [B, N, 4] 111 | # batch_labels: [B, N] 112 | # batch_mask: [B, N] 113 | 114 | assert dataset in ['rico13', 'rico25', 'publaynet', 'magazine', 'crello'] 115 | 116 | if isinstance(out_path, pathlib.PosixPath): 117 | out_path = str(out_path) 118 | 119 | if dataset == 'rico13': 120 | colors = color_extend(13) 121 | elif dataset == 'rico25': 122 | colors = color_extend(25) 123 | elif dataset in ['publaynet', 'magazine', 'crello']: 124 | colors = color_6 125 | 126 | if not draw_label: 127 | names = None 128 | else: 129 | names = label_names[dataset] 130 | 131 | # raise Exception('dataset must be rico or publaynet') 132 | 133 | imgs = [] 134 | B = batch_boxes.size(0) 135 | to_tensor = T.ToTensor() 136 | for i in range(B): 137 | mask_i = batch_mask[i] 138 | boxes = batch_boxes[i][mask_i] 139 | labels = batch_labels[i][mask_i] 140 | if batch_resources: 141 | resources = {k: v[i] for (k, v) in batch_resources.items()} 142 | img = convert_layout_to_image(boxes, labels, colors, canvas_size, resources, names=names, index=draw_index) 143 | else: 144 | img = convert_layout_to_image(boxes, labels, colors, canvas_size, names=names, index=draw_index) 145 | imgs.append(to_tensor(img)) 146 | image = torch.stack(imgs) 147 | 148 | if nrow is None: 149 | nrow = int(np.ceil(np.sqrt(B))) 150 | 151 | if out_path: 152 | vutils.save_image(image, out_path, normalize=False, nrow=nrow) 153 | else: 154 | if use_grid: 155 | return torch_to_numpy_image( 156 | vutils.make_grid(image, normalize=False, nrow=nrow) 157 | ) 158 | else: 159 | return image 160 | 161 | 162 | def save_label( 163 | labels: Union[LongTensor, list], 164 | names: List[str], 165 | colors: List[Tuple[int]], 166 | out_path: Optional[Union[pathlib.PosixPath, str]] = None, 167 | **kwargs 168 | # canvas_size: Optional[Tuple[int]] = (60, 40), 169 | ): 170 | space, pad = 12, 12 171 | x_offset, y_offset = 500, 100 172 | 173 | img = Image.new("RGBA", (1000, 1000)) 174 | # fnt = ImageFont.truetype("LiberationSerif-Regular", 40) 175 | fnt = ImageFont.load_default() 176 | # fnt_sm = ImageFont.truetype("LiberationSerif-Regular", 32) 177 | fnt_sm = ImageFont.load_default() 178 | d = ImageDraw.Draw(img) 179 | 180 | if isinstance(labels, LongTensor): 181 | labels = labels.tolist() 182 | 183 | cnt = Counter(labels) 184 | for l in range(len(colors)): 185 | if l not in cnt.keys(): 186 | continue 187 | 188 | text = names[l] 189 | use_multiline = False 190 | 191 | if cnt[l] > 1: 192 | add_width = d.textsize(f" × {cnt[l]}", font=fnt)[0] + pad 193 | else: 194 | add_width = 0 195 | 196 | width = d.textsize(text, font=fnt)[0] 197 | bbox = d.textbbox( 198 | (x_offset - (width + add_width) / 2, y_offset), text, font=fnt 199 | ) 200 | bbox = (bbox[0] - pad, bbox[1] - pad, bbox[2] + pad, bbox[3] + pad) 201 | d.rectangle(bbox, fill=colors[l]) 202 | 203 | _x_offset = x_offset - (width + add_width) / 2 204 | if cnt[l] > 1: 205 | d.text( 206 | (_x_offset + width + pad, y_offset), 207 | f" × {cnt[l]}", 208 | font=fnt, 209 | fill=(0, 0, 0), 210 | ) 211 | 212 | d.text((_x_offset, y_offset), text, font=fnt, fill=(255, 255, 255)) 213 | 214 | y_offset = y_offset + bbox[3] - bbox[1] + space 215 | 216 | # crop 217 | bbox = img.getbbox() 218 | bbox = (bbox[0] - pad, bbox[1] - pad, bbox[2] + pad, bbox[3] + pad) 219 | img = img.crop(bbox) 220 | 221 | # add white background 222 | out = Image.new("RGB", img.size, color=(255, 255, 255)) 223 | out.paste(img, mask=img) 224 | # pil_size = canvas_size[::-1] # (H, W) -> (W, H) 225 | # out = out.resize(pil_size) 226 | if out_path: 227 | out.save(out_path) 228 | else: 229 | return np.array(out) 230 | 231 | 232 | def save_label_with_size( 233 | labels: LongTensor, 234 | boxes: FloatTensor, 235 | names: List[str], 236 | colors: List[Tuple[int]], 237 | out_path: Optional[Union[pathlib.PosixPath, str]] = None, 238 | # canvas_size: Optional[Tuple[int]] = (60, 40), 239 | **kwargs, 240 | ): 241 | space, pad = 12, 12 242 | x_offset, y_offset = 500, 100 243 | B = 32 244 | 245 | img = Image.new("RGBA", (1000, 1000)) 246 | # fnt = ImageFont.truetype("LiberationSerif-Regular", 40) 247 | fnt = ImageFont.load_default() 248 | # fnt_sm = ImageFont.truetype("LiberationSerif-Regular", 32) 249 | fnt_sm = ImageFont.load_default() 250 | d = ImageDraw.Draw(img) 251 | 252 | for i, l in enumerate(labels): 253 | w, h = [int(x) for x in (boxes[i].clip(1 / B, 1.0) * B).long()[2:]] 254 | text = f"{names[l]} ({w},{h})" 255 | add_width = 0 256 | 257 | width = d.textsize(text, font=fnt)[0] 258 | bbox = d.textbbox( 259 | (x_offset - (width + add_width) / 2, y_offset), text, font=fnt 260 | ) 261 | bbox = (bbox[0] - pad, bbox[1] - pad, bbox[2] + pad, bbox[3] + pad) 262 | d.rectangle(bbox, fill=colors[l]) 263 | 264 | _x_offset = x_offset - (width + add_width) / 2 265 | d.text((_x_offset, y_offset), text, font=fnt, fill=(255, 255, 255)) 266 | y_offset = y_offset + bbox[3] - bbox[1] + space 267 | 268 | # crop 269 | bbox = img.getbbox() 270 | bbox = (bbox[0] - pad, bbox[1] - pad, bbox[2] + pad, bbox[3] + pad) 271 | img = img.crop(bbox) 272 | 273 | # add white background 274 | out = Image.new("RGB", img.size, color=(255, 255, 255)) 275 | out.paste(img, mask=img) 276 | # pil_size = canvas_size[::-1] # (H, W) -> (W, H) 277 | # out = out.resize(pil_size) 278 | if out_path: 279 | out.save(out_path) 280 | else: 281 | return np.array(out) 282 | 283 | 284 | def torch_to_numpy_image(input_th: FloatTensor) -> np.ndarray: 285 | """ 286 | Args 287 | input_th: (C, H, W), in [0.0, 1/0], torch image 288 | Returns 289 | output_npy: (H, W, C), in {0, 1, ..., 255}, numpy image 290 | """ 291 | x = (input_th * 255.0).clamp(0, 255) 292 | x = rearrange(x, "c h w -> h w c") 293 | output_npy = x.numpy().astype(np.uint8) 294 | return output_npy 295 | 296 | 297 | def save_relation( 298 | label_with_canvas: LongTensor, 299 | edge_attr: LongTensor, 300 | names: List[str], 301 | colors: List[Tuple[int]], 302 | out_path: Optional[Union[pathlib.PosixPath, str]] = None, 303 | **kwargs, 304 | ): 305 | from trainer.data.util import ( # lazy load to avoid circular import 306 | RelLoc, 307 | RelSize, 308 | get_rel_text, 309 | ) 310 | 311 | pairs, triplets = [], [] 312 | relations = list(RelSize) + list(RelLoc) 313 | for rel_value in relations: 314 | if rel_value in [RelSize.UNKNOWN, RelLoc.UNKNOWN]: 315 | continue 316 | cond = edge_attr & 1 << rel_value 317 | ii, jj = np.where(cond.numpy() > 0) 318 | for i, j in zip(ii, jj): 319 | li = label_with_canvas[i] - 1 320 | lj = label_with_canvas[j] - 1 321 | 322 | if i == 0: 323 | rel = get_rel_text(rel_value, canvas=True) 324 | pairs.append((lj, rel, None)) 325 | else: 326 | rel = get_rel_text(rel_value, canvas=False) 327 | triplets.append((li, rel, lj)) 328 | 329 | triplets = pairs + triplets 330 | 331 | space, pad = 6, 6 332 | img = Image.new("RGBA", (1000, 1000)) 333 | fnt = ImageFont.truetype("LiberationSerif-Regular", 20) 334 | fnt_sm = ImageFont.truetype("LiberationSerif-Regular", 16) 335 | d = ImageDraw.Draw(img) 336 | 337 | def draw_text(x_offset, y_offset, text, color=None, first=False): 338 | if color is None: 339 | d.text((x_offset, y_offset), text, font=fnt, fill=(0, 0, 0)) 340 | x_offset = x_offset + d.textsize(text, font=fnt)[0] + space 341 | else: 342 | x_offset = x_offset + pad 343 | 344 | use_multiline = False 345 | bbox = d.textbbox((x_offset, y_offset), text, font=fnt) 346 | if bbox[2] - bbox[0] > 120 and " " in text: 347 | use_multiline = True 348 | h_old = d.textsize(text, font=fnt)[1] 349 | text = text.replace(" ", "\n") 350 | h_new = d.multiline_textsize(text, font=fnt_sm)[1] 351 | h_diff = h_new - h_old 352 | if first: 353 | y_offset = y_offset + h_diff / 2 354 | bbox = d.multiline_textbbox( 355 | (x_offset, y_offset - h_diff / 2), text, font=fnt_sm 356 | ) 357 | 358 | bbox = (bbox[0] - pad, bbox[1] - pad, bbox[2] + pad, bbox[3] + pad) 359 | d.rectangle(bbox, fill=color) 360 | 361 | if use_multiline: 362 | d.multiline_text( 363 | (x_offset, y_offset - h_diff / 2), 364 | text, 365 | align="center", 366 | font=fnt_sm, 367 | fill=(255, 255, 255), 368 | ) 369 | text_width = d.multiline_textsize(text, font=fnt_sm)[0] 370 | else: 371 | d.text((x_offset, y_offset), text, font=fnt, fill=(255, 255, 255)) 372 | text_width = d.textsize(text, font=fnt)[0] 373 | 374 | x_offset = x_offset + text_width + space + pad 375 | return x_offset, y_offset 376 | 377 | for i, (l1, rel, l2) in enumerate(triplets): 378 | x_offset, y_offset = 20, 40 * (i + 1) 379 | x_offset, y_offset = draw_text( 380 | x_offset, y_offset, names[l1], colors[l1], first=True 381 | ) 382 | x_offset, y_offset = draw_text(x_offset, y_offset, rel) 383 | if l2 is not None: 384 | draw_text(x_offset, y_offset, names[l2], colors[l2]) 385 | 386 | # crop 387 | bbox = img.getbbox() 388 | if bbox is not None: 389 | bbox = (bbox[0] - pad, bbox[1] - pad, bbox[2] + pad, bbox[3] + pad) 390 | img = img.crop(bbox) 391 | 392 | # add white background 393 | out = Image.new("RGB", img.size, color=(255, 255, 255)) 394 | out.paste(img, mask=img) 395 | 396 | if out_path: 397 | out.save(out_path) 398 | else: 399 | return np.array(out) 400 | 401 | 402 | def save_gif( 403 | images: List[NDArray], 404 | out_path: str, 405 | **kwargs, 406 | ): 407 | assert images[0].ndim == 4 408 | to_pil = T.ToPILImage() 409 | for i in range(len(images[0])): 410 | tmp_images = [to_pil(image[i]) for image in images] 411 | tmp_images[0].save( 412 | # f"tmp/animation/{i}.gif", 413 | out_path.format(i), 414 | save_all=True, 415 | append_images=tmp_images[1:], 416 | optimize=False, 417 | duration=200, 418 | loop=0, 419 | ) 420 | 421 | 422 | 423 | if __name__=="__main__": 424 | 425 | print(len(rico_labels)) 426 | --------------------------------------------------------------------------------