├── LICENSE
├── README.md
├── git_demo.png
├── model_diffusion.py
├── plot_train.py
├── requirements.txt
├── test.py
├── train.py
├── train_fid_model.py
└── util
├── backbone.py
├── constraint.py
├── data_util.py
├── datasets
├── __init__.py
├── base.py
├── crello.py
├── dataset.py
├── load_data.py
├── magazine.py
├── publaynet.py
└── rico.py
├── diffusion_utils.py
├── ema.py
├── metric.py
├── seq_util.py
├── util.py
└── visualization.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Jian Chen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Towards Aligned Layout Generation via Diffusion Model with Aesthetic Constraints
2 | source code of the layout generation model, [LACE](https://arxiv.org/abs/2402.04754).
3 | 
4 | ## 1. Installation
5 | ### 1.1 Prepare environment
6 | Install package for python 3.9 or later version:
7 | ```
8 | conda create --name LACE python=3.9
9 | conda activate LACE
10 | python -m pip install -r requirements.txt
11 | ```
12 |
13 | ### 1.2 Checkpoints
14 | Download the **trained checkpoints** for diffusion model and FID model at [Hugging Face](https://huggingface.co/datasets/puar-playground/LACE/tree/main) or through command line:
15 | ```
16 | wget https://huggingface.co/datasets/puar-playground/LACE/resolve/main/model.tar.gz
17 | wget https://huggingface.co/datasets/puar-playground/LACE/resolve/main/fid.tar.gz
18 | tar -xvzf model.tar.gz
19 | tar -xvzf fid.tar.gz
20 | ```
21 | Model hyper-parameter:
22 | for Publaynet: `--dim_transformer 1024 --nhead 16 --nlayer 4 --feature_dim 2048`
23 | for Rico13 and Rico25: `--dim_transformer 512 --nhead 16 --nlayer 4 --feature_dim 2048`
24 |
25 | ### 1.3 Datasets
26 | The datasets are also available at:
27 | ```
28 | wget https://huggingface.co/datasets/puar-playground/LACE/resolve/main/datasets.tar.gz
29 | tar -xvzf datasets.tar.gz
30 | ```
31 | Alternatively, you can download from the source and prepare each dataset as following:
32 | * [PubLayNet](https://developer.ibm.com/exchanges/data/all/publaynet/): Download the `labels.tar.gz` and decompress to `./dataset/publaynet-max25/raw` folder.
33 | * [Rico](https://www.kaggle.com/datasets/onurgunes1993/rico-dataset): Download the `rico_dataset_v0.1_semantic_annotations.zip` and decompress to `./dataset/rico25-max25/raw` folder.
34 |
35 | When the dataset is initialized for the first time, a new folder callled `processed` will be created at e.g., `./dataset/magazine-max25/processed` containing the formatted dataset for future uses. Training split of smaller dataset: Rico and Magazine will be duplicated to reach a reasonable epoch size.
36 |
37 |
38 |
39 | ## 2. Testing
40 | Run python script `test.py` to test. Please run `python test.py -h` to see detailed explaination.
41 | For PubLayNet:
42 | ```
43 | python test.py --dataset publaynet --experiment all --device cuda:0 --dim_transformer 1024 --nhead 16 --batch_size 2048 --beautify
44 | ```
45 | For Rico:
46 | ```
47 | python test.py --dataset rico25 --experiment all --device cuda:0 --dim_transformer 512 --nhead 16 --batch_size 2048 --beautify
48 | ```
49 |
50 |
51 | ## 3. Training
52 | Run python script `train_diffusion.py` to train.
53 | The script takes several command line arguments. Please run `python train_diffusion.py -h` to see detailed explaination.
54 | Example command for training:
55 | ```
56 | python train.py --device cuda:1 --dataset rico25 --no-load_pre --lr 1e-6 --n_save_epoch 10
57 | ```
58 |
59 | ## Reference
60 | ```
61 | @inproceedings{
62 | chen2024towards,
63 | title={Towards Aligned Layout Generation via Diffusion Model with Aesthetic Constraints},
64 | author={Jian Chen and Ruiyi Zhang and Yufan Zhou and Changyou Chen},
65 | booktitle={The Twelfth International Conference on Learning Representations},
66 | year={2024},
67 | url={https://openreview.net/forum?id=kJ0qp9Xdsh}
68 | }
69 | ```
70 |
--------------------------------------------------------------------------------
/git_demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/puar-playground/LACE/3df36879a1e80cce58affa4aadeeb768f676c7f1/git_demo.png
--------------------------------------------------------------------------------
/model_diffusion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from util.diffusion_utils import *
4 | import torch.nn.functional as F
5 | from typing import Dict, List, Optional, Tuple, Union
6 | from util.backbone import TransformerEncoder
7 | from util.visualization import save_image
8 | import matplotlib.pyplot as plt
9 | import matplotlib
10 | # matplotlib.use('TkAgg')
11 | torch.manual_seed(1)
12 |
13 |
14 | class Diffusion(nn.Module):
15 | def __init__(self, num_timesteps=1000, nhead=8, feature_dim=2048, dim_transformer=512, seq_dim=10, num_layers=4, device='cuda',
16 | beta_schedule='cosine', ddim_num_steps=50, condition='None'):
17 | super().__init__()
18 | self.device = device
19 | self.num_timesteps = num_timesteps
20 | betas = make_beta_schedule(schedule=beta_schedule, num_timesteps=self.num_timesteps, start=0.0001, end=0.02)
21 | betas = self.betas = betas.float().to(self.device)
22 | self.betas_sqrt = torch.sqrt(betas)
23 | alphas = 1.0 - betas
24 | self.alphas = alphas
25 | self.one_minus_betas_sqrt = torch.sqrt(alphas)
26 | self.alphas_cumprod = alphas.cumprod(dim=0)
27 | self.alphas_bar_sqrt = torch.sqrt(self.alphas_cumprod)
28 | self.one_minus_alphas_bar_sqrt = torch.sqrt(1 - self.alphas_cumprod)
29 | alphas_cumprod_prev = torch.cat([torch.ones(1).to(self.device), self.alphas_cumprod[:-1]], dim=0)
30 | self.alphas_cumprod_prev = alphas_cumprod_prev
31 | self.posterior_mean_coeff_1 = (betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - self.alphas_cumprod))
32 | self.posterior_mean_coeff_2 = (torch.sqrt(alphas) * (1 - alphas_cumprod_prev) / (1 - self.alphas_cumprod))
33 | posterior_variance = (betas * (1.0 - alphas_cumprod_prev) / (1.0 - self.alphas_cumprod))
34 | self.posterior_variance = posterior_variance
35 | self.logvar = betas.log()
36 |
37 | self.condition = condition
38 | self.seq_dim = seq_dim
39 | self.num_class = seq_dim - 4
40 |
41 | self.model = TransformerEncoder(num_layers=num_layers, dim_seq=seq_dim, dim_transformer=dim_transformer, nhead=nhead,
42 | dim_feedforward=feature_dim, diffusion_step=num_timesteps, device=device)
43 |
44 | self.ddim_num_steps = ddim_num_steps
45 | self.make_ddim_schedule(ddim_num_steps)
46 |
47 | def make_ddim_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.):
48 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
49 | num_ddpm_timesteps=self.num_timesteps)
50 |
51 | assert self.alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
52 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
53 |
54 | self.register_buffer('sqrt_alphas_cumprod', to_torch(torch.sqrt(self.alphas_cumprod)))
55 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(torch.sqrt(1. - self.alphas_cumprod)))
56 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(torch.log(1. - self.alphas_cumprod)))
57 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(torch.sqrt(1. / self.alphas_cumprod)))
58 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(torch.sqrt(1. / self.alphas_cumprod - 1)))
59 |
60 | # ddim sampling parameters
61 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=self.alphas_cumprod,
62 | ddim_timesteps=self.ddim_timesteps,
63 | eta=ddim_eta)
64 | self.register_buffer('ddim_sigmas', ddim_sigmas)
65 | self.register_buffer('ddim_alphas', ddim_alphas)
66 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
67 | self.register_buffer('ddim_sqrt_one_minus_alphas', torch.sqrt(1. - ddim_alphas))
68 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
69 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
70 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
71 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
72 |
73 | def load_diffusion_net(self, net_state_dict):
74 | # new_states = dict()
75 | # for k in net_state_dict.keys():
76 | # if 'layer_out' not in k and 'layer_in' not in k:
77 | # new_states[k] = net_state_dict[k]
78 | self.model.load_state_dict(net_state_dict, strict=True)
79 |
80 | def sample_t(self, size=(1,), t_max=None):
81 | """Samples batches of time steps to use."""
82 | if t_max is None:
83 | t_max = int(self.num_timesteps) - 1
84 |
85 | t = torch.randint(low=0, high=t_max, size=size, device=self.device)
86 |
87 | return t.to(self.device)
88 |
89 | def forward_t(self, l_0_batch, t, real_mask, reparam=False):
90 |
91 | batch_size = l_0_batch.shape[0]
92 | e = torch.randn_like(l_0_batch).to(l_0_batch.device)
93 |
94 | l_t_noise = q_sample(l_0_batch, self.alphas_bar_sqrt,
95 | self.one_minus_alphas_bar_sqrt, t, noise=e)
96 |
97 | # cond c
98 | l_t_input_c = l_0_batch.clone()
99 | l_t_input_c[:, :, self.num_class:] = l_t_noise[:, :, self.num_class:]
100 |
101 | # cond cwh
102 | l_t_input_cwh = l_0_batch.clone()
103 | l_t_input_cwh[:, :, self.num_class:self.num_class+2] = l_t_noise[:, :, self.num_class:self.num_class+2]
104 |
105 | # cond complete
106 | fix_mask = rand_fix(batch_size, real_mask, ratio=0.2)
107 | l_t_input_complete = l_t_noise.clone()
108 | l_t_input_complete[fix_mask] = l_0_batch[fix_mask]
109 |
110 | l_t_input_all = torch.cat([l_t_noise, l_t_input_c, l_t_input_cwh, l_t_input_complete], dim=0)
111 | e_all = torch.cat([e, e, e, e], dim=0)
112 | t_all = torch.cat([t, t, t, t], dim=0)
113 |
114 | eps_theta = self.model(l_t_input_all, timestep=t_all)
115 |
116 | if reparam:
117 | sqrt_one_minus_alpha_bar_t = extract(self.one_minus_alphas_bar_sqrt, t_all, l_t_input_all)
118 | sqrt_alpha_bar_t = (1 - sqrt_one_minus_alpha_bar_t.square()).sqrt()
119 | l_0_generate_reparam = 1 / sqrt_alpha_bar_t * (l_t_input_all - eps_theta * sqrt_one_minus_alpha_bar_t).to(self.device)
120 |
121 | return eps_theta, e_all, l_0_generate_reparam
122 | else:
123 | return eps_theta, e_all, None
124 |
125 | def reverse(self, batch_size, only_last_sample=True, stochastic=True):
126 |
127 | self.model.eval()
128 | layout_t_0 = p_sample_loop(self.model, batch_size,
129 | self.num_timesteps, self.alphas,
130 | self.one_minus_alphas_bar_sqrt,
131 | only_last_sample=only_last_sample, stochastic=stochastic)
132 |
133 | bbox, label, mask = self.finalize(layout_t_0)
134 |
135 | return bbox, label, mask
136 |
137 | def reverse_ddim(self, batch_size=4, stochastic=True, save_inter=False, max_len=25):
138 |
139 | self.model.eval()
140 | layout_t_0, intermediates = ddim_sample_loop(self.model, batch_size, self.ddim_timesteps, self.ddim_alphas,
141 | self.ddim_alphas_prev, self.ddim_sigmas, stochastic=stochastic,
142 | seq_len=max_len, seq_dim=self.seq_dim)
143 |
144 | bbox, label, mask = self.finalize(layout_t_0, self.num_class)
145 |
146 | if not save_inter:
147 | return bbox, label, mask
148 |
149 | else:
150 | for i, layout_t in enumerate(intermediates['y_inter']):
151 | bbox, label, mask = self.finalize(layout_t, self.num_class)
152 | a = save_image(bbox, label, mask, draw_label=True)
153 | plt.figure(figsize=[15, 20])
154 | plt.imshow(a)
155 | plt.tight_layout()
156 | plt.savefig(f'./plot/inter_{i}.png')
157 | plt.close()
158 |
159 | return bbox, label, mask
160 |
161 |
162 | @staticmethod
163 | def finalize(layout, num_class):
164 | layout[:, :, num_class:] = torch.clamp(layout[:, :, num_class:], min=-1, max=1) / 2 + 0.5
165 | bbox = layout[:, :, num_class:]
166 | label = torch.argmax(layout[:, :, :num_class], dim=2)
167 | mask = (label != num_class-1).clone().detach()
168 |
169 | return bbox, label, mask
170 |
171 | def conditional_reverse_ddim(self, real_layout, cond='c', ratio=0.2, stochastic=True):
172 |
173 | self.model.eval()
174 | layout_t_0, intermediates = \
175 | ddim_cond_sample_loop(self.model, real_layout, self.ddim_timesteps, self.ddim_alphas,
176 | self.ddim_alphas_prev, self.ddim_sigmas, stochastic=stochastic, cond=cond,
177 | ratio=ratio)
178 |
179 | bbox, label, mask = self.finalize(layout_t_0, self.num_class)
180 |
181 | return bbox, label, mask
182 |
183 | def refinement_reverse_ddim(self, noisy_layout):
184 | self.model.eval()
185 | layout_t_0, intermediates = \
186 | ddim_refine_sample_loop(self.model, noisy_layout, self.ddim_timesteps, self.ddim_alphas,
187 | self.ddim_alphas_prev, self.ddim_sigmas)
188 |
189 | bbox, label, mask = self.finalize(layout_t_0, self.num_class)
190 |
191 | return bbox, label, mask
192 |
193 |
194 |
195 | if __name__ == "__main__":
196 |
197 | model = Diffusion(num_timesteps=1000, nhead=8, dim_transformer=1024,
198 | feature_dim=2048, seq_dim=10, num_layers=4,
199 | device='cpu', ddim_num_steps=200, embed_type='pos')
200 |
201 | print(pow(model.one_minus_alphas_bar_sqrt[201], 2))
202 |
203 |
204 | print(sum(model.ddim_timesteps <= 201))
205 | timesteps = model.ddim_timesteps
206 | total_steps = sum(model.ddim_timesteps <= 201)
207 | time_range = np.flip(timesteps[:total_steps])
208 | print(time_range)
209 |
--------------------------------------------------------------------------------
/plot_train.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from tqdm import tqdm
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | import matplotlib.pyplot as plt
8 | import matplotlib
9 | matplotlib.use('TkAgg')
10 | from util.datasets.load_data import init_dataset
11 | from util.visualization import save_image
12 | from util.seq_util import sparse_to_dense, pad_until
13 | import argparse
14 |
15 | if __name__ == "__main__":
16 |
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument("--nepoch", default=500, help="number of training epochs", type=int)
19 | parser.add_argument("--batch_size", default=4, help="batch_size", type=int)
20 | parser.add_argument("--device", default='cuda', help="which GPU to use", type=str)
21 | parser.add_argument("--num_workers", default=1, help="num_workers", type=int)
22 | parser.add_argument("--num_plot", default=50, help="num_workers", type=int)
23 | parser.add_argument("--dataset", default='magazine',
24 | help="choose from [publaynet, rico13, rico25]", type=str)
25 | parser.add_argument("--data_dir", default='./datasets', help="dir of datasets", type=str)
26 | args = parser.parse_args()
27 |
28 | # prepare data
29 | train_dataset, train_loader = init_dataset(args.dataset, args.data_dir, batch_size=args.batch_size,
30 | split='train', shuffle=False)
31 | if not os.path.exists(f'./plot/{args.dataset}_train/'):
32 | os.mkdir(f'./plot/{args.dataset}_train/')
33 |
34 | with tqdm(enumerate(train_loader, 1), total=args.num_plot, desc=f'load data',
35 | ncols=120) as pbar:
36 |
37 | for i, data in pbar:
38 |
39 | if i > args.num_plot:
40 | break
41 |
42 | bbox, label, _, mask = sparse_to_dense(data)
43 | label, bbox, mask = pad_until(label, bbox, mask, max_seq_length=25)
44 |
45 | a = save_image(bbox, label, mask, draw_label=True, dataset=f'{args.dataset}')
46 | plt.figure(figsize=[15, 20])
47 | plt.imshow(a)
48 | plt.tight_layout()
49 | plt.savefig(f'./plot/{args.dataset}_train/{i}.png')
50 | plt.close()
51 |
52 |
53 |
54 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops==0.6.1
2 | fsspec==2023.4.0
3 | imageio==2.28.0
4 | matplotlib==3.7.1
5 | numpy==1.24.3
6 | omegaconf==2.3.0
7 | Pillow==10.0.0
8 | prdc==0.2
9 | pycocotools==2.0.6
10 | pytorch_fid==0.3.0
11 | scipy==1.10.1
12 | seaborn==0.12.2
13 | torch==2.0.0
14 | torch_geometric==2.3.1
15 | torchvision==0.15.1
16 | tqdm==4.65.0
17 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | from fid.model import load_fidnet_v3
3 | from util.metric import compute_generative_model_scores, compute_maximum_iou, compute_overlap, compute_alignment
4 | import pickle as pk
5 | from tqdm import tqdm
6 | from util.datasets.load_data import init_dataset
7 | from util.visualization import save_image
8 | from util.constraint import *
9 | import matplotlib.pyplot as plt
10 | import matplotlib
11 | # matplotlib.use('TkAgg')
12 | from util.seq_util import sparse_to_dense, loader_to_list, pad_until
13 | import argparse
14 | from model_diffusion import Diffusion
15 |
16 |
17 | def test_fid_feat(dataset_name, device='cuda', batch_size=20):
18 |
19 | if os.path.exists(f'./fid/feature/fid_feat_test_{dataset_name}.pk'):
20 | feats_test = pk.load(open(f'./fid/feature/fid_feat_test_{dataset_name}.pk', 'rb'))
21 | return feats_test
22 |
23 | # prepare dataset
24 | main_dataset, main_dataloader = init_dataset(dataset_name, './datasets', batch_size=batch_size,
25 | split='test', shuffle=False, transform=None)
26 |
27 | fid_model = load_fidnet_v3(main_dataset, './fid/FIDNetV3', device=device)
28 | feats_test = []
29 |
30 | with tqdm(enumerate(main_dataloader), total=len(main_dataloader), desc=f'Get feature for FID',
31 | ncols=200) as pbar:
32 |
33 | for i, data in pbar:
34 |
35 | bbox, label, _, mask = sparse_to_dense(data)
36 | label, bbox, mask = label.to(device), bbox.to(device), mask.to(device)
37 | padding_mask = ~mask
38 |
39 | with torch.set_grad_enabled(False):
40 | feat = fid_model.extract_features(bbox, label, padding_mask)
41 | feats_test.append(feat.detach().cpu())
42 |
43 | pk.dump(feats_test, open(f'./fid/feature/fid_feat_test_{dataset_name}.pk', 'wb'))
44 |
45 | return feats_test
46 |
47 |
48 | def test_layout_uncond(model, batch_size=128, dataset_name='publaynet', test_plot=False,
49 | save_dir='./plot/test', beautify=False):
50 |
51 | model.eval()
52 | device = model.device
53 | n_batch_dict = {'publaynet': int(44 * 256 / batch_size), 'rico13': int(17 * 256 / batch_size),
54 | 'rico25': int(17 * 256 / batch_size), 'magazine': int(512 / batch_size),
55 | 'crello': int(2560 / batch_size)}
56 | n_batch = n_batch_dict[dataset_name]
57 |
58 | # prepare dataset
59 | main_dataset, _ = init_dataset(dataset_name, './datasets', batch_size=batch_size, split='test')
60 |
61 | fid_model = load_fidnet_v3(main_dataset, './fid/FIDNetV3', device=device)
62 | feats_test = test_fid_feat(dataset_name, device=device, batch_size=20)
63 | feats_generate = []
64 |
65 | align_sum = 0
66 | overlap_sum = 0
67 | with torch.no_grad():
68 | for i in tqdm(range(n_batch), desc='uncond testing', ncols=200, total=n_batch):
69 | bbox_generated, label, mask = model.reverse_ddim(batch_size=batch_size, stochastic=True, save_inter=False)
70 | if beautify and dataset_name=='publaynet':
71 | bbox_generated, mask = post_process(bbox_generated, mask, w_o=1)
72 | elif beautify and (dataset_name=='rico25' or dataset_name=='rico13'):
73 | bbox_generated, mask = post_process(bbox_generated, mask, w_o=0)
74 |
75 | padding_mask = ~mask
76 |
77 | label[mask == False] = 0
78 |
79 | if torch.isnan(bbox_generated[0, 0, 0]):
80 | print('not a number error')
81 | return None
82 |
83 | # accumulate align and overlap
84 | align_norm = compute_alignment(bbox_generated, mask)
85 | align_sum += torch.mean(align_norm)
86 | overlap_score = compute_overlap(bbox_generated, mask)
87 | overlap_sum += torch.mean(overlap_score)
88 |
89 |
90 | with torch.set_grad_enabled(False):
91 | feat = fid_model.extract_features(bbox_generated, label, padding_mask)
92 | feats_generate.append(feat.cpu())
93 |
94 | if test_plot and i <= 10:
95 | img = save_image(bbox_generated[:9], label[:9], mask[:9], draw_label=False, dataset=dataset_name)
96 | plt.figure(figsize=[12, 12])
97 | plt.imshow(img)
98 | plt.tight_layout()
99 | plt.savefig(os.path.join(save_dir, f'{dataset_name}_{i}.png'))
100 | # plt.close()
101 |
102 | result = compute_generative_model_scores(feats_test, feats_generate)
103 | fid = result['fid']
104 |
105 | align_final = 100 * align_sum / n_batch
106 | overlap_final = 100 * overlap_sum / n_batch
107 |
108 | print(f'uncond, align: {align_final}, fid: {fid}, overlap: {overlap_final}')
109 |
110 | return align_final, fid, overlap_final
111 |
112 |
113 | def test_layout_cond(model, batch_size=256, cond='c', dataset_name='publaynet', seq_dim=10,
114 | test_plot=False, save_dir='./plot/test', beautify=False):
115 |
116 | assert cond in {'c', 'cwh', 'complete'}
117 | model.eval()
118 | device = model.device
119 |
120 | # prepare dataset
121 | main_dataset, main_dataloader = init_dataset(dataset_name, './datasets', batch_size=batch_size,
122 | split='test', shuffle=False, transform=None)
123 |
124 | layouts_main = loader_to_list(main_dataloader)
125 | layout_generated = []
126 |
127 | fid_model = load_fidnet_v3(main_dataset, './fid/FIDNetV3', device=device)
128 | feats_test = test_fid_feat(dataset_name, device=device, batch_size=20)
129 | feats_generate = []
130 |
131 | align_sum = 0
132 | overlap_sum = 0
133 | with torch.no_grad():
134 |
135 | with tqdm(enumerate(main_dataloader), total=len(main_dataloader), desc=f'cond: {cond} generation',
136 | ncols=200) as pbar:
137 |
138 | for i, data in pbar:
139 |
140 | bbox, label, _, mask = sparse_to_dense(data)
141 | label, bbox, mask = pad_until(label, bbox, mask, max_seq_length=25)
142 |
143 | label, bbox, mask = label.to(device), bbox.to(device), mask.to(device)
144 |
145 | # shift to center
146 | bbox_in = 2 * (bbox - 0.5).to(device)
147 |
148 | # set mask to label 5
149 | label[mask == False] = seq_dim - 5
150 |
151 | label_oh = torch.nn.functional.one_hot(label, num_classes=seq_dim - 4).to(device)
152 | real_layout = torch.cat((label_oh, bbox_in), dim=2).to(device)
153 |
154 | bbox_generated, label_generated, mask_generated = model.conditional_reverse_ddim(real_layout, cond=cond)
155 |
156 | if beautify and dataset_name == 'publaynet':
157 | bbox_generated, mask_generated = post_process(bbox_generated, mask_generated, w_o=1)
158 | elif beautify and (dataset_name == 'rico25' or dataset_name == 'rico13'):
159 | bbox_generated, mask_generated = post_process(bbox_generated, mask_generated, w_o=0)
160 |
161 | padding_mask = ~mask_generated
162 |
163 | # test for errors
164 | if torch.isnan(bbox[0, 0, 0]):
165 | print('not a number error')
166 | return None
167 |
168 | # accumulate align and overlap
169 | align_norm = compute_alignment(bbox_generated, mask)
170 | align_sum += torch.mean(align_norm)
171 | overlap_score = compute_overlap(bbox_generated, mask)
172 | overlap_sum += torch.mean(overlap_score)
173 |
174 | # record for max_iou
175 | label_generated[label_generated == seq_dim - 5] = 0
176 | for j in range(bbox.shape[0]):
177 | mask_single = mask_generated[j, :]
178 | bbox_single = bbox_generated[j, mask_single, :]
179 | label_single = label_generated[j, mask_single]
180 |
181 | layout_generated.append((bbox_single.to('cpu').numpy(), label_single.to('cpu').numpy()))
182 |
183 | # record for FID
184 | with torch.set_grad_enabled(False):
185 | feat = fid_model.extract_features(bbox_generated, label_generated, padding_mask)
186 | feats_generate.append(feat.cpu())
187 |
188 | if test_plot and i <= 10:
189 | img = save_image(bbox_generated[:9], label_generated[:9], mask_generated[:9],
190 | draw_label=False, dataset=dataset_name)
191 | plt.figure(figsize=[12, 12])
192 | plt.imshow(img)
193 | plt.tight_layout()
194 | plt.savefig(f'./plot/test/cond_{cond}_{dataset_name}_{i}.png')
195 | plt.close()
196 |
197 | img = save_image(bbox[:9], label[:9], mask[:9], draw_label=False, dataset=dataset_name)
198 | plt.figure(figsize=[12, 12])
199 | plt.imshow(img)
200 | plt.tight_layout()
201 | plt.savefig(os.path.join(save_dir, f'{dataset_name}_real.png'))
202 | plt.close()
203 |
204 | maxiou = compute_maximum_iou(layouts_main, layout_generated)
205 | result = compute_generative_model_scores(feats_test, feats_generate)
206 | fid = result['fid']
207 |
208 | align_final = 100 * align_sum / len(main_dataloader)
209 | overlap_final = 100 * overlap_sum / len(main_dataloader)
210 |
211 | print(f'cond {cond}, align: {align_final}, fid: {fid}, maxiou: {maxiou}, overlap: {overlap_final}')
212 |
213 | return align_final, fid, maxiou, overlap_final
214 |
215 |
216 | def test_layout_refine(model, batch_size=256, dataset_name='publaynet', seq_dim=10,
217 | test_plot=False, save_dir='./plot/test', beautify=False):
218 |
219 | model.eval()
220 | device = model.device
221 | n_batch_dict = {'publaynet': 44, 'rico13': 17, 'rico25': 17, 'magazine': 2, 'crello': 10}
222 | n_batch = n_batch_dict[dataset_name]
223 |
224 | # prepare dataset
225 | main_dataset, main_dataloader = init_dataset(dataset_name, './datasets', batch_size=batch_size,
226 | split='test', shuffle=False, transform=None)
227 |
228 | layouts_main = loader_to_list(main_dataloader)
229 | layout_generated = []
230 |
231 | fid_model = load_fidnet_v3(main_dataset, './fid/FIDNetV3', device=device)
232 | feats_test = test_fid_feat(dataset_name, device=device, batch_size=20)
233 | feats_generate = []
234 |
235 | align_sum = 0
236 | overlap_sum = 0
237 | with torch.no_grad():
238 |
239 | with tqdm(enumerate(main_dataloader), total=min(n_batch, len(main_dataloader)), desc=f'refine generation',
240 | ncols=200) as pbar:
241 |
242 | for i, data in pbar:
243 | if i == min(n_batch, len(main_dataloader)):
244 | break
245 |
246 | bbox, label, _, mask = sparse_to_dense(data)
247 | label, bbox, mask = pad_until(label, bbox, mask, max_seq_length=25)
248 |
249 | label, bbox, mask = label.to(device), bbox.to(device), mask.to(device)
250 |
251 | # shift to center
252 | bbox_noisy = torch.clamp(bbox + 0.1 * torch.randn_like(bbox), min=0, max=1)
253 | bbox_in_noisy = 2 * (bbox_noisy - 0.5).to(device)
254 | #
255 | # set mask to label 5
256 | label[mask == False] = seq_dim - 5
257 |
258 | label_oh = torch.nn.functional.one_hot(label, num_classes=seq_dim - 4).to(device)
259 | noisy_layout = torch.cat((label_oh, bbox_in_noisy), dim=2).to(device)
260 |
261 | bbox_refined, _, _ = model.refinement_reverse_ddim(noisy_layout)
262 |
263 | if beautify and dataset_name == 'publaynet':
264 | bbox_refined, mask = post_process(bbox_refined, mask, w_o=1)
265 | elif beautify and (dataset_name == 'rico25' or dataset_name == 'rico13'):
266 | bbox_refined, mask = post_process(bbox_refined, mask, w_o=0)
267 | padding_mask = ~mask
268 |
269 | # accumulate align and overlap
270 | align_norm = compute_alignment(bbox_refined, mask)
271 | align_sum += torch.mean(align_norm)
272 | overlap_score = compute_overlap(bbox_refined, mask)
273 | overlap_sum += torch.mean(overlap_score)
274 |
275 | # record for max_iou
276 | label[label == seq_dim - 5] = 0
277 |
278 | for j in range(bbox_refined.shape[0]):
279 | mask_single = mask[j, :]
280 | bbox_single = bbox_refined[j, mask_single, :]
281 | label_single = label[j, mask_single]
282 |
283 | layout_generated.append((bbox_single.to('cpu').numpy(), label_single.to('cpu').numpy()))
284 |
285 | # record for FID
286 | with torch.set_grad_enabled(False):
287 | feat = fid_model.extract_features(bbox_refined, label, padding_mask)
288 | feats_generate.append(feat.cpu())
289 |
290 |
291 | if test_plot and i <= 10:
292 | img = save_image(bbox_refined[:9], label[:9], mask[:9],
293 | draw_label=False, dataset=dataset_name)
294 | plt.figure(figsize=[12, 12])
295 | plt.imshow(img)
296 | plt.tight_layout()
297 | plt.savefig(f'./plot/test/refine_{dataset_name}_{i}.png')
298 | plt.close()
299 |
300 | img = save_image(bbox[:9], label[:9], mask[:9], draw_label=False, dataset=dataset_name)
301 | plt.figure(figsize=[12, 12])
302 | plt.imshow(img)
303 | plt.tight_layout()
304 | plt.savefig(os.path.join(save_dir, f'{dataset_name}_real.png'))
305 | plt.close()
306 |
307 | maxiou = compute_maximum_iou(layouts_main, layout_generated)
308 | result = compute_generative_model_scores(feats_test, feats_generate)
309 | fid = result['fid']
310 |
311 | align_final = 100 * align_sum / len(main_dataloader)
312 | overlap_final = 100 * overlap_sum / len(main_dataloader)
313 |
314 | print(f'refine, align: {align_final}, fid: {fid}, maxiou: {maxiou}, overlap: {overlap_final}')
315 | return align_final, fid, maxiou, overlap_final
316 |
317 |
318 | def test_all(model, dataset_name='publaynet', seq_dim=10, test_plot=False, save_dir='./plot/test', batch_size=256,
319 | beautify=False):
320 |
321 | align_uncond, fid_uncond, overlap_uncond = test_layout_uncond(model, batch_size=batch_size, dataset_name=dataset_name,
322 | test_plot=test_plot, save_dir=save_dir, beautify=beautify)
323 | align_c, fid_c, maxiou_c, overlap_c = test_layout_cond(model, batch_size=batch_size, cond='c',
324 | dataset_name=dataset_name, seq_dim=seq_dim,
325 | test_plot=test_plot, save_dir=save_dir, beautify=beautify)
326 | align_cwh, fid_cwh, maxiou_cwh, overlap_cwh = test_layout_cond(model, batch_size=batch_size, cond='cwh',
327 | dataset_name=dataset_name, seq_dim=seq_dim,
328 | test_plot=test_plot, save_dir=save_dir, beautify=beautify)
329 | align_complete, fid_complete, maxiou_complete, overlap_complete = test_layout_cond(model, batch_size=batch_size,
330 | cond='complete', dataset_name=dataset_name,
331 | seq_dim=seq_dim, test_plot=test_plot,
332 | save_dir=save_dir, beautify=beautify)
333 | align_r, fid_r, maxiou_r, overlap_r = test_layout_refine(model, batch_size=batch_size,
334 | dataset_name=dataset_name, seq_dim=seq_dim,
335 | test_plot=test_plot, save_dir=save_dir, beautify=beautify)
336 |
337 | # fid_total = fid_uncond + fid_c + fid_cwh + fid_complete
338 | # print(f'total fid: {fid_total}')
339 | # return fid_total
340 |
341 |
342 | if __name__ == "__main__":
343 |
344 | parser = argparse.ArgumentParser()
345 | parser.add_argument("--batch_size", default=256, help="batch_size", type=int)
346 | parser.add_argument("--device", default='cpu', help="which GPU to use", type=str)
347 | parser.add_argument("--dataset", default='publaynet',
348 | help="choose from [publaynet, rico13, rico25]", type=str)
349 | parser.add_argument("--data_dir", default='./datasets', help="dir of datasets", type=str)
350 | parser.add_argument("--num_workers", default=4, help="num_workers", type=int)
351 | parser.add_argument("--feature_dim", default=2048, help="feature_dim", type=int)
352 | parser.add_argument("--dim_transformer", default=1024, help="dim_transformer", type=int)
353 | parser.add_argument("--nhead", default=16, help="nhead attention", type=int)
354 | parser.add_argument("--nlayer", default=4, help="nlayer", type=int)
355 | parser.add_argument("--experiment", default='c', help="experiment setting [uncond, c, cwh, complete, all]", type=str)
356 | parser.add_argument('--plot', default=False, action=argparse.BooleanOptionalAction)
357 | parser.add_argument('--beautify', default=False, action=argparse.BooleanOptionalAction)
358 | parser.add_argument("--plot_save_dir", default='./plot/test', help="dir to save generated plot of layouts", type=str)
359 | args = parser.parse_args()
360 |
361 | # prepare data
362 | train_dataset, train_loader = init_dataset(args.dataset, args.data_dir, batch_size=args.batch_size,
363 | split='train', shuffle=True)
364 | num_class = train_dataset.num_classes + 1
365 |
366 | # set up model
367 | model_ddpm = Diffusion(num_timesteps=1000, nhead=args.nhead, dim_transformer=args.dim_transformer,
368 | feature_dim=args.feature_dim, seq_dim=num_class + 4, num_layers=args.nlayer,
369 | device=args.device, ddim_num_steps=100)
370 |
371 | state_dict = torch.load(f'./model/{args.dataset}_best.pt', map_location='cpu')
372 | model_ddpm.load_diffusion_net(state_dict)
373 |
374 | if args.experiment == 'uncond':
375 | test_layout_uncond(model_ddpm, batch_size=args.batch_size,
376 | dataset_name=args.dataset, test_plot=args.plot,
377 | save_dir=args.plot_save_dir, beautify=args.beautify)
378 | elif args.experiment in ['c', 'cwh', 'complete']:
379 | test_layout_cond(model_ddpm, batch_size=args.batch_size, cond=args.experiment,
380 | dataset_name=args.dataset, seq_dim=num_class + 4,
381 | test_plot=args.plot, save_dir=args.plot_save_dir, beautify=args.beautify)
382 | elif args.experiment == 'refine':
383 | test_layout_refine(model_ddpm, batch_size=args.batch_size,
384 | dataset_name=args.dataset, seq_dim=num_class + 4,
385 | test_plot=args.plot, save_dir=args.plot_save_dir, beautify=args.beautify)
386 | elif args.experiment == 'all':
387 | test_all(model_ddpm, dataset_name=args.dataset, seq_dim=num_class + 4, test_plot=args.plot,
388 | save_dir=args.plot_save_dir, batch_size=args.batch_size, beautify=args.beautify)
389 | else:
390 | raise Exception('experiment setting undefined')
391 |
392 |
393 |
394 |
395 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 | from util.datasets.load_data import init_dataset
4 | from util.visualization import save_image
5 | from util.seq_util import sparse_to_dense, pad_until
6 | from model_diffusion import Diffusion
7 | from util.ema import EMA
8 | import argparse
9 | import pickle as pk
10 | import torch.optim as optim
11 | from util.constraint import *
12 | import math
13 | import os
14 | from test import test_all
15 |
16 |
17 | if __name__ == "__main__":
18 |
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument("--nepoch", default=None, help="number of training epochs", type=int)
21 | parser.add_argument("--start_epoch", default=0, help="start epoch", type=int)
22 | parser.add_argument("--batch_size", default=256, help="batch_size", type=int)
23 | parser.add_argument("--lr", default=1e-5, help="learning rate", type=float)
24 | parser.add_argument("--sample_t_max", default=999, help="maximum t in training", type=int)
25 | parser.add_argument("--dataset", default='publaynet',
26 | help="choose from [publaynet, rico13, rico25, magazine, crello]", type=str)
27 | parser.add_argument("--data_dir", default='./datasets', help="dir of datasets", type=str)
28 | parser.add_argument("--num_workers", default=4, help="num_workers", type=int)
29 | parser.add_argument("--n_save_epoch", default=50, help="number of epochs to do test and save model", type=int)
30 | parser.add_argument("--feature_dim", default=2048, help="feature_dim", type=int)
31 | parser.add_argument("--dim_transformer", default=1024, help="dim_transformer", type=int)
32 | parser.add_argument("--embed_type", default='pos', help="embed type for transformer, pos or time", type=str)
33 | parser.add_argument("--nhead", default=16, help="nhead attention", type=int)
34 | parser.add_argument("--nlayer", default=4, help="nlayer", type=int)
35 | parser.add_argument("--align_weight", default=1, help="the weight of alignment constraint", type=float)
36 | parser.add_argument("--align_type", default='local', help="local or global alignment constraint", type=str)
37 | parser.add_argument("--overlap_weight", default=1, help="the weight of overlap constraint", type=float)
38 | parser.add_argument('--load_pre', default=False, action=argparse.BooleanOptionalAction)
39 | parser.add_argument('--beautify', default=False, action=argparse.BooleanOptionalAction)
40 | parser.add_argument('--enable_test', default=True, action=argparse.BooleanOptionalAction)
41 | parser.add_argument("--gpu_devices", default=[0, 2, 3], type=int, nargs='+', help="")
42 | parser.add_argument("--device", default=None, help="which cuda to use", type=str)
43 | args = parser.parse_args()
44 |
45 | if args.device is None:
46 | gpu_devices = ','.join([str(id) for id in args.gpu_devices])
47 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu_devices
48 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
49 | else:
50 | device = args.device
51 |
52 | print(f'load_pre: {args.load_pre}, enable_test: {args.enable_test}, embed: {args.embed_type}')
53 | print(f'dim_transformer: {args.dim_transformer}, n_layers: {args.nlayer}, nhead: {args.nhead}')
54 | print(f'align_type: {args.align_type}, align_weight: {args.align_weight}, overlap_weight: {args.overlap_weight}')
55 | print(f'device: {args.device}')
56 |
57 | # prepare data
58 | if args.embed_type == 'pos':
59 | train_dataset, train_loader = init_dataset(args.dataset, args.data_dir, batch_size=args.batch_size,
60 | split='train', shuffle=True, transform=None)
61 | else:
62 | train_dataset, train_loader = init_dataset(args.dataset, args.data_dir, batch_size=args.batch_size,
63 | split='train', shuffle=True)
64 |
65 | num_class = train_dataset.num_classes + 1
66 |
67 | # set up model
68 | model_ddpm = Diffusion(num_timesteps=1000, nhead=args.nhead, dim_transformer=args.dim_transformer,
69 | feature_dim=args.feature_dim, seq_dim=num_class + 4, num_layers=args.nlayer,
70 | device=device, ddim_num_steps=200)
71 |
72 | if args.load_pre:
73 | # state_dict = torch.load(f'./model/{args.embed_type}_{args.dataset}_1024_recent.pt', map_location='cpu')
74 | state_dict = torch.load(f'./model/publaynet_best.pt', map_location='cpu')
75 | model_ddpm.load_diffusion_net(state_dict)
76 |
77 | if args.device is None:
78 | print('using DataParallel')
79 | model_ddpm.model = nn.DataParallel(model_ddpm.model).to(device)
80 | else:
81 | print('using single gpu')
82 | model_ddpm.to(device)
83 |
84 | if args.load_pre:
85 | fid_best = test_all(model_ddpm, dataset_name=args.dataset, seq_dim=num_class + 4, batch_size=args.batch_size,
86 | beautify=args.beautify)
87 | # fid_best = 1e10
88 | else:
89 | fid_best = 1e10
90 |
91 | # optimizer
92 | optimizer = optim.Adam(model_ddpm.model.parameters(), lr=args.lr, weight_decay=0.0, betas=(0.9, 0.999), amsgrad=False, eps=1e-08)
93 | mse_loss = nn.MSELoss()
94 |
95 | ema_helper = EMA(mu=0.9999)
96 | ema_helper.register(model_ddpm.model)
97 |
98 |
99 | for epoch in range(args.start_epoch, args.nepoch):
100 | model_ddpm.model.train()
101 |
102 | if (epoch) % args.n_save_epoch == 0 and epoch != 0:
103 |
104 | # model_path = f'./model/{args.embed_type}_{args.dataset}_1024_recent.pt'
105 | # states = model_ddpm.model.module.state_dict()
106 | # torch.save(states, model_path)
107 |
108 | if args.enable_test:
109 | fid_total = test_all(model_ddpm, dataset_name=args.dataset, seq_dim=num_class + 4, batch_size=args.batch_size, beautify=False)
110 | # print(f'previous best fid: {fid_best}')
111 | # if fid_total < fid_best:
112 | # # model_path = f'./model/{args.embed_type}_{args.dataset}_1024_lowest.pt'
113 | # # torch.save(states, model_path)
114 | # fid_best = fid_total
115 | # print('New lowest fid model, saved')
116 |
117 | with tqdm(enumerate(train_loader), total=len(train_loader), desc=f'train diffusion epoch {epoch}', ncols=200) as pbar:
118 |
119 | for i, data in pbar:
120 | bbox, label, _, mask = sparse_to_dense(data)
121 | label, bbox, mask = pad_until(label, bbox, mask, max_seq_length=25)
122 |
123 | label, bbox, mask = label.to(device), bbox.to(device), mask.to(device)
124 |
125 | # shift to center
126 | bbox_in = 2 * (bbox - 0.5).to(args.device)
127 |
128 | # set mask to label 5
129 | label[mask==False] = num_class - 1
130 |
131 | label_oh = torch.nn.functional.one_hot(label, num_classes=num_class).to(args.device)
132 |
133 | # concat label with bbox and get a 10 dim
134 | layout_input = torch.cat((label_oh, bbox_in), dim=2).to(args.device)
135 |
136 | t = model_ddpm.sample_t([bbox.shape[0]], t_max=args.sample_t_max)
137 | t_all = torch.cat([t, t, t, t], dim=0)
138 |
139 | eps_theta, e, b_0_reparam = model_ddpm.forward_t(layout_input, t=t, real_mask=mask, reparam=True)
140 |
141 | # compute b_0 reparameterization
142 | bbox_rep = torch.clamp(b_0_reparam[:, :, num_class:], min=-1, max=1) / 2 + 0.5
143 | mask_4 = torch.cat([mask, mask, mask, mask], dim=0)
144 | bbox_4 = torch.cat([bbox, bbox, bbox, bbox], dim=0)
145 |
146 | # compute alignment loss
147 | if args.align_type == 'global':
148 | # global alignment
149 | align_loss = mean_alignment_error(bbox_rep, bbox_4, mask_4)
150 | else:
151 | # local alignment
152 | _, align_loss = layout_alignment(bbox_rep, mask_4, xy_only=False)
153 | align_loss = 20 * align_loss
154 |
155 | # compute piou and pdist
156 | piou = PIoU_xywh(bbox_rep, mask=mask_4.to(torch.float32), xy_only=False)
157 | pdist = Pdist(bbox_rep)
158 |
159 | # compute piou loss with temporal weight
160 | overlap_loss = torch.mean(piou, dim=[1, 2]) + torch.mean(piou.ne(0) * torch.exp(-pdist), dim=[1, 2])
161 | # overlap_loss = torch.mean(piou, dim=[1, 2])
162 |
163 | # reconstruction loss
164 | layout_input_all = torch.cat([layout_input, layout_input, layout_input, layout_input], dim=0)
165 | reconstruct_loss = mse_loss(layout_input_all[:, :, num_class:], b_0_reparam[:, :, num_class:])
166 | # _, giou = GIoU_xywh(b_0_reparam[:, :, num_class:], layout_input_all[:, :, num_class:])
167 | # reconstruct_loss = (1 - 1 * torch.mean(giou))
168 |
169 | # combine constraints with temporal weight
170 | weight = constraint_temporal_weight(t_all, schedule='const')
171 | constraint_loss = torch.mean((args.align_weight * align_loss + args.overlap_weight * overlap_loss)
172 | * weight)
173 |
174 | # compute diffusion loss
175 | diffusion_loss = mse_loss(e, eps_theta)
176 |
177 | # total loss
178 | loss = diffusion_loss + constraint_loss + reconstruct_loss
179 |
180 | pbar.set_postfix({'diffusion': diffusion_loss.item(), 'align': torch.mean(align_loss).item(),
181 | 'overlap': torch.mean(overlap_loss).item(), 'reconstruct': reconstruct_loss.item()})
182 |
183 | # optimize
184 | optimizer.zero_grad()
185 | loss.backward()
186 | torch.nn.utils.clip_grad_norm_(model_ddpm.model.parameters(), 1.0)
187 | optimizer.step()
188 | ema_helper.update(model_ddpm.model)
189 |
190 |
--------------------------------------------------------------------------------
/train_fid_model.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 | from pathlib import Path
5 | from tqdm import tqdm
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | import torchvision.transforms as T
10 | from torch_geometric.utils import to_dense_batch
11 | from util.data_util import AddNoiseToBBox, LexicographicOrder
12 | from util.fid.model import FIDNetV3
13 | from util.datasets.load_data import init_dataset
14 |
15 |
16 | def save_checkpoint(state, is_best, out_dir):
17 | out_path = Path(out_dir) / "checkpoint.pth.tar"
18 | torch.save(state, out_path)
19 |
20 | if is_best:
21 | best_path = Path(out_dir) / "model_best.pth.tar"
22 | shutil.copyfile(out_path, best_path)
23 |
24 |
25 | def main():
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument("--batch_size", type=int, default=64, help="input batch size")
28 | parser.add_argument("--dataset", default='crello', help="choose from [magazine, rico13, rico25, publaynet, crello]",
29 | type=str)
30 | parser.add_argument("--data_dir", default='./datasets', help="dir of datasets", type=str)
31 | parser.add_argument("--device", default='cpu', help="which GPU to use", type=str)
32 | parser.add_argument("--out_dir", type=str, default="./fid/FIDNetV3/")
33 | parser.add_argument(
34 | "--iteration",
35 | type=int,
36 | default=int(2e5),
37 | help="number of iterations to train for",
38 | )
39 | parser.add_argument(
40 | "--lr", type=float, default=3e-4, help="learning rate, default=3e-4"
41 | )
42 | parser.add_argument("--seed", type=int, help="manual seed")
43 | args = parser.parse_args()
44 | print(args)
45 |
46 | prefix = "FIDNetV3"
47 | out_dir = Path(os.path.join(args.out_dir, args.dataset + '-max25'))
48 | out_dir.mkdir(parents=True, exist_ok=True)
49 |
50 | transform = T.Compose(
51 | [
52 | T.RandomApply([AddNoiseToBBox()], 0.5),
53 | LexicographicOrder(),
54 | ]
55 | )
56 |
57 | train_dataset, train_dataloader = init_dataset(args.dataset, args.data_dir, batch_size=64, split='train',
58 | transform=transform, shuffle=True)
59 | val_dataset, val_dataloader = init_dataset(args.dataset, args.data_dir, batch_size=64, split='test',
60 | transform=transform, shuffle=False)
61 | print('num_classes', train_dataset.num_classes)
62 |
63 | device = torch.device(args.device if torch.cuda.is_available() else "cpu")
64 | model = FIDNetV3(num_label=train_dataset.num_classes, max_bbox=25).to(device)
65 |
66 | # setup optimizer
67 | optimizer = optim.Adam(model.parameters(), lr=args.lr)
68 |
69 | criterion_bce = nn.BCEWithLogitsLoss(reduction="none")
70 | criterion_label = nn.CrossEntropyLoss(reduction="none")
71 | criterion_bbox = nn.MSELoss(reduction="none")
72 |
73 | def proc_batch(batch):
74 | batch = batch.to(device)
75 | bbox, _ = to_dense_batch(batch.x, batch.batch)
76 | label, mask = to_dense_batch(batch.y, batch.batch)
77 | padding_mask = ~mask
78 |
79 | is_real = batch.attr["NoiseAdded"].float()
80 | return bbox, label, padding_mask, mask, is_real
81 |
82 | iteration = 0
83 | best_loss = 1e8
84 | max_epoch = args.iteration * args.batch_size / len(train_dataset)
85 | max_epoch = torch.ceil(torch.tensor(max_epoch)).int().item()
86 | for epoch in range(max_epoch):
87 | model.train()
88 | train_loss = {
89 | "Loss_BCE": 0,
90 | "Loss_Label": 0,
91 | "Loss_BBox": 0,
92 | }
93 |
94 | for i, batch in enumerate(train_dataloader):
95 |
96 | bbox, label, padding_mask, mask, is_real = proc_batch(batch)
97 | model.zero_grad()
98 |
99 | logit, logit_cls, bbox_pred = model(bbox, label, padding_mask)
100 |
101 | loss_bce = criterion_bce(logit, is_real)
102 | loss_label = criterion_label(logit_cls[mask], label[mask])
103 | loss_bbox = criterion_bbox(bbox_pred[mask], bbox[mask]).sum(-1)
104 | loss = loss_bce.mean() + loss_label.mean() + 10 * loss_bbox.mean()
105 | loss.backward()
106 |
107 | optimizer.step()
108 |
109 | loss_bce_mean = loss_bce.mean().item()
110 | train_loss["Loss_BCE"] += loss_bce.sum().item()
111 | loss_label_mean = loss_label.mean().item()
112 | train_loss["Loss_Label"] += loss_label.sum().item()
113 | loss_bbox_mean = loss_bbox.mean().item()
114 | train_loss["Loss_BBox"] += loss_bbox.sum().item()
115 |
116 | if i % 100 == 0:
117 | log_prefix = f"[{epoch}/{max_epoch}][{i}/{len(train_dataset) // args.batch_size}]"
118 | log = f"Loss: {loss.item():E}\tBCE: {loss_bce_mean:E}\tLabel: {loss_label_mean:E}\tBBox: {loss_bbox_mean:E}"
119 | print(f"{log_prefix}\t{log}")
120 |
121 | iteration += 1
122 |
123 | for key in train_loss.keys():
124 | train_loss[key] /= len(train_dataset)
125 |
126 | model.eval()
127 | with torch.no_grad():
128 | val_loss = {
129 | "Loss_BCE": 0,
130 | "Loss_Label": 0,
131 | "Loss_BBox": 0,
132 | }
133 |
134 | for i, batch in enumerate(val_dataloader):
135 | bbox, label, padding_mask, mask, is_real = proc_batch(batch)
136 |
137 | logit, logit_cls, bbox_pred = model(bbox, label, padding_mask)
138 |
139 | loss_bce = criterion_bce(logit, is_real)
140 | loss_label = criterion_label(logit_cls[mask], label[mask])
141 | loss_bbox = criterion_bbox(bbox_pred[mask], bbox[mask]).sum(-1)
142 |
143 | val_loss["Loss_BCE"] += loss_bce.sum().item()
144 | val_loss["Loss_Label"] += loss_label.sum().item()
145 | val_loss["Loss_BBox"] += loss_bbox.sum().item()
146 |
147 | for key in val_loss.keys():
148 | val_loss[key] /= len(val_dataset)
149 |
150 | tag_scalar_dict = {
151 | "train": sum(train_loss.values()),
152 | "val": sum(val_loss.values()),
153 | }
154 | for key in train_loss.keys():
155 | tag_scalar_dict = {"train": train_loss[key], "val": val_loss[key]}
156 |
157 | # do checkpointing
158 | val_loss = sum(val_loss.values())
159 | is_best = val_loss < best_loss
160 | best_loss = min(val_loss, best_loss)
161 |
162 | save_checkpoint(
163 | {
164 | "epoch": epoch + 1,
165 | "state_dict": model.state_dict(),
166 | "best_loss": best_loss,
167 | "optimizer": optimizer.state_dict(),
168 | },
169 | is_best,
170 | out_dir,
171 | )
172 |
173 |
174 | if __name__ == "__main__":
175 | main()
176 |
--------------------------------------------------------------------------------
/util/backbone.py:
--------------------------------------------------------------------------------
1 | # Implement TransformerEncoder that can consider timesteps as optional args for Diffusion.
2 |
3 | import copy
4 | import math
5 | from typing import Callable, Optional, Union
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | from einops.layers.torch import Rearrange
10 | from torch import Tensor, nn
11 |
12 |
13 | def _get_clones(module, N):
14 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
15 |
16 |
17 | def _gelu2(x):
18 | return x * F.sigmoid(1.702 * x)
19 |
20 |
21 | def _get_activation_fn(activation):
22 | if activation == "relu":
23 | return F.relu
24 | elif activation == "gelu":
25 | return F.gelu
26 | elif activation == "gelu2":
27 | return _gelu2
28 | else:
29 | raise RuntimeError(
30 | "activation should be relu/gelu/gelu2, not {}".format(activation)
31 | )
32 |
33 |
34 | class SinusoidalPosEmb(nn.Module):
35 | def __init__(self, num_steps: int, dim: int, rescale_steps: int = 4000):
36 | super().__init__()
37 | self.dim = dim
38 | self.num_steps = float(num_steps)
39 | self.rescale_steps = float(rescale_steps)
40 |
41 | def forward(self, x: Tensor):
42 | x = x / self.num_steps * self.rescale_steps
43 | device = x.device
44 | half_dim = self.dim // 2
45 | emb = math.log(10000) / (half_dim - 1)
46 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
47 | emb = x[:, None] * emb[None, :]
48 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
49 | return emb
50 |
51 |
52 | class _AdaNorm(nn.Module):
53 | def __init__(
54 | self, n_embd: int, max_timestep: int, emb_type: str = "adalayernorm_abs"
55 | ):
56 | super().__init__()
57 | if "abs" in emb_type:
58 | self.emb = SinusoidalPosEmb(max_timestep, n_embd)
59 | elif "mlp" in emb_type:
60 | self.emb = nn.Sequential(
61 | Rearrange("b -> b 1"),
62 | nn.Linear(1, n_embd // 2),
63 | nn.ReLU(),
64 | nn.Linear(n_embd // 2, n_embd),
65 | )
66 | else:
67 | self.emb = nn.Embedding(max_timestep, n_embd)
68 | self.silu = nn.SiLU()
69 | self.linear = nn.Linear(n_embd, n_embd * 2)
70 |
71 |
72 | class AdaLayerNorm(_AdaNorm):
73 | def __init__(
74 | self, n_embd: int, max_timestep: int, emb_type: str = "adalayernorm_abs"
75 | ):
76 | super().__init__(n_embd, max_timestep, emb_type)
77 | self.layernorm = nn.LayerNorm(n_embd, elementwise_affine=False)
78 |
79 | def forward(self, x: Tensor, timestep: int):
80 |
81 | emb = self.linear(self.silu(self.emb(timestep))).unsqueeze(1)
82 | scale, shift = torch.chunk(emb, 2, dim=2)
83 | x = self.layernorm(x) * (1 + scale) + shift
84 | return x
85 |
86 |
87 | class AdaInsNorm(_AdaNorm):
88 | def __init__(
89 | self, n_embd: int, max_timestep: int, emb_type: str = "adalayernorm_abs"
90 | ):
91 | super().__init__(n_embd, max_timestep, emb_type)
92 | self.instancenorm = nn.InstanceNorm1d(n_embd)
93 |
94 | def forward(self, x, timestep):
95 | emb = self.linear(self.silu(self.emb(timestep))).unsqueeze(1)
96 | scale, shift = torch.chunk(emb, 2, dim=2)
97 | x = (
98 | self.instancenorm(x.transpose(-1, -2)).transpose(-1, -2) * (1 + scale)
99 | + shift
100 | )
101 | return x
102 |
103 |
104 | class Block(nn.Module):
105 | """an unassuming Transformer block"""
106 |
107 | def __init__(
108 | self,
109 | d_model=512,
110 | nhead=8,
111 | dim_feedforward: int = 2048,
112 | dropout: float = 0.0,
113 | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
114 | batch_first: bool = True,
115 | norm_first: bool = True,
116 | device=None,
117 | dtype=None,
118 | # extension for diffusion
119 | diffusion_step: int = 100,
120 | timestep_type: str = 'adalayernorm',
121 | ) -> None:
122 | super().__init__()
123 |
124 | assert norm_first # minGPT-based implementations are designed for prenorm only
125 | assert timestep_type in [
126 | None,
127 | "adalayernorm",
128 | "adainnorm",
129 | "adalayernorm_abs",
130 | "adainnorm_abs",
131 | "adalayernorm_mlp",
132 | "adainnorm_mlp",
133 | ]
134 | layer_norm_eps = 1e-5 # fixed
135 |
136 | self.norm_first = norm_first
137 | self.diffusion_step = diffusion_step
138 | self.timestep_type = timestep_type
139 |
140 | factory_kwargs = {"device": device, "dtype": dtype}
141 | self.self_attn = torch.nn.MultiheadAttention(
142 | d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs
143 | )
144 |
145 | # Implementation of Feedforward model
146 | self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
147 | self.dropout = nn.Dropout(dropout)
148 | self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
149 |
150 | if timestep_type is not None:
151 | if "adalayernorm" in timestep_type:
152 | self.norm1 = AdaLayerNorm(d_model, diffusion_step, timestep_type)
153 | elif "adainnorm" in timestep_type:
154 | self.norm1 = AdaInsNorm(d_model, diffusion_step, timestep_type)
155 | else:
156 | self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
157 | self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
158 | self.dropout1 = nn.Dropout(dropout)
159 | self.dropout2 = nn.Dropout(dropout)
160 |
161 | if isinstance(activation, str):
162 | self.activation = _get_activation_fn(activation)
163 | else:
164 | self.activation = activation
165 |
166 | def forward(
167 | self,
168 | src: Tensor,
169 | src_mask: Optional[Tensor] = None,
170 | src_key_padding_mask: Optional[Tensor] = None,
171 | timestep: Tensor = None,
172 | ) -> Tensor:
173 | x = src
174 | if self.norm_first:
175 | if self.timestep_type is not None:
176 | x = self.norm1(x, timestep)
177 | else:
178 | x = self.norm1(x)
179 | x = x + self._sa_block(x, src_mask, src_key_padding_mask)
180 | x = x + self._ff_block(self.norm2(x))
181 | else:
182 | x = x + self._sa_block(x, src_mask, src_key_padding_mask)
183 | if self.timestep_type is not None:
184 | x = self.norm1(x, timestep)
185 | else:
186 | x = self.norm1(x)
187 | x = self.norm2(x + self._ff_block(x))
188 |
189 | return x
190 |
191 | # self-attention block
192 | def _sa_block(
193 | self,
194 | x: Tensor,
195 | attn_mask: Optional[Tensor],
196 | key_padding_mask: Optional[Tensor],
197 | ) -> Tensor:
198 | x = self.self_attn(
199 | x,
200 | x,
201 | x,
202 | attn_mask=attn_mask,
203 | key_padding_mask=key_padding_mask,
204 | need_weights=False,
205 | )[0]
206 | return self.dropout1(x)
207 |
208 | # feed forward block
209 | def _ff_block(self, x: Tensor) -> Tensor:
210 | x = self.linear2(self.dropout(self.activation(self.linear1(x))))
211 | return self.dropout2(x)
212 |
213 |
214 | class TransformerEncoder(nn.Module):
215 | """
216 | Close to torch.nn.TransformerEncoder, but with timestep support for diffusion
217 | """
218 |
219 | __constants__ = ["norm"]
220 |
221 | def __init__(self, num_layers=4, dim_seq=10, dim_transformer=512, nhead=8, dim_feedforward=2048,
222 | diffusion_step=100, device='cuda'):
223 | super(TransformerEncoder, self).__init__()
224 |
225 | self.pos_encoder = SinusoidalPosEmb(num_steps=25, dim=dim_transformer).to(device)
226 | pos_i = torch.tensor([i for i in range(25)]).to(device)
227 | self.pos_embed = self.pos_encoder(pos_i)
228 |
229 | self.layer_in = nn.Linear(in_features=dim_seq, out_features=dim_transformer).to(device)
230 | encoder_layer = Block(d_model=dim_transformer, nhead=nhead, dim_feedforward=dim_feedforward, diffusion_step=diffusion_step)
231 | self.layers = _get_clones(encoder_layer, num_layers).to(device)
232 | self.num_layers = num_layers
233 | self.layer_out = nn.Linear(in_features=dim_transformer, out_features=dim_seq).to(device)
234 |
235 | def forward(
236 | self,
237 | src: Tensor,
238 | mask: Optional[Tensor] = None,
239 | src_key_padding_mask: Optional[Tensor] = None,
240 | timestep: Tensor = None,
241 | ) -> Tensor:
242 | output = src
243 |
244 | output = self.layer_in(output)
245 | output = F.softplus(output)
246 | output = output + self.pos_embed
247 |
248 | for i, mod in enumerate(self.layers):
249 | output = mod(
250 | output,
251 | src_mask=mask,
252 | src_key_padding_mask=src_key_padding_mask,
253 | timestep=timestep,
254 | )
255 |
256 | if i < self.num_layers - 1:
257 | output = F.softplus(output)
258 |
259 | output = self.layer_out(output)
260 |
261 | return output
262 |
263 |
--------------------------------------------------------------------------------
/util/constraint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.optim as optim
5 | import numpy as np
6 | import einops
7 | import imageio
8 | torch.manual_seed(0)
9 | import math
10 |
11 |
12 | def xywh_2_ltrb(bbox_xywh):
13 |
14 | bbox_ltrb = torch.zeros(bbox_xywh.shape).to(bbox_xywh.device)
15 | bbox_xy = torch.abs(bbox_xywh[:, :, :2])
16 | bbox_wh = torch.abs(bbox_xywh[:, :, 2:])
17 | bbox_ltrb[:, :, :2] = bbox_xy - 0.5 * bbox_wh
18 | bbox_ltrb[:, :, 2:] = bbox_xy + 0.5 * bbox_wh
19 | return bbox_ltrb
20 |
21 |
22 | def ltrb_2_xywh(bbox_ltrb):
23 | bbox_xywh = torch.zeros(bbox_ltrb.shape)
24 | bbox_wh = torch.abs(bbox_ltrb[:, :, 2:] - bbox_ltrb[:, :, :2])
25 | bbox_xy = bbox_ltrb[:, :, :2] + 0.5 * bbox_wh
26 | bbox_xywh[:, :, :2] = bbox_xy
27 | bbox_xywh[:, :, 2:] = bbox_wh
28 | return bbox_xywh
29 |
30 |
31 | def xywh_to_ltrb_split(bbox):
32 | xc, yc, w, h = bbox
33 | x1 = xc - w / 2
34 | y1 = yc - h / 2
35 | x2 = xc + w / 2
36 | y2 = yc + h / 2
37 | return [x1, y1, x2, y2]
38 |
39 |
40 | def rand_bbox_ltrb(batch_shape):
41 |
42 | bbox_lt = torch.rand(batch_shape + [2])
43 | bbox_wh_max = 1 - bbox_lt
44 | bbox_wh_weight = torch.rand(batch_shape).unsqueeze(-1).repeat([1 for _ in range(len(batch_shape))] + [2])
45 |
46 | bbox_wh = 1 * bbox_wh_weight * bbox_wh_max
47 | bbox_rb = bbox_lt + bbox_wh
48 |
49 | bbox = torch.cat([bbox_lt, bbox_rb], dim=-1)
50 | return bbox
51 |
52 |
53 | def rand_bbox_xywh(batch_shape):
54 |
55 | bbox_ltrb = rand_bbox_ltrb(batch_shape)
56 | bbox_xywh = ltrb_2_xywh(bbox_ltrb)
57 | return bbox_xywh
58 |
59 |
60 | def GIoU_ltrb(bbox_1, bbox_2):
61 |
62 | # step 1 calculate area of bbox_1 and bbox_2
63 | a_1 = (bbox_1[:, :, 2] - bbox_1[:, :, 0]) * (bbox_1[:, :, 3] - bbox_1[:, :, 1])
64 | a_2 = (bbox_2[:, :, 2] - bbox_2[:, :, 0]) * (bbox_2[:, :, 3] - bbox_2[:, :, 1])
65 |
66 | # step 2.1 compute intersection I bbox
67 | bbox = torch.cat([bbox_1.unsqueeze(-1), bbox_2.unsqueeze(-1)], dim=-1)
68 | bbox_I_lt = torch.max(bbox, dim=-1)[0][:, :, :2]
69 | bbox_I_rb = torch.min(bbox, dim=-1)[0][:, :, 2:]
70 |
71 | # step 2.2 compute area of I
72 | a_I = F.relu((bbox_I_rb[:, :, 0] - bbox_I_lt[:, :, 0])) * F.relu((bbox_I_rb[:, :, 1] - bbox_I_lt[:, :, 1]))
73 |
74 | # step 3.1 compute smallest enclosing box C
75 | bbox_C_lt = torch.min(bbox, dim=-1)[0][:, :, :2]
76 | bbox_C_rb = torch.max(bbox, dim=-1)[0][:, :, 2:]
77 |
78 | # step 3.2 compute area of C
79 | a_C = (bbox_C_rb[:, :, 0] - bbox_C_lt[:, :, 0]) * (bbox_C_rb[:, :, 1] - bbox_C_lt[:, :, 1])
80 |
81 | # step 4 compute IoU
82 | a_U = (a_1 + a_2 - a_I)
83 | iou = a_I / (a_U + 1e-10)
84 |
85 | # step 5 copute giou
86 | giou = iou - (a_C - a_U) / (a_C + 1e-10)
87 |
88 | return iou, giou
89 |
90 |
91 | def GIoU_xywh(bbox_pred, bbox_true, xy_only=False):
92 |
93 | if xy_only:
94 | wh = torch.abs(bbox_pred[:, :, 2:].clone().detach())
95 | bbox = torch.cat([bbox_pred[:, :, :2], wh], dim=2)
96 | else:
97 | bbox = bbox_pred
98 |
99 | bbox_pred_ltrb = xywh_2_ltrb(torch.abs(bbox))
100 | bbox_true_ltrb = xywh_2_ltrb(torch.abs(bbox_true))
101 | return GIoU_ltrb(bbox_pred_ltrb, bbox_true_ltrb)
102 |
103 |
104 | def PIoU_ltrb(bbox_ltrb, mask=None):
105 |
106 | n_box = bbox_ltrb.shape[1]
107 | device = bbox_ltrb.device
108 |
109 | # compute area of bboxes
110 | area_bbox = (bbox_ltrb[:, :, 2] - bbox_ltrb[:, :, 0]) * (bbox_ltrb[:, :, 3] - bbox_ltrb[:, :, 1])
111 | area_bbox_psum = area_bbox.unsqueeze(-1) + area_bbox.unsqueeze(-2)
112 |
113 | # compute pairwise intersection
114 | x1y1 = bbox_ltrb[:, :, [0, 1]]
115 | x1y1 = torch.swapaxes(x1y1, 1, 2)
116 | x1y1_I = torch.max(x1y1.unsqueeze(-1), x1y1.unsqueeze(-2))
117 |
118 | x2y2 = bbox_ltrb[:, :, [2, 3]]
119 | x2y2 = torch.swapaxes(x2y2, 1, 2)
120 | x2y2_I = torch.min(x2y2.unsqueeze(-1), x2y2.unsqueeze(-2))
121 | # compute area of Is
122 | wh_I = F.relu(x2y2_I - x1y1_I)
123 | area_I = wh_I[:, 0, :, :] * wh_I[:, 1, :, :]
124 |
125 | # compute pairwise IoU
126 | piou = area_I / (area_bbox_psum - area_I + 1e-10)
127 |
128 | piou.masked_fill_(torch.eye(n_box, n_box).to(torch.bool).to(device), 0)
129 |
130 | if mask is not None:
131 | mask = mask.unsqueeze(2)
132 | select_mask = torch.matmul(mask, torch.transpose(mask, dim0=1, dim1=2))
133 | piou = piou * select_mask.to(device)
134 |
135 | return piou
136 |
137 |
138 | def PIoU_xywh(bbox_xywh, mask=None, xy_only=True):
139 |
140 | if xy_only:
141 | wh = torch.abs(bbox_xywh[:, :, 2:].clone().detach())
142 | bbox = torch.cat([bbox_xywh[:, :, :2], wh], dim=2)
143 | bbox_ltrb = xywh_2_ltrb(bbox)
144 | else:
145 | bbox_ltrb = xywh_2_ltrb(bbox_xywh)
146 |
147 | return PIoU_ltrb(bbox_ltrb, mask)
148 |
149 |
150 | def Pdist(bbox):
151 | xy = bbox[:, :, :2]
152 | pdist_m = torch.cdist(xy, xy, p=2)
153 |
154 | return pdist_m
155 |
156 | def layout_alignment(bbox, mask, xy_only=False, mode='all'):
157 | """
158 | alignment metrics in Attribute-conditioned Layout GAN for Automatic Graphic Design (TVCG2020)
159 | https://arxiv.org/abs/2009.05284
160 | """
161 |
162 | if xy_only:
163 | wh = torch.abs(bbox[:, :, 2:].clone().detach())
164 | bbox = torch.cat([bbox[:, :, :2], wh], dim=2)
165 |
166 | bbox = bbox.permute(2, 0, 1)
167 | xl, yt, xr, yb = xywh_to_ltrb_split(bbox)
168 | xc, yc = bbox[0], bbox[1]
169 | if mode == 'all':
170 | X = torch.stack([xl, xc, xr, yt, yc, yb], dim=1)
171 | elif mode == 'partial':
172 | X = torch.stack([xl, xc, yt, yb], dim=1)
173 | else:
174 | raise Exception('mode must be all or partial')
175 |
176 | X = X.unsqueeze(-1) - X.unsqueeze(-2)
177 | idx = torch.arange(X.size(2), device=X.device)
178 | X[:, :, idx, idx] = 1.0
179 | X = X.abs().permute(0, 2, 1, 3)
180 | X[~mask] = 1.0
181 |
182 | X = X.min(-1).values.min(-1).values
183 | X.masked_fill_(X.eq(1.0), 0.0)
184 | X = -torch.log(1 - X)
185 |
186 | score = einops.reduce(X, "b s -> b", reduction="sum")
187 | score_normalized = score / einops.reduce(mask, "b s -> b", reduction="sum")
188 | score_normalized[torch.isnan(score_normalized)] = 0.0
189 |
190 | return score, score_normalized
191 |
192 |
193 | def layout_alignment_matrix(bbox, mask):
194 | bbox = bbox.permute(2, 0, 1)
195 | xl, yt, xr, yb = xywh_to_ltrb_split(bbox)
196 | xc, yc = bbox[0], bbox[1]
197 | X = torch.stack([xl, xc, xr, yt, yc, yb], dim=1)
198 | X = X.unsqueeze(-1) - X.unsqueeze(-2)
199 | idx = torch.arange(X.size(2), device=X.device)
200 | X[:, :, idx, idx] = 1.0
201 | X = X.abs().permute(0, 2, 1, 3)
202 | X[~mask] = 1.0
203 | return X
204 |
205 |
206 | def mean_alignment_error(bbox_target, bbox_true, mask_true, th=1e-5, xy_only=False):
207 | """
208 | mean coordinate difference error for aligned positions, a function for a batch
209 | tau_t: misalignment tolerance threshold
210 | th: threshold for alignment error in real-data
211 | mask_true: indices where the coordinate difference is smaller than th
212 | """
213 |
214 | if xy_only:
215 | wh = torch.abs(bbox_target[:, :, 2:].clone().detach())
216 | bbox = torch.cat([bbox_target[:, :, :2], wh], dim=2)
217 | else:
218 | bbox = bbox_target
219 |
220 | align_score_target = layout_alignment_matrix(bbox, mask_true)
221 |
222 | align_score_true = layout_alignment_matrix(bbox_true, mask_true)
223 | align_mask = (align_score_true < th).clone().detach()
224 |
225 | selected_difference = align_score_target * align_mask
226 |
227 | mae = einops.reduce(selected_difference, "n a b c -> n", reduction="sum")
228 |
229 | return mae
230 |
231 |
232 | def constraint_temporal_weight(t, schedule="linear", num_timesteps=1000, start=1e-5, end=1e-2):
233 | if schedule == "linear":
234 | w = 1 - torch.linspace(start, end, num_timesteps)
235 | elif schedule == "const":
236 | w = 1 - 4 * end * torch.ones(num_timesteps)
237 | elif schedule == "quad":
238 | w = 1 - torch.linspace(start ** 0.5, end ** 0.5, num_timesteps) ** 2
239 | elif schedule == "jsd":
240 | w = 1 - 1.0 / torch.linspace(num_timesteps, 1, num_timesteps)
241 | elif schedule == "sigmoid":
242 | betas = torch.linspace(-6, 6, num_timesteps)
243 | w = 1 - torch.sigmoid(betas) * (end - start) + start
244 | elif schedule == "cosine" or schedule == "cosine_reverse":
245 | max_beta = 0.999
246 | cosine_s = 0.008
247 | w = 1 - torch.tensor(
248 | [min(1 - (math.cos(((i + 1) / num_timesteps + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2) / (
249 | math.cos((i / num_timesteps + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2), max_beta) for i in
250 | range(num_timesteps)])
251 | elif schedule == "cosine_anneal":
252 | w = 1 - torch.tensor(
253 | [start + 0.5 * (end - start) * (1 - math.cos(t / (num_timesteps - 1) * math.pi)) for t in
254 | range(num_timesteps)])
255 |
256 | weight = w.cumprod(dim=0).to(t.device)
257 |
258 | return weight[t]
259 |
260 |
261 | def post_process(bbox, mask_generated, xy_only=False, w_o=1, w_a=1):
262 |
263 | # print('beautify')
264 | if torch.sum(mask_generated) == 1:
265 | return bbox, mask_generated
266 |
267 | if xy_only:
268 | wh = torch.abs(bbox[:, :, 2:].clone().detach())
269 | bbox_in = torch.cat([bbox[:, :, :2], wh], dim=2)
270 | else:
271 | bbox_in = bbox
272 |
273 | bbox_in[:, :, [0, 2]] *= 10 / 4
274 | bbox_in[:, :, [1, 3]] *= 10 / 6
275 |
276 | bbox_initial = bbox_in.clone().detach()
277 | mse_loss = nn.MSELoss()
278 |
279 | bbox_p = nn.Parameter(bbox_in)
280 | optimizer = optim.Adam([bbox_p], lr=1e-4, weight_decay=0.0, betas=(0.9, 0.999), amsgrad=False, eps=1e-08)
281 | with torch.enable_grad():
282 | for i in range(1000):
283 | bbox_1 = torch.relu(bbox_p)
284 | align_score_target = layout_alignment_matrix(bbox_1, mask_generated)
285 | align_mask = (align_score_target < 1/64).clone().detach()
286 | align_loss = torch.mean(align_score_target * align_mask)
287 |
288 | piou_m = PIoU_xywh(bbox_1, mask=mask_generated.to(torch.float32), xy_only=True)
289 | piou = torch.mean(piou_m)
290 |
291 | mse = mse_loss(bbox_1, bbox_initial)
292 | loss = 1 * mse + w_a * align_loss + w_o * piou
293 | optimizer.zero_grad()
294 | loss.backward()
295 | torch.nn.utils.clip_grad_norm_([bbox_p], 1.0)
296 | optimizer.step()
297 |
298 | a, _ = torch.min(bbox_1[:, :, [2, 3]], dim=2)
299 | mask_generated = mask_generated * (a > 0.01)
300 |
301 | bbox_out = torch.relu(bbox_p)
302 | bbox_out[:, :, [0, 2]] *= 4 / 10
303 | bbox_out[:, :, [1, 3]] *= 6 / 10
304 |
305 | return bbox_out, mask_generated
306 |
307 |
308 | if __name__ == "__main__":
309 |
310 | w = constraint_temporal_weight(torch.tensor([225]), schedule="const")
311 | print(w)
--------------------------------------------------------------------------------
/util/data_util.py:
--------------------------------------------------------------------------------
1 | import random
2 | from enum import IntEnum
3 | from itertools import combinations, product
4 | from typing import List, Tuple
5 |
6 | import numpy as np
7 | import torch
8 | import torchvision.transforms as T
9 | from torch import BoolTensor, FloatTensor, LongTensor
10 | from torch_geometric.utils import to_dense_batch
11 | from .constraint import xywh_to_ltrb_split
12 |
13 |
14 | class RelSize(IntEnum):
15 | UNKNOWN = 0
16 | SMALLER = 1
17 | EQUAL = 2
18 | LARGER = 3
19 |
20 |
21 | class RelLoc(IntEnum):
22 | UNKNOWN = 4
23 | LEFT = 5
24 | TOP = 6
25 | RIGHT = 7
26 | BOTTOM = 8
27 | CENTER = 9
28 |
29 |
30 | REL_SIZE_ALPHA = 0.1
31 |
32 |
33 | def detect_size_relation(b1, b2):
34 | a1 = b1[2] * b1[3]
35 | a2 = b2[2] * b2[3]
36 | alpha = REL_SIZE_ALPHA
37 | if (1 - alpha) * a1 < a2 < (1 + alpha) * a1:
38 | return RelSize.EQUAL
39 | elif a1 < a2:
40 | return RelSize.LARGER
41 | else:
42 | return RelSize.SMALLER
43 |
44 |
45 | def detect_loc_relation(b1, b2, is_canvas=False):
46 | if is_canvas:
47 | yc = b2[1]
48 | if yc < 1.0 / 3:
49 | return RelLoc.TOP
50 | elif yc < 2.0 / 3:
51 | return RelLoc.CENTER
52 | else:
53 | return RelLoc.BOTTOM
54 |
55 | else:
56 | l1, t1, r1, b1 = xywh_to_ltrb_split(b1)
57 | l2, t2, r2, b2 = xywh_to_ltrb_split(b2)
58 |
59 | if b2 <= t1:
60 | return RelLoc.TOP
61 | elif b1 <= t2:
62 | return RelLoc.BOTTOM
63 | elif r2 <= l1:
64 | return RelLoc.LEFT
65 | elif r1 <= l2:
66 | return RelLoc.RIGHT
67 | else:
68 | # might not be necessary
69 | return RelLoc.CENTER
70 |
71 |
72 | def get_rel_text(rel, canvas=False):
73 | if type(rel) == RelSize:
74 | index = rel - RelSize.UNKNOWN - 1
75 | if canvas:
76 | return [
77 | "within canvas",
78 | "spread over canvas",
79 | "out of canvas",
80 | ][index]
81 |
82 | else:
83 | return [
84 | "larger than",
85 | "equal to",
86 | "smaller than",
87 | ][index]
88 |
89 | else:
90 | index = rel - RelLoc.UNKNOWN - 1
91 | if canvas:
92 | return [
93 | "",
94 | "at top",
95 | "",
96 | "at bottom",
97 | "at middle",
98 | ][index]
99 |
100 | else:
101 | return [
102 | "right to",
103 | "below",
104 | "left to",
105 | "above",
106 | "around",
107 | ][index]
108 |
109 |
110 | # transform
111 | class AddCanvasElement:
112 | x = torch.tensor([[0.5, 0.5, 1.0, 1.0]], dtype=torch.float)
113 | y = torch.tensor([0], dtype=torch.long)
114 |
115 | def __call__(self, data):
116 | flag = data.attr["has_canvas_element"].any().item()
117 | assert not flag
118 | if not flag:
119 | # device = data.x.device
120 | # x, y = self.x.to(device), self.y.to(device)
121 | data.x = torch.cat([self.x, data.x], dim=0)
122 | data.y = torch.cat([self.y, data.y + 1], dim=0)
123 | data.attr = data.attr.copy()
124 | data.attr["has_canvas_element"] = True
125 | return data
126 |
127 |
128 | class AddRelationConstraints:
129 | def __init__(self, seed=None, edge_ratio=0.1, use_v1=False):
130 | self.edge_ratio = edge_ratio
131 | self.use_v1 = use_v1
132 | self.generator = random.Random()
133 | if seed is not None:
134 | self.generator.seed(seed)
135 |
136 | def __call__(self, data):
137 | N = data.x.size(0)
138 | has_canvas = data.attr["has_canvas_element"]
139 |
140 | rel_all = list(product(range(2), combinations(range(N), 2)))
141 | size = int(len(rel_all) * self.edge_ratio)
142 | rel_sample = set(self.generator.sample(rel_all, size))
143 |
144 | edge_index, edge_attr = [], []
145 | rel_unk = 1 << RelSize.UNKNOWN | 1 << RelLoc.UNKNOWN
146 | for i, j in combinations(range(N), 2):
147 | bi, bj = data.x[i], data.x[j]
148 | canvas = data.y[i] == 0 and has_canvas
149 |
150 | if self.use_v1:
151 | if (0, (i, j)) in rel_sample:
152 | rel_size = 1 << detect_size_relation(bi, bj)
153 | rel_loc = 1 << detect_loc_relation(bi, bj, canvas)
154 | else:
155 | rel_size = 1 << RelSize.UNKNOWN
156 | rel_loc = 1 << RelLoc.UNKNOWN
157 | else:
158 | if (0, (i, j)) in rel_sample:
159 | rel_size = 1 << detect_size_relation(bi, bj)
160 | else:
161 | rel_size = 1 << RelSize.UNKNOWN
162 |
163 | if (1, (i, j)) in rel_sample:
164 | rel_loc = 1 << detect_loc_relation(bi, bj, canvas)
165 | else:
166 | rel_loc = 1 << RelLoc.UNKNOWN
167 |
168 | rel = rel_size | rel_loc
169 | if rel != rel_unk:
170 | edge_index.append((i, j))
171 | edge_attr.append(rel)
172 |
173 | data.edge_index = torch.as_tensor(edge_index).long()
174 | data.edge_index = data.edge_index.t().contiguous()
175 | data.edge_attr = torch.as_tensor(edge_attr).long()
176 |
177 | return data
178 |
179 |
180 | class RandomOrder:
181 | def __call__(self, data):
182 | assert not data.attr["has_canvas_element"]
183 | device = data.x.device
184 | N = data.x.size(0)
185 | idx = torch.randperm(N, device=device)
186 | data.x, data.y = data.x[idx], data.y[idx]
187 | return data
188 |
189 |
190 | class SortByLabel:
191 | def __call__(self, data):
192 | assert not data.attr["has_canvas_element"]
193 | idx = data.y.sort().indices
194 | data.x, data.y = data.x[idx], data.y[idx]
195 | return data
196 |
197 |
198 | class LexicographicOrder:
199 | def __call__(self, data):
200 | assert not data.attr["has_canvas_element"]
201 | x, y, _, _ = xywh_to_ltrb_split(data.x.t())
202 | _zip = zip(*sorted(enumerate(zip(y, x)), key=lambda c: c[1:]))
203 | idx = list(list(_zip)[0])
204 |
205 | data.x_orig, data.y_orig = data.x, data.y
206 | data.x, data.y = data.x[idx], data.y[idx]
207 | return data
208 |
209 |
210 | class AddNoiseToBBox:
211 | def __init__(self, std: float = 0.05):
212 | self.std = float(std)
213 |
214 | def __call__(self, data):
215 | noise = torch.normal(0, self.std, size=data.x.size(), device=data.x.device)
216 | data.x_orig = data.x.clone()
217 | data.x = data.x + noise
218 | data.attr = data.attr.copy()
219 | data.attr["NoiseAdded"][0] = True
220 | return data
221 |
222 |
223 | class HorizontalFlip:
224 | def __call__(self, data):
225 | data.x = data.x.clone()
226 | data.x[:, 0] = 1 - data.x[:, 0]
227 | return data
228 |
229 |
230 | # def compose_transform(transforms):
231 | # module = sys.modules[__name__]
232 | # transform_list = []
233 | # for t in transforms:
234 | # # parse args
235 | # if "(" in t and ")" in t:
236 | # args = t[t.index("(") + 1 : t.index(")")]
237 | # t = t[: t.index("(")]
238 | # regex = re.compile(r"\b(\w+)=(.*?)(?=\s\w+=\s*|$)")
239 | # args = dict(regex.findall(args))
240 | # for k in args:
241 | # try:
242 | # args[k] = float(args[k])
243 | # except:
244 | # pass
245 | # else:
246 | # args = {}
247 | # if isinstance(t, str):
248 | # if hasattr(module, t):
249 | # transform_list.append(getattr(module, t)(**args))
250 | # else:
251 | # raise NotImplementedError
252 | # else:
253 | # raise NotImplementedError
254 | # return T.Compose(transform_list)
255 |
256 |
257 | def compose_transform(transforms: List[str]) -> T.Compose:
258 | """
259 | Compose transforms, optionally with args (e.g., AddRelationConstraints(edge_ratio=0.1))
260 | """
261 | transform_list = []
262 | for t in transforms:
263 | if "(" in t and ")" in t:
264 | pass
265 | else:
266 | t += "()"
267 | transform_list.append(eval(t))
268 | return T.Compose(transform_list)
269 |
270 |
271 | def sparse_to_dense(
272 | batch,
273 | device: torch.device = torch.device("cpu"),
274 | remove_canvas: bool = False,
275 | ) -> Tuple[FloatTensor, LongTensor, BoolTensor, BoolTensor]:
276 | batch = batch.to(device)
277 | bbox, _ = to_dense_batch(batch.x, batch.batch)
278 | label, mask = to_dense_batch(batch.y, batch.batch)
279 |
280 | if remove_canvas:
281 | bbox = bbox[:, 1:].contiguous()
282 | label = label[:, 1:].contiguous() - 1 # cancel +1 effect in transform
283 | label = label.clamp(min=0)
284 | mask = mask[:, 1:].contiguous()
285 |
286 | padding_mask = ~mask
287 | return bbox, label, padding_mask, mask
288 |
289 |
290 | def loader_to_list(
291 | loader: torch.utils.data.dataloader.DataLoader,
292 | ) -> List[Tuple[np.ndarray, np.ndarray]]:
293 | layouts = []
294 | for batch in loader:
295 | bbox, label, _, mask = sparse_to_dense(batch)
296 | for i in range(len(label)):
297 | valid = mask[i].numpy()
298 | layouts.append((bbox[i].numpy()[valid], label[i].numpy()[valid]))
299 | return layouts
300 |
301 |
302 | def split_num_samples(N: int, batch_size: int) -> List[int]:
303 | quontinent = N // batch_size
304 | remainder = N % batch_size
305 | dataloader = quontinent * [batch_size]
306 | if remainder > 0:
307 | dataloader.append(remainder)
308 | return dataloader
309 |
--------------------------------------------------------------------------------
/util/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .publaynet import PubLayNetDataset
2 | from .rico import Rico25Dataset
3 |
4 | _DATASETS = [
5 | Rico25Dataset,
6 | PubLayNetDataset,
7 | ]
8 | DATASETS = {d.name: d for d in _DATASETS}
9 |
--------------------------------------------------------------------------------
/util/datasets/base.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import fsspec
4 | import seaborn as sns
5 | import torch
6 | from fsspec.core import url_to_fs
7 |
8 | from .dataset import InMemoryDataset
9 |
10 |
11 | class BaseDataset(InMemoryDataset):
12 | name = None
13 | labels = []
14 | _label2index = None
15 | _index2label = None
16 | _colors = None
17 |
18 | def __init__(self, dir: str, split: str, max_seq_length: int, transform=None):
19 | assert split in ["train", "val", "test"]
20 | name = f"{self.name}-max{max_seq_length}"
21 | self.max_seq_length = max_seq_length
22 | super().__init__(os.path.join(dir, name), transform)
23 | idx = self.processed_file_names.index("{}.pt".format(split))
24 |
25 | with fsspec.open(self.processed_paths[idx], "rb") as file_obj:
26 | self.data, self.slices = torch.load(file_obj)
27 |
28 | @property
29 | def label2index(self):
30 | if self._label2index is None:
31 | self._label2index = dict()
32 | for idx, label in enumerate(self.labels):
33 | self._label2index[label] = idx
34 | return self._label2index
35 |
36 | @property
37 | def index2label(self):
38 | if self._index2label is None:
39 | self._index2label = dict()
40 | for idx, label in enumerate(self.labels):
41 | self._index2label[idx] = label
42 | return self._index2label
43 |
44 | @property
45 | def colors(self):
46 | if self._colors is None:
47 | n_colors = self.num_classes
48 | colors = sns.color_palette("husl", n_colors=n_colors)
49 | self._colors = [tuple(map(lambda x: int(x * 255), c)) for c in colors]
50 | return self._colors
51 |
52 | @property
53 | def raw_file_names(self):
54 | fs, _ = url_to_fs(self.raw_dir)
55 | if not fs.exists(self.raw_dir):
56 | return []
57 | file_names = [f.split("/")[-1] for f in fs.ls(self.raw_dir)]
58 | return file_names
59 |
60 | @property
61 | def processed_file_names(self):
62 | return ["train.pt", "val.pt", "test.pt"]
63 |
64 | def download(self):
65 | raise FileNotFoundError("See dataset/README.md")
66 |
67 | def process(self):
68 | raise NotImplementedError
69 |
70 | def get_original_images(self):
71 | raise NotImplementedError
72 |
--------------------------------------------------------------------------------
/util/datasets/crello.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 | import random
4 | random.seed(1)
5 | from fsspec.core import url_to_fs
6 | import numpy as np
7 | import torch
8 | from torch_geometric.data import Data
9 | from .base import BaseDataset
10 | import datasets
11 |
12 | def ltrb_2_xywh(bbox_ltrb):
13 | bbox_xywh = torch.zeros(bbox_ltrb.shape)
14 | bbox_wh = torch.abs(bbox_ltrb[:, 2:] - bbox_ltrb[:, :2])
15 | bbox_xy = bbox_ltrb[:, :2] + 0.5 * bbox_wh
16 | bbox_xywh[:, :2] = bbox_xy
17 | bbox_xywh[:, 2:] = bbox_wh
18 | return bbox_xywh
19 |
20 | class CrelloDataset(BaseDataset):
21 | name = "crello"
22 | label_names = ['coloredBackground', 'imageElement', 'maskElement', 'svgElement', 'textElement']
23 | label_dict = {k: i for i, k in enumerate(label_names)}
24 |
25 | def __init__(self, dir: str, split: str, max_seq_length: int, transform=None):
26 | super().__init__(dir, split, max_seq_length, transform)
27 |
28 | def download(self):
29 | _ = datasets.load_dataset("cyberagent/crello")
30 |
31 | def process(self):
32 | fs, _ = url_to_fs(self.raw_dir)
33 | for split_i, split in enumerate(['train', 'validation', 'test']):
34 |
35 | dataset = datasets.load_dataset("cyberagent/crello", split=split)
36 | # label_names = dataset.features['type'].feature.names
37 | width_list = [int(x) for x in dataset.features['canvas_width'].names]
38 | height_list = [int(x) for x in dataset.features['canvas_height'].names]
39 |
40 | data_list = []
41 | for i, elements in tqdm(enumerate(dataset), total=len(dataset), desc=f'split: {split}', ncols=150):
42 | labels = torch.tensor(elements['type']).to(torch.long)
43 | wi = elements['canvas_width']
44 | hi = elements['canvas_height']
45 | W = width_list[wi]
46 | H = height_list[hi]
47 | left = torch.tensor(elements['left'])
48 | top = torch.tensor(elements['top'])
49 | width = torch.tensor(elements['width'])
50 | height = torch.tensor(elements['height'])
51 |
52 | def is_valid(x1, y1, width, height):
53 |
54 | if torch.min(width) <= 0 or torch.min(height) <= 0:
55 | return None, False
56 |
57 | x2, y2 = x1 + width, y1 + height
58 | bboxes = torch.stack([x1, y1, x2, y2], dim=1)
59 | if torch.max(bboxes) > 1 or torch.min(bboxes) < 0:
60 | bboxes_ltbr = torch.clamp(bboxes, min=0, max=1)
61 | bboxes = ltrb_2_xywh(bboxes_ltbr)
62 | return bboxes, False
63 |
64 | return bboxes, True
65 |
66 | bboxes, filtered = is_valid(left, top, width, height)
67 | if bboxes is None:
68 | continue
69 |
70 | if bboxes.shape[0] == 0 or bboxes.shape[0] > self.max_seq_length:
71 | continue
72 |
73 | data = Data(x=bboxes, y=labels)
74 | data.attr = {
75 | "name": elements['id'],
76 | "width": W,
77 | "height": H,
78 | "filtered": filtered,
79 | "has_canvas_element": False,
80 | "NoiseAdded": False,
81 | }
82 | data_list.append(data)
83 |
84 | with fs.open(self.processed_paths[split_i], "wb") as file_obj:
85 | if split_i == 0:
86 | print('duplicate training split (x10) for more batches per epoch')
87 | torch.save(self.collate(data_list * 10), file_obj)
88 | else:
89 | torch.save(self.collate(data_list), file_obj)
90 |
91 |
92 |
93 | if __name__ == '__main__':
94 |
95 | data_dir = '../../download/datasets'
96 | # get_train_labels(data_dir)
97 | # get_val_test_labels(data_dir)
98 |
99 | # train_dataset = MagazineDataset(dir=data_dir, split='train')
100 | # print(len(train_dataset))
101 | # train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=False, num_workers=4)
102 | # for bbox, label, mask in tqdm(train_loader):
103 | # print(bbox[0])
104 | # print(label[0])
105 | # print(mask[0])
106 | # break
107 |
108 |
--------------------------------------------------------------------------------
/util/datasets/dataset.py:
--------------------------------------------------------------------------------
1 | # slightly modified to cope with remote folders/files starting with gs://
2 | import copy
3 | import os.path as osp
4 | import re
5 | import sys
6 | import warnings
7 | from collections.abc import Mapping, Sequence
8 | from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
9 | import numpy as np
10 | import torch
11 | import torch.utils.data
12 | from fsspec.core import url_to_fs
13 | from torch import Tensor
14 | from torch_geometric.data import Data
15 | from torch_geometric.data.collate import collate
16 | from torch_geometric.data.dataset import Dataset
17 | from torch_geometric.data.makedirs import makedirs
18 | from torch_geometric.data.separate import separate
19 |
20 | IndexType = Union[slice, Tensor, np.ndarray, Sequence]
21 |
22 |
23 | class Dataset(torch.utils.data.Dataset):
24 | r"""Dataset base class for creating graph datasets.
25 | See `here `__ for the accompanying tutorial.
27 |
28 | Args:
29 | root (string, optional): Root directory where the dataset should be
30 | saved. (optional: :obj:`None`)
31 | transform (callable, optional): A function/transform that takes in an
32 | :obj:`torch_geometric.data.Data` object and returns a transformed
33 | version. The data object will be transformed before every access.
34 | (default: :obj:`None`)
35 | pre_transform (callable, optional): A function/transform that takes in
36 | an :obj:`torch_geometric.data.Data` object and returns a
37 | transformed version. The data object will be transformed before
38 | being saved to disk. (default: :obj:`None`)
39 | pre_filter (callable, optional): A function that takes in an
40 | :obj:`torch_geometric.data.Data` object and returns a boolean
41 | value, indicating whether the data object should be included in the
42 | final dataset. (default: :obj:`None`)
43 | """
44 |
45 | @property
46 | def raw_file_names(self) -> Union[str, List[str], Tuple]:
47 | r"""The name of the files in the :obj:`self.raw_dir` folder that must
48 | be present in order to skip downloading."""
49 | raise NotImplementedError
50 |
51 | @property
52 | def processed_file_names(self) -> Union[str, List[str], Tuple]:
53 | r"""The name of the files in the :obj:`self.processed_dir` folder that
54 | must be present in order to skip processing."""
55 | raise NotImplementedError
56 |
57 | def download(self):
58 | r"""Downloads the dataset to the :obj:`self.raw_dir` folder."""
59 | raise NotImplementedError
60 |
61 | def process(self):
62 | r"""Processes the dataset to the :obj:`self.processed_dir` folder."""
63 | raise NotImplementedError
64 |
65 | def len(self) -> int:
66 | r"""Returns the number of graphs stored in the dataset."""
67 | raise NotImplementedError
68 |
69 | def get(self, idx: int) -> Data:
70 | r"""Gets the data object at index :obj:`idx`."""
71 | raise NotImplementedError
72 |
73 | def __init__(
74 | self,
75 | root: Optional[str] = None,
76 | transform: Optional[Callable] = None,
77 | pre_transform: Optional[Callable] = None,
78 | pre_filter: Optional[Callable] = None,
79 | ):
80 | super().__init__()
81 |
82 | if isinstance(root, str):
83 | if not root.startswith("gs://"):
84 | root = osp.expanduser(osp.normpath(root))
85 |
86 | self.root = root
87 | self.transform = transform
88 | self.pre_transform = pre_transform
89 | self.pre_filter = pre_filter
90 | self._indices: Optional[Sequence] = None
91 |
92 | if self.download.__qualname__.split(".")[0] != "Dataset":
93 | self._download()
94 |
95 | if self.process.__qualname__.split(".")[0] != "Dataset":
96 | fs, _ = url_to_fs(self.processed_dir)
97 | if not all(fs.exists(p) for p in self.processed_paths):
98 | self._process()
99 |
100 | def indices(self) -> Sequence:
101 | return range(self.len()) if self._indices is None else self._indices
102 |
103 | @property
104 | def raw_dir(self) -> str:
105 | return osp.join(self.root, "raw")
106 |
107 | @property
108 | def processed_dir(self) -> str:
109 | return osp.join(self.root, "processed")
110 |
111 | @property
112 | def num_node_features(self) -> int:
113 | r"""Returns the number of features per node in the dataset."""
114 | data = self[0]
115 | data = data[0] if isinstance(data, tuple) else data
116 | if hasattr(data, "num_node_features"):
117 | return data.num_node_features
118 | raise AttributeError(
119 | f"'{data.__class__.__name__}' object has no "
120 | f"attribute 'num_node_features'"
121 | )
122 |
123 | @property
124 | def num_features(self) -> int:
125 | r"""Returns the number of features per node in the dataset.
126 | Alias for :py:attr:`~num_node_features`."""
127 | return self.num_node_features
128 |
129 | @property
130 | def num_edge_features(self) -> int:
131 | r"""Returns the number of features per edge in the dataset."""
132 | data = self[0]
133 | data = data[0] if isinstance(data, tuple) else data
134 | if hasattr(data, "num_edge_features"):
135 | return data.num_edge_features
136 | raise AttributeError(
137 | f"'{data.__class__.__name__}' object has no "
138 | f"attribute 'num_edge_features'"
139 | )
140 |
141 | @property
142 | def raw_paths(self) -> List[str]:
143 | r"""The absolute filepaths that must be present in order to skip
144 | downloading."""
145 | files = to_list(self.raw_file_names)
146 | return [osp.join(self.raw_dir, f) for f in files]
147 |
148 | @property
149 | def processed_paths(self) -> List[str]:
150 | r"""The absolute filepaths that must be present in order to skip
151 | processing."""
152 | files = to_list(self.processed_file_names)
153 | return [osp.join(self.processed_dir, f) for f in files]
154 |
155 | def _download(self):
156 | if files_exist(self.raw_paths): # pragma: no cover
157 | return
158 |
159 | makedirs(self.raw_dir)
160 | self.download()
161 |
162 | def _process(self):
163 | f = osp.join(self.processed_dir, "pre_transform.pt")
164 | fs, _ = url_to_fs(self.processed_dir)
165 | if fs.exists(f):
166 | with fs.open(f, "rb") as file_obj:
167 | x = torch.load(file_obj)
168 | if x != _repr(self.pre_transform):
169 | warnings.warn(
170 | f"The `pre_transform` argument differs from the one used in "
171 | f"the pre-processed version of this dataset. If you want to "
172 | f"make use of another pre-processing technique, make sure to "
173 | f"delete '{self.processed_dir}' first"
174 | )
175 |
176 | f = osp.join(self.processed_dir, "pre_filter.pt")
177 | if fs.exists(f):
178 | with fs.open(f, "rb") as file_obj:
179 | x = torch.load(file_obj)
180 | if x != _repr(self.pre_filter):
181 | warnings.warn(
182 | "The `pre_filter` argument differs from the one used in "
183 | "the pre-processed version of this dataset. If you want to "
184 | "make use of another pre-fitering technique, make sure to "
185 | "delete '{self.processed_dir}' first"
186 | )
187 |
188 | if files_exist(self.processed_paths): # pragma: no cover
189 | return
190 |
191 | print("Processing...", file=sys.stderr)
192 |
193 | makedirs(self.processed_dir)
194 | self.process()
195 |
196 | path = osp.join(self.processed_dir, "pre_transform.pt")
197 | with fs.open(path, "wb") as file_obj:
198 | torch.save(_repr(self.pre_transform), file_obj)
199 | path = osp.join(self.processed_dir, "pre_filter.pt")
200 | with fs.open(path, "wb") as file_obj:
201 | torch.save(_repr(self.pre_filter), file_obj)
202 |
203 | print("Done!", file=sys.stderr)
204 |
205 | def __len__(self) -> int:
206 | r"""The number of examples in the dataset."""
207 | return len(self.indices())
208 |
209 | def __getitem__(
210 | self,
211 | idx: Union[int, np.integer, IndexType],
212 | ) -> Union["Dataset", Data]:
213 | r"""In case :obj:`idx` is of type integer, will return the data object
214 | at index :obj:`idx` (and transforms it in case :obj:`transform` is
215 | present).
216 | In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a
217 | tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type long or
218 | bool, will return a subset of the dataset at the specified indices."""
219 | if (
220 | isinstance(idx, (int, np.integer))
221 | or (isinstance(idx, Tensor) and idx.dim() == 0)
222 | or (isinstance(idx, np.ndarray) and np.isscalar(idx))
223 | ):
224 | data = self.get(self.indices()[idx])
225 | data = data if self.transform is None else self.transform(data)
226 | return data
227 |
228 | else:
229 | return self.index_select(idx)
230 |
231 | def index_select(self, idx: IndexType) -> "Dataset":
232 | r"""Creates a subset of the dataset from specified indices :obj:`idx`.
233 | Indices :obj:`idx` can be a slicing object, *e.g.*, :obj:`[2:5]`, a
234 | list, a tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type
235 | long or bool."""
236 | indices = self.indices()
237 |
238 | if isinstance(idx, slice):
239 | indices = indices[idx]
240 |
241 | elif isinstance(idx, Tensor) and idx.dtype == torch.long:
242 | return self.index_select(idx.flatten().tolist())
243 |
244 | elif isinstance(idx, Tensor) and idx.dtype == torch.bool:
245 | idx = idx.flatten().nonzero(as_tuple=False)
246 | return self.index_select(idx.flatten().tolist())
247 |
248 | elif isinstance(idx, np.ndarray) and idx.dtype == np.int64:
249 | return self.index_select(idx.flatten().tolist())
250 |
251 | elif isinstance(idx, np.ndarray) and idx.dtype == np.bool:
252 | idx = idx.flatten().nonzero()[0]
253 | return self.index_select(idx.flatten().tolist())
254 |
255 | elif isinstance(idx, Sequence) and not isinstance(idx, str):
256 | indices = [indices[i] for i in idx]
257 |
258 | else:
259 | raise IndexError(
260 | f"Only slices (':'), list, tuples, torch.tensor and "
261 | f"np.ndarray of dtype long or bool are valid indices (got "
262 | f"'{type(idx).__name__}')"
263 | )
264 |
265 | dataset = copy.copy(self)
266 | dataset._indices = indices
267 | return dataset
268 |
269 | def shuffle(
270 | self,
271 | return_perm: bool = False,
272 | ) -> Union["Dataset", Tuple["Dataset", Tensor]]:
273 | r"""Randomly shuffles the examples in the dataset.
274 |
275 | Args:
276 | return_perm (bool, optional): If set to :obj:`True`, will also
277 | return the random permutation used to shuffle the dataset.
278 | (default: :obj:`False`)
279 | """
280 | perm = torch.randperm(len(self))
281 | dataset = self.index_select(perm)
282 | return (dataset, perm) if return_perm is True else dataset
283 |
284 | def __repr__(self) -> str:
285 | arg_repr = str(len(self)) if len(self) > 1 else ""
286 | return f"{self.__class__.__name__}({arg_repr})"
287 |
288 |
289 | def to_list(value: Any) -> Sequence:
290 | if isinstance(value, Sequence) and not isinstance(value, str):
291 | return value
292 | else:
293 | return [value]
294 |
295 |
296 | def files_exist(files: List[str]) -> bool:
297 | # NOTE: We return `False` in case `files` is empty, leading to a
298 | # re-processing of files on every instantiation.
299 | if len(files) == 0:
300 | return False
301 | else:
302 | fs, _ = url_to_fs(files[0])
303 | return all([fs.exists(f) for f in files])
304 |
305 |
306 | def _repr(obj: Any) -> str:
307 | if obj is None:
308 | return "None"
309 | return re.sub("(<.*?)\\s.*(>)", r"\1\2", obj.__repr__())
310 |
311 |
312 | class InMemoryDataset(Dataset):
313 | r"""Dataset base class for creating graph datasets which easily fit
314 | into CPU memory.
315 | Inherits from :class:`torch_geometric.data.Dataset`.
316 | See `here `__ for the accompanying
318 | tutorial.
319 |
320 | Args:
321 | root (string, optional): Root directory where the dataset should be
322 | saved. (default: :obj:`None`)
323 | transform (callable, optional): A function/transform that takes in an
324 | :obj:`torch_geometric.data.Data` object and returns a transformed
325 | version. The data object will be transformed before every access.
326 | (default: :obj:`None`)
327 | pre_transform (callable, optional): A function/transform that takes in
328 | an :obj:`torch_geometric.data.Data` object and returns a
329 | transformed version. The data object will be transformed before
330 | being saved to disk. (default: :obj:`None`)
331 | pre_filter (callable, optional): A function that takes in an
332 | :obj:`torch_geometric.data.Data` object and returns a boolean
333 | value, indicating whether the data object should be included in the
334 | final dataset. (default: :obj:`None`)
335 | """
336 |
337 | @property
338 | def raw_file_names(self) -> Union[str, List[str], Tuple]:
339 | raise NotImplementedError
340 |
341 | @property
342 | def processed_file_names(self) -> Union[str, List[str], Tuple]:
343 | raise NotImplementedError
344 |
345 | def __init__(
346 | self,
347 | root: Optional[str] = None,
348 | transform: Optional[Callable] = None,
349 | pre_transform: Optional[Callable] = None,
350 | pre_filter: Optional[Callable] = None,
351 | ):
352 | super().__init__(root, transform, pre_transform, pre_filter)
353 | self.data = None
354 | self.slices = None
355 | self._data_list: Optional[List[Data]] = None
356 |
357 | @property
358 | def num_classes(self) -> int:
359 | r"""Returns the number of classes in the dataset."""
360 | y = self.data.y
361 | if y is None:
362 | return 0
363 | elif y.numel() == y.size(0) and not torch.is_floating_point(y):
364 | return int(self.data.y.max()) + 1
365 | elif y.numel() == y.size(0) and torch.is_floating_point(y):
366 | return torch.unique(y).numel()
367 | else:
368 | return self.data.y.size(-1)
369 |
370 | def len(self) -> int:
371 | if self.slices is None:
372 | return 1
373 | for _, value in nested_iter(self.slices):
374 | return len(value) - 1
375 | return 0
376 |
377 | def get(self, idx: int) -> Data:
378 | if self.len() == 1:
379 | return copy.copy(self.data)
380 |
381 | if not hasattr(self, "_data_list") or self._data_list is None:
382 | self._data_list = self.len() * [None]
383 | elif self._data_list[idx] is not None:
384 | return copy.copy(self._data_list[idx])
385 |
386 | data = separate(
387 | cls=self.data.__class__,
388 | batch=self.data,
389 | idx=idx,
390 | slice_dict=self.slices,
391 | decrement=False,
392 | )
393 |
394 | self._data_list[idx] = copy.copy(data)
395 |
396 | return data
397 |
398 | @staticmethod
399 | def collate(data_list: List[Data]) -> Tuple[Data, Optional[Dict[str, Tensor]]]:
400 | r"""Collates a Python list of :obj:`torch_geometric.data.Data` objects
401 | to the internal storage format of
402 | :class:`~torch_geometric.data.InMemoryDataset`."""
403 | if len(data_list) == 1:
404 | return data_list[0], None
405 |
406 | data, slices, _ = collate(
407 | data_list[0].__class__,
408 | data_list=data_list,
409 | increment=False,
410 | add_batch=False,
411 | )
412 |
413 | return data, slices
414 |
415 | def copy(self, idx: Optional[IndexType] = None) -> "InMemoryDataset":
416 | r"""Performs a deep-copy of the dataset. If :obj:`idx` is not given,
417 | will clone the full dataset. Otherwise, will only clone a subset of the
418 | dataset from indices :obj:`idx`.
419 | Indices can be slices, lists, tuples, and a :obj:`torch.Tensor` or
420 | :obj:`np.ndarray` of type long or bool.
421 | """
422 | if idx is None:
423 | data_list = [self.get(i) for i in self.indices()]
424 | else:
425 | data_list = [self.get(i) for i in self.index_select(idx).indices()]
426 |
427 | dataset = copy.copy(self)
428 | dataset._indices = None
429 | dataset._data_list = None
430 | dataset.data, dataset.slices = self.collate(data_list)
431 | return dataset
432 |
433 |
434 | def nested_iter(mapping: Mapping) -> Iterable:
435 | for key, value in mapping.items():
436 | if isinstance(value, Mapping):
437 | for inner_key, inner_value in nested_iter(value):
438 | yield inner_key, inner_value
439 | else:
440 | yield key, value
441 |
--------------------------------------------------------------------------------
/util/datasets/load_data.py:
--------------------------------------------------------------------------------
1 | from util.datasets.publaynet import PubLayNetDataset
2 | from util.datasets.magazine import MagazineDataset
3 | from util.datasets.rico import Rico25Dataset, Rico5Dataset, Rico13Dataset
4 | from util.datasets.crello import CrelloDataset
5 | import torch
6 | import torch_geometric
7 | from util.seq_util import sparse_to_dense, pad_until
8 |
9 | dataset_dict = {'publaynet': PubLayNetDataset, 'rico13': Rico13Dataset,
10 | 'rico25': Rico25Dataset, 'magazine': MagazineDataset, 'crello': CrelloDataset}
11 |
12 |
13 | def init_dataset(dataset_name, data_dir, batch_size=128, split='train', transform=None, shuffle=False):
14 |
15 | main_dataset = dataset_dict[dataset_name](dir=data_dir, split=split, max_seq_length=25, transform=transform)
16 | main_loader = torch_geometric.loader.DataLoader(main_dataset, shuffle=shuffle, batch_size=batch_size,
17 | num_workers=4, pin_memory=True)
18 | return main_dataset, main_loader
19 |
20 |
21 | if __name__=="__main__":
22 |
23 | device = "cuda" if torch.cuda.is_available() else "cpu"
24 | data_dir = '/Users/chenjian/Documents/Projects/Layout/LayoutDiff/datasets'
25 |
26 | train_dataset, train_loader = init_dataset('crello', data_dir, batch_size=120, split='train')
27 | print(train_dataset.num_classes)
28 |
29 | for data in train_loader:
30 | bbox, label, _, mask = sparse_to_dense(data)
31 | label, bbox, mask = pad_until(label, bbox, mask, max_seq_length=25)
32 | bbox = bbox.to(device)
33 |
34 | print(bbox[0])
35 | break
36 |
37 | #
38 |
39 |
--------------------------------------------------------------------------------
/util/datasets/magazine.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 | import random
4 | random.seed(1)
5 | from fsspec.core import url_to_fs
6 | import numpy as np
7 | import torch
8 | from torch_geometric.data import Data
9 | import torch_geometric
10 | from .base import BaseDataset
11 | import torchvision.transforms as transforms
12 | import xml.etree.ElementTree as ET
13 |
14 | magazine_cmap = {
15 | "text": (254, 231, 44),
16 | "image": (27, 187, 146),
17 | "headline": (255, 0, 0),
18 | "text-over-image": (0, 102, 255),
19 | "headline-over-image": (204, 0, 255),
20 | "background": (200, 200, 200),
21 | }
22 |
23 | label_dict = {"text": 0, "image": 1, "headline": 2, "text-over-image": 3, "headline-over-image": 4, "background": 5}
24 |
25 |
26 | class MagazineDataset(BaseDataset):
27 | name = "magazine"
28 | labels = ["text", "image", "headline", "text-over-image", "headline-over-image", "background"]
29 |
30 | def __init__(self, dir: str, split: str, max_seq_length: int, transform=None):
31 | super().__init__(dir, split, max_seq_length, transform)
32 |
33 | def download(self):
34 | # super().download()
35 | pass
36 |
37 | def process(self):
38 |
39 | fs, _ = url_to_fs(self.raw_dir)
40 | # self.data_root
41 | self.colormap = magazine_cmap
42 | self.labels = ("text", "image", "headline", "text-over-image", "headline-over-image", "background")
43 |
44 | file_folder = os.path.join(self.raw_dir, 'annotations')
45 |
46 | file_cls_wise = {'fashion': [], 'food': [], 'science': [], 'news': [], 'travel': [], 'wedding': []}
47 | for file_name in sorted(os.listdir(file_folder)):
48 | cls = file_name.split('_')[0]
49 | file_cls_wise[cls].append(file_name)
50 |
51 | file_lists = dict()
52 | file_lists['train'] = sum([file_cls_wise[k][:int(0.9 * len(file_cls_wise[k]))] for k in file_cls_wise.keys()], [])
53 | file_lists['test'] = sum([file_cls_wise[k][int(0.9 * len(file_cls_wise[k])):] for k in file_cls_wise.keys()], [])
54 |
55 | for split in ['train', 'test']:
56 | file_list = file_lists[split]
57 | data_list = []
58 |
59 | for file_name in file_list:
60 | boxes = []
61 | labels = []
62 |
63 | tree = ET.parse(os.path.join(file_folder, file_name))
64 | root = tree.getroot()
65 | try:
66 | for layout in root.findall('layout'):
67 | if len(layout.findall('element')) > self.max_seq_length:
68 | continue
69 | for i, element in enumerate(layout.findall('element')):
70 | label = element.get('label')
71 | c = label_dict[label]
72 | px = [int(i) for i in element.get('polygon_x').split(" ")]
73 | py = [int(i) for i in element.get('polygon_y').split(" ")]
74 | # get center coordinate (x,y), width and height (w, h)
75 | x = (px[2] + px[0]) / 2 / 225
76 | y = (py[2] + py[0]) / 2 / 300
77 | w = (px[2] - px[0]) / 225
78 | h = (py[2] - py[0]) / 300
79 | boxes.append([x, y, w, h])
80 | labels.append(c)
81 |
82 | except:
83 | continue
84 |
85 | if len(labels) == 0:
86 | continue
87 |
88 | boxes = torch.tensor(boxes, dtype=torch.float)
89 | labels = torch.tensor(labels, dtype=torch.long)
90 |
91 | data = Data(x=boxes, y=labels)
92 | data.attr = {
93 | "name": file_name,
94 | "width": 225,
95 | "height": 300,
96 | "filtered": True,
97 | "has_canvas_element": False,
98 | "NoiseAdded": False,
99 | }
100 | data_list.append(data)
101 |
102 | if split == "train":
103 | train_list = data_list
104 | else:
105 | test_list = data_list
106 |
107 | # train 90% / test 10%
108 | # self.processed_paths: [train. val, test]
109 | with fs.open(self.processed_paths[0], "wb") as file_obj:
110 | print('duplicate training split (x100) for more batches per epoch')
111 | torch.save(self.collate(train_list * 100), file_obj)
112 | with fs.open(self.processed_paths[1], "wb") as file_obj:
113 | torch.save([], file_obj)
114 | with fs.open(self.processed_paths[2], "wb") as file_obj:
115 | torch.save(self.collate(test_list), file_obj)
116 |
--------------------------------------------------------------------------------
/util/datasets/publaynet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from fsspec.core import url_to_fs
4 | from torch_geometric.data import Data
5 | from tqdm import tqdm
6 | from .base import BaseDataset
7 | from typing import List, Tuple, Union
8 | import numpy as np
9 | from torch import Tensor
10 | from collections.abc import Mapping, Sequence
11 | IndexType = Union[slice, Tensor, np.ndarray, Sequence]
12 | from torch import BoolTensor, FloatTensor, LongTensor
13 | from torch_geometric.utils import to_dense_batch
14 |
15 | def sparse_to_dense(
16 | batch,
17 | device: torch.device = torch.device("cpu"),
18 | remove_canvas: bool = False,
19 | ) -> Tuple[FloatTensor, LongTensor, BoolTensor, BoolTensor]:
20 | batch = batch.to(device)
21 | bbox, _ = to_dense_batch(batch.x, batch.batch)
22 | label, mask = to_dense_batch(batch.y, batch.batch)
23 |
24 | if remove_canvas:
25 | bbox = bbox[:, 1:].contiguous()
26 | label = label[:, 1:].contiguous() - 1 # cancel +1 effect in transform
27 | label = label.clamp(min=0)
28 | mask = mask[:, 1:].contiguous()
29 |
30 | padding_mask = ~mask
31 | return bbox, label, padding_mask, mask
32 |
33 |
34 | class PubLayNetDataset(BaseDataset):
35 | name = "publaynet"
36 | labels = [
37 | "text",
38 | "title",
39 | "list",
40 | "table",
41 | "figure",
42 | ]
43 |
44 | def __init__(self, dir: str, split: str, max_seq_length: int, transform=None):
45 | super().__init__(dir, split, max_seq_length, transform)
46 |
47 | def download(self):
48 | # super().download()
49 | pass
50 |
51 | def process(self):
52 | from pycocotools.coco import COCO
53 |
54 | fs, _ = url_to_fs(self.raw_dir)
55 |
56 | # if self.raw_dir.startswith("gs://"):
57 | # raise NotImplementedError
58 |
59 | for split_publaynet in ["train", "val"]:
60 | data_list = []
61 | coco = COCO(
62 | os.path.join(self.raw_dir, "publaynet", f"{split_publaynet}.json")
63 | )
64 | for img_id in tqdm(sorted(coco.getImgIds())):
65 | ann_img = coco.loadImgs(img_id)
66 | W = float(ann_img[0]["width"])
67 | H = float(ann_img[0]["height"])
68 | name = ann_img[0]["file_name"]
69 | if H < W:
70 | continue
71 |
72 | def is_valid(element):
73 | x1, y1, width, height = element["bbox"]
74 | x2, y2 = x1 + width, y1 + height
75 | if x1 < 0 or y1 < 0 or W < x2 or H < y2:
76 | return False
77 |
78 | if x2 <= x1 or y2 <= y1:
79 | return False
80 |
81 | return True
82 |
83 | elements = coco.loadAnns(coco.getAnnIds(imgIds=[img_id]))
84 | _elements = list(filter(is_valid, elements))
85 | filtered = len(elements) != len(_elements)
86 | elements = _elements
87 |
88 | N = len(elements)
89 | if N == 0 or self.max_seq_length < N:
90 | continue
91 |
92 | boxes = []
93 | labels = []
94 |
95 | for element in elements:
96 | # bbox
97 | x1, y1, width, height = element["bbox"]
98 | xc = x1 + width / 2.0
99 | yc = y1 + height / 2.0
100 | b = [xc / W, yc / H, width / W, height / H]
101 | boxes.append(b)
102 |
103 | # label
104 | l = coco.cats[element["category_id"]]["name"]
105 | labels.append(self.label2index[l])
106 |
107 | boxes = torch.tensor(boxes, dtype=torch.float)
108 | labels = torch.tensor(labels, dtype=torch.long)
109 |
110 | data = Data(x=boxes, y=labels)
111 | data.attr = {
112 | "name": name,
113 | "width": W,
114 | "height": H,
115 | "filtered": filtered,
116 | "has_canvas_element": False,
117 | "NoiseAdded": False,
118 | }
119 | data_list.append(data)
120 |
121 | if split_publaynet == "train":
122 | train_list = data_list
123 | else:
124 | val_list = data_list
125 |
126 | # shuffle train with seed
127 | generator = torch.Generator().manual_seed(0)
128 | indices = torch.randperm(len(train_list), generator=generator)
129 | train_list = [train_list[i] for i in indices]
130 |
131 | # train_list -> train 95% / val 5%
132 | # val_list -> test 100%
133 | s = int(len(train_list) * 0.95)
134 | with fs.open(self.processed_paths[0], "wb") as file_obj:
135 | torch.save(self.collate(train_list[:s]), file_obj)
136 | with fs.open(self.processed_paths[1], "wb") as file_obj:
137 | torch.save(self.collate(train_list[s:]), file_obj)
138 | with fs.open(self.processed_paths[2], "wb") as file_obj:
139 | torch.save(self.collate(val_list), file_obj)
140 |
141 |
142 | if __name__ == "__main__":
143 |
144 | dataset = PubLayNetDataset(dir='./download/datasets', max_seq_length=25, split="test", transform=None)
145 |
146 |
--------------------------------------------------------------------------------
/util/datasets/rico.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import json
3 | import os
4 | from pathlib import Path
5 | from zipfile import ZipFile
6 |
7 | import numpy as np
8 | import torch
9 | from fsspec.core import url_to_fs
10 | from PIL import Image, ImageDraw
11 | from torch_geometric.data import Data
12 | from tqdm import tqdm
13 | from util.seq_util import sparse_to_dense
14 | from util.constraint import xywh_to_ltrb_split
15 | from .base import BaseDataset
16 |
17 | _rico5_labels = [
18 | "Text",
19 | "Text Button",
20 | "Toolbar",
21 | "Image",
22 | "Icon",
23 | ]
24 |
25 | _rico13_labels = [
26 | "Toolbar",
27 | "Image",
28 | "Text",
29 | "Icon",
30 | "Text Button",
31 | "Input",
32 | "List Item",
33 | "Advertisement",
34 | "Pager Indicator",
35 | "Web View",
36 | "Background Image",
37 | "Drawer",
38 | "Modal",
39 | ]
40 |
41 | _rico25_labels = [
42 | "Text",
43 | "Image",
44 | "Icon",
45 | "Text Button",
46 | "List Item",
47 | "Input",
48 | "Background Image",
49 | "Card",
50 | "Web View",
51 | "Radio Button",
52 | "Drawer",
53 | "Checkbox",
54 | "Advertisement",
55 | "Modal",
56 | "Pager Indicator",
57 | "Slider",
58 | "On/Off Switch",
59 | "Button Bar",
60 | "Toolbar",
61 | "Number Stepper",
62 | "Multi-Tab",
63 | "Date Picker",
64 | "Map View",
65 | "Video",
66 | "Bottom Navigation",
67 | ]
68 |
69 |
70 | def append_child(element, elements):
71 | if "children" in element.keys():
72 | for child in element["children"]:
73 | elements.append(child)
74 | elements = append_child(child, elements)
75 | return elements
76 |
77 |
78 | class _RicoDataset(BaseDataset):
79 | def __init__(self, dir: str, split: str, max_seq_length: int, transform=None):
80 | super().__init__(dir, split, max_seq_length, transform)
81 |
82 | def process(self):
83 | data_list = []
84 | raw_file = os.path.join(
85 | self.raw_dir, "rico_dataset_v0.1_semantic_annotations.zip"
86 | )
87 |
88 | fs, _ = url_to_fs(self.raw_dir)
89 | with fs.open(raw_file, "rb") as f, ZipFile(f) as z:
90 | names = sorted([n for n in z.namelist() if n.endswith(".json")])
91 | for name in tqdm(names, total=len(names), ncols=150, desc='prepare Rico'):
92 | ann = json.loads(z.open(name).read())
93 |
94 | B = ann["bounds"]
95 | W, H = float(B[2]), float(B[3])
96 |
97 | if B[0] != 0 or B[1] != 0 or H < W:
98 | continue
99 |
100 | def is_valid(element):
101 | if element["componentLabel"] not in set(self.labels):
102 | return False
103 |
104 | x1, y1, x2, y2 = element["bounds"]
105 | if x1 < 0 or y1 < 0 or W < x2 or H < y2:
106 | return False
107 |
108 | if x2 <= x1 or y2 <= y1:
109 | return False
110 |
111 | return True
112 |
113 | elements = append_child(ann, [])
114 | _elements = list(filter(is_valid, elements))
115 | filtered = len(elements) != len(_elements)
116 | elements = _elements
117 | N = len(elements)
118 | if N == 0 or self.max_seq_length < N:
119 | continue
120 |
121 | # only for debugging slice-based preprocessing
122 | # elements = append_child(ann, [])
123 | # filtered = False
124 | # if len(elements) == 0:
125 | # continue
126 | # elements = elements[:self.max_seq_length]
127 |
128 | boxes = []
129 | labels = []
130 |
131 | for element in elements:
132 | # bbox
133 | x1, y1, x2, y2 = element["bounds"]
134 | xc = (x1 + x2) / 2.0
135 | yc = (y1 + y2) / 2.0
136 | width = x2 - x1
137 | height = y2 - y1
138 | b = [xc / W, yc / H, width / W, height / H]
139 | boxes.append(b)
140 |
141 | # label
142 | l = element["componentLabel"]
143 | labels.append(self.label2index[l])
144 |
145 | boxes = torch.tensor(boxes, dtype=torch.float)
146 | labels = torch.tensor(labels, dtype=torch.long)
147 |
148 | data = Data(x=boxes, y=labels)
149 | data.attr = {
150 | "name": name,
151 | "width": W,
152 | "height": H,
153 | "filtered": filtered,
154 | "has_canvas_element": False,
155 | "NoiseAdded": False,
156 | }
157 | data_list.append(data)
158 |
159 | # shuffle with seed
160 | generator = torch.Generator().manual_seed(0)
161 | indices = torch.randperm(len(data_list), generator=generator)
162 | data_list = [data_list[i] for i in indices]
163 |
164 | # train 85% / val 5% / test 10%
165 | N = len(data_list)
166 | s = [int(N * 0.85), int(N * 0.90)]
167 |
168 | with fs.open(self.processed_paths[0], "wb") as file_obj:
169 | print('duplicate training split (x10) for more batches per epoch')
170 | torch.save(self.collate(data_list[: s[0]] * 10), file_obj)
171 | with fs.open(self.processed_paths[1], "wb") as file_obj:
172 | torch.save(self.collate(data_list[s[0]:s[1]]), file_obj)
173 | with fs.open(self.processed_paths[2], "wb") as file_obj:
174 | torch.save(self.collate(data_list[s[1]:]), file_obj)
175 |
176 | def download(self):
177 | pass
178 |
179 | def get_original_resource(self, batch) -> Image:
180 | assert not self.raw_dir.startswith("gs://")
181 | bbox, _, _, _ = sparse_to_dense(batch)
182 |
183 | img_bg, img_original, cropped_patches = [], [], []
184 | names = batch.attr["name"]
185 | if isinstance(names, str):
186 | names = [names]
187 |
188 | for i, name in enumerate(names):
189 | name = Path(name).name.replace(".json", ".jpg")
190 | img = Image.open(Path(self.raw_dir) / "combined" / name)
191 | img_original.append(copy.deepcopy(img))
192 |
193 | W, H = img.size
194 | ltrb = xywh_to_ltrb_split(bbox[i].T.numpy())
195 | left, right = (ltrb[0] * W).astype(np.uint32), (ltrb[2] * W).astype(
196 | np.uint32
197 | )
198 | top, bottom = (ltrb[1] * H).astype(np.uint32), (ltrb[3] * H).astype(
199 | np.uint32
200 | )
201 | draw = ImageDraw.Draw(img)
202 | patches = []
203 | for (l, r, t, b) in zip(left, right, top, bottom):
204 | patches.append(img.crop((l, t, r, b)))
205 | # draw.rectangle([(l, t), (r, b)], fill=(255, 0, 0))
206 | draw.rectangle([(l, t), (r, b)], fill=(255, 255, 255))
207 | img_bg.append(img)
208 | cropped_patches.append(patches)
209 | # if len(patches) < S:
210 | # for i in range(S - len(patches)):
211 | # patches.append(Image.new("RGB", (0, 0)))
212 |
213 | return {
214 | "img_bg": img_bg,
215 | "img_original": img_original,
216 | "cropped_patches": cropped_patches,
217 | }
218 |
219 | # read from uncompressed data (the last line takes infinite time, so not used now..)
220 | # raw_file = os.path.join(self.raw_dir, "unique_uis.tar.gz")
221 | # with tarfile.open(raw_file) as f:
222 | # # return gzip.GzipFile(fileobj=f.extractfile(f"combined/{name}")).read()
223 | # return gzip.GzipFile(fileobj=f.extractfile(f"combined/hoge")).read()
224 |
225 |
226 | class Rico5Dataset(_RicoDataset):
227 | name = "rico5"
228 | labels = _rico5_labels
229 |
230 | def __init__(self, dir: str, split: str, max_seq_length: int, transform=None):
231 | super().__init__(dir, split, max_seq_length, transform)
232 |
233 |
234 | # Constrained Graphic Layout Generation via Latent Optimization (ACMMM2021)
235 | class Rico13Dataset(_RicoDataset):
236 | name = "rico13"
237 | labels = _rico13_labels
238 |
239 | def __init__(self, dir: str, split: str, max_seq_length: int, transform=None):
240 | super().__init__(dir, split, max_seq_length, transform)
241 |
242 |
243 | class Rico25Dataset(_RicoDataset):
244 | name = "rico25"
245 | labels = _rico25_labels
246 |
247 | def __init__(self, dir: str, split: str, max_seq_length: int, transform=None):
248 | super().__init__(dir, split, max_seq_length, transform)
249 |
--------------------------------------------------------------------------------
/util/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import einops
4 | import torch
5 | import numpy as np
6 | from tqdm import tqdm
7 | import torch.nn.functional as F
8 | torch.manual_seed(3)
9 | import sys
10 | import torch.nn as nn
11 | import torch.optim as optim
12 | import matplotlib
13 | matplotlib.use('Agg')
14 | import matplotlib.pyplot as plt
15 |
16 | def softmax_prefix(input_tensor, n):
17 |
18 | # Extract the first 6 entries on the second dimension
19 | sub_tensor = input_tensor[:, :, :n]
20 |
21 | # Apply softmax along the second dimension
22 | softmax_sub_tensor = F.softmax(sub_tensor, dim=2)
23 |
24 | # Replace the first 6 entries with the softmax results
25 | output_tensor = input_tensor.clone()
26 | output_tensor[:, :, :n] = softmax_sub_tensor
27 |
28 | return output_tensor
29 |
30 |
31 | def make_beta_schedule(schedule="linear", num_timesteps=1000, start=1e-5, end=1e-2):
32 | if schedule == "linear":
33 | betas = torch.linspace(start, end, num_timesteps)
34 | elif schedule == "const":
35 | betas = end * torch.ones(num_timesteps)
36 | elif schedule == "quad":
37 | betas = torch.linspace(start ** 0.5, end ** 0.5, num_timesteps) ** 2
38 | elif schedule == "jsd":
39 | betas = 1.0 / torch.linspace(num_timesteps, 1, num_timesteps)
40 | elif schedule == "sigmoid":
41 | betas = torch.linspace(-6, 6, num_timesteps)
42 | betas = torch.sigmoid(betas) * (end - start) + start
43 | elif schedule == "cosine" or schedule == "cosine_reverse":
44 | max_beta = 0.999
45 | cosine_s = 0.008
46 | betas = torch.tensor(
47 | [min(1 - (math.cos(((i + 1) / num_timesteps + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2) / (
48 | math.cos((i / num_timesteps + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2), max_beta) for i in
49 | range(num_timesteps)])
50 | elif schedule == "cosine_anneal":
51 | betas = torch.tensor(
52 | [start + 0.5 * (end - start) * (1 - math.cos(t / (num_timesteps - 1) * math.pi)) for t in
53 | range(num_timesteps)])
54 | return betas
55 |
56 |
57 | def extract(input, t, x):
58 | shape = x.shape
59 | out = torch.gather(input, 0, t.to(input.device))
60 | reshape = [t.shape[0]] + [1] * (len(shape) - 1)
61 | return out.reshape(*reshape)
62 |
63 |
64 | # Forward functions
65 | def q_sample(y, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, t, noise=None):
66 |
67 | if noise is None:
68 | noise = torch.randn_like(y).to(y.device)
69 | sqrt_alpha_bar_t = extract(alphas_bar_sqrt, t, y)
70 |
71 | sqrt_one_minus_alpha_bar_t = extract(one_minus_alphas_bar_sqrt, t, y)
72 | # q(y_t | y_0, x)
73 | y_t = sqrt_alpha_bar_t * y + sqrt_one_minus_alpha_bar_t * noise
74 |
75 | return y_t
76 |
77 |
78 | # Reverse function -- sample y_{t-1} given y_t
79 | def p_sample(model, y_t, t, alphas, one_minus_alphas_bar_sqrt, stochastic=True):
80 | """
81 | Reverse diffusion process sampling -- one time step.
82 | y: sampled y at time step t, y_t.
83 | """
84 | device = next(model.parameters()).device
85 | z = stochastic * torch.randn_like(y_t)
86 | t = torch.tensor([t]).to(device)
87 | alpha_t = extract(alphas, t, y_t)
88 | sqrt_one_minus_alpha_bar_t = extract(one_minus_alphas_bar_sqrt, t, y_t)
89 | sqrt_one_minus_alpha_bar_t_m_1 = extract(one_minus_alphas_bar_sqrt, t - 1, y_t)
90 | sqrt_alpha_bar_t = (1 - sqrt_one_minus_alpha_bar_t.square()).sqrt()
91 | sqrt_alpha_bar_t_m_1 = (1 - sqrt_one_minus_alpha_bar_t_m_1.square()).sqrt()
92 | # y_t_m_1 posterior mean component coefficients
93 | gamma_0 = (1 - alpha_t) * sqrt_alpha_bar_t_m_1 / (sqrt_one_minus_alpha_bar_t.square())
94 | gamma_1 = (sqrt_one_minus_alpha_bar_t_m_1.square()) * (alpha_t.sqrt()) / (sqrt_one_minus_alpha_bar_t.square())
95 |
96 | eps_theta = model(y_t, timestep=t).to(device).detach()
97 |
98 | # y_0 reparameterization
99 | y_0_reparam = 1 / sqrt_alpha_bar_t * (y_t - eps_theta * sqrt_one_minus_alpha_bar_t).to(device)
100 |
101 | # posterior mean
102 | y_t_m_1_hat = gamma_0 * y_0_reparam + gamma_1 * y_t
103 |
104 | # posterior variance
105 | beta_t_hat = (sqrt_one_minus_alpha_bar_t_m_1.square()) / (sqrt_one_minus_alpha_bar_t.square()) * (1 - alpha_t)
106 | y_t_m_1 = y_t_m_1_hat.to(device) + beta_t_hat.sqrt().to(device) * z.to(device)
107 |
108 |
109 | return y_t_m_1
110 |
111 |
112 | # Reverse function -- sample y_0 given y_1
113 | def p_sample_t_1to0(model, y_t, one_minus_alphas_bar_sqrt):
114 | device = next(model.parameters()).device
115 | t = torch.tensor([0]).to(device) # corresponding to timestep 1 (i.e., t=1 in diffusion models)
116 | sqrt_one_minus_alpha_bar_t = extract(one_minus_alphas_bar_sqrt, t, y_t)
117 | sqrt_alpha_bar_t = (1 - sqrt_one_minus_alpha_bar_t.square()).sqrt()
118 | eps_theta = model(y_t, timestep=t).to(device).detach()
119 |
120 | # y_0 reparameterization
121 | y_0_reparam = 1 / sqrt_alpha_bar_t * (y_t - eps_theta * sqrt_one_minus_alpha_bar_t).to(device)
122 |
123 | y_t_m_1 = y_0_reparam.to(device)
124 |
125 | return y_t_m_1
126 |
127 |
128 | def p_sample_loop(model, batch_size, n_steps, alphas, one_minus_alphas_bar_sqrt,
129 | only_last_sample=True, stochastic=True):
130 | num_t, l_p_seq = None, None
131 |
132 | device = next(model.parameters()).device
133 |
134 | l_t = stochastic * torch.randn_like(torch.zeros([batch_size, 25, 10])).to(device)
135 | if only_last_sample:
136 | num_t = 1
137 | else:
138 | # y_p_seq = [y_t]
139 | l_p_seq = torch.zeros([batch_size, 25, 10, n_steps+1]).to(device)
140 | l_p_seq[:, :, :, n_steps] = l_t
141 |
142 | for t in reversed(range(1, n_steps-1)):
143 |
144 | l_t = p_sample(model, l_t, t, alphas, one_minus_alphas_bar_sqrt, stochastic=stochastic) # y_{t-1}
145 |
146 | if only_last_sample:
147 | num_t += 1
148 | else:
149 | # y_p_seq.append(y_t)
150 | l_p_seq[:, :, :, t] = l_t
151 |
152 |
153 | if only_last_sample:
154 | l_0 = p_sample_t_1to0(model, l_t, one_minus_alphas_bar_sqrt)
155 | return l_0
156 | else:
157 | # assert len(y_p_seq) == n_steps
158 | l_0 = p_sample_t_1to0(model, l_p_seq[:, :, :, 1], one_minus_alphas_bar_sqrt)
159 | # y_p_seq.append(y_0)
160 | l_p_seq[:, :, :, 0] = l_0
161 |
162 | return l_0, l_p_seq
163 |
164 |
165 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps):
166 | if ddim_discr_method == 'uniform':
167 | c = num_ddpm_timesteps // num_ddim_timesteps
168 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
169 | elif ddim_discr_method == 'quad':
170 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
171 | elif ddim_discr_method == 'new':
172 | c = (num_ddpm_timesteps - 50) // (num_ddim_timesteps - 50)
173 | ddim_timesteps = np.asarray(list(range(0, 50)) + list(range(50, num_ddpm_timesteps - 50, c)))
174 | else:
175 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
176 |
177 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
178 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
179 | steps_out = ddim_timesteps + 1
180 | # print(f'Selected timesteps for ddim sampler: {steps_out}')
181 |
182 | return steps_out
183 |
184 |
185 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta):
186 | # select alphas for computing the variance schedule
187 | device = alphacums.device
188 | alphas = alphacums[ddim_timesteps]
189 | alphas_prev = torch.tensor([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()).to(device)
190 |
191 | sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
192 | # print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
193 | # print(f'For the chosen value of eta, which is {eta}, '
194 | # f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
195 | return sigmas, alphas, alphas_prev
196 |
197 |
198 | def ddim_sample_loop(model, batch_size, timesteps, ddim_alphas, ddim_alphas_prev, ddim_sigmas, stochastic=True,
199 | seq_len=25, seq_dim=10):
200 | device = next(model.parameters()).device
201 |
202 | b_t = 1 * stochastic * torch.randn_like(torch.zeros([batch_size, seq_len, seq_dim])).to(device)
203 |
204 | intermediates = {'y_inter': [b_t], 'pred_y0': [b_t]}
205 | time_range = np.flip(timesteps)
206 | total_steps = timesteps.shape[0]
207 | # print(f"Running DDIM Sampling with {total_steps} timesteps")
208 |
209 | for i, step in enumerate(time_range):
210 | index = total_steps - i - 1
211 | t = torch.full((batch_size,), step, device=device, dtype=torch.long)
212 |
213 | b_t, pred_y0 = ddim_sample_step(model, b_t, t, index, ddim_alphas,
214 | ddim_alphas_prev, ddim_sigmas)
215 |
216 | intermediates['y_inter'].append(b_t)
217 | intermediates['pred_y0'].append(pred_y0)
218 |
219 | return b_t, intermediates
220 |
221 |
222 | def rand_fix(batch_size, mask, ratio=0.2, n_elements=25, stochastic=True):
223 |
224 | if stochastic:
225 | indices = (torch.rand([batch_size, n_elements]) <= torch.rand([1]).item() * ratio).to(mask.device) * mask.to(torch.bool)
226 | else:
227 | a = torch.tensor([False, False, True, False, False, True, True, False, False, False,
228 | False, False, False, False, False, False, False, False, False, False,
229 | False, False, False, False, False])
230 | indices = einops.repeat(a, "l -> n l", n=batch_size) * mask.to(torch.bool)
231 |
232 | return indices
233 |
234 |
235 | def ddim_cond_sample_loop(model, real_layout, timesteps, ddim_alphas, ddim_alphas_prev, ddim_sigmas, stochastic=True, cond='c', ratio=0.2):
236 |
237 | device = next(model.parameters()).device
238 | batch_size, seq_len, seq_dim = real_layout.shape
239 | num_class = seq_dim - 4
240 |
241 |
242 | real_label = torch.argmax(real_layout[:, :, :num_class], dim=2)
243 | real_mask = (real_label != num_class-1).clone().detach()
244 |
245 | # cond mask
246 | if cond == 'complete':
247 | fix_mask = rand_fix(batch_size, real_mask, ratio=ratio, stochastic=True)
248 | elif cond == 'c':
249 | fix_mask = torch.zeros([batch_size, seq_len, seq_dim]).to(torch.bool)
250 | fix_mask[:, :, :num_class] = True
251 | elif cond == 'cwh':
252 | fix_mask = torch.zeros([batch_size, seq_len, seq_dim]).to(torch.bool)
253 | fix_ind = [x for x in range(num_class)] + [num_class + 2, num_class + 3]
254 | fix_mask[:, :, fix_ind] = True
255 | else:
256 | raise Exception('cond must be c, cwh, or complete')
257 |
258 | l_t = 1 * stochastic * torch.randn_like(torch.zeros([batch_size, seq_len, seq_dim])).to(device)
259 |
260 | intermediates = {'y_inter': [l_t], 'pred_y0': [l_t]}
261 | time_range = np.flip(timesteps)
262 | total_steps = timesteps.shape[0]
263 | # noise = 1 * torch.randn_like(real_layout).to(device)
264 |
265 | for i, step in enumerate(time_range):
266 | index = total_steps - i - 1
267 | t = torch.full((batch_size,), step, device=device, dtype=torch.long)
268 |
269 | l_t[fix_mask] = real_layout[fix_mask]
270 |
271 | # # plot inter
272 | # print('hi here')
273 | # l_t_label = torch.argmax(l_t[:, :, :6], dim=2)
274 | # l_t_mask = (l_t_label != 5).clone().detach()
275 | # l_bbox = torch.clamp(l_t[:1, :, 6:], min=-1, max=1) / 2 + 0.5
276 | # img = save_image(l_bbox[:4], l_t_label[:4], l_t_mask[:4], draw_label=False)
277 | # plt.figure(figsize=[12, 12])
278 | # plt.imshow(img)
279 | # plt.tight_layout()
280 | # plt.savefig(f'./plot/test/conditional_test_t_{t[0]}.png')
281 | # plt.close()
282 |
283 | l_t, pred_y0 = ddim_sample_step(model, l_t, t, index, ddim_alphas,
284 | ddim_alphas_prev, ddim_sigmas)
285 |
286 | l_t[fix_mask] = real_layout[fix_mask]
287 |
288 | intermediates['y_inter'].append(l_t)
289 | intermediates['pred_y0'].append(pred_y0)
290 |
291 | return l_t, intermediates
292 |
293 |
294 | def ddim_refine_sample_loop(model, noisy_layout, timesteps, ddim_alphas, ddim_alphas_prev, ddim_sigmas):
295 |
296 | device = next(model.parameters()).device
297 | batch_size, seq_len, seq_dim = noisy_layout.shape
298 | l_t = noisy_layout
299 |
300 | intermediates = {'y_inter': [l_t], 'pred_y0': [l_t]}
301 | total_steps = sum(timesteps <= 201)
302 | time_range = np.flip(timesteps[:total_steps])
303 |
304 | # noise = 1 * torch.randn_like(real_layout).to(device)
305 | for i, step in enumerate(time_range):
306 | index = total_steps - i - 1
307 | t = torch.full((batch_size,), step, device=device, dtype=torch.long)
308 |
309 | l_t, pred_y0 = ddim_sample_step(model, l_t, t, index, ddim_alphas,
310 | ddim_alphas_prev, ddim_sigmas)
311 |
312 | intermediates['y_inter'].append(l_t)
313 | intermediates['pred_y0'].append(pred_y0)
314 |
315 | return l_t, intermediates
316 |
317 | def ddim_sample_step(model, l_t, t, index, ddim_alphas, ddim_alphas_prev, ddim_sigmas):
318 |
319 | device = next(model.parameters()).device
320 | e_t = model(l_t, timestep=t).to(device).detach()
321 |
322 | sqrt_one_minus_alphas = torch.sqrt(1. - ddim_alphas)
323 | # select parameters corresponding to the currently considered timestep
324 | a_t = torch.full(e_t.shape, ddim_alphas[index], device=device)
325 | a_t_m_1 = torch.full(e_t.shape, ddim_alphas_prev[index], device=device)
326 | sigma_t = torch.full(e_t.shape, ddim_sigmas[index], device=device)
327 | sqrt_one_minus_at = torch.full(e_t.shape, sqrt_one_minus_alphas[index], device=device)
328 |
329 | # direction pointing to x_t
330 | dir_b_t = (1. - a_t_m_1 - sigma_t ** 2).sqrt() * e_t
331 | noise = sigma_t * torch.randn_like(l_t).to(device)
332 |
333 | # reparameterize x_0
334 | b_0_reparam = (l_t - sqrt_one_minus_at * e_t) / a_t.sqrt()
335 |
336 | # compute b_t_m_1
337 | b_t_m_1 = a_t_m_1.sqrt() * b_0_reparam + 1 * dir_b_t + noise
338 |
339 | return b_t_m_1, b_0_reparam
340 |
341 |
342 | # def compute_piou_grad(layout_in):
343 | #
344 | # layout = layout_in.clone().detach()
345 | #
346 | # # convert centralized layout bbox representation [-1, 1] to wh representation [0, 1]
347 | # layout[:, :, 6:] = torch.clamp(layout[:, :, 6:], min=-1, max=1) / 2 + 0.5
348 | # # print('layout generated:', layout_t_0[0, :10, :])
349 | # bbox = layout[:, :, 6:]
350 | # label = torch.argmax(layout[:, :, :6], dim=2)
351 | # mask = (label != 5).clone().detach()
352 | #
353 | # # compute grad iou
354 | # piou_gradient, mpiou = guidance_grad(bbox, mask=mask.to(torch.float), xy_only=True)
355 | #
356 | # piou_gradient = 2 * (piou_gradient - 0.5)
357 | #
358 | # return piou_gradient, mpiou
359 |
360 |
361 |
362 | if __name__ == "__main__":
363 |
364 | # t = make_ddim_timesteps('new', 200, 1000)
365 | # print(t)
366 |
367 | m = rand_fix(batch_size=2, mask=torch.tensor([1 for _ in range(25)]), n_elements=25)
368 | print(m)
369 |
--------------------------------------------------------------------------------
/util/ema.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | class EMA(object):
4 | def __init__(self, mu=0.999):
5 | self.mu = mu
6 | self.shadow = {}
7 |
8 | def register(self, module):
9 | for name, param in module.named_parameters():
10 | if param.requires_grad:
11 | self.shadow[name] = param.data.clone()
12 |
13 | def update(self, module):
14 | for name, param in module.named_parameters():
15 | if param.requires_grad:
16 | self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data
17 |
18 | def ema(self, module):
19 | for name, param in module.named_parameters():
20 | if param.requires_grad:
21 | param.data.copy_(self.shadow[name].data)
22 |
23 | def ema_copy(self, module):
24 | module_copy = type(module)(module.config).to(module.config.device)
25 | module_copy.load_state_dict(module.state_dict())
26 | self.ema(module_copy)
27 | return module_copy
28 |
29 | def state_dict(self):
30 | return self.shadow
31 |
32 | def load_state_dict(self, state_dict):
33 | self.shadow = state_dict
34 |
--------------------------------------------------------------------------------
/util/metric.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | from functools import partial
3 | from itertools import chain
4 | from typing import Dict, List, Optional, Tuple, Union
5 |
6 | import numpy as np
7 | import torch
8 | import torch.distributions as tdist
9 | from einops import rearrange, reduce, repeat
10 | from prdc import compute_prdc
11 | from pytorch_fid.fid_score import calculate_frechet_distance
12 | from scipy.optimize import linear_sum_assignment
13 | from scipy.stats import wasserstein_distance
14 | from torch import BoolTensor, FloatTensor
15 | from torch_geometric.utils import to_dense_adj
16 | from .data_util import RelLoc, RelSize, detect_loc_relation, detect_size_relation
17 | from .constraint import xywh_to_ltrb_split
18 |
19 | Feats = Union[FloatTensor, List[FloatTensor]]
20 | Layout = Tuple[np.ndarray, np.ndarray]
21 |
22 | # set True to disable parallel computing by multiprocessing (typically for debug)
23 | # DISABLED = False
24 | DISABLED = True
25 |
26 |
27 | def __to_numpy_array(feats: Feats) -> np.ndarray:
28 | if isinstance(feats, list):
29 | # flatten list of batch-processed features
30 | if isinstance(feats[0], FloatTensor):
31 | feats = [x.detach().cpu().numpy() for x in feats]
32 | else:
33 | feats = feats.detach().cpu().numpy()
34 | return np.concatenate(feats)
35 |
36 |
37 | def compute_generative_model_scores(
38 | feats_real: Feats,
39 | feats_fake: Feats,
40 | ) -> Dict[str, float]:
41 | """
42 | Compute precision, recall, density, coverage, and FID.
43 | """
44 | feats_real = __to_numpy_array(feats_real)
45 | feats_fake = __to_numpy_array(feats_fake)
46 |
47 | mu_real = np.mean(feats_real, axis=0)
48 | sigma_real = np.cov(feats_real, rowvar=False)
49 | mu_fake = np.mean(feats_fake, axis=0)
50 | sigma_fake = np.cov(feats_fake, rowvar=False)
51 |
52 | results = compute_prdc(
53 | real_features=feats_real, fake_features=feats_fake, nearest_k=5
54 | )
55 | results["fid"] = calculate_frechet_distance(
56 | mu_real, sigma_real, mu_fake, sigma_fake
57 | )
58 |
59 | return results
60 |
61 |
62 | def compute_violation(bbox_flatten, data):
63 | """
64 | Compute relation violation accuracy as in LayoutGAN++ [Kikuchi+, ACMMM'21].
65 | """
66 | device = data.x.device
67 | failures, valid = [], []
68 |
69 | _zip = zip(data.edge_attr, data.edge_index.t())
70 | for gt, (i, j) in _zip:
71 | failure, _valid = 0, 0
72 | b1, b2 = bbox_flatten[i], bbox_flatten[j]
73 |
74 | # size relation
75 | if ~gt & 1 << RelSize.UNKNOWN:
76 | pred = detect_size_relation(b1, b2)
77 | failure += (gt & 1 << pred).eq(0).long()
78 | _valid += 1
79 |
80 | # loc relation
81 | if ~gt & 1 << RelLoc.UNKNOWN:
82 | canvas = data.y[i].eq(0)
83 | pred = detect_loc_relation(b1, b2, canvas)
84 | failure += (gt & 1 << pred).eq(0).long()
85 | _valid += 1
86 |
87 | failures.append(failure)
88 | valid.append(_valid)
89 |
90 | failures = torch.as_tensor(failures).to(device)
91 | failures = to_dense_adj(data.edge_index, data.batch, failures)
92 | valid = torch.as_tensor(valid).to(device)
93 | valid = to_dense_adj(data.edge_index, data.batch, valid)
94 |
95 | return failures.sum((1, 2)) / valid.sum((1, 2))
96 |
97 |
98 | def compute_alignment(bbox: FloatTensor, mask: BoolTensor) -> FloatTensor:
99 | """
100 | Computes some alignment metrics that are different to each other in previous works.
101 | Attribute-conditioned Layout GAN for Automatic Graphic Design (TVCG2020)
102 | https://arxiv.org/abs/2009.05284
103 | """
104 | S = bbox.size(1)
105 |
106 | bbox = bbox.permute(2, 0, 1)
107 | xl, yt, xr, yb = xywh_to_ltrb_split(bbox)
108 | xc, yc = bbox[0], bbox[1]
109 | X = torch.stack([xl, xc, xr, yt, yc, yb], dim=1)
110 | X = X.unsqueeze(-1) - X.unsqueeze(-2)
111 | idx = torch.arange(X.size(2), device=X.device)
112 | X[:, :, idx, idx] = 1.0
113 | X = X.abs().permute(0, 2, 1, 3)
114 | X[~mask] = 1.0
115 | X = X.min(-1).values.min(-1).values
116 | X.masked_fill_(X.eq(1.0), 0.0)
117 | X = -torch.log(1 - X)
118 |
119 | # original
120 | # return X.sum(-1) / mask.float().sum(-1)
121 |
122 | score = reduce(X, "b s -> b", reduction="sum")
123 | score_normalized = score / reduce(mask, "b s -> b", reduction="sum")
124 | score_normalized[torch.isnan(score_normalized)] = 0.0
125 |
126 | Y = torch.stack([xl, xc, xr], dim=1)
127 | Y = rearrange(Y, "b x s -> b x 1 s") - rearrange(Y, "b x s -> b x s 1")
128 |
129 | batch_mask = rearrange(~mask, "b s -> b 1 s") | rearrange(~mask, "b s -> b s 1")
130 | idx = torch.arange(S, device=Y.device)
131 | batch_mask[:, idx, idx] = True
132 | batch_mask = repeat(batch_mask, "b s1 s2 -> b x s1 s2", x=3)
133 | Y[batch_mask] = 1.0
134 |
135 | # Y = rearrange(Y.abs(), "b x s1 s2 -> b s1 x s2")
136 | # Y = reduce(Y, "b x s1 s2 -> b x", "min")
137 | # Y = rearrange(Y.abs(), " -> b s1 x s2")
138 | Y = reduce(Y.abs(), "b x s1 s2 -> b s1", "min")
139 | Y[Y == 1.0] = 0.0
140 | score_Y = reduce(Y, "b s -> b", "sum")
141 |
142 | results = {
143 | "alignment-ACLayoutGAN": score,
144 | "alignment-LayoutGAN++": score_normalized,
145 | "alignment-NDN": score_Y,
146 | }
147 | return results["alignment-LayoutGAN++"]
148 |
149 |
150 | def compute_overlap(bbox: FloatTensor, mask: BoolTensor):
151 | """
152 | Based on
153 | (i) Attribute-conditioned Layout GAN for Automatic Graphic Design (TVCG2020)
154 | https://arxiv.org/abs/2009.05284
155 | (ii) LAYOUTGAN: GENERATING GRAPHIC LAYOUTS WITH WIREFRAME DISCRIMINATORS (ICLR2019)
156 | https://arxiv.org/abs/1901.06767
157 | "percentage of total overlapping area among any two bounding boxes inside the whole page."
158 | At least BLT authors seems to sum. (in the MSCOCO case, it surpasses 1.0)
159 | """
160 | B, S = mask.size()
161 | bbox = bbox.masked_fill(~mask.unsqueeze(-1), 0)
162 | bbox = bbox.permute(2, 0, 1)
163 |
164 | l1, t1, r1, b1 = xywh_to_ltrb_split(bbox.unsqueeze(-1))
165 | l2, t2, r2, b2 = xywh_to_ltrb_split(bbox.unsqueeze(-2))
166 | a1 = (r1 - l1) * (b1 - t1)
167 |
168 | # intersection
169 | l_max = torch.maximum(l1, l2)
170 | r_min = torch.minimum(r1, r2)
171 | t_max = torch.maximum(t1, t2)
172 | b_min = torch.minimum(b1, b2)
173 | cond = (l_max < r_min) & (t_max < b_min)
174 | ai = torch.where(cond, (r_min - l_max) * (b_min - t_max), torch.zeros_like(a1[0]))
175 |
176 | # diag_mask = torch.eye(a1.size(1), dtype=torch.bool, device=a1.device)
177 | # ai = ai.masked_fill(diag_mask, 0)
178 | batch_mask = rearrange(~mask, "b s -> b 1 s") | rearrange(~mask, "b s -> b s 1")
179 | idx = torch.arange(S, device=ai.device)
180 | batch_mask[:, idx, idx] = True
181 | ai = ai.masked_fill(batch_mask, 0)
182 |
183 | ar = torch.nan_to_num(ai / a1) # (B, S, S)
184 |
185 | # original
186 | # return ar.sum(dim=(1, 2)) / mask.float().sum(-1)
187 |
188 | # fixed to avoid the case with single bbox
189 | score = reduce(ar, "b s1 s2 -> b", reduction="sum")
190 | score_normalized = score / reduce(mask, "b s -> b", reduction="sum")
191 | score_normalized[torch.isnan(score_normalized)] = 0.0
192 |
193 | ids = torch.arange(S)
194 | ii, jj = torch.meshgrid(ids, ids, indexing="ij")
195 | ai[repeat(ii >= jj, "s1 s2 -> b s1 s2", b=B)] = 0.0
196 | overlap = reduce(ai, "b s1 s2 -> b", reduction="sum")
197 |
198 | results = {
199 | "overlap-ACLayoutGAN": score,
200 | "overlap-LayoutGAN++": score_normalized,
201 | "overlap-LayoutGAN": overlap,
202 | }
203 | return results["overlap-LayoutGAN++"]
204 |
205 |
206 | def compute_iou(
207 | box_1: Union[np.ndarray, FloatTensor],
208 | box_2: Union[np.ndarray, FloatTensor],
209 | generalized: bool = False,
210 | ) -> Union[np.ndarray, FloatTensor]:
211 | # box_1: [N, 4] box_2: [N, 4]
212 |
213 | if isinstance(box_1, np.ndarray):
214 | lib = np
215 | elif isinstance(box_1, FloatTensor):
216 | lib = torch
217 | else:
218 | raise NotImplementedError(type(box_1))
219 |
220 | l1, t1, r1, b1 = xywh_to_ltrb_split(box_1.T)
221 | l2, t2, r2, b2 = xywh_to_ltrb_split(box_2.T)
222 | a1, a2 = (r1 - l1) * (b1 - t1), (r2 - l2) * (b2 - t2)
223 |
224 | # intersection
225 | l_max = lib.maximum(l1, l2)
226 | r_min = lib.minimum(r1, r2)
227 | t_max = lib.maximum(t1, t2)
228 | b_min = lib.minimum(b1, b2)
229 | cond = (l_max < r_min) & (t_max < b_min)
230 | ai = lib.where(cond, (r_min - l_max) * (b_min - t_max), lib.zeros_like(a1[0]))
231 |
232 | au = a1 + a2 - ai
233 | iou = ai / au
234 |
235 | if not generalized:
236 | return iou
237 |
238 | # outer region
239 | l_min = lib.minimum(l1, l2)
240 | r_max = lib.maximum(r1, r2)
241 | t_min = lib.minimum(t1, t2)
242 | b_max = lib.maximum(b1, b2)
243 | ac = (r_max - l_min) * (b_max - t_min)
244 |
245 | giou = iou - (ac - au) / ac
246 |
247 | return giou
248 |
249 |
250 | def compute_perceptual_iou(
251 | box_1: Union[np.ndarray, FloatTensor],
252 | box_2: Union[np.ndarray, FloatTensor],
253 | ) -> Union[np.ndarray, FloatTensor]:
254 | """
255 | Computes 'Perceptual' IoU [Kong+, BLT'22]
256 | """
257 | # box_1: [N, 4] box_2: [N, 4]
258 |
259 | if isinstance(box_1, np.ndarray):
260 | lib = np
261 | elif isinstance(box_1, FloatTensor):
262 | lib = torch
263 | else:
264 | raise NotImplementedError(type(box_1))
265 |
266 | l1, t1, r1, b1 = xywh_to_ltrb_split(box_1.T)
267 | l2, t2, r2, b2 = xywh_to_ltrb_split(box_2.T)
268 | a1, a2 = (r1 - l1) * (b1 - t1), (r2 - l2) * (b2 - t2)
269 |
270 | # intersection
271 | l_max = lib.maximum(l1, l2)
272 | r_min = lib.minimum(r1, r2)
273 | t_max = lib.maximum(t1, t2)
274 | b_min = lib.minimum(b1, b2)
275 | cond = (l_max < r_min) & (t_max < b_min)
276 | ai = lib.where(cond, (r_min - l_max) * (b_min - t_max), lib.zeros_like(a1[0]))
277 |
278 | # numpy-only procedure in this part
279 | if isinstance(box_1, FloatTensor):
280 | unique_box_1 = np.unique(box_1.numpy(), axis=0)
281 | else:
282 | unique_box_1 = np.unique(box_1, axis=0)
283 | N = 32
284 | l1, t1, r1, b1 = [
285 | (x * N).round().astype(np.int32).clip(0, N)
286 | for x in xywh_to_ltrb_split(unique_box_1.T)
287 | ]
288 | canvas = np.zeros((N, N))
289 | for (l, t, r, b) in zip(l1, t1, r1, b1):
290 | canvas[t:b, l:r] = 1
291 | global_area_union = canvas.sum() / (N**2)
292 |
293 | if global_area_union > 0.0:
294 | iou = ai / global_area_union
295 | return iou
296 | else:
297 | return lib.zeros((1,))
298 |
299 |
300 | def __compute_maximum_iou_for_layout(layout_1: Layout, layout_2: Layout) -> float:
301 | score = 0.0
302 | (bi, li), (bj, lj) = layout_1, layout_2
303 | N = len(bi)
304 | for l in list(set(li.tolist())):
305 | _bi = bi[np.where(li == l)]
306 | _bj = bj[np.where(lj == l)]
307 | n = len(_bi)
308 | ii, jj = np.meshgrid(range(n), range(n))
309 | ii, jj = ii.flatten(), jj.flatten()
310 | iou = compute_iou(_bi[ii], _bj[jj]).reshape(n, n)
311 | # note: maximize is supported only when scipy >= 1.4
312 | ii, jj = linear_sum_assignment(iou, maximize=True)
313 | score += iou[ii, jj].sum().item()
314 | return score / N
315 |
316 |
317 | def __compute_maximum_iou(layouts_1_and_2: Tuple[List[Layout]]) -> np.ndarray:
318 | layouts_1, layouts_2 = layouts_1_and_2
319 | N, M = len(layouts_1), len(layouts_2)
320 | ii, jj = np.meshgrid(range(N), range(M))
321 | ii, jj = ii.flatten(), jj.flatten()
322 | scores = np.asarray(
323 | [
324 | __compute_maximum_iou_for_layout(layouts_1[i], layouts_2[j])
325 | for i, j in zip(ii, jj)
326 | ]
327 | ).reshape(N, M)
328 | ii, jj = linear_sum_assignment(scores, maximize=True)
329 | return scores[ii, jj]
330 |
331 |
332 | def __get_cond2layouts(layout_list: List[Layout]) -> Dict[str, List[Layout]]:
333 | out = dict()
334 | for bs, ls in layout_list:
335 | cond_key = str(sorted(ls.tolist()))
336 | if cond_key not in out.keys():
337 | out[cond_key] = [(bs, ls)]
338 | else:
339 | out[cond_key].append((bs, ls))
340 | return out
341 |
342 |
343 | def compute_maximum_iou(
344 | layouts_1: List[Layout],
345 | layouts_2: List[Layout],
346 | disable_parallel: bool = DISABLED,
347 | n_jobs: Optional[int] = None,
348 | ):
349 | """
350 | Computes Maximum IoU [Kikuchi+, ACMMM'21]
351 | """
352 | c2bl_1 = __get_cond2layouts(layouts_1)
353 | keys_1 = set(c2bl_1.keys())
354 | c2bl_2 = __get_cond2layouts(layouts_2)
355 | keys_2 = set(c2bl_2.keys())
356 | keys = list(keys_1.intersection(keys_2))
357 | args = [(c2bl_1[key], c2bl_2[key]) for key in keys]
358 | # to check actual number of layouts for evaluation
359 | # ans = 0
360 | # for x in args:
361 | # ans += len(x[0])
362 | if disable_parallel:
363 | scores = [__compute_maximum_iou(a) for a in args]
364 | else:
365 | with multiprocessing.Pool(n_jobs) as p:
366 | scores = p.map(__compute_maximum_iou, args)
367 | scores = np.asarray(list(chain.from_iterable(scores)))
368 | if len(scores) == 0:
369 | return 0.0
370 | else:
371 | return scores.mean().item()
372 |
373 |
374 | def __compute_average_iou(layout: Layout, perceptual: bool = False) -> float:
375 | bbox, _ = layout
376 | N = bbox.shape[0]
377 | if N in [0, 1]:
378 | return 0.0 # no overlap in principle
379 |
380 | ii, jj = np.meshgrid(range(N), range(N))
381 | ii, jj = ii.flatten(), jj.flatten()
382 | is_non_diag = ii != jj # IoU for diag is always 1.0
383 | ii, jj = ii[is_non_diag], jj[is_non_diag]
384 |
385 | if perceptual:
386 | iou = compute_perceptual_iou(bbox[ii], bbox[jj])
387 | else:
388 | iou = compute_iou(bbox[ii], bbox[jj])
389 |
390 | # pick all pairs of overlapped objects
391 | cond = iou > np.finfo(np.float32).eps # to avoid very-small nonzero
392 | # return iou.mean().item()
393 | if len(iou[cond]) > 0:
394 | return iou[cond].mean().item()
395 | else:
396 | return 0.0
397 |
398 |
399 | def compute_average_iou(
400 | layouts: List[Layout],
401 | disable_parallel: bool = DISABLED,
402 | n_jobs: Optional[int] = None,
403 | ) -> Dict[str, float]:
404 | """
405 | Compute IoU between overlapping objects for each layout.
406 | Note that the lower is better unlike popular IoU.
407 |
408 | Reference:
409 | Variational Transformer Networks for Layout Generation (CVPR2021)
410 | https://arxiv.org/abs/2104.02416
411 | Reference: (perceptual version)
412 | BLT: Bidirectional Layout Transformer for Controllable Layout Generation (ECCV2022)
413 | https://arxiv.org/abs/2112.05112
414 | """
415 | func1 = partial(__compute_average_iou, perceptual=True)
416 | func2 = partial(__compute_average_iou, perceptual=False)
417 |
418 | # single-thread process for debugging
419 | if disable_parallel:
420 | scores1 = [func1(l) for l in layouts]
421 | scores2 = [func2(l) for l in layouts]
422 | else:
423 | with multiprocessing.Pool(n_jobs) as p1:
424 | scores1 = p1.map(func1, layouts)
425 | with multiprocessing.Pool(n_jobs) as p2:
426 | scores2 = p2.map(func2, layouts)
427 | results = {
428 | "average_iou-BLT": np.array(scores1).mean().item(),
429 | "average_iou-VTN": np.array(scores2).mean().item(),
430 | }
431 | return results
432 |
433 |
434 | def __compute_bbox_sim(
435 | bboxes_1: np.ndarray,
436 | category_1: np.int64,
437 | bboxes_2: np.ndarray,
438 | category_2: np.int64,
439 | C_S: float = 2.0,
440 | C: float = 0.5,
441 | ) -> float:
442 | # bboxes from diffrent categories never match
443 | if category_1 != category_2:
444 | return 0.0
445 |
446 | cx1, cy1, w1, h1 = bboxes_1
447 | cx2, cy2, w2, h2 = bboxes_2
448 |
449 | delta_c = np.sqrt(np.power(cx1 - cx2, 2) + np.power(cy1 - cy2, 2))
450 | delta_s = np.abs(w1 - w2) + np.abs(h1 - h2)
451 | area = np.minimum(w1 * h1, w2 * h2)
452 | alpha = np.power(np.clip(area, 0.0, None), C)
453 |
454 | weight = alpha * np.power(2.0, -1.0 * delta_c - C_S * delta_s)
455 | return weight
456 |
457 |
458 | def __compute_docsim_between_two_layouts(
459 | layouts_1_layouts_2: Tuple[List[Layout]],
460 | max_diff_thresh: int = 3,
461 | ) -> float:
462 | layouts_1, layouts_2 = layouts_1_layouts_2
463 | bboxes_1, categories_1 = layouts_1
464 | bboxes_2, categories_2 = layouts_2
465 |
466 | N, M = len(bboxes_1), len(bboxes_2)
467 | if N >= M + max_diff_thresh or N <= M - max_diff_thresh:
468 | return 0.0
469 |
470 | ii, jj = np.meshgrid(range(N), range(M))
471 | ii, jj = ii.flatten(), jj.flatten()
472 | scores = np.asarray(
473 | [
474 | __compute_bbox_sim(
475 | bboxes_1[i], categories_1[i], bboxes_2[j], categories_2[j]
476 | )
477 | for i, j in zip(ii, jj)
478 | ]
479 | ).reshape(N, M)
480 | ii, jj = linear_sum_assignment(scores, maximize=True)
481 |
482 | if len(scores[ii, jj]) == 0:
483 | # sometimes, predicted bboxes are somehow filtered.
484 | return 0.0
485 | else:
486 | return scores[ii, jj].mean()
487 |
488 |
489 | def compute_docsim(
490 | layouts_gt: List[Layout],
491 | layouts_generated: List[Layout],
492 | disable_parallel: bool = DISABLED,
493 | n_jobs: Optional[int] = None,
494 | ) -> float:
495 | """
496 | Compute layout-to-layout similarity and average over layout pairs.
497 | Note that this is different from layouts-to-layouts similarity.
498 | """
499 | args = list(zip(layouts_gt, layouts_generated))
500 | if disable_parallel:
501 | scores = []
502 | for arg in args:
503 | scores.append(__compute_docsim_between_two_layouts(arg))
504 | else:
505 | with multiprocessing.Pool(n_jobs) as p:
506 | scores = p.map(__compute_docsim_between_two_layouts, args)
507 | return np.array(scores).mean()
508 |
509 |
510 | def _compute_wasserstein_distance_class(
511 | layouts_1: List[Layout],
512 | layouts_2: List[Layout],
513 | n_categories: int = 25,
514 | ) -> float:
515 | categories_1 = np.concatenate([l[1] for l in layouts_1])
516 | counts = np.array(
517 | [np.count_nonzero(categories_1 == i) for i in range(n_categories)]
518 | )
519 | prob_1 = counts / np.sum(counts)
520 |
521 | categories_2 = np.concatenate([l[1] for l in layouts_2])
522 | counts = np.array(
523 | [np.count_nonzero(categories_2 == i) for i in range(n_categories)]
524 | )
525 | prob_2 = counts / np.sum(counts)
526 | return np.absolute(prob_1 - prob_2).sum()
527 |
528 |
529 | def _compute_wasserstein_distance_bbox(
530 | layouts_1: List[Layout],
531 | layouts_2: List[Layout],
532 | ) -> float:
533 | bboxes_1 = np.concatenate([l[0] for l in layouts_1]).T
534 | bboxes_2 = np.concatenate([l[0] for l in layouts_2]).T
535 |
536 | # simple 1-dimensional wasserstein for (cx, cy, w, h) independently
537 | N = 4
538 | ans = 0.0
539 | for i in range(N):
540 | ans += wasserstein_distance(bboxes_1[i], bboxes_2[i])
541 | ans /= N
542 |
543 | return ans
544 |
545 |
546 | def compute_wasserstein_distance(
547 | layouts_1: List[Layout],
548 | layouts_2: List[Layout],
549 | n_classes: int = 25,
550 | ) -> Dict[str, float]:
551 | w_class = _compute_wasserstein_distance_class(layouts_1, layouts_2, n_classes)
552 | w_bbox = _compute_wasserstein_distance_bbox(layouts_1, layouts_2)
553 | return {
554 | "wdist_class": w_class,
555 | "wdist_bbox": w_bbox,
556 | }
557 |
558 |
559 | if __name__ == "__main__":
560 | layouts = [
561 | (
562 | np.array(
563 | [
564 | [0.2, 0.2, 0.4, 0.4],
565 | ]
566 | ),
567 | np.zeros((1,)),
568 | )
569 | ]
570 | print(compute_average_iou(layouts))
571 |
--------------------------------------------------------------------------------
/util/seq_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from typing import List, Tuple
4 | from torch import BoolTensor, FloatTensor, LongTensor
5 | from torch_geometric.utils import to_dense_batch
6 |
7 | def sparse_to_dense(
8 | batch,
9 | device: torch.device = torch.device("cpu"),
10 | remove_canvas: bool = False,
11 | ) -> Tuple[FloatTensor, LongTensor, BoolTensor, BoolTensor]:
12 | batch = batch.to(device)
13 | bbox, _ = to_dense_batch(batch.x, batch.batch)
14 | label, mask = to_dense_batch(batch.y, batch.batch)
15 |
16 | if remove_canvas:
17 | bbox = bbox[:, 1:].contiguous()
18 | label = label[:, 1:].contiguous() - 1 # cancel +1 effect in transform
19 | label = label.clamp(min=0)
20 | mask = mask[:, 1:].contiguous()
21 |
22 | padding_mask = ~mask
23 | return bbox, label, padding_mask, mask
24 |
25 |
26 | def pad_sequence(seq: LongTensor, max_seq_length: int, value) -> LongTensor:
27 | S = seq.shape[1]
28 | new_shape = list(seq.shape)
29 | s = max_seq_length - S
30 | if s > 0:
31 | new_shape[1] = s
32 | pad = torch.full(new_shape, value, dtype=seq.dtype)
33 | new_seq = torch.cat([seq, pad], dim=1)
34 | else:
35 | new_seq = seq
36 |
37 | return new_seq
38 |
39 |
40 | def pad_until(label: LongTensor, bbox: FloatTensor, mask: BoolTensor, max_seq_length: int
41 | ) -> Tuple[LongTensor, FloatTensor, BoolTensor]:
42 | label = pad_sequence(label, max_seq_length, 0)
43 | bbox = pad_sequence(bbox, max_seq_length, 0)
44 | mask = pad_sequence(mask, max_seq_length, False)
45 | return label, bbox, mask
46 |
47 |
48 | def _to(inputs, device):
49 | """
50 | recursively send tensor to the specified device
51 | """
52 | outputs = {}
53 | for k, v in inputs.items():
54 | if isinstance(v, dict):
55 | outputs[k] = _to(v, device)
56 | elif isinstance(v, torch.Tensor):
57 | outputs[k] = v.to(device)
58 | return outputs
59 |
60 | def loader_to_list(
61 | loader: torch.utils.data.dataloader.DataLoader,
62 | ) -> List[Tuple[np.ndarray, np.ndarray]]:
63 | layouts = []
64 | for batch in loader:
65 | bbox, label, _, mask = sparse_to_dense(batch)
66 | for i in range(len(label)):
67 | valid = mask[i].numpy()
68 | layouts.append((bbox[i].numpy()[valid], label[i].numpy()[valid]))
69 | return layouts
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | import random
2 | from typing import Optional, Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | from einops import rearrange
7 | from torch import BoolTensor, FloatTensor, LongTensor
8 |
9 |
10 | def set_seed(seed: int):
11 | random.seed(seed)
12 | np.random.seed(seed)
13 | torch.manual_seed(seed)
14 |
15 |
16 | def convert_xywh_to_ltrb(bbox: Union[np.ndarray, FloatTensor]):
17 | xc, yc, w, h = bbox
18 | x1 = xc - w / 2
19 | y1 = yc - h / 2
20 | x2 = xc + w / 2
21 | y2 = yc + h / 2
22 | return [x1, y1, x2, y2]
23 |
24 |
25 | def batch_topk_mask(
26 | scores: FloatTensor,
27 | topk: LongTensor,
28 | mask: Optional[BoolTensor] = None,
29 | ) -> Tuple[BoolTensor, FloatTensor]:
30 | assert scores.ndim == 2 and topk.ndim == 1 and scores.size(0) == topk.size(0)
31 | if mask is not None:
32 | assert mask.size() == scores.size()
33 | assert (scores.size(1) >= topk).all()
34 |
35 | # ignore scores where mask = False by setting extreme values
36 | if mask is not None:
37 | const = -1.0 * float("Inf")
38 | const = torch.full_like(scores, fill_value=const)
39 | scores = torch.where(mask, scores, const)
40 |
41 | sorted_values, _ = torch.sort(scores, dim=-1, descending=True)
42 | topk = rearrange(topk, "b -> b 1")
43 |
44 | k_th_scores = torch.gather(sorted_values, dim=1, index=topk)
45 |
46 | topk_mask = scores > k_th_scores
47 | return topk_mask, k_th_scores
48 |
49 |
50 | def batch_shuffle_index(
51 | batch_size: int,
52 | feature_length: int,
53 | mask: Optional[BoolTensor] = None,
54 | ) -> LongTensor:
55 | """
56 | Note: masked part may be shuffled because of unpredictable behaviour of sorting [inf, ..., inf]
57 | """
58 | if mask:
59 | assert mask.size() == [batch_size, feature_length]
60 | scores = torch.rand((batch_size, feature_length))
61 | if mask:
62 | scores[~mask] = float("Inf")
63 | _, indices = torch.sort(scores, dim=1)
64 | return indices
65 |
66 |
67 | if __name__ == "__main__":
68 | scores = torch.arange(6).view(2, 3).float()
69 | # topk = torch.arange(2) + 1
70 | topk = torch.full((2,), 3)
71 | mask = torch.full((2, 3), False)
72 | # mask[1, 2] = False
73 | print(batch_topk_mask(scores, topk, mask=mask))
74 |
--------------------------------------------------------------------------------
/util/visualization.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | from collections import Counter
3 | from typing import Dict, List, Optional, Tuple, Union
4 |
5 | import numpy as np
6 | import torch
7 | import torchvision.transforms as T
8 | import torchvision.utils as vutils
9 | from einops import rearrange
10 | from numpy.typing import NDArray
11 | from PIL import Image, ImageDraw, ImageFont
12 | from torch import BoolTensor, FloatTensor, LongTensor
13 |
14 | from .util import convert_xywh_to_ltrb
15 |
16 | label_names = dict()
17 | label_names['rico25'] = ("Text", "Image", "Icon", "Text Button", "List Item", "Input", "Background Image", "Card",
18 | "Web View", "Radio Button", "Drawer", "Checkbox", "Advertisement", "Modal", "Pager Indicator",
19 | "Slider", "On/Off Switch", "Button Bar", "Toolbar", "Number Stepper", "Multi-Tab", "Date Picker",
20 | "Map View", "Video", "Bottom Navigation")
21 |
22 | label_names['rico13'] = ("Toolbar", "Image", "Text", "Icon", "Text Button", "Input", "List Item", "Advertisement",
23 | "Pager Indicator", "Web View", "Background Image", "Drawer", "Modal")
24 |
25 | label_names['magazine'] = ('text', 'image', 'headline', 'text-over-image', 'headline-over-image', 'background')
26 | label_names['publaynet'] = ('text', 'title', 'list', 'table', 'figure')
27 | label_names['crello'] = ['coloredBackground', 'imageElement', 'maskElement', 'svgElement', 'textElement']
28 | color_6 = ((246, 112, 136), (173, 156, 49), (51, 176, 122), (56, 168, 197), (204, 121, 244), (204, 50, 144))
29 |
30 | def color_extend(n_color):
31 |
32 | colors = ((246, 112, 136), (173, 156, 49), (51, 176, 122), (56, 168, 197), (204, 121, 244))
33 | batch_size = np.ceil(n_color / 4).astype(int)
34 |
35 | color_batch_1 = np.linspace(colors[0], colors[1], batch_size)[:-1].astype(int)
36 | color_batch_2 = np.linspace(colors[1], colors[2], batch_size)[:-1].astype(int)
37 | color_batch_3 = np.linspace(colors[2], colors[3], batch_size)[:-1].astype(int)
38 | color_batch_4 = np.linspace(colors[3], colors[4], batch_size).astype(int)
39 | colors_long = np.concatenate([color_batch_1, color_batch_2, color_batch_3, color_batch_4])
40 | colors_long = tuple(map(tuple, colors_long))
41 | return colors_long
42 |
43 | def convert_layout_to_image(
44 | boxes: FloatTensor,
45 | labels: LongTensor,
46 | colors: List[Tuple[int]],
47 | canvas_size: Optional[Tuple[int]] = (60, 40),
48 | resources: Optional[Dict] = None,
49 | names: Optional[Tuple[str]] = None,
50 | index: bool = False,
51 | **kwargs,
52 | ):
53 | H, W = canvas_size
54 | if names or index:
55 | # font = ImageFont.truetype("LiberationSerif-Regular", W // 10)
56 | font = ImageFont.load_default()
57 |
58 | if resources:
59 | img = resources["img_bg"].resize((W, H))
60 | else:
61 | img = Image.new("RGB", (int(W), int(H)), color=(255, 255, 255))
62 | draw = ImageDraw.Draw(img, "RGBA")
63 |
64 | # draw from larger boxes
65 | a = [b[2] * b[3] for b in boxes]
66 | indices = sorted(range(len(a)), key=lambda i: a[i], reverse=True)
67 |
68 | for i in indices:
69 | bbox, label = boxes[i], labels[i]
70 | if isinstance(label, LongTensor):
71 | label = label.item()
72 |
73 | c_fill = colors[label] + (100,)
74 | x1, y1, x2, y2 = convert_xywh_to_ltrb(bbox)
75 | x1, x2 = x1 * (W - 1), x2 * (W - 1)
76 | y1, y2 = y1 * (H - 1), y2 * (H - 1)
77 |
78 | if resources:
79 | patch = resources["cropped_patches"][i]
80 | # round coordinates for exact size match for rendering images
81 | x1, x2 = int(x1), int(x2)
82 | y1, y2 = int(y1), int(y2)
83 | w, h = x2 - x1, y2 - y1
84 | patch = patch.resize((w, h))
85 | img.paste(patch, (x1, y1))
86 | else:
87 | draw.rectangle([x1, y1, x2, y2], outline=colors[label], fill=c_fill)
88 | if names:
89 | # draw.text((x1, y1), names[label], colors[label], font=font)
90 | draw.text((x1, y1), names[label], "black", font=font)
91 | elif index:
92 | draw.text((x1, y1), str(int(i % (len(labels)/2))), "black", font=font)
93 |
94 | return img
95 |
96 |
97 | def save_image(
98 | batch_boxes: FloatTensor,
99 | batch_labels: LongTensor,
100 | batch_mask: BoolTensor,
101 | out_path: Optional[Union[pathlib.PosixPath, str]] = None,
102 | canvas_size: Optional[Tuple[int]] = (360, 240),
103 | nrow: Optional[int] = None,
104 | batch_resources: Optional[Dict] = None,
105 | use_grid: bool = True,
106 | draw_label: bool = False,
107 | draw_index: bool = False,
108 | dataset: str = 'publaynet'
109 | ):
110 | # batch_boxes: [B, N, 4]
111 | # batch_labels: [B, N]
112 | # batch_mask: [B, N]
113 |
114 | assert dataset in ['rico13', 'rico25', 'publaynet', 'magazine', 'crello']
115 |
116 | if isinstance(out_path, pathlib.PosixPath):
117 | out_path = str(out_path)
118 |
119 | if dataset == 'rico13':
120 | colors = color_extend(13)
121 | elif dataset == 'rico25':
122 | colors = color_extend(25)
123 | elif dataset in ['publaynet', 'magazine', 'crello']:
124 | colors = color_6
125 |
126 | if not draw_label:
127 | names = None
128 | else:
129 | names = label_names[dataset]
130 |
131 | # raise Exception('dataset must be rico or publaynet')
132 |
133 | imgs = []
134 | B = batch_boxes.size(0)
135 | to_tensor = T.ToTensor()
136 | for i in range(B):
137 | mask_i = batch_mask[i]
138 | boxes = batch_boxes[i][mask_i]
139 | labels = batch_labels[i][mask_i]
140 | if batch_resources:
141 | resources = {k: v[i] for (k, v) in batch_resources.items()}
142 | img = convert_layout_to_image(boxes, labels, colors, canvas_size, resources, names=names, index=draw_index)
143 | else:
144 | img = convert_layout_to_image(boxes, labels, colors, canvas_size, names=names, index=draw_index)
145 | imgs.append(to_tensor(img))
146 | image = torch.stack(imgs)
147 |
148 | if nrow is None:
149 | nrow = int(np.ceil(np.sqrt(B)))
150 |
151 | if out_path:
152 | vutils.save_image(image, out_path, normalize=False, nrow=nrow)
153 | else:
154 | if use_grid:
155 | return torch_to_numpy_image(
156 | vutils.make_grid(image, normalize=False, nrow=nrow)
157 | )
158 | else:
159 | return image
160 |
161 |
162 | def save_label(
163 | labels: Union[LongTensor, list],
164 | names: List[str],
165 | colors: List[Tuple[int]],
166 | out_path: Optional[Union[pathlib.PosixPath, str]] = None,
167 | **kwargs
168 | # canvas_size: Optional[Tuple[int]] = (60, 40),
169 | ):
170 | space, pad = 12, 12
171 | x_offset, y_offset = 500, 100
172 |
173 | img = Image.new("RGBA", (1000, 1000))
174 | # fnt = ImageFont.truetype("LiberationSerif-Regular", 40)
175 | fnt = ImageFont.load_default()
176 | # fnt_sm = ImageFont.truetype("LiberationSerif-Regular", 32)
177 | fnt_sm = ImageFont.load_default()
178 | d = ImageDraw.Draw(img)
179 |
180 | if isinstance(labels, LongTensor):
181 | labels = labels.tolist()
182 |
183 | cnt = Counter(labels)
184 | for l in range(len(colors)):
185 | if l not in cnt.keys():
186 | continue
187 |
188 | text = names[l]
189 | use_multiline = False
190 |
191 | if cnt[l] > 1:
192 | add_width = d.textsize(f" × {cnt[l]}", font=fnt)[0] + pad
193 | else:
194 | add_width = 0
195 |
196 | width = d.textsize(text, font=fnt)[0]
197 | bbox = d.textbbox(
198 | (x_offset - (width + add_width) / 2, y_offset), text, font=fnt
199 | )
200 | bbox = (bbox[0] - pad, bbox[1] - pad, bbox[2] + pad, bbox[3] + pad)
201 | d.rectangle(bbox, fill=colors[l])
202 |
203 | _x_offset = x_offset - (width + add_width) / 2
204 | if cnt[l] > 1:
205 | d.text(
206 | (_x_offset + width + pad, y_offset),
207 | f" × {cnt[l]}",
208 | font=fnt,
209 | fill=(0, 0, 0),
210 | )
211 |
212 | d.text((_x_offset, y_offset), text, font=fnt, fill=(255, 255, 255))
213 |
214 | y_offset = y_offset + bbox[3] - bbox[1] + space
215 |
216 | # crop
217 | bbox = img.getbbox()
218 | bbox = (bbox[0] - pad, bbox[1] - pad, bbox[2] + pad, bbox[3] + pad)
219 | img = img.crop(bbox)
220 |
221 | # add white background
222 | out = Image.new("RGB", img.size, color=(255, 255, 255))
223 | out.paste(img, mask=img)
224 | # pil_size = canvas_size[::-1] # (H, W) -> (W, H)
225 | # out = out.resize(pil_size)
226 | if out_path:
227 | out.save(out_path)
228 | else:
229 | return np.array(out)
230 |
231 |
232 | def save_label_with_size(
233 | labels: LongTensor,
234 | boxes: FloatTensor,
235 | names: List[str],
236 | colors: List[Tuple[int]],
237 | out_path: Optional[Union[pathlib.PosixPath, str]] = None,
238 | # canvas_size: Optional[Tuple[int]] = (60, 40),
239 | **kwargs,
240 | ):
241 | space, pad = 12, 12
242 | x_offset, y_offset = 500, 100
243 | B = 32
244 |
245 | img = Image.new("RGBA", (1000, 1000))
246 | # fnt = ImageFont.truetype("LiberationSerif-Regular", 40)
247 | fnt = ImageFont.load_default()
248 | # fnt_sm = ImageFont.truetype("LiberationSerif-Regular", 32)
249 | fnt_sm = ImageFont.load_default()
250 | d = ImageDraw.Draw(img)
251 |
252 | for i, l in enumerate(labels):
253 | w, h = [int(x) for x in (boxes[i].clip(1 / B, 1.0) * B).long()[2:]]
254 | text = f"{names[l]} ({w},{h})"
255 | add_width = 0
256 |
257 | width = d.textsize(text, font=fnt)[0]
258 | bbox = d.textbbox(
259 | (x_offset - (width + add_width) / 2, y_offset), text, font=fnt
260 | )
261 | bbox = (bbox[0] - pad, bbox[1] - pad, bbox[2] + pad, bbox[3] + pad)
262 | d.rectangle(bbox, fill=colors[l])
263 |
264 | _x_offset = x_offset - (width + add_width) / 2
265 | d.text((_x_offset, y_offset), text, font=fnt, fill=(255, 255, 255))
266 | y_offset = y_offset + bbox[3] - bbox[1] + space
267 |
268 | # crop
269 | bbox = img.getbbox()
270 | bbox = (bbox[0] - pad, bbox[1] - pad, bbox[2] + pad, bbox[3] + pad)
271 | img = img.crop(bbox)
272 |
273 | # add white background
274 | out = Image.new("RGB", img.size, color=(255, 255, 255))
275 | out.paste(img, mask=img)
276 | # pil_size = canvas_size[::-1] # (H, W) -> (W, H)
277 | # out = out.resize(pil_size)
278 | if out_path:
279 | out.save(out_path)
280 | else:
281 | return np.array(out)
282 |
283 |
284 | def torch_to_numpy_image(input_th: FloatTensor) -> np.ndarray:
285 | """
286 | Args
287 | input_th: (C, H, W), in [0.0, 1/0], torch image
288 | Returns
289 | output_npy: (H, W, C), in {0, 1, ..., 255}, numpy image
290 | """
291 | x = (input_th * 255.0).clamp(0, 255)
292 | x = rearrange(x, "c h w -> h w c")
293 | output_npy = x.numpy().astype(np.uint8)
294 | return output_npy
295 |
296 |
297 | def save_relation(
298 | label_with_canvas: LongTensor,
299 | edge_attr: LongTensor,
300 | names: List[str],
301 | colors: List[Tuple[int]],
302 | out_path: Optional[Union[pathlib.PosixPath, str]] = None,
303 | **kwargs,
304 | ):
305 | from trainer.data.util import ( # lazy load to avoid circular import
306 | RelLoc,
307 | RelSize,
308 | get_rel_text,
309 | )
310 |
311 | pairs, triplets = [], []
312 | relations = list(RelSize) + list(RelLoc)
313 | for rel_value in relations:
314 | if rel_value in [RelSize.UNKNOWN, RelLoc.UNKNOWN]:
315 | continue
316 | cond = edge_attr & 1 << rel_value
317 | ii, jj = np.where(cond.numpy() > 0)
318 | for i, j in zip(ii, jj):
319 | li = label_with_canvas[i] - 1
320 | lj = label_with_canvas[j] - 1
321 |
322 | if i == 0:
323 | rel = get_rel_text(rel_value, canvas=True)
324 | pairs.append((lj, rel, None))
325 | else:
326 | rel = get_rel_text(rel_value, canvas=False)
327 | triplets.append((li, rel, lj))
328 |
329 | triplets = pairs + triplets
330 |
331 | space, pad = 6, 6
332 | img = Image.new("RGBA", (1000, 1000))
333 | fnt = ImageFont.truetype("LiberationSerif-Regular", 20)
334 | fnt_sm = ImageFont.truetype("LiberationSerif-Regular", 16)
335 | d = ImageDraw.Draw(img)
336 |
337 | def draw_text(x_offset, y_offset, text, color=None, first=False):
338 | if color is None:
339 | d.text((x_offset, y_offset), text, font=fnt, fill=(0, 0, 0))
340 | x_offset = x_offset + d.textsize(text, font=fnt)[0] + space
341 | else:
342 | x_offset = x_offset + pad
343 |
344 | use_multiline = False
345 | bbox = d.textbbox((x_offset, y_offset), text, font=fnt)
346 | if bbox[2] - bbox[0] > 120 and " " in text:
347 | use_multiline = True
348 | h_old = d.textsize(text, font=fnt)[1]
349 | text = text.replace(" ", "\n")
350 | h_new = d.multiline_textsize(text, font=fnt_sm)[1]
351 | h_diff = h_new - h_old
352 | if first:
353 | y_offset = y_offset + h_diff / 2
354 | bbox = d.multiline_textbbox(
355 | (x_offset, y_offset - h_diff / 2), text, font=fnt_sm
356 | )
357 |
358 | bbox = (bbox[0] - pad, bbox[1] - pad, bbox[2] + pad, bbox[3] + pad)
359 | d.rectangle(bbox, fill=color)
360 |
361 | if use_multiline:
362 | d.multiline_text(
363 | (x_offset, y_offset - h_diff / 2),
364 | text,
365 | align="center",
366 | font=fnt_sm,
367 | fill=(255, 255, 255),
368 | )
369 | text_width = d.multiline_textsize(text, font=fnt_sm)[0]
370 | else:
371 | d.text((x_offset, y_offset), text, font=fnt, fill=(255, 255, 255))
372 | text_width = d.textsize(text, font=fnt)[0]
373 |
374 | x_offset = x_offset + text_width + space + pad
375 | return x_offset, y_offset
376 |
377 | for i, (l1, rel, l2) in enumerate(triplets):
378 | x_offset, y_offset = 20, 40 * (i + 1)
379 | x_offset, y_offset = draw_text(
380 | x_offset, y_offset, names[l1], colors[l1], first=True
381 | )
382 | x_offset, y_offset = draw_text(x_offset, y_offset, rel)
383 | if l2 is not None:
384 | draw_text(x_offset, y_offset, names[l2], colors[l2])
385 |
386 | # crop
387 | bbox = img.getbbox()
388 | if bbox is not None:
389 | bbox = (bbox[0] - pad, bbox[1] - pad, bbox[2] + pad, bbox[3] + pad)
390 | img = img.crop(bbox)
391 |
392 | # add white background
393 | out = Image.new("RGB", img.size, color=(255, 255, 255))
394 | out.paste(img, mask=img)
395 |
396 | if out_path:
397 | out.save(out_path)
398 | else:
399 | return np.array(out)
400 |
401 |
402 | def save_gif(
403 | images: List[NDArray],
404 | out_path: str,
405 | **kwargs,
406 | ):
407 | assert images[0].ndim == 4
408 | to_pil = T.ToPILImage()
409 | for i in range(len(images[0])):
410 | tmp_images = [to_pil(image[i]) for image in images]
411 | tmp_images[0].save(
412 | # f"tmp/animation/{i}.gif",
413 | out_path.format(i),
414 | save_all=True,
415 | append_images=tmp_images[1:],
416 | optimize=False,
417 | duration=200,
418 | loop=0,
419 | )
420 |
421 |
422 |
423 | if __name__=="__main__":
424 |
425 | print(len(rico_labels))
426 |
--------------------------------------------------------------------------------