├── LICENSE ├── README.md ├── environment.yml ├── imgs └── layout_teasor.jpg ├── layout_transformer ├── dataset.py ├── main.py ├── model.py ├── trainer.py └── utils.py └── layout_vae ├── box.py ├── count.py ├── layout.py ├── train_counts.py └── train_layouts.py /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LayoutTransformer 2 | 3 | [arXiv](https://arxiv.org/abs/2006.14615) | [BibTeX](#bibtex) | [Project Page](https://kampta.github.io/layout) 4 | 5 | This repo contains code for single GPU training of LayoutTransformer from 6 | [LayoutTransformer: Layout Generation and Completion with Self-attention](https://arxiv.org/abs/2006.14615). 7 | This code was rewritten from scratch using a cleaner GPT [codebase](https://github.com/karpathy/minGPT). 8 | Some of the details such as training hyperparameters might differ from the arxiv version of the paper. 9 | 10 | 11 | 12 | 13 | ## How To Use This Code 14 | 15 | Start a new conda environment 16 | ``` 17 | conda env create -f environment.yml 18 | conda activate layout 19 | ``` 20 | or update an existing environment 21 | 22 | ``` 23 | conda env update -f environment.yml --prune 24 | ``` 25 | 26 | ### Logging with `wandb` 27 | 28 | In order to log experiments to wandb, 29 | we use wandb's API keys that can be found here https://wandb.ai/settings. 30 | Copy your key and store them in an environment variable using 31 | 32 | ``` 33 | export WANDB_API_KEY= 34 | ``` 35 | 36 | Alternately, you can also login using `wandb login`. 37 | 38 | ## Datasets 39 | 40 | ### COCO Bounding Boxes 41 | 42 | See the instructions to obtain the dataset [here](https://cocodataset.org/). 43 | 44 | ### PubLayNet Document Layouts 45 | 46 | See the instructions to obtain the dataset [here](https://github.com/ibm-aur-nlp/PubLayNet). 47 | 48 | 49 | ## LayoutVAE 50 | 51 | Reimplementation of [LayoutVAE](https://arxiv.org/abs/1907.10719) is [here](layout_vae). 52 | Code contributed primarily by Justin. 53 | 54 | ``` 55 | cd layout_vae 56 | 57 | # Train the CountVAE model 58 | python train_counts.py \ 59 | --exp count_coco_instances \ 60 | --train_json /path/to/coco/annotations/instances_train2017.json \ 61 | --val_json /path/to/coco/annotations/instances_val2017.json \ 62 | --epochs 50 63 | 64 | # Train the BoxVAE model 65 | python train_counts.py \ 66 | --exp box_coco_instances \ 67 | --train_json /path/to/coco/annotations/instances_train2017.json \ 68 | --val_json /path/to/coco/annotations/instances_val2017.json \ 69 | --epochs 50 70 | ``` 71 | 72 | ## LayoutTransformer 73 | 74 | Rewritten from scratch using a cleaner GPT [codebase](https://github.com/karpathy/minGPT). 75 | Some of the details such as training hyperparameters might differ from the arxiv version. 76 | 77 | ``` 78 | # Training on MNIST layouts 79 | python main.py \ 80 | --data_dir /path/to/mnist \ 81 | --threshold 1 --exp mnist_threshold_1 82 | ``` 83 | 84 | In your wandb, you can see some generated samples 85 | 86 | ![media_images_sample_random_layouts_18750_0](https://user-images.githubusercontent.com/1719140/137636972-4030c68e-b1c1-4234-b420-cf3068a5a9c6.png) 87 | ![media_images_sample_random_layouts_18750_1](https://user-images.githubusercontent.com/1719140/137636974-0f40c6ce-ea3c-445f-9610-b660f8b60d38.png) 88 | ![media_images_sample_random_layouts_18750_2](https://user-images.githubusercontent.com/1719140/137636975-8365f231-246d-4aae-a2a2-a339dd27e8b5.png) 89 | ![media_images_sample_random_layouts_18750_3](https://user-images.githubusercontent.com/1719140/137636976-6c8b88c0-41c0-43e1-a492-17dc718138be.png) 90 | 91 | 92 | ``` 93 | # Training on COCO bounding boxes or PubLayNet 94 | python main.py \ 95 | --train_json /path/to/annotations/train.json \ 96 | --val_json /path/to/annotations/val.json \ 97 | --exp publaynet 98 | ``` 99 | 100 | For the PubLayNet dataset, generated samples might look like this 101 | 102 | 103 | ![media_images_sample_random_layouts_15738_3](https://user-images.githubusercontent.com/1719140/137637046-e2181cda-904e-4ea3-868b-39a7bf64a236.png) 104 | ![media_images_sample_random_layouts_26230_2](https://user-images.githubusercontent.com/1719140/137637047-43fd285f-afec-42ba-a4f7-04ddf66d4d86.png) 105 | ![media_images_sample_random_layouts_26230_3](https://user-images.githubusercontent.com/1719140/137637048-7263f9ab-1d19-4826-a6c2-d7ce152d9e0d.png) 106 | 107 | 108 | ## BibTeX 109 | 110 | If you use this code, please cite 111 | ```text 112 | @inproceedings{gupta2021layouttransformer, 113 | title={LayoutTransformer: Layout Generation and Completion with Self-attention}, 114 | author={Gupta, Kamal and Lazarow, Justin and Achille, Alessandro and Davis, Larry S and Mahadevan, Vijay and Shrivastava, Abhinav}, 115 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 116 | pages={1004--1014}, 117 | year={2021} 118 | } 119 | } 120 | ``` 121 | 122 | ## Acknowledgments 123 | 124 | We would like to thank several public repos 125 | 126 | * https://github.com/JiananLi2016/LayoutGAN-Tensorflow 127 | * https://github.com/Layout-Generation/layout-generation 128 | * https://github.com/karpathy/minGPT 129 | * https://github.com/ChrisWu1997/PQ-NET 130 | 131 | 132 | ## License 133 | 134 | This repository is released under the Apache 2.0 license as found in the [LICENSE](LICENSE) file. 135 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: layout 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=10.2 9 | - pytorch=1.7.0 10 | - torchvision=0.8.1 11 | - numpy=1.19.2 12 | - scikit-learn 13 | - scipy 14 | - pip: 15 | - tqdm 16 | - seaborn 17 | - wandb 18 | - tabulate 19 | - ipdb -------------------------------------------------------------------------------- /imgs/layout_teasor.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kampta/DeepLayout/60ebaca7f29a14d92c455f83e681c0bd8e2962fe/imgs/layout_teasor.jpg -------------------------------------------------------------------------------- /layout_transformer/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision.datasets.mnist import MNIST 4 | from torch.utils.data.dataset import Dataset 5 | from PIL import Image, ImageDraw, ImageOps 6 | import json 7 | 8 | from utils import trim_tokens, gen_colors 9 | 10 | 11 | class Padding(object): 12 | def __init__(self, max_length, vocab_size): 13 | self.max_length = max_length 14 | self.bos_token = vocab_size - 3 15 | self.eos_token = vocab_size - 2 16 | self.pad_token = vocab_size - 1 17 | 18 | def __call__(self, layout): 19 | # grab a chunk of (max_length + 1) from the layout 20 | 21 | chunk = torch.zeros(self.max_length+1, dtype=torch.long) + self.pad_token 22 | # Assume len(item) will always be <= self.max_length: 23 | chunk[0] = self.bos_token 24 | chunk[1:len(layout)+1] = layout 25 | chunk[len(layout)+1] = self.eos_token 26 | 27 | x = chunk[:-1] 28 | y = chunk[1:] 29 | return {'x': x, 'y': y} 30 | 31 | 32 | class MNISTLayout(MNIST): 33 | 34 | def __init__(self, root, train=True, download=True, threshold=32, max_length=None): 35 | super().__init__(root, train=train, download=download) 36 | self.vocab_size = 784 + 3 # bos, eos, pad tokens 37 | self.bos_token = self.vocab_size - 3 38 | self.eos_token = self.vocab_size - 2 39 | self.pad_token = self.vocab_size - 1 40 | 41 | self.threshold = threshold 42 | self.data = [self.img_to_set(img) for img in self.data] 43 | self.max_length = max_length 44 | if self.max_length is None: 45 | self.max_length = max([len(x) for x in self.data]) + 2 # bos, eos tokens 46 | self.transform = Padding(self.max_length, self.vocab_size) 47 | 48 | def __len__(self): 49 | return len(self.data) 50 | 51 | def img_to_set(self, img): 52 | fg_mask = img >= self.threshold 53 | fg_idx = fg_mask.nonzero(as_tuple=False) 54 | fg_idx = fg_idx[:, 0] * 28 + fg_idx[:, 1] 55 | return fg_idx 56 | 57 | def render(self, layout): 58 | layout = trim_tokens(layout, self.bos_token, self.eos_token, self.pad_token) 59 | x_coords = layout % 28 60 | y_coords = layout // 28 61 | # valid_idx = torch.where((y_coords < 28) & (y_coords >= 0))[0] 62 | img = np.zeros((28, 28, 3)).astype(np.uint8) 63 | img[y_coords, x_coords] = 255 64 | return Image.fromarray(img, 'RGB') 65 | 66 | def __getitem__(self, idx): 67 | # grab a chunk of (block_size + 1) tokens from the data 68 | layout = self.transform(self.data[idx]) 69 | return layout['x'], layout['y'] 70 | 71 | 72 | class JSONLayout(Dataset): 73 | def __init__(self, json_path, max_length=None, precision=8): 74 | with open(json_path, "r") as f: 75 | data = json.loads(f.read()) 76 | 77 | images, annotations, categories = data['images'], data['annotations'], data['categories'] 78 | self.size = pow(2, precision) 79 | 80 | self.categories = {c["id"]: c for c in categories} 81 | self.colors = gen_colors(len(self.categories)) 82 | 83 | self.json_category_id_to_contiguous_id = { 84 | v: i + self.size for i, v in enumerate([c["id"] for c in self.categories.values()]) 85 | } 86 | 87 | self.contiguous_category_id_to_json_id = { 88 | v: k for k, v in self.json_category_id_to_contiguous_id.items() 89 | } 90 | 91 | self.vocab_size = self.size + len(self.categories) + 3 # bos, eos, pad tokens 92 | self.bos_token = self.vocab_size - 3 93 | self.eos_token = self.vocab_size - 2 94 | self.pad_token = self.vocab_size - 1 95 | 96 | image_to_annotations = {} 97 | for annotation in annotations: 98 | image_id = annotation["image_id"] 99 | 100 | if not (image_id in image_to_annotations): 101 | image_to_annotations[image_id] = [] 102 | 103 | image_to_annotations[image_id].append(annotation) 104 | 105 | self.data = [] 106 | for image in images: 107 | image_id = image["id"] 108 | height, width = float(image["height"]), float(image["width"]) 109 | 110 | if image_id not in image_to_annotations: 111 | continue 112 | 113 | ann_box = [] 114 | ann_cat = [] 115 | for ann in image_to_annotations[image_id]: 116 | x, y, w, h = ann["bbox"] 117 | ann_box.append([x, y, w, h]) 118 | ann_cat.append(self.json_category_id_to_contiguous_id[ann["category_id"]]) 119 | 120 | # Sort boxes 121 | ann_box = np.array(ann_box) 122 | ind = np.lexsort((ann_box[:, 0], ann_box[:, 1])) 123 | ann_box = ann_box[ind] 124 | 125 | ann_cat = np.array(ann_cat) 126 | ann_cat = ann_cat[ind] 127 | 128 | # Discretize boxes 129 | ann_box = self.quantize_box(ann_box, width, height) 130 | 131 | # Append the categories 132 | layout = np.concatenate([ann_cat.reshape(-1, 1), ann_box], axis=1) 133 | 134 | # Flatten and add to the dataset 135 | self.data.append(layout.reshape(-1)) 136 | 137 | self.max_length = max_length 138 | if self.max_length is None: 139 | self.max_length = max([len(x) for x in self.data]) + 2 # bos, eos tokens 140 | self.transform = Padding(self.max_length, self.vocab_size) 141 | 142 | def quantize_box(self, boxes, width, height): 143 | 144 | # range of xy is [0, large_side-1] 145 | # range of wh is [1, large_side] 146 | # bring xywh to [0, 1] 147 | boxes[:, [2, 3]] = boxes[:, [2, 3]] - 1 148 | boxes[:, [0, 2]] = boxes[:, [0, 2]] / (width - 1) 149 | boxes[:, [1, 3]] = boxes[:, [1, 3]] / (height - 1) 150 | boxes = np.clip(boxes, 0, 1) 151 | 152 | # next take xywh to [0, size-1] 153 | boxes = (boxes * (self.size - 1)).round() 154 | 155 | return boxes.astype(np.int32) 156 | 157 | def __len__(self): 158 | return len(self.data) 159 | 160 | def render(self, layout): 161 | img = Image.new('RGB', (256, 256), color=(255, 255, 255)) 162 | draw = ImageDraw.Draw(img, 'RGBA') 163 | layout = layout.reshape(-1) 164 | layout = trim_tokens(layout, self.bos_token, self.eos_token, self.pad_token) 165 | layout = layout[: len(layout) // 5 * 5].reshape(-1, 5) 166 | box = layout[:, 1:].astype(np.float32) 167 | box[:, [0, 1]] = box[:, [0, 1]] / (self.size - 1) * 255 168 | box[:, [2, 3]] = box[:, [2, 3]] / self.size * 256 169 | box[:, [2, 3]] = box[:, [0, 1]] + box[:, [2, 3]] 170 | 171 | for i in range(len(layout)): 172 | x1, y1, x2, y2 = box[i] 173 | cat = layout[i][0] 174 | col = self.colors[cat-self.size] if 0 <= cat-self.size < len(self.colors) else [0, 0, 0] 175 | draw.rectangle([x1, y1, x2, y2], 176 | outline=tuple(col) + (200,), 177 | fill=tuple(col) + (64,), 178 | width=2) 179 | 180 | # Add border around image 181 | img = ImageOps.expand(img, border=2) 182 | return img 183 | 184 | def __getitem__(self, idx): 185 | # grab a chunk of (block_size + 1) tokens from the data 186 | layout = torch.tensor(self.data[idx], dtype=torch.long) 187 | layout = self.transform(layout) 188 | return layout['x'], layout['y'] 189 | -------------------------------------------------------------------------------- /layout_transformer/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from dataset import MNISTLayout, JSONLayout 5 | from model import GPT, GPTConfig 6 | from trainer import Trainer, TrainerConfig 7 | from utils import set_seed 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser('Layout Transformer') 12 | parser.add_argument("--exp", default="layout", help="experiment name") 13 | parser.add_argument("--log_dir", default="./logs", help="/path/to/logs/dir") 14 | 15 | # MNIST options 16 | parser.add_argument("--data_dir", default=None, help="/path/to/mnist/data") 17 | parser.add_argument("--threshold", type=int, default=16, help="threshold for grayscale values") 18 | 19 | # COCO/PubLayNet options 20 | parser.add_argument("--train_json", default="./instances_train.json", help="/path/to/train/json") 21 | parser.add_argument("--val_json", default="./instances_val.json", help="/path/to/val/json") 22 | 23 | # Layout options 24 | parser.add_argument("--max_length", type=int, default=128, help="batch size") 25 | parser.add_argument('--precision', default=8, type=int) 26 | parser.add_argument('--element_order', default='raster') 27 | parser.add_argument('--attribute_order', default='cxywh') 28 | 29 | # Architecture/training options 30 | parser.add_argument("--seed", type=int, default=42, help="random seed") 31 | parser.add_argument("--epochs", type=int, default=10, help="number of epochs") 32 | parser.add_argument("--batch_size", type=int, default=64, help="batch size") 33 | parser.add_argument("--lr", type=float, default=4.5e-06, help="learning rate") 34 | parser.add_argument('--n_layer', default=6, type=int) 35 | parser.add_argument('--n_embd', default=512, type=int) 36 | parser.add_argument('--n_head', default=8, type=int) 37 | # parser.add_argument('--evaluate', action='store_true', help="evaluate only") 38 | parser.add_argument('--lr_decay', action='store_true', help="use learning rate decay") 39 | parser.add_argument('--warmup_iters', type=int, default=0, help="linear lr warmup iters") 40 | parser.add_argument('--final_iters', type=int, default=0, help="cosine lr final iters") 41 | parser.add_argument('--sample_every', type=int, default=1, help="sample every epoch") 42 | 43 | args = parser.parse_args() 44 | 45 | log_dir = os.path.join(args.log_dir, args.exp) 46 | samples_dir = os.path.join(log_dir, "samples") 47 | ckpt_dir = os.path.join(log_dir, "checkpoints") 48 | os.makedirs(samples_dir, exist_ok=True) 49 | os.makedirs(ckpt_dir, exist_ok=True) 50 | 51 | set_seed(args.seed) 52 | 53 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 54 | print(f"using device: {device}") 55 | 56 | # MNIST Testing 57 | if args.data_dir is not None: 58 | train_dataset = MNISTLayout(args.log_dir, train=True, threshold=args.threshold) 59 | valid_dataset = MNISTLayout(args.log_dir, train=False, threshold=args.threshold, 60 | max_length=train_dataset.max_length) 61 | # COCO and PubLayNet 62 | else: 63 | train_dataset = JSONLayout(args.train_json) 64 | valid_dataset = JSONLayout(args.val_json, max_length=train_dataset.max_length) 65 | 66 | mconf = GPTConfig(train_dataset.vocab_size, train_dataset.max_length, 67 | n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd) # a GPT-1 68 | model = GPT(mconf) 69 | tconf = TrainerConfig(max_epochs=args.epochs, 70 | batch_size=args.batch_size, 71 | lr_decay=args.lr_decay, 72 | learning_rate=args.lr * args.batch_size, 73 | warmup_iters=args.warmup_iters, 74 | final_iters=args.final_iters, 75 | ckpt_dir=ckpt_dir, 76 | samples_dir=samples_dir, 77 | sample_every=args.sample_every) 78 | trainer = Trainer(model, train_dataset, valid_dataset, tconf, args) 79 | trainer.train() 80 | -------------------------------------------------------------------------------- /layout_transformer/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | GPT model: 3 | - the initial stem consists of a combination of token encoding and a positional encoding 4 | - the meat of it is a uniform sequence of Transformer blocks 5 | - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block 6 | - all blocks feed into a central residual pathway similar to resnets 7 | - the final decoder is a linear projection into a vanilla Softmax classifier 8 | """ 9 | 10 | import math 11 | import logging 12 | 13 | import torch 14 | import torch.nn as nn 15 | from torch.nn import functional as F 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class GPTConfig: 21 | """ base GPT config, params common to all GPT versions """ 22 | embd_pdrop = 0.1 23 | resid_pdrop = 0.1 24 | attn_pdrop = 0.1 25 | 26 | def __init__(self, vocab_size, block_size, **kwargs): 27 | self.vocab_size = vocab_size 28 | self.block_size = block_size 29 | for k,v in kwargs.items(): 30 | setattr(self, k, v) 31 | 32 | 33 | class GPT1Config(GPTConfig): 34 | """ GPT-1 like network roughly 125M params """ 35 | n_layer = 12 36 | n_head = 12 37 | n_embd = 768 38 | 39 | 40 | class CausalSelfAttention(nn.Module): 41 | """ 42 | A vanilla multi-head masked self-attention layer with a projection at the end. 43 | It is possible to use torch.nn.MultiheadAttention here but I am including an 44 | explicit implementation here to show that there is nothing too scary here. 45 | """ 46 | 47 | def __init__(self, config): 48 | super().__init__() 49 | assert config.n_embd % config.n_head == 0 50 | # key, query, value projections for all heads 51 | self.key = nn.Linear(config.n_embd, config.n_embd) 52 | self.query = nn.Linear(config.n_embd, config.n_embd) 53 | self.value = nn.Linear(config.n_embd, config.n_embd) 54 | # regularization 55 | self.attn_drop = nn.Dropout(config.attn_pdrop) 56 | self.resid_drop = nn.Dropout(config.resid_pdrop) 57 | # output projection 58 | self.proj = nn.Linear(config.n_embd, config.n_embd) 59 | # causal mask to ensure that attention is only applied to the left in the input sequence 60 | self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size)) 61 | .view(1, 1, config.block_size, config.block_size)) 62 | self.n_head = config.n_head 63 | 64 | def forward(self, x, layer_past=None): 65 | B, T, C = x.size() 66 | 67 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 68 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 69 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 70 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 71 | 72 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 73 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 74 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) 75 | att = F.softmax(att, dim=-1) 76 | att = self.attn_drop(att) 77 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 78 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 79 | 80 | # output projection 81 | y = self.resid_drop(self.proj(y)) 82 | return y 83 | 84 | 85 | class Block(nn.Module): 86 | """ an unassuming Transformer block """ 87 | 88 | def __init__(self, config): 89 | super().__init__() 90 | self.ln1 = nn.LayerNorm(config.n_embd) 91 | self.ln2 = nn.LayerNorm(config.n_embd) 92 | self.attn = CausalSelfAttention(config) 93 | self.mlp = nn.Sequential( 94 | nn.Linear(config.n_embd, 4 * config.n_embd), 95 | nn.GELU(), 96 | nn.Linear(4 * config.n_embd, config.n_embd), 97 | nn.Dropout(config.resid_pdrop), 98 | ) 99 | 100 | def forward(self, x): 101 | x = x + self.attn(self.ln1(x)) 102 | x = x + self.mlp(self.ln2(x)) 103 | return x 104 | 105 | 106 | class GPT(nn.Module): 107 | """ the full GPT language model, with a context size of block_size """ 108 | 109 | def __init__(self, config): 110 | super().__init__() 111 | 112 | # input embedding stem 113 | self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) 114 | self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) 115 | self.drop = nn.Dropout(config.embd_pdrop) 116 | # transformer 117 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) 118 | # decoder head 119 | self.ln_f = nn.LayerNorm(config.n_embd) 120 | self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 121 | 122 | self.block_size = config.block_size 123 | self.apply(self._init_weights) 124 | 125 | logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) 126 | 127 | def get_block_size(self): 128 | return self.block_size 129 | 130 | def _init_weights(self, module): 131 | if isinstance(module, (nn.Linear, nn.Embedding)): 132 | module.weight.data.normal_(mean=0.0, std=0.02) 133 | if isinstance(module, nn.Linear) and module.bias is not None: 134 | module.bias.data.zero_() 135 | elif isinstance(module, nn.LayerNorm): 136 | module.bias.data.zero_() 137 | module.weight.data.fill_(1.0) 138 | 139 | def configure_optimizers(self, train_config): 140 | """ 141 | This long function is unfortunately doing something very simple and is being very defensive: 142 | We are separating out all parameters of the model into two buckets: those that will experience 143 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 144 | We are then returning the PyTorch optimizer object. 145 | """ 146 | 147 | # separate out all parameters to those that will and won't experience regularizing weight decay 148 | decay = set() 149 | no_decay = set() 150 | whitelist_weight_modules = (torch.nn.Linear, ) 151 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 152 | for mn, m in self.named_modules(): 153 | for pn, p in m.named_parameters(): 154 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 155 | 156 | if pn.endswith('bias'): 157 | # all biases will not be decayed 158 | no_decay.add(fpn) 159 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 160 | # weights of whitelist modules will be weight decayed 161 | decay.add(fpn) 162 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 163 | # weights of blacklist modules will NOT be weight decayed 164 | no_decay.add(fpn) 165 | 166 | # special case the position embedding parameter in the root GPT module as not decayed 167 | no_decay.add('pos_emb') 168 | 169 | # validate that we considered every parameter 170 | param_dict = {pn: p for pn, p in self.named_parameters()} 171 | inter_params = decay & no_decay 172 | union_params = decay | no_decay 173 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 174 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ 175 | % (str(param_dict.keys() - union_params), ) 176 | 177 | # create the pytorch optimizer object 178 | optim_groups = [ 179 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, 180 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 181 | ] 182 | optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) 183 | return optimizer 184 | 185 | def forward(self, idx, targets=None, pad_token=-100): 186 | b, t = idx.size() 187 | assert t <= self.block_size, "Cannot forward, model block size is exhausted." 188 | 189 | # forward the GPT model 190 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector 191 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector 192 | x = self.drop(token_embeddings + position_embeddings) 193 | x = self.blocks(x) 194 | x = self.ln_f(x) 195 | logits = self.head(x) 196 | 197 | # if we are given some desired targets also calculate the loss 198 | loss = None 199 | if targets is not None: 200 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=pad_token) 201 | 202 | return logits, loss 203 | -------------------------------------------------------------------------------- /layout_transformer/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple training loop; Boilerplate that could apply to any arbitrary neural network, 3 | so nothing in this file really has anything to do with GPT specifically. 4 | """ 5 | import os 6 | import math 7 | import logging 8 | import wandb 9 | 10 | from tqdm import tqdm 11 | import numpy as np 12 | 13 | import torch 14 | from torch.nn import functional as F 15 | 16 | from torch.utils.data.dataloader import DataLoader 17 | 18 | from utils import sample 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class TrainerConfig: 24 | # optimization parameters 25 | max_epochs = 10 26 | batch_size = 64 27 | learning_rate = 3e-4 28 | betas = (0.9, 0.95) 29 | grad_norm_clip = 1.0 30 | weight_decay = 0.1 # only applied on matmul weights 31 | # learning rate decay params: linear warmup followed by cosine decay to 10% of original 32 | lr_decay = False 33 | warmup_iters = 0 34 | final_iters = 0 # (at what point we reach 10% of original LR) 35 | # checkpoint settings 36 | ckpt_dir = None 37 | samples_dir = None 38 | sample_every = 1 39 | num_workers = 0 # for DataLoader 40 | 41 | def __init__(self, **kwargs): 42 | for k, v in kwargs.items(): 43 | setattr(self, k, v) 44 | 45 | 46 | class Trainer: 47 | 48 | def __init__(self, model, train_dataset, test_dataset, config, args): 49 | self.model = model 50 | self.train_dataset = train_dataset 51 | self.test_dataset = test_dataset 52 | self.config = config 53 | self.iters = 0 54 | self.fixed_x = None 55 | self.fixed_y = None 56 | print("Using wandb") 57 | wandb.init(project='LayoutTransformer', name=args.exp) 58 | wandb.config.update(args) 59 | 60 | # take over whatever gpus are on the system 61 | self.device = 'cpu' 62 | if torch.cuda.is_available(): 63 | self.device = torch.cuda.current_device() 64 | self.model = torch.nn.DataParallel(self.model).to(self.device) 65 | 66 | def save_checkpoint(self): 67 | # DataParallel wrappers keep raw model object in .module attribute 68 | raw_model = self.model.module if hasattr(self.model, "module") else self.model 69 | ckpt_path = os.path.join(self.config.ckpt_dir, 'checkpoint.pth') 70 | logger.info("saving %s", ckpt_path) 71 | torch.save(raw_model.state_dict(), ckpt_path) 72 | 73 | def train(self): 74 | model, config = self.model, self.config 75 | raw_model = model.module if hasattr(self.model, "module") else model 76 | optimizer = raw_model.configure_optimizers(config) 77 | pad_token = self.train_dataset.vocab_size - 1 78 | 79 | def run_epoch(split): 80 | is_train = split == 'train' 81 | model.train(is_train) 82 | data = self.train_dataset if is_train else self.test_dataset 83 | loader = DataLoader(data, shuffle=True, pin_memory=True, 84 | batch_size=config.batch_size, 85 | num_workers=config.num_workers) 86 | 87 | losses = [] 88 | pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader) 89 | for it, (x, y) in pbar: 90 | 91 | if epoch == 0 and not is_train: 92 | self.fixed_x = x[:min(4, len(x))] 93 | self.fixed_y = y[:min(4, len(y))] 94 | 95 | # place data on the correct device 96 | x = x.to(self.device) 97 | y = y.to(self.device) 98 | 99 | # forward the model 100 | with torch.set_grad_enabled(is_train): 101 | # import ipdb; ipdb.set_trace() 102 | logits, loss = model(x, y, pad_token=pad_token) 103 | loss = loss.mean() # collapse all losses if they are scattered on multiple gpus 104 | losses.append(loss.item()) 105 | 106 | if is_train: 107 | 108 | # backprop and update the parameters 109 | model.zero_grad() 110 | loss.backward() 111 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) 112 | optimizer.step() 113 | self.iters += 1 114 | # decay the learning rate based on our progress 115 | if config.lr_decay: 116 | # self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100) 117 | if self.iters < config.warmup_iters: 118 | # linear warmup 119 | lr_mult = float(self.iters) / float(max(1, config.warmup_iters)) 120 | else: 121 | # cosine learning rate decay 122 | progress = float(self.iters - config.warmup_iters) / float(max(1, config.final_iters - config.warmup_iters)) 123 | lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress))) 124 | lr = config.learning_rate * lr_mult 125 | for param_group in optimizer.param_groups: 126 | param_group['lr'] = lr 127 | else: 128 | lr = config.learning_rate 129 | 130 | # report progress 131 | wandb.log({ 132 | 'train loss': loss.item(), 133 | 'lr': lr, 'epoch': epoch+1 134 | }, step=self.iters) 135 | pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}") 136 | 137 | if not is_train: 138 | test_loss = float(np.mean(losses)) 139 | logger.info("test loss: %f", test_loss) 140 | wandb.log({'test loss': test_loss}, step=self.iters) 141 | return test_loss 142 | 143 | best_loss = float('inf') 144 | for epoch in range(config.max_epochs): 145 | run_epoch('train') 146 | if self.test_dataset is not None: 147 | with torch.no_grad(): 148 | test_loss = run_epoch('test') 149 | 150 | # supports early stopping based on the test loss, or just save always if no test set is provided 151 | good_model = self.test_dataset is None or test_loss < best_loss 152 | if self.config.ckpt_dir is not None and good_model: 153 | best_loss = test_loss 154 | self.save_checkpoint() 155 | 156 | # sample from the model 157 | if self.config.samples_dir is not None and (epoch+1) % self.config.sample_every == 0: 158 | # import ipdb; ipdb.set_trace() 159 | # inputs 160 | layouts = self.fixed_x.detach().cpu().numpy() 161 | input_layouts = [self.train_dataset.render(layout) for layout in layouts] 162 | # for i, layout in enumerate(layouts): 163 | # layout = self.train_dataset.render(layout) 164 | # layout.save(os.path.join(self.config.samples_dir, f'input_{epoch:02d}_{i:02d}.png')) 165 | 166 | # reconstruction 167 | x_cond = self.fixed_x.to(self.device) 168 | logits, _ = model(x_cond) 169 | probs = F.softmax(logits, dim=-1) 170 | _, y = torch.topk(probs, k=1, dim=-1) 171 | layouts = torch.cat((x_cond[:, :1], y[:, :, 0]), dim=1).detach().cpu().numpy() 172 | recon_layouts = [self.train_dataset.render(layout) for layout in layouts] 173 | # for i, layout in enumerate(layouts): 174 | # layout = self.train_dataset.render(layout) 175 | # layout.save(os.path.join(self.config.samples_dir, f'recon_{epoch:02d}_{i:02d}.png')) 176 | 177 | # samples - random 178 | layouts = sample(model, x_cond[:, :6], steps=self.train_dataset.max_length, 179 | temperature=1.0, sample=True, top_k=5).detach().cpu().numpy() 180 | sample_random_layouts = [self.train_dataset.render(layout) for layout in layouts] 181 | # for i, layout in enumerate(layouts): 182 | # layout = self.train_dataset.render(layout) 183 | # layout.save(os.path.join(self.config.samples_dir, f'sample_random_{epoch:02d}_{i:02d}.png')) 184 | 185 | # samples - deterministic 186 | layouts = sample(model, x_cond[:, :6], steps=self.train_dataset.max_length, 187 | temperature=1.0, sample=False, top_k=None).detach().cpu().numpy() 188 | sample_det_layouts = [self.train_dataset.render(layout) for layout in layouts] 189 | # for i, layout in enumerate(layouts): 190 | # layout = self.train_dataset.render(layout) 191 | # layout.save(os.path.join(self.config.samples_dir, f'sample_det_{epoch:02d}_{i:02d}.png')) 192 | 193 | wandb.log({ 194 | "input_layouts": [wandb.Image(pil, caption=f'input_{epoch:02d}_{i:02d}.png') 195 | for i, pil in enumerate(input_layouts)], 196 | "recon_layouts": [wandb.Image(pil, caption=f'recon_{epoch:02d}_{i:02d}.png') 197 | for i, pil in enumerate(recon_layouts)], 198 | "sample_random_layouts": [wandb.Image(pil, caption=f'sample_random_{epoch:02d}_{i:02d}.png') 199 | for i, pil in enumerate(sample_random_layouts)], 200 | "sample_det_layouts": [wandb.Image(pil, caption=f'sample_det_{epoch:02d}_{i:02d}.png') 201 | for i, pil in enumerate(sample_det_layouts)], 202 | }, step=self.iters) 203 | -------------------------------------------------------------------------------- /layout_transformer/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from torch.nn import functional as F 5 | import seaborn as sns 6 | 7 | 8 | def set_seed(seed): 9 | random.seed(seed) 10 | np.random.seed(seed) 11 | torch.manual_seed(seed) 12 | torch.cuda.manual_seed_all(seed) 13 | 14 | 15 | def top_k_logits(logits, k): 16 | v, ix = torch.topk(logits, k) 17 | out = logits.clone() 18 | out[out < v[:, [-1]]] = -float('Inf') 19 | return out 20 | 21 | 22 | def gen_colors(num_colors): 23 | """ 24 | Generate uniformly distributed `num_colors` colors 25 | :param num_colors: 26 | :return: 27 | """ 28 | palette = sns.color_palette(None, num_colors) 29 | rgb_triples = [[int(x[0]*255), int(x[1]*255), int(x[2]*255)] for x in palette] 30 | return rgb_triples 31 | 32 | 33 | @torch.no_grad() 34 | def sample(model, x, steps, temperature=1.0, sample=False, top_k=None): 35 | """ 36 | take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in 37 | the sequence, feeding the predictions back into the model each time. Clearly the sampling 38 | has quadratic complexity unlike an RNN that is only linear, and has a finite context window 39 | of block_size, unlike an RNN that has an infinite context window. 40 | """ 41 | block_size = model.module.get_block_size() if hasattr(model, "module") else model.getcond_block_size() 42 | model.eval() 43 | for k in range(steps): 44 | x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed 45 | logits, _ = model(x_cond) 46 | # pluck the logits at the final step and scale by temperature 47 | logits = logits[:, -1, :] / temperature 48 | # optionally crop probabilities to only the top k options 49 | if top_k is not None: 50 | logits = top_k_logits(logits, top_k) 51 | # apply softmax to convert to probabilities 52 | probs = F.softmax(logits, dim=-1) 53 | # sample from the distribution or take the most likely 54 | if sample: 55 | ix = torch.multinomial(probs, num_samples=1) 56 | else: 57 | _, ix = torch.topk(probs, k=1, dim=-1) 58 | # append to the sequence and continue 59 | x = torch.cat((x, ix), dim=1) 60 | 61 | return x 62 | 63 | 64 | def trim_tokens(tokens, bos, eos, pad=None): 65 | bos_idx = np.where(tokens == bos)[0] 66 | tokens = tokens[bos_idx[0]+1:] if len(bos_idx) > 0 else tokens 67 | eos_idx = np.where(tokens == eos)[0] 68 | tokens = tokens[:eos_idx[0]] if len(eos_idx) > 0 else tokens 69 | # tokens = tokens[tokens != bos] 70 | # tokens = tokens[tokens != eos] 71 | if pad is not None: 72 | tokens = tokens[tokens != pad] 73 | return tokens 74 | 75 | -------------------------------------------------------------------------------- /layout_vae/box.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from count import LabelSetEncoder, SingleLabelEncoder 6 | 7 | 8 | class BoxHistoryMemory(nn.Module): 9 | def __init__(self, number_labels): 10 | super(BoxHistoryMemory, self).__init__() 11 | 12 | self.memory = nn.LSTM(number_labels + 4, 128) 13 | 14 | def forward(self, labels, boxes, state=None): 15 | # I assume if there is no history, we just return the zero vector? 16 | device = labels.device 17 | batch_size = labels.size(0) 18 | history_length = labels.size(1) 19 | if history_length == 0: 20 | return torch.zeros((batch_size, 128)).to(device), None 21 | 22 | # make labels, and boxes into L x N 23 | input = torch.cat((labels, boxes), dim=-1).permute(1, 0, 2) 24 | 25 | output, state = self.memory(input, state) 26 | output = output[-1, :] 27 | 28 | return output, state 29 | 30 | 31 | class BoxConditioningMLP(nn.Module): 32 | def __init__(self, number_labels): 33 | super(BoxConditioningMLP, self).__init__() 34 | 35 | self.encode_label_set = LabelSetEncoder(number_labels) 36 | self.encode_single_label = SingleLabelEncoder(number_labels) 37 | self.encode_box = BoxHistoryMemory(number_labels) 38 | 39 | self.fc = nn.Linear(128 * 3, 128) 40 | 41 | # the LayoutVAE paper seems to indicate they teacher-force 42 | # at evaluation time... I can't imagine that is correct? 43 | def forward(self, label_set, current_label, labels_so_far, boxes_so_far, state=None): 44 | label_set = self.encode_label_set(label_set) 45 | current_label = self.encode_single_label(current_label) 46 | boxes_so_far, state = self.encode_box(labels_so_far, boxes_so_far, state) 47 | 48 | aggregate = torch.cat((label_set, current_label, boxes_so_far), dim=-1) 49 | aggregate = self.fc(aggregate) 50 | 51 | return aggregate, state 52 | 53 | 54 | class BoxInputEncoder(nn.Module): 55 | def __init__(self): 56 | super(BoxInputEncoder, self).__init__() 57 | 58 | self.fc1 = nn.Linear(4, 128) 59 | self.fc2 = nn.Linear(128, 128) 60 | 61 | def forward(self, x): 62 | x = F.relu(self.fc1(x)) 63 | x = self.fc2(x) 64 | 65 | return x 66 | 67 | 68 | class AutoregressiveBoxEncoder(nn.Module): 69 | def __init__(self, number_labels, conditioning_size, representation_size=32): 70 | super(AutoregressiveBoxEncoder, self).__init__() 71 | 72 | self.number_labels = number_labels 73 | 74 | self.input_encoder = BoxInputEncoder() 75 | self.conditioning = BoxConditioningMLP(self.number_labels) 76 | 77 | self.fc = nn.Linear(128 + conditioning_size, representation_size) 78 | self.project_mu = nn.Linear(representation_size, representation_size) 79 | self.project_s = nn.Linear(representation_size, representation_size) 80 | 81 | # x is the count to be encoded. 82 | def forward(self, x, label_set, current_label, labels_so_far, boxes_so_far, state=None): 83 | x = self.input_encoder(x) 84 | condition, state = self.conditioning(label_set, current_label, labels_so_far, boxes_so_far, state=state) 85 | 86 | x = torch.cat((x, condition), dim=-1) 87 | x = F.relu(self.fc(x)) 88 | 89 | mu = self.project_mu(x) 90 | s = self.project_s(x) 91 | 92 | return mu, s, condition, state 93 | 94 | 95 | class AutoregressiveBoxDecoder(nn.Module): 96 | def __init__(self, conditioning_size, representation_size=32): 97 | super(AutoregressiveBoxDecoder, self).__init__() 98 | 99 | self.fc1 = nn.Linear(conditioning_size + representation_size, 2048) 100 | self.fc2 = nn.Linear(2048, 2048) 101 | self.fc3 = nn.Linear(2048, 512) 102 | self.project = nn.Linear(512, 4) 103 | 104 | self.actvn = nn.LeakyReLU(0.2, False) 105 | 106 | def forward(self, z, condition): 107 | x = torch.cat((z, condition), dim=-1) 108 | 109 | x = self.actvn(self.fc1(x)) 110 | x = self.actvn(self.fc2(x)) 111 | x = self.actvn(self.fc3(x)) 112 | x = F.sigmoid(self.project(x)) 113 | 114 | return x 115 | -------------------------------------------------------------------------------- /layout_vae/count.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LabelSetEncoder(nn.Module): 7 | def __init__(self, number_labels): 8 | super(LabelSetEncoder, self).__init__() 9 | 10 | self.number_labels = number_labels 11 | self.fc1 = nn.Linear(self.number_labels, 128) 12 | self.fc2 = nn.Linear(128, 128) 13 | 14 | def forward(self, x): 15 | x = F.relu(self.fc1(x)) 16 | x = F.relu(self.fc2(x)) 17 | 18 | return x 19 | 20 | 21 | class SingleLabelEncoder(nn.Module): 22 | def __init__(self, number_labels): 23 | super(SingleLabelEncoder, self).__init__() 24 | 25 | self.number_labels = number_labels 26 | self.fc1 = nn.Linear(self.number_labels, 128) 27 | self.fc2 = nn.Linear(128, 128) 28 | 29 | def forward(self, x): 30 | x = F.relu(self.fc1(x)) 31 | x = F.relu(self.fc2(x)) 32 | 33 | return x 34 | 35 | 36 | class CountsEncoder(nn.Module): 37 | def __init__(self, number_labels): 38 | super(CountsEncoder, self).__init__() 39 | 40 | self.number_labels = number_labels 41 | self.fc1 = nn.Linear(self.number_labels, 128) 42 | self.fc2 = nn.Linear(128, 128) 43 | 44 | def forward(self, x): 45 | x = F.relu(self.fc1(x)) 46 | x = F.relu(self.fc2(x)) 47 | 48 | return x 49 | 50 | 51 | class CountConditioningMLP(nn.Module): 52 | def __init__(self, number_labels): 53 | super(CountConditioningMLP, self).__init__() 54 | 55 | self.encode_label_set = LabelSetEncoder(number_labels) 56 | self.encode_single_label = SingleLabelEncoder(number_labels) 57 | self.encode_counts = CountsEncoder(number_labels) 58 | 59 | self.fc = nn.Linear(128 * 3, 128) 60 | 61 | # the LayoutVAE paper seems to indicate they teacher-force 62 | # at evaluation time... I can't imagine that is correct? 63 | def forward(self, label_set, current_label, count_so_far): 64 | label_set = self.encode_label_set(label_set) 65 | current_label = self.encode_single_label(current_label) 66 | count_so_far = self.encode_counts(count_so_far) 67 | 68 | aggregate = torch.cat((label_set, current_label, count_so_far), dim=-1) 69 | aggregate = self.fc(aggregate) 70 | 71 | return aggregate 72 | 73 | 74 | class CountInputEncoder(nn.Module): 75 | def __init__(self): 76 | super(CountInputEncoder, self).__init__() 77 | 78 | self.fc1 = nn.Linear(1, 128) 79 | self.fc2 = nn.Linear(128, 128) 80 | 81 | def forward(self, x): 82 | x = F.relu(self.fc1(x)) 83 | x = self.fc2(x) 84 | 85 | return x 86 | 87 | 88 | class AutoregressiveCountEncoder(nn.Module): 89 | def __init__(self, number_labels, conditioning_size, representation_size=32): 90 | super(AutoregressiveCountEncoder, self).__init__() 91 | 92 | self.number_labels = number_labels 93 | 94 | self.input_encoder = CountInputEncoder() 95 | self.conditioning = CountConditioningMLP(self.number_labels) 96 | 97 | self.fc = nn.Linear(128 + conditioning_size, representation_size) 98 | self.project_mu = nn.Linear(representation_size, representation_size) 99 | self.project_s = nn.Linear(representation_size, representation_size) 100 | 101 | # x is the count to be encoded. 102 | def forward(self, x, label_set, current_label, count_so_far): 103 | x = self.input_encoder(x) 104 | condition = self.conditioning(label_set, current_label, count_so_far) 105 | 106 | x = torch.cat((x, condition), dim=-1) 107 | x = F.relu(self.fc(x)) 108 | 109 | mu = self.project_mu(x) 110 | s = self.project_s(x) 111 | 112 | return mu, s, condition 113 | 114 | 115 | class AutoregressiveCountDecoder(nn.Module): 116 | def __init__(self, conditioning_size, representation_size=32): 117 | super(AutoregressiveCountDecoder, self).__init__() 118 | 119 | self.fc1 = nn.Linear(conditioning_size + representation_size, 128) 120 | self.fc2 = nn.Linear(128, 64) 121 | self.project = nn.Linear(64, 1) 122 | 123 | def forward(self, z, condition): 124 | x = torch.cat((z, condition), dim=-1) 125 | 126 | x = F.relu(self.fc1(x)) 127 | x = F.relu(self.fc2(x)) 128 | 129 | # note, we are returning log(lambda) 130 | x = self.project(x) 131 | 132 | return x 133 | -------------------------------------------------------------------------------- /layout_vae/layout.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class BatchCollator(object): 8 | def __call__(self, batch): 9 | transposed_batch = list(zip(*batch)) 10 | indexes = transposed_batch[0] 11 | targets = transposed_batch[1] 12 | 13 | return indexes, targets 14 | 15 | 16 | class TargetLayout(object): 17 | def __init__(self, label_set, count, bbox, label, width, height, annotation_id, permutation, image_id): 18 | device = bbox.device if isinstance(bbox, torch.Tensor) else torch.device("cpu") 19 | 20 | self.label_set = torch.as_tensor(label_set, dtype=torch.float32, device=device) 21 | self.count = torch.as_tensor(count, dtype=torch.float32, device=device) 22 | self.bbox = torch.as_tensor(bbox, dtype=torch.float32, device=device) 23 | self.label = torch.as_tensor(label, dtype=torch.float32, device=device) 24 | self.width = width 25 | self.height = height 26 | self.annotation_id = torch.as_tensor(annotation_id, device=device) 27 | self.permutation = torch.as_tensor(permutation, device=device) 28 | self.image_id = image_id 29 | 30 | def to(self, device): 31 | result = TargetLayout( 32 | self.label_set.to(device), 33 | self.count.to(device), 34 | self.bbox.to(device), 35 | self.label.to(device), 36 | self.width, 37 | self.height, 38 | self.annotation_id.to(device), 39 | self.permutation.to(device), 40 | self.image_id) 41 | 42 | return result 43 | 44 | def __len__(self): 45 | return self.bbox.shape[0] 46 | 47 | 48 | class LayoutDataset(Dataset): 49 | def __init__(self, annotations_path, max_length=128): 50 | super(LayoutDataset, self).__init__() 51 | self.max_length = max_length 52 | self.annotations_path = annotations_path 53 | 54 | # load annotations. 55 | with open(self.annotations_path, "r") as f: 56 | self.data = json.load(f) 57 | 58 | self.categories = {c["id"]: c for c in self.data["categories"]} 59 | self.number_labels = len(self.categories) 60 | print("label set size: {0}".format(self.number_labels)) 61 | 62 | self.json_category_id_to_contiguous_id = { 63 | v: i + 1 for i, v in enumerate([c["id"] for c in self.categories.values()]) 64 | } 65 | 66 | self.contiguous_category_id_to_json_id = { 67 | v: k for k, v in self.json_category_id_to_contiguous_id.items() 68 | } 69 | 70 | self.image_to_annotations = {} 71 | for annotation in self.data["annotations"]: 72 | image_id = annotation["image_id"] 73 | 74 | if not (image_id in self.image_to_annotations): 75 | self.image_to_annotations[image_id] = [] 76 | 77 | self.image_to_annotations[image_id].append(annotation) 78 | 79 | label_sets = [] 80 | counts = [] 81 | boxes = [] 82 | labels = [] 83 | annotation_ids = [] 84 | widths = [] 85 | heights = [] 86 | image_ids = [] 87 | permutations = [] 88 | 89 | self.images = [] 90 | self.annotations = [] 91 | 92 | for image in self.data["images"]: 93 | image_id = image["id"] 94 | height, width = float(image["height"]), float(image["width"]) 95 | 96 | if image_id not in self.image_to_annotations: 97 | continue 98 | 99 | annotations = self.image_to_annotations[image_id] 100 | 101 | if (self.max_length is not None) and (len(annotations) > self.max_length): 102 | annotations = annotations[:self.max_length] 103 | 104 | # hack. 105 | for i, annotation in enumerate(annotations): 106 | annotation["index"] = i 107 | 108 | # sort the annotations left to right with labels (smallest first). 109 | sorted_annotations = [] 110 | for label_index in range(self.number_labels): 111 | category_id = self.contiguous_category_id_to_json_id[label_index + 1] 112 | annotations_of_label = [a for a in annotations if a["category_id"] == category_id] 113 | annotations_of_label = list(sorted(annotations_of_label, key=lambda a: a["bbox"][0])) 114 | sorted_annotations += annotations_of_label 115 | 116 | self.annotations.append(sorted_annotations) 117 | 118 | label_set = np.zeros((self.number_labels,)).astype(np.uint8) 119 | count = np.zeros((self.number_labels,)).astype(np.uint8) 120 | box = np.zeros((len(sorted_annotations), 4)) 121 | label = np.zeros((len(sorted_annotations),)) 122 | annotation_id = np.zeros((len(sorted_annotations),)) 123 | 124 | for annotation_index, annotation in enumerate(sorted_annotations): 125 | contiguous_id = self.json_category_id_to_contiguous_id[annotation["category_id"]] 126 | label_set[contiguous_id - 1] = 1 127 | count[contiguous_id - 1] += 1 128 | x, y, w, h = annotation["bbox"] 129 | 130 | # a good question is if we should divide by the long edge only. 131 | box[annotation_index] = np.array([x / width, y / height, w / width, h / height]) 132 | label[annotation_index] = contiguous_id 133 | annotation_id[annotation_index] = annotation["id"] 134 | 135 | permutation = np.array([a["index"] for a in sorted_annotations]).astype(np.int) 136 | 137 | label_sets.append(label_set) 138 | counts.append(count) 139 | boxes.append(box) 140 | labels.append(label) 141 | widths.append(width) 142 | heights.append(height) 143 | annotation_ids.append(annotation_id) 144 | image_ids.append(image_id) 145 | permutations.append(permutation) 146 | self.images.append(image) 147 | 148 | self.label_sets = np.stack(label_sets, axis=0) 149 | self.counts = np.stack(counts, axis=0) 150 | self.boxes = boxes 151 | self.labels = labels 152 | self.widths = widths 153 | self.heights = heights 154 | self.annotation_ids = annotation_ids 155 | self.image_ids = image_ids 156 | self.permutations = permutations 157 | 158 | print("{0} images retained".format(len(self))) 159 | 160 | def __len__(self): 161 | return self.counts.shape[0] 162 | 163 | def __getitem__(self, index): 164 | # image_data = self.images[index] 165 | 166 | label_set = torch.from_numpy(self.label_sets[index]) 167 | count = torch.from_numpy(self.counts[index]) 168 | box = torch.from_numpy(self.boxes[index]) 169 | label = torch.from_numpy(self.labels[index]) 170 | width = self.widths[index] 171 | height = self.heights[index] 172 | annotation_id = torch.from_numpy(self.annotation_ids[index]) 173 | image_id = self.image_ids[index] 174 | permutation = self.permutations[index] 175 | 176 | target = TargetLayout(label_set, count, box, label, width, height, annotation_id, permutation, image_id) 177 | 178 | return index, target 179 | -------------------------------------------------------------------------------- /layout_vae/train_counts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import argparse 5 | from datetime import datetime 6 | from tqdm import tqdm 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.parallel 10 | import torch.optim as optim 11 | import torch.utils.data 12 | 13 | from count import AutoregressiveCountEncoder, AutoregressiveCountDecoder 14 | from layout import BatchCollator, LayoutDataset 15 | 16 | 17 | def evaluate(model, loader, loss): 18 | errors = [] 19 | model.eval() 20 | losses = None 21 | count_losses = [] 22 | divergence_losses = [] 23 | 24 | for batch_i, (indexes, target) in tqdm(enumerate(loader)): 25 | label_set = torch.stack([t.label_set for t in target], dim=0).float().to(device) 26 | counts = torch.stack([t.count for t in target], dim=0).float().to(device) 27 | batch_size = label_set.size(0) 28 | 29 | label_set_size = torch.sum(label_set > 0, dim=1).float() 30 | 31 | # depending on forcing, this can be prediction or ground truth. 32 | previous_counts = torch.zeros((batch_size, NUMBER_LABELS)).to(device) 33 | predicted_counts = torch.zeros((batch_size, NUMBER_LABELS)).to(device) 34 | 35 | batch_errors = torch.zeros((NUMBER_LABELS,), device=device).to(device) 36 | for label_i in range(NUMBER_LABELS): 37 | current_count = counts[:, label_i] 38 | has_nonzero = current_count > 0 39 | nonzero_batch_size = has_nonzero.nonzero().size(0) 40 | 41 | if nonzero_batch_size > 0: 42 | current_label_set = label_set[has_nonzero, :] 43 | current_previous_counts = previous_counts[has_nonzero, :] 44 | current_count = current_count[has_nonzero].unsqueeze(-1) #.unsqueeze(-1) 45 | current_label = label_encodings[label_i].unsqueeze(0).repeat(nonzero_batch_size, 1) 46 | 47 | log_rate, kl_divergence, _ = model(current_count, current_label_set, current_label, current_previous_counts) 48 | 49 | current_loss = loss(log_rate, current_count) 50 | losses = current_loss if losses is None else torch.cat([losses, current_loss]) 51 | 52 | count_loss_i = count_loss(log_rate, current_count) 53 | # print(count_loss_i) 54 | count_losses.append(count_loss_i.reshape(-1)) 55 | divergence_losses.append(kl_divergence.reshape(-1)) 56 | 57 | # # predict the counts (try the lame way for now) 58 | # predicted_count = [] 59 | 60 | # for nz_i in range(nonzero_batch_size): 61 | # rate_i = torch.exp(log_rate[nz_i]) 62 | # dist = Poisson(rate_i) 63 | # predicted_count_i = dist.sample((1,))[0] + 1 64 | # predicted_count.append(predicted_count_i) 65 | 66 | # predicted_count = torch.cat(predicted_count, dim=0) 67 | # I think technically this needs to take care of the integer case. 68 | predicted_count = torch.floor(torch.exp(torch.squeeze(log_rate, dim=-1))) + 1 69 | predicted_counts[has_nonzero, label_i] = predicted_count 70 | 71 | batch_errors[label_i] = torch.mean(torch.abs(torch.squeeze(current_count, dim=-1) - predicted_count)) 72 | 73 | # teacher forcing when evaluating reconstructions? 74 | previous_counts_mask = torch.cat(( 75 | torch.ones((batch_size, label_i + 1), device=device), 76 | torch.zeros((batch_size, NUMBER_LABELS - label_i - 1), 77 | device=device)), dim=-1) 78 | previous_counts = previous_counts_mask * counts 79 | 80 | errors.append(batch_errors) 81 | 82 | errors = torch.stack(errors, dim=0) 83 | average_error = torch.mean(errors, dim=0) 84 | average_loss = torch.mean(losses) 85 | print(f"validation: average error per class: {average_error}") 86 | print(f"validation: average loss: {average_loss}") 87 | count_losses = torch.cat(count_losses) 88 | divergence_losses = torch.cat(divergence_losses) 89 | loss_epoch = torch.mean(count_losses) + torch.mean(divergence_losses) 90 | 91 | return loss_epoch.item() 92 | 93 | 94 | class PoissonLogLikelihood(nn.Module): 95 | def __init__(self): 96 | super(PoissonLogLikelihood, self).__init__() 97 | 98 | def forward(self, log_rate, count): 99 | # learned over count - 1? 100 | count = count - 1 101 | log_factorial = torch.lgamma(count + 1) 102 | log_likelihood = -torch.exp(log_rate) + count * log_rate - log_factorial 103 | 104 | # I assume this will be like N x [max # labels] 105 | return -log_likelihood 106 | 107 | 108 | class AutoregressiveVariationalAutoencoder(nn.Module): 109 | def __init__(self, number_labels, conditioning_size, representation_size): 110 | super(AutoregressiveVariationalAutoencoder, self).__init__() 111 | 112 | self.representation_size = representation_size 113 | 114 | self.encoder = AutoregressiveCountEncoder(number_labels, conditioning_size, representation_size) 115 | self.decoder = AutoregressiveCountDecoder(conditioning_size, representation_size) 116 | 117 | def sample(self, mu, log_var): 118 | batch_size = mu.size(0) 119 | device = mu.device 120 | 121 | standard_normal = torch.randn((batch_size, self.representation_size), device=device) 122 | z = mu + standard_normal * torch.exp(0.5 * log_var) 123 | 124 | kl_divergence = -0.5 * torch.sum( 125 | 1 + log_var - (mu ** 2) - torch.exp(log_var), dim=1) 126 | 127 | return z, kl_divergence 128 | 129 | def forward(self, x, label_set, current_label, count_so_far): 130 | mu, s, condition = self.encoder(x, label_set, current_label, count_so_far) 131 | z, kl_divergence = self.sample(mu, s) 132 | log_rate = self.decoder(z, condition) 133 | 134 | return log_rate, kl_divergence, z 135 | 136 | 137 | if __name__ == "__main__": 138 | parser = argparse.ArgumentParser('Count VAE') 139 | parser.add_argument("--exp", default="count_vae", help="postfix for experiment name") 140 | parser.add_argument("--log_dir", default="./logs", help="/path/to/logs/dir") 141 | parser.add_argument("--train_json", default="./instances_train.json", help="/path/to/train/json") 142 | parser.add_argument("--val_json", default="./instances_val.json", help="/path/to/val/json") 143 | 144 | parser.add_argument("--max_length", type=int, default=128, help="batch size") 145 | 146 | parser.add_argument("--seed", type=int, default=42, help="random seed") 147 | parser.add_argument("--epochs", type=int, default=50, help="number of epochs") 148 | parser.add_argument("--batch_size", type=int, default=32, help="batch size") 149 | parser.add_argument("--lr", type=float, default=0.0001, help="learning rate") 150 | parser.add_argument("--beta_1", type=float, default=0.9, help="beta_1 for adam") 151 | parser.add_argument('--evaluate', action='store_true', help="evaluate only") 152 | parser.add_argument('--save_every', type=int, default=10, help="evaluate only") 153 | 154 | args = parser.parse_args() 155 | 156 | if not args.evaluate: 157 | now = datetime.now().strftime("%m%d%y_%H%M%S") 158 | log_dir = os.path.join(args.log_dir, f"{now}_{args.exp}") 159 | ckpt_dir = os.path.join(log_dir, "checkpoints") 160 | os.makedirs(ckpt_dir, exist_ok=True) 161 | else: 162 | log_dir = args.log_dir 163 | ckpt_dir = os.path.join(log_dir, "checkpoints") 164 | 165 | random.seed(args.seed) 166 | torch.manual_seed(args.seed) 167 | 168 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 169 | print(f"using device: {device}") 170 | 171 | collator = BatchCollator() 172 | train_dataset = LayoutDataset(args.train_json, args.max_length) 173 | train_loader = torch.utils.data.DataLoader( 174 | train_dataset, 175 | batch_size=args.batch_size, 176 | shuffle=True, 177 | num_workers=0, 178 | collate_fn=collator) 179 | 180 | validation_dataset = LayoutDataset(args.val_json, args.max_length) 181 | validation_loader = torch.utils.data.DataLoader( 182 | validation_dataset, 183 | batch_size=args.batch_size, 184 | shuffle=False, 185 | num_workers=0, 186 | collate_fn=collator) 187 | 188 | NUMBER_LABELS = train_dataset.number_labels 189 | 190 | label_encodings = torch.eye(NUMBER_LABELS).float().to(device) 191 | count_loss = PoissonLogLikelihood().to(device) 192 | 193 | autoencoder = AutoregressiveVariationalAutoencoder( 194 | NUMBER_LABELS, 195 | conditioning_size=128, 196 | representation_size=32).to(device) 197 | 198 | # evaluate the model 199 | if args.evaluate: 200 | min_epoch = -1 201 | min_loss = 1e100 202 | for epoch in range(args.epochs): 203 | checkpoint_path = os.path.join(log_dir, "checkpoints", 'epoch_%d.pth' % epoch) 204 | if not os.path.exists(checkpoint_path): 205 | continue 206 | print('Evaluating', checkpoint_path) 207 | checkpoint = torch.load(checkpoint_path) 208 | autoencoder.load_state_dict(checkpoint["model_state_dict"], strict=True) 209 | loss = evaluate(autoencoder, validation_loader, count_loss) 210 | print('End of epoch %d : %f' % (epoch, loss)) 211 | if loss < min_loss: 212 | min_loss = loss 213 | min_epoch = epoch 214 | print('Best epoch: %d Best nll: %f' % (min_epoch, min_loss)) 215 | sys.exit(0) 216 | 217 | opt = optim.Adam(autoencoder.parameters(), lr=args.lr, betas=(args.beta_1, 0.999)) 218 | epoch_number = 0 219 | while True: 220 | if (epoch_number > 0) and (epoch_number == args.epochs): 221 | print("done!") 222 | break 223 | 224 | print(f"starting epoch {epoch_number+1}") 225 | autoencoder.train() 226 | 227 | with tqdm(enumerate(train_loader)) as tq: 228 | for batch_i, (indexes, target) in tq: 229 | autoencoder.zero_grad() 230 | count_loss.zero_grad() 231 | 232 | label_set = torch.stack([t.label_set for t in target], dim=0).float().to(device) 233 | counts = torch.stack([t.count for t in target], dim=0).float().to(device) 234 | 235 | batch_size = label_set.size(0) 236 | 237 | label_set_size = torch.sum(label_set > 0, dim=1).float() 238 | 239 | count_losses = [] 240 | divergence_losses = [] 241 | 242 | previous_counts = torch.zeros((batch_size, NUMBER_LABELS)).to(device) 243 | 244 | for label_i in range(NUMBER_LABELS): 245 | current_count_loss = torch.zeros((batch_size,)).to(device) 246 | current_divergence_loss = torch.zeros((batch_size,)).to(device) 247 | 248 | current_count = counts[:, label_i] 249 | has_nonzero = current_count > 0 250 | nonzero_batch_size = has_nonzero.nonzero().size(0) 251 | if nonzero_batch_size > 0: 252 | current_label_set = label_set[has_nonzero, :] 253 | current_previous_counts = previous_counts[has_nonzero, :] 254 | current_count = current_count[has_nonzero].unsqueeze(-1) # .unsqueeze(-1) 255 | current_label = label_encodings[label_i].unsqueeze(0).repeat(nonzero_batch_size, 1) 256 | 257 | log_rate, kl_divergence, z = autoencoder(current_count, current_label_set, current_label, 258 | current_previous_counts) 259 | count_loss_i = count_loss(log_rate, current_count) 260 | current_count_loss[has_nonzero] = count_loss_i[:, 0] 261 | count_losses.append(current_count_loss) 262 | 263 | current_divergence_loss[has_nonzero] = kl_divergence 264 | divergence_losses.append(current_divergence_loss) 265 | # unsure if we do backward() here? 266 | 267 | # teacher forcing! 268 | previous_counts_mask = torch.cat(( 269 | torch.ones((batch_size, label_i + 1), device=device), 270 | torch.zeros((batch_size, NUMBER_LABELS - label_i - 1), device=device)), dim=-1) 271 | previous_counts = previous_counts_mask * counts 272 | 273 | count_losses = torch.stack(count_losses, dim=-1) 274 | count_loss_batch = torch.mean(torch.sum(count_losses, dim=-1) / label_set_size) 275 | 276 | divergence_losses = torch.stack(divergence_losses, dim=-1) 277 | divergence_loss_batch = torch.mean(torch.sum(divergence_losses, dim=-1) / label_set_size) 278 | 279 | loss_batch = count_loss_batch + 0.01 * divergence_loss_batch 280 | loss_batch.backward() 281 | opt.step() 282 | 283 | tq.set_description(f"{epoch_number+1}/{args.epochs} count_loss: {count_loss_batch.item()} " 284 | f"divergence_loss: {divergence_loss_batch.item()}") 285 | 286 | evaluate(autoencoder, validation_loader, count_loss) 287 | # print(f"validation loss [{epoch_number}/{args.epochs}: {validation_loss.item():4f}") 288 | 289 | # write out a checkpoint too. 290 | if (epoch_number + 1) % args.save_every == 0: 291 | torch.save({ 292 | "epoch": epoch_number, 293 | "model_state_dict": autoencoder.state_dict(), 294 | }, os.path.join(ckpt_dir, "epoch_{0}.pth".format(epoch_number))) 295 | 296 | epoch_number += 1 297 | -------------------------------------------------------------------------------- /layout_vae/train_layouts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | import random 5 | import sys 6 | from tqdm import tqdm 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.parallel 10 | import torch.optim as optim 11 | import torch.utils.data 12 | from datetime import datetime 13 | from PIL import Image, ImageDraw 14 | import seaborn as sns 15 | 16 | from box import AutoregressiveBoxEncoder, AutoregressiveBoxDecoder 17 | from layout import BatchCollator, LayoutDataset 18 | 19 | 20 | def gen_colors(num_colors): 21 | """ 22 | Generate uniformly distributed `num_colors` colors 23 | :param num_colors: 24 | :return: 25 | """ 26 | palette = sns.color_palette(None, num_colors) 27 | rgb_triples = [[int(x[0]*255), int(x[1]*255), int(x[2]*255)] for x in palette] 28 | return rgb_triples 29 | 30 | 31 | def plot_layout(real_boxes, predicted_boxes, labels, width, height, colors=None): 32 | blank_image = Image.new("RGB", (int(width), int(height)), (255, 255, 255)) 33 | blank_draw = ImageDraw.Draw(blank_image) 34 | 35 | number_boxes = real_boxes.shape[0] 36 | for i in range(number_boxes): 37 | real_box = real_boxes[i].tolist() 38 | predicted_box = predicted_boxes[i].tolist() 39 | label = int(labels[i]) 40 | 41 | real_x1, real_y1 = int(real_box[0] * width), int(real_box[1] * height) 42 | real_x2, real_y2 = real_x1 + int(real_box[2] * width), real_y1 + int(real_box[3] * height) 43 | 44 | predicted_x1, predicted_y1 = int(predicted_box[0] * width), int(predicted_box[1] * height) 45 | predicted_x2, predicted_y2 = predicted_x1 + int(predicted_box[2] * width), predicted_y1 + int( 46 | predicted_box[3] * height) 47 | 48 | real_color = (0, 0, 0) 49 | if colors is not None: 50 | real_color = tuple(colors[label]) 51 | 52 | blank_draw.rectangle([(real_x1, real_y1), (real_x2, real_y2)], outline=real_color) 53 | blank_draw.rectangle([(predicted_x1, predicted_y1), (predicted_x2, predicted_y2)], outline=(0, 0, 0)) 54 | 55 | return blank_image 56 | 57 | 58 | def evaluate(model, loader, loss, prefix='', colors=None): 59 | errors = [] 60 | model.eval() 61 | losses = None 62 | box_losses = [] 63 | divergence_losses = [] 64 | 65 | for batch_i, (indexes, target) in tqdm(enumerate(loader)): 66 | label_set = torch.stack([t.label_set for t in target], dim=0).to(device) 67 | counts = torch.stack([t.count for t in target], dim=0).to(device) 68 | boxes = [t.bbox.to(device) for t in target] 69 | labels = [t.label.to(device) for t in target] 70 | number_boxes = np.stack([len(t) for t in target], axis=0) 71 | max_number_boxes = np.max(number_boxes) 72 | batch_size = label_set.size(0) 73 | 74 | predicted_boxes = torch.zeros((batch_size, max_number_boxes, 4)).to(device) 75 | # import ipdb; ipdb.set_trace() 76 | for step in range(max_number_boxes): 77 | # determine who has a box. 78 | has_box = number_boxes > step 79 | 80 | # determine their history of box/labels. 81 | current_label_set = label_set[has_box, :] 82 | current_counts = counts[has_box, :] 83 | 84 | all_boxes = [boxes[i] for i, has in enumerate(has_box) if has] 85 | all_labels = [labels[i] for i, has in enumerate(has_box) if has] 86 | current_label = torch.stack([l[step] for l in all_labels], dim=0).to(device) 87 | current_label = label_encodings[current_label.long() - 1] 88 | current_box = torch.stack([b[step] for b in all_boxes], dim=0).to(device) 89 | 90 | # now, consider the history. 91 | if step == 0: 92 | previous_labels = torch.zeros((batch_size, 0, 7)).to(device) 93 | previous_boxes = torch.zeros((batch_size, 0, 4)).to(device) 94 | else: 95 | previous_labels = torch.stack([l[step - 1] for l in all_labels], dim=0).unsqueeze(1) 96 | previous_labels = label_encodings[previous_labels.long() - 1] 97 | 98 | # we need to 1-hot these. only take the previous one since 99 | # we'll accumulate state instead. 100 | previous_boxes = torch.stack([b[step - 1] for b in all_boxes], dim=0).unsqueeze(1) 101 | 102 | # take a step. x, label_set, current_label, count_so_far): 103 | state = (h[has_box].unsqueeze(0), c[has_box].unsqueeze(0)) if step > 1 else None 104 | predicted_boxes_step, kl_divergence, z, state = model(current_box, current_label_set, current_label, 105 | previous_labels, previous_boxes, state=state) 106 | predicted_boxes[has_box, step] = predicted_boxes_step 107 | 108 | box_loss_step = loss(predicted_boxes_step, current_box) 109 | losses = box_loss_step if losses is None else torch.cat([losses, box_loss_step]) 110 | 111 | box_losses.append(box_loss_step.reshape(-1)) 112 | divergence_losses.append(kl_divergence.reshape(-1)) 113 | 114 | if state is not None: 115 | h, c = torch.zeros((batch_size, 128)).to(device), torch.zeros((batch_size, 128)).to(device) 116 | h[has_box, :] = state[0][-1] 117 | c[has_box, :] = state[1][-1] 118 | 119 | if batch_i == 0 and colors is not None: 120 | # try plotting the first batch. 121 | for i in range(batch_size): 122 | count = number_boxes[i] 123 | plotted = plot_layout( 124 | boxes[i].detach().cpu().numpy(), 125 | predicted_boxes[i, :count], 126 | labels[i].detach().cpu().numpy()-1, 127 | target[i].width, 128 | target[i].height, 129 | colors=colors) 130 | 131 | plotted.save(f"{prefix}_{i:05d}.png") 132 | 133 | # pdb.set_trace() 134 | average_loss = torch.mean(losses) 135 | print(f"validation: average loss: {average_loss}") 136 | count_losses = torch.cat(box_losses) 137 | divergence_losses = torch.cat(divergence_losses) 138 | loss_epoch = torch.mean(count_losses) + torch.mean(divergence_losses) 139 | 140 | return loss_epoch.item() 141 | 142 | 143 | class GaussianLogLikelihood(nn.Module): 144 | def __init__(self): 145 | super(GaussianLogLikelihood, self).__init__() 146 | 147 | self.var = 0.02 ** 2 148 | 149 | def forward(self, predicted, expected): 150 | # not really sure if I am supposed to use the variance 151 | # stated in the paper. 152 | error = torch.mean((predicted - expected) ** 2, dim=-1) 153 | return error 154 | 155 | 156 | class AutoregressiveBoxVariationalAutoencoder(nn.Module): 157 | def __init__(self, number_labels, conditioning_size, representation_size): 158 | super(AutoregressiveBoxVariationalAutoencoder, self).__init__() 159 | 160 | self.representation_size = representation_size 161 | 162 | self.encoder = AutoregressiveBoxEncoder(number_labels, conditioning_size, representation_size) 163 | self.decoder = AutoregressiveBoxDecoder(conditioning_size, representation_size) 164 | 165 | def sample(self, mu, log_var): 166 | batch_size = mu.size(0) 167 | device = mu.device 168 | 169 | standard_normal = torch.randn((batch_size, self.representation_size), device=device) 170 | z = mu + standard_normal * torch.exp(0.5 * log_var) 171 | 172 | kl_divergence = -0.5 * torch.sum( 173 | 1 + log_var - (mu ** 2) - torch.exp(log_var), dim=1) 174 | 175 | return z, kl_divergence 176 | 177 | def forward(self, x, label_set, current_label, labels_so_far, boxes_so_far, state=None): 178 | mu, s, condition, state = self.encoder(x, label_set, current_label, labels_so_far, boxes_so_far, state) 179 | 180 | z, kl_divergence = self.sample(mu, s) 181 | boxes = self.decoder(z, condition) 182 | 183 | return boxes, kl_divergence, z, state 184 | 185 | 186 | if __name__ == "__main__": 187 | parser = argparse.ArgumentParser('Box VAE') 188 | parser.add_argument("--exp", default="box_vae", help="postfix for experiment name") 189 | parser.add_argument("--log_dir", default="./logs", help="/path/to/logs/dir") 190 | parser.add_argument("--train_json", default="./instances_train.json", help="/path/to/train/json") 191 | parser.add_argument("--val_json", default="./instances_val.json", help="/path/to/val/json") 192 | 193 | parser.add_argument("--max_length", type=int, default=128, help="batch size") 194 | 195 | parser.add_argument("--seed", type=int, default=42, help="random seed") 196 | parser.add_argument("--epochs", type=int, default=50, help="number of epochs") 197 | parser.add_argument("--batch_size", type=int, default=32, help="batch size") 198 | parser.add_argument("--lr", type=float, default=0.0001, help="learning rate") 199 | parser.add_argument("--beta_1", type=float, default=0.9, help="beta_1 for adam") 200 | parser.add_argument('--evaluate', action='store_true', help="evaluate only") 201 | parser.add_argument('--save_every', type=int, default=10, help="evaluate only") 202 | 203 | args = parser.parse_args() 204 | 205 | if not args.evaluate: 206 | now = datetime.now().strftime("%m%d%y_%H%M%S") 207 | log_dir = os.path.join(args.log_dir, f"{now}_{args.exp}") 208 | samples_dir = os.path.join(log_dir, "samples") 209 | ckpt_dir = os.path.join(log_dir, "checkpoints") 210 | os.makedirs(samples_dir, exist_ok=True) 211 | os.makedirs(ckpt_dir, exist_ok=True) 212 | else: 213 | log_dir = args.log_dir 214 | samples_dir = os.path.join(log_dir, "samples") 215 | ckpt_dir = os.path.join(log_dir, "checkpoints") 216 | 217 | random.seed(args.seed) 218 | torch.manual_seed(args.seed) 219 | 220 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 221 | print(f"using device: {device}") 222 | 223 | collator = BatchCollator() 224 | train_dataset = LayoutDataset(args.train_json, args.max_length) 225 | train_loader = torch.utils.data.DataLoader( 226 | train_dataset, 227 | batch_size=args.batch_size, 228 | shuffle=False, 229 | num_workers=0, 230 | collate_fn=collator) 231 | 232 | validation_dataset = LayoutDataset(args.val_json, args.max_length) 233 | validation_loader = torch.utils.data.DataLoader( 234 | validation_dataset, 235 | batch_size=args.batch_size, 236 | shuffle=True, 237 | num_workers=0, 238 | collate_fn=collator) 239 | 240 | NUMBER_LABELS = train_dataset.number_labels 241 | colors = gen_colors(NUMBER_LABELS) 242 | 243 | label_encodings = torch.eye(NUMBER_LABELS).float().to(device) 244 | box_loss = GaussianLogLikelihood().to(device) 245 | 246 | autoencoder = AutoregressiveBoxVariationalAutoencoder( 247 | NUMBER_LABELS, 248 | conditioning_size=128, 249 | representation_size=32).to(device) 250 | 251 | # evaluate the model 252 | if args.evaluate: 253 | min_epoch = -1 254 | min_loss = 1e100 255 | for epoch in range(args.epochs): 256 | checkpoint_path = os.path.join(log_dir, "checkpoints", 'epoch_%d.pth' % epoch) 257 | if not os.path.exists(checkpoint_path): 258 | continue 259 | print('Evaluating', checkpoint_path) 260 | checkpoint = torch.load(checkpoint_path) 261 | autoencoder.load_state_dict(checkpoint["model_state_dict"], strict=True) 262 | loss = evaluate(autoencoder, validation_loader, box_loss) 263 | print('End of epoch %d : %f' % (epoch, loss)) 264 | if loss < min_loss: 265 | min_loss = loss 266 | min_epoch = epoch 267 | print('Best epoch: %d Best nll: %f' % (min_epoch, min_loss)) 268 | sys.exit(0) 269 | 270 | opt = optim.Adam(autoencoder.parameters(), lr=args.lr, betas=(args.beta_1, 0.999)) 271 | epoch_number = 0 272 | while True: 273 | if (epoch_number > 0) and (epoch_number == args.epochs): 274 | print("done!") 275 | break 276 | 277 | print(f"starting epoch {epoch_number+1}") 278 | autoencoder.train() 279 | 280 | with tqdm(enumerate(train_loader)) as tq: 281 | for batch_i, (indexes, target) in tq: 282 | autoencoder.zero_grad() 283 | box_loss.zero_grad() 284 | label_set = torch.stack([t.label_set for t in target], dim=0).to(device) 285 | counts = torch.stack([t.count for t in target], dim=0).to(device) 286 | boxes = [t.bbox.to(device) for t in target] 287 | labels = [t.label.to(device) for t in target] 288 | number_boxes = np.stack([len(t) for t in target], axis=0) 289 | max_number_boxes = np.max(number_boxes) 290 | 291 | batch_size = label_set.size(0) 292 | # previous_boxes = torch.zeros((batch_size, max_number_boxes, 4)).to(device) 293 | 294 | box_losses = [] 295 | divergence_losses = [] 296 | 297 | current_box_loss = torch.zeros((batch_size, max_number_boxes)).to(device) 298 | current_divergence_loss = torch.zeros((batch_size, max_number_boxes)).to(device) 299 | 300 | for step in range(max_number_boxes): 301 | 302 | # determine who has a box. 303 | has_box = number_boxes > step 304 | 305 | # determine their history of box/labels. 306 | current_label_set = label_set[has_box, :] 307 | current_counts = counts[has_box, :] 308 | 309 | all_boxes = [boxes[i] for i, has in enumerate(has_box) if has] 310 | all_labels = [labels[i] for i, has in enumerate(has_box) if has] 311 | current_label = torch.stack([l[step] for l in all_labels], dim=0).to(device) 312 | current_label = label_encodings[current_label.long() - 1] 313 | current_box = torch.stack([b[step] for b in all_boxes], dim=0).to(device) 314 | 315 | # now, consider the history. 316 | if step == 0: 317 | previous_labels = torch.zeros((batch_size, 0, 7)).to(device) 318 | previous_boxes = torch.zeros((batch_size, 0, 4)).to(device) 319 | else: 320 | previous_labels = torch.stack([l[step - 1] for l in all_labels], dim=0).unsqueeze(1) 321 | previous_labels = label_encodings[previous_labels.long() - 1] 322 | 323 | # we need to 1-hot these. only take the previous one since 324 | # we'll accumulate state instead. 325 | previous_boxes = torch.stack([b[step - 1] for b in all_boxes], dim=0).unsqueeze(1) 326 | 327 | # take a step. x, label_set, current_label, count_so_far): 328 | state = (h[has_box].unsqueeze(0), c[has_box].unsqueeze(0)) if step > 1 else None 329 | predicted_boxes, kl_divergence, z, state = autoencoder(current_box, current_label_set, current_label, 330 | previous_labels, previous_boxes, state=state) 331 | if not (state is None): 332 | h, c = torch.zeros((batch_size, 128)).to(device), torch.zeros((batch_size, 128)).to(device) 333 | h[has_box, :] = state[0][-1] 334 | c[has_box, :] = state[1][-1] 335 | 336 | box_loss_step = box_loss(predicted_boxes, current_box) 337 | 338 | current_box_loss[has_box, step] = box_loss_step 339 | current_divergence_loss[has_box, step] = kl_divergence 340 | 341 | number_boxes = torch.from_numpy(number_boxes).to(device).float() 342 | box_loss_batch = torch.mean(torch.sum(current_box_loss, dim=-1) / number_boxes) 343 | divergence_loss_batch = torch.mean(torch.sum(current_divergence_loss, dim=-1) / number_boxes) 344 | 345 | loss_batch = box_loss_batch + 0.0001 * divergence_loss_batch 346 | loss_batch.backward() 347 | opt.step() 348 | 349 | tq.set_description(f"{epoch_number+1}/{args.epochs} box_loss: {box_loss_batch.item()}" 350 | f"kl: {divergence_loss_batch.item()}") 351 | 352 | # if (epoch_number + 1) % 1 == 0: 353 | # validation_loss, validation_accuracy = evaluate() 354 | # print("validation loss [{0}/{1}: {2:4f}".format(epoch_number, NUMBER_EPOCHS, validation_loss.item())) 355 | # # write out a checkpoint too. 356 | prefix = os.path.join(samples_dir, f"epoch_{epoch_number+1:03d}") 357 | evaluate(autoencoder, validation_loader, box_loss, prefix=prefix, colors=colors) 358 | torch.save({ 359 | "epoch": epoch_number, 360 | "model_state_dict": autoencoder.state_dict(), 361 | }, os.path.join(ckpt_dir, "epoch_{0}.pth".format(epoch_number))) 362 | 363 | epoch_number += 1 364 | --------------------------------------------------------------------------------