├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── bin ├── clustering_coordinates.py ├── test.sh └── train.sh ├── docs └── custom_dataset.md ├── eval.py ├── notebooks └── demo.ipynb ├── poetry.lock ├── pyproject.toml ├── render.py └── src └── trainer └── trainer ├── __init__.py ├── config ├── __init__.py ├── backbone │ └── medium.yaml ├── data │ └── default.yaml ├── dataset │ ├── publaynet.yaml │ └── rico25.yaml ├── experiment │ ├── bart.yaml │ ├── blt_eccv2022.yaml │ ├── blt_eccv2022_ordered.yaml │ ├── diffusionlm_neurips2022.yaml │ ├── layout_transformer_iccv2021.yaml │ ├── layout_transformer_iccv2021_ordered.yaml │ ├── layoutdm.yaml │ ├── maskgit_cvpr2022.yaml │ ├── maskgit_cvpr2022_ordered.yaml │ ├── ruite.yaml │ └── vqdiffusion.yaml ├── main.yaml ├── model │ ├── bart.yaml │ ├── blt.yaml │ ├── elem_wise_autoreg.yaml │ ├── layout_continuous_diffusion.yaml │ ├── layoutdm.yaml │ ├── maskgit.yaml │ └── ruite.yaml ├── optimizer │ └── adamw.yaml ├── scheduler │ ├── inverse_sqrt_decay_with_warmup.yaml │ ├── reduce_lr_on_plateau.yaml │ ├── reduce_lr_on_plateau_with_warmup.yaml │ └── void.yaml └── training │ └── default.yaml ├── crossplatform_util.py ├── data ├── __init__.py └── util.py ├── datasets ├── __init__.py ├── base.py ├── dataset.py ├── publaynet.py └── rico.py ├── fid ├── __init__.py ├── model.py └── train.py ├── global_configs.py ├── helpers ├── __init__.py ├── bbox_tokenizer.py ├── clustering.py ├── layout_tokenizer.py ├── mask.py ├── metric.py ├── sampling.py ├── scheduler.py ├── task.py ├── util.py └── visualization.py ├── hydra_configs.py ├── main.py ├── models ├── __init__.py ├── bart.py ├── base_model.py ├── blt.py ├── categorical_diffusion │ ├── __init__.py │ ├── base.py │ ├── constrained.py │ ├── logit_adjustment.py │ ├── util.py │ └── vanilla.py ├── clg │ ├── __init__.py │ └── const.py ├── common │ ├── __init__.py │ ├── layout.py │ ├── nn_lib.py │ └── util.py ├── continuous_diffusion │ ├── __init__.py │ ├── base.py │ ├── bitdiffusion.py │ └── diffusion_lm.py ├── elem_wise_autoreg.py ├── layout_continuous_diffusion.py ├── layoutdm.py ├── maskgit.py ├── ruite.py └── transformer_utils.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | multirun 2 | download 3 | *.csv 4 | *.pdf 5 | *.png 6 | *.zip 7 | *.pt 8 | *.pth.tar 9 | 10 | ### Generated by gibo (https://github.com/simonwhitaker/gibo) 11 | ### https://raw.github.com/github/gitignore/b2ccc4644b997fa2e86da5ae37f3b053c39f3d7b/Python.gitignore 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py,cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | cover/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | db.sqlite3 74 | db.sqlite3-journal 75 | 76 | # Flask stuff: 77 | instance/ 78 | .webassets-cache 79 | 80 | # Scrapy stuff: 81 | .scrapy 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyBuilder 87 | .pybuilder/ 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # IPython 94 | profile_default/ 95 | ipython_config.py 96 | 97 | # pyenv 98 | # For a library or package, you might want to ignore these files since the code is 99 | # intended to run in multiple environments; otherwise, check them in: 100 | # .python-version 101 | 102 | # pipenv 103 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 104 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 105 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 106 | # install all needed dependencies. 107 | #Pipfile.lock 108 | 109 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 110 | __pypackages__/ 111 | 112 | # Celery stuff 113 | celerybeat-schedule 114 | celerybeat.pid 115 | 116 | # SageMath parsed files 117 | *.sage.py 118 | 119 | # Environments 120 | .env 121 | .venv 122 | env/ 123 | venv/ 124 | ENV/ 125 | env.bak/ 126 | venv.bak/ 127 | 128 | # Spyder project settings 129 | .spyderproject 130 | .spyproject 131 | 132 | # Rope project settings 133 | .ropeproject 134 | 135 | # mkdocs documentation 136 | /site 137 | 138 | # mypy 139 | .mypy_cache/ 140 | .dmypy.json 141 | dmypy.json 142 | 143 | # Pyre type checker 144 | .pyre/ 145 | 146 | # pytype static type analyzer 147 | .pytype/ 148 | 149 | # Cython debug symbols 150 | cython_debug/ 151 | 152 | 153 | ### Generated by gibo (https://github.com/simonwhitaker/gibo) 154 | ### https://raw.github.com/github/gitignore/b2ccc4644b997fa2e86da5ae37f3b053c39f3d7b/Global/macOS.gitignore 155 | 156 | # General 157 | .DS_Store 158 | .AppleDouble 159 | .LSOverride 160 | 161 | # Icon must end with two \r 162 | Icon 163 | 164 | 165 | # Thumbnails 166 | ._* 167 | 168 | # Files that might appear in the root of a volume 169 | .DocumentRevisions-V100 170 | .fseventsd 171 | .Spotlight-V100 172 | .TemporaryItems 173 | .Trashes 174 | .VolumeIcon.icns 175 | .com.apple.timemachine.donotpresent 176 | 177 | # Directories potentially created on remote AFP share 178 | .AppleDB 179 | .AppleDesktop 180 | Network Trash Folder 181 | Temporary Items 182 | .apdisk 183 | 184 | 185 | tmp 186 | .hydra 187 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "src/trainer/trainer/helpers/--recursive"] 2 | path = src/trainer/trainer/helpers/--recursive 3 | url = https://github.com/cocodataset/cocoapi.git 4 | [submodule "src/trainer/trainer/helpers/cocoapi"] 5 | path = src/trainer/trainer/helpers/cocoapi 6 | url = https://github.com/cocodataset/cocoapi.git 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LayoutDM: Discrete Diffusion Model for Controllable Layout Generation (CVPR2023) 2 | This repository is an official implementation of the paper titled above. 3 | Please refer to [project page](https://cyberagentailab.github.io/layout-dm/) or [paper](https://arxiv.org/abs/2303.08137) for more details. 4 | 5 | ## Setup 6 | Here we describe the setup required for the model training and evaluation. 7 | 8 | ### Requirements 9 | We check the reproducibility under this environment. 10 | - Python3.7 11 | - CUDA 11.3 12 | - [PyTorch](https://pytorch.org/get-started/locally/) 1.10 13 | 14 | We recommend using Poetry (all settings and dependencies in [pyproject.toml](pyproject.toml)). 15 | Pytorch-geometry provides independent pre-build wheel for a *combination* of PyTorch and CUDA version (see [PyG:Installation](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html 16 | ) for details). If your environment does not match the one above, please update the dependencies. 17 | 18 | 19 | ### How to install 20 | 1. Install poetry (see [official docs](https://python-poetry.org/docs/)). We recommend to make a virtualenv and install poetry inside it. 21 | 22 | ```bash 23 | curl -sSL https://install.python-poetry.org | python3 - 24 | ``` 25 | 26 | 2. Install dependencies (it may be slow..) 27 | 28 | ```bash 29 | poetry install 30 | ``` 31 | 32 | 3. Download resources and unzip 33 | 34 | ``` bash 35 | wget https://github.com/CyberAgentAILab/layout-dm/releases/download/v1.0.0/layoutdm_starter.zip 36 | unzip layoutdm_starter.zip 37 | ``` 38 | 39 | The data is decompressed to the following structure: 40 | ``` 41 | download 42 | - clustering_weights 43 | - datasets 44 | - fid_weights 45 | - pretrained_weights 46 | ``` 47 | 48 | ## Experiment 49 | **Important**: we find some critical errors that cannot be fixed quickly in using multiple GPUs. Please set `CUDA_VISIBLE_DEVICES=` to force the model use a single GPU. 50 | 51 | Note: our main framework is based on [hydra](https://hydra.cc/). It is convenient to handle dozens of arguments hierarchically but may require some additional efforts if one is new to hydra. 52 | 53 | ### Demo 54 | Please run a jupyter notebook in [notebooks/demo.ipynb](notebooks/demo.ipynb). You can get and render the results of six layout generation tasks on two datasets (Rico and PubLayNet). 55 | 56 | ### Training 57 | You can also train your own model from scratch, for example by 58 | 59 | ```bash 60 | bash bin/train.sh rico25 layoutdm 61 | ``` 62 | 63 | , where the first and second argument specifies the dataset ([choices](src/trainer/trainer/config/dataset)) and the type of experiment ([choices](src/trainer/trainer/config/experiment)), respectively. 64 | Note that for training/testing, style of the arguments is `key=value` because we use hydra, unlike popular `--key value` (e.g., [argparse](https://docs.python.org/3/library/argparse.html)). 65 | 66 | ### Testing 67 | 68 | ```bash 69 | poetry run python3 -m src.trainer.trainer.test \ 70 | cond= \ 71 | job_dir= \ 72 | result_dir= \ 73 | 74 | ``` 75 | `` can be: (unconditional, c, cwh, partial, refinement, relation) 76 | 77 | For example, if you want to test the provided LayoutDM model on `C->S+P`, the command is as follows: 78 | ``` 79 | poetry run python3 -m src.trainer.trainer.test cond=c dataset_dir=./download/datasets job_dir=./download/pretrained_weights/layoutdm_rico result_dir=tmp/dummy_results 80 | ``` 81 | 82 | Please refer to [TestConfig](src/trainer/trainer/hydra_configs.py#L12) for more options available. 83 | Below are some popular options for 84 | - `is_validation=true`: used to evaluate the generation performance on validation set instead of test set. This must be used when tuning the hyper-parameters. 85 | - `sampling=top_p top_p=`: use top-p sampling with p= instead of default sampling. 86 | 87 | ### Evaluation 88 | ```bash 89 | poetry run python3 eval.py 90 | ``` 91 | 92 | ## Citation 93 | 94 | If you find this code useful for your research, please cite our paper: 95 | 96 | ``` 97 | @inproceedings{inoue2023layout, 98 | title={{LayoutDM: Discrete Diffusion Model for Controllable Layout Generation}}, 99 | author={Naoto Inoue and Kotaro Kikuchi and Edgar Simo-Serra and Mayu Otani and Kota Yamaguchi}, 100 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 101 | year={2023}, 102 | pages={10167-10176}, 103 | } 104 | ``` 105 | -------------------------------------------------------------------------------- /bin/clustering_coordinates.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import pickle 4 | import time 5 | from pathlib import Path 6 | 7 | import omegaconf 8 | import torch 9 | from fsspec.core import url_to_fs 10 | from hydra.utils import instantiate 11 | from sklearn.cluster import KMeans 12 | from trainer.global_configs import DATASET_DIR 13 | from trainer.helpers.clustering import Percentile 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | KEYS = ["x", "y", "w", "h"] 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("dataset_yaml", type=str) 21 | parser.add_argument("algorithm", type=str, choices=["kmeans", "percentile"]) 22 | parser.add_argument("--result_dir", type=str, default="tmp/clustering_weights") 23 | parser.add_argument("--random_state", type=int, default=0) 24 | parser.add_argument( 25 | "--max_bbox_num", 26 | type=int, 27 | default=int(1e5), 28 | help="filter number of bboxes to avoid too much time consumption in kmeans", 29 | ) 30 | 31 | args = parser.parse_args() 32 | fs, _ = url_to_fs(args.dataset_yaml) 33 | n_clusters_list = [2**i for i in range(1, 9)] 34 | 35 | dataset_cfg = omegaconf.OmegaConf.load(args.dataset_yaml) 36 | dataset_cfg["dir"] = DATASET_DIR 37 | dataset = instantiate(dataset_cfg)(split="train", transform=None) 38 | bboxes = torch.cat([e.x for e in dataset], axis=0) 39 | 40 | models = {} 41 | name = Path(args.dataset_yaml).stem 42 | weight_path = f"{args.result_dir}/{name}_max{dataset_cfg.max_seq_length}_{args.algorithm}_train_clusters.pkl" 43 | 44 | if bboxes.size(0) > args.max_bbox_num and args.algorithm == "kmeans": 45 | text = f"{bboxes.size(0)} -> {args.max_bbox_num}" 46 | logger.warning( 47 | f"Subsampling bboxes because there are too many for kmeans: ({text})" 48 | ) 49 | generator = torch.Generator().manual_seed(args.random_state) 50 | indices = torch.randperm(bboxes.size(0), generator=generator) 51 | bboxes = bboxes[indices[: args.max_bbox_num]] 52 | 53 | for n_clusters in n_clusters_list: 54 | start_time = time.time() 55 | if args.algorithm == "kmeans": 56 | kwargs = {"n_clusters": n_clusters, "random_state": args.random_state} 57 | # one variable 58 | for i, key in enumerate(KEYS): 59 | key = f"{key}-{n_clusters}" 60 | models[key] = KMeans(**kwargs).fit(bboxes[..., i : i + 1].numpy()) 61 | elif args.algorithm == "percentile": 62 | kwargs = {"n_clusters": n_clusters} 63 | for i, key in enumerate(KEYS): 64 | key = f"{key}-{n_clusters}" 65 | models[key] = Percentile(**kwargs).fit(bboxes[..., i : i + 1].numpy()) 66 | print( 67 | f"{name} ({args.algorithm} {n_clusters} clusters): {time.time() - start_time}s" 68 | ) 69 | 70 | with fs.open(weight_path, "wb") as f: 71 | pickle.dump(models, f, protocol=pickle.HIGHEST_PROTOCOL) 72 | -------------------------------------------------------------------------------- /bin/test.sh: -------------------------------------------------------------------------------- 1 | poetry run python3 -m src.trainer.trainer.test \ 2 | cond= \ # choices: (unconditional, c, cwh, partial, refinement, relation) 3 | dataset_dir=./download/datasets \ 4 | job_dir= \ 5 | result_dir= 6 | -------------------------------------------------------------------------------- /bin/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | DATASET=$1 3 | EXPERIMENT=$2 4 | 5 | if [ "${DATASET}" = "" ]; then 6 | echo "Please specify DATASET as the first args" 7 | exit; 8 | fi 9 | 10 | if [ "${EXPERIMENT}" = "" ]; then 11 | echo "Please specify EXPERIMENT as the second args" 12 | exit; 13 | fi 14 | 15 | ADDITIONAL_ARGS="" 16 | if [ "${3}" != "" ]; then 17 | ADDITIONAL_ARGS="${ADDITIONAL_ARGS} ${@:3}" 18 | fi 19 | 20 | NOW=$(date "+%Y%m%d%H%M%S") 21 | VCPU=16 22 | 23 | DATA_DIR="./download/datasets" 24 | JOB_DIR="tmp/jobs/${DATASET}/${EXPERIMENT}${OPTION}_${NOW}" 25 | FID_WEIGHTS_DIR="./download/fid_weights/FIDNetV3" 26 | SEEDS=0 27 | 28 | echo "DATA_DIR=${DATA_DIR}" 29 | echo "JOB_DIR=${JOB_DIR}" 30 | echo "ADDITIONAL_ARGS=${ADDITIONAL_ARGS}" 31 | 32 | SHARED_DEFAULT_ARGS="--multirun +experiment=${EXPERIMENT} fid_weight_dir=${FID_WEIGHTS_DIR} job_dir=${JOB_DIR} dataset=${DATASET} dataset.dir=${DATA_DIR} data.num_workers=${VCPU} seed=${SEEDS}" 33 | 34 | poetry run python -m trainer.main ${SHARED_DEFAULT_ARGS} ${ADDITIONAL_ARGS} 35 | -------------------------------------------------------------------------------- /docs/custom_dataset.md: -------------------------------------------------------------------------------- 1 | # Training on custom dataset 2 | 3 | Note: Please run all the scripts at the root of this project to make sure to invoke `poetry` command. 4 | 5 | ### 1. Make a config file for the dataset 6 | 7 | Make a yaml file describing the dataset and put it under [this directory](../src/trainer/trainer/config/dataset/). 8 | This yaml is parsed and used as the input for [hydra.utils.instantiate](https://hydra.cc/docs/advanced/instantiate_objects/overview/) to initialize the dataset class. 9 | For example, the config for Rico dataset ([rico25.yaml](../src/trainer/trainer/config/dataset/rico25.yaml)) is currently as follows: 10 | 11 | ```yaml 12 | _target_: trainer.datasets.rico.Rico25Dataset 13 | _partial_: true 14 | dir: ??? 15 | max_seq_length: 25 16 | ``` 17 | 18 | ### 2. Implement a dataset class 19 | 20 | Implement a dataset class that is defined above. It should inherit `BaseDataset` in [base.py](../src/trainer/trainer/datasets/base.py) and override `preprocess` function to conduct dataset-specific preprocessing and train-val-test split. For example, please refer to `Rico25Dataset` in [rico.py](../src/trainer/trainer/datasets/rico.py). 21 | 22 | Modify `DATASET_DIR` in [global_config.py](../src/trainer/trainer/global_configs.py), so that your dataset is used. 23 | `DATASET_DIR` should have the following structure. 24 | ``` 25 | DATASET_DIR 26 | - raw 27 | - processed 28 | ``` 29 | `DATASET_DIR/raw` contains raw dataset files. 30 | `DATASET_DIR/processed` contains the preprocessed splits and meta information that is auto-generated. 31 | 32 | ### 3. Train a layout classifier for FID computation 33 | 34 | Following [LayoutGAN++](https://arxiv.org/abs/2108.00871), we compute FID during and after the training of LayoutDM for validation and testing, respectively. In order to do so, we first train a Transformer-based model that can extract discriminative layout features, which is used to compute the FID. This is done by: 35 | 36 | ```bash 37 | poetry run python3 src/trainer/trainer/fid/train.py --out_dir 38 | ``` 39 | 40 | After the training, modify `FID_WEIGHT_DIR` in [global_config.py](../src/trainer/trainer/global_configs.py), so that the trained weights are used for FID computation later. 41 | 42 | ### (4. Clustering coordinates of layouts) 43 | If one wants to apply adaptive quantization for position and size tokens, please first conduct clustering. 44 | ```bash 45 | poetry run python3 bin/clustering_coordinates.py --result_dir 46 | ``` 47 | 48 | After the clustering, modify `KMEANS_WEIGHT_ROOT` in [global_config.py](../src/trainer/trainer/global_configs.py), so that the cluster centroids are loaded later. 49 | 50 | ### 5. Train your own model 51 | ```bash 52 | bash bin/train.sh rico25 layoutdm 53 | ``` 54 | 55 | # Testing on custom dataset 56 | If you want to feed a hand-made layout to LayoutDM, the quickest way is to instantiate [Data](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data). 57 | 58 | ```python 59 | from torch_geometric.data import Data 60 | 61 | # [xc, yc, w, h] format in 0~1 normalized coordinates 62 | bboxes = torch.FloatTensor([ 63 | [0.4985, 0.0968, 0.4990, 0.0153], 64 | [0.4986, 0.5134, 0.8288, 0.0285], 65 | [0.4986, 0.2918, 0.8289, 0.3573], 66 | ]) 67 | # see .labels of each dataset class for name-index correspondense 68 | labels = torch.LongTensor([0, 0, 3]) 69 | assert bboxes.size(0) == labels.size(0) and bboxes.size(1) == 4 70 | 71 | # set some optional attributes by a dummy value (False) 72 | attr = {k: torch.full((1,), fill_value=False) for k in ["filtered", "has_canvas_element", "NoiseAdded"]} 73 | 74 | data = Data(x=bboxes, y=labels, attr=attr) # can be used as an alternative for `dataset[target_index]` in demo.ipynb 75 | ``` 76 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | from collections import defaultdict 5 | from typing import Dict, List 6 | 7 | import numpy as np 8 | import torch 9 | from fsspec.core import url_to_fs 10 | from hydra.utils import instantiate 11 | from omegaconf import DictConfig 12 | from torch_geometric.loader import DataLoader 13 | from trainer.data.util import loader_to_list, sparse_to_dense 14 | from trainer.fid.model import load_fidnet_v3 15 | from trainer.global_configs import FID_WEIGHT_DIR 16 | from trainer.helpers.metric import ( 17 | Layout, 18 | compute_alignment, 19 | compute_average_iou, 20 | compute_docsim, 21 | compute_generative_model_scores, 22 | compute_maximum_iou, 23 | compute_overlap, 24 | ) 25 | from trainer.helpers.util import set_seed 26 | 27 | 28 | def preprocess(layouts: List[Layout], max_len: int, device: torch.device): 29 | layout = defaultdict(list) 30 | for (b, l) in layouts: 31 | pad_len = max_len - l.shape[0] 32 | bbox = torch.tensor( 33 | np.concatenate([b, np.full((pad_len, 4), 0.0)], axis=0), 34 | dtype=torch.float, 35 | ) 36 | layout["bbox"].append(bbox) 37 | label = torch.tensor( 38 | np.concatenate([l, np.full((pad_len,), 0)], axis=0), 39 | dtype=torch.long, 40 | ) 41 | layout["label"].append(label) 42 | mask = torch.tensor( 43 | [True for _ in range(l.shape[0])] + [False for _ in range(pad_len)] 44 | ) 45 | layout["mask"].append(mask) 46 | bbox = torch.stack(layout["bbox"], dim=0).to(device) 47 | label = torch.stack(layout["label"], dim=0).to(device) 48 | mask = torch.stack(layout["mask"], dim=0).to(device) 49 | padding_mask = ~mask 50 | return bbox, label, padding_mask, mask 51 | 52 | 53 | def print_scores(scores: Dict, test_cfg: argparse.Namespace, train_cfg: DictConfig): 54 | scores = {k: scores[k] for k in sorted(scores)} 55 | job_name = train_cfg.job_dir.split("/")[-1] 56 | model_name = train_cfg.model._target_.split(".")[-1] 57 | cond = test_cfg.cond 58 | 59 | if "num_timesteps" in test_cfg: 60 | step = test_cfg.num_timesteps 61 | else: 62 | step = train_cfg.sampling.get("num_timesteps", None) 63 | 64 | option = "" 65 | header = ["job_name", "model_name", "cond", "step", "option"] 66 | data = [job_name, model_name, cond, step, option] 67 | 68 | tex = "" 69 | for k, v in scores.items(): 70 | # if k == "Alignment" or k == "Overlap" or "Violation" in k: 71 | # v = [_v * 100 for _v in v] 72 | mean, std = np.mean(v), np.std(v) 73 | stdp = std * 100.0 / mean 74 | print(f"\t{k}: {mean:.4f} ({stdp:.4f}%)") 75 | tex += f"& {mean:.4f}\\std{{{stdp:.1f}}}\% " 76 | 77 | header.extend([f"{k}-mean", f"{k}-std"]) 78 | data.extend([mean, std]) 79 | 80 | print(tex + "\\\\") 81 | 82 | print(",".join(header)) 83 | print(",".join([str(d) for d in data])) 84 | 85 | 86 | def main(): 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument("result_dir", type=str, default="tmp/results") 89 | parser.add_argument( 90 | "--compute_real", 91 | action="store_true", 92 | help="compute some metric between validation and test subset", 93 | ) 94 | parser.add_argument( 95 | "--num_samples", 96 | type=int, 97 | default=1000, 98 | help="number of samples used for evaluating unconditional generation", 99 | ) 100 | parser.add_argument("--batch_size", type=int, default=512) 101 | args = parser.parse_args() 102 | set_seed(0) 103 | 104 | fs, _ = url_to_fs(args.result_dir) 105 | pkl_paths = [p for p in fs.ls(args.result_dir) if p.split(".")[-1] == "pkl"] 106 | with fs.open(pkl_paths[0], "rb") as file_obj: 107 | meta = pickle.load(file_obj) 108 | train_cfg, test_cfg = meta["train_cfg"], meta["test_cfg"] 109 | assert test_cfg.num_run == 1 110 | 111 | train_cfg.data.num_workers = os.cpu_count() 112 | 113 | kwargs = { 114 | "batch_size": args.batch_size, 115 | "num_workers": train_cfg.data.num_workers, 116 | "pin_memory": True, 117 | "shuffle": False, 118 | } 119 | 120 | if test_cfg.get("is_validation", False): 121 | split_main, split_sub = "val", "test" 122 | else: 123 | split_main, split_sub = "test", "val" 124 | 125 | main_dataset = instantiate(train_cfg.dataset)(split=split_main, transform=None) 126 | if test_cfg.get("debug_num_samples", -1) > 0: 127 | main_dataset = main_dataset[: test_cfg.debug_num_samples] 128 | main_dataloader = DataLoader(main_dataset, **kwargs) 129 | layouts_main = loader_to_list(main_dataloader) 130 | 131 | if args.compute_real: 132 | sub_dataset = instantiate(train_cfg.dataset)(split=split_sub, transform=None) 133 | if test_cfg.cond == "unconditional": 134 | sub_dataset = sub_dataset[: args.num_samples] 135 | sub_dataloader = DataLoader(sub_dataset, **kwargs) 136 | layouts_sub = loader_to_list(sub_dataloader) 137 | 138 | num_classes = len(main_dataset.labels) 139 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 140 | fid_model = load_fidnet_v3(main_dataset, FID_WEIGHT_DIR, device) 141 | 142 | scores_all = defaultdict(list) 143 | feats_1 = [] 144 | batch_metrics = defaultdict(float) 145 | for i, batch in enumerate(main_dataloader): 146 | bbox, label, padding_mask, mask = sparse_to_dense(batch, device) 147 | with torch.set_grad_enabled(False): 148 | feat = fid_model.extract_features(bbox, label, padding_mask) 149 | feats_1.append(feat.cpu()) 150 | # save_image(bbox, label, mask, main_dataset.colors, f"dummy.png") 151 | 152 | if args.compute_real: 153 | for k, v in compute_alignment(bbox.cpu(), mask.cpu()).items(): 154 | batch_metrics[k] += v.sum().item() 155 | for k, v in compute_overlap(bbox.cpu(), mask.cpu()).items(): 156 | batch_metrics[k] += v.sum().item() 157 | 158 | if args.compute_real: 159 | scores_real = defaultdict(list) 160 | for k, v in batch_metrics.items(): 161 | scores_real.update({k: v / len(main_dataset)}) 162 | 163 | # compute metrics between real val and test dataset 164 | if args.compute_real: 165 | feats_1_another = [] 166 | for batch in sub_dataloader: 167 | bbox, label, padding_mask, mask = sparse_to_dense(batch, device) 168 | with torch.set_grad_enabled(False): 169 | feat = fid_model.extract_features(bbox, label, padding_mask) 170 | feats_1_another.append(feat.cpu()) 171 | 172 | scores_real.update(compute_generative_model_scores(feats_1, feats_1_another)) 173 | scores_real.update(compute_average_iou(layouts_sub)) 174 | if test_cfg.cond != "unconditional": 175 | scores_real["maximum_iou"] = compute_maximum_iou(layouts_main, layouts_sub) 176 | scores_real["DocSim"] = compute_docsim(layouts_main, layouts_main) 177 | 178 | # regard as the result of single run 179 | scores_real = {k: [v] for (k, v) in scores_real.items()} 180 | print() 181 | print("\nReal data:") 182 | print_scores(scores_real, test_cfg, train_cfg) 183 | 184 | # compute scores for each run 185 | for pkl_path in pkl_paths: 186 | feats_2 = [] 187 | batch_metrics = defaultdict(float) 188 | 189 | with fs.open(pkl_path, "rb") as file_obj: 190 | x = pickle.load(file_obj) 191 | generated = x["results"] 192 | 193 | for i in range(0, len(generated), args.batch_size): 194 | i_end = min(i + args.batch_size, len(generated)) 195 | batch = generated[i:i_end] 196 | max_len = max(len(g[-1]) for g in batch) 197 | 198 | bbox, label, padding_mask, mask = preprocess(batch, max_len, device) 199 | with torch.set_grad_enabled(False): 200 | feat = fid_model.extract_features(bbox, label, padding_mask) 201 | feats_2.append(feat.cpu()) 202 | 203 | for k, v in compute_alignment(bbox, mask).items(): 204 | batch_metrics[k] += v.sum().item() 205 | for k, v in compute_overlap(bbox, mask).items(): 206 | batch_metrics[k] += v.sum().item() 207 | 208 | scores = {} 209 | for k, v in batch_metrics.items(): 210 | scores[k] = v / len(generated) 211 | scores.update(compute_average_iou(generated)) 212 | scores.update(compute_generative_model_scores(feats_1, feats_2)) 213 | if test_cfg.cond != "unconditional": 214 | scores["maximum_iou"] = compute_maximum_iou(layouts_main, generated) 215 | scores["DocSim"] = compute_docsim(layouts_main, generated) 216 | 217 | for k, v in scores.items(): 218 | scores_all[k].append(v) 219 | 220 | print_scores(scores_all, test_cfg, train_cfg) 221 | print() 222 | 223 | 224 | if __name__ == "__main__": 225 | main() 226 | -------------------------------------------------------------------------------- /notebooks/demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import copy\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "import math\n", 22 | "import os\n", 23 | "from omegaconf import OmegaConf\n", 24 | "\n", 25 | "import torch\n", 26 | "from torch_geometric.utils import to_dense_adj\n", 27 | "import torchvision.transforms as T\n", 28 | "from fsspec.core import url_to_fs\n", 29 | "from hydra.utils import instantiate\n", 30 | "from trainer.data.util import AddCanvasElement, AddRelationConstraints, sparse_to_dense\n", 31 | "from trainer.global_configs import DATASET_DIR, JOB_DIR\n", 32 | "from trainer.helpers.layout_tokenizer import LayoutSequenceTokenizer\n", 33 | "from trainer.helpers.sampling import SAMPLING_CONFIG_DICT\n", 34 | "from trainer.helpers.task import get_cond, filter_canvas\n", 35 | "from trainer.helpers.visualization import save_gif, save_image, save_label, save_label_with_size, save_relation\n", 36 | "from trainer.hydra_configs import TestConfig\n", 37 | "\n", 38 | "SIZE = (360, 240)\n", 39 | "\n", 40 | "# user tunable parameters\n", 41 | "# cond_type, W_CANVAS = \"relation\", True # uncomment this line if you want to try relation task\n", 42 | "cond_type, W_CANVAS = \"cwh\", False # choices: unconditional, c, cwh, partial, refinement\n", 43 | "n_samples = 4 # num. of samples to generate at once\n", 44 | "target_index = 0 # index of real data, partial fields in it are used for conditional generation\n", 45 | "\n", 46 | "job_dir = os.path.join(JOB_DIR, \"layoutdm_publaynet/0\")\n", 47 | "\n", 48 | "config_path = os.path.join(job_dir, \"config.yaml\")\n", 49 | "fs, _ = url_to_fs(config_path)\n", 50 | "if fs.exists(config_path):\n", 51 | " with fs.open(config_path, \"rb\") as file_obj:\n", 52 | " train_cfg = OmegaConf.load(file_obj)\n", 53 | "else:\n", 54 | " raise FileNotFoundError\n", 55 | "train_cfg.dataset.dir = DATASET_DIR\n", 56 | "\n", 57 | "test_cfg = OmegaConf.structured(TestConfig)\n", 58 | "test_cfg.cond = cond_type\n", 59 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", 60 | "\n", 61 | "sampling_cfg = OmegaConf.structured(SAMPLING_CONFIG_DICT[test_cfg.sampling]) # NOTE: you may change sampling algorithm\n", 62 | "OmegaConf.set_struct(sampling_cfg, False)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "# initialize data and model\n", 72 | "tokenizer = LayoutSequenceTokenizer(\n", 73 | " data_cfg=train_cfg.data, dataset_cfg=train_cfg.dataset\n", 74 | ")\n", 75 | "model = instantiate(train_cfg.model)(\n", 76 | " backbone_cfg=train_cfg.backbone, tokenizer=tokenizer\n", 77 | ").to(device)\n", 78 | "model_path = os.path.join(job_dir, \"best_model.pt\")\n", 79 | "with fs.open(model_path, \"rb\") as file_obj:\n", 80 | " model.load_state_dict(torch.load(file_obj))\n", 81 | "model = model.to(device)\n", 82 | "model.eval()\n", 83 | "sampling_cfg = model.aggregate_sampling_settings(sampling_cfg, test_cfg)\n", 84 | "\n", 85 | "if W_CANVAS:\n", 86 | " # add canvas and shift label id to load relation gts\n", 87 | " assert cond_type == \"relation\"\n", 88 | " transform = T.Compose([\n", 89 | " AddCanvasElement(),\n", 90 | " AddRelationConstraints(edge_ratio=0.1),\n", 91 | " ])\n", 92 | "else:\n", 93 | " assert cond_type != \"relation\"\n", 94 | " transform = None\n", 95 | "dataset = instantiate(train_cfg.dataset)(split=\"test\", transform=transform)\n", 96 | "save_kwargs = {\n", 97 | " \"colors\": dataset.colors, \"names\": dataset.labels,\n", 98 | " \"canvas_size\": SIZE, \"use_grid\": True,\n", 99 | " # \"draw_label\": True,\n", 100 | "}\n" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": {}, 106 | "source": [ 107 | "### Real data visualization" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "# load target data and visualize GT\n", 117 | "bbox, label, _, mask = sparse_to_dense(dataset[target_index])\n", 118 | "gt_cond = model.tokenizer.encode(\n", 119 | " {\"label\": label, \"mask\": mask, \"bbox\": bbox}\n", 120 | ")\n", 121 | "if \"bos\" in tokenizer.special_tokens:\n", 122 | " gt = model.tokenizer.decode(gt_cond[\"seq\"][:, 1:])\n", 123 | "else:\n", 124 | " gt = model.tokenizer.decode(gt_cond[\"seq\"])\n", 125 | "if W_CANVAS:\n", 126 | " gt = filter_canvas(gt) # remove canvas attributes before visualization\n", 127 | "plt.axis(\"off\")\n", 128 | "plt.imshow(save_image(gt[\"bbox\"], gt[\"label\"], gt[\"mask\"], **save_kwargs))\n" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "### Unconditional Generation" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "assert cond_type == \"unconditional\"\n", 145 | "pred = model.sample(batch_size=n_samples, cond=None, sampling_cfg=sampling_cfg)\n", 146 | "plt.axis(\"off\")\n", 147 | "plt.imshow(save_image(pred[\"bbox\"], pred[\"label\"], pred[\"mask\"], **save_kwargs))" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "### Conditional Generation" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "metadata": {}, 160 | "source": [ 161 | "#### Prediction" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "cond = get_cond(\n", 171 | " batch=dataset[target_index],\n", 172 | " tokenizer=model.tokenizer,\n", 173 | " cond_type=cond_type,\n", 174 | " model_type=type(model).__name__,\n", 175 | ")\n", 176 | "pred = model.sample(batch_size=n_samples, cond=cond, sampling_cfg=sampling_cfg)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "metadata": {}, 182 | "source": [ 183 | "#### Visualization of conditional inputs" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "plt.axis(\"off\")\n", 193 | "input_ = model.tokenizer.decode(cond[\"seq\"].cpu())\n", 194 | "mask = pred[\"mask\"][0]\n", 195 | "label, bbox = pred[\"label\"][0][mask], pred[\"bbox\"][0][mask]\n", 196 | "if cond_type == \"c\":\n", 197 | " plt.imshow(save_label(label, **save_kwargs))\n", 198 | "elif cond_type == \"cwh\":\n", 199 | " plt.imshow(save_label_with_size(label, bbox, **save_kwargs))\n", 200 | "elif cond_type == \"relation\":\n", 201 | " data = cond[\"batch_w_canvas\"]\n", 202 | " edge_attr = to_dense_adj(data.edge_index, data.batch, data.edge_attr)\n", 203 | " plt.imshow(save_relation(label_with_canvas=data.y.cpu(), edge_attr=edge_attr.cpu()[0], **save_kwargs))\n", 204 | "elif cond_type == \"partial\":\n", 205 | " plt.imshow(save_image(input_[\"bbox\"], input_[\"label\"], input_[\"mask\"], **save_kwargs))\n", 206 | "elif cond_type == \"refinement\":\n", 207 | " noisy_input = model.tokenizer.decode(cond[\"seq_orig\"].cpu())\n", 208 | " plt.imshow(save_image(noisy_input[\"bbox\"][0:1], noisy_input[\"label\"][0:1], noisy_input[\"mask\"][0:1], **save_kwargs))\n" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "metadata": {}, 214 | "source": [ 215 | "#### Visualization of outputs" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "fig, ax = plt.subplots(figsize=(15, 5))\n", 225 | "ax.set_axis_off()\n", 226 | "ax.imshow(save_image(pred[\"bbox\"], pred[\"label\"], pred[\"mask\"], **save_kwargs, nrow=int(math.sqrt(n_samples) * 2)))" 227 | ] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": {}, 232 | "source": [ 233 | "#### Make GIF for Unconditional Generation" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "new_save_kwargs = copy.deepcopy(save_kwargs)\n", 243 | "new_save_kwargs.pop(\"use_grid\")\n", 244 | "ids_list = model.model.sample(\n", 245 | " batch_size=4,\n", 246 | " sampling_cfg=sampling_cfg,\n", 247 | " get_intermediate_results=True,\n", 248 | ")\n", 249 | "images = []\n", 250 | "for ids in ids_list:\n", 251 | " layouts = model.tokenizer.decode(ids)\n", 252 | " image = save_image(\n", 253 | " layouts[\"bbox\"],\n", 254 | " layouts[\"label\"],\n", 255 | " layouts[\"mask\"],\n", 256 | " **new_save_kwargs\n", 257 | " )\n", 258 | " images.append(image)\n", 259 | "N_step = len(images)\n", 260 | "images = images[int(0.5*N_step):]\n", 261 | "save_gif(images, \"../tmp/animation/{}.gif\")\n" 262 | ] 263 | }, 264 | { 265 | "attachments": {}, 266 | "cell_type": "markdown", 267 | "metadata": {}, 268 | "source": [ 269 | "#### Dump colors of all labels" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": null, 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "labels = []\n", 279 | "for i, name in enumerate(save_kwargs[\"names\"]):\n", 280 | " if \"_cutout\" in name:\n", 281 | " continue\n", 282 | " else:\n", 283 | " labels.append(i)\n", 284 | "plt.axis(\"off\")\n", 285 | "plt.imshow(save_label(labels, **save_kwargs))" 286 | ] 287 | } 288 | ], 289 | "metadata": { 290 | "kernelspec": { 291 | "display_name": "Python 3.7.12 ('trainer-GSN1huuF-py3.7')", 292 | "language": "python", 293 | "name": "python3" 294 | }, 295 | "language_info": { 296 | "codemirror_mode": { 297 | "name": "ipython", 298 | "version": 3 299 | }, 300 | "file_extension": ".py", 301 | "mimetype": "text/x-python", 302 | "name": "python", 303 | "nbconvert_exporter": "python", 304 | "pygments_lexer": "ipython3", 305 | "version": "3.7.12" 306 | }, 307 | "orig_nbformat": 4, 308 | "vscode": { 309 | "interpreter": { 310 | "hash": "df1bfdd73842a1319c146cb5c112d3818824eaeb9e2048caf661736e74000887" 311 | } 312 | } 313 | }, 314 | "nbformat": 4, 315 | "nbformat_minor": 2 316 | } 317 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "trainer" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Dummy User "] 6 | packages = [ 7 | { include = "trainer", from = "src/trainer", format="sdist" }, 8 | ] 9 | 10 | [tool.poetry.dependencies] 11 | python = "^3.7.1,<3.8" 12 | Pillow = "^9.1.0" 13 | google-cloud-storage = "^2.2.1" 14 | torch = [ 15 | {url = "https://download.pytorch.org/whl/cu113/torch-1.10.2%2Bcu113-cp37-cp37m-linux_x86_64.whl", platform = "linux"} 16 | ] 17 | tensorflow = "^2.8.0" 18 | hydra-core = "^1.1.2" 19 | einops = "^0.4.1" 20 | tqdm = "^4.64.0" 21 | torchvision = {url = "https://download.pytorch.org/whl/cu113/torchvision-0.11.3%2Bcu113-cp37-cp37m-linux_x86_64.whl", platform = "linux"} 22 | torch-scatter = {url = "https://data.pyg.org/whl/torch-1.10.0%2Bcu113/torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl", platform = "linux"} 23 | torch-sparse = {url = "https://data.pyg.org/whl/torch-1.10.0%2Bcu113/torch_sparse-0.6.13-cp37-cp37m-linux_x86_64.whl", platform = "linux"} 24 | pandas = "1.3.5" 25 | torch-geometric = "^2.0.4" 26 | seaborn = "^0.11.2" 27 | ipython = "<7.32.0" 28 | ipykernel = "^6.13.0" 29 | torch-tb-profiler = "^0.4.0" 30 | setuptools = "59.5.0" 31 | pytorch-fid = "^0.2.1" 32 | Cython = "^0.29.30" 33 | pycocotools = "^2.0.4" 34 | prdc = "^0.2" 35 | scikit-learn = "1.0.2" 36 | numpy = "1.21.6" 37 | matplotlib = "3.5.2" 38 | fsspec = "2023.1.0" 39 | gcsfs = "2023.1.0" 40 | scipy = ">=1.4,<1.10" 41 | 42 | [tool.poetry.dev-dependencies] 43 | pytest = "^7.0" 44 | pysen = {extras = ["lint"], version = "^0.10.1"} 45 | 46 | [build-system] 47 | requires = ["poetry-core>=1.0.0"] 48 | build-backend = "poetry.core.masonry.api" 49 | 50 | [tool.pysen] 51 | version = "0.10" 52 | 53 | [tool.pysen.lint] 54 | enable_black = true 55 | enable_flake8 = true 56 | enable_isort = true 57 | enable_mypy = true 58 | mypy_preset = "strict" 59 | line_length = 88 60 | py_version = "py37" 61 | 62 | [[tool.pysen.lint.mypy_targets]] 63 | paths = [".", "src/preprocess/preprocess", "src/trainer/trainer"] -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | # Render generation results and conditional inputs if available 2 | import argparse 3 | import csv 4 | import os 5 | import pickle 6 | from collections import defaultdict 7 | from pathlib import Path 8 | 9 | import torch 10 | from einops import rearrange 11 | from fsspec.core import url_to_fs 12 | from hydra.utils import instantiate 13 | from torch_geometric.loader import DataLoader 14 | from torch_geometric.utils import to_dense_adj 15 | from tqdm import tqdm 16 | from trainer.data.util import loader_to_list, sparse_to_dense 17 | from trainer.helpers.metric import compute_alignment, compute_docsim, compute_overlap 18 | from trainer.helpers.util import set_seed 19 | from trainer.helpers.visualization import ( 20 | save_image, 21 | save_label, 22 | save_label_with_size, 23 | save_relation, 24 | ) 25 | 26 | CANVAS_SIZE = (120, 80) 27 | 28 | 29 | def _repeat(inputs, n: int): 30 | outputs = [] 31 | for x in inputs: 32 | for i in range(n): 33 | outputs.append(x) 34 | return outputs 35 | 36 | 37 | def main(): 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument("result_dir", type=str) 40 | parser.add_argument("--output_dir", type=str, default="tmp/visualization") 41 | parser.add_argument("--eval_batch_size", type=int, default=512) 42 | parser.add_argument("--dump_num_samples", type=int, default=100) 43 | parser.add_argument( 44 | "--debug", 45 | action="store_true", 46 | help="disable parallel computation for debugging", 47 | ) 48 | args = parser.parse_args() 49 | set_seed(0) 50 | 51 | fs, _ = url_to_fs(args.result_dir) 52 | 53 | pkl_paths = [p for p in fs.ls(args.result_dir) if p.split(".")[-1] == "pkl"] 54 | pkl_paths = pkl_paths[:1] 55 | with fs.open(pkl_paths[0], "rb") as file_obj: 56 | meta = pickle.load(file_obj) 57 | train_cfg = meta["train_cfg"] 58 | test_cfg = meta["test_cfg"] 59 | 60 | train_cfg.data.num_workers = os.cpu_count() 61 | batch_size = args.eval_batch_size # Note: arbitrary number is OK unless OOM 62 | 63 | kwargs = { 64 | "batch_size": batch_size, 65 | "num_workers": train_cfg.data.num_workers, 66 | "pin_memory": True, 67 | "shuffle": False, 68 | } 69 | 70 | split_main = "val" if test_cfg.get("is_validation", False) else "test" 71 | main_dataset = instantiate(train_cfg.dataset)(split=split_main, transform=None) 72 | if test_cfg.get("debug_num_samples", -1) > 0: 73 | num_samples = test_cfg.get("debug_num_samples") 74 | else: 75 | num_samples = args.dump_num_samples 76 | main_dataset = main_dataset[:num_samples] 77 | main_dataloader = DataLoader(main_dataset, **kwargs) 78 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 79 | 80 | dataset_name, ckpt_name, cond_name = args.result_dir.split("/")[-3:] 81 | pred_dir = Path(args.output_dir) / dataset_name / ckpt_name / cond_name 82 | pred_dir.mkdir(parents=True, exist_ok=True) 83 | 84 | gt_dir = Path(args.output_dir) / dataset_name / "gt" 85 | gt_dir.mkdir(parents=True, exist_ok=True) 86 | 87 | names_all = [] 88 | for i, batch in enumerate(main_dataloader): 89 | bbox, label, padding_mask, mask = sparse_to_dense(batch, device) 90 | names = [Path(n).with_suffix("").name for n in batch.attr["name"]] 91 | names_all.extend(names) 92 | for j, name in enumerate(names): 93 | out_path = gt_dir / f"{name}.png" 94 | if out_path.exists(): 95 | continue 96 | save_image( 97 | bbox[j : j + 1], 98 | label[j : j + 1], 99 | mask[j : j + 1], 100 | main_dataset.colors, 101 | out_path, 102 | canvas_size=CANVAS_SIZE, 103 | ) 104 | layouts_main = _repeat(loader_to_list(main_dataloader), test_cfg.num_run) 105 | names_all = _repeat(names_all, test_cfg.num_run) 106 | 107 | for t, pkl_path in enumerate(pkl_paths): 108 | headers = ["name", "alignment-LayoutGAN++", "overlap-LayoutGAN++", "docsim"] 109 | if test_cfg.cond in [ 110 | "relation", 111 | ]: 112 | headers.append("violation") 113 | scores = [] 114 | 115 | with fs.open(pkl_path, "rb") as file_obj: 116 | x = pickle.load(file_obj) 117 | 118 | N = num_samples * test_cfg.num_run 119 | generated = x["results"][:N] 120 | if "inputs" in x: 121 | inputs = x["inputs"][:N] 122 | if "relations" in x: 123 | relations = x["relations"][:N] 124 | relation_scores = x["relation_scores"][:N] 125 | 126 | for i, (b, l) in enumerate(tqdm(generated)): 127 | if test_cfg.cond in [ 128 | "relation", 129 | ]: 130 | data = relations[i] 131 | id_ = data.attr["name"][0].split("/")[-1].split(".")[0] 132 | else: 133 | id_ = names_all[i] 134 | 135 | b, l = torch.from_numpy(b), torch.from_numpy(l) 136 | batched_b = rearrange(b, "s x -> 1 s x") 137 | batched_l = rearrange(l, "s -> 1 s") 138 | batched_m = torch.full(batched_l.size(), True) 139 | alignment = compute_alignment(batched_b, batched_m) 140 | overlap = compute_overlap(batched_b, batched_m) 141 | docsim = compute_docsim( 142 | [(b, l)], 143 | [list(torch.from_numpy(l) for l in layouts_main[i])], 144 | ) 145 | 146 | score = [ 147 | id_, 148 | alignment["alignment-LayoutGAN++"].item(), 149 | overlap["overlap-LayoutGAN++"].item(), 150 | docsim, 151 | ] 152 | 153 | is_first = i % test_cfg.num_run == 0 154 | 155 | pred_image_path = pred_dir / f"{id_}_{i % test_cfg.num_run}.png" 156 | if not pred_image_path.exists(): 157 | save_image( 158 | batched_b, 159 | batched_l, 160 | batched_m, 161 | main_dataset.colors, 162 | pred_image_path, 163 | canvas_size=CANVAS_SIZE, 164 | ) 165 | 166 | if test_cfg.cond in ["c", "cwh", "relation"] and is_first: 167 | pred_label_path = pred_dir / f"{id_}_label.png" 168 | if not pred_label_path.exists(): 169 | save_label( 170 | l, 171 | main_dataset.labels, 172 | main_dataset.colors, 173 | pred_label_path, 174 | canvas_size=tuple([x * 3 for x in CANVAS_SIZE]), 175 | ) 176 | if test_cfg.cond == "cwh" and is_first: 177 | pred_label_w_size_path = pred_dir / f"{id_}_label_w_size.png" 178 | if not pred_label_w_size_path.exists(): 179 | save_label_with_size( 180 | l, 181 | main_dataset.labels, 182 | main_dataset.colors, 183 | pred_label_w_size_path, 184 | canvas_size=tuple([x * 3 for x in CANVAS_SIZE]), 185 | ) 186 | 187 | if test_cfg.cond in ["partial", "refinement"] and is_first: 188 | input_b, input_l = inputs[i] 189 | batched_b = rearrange(torch.from_numpy(input_b), "s x -> 1 s x") 190 | batched_l = rearrange(torch.from_numpy(input_l), "s -> 1 s") 191 | batched_m = torch.full(batched_l.size(), True) 192 | input_path = pred_dir / f"{id_}_input.png" 193 | if not input_path.exists(): 194 | save_image( 195 | batched_b, 196 | batched_l, 197 | batched_m, 198 | main_dataset.colors, 199 | input_path, 200 | canvas_size=CANVAS_SIZE, 201 | ) 202 | 203 | if test_cfg.cond == "relation": 204 | if len(data.edge_index) == 0: 205 | continue 206 | if is_first: 207 | relation_path = pred_dir / f"{id_}_relation.png" 208 | edge_attr = to_dense_adj( 209 | data.edge_index, data.batch, data.edge_attr 210 | ) 211 | save_relation( 212 | data.y.cpu().numpy(), 213 | edge_attr.cpu()[0], 214 | main_dataset.labels, 215 | main_dataset.colors, 216 | relation_path, 217 | ) 218 | score.append(relation_scores[i]) 219 | 220 | scores.append(score) 221 | 222 | with (pred_dir / "stats.csv").open("w") as f: 223 | writer = csv.writer(f) 224 | writer.writerow(headers) 225 | writer.writerows(scores) 226 | exit() 227 | 228 | 229 | if __name__ == "__main__": 230 | main() 231 | -------------------------------------------------------------------------------- /src/trainer/trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyberAgentAILab/layout-dm/873b5eebe4c61862e5c08a10859accf65a168dfd/src/trainer/trainer/__init__.py -------------------------------------------------------------------------------- /src/trainer/trainer/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyberAgentAILab/layout-dm/873b5eebe4c61862e5c08a10859accf65a168dfd/src/trainer/trainer/config/__init__.py -------------------------------------------------------------------------------- /src/trainer/trainer/config/backbone/medium.yaml: -------------------------------------------------------------------------------- 1 | # common setup for layout generation 2 | # used in VTN (CVPR'21) and BLT (ECCV'22) 3 | _target_: trainer.models.transformer_utils.TransformerEncoder 4 | encoder_layer: 5 | _target_: trainer.models.transformer_utils.Block 6 | d_model: 512 7 | nhead: 8 8 | dim_feedforward: 2048 9 | dropout: 0.1 10 | batch_first: true 11 | norm_first: true 12 | num_layers: 4 -------------------------------------------------------------------------------- /src/trainer/trainer/config/data/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_data_default 3 | -------------------------------------------------------------------------------- /src/trainer/trainer/config/dataset/publaynet.yaml: -------------------------------------------------------------------------------- 1 | _target_: trainer.datasets.publaynet.PubLayNetDataset 2 | _partial_: true 3 | dir: ??? 4 | max_seq_length: 25 -------------------------------------------------------------------------------- /src/trainer/trainer/config/dataset/rico25.yaml: -------------------------------------------------------------------------------- 1 | _target_: trainer.datasets.rico.Rico25Dataset 2 | _partial_: true 3 | dir: ??? 4 | max_seq_length: 25 -------------------------------------------------------------------------------- /src/trainer/trainer/config/experiment/bart.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: bart 4 | 5 | data: 6 | pad_until_max: true 7 | special_tokens: [pad, bos, eos, mask] 8 | var_order: c-w-h-x-y 9 | -------------------------------------------------------------------------------- /src/trainer/trainer/config/experiment/blt_eccv2022.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: blt 4 | 5 | # Note: authors performed HP search (details in supplementary), but not sure on other hyper parameters 6 | # lr: [1e-3, 3e-3, 5e-3] 7 | # dropout: [0.1, 0.3] 8 | -------------------------------------------------------------------------------- /src/trainer/trainer/config/experiment/blt_eccv2022_ordered.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: blt 4 | 5 | data: 6 | transforms: ["SortByLabel", "LexicographicOrder"] 7 | 8 | # Note: authors performed HP search (details in supplementary), but not sure on other hyper parameters 9 | # lr: [1e-3, 3e-3, 5e-3] 10 | # dropout: [0.1, 0.3] 11 | -------------------------------------------------------------------------------- /src/trainer/trainer/config/experiment/diffusionlm_neurips2022.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: layout_continuous_diffusion 4 | - override /scheduler: reduce_lr_on_plateau 5 | 6 | data: 7 | pad_until_max: true 8 | shared_bbox_vocab: x-y-w-h 9 | optimizer: 10 | lr: 5.0e-4 11 | backbone: 12 | encoder_layer: 13 | timestep_type: adalayernorm_mlp 14 | diffusion_step: 100 15 | dropout: 0.0 16 | model: 17 | model_type: diffusion_lm 18 | learnable_token_emb: true 19 | -------------------------------------------------------------------------------- /src/trainer/trainer/config/experiment/layout_transformer_iccv2021.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: elem_wise_autoreg 4 | 5 | data: 6 | special_tokens: [pad, bos, eos] 7 | var_order: c-w-h-x-y -------------------------------------------------------------------------------- /src/trainer/trainer/config/experiment/layout_transformer_iccv2021_ordered.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: elem_wise_autoreg 4 | 5 | data: 6 | special_tokens: [pad, bos, eos] 7 | var_order: c-w-h-x-y 8 | transforms: ["SortByLabel", "LexicographicOrder"] 9 | -------------------------------------------------------------------------------- /src/trainer/trainer/config/experiment/layoutdm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: layoutdm 4 | - override /scheduler: reduce_lr_on_plateau 5 | 6 | data: 7 | pad_until_max: true 8 | shared_bbox_vocab: x-y-w-h 9 | bbox_quantization: kmeans 10 | optimizer: 11 | lr: 5.0e-4 12 | backbone: 13 | encoder_layer: 14 | timestep_type: adalayernorm 15 | diffusion_step: 100 16 | dropout: 0.0 17 | model: 18 | q_type: constrained 19 | -------------------------------------------------------------------------------- /src/trainer/trainer/config/experiment/maskgit_cvpr2022.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: maskgit 4 | -------------------------------------------------------------------------------- /src/trainer/trainer/config/experiment/maskgit_cvpr2022_ordered.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: maskgit 4 | 5 | data: 6 | transforms: ["SortByLabel", "LexicographicOrder"] 7 | -------------------------------------------------------------------------------- /src/trainer/trainer/config/experiment/ruite.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ruite 4 | 5 | data: 6 | special_tokens: [pad, ] 7 | transforms: ["RandomOrder", "AddNoiseToBBox(std=0.1)"] 8 | -------------------------------------------------------------------------------- /src/trainer/trainer/config/experiment/vqdiffusion.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: layoutdm 4 | - override /scheduler: reduce_lr_on_plateau 5 | 6 | data: 7 | pad_until_max: true 8 | shared_bbox_vocab: x-y-w-h 9 | bbox_quantization: linear 10 | optimizer: 11 | lr: 5.0e-4 12 | backbone: 13 | encoder_layer: 14 | timestep_type: adalayernorm 15 | diffusion_step: 100 16 | dropout: 0.0 17 | model: 18 | q_type: vanilla 19 | pos_emb: default 20 | -------------------------------------------------------------------------------- /src/trainer/trainer/config/main.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - backbone: medium 3 | - dataset: rico25 4 | - data: default 5 | - model: elem_wise_autoreg 6 | - optimizer: adamw 7 | - sampling: random 8 | - scheduler: void 9 | - training: default 10 | # https://hydra.cc/docs/upgrades/1.0_to_1.1/default_composition_order/ 11 | - _self_ 12 | 13 | job_dir: ??? 14 | fid_weight_dir: ??? 15 | seed: ??? 16 | device: cuda 17 | debug: false 18 | hydra: 19 | run: 20 | dir: . -------------------------------------------------------------------------------- /src/trainer/trainer/config/model/bart.yaml: -------------------------------------------------------------------------------- 1 | _target_: trainer.models.bart.BART 2 | _partial_: true -------------------------------------------------------------------------------- /src/trainer/trainer/config/model/blt.yaml: -------------------------------------------------------------------------------- 1 | _target_: trainer.models.blt.BLT 2 | _partial_: true -------------------------------------------------------------------------------- /src/trainer/trainer/config/model/elem_wise_autoreg.yaml: -------------------------------------------------------------------------------- 1 | _target_: trainer.models.elem_wise_autoreg.ElemWiseAutoreg 2 | _partial_: true -------------------------------------------------------------------------------- /src/trainer/trainer/config/model/layout_continuous_diffusion.yaml: -------------------------------------------------------------------------------- 1 | _target_: trainer.models.layout_continuous_diffusion.LayoutContinuousDiffusion 2 | _partial_: true 3 | -------------------------------------------------------------------------------- /src/trainer/trainer/config/model/layoutdm.yaml: -------------------------------------------------------------------------------- 1 | _target_: trainer.models.layoutdm.LayoutDM 2 | _partial_: true 3 | -------------------------------------------------------------------------------- /src/trainer/trainer/config/model/maskgit.yaml: -------------------------------------------------------------------------------- 1 | _target_: trainer.models.maskgit.MaskGIT 2 | _partial_: true -------------------------------------------------------------------------------- /src/trainer/trainer/config/model/ruite.yaml: -------------------------------------------------------------------------------- 1 | _target_: trainer.models.ruite.RUITE 2 | _partial_: true -------------------------------------------------------------------------------- /src/trainer/trainer/config/optimizer/adamw.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.AdamW 2 | _partial_: true 3 | lr: 1e-4 4 | betas: [0.9, 0.98] 5 | -------------------------------------------------------------------------------- /src/trainer/trainer/config/scheduler/inverse_sqrt_decay_with_warmup.yaml: -------------------------------------------------------------------------------- 1 | _target_: trainer.helpers.scheduler.D3PMScheduler 2 | _partial_: true -------------------------------------------------------------------------------- /src/trainer/trainer/config/scheduler/reduce_lr_on_plateau.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 2 | _partial_: true 3 | mode: min 4 | factor: 0.5 5 | patience: 2 6 | threshold: 1e-2 -------------------------------------------------------------------------------- /src/trainer/trainer/config/scheduler/reduce_lr_on_plateau_with_warmup.yaml: -------------------------------------------------------------------------------- 1 | _target_: trainer.helpers.scheduler.ReduceLROnPlateauWithWarmup 2 | _partial_: true 3 | mode: min 4 | factor: 0.5 5 | patience: 2 6 | threshold: 1e-2 7 | warmup_lr: 5.0e-4 8 | warmup: 5 -------------------------------------------------------------------------------- /src/trainer/trainer/config/scheduler/void.yaml: -------------------------------------------------------------------------------- 1 | _target_: trainer.helpers.scheduler.VoidScheduler 2 | _partial_: true -------------------------------------------------------------------------------- /src/trainer/trainer/config/training/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_training_default 3 | -------------------------------------------------------------------------------- /src/trainer/trainer/crossplatform_util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | 7 | def filter_args_for_ai_platform(): 8 | """ 9 | This is to filter out "--job-dir " which is passed from AI Platform training command, 10 | """ 11 | key = "--job_dir" 12 | if key in sys.argv: 13 | logger.warning(f"{key} removed") 14 | arguments = sys.argv 15 | ind = arguments.index(key) 16 | sys.argv = [a for (i, a) in enumerate(arguments) if i not in [ind, ind + 1]] 17 | 18 | key = "--job-dir" 19 | for i, arg in enumerate(sys.argv): 20 | if len(arg) >= len(key) and arg[: len(key)] == key: 21 | sys.argv = [a for (j, a) in enumerate(sys.argv) if i != j] 22 | -------------------------------------------------------------------------------- /src/trainer/trainer/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyberAgentAILab/layout-dm/873b5eebe4c61862e5c08a10859accf65a168dfd/src/trainer/trainer/data/__init__.py -------------------------------------------------------------------------------- /src/trainer/trainer/data/util.py: -------------------------------------------------------------------------------- 1 | import random 2 | from enum import IntEnum 3 | from itertools import combinations, product 4 | from typing import List, Tuple 5 | 6 | import numpy as np 7 | import torch 8 | import torchvision.transforms as T 9 | from torch import BoolTensor, FloatTensor, LongTensor 10 | from torch_geometric.utils import to_dense_batch 11 | from trainer.helpers.util import convert_xywh_to_ltrb 12 | 13 | 14 | class RelSize(IntEnum): 15 | UNKNOWN = 0 16 | SMALLER = 1 17 | EQUAL = 2 18 | LARGER = 3 19 | 20 | 21 | class RelLoc(IntEnum): 22 | UNKNOWN = 4 23 | LEFT = 5 24 | TOP = 6 25 | RIGHT = 7 26 | BOTTOM = 8 27 | CENTER = 9 28 | 29 | 30 | REL_SIZE_ALPHA = 0.1 31 | 32 | 33 | def detect_size_relation(b1, b2): 34 | a1 = b1[2] * b1[3] 35 | a2 = b2[2] * b2[3] 36 | alpha = REL_SIZE_ALPHA 37 | if (1 - alpha) * a1 < a2 < (1 + alpha) * a1: 38 | return RelSize.EQUAL 39 | elif a1 < a2: 40 | return RelSize.LARGER 41 | else: 42 | return RelSize.SMALLER 43 | 44 | 45 | def detect_loc_relation(b1, b2, is_canvas=False): 46 | if is_canvas: 47 | yc = b2[1] 48 | if yc < 1.0 / 3: 49 | return RelLoc.TOP 50 | elif yc < 2.0 / 3: 51 | return RelLoc.CENTER 52 | else: 53 | return RelLoc.BOTTOM 54 | 55 | else: 56 | l1, t1, r1, b1 = convert_xywh_to_ltrb(b1) 57 | l2, t2, r2, b2 = convert_xywh_to_ltrb(b2) 58 | 59 | if b2 <= t1: 60 | return RelLoc.TOP 61 | elif b1 <= t2: 62 | return RelLoc.BOTTOM 63 | elif r2 <= l1: 64 | return RelLoc.LEFT 65 | elif r1 <= l2: 66 | return RelLoc.RIGHT 67 | else: 68 | # might not be necessary 69 | return RelLoc.CENTER 70 | 71 | 72 | def get_rel_text(rel, canvas=False): 73 | if type(rel) == RelSize: 74 | index = rel - RelSize.UNKNOWN - 1 75 | if canvas: 76 | return [ 77 | "within canvas", 78 | "spread over canvas", 79 | "out of canvas", 80 | ][index] 81 | 82 | else: 83 | return [ 84 | "larger than", 85 | "equal to", 86 | "smaller than", 87 | ][index] 88 | 89 | else: 90 | index = rel - RelLoc.UNKNOWN - 1 91 | if canvas: 92 | return [ 93 | "", 94 | "at top", 95 | "", 96 | "at bottom", 97 | "at middle", 98 | ][index] 99 | 100 | else: 101 | return [ 102 | "right to", 103 | "below", 104 | "left to", 105 | "above", 106 | "around", 107 | ][index] 108 | 109 | 110 | # transform 111 | class AddCanvasElement: 112 | x = torch.tensor([[0.5, 0.5, 1.0, 1.0]], dtype=torch.float) 113 | y = torch.tensor([0], dtype=torch.long) 114 | 115 | def __call__(self, data): 116 | flag = data.attr["has_canvas_element"].any().item() 117 | assert not flag 118 | if not flag: 119 | # device = data.x.device 120 | # x, y = self.x.to(device), self.y.to(device) 121 | data.x = torch.cat([self.x, data.x], dim=0) 122 | data.y = torch.cat([self.y, data.y + 1], dim=0) 123 | data.attr = data.attr.copy() 124 | data.attr["has_canvas_element"] = True 125 | return data 126 | 127 | 128 | class AddRelationConstraints: 129 | def __init__(self, seed=None, edge_ratio=0.1, use_v1=False): 130 | self.edge_ratio = edge_ratio 131 | self.use_v1 = use_v1 132 | self.generator = random.Random() 133 | if seed is not None: 134 | self.generator.seed(seed) 135 | 136 | def __call__(self, data): 137 | N = data.x.size(0) 138 | has_canvas = data.attr["has_canvas_element"] 139 | 140 | rel_all = list(product(range(2), combinations(range(N), 2))) 141 | size = int(len(rel_all) * self.edge_ratio) 142 | rel_sample = set(self.generator.sample(rel_all, size)) 143 | 144 | edge_index, edge_attr = [], [] 145 | rel_unk = 1 << RelSize.UNKNOWN | 1 << RelLoc.UNKNOWN 146 | for i, j in combinations(range(N), 2): 147 | bi, bj = data.x[i], data.x[j] 148 | canvas = data.y[i] == 0 and has_canvas 149 | 150 | if self.use_v1: 151 | if (0, (i, j)) in rel_sample: 152 | rel_size = 1 << detect_size_relation(bi, bj) 153 | rel_loc = 1 << detect_loc_relation(bi, bj, canvas) 154 | else: 155 | rel_size = 1 << RelSize.UNKNOWN 156 | rel_loc = 1 << RelLoc.UNKNOWN 157 | else: 158 | if (0, (i, j)) in rel_sample: 159 | rel_size = 1 << detect_size_relation(bi, bj) 160 | else: 161 | rel_size = 1 << RelSize.UNKNOWN 162 | 163 | if (1, (i, j)) in rel_sample: 164 | rel_loc = 1 << detect_loc_relation(bi, bj, canvas) 165 | else: 166 | rel_loc = 1 << RelLoc.UNKNOWN 167 | 168 | rel = rel_size | rel_loc 169 | if rel != rel_unk: 170 | edge_index.append((i, j)) 171 | edge_attr.append(rel) 172 | 173 | data.edge_index = torch.as_tensor(edge_index).long() 174 | data.edge_index = data.edge_index.t().contiguous() 175 | data.edge_attr = torch.as_tensor(edge_attr).long() 176 | 177 | return data 178 | 179 | 180 | class RandomOrder: 181 | def __call__(self, data): 182 | assert not data.attr["has_canvas_element"] 183 | device = data.x.device 184 | N = data.x.size(0) 185 | idx = torch.randperm(N, device=device) 186 | data.x, data.y = data.x[idx], data.y[idx] 187 | return data 188 | 189 | 190 | class SortByLabel: 191 | def __call__(self, data): 192 | assert not data.attr["has_canvas_element"] 193 | idx = data.y.sort().indices 194 | data.x, data.y = data.x[idx], data.y[idx] 195 | return data 196 | 197 | 198 | class LexicographicOrder: 199 | def __call__(self, data): 200 | assert not data.attr["has_canvas_element"] 201 | x, y, _, _ = convert_xywh_to_ltrb(data.x.t()) 202 | _zip = zip(*sorted(enumerate(zip(y, x)), key=lambda c: c[1:])) 203 | idx = list(list(_zip)[0]) 204 | data.x_orig, data.y_orig = data.x, data.y 205 | data.x, data.y = data.x[idx], data.y[idx] 206 | return data 207 | 208 | 209 | class AddNoiseToBBox: 210 | def __init__(self, std: float = 0.05): 211 | self.std = float(std) 212 | 213 | def __call__(self, data): 214 | noise = torch.normal(0, self.std, size=data.x.size(), device=data.x.device) 215 | data.x_orig = data.x.clone() 216 | data.x = data.x + noise 217 | data.attr = data.attr.copy() 218 | data.attr["NoiseAdded"][0] = True 219 | return data 220 | 221 | 222 | class HorizontalFlip: 223 | def __call__(self, data): 224 | data.x = data.x.clone() 225 | data.x[:, 0] = 1 - data.x[:, 0] 226 | return data 227 | 228 | 229 | # def compose_transform(transforms): 230 | # module = sys.modules[__name__] 231 | # transform_list = [] 232 | # for t in transforms: 233 | # # parse args 234 | # if "(" in t and ")" in t: 235 | # args = t[t.index("(") + 1 : t.index(")")] 236 | # t = t[: t.index("(")] 237 | # regex = re.compile(r"\b(\w+)=(.*?)(?=\s\w+=\s*|$)") 238 | # args = dict(regex.findall(args)) 239 | # for k in args: 240 | # try: 241 | # args[k] = float(args[k]) 242 | # except: 243 | # pass 244 | # else: 245 | # args = {} 246 | # if isinstance(t, str): 247 | # if hasattr(module, t): 248 | # transform_list.append(getattr(module, t)(**args)) 249 | # else: 250 | # raise NotImplementedError 251 | # else: 252 | # raise NotImplementedError 253 | # return T.Compose(transform_list) 254 | 255 | 256 | def compose_transform(transforms: List[str]) -> T.Compose: 257 | """ 258 | Compose transforms, optionally with args (e.g., AddRelationConstraints(edge_ratio=0.1)) 259 | """ 260 | transform_list = [] 261 | for t in transforms: 262 | if "(" in t and ")" in t: 263 | pass 264 | else: 265 | t += "()" 266 | transform_list.append(eval(t)) 267 | return T.Compose(transform_list) 268 | 269 | 270 | def sparse_to_dense( 271 | batch, 272 | device: torch.device = torch.device("cpu"), 273 | remove_canvas: bool = False, 274 | ) -> Tuple[FloatTensor, LongTensor, BoolTensor, BoolTensor]: 275 | batch = batch.to(device) 276 | bbox, _ = to_dense_batch(batch.x, batch.batch) 277 | label, mask = to_dense_batch(batch.y, batch.batch) 278 | 279 | if remove_canvas: 280 | bbox = bbox[:, 1:].contiguous() 281 | label = label[:, 1:].contiguous() - 1 # cancel +1 effect in transform 282 | label = label.clamp(min=0) 283 | mask = mask[:, 1:].contiguous() 284 | 285 | padding_mask = ~mask 286 | return bbox, label, padding_mask, mask 287 | 288 | 289 | def loader_to_list( 290 | loader: torch.utils.data.dataloader.DataLoader, 291 | ) -> List[Tuple[np.ndarray, np.ndarray]]: 292 | layouts = [] 293 | for batch in loader: 294 | bbox, label, _, mask = sparse_to_dense(batch) 295 | for i in range(len(label)): 296 | valid = mask[i].numpy() 297 | layouts.append((bbox[i].numpy()[valid], label[i].numpy()[valid])) 298 | return layouts 299 | 300 | 301 | def split_num_samples(N: int, batch_size: int) -> List[int]: 302 | quontinent = N // batch_size 303 | remainder = N % batch_size 304 | dataloader = quontinent * [batch_size] 305 | if remainder > 0: 306 | dataloader.append(remainder) 307 | return dataloader 308 | -------------------------------------------------------------------------------- /src/trainer/trainer/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .publaynet import PubLayNetDataset 2 | from .rico import Rico25Dataset 3 | 4 | _DATASETS = [ 5 | Rico25Dataset, 6 | PubLayNetDataset, 7 | ] 8 | DATASETS = {d.name: d for d in _DATASETS} 9 | -------------------------------------------------------------------------------- /src/trainer/trainer/datasets/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import fsspec 4 | import seaborn as sns 5 | import torch 6 | from fsspec.core import url_to_fs 7 | 8 | from .dataset import InMemoryDataset 9 | 10 | 11 | class BaseDataset(InMemoryDataset): 12 | name = None 13 | labels = [] 14 | _label2index = None 15 | _index2label = None 16 | _colors = None 17 | 18 | def __init__(self, dir: str, split: str, max_seq_length: int, transform=None): 19 | assert split in ["train", "val", "test"] 20 | name = f"{self.name}-max{max_seq_length}" 21 | self.max_seq_length = max_seq_length 22 | super().__init__(os.path.join(dir, name), transform) 23 | idx = self.processed_file_names.index("{}.pt".format(split)) 24 | 25 | with fsspec.open(self.processed_paths[idx], "rb") as file_obj: 26 | self.data, self.slices = torch.load(file_obj) 27 | 28 | @property 29 | def label2index(self): 30 | if self._label2index is None: 31 | self._label2index = dict() 32 | for idx, label in enumerate(self.labels): 33 | self._label2index[label] = idx 34 | return self._label2index 35 | 36 | @property 37 | def index2label(self): 38 | if self._index2label is None: 39 | self._index2label = dict() 40 | for idx, label in enumerate(self.labels): 41 | self._index2label[idx] = label 42 | return self._index2label 43 | 44 | @property 45 | def colors(self): 46 | if self._colors is None: 47 | n_colors = self.num_classes 48 | colors = sns.color_palette("husl", n_colors=n_colors) 49 | self._colors = [tuple(map(lambda x: int(x * 255), c)) for c in colors] 50 | return self._colors 51 | 52 | @property 53 | def raw_file_names(self): 54 | fs, _ = url_to_fs(self.raw_dir) 55 | if not fs.exists(self.raw_dir): 56 | return [] 57 | file_names = [f.split("/")[-1] for f in fs.ls(self.raw_dir)] 58 | return file_names 59 | 60 | @property 61 | def processed_file_names(self): 62 | return ["train.pt", "val.pt", "test.pt"] 63 | 64 | def download(self): 65 | raise FileNotFoundError("See dataset/README.md") 66 | 67 | def process(self): 68 | raise NotImplementedError 69 | 70 | def get_original_images(self): 71 | raise NotImplementedError 72 | -------------------------------------------------------------------------------- /src/trainer/trainer/datasets/publaynet.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from fsspec.core import url_to_fs 5 | from torch_geometric.data import Data 6 | from tqdm import tqdm 7 | 8 | from .base import BaseDataset 9 | 10 | 11 | class PubLayNetDataset(BaseDataset): 12 | name = "publaynet" 13 | labels = [ 14 | "text", 15 | "title", 16 | "list", 17 | "table", 18 | "figure", 19 | ] 20 | 21 | def __init__(self, dir: str, split: str, max_seq_length: int, transform=None): 22 | super().__init__(dir, split, max_seq_length, transform) 23 | 24 | def download(self): 25 | # super().download() 26 | pass 27 | 28 | def process(self): 29 | from pycocotools.coco import COCO 30 | 31 | fs, _ = url_to_fs(self.raw_dir) 32 | 33 | # if self.raw_dir.startswith("gs://"): 34 | # raise NotImplementedError 35 | 36 | for split_publaynet in ["train", "val"]: 37 | data_list = [] 38 | coco = COCO( 39 | os.path.join(self.raw_dir, "publaynet", f"{split_publaynet}.json") 40 | ) 41 | for img_id in tqdm(sorted(coco.getImgIds())): 42 | ann_img = coco.loadImgs(img_id) 43 | W = float(ann_img[0]["width"]) 44 | H = float(ann_img[0]["height"]) 45 | name = ann_img[0]["file_name"] 46 | if H < W: 47 | continue 48 | 49 | def is_valid(element): 50 | x1, y1, width, height = element["bbox"] 51 | x2, y2 = x1 + width, y1 + height 52 | if x1 < 0 or y1 < 0 or W < x2 or H < y2: 53 | return False 54 | 55 | if x2 <= x1 or y2 <= y1: 56 | return False 57 | 58 | return True 59 | 60 | elements = coco.loadAnns(coco.getAnnIds(imgIds=[img_id])) 61 | _elements = list(filter(is_valid, elements)) 62 | filtered = len(elements) != len(_elements) 63 | elements = _elements 64 | 65 | N = len(elements) 66 | if N == 0 or self.max_seq_length < N: 67 | continue 68 | 69 | boxes = [] 70 | labels = [] 71 | 72 | for element in elements: 73 | # bbox 74 | x1, y1, width, height = element["bbox"] 75 | xc = x1 + width / 2.0 76 | yc = y1 + height / 2.0 77 | b = [xc / W, yc / H, width / W, height / H] 78 | boxes.append(b) 79 | 80 | # label 81 | l = coco.cats[element["category_id"]]["name"] 82 | labels.append(self.label2index[l]) 83 | 84 | boxes = torch.tensor(boxes, dtype=torch.float) 85 | labels = torch.tensor(labels, dtype=torch.long) 86 | 87 | data = Data(x=boxes, y=labels) 88 | data.attr = { 89 | "name": name, 90 | "width": W, 91 | "height": H, 92 | "filtered": filtered, 93 | "has_canvas_element": False, 94 | "NoiseAdded": False, 95 | } 96 | data_list.append(data) 97 | 98 | if split_publaynet == "train": 99 | train_list = data_list 100 | else: 101 | val_list = data_list 102 | 103 | # shuffle train with seed 104 | generator = torch.Generator().manual_seed(0) 105 | indices = torch.randperm(len(train_list), generator=generator) 106 | train_list = [train_list[i] for i in indices] 107 | 108 | # train_list -> train 95% / val 5% 109 | # val_list -> test 100% 110 | s = int(len(train_list) * 0.95) 111 | with fs.open(self.processed_paths[0], "wb") as file_obj: 112 | torch.save(self.collate(train_list[:s]), file_obj) 113 | with fs.open(self.processed_paths[1], "wb") as file_obj: 114 | torch.save(self.collate(train_list[s:]), file_obj) 115 | with fs.open(self.processed_paths[2], "wb") as file_obj: 116 | torch.save(self.collate(val_list), file_obj) 117 | -------------------------------------------------------------------------------- /src/trainer/trainer/datasets/rico.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import os 4 | from pathlib import Path 5 | from zipfile import ZipFile 6 | 7 | import numpy as np 8 | import torch 9 | from fsspec.core import url_to_fs 10 | from PIL import Image, ImageDraw 11 | from torch_geometric.data import Data 12 | from tqdm import tqdm 13 | from trainer.data.util import sparse_to_dense 14 | from trainer.helpers.util import convert_xywh_to_ltrb 15 | 16 | from .base import BaseDataset 17 | 18 | _rico5_labels = [ 19 | "Text", 20 | "Text Button", 21 | "Toolbar", 22 | "Image", 23 | "Icon", 24 | ] 25 | 26 | _rico13_labels = [ 27 | "Toolbar", 28 | "Image", 29 | "Text", 30 | "Icon", 31 | "Text Button", 32 | "Input", 33 | "List Item", 34 | "Advertisement", 35 | "Pager Indicator", 36 | "Web View", 37 | "Background Image", 38 | "Drawer", 39 | "Modal", 40 | ] 41 | 42 | _rico25_labels = [ 43 | "Text", 44 | "Image", 45 | "Icon", 46 | "Text Button", 47 | "List Item", 48 | "Input", 49 | "Background Image", 50 | "Card", 51 | "Web View", 52 | "Radio Button", 53 | "Drawer", 54 | "Checkbox", 55 | "Advertisement", 56 | "Modal", 57 | "Pager Indicator", 58 | "Slider", 59 | "On/Off Switch", 60 | "Button Bar", 61 | "Toolbar", 62 | "Number Stepper", 63 | "Multi-Tab", 64 | "Date Picker", 65 | "Map View", 66 | "Video", 67 | "Bottom Navigation", 68 | ] 69 | 70 | 71 | def append_child(element, elements): 72 | if "children" in element.keys(): 73 | for child in element["children"]: 74 | elements.append(child) 75 | elements = append_child(child, elements) 76 | return elements 77 | 78 | 79 | class _RicoDataset(BaseDataset): 80 | def __init__(self, dir: str, split: str, max_seq_length: int, transform=None): 81 | super().__init__(dir, split, max_seq_length, transform) 82 | 83 | def process(self): 84 | data_list = [] 85 | raw_file = os.path.join( 86 | self.raw_dir, "rico_dataset_v0.1_semantic_annotations.zip" 87 | ) 88 | fs, _ = url_to_fs(self.raw_dir) 89 | with fs.open(raw_file, "rb") as f, ZipFile(f) as z: 90 | names = sorted([n for n in z.namelist() if n.endswith(".json")]) 91 | for name in tqdm(names): 92 | ann = json.loads(z.open(name).read()) 93 | 94 | B = ann["bounds"] 95 | W, H = float(B[2]), float(B[3]) 96 | if B[0] != 0 or B[1] != 0 or H < W: 97 | continue 98 | 99 | def is_valid(element): 100 | if element["componentLabel"] not in set(self.labels): 101 | print(element["componentLabel"]) 102 | return False 103 | 104 | x1, y1, x2, y2 = element["bounds"] 105 | if x1 < 0 or y1 < 0 or W < x2 or H < y2: 106 | return False 107 | 108 | if x2 <= x1 or y2 <= y1: 109 | return False 110 | 111 | return True 112 | 113 | elements = append_child(ann, []) 114 | _elements = list(filter(is_valid, elements)) 115 | filtered = len(elements) != len(_elements) 116 | elements = _elements 117 | N = len(elements) 118 | if N == 0 or self.max_seq_length < N: 119 | continue 120 | 121 | # only for debugging slice-based preprocessing 122 | # elements = append_child(ann, []) 123 | # filtered = False 124 | # if len(elements) == 0: 125 | # continue 126 | # elements = elements[:self.max_seq_length] 127 | 128 | boxes = [] 129 | labels = [] 130 | 131 | for element in elements: 132 | # bbox 133 | x1, y1, x2, y2 = element["bounds"] 134 | xc = (x1 + x2) / 2.0 135 | yc = (y1 + y2) / 2.0 136 | width = x2 - x1 137 | height = y2 - y1 138 | b = [xc / W, yc / H, width / W, height / H] 139 | boxes.append(b) 140 | 141 | # label 142 | l = element["componentLabel"] 143 | labels.append(self.label2index[l]) 144 | 145 | boxes = torch.tensor(boxes, dtype=torch.float) 146 | labels = torch.tensor(labels, dtype=torch.long) 147 | 148 | data = Data(x=boxes, y=labels) 149 | data.attr = { 150 | "name": name, 151 | "width": W, 152 | "height": H, 153 | "filtered": filtered, 154 | "has_canvas_element": False, 155 | "NoiseAdded": False, 156 | } 157 | data_list.append(data) 158 | 159 | # shuffle with seed 160 | generator = torch.Generator().manual_seed(0) 161 | indices = torch.randperm(len(data_list), generator=generator) 162 | data_list = [data_list[i] for i in indices] 163 | 164 | # train 85% / val 5% / test 10% 165 | N = len(data_list) 166 | s = [int(N * 0.85), int(N * 0.90)] 167 | 168 | with fs.open(self.processed_paths[0], "wb") as file_obj: 169 | torch.save(self.collate(data_list[: s[0]]), file_obj) 170 | with fs.open(self.processed_paths[1], "wb") as file_obj: 171 | torch.save(self.collate(data_list[s[0] : s[1]]), file_obj) 172 | with fs.open(self.processed_paths[2], "wb") as file_obj: 173 | torch.save(self.collate(data_list[s[1] :]), file_obj) 174 | 175 | def download(self): 176 | pass 177 | 178 | def get_original_resource(self, batch) -> Image: 179 | assert not self.raw_dir.startswith("gs://") 180 | bbox, _, _, _ = sparse_to_dense(batch) 181 | 182 | img_bg, img_original, cropped_patches = [], [], [] 183 | names = batch.attr["name"] 184 | if isinstance(names, str): 185 | names = [names] 186 | 187 | for i, name in enumerate(names): 188 | name = Path(name).name.replace(".json", ".jpg") 189 | img = Image.open(Path(self.raw_dir) / "combined" / name) 190 | img_original.append(copy.deepcopy(img)) 191 | 192 | W, H = img.size 193 | ltrb = convert_xywh_to_ltrb(bbox[i].T.numpy()) 194 | left, right = (ltrb[0] * W).astype(np.uint32), (ltrb[2] * W).astype( 195 | np.uint32 196 | ) 197 | top, bottom = (ltrb[1] * H).astype(np.uint32), (ltrb[3] * H).astype( 198 | np.uint32 199 | ) 200 | draw = ImageDraw.Draw(img) 201 | patches = [] 202 | for (l, r, t, b) in zip(left, right, top, bottom): 203 | patches.append(img.crop((l, t, r, b))) 204 | # draw.rectangle([(l, t), (r, b)], fill=(255, 0, 0)) 205 | draw.rectangle([(l, t), (r, b)], fill=(255, 255, 255)) 206 | img_bg.append(img) 207 | cropped_patches.append(patches) 208 | # if len(patches) < S: 209 | # for i in range(S - len(patches)): 210 | # patches.append(Image.new("RGB", (0, 0))) 211 | 212 | return { 213 | "img_bg": img_bg, 214 | "img_original": img_original, 215 | "cropped_patches": cropped_patches, 216 | } 217 | 218 | # read from uncompressed data (the last line takes infinite time, so not used now..) 219 | # raw_file = os.path.join(self.raw_dir, "unique_uis.tar.gz") 220 | # with tarfile.open(raw_file) as f: 221 | # # return gzip.GzipFile(fileobj=f.extractfile(f"combined/{name}")).read() 222 | # return gzip.GzipFile(fileobj=f.extractfile(f"combined/hoge")).read() 223 | 224 | 225 | class Rico5Dataset(_RicoDataset): 226 | name = "rico5" 227 | labels = _rico5_labels 228 | 229 | def __init__(self, dir: str, split: str, max_seq_length: int, transform=None): 230 | super().__init__(dir, split, max_seq_length, transform) 231 | 232 | 233 | # Constrained Graphic Layout Generation via Latent Optimization (ACMMM2021) 234 | class Rico13Dataset(_RicoDataset): 235 | name = "rico13" 236 | labels = _rico13_labels 237 | 238 | def __init__(self, dir: str, split: str, max_seq_length: int, transform=None): 239 | super().__init__(dir, split, max_seq_length, transform) 240 | 241 | 242 | class Rico25Dataset(_RicoDataset): 243 | name = "rico25" 244 | labels = _rico25_labels 245 | 246 | def __init__(self, dir: str, split: str, max_seq_length: int, transform=None): 247 | super().__init__(dir, split, max_seq_length, transform) 248 | -------------------------------------------------------------------------------- /src/trainer/trainer/fid/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyberAgentAILab/layout-dm/873b5eebe4c61862e5c08a10859accf65a168dfd/src/trainer/trainer/fid/__init__.py -------------------------------------------------------------------------------- /src/trainer/trainer/fid/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import fsspec 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class TransformerWithToken(nn.Module): 9 | def __init__(self, d_model, nhead, dim_feedforward, num_layers): 10 | super().__init__() 11 | 12 | self.token = nn.Parameter(torch.randn(1, 1, d_model)) 13 | token_mask = torch.zeros(1, 1, dtype=torch.bool) 14 | self.register_buffer("token_mask", token_mask) 15 | 16 | self.core = nn.TransformerEncoder( 17 | nn.TransformerEncoderLayer( 18 | d_model=d_model, 19 | nhead=nhead, 20 | dim_feedforward=dim_feedforward, 21 | ), 22 | num_layers=num_layers, 23 | ) 24 | 25 | def forward(self, x, src_key_padding_mask): 26 | # x: [N, B, E] 27 | # padding_mask: [B, N] 28 | # `False` for valid values 29 | # `True` for padded values 30 | 31 | B = x.size(1) 32 | 33 | token = self.token.expand(-1, B, -1) 34 | x = torch.cat([token, x], dim=0) 35 | 36 | token_mask = self.token_mask.expand(B, -1) 37 | padding_mask = torch.cat([token_mask, src_key_padding_mask], dim=1) 38 | 39 | x = self.core(x, src_key_padding_mask=padding_mask) 40 | 41 | return x 42 | 43 | 44 | class FIDNet(nn.Module): 45 | def __init__(self, num_label): 46 | super().__init__() 47 | 48 | self.emb_label = nn.Embedding(num_label, 32) 49 | self.fc_bbox = nn.Linear(4, 32) 50 | self.transformer = TransformerWithToken( 51 | d_model=64, nhead=4, dim_feedforward=32, num_layers=4 52 | ) 53 | self.fc_out = nn.Linear(64, 1) 54 | 55 | def extract_features(self, bbox, label, padding_mask): 56 | l = self.emb_label(label) 57 | b = self.fc_bbox(bbox) 58 | x = torch.cat([l, b], dim=-1).permute(1, 0, 2) 59 | x = self.transformer(x, padding_mask) 60 | return x[0] 61 | 62 | def forward(self, bbox, label, padding_mask): 63 | x = self.extract_features(bbox, label, padding_mask) 64 | x = self.fc_out(x) 65 | return x.squeeze(-1) 66 | 67 | 68 | class FIDNetV2(nn.Module): 69 | def __init__(self, num_label, max_bbox=50): 70 | super().__init__() 71 | 72 | self.emb_label = nn.Embedding(num_label, 128) 73 | self.fc_bbox = nn.Linear(4, 128) 74 | self.encoder = TransformerWithToken( 75 | d_model=256, nhead=4, dim_feedforward=128, num_layers=8 76 | ) 77 | 78 | self.fc_out = nn.Sequential( 79 | nn.Linear(256, 128), 80 | nn.BatchNorm1d(128), 81 | nn.ReLU(), 82 | nn.Linear(128, 64), 83 | nn.BatchNorm1d(64), 84 | nn.ReLU(), 85 | nn.Linear(64, 1), 86 | ) 87 | 88 | self.token = nn.Parameter(torch.rand(max_bbox, 1, 256)) 89 | te = nn.TransformerEncoderLayer(d_model=256, dim_feedforward=128, nhead=4) 90 | self.decoder = nn.TransformerEncoder(te, num_layers=8) 91 | self.fc_out_cls = nn.Linear(256, num_label) 92 | self.fc_out_bbox = nn.Linear(256, 4) 93 | 94 | def extract_features(self, bbox, label, padding_mask): 95 | l = self.emb_label(label) 96 | b = self.fc_bbox(bbox) 97 | x = torch.cat([l, b], dim=-1).permute(1, 0, 2) 98 | x = self.encoder(x, padding_mask) 99 | return x[0] 100 | 101 | def forward(self, bbox, label, padding_mask): 102 | B, N, _ = bbox.size() 103 | x = self.extract_features(bbox, label, padding_mask) 104 | 105 | logit = self.fc_out(x).squeeze(-1) 106 | 107 | t = self.token[:N].expand(-1, B, -1) 108 | x = torch.cat([x.unsqueeze(0), t], dim=0) 109 | 110 | token_mask = self.encoder.token_mask.expand(B, -1) 111 | _padding_mask = torch.cat([token_mask, padding_mask], dim=1) 112 | 113 | x = self.decoder(x, src_key_padding_mask=_padding_mask) 114 | # x = x[1:].permute(1, 0, 2)[~padding_mask] 115 | x = x[1:].permute(1, 0, 2) 116 | 117 | logit_cls = self.fc_out_cls(x) 118 | bbox = torch.sigmoid(self.fc_out_bbox(x)) 119 | 120 | return logit, logit_cls, bbox 121 | 122 | 123 | class FIDNetV3(nn.Module): 124 | def __init__(self, num_label, d_model=256, nhead=4, num_layers=4, max_bbox=50): 125 | super().__init__() 126 | 127 | # encoder 128 | self.emb_label = nn.Embedding(num_label, d_model) 129 | self.fc_bbox = nn.Linear(4, d_model) 130 | self.enc_fc_in = nn.Linear(d_model * 2, d_model) 131 | 132 | self.enc_transformer = TransformerWithToken( 133 | d_model=d_model, 134 | dim_feedforward=d_model // 2, 135 | nhead=nhead, 136 | num_layers=num_layers, 137 | ) 138 | 139 | self.fc_out_disc = nn.Linear(d_model, 1) 140 | 141 | # decoder 142 | self.pos_token = nn.Parameter(torch.rand(max_bbox, 1, d_model)) 143 | self.dec_fc_in = nn.Linear(d_model * 2, d_model) 144 | 145 | te = nn.TransformerEncoderLayer( 146 | d_model=d_model, nhead=nhead, dim_feedforward=d_model // 2 147 | ) 148 | self.dec_transformer = nn.TransformerEncoder(te, num_layers=num_layers) 149 | 150 | self.fc_out_cls = nn.Linear(d_model, num_label) 151 | self.fc_out_bbox = nn.Linear(d_model, 4) 152 | 153 | def extract_features(self, bbox, label, padding_mask): 154 | b = self.fc_bbox(bbox) 155 | l = self.emb_label(label) 156 | x = self.enc_fc_in(torch.cat([b, l], dim=-1)) 157 | x = torch.relu(x).permute(1, 0, 2) 158 | x = self.enc_transformer(x, padding_mask) 159 | return x[0] 160 | 161 | def forward(self, bbox, label, padding_mask): 162 | B, N, _ = bbox.size() 163 | x = self.extract_features(bbox, label, padding_mask) 164 | 165 | logit_disc = self.fc_out_disc(x).squeeze(-1) 166 | 167 | x = x.unsqueeze(0).expand(N, -1, -1) 168 | t = self.pos_token[:N].expand(-1, B, -1) 169 | x = torch.cat([x, t], dim=-1) 170 | x = torch.relu(self.dec_fc_in(x)) 171 | 172 | x = self.dec_transformer(x, src_key_padding_mask=padding_mask) 173 | # x = x.permute(1, 0, 2)[~padding_mask] 174 | x = x.permute(1, 0, 2) 175 | 176 | # logit_cls: [B, N, L] bbox_pred: [B, N, 4] 177 | logit_cls = self.fc_out_cls(x) 178 | bbox_pred = torch.sigmoid(self.fc_out_bbox(x)) 179 | 180 | return logit_disc, logit_cls, bbox_pred 181 | 182 | 183 | def load_fidnet_v3(dataset, weight_dir: str, device: torch.device) -> nn.Module: 184 | prefix = f"{dataset.name}-max{dataset.max_seq_length}" 185 | ckpt_path = os.path.join(weight_dir, prefix, "model_best.pth.tar") 186 | fid_model = FIDNetV3( 187 | num_label=dataset.num_classes, max_bbox=dataset.max_seq_length 188 | ).to(device) 189 | with fsspec.open(ckpt_path, "rb") as file_obj: 190 | x = torch.load(file_obj, map_location=device) 191 | fid_model.load_state_dict(x["state_dict"]) 192 | fid_model.eval() 193 | return fid_model 194 | -------------------------------------------------------------------------------- /src/trainer/trainer/fid/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | from pathlib import Path 5 | 6 | import omegaconf 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torchvision.transforms as T 11 | from hydra.utils import instantiate 12 | from torch.utils.tensorboard import SummaryWriter 13 | from torch_geometric.loader import DataLoader 14 | from torch_geometric.utils import to_dense_batch 15 | from trainer.data.util import AddNoiseToBBox, LexicographicOrder 16 | from trainer.fid.model import FIDNetV3 17 | from trainer.global_configs import DATASET_DIR 18 | from trainer.helpers.visualization import save_image 19 | 20 | 21 | def save_checkpoint(state, is_best, out_dir): 22 | out_path = Path(out_dir) / "checkpoint.pth.tar" 23 | torch.save(state, out_path) 24 | 25 | if is_best: 26 | best_path = Path(out_dir) / "model_best.pth.tar" 27 | shutil.copyfile(out_path, best_path) 28 | 29 | 30 | def main(): 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("dataset_yaml", type=str) 33 | parser.add_argument("--out_dir", type=str, default="tmp/fid_weights") 34 | parser.add_argument("--batch_size", type=int, default=64, help="input batch size") 35 | parser.add_argument( 36 | "--iteration", 37 | type=int, 38 | default=int(2e5), 39 | help="number of iterations to train for", 40 | ) 41 | parser.add_argument( 42 | "--lr", type=float, default=3e-4, help="learning rate, default=3e-4" 43 | ) 44 | parser.add_argument("--seed", type=int, help="manual seed") 45 | args = parser.parse_args() 46 | print(args) 47 | 48 | dataset_cfg = omegaconf.OmegaConf.load(args.dataset_yaml) 49 | dataset_cfg["dir"] = DATASET_DIR 50 | 51 | prefix = "FIDNetV3" 52 | out_dir = Path(args.out_dir) 53 | out_dir.mkdir(parents=True, exist_ok=True) 54 | writer = SummaryWriter(str(out_dir / "logs")) 55 | 56 | transform = T.Compose( 57 | [ 58 | T.RandomApply([AddNoiseToBBox()], 0.5), 59 | LexicographicOrder(), 60 | ] 61 | ) 62 | train_dataset = instantiate(dataset_cfg)(split="train", transform=transform) 63 | val_dataset = instantiate(dataset_cfg)(split="test", transform=transform) 64 | categories = train_dataset.labels 65 | 66 | kwargs = { 67 | "batch_size": args.batch_size, 68 | "num_workers": os.cpu_count(), 69 | "pin_memory": True, 70 | } 71 | 72 | train_dataloader = DataLoader(train_dataset, shuffle=True, **kwargs) 73 | val_dataloader = DataLoader(val_dataset, shuffle=False, **kwargs) 74 | 75 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 76 | model = FIDNetV3(num_label=len(categories), max_bbox=dataset_cfg.max_seq_length).to( 77 | device 78 | ) 79 | 80 | # setup optimizer 81 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 82 | 83 | criterion_bce = nn.BCEWithLogitsLoss(reduction="none") 84 | criterion_label = nn.CrossEntropyLoss(reduction="none") 85 | criterion_bbox = nn.MSELoss(reduction="none") 86 | 87 | def proc_batch(batch): 88 | batch = batch.to(device) 89 | bbox, _ = to_dense_batch(batch.x, batch.batch) 90 | label, mask = to_dense_batch(batch.y, batch.batch) 91 | padding_mask = ~mask 92 | 93 | is_real = batch.attr["NoiseAdded"].float() 94 | return bbox, label, padding_mask, mask, is_real 95 | 96 | iteration = 0 97 | best_loss = 1e8 98 | max_epoch = args.iteration * args.batch_size / len(train_dataset) 99 | max_epoch = torch.ceil(torch.tensor(max_epoch)).int().item() 100 | for epoch in range(max_epoch): 101 | model.train() 102 | train_loss = { 103 | "Loss_BCE": 0, 104 | "Loss_Label": 0, 105 | "Loss_BBox": 0, 106 | } 107 | 108 | for i, batch in enumerate(train_dataloader): 109 | bbox, label, padding_mask, mask, is_real = proc_batch(batch) 110 | model.zero_grad() 111 | 112 | logit, logit_cls, bbox_pred = model(bbox, label, padding_mask) 113 | 114 | loss_bce = criterion_bce(logit, is_real) 115 | loss_label = criterion_label(logit_cls[mask], label[mask]) 116 | loss_bbox = criterion_bbox(bbox_pred[mask], bbox[mask]).sum(-1) 117 | loss = loss_bce.mean() + loss_label.mean() + 10 * loss_bbox.mean() 118 | loss.backward() 119 | 120 | optimizer.step() 121 | 122 | loss_bce_mean = loss_bce.mean().item() 123 | train_loss["Loss_BCE"] += loss_bce.sum().item() 124 | loss_label_mean = loss_label.mean().item() 125 | train_loss["Loss_Label"] += loss_label.sum().item() 126 | loss_bbox_mean = loss_bbox.mean().item() 127 | train_loss["Loss_BBox"] += loss_bbox.sum().item() 128 | 129 | # add data to tensorboard 130 | writer.add_scalar(prefix + "/Loss", loss.item(), iteration) 131 | writer.add_scalar(prefix + "/Loss_BCE", loss_bce_mean, iteration) 132 | writer.add_scalar(prefix + "/Loss_Label", loss_label_mean, iteration) 133 | writer.add_scalar(prefix + "/Loss_BBox", loss_bbox_mean, iteration) 134 | 135 | if i % 50 == 0: 136 | log_prefix = f"[{epoch}/{max_epoch}][{i}/{len(train_dataset) // args.batch_size}]" 137 | log = f"Loss: {loss.item():E}\tBCE: {loss_bce_mean:E}\tLabel: {loss_label_mean:E}\tBBox: {loss_bbox_mean:E}" 138 | print(f"{log_prefix}\t{log}") 139 | 140 | iteration += 1 141 | 142 | for key in train_loss.keys(): 143 | train_loss[key] /= len(train_dataset) 144 | 145 | model.eval() 146 | with torch.no_grad(): 147 | val_loss = { 148 | "Loss_BCE": 0, 149 | "Loss_Label": 0, 150 | "Loss_BBox": 0, 151 | } 152 | 153 | for i, batch in enumerate(val_dataloader): 154 | bbox, label, padding_mask, mask, is_real = proc_batch(batch) 155 | 156 | logit, logit_cls, bbox_pred = model(bbox, label, padding_mask) 157 | 158 | loss_bce = criterion_bce(logit, is_real) 159 | loss_label = criterion_label(logit_cls[mask], label[mask]) 160 | loss_bbox = criterion_bbox(bbox_pred[mask], bbox[mask]).sum(-1) 161 | 162 | val_loss["Loss_BCE"] += loss_bce.sum().item() 163 | val_loss["Loss_Label"] += loss_label.sum().item() 164 | val_loss["Loss_BBox"] += loss_bbox.sum().item() 165 | 166 | if i == 0 and epoch % 10 == 0: 167 | save_image( 168 | bbox, 169 | label, 170 | mask, 171 | val_dataset.colors, 172 | out_dir / f"samples_{epoch}.png", 173 | ) 174 | cls_pred = logit_cls.argmax(dim=-1) 175 | save_image( 176 | bbox_pred, 177 | cls_pred, 178 | mask, 179 | val_dataset.colors, 180 | out_dir / f"recon_samples_{epoch}.png", 181 | ) 182 | 183 | for key in val_loss.keys(): 184 | val_loss[key] /= len(val_dataset) 185 | 186 | writer.add_scalar(prefix + "/Epoch", epoch, iteration) 187 | tag_scalar_dict = { 188 | "train": sum(train_loss.values()), 189 | "val": sum(val_loss.values()), 190 | } 191 | writer.add_scalars(prefix + "/Loss_Epoch", tag_scalar_dict, iteration) 192 | for key in train_loss.keys(): 193 | tag_scalar_dict = {"train": train_loss[key], "val": val_loss[key]} 194 | writer.add_scalars(prefix + f"/{key}_Epoch", tag_scalar_dict, iteration) 195 | 196 | # do checkpointing 197 | val_loss = sum(val_loss.values()) 198 | is_best = val_loss < best_loss 199 | best_loss = min(val_loss, best_loss) 200 | 201 | save_checkpoint( 202 | { 203 | "epoch": epoch + 1, 204 | "state_dict": model.state_dict(), 205 | "best_loss": best_loss, 206 | "optimizer": optimizer.state_dict(), 207 | }, 208 | is_best, 209 | out_dir, 210 | ) 211 | 212 | 213 | if __name__ == "__main__": 214 | main() 215 | -------------------------------------------------------------------------------- /src/trainer/trainer/global_configs.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | ROOT = f"{str(Path(__file__).parent)}/../../../download" 4 | KMEANS_WEIGHT_ROOT = f"{ROOT}/clustering_weights" 5 | DATASET_DIR = f"{ROOT}/datasets" 6 | FID_WEIGHT_DIR = f"{ROOT}/fid_weights/FIDNetV3" 7 | JOB_DIR = f"{ROOT}/pretrained_weights" 8 | -------------------------------------------------------------------------------- /src/trainer/trainer/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyberAgentAILab/layout-dm/873b5eebe4c61862e5c08a10859accf65a168dfd/src/trainer/trainer/helpers/__init__.py -------------------------------------------------------------------------------- /src/trainer/trainer/helpers/bbox_tokenizer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import pickle 4 | from typing import Dict, List, Union 5 | 6 | import fsspec 7 | import numpy as np 8 | import torch 9 | from einops import rearrange 10 | from sklearn.cluster import KMeans 11 | from torch import BoolTensor, FloatTensor, LongTensor 12 | from trainer.global_configs import KMEANS_WEIGHT_ROOT 13 | from trainer.helpers.clustering import Percentile 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | KEY_MULT_DICT = { 18 | "x-y-w-h": {"y": 1, "w": 2, "h": 3}, 19 | "xywh": {}, 20 | } 21 | 22 | 23 | class DummyClusteringModel: 24 | def __init__(self, cluster_centers: np.ndarray): 25 | self.cluster_centers_ = cluster_centers 26 | 27 | 28 | class BboxTokenizer: 29 | """ 30 | If N is number of bins, 0 <= x, y <= (N - 1) / N and 1 / N <= w, h <= 1 31 | 'bbox' variable is assumed to have "xywh" order 32 | """ 33 | 34 | def __init__( 35 | self, 36 | num_bin_bboxes: int, 37 | var_order: str = "c-x-y-w-h", 38 | shared_bbox_vocab: str = "xywh", 39 | bbox_quantization: str = "linear", 40 | dataset_name: str = "rico25_max25", 41 | ): 42 | # if bbox_quantization == "kmeans": 43 | # assert shared_bbox_vocab == "x-y-w-h" 44 | 45 | self._num_bin_bboxes = num_bin_bboxes 46 | self._var_order = var_order.lstrip("c-").split("-") 47 | self._shared_bbox_vocab = shared_bbox_vocab 48 | self._bbox_quantization = bbox_quantization 49 | self._dataset_name = dataset_name 50 | self._var_names = ["x", "y", "w", "h"] 51 | 52 | self._clustering_models = {} 53 | if self.bbox_quantization in ["kmeans", "percentile"]: 54 | name = f"{dataset_name}_{self.bbox_quantization}_train_clusters.pkl" 55 | path = f"{KMEANS_WEIGHT_ROOT}/{name}" 56 | with fsspec.open(path, "rb") as f: 57 | valid_keys = [f"{k}-{self.num_bin_bboxes}" for k in self._var_names] 58 | for key, model in pickle.load(f).items(): 59 | if key not in valid_keys: 60 | continue 61 | 62 | # sort cluster center in 1d case 63 | var_name = key.split("-")[0] 64 | if len(var_name) == 1: 65 | cluster_centers = np.sort( 66 | model.cluster_centers_, axis=0 67 | ) # (N, 1) 68 | model.cluster_centers_ = cluster_centers 69 | 70 | self._clustering_models[key] = model 71 | else: 72 | for n in self.var_names: 73 | d = 1 / self.num_bin_bboxes 74 | if n in ["x", "y", "z"]: 75 | centers = np.linspace( 76 | start=0.0, stop=1.0 - d, num=self.num_bin_bboxes 77 | ) 78 | else: 79 | centers = np.linspace(start=d, stop=1.0, num=self.num_bin_bboxes) 80 | centers = rearrange(centers, "c -> c 1") 81 | key = f"{n}-{self.num_bin_bboxes}" 82 | self._clustering_models[key] = DummyClusteringModel(centers) 83 | 84 | def encode(self, bbox: FloatTensor) -> LongTensor: 85 | d = 1 / self.num_bin_bboxes # delta 86 | bbox_q = torch.zeros_like(bbox) 87 | 88 | if self.bbox_quantization == "linear": 89 | bbox_q[..., :2] = torch.clamp(bbox[..., :2], 0.0, 1.0 - d) # ["x", "y"] 90 | bbox_q[..., 2:] = torch.clamp(bbox[..., 2:], d, 1.0) - d # ["w", "h"] 91 | indices = (self.num_bin_bboxes * bbox_q).round().long() 92 | 93 | elif self.bbox_quantization in ["kmeans", "percentile"]: 94 | B, S = bbox.size()[:2] 95 | indices = [] 96 | if len(self._var_order) == 4: 97 | for i, key in enumerate(self._var_names): 98 | model = self.clustering_models[f"{key}-{self.num_bin_bboxes}"] 99 | input_ = bbox[..., i : i + 1].view(-1, 1).numpy().astype(np.float32) 100 | output = torch.from_numpy(model.predict(input_)).view(B, S, 1) 101 | indices.append(output) 102 | 103 | indices = torch.cat(indices, dim=-1) 104 | 105 | # add offset if vocabularies are not fully shared among xywh 106 | if len(self._var_order) == 4: 107 | for (key, mult) in KEY_MULT_DICT[self.shared_bbox_vocab].items(): 108 | indices[..., self._var_names.index(key)] += self.num_bin_bboxes * mult 109 | 110 | # change var order if necessary 111 | if len(self._var_order) == 4: 112 | order_indices = [self._var_names.index(k) for k in self._var_order] 113 | indices = indices[..., order_indices] 114 | 115 | return indices 116 | 117 | def decode(self, bbox_indices: LongTensor) -> FloatTensor: 118 | arr = torch.clone(bbox_indices) # avoid overriding 119 | 120 | # restore var order back to "xywh" if necessary 121 | if len(self._var_order) == 4: 122 | order_indices = [self._var_order.index(k) for k in self._var_names] 123 | arr = arr[..., order_indices] 124 | 125 | # subtract offset if vocabularies are not fully shared among xywh 126 | if len(self._var_order) == 4: 127 | for (key, mult) in KEY_MULT_DICT[self.shared_bbox_vocab].items(): 128 | arr[..., self._var_names.index(key)] -= self.num_bin_bboxes * mult 129 | 130 | if self.bbox_quantization == "linear": 131 | # if len(self._var_order) == 2: 132 | # # decode from product space 133 | # tmp = {} 134 | # for i, vs in enumerate(self._var_order): 135 | # x = torch.clone(arr[..., i]) 136 | # for v in reversed(vs): 137 | # tmp[v] = x % self.num_bin_bboxes 138 | # x = torch.div(x, self.num_bin_bboxes, rounding_mode="floor") 139 | # arr = torch.stack([tmp[k] for k in self.keys], dim=-1) 140 | 141 | arr = torch.clamp(arr, 0, self.num_bin_bboxes - 1) # avoid OOV 142 | 143 | bbox = torch.zeros(arr.size()).float() 144 | d = 1 / self.num_bin_bboxes 145 | bbox[..., :2] = arr[..., :2].float() * d 146 | bbox[..., 2:] = (arr[..., 2:] + 1).float() * d 147 | 148 | elif self.bbox_quantization in ["kmeans", "percentile"]: 149 | B, S = arr.size()[:2] 150 | arr = torch.clamp(arr, 0, self.num_bin_bboxes - 1) # avoid OOV 151 | bbox = [] 152 | if len(self._var_order) == 4: 153 | for i, key in enumerate(self._var_names): 154 | model = self.clustering_models[f"{key}-{self.num_bin_bboxes}"] 155 | inds = arr[..., i : i + 1].view(-1).numpy() 156 | loc = torch.from_numpy(model.cluster_centers_[inds]).view(B, S, -1) 157 | bbox.append(loc) 158 | elif len(self._var_order) == 2: 159 | tmp = {} 160 | for i, vs in enumerate(self._var_order): 161 | model = self.clustering_models[f"{vs}-{self.num_bin_bboxes}"] 162 | inds = arr[..., i : i + 1].view(-1).numpy() 163 | inds = model.cluster_centers_[inds] 164 | loc = torch.from_numpy(inds).view(B, S, -1) 165 | for j, key in enumerate(vs): 166 | tmp[key] = loc[..., j : j + 1] 167 | 168 | # reorganize var order to xywh 169 | bbox = [tmp[key] for key in self._var_names] 170 | 171 | bbox = torch.cat(bbox, dim=-1) 172 | bbox = torch.clamp(bbox, 0.0, 1.0) 173 | 174 | return bbox 175 | 176 | @property 177 | def bbox_vocab_len(self) -> int: 178 | return self.num_bin_bboxes * len(self.shared_bbox_vocab.split("-")) 179 | 180 | @property 181 | def bbox_quantization(self) -> str: 182 | return self._bbox_quantization 183 | 184 | @property 185 | def clustering_models( 186 | self, 187 | ) -> Dict[str, Union[KMeans, Percentile, DummyClusteringModel]]: 188 | return self._clustering_models 189 | 190 | @property 191 | def num_bin_bboxes(self) -> int: 192 | return self._num_bin_bboxes 193 | 194 | @property 195 | def shared_bbox_vocab(self) -> str: 196 | return self._shared_bbox_vocab 197 | 198 | @property 199 | def token_mask(self) -> Dict[str, BoolTensor]: 200 | masks = {} 201 | if self.shared_bbox_vocab == "xywh": 202 | for key in self._var_order: 203 | masks[key] = torch.full((self.num_bin_bboxes,), True) 204 | elif self.shared_bbox_vocab == "x-y-w-h": 205 | key_mult = KEY_MULT_DICT["x-y-w-h"] 206 | S = self.num_bin_bboxes * 4 207 | false_tensor = torch.full((S,), False) 208 | for key in self._var_order: 209 | masks[key] = copy.deepcopy(false_tensor) 210 | i = key_mult.get(key, 0) 211 | start, stop = i * self.num_bin_bboxes, (i + 1) * self.num_bin_bboxes 212 | masks[key][start:stop] = True 213 | else: 214 | raise NotImplementedError 215 | 216 | return masks 217 | 218 | @property 219 | def var_names(self) -> List[str]: 220 | return self._var_names 221 | -------------------------------------------------------------------------------- /src/trainer/trainer/helpers/clustering.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from einops import rearrange 3 | 4 | EPS = 1e-12 5 | 6 | 7 | class Percentile: 8 | """ 9 | It resembles KMeans interface in scikit-learn 10 | https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html 11 | """ 12 | 13 | def __init__(self, n_clusters: int = 32, v_min: float = 0.0, v_max: float = 1.0): 14 | self.n_clusters = n_clusters 15 | self.v_min = v_min 16 | self.v_max = v_max 17 | 18 | def fit(self, X: np.ndarray): 19 | assert X.ndim == 2 # (B, 1) 20 | X = X[:, 0].clip(self.v_min, self.v_max) 21 | # n_points_per_bin = int(np.ceil(len(data_sorted) / self.n_clusters)) 22 | # thresholds = data_sorted[::n_points_per_bin] 23 | # self._thresholds = np.concatenate([thresholds[1:], [self.v_max,]]).reshape( 24 | # 1, -1 25 | # ) # (1, T) 26 | # cond = self._thresholds <= X.reshape(-1, 1) # (B, T) 27 | # ids = cond.argmax(axis=1) # get smallest index of non-zero (false) item 28 | 29 | X = np.sort(np.unique(X)) 30 | thresholds = np.linspace(0.0, 1.0, self.n_clusters + 1)[:-1] 31 | thresholds = [X[int(t * len(X))] for t in thresholds] 32 | ids = (thresholds <= X.reshape(-1, 1)).sum(axis=1).astype(np.uint64) - 1 33 | 34 | self.cluster_centers_ = np.full( 35 | (self.n_clusters, 1), -1.0, dtype=np.float32 36 | ) # -1 will not be queried 37 | for i in range(self.n_clusters): 38 | values = X[ids == i] 39 | if len(values) > 0: 40 | self.cluster_centers_[i, 0] = values.mean().astype(np.float32) 41 | return self 42 | 43 | def predict(self, X): 44 | if not hasattr(self, "cluster_centers_"): 45 | raise NotImplementedError 46 | 47 | assert X.ndim == 2 48 | X = X.clip(self.v_min, self.v_max) 49 | # cond = self._thresholds >= X.reshape(-1, 1) # (B, T) 50 | # ids = cond.argmax(axis=1) 51 | # https://github.com/jannerm/trajectory-transformer/blob/c77076d1c39e8c8edc3d1e5032b55499de556d73/trajectory/utils/discretization.py#L196-L213 52 | dist = np.fabs(self.cluster_centers_ - rearrange(X, "s 1 -> 1 s")) 53 | ids = np.argmin(dist, axis=0) 54 | 55 | return ids 56 | -------------------------------------------------------------------------------- /src/trainer/trainer/helpers/mask.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import torch 4 | from einops import rearrange, reduce, repeat 5 | from torch import BoolTensor, FloatTensor, LongTensor 6 | 7 | from .util import batch_topk_mask 8 | 9 | 10 | def sequence_mask(length: LongTensor, maxlen: Optional[int] = None) -> BoolTensor: 11 | """ 12 | Similar to https://www.tensorflow.org/api_docs/python/tf/sequence_mask 13 | """ 14 | B = length.size(0) 15 | maxlen = maxlen if maxlen else length.max() 16 | indices = repeat(torch.arange(maxlen), "s -> b s", b=B) 17 | mask = indices < rearrange(length, "b -> b 1") 18 | return mask 19 | 20 | 21 | def sample_mask(mask: BoolTensor, ratio: Union[float, FloatTensor]) -> BoolTensor: 22 | """ 23 | Generate sampled_mask (B, S) given mask (B, S) according to the specified ratio 24 | If mask[b, s] is False, sampled_mask[b, s] should be False. 25 | """ 26 | if isinstance(ratio, float): 27 | ratio = torch.full((mask.size(0),), fill_value=ratio) 28 | 29 | scores = torch.rand(mask.size()) 30 | n_elem = reduce(mask, "b s -> b", reduction="sum") 31 | topk = (ratio * n_elem).long() 32 | sampled_mask, _ = batch_topk_mask(scores, topk, mask=mask) 33 | return sampled_mask 34 | 35 | 36 | if __name__ == "__main__": 37 | sample_mask(torch.full((2, 3), fill_value=False), 0.5) 38 | -------------------------------------------------------------------------------- /src/trainer/trainer/helpers/sampling.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | from hydra.core.config_store import ConfigStore 7 | from omegaconf import DictConfig 8 | from torch import FloatTensor, LongTensor 9 | 10 | FILTER_VALUE = -float("Inf") 11 | 12 | 13 | @dataclass 14 | class DeterministicSamplingConfig: 15 | name: str = "deterministic" 16 | 17 | 18 | @dataclass 19 | class _StochasticSamplingConfig: 20 | temperature: float = 1.0 21 | 22 | 23 | @dataclass 24 | class RandomSamplingConfig(_StochasticSamplingConfig): 25 | name: str = "random" 26 | 27 | 28 | @dataclass 29 | class GumbelSamplingConfig(_StochasticSamplingConfig): 30 | name: str = "gumbel" 31 | 32 | 33 | @dataclass 34 | class TopKSamplingConfig(_StochasticSamplingConfig): 35 | name: str = "top_k" 36 | top_k: int = 5 37 | 38 | 39 | @dataclass 40 | class TopPSamplingConfig(_StochasticSamplingConfig): 41 | name: str = "top_p" 42 | top_p: float = 0.9 43 | 44 | 45 | @dataclass 46 | class TopKTopPSamplingConfig(_StochasticSamplingConfig): 47 | name: str = "top_k_top_p" 48 | top_k: int = 5 49 | top_p: float = 0.9 50 | 51 | 52 | SAMPLING_CONFIG_DICT = { 53 | "top_k": TopKSamplingConfig, 54 | "top_k": TopKTopPSamplingConfig, 55 | "top_p": TopPSamplingConfig, 56 | "deterministic": DeterministicSamplingConfig, 57 | "random": RandomSamplingConfig, 58 | "gumbel": GumbelSamplingConfig, 59 | } 60 | 61 | 62 | def register_sampling_config(cs: ConfigStore): 63 | """ 64 | Helper to register all sampling configurations defined above 65 | """ 66 | cs.store(group="sampling", name="top_k", node=TopKSamplingConfig) 67 | cs.store(group="sampling", name="top_k_top_p", node=TopKTopPSamplingConfig) 68 | cs.store(group="sampling", name="top_p", node=TopPSamplingConfig) 69 | cs.store(group="sampling", name="deterministic", node=DeterministicSamplingConfig) 70 | cs.store(group="sampling", name="random", node=RandomSamplingConfig) 71 | 72 | 73 | def top_k_logits(logits: FloatTensor, k: int, dim: int = -1): 74 | # logits: (B, C) 75 | v, _ = torch.topk(logits, k, dim) 76 | out = logits.clone() 77 | out[out < v[:, [-1]]] = FILTER_VALUE 78 | return out 79 | 80 | 81 | def sample(logits: FloatTensor, sampling_cfg: DictConfig) -> LongTensor: 82 | """ 83 | Input: logits (B, C, *N) 84 | Output: (B, 1, *N) 85 | """ 86 | assert logits.ndim in [2, 3] 87 | if sampling_cfg.name == "deterministic": 88 | output = torch.argmax(logits, dim=1, keepdim=True) 89 | else: 90 | logits_ = logits / sampling_cfg.temperature 91 | 92 | if sampling_cfg.name == "top_k": 93 | logits = top_k_logits(logits_, k=sampling_cfg.top_k, dim=1) 94 | elif sampling_cfg.name == "top_p": 95 | top_p = sampling_cfg.top_p 96 | assert 0.0 < top_p <= 1.0 97 | 98 | S = logits.size(1) 99 | # https://stackoverflow.com/questions/52127723/pytorch-better-way-to-get-back-original-tensor-order-after-torch-sort 100 | sorted_logits, sorted_indices = torch.sort(logits_, descending=True, dim=1) 101 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=1), dim=1) 102 | 103 | indices = torch.arange(S).view(1, S).to(logits.device) 104 | if logits.ndim == 3: 105 | indices = indices.unsqueeze(dim=-1) 106 | 107 | # make sure to keep the first logit (most likely one) 108 | sorted_logits[(cumulative_probs > top_p) & (indices > 0)] = FILTER_VALUE 109 | logits = sorted_logits.gather(dim=1, index=sorted_indices.argsort(dim=1)) 110 | elif sampling_cfg.name == "random": 111 | logits = logits_ 112 | elif sampling_cfg.name == "gumbel": 113 | uniform = torch.rand_like(logits_) 114 | const = 1e-30 115 | gumbel_noise = -torch.log(-torch.log(uniform + const) + const) 116 | logits = logits_ + gumbel_noise 117 | else: 118 | raise NotImplementedError 119 | 120 | probs = F.softmax(logits, dim=1) 121 | if probs.ndim == 2: 122 | output = torch.multinomial(probs, num_samples=1) # (B, 1) 123 | elif probs.ndim == 3: 124 | S = probs.shape[2] 125 | probs = rearrange(probs, "b c s -> (b s) c") 126 | output = torch.multinomial(probs, num_samples=1) 127 | output = rearrange(output, "(b s) 1 -> b 1 s", s=S) 128 | else: 129 | raise NotImplementedError 130 | return output 131 | 132 | 133 | if __name__ == "__main__": 134 | from einops import repeat 135 | from omegaconf import OmegaConf 136 | 137 | sampling_cfg = OmegaConf.create({"name": "top_p", "top_p": 0.9, "temperature": 1.0}) 138 | logits = repeat(torch.arange(5), "c -> b c 1", b=2) 139 | x = sample(logits, sampling_cfg, return_confidence=True) 140 | print(x) 141 | -------------------------------------------------------------------------------- /src/trainer/trainer/helpers/task.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Any, Dict 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from einops import rearrange, reduce, repeat 7 | from omegaconf import DictConfig 8 | from torch import LongTensor, Tensor 9 | from torch_geometric.utils import to_dense_batch 10 | from trainer.data.util import sparse_to_dense 11 | from trainer.helpers.layout_tokenizer import LayoutSequenceTokenizer 12 | from trainer.helpers.mask import sample_mask 13 | from trainer.helpers.util import batch_topk_mask 14 | 15 | MAX_PARTIAL_RATIO = 0.3 16 | COND_TYPES = [ 17 | "c", # given category, predict position/size (C->P+S) 18 | "cwh", # given category/size, predict position (C+S->P) 19 | "partial", # given a partial layout (i.e., only a few elements), predict to generate a complete layout 20 | "gt", # just copy 21 | "random", # random masking 22 | "refinement", # given category and noisy position/size, predict accurate position/size 23 | "relation", # given category and some relationships between elements, try to fulfill the relationships as much as possible 24 | ] 25 | 26 | 27 | def get_cond( 28 | batch, # torch_geometric.data.batch.DataBatch 29 | tokenizer: LayoutSequenceTokenizer, 30 | cond_type: str = "c", 31 | model_type: str = "", 32 | get_real_images: bool = False, 33 | ) -> Dict[str, Any]: 34 | assert cond_type in COND_TYPES 35 | 36 | if get_real_images: 37 | assert cond_type in ["cwh", "gt"] 38 | 39 | special_keys = tokenizer.special_tokens 40 | pad_id = tokenizer.name_to_id("pad") 41 | mask_id = tokenizer.name_to_id("mask") if "mask" in special_keys else -1 42 | 43 | if cond_type == "relation": 44 | # extract non-canvas variables 45 | flag = batch.attr["has_canvas_element"] 46 | if isinstance(flag, bool): 47 | assert flag 48 | elif isinstance(flag, Tensor): 49 | assert flag.all() 50 | remove_canvas = True 51 | else: 52 | remove_canvas = False 53 | 54 | # load real layouts 55 | bbox, label, _, mask = sparse_to_dense(batch, remove_canvas=remove_canvas) 56 | cond = tokenizer.encode({"label": label, "mask": mask, "bbox": bbox}) 57 | B = bbox.shape[0] 58 | S = cond["seq"].shape[1] 59 | C = tokenizer.N_var_per_element 60 | 61 | # modify some elements to simulate various conditional generation settings 62 | if cond_type == "partial": 63 | start = 1 if "bos" in special_keys else 0 64 | n_elem = (S - start) // C 65 | scores = torch.rand(B, n_elem) 66 | mask = cond["mask"][:, start::C] 67 | 68 | n_valid_elem = reduce(mask, "b s -> b", reduction="sum") 69 | topk = [] 70 | for k in n_valid_elem: 71 | vmax = int((k - 1) * MAX_PARTIAL_RATIO) 72 | val = random.randint(1, vmax) if vmax > 1 else 1 73 | topk.append(val) 74 | topk = torch.LongTensor(topk) 75 | keep, _ = batch_topk_mask(scores, topk, mask=mask) 76 | 77 | keep = repeat(keep, "b s -> b (s c)", c=C) 78 | if "bos" in special_keys: 79 | # for order-sensitive methods, shift valid condition at the beginning of seq. 80 | keep = torch.cat([torch.full((B, 1), fill_value=True), keep], dim=-1) 81 | new_seq = torch.full_like(cond["seq"], mask_id) 82 | new_mask = torch.full_like(cond["mask"], False) 83 | for i in range(B): 84 | s = cond["seq"][i] 85 | ind_end = keep[i].sum().item() 86 | new_seq[i][:ind_end] = s[keep[i]] 87 | new_mask[i][:ind_end] = True 88 | cond["seq"] = new_seq 89 | cond["mask"] = new_mask 90 | else: 91 | cond["seq"][~keep] = mask_id 92 | cond["mask"] = keep 93 | 94 | elif cond_type in ["c", "cwh", "relation"]: 95 | vars = {"c": "c", "cwh": "cwh", "relation": "c"} 96 | keep = torch.full((B, S), False) 97 | if "bos" in special_keys: 98 | attr_ind = (torch.arange(S).view(1, S) - 1) % C 99 | attr_ind[:, 0] = -1 # dummy id for BOS 100 | keep[:, 0] = True 101 | else: 102 | attr_ind = torch.arange(S).view(1, S) % C 103 | for s in vars[cond_type]: 104 | ind = tokenizer.var_names.index(s) 105 | keep |= attr_ind == ind 106 | cond["seq"][~keep] = mask_id 107 | 108 | # specifying number of elements since it is known in these settings 109 | cond["seq"][~cond["mask"]] = pad_id 110 | cond["mask"] = (cond["mask"] & keep) | ~cond["mask"] 111 | 112 | # load edge attributes for imposing relational constraints 113 | if cond_type == "relation": 114 | cond["batch_w_canvas"] = batch 115 | 116 | elif cond_type == "gt": 117 | pass 118 | 119 | elif cond_type == "random": 120 | ratio = torch.rand((B,)) 121 | loss_mask = sample_mask(torch.full(cond["mask"].size(), True), ratio) 122 | # pass 123 | cond["seq"][loss_mask] = mask_id 124 | cond["mask"] = ~loss_mask 125 | 126 | elif cond_type == "refinement": 127 | new_bbox = bbox + torch.normal(0, std=0.1, size=bbox.size()) 128 | new_cond = tokenizer.encode({"label": label, "mask": mask, "bbox": new_bbox}) 129 | index = repeat(torch.arange(S), "s -> b s", b=B) 130 | cond = {} 131 | if "bos" in special_keys: 132 | cond["mask"] = new_cond["mask"] & ((index - 1) % C == 0) | ~new_cond["mask"] 133 | else: 134 | cond["mask"] = new_cond["mask"] & (index % C == 0) | ~new_cond["mask"] 135 | if model_type in ["LayoutDM", "ElemWiseAutoreg"]: 136 | cond["seq"] = torch.where(cond["mask"], new_cond["seq"], mask_id) 137 | cond["seq"] = torch.where(new_cond["mask"], cond["seq"], pad_id) 138 | cond["seq_orig"] = new_cond["seq"] 139 | else: 140 | cond["seq"] = new_cond["seq"] 141 | else: 142 | raise NotImplementedError 143 | 144 | if get_real_images: 145 | pass 146 | 147 | cond["type"] = cond_type 148 | if cond_type in ["c", "cwh", "refinement", "relation"]: 149 | cond["num_element"] = mask.sum(dim=1) 150 | 151 | return cond 152 | 153 | 154 | def _index_to_smoothed_log_onehot( 155 | seq: LongTensor, 156 | tokenizer: LayoutSequenceTokenizer, 157 | mode: str = "uniform", 158 | offset_ratio: float = 0.2, 159 | ): 160 | # for ease of hp-tuning, the range is limited to [0.0, 1.0] 161 | assert tokenizer.N_var_per_element == 5 162 | assert mode in ["uniform", "gaussian", "negative"] 163 | 164 | bbt = tokenizer.bbox_tokenizer 165 | V = len(bbt.var_names) 166 | N = tokenizer.N_bbox_per_var 167 | 168 | if tokenizer.bbox_tokenizer.shared_bbox_vocab == "xywh": 169 | slices = [ 170 | slice(tokenizer.N_category, tokenizer.N_category + N) for i in range(V) 171 | ] 172 | else: 173 | slices = [ 174 | slice(tokenizer.N_category + i * N, tokenizer.N_category + (i + 1) * N) 175 | for i in range(V) 176 | ] 177 | 178 | logits = torch.zeros( 179 | (tokenizer.N_total, tokenizer.N_total), 180 | ) 181 | logits.fill_diagonal_(1.0) 182 | 183 | for i, key in enumerate(bbt.var_names): 184 | name = f"{key}-{N}" 185 | cluster_model = tokenizer.bbox_tokenizer.clustering_models[name] 186 | cluster_centers = torch.from_numpy(cluster_model.cluster_centers_).view(-1) 187 | ii, jj = torch.meshgrid(cluster_centers, cluster_centers, indexing="ij") 188 | if mode == "uniform": 189 | logits[slices[i], slices[i]] = (torch.abs(ii - jj) < offset_ratio).float() 190 | elif mode == "negative": 191 | logits[slices[i], slices[i]] = (torch.abs(ii - jj) >= offset_ratio).float() 192 | elif mode == "gaussian": 193 | # p(x) = a * exp( -(x-b)^2 / (2 * c^2)) 194 | # -> log p(x) = log(a) - (x-b)^2 / (2 * c^2) 195 | # thus, a strength of adjustment is proportional to -(ii - jj)^2 196 | logits[slices[i], slices[i]] = -1.0 * (ii - jj) ** 2 197 | else: 198 | raise NotImplementedError 199 | 200 | logits = rearrange(F.embedding(seq, logits), "b s c -> b c s") 201 | return logits 202 | 203 | 204 | def set_additional_conditions_for_refinement( 205 | cond: Dict[str, Any], 206 | tokenizer: LayoutSequenceTokenizer, 207 | sampling_cfg: DictConfig, 208 | ) -> Dict[str, Any]: 209 | """ 210 | Set hand-crafted prior for the position/size of each element (Eq. 8) 211 | """ 212 | w = sampling_cfg.refine_lambda 213 | if sampling_cfg.refine_mode == "negative": 214 | w *= -1.0 215 | 216 | cond["weak_mask"] = repeat(~cond["mask"], "b s -> b c s", c=tokenizer.N_total) 217 | cond["weak_logits"] = _index_to_smoothed_log_onehot( 218 | cond["seq_orig"], 219 | tokenizer, 220 | mode=sampling_cfg.refine_mode, 221 | offset_ratio=sampling_cfg.refine_offset_ratio, 222 | ) 223 | cond["weak_logits"] *= w 224 | return cond 225 | 226 | 227 | def filter_canvas(layout: Dict): 228 | new_layout = {} 229 | new_layout["bbox"] = layout["bbox"][:, 1:] 230 | new_layout["label"] = layout["label"][:, 1:] - 1 231 | new_layout["mask"] = layout["mask"][:, 1:] 232 | return new_layout 233 | 234 | 235 | def duplicate_cond(cond: Dict, batch_size: int) -> Dict: 236 | # this is used in demo to see the variety 237 | # if there's single example but batch_size > 1, copy conditions 238 | flag = cond["seq"].size(0) == 1 239 | flag &= batch_size > 1 240 | if flag: 241 | for k in cond: 242 | if isinstance(cond[k], Tensor): 243 | sizes = [ 244 | batch_size, 245 | ] 246 | sizes += [1 for _ in range(cond[k].dim() - 1)] 247 | cond[k] = cond[k].repeat(sizes) 248 | return cond 249 | -------------------------------------------------------------------------------- /src/trainer/trainer/helpers/util.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Optional, Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | from einops import rearrange 7 | from torch import BoolTensor, FloatTensor, LongTensor 8 | 9 | 10 | def set_seed(seed: int): 11 | random.seed(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | 15 | 16 | def convert_xywh_to_ltrb(bbox: Union[np.ndarray, FloatTensor]): 17 | xc, yc, w, h = bbox 18 | x1 = xc - w / 2 19 | y1 = yc - h / 2 20 | x2 = xc + w / 2 21 | y2 = yc + h / 2 22 | return [x1, y1, x2, y2] 23 | 24 | 25 | def batch_topk_mask( 26 | scores: FloatTensor, 27 | topk: LongTensor, 28 | mask: Optional[BoolTensor] = None, 29 | ) -> Tuple[BoolTensor, FloatTensor]: 30 | assert scores.ndim == 2 and topk.ndim == 1 and scores.size(0) == topk.size(0) 31 | if mask is not None: 32 | assert mask.size() == scores.size() 33 | assert (scores.size(1) >= topk).all() 34 | 35 | # ignore scores where mask = False by setting extreme values 36 | if mask is not None: 37 | const = -1.0 * float("Inf") 38 | const = torch.full_like(scores, fill_value=const) 39 | scores = torch.where(mask, scores, const) 40 | 41 | sorted_values, _ = torch.sort(scores, dim=-1, descending=True) 42 | topk = rearrange(topk, "b -> b 1") 43 | 44 | k_th_scores = torch.gather(sorted_values, dim=1, index=topk) 45 | 46 | topk_mask = scores > k_th_scores 47 | return topk_mask, k_th_scores 48 | 49 | 50 | def batch_shuffle_index( 51 | batch_size: int, 52 | feature_length: int, 53 | mask: Optional[BoolTensor] = None, 54 | ) -> LongTensor: 55 | """ 56 | Note: masked part may be shuffled because of unpredictable behaviour of sorting [inf, ..., inf] 57 | """ 58 | if mask: 59 | assert mask.size() == [batch_size, feature_length] 60 | scores = torch.rand((batch_size, feature_length)) 61 | if mask: 62 | scores[~mask] = float("Inf") 63 | _, indices = torch.sort(scores, dim=1) 64 | return indices 65 | 66 | 67 | if __name__ == "__main__": 68 | scores = torch.arange(6).view(2, 3).float() 69 | # topk = torch.arange(2) + 1 70 | topk = torch.full((2,), 3) 71 | mask = torch.full((2, 3), False) 72 | # mask[1, 2] = False 73 | print(batch_topk_mask(scores, topk, mask=mask)) 74 | -------------------------------------------------------------------------------- /src/trainer/trainer/hydra_configs.py: -------------------------------------------------------------------------------- 1 | """ 2 | A file to declare dataclass instances used for hydra configs at ./config/* 3 | """ 4 | import os 5 | from dataclasses import dataclass 6 | from typing import Optional, Tuple 7 | 8 | from trainer.helpers.layout_tokenizer import CHOICES 9 | 10 | 11 | @dataclass 12 | class TestConfig: 13 | job_dir: str 14 | result_dir: str 15 | dataset_dir: Optional[str] = None # set if it is different for train/test 16 | max_batch_size: int = 512 17 | num_run: int = 1 # number of outputs per input 18 | cond: str = "unconditional" 19 | num_timesteps: int = 100 20 | is_validation: bool = False # for evaluation in validation set (e.g. HP search) 21 | debug: bool = False # disable some features to enable fast runtime 22 | debug_num_samples: int = -1 # in debug mode, reduce the number of samples when > 0 23 | 24 | # for sampling 25 | sampling: str = "random" # see ./helpers/sampling.py for options 26 | # below are additional parameters for sampling modes 27 | temperature: float = 1.0 28 | top_p: float = 0.9 29 | top_k: float = 5 30 | 31 | # for unconditional models 32 | num_uncond_samples: int = 1000 33 | 34 | # for diffusion models 35 | # assymetric time-difference (https://arxiv.org/abs/2208.04202) 36 | time_difference: float = 0.0 37 | 38 | # for diffusion models, refinement only 39 | refine_lambda: float = 3.0 # if > 0.0, trigger refinement mode 40 | refine_mode: str = "uniform" 41 | refine_offset_ratio: float = 0.1 # 0.2 42 | 43 | # for diffusion models, relation only 44 | relation_lambda: float = 3e6 # if > 0.0, trigger relation mode 45 | relation_mode: str = "average" 46 | relation_tau: float = 1.0 47 | relation_num_update: int = 3 48 | 49 | # for continuous diffusion models 50 | use_ddim: bool = False 51 | 52 | 53 | @dataclass 54 | class TrainConfig: 55 | epochs: int = 50 56 | grad_norm_clip: float = 1.0 57 | weight_decay: float = 1e-1 58 | loss_plot_iter_interval: int = 50 59 | sample_plot_epoch_interval: int = 1 60 | fid_plot_num_samples: int = 1000 61 | fid_plot_batch_size: int = 512 62 | 63 | 64 | @dataclass 65 | class DataConfig: 66 | batch_size: int = 64 67 | bbox_quantization: str = "linear" 68 | num_bin_bboxes: int = 32 69 | num_workers: int = os.cpu_count() 70 | pad_until_max: bool = ( 71 | False # True for diffusion models, False for others for efficient batching 72 | ) 73 | shared_bbox_vocab: str = "xywh" 74 | special_tokens: Tuple[str] = ("pad", "mask") 75 | # special_tokens: Tuple[str] = ("pad",) 76 | # transforms: Tuple[str] = ("SortByLabel", "LexicographicOrder") 77 | transforms: Tuple[str] = ("RandomOrder",) 78 | var_order: str = "c-x-y-w-h" 79 | 80 | def __post_init__(self) -> None: 81 | # advanced validation like choices in argparse 82 | for key in ["shared_bbox_vocab", "bbox_quantization", "var_order"]: 83 | assert getattr(self, key) in CHOICES[key] 84 | -------------------------------------------------------------------------------- /src/trainer/trainer/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyberAgentAILab/layout-dm/873b5eebe4c61862e5c08a10859accf65a168dfd/src/trainer/trainer/models/__init__.py -------------------------------------------------------------------------------- /src/trainer/trainer/models/base_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, Iterable, List, Optional, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | from omegaconf import DictConfig 7 | from torch import Tensor 8 | from trainer.helpers.layout_tokenizer import LayoutTokenizer 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class BaseModel(torch.nn.Module): 14 | def __init__(self) -> None: 15 | super().__init__() 16 | 17 | def forward(self): 18 | raise NotImplementedError 19 | 20 | # def sample(self, z: Optional[Tensor]): 21 | def sample(self): 22 | """ 23 | Generate sample based on z. 24 | z can be either given or sampled from some distribution (e.g., normal) 25 | """ 26 | raise NotImplementedError 27 | 28 | def preprocess(self): 29 | raise NotImplementedError 30 | 31 | def postprocess(self): 32 | raise NotImplementedError 33 | 34 | @property 35 | def device(self) -> torch.device: 36 | if hasattr(self, "model"): 37 | return next(self.model.parameters()).device 38 | else: 39 | raise NotImplementedError 40 | 41 | @property 42 | def tokenizer(self): 43 | return self._tokenizer 44 | 45 | @tokenizer.setter 46 | def tokenizer(self, value: LayoutTokenizer): 47 | self._tokenizer = value 48 | 49 | def compute_stats(self): 50 | logger.info( 51 | "number of parameters: %e", sum(p.numel() for p in self.parameters()) / 1e6 52 | ) 53 | 54 | def optim_groups( 55 | self, weight_decay: float = 0.0, additional_no_decay: Optional[List[str]] = None 56 | ) -> Union[Iterable[Tensor], Dict[str, Tensor]]: 57 | # see https://github.com/kampta/DeepLayout/blob/main/layout_transformer/model.py#L139 58 | decay = set() 59 | no_decay = set() 60 | whitelist_weight_modules = ( 61 | torch.nn.Linear, 62 | torch.nn.modules.activation.MultiheadAttention, 63 | ) 64 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 65 | for mn, m in self.named_modules(): 66 | for pn, p in m.named_parameters(): 67 | fpn = "%s.%s" % (mn, pn) if mn else pn # full param name 68 | if pn.endswith("bias"): 69 | # all biases will not be decayed 70 | no_decay.add(fpn) 71 | elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): 72 | # weights of whitelist modules will be weight decayed 73 | decay.add(fpn) 74 | elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): 75 | # weights of blacklist modules will NOT be weight decayed 76 | no_decay.add(fpn) 77 | 78 | if additional_no_decay: 79 | for k in additional_no_decay: 80 | no_decay.add(k) 81 | 82 | # validate that we considered every parameter 83 | param_dict = {pn: p for pn, p in self.named_parameters()} 84 | inter_params = decay & no_decay 85 | union_params = decay | no_decay 86 | assert ( 87 | len(inter_params) == 0 88 | ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) 89 | assert ( 90 | len(param_dict.keys() - union_params) == 0 91 | ), "parameters %s were not separated into either decay/no_decay set!" % ( 92 | str(param_dict.keys() - union_params), 93 | ) 94 | 95 | # create the pytorch optimizer object 96 | optim_groups = [ 97 | { 98 | "params": [param_dict[pn] for pn in sorted(list(decay))], 99 | "weight_decay": weight_decay, 100 | }, 101 | { 102 | "params": [param_dict[pn] for pn in sorted(list(no_decay))], 103 | "weight_decay": 0.0, 104 | }, 105 | ] 106 | return optim_groups 107 | 108 | def _init_weights(self, module): 109 | if isinstance(module, (nn.Linear, nn.Embedding)): 110 | module.weight.data.normal_(mean=0.0, std=0.02) 111 | if isinstance(module, nn.Linear) and module.bias is not None: 112 | module.bias.data.zero_() 113 | elif isinstance(module, nn.LayerNorm): 114 | if module.elementwise_affine == True: 115 | module.bias.data.zero_() 116 | module.weight.data.fill_(1.0) 117 | 118 | def update_per_epoch(self, epoch: int, max_epoch: int): 119 | """ 120 | Update some non-trainable parameters during training (e.g., warmup) 121 | """ 122 | pass 123 | 124 | def aggregate_sampling_settings( 125 | self, sampling_cfg: DictConfig, args: DictConfig 126 | ) -> DictConfig: 127 | """ 128 | Set user-specified args for sampling cfg 129 | """ 130 | # Aggregate refinement-related parameters 131 | is_ruite = type(self).__name__ == "RUITE" 132 | if args.cond == "refinement" and args.refine_lambda > 0.0 and not is_ruite: 133 | sampling_cfg.refine_mode = args.refine_mode 134 | sampling_cfg.refine_offset_ratio = args.refine_offset_ratio 135 | sampling_cfg.refine_lambda = args.refine_lambda 136 | 137 | if args.cond == "relation" and args.relation_lambda > 0.0: 138 | sampling_cfg.relation_mode = args.relation_mode 139 | sampling_cfg.relation_lambda = args.relation_lambda 140 | sampling_cfg.relation_tau = args.relation_tau 141 | sampling_cfg.relation_num_update = args.relation_num_update 142 | 143 | if "num_timesteps" not in sampling_cfg: 144 | # for dec or enc-dec 145 | if "eos" in self.tokenizer.special_tokens: 146 | sampling_cfg.num_timesteps = self.tokenizer.max_token_length 147 | else: 148 | sampling_cfg.num_timesteps = args.num_timesteps 149 | 150 | return sampling_cfg 151 | -------------------------------------------------------------------------------- /src/trainer/trainer/models/blt.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import random 4 | from typing import Dict, Iterable, Optional, Tuple, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | from einops import rearrange, reduce, repeat 9 | from hydra.utils import instantiate 10 | from omegaconf import DictConfig 11 | from torch import Tensor 12 | from trainer.data.util import sparse_to_dense 13 | from trainer.helpers.layout_tokenizer import LayoutSequenceTokenizer 14 | from trainer.helpers.sampling import sample 15 | from trainer.helpers.task import duplicate_cond 16 | from trainer.helpers.util import batch_topk_mask 17 | from trainer.models.base_model import BaseModel 18 | from trainer.models.common.nn_lib import ( 19 | CategoricalTransformer, 20 | CustomDataParallel, 21 | SeqLengthDistribution, 22 | ) 23 | from trainer.models.common.util import get_dim_model 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | TARGET_ATTRS = [["c"], ["w", "h"], ["x", "y"]] # (category, size, position) 28 | 29 | 30 | def sample_mask(mask: Tensor, n_attr: int = 1): 31 | num_true = mask.sum().item() * n_attr 32 | n = random.randint(1, num_true) 33 | x = [True for _ in range(n)] + [False for _ in range(num_true - n)] 34 | random.shuffle(x) 35 | x += [False for _ in range(mask.size(0) * n_attr - len(x))] 36 | return rearrange(Tensor(x).to(mask), "(s c) -> s c", c=n_attr) 37 | 38 | 39 | class BLT(BaseModel): 40 | """ 41 | To reproduce 42 | BLT: Bidirectional Layout Transformer for Controllable Layout Generation (ECCV2022) 43 | https://arxiv.org/abs/2112.05112 44 | """ 45 | 46 | def __init__( 47 | self, 48 | backbone_cfg: DictConfig, 49 | tokenizer: LayoutSequenceTokenizer, 50 | use_padding_as_vocab: bool = False, 51 | ) -> None: 52 | super().__init__() 53 | # check conditions 54 | if use_padding_as_vocab: 55 | assert tokenizer.pad_until_max 56 | assert tokenizer.var_order == "c-x-y-w-h" 57 | 58 | self.tokenizer = tokenizer 59 | self.use_padding_as_vocab = use_padding_as_vocab 60 | 61 | # Note: make sure learnable parameters are inside self.model 62 | backbone = instantiate(backbone_cfg) 63 | self.model = CustomDataParallel( 64 | CategoricalTransformer( 65 | backbone=backbone, 66 | dim_model=get_dim_model(backbone_cfg), 67 | num_classes=tokenizer.N_total, 68 | max_token_length=tokenizer.max_token_length, 69 | ) 70 | ) 71 | self.apply(self._init_weights) 72 | self.compute_stats() 73 | self.seq_dist = SeqLengthDistribution(tokenizer.max_seq_length) 74 | self.loss_fn_ce = nn.CrossEntropyLoss() 75 | 76 | def forward(self, inputs: Dict[str, Tensor]) -> Tuple[Dict[str, Tensor]]: 77 | loss_mask = inputs["loss_mask"] 78 | 79 | if self.use_padding_as_vocab: 80 | outputs = self.model(inputs["input"]) 81 | else: 82 | outputs = self.model( 83 | inputs["input"], src_key_padding_mask=inputs["padding_mask"] 84 | ) 85 | nll_loss = self.loss_fn_ce( 86 | outputs["logits"][loss_mask], 87 | inputs["target"][loss_mask], 88 | ) 89 | losses = {"nll_loss": nll_loss} 90 | 91 | # replace masked tokens with predicted tokens 92 | outputs["outputs"] = copy.deepcopy(inputs["input"]) 93 | ids = torch.argmax(outputs["logits"], dim=-1) 94 | outputs["outputs"][loss_mask] = ids[loss_mask] 95 | 96 | return outputs, losses 97 | 98 | def sample( 99 | self, 100 | batch_size: Optional[int], 101 | cond: Optional[Tensor] = None, 102 | sampling_cfg: Optional[DictConfig] = None, 103 | device: Optional[torch.device] = None, 104 | **kwargs, 105 | ) -> Dict[str, Tensor]: 106 | """ 107 | Generate sample based on z. 108 | z can be either given or sampled from some distribution (e.g., normal) 109 | """ 110 | 111 | mask_id = self.tokenizer.name_to_id("mask") 112 | pad_id = self.tokenizer.name_to_id("pad") 113 | 114 | B, S = (batch_size, self.tokenizer.max_token_length) 115 | 116 | if "num_timesteps" in sampling_cfg: 117 | self.num_timesteps = sampling_cfg.num_timesteps 118 | else: 119 | self.num_timesteps = 9 120 | assert self.num_timesteps % 3 == 0 121 | 122 | if cond: 123 | cond = duplicate_cond(cond, batch_size) 124 | seq = cond["seq"].clone() 125 | # **_user will not be updated (kept as reference) 126 | seq_user = cond["seq"].clone() 127 | mask_user = cond["mask"].clone() 128 | if not self.use_padding_as_vocab: 129 | src_key_padding_mask_user = seq == pad_id 130 | else: 131 | n_elements = self.seq_dist.sample(B) * self.tokenizer.N_var_per_element 132 | indices = rearrange(torch.arange(S), "s -> 1 s") 133 | mask = indices < rearrange(n_elements, "b -> b 1") 134 | seq = torch.full((B, S), fill_value=pad_id) 135 | seq[mask] = mask_id 136 | seq_user = seq.clone() 137 | mask_user = ~mask.clone() 138 | src_key_padding_mask_user = ~mask.clone() 139 | 140 | T = self.num_timesteps // 3 141 | n_attr = self.tokenizer.N_var_per_element 142 | indices = [ 143 | [self.tokenizer.var_names.index(a) for a in attrs] for attrs in TARGET_ATTRS 144 | ] 145 | for target_attr_indices in indices: 146 | # ignore already filled region or non-target attributes 147 | attr_indices = repeat(torch.arange(S), "s -> b s", b=B) % n_attr 148 | keep_attr = torch.full((B, S), fill_value=True) 149 | for ind in target_attr_indices: 150 | keep_attr[attr_indices == ind] = False 151 | 152 | for t in range(T): 153 | ratio = (T - (t + 1)) / T 154 | if self.use_padding_as_vocab: 155 | logits = self.model(seq.to(device))["logits"].cpu() 156 | else: 157 | logits = self.model( 158 | seq.to(device), src_key_padding_mask=src_key_padding_mask_user 159 | )["logits"].cpu() 160 | 161 | invalid = repeat(~self.tokenizer.token_mask, "s c -> b s c", b=B) 162 | logits[invalid] = -float("Inf") 163 | 164 | seq_pred = sample(rearrange(logits, "b s c -> b c s"), sampling_cfg) 165 | confidence = torch.gather( 166 | logits, -1, rearrange(seq_pred, "b 1 s -> b s 1") 167 | ) 168 | confidence = rearrange(confidence, "b s 1 -> b s") 169 | seq_pred = rearrange(seq_pred, "b 1 s -> b s") 170 | 171 | # update by predicted tokens 172 | mask = (seq == mask_id) & (~keep_attr) 173 | seq = torch.where(mask, seq_pred, seq) 174 | 175 | if t < T - 1: 176 | # re-fill [MASK] for unconfident predictions 177 | n_elem = reduce( 178 | ~(mask_user | keep_attr), "b s -> b", reduction="sum" 179 | ) 180 | topk = (n_elem * ratio).long() 181 | is_unconfident, _ = batch_topk_mask( 182 | -1.0 * confidence, topk, mask=mask 183 | ) 184 | seq[is_unconfident] = mask_id 185 | 186 | # make sure to use user-defined inputs 187 | seq = torch.where(mask_user, seq_user, seq) 188 | 189 | layouts = self.tokenizer.decode(seq) 190 | return layouts 191 | 192 | def preprocess(self, batch): 193 | bbox, label, _, mask = sparse_to_dense(batch) 194 | self.seq_dist(mask) 195 | 196 | inputs = self.tokenizer.encode({"label": label, "mask": mask, "bbox": bbox}) 197 | B = inputs["mask"].size(0) 198 | C = self.tokenizer.N_var_per_element 199 | S = inputs["mask"].size(1) // C 200 | mask_id = self.tokenizer._special_token_name_to_id["mask"] 201 | 202 | sampled_indices = torch.randint(0, len(TARGET_ATTRS), size=(B,)) 203 | loss_mask = torch.full((B, S, C), False) 204 | for i, ind in enumerate(sampled_indices): 205 | if self.use_padding_as_vocab: 206 | # no constraint on mask location 207 | tmp_mask = torch.full((S,), True) 208 | else: 209 | tmp_mask = inputs["mask"][i, 0::C] 210 | if ind == 0: # C(ategory) 211 | loss_mask[i, :, 0:1] = sample_mask(tmp_mask, n_attr=1) 212 | elif ind == 1: # P(osition) 213 | loss_mask[i, :, 1:3] = sample_mask(tmp_mask, n_attr=2) 214 | elif ind == 2: # S(ize) 215 | loss_mask[i, :, 3:] = sample_mask(tmp_mask, n_attr=2) 216 | loss_mask = rearrange(loss_mask, "b s c -> b (s c)") 217 | 218 | masked_seq = copy.deepcopy(inputs["seq"]) 219 | masked_seq[loss_mask] = mask_id 220 | 221 | return { 222 | "target": inputs["seq"], 223 | "padding_mask": ~inputs["mask"], 224 | "loss_mask": loss_mask, 225 | "input": masked_seq, 226 | } 227 | 228 | def optim_groups( 229 | self, weight_decay: float = 0.0 230 | ) -> Union[Iterable[Tensor], Dict[str, Tensor]]: 231 | return super().optim_groups( 232 | weight_decay=weight_decay, 233 | additional_no_decay=[ 234 | "model.module.pos_emb.pos_emb", 235 | ], 236 | ) 237 | -------------------------------------------------------------------------------- /src/trainer/trainer/models/categorical_diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyberAgentAILab/layout-dm/873b5eebe4c61862e5c08a10859accf65a168dfd/src/trainer/trainer/models/categorical_diffusion/__init__.py -------------------------------------------------------------------------------- /src/trainer/trainer/models/categorical_diffusion/logit_adjustment.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, Optional 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from einops import rearrange, reduce, repeat 7 | from omegaconf import DictConfig 8 | from torch import FloatTensor, Tensor 9 | from trainer.helpers.layout_tokenizer import LayoutSequenceTokenizer 10 | from trainer.models.categorical_diffusion.util import index_to_log_onehot 11 | from trainer.models.clg.const import relation as relational_constraints 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def _stochastic_convert( 17 | cond: Dict, 18 | model_log_prob: Tensor, 19 | tokenizer: LayoutSequenceTokenizer, 20 | tau: float = 1.0, 21 | mode: str = "average", 22 | ) -> Tensor: 23 | """ 24 | Convert model_log_prob (B, C, S) to average bbox location (E, X) 25 | , where E is number of valid layout components and X is number of fields in each component. 26 | Use mode='average' by default because 'gumbel' did not work at all. 27 | """ 28 | assert mode in ["gumbel", "average"] 29 | B = model_log_prob.size(0) 30 | N = tokenizer.N_bbox_per_var 31 | device = model_log_prob.device 32 | step = len(tokenizer.var_names) 33 | bt = tokenizer.bbox_tokenizer 34 | 35 | # get bbox logits for canvas (B, C, X) 36 | canvas_ids = bt.encode(FloatTensor([[[0.5, 0.5, 1.0, 1.0]]])).long() 37 | canvas_ids += tokenizer.N_category 38 | canvas_logits = index_to_log_onehot( 39 | repeat(canvas_ids, "1 1 x -> b x", b=B), tokenizer.N_total 40 | ).to(model_log_prob) 41 | 42 | # get element-wise mask (B, S+1) 43 | mask = cond["seq"][..., ::step] != tokenizer.name_to_id("pad") 44 | mask = torch.cat([torch.full((B, 1), fill_value=True).to(mask), mask], dim=1).to( 45 | device 46 | ) 47 | 48 | if bt.shared_bbox_vocab == "xywh": 49 | slices = [ 50 | slice(tokenizer.N_category, tokenizer.N_category + N) 51 | for _ in range(step - 1) 52 | ] 53 | else: 54 | slices = [ 55 | slice(tokenizer.N_category + i * N, tokenizer.N_category + (i + 1) * N) 56 | for i in range(step - 1) 57 | ] 58 | 59 | bbox_logits = [] 60 | for i in range(step - 1): 61 | bbox_logit = torch.cat( 62 | [ 63 | canvas_logits[:, slices[i], i : i + 1], 64 | model_log_prob[:, slices[i], (i + 1) :: step], 65 | ], 66 | dim=2, 67 | ) 68 | # why requires_grad diminishes in maskgit? 69 | bbox_logits.append(bbox_logit) 70 | 71 | bbox_logits = rearrange(torch.stack(bbox_logits, dim=-1), "b n s x -> b s n x") 72 | bbox_logits = bbox_logits[mask] 73 | 74 | if mode == "gumbel": 75 | bbox_prob = F.gumbel_softmax(bbox_logits, tau=tau, hard=True, dim=1) 76 | elif mode == "average": 77 | bbox_prob = F.softmax(bbox_logits, dim=1) 78 | 79 | centers = [] 80 | for name in bt.var_names: 81 | centers.append(bt.clustering_models[f"{name}-{N}"].cluster_centers_) 82 | centers = torch.cat([torch.from_numpy(arr) for arr in centers], dim=1) 83 | centers = rearrange(centers, "n x -> 1 n x") 84 | bbox = reduce(bbox_prob * centers.to(bbox_prob), "e n x -> e x", reduction="sum") 85 | return bbox 86 | 87 | 88 | def update( 89 | t: int, 90 | cond: Dict, 91 | model_log_prob: FloatTensor, # (B, C, S) 92 | tokenizer: LayoutSequenceTokenizer, 93 | sampling_cfg: Optional[DictConfig] = None, 94 | ): 95 | """ 96 | Update model_log_prob multiple times following Eq. 7. 97 | model_log_prob corresponds to p_{\theta}(\bm{z}_{t-1}|\bm{z}_{t}). 98 | """ 99 | # detach var. in order not to backpropagate thrhough diffusion model p_{\theta}. 100 | optim_target_log_prob = torch.nn.Parameter(model_log_prob.detach()) 101 | 102 | # we found that adaptive optimizer does not work. 103 | optimizer = torch.optim.SGD( 104 | [optim_target_log_prob], lr=sampling_cfg.relation_lambda 105 | ) 106 | batch = cond["batch_w_canvas"].to(model_log_prob.device) 107 | T = 0 if t < 10 else sampling_cfg.relation_num_update 108 | for _ in range(T): 109 | optimizer.zero_grad() 110 | bbox_flatten = _stochastic_convert( 111 | cond=cond, 112 | model_log_prob=optim_target_log_prob, 113 | tokenizer=tokenizer, 114 | tau=sampling_cfg.relation_tau, 115 | mode=sampling_cfg.relation_mode, 116 | ) 117 | if len(batch.edge_index) == 0: 118 | # sometimes there are no edge in batch_size = 1 119 | continue 120 | loss = [f(bbox_flatten, batch) for f in relational_constraints] 121 | loss = torch.stack(loss, dim=-1) 122 | loss = loss.mean() 123 | loss.backward() 124 | optimizer.step() 125 | 126 | return optim_target_log_prob.detach() 127 | -------------------------------------------------------------------------------- /src/trainer/trainer/models/categorical_diffusion/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | EPS = 1e-30 8 | LOG_EPS = math.log(1e-30) 9 | 10 | 11 | def mean_except_batch(x, num_dims=1): 12 | return x.reshape(*x.shape[:num_dims], -1).mean(-1) 13 | 14 | 15 | def log_1_min_a(a): 16 | return torch.log(1 - a.exp() + 1e-40) 17 | 18 | 19 | def log_add_exp(a, b): 20 | maximum = torch.max(a, b) 21 | return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum)) 22 | 23 | 24 | def extract(a, t, x_shape): 25 | b, *_ = t.shape 26 | out = a.gather(-1, t) 27 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 28 | 29 | 30 | def log_categorical(log_x_start, log_prob): 31 | return (log_x_start.exp() * log_prob).sum(dim=1) 32 | 33 | 34 | def index_to_log_onehot(x, num_classes): 35 | assert x.max().item() < num_classes, f"Error: {x.max().item()} >= {num_classes}" 36 | x_onehot = F.one_hot(x, num_classes) 37 | permute_order = (0, -1) + tuple(range(1, len(x.size()))) 38 | x_onehot = x_onehot.permute(permute_order) 39 | log_x = torch.log(x_onehot.float().clamp(min=1e-30)) 40 | return log_x 41 | 42 | 43 | def log_onehot_to_index(log_x): 44 | return log_x.argmax(1) 45 | 46 | 47 | def alpha_schedule( 48 | num_timesteps, N=100, att_1=0.99999, att_T=0.000009, ctt_1=0.000009, ctt_T=0.99999 49 | ): 50 | # note: 0.0 will tends to raise unexpected behaviour (e.g., log(0.0)), thus avoid 0.0 51 | assert att_1 > 0.0 and att_T > 0.0 and ctt_1 > 0.0 and ctt_T > 0.0 52 | assert att_1 + ctt_1 <= 1.0 and att_T + ctt_T <= 1.0 53 | 54 | att = np.arange(0, num_timesteps) / (num_timesteps - 1) * (att_T - att_1) + att_1 55 | att = np.concatenate(([1], att)) 56 | at = att[1:] / att[:-1] 57 | ctt = np.arange(0, num_timesteps) / (num_timesteps - 1) * (ctt_T - ctt_1) + ctt_1 58 | ctt = np.concatenate(([0], ctt)) 59 | one_minus_ctt = 1 - ctt 60 | one_minus_ct = one_minus_ctt[1:] / one_minus_ctt[:-1] 61 | ct = 1 - one_minus_ct 62 | bt = (1 - at - ct) / N 63 | att = np.concatenate((att[1:], [1])) 64 | ctt = np.concatenate((ctt[1:], [0])) 65 | btt = (1 - att - ctt) / N 66 | 67 | def _f(x): 68 | return torch.tensor(x.astype("float64")) 69 | 70 | return _f(at), _f(bt), _f(ct), _f(att), _f(btt), _f(ctt) 71 | -------------------------------------------------------------------------------- /src/trainer/trainer/models/categorical_diffusion/vanilla.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from omegaconf import DictConfig 5 | 6 | from .base import BaseMaskAndReplaceDiffusion 7 | from .util import ( 8 | extract, 9 | index_to_log_onehot, 10 | log_1_min_a, 11 | log_add_exp, 12 | log_categorical, 13 | log_onehot_to_index, 14 | mean_except_batch, 15 | ) 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class VanillaMaskAndReplaceDiffusion(BaseMaskAndReplaceDiffusion): 21 | """ 22 | Reference: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/4d4cbefe3ed917ec2953af5879aa7608a171b91f/labml_nn/diffusion/ddpm 23 | Notation is strictly following DDPM paper to avoid confusion 24 | """ 25 | 26 | def __init__( 27 | self, 28 | backbone_cfg: DictConfig, 29 | num_classes: int, 30 | max_token_length: int, 31 | num_timesteps: int = 100, 32 | **kwargs, 33 | ) -> None: 34 | super().__init__( 35 | backbone_cfg=backbone_cfg, 36 | num_classes=num_classes, 37 | max_token_length=max_token_length, 38 | num_timesteps=num_timesteps, 39 | **kwargs, 40 | ) 41 | 42 | if self.alpha_init_type == "alpha1": 43 | N = self.num_classes - 1 44 | at, bt, ct, att, btt, ctt = self.alpha_schedule_partial_func(N=N) 45 | else: 46 | print("alpha_init_type is Wrong !! ") 47 | 48 | log_at, log_bt, log_ct = torch.log(at), torch.log(bt), torch.log(ct) 49 | log_cumprod_at, log_cumprod_bt, log_cumprod_ct = ( 50 | torch.log(att), 51 | torch.log(btt), 52 | torch.log(ctt), 53 | ) 54 | 55 | log_1_min_ct = log_1_min_a(log_ct) 56 | log_1_min_cumprod_ct = log_1_min_a(log_cumprod_ct) 57 | 58 | assert log_add_exp(log_ct, log_1_min_ct).abs().sum().item() < 1.0e-5 59 | assert ( 60 | log_add_exp(log_cumprod_ct, log_1_min_cumprod_ct).abs().sum().item() 61 | < 1.0e-5 62 | ) 63 | 64 | # Convert to float32 and register buffers. 65 | self.register_buffer("log_at", log_at.float()) 66 | self.register_buffer("log_bt", log_bt.float()) 67 | self.register_buffer("log_ct", log_ct.float()) 68 | self.register_buffer("log_cumprod_at", log_cumprod_at.float()) 69 | self.register_buffer("log_cumprod_bt", log_cumprod_bt.float()) 70 | self.register_buffer("log_cumprod_ct", log_cumprod_ct.float()) 71 | self.register_buffer("log_1_min_ct", log_1_min_ct.float()) 72 | self.register_buffer("log_1_min_cumprod_ct", log_1_min_cumprod_ct.float()) 73 | 74 | def q_pred_one_timestep(self, log_x_t, t): # q(xt|xt_1) 75 | log_at = extract(self.log_at, t, log_x_t.shape) # at 76 | log_bt = extract(self.log_bt, t, log_x_t.shape) # bt 77 | log_ct = extract(self.log_ct, t, log_x_t.shape) # ct 78 | log_1_min_ct = extract(self.log_1_min_ct, t, log_x_t.shape) # 1-ct 79 | 80 | log_probs = torch.cat( 81 | [ 82 | log_add_exp(log_x_t[:, :-1, :] + log_at, log_bt), 83 | log_add_exp(log_x_t[:, -1:, :] + log_1_min_ct, log_ct), 84 | ], 85 | dim=1, 86 | ) 87 | 88 | return log_probs 89 | 90 | def q_pred(self, log_x_start, t): # q(xt|x0) 91 | # log_x_start can be onehot or not 92 | t = (t + (self.num_timesteps + 1)) % (self.num_timesteps + 1) 93 | log_cumprod_at = extract(self.log_cumprod_at, t, log_x_start.shape) # at~ 94 | log_cumprod_bt = extract(self.log_cumprod_bt, t, log_x_start.shape) # bt~ 95 | log_cumprod_ct = extract(self.log_cumprod_ct, t, log_x_start.shape) # ct~ 96 | log_1_min_cumprod_ct = extract( 97 | self.log_1_min_cumprod_ct, t, log_x_start.shape 98 | ) # 1-ct~ 99 | 100 | log_probs = torch.cat( 101 | [ 102 | log_add_exp(log_x_start[:, :-1, :] + log_cumprod_at, log_cumprod_bt), 103 | log_add_exp( 104 | log_x_start[:, -1:, :] + log_1_min_cumprod_ct, log_cumprod_ct 105 | ), 106 | ], 107 | dim=1, 108 | ) 109 | 110 | return log_probs 111 | 112 | def q_posterior( 113 | self, log_x_start, log_x_t, t 114 | ): # p_theta(xt_1|xt) = sum(q(xt-1|xt,x0')*p(x0')) 115 | # notice that log_x_t is onehot 116 | assert t.min().item() >= 0 and t.max().item() < self.num_timesteps 117 | batch_size = log_x_start.size()[0] 118 | onehot_x_t = log_onehot_to_index(log_x_t) 119 | mask = (onehot_x_t == self.num_classes - 1).unsqueeze(1) 120 | log_one_vector = torch.zeros(batch_size, 1, 1).type_as(log_x_t) 121 | log_zero_vector = torch.log(log_one_vector + 1.0e-30).expand( 122 | -1, -1, self.max_token_length 123 | ) 124 | 125 | log_qt = self.q_pred(log_x_t, t) # q(xt|x0) 126 | # log_qt = torch.cat((log_qt[:,:-1,:], log_zero_vector), dim=1) 127 | log_qt = log_qt[:, :-1, :] 128 | log_cumprod_ct = extract(self.log_cumprod_ct, t, log_x_start.shape) # ct~ 129 | ct_cumprod_vector = log_cumprod_ct.expand(-1, self.num_classes - 1, -1) 130 | # ct_cumprod_vector = torch.cat((ct_cumprod_vector, log_one_vector), dim=1) 131 | log_qt = (~mask) * log_qt + mask * ct_cumprod_vector 132 | 133 | log_qt_one_timestep = self.q_pred_one_timestep(log_x_t, t) # q(xt|xt_1) 134 | log_qt_one_timestep = torch.cat( 135 | (log_qt_one_timestep[:, :-1, :], log_zero_vector), dim=1 136 | ) 137 | log_ct = extract(self.log_ct, t, log_x_start.shape) # ct 138 | ct_vector = log_ct.expand(-1, self.num_classes - 1, -1) 139 | ct_vector = torch.cat((ct_vector, log_one_vector), dim=1) 140 | log_qt_one_timestep = (~mask) * log_qt_one_timestep + mask * ct_vector 141 | 142 | # log_x_start = torch.cat((log_x_start, log_zero_vector), dim=1) 143 | # q = log_x_start - log_qt 144 | q = log_x_start[:, :-1, :] - log_qt 145 | q = torch.cat((q, log_zero_vector), dim=1) 146 | q_log_sum_exp = torch.logsumexp(q, dim=1, keepdim=True) 147 | q = q - q_log_sum_exp 148 | log_EV_xtmin_given_xt_given_xstart = ( 149 | self.q_pred(q, t - 1) + log_qt_one_timestep + q_log_sum_exp 150 | ) 151 | return torch.clamp(log_EV_xtmin_given_xt_given_xstart, -70, 0) 152 | 153 | def q_sample(self, log_x_start, t): # diffusion step, q(xt|x0) and sample xt 154 | log_EV_qxt_x0 = self.q_pred(log_x_start, t) 155 | 156 | log_sample = self.log_sample_categorical(log_EV_qxt_x0) 157 | 158 | return log_sample 159 | 160 | def forward(self, x, is_train=True): # get the KL loss 161 | b, device = x.size(0), x.device 162 | 163 | assert self.loss_type == "vb_stochastic" 164 | x_start = x 165 | t, pt = self.sample_time(b, device, "importance") 166 | 167 | log_x_start = index_to_log_onehot(x_start, self.num_classes) 168 | log_xt = self.q_sample(log_x_start=log_x_start, t=t) 169 | xt = log_onehot_to_index(log_xt) 170 | 171 | ############### go to p_theta function ############### 172 | log_x0_recon = self.predict_start(log_xt, t=t) # P_theta(x0|xt) 173 | log_model_prob = self.q_posterior( 174 | log_x_start=log_x0_recon, log_x_t=log_xt, t=t 175 | ) # go through q(xt_1|xt,x0) 176 | 177 | ################## compute acc list ################ 178 | x0_recon = log_onehot_to_index(log_x0_recon) 179 | x0_real = x_start 180 | xt_1_recon = log_onehot_to_index(log_model_prob) 181 | xt_recon = log_onehot_to_index(log_xt) 182 | for index in range(t.size()[0]): 183 | this_t = t[index].item() 184 | same_rate = ( 185 | x0_recon[index] == x0_real[index] 186 | ).sum().cpu() / x0_real.size()[1] 187 | self.diffusion_acc_list[this_t] = ( 188 | same_rate.item() * 0.1 + self.diffusion_acc_list[this_t] * 0.9 189 | ) 190 | same_rate = ( 191 | xt_1_recon[index] == xt_recon[index] 192 | ).sum().cpu() / xt_recon.size()[1] 193 | self.diffusion_keep_list[this_t] = ( 194 | same_rate.item() * 0.1 + self.diffusion_keep_list[this_t] * 0.9 195 | ) 196 | 197 | # compute log_true_prob now 198 | log_true_prob = self.q_posterior(log_x_start=log_x_start, log_x_t=log_xt, t=t) 199 | kl = self.multinomial_kl(log_true_prob, log_model_prob) 200 | mask_region = (xt == self.num_classes - 1).float() 201 | mask_weight = ( 202 | mask_region * self.mask_weight[0] 203 | + (1.0 - mask_region) * self.mask_weight[1] 204 | ) 205 | kl = kl * mask_weight 206 | kl = mean_except_batch(kl) 207 | 208 | decoder_nll = -log_categorical(log_x_start, log_model_prob) 209 | decoder_nll = mean_except_batch(decoder_nll) 210 | 211 | mask = (t == torch.zeros_like(t)).float() 212 | kl_loss = mask * decoder_nll + (1.0 - mask) * kl 213 | 214 | Lt2 = kl_loss.pow(2) 215 | Lt2_prev = self.Lt_history.gather(dim=0, index=t) 216 | new_Lt_history = (0.1 * Lt2 + 0.9 * Lt2_prev).detach() 217 | self.Lt_history.scatter_(dim=0, index=t, src=new_Lt_history) 218 | self.Lt_count.scatter_add_(dim=0, index=t, src=torch.ones_like(Lt2)) 219 | 220 | # Upweigh loss term of the kl 221 | loss1 = kl_loss / pt 222 | losses = {"kl_loss": loss1.mean()} 223 | if self.auxiliary_loss_weight != 0 and is_train == True: 224 | kl_aux = self.multinomial_kl( 225 | log_x_start[:, :-1, :], log_x0_recon[:, :-1, :] 226 | ) 227 | kl_aux = kl_aux * mask_weight 228 | kl_aux = mean_except_batch(kl_aux) 229 | kl_aux_loss = mask * decoder_nll + (1.0 - mask) * kl_aux 230 | if self.adaptive_auxiliary_loss == True: 231 | addition_loss_weight = (1 - t / self.num_timesteps) + 1.0 232 | else: 233 | addition_loss_weight = 1.0 234 | 235 | loss2 = addition_loss_weight * self.auxiliary_loss_weight * kl_aux_loss / pt 236 | 237 | losses["aux_loss"] = loss2.mean() 238 | 239 | outputs = {"probs": log_model_prob.exp()} 240 | return outputs, losses 241 | -------------------------------------------------------------------------------- /src/trainer/trainer/models/clg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyberAgentAILab/layout-dm/873b5eebe4c61862e5c08a10859accf65a168dfd/src/trainer/trainer/models/clg/__init__.py -------------------------------------------------------------------------------- /src/trainer/trainer/models/clg/const.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch_geometric.utils import to_dense_adj, to_dense_batch 6 | from trainer.data.util import REL_SIZE_ALPHA, RelLoc, RelSize 7 | from trainer.helpers.metric import compute_alignment, compute_overlap 8 | from trainer.helpers.util import convert_xywh_to_ltrb 9 | 10 | 11 | def beautify_alignment(bbox_flatten, data, threshold=0.004, **kwargs): 12 | bbox, mask = to_dense_batch(bbox_flatten, data.batch) 13 | bbox, mask = bbox[:, 1:], mask[:, 1:] 14 | 15 | if len(bbox_flatten.size()) == 3: 16 | bbox = bbox.transpose(1, 2) 17 | B, P, N, D = bbox.size() 18 | bbox = bbox.reshape(-1, N, D) 19 | mask = mask.unsqueeze(1).expand(-1, P, -1).reshape(-1, N) 20 | 21 | cost = compute_alignment(bbox, mask) 22 | cost = cost.masked_fill(cost.le(threshold), 0) 23 | 24 | if len(bbox_flatten.size()) == 3: 25 | cost = cost.view(B, P) 26 | 27 | return cost 28 | 29 | 30 | def beautify_non_overlap(bbox_flatten, data, **kwargs): 31 | bbox, mask = to_dense_batch(bbox_flatten, data.batch) 32 | bbox, mask = bbox[:, 1:], mask[:, 1:] 33 | 34 | if len(bbox_flatten.size()) == 3: 35 | bbox = bbox.transpose(1, 2) 36 | B, P, N, D = bbox.size() 37 | bbox = bbox.reshape(-1, N, D) 38 | mask = mask.unsqueeze(1).expand(-1, P, -1).reshape(-1, N) 39 | 40 | cost = compute_overlap(bbox, mask) 41 | 42 | if len(bbox_flatten.size()) == 3: 43 | cost = cost.view(B, P) 44 | 45 | return cost 46 | 47 | 48 | beautify = [beautify_alignment, beautify_non_overlap] 49 | 50 | 51 | def less_equal(a, b): 52 | return torch.relu(a - b) 53 | 54 | 55 | def less(a, b, eps=1e-8): 56 | return torch.relu(a - b + eps) 57 | 58 | 59 | def _relation_size(rel_value, cost_func, bbox_flatten, data, canvas): 60 | cond = data.y[data.edge_index[0]].eq(0).eq(canvas) 61 | cond &= (data.edge_attr & 1 << rel_value).ne(0) 62 | 63 | if len(bbox_flatten.size()) == 3: 64 | cond = cond.unsqueeze(-1) 65 | a = bbox_flatten[:, :, 2] * bbox_flatten[:, :, 3] 66 | else: 67 | a = bbox_flatten[:, 2] * bbox_flatten[:, 3] 68 | 69 | ai, aj = a[data.edge_index[0]], a[data.edge_index[1]] 70 | 71 | cost = cost_func(ai, aj).masked_fill(~cond, 0) 72 | cost = to_dense_adj(data.edge_index, data.batch, cost) 73 | cost = cost.sum(dim=(1, 2)) 74 | 75 | return cost 76 | 77 | 78 | def relation_size_sm(bbox_flatten, data, canvas=False): 79 | def cost_func(a1, a2): 80 | # a2 <= a1_sm 81 | a1_sm = (1 - REL_SIZE_ALPHA) * a1 82 | return less_equal(a2, a1_sm) 83 | 84 | return _relation_size(RelSize.SMALLER, cost_func, bbox_flatten, data, canvas) 85 | 86 | 87 | def relation_size_eq(bbox_flatten, data, canvas=False): 88 | def cost_func(a1, a2): 89 | # a1_sm < a2 and a2 < a1_lg 90 | a1_sm = (1 - REL_SIZE_ALPHA) * a1 91 | a1_lg = (1 + REL_SIZE_ALPHA) * a1 92 | return less(a1_sm, a2) + less(a2, a1_lg) 93 | 94 | return _relation_size(RelSize.EQUAL, cost_func, bbox_flatten, data, canvas) 95 | 96 | 97 | def relation_size_lg(bbox_flatten, data, canvas=False): 98 | def cost_func(a1, a2): 99 | # a1_lg <= a2 100 | a1_lg = (1 + REL_SIZE_ALPHA) * a1 101 | return less_equal(a1_lg, a2) 102 | 103 | return _relation_size(RelSize.LARGER, cost_func, bbox_flatten, data, canvas) 104 | 105 | 106 | def _relation_loc_canvas(rel_value, cost_func, bbox_flatten, data): 107 | cond = data.y[data.edge_index[0]].eq(0) 108 | cond &= (data.edge_attr & 1 << rel_value).ne(0) 109 | 110 | if len(bbox_flatten.size()) == 3: 111 | cond = cond.unsqueeze(-1) 112 | yc = bbox_flatten[:, :, 1] 113 | else: 114 | yc = bbox_flatten[:, 1] 115 | 116 | yc = yc[data.edge_index[1]] 117 | 118 | cost = cost_func(yc).masked_fill(~cond, 0) 119 | cost = to_dense_adj(data.edge_index, data.batch, cost) 120 | cost = cost.sum(dim=(1, 2)) 121 | 122 | return cost 123 | 124 | 125 | def relation_loc_canvas_t(bbox_flatten, data): 126 | def cost_func(yc): 127 | # yc <= y_sm 128 | y_sm = 1.0 / 3 129 | return less_equal(yc, y_sm) 130 | 131 | return _relation_loc_canvas(RelLoc.TOP, cost_func, bbox_flatten, data) 132 | 133 | 134 | def relation_loc_canvas_c(bbox_flatten, data): 135 | def cost_func(yc): 136 | # y_sm < yc and yc < y_lg 137 | y_sm, y_lg = 1.0 / 3, 2.0 / 3 138 | return less(y_sm, yc) + less(yc, y_lg) 139 | 140 | return _relation_loc_canvas(RelLoc.CENTER, cost_func, bbox_flatten, data) 141 | 142 | 143 | def relation_loc_canvas_b(bbox_flatten, data): 144 | def cost_func(yc): 145 | # y_lg <= yc 146 | y_lg = 2.0 / 3 147 | return less_equal(y_lg, yc) 148 | 149 | return _relation_loc_canvas(RelLoc.BOTTOM, cost_func, bbox_flatten, data) 150 | 151 | 152 | def _relation_loc(rel_value, cost_func, bbox_flatten, data): 153 | cond = data.y[data.edge_index[0]].ne(0) 154 | cond &= (data.edge_attr & 1 << rel_value).ne(0) 155 | 156 | if len(bbox_flatten.size()) == 3: 157 | cond = cond.unsqueeze(-1) 158 | l, t, r, b = convert_xywh_to_ltrb(bbox_flatten.permute(2, 0, 1)) 159 | else: 160 | l, t, r, b = convert_xywh_to_ltrb(bbox_flatten.t()) 161 | 162 | li, lj = l[data.edge_index[0]], l[data.edge_index[1]] 163 | ti, tj = t[data.edge_index[0]], t[data.edge_index[1]] 164 | ri, rj = r[data.edge_index[0]], r[data.edge_index[1]] 165 | bi, bj = b[data.edge_index[0]], b[data.edge_index[1]] 166 | 167 | cost = cost_func(l1=li, t1=ti, r1=ri, b1=bi, l2=lj, t2=tj, r2=rj, b2=bj) 168 | 169 | if rel_value in [RelLoc.LEFT, RelLoc.RIGHT, RelLoc.CENTER]: 170 | # t1 < b2 and t2 < b1 171 | cost = cost + less(ti, bj) + less(tj, bi) 172 | 173 | cost = cost.masked_fill(~cond, 0) 174 | cost = to_dense_adj(data.edge_index, data.batch, cost) 175 | cost = cost.sum(dim=(1, 2)) 176 | 177 | return cost 178 | 179 | 180 | def relation_loc_t(bbox_flatten, data): 181 | def cost_func(b2, t1, **kwargs): 182 | # b2 <= t1 183 | return less_equal(b2, t1) 184 | 185 | return _relation_loc(RelLoc.TOP, cost_func, bbox_flatten, data) 186 | 187 | 188 | def relation_loc_b(bbox_flatten, data): 189 | def cost_func(b1, t2, **kwargs): 190 | # b1 <= t2 191 | return less_equal(b1, t2) 192 | 193 | return _relation_loc(RelLoc.BOTTOM, cost_func, bbox_flatten, data) 194 | 195 | 196 | def relation_loc_l(bbox_flatten, data): 197 | def cost_func(r2, l1, **kwargs): 198 | # r2 <= l1 199 | return less_equal(r2, l1) 200 | 201 | return _relation_loc(RelLoc.LEFT, cost_func, bbox_flatten, data) 202 | 203 | 204 | def relation_loc_r(bbox_flatten, data): 205 | def cost_func(r1, l2, **kwargs): 206 | # r1 <= l2 207 | return less_equal(r1, l2) 208 | 209 | return _relation_loc(RelLoc.RIGHT, cost_func, bbox_flatten, data) 210 | 211 | 212 | def relation_loc_c(bbox_flatten, data): 213 | def cost_func(l1, r2, l2, r1, **kwargs): 214 | # l1 < r2 and l2 < r1 215 | return less(l1, r2) + less(l2, r1) 216 | 217 | return _relation_loc(RelLoc.CENTER, cost_func, bbox_flatten, data) 218 | 219 | 220 | relation = [ 221 | partial(relation_size_sm, canvas=False), 222 | partial(relation_size_sm, canvas=True), 223 | partial(relation_size_eq, canvas=False), 224 | partial(relation_size_eq, canvas=True), 225 | partial(relation_size_lg, canvas=False), 226 | partial(relation_size_lg, canvas=True), 227 | relation_loc_canvas_t, 228 | relation_loc_canvas_c, 229 | relation_loc_canvas_b, 230 | relation_loc_t, 231 | relation_loc_b, 232 | relation_loc_l, 233 | relation_loc_r, 234 | relation_loc_c, 235 | ] 236 | -------------------------------------------------------------------------------- /src/trainer/trainer/models/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyberAgentAILab/layout-dm/873b5eebe4c61862e5c08a10859accf65a168dfd/src/trainer/trainer/models/common/__init__.py -------------------------------------------------------------------------------- /src/trainer/trainer/models/common/layout.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict 3 | 4 | import torch 5 | import torch.nn as nn 6 | from einops import rearrange 7 | from torch import FloatTensor, LongTensor 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class BboxEncoder(torch.nn.Module): 13 | def __init__( 14 | self, 15 | num_bin_bboxes: int, 16 | output_dim: int, 17 | fusion: str = "emb_concat", 18 | ) -> None: 19 | super().__init__() 20 | self.fusion = fusion 21 | if fusion == "linear": 22 | # self.emb = nn.Linear(4, output_dim) 23 | raise NotImplementedError 24 | elif fusion in ["emb_concat", "emb_add"]: 25 | self.x_emb = nn.Embedding(num_bin_bboxes, output_dim) 26 | self.y_emb = nn.Embedding(num_bin_bboxes, output_dim) 27 | self.w_emb = nn.Embedding(num_bin_bboxes, output_dim) 28 | self.h_emb = nn.Embedding(num_bin_bboxes, output_dim) 29 | else: 30 | raise NotImplementedError 31 | 32 | def forward(self, bbox: LongTensor) -> FloatTensor: 33 | if self.fusion == "linear": 34 | emb = self.emb(bbox.float()) 35 | elif self.fusion in ["emb_concat", "emb_add"]: 36 | embs = [] 37 | for (key, value) in zip( 38 | ["x", "y", "w", "h"], 39 | torch.split(bbox, split_size_or_sections=1, dim=-1), 40 | ): 41 | embs.append( 42 | getattr(self, f"{key}_emb")(rearrange(value, "b s 1 -> b s")) 43 | ) 44 | if self.fusion == "emb_add": 45 | emb = sum(embs) 46 | else: 47 | emb = torch.cat(embs, dim=-1) 48 | else: 49 | raise NotImplementedError 50 | return emb 51 | 52 | 53 | class LayoutEncoder(torch.nn.Module): 54 | def __init__( 55 | self, 56 | output_dim: int, 57 | num_classes: int, 58 | lb_fusion: str = "concat_fc", 59 | ) -> None: 60 | super().__init__() 61 | assert lb_fusion in ["add", "concat_fc"] 62 | self.lb_fusion = lb_fusion 63 | self.bbox_fusion = "emb_concat" 64 | 65 | if self.lb_fusion == "concat_fc": 66 | self.label_emb = nn.Embedding(num_classes, output_dim) 67 | self.bbox_emb = BboxEncoder( 68 | num_classes, output_dim, fusion=self.bbox_fusion 69 | ) 70 | if self.bbox_fusion == "emb_concat": 71 | self.fc = nn.Linear(output_dim * 5, output_dim) 72 | elif self.bbox_fusion == "emb_add": 73 | self.fc = nn.Linear(output_dim * 2, output_dim) 74 | 75 | elif self.lb_fusion == "add": 76 | assert self.bbox_fusion == "emb_add" 77 | self.label_emb = nn.Embedding(num_classes, output_dim) 78 | self.bbox_emb = BboxEncoder( 79 | num_classes, output_dim, fusion=self.bbox_fusion 80 | ) 81 | else: 82 | raise NotImplementedError 83 | 84 | def forward(self, inputs: Dict[str, LongTensor]) -> FloatTensor: 85 | h_label = self.label_emb(inputs["label"]) 86 | h_bbox = self.bbox_emb(inputs["bbox"]) 87 | if self.lb_fusion == "concat_fc": 88 | h = torch.cat([h_label, h_bbox], dim=-1) 89 | h = self.fc(h) 90 | elif self.lb_fusion == "add": 91 | h = h_label + h_bbox 92 | else: 93 | raise NotImplementedError 94 | if "mask" in inputs: 95 | mask_float = inputs["mask"].float() 96 | mask_float = rearrange(mask_float, "b s -> b s 1") 97 | h *= mask_float 98 | return h 99 | 100 | 101 | class LayoutDecoder(torch.nn.Module): 102 | def __init__( 103 | self, 104 | input_dim: int, 105 | num_classes: int, 106 | ) -> None: 107 | super().__init__() 108 | self.linear_label = nn.Linear(input_dim, num_classes, bias=False) 109 | self.linear_bbox = nn.Linear(input_dim, 4 * num_classes, bias=False) 110 | 111 | def forward(self, h: FloatTensor) -> Dict[str, FloatTensor]: 112 | outputs = {} 113 | outputs["logit_label"] = self.linear_label(h) # (B, S, C) 114 | logit_bbox = self.linear_bbox(h) 115 | outputs["logit_bbox"] = rearrange(logit_bbox, "b s (c x) -> b s c x", x=4) 116 | return outputs 117 | -------------------------------------------------------------------------------- /src/trainer/trainer/models/common/util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import os 4 | 5 | import fsspec 6 | import torch 7 | import torch.nn as nn 8 | from omegaconf import DictConfig 9 | from torch import BoolTensor 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def generate_causal_mask(sz: int) -> BoolTensor: 15 | """Generates an upper-triangular matrix of -inf, with zeros on diag.""" 16 | return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1) 17 | 18 | 19 | def get_dim_model(backbone_cfg: DictConfig) -> int: 20 | """ 21 | It takes hierarchical config for 22 | - trainer.models.transformer_utils import TransformerEncoder 23 | and get a number of dimension inside the Transformer 24 | """ 25 | result = None 26 | for key, value in backbone_cfg.items(): 27 | if key == "d_model": 28 | result = value 29 | elif isinstance(value, DictConfig): 30 | x = get_dim_model(value) 31 | if x: 32 | result = x 33 | return result 34 | 35 | 36 | def shrink(backbone_cfg: DictConfig, mult: float) -> DictConfig: 37 | """ 38 | Rescale dimension of a model linearly 39 | """ 40 | new_backbone_cfg = copy.deepcopy(backbone_cfg) 41 | l = new_backbone_cfg.encoder_layer 42 | for key in ["d_model", "dim_feedforward"]: 43 | new_backbone_cfg.encoder_layer[key] = int(mult * l[key]) 44 | return new_backbone_cfg 45 | 46 | 47 | def load_model( 48 | model: nn.Module, 49 | ckpt_dir: str, 50 | device: torch.device, 51 | best_or_final: str = "best", 52 | extension: str = ".pt", 53 | ): 54 | model_path = os.path.join(ckpt_dir, f"{best_or_final}_model{extension}") 55 | with fsspec.open(str(model_path), "rb") as file_obj: 56 | model.load_state_dict(torch.load(file_obj, map_location=device)) 57 | return model 58 | 59 | 60 | def save_model(model: nn.Module, ckpt_dir: str, best_or_final: str = "best"): 61 | model_path = os.path.join(ckpt_dir, f"{best_or_final}_model.pt") 62 | with fsspec.open(str(model_path), "wb") as file_obj: 63 | torch.save(model.state_dict(), file_obj) 64 | return model 65 | -------------------------------------------------------------------------------- /src/trainer/trainer/models/continuous_diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyberAgentAILab/layout-dm/873b5eebe4c61862e5c08a10859accf65a168dfd/src/trainer/trainer/models/continuous_diffusion/__init__.py -------------------------------------------------------------------------------- /src/trainer/trainer/models/continuous_diffusion/bitdiffusion.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import rearrange, reduce, repeat 6 | from torch import FloatTensor, LongTensor 7 | from trainer.helpers.layout_tokenizer import LayoutSequenceTokenizer 8 | 9 | from .base import ContinuousDiffusionBase 10 | 11 | 12 | def ids_to_bits(x: LongTensor, num_bits: int) -> FloatTensor: 13 | """ 14 | Given ids with shape (B, S), returns bits (in -1. or 1. form) with shape (B, S, bits) 15 | """ 16 | assert x.max().item() < 2**num_bits 17 | mask = 2 ** torch.arange(num_bits - 1, -1, -1).to(x.device) 18 | mask = rearrange(mask, "d -> 1 d") 19 | x = rearrange(x, "b s -> b s 1") 20 | 21 | bits = ((x & mask) != 0).float() 22 | bits = bits * 2 - 1.0 23 | return bits 24 | 25 | 26 | def bits_to_ids( 27 | x: FloatTensor, num_bits: int, tokenizer: Optional[LayoutSequenceTokenizer] = None 28 | ) -> LongTensor: 29 | B, S, _ = x.size() 30 | mask = 2 ** torch.arange(num_bits - 1, -1, -1, dtype=torch.int32).to(x.device) 31 | mask = rearrange(mask, "d -> 1 d") 32 | 33 | bits = (x > 0).int() 34 | if tokenizer: 35 | base_ids = rearrange(torch.arange(2**num_bits), "d -> 1 d") 36 | base_bits = rearrange(ids_to_bits(base_ids, num_bits), "1 n c -> 1 1 n c") 37 | dist = torch.abs(rearrange(x, "b s c -> b s 1 c") - base_bits.to(x.device)) 38 | dist = reduce(dist, "b s n c -> b s n", reduction="sum") 39 | 40 | pad = torch.full((S, 2**num_bits - tokenizer.N_total), False) 41 | valid = torch.cat([tokenizer.token_mask, pad], dim=1) 42 | valid = repeat(valid, "s x -> b s x", b=B).to(x.device) 43 | dist = torch.where(valid, dist, torch.full_like(dist, fill_value=float("Inf"))) 44 | ids = torch.argmax(-dist, dim=-1) 45 | else: 46 | ids = reduce(bits * mask, "b s d -> b s", "sum").long() 47 | return ids 48 | 49 | 50 | class BitDiffusion(ContinuousDiffusionBase): 51 | def __init__( 52 | self, 53 | **kwargs, 54 | ) -> None: 55 | super().__init__(**kwargs) 56 | self.scale = 1.0 57 | self.con2logits = None 58 | 59 | def dis2con( 60 | self, seq: LongTensor, reparametrize: bool = False, normalize: bool = False 61 | ) -> Tuple[FloatTensor, FloatTensor]: 62 | assert seq.dim() == 2 63 | # return ids_to_bits(seq, self.num_channel) * self.scale, None 64 | x = ids_to_bits(seq, self.num_channel) * self.scale 65 | return x, x 66 | 67 | def con2dis(self, arr: FloatTensor) -> LongTensor: 68 | assert arr.dim() == 3 69 | return bits_to_ids(arr, self.num_channel, self.tokenizer) 70 | -------------------------------------------------------------------------------- /src/trainer/trainer/models/continuous_diffusion/diffusion_lm.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import FloatTensor, LongTensor 7 | 8 | from .base import ContinuousDiffusionBase 9 | 10 | 11 | class DiffusionLM(ContinuousDiffusionBase): 12 | """ 13 | Diffusion-LM Improves Controllable Text Generation (NeurIPS'22) 14 | https://arxiv.org/abs/2205.14217 15 | """ 16 | 17 | def __init__( 18 | self, 19 | **kwargs, 20 | ) -> None: 21 | super().__init__(**kwargs) 22 | self.rounder = nn.Linear(self.num_channel, self.tokenizer.N_total) 23 | 24 | def dis2con( 25 | self, seq: LongTensor, reparametrize: bool = False, normalize: bool = False 26 | ) -> Union[FloatTensor, Tuple[FloatTensor, FloatTensor]]: 27 | """ 28 | Args: 29 | seq: LongTensor with shape (B, S) indicating id of each token 30 | Returns: 31 | arr: FloatTensor with shape (B, S, D) indicating continuous vector 32 | """ 33 | assert seq.dim() == 2 34 | emb = self.token_emb(seq) 35 | if normalize: 36 | emb = F.normalize(emb, dim=-1) 37 | if reparametrize and hasattr(self, "con2logits"): 38 | if hasattr(self, "scheduler"): # w/ diffuser 39 | timestep = torch.zeros((1,), device=self.device).long() 40 | # get mean of the final distribution by setting zero input 41 | noise = self.scheduler.add_noise( 42 | torch.zeros_like(emb), torch.randn_like(emb), timestep 43 | ) 44 | emb_reparametrized = emb + noise 45 | else: 46 | from .base import log_snr_to_alpha_sigma 47 | 48 | rep_times = torch.zeros((1,), device=self.device).float() 49 | _, rep_sigma = log_snr_to_alpha_sigma(self.log_snr(rep_times)) 50 | emb_reparametrized = emb + rep_sigma * torch.randn_like(emb) 51 | return emb_reparametrized, emb 52 | else: 53 | return emb 54 | 55 | def con2dis(self, arr: FloatTensor) -> LongTensor: 56 | """ 57 | Args: 58 | arr: FloatTensor with shape (B, S, D) indicating continuous vector 59 | Returns: 60 | seq: LongTensor with shape (B, S) indicating id of each token 61 | """ 62 | assert arr.dim() == 3 63 | seq = torch.argmax(self.con2logits(arr), dim=-1) 64 | return seq 65 | 66 | def con2logits(self, arr: FloatTensor) -> LongTensor: 67 | """ 68 | Args: 69 | arr: FloatTensor with shape (B, S, D) indicating continuous vector 70 | Returns: 71 | logits: FloatTensor with shape (B, S, C) indicating logit for each discrete token 72 | """ 73 | assert arr.dim() == 3 74 | logits = self.rounder(arr) 75 | return logits 76 | -------------------------------------------------------------------------------- /src/trainer/trainer/models/elem_wise_autoreg.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, Iterable, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | from einops import rearrange, repeat 7 | from hydra.utils import instantiate 8 | from omegaconf import DictConfig 9 | from torch import Tensor 10 | from trainer.data.util import sparse_to_dense 11 | from trainer.helpers.layout_tokenizer import LayoutSequenceTokenizer 12 | from trainer.helpers.sampling import sample 13 | from trainer.helpers.task import ( 14 | duplicate_cond, 15 | set_additional_conditions_for_refinement, 16 | ) 17 | from trainer.models.base_model import BaseModel 18 | from trainer.models.common.nn_lib import CategoricalTransformer, CustomDataParallel 19 | from trainer.models.common.util import get_dim_model 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | class ElemWiseAutoreg(BaseModel): 25 | """ 26 | To reproduce 27 | LayoutTransformer: Layout Generation and Completion with Self-attention (ICCV2021) 28 | https://arxiv.org/abs/2006.14615 29 | """ 30 | 31 | def __init__( 32 | self, 33 | # cfg: DictConfig, 34 | backbone_cfg: DictConfig, 35 | tokenizer: LayoutSequenceTokenizer, 36 | pos_emb: str = "default", 37 | ) -> None: 38 | super().__init__() 39 | self.tokenizer = tokenizer 40 | 41 | kwargs = {} 42 | if pos_emb == "elem_attr": 43 | kwargs["n_attr_per_elem"] = tokenizer.N_var_per_element 44 | 45 | # Note: make sure learnable parameters are inside self.model 46 | backbone = instantiate(backbone_cfg) 47 | self.model = CustomDataParallel( 48 | CategoricalTransformer( 49 | backbone=backbone, 50 | dim_model=get_dim_model(backbone_cfg), 51 | num_classes=self.tokenizer.N_total, 52 | max_token_length=tokenizer.max_token_length + 1, # +1 for BOS 53 | pos_emb=pos_emb, 54 | lookahead=False, 55 | **kwargs, 56 | ) 57 | ) 58 | self.apply(self._init_weights) 59 | self.compute_stats() 60 | 61 | self.loss_fn_ce = nn.CrossEntropyLoss( 62 | label_smoothing=0.1, ignore_index=self.tokenizer.name_to_id("pad") 63 | ) 64 | 65 | def forward(self, inputs: Dict[str, Tensor]) -> Tuple[Dict[str, Tensor]]: 66 | outputs = self.model(inputs["input"]) 67 | nll_loss = self.loss_fn_ce( 68 | rearrange(outputs["logits"], "b s c -> b c s"), 69 | inputs["target"], 70 | ) 71 | losses = {"nll_loss": nll_loss} 72 | return outputs, losses 73 | 74 | def sample( 75 | self, 76 | batch_size: Optional[int], 77 | cond: Optional[Tensor] = None, 78 | sampling_cfg: Optional[DictConfig] = None, 79 | **kwargs, 80 | ) -> Dict[str, Tensor]: 81 | """ 82 | Generate sample based on z. 83 | z can be either given or sampled from some distribution (e.g., normal) 84 | """ 85 | 86 | if cond and cond["type"] == "refinement": 87 | # additional weak constraints 88 | cond = set_additional_conditions_for_refinement( 89 | cond, self.tokenizer, sampling_cfg 90 | ) 91 | 92 | if cond: 93 | cond = duplicate_cond(cond, batch_size) 94 | for k, v in cond.items(): 95 | if isinstance(v, Tensor): 96 | cond[k] = v.to(self.device) 97 | 98 | special_keys = self.tokenizer._special_token_name_to_id 99 | mask_id = self.tokenizer.name_to_id("mask") if "mask" in special_keys else -1 100 | input_ = torch.full( 101 | (batch_size, 1), fill_value=self.tokenizer.name_to_id("bos") 102 | ) 103 | input_ = input_.to(self.device) 104 | 105 | for i in range(self.tokenizer.max_token_length): 106 | with torch.no_grad(): 107 | logits = self.model(input_)["logits"] 108 | logits = rearrange(logits[:, i : i + 1], "b 1 c -> b c") 109 | 110 | if cond: 111 | if cond.get("type", None) == "relation": 112 | raise NotImplementedError 113 | # using sparse DataBatch and partial dense array to compute relation is "practically" very complex, 114 | # (i) add dummy logits to match the dimension (easy) 115 | # (ii) add "causal" loss mask because interaction between i-th and i'-th (i' < i) is valid (hard) 116 | # impose weak user-specified constraints by addition 117 | if cond.get("type", None) == "refinement": 118 | weak_mask = cond["weak_mask"][..., i + 1] 119 | weak_logits = cond["weak_logits"][..., i + 1] 120 | logits[weak_mask] += weak_logits[weak_mask] 121 | 122 | invalid = repeat( 123 | ~self.tokenizer.token_mask[i : i + 1], "1 c -> b c", b=input_.size(0) 124 | ) 125 | logits[invalid] = -float("Inf") 126 | 127 | predicted = sample(logits, sampling_cfg) 128 | if cond: 129 | id_ = cond["seq"][:, i + 1 : i + 2] 130 | if id_.size(1) == 1: 131 | # If condition exists and is valid, use it 132 | flag = id_ == mask_id 133 | predicted = torch.where(flag, predicted, id_) 134 | input_ = torch.cat([input_, predicted], dim=1) 135 | 136 | ids = input_[:, 1:].cpu() # pop BOS 137 | layouts = self.tokenizer.decode(ids) 138 | return layouts 139 | 140 | def preprocess(self, batch): 141 | bbox, label, _, mask = sparse_to_dense(batch) 142 | x = self.tokenizer.encode({"label": label, "mask": mask, "bbox": bbox}) 143 | input_ = x["seq"][:, :-1] 144 | target = x["seq"][:, 1:] 145 | return {"input": input_, "target": target} 146 | 147 | def optim_groups( 148 | self, weight_decay: float = 0.0 149 | ) -> Union[Iterable[Tensor], Dict[str, Tensor]]: 150 | base = "model.module.pos_emb" 151 | additional_no_decay = [ 152 | f"{base}.{name}" for name in self.model.module.pos_emb.no_decay_param_names 153 | ] 154 | return super().optim_groups( 155 | weight_decay=weight_decay, additional_no_decay=additional_no_decay 156 | ) 157 | -------------------------------------------------------------------------------- /src/trainer/trainer/models/layout_continuous_diffusion.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, Iterable, Optional, Tuple, Union 3 | 4 | from hydra.utils import instantiate 5 | from omegaconf import DictConfig 6 | from torch import Tensor 7 | from trainer.data.util import sparse_to_dense 8 | from trainer.helpers.layout_tokenizer import LayoutSequenceTokenizer 9 | from trainer.models.base_model import BaseModel 10 | from trainer.models.common.nn_lib import CustomDataParallel 11 | from trainer.models.common.util import get_dim_model, shrink 12 | from trainer.models.continuous_diffusion.base import init_token_embedding 13 | from trainer.models.continuous_diffusion.bitdiffusion import BitDiffusion 14 | from trainer.models.continuous_diffusion.diffusion_lm import DiffusionLM 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | BITS = 8 19 | MODELS = {"bit_diffusion": BitDiffusion, "diffusion_lm": DiffusionLM} 20 | 21 | 22 | class LayoutContinuousDiffusion(BaseModel): 23 | def __init__( 24 | self, 25 | # cfg: DictConfig, 26 | backbone_cfg: DictConfig, 27 | tokenizer: LayoutSequenceTokenizer, 28 | model_type: str, 29 | num_channel: int = 16, 30 | **kwargs, 31 | ) -> None: 32 | super().__init__() 33 | 34 | self.num_channel = num_channel 35 | 36 | self.max_len = tokenizer.max_token_length 37 | 38 | # make sure MASK is the last vocabulary 39 | assert tokenizer.id_to_name(tokenizer.N_total - 1) == "mask" 40 | self.tokenizer = tokenizer 41 | 42 | model = MODELS[model_type] 43 | # Note: make sure learnable parameters are inside self.model 44 | backbone_shrink_cfg = shrink(backbone_cfg, 29 / 32) 45 | backbone = instantiate(backbone_shrink_cfg) # for fair comparison 46 | self.model = CustomDataParallel( 47 | model( 48 | backbone=backbone, 49 | tokenizer=tokenizer, 50 | dim_model=get_dim_model(backbone_shrink_cfg), 51 | max_len=self.max_len, 52 | num_channel=self.num_channel, 53 | **kwargs, 54 | ) 55 | ) 56 | self.apply(self._init_weights) 57 | 58 | if model_type == "diffusion_lm": 59 | # re-initialize to avoid weight range specification 60 | self.model.module.token_emb = init_token_embedding( 61 | num_embeddings=tokenizer.N_total, 62 | embedding_dim=num_channel, 63 | is_learnable=self.model.module.learnable_token_emb, 64 | ) 65 | # initialize rounder by an inverse of token_emb 66 | self.model.module.rounder.weight.data = ( 67 | self.model.module.token_emb.weight.data.clone() 68 | ) 69 | 70 | self.compute_stats() 71 | 72 | @property 73 | def device(self): 74 | return next(self.model.parameters()).device 75 | 76 | def forward(self, inputs: Dict[str, Tensor]) -> Tuple[Dict[str, Tensor]]: 77 | outputs, losses = self.model(inputs) 78 | # aggregate losses for multi-GPU mode (no change in single GPU mode) 79 | new_losses = {k: v.mean() for (k, v) in losses.items()} 80 | return outputs, new_losses 81 | 82 | def sample( 83 | self, 84 | batch_size: Optional[int] = 1, 85 | cond: Optional[Dict] = None, 86 | sampling_cfg: Optional[DictConfig] = None, 87 | **kwargs, 88 | ) -> Dict[str, Tensor]: 89 | seq = self.model.sample( 90 | batch_size=batch_size, cond=cond, sampling_cfg=sampling_cfg, **kwargs 91 | ).cpu() 92 | return self.tokenizer.decode(seq) 93 | 94 | def aggregate_sampling_settings( 95 | self, sampling_cfg: DictConfig, args: DictConfig 96 | ) -> DictConfig: 97 | sampling_cfg = super().aggregate_sampling_settings(sampling_cfg, args) 98 | 99 | sampling_cfg.use_ddim = args.use_ddim 100 | sampling_cfg.time_difference = args.time_difference 101 | 102 | return sampling_cfg 103 | 104 | def preprocess(self, batch): 105 | bbox, label, _, mask = sparse_to_dense(batch) 106 | inputs = {"label": label, "mask": mask, "bbox": bbox} 107 | return self.tokenizer.encode(inputs) 108 | 109 | def optim_groups( 110 | self, weight_decay: float = 0.0 111 | ) -> Union[Iterable[Tensor], Dict[str, Tensor]]: 112 | base = "model.module.transformer.pos_emb" 113 | additional_no_decay = [ 114 | f"{base}.{name}" 115 | for name in self.model.module.transformer.pos_emb.no_decay_param_names 116 | ] 117 | return super().optim_groups( 118 | weight_decay=weight_decay, additional_no_decay=additional_no_decay 119 | ) 120 | -------------------------------------------------------------------------------- /src/trainer/trainer/models/layoutdm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, Iterable, Optional, Tuple, Union 3 | 4 | import torch 5 | from einops import repeat 6 | from omegaconf import DictConfig 7 | from torch import Tensor 8 | from trainer.data.util import sparse_to_dense 9 | from trainer.helpers.layout_tokenizer import LayoutSequenceTokenizer 10 | from trainer.models.base_model import BaseModel 11 | from trainer.models.categorical_diffusion.constrained import ( 12 | ConstrainedMaskAndReplaceDiffusion, 13 | ) 14 | from trainer.models.categorical_diffusion.vanilla import VanillaMaskAndReplaceDiffusion 15 | from trainer.models.common.nn_lib import CustomDataParallel 16 | from trainer.models.common.util import shrink 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | Q_TYPES = { 21 | "vanilla": VanillaMaskAndReplaceDiffusion, 22 | "constrained": ConstrainedMaskAndReplaceDiffusion, 23 | } 24 | 25 | 26 | class LayoutDM(BaseModel): 27 | def __init__( 28 | self, 29 | backbone_cfg: DictConfig, 30 | tokenizer: LayoutSequenceTokenizer, 31 | transformer_type: str = "flattened", 32 | pos_emb: str = "elem_attr", 33 | num_timesteps: int = 100, 34 | auxiliary_loss_weight: float = 1e-1, 35 | q_type: str = "single", 36 | seq_type: str = "poset", 37 | **kwargs, 38 | ) -> None: 39 | super().__init__() 40 | assert q_type in Q_TYPES 41 | assert seq_type in ["set", "poset"] 42 | 43 | self.pos_emb = pos_emb 44 | self.seq_type = seq_type 45 | # make sure MASK is the last vocabulary 46 | assert tokenizer.id_to_name(tokenizer.N_total - 1) == "mask" 47 | 48 | # Note: make sure learnable parameters are inside self.model 49 | self.tokenizer = tokenizer 50 | model = Q_TYPES[q_type] 51 | 52 | self.model = CustomDataParallel( 53 | model( 54 | backbone_cfg=shrink(backbone_cfg, 29 / 32), # for fair comparison 55 | num_classes=tokenizer.N_total, 56 | max_token_length=tokenizer.max_token_length, 57 | num_timesteps=num_timesteps, 58 | pos_emb=pos_emb, 59 | transformer_type=transformer_type, 60 | auxiliary_loss_weight=auxiliary_loss_weight, 61 | tokenizer=tokenizer, 62 | **kwargs, 63 | ) 64 | ) 65 | 66 | self.apply(self._init_weights) 67 | self.compute_stats() 68 | 69 | def forward(self, inputs: Dict[str, Tensor]) -> Tuple[Dict[str, Tensor]]: 70 | outputs, losses = self.model(inputs["seq"]) 71 | 72 | # aggregate losses for multi-GPU mode (no change in single GPU mode) 73 | new_losses = {k: v.mean() for (k, v) in losses.items()} 74 | 75 | return outputs, new_losses 76 | 77 | def sample( 78 | self, 79 | batch_size: Optional[int] = 1, 80 | cond: Optional[Dict] = None, 81 | sampling_cfg: Optional[DictConfig] = None, 82 | **kwargs, 83 | ) -> Dict[str, Tensor]: 84 | ids = self.model.sample( 85 | batch_size=batch_size, cond=cond, sampling_cfg=sampling_cfg, **kwargs 86 | ).cpu() 87 | layouts = self.tokenizer.decode(ids) 88 | return layouts 89 | 90 | def aggregate_sampling_settings( 91 | self, sampling_cfg: DictConfig, args: DictConfig 92 | ) -> DictConfig: 93 | sampling_cfg = super().aggregate_sampling_settings(sampling_cfg, args) 94 | if args.time_difference > 0: 95 | sampling_cfg.time_difference = args.time_difference 96 | 97 | return sampling_cfg 98 | 99 | def preprocess(self, batch): 100 | bbox, label, _, mask = sparse_to_dense(batch) 101 | inputs = {"label": label, "mask": mask, "bbox": bbox} 102 | 103 | ids = self.tokenizer.encode(inputs) 104 | if self.seq_type == "set": 105 | # randomly shuffle [PAD]'s location 106 | B, S = ids["mask"].size() 107 | C = self.tokenizer.N_var_per_element 108 | for i in range(B): 109 | indices = torch.randperm(S // C) 110 | indices = repeat(indices * C, "b -> (b c)", c=C) 111 | indices += torch.arange(S) % C 112 | for k in ids: 113 | ids[k][i, :] = ids[k][i, indices] 114 | return ids 115 | 116 | def optim_groups( 117 | self, weight_decay: float = 0.0 118 | ) -> Union[Iterable[Tensor], Dict[str, Tensor]]: 119 | base = "model.module.transformer.pos_emb" 120 | additional_no_decay = [ 121 | f"{base}.{name}" 122 | for name in self.model.module.transformer.pos_emb.no_decay_param_names 123 | ] 124 | return super().optim_groups( 125 | weight_decay=weight_decay, additional_no_decay=additional_no_decay 126 | ) 127 | -------------------------------------------------------------------------------- /src/trainer/trainer/models/ruite.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, Iterable, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | from einops import rearrange 7 | from hydra.utils import instantiate 8 | from omegaconf import DictConfig 9 | from torch import Tensor 10 | from torch_geometric.utils import to_dense_batch 11 | from trainer.data.util import sparse_to_dense 12 | from trainer.helpers.layout_tokenizer import LayoutSequenceTokenizer 13 | from trainer.helpers.sampling import sample 14 | from trainer.helpers.task import duplicate_cond 15 | from trainer.models.base_model import BaseModel 16 | from trainer.models.common.nn_lib import CategoricalTransformer, CustomDataParallel 17 | from trainer.models.common.util import get_dim_model 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class RUITE(BaseModel): 23 | """ 24 | To reproduce 25 | Refining ui layout aesthetics using transformer encoder 26 | https://dl.acm.org/doi/10.1145/3397482.3450716 27 | 28 | """ 29 | 30 | def __init__( 31 | self, 32 | backbone_cfg: DictConfig, 33 | tokenizer: LayoutSequenceTokenizer, 34 | ) -> None: 35 | super().__init__() 36 | self.tokenizer = tokenizer 37 | 38 | backbone = instantiate(backbone_cfg) 39 | self.model = CustomDataParallel( 40 | CategoricalTransformer( 41 | backbone=backbone, 42 | dim_model=get_dim_model(backbone_cfg), 43 | num_classes=self.tokenizer.N_total, 44 | max_token_length=tokenizer.max_token_length, 45 | lookahead=True, 46 | ) 47 | ) 48 | 49 | # Note: make sure learnable parameters are inside self.model 50 | self.apply(self._init_weights) 51 | self.compute_stats() 52 | self.loss_fn_ce = nn.CrossEntropyLoss( 53 | ignore_index=self.tokenizer.name_to_id("pad") 54 | ) 55 | 56 | def forward(self, inputs: Dict[str, Tensor]) -> Tuple[Dict[str, Tensor]]: 57 | outputs = self.model( 58 | inputs["input"], src_key_padding_mask=inputs["padding_mask"] 59 | ) 60 | nll_loss = self.loss_fn_ce( 61 | rearrange(outputs["logits"], "b s c -> b c s"), inputs["target"] 62 | ) 63 | 64 | losses = {"nll_loss": nll_loss} 65 | outputs["outputs"] = torch.argmax(outputs["logits"], dim=-1) 66 | return outputs, losses 67 | 68 | def sample( 69 | self, 70 | batch_size: Optional[int], 71 | cond: Optional[Tensor] = None, 72 | sampling_cfg: Optional[DictConfig] = None, 73 | **kwargs, 74 | ) -> Dict[str, Tensor]: 75 | """ 76 | Generate sample based on z. 77 | z can be either given or sampled from some distribution (e.g., normal) 78 | """ 79 | pad_id = self.tokenizer.name_to_id("pad") 80 | 81 | if cond: 82 | cond = duplicate_cond(cond, batch_size) 83 | padding_mask = cond["seq"] == pad_id 84 | outputs = self.model( 85 | cond["seq"].to(self.device), 86 | src_key_padding_mask=padding_mask.to(self.device), 87 | ) 88 | logits = rearrange(outputs["logits"].cpu(), "b s c -> b c s") 89 | seq = rearrange(sample(logits, sampling_cfg), "b 1 s -> b s") 90 | seq[cond["mask"]] = cond["seq"][cond["mask"]] 91 | else: 92 | # since RUITE cannot generate without inputs, just generate dummy 93 | seq = torch.full( 94 | (batch_size, self.tokenizer.max_token_length), fill_value=pad_id 95 | ) 96 | seq[:, 0] = 0 97 | seq[:, 1:5] = self.tokenizer.N_category 98 | 99 | layouts = self.tokenizer.decode(seq) 100 | return layouts 101 | 102 | def preprocess(self, batch): 103 | bbox_w_noise, label, _, mask = sparse_to_dense(batch) 104 | bbox, _ = to_dense_batch(batch.x_orig, batch.batch) 105 | inputs = self.tokenizer.encode( 106 | {"label": label, "mask": mask, "bbox": bbox_w_noise} 107 | ) 108 | targets = self.tokenizer.encode({"label": label, "mask": mask, "bbox": bbox}) 109 | 110 | return { 111 | "target": targets["seq"], 112 | "padding_mask": ~inputs["mask"], 113 | "input": inputs["seq"], 114 | } 115 | 116 | def optim_groups( 117 | self, weight_decay: float = 0.0 118 | ) -> Union[Iterable[Tensor], Dict[str, Tensor]]: 119 | return super().optim_groups( 120 | weight_decay=weight_decay, 121 | additional_no_decay=[ 122 | "model.module.pos_emb.pos_emb", 123 | ], 124 | ) 125 | -------------------------------------------------------------------------------- /src/trainer/trainer/models/transformer_utils.py: -------------------------------------------------------------------------------- 1 | # Implement TransformerEncoder that can consider timesteps as optional args for Diffusion. 2 | 3 | import copy 4 | import math 5 | from typing import Callable, Optional, Union 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from einops.layers.torch import Rearrange 10 | from torch import Tensor, nn 11 | 12 | 13 | def _get_clones(module, N): 14 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 15 | 16 | 17 | def _gelu2(x): 18 | return x * F.sigmoid(1.702 * x) 19 | 20 | 21 | def _get_activation_fn(activation): 22 | if activation == "relu": 23 | return F.relu 24 | elif activation == "gelu": 25 | return F.gelu 26 | elif activation == "gelu2": 27 | return _gelu2 28 | else: 29 | raise RuntimeError( 30 | "activation should be relu/gelu/gelu2, not {}".format(activation) 31 | ) 32 | 33 | 34 | class SinusoidalPosEmb(nn.Module): 35 | def __init__(self, num_steps: int, dim: int, rescale_steps: int = 4000): 36 | super().__init__() 37 | self.dim = dim 38 | self.num_steps = float(num_steps) 39 | self.rescale_steps = float(rescale_steps) 40 | 41 | def forward(self, x: Tensor): 42 | x = x / self.num_steps * self.rescale_steps 43 | device = x.device 44 | half_dim = self.dim // 2 45 | emb = math.log(10000) / (half_dim - 1) 46 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 47 | emb = x[:, None] * emb[None, :] 48 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 49 | return emb 50 | 51 | 52 | class _AdaNorm(nn.Module): 53 | def __init__( 54 | self, n_embd: int, max_timestep: int, emb_type: str = "adalayernorm_abs" 55 | ): 56 | super().__init__() 57 | if "abs" in emb_type: 58 | self.emb = SinusoidalPosEmb(max_timestep, n_embd) 59 | elif "mlp" in emb_type: 60 | self.emb = nn.Sequential( 61 | Rearrange("b -> b 1"), 62 | nn.Linear(1, n_embd // 2), 63 | nn.ReLU(), 64 | nn.Linear(n_embd // 2, n_embd), 65 | ) 66 | else: 67 | self.emb = nn.Embedding(max_timestep, n_embd) 68 | self.silu = nn.SiLU() 69 | self.linear = nn.Linear(n_embd, n_embd * 2) 70 | 71 | 72 | class AdaLayerNorm(_AdaNorm): 73 | def __init__( 74 | self, n_embd: int, max_timestep: int, emb_type: str = "adalayernorm_abs" 75 | ): 76 | super().__init__(n_embd, max_timestep, emb_type) 77 | self.layernorm = nn.LayerNorm(n_embd, elementwise_affine=False) 78 | 79 | def forward(self, x: Tensor, timestep: int): 80 | emb = self.linear(self.silu(self.emb(timestep))).unsqueeze(1) 81 | scale, shift = torch.chunk(emb, 2, dim=2) 82 | x = self.layernorm(x) * (1 + scale) + shift 83 | return x 84 | 85 | 86 | class AdaInsNorm(_AdaNorm): 87 | def __init__( 88 | self, n_embd: int, max_timestep: int, emb_type: str = "adalayernorm_abs" 89 | ): 90 | super().__init__(n_embd, max_timestep, emb_type) 91 | self.instancenorm = nn.InstanceNorm1d(n_embd) 92 | 93 | def forward(self, x, timestep): 94 | emb = self.linear(self.silu(self.emb(timestep))).unsqueeze(1) 95 | scale, shift = torch.chunk(emb, 2, dim=2) 96 | x = ( 97 | self.instancenorm(x.transpose(-1, -2)).transpose(-1, -2) * (1 + scale) 98 | + shift 99 | ) 100 | return x 101 | 102 | 103 | class Block(nn.Module): 104 | """an unassuming Transformer block""" 105 | 106 | def __init__( 107 | self, 108 | d_model=1024, 109 | nhead=16, 110 | dim_feedforward: int = 2048, 111 | dropout: float = 0.0, 112 | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, 113 | batch_first: bool = False, 114 | norm_first: bool = False, 115 | device=None, 116 | dtype=None, 117 | # extension for diffusion 118 | diffusion_step: int = 100, 119 | timestep_type: str = None, 120 | ) -> None: 121 | super().__init__() 122 | 123 | assert norm_first # minGPT-based implementations are designed for prenorm only 124 | assert timestep_type in [ 125 | None, 126 | "adalayernorm", 127 | "adainnorm", 128 | "adalayernorm_abs", 129 | "adainnorm_abs", 130 | "adalayernorm_mlp", 131 | "adainnorm_mlp", 132 | ] 133 | layer_norm_eps = 1e-5 # fixed 134 | 135 | self.norm_first = norm_first 136 | self.diffusion_step = diffusion_step 137 | self.timestep_type = timestep_type 138 | 139 | factory_kwargs = {"device": device, "dtype": dtype} 140 | self.self_attn = torch.nn.MultiheadAttention( 141 | d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs 142 | ) 143 | 144 | # Implementation of Feedforward model 145 | self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) 146 | self.dropout = nn.Dropout(dropout) 147 | self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) 148 | 149 | if timestep_type is not None: 150 | if "adalayernorm" in timestep_type: 151 | self.norm1 = AdaLayerNorm(d_model, diffusion_step, timestep_type) 152 | elif "adainnorm" in timestep_type: 153 | self.norm1 = AdaInsNorm(d_model, diffusion_step, timestep_type) 154 | else: 155 | self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) 156 | self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) 157 | self.dropout1 = nn.Dropout(dropout) 158 | self.dropout2 = nn.Dropout(dropout) 159 | 160 | if isinstance(activation, str): 161 | self.activation = _get_activation_fn(activation) 162 | else: 163 | self.activation = activation 164 | 165 | def forward( 166 | self, 167 | src: Tensor, 168 | src_mask: Optional[Tensor] = None, 169 | src_key_padding_mask: Optional[Tensor] = None, 170 | timestep: Tensor = None, 171 | ) -> Tensor: 172 | x = src 173 | if self.norm_first: 174 | if self.timestep_type is not None: 175 | x = self.norm1(x, timestep) 176 | else: 177 | x = self.norm1(x) 178 | x = x + self._sa_block(x, src_mask, src_key_padding_mask) 179 | x = x + self._ff_block(self.norm2(x)) 180 | else: 181 | x = x + self._sa_block(x, src_mask, src_key_padding_mask) 182 | if self.timestep_type is not None: 183 | x = self.norm1(x, timestep) 184 | else: 185 | x = self.norm1(x) 186 | x = self.norm2(x + self._ff_block(x)) 187 | 188 | return x 189 | 190 | # self-attention block 191 | def _sa_block( 192 | self, 193 | x: Tensor, 194 | attn_mask: Optional[Tensor], 195 | key_padding_mask: Optional[Tensor], 196 | ) -> Tensor: 197 | x = self.self_attn( 198 | x, 199 | x, 200 | x, 201 | attn_mask=attn_mask, 202 | key_padding_mask=key_padding_mask, 203 | need_weights=False, 204 | )[0] 205 | return self.dropout1(x) 206 | 207 | # feed forward block 208 | def _ff_block(self, x: Tensor) -> Tensor: 209 | x = self.linear2(self.dropout(self.activation(self.linear1(x)))) 210 | return self.dropout2(x) 211 | 212 | 213 | class TransformerEncoder(nn.Module): 214 | """ 215 | Close to torch.nn.TransformerEncoder, but with timestep support for diffusion 216 | """ 217 | 218 | __constants__ = ["norm"] 219 | 220 | def __init__(self, encoder_layer, num_layers, norm=None): 221 | super(TransformerEncoder, self).__init__() 222 | self.layers = _get_clones(encoder_layer, num_layers) 223 | self.num_layers = num_layers 224 | self.norm = norm 225 | 226 | def forward( 227 | self, 228 | src: Tensor, 229 | mask: Optional[Tensor] = None, 230 | src_key_padding_mask: Optional[Tensor] = None, 231 | timestep: Tensor = None, 232 | ) -> Tensor: 233 | output = src 234 | 235 | for mod in self.layers: 236 | output = mod( 237 | output, 238 | src_mask=mask, 239 | src_key_padding_mask=src_key_padding_mask, 240 | timestep=timestep, 241 | ) 242 | 243 | if self.norm is not None: 244 | output = self.norm(output) 245 | 246 | return output 247 | --------------------------------------------------------------------------------