├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── TRAINING.md ├── images ├── sheep │ ├── img000000.png │ ├── img000001.png │ ├── img000002.png │ ├── img000003.png │ ├── img000004.png │ ├── img000005.png │ ├── img000006.png │ ├── sg000000.png │ ├── sg000001.png │ ├── sg000002.png │ ├── sg000003.png │ ├── sg000004.png │ ├── sg000005.png │ └── sg000006.png └── system.png ├── requirements.txt ├── scene_graphs ├── figure_5_coco.json ├── figure_5_vg.json ├── figure_6_sheep.json └── figure_6_street.json ├── scripts ├── download_ablated_models.sh ├── download_coco.sh ├── download_full_models.sh ├── download_models.sh ├── download_vg.sh ├── preprocess_vg.py ├── print_args.py ├── run_model.py ├── sample_images.py ├── strip_checkpoint.py ├── strip_old_args.py └── train.py └── sg2im ├── __init__.py ├── bilinear.py ├── box_utils.py ├── crn.py ├── data ├── __init__.py ├── coco.py ├── utils.py ├── vg.py └── vg_splits.json ├── discriminators.py ├── graph.py ├── layers.py ├── layout.py ├── losses.py ├── metrics.py ├── model.py ├── utils.py └── vis.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.pyc 3 | *.DS_Store 4 | outputs/ 5 | sg2im-models/ 6 | env/ 7 | datasets/ 8 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sg2im 2 | 3 | This is the code for the paper 4 | 5 | **Image Generation from Scene Graphs** 6 |
7 | Justin Johnson, 8 | Agrim Gupta, 9 | Li Fei-Fei 10 |
11 | Presented at [CVPR 2018](http://cvpr2018.thecvf.com/) 12 | 13 | Please note that this is not an officially supported Google product. 14 | 15 | A **scene graph** is a structured representation of a visual scene where nodes represent *objects* in the scene and edges represent *relationships* between objects. In this paper we present and end-to-end neural network model that inputs a scene graph and outputs an image. 16 | 17 | Below we show some example scene graphs along with images generated from those scene graphs using our model. By modifying the input scene graph we can exercise fine-grained control over the objects in the generated image. 18 | 19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 |
28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 |
36 | 37 | If you find this code useful in your research then please cite 38 | ``` 39 | @inproceedings{johnson2018image, 40 | title={Image Generation from Scene Graphs}, 41 | author={Johnson, Justin and Gupta, Agrim and Fei-Fei, Li}, 42 | booktitle={CVPR}, 43 | year={2018} 44 | } 45 | ``` 46 | 47 | ## Model 48 | The input scene graph is processed with a *graph convolution network* which passes information along edges to compute embedding vectors for all objects. These vectors are used to predict bounding boxes and segmentation masks for all objects, which are combined to form a coarse *scene layout*. The layout is passed to a *cascaded refinement network* (Chen an Koltun, ICCV 2017) which generates an output image at increasing spatial scales. The model is trained adversarially against a pair of *discriminator networks* which ensure that output images look realistic. 49 | 50 |
51 | 52 |
53 | 54 | ## Setup 55 | All code was developed and tested on Ubuntu 16.04 with Python 3.5 and PyTorch 0.4. 56 | 57 | You can setup a virtual environment to run the code like this: 58 | 59 | ```bash 60 | python3 -m venv env # Create a virtual environment 61 | source env/bin/activate # Activate virtual environment 62 | pip install -r requirements.txt # Install dependencies 63 | echo $PWD > env/lib/python3.5/site-packages/sg2im.pth # Add current directory to python path 64 | # Work for a while ... 65 | deactivate # Exit virtual environment 66 | ``` 67 | 68 | ## Pretrained Models 69 | You can download pretrained models by running the script `bash scripts/download_models.sh`. This will download the following models, and will require about 355 MB of disk space: 70 | 71 | - `sg2im-models/coco64.pt`: Trained to generate 64 x 64 images on the COCO-Stuff dataset. This model was used to generate the COCO images in Figure 5 from the paper. 72 | - `sg2im-models/vg64.pt`: Trained to generate 64 x 64 images on the Visual Genome dataset. This model was used to generate the Visual Genome images in Figure 5 from the paper. 73 | - `sg2im-models/vg128.pt`: Trained to generate 128 x 128 images on the Visual Genome dataset. This model was used to generate the images in Figure 6 from the paper. 74 | 75 | Table 1 in the paper presents an ablation study where we disable various components of the full model. You can download the additional models used in this ablation study by running the script `bash scripts/download_ablated_models.sh`. This will download 12 additional models, requiring and additional 1.25 GB of disk space. 76 | 77 | ## Running Models 78 | You can use the script `scripts/run_model.py` to easily run any of the pretrained models on new scene graphs using a simple human-readable JSON format. For example you can replicate the sheep images above like this: 79 | 80 | ```bash 81 | python scripts/run_model.py \ 82 | --checkpoint sg2im-models/vg128.pt \ 83 | --scene_graphs scene_graphs/figure_6_sheep.json \ 84 | --output_dir outputs 85 | ``` 86 | 87 | The generated images will be saved to the directory specified by the `--output_dir` flag. You can control whether the model runs on CPU or GPU using py passing the flag `--device cpu` or `--device gpu`. 88 | 89 | We provide JSON files and pretrained models allowing you to recreate all images from Figures 5 and 6 from the paper. 90 | 91 | #### (Optional): GraphViz 92 | This script can also draw images for the scene graphs themselves using [GraphViz](http://www.graphviz.org/); to enable this option just add the flag `--draw_scene_graphs 1` and the scene graph images will also be saved in the output directory. For this option to work you must install GraphViz; on Ubuntu 16.04 you can simply run `sudo apt-get install graphviz`. 93 | 94 | ## Training new models 95 | Instructions for training new models can be [found here](TRAINING.md). 96 | -------------------------------------------------------------------------------- /TRAINING.md: -------------------------------------------------------------------------------- 1 | You can train your own model by following these instructions: 2 | 3 | ## Step 1: Install COCO API 4 | To train new models you will need to install the [COCO Python API](https://github.com/cocodataset/cocoapi). Unfortunately installing this package via pip often leads to build errors, but you can install it from source like this: 5 | 6 | ```bash 7 | cd ~ 8 | git clone https://github.com/cocodataset/cocoapi.git 9 | cd cocoapi/PythonAPI/ 10 | python setup.py install 11 | ``` 12 | 13 | ## Step 2: Preparing the data 14 | 15 | ### Visual Genome 16 | Run the following script to download and unpack the relevant parts of the Visual Genome dataset: 17 | 18 | ```bash 19 | bash scripts/download_vg.sh 20 | ``` 21 | 22 | This will create the directory `datasets/vg` and will download about 15 GB of data to this directory; after unpacking it will take about 30 GB of disk space. 23 | 24 | After downloading the Visual Genome dataset, we need to preprocess it. This will split the data into train / val / test splits, consolidate all scene graphs into HDF5 files, and apply several heuristics to clean the data. In particular we ignore images that are too small, and only consider object and attribute categories that appear some number of times in the training set; we also igmore objects that are too small, and set minimum and maximum values on the number of objects and relationships that appear per image. 25 | 26 | ```bash 27 | python scripts/preprocess_vg.py 28 | ``` 29 | 30 | This will create files `train.h5`, `val.h5`, `test.h5`, and `vocab.json` in the directory `datasets/vg`. 31 | 32 | ### COCO 33 | Run the following script to download and unpack the relevant parts of the COCO dataset: 34 | 35 | ```bash 36 | bash scripts/download_coco.sh 37 | ``` 38 | 39 | This will create the directory `datasets/coco` and will download about 21 GB of data to this directory; after unpacking it will take about 60 GB of disk space. 40 | 41 | ## Step 3: Train a model 42 | 43 | Now you can train a new model by running the script: 44 | 45 | ```bash 46 | python scripts/train.py 47 | ``` 48 | 49 | By default this will train a model on COCO, periodically saving checkpoint files `checkpoint_with_model.pt` and `checkpoint_no_model.pt` to the current working directory. The training script has a number of command-line flags that you can use to configure the model architecture, hyperparameters, and input / output settings: 50 | 51 | ### Optimization 52 | 53 | - `--batch_size`: How many pairs of (scene graph, image) to use in each minibatch during training. Default is 32. 54 | - `--num_iterations`: Number of training iterations. Default is 1,000,000. 55 | - `--learning_rate`: Learning rate to use in Adam optimizer for the generator and discriminators; default is 1e-4. 56 | - `--eval_mode_after`: The generator is trained in "train" mode for this many iterations, after which training continues in "eval" mode. We found that if the model is trained exclusively in "train" mode then generated images can have severe artifacts if test batches have a different size or composition than those used during training. 57 | 58 | ### Dataset options 59 | 60 | - `--dataset`: The dataset to use for training; must be either `coco` or `vg`. Default is `coco`. 61 | - `--image_size`: The size of images to generate, as a tuple of integers. Default is `64,64`. This is also the resolution at which scene layouts are predicted. 62 | - `--num_train_samples`: The number of images from the training set to use. Default is None, which means the entire training set will be used. 63 | - `--num_val_samples`: The number of images from the validation set to use. Default is 1024. This is particularly important for the COCO dataset, since we partition the COCO validation images into our own validation and test sets; this flag thus controls the number of COCO validation images which we will use as our own validation set, and the remaining images will serve as our test set. 64 | - `--shuffle_val`: Whether to shuffle the samples from the validation set. Default is True. 65 | - `--loader_num_workers`: The number of background threads to use for data loading. Default is 4. 66 | - `--include_relationships`: Whether to include relationships in the scene graphs; default is 1 which means use relationships, 0 means omit them. This is used to train the "no relationships" ablated model. 67 | 68 | **Visual Genome options**: 69 | These flags only take effect if `--dataset` is set to `vg`: 70 | 71 | - `--vg_image_dir`: Directory from which to load Visual Genome images. Default is `datasets/vg/images`. 72 | - `--train_h5`: Path to HDF5 file containing data for the training split, created by `scripts/preprocess_vg.py`. Default is `datasets/vg/train.h5`. 73 | - `--val_h5`: Path to HDF5 file containing data for the validation split, created by `scripts/preprocess_vg.py`. Default is `datasets/vg/val.h5`. 74 | - `--vocab_json`: Path to JSON file containing Visual Genome vocabulary, created by `scripts/preprocess_vg.py`. Default is `datasets/vg/vocab.json`. 75 | - `--max_objects_per_image`: The maximum number of objects to use per scene graph during training; default is 10. Note that `scripts/preprocess_vg.py` also defines a maximum number of objects per image, but the two settings are different. The max value in the preprocessing script causes images with more than the max number of objects to be skipped entirely; in contrast during training if we encounter images with more than the max number of objects then they are randomly subsampled to the max value as a form of data augmentation. 76 | - `--vg_use_orphaned_objects`: Whether to include objects which do not participate in any relationships; 1 means use them (default), 0 means skip them. 77 | 78 | **COCO options**: 79 | These flags only take effect if `--dataset` is set to `coco`: 80 | 81 | - `--coco_train_image_dir`: Directory from which to load COCO training images; default is `datasets/coco/images/train2017`. 82 | - `--coco_val_image_dir`: Directory from which to load COCO validation images; default is `datasets/coco/images/val2017`. 83 | - `--coco_train_instances_json`: Path to JSON file containing object annotations for the COCO training images; default is `datasets/coco/annotations/instances_train2017.json`. 84 | - `--coco_train_stuff_json`: Path to JSON file containing stuff annotations for the COCO training images; default is `datasets/coco/annotations/stuff_train2017.json`. 85 | - `--coco_val_instances_json`: Path to JSON file containing object annotations for COCO validation images; default is `datasets/coco/instances_val2017.json`. 86 | - `--coco_train_instances_json`: Path to JSON file containing stuff annotations for COCO validation images; default is `datasets/coco/stuff_val2017.json`. 87 | - `--instance_whitelist`: The default behavior is to train the model to generate all object categories; however by passing a comma-separated list to this flag we can train the model to generate only a subset of object categories. 88 | - `--stuff_whitelist`: The default behavior is to train the model to generate all stuff categories (except other, see below); however by passing a comma-separated list to this flag we can train the model to generate only a subset of stuff categories. 89 | - `--coco_include_other`: The COCO-Stuff annotations include an "other" category for objects which do not fall into any of the other object categories. Due to the variety in this category I found that the model was unable to learn it, so setting this flag to 0 (default) causes COCO-Stuff annotations with the "other" category to be ignored. Setting it to 1 will include these "other" annotations. 90 | - `--coco_stuff_only`: The 2017 COCO training split contains 115K images. Object annotations are provided for all of these images, but Stuff annotations are only provided for 40K of these images. Setting this flag to 1 (default) will only train using images for which Stuff annotations are available; setting this flag to 0 will use all 115K images for training, including Stuff annotations only for the images for which they are available. 91 | 92 | ### Generator options 93 | These flags control the architecture and loss hyperparameters for the generator, which inputs scene graphs and outputs images. 94 | 95 | - `--mask_size`: Integer giving the resolution at which instance segmentation masks are predicted for objects. Default is 16. Setting this to 0 causes the model to omit the mask prediction subnetwork, instead using the entire object bounding box as the mask. Since mask prediction is differentiable the model can predict masks even when the training dataset does not provide masks; in particular Visual Genome does not provide masks, but all VG models were trained with `--mask_size 16`. 96 | - `--embedding_dim`: Integer giving the dimension for the embedding layer for objects and relationships prior to the first graph convolution layer. Default is 128. 97 | - `--gconv_dim`: Integer giving the dimension for the vectors in graph convolution layers. Default is 128. 98 | - `--gconv_hidden_dim`: Integer giving the dimension for the hidden dimension inside each graph convolution layer; this is the dimension of the candidate vectors V^s_i and V^s_o from Equations 1 and 2 in the paper. Default is 512. 99 | - `--gconv_num_layers`: The number of graph convolution layers to use. Default is 5. 100 | - `--mlp_normalization`: The type of normalization (if any) to use for linear layers in MLPs inside graph convolution layers and the box prediction subnetwork. Choices are 'none' (default), which means to use no normalization, or 'batch' which means to use batch normalization. 101 | - `--refinement_network_dims`: Comma-separated list of integers specifying the architecture of the cascaded refinement network used to generate images; default is `1024,512,256,128,64` which means to use five refinement modules, the first with 1024 feature maps, the second with 512 feature maps, etc. Spatial resolution of the feature maps doubles between each successive refinement modules. 102 | - `--normalization`: The type of normalization layer to use in the cascaded refinement network. Options are 'batch' (default) for batch normalization, 'instance' for instance normalization, or 'none' for no normalization. 103 | - `--activation`: Activation function to use in the cascaded refinement network; default is `leakyrelu-0.2` which is a Leaky ReLU with a negative slope of 0.2. Can also be `relu`. 104 | - `--layout_noise_dim`: The number of channels of random noise that is concatenated with the scene layout before feeding to the cascaded refinement network. Default is 32. 105 | 106 | **Generator Losses**: These flags control the non-adversarial losses used to to train the generator: 107 | 108 | - `--l1_pixel_loss_weight`: Float giving the weight to give L1 difference between generated and ground-truth image. Default is 1.0. 109 | - `--mask_loss_weight`: Float giving the weight to give mask prediction in the overall model loss. Setting this to 0 (default) means that masks are weakly supervised, which is required when training on Visual Genome. For COCO I found that setting `--mask_loss_weight` to 0.1 works well. 110 | - `--bbox_pred_loss_weight`: Float giving the weight to assign to regressing the bounding boxes for objects. Default is 10. 111 | 112 | ### Discriminator options 113 | The generator is trained adversarially against two discriminators: an patch-based image discriminator ensuring that patches of the generated image look realistic, and an object discriminator that ensures that generated objects are realistic. These flags apply to both discriminators: 114 | 115 | - `--discriminator_loss_weight`: The weight to assign to discriminator losses when training the generator. Default is 0.01. 116 | - `--gan_loss_type`: The GAN loss function to use. Default is 'gan' which is the original GAN loss function; can also be 'lsgan' for least-squares GAN or 'wgan' for Wasserstein GAN loss. 117 | - `--d_clip`: Value at which to clip discriminator weights, for WGAN. Default is no clipping. 118 | - `--d_normalization`: The type of normalization to use in discriminators. Default is 'batch' for batch normalization, but like CRN normalization this can also be 'none' or 'instance'. 119 | - `--d_padding`: The type of padding to use for convolutions in the discriminators, either 'valid' (default) or 'same'. 120 | - `--d_activation`: Activation function to use in the discriminators. Like CRN the default is `leakyrelu-0.2`. 121 | 122 | **Object Discriminator**: These flags only apply to the object discriminator: 123 | 124 | - `--d_obj_arch`: String giving the architecture of the object discriminator; the semantics for architecture strings [is described here](https://github.com/jcjohnson/sg2im-release/blob/master/sg2im/layers.py#L116). 125 | - `--crop_size`: The object discriminator crops out each object in images; this gives the spatial size to which these crops are (differentiably) resized. Default is 32. 126 | - `--d_obj_weight`: Weight for real / fake classification in the object discriminator. During training the weight given to fooling the object discriminator is `--discriminator_loss_weight * --d_obj_weight`. Default is 1.0 127 | - `--ac_loss_weight`: Weight for the auxiliary classifier in the object discriminator that attempts to predict the object category of objects; the weight assigned to this loss is `--discriminator_loss_weight * --ac_loss_weight`. Default is 0.1. 128 | 129 | **Image Discriminator**: These flags only apply to the image discriminator: 130 | 131 | - `--d_img_arch`: String giving the architecture of the image discriminator; the semantics for architecture strings [is described here](https://github.com/jcjohnson/sg2im-release/blob/master/sg2im/layers.py#L116). 132 | - `--d_img_weight`: The weight assigned to fooling the image discriminator is `--discriminator_loss_weight * --d_img_weight`. Default is 1.0. 133 | 134 | ### Output Options 135 | These flags control outputs from the training script: 136 | 137 | - `--print_every`: Training losses are printed and recorded every `--print_every` iterations. Default is 10. 138 | - `--timing`: If this flag is set to 1 then measure and print the time that each model component takes to execute. 139 | - `--checkpoint_every`: Checkpoints are saved to disk every `--checkpoint_every` iterations. Default is 10000. Each checkpoint contains a history of training losses, a history of images generated from training-set and val-set scene graphs, the current state of the generator, discriminators, and optimizers, as well as all other state information needed to resume training in case it is interrupted. We actually save two checkpoints: one with all information, and one without model parameters; the latter is much smaller, and is convenient for exploring the results of a large hyperparameter sweep without actually loading model parameters. 140 | - `--output_dir`: Directory to which checkpoints will be saved. Default is current directory. 141 | - `--checkpoint_name`: Base filename for saved checkpoints; default is 'checkpoint', so the filename for the checkpoint with model parameters will be 'checkpoint_with_model.pt' and the filename for the checkpoint without model parameters will be 'checkpoint_no_model.pt'. 142 | - `--restore_from_checkpoint`: Default behavior is to start training from scratch, and overwrite the output checkpoint path if it already exists. If this flag is set to 1 then instead resume training from the output checkpoint file if it already exists. This is useful when running in an environment where jobs can be preempted. 143 | - `--checkpoint_start_from`: Default behavior is to start training from scratch; if this flag is given then instead resume training from the specified checkpoint. This takes precedence over `--restore_from_checkpoint` if both are given. 144 | 145 | ## (Optional) Step 4: Strip Checkpoint 146 | Checkpoints saved by train.py contain not only model parameters but also optimizer states, losses, a history of generated images, and other statistics. This information is very useful for development and debugging models, but makes the saved checkpoints very large. You can use the script `scripts/strip_checkpoint.py` to remove all this extra information from a saved checkpoint, and save only the trained models: 147 | 148 | ```bash 149 | python scripts/strip_checkpoint.py \ 150 | --input_checkpoint checkpoint_with_model.pt \ 151 | --output_checkpoint checkpoint_stripped.pt 152 | ``` 153 | 154 | ## Step 5: Run the model 155 | You can use the script `scripts/run_model.py` to run the model on arbitrary scene graphs specified in a simple JSON format. For example to generate images for the scene graphs used in Figure 6 of the paper you can run: 156 | 157 | ```bash 158 | python scripts/run_model.py \ 159 | --checkpoint checkpoint_with_model.pt \ 160 | --scene_graphs_json scene_graphs/figure_6_sheep.json 161 | ``` 162 | -------------------------------------------------------------------------------- /images/sheep/img000000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/img000000.png -------------------------------------------------------------------------------- /images/sheep/img000001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/img000001.png -------------------------------------------------------------------------------- /images/sheep/img000002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/img000002.png -------------------------------------------------------------------------------- /images/sheep/img000003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/img000003.png -------------------------------------------------------------------------------- /images/sheep/img000004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/img000004.png -------------------------------------------------------------------------------- /images/sheep/img000005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/img000005.png -------------------------------------------------------------------------------- /images/sheep/img000006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/img000006.png -------------------------------------------------------------------------------- /images/sheep/sg000000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/sg000000.png -------------------------------------------------------------------------------- /images/sheep/sg000001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/sg000001.png -------------------------------------------------------------------------------- /images/sheep/sg000002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/sg000002.png -------------------------------------------------------------------------------- /images/sheep/sg000003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/sg000003.png -------------------------------------------------------------------------------- /images/sheep/sg000004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/sg000004.png -------------------------------------------------------------------------------- /images/sheep/sg000005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/sg000005.png -------------------------------------------------------------------------------- /images/sheep/sg000006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/sg000006.png -------------------------------------------------------------------------------- /images/system.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/system.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cloudpickle==0.5.3 2 | cycler==0.10.0 3 | Cython==0.28.3 4 | dask==0.17.5 5 | decorator==4.3.0 6 | h5py==2.8.0 7 | imageio==2.3.0 8 | kiwisolver==1.0.1 9 | matplotlib==2.2.2 10 | networkx==2.1 11 | numpy==1.14.4 12 | Pillow==5.1.0 13 | pyparsing==2.2.0 14 | python-dateutil==2.7.3 15 | pytz==2018.4 16 | PyWavelets==0.5.2 17 | scikit-image==0.14.0 18 | scipy==1.1.0 19 | six==1.11.0 20 | toolz==0.9.0 21 | torch==0.4.0 22 | torchvision==0.2.1 23 | -------------------------------------------------------------------------------- /scene_graphs/figure_5_coco.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "objects": [ 4 | "car", "car", "cage", "person", "grass", "tree", "playingfield", "person" 5 | ], 6 | "relationships": [ 7 | [0, "left of", 1], 8 | [0, "above", 6], 9 | [1, "above", 4], 10 | [3, "left of", 7], 11 | [5, "above", 7], 12 | [6, "below", 2], 13 | [6, "below", 3], 14 | [7, "left of", 4] 15 | ] 16 | }, 17 | { 18 | "objects": [ "broccoli", "broccoli", "vegetable", "carrot"], 19 | "relationships": [ 20 | [0, "left of", 1], 21 | [0, "left of", 1], 22 | [1, "inside", 2], 23 | [3, "below", 1] 24 | ] 25 | }, 26 | { 27 | "objects": ["person", "person", "person", "fence"], 28 | "relationships": [ 29 | [0, "left of", 1], 30 | [0, "left of", 1], 31 | [0, "inside", 3], 32 | [2, "inside", 3] 33 | ] 34 | }, 35 | { 36 | "objects": ["sky-other", "person", "skateboard", "tree"], 37 | "relationships": [ 38 | [0, "surrounding", 2], 39 | [1, "inside", 0], 40 | [1, "above", 3], 41 | [3, "below", 0] 42 | ] 43 | }, 44 | { 45 | "objects": ["person", "tie", "wall-panel", "clothes"], 46 | "relationships": [ 47 | [0, "surrounding", 1], 48 | [1, "inside", 0], 49 | [1, "above", 3], 50 | [2, "surrounding", 0] 51 | ] 52 | }, 53 | { 54 | "objects": ["clouds", "tree", "person", "horse", "grass"], 55 | "relationships": [ 56 | [0, "above", 4], 57 | [0, "above", 4], 58 | [1, "right of", 2], 59 | [2, "left of", 3], 60 | [3, "above", 4] 61 | ] 62 | }, 63 | { 64 | "objects": ["elephant", "elephant", "tree", "grass"], 65 | "relationships": [ 66 | [0, "inside", 2], 67 | [0, "above", 3], 68 | [2, "surrounding", 1] 69 | ] 70 | }, 71 | { 72 | "objects": ["clouds", "building-other", "boat", "tree", "river"], 73 | "relationships": [ 74 | [0, "above", 1], 75 | [0, "above", 1], 76 | [1, "above", 4], 77 | [3, "left of", 4], 78 | [4, "below", 0] 79 | ] 80 | } 81 | ] 82 | -------------------------------------------------------------------------------- /scene_graphs/figure_5_vg.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "objects": [ 4 | "sky", "cloud", "mountain", "sheep", "grass", "tree", 5 | "stone", "rock", "sheep" 6 | ], 7 | "relationships": [ 8 | [0, "has", 1], 9 | [2, "behind", 5], 10 | [3, "eating", 4], 11 | [3, "eating", 4], 12 | [5, "in front of", 2] 13 | ] 14 | }, 15 | { 16 | "objects": [ 17 | "sky", "water", "person", "wave", "board", 18 | "cloud", "background", "edge" 19 | ], 20 | "relationships": [ 21 | [0, "above", 1], 22 | [2, "by", 1], 23 | [2, "riding", 3], 24 | [2, "riding", 4] 25 | ] 26 | }, 27 | { 28 | "objects": ["boy", "grass", "sky", "kite", "field", "mountain", "brick"], 29 | "relationships": [ 30 | [0, "on top of", 1], 31 | [0, "standing on", 1], 32 | [0, "looking at", 2], 33 | [0, "looking at", 3], 34 | [4, "under", 5] 35 | ] 36 | }, 37 | { 38 | "objects": [ 39 | "building", "bus", "tree", "bus", "windshield", "windshield", 40 | "sky", "line", "sign" 41 | ], 42 | "relationships": [ 43 | [0, "next to", 0], 44 | [0, "next to", 0], 45 | [0, "next to", 0], 46 | [0, "next to", 0], 47 | [1, "has", 4], 48 | [1, "behind", 3], 49 | [2, "behind", 3], 50 | [3, "has", 5] 51 | ] 52 | }, 53 | { 54 | "objects": [ 55 | "car", "tree", "house", "street", "window", "roof", "house", "bush", "car" 56 | ], 57 | "relationships": [ 58 | [0, "parked on", 3], 59 | [1, "along", 3], 60 | [2, "has", 5], 61 | [4, "in front of", 6] 62 | ] 63 | }, 64 | { 65 | "objects": [ 66 | "sky", "man", "leg", "horse", "tail", "leg", 67 | "short", "hill", "hill" 68 | ], 69 | "relationships": [ 70 | [0, "above", 1], 71 | [1, "has", 2], 72 | [1, "riding", 3], 73 | [3, "has", 4], 74 | [3, "has", 4], 75 | [3, "has", 5] 76 | ] 77 | }, 78 | { 79 | "objects": ["boat", "water", "sky", "rock", "bird"], 80 | "relationships": [[0, "on top of", 1]] 81 | }, 82 | { 83 | "objects": ["food", "glass", "glass", "plate", "plate"], 84 | "relationships": [ 85 | [0, "on top of", 3], 86 | [1, "by", 3], 87 | [2, "on top of", 4] 88 | ] 89 | } 90 | ] 91 | -------------------------------------------------------------------------------- /scene_graphs/figure_6_sheep.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "objects": ["sky", "grass", "zebra"], 4 | "relationships": [ 5 | [0, "above", 1], 6 | [2, "standing on", 1] 7 | ] 8 | }, 9 | { 10 | "objects": ["sky", "grass", "sheep"], 11 | "relationships": [ 12 | [0, "above", 1], 13 | [2, "standing on", 1] 14 | ] 15 | }, 16 | { 17 | "objects": ["sky", "grass", "sheep", "sheep"], 18 | "relationships": [ 19 | [0, "above", 1], 20 | [2, "standing on", 1], 21 | [3, "by", 2] 22 | ] 23 | }, 24 | { 25 | "objects": ["sky", "grass", "sheep", "sheep", "tree"], 26 | "relationships": [ 27 | [0, "above", 1], 28 | [2, "standing on", 1], 29 | [3, "by", 2], 30 | [4, "behind", 2] 31 | ] 32 | }, 33 | { 34 | "objects": ["sky", "grass", "sheep", "sheep", "tree", "ocean"], 35 | "relationships": [ 36 | [0, "above", 1], 37 | [2, "standing on", 1], 38 | [3, "by", 2], 39 | [4, "behind", 2], 40 | [5, "by", 4] 41 | ] 42 | }, 43 | { 44 | "objects": ["sky", "grass", "sheep", "sheep", "tree", "ocean", "boat"], 45 | "relationships": [ 46 | [0, "above", 1], 47 | [2, "standing on", 1], 48 | [3, "by", 2], 49 | [4, "behind", 2], 50 | [5, "by", 4], 51 | [6, "in", 5] 52 | ] 53 | }, 54 | { 55 | "objects": ["sky", "grass", "sheep", "sheep", "tree", "ocean", "boat"], 56 | "relationships": [ 57 | [0, "above", 1], 58 | [2, "standing on", 1], 59 | [3, "by", 2], 60 | [4, "behind", 2], 61 | [5, "by", 4], 62 | [6, "on", 1] 63 | ] 64 | } 65 | ] 66 | -------------------------------------------------------------------------------- /scene_graphs/figure_6_street.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "objects": ["car", "street", "line", "sky"], 4 | "relationships": [ 5 | [0, "on", 1], 6 | [2, "on", 1], 7 | [3, "above", 1] 8 | ] 9 | }, 10 | { 11 | "objects": ["bus", "street", "line", "sky"], 12 | "relationships": [ 13 | [0, "on", 1], 14 | [2, "on", 1], 15 | [3, "above", 1] 16 | ] 17 | }, 18 | { 19 | "objects": ["bus", "street", "line", "sky", "car"], 20 | "relationships": [ 21 | [0, "on", 1], 22 | [2, "on", 1], 23 | [3, "above", 1], 24 | [4, "on", 1] 25 | ] 26 | }, 27 | { 28 | "objects": ["bus", "street", "line", "sky", "car", "kite"], 29 | "relationships": [ 30 | [0, "on", 1], 31 | [2, "on", 1], 32 | [3, "above", 1], 33 | [4, "on", 1], 34 | [5, "in", 3] 35 | ] 36 | }, 37 | { 38 | "objects": ["bus", "street", "line", "sky", "car", "kite"], 39 | "relationships": [ 40 | [0, "on", 1], 41 | [2, "on", 1], 42 | [3, "above", 1], 43 | [4, "on", 1], 44 | [5, "in", 3], 45 | [4, "below", 5] 46 | ] 47 | }, 48 | { 49 | "objects": ["bus", "street", "line", "sky", "car", "building"], 50 | "relationships": [ 51 | [0, "on", 1], 52 | [2, "on", 1], 53 | [3, "above", 1], 54 | [4, "on", 1], 55 | [5, "behind", 1] 56 | ] 57 | }, 58 | { 59 | "objects": ["bus", "street", "line", "sky", "car", "building", "window"], 60 | "relationships": [ 61 | [0, "on", 1], 62 | [2, "on", 1], 63 | [3, "above", 1], 64 | [4, "on", 1], 65 | [5, "behind", 1], 66 | [6, "on", 5] 67 | ] 68 | } 69 | ] 70 | -------------------------------------------------------------------------------- /scripts/download_ablated_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # Download all models from the ablation study 18 | mkdir -p sg2im-models 19 | 20 | # COCO models 21 | wget https://storage.googleapis.com/sg2im-data/small/coco64_no_gconv.pt -O sg2im-models/coco64_no_gconv.pt 22 | wget https://storage.googleapis.com/sg2im-data/small/coco64_no_relations.pt -O sg2im-models/coco64_no_relations.pt 23 | wget https://storage.googleapis.com/sg2im-data/small/coco64_no_discriminators.pt -O sg2im-models/coco64_no_discriminators.pt 24 | wget https://storage.googleapis.com/sg2im-data/small/coco64_no_img_discriminator.pt -O sg2im-models/coco64_no_img_discriminator.pt 25 | wget https://storage.googleapis.com/sg2im-data/small/coco64_no_obj_discriminator.pt -O sg2im-models/coco64_no_obj_discriminator.pt 26 | wget https://storage.googleapis.com/sg2im-data/small/coco64_gt_layout.pt -O sg2im-models/coco64_gt_layout.pt 27 | wget https://storage.googleapis.com/sg2im-data/small/coco64_gt_layout_no_gconv.pt -O sg2im-models/coco64_gt_layout_no_gconv.pt 28 | 29 | # VG models 30 | wget https://storage.googleapis.com/sg2im-data/small/vg64_no_relations.pt -O sg2im-models/vg64_no_relations.pt 31 | wget https://storage.googleapis.com/sg2im-data/small/vg64_no_gconv.pt -O sg2im-models/vg64_no_gconv.pt 32 | wget https://storage.googleapis.com/sg2im-data/small/vg64_no_discriminators.pt -O sg2im-models/vg64_no_discriminators.pt 33 | wget https://storage.googleapis.com/sg2im-data/small/vg64_no_img_discriminator.pt -O sg2im-models/vg64_no_img_discriminator.pt 34 | wget https://storage.googleapis.com/sg2im-data/small/vg64_no_obj_discriminator.pt -O sg2im-models/vg64_no_obj_discriminator.pt 35 | -------------------------------------------------------------------------------- /scripts/download_coco.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | COCO_DIR=datasets/coco 18 | mkdir -p $COCO_DIR 19 | 20 | wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip -O $COCO_DIR/annotations_trainval2017.zip 21 | wget http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip -O $COCO_DIR/stuff_annotations_trainval2017.zip 22 | wget http://images.cocodataset.org/zips/train2017.zip -O $COCO_DIR/train2017.zip 23 | wget http://images.cocodataset.org/zips/val2017.zip -O $COCO_DIR/val2017.zip 24 | 25 | unzip $COCO_DIR/annotations_trainval2017.zip -d $COCO_DIR 26 | unzip $COCO_DIR/stuff_annotations_trainval2017.zip -d $COCO_DIR 27 | unzip $COCO_DIR/train2017.zip -d $COCO_DIR/images 28 | unzip $COCO_DIR/val2017.zip -d $COCO_DIR/images 29 | -------------------------------------------------------------------------------- /scripts/download_full_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # Download full versions of all models which include training history 18 | mkdir -p sg2im-models/full 19 | 20 | # COCO models 21 | wget https://storage.googleapis.com/sg2im-data/full/coco64.pt -O sg2im-models/full/coco64.pt 22 | wget https://storage.googleapis.com/sg2im-data/full/coco64_no_gconv.pt -O sg2im-models/full/coco64_no_gconv.pt 23 | wget https://storage.googleapis.com/sg2im-data/full/coco64_no_relations.pt -O sg2im-models/full/coco64_no_relations.pt 24 | wget https://storage.googleapis.com/sg2im-data/full/coco64_no_discriminators.pt -O sg2im-models/full/coco64_no_discriminators.pt 25 | wget https://storage.googleapis.com/sg2im-data/full/coco64_no_img_discriminator.pt -O sg2im-models/full/coco64_no_img_discriminator.pt 26 | wget https://storage.googleapis.com/sg2im-data/full/coco64_no_obj_discriminator.pt -O sg2im-models/full/coco64_no_obj_discriminator.pt 27 | wget https://storage.googleapis.com/sg2im-data/full/coco64_gt_layout.pt -O sg2im-models/full/coco64_gt_layout.pt 28 | wget https://storage.googleapis.com/sg2im-data/full/coco64_gt_layout_no_gconv.pt -O sg2im-models/full/coco64_gt_layout_no_gconv.pt 29 | 30 | # VG models 31 | wget https://storage.googleapis.com/sg2im-data/full/vg64.pt -O sg2im-models/full/vg64.pt 32 | wget https://storage.googleapis.com/sg2im-data/full/vg128.pt -O sg2im-models/full/vg128.pt 33 | wget https://storage.googleapis.com/sg2im-data/full/vg64_no_relations.pt -O sg2im-models/full/vg64_no_relations.pt 34 | wget https://storage.googleapis.com/sg2im-data/full/vg64_no_gconv.pt -O sg2im-models/full/vg64_no_gconv.pt 35 | wget https://storage.googleapis.com/sg2im-data/full/vg64_no_discriminators.pt -O sg2im-models/full/vg64_no_discriminators.pt 36 | wget https://storage.googleapis.com/sg2im-data/full/vg64_no_img_discriminator.pt -O sg2im-models/full/vg64_no_img_discriminator.pt 37 | wget https://storage.googleapis.com/sg2im-data/full/vg64_no_obj_discriminator.pt -O sg2im-models/full/vg64_no_obj_discriminator.pt 38 | -------------------------------------------------------------------------------- /scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # Download the main models: 64 x 64 for coco and vg, and 128 x 128 for vg 18 | mkdir -p sg2im-models 19 | wget https://storage.googleapis.com/sg2im-data/small/coco64.pt -O sg2im-models/coco64.pt 20 | wget https://storage.googleapis.com/sg2im-data/small/vg64.pt -O sg2im-models/vg64.pt 21 | wget https://storage.googleapis.com/sg2im-data/small/vg128.pt -O sg2im-models/vg128.pt 22 | -------------------------------------------------------------------------------- /scripts/download_vg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | VG_DIR=datasets/vg 18 | mkdir -p $VG_DIR 19 | 20 | wget https://visualgenome.org/static/data/dataset/objects.json.zip -O $VG_DIR/objects.json.zip 21 | wget https://visualgenome.org/static/data/dataset/attributes.json.zip -O $VG_DIR/attributes.json.zip 22 | wget https://visualgenome.org/static/data/dataset/relationships.json.zip -O $VG_DIR/relationships.json.zip 23 | wget https://visualgenome.org/static/data/dataset/object_alias.txt -O $VG_DIR/object_alias.txt 24 | wget https://visualgenome.org/static/data/dataset/relationship_alias.txt -O $VG_DIR/relationship_alias.txt 25 | wget https://visualgenome.org/static/data/dataset/image_data.json.zip -O $VG_DIR/image_data.json.zip 26 | wget https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip -O $VG_DIR/images.zip 27 | wget https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip -O $VG_DIR/images2.zip 28 | 29 | unzip $VG_DIR/objects.json.zip -d $VG_DIR 30 | unzip $VG_DIR/attributes.json.zip -d $VG_DIR 31 | unzip $VG_DIR/relationships.json.zip -d $VG_DIR 32 | unzip $VG_DIR/image_data.json.zip -d $VG_DIR 33 | unzip $VG_DIR/images.zip -d $VG_DIR/images 34 | unzip $VG_DIR/images2.zip -d $VG_DIR/images 35 | -------------------------------------------------------------------------------- /scripts/preprocess_vg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import argparse, json, os 18 | from collections import Counter, defaultdict 19 | 20 | import numpy as np 21 | import h5py 22 | from scipy.misc import imread, imresize 23 | 24 | 25 | """ 26 | vocab for objects contains a special entry "__image__" intended to be used for 27 | dummy nodes encompassing the entire image; vocab for predicates includes a 28 | special entry "__in_image__" to be used for dummy relationships making the graph 29 | fully-connected. 30 | """ 31 | 32 | 33 | VG_DIR = 'datasets/vg' 34 | 35 | parser = argparse.ArgumentParser() 36 | 37 | # Input data 38 | parser.add_argument('--splits_json', default='sg2im/data/vg_splits.json') 39 | parser.add_argument('--images_json', 40 | default=os.path.join(VG_DIR, 'image_data.json')) 41 | parser.add_argument('--objects_json', 42 | default=os.path.join(VG_DIR, 'objects.json')) 43 | parser.add_argument('--attributes_json', 44 | default=os.path.join(VG_DIR, 'attributes.json')) 45 | parser.add_argument('--object_aliases', 46 | default=os.path.join(VG_DIR, 'object_alias.txt')) 47 | parser.add_argument('--relationship_aliases', 48 | default=os.path.join(VG_DIR, 'relationship_alias.txt')) 49 | parser.add_argument('--relationships_json', 50 | default=os.path.join(VG_DIR, 'relationships.json')) 51 | 52 | # Arguments for images 53 | parser.add_argument('--min_image_size', default=200, type=int) 54 | parser.add_argument('--train_split', default='train') 55 | 56 | # Arguments for objects 57 | parser.add_argument('--min_object_instances', default=2000, type=int) 58 | parser.add_argument('--min_attribute_instances', default=2000, type=int) 59 | parser.add_argument('--min_object_size', default=32, type=int) 60 | parser.add_argument('--min_objects_per_image', default=3, type=int) 61 | parser.add_argument('--max_objects_per_image', default=30, type=int) 62 | parser.add_argument('--max_attributes_per_image', default=30, type=int) 63 | 64 | # Arguments for relationships 65 | parser.add_argument('--min_relationship_instances', default=500, type=int) 66 | parser.add_argument('--min_relationships_per_image', default=1, type=int) 67 | parser.add_argument('--max_relationships_per_image', default=30, type=int) 68 | 69 | # Output 70 | parser.add_argument('--output_vocab_json', 71 | default=os.path.join(VG_DIR, 'vocab.json')) 72 | parser.add_argument('--output_h5_dir', default=VG_DIR) 73 | 74 | 75 | def main(args): 76 | print('Loading image info from "%s"' % args.images_json) 77 | with open(args.images_json, 'r') as f: 78 | images = json.load(f) 79 | image_id_to_image = {i['image_id']: i for i in images} 80 | 81 | with open(args.splits_json, 'r') as f: 82 | splits = json.load(f) 83 | 84 | # Filter images for being too small 85 | splits = remove_small_images(args, image_id_to_image, splits) 86 | 87 | obj_aliases = load_aliases(args.object_aliases) 88 | rel_aliases = load_aliases(args.relationship_aliases) 89 | 90 | print('Loading objects from "%s"' % args.objects_json) 91 | with open(args.objects_json, 'r') as f: 92 | objects = json.load(f) 93 | 94 | # Vocab for objects and relationships 95 | vocab = {} 96 | train_ids = splits[args.train_split] 97 | create_object_vocab(args, train_ids, objects, obj_aliases, vocab) 98 | 99 | print('Loading attributes from "%s"' % args.attributes_json) 100 | with open(args.attributes_json, 'r') as f: 101 | attributes = json.load(f) 102 | 103 | # Vocab for attributes 104 | create_attribute_vocab(args, train_ids, attributes, vocab) 105 | 106 | object_id_to_obj = filter_objects(args, objects, obj_aliases, vocab, splits) 107 | print('After filtering there are %d object instances' 108 | % len(object_id_to_obj)) 109 | 110 | print('Loading relationshps from "%s"' % args.relationships_json) 111 | with open(args.relationships_json, 'r') as f: 112 | relationships = json.load(f) 113 | 114 | create_rel_vocab(args, train_ids, relationships, object_id_to_obj, 115 | rel_aliases, vocab) 116 | 117 | print('Encoding objects and relationships ...') 118 | numpy_arrays = encode_graphs(args, splits, objects, relationships, vocab, 119 | object_id_to_obj, attributes) 120 | 121 | print('Writing HDF5 output files') 122 | for split_name, split_arrays in numpy_arrays.items(): 123 | image_ids = list(split_arrays['image_ids'].astype(int)) 124 | h5_path = os.path.join(args.output_h5_dir, '%s.h5' % split_name) 125 | print('Writing file "%s"' % h5_path) 126 | with h5py.File(h5_path, 'w') as h5_file: 127 | for name, ary in split_arrays.items(): 128 | print('Creating datset: ', name, ary.shape, ary.dtype) 129 | h5_file.create_dataset(name, data=ary) 130 | print('Writing image paths') 131 | image_paths = get_image_paths(image_id_to_image, image_ids) 132 | path_dtype = h5py.special_dtype(vlen=str) 133 | path_shape = (len(image_paths),) 134 | path_dset = h5_file.create_dataset('image_paths', path_shape, 135 | dtype=path_dtype) 136 | for i, p in enumerate(image_paths): 137 | path_dset[i] = p 138 | print() 139 | 140 | print('Writing vocab to "%s"' % args.output_vocab_json) 141 | with open(args.output_vocab_json, 'w') as f: 142 | json.dump(vocab, f) 143 | 144 | def remove_small_images(args, image_id_to_image, splits): 145 | new_splits = {} 146 | for split_name, image_ids in splits.items(): 147 | new_image_ids = [] 148 | num_skipped = 0 149 | for image_id in image_ids: 150 | image = image_id_to_image[image_id] 151 | height, width = image['height'], image['width'] 152 | if min(height, width) < args.min_image_size: 153 | num_skipped += 1 154 | continue 155 | new_image_ids.append(image_id) 156 | new_splits[split_name] = new_image_ids 157 | print('Removed %d images from split "%s" for being too small' % 158 | (num_skipped, split_name)) 159 | 160 | return new_splits 161 | 162 | 163 | def get_image_paths(image_id_to_image, image_ids): 164 | paths = [] 165 | for image_id in image_ids: 166 | image = image_id_to_image[image_id] 167 | base, filename = os.path.split(image['url']) 168 | path = os.path.join(os.path.basename(base), filename) 169 | paths.append(path) 170 | return paths 171 | 172 | 173 | def handle_images(args, image_ids, h5_file): 174 | with open(args.images_json, 'r') as f: 175 | images = json.load(f) 176 | if image_ids: 177 | image_ids = set(image_ids) 178 | 179 | image_heights, image_widths = [], [] 180 | image_ids_out, image_paths = [], [] 181 | for image in images: 182 | image_id = image['image_id'] 183 | if image_ids and image_id not in image_ids: 184 | continue 185 | height, width = image['height'], image['width'] 186 | 187 | base, filename = os.path.split(image['url']) 188 | path = os.path.join(os.path.basename(base), filename) 189 | image_paths.append(path) 190 | image_heights.append(height) 191 | image_widths.append(width) 192 | image_ids_out.append(image_id) 193 | 194 | image_ids_np = np.asarray(image_ids_out, dtype=int) 195 | h5_file.create_dataset('image_ids', data=image_ids_np) 196 | 197 | image_heights = np.asarray(image_heights, dtype=int) 198 | h5_file.create_dataset('image_heights', data=image_heights) 199 | 200 | image_widths = np.asarray(image_widths, dtype=int) 201 | h5_file.create_dataset('image_widths', data=image_widths) 202 | 203 | return image_paths 204 | 205 | 206 | def load_aliases(alias_path): 207 | aliases = {} 208 | print('Loading aliases from "%s"' % alias_path) 209 | with open(alias_path, 'r') as f: 210 | for line in f: 211 | line = [s.strip() for s in line.split(',')] 212 | for s in line: 213 | aliases[s] = line[0] 214 | return aliases 215 | 216 | 217 | def create_object_vocab(args, image_ids, objects, aliases, vocab): 218 | image_ids = set(image_ids) 219 | 220 | print('Making object vocab from %d training images' % len(image_ids)) 221 | object_name_counter = Counter() 222 | for image in objects: 223 | if image['image_id'] not in image_ids: 224 | continue 225 | for obj in image['objects']: 226 | names = set() 227 | for name in obj['names']: 228 | names.add(aliases.get(name, name)) 229 | object_name_counter.update(names) 230 | 231 | object_names = ['__image__'] 232 | for name, count in object_name_counter.most_common(): 233 | if count >= args.min_object_instances: 234 | object_names.append(name) 235 | print('Found %d object categories with >= %d training instances' % 236 | (len(object_names), args.min_object_instances)) 237 | 238 | object_name_to_idx = {} 239 | object_idx_to_name = [] 240 | for idx, name in enumerate(object_names): 241 | object_name_to_idx[name] = idx 242 | object_idx_to_name.append(name) 243 | 244 | vocab['object_name_to_idx'] = object_name_to_idx 245 | vocab['object_idx_to_name'] = object_idx_to_name 246 | 247 | def create_attribute_vocab(args, image_ids, attributes, vocab): 248 | image_ids = set(image_ids) 249 | print('Making attribute vocab from %d training images' % len(image_ids)) 250 | attribute_name_counter = Counter() 251 | for image in attributes: 252 | if image['image_id'] not in image_ids: 253 | continue 254 | for attribute in image['attributes']: 255 | names = set() 256 | try: 257 | for name in attribute['attributes']: 258 | names.add(name) 259 | attribute_name_counter.update(names) 260 | except KeyError: 261 | pass 262 | attribute_names = [] 263 | for name, count in attribute_name_counter.most_common(): 264 | if count >= args.min_attribute_instances: 265 | attribute_names.append(name) 266 | print('Found %d attribute categories with >= %d training instances' % 267 | (len(attribute_names), args.min_attribute_instances)) 268 | 269 | attribute_name_to_idx = {} 270 | attribute_idx_to_name = [] 271 | for idx, name in enumerate(attribute_names): 272 | attribute_name_to_idx[name] = idx 273 | attribute_idx_to_name.append(name) 274 | vocab['attribute_name_to_idx'] = attribute_name_to_idx 275 | vocab['attribute_idx_to_name'] = attribute_idx_to_name 276 | 277 | def filter_objects(args, objects, aliases, vocab, splits): 278 | object_id_to_objects = {} 279 | all_image_ids = set() 280 | for image_ids in splits.values(): 281 | all_image_ids |= set(image_ids) 282 | 283 | object_name_to_idx = vocab['object_name_to_idx'] 284 | object_id_to_obj = {} 285 | 286 | num_too_small = 0 287 | for image in objects: 288 | image_id = image['image_id'] 289 | if image_id not in all_image_ids: 290 | continue 291 | for obj in image['objects']: 292 | object_id = obj['object_id'] 293 | final_name = None 294 | final_name_idx = None 295 | for name in obj['names']: 296 | name = aliases.get(name, name) 297 | if name in object_name_to_idx: 298 | final_name = name 299 | final_name_idx = object_name_to_idx[final_name] 300 | break 301 | w, h = obj['w'], obj['h'] 302 | too_small = (w < args.min_object_size) or (h < args.min_object_size) 303 | if too_small: 304 | num_too_small += 1 305 | if final_name is not None and not too_small: 306 | object_id_to_obj[object_id] = { 307 | 'name': final_name, 308 | 'name_idx': final_name_idx, 309 | 'box': [obj['x'], obj['y'], obj['w'], obj['h']], 310 | } 311 | print('Skipped %d objects with size < %d' % (num_too_small, args.min_object_size)) 312 | return object_id_to_obj 313 | 314 | 315 | def create_rel_vocab(args, image_ids, relationships, object_id_to_obj, 316 | rel_aliases, vocab): 317 | pred_counter = defaultdict(int) 318 | image_ids_set = set(image_ids) 319 | for image in relationships: 320 | image_id = image['image_id'] 321 | if image_id not in image_ids_set: 322 | continue 323 | for rel in image['relationships']: 324 | sid = rel['subject']['object_id'] 325 | oid = rel['object']['object_id'] 326 | found_subject = sid in object_id_to_obj 327 | found_object = oid in object_id_to_obj 328 | if not found_subject or not found_object: 329 | continue 330 | pred = rel['predicate'].lower().strip() 331 | pred = rel_aliases.get(pred, pred) 332 | rel['predicate'] = pred 333 | pred_counter[pred] += 1 334 | 335 | pred_names = ['__in_image__'] 336 | for pred, count in pred_counter.items(): 337 | if count >= args.min_relationship_instances: 338 | pred_names.append(pred) 339 | print('Found %d relationship types with >= %d training instances' 340 | % (len(pred_names), args.min_relationship_instances)) 341 | 342 | pred_name_to_idx = {} 343 | pred_idx_to_name = [] 344 | for idx, name in enumerate(pred_names): 345 | pred_name_to_idx[name] = idx 346 | pred_idx_to_name.append(name) 347 | 348 | vocab['pred_name_to_idx'] = pred_name_to_idx 349 | vocab['pred_idx_to_name'] = pred_idx_to_name 350 | 351 | 352 | def encode_graphs(args, splits, objects, relationships, vocab, 353 | object_id_to_obj, attributes): 354 | 355 | image_id_to_objects = {} 356 | for image in objects: 357 | image_id = image['image_id'] 358 | image_id_to_objects[image_id] = image['objects'] 359 | image_id_to_relationships = {} 360 | for image in relationships: 361 | image_id = image['image_id'] 362 | image_id_to_relationships[image_id] = image['relationships'] 363 | image_id_to_attributes = {} 364 | for image in attributes: 365 | image_id = image['image_id'] 366 | image_id_to_attributes[image_id] = image['attributes'] 367 | 368 | numpy_arrays = {} 369 | for split, image_ids in splits.items(): 370 | skip_stats = defaultdict(int) 371 | # We need to filter *again* based on number of objects and relationships 372 | final_image_ids = [] 373 | object_ids = [] 374 | object_names = [] 375 | object_boxes = [] 376 | objects_per_image = [] 377 | relationship_ids = [] 378 | relationship_subjects = [] 379 | relationship_predicates = [] 380 | relationship_objects = [] 381 | relationships_per_image = [] 382 | attribute_ids = [] 383 | attributes_per_object = [] 384 | object_attributes = [] 385 | for image_id in image_ids: 386 | image_object_ids = [] 387 | image_object_names = [] 388 | image_object_boxes = [] 389 | object_id_to_idx = {} 390 | for obj in image_id_to_objects[image_id]: 391 | object_id = obj['object_id'] 392 | if object_id not in object_id_to_obj: 393 | continue 394 | obj = object_id_to_obj[object_id] 395 | object_id_to_idx[object_id] = len(image_object_ids) 396 | image_object_ids.append(object_id) 397 | image_object_names.append(obj['name_idx']) 398 | image_object_boxes.append(obj['box']) 399 | num_objects = len(image_object_ids) 400 | too_few = num_objects < args.min_objects_per_image 401 | too_many = num_objects > args.max_objects_per_image 402 | if too_few: 403 | skip_stats['too_few_objects'] += 1 404 | continue 405 | if too_many: 406 | skip_stats['too_many_objects'] += 1 407 | continue 408 | image_rel_ids = [] 409 | image_rel_subs = [] 410 | image_rel_preds = [] 411 | image_rel_objs = [] 412 | for rel in image_id_to_relationships[image_id]: 413 | relationship_id = rel['relationship_id'] 414 | pred = rel['predicate'] 415 | pred_idx = vocab['pred_name_to_idx'].get(pred, None) 416 | if pred_idx is None: 417 | continue 418 | sid = rel['subject']['object_id'] 419 | sidx = object_id_to_idx.get(sid, None) 420 | oid = rel['object']['object_id'] 421 | oidx = object_id_to_idx.get(oid, None) 422 | if sidx is None or oidx is None: 423 | continue 424 | image_rel_ids.append(relationship_id) 425 | image_rel_subs.append(sidx) 426 | image_rel_preds.append(pred_idx) 427 | image_rel_objs.append(oidx) 428 | num_relationships = len(image_rel_ids) 429 | too_few = num_relationships < args.min_relationships_per_image 430 | too_many = num_relationships > args.max_relationships_per_image 431 | if too_few: 432 | skip_stats['too_few_relationships'] += 1 433 | continue 434 | if too_many: 435 | skip_stats['too_many_relationships'] += 1 436 | continue 437 | 438 | obj_id_to_attributes = {} 439 | num_attributes = [] 440 | for obj_attribute in image_id_to_attributes[image_id]: 441 | obj_id_to_attributes[obj_attribute['object_id']] = obj_attribute.get('attributes', None) 442 | for object_id in image_object_ids: 443 | attributes = obj_id_to_attributes.get(object_id, None) 444 | if attributes is None: 445 | object_attributes.append([-1] * args.max_attributes_per_image) 446 | num_attributes.append(0) 447 | else: 448 | attribute_ids = [] 449 | for attribute in attributes: 450 | if attribute in vocab['attribute_name_to_idx']: 451 | attribute_ids.append(vocab['attribute_name_to_idx'][attribute]) 452 | if len(attribute_ids) >= args.max_attributes_per_image: 453 | break 454 | num_attributes.append(len(attribute_ids)) 455 | pad_len = args.max_attributes_per_image - len(attribute_ids) 456 | attribute_ids = attribute_ids + [-1] * pad_len 457 | object_attributes.append(attribute_ids) 458 | 459 | # Pad object info out to max_objects_per_image 460 | while len(image_object_ids) < args.max_objects_per_image: 461 | image_object_ids.append(-1) 462 | image_object_names.append(-1) 463 | image_object_boxes.append([-1, -1, -1, -1]) 464 | num_attributes.append(-1) 465 | 466 | # Pad relationship info out to max_relationships_per_image 467 | while len(image_rel_ids) < args.max_relationships_per_image: 468 | image_rel_ids.append(-1) 469 | image_rel_subs.append(-1) 470 | image_rel_preds.append(-1) 471 | image_rel_objs.append(-1) 472 | 473 | final_image_ids.append(image_id) 474 | object_ids.append(image_object_ids) 475 | object_names.append(image_object_names) 476 | object_boxes.append(image_object_boxes) 477 | objects_per_image.append(num_objects) 478 | relationship_ids.append(image_rel_ids) 479 | relationship_subjects.append(image_rel_subs) 480 | relationship_predicates.append(image_rel_preds) 481 | relationship_objects.append(image_rel_objs) 482 | relationships_per_image.append(num_relationships) 483 | attributes_per_object.append(num_attributes) 484 | 485 | print('Skip stats for split "%s"' % split) 486 | for stat, count in skip_stats.items(): 487 | print(stat, count) 488 | print() 489 | numpy_arrays[split] = { 490 | 'image_ids': np.asarray(final_image_ids), 491 | 'object_ids': np.asarray(object_ids), 492 | 'object_names': np.asarray(object_names), 493 | 'object_boxes': np.asarray(object_boxes), 494 | 'objects_per_image': np.asarray(objects_per_image), 495 | 'relationship_ids': np.asarray(relationship_ids), 496 | 'relationship_subjects': np.asarray(relationship_subjects), 497 | 'relationship_predicates': np.asarray(relationship_predicates), 498 | 'relationship_objects': np.asarray(relationship_objects), 499 | 'relationships_per_image': np.asarray(relationships_per_image), 500 | 'attributes_per_object': np.asarray(attributes_per_object), 501 | 'object_attributes': np.asarray(object_attributes), 502 | } 503 | for k, v in numpy_arrays[split].items(): 504 | if v.dtype == np.int64: 505 | numpy_arrays[split][k] = v.astype(np.int32) 506 | return numpy_arrays 507 | 508 | 509 | if __name__ == '__main__': 510 | args = parser.parse_args() 511 | main(args) 512 | -------------------------------------------------------------------------------- /scripts/print_args.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import argparse 18 | import torch 19 | 20 | 21 | """ 22 | Tiny utility to print the command-line args used for a checkpoint 23 | """ 24 | 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('checkpoint') 28 | 29 | 30 | def main(args): 31 | checkpoint = torch.load(args.checkpoint, map_location='cpu') 32 | for k, v in checkpoint['args'].items(): 33 | print(k, v) 34 | 35 | 36 | if __name__ == '__main__': 37 | args = parser.parse_args() 38 | main(args) 39 | 40 | -------------------------------------------------------------------------------- /scripts/run_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import argparse, json, os 18 | 19 | from imageio import imwrite 20 | import torch 21 | 22 | from sg2im.model import Sg2ImModel 23 | from sg2im.data.utils import imagenet_deprocess_batch 24 | import sg2im.vis as vis 25 | 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--checkpoint', default='sg2im-models/vg128.pt') 29 | parser.add_argument('--scene_graphs_json', default='scene_graphs/figure_6_sheep.json') 30 | parser.add_argument('--output_dir', default='outputs') 31 | parser.add_argument('--draw_scene_graphs', type=int, default=0) 32 | parser.add_argument('--device', default='gpu', choices=['cpu', 'gpu']) 33 | 34 | 35 | def main(args): 36 | if not os.path.isfile(args.checkpoint): 37 | print('ERROR: Checkpoint file "%s" not found' % args.checkpoint) 38 | print('Maybe you forgot to download pretraind models? Try running:') 39 | print('bash scripts/download_models.sh') 40 | return 41 | 42 | if not os.path.isdir(args.output_dir): 43 | print('Output directory "%s" does not exist; creating it' % args.output_dir) 44 | os.makedirs(args.output_dir) 45 | 46 | if args.device == 'cpu': 47 | device = torch.device('cpu') 48 | elif args.device == 'gpu': 49 | device = torch.device('cuda:0') 50 | if not torch.cuda.is_available(): 51 | print('WARNING: CUDA not available; falling back to CPU') 52 | device = torch.device('cpu') 53 | 54 | # Load the model, with a bit of care in case there are no GPUs 55 | map_location = 'cpu' if device == torch.device('cpu') else None 56 | checkpoint = torch.load(args.checkpoint, map_location=map_location) 57 | model = Sg2ImModel(**checkpoint['model_kwargs']) 58 | model.load_state_dict(checkpoint['model_state']) 59 | model.eval() 60 | model.to(device) 61 | 62 | # Load the scene graphs 63 | with open(args.scene_graphs_json, 'r') as f: 64 | scene_graphs = json.load(f) 65 | 66 | # Run the model forward 67 | with torch.no_grad(): 68 | imgs, boxes_pred, masks_pred, _ = model.forward_json(scene_graphs) 69 | imgs = imagenet_deprocess_batch(imgs) 70 | 71 | # Save the generated images 72 | for i in range(imgs.shape[0]): 73 | img_np = imgs[i].numpy().transpose(1, 2, 0) 74 | img_path = os.path.join(args.output_dir, 'img%06d.png' % i) 75 | imwrite(img_path, img_np) 76 | 77 | # Draw the scene graphs 78 | if args.draw_scene_graphs == 1: 79 | for i, sg in enumerate(scene_graphs): 80 | sg_img = vis.draw_scene_graph(sg['objects'], sg['relationships']) 81 | sg_img_path = os.path.join(args.output_dir, 'sg%06d.png' % i) 82 | imwrite(sg_img_path, sg_img) 83 | 84 | 85 | if __name__ == '__main__': 86 | args = parser.parse_args() 87 | main(args) 88 | 89 | -------------------------------------------------------------------------------- /scripts/sample_images.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """ 18 | This script can be used to sample many images from a model for evaluation. 19 | """ 20 | 21 | 22 | import argparse, json 23 | import os 24 | 25 | import torch 26 | from torch.autograd import Variable 27 | from torch.utils.data import DataLoader 28 | 29 | from scipy.misc import imsave, imresize 30 | 31 | from sg2im.data import imagenet_deprocess_batch 32 | from sg2im.data.coco import CocoSceneGraphDataset, coco_collate_fn 33 | from sg2im.data.vg import VgSceneGraphDataset, vg_collate_fn 34 | from sg2im.data.utils import split_graph_batch 35 | from sg2im.model import Sg2ImModel 36 | from sg2im.utils import int_tuple, bool_flag 37 | from sg2im.vis import draw_scene_graph 38 | 39 | 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('--checkpoint', default='sg2im-models/vg64.pt') 42 | parser.add_argument('--checkpoint_list', default=None) 43 | parser.add_argument('--model_mode', default='eval', choices=['train', 'eval']) 44 | 45 | # Shared dataset options 46 | parser.add_argument('--dataset', default='vg', choices=['coco', 'vg']) 47 | parser.add_argument('--image_size', default=(64, 64), type=int_tuple) 48 | parser.add_argument('--batch_size', default=24, type=int) 49 | parser.add_argument('--shuffle', default=False, type=bool_flag) 50 | parser.add_argument('--loader_num_workers', default=4, type=int) 51 | parser.add_argument('--num_samples', default=10000, type=int) 52 | parser.add_argument('--save_gt_imgs', default=False, type=bool_flag) 53 | parser.add_argument('--save_graphs', default=False, type=bool_flag) 54 | parser.add_argument('--use_gt_boxes', default=False, type=bool_flag) 55 | parser.add_argument('--use_gt_masks', default=False, type=bool_flag) 56 | parser.add_argument('--save_layout', default=True, type=bool_flag) 57 | 58 | parser.add_argument('--output_dir', default='output') 59 | 60 | # For VG 61 | VG_DIR = os.path.expanduser('datasets/vg') 62 | parser.add_argument('--vg_h5', default=os.path.join(VG_DIR, 'val.h5')) 63 | parser.add_argument('--vg_image_dir', 64 | default=os.path.join(VG_DIR, 'images')) 65 | 66 | # For COCO 67 | COCO_DIR = os.path.expanduser('~/datasets/coco/2017') 68 | parser.add_argument('--coco_image_dir', 69 | default=os.path.join(COCO_DIR, 'images/val2017')) 70 | parser.add_argument('--instances_json', 71 | default=os.path.join(COCO_DIR, 'annotations/instances_val2017.json')) 72 | parser.add_argument('--stuff_json', 73 | default=os.path.join(COCO_DIR, 'annotations/stuff_val2017.json')) 74 | 75 | 76 | def build_coco_dset(args, checkpoint): 77 | checkpoint_args = checkpoint['args'] 78 | print('include other: ', checkpoint_args.get('coco_include_other')) 79 | dset_kwargs = { 80 | 'image_dir': args.coco_image_dir, 81 | 'instances_json': args.instances_json, 82 | 'stuff_json': args.stuff_json, 83 | 'stuff_only': checkpoint_args['coco_stuff_only'], 84 | 'image_size': args.image_size, 85 | 'mask_size': checkpoint_args['mask_size'], 86 | 'max_samples': args.num_samples, 87 | 'min_object_size': checkpoint_args['min_object_size'], 88 | 'min_objects_per_image': checkpoint_args['min_objects_per_image'], 89 | 'instance_whitelist': checkpoint_args['instance_whitelist'], 90 | 'stuff_whitelist': checkpoint_args['stuff_whitelist'], 91 | 'include_other': checkpoint_args.get('coco_include_other', True), 92 | } 93 | dset = CocoSceneGraphDataset(**dset_kwargs) 94 | return dset 95 | 96 | 97 | def build_vg_dset(args, checkpoint): 98 | vocab = checkpoint['model_kwargs']['vocab'] 99 | dset_kwargs = { 100 | 'vocab': vocab, 101 | 'h5_path': args.vg_h5, 102 | 'image_dir': args.vg_image_dir, 103 | 'image_size': args.image_size, 104 | 'max_samples': args.num_samples, 105 | 'max_objects': checkpoint['args']['max_objects_per_image'], 106 | 'use_orphaned_objects': checkpoint['args']['vg_use_orphaned_objects'], 107 | } 108 | dset = VgSceneGraphDataset(**dset_kwargs) 109 | return dset 110 | 111 | 112 | def build_loader(args, checkpoint): 113 | if args.dataset == 'coco': 114 | dset = build_coco_dset(args, checkpoint) 115 | collate_fn = coco_collate_fn 116 | elif args.dataset == 'vg': 117 | dset = build_vg_dset(args, checkpoint) 118 | collate_fn = vg_collate_fn 119 | 120 | loader_kwargs = { 121 | 'batch_size': args.batch_size, 122 | 'num_workers': args.loader_num_workers, 123 | 'shuffle': args.shuffle, 124 | 'collate_fn': collate_fn, 125 | } 126 | loader = DataLoader(dset, **loader_kwargs) 127 | return loader 128 | 129 | 130 | def build_model(args, checkpoint): 131 | kwargs = checkpoint['model_kwargs'] 132 | model = Sg2ImModel(**checkpoint['model_kwargs']) 133 | model.load_state_dict(checkpoint['model_state']) 134 | if args.model_mode == 'eval': 135 | model.eval() 136 | elif args.model_mode == 'train': 137 | model.train() 138 | model.image_size = args.image_size 139 | model.cuda() 140 | return model 141 | 142 | 143 | def makedir(base, name, flag=True): 144 | dir_name = None 145 | if flag: 146 | dir_name = os.path.join(base, name) 147 | if not os.path.isdir(dir_name): 148 | os.makedirs(dir_name) 149 | return dir_name 150 | 151 | 152 | def run_model(args, checkpoint, output_dir, loader=None): 153 | vocab = checkpoint['model_kwargs']['vocab'] 154 | model = build_model(args, checkpoint) 155 | if loader is None: 156 | loader = build_loader(args, checkpoint) 157 | 158 | img_dir = makedir(output_dir, 'images') 159 | graph_dir = makedir(output_dir, 'graphs', args.save_graphs) 160 | gt_img_dir = makedir(output_dir, 'images_gt', args.save_gt_imgs) 161 | data_path = os.path.join(output_dir, 'data.pt') 162 | 163 | data = { 164 | 'vocab': vocab, 165 | 'objs': [], 166 | 'masks_pred': [], 167 | 'boxes_pred': [], 168 | 'masks_gt': [], 169 | 'boxes_gt': [], 170 | 'filenames': [], 171 | } 172 | 173 | img_idx = 0 174 | for batch in loader: 175 | masks = None 176 | if len(batch) == 6: 177 | imgs, objs, boxes, triples, obj_to_img, triple_to_img = [x.cuda() for x in batch] 178 | elif len(batch) == 7: 179 | imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = [x.cuda() for x in batch] 180 | 181 | imgs_gt = imagenet_deprocess_batch(imgs) 182 | boxes_gt = None 183 | masks_gt = None 184 | if args.use_gt_boxes: 185 | boxes_gt = boxes 186 | if args.use_gt_masks: 187 | masks_gt = masks 188 | 189 | # Run the model with predicted masks 190 | model_out = model(objs, triples, obj_to_img, 191 | boxes_gt=boxes_gt, masks_gt=masks_gt) 192 | imgs_pred, boxes_pred, masks_pred, _ = model_out 193 | imgs_pred = imagenet_deprocess_batch(imgs_pred) 194 | 195 | obj_data = [objs, boxes_pred, masks_pred] 196 | _, obj_data = split_graph_batch(triples, obj_data, obj_to_img, 197 | triple_to_img) 198 | objs, boxes_pred, masks_pred = obj_data 199 | 200 | obj_data_gt = [boxes.data] 201 | if masks is not None: 202 | obj_data_gt.append(masks.data) 203 | triples, obj_data_gt = split_graph_batch(triples, obj_data_gt, 204 | obj_to_img, triple_to_img) 205 | boxes_gt, masks_gt = obj_data_gt[0], None 206 | if masks is not None: 207 | masks_gt = obj_data_gt[1] 208 | 209 | for i in range(imgs_pred.size(0)): 210 | img_filename = '%04d.png' % img_idx 211 | if args.save_gt_imgs: 212 | img_gt = imgs_gt[i].numpy().transpose(1, 2, 0) 213 | img_gt_path = os.path.join(gt_img_dir, img_filename) 214 | imsave(img_gt_path, img_gt) 215 | 216 | img_pred = imgs_pred[i] 217 | img_pred_np = imgs_pred[i].numpy().transpose(1, 2, 0) 218 | img_path = os.path.join(img_dir, img_filename) 219 | imsave(img_path, img_pred_np) 220 | 221 | data['objs'].append(objs[i].cpu().clone()) 222 | data['masks_pred'].append(masks_pred[i].cpu().clone()) 223 | data['boxes_pred'].append(boxes_pred[i].cpu().clone()) 224 | data['boxes_gt'].append(boxes_gt[i].cpu().clone()) 225 | data['filenames'].append(img_filename) 226 | 227 | cur_masks_gt = None 228 | if masks_gt is not None: 229 | cur_masks_gt = masks_gt[i].cpu().clone() 230 | data['masks_gt'].append(cur_masks_gt) 231 | 232 | if args.save_graphs: 233 | graph_img = draw_scene_graph(vocab, objs[i], triples[i]) 234 | graph_path = os.path.join(graph_dir, img_filename) 235 | imsave(graph_path, graph_img) 236 | 237 | img_idx += 1 238 | 239 | torch.save(data, data_path) 240 | print('Saved %d images' % img_idx) 241 | 242 | 243 | def main(args): 244 | got_checkpoint = args.checkpoint is not None 245 | got_checkpoint_list = args.checkpoint_list is not None 246 | if got_checkpoint == got_checkpoint_list: 247 | raise ValueError('Must specify exactly one of --checkpoint and --checkpoint_list') 248 | 249 | if got_checkpoint: 250 | checkpoint = torch.load(args.checkpoint) 251 | print('Loading model from ', args.checkpoint) 252 | run_model(args, checkpoint, args.output_dir) 253 | elif got_checkpoint_list: 254 | # For efficiency, use the same loader for all checkpoints 255 | loader = None 256 | with open(args.checkpoint_list, 'r') as f: 257 | checkpoint_list = [line.strip() for line in f] 258 | for i, path in enumerate(checkpoint_list): 259 | if os.path.isfile(path): 260 | print('Loading model from ', path) 261 | checkpoint = torch.load(path) 262 | if loader is None: 263 | loader = build_loader(args, checkpoint) 264 | output_dir = os.path.join(args.output_dir, 'result%03d' % (i + 1)) 265 | run_model(args, checkpoint, output_dir, loader) 266 | elif os.path.isdir(path): 267 | # Look for snapshots in this dir 268 | for fn in sorted(os.listdir(path)): 269 | if 'snapshot' not in fn: 270 | continue 271 | checkpoint_path = os.path.join(path, fn) 272 | print('Loading model from ', checkpoint_path) 273 | checkpoint = torch.load(checkpoint_path) 274 | if loader is None: 275 | loader = build_loader(args, checkpoint) 276 | 277 | # Snapshots have names like "snapshot_00100K.pt'; we want to 278 | # extract the "00100K" part 279 | snapshot_name = os.path.splitext(fn)[0].split('_')[1] 280 | output_dir = 'result%03d_%s' % (i, snapshot_name) 281 | output_dir = os.path.join(args.output_dir, output_dir) 282 | 283 | run_model(args, checkpoint, output_dir, loader) 284 | 285 | 286 | if __name__ == '__main__': 287 | args = parser.parse_args() 288 | main(args) 289 | 290 | 291 | -------------------------------------------------------------------------------- /scripts/strip_checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import argparse, os 18 | import torch 19 | 20 | 21 | """ 22 | Checkpoints saved by train.py contain not only model parameters but also 23 | optimizer states, losses, a history of generated images, and other statistics. 24 | This information is very useful for development and debugging models, but makes 25 | the saved checkpoints very large. This utility script strips away all extra 26 | information from saved checkpoints, keeping only the saved models. 27 | """ 28 | 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--input_checkpoint', default=None) 32 | parser.add_argument('--output_checkpoint', default=None) 33 | parser.add_argument('--input_dir', default=None) 34 | parser.add_argument('--output_dir', default=None) 35 | parser.add_argument('--keep_discriminators', type=int, default=1) 36 | 37 | 38 | def main(args): 39 | if args.input_checkpoint is not None: 40 | handle_checkpoint(args, args.input_checkpoint, args.output_checkpoint) 41 | if args.input_dir is not None: 42 | handle_dir(args, args.input_dir, args.output_dir) 43 | 44 | 45 | def handle_dir(args, input_dir, output_dir): 46 | for fn in os.listdir(input_dir): 47 | if not fn.endswith('.pt'): 48 | continue 49 | input_path = os.path.join(input_dir, fn) 50 | output_path = os.path.join(output_dir, fn) 51 | handle_checkpoint(args, input_path, output_path) 52 | 53 | 54 | def handle_checkpoint(args, input_path, output_path): 55 | input_checkpoint = torch.load(input_path) 56 | keep = ['args', 'model_state', 'model_kwargs'] 57 | if args.keep_discriminators == 1: 58 | keep += ['d_img_state', 'd_img_kwargs', 'd_obj_state', 'd_obj_kwargs'] 59 | output_checkpoint = {} 60 | for k, v in input_checkpoint.items(): 61 | if k in keep: 62 | output_checkpoint[k] = v 63 | torch.save(output_checkpoint, output_path) 64 | 65 | 66 | if __name__ == '__main__': 67 | args = parser.parse_args() 68 | main(args) 69 | 70 | -------------------------------------------------------------------------------- /scripts/strip_old_args.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import argparse, os 18 | import torch 19 | 20 | 21 | """ 22 | This utility script removes deprecated kwargs in checkpoints. 23 | """ 24 | 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--input_checkpoint', default=None) 28 | parser.add_argument('--input_dir', default=None) 29 | 30 | 31 | DEPRECATED_KWARGS = { 32 | 'model_kwargs': [ 33 | 'vec_noise_dim', 'gconv_mode', 'box_anchor', 'decouple_obj_predictions', 34 | ], 35 | } 36 | 37 | 38 | def main(args): 39 | got_checkpoint = (args.input_checkpoint is not None) 40 | got_dir = (args.input_dir is not None) 41 | assert got_checkpoint != got_dir, "Must give exactly one of checkpoint or dir" 42 | if got_checkpoint: 43 | handle_checkpoint(args.input_checkpoint) 44 | elif got_dir: 45 | handle_dir(args.input_dir) 46 | 47 | 48 | 49 | def handle_dir(dir_path): 50 | for fn in os.listdir(dir_path): 51 | if not fn.endswith('.pt'): 52 | continue 53 | checkpoint_path = os.path.join(dir_path, fn) 54 | handle_checkpoint(checkpoint_path) 55 | 56 | 57 | def handle_checkpoint(checkpoint_path): 58 | print('Stripping old args from checkpoint "%s"' % checkpoint_path) 59 | checkpoint = torch.load(checkpoint_path) 60 | for group, deprecated in DEPRECATED_KWARGS.items(): 61 | assert group in checkpoint 62 | for k in deprecated: 63 | if k in checkpoint[group]: 64 | print('Removing key "%s" from "%s"' % (k, group)) 65 | del checkpoint[group][k] 66 | torch.save(checkpoint, checkpoint_path) 67 | 68 | 69 | if __name__ == '__main__': 70 | args = parser.parse_args() 71 | main(args) 72 | 73 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import argparse 18 | import functools 19 | import os 20 | import json 21 | import math 22 | from collections import defaultdict 23 | import random 24 | 25 | import numpy as np 26 | import torch 27 | import torch.optim as optim 28 | import torch.nn as nn 29 | import torch.nn.functional as F 30 | from torch.utils.data import DataLoader 31 | 32 | from sg2im.data import imagenet_deprocess_batch 33 | from sg2im.data.coco import CocoSceneGraphDataset, coco_collate_fn 34 | from sg2im.data.vg import VgSceneGraphDataset, vg_collate_fn 35 | from sg2im.discriminators import PatchDiscriminator, AcCropDiscriminator 36 | from sg2im.losses import get_gan_losses 37 | from sg2im.metrics import jaccard 38 | from sg2im.model import Sg2ImModel 39 | from sg2im.utils import int_tuple, float_tuple, str_tuple 40 | from sg2im.utils import timeit, bool_flag, LossManager 41 | 42 | torch.backends.cudnn.benchmark = True 43 | 44 | VG_DIR = os.path.expanduser('datasets/vg') 45 | COCO_DIR = os.path.expanduser('datasets/coco') 46 | 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument('--dataset', default='coco', choices=['vg', 'coco']) 49 | 50 | # Optimization hyperparameters 51 | parser.add_argument('--batch_size', default=32, type=int) 52 | parser.add_argument('--num_iterations', default=1000000, type=int) 53 | parser.add_argument('--learning_rate', default=1e-4, type=float) 54 | 55 | # Switch the generator to eval mode after this many iterations 56 | parser.add_argument('--eval_mode_after', default=100000, type=int) 57 | 58 | # Dataset options common to both VG and COCO 59 | parser.add_argument('--image_size', default='64,64', type=int_tuple) 60 | parser.add_argument('--num_train_samples', default=None, type=int) 61 | parser.add_argument('--num_val_samples', default=1024, type=int) 62 | parser.add_argument('--shuffle_val', default=True, type=bool_flag) 63 | parser.add_argument('--loader_num_workers', default=4, type=int) 64 | parser.add_argument('--include_relationships', default=True, type=bool_flag) 65 | 66 | # VG-specific options 67 | parser.add_argument('--vg_image_dir', default=os.path.join(VG_DIR, 'images')) 68 | parser.add_argument('--train_h5', default=os.path.join(VG_DIR, 'train.h5')) 69 | parser.add_argument('--val_h5', default=os.path.join(VG_DIR, 'val.h5')) 70 | parser.add_argument('--vocab_json', default=os.path.join(VG_DIR, 'vocab.json')) 71 | parser.add_argument('--max_objects_per_image', default=10, type=int) 72 | parser.add_argument('--vg_use_orphaned_objects', default=True, type=bool_flag) 73 | 74 | # COCO-specific options 75 | parser.add_argument('--coco_train_image_dir', 76 | default=os.path.join(COCO_DIR, 'images/train2017')) 77 | parser.add_argument('--coco_val_image_dir', 78 | default=os.path.join(COCO_DIR, 'images/val2017')) 79 | parser.add_argument('--coco_train_instances_json', 80 | default=os.path.join(COCO_DIR, 'annotations/instances_train2017.json')) 81 | parser.add_argument('--coco_train_stuff_json', 82 | default=os.path.join(COCO_DIR, 'annotations/stuff_train2017.json')) 83 | parser.add_argument('--coco_val_instances_json', 84 | default=os.path.join(COCO_DIR, 'annotations/instances_val2017.json')) 85 | parser.add_argument('--coco_val_stuff_json', 86 | default=os.path.join(COCO_DIR, 'annotations/stuff_val2017.json')) 87 | parser.add_argument('--instance_whitelist', default=None, type=str_tuple) 88 | parser.add_argument('--stuff_whitelist', default=None, type=str_tuple) 89 | parser.add_argument('--coco_include_other', default=False, type=bool_flag) 90 | parser.add_argument('--min_object_size', default=0.02, type=float) 91 | parser.add_argument('--min_objects_per_image', default=3, type=int) 92 | parser.add_argument('--coco_stuff_only', default=True, type=bool_flag) 93 | 94 | # Generator options 95 | parser.add_argument('--mask_size', default=16, type=int) # Set this to 0 to use no masks 96 | parser.add_argument('--embedding_dim', default=128, type=int) 97 | parser.add_argument('--gconv_dim', default=128, type=int) 98 | parser.add_argument('--gconv_hidden_dim', default=512, type=int) 99 | parser.add_argument('--gconv_num_layers', default=5, type=int) 100 | parser.add_argument('--mlp_normalization', default='none', type=str) 101 | parser.add_argument('--refinement_network_dims', default='1024,512,256,128,64', type=int_tuple) 102 | parser.add_argument('--normalization', default='batch') 103 | parser.add_argument('--activation', default='leakyrelu-0.2') 104 | parser.add_argument('--layout_noise_dim', default=32, type=int) 105 | parser.add_argument('--use_boxes_pred_after', default=-1, type=int) 106 | 107 | # Generator losses 108 | parser.add_argument('--mask_loss_weight', default=0, type=float) 109 | parser.add_argument('--l1_pixel_loss_weight', default=1.0, type=float) 110 | parser.add_argument('--bbox_pred_loss_weight', default=10, type=float) 111 | parser.add_argument('--predicate_pred_loss_weight', default=0, type=float) # DEPRECATED 112 | 113 | # Generic discriminator options 114 | parser.add_argument('--discriminator_loss_weight', default=0.01, type=float) 115 | parser.add_argument('--gan_loss_type', default='gan') 116 | parser.add_argument('--d_clip', default=None, type=float) 117 | parser.add_argument('--d_normalization', default='batch') 118 | parser.add_argument('--d_padding', default='valid') 119 | parser.add_argument('--d_activation', default='leakyrelu-0.2') 120 | 121 | # Object discriminator 122 | parser.add_argument('--d_obj_arch', 123 | default='C4-64-2,C4-128-2,C4-256-2') 124 | parser.add_argument('--crop_size', default=32, type=int) 125 | parser.add_argument('--d_obj_weight', default=1.0, type=float) # multiplied by d_loss_weight 126 | parser.add_argument('--ac_loss_weight', default=0.1, type=float) 127 | 128 | # Image discriminator 129 | parser.add_argument('--d_img_arch', 130 | default='C4-64-2,C4-128-2,C4-256-2') 131 | parser.add_argument('--d_img_weight', default=1.0, type=float) # multiplied by d_loss_weight 132 | 133 | # Output options 134 | parser.add_argument('--print_every', default=10, type=int) 135 | parser.add_argument('--timing', default=False, type=bool_flag) 136 | parser.add_argument('--checkpoint_every', default=10000, type=int) 137 | parser.add_argument('--output_dir', default=os.getcwd()) 138 | parser.add_argument('--checkpoint_name', default='checkpoint') 139 | parser.add_argument('--checkpoint_start_from', default=None) 140 | parser.add_argument('--restore_from_checkpoint', default=False, type=bool_flag) 141 | 142 | 143 | def add_loss(total_loss, curr_loss, loss_dict, loss_name, weight=1): 144 | curr_loss = curr_loss * weight 145 | loss_dict[loss_name] = curr_loss.item() 146 | if total_loss is not None: 147 | total_loss += curr_loss 148 | else: 149 | total_loss = curr_loss 150 | return total_loss 151 | 152 | 153 | def check_args(args): 154 | H, W = args.image_size 155 | for _ in args.refinement_network_dims[1:]: 156 | H = H // 2 157 | if H == 0: 158 | raise ValueError("Too many layers in refinement network") 159 | 160 | 161 | def build_model(args, vocab): 162 | if args.checkpoint_start_from is not None: 163 | checkpoint = torch.load(args.checkpoint_start_from) 164 | kwargs = checkpoint['model_kwargs'] 165 | model = Sg2ImModel(**kwargs) 166 | raw_state_dict = checkpoint['model_state'] 167 | state_dict = {} 168 | for k, v in raw_state_dict.items(): 169 | if k.startswith('module.'): 170 | k = k[7:] 171 | state_dict[k] = v 172 | model.load_state_dict(state_dict) 173 | else: 174 | kwargs = { 175 | 'vocab': vocab, 176 | 'image_size': args.image_size, 177 | 'embedding_dim': args.embedding_dim, 178 | 'gconv_dim': args.gconv_dim, 179 | 'gconv_hidden_dim': args.gconv_hidden_dim, 180 | 'gconv_num_layers': args.gconv_num_layers, 181 | 'mlp_normalization': args.mlp_normalization, 182 | 'refinement_dims': args.refinement_network_dims, 183 | 'normalization': args.normalization, 184 | 'activation': args.activation, 185 | 'mask_size': args.mask_size, 186 | 'layout_noise_dim': args.layout_noise_dim, 187 | } 188 | model = Sg2ImModel(**kwargs) 189 | return model, kwargs 190 | 191 | 192 | def build_obj_discriminator(args, vocab): 193 | discriminator = None 194 | d_kwargs = {} 195 | d_weight = args.discriminator_loss_weight 196 | d_obj_weight = args.d_obj_weight 197 | if d_weight == 0 or d_obj_weight == 0: 198 | return discriminator, d_kwargs 199 | 200 | d_kwargs = { 201 | 'vocab': vocab, 202 | 'arch': args.d_obj_arch, 203 | 'normalization': args.d_normalization, 204 | 'activation': args.d_activation, 205 | 'padding': args.d_padding, 206 | 'object_size': args.crop_size, 207 | } 208 | discriminator = AcCropDiscriminator(**d_kwargs) 209 | return discriminator, d_kwargs 210 | 211 | 212 | def build_img_discriminator(args, vocab): 213 | discriminator = None 214 | d_kwargs = {} 215 | d_weight = args.discriminator_loss_weight 216 | d_img_weight = args.d_img_weight 217 | if d_weight == 0 or d_img_weight == 0: 218 | return discriminator, d_kwargs 219 | 220 | d_kwargs = { 221 | 'arch': args.d_img_arch, 222 | 'normalization': args.d_normalization, 223 | 'activation': args.d_activation, 224 | 'padding': args.d_padding, 225 | } 226 | discriminator = PatchDiscriminator(**d_kwargs) 227 | return discriminator, d_kwargs 228 | 229 | 230 | def build_coco_dsets(args): 231 | dset_kwargs = { 232 | 'image_dir': args.coco_train_image_dir, 233 | 'instances_json': args.coco_train_instances_json, 234 | 'stuff_json': args.coco_train_stuff_json, 235 | 'stuff_only': args.coco_stuff_only, 236 | 'image_size': args.image_size, 237 | 'mask_size': args.mask_size, 238 | 'max_samples': args.num_train_samples, 239 | 'min_object_size': args.min_object_size, 240 | 'min_objects_per_image': args.min_objects_per_image, 241 | 'instance_whitelist': args.instance_whitelist, 242 | 'stuff_whitelist': args.stuff_whitelist, 243 | 'include_other': args.coco_include_other, 244 | 'include_relationships': args.include_relationships, 245 | } 246 | train_dset = CocoSceneGraphDataset(**dset_kwargs) 247 | num_objs = train_dset.total_objects() 248 | num_imgs = len(train_dset) 249 | print('Training dataset has %d images and %d objects' % (num_imgs, num_objs)) 250 | print('(%.2f objects per image)' % (float(num_objs) / num_imgs)) 251 | 252 | dset_kwargs['image_dir'] = args.coco_val_image_dir 253 | dset_kwargs['instances_json'] = args.coco_val_instances_json 254 | dset_kwargs['stuff_json'] = args.coco_val_stuff_json 255 | dset_kwargs['max_samples'] = args.num_val_samples 256 | val_dset = CocoSceneGraphDataset(**dset_kwargs) 257 | 258 | assert train_dset.vocab == val_dset.vocab 259 | vocab = json.loads(json.dumps(train_dset.vocab)) 260 | 261 | return vocab, train_dset, val_dset 262 | 263 | 264 | def build_vg_dsets(args): 265 | with open(args.vocab_json, 'r') as f: 266 | vocab = json.load(f) 267 | dset_kwargs = { 268 | 'vocab': vocab, 269 | 'h5_path': args.train_h5, 270 | 'image_dir': args.vg_image_dir, 271 | 'image_size': args.image_size, 272 | 'max_samples': args.num_train_samples, 273 | 'max_objects': args.max_objects_per_image, 274 | 'use_orphaned_objects': args.vg_use_orphaned_objects, 275 | 'include_relationships': args.include_relationships, 276 | } 277 | train_dset = VgSceneGraphDataset(**dset_kwargs) 278 | iter_per_epoch = len(train_dset) // args.batch_size 279 | print('There are %d iterations per epoch' % iter_per_epoch) 280 | 281 | dset_kwargs['h5_path'] = args.val_h5 282 | del dset_kwargs['max_samples'] 283 | val_dset = VgSceneGraphDataset(**dset_kwargs) 284 | 285 | return vocab, train_dset, val_dset 286 | 287 | 288 | def build_loaders(args): 289 | if args.dataset == 'vg': 290 | vocab, train_dset, val_dset = build_vg_dsets(args) 291 | collate_fn = vg_collate_fn 292 | elif args.dataset == 'coco': 293 | vocab, train_dset, val_dset = build_coco_dsets(args) 294 | collate_fn = coco_collate_fn 295 | 296 | loader_kwargs = { 297 | 'batch_size': args.batch_size, 298 | 'num_workers': args.loader_num_workers, 299 | 'shuffle': True, 300 | 'collate_fn': collate_fn, 301 | } 302 | train_loader = DataLoader(train_dset, **loader_kwargs) 303 | 304 | loader_kwargs['shuffle'] = args.shuffle_val 305 | val_loader = DataLoader(val_dset, **loader_kwargs) 306 | return vocab, train_loader, val_loader 307 | 308 | 309 | def check_model(args, t, loader, model): 310 | float_dtype = torch.cuda.FloatTensor 311 | long_dtype = torch.cuda.LongTensor 312 | num_samples = 0 313 | all_losses = defaultdict(list) 314 | total_iou = 0 315 | total_boxes = 0 316 | with torch.no_grad(): 317 | for batch in loader: 318 | batch = [tensor.cuda() for tensor in batch] 319 | masks = None 320 | if len(batch) == 6: 321 | imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch 322 | elif len(batch) == 7: 323 | imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = batch 324 | predicates = triples[:, 1] 325 | 326 | # Run the model as it has been run during training 327 | model_masks = masks 328 | model_out = model(objs, triples, obj_to_img, boxes_gt=boxes, masks_gt=model_masks) 329 | imgs_pred, boxes_pred, masks_pred, predicate_scores = model_out 330 | 331 | skip_pixel_loss = False 332 | total_loss, losses = calculate_model_losses( 333 | args, skip_pixel_loss, model, imgs, imgs_pred, 334 | boxes, boxes_pred, masks, masks_pred, 335 | predicates, predicate_scores) 336 | 337 | total_iou += jaccard(boxes_pred, boxes) 338 | total_boxes += boxes_pred.size(0) 339 | 340 | for loss_name, loss_val in losses.items(): 341 | all_losses[loss_name].append(loss_val) 342 | num_samples += imgs.size(0) 343 | if num_samples >= args.num_val_samples: 344 | break 345 | 346 | samples = {} 347 | samples['gt_img'] = imgs 348 | 349 | model_out = model(objs, triples, obj_to_img, boxes_gt=boxes, masks_gt=masks) 350 | samples['gt_box_gt_mask'] = model_out[0] 351 | 352 | model_out = model(objs, triples, obj_to_img, boxes_gt=boxes) 353 | samples['gt_box_pred_mask'] = model_out[0] 354 | 355 | model_out = model(objs, triples, obj_to_img) 356 | samples['pred_box_pred_mask'] = model_out[0] 357 | 358 | for k, v in samples.items(): 359 | samples[k] = imagenet_deprocess_batch(v) 360 | 361 | mean_losses = {k: np.mean(v) for k, v in all_losses.items()} 362 | avg_iou = total_iou / total_boxes 363 | 364 | masks_to_store = masks 365 | if masks_to_store is not None: 366 | masks_to_store = masks_to_store.data.cpu().clone() 367 | 368 | masks_pred_to_store = masks_pred 369 | if masks_pred_to_store is not None: 370 | masks_pred_to_store = masks_pred_to_store.data.cpu().clone() 371 | 372 | batch_data = { 373 | 'objs': objs.detach().cpu().clone(), 374 | 'boxes_gt': boxes.detach().cpu().clone(), 375 | 'masks_gt': masks_to_store, 376 | 'triples': triples.detach().cpu().clone(), 377 | 'obj_to_img': obj_to_img.detach().cpu().clone(), 378 | 'triple_to_img': triple_to_img.detach().cpu().clone(), 379 | 'boxes_pred': boxes_pred.detach().cpu().clone(), 380 | 'masks_pred': masks_pred_to_store 381 | } 382 | out = [mean_losses, samples, batch_data, avg_iou] 383 | 384 | return tuple(out) 385 | 386 | 387 | def calculate_model_losses(args, skip_pixel_loss, model, img, img_pred, 388 | bbox, bbox_pred, masks, masks_pred, 389 | predicates, predicate_scores): 390 | total_loss = torch.zeros(1).to(img) 391 | losses = {} 392 | 393 | l1_pixel_weight = args.l1_pixel_loss_weight 394 | if skip_pixel_loss: 395 | l1_pixel_weight = 0 396 | l1_pixel_loss = F.l1_loss(img_pred, img) 397 | total_loss = add_loss(total_loss, l1_pixel_loss, losses, 'L1_pixel_loss', 398 | l1_pixel_weight) 399 | loss_bbox = F.mse_loss(bbox_pred, bbox) 400 | total_loss = add_loss(total_loss, loss_bbox, losses, 'bbox_pred', 401 | args.bbox_pred_loss_weight) 402 | 403 | if args.predicate_pred_loss_weight > 0: 404 | loss_predicate = F.cross_entropy(predicate_scores, predicates) 405 | total_loss = add_loss(total_loss, loss_predicate, losses, 'predicate_pred', 406 | args.predicate_pred_loss_weight) 407 | 408 | if args.mask_loss_weight > 0 and masks is not None and masks_pred is not None: 409 | mask_loss = F.binary_cross_entropy(masks_pred, masks.float()) 410 | total_loss = add_loss(total_loss, mask_loss, losses, 'mask_loss', 411 | args.mask_loss_weight) 412 | return total_loss, losses 413 | 414 | 415 | def main(args): 416 | print(args) 417 | check_args(args) 418 | float_dtype = torch.cuda.FloatTensor 419 | long_dtype = torch.cuda.LongTensor 420 | 421 | vocab, train_loader, val_loader = build_loaders(args) 422 | model, model_kwargs = build_model(args, vocab) 423 | model.type(float_dtype) 424 | print(model) 425 | 426 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) 427 | 428 | obj_discriminator, d_obj_kwargs = build_obj_discriminator(args, vocab) 429 | img_discriminator, d_img_kwargs = build_img_discriminator(args, vocab) 430 | gan_g_loss, gan_d_loss = get_gan_losses(args.gan_loss_type) 431 | 432 | if obj_discriminator is not None: 433 | obj_discriminator.type(float_dtype) 434 | obj_discriminator.train() 435 | print(obj_discriminator) 436 | optimizer_d_obj = torch.optim.Adam(obj_discriminator.parameters(), 437 | lr=args.learning_rate) 438 | 439 | if img_discriminator is not None: 440 | img_discriminator.type(float_dtype) 441 | img_discriminator.train() 442 | print(img_discriminator) 443 | optimizer_d_img = torch.optim.Adam(img_discriminator.parameters(), 444 | lr=args.learning_rate) 445 | 446 | restore_path = None 447 | if args.restore_from_checkpoint: 448 | restore_path = '%s_with_model.pt' % args.checkpoint_name 449 | restore_path = os.path.join(args.output_dir, restore_path) 450 | if restore_path is not None and os.path.isfile(restore_path): 451 | print('Restoring from checkpoint:') 452 | print(restore_path) 453 | checkpoint = torch.load(restore_path) 454 | model.load_state_dict(checkpoint['model_state']) 455 | optimizer.load_state_dict(checkpoint['optim_state']) 456 | 457 | if obj_discriminator is not None: 458 | obj_discriminator.load_state_dict(checkpoint['d_obj_state']) 459 | optimizer_d_obj.load_state_dict(checkpoint['d_obj_optim_state']) 460 | 461 | if img_discriminator is not None: 462 | img_discriminator.load_state_dict(checkpoint['d_img_state']) 463 | optimizer_d_img.load_state_dict(checkpoint['d_img_optim_state']) 464 | 465 | t = checkpoint['counters']['t'] 466 | if 0 <= args.eval_mode_after <= t: 467 | model.eval() 468 | else: 469 | model.train() 470 | epoch = checkpoint['counters']['epoch'] 471 | else: 472 | t, epoch = 0, 0 473 | checkpoint = { 474 | 'args': args.__dict__, 475 | 'vocab': vocab, 476 | 'model_kwargs': model_kwargs, 477 | 'd_obj_kwargs': d_obj_kwargs, 478 | 'd_img_kwargs': d_img_kwargs, 479 | 'losses_ts': [], 480 | 'losses': defaultdict(list), 481 | 'd_losses': defaultdict(list), 482 | 'checkpoint_ts': [], 483 | 'train_batch_data': [], 484 | 'train_samples': [], 485 | 'train_iou': [], 486 | 'val_batch_data': [], 487 | 'val_samples': [], 488 | 'val_losses': defaultdict(list), 489 | 'val_iou': [], 490 | 'norm_d': [], 491 | 'norm_g': [], 492 | 'counters': { 493 | 't': None, 494 | 'epoch': None, 495 | }, 496 | 'model_state': None, 'model_best_state': None, 'optim_state': None, 497 | 'd_obj_state': None, 'd_obj_best_state': None, 'd_obj_optim_state': None, 498 | 'd_img_state': None, 'd_img_best_state': None, 'd_img_optim_state': None, 499 | 'best_t': [], 500 | } 501 | 502 | while True: 503 | if t >= args.num_iterations: 504 | break 505 | epoch += 1 506 | print('Starting epoch %d' % epoch) 507 | 508 | for batch in train_loader: 509 | if t == args.eval_mode_after: 510 | print('switching to eval mode') 511 | model.eval() 512 | optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) 513 | t += 1 514 | batch = [tensor.cuda() for tensor in batch] 515 | masks = None 516 | if len(batch) == 6: 517 | imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch 518 | elif len(batch) == 7: 519 | imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = batch 520 | else: 521 | assert False 522 | predicates = triples[:, 1] 523 | 524 | with timeit('forward', args.timing): 525 | model_boxes = boxes 526 | model_masks = masks 527 | model_out = model(objs, triples, obj_to_img, 528 | boxes_gt=model_boxes, masks_gt=model_masks) 529 | imgs_pred, boxes_pred, masks_pred, predicate_scores = model_out 530 | with timeit('loss', args.timing): 531 | # Skip the pixel loss if using GT boxes 532 | skip_pixel_loss = (model_boxes is None) 533 | total_loss, losses = calculate_model_losses( 534 | args, skip_pixel_loss, model, imgs, imgs_pred, 535 | boxes, boxes_pred, masks, masks_pred, 536 | predicates, predicate_scores) 537 | 538 | if obj_discriminator is not None: 539 | scores_fake, ac_loss = obj_discriminator(imgs_pred, objs, boxes, obj_to_img) 540 | total_loss = add_loss(total_loss, ac_loss, losses, 'ac_loss', 541 | args.ac_loss_weight) 542 | weight = args.discriminator_loss_weight * args.d_obj_weight 543 | total_loss = add_loss(total_loss, gan_g_loss(scores_fake), losses, 544 | 'g_gan_obj_loss', weight) 545 | 546 | if img_discriminator is not None: 547 | scores_fake = img_discriminator(imgs_pred) 548 | weight = args.discriminator_loss_weight * args.d_img_weight 549 | total_loss = add_loss(total_loss, gan_g_loss(scores_fake), losses, 550 | 'g_gan_img_loss', weight) 551 | 552 | losses['total_loss'] = total_loss.item() 553 | if not math.isfinite(losses['total_loss']): 554 | print('WARNING: Got loss = NaN, not backpropping') 555 | continue 556 | 557 | optimizer.zero_grad() 558 | with timeit('backward', args.timing): 559 | total_loss.backward() 560 | optimizer.step() 561 | total_loss_d = None 562 | ac_loss_real = None 563 | ac_loss_fake = None 564 | d_losses = {} 565 | 566 | if obj_discriminator is not None: 567 | d_obj_losses = LossManager() 568 | imgs_fake = imgs_pred.detach() 569 | scores_fake, ac_loss_fake = obj_discriminator(imgs_fake, objs, boxes, obj_to_img) 570 | scores_real, ac_loss_real = obj_discriminator(imgs, objs, boxes, obj_to_img) 571 | 572 | d_obj_gan_loss = gan_d_loss(scores_real, scores_fake) 573 | d_obj_losses.add_loss(d_obj_gan_loss, 'd_obj_gan_loss') 574 | d_obj_losses.add_loss(ac_loss_real, 'd_ac_loss_real') 575 | d_obj_losses.add_loss(ac_loss_fake, 'd_ac_loss_fake') 576 | 577 | optimizer_d_obj.zero_grad() 578 | d_obj_losses.total_loss.backward() 579 | optimizer_d_obj.step() 580 | 581 | if img_discriminator is not None: 582 | d_img_losses = LossManager() 583 | imgs_fake = imgs_pred.detach() 584 | scores_fake = img_discriminator(imgs_fake) 585 | scores_real = img_discriminator(imgs) 586 | 587 | d_img_gan_loss = gan_d_loss(scores_real, scores_fake) 588 | d_img_losses.add_loss(d_img_gan_loss, 'd_img_gan_loss') 589 | 590 | optimizer_d_img.zero_grad() 591 | d_img_losses.total_loss.backward() 592 | optimizer_d_img.step() 593 | 594 | if t % args.print_every == 0: 595 | print('t = %d / %d' % (t, args.num_iterations)) 596 | for name, val in losses.items(): 597 | print(' G [%s]: %.4f' % (name, val)) 598 | checkpoint['losses'][name].append(val) 599 | checkpoint['losses_ts'].append(t) 600 | 601 | if obj_discriminator is not None: 602 | for name, val in d_obj_losses.items(): 603 | print(' D_obj [%s]: %.4f' % (name, val)) 604 | checkpoint['d_losses'][name].append(val) 605 | 606 | if img_discriminator is not None: 607 | for name, val in d_img_losses.items(): 608 | print(' D_img [%s]: %.4f' % (name, val)) 609 | checkpoint['d_losses'][name].append(val) 610 | 611 | if t % args.checkpoint_every == 0: 612 | print('checking on train') 613 | train_results = check_model(args, t, train_loader, model) 614 | t_losses, t_samples, t_batch_data, t_avg_iou = train_results 615 | 616 | checkpoint['train_batch_data'].append(t_batch_data) 617 | checkpoint['train_samples'].append(t_samples) 618 | checkpoint['checkpoint_ts'].append(t) 619 | checkpoint['train_iou'].append(t_avg_iou) 620 | 621 | print('checking on val') 622 | val_results = check_model(args, t, val_loader, model) 623 | val_losses, val_samples, val_batch_data, val_avg_iou = val_results 624 | checkpoint['val_samples'].append(val_samples) 625 | checkpoint['val_batch_data'].append(val_batch_data) 626 | checkpoint['val_iou'].append(val_avg_iou) 627 | 628 | print('train iou: ', t_avg_iou) 629 | print('val iou: ', val_avg_iou) 630 | 631 | for k, v in val_losses.items(): 632 | checkpoint['val_losses'][k].append(v) 633 | checkpoint['model_state'] = model.state_dict() 634 | 635 | if obj_discriminator is not None: 636 | checkpoint['d_obj_state'] = obj_discriminator.state_dict() 637 | checkpoint['d_obj_optim_state'] = optimizer_d_obj.state_dict() 638 | 639 | if img_discriminator is not None: 640 | checkpoint['d_img_state'] = img_discriminator.state_dict() 641 | checkpoint['d_img_optim_state'] = optimizer_d_img.state_dict() 642 | 643 | checkpoint['optim_state'] = optimizer.state_dict() 644 | checkpoint['counters']['t'] = t 645 | checkpoint['counters']['epoch'] = epoch 646 | checkpoint_path = os.path.join(args.output_dir, 647 | '%s_with_model.pt' % args.checkpoint_name) 648 | print('Saving checkpoint to ', checkpoint_path) 649 | torch.save(checkpoint, checkpoint_path) 650 | 651 | # Save another checkpoint without any model or optim state 652 | checkpoint_path = os.path.join(args.output_dir, 653 | '%s_no_model.pt' % args.checkpoint_name) 654 | key_blacklist = ['model_state', 'optim_state', 'model_best_state', 655 | 'd_obj_state', 'd_obj_optim_state', 'd_obj_best_state', 656 | 'd_img_state', 'd_img_optim_state', 'd_img_best_state'] 657 | small_checkpoint = {} 658 | for k, v in checkpoint.items(): 659 | if k not in key_blacklist: 660 | small_checkpoint[k] = v 661 | torch.save(small_checkpoint, checkpoint_path) 662 | 663 | 664 | if __name__ == '__main__': 665 | args = parser.parse_args() 666 | main(args) 667 | 668 | -------------------------------------------------------------------------------- /sg2im/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | -------------------------------------------------------------------------------- /sg2im/bilinear.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | from sg2im.utils import timeit 20 | 21 | 22 | """ 23 | Functions for performing differentiable bilinear cropping of images, for use in 24 | the object discriminator 25 | """ 26 | 27 | 28 | def crop_bbox_batch(feats, bbox, bbox_to_feats, HH, WW=None, backend='cudnn'): 29 | """ 30 | Inputs: 31 | - feats: FloatTensor of shape (N, C, H, W) 32 | - bbox: FloatTensor of shape (B, 4) giving bounding box coordinates 33 | - bbox_to_feats: LongTensor of shape (B,) mapping boxes to feature maps; 34 | each element is in the range [0, N) and bbox_to_feats[b] = i means that 35 | bbox[b] will be cropped from feats[i]. 36 | - HH, WW: Size of the output crops 37 | 38 | Returns: 39 | - crops: FloatTensor of shape (B, C, HH, WW) where crops[i] uses bbox[i] to 40 | crop from feats[bbox_to_feats[i]]. 41 | """ 42 | if backend == 'cudnn': 43 | return crop_bbox_batch_cudnn(feats, bbox, bbox_to_feats, HH, WW) 44 | N, C, H, W = feats.size() 45 | B = bbox.size(0) 46 | if WW is None: WW = HH 47 | dtype, device = feats.dtype, feats.device 48 | crops = torch.zeros(B, C, HH, WW, dtype=dtype, device=device) 49 | for i in range(N): 50 | idx = (bbox_to_feats.data == i).nonzero() 51 | if idx.dim() == 0: 52 | continue 53 | idx = idx.view(-1) 54 | n = idx.size(0) 55 | cur_feats = feats[i].view(1, C, H, W).expand(n, C, H, W).contiguous() 56 | cur_bbox = bbox[idx] 57 | cur_crops = crop_bbox(cur_feats, cur_bbox, HH, WW) 58 | crops[idx] = cur_crops 59 | return crops 60 | 61 | 62 | def _invperm(p): 63 | N = p.size(0) 64 | eye = torch.arange(0, N).type_as(p) 65 | pp = (eye[:, None] == p).nonzero()[:, 1] 66 | return pp 67 | 68 | 69 | def crop_bbox_batch_cudnn(feats, bbox, bbox_to_feats, HH, WW=None): 70 | N, C, H, W = feats.size() 71 | B = bbox.size(0) 72 | if WW is None: WW = HH 73 | dtype = feats.data.type() 74 | 75 | feats_flat, bbox_flat, all_idx = [], [], [] 76 | for i in range(N): 77 | idx = (bbox_to_feats.data == i).nonzero() 78 | if idx.dim() == 0: 79 | continue 80 | idx = idx.view(-1) 81 | n = idx.size(0) 82 | cur_feats = feats[i].view(1, C, H, W).expand(n, C, H, W).contiguous() 83 | cur_bbox = bbox[idx] 84 | 85 | feats_flat.append(cur_feats) 86 | bbox_flat.append(cur_bbox) 87 | all_idx.append(idx) 88 | 89 | feats_flat = torch.cat(feats_flat, dim=0) 90 | bbox_flat = torch.cat(bbox_flat, dim=0) 91 | crops = crop_bbox(feats_flat, bbox_flat, HH, WW, backend='cudnn') 92 | 93 | # If the crops were sequential (all_idx is identity permutation) then we can 94 | # simply return them; otherwise we need to permute crops by the inverse 95 | # permutation from all_idx. 96 | all_idx = torch.cat(all_idx, dim=0) 97 | eye = torch.arange(0, B).type_as(all_idx) 98 | if (all_idx == eye).all(): 99 | return crops 100 | return crops[_invperm(all_idx)] 101 | 102 | 103 | def crop_bbox(feats, bbox, HH, WW=None, backend='cudnn'): 104 | """ 105 | Take differentiable crops of feats specified by bbox. 106 | 107 | Inputs: 108 | - feats: Tensor of shape (N, C, H, W) 109 | - bbox: Bounding box coordinates of shape (N, 4) in the format 110 | [x0, y0, x1, y1] in the [0, 1] coordinate space. 111 | - HH, WW: Size of the output crops. 112 | 113 | Returns: 114 | - crops: Tensor of shape (N, C, HH, WW) where crops[i] is the portion of 115 | feats[i] specified by bbox[i], reshaped to (HH, WW) using bilinear sampling. 116 | """ 117 | N = feats.size(0) 118 | assert bbox.size(0) == N 119 | assert bbox.size(1) == 4 120 | if WW is None: WW = HH 121 | if backend == 'cudnn': 122 | # Change box from [0, 1] to [-1, 1] coordinate system 123 | bbox = 2 * bbox - 1 124 | x0, y0 = bbox[:, 0], bbox[:, 1] 125 | x1, y1 = bbox[:, 2], bbox[:, 3] 126 | X = tensor_linspace(x0, x1, steps=WW).view(N, 1, WW).expand(N, HH, WW) 127 | Y = tensor_linspace(y0, y1, steps=HH).view(N, HH, 1).expand(N, HH, WW) 128 | if backend == 'jj': 129 | return bilinear_sample(feats, X, Y) 130 | elif backend == 'cudnn': 131 | grid = torch.stack([X, Y], dim=3) 132 | return F.grid_sample(feats, grid) 133 | 134 | 135 | 136 | def uncrop_bbox(feats, bbox, H, W=None, fill_value=0): 137 | """ 138 | Inverse operation to crop_bbox; construct output images where the feature maps 139 | from feats have been reshaped and placed into the positions specified by bbox. 140 | 141 | Inputs: 142 | - feats: Tensor of shape (N, C, HH, WW) 143 | - bbox: Bounding box coordinates of shape (N, 4) in the format 144 | [x0, y0, x1, y1] in the [0, 1] coordinate space. 145 | - H, W: Size of output. 146 | - fill_value: Portions of the output image that are outside the bounding box 147 | will be filled with this value. 148 | 149 | Returns: 150 | - out: Tensor of shape (N, C, H, W) where the portion of out[i] given by 151 | bbox[i] contains feats[i], reshaped using bilinear sampling. 152 | """ 153 | N, C = feats.size(0), feats.size(1) 154 | assert bbox.size(0) == N 155 | assert bbox.size(1) == 4 156 | if W is None: H = W 157 | 158 | x0, y0 = bbox[:, 0], bbox[:, 1] 159 | x1, y1 = bbox[:, 2], bbox[:, 3] 160 | ww = x1 - x0 161 | hh = y1 - y0 162 | 163 | x0 = x0.contiguous().view(N, 1).expand(N, H) 164 | x1 = x1.contiguous().view(N, 1).expand(N, H) 165 | ww = ww.view(N, 1).expand(N, H) 166 | 167 | y0 = y0.contiguous().view(N, 1).expand(N, W) 168 | y1 = y1.contiguous().view(N, 1).expand(N, W) 169 | hh = hh.view(N, 1).expand(N, W) 170 | 171 | X = torch.linspace(0, 1, steps=W).view(1, W).expand(N, W).to(feats) 172 | Y = torch.linspace(0, 1, steps=H).view(1, H).expand(N, H).to(feats) 173 | 174 | X = (X - x0) / ww 175 | Y = (Y - y0) / hh 176 | 177 | # For ByteTensors, (x + y).clamp(max=1) gives logical_or 178 | X_out_mask = ((X < 0) + (X > 1)).view(N, 1, W).expand(N, H, W) 179 | Y_out_mask = ((Y < 0) + (Y > 1)).view(N, H, 1).expand(N, H, W) 180 | out_mask = (X_out_mask + Y_out_mask).clamp(max=1) 181 | out_mask = out_mask.view(N, 1, H, W).expand(N, C, H, W) 182 | 183 | X = X.view(N, 1, W).expand(N, H, W) 184 | Y = Y.view(N, H, 1).expand(N, H, W) 185 | 186 | out = bilinear_sample(feats, X, Y) 187 | out[out_mask] = fill_value 188 | return out 189 | 190 | 191 | def bilinear_sample(feats, X, Y): 192 | """ 193 | Perform bilinear sampling on the features in feats using the sampling grid 194 | given by X and Y. 195 | 196 | Inputs: 197 | - feats: Tensor holding input feature map, of shape (N, C, H, W) 198 | - X, Y: Tensors holding x and y coordinates of the sampling 199 | grids; both have shape shape (N, HH, WW) and have elements in the range [0, 1]. 200 | Returns: 201 | - out: Tensor of shape (B, C, HH, WW) where out[i] is computed 202 | by sampling from feats[idx[i]] using the sampling grid (X[i], Y[i]). 203 | """ 204 | N, C, H, W = feats.size() 205 | assert X.size() == Y.size() 206 | assert X.size(0) == N 207 | _, HH, WW = X.size() 208 | 209 | X = X.mul(W) 210 | Y = Y.mul(H) 211 | 212 | # Get the x and y coordinates for the four samples 213 | x0 = X.floor().clamp(min=0, max=W-1) 214 | x1 = (x0 + 1).clamp(min=0, max=W-1) 215 | y0 = Y.floor().clamp(min=0, max=H-1) 216 | y1 = (y0 + 1).clamp(min=0, max=H-1) 217 | 218 | # In numpy we could do something like feats[i, :, y0, x0] to pull out 219 | # the elements of feats at coordinates y0 and x0, but PyTorch doesn't 220 | # yet support this style of indexing. Instead we have to use the gather 221 | # method, which only allows us to index along one dimension at a time; 222 | # therefore we will collapse the features (BB, C, H, W) into (BB, C, H * W) 223 | # and index along the last dimension. Below we generate linear indices into 224 | # the collapsed last dimension for each of the four combinations we need. 225 | y0x0_idx = (W * y0 + x0).view(N, 1, HH * WW).expand(N, C, HH * WW) 226 | y1x0_idx = (W * y1 + x0).view(N, 1, HH * WW).expand(N, C, HH * WW) 227 | y0x1_idx = (W * y0 + x1).view(N, 1, HH * WW).expand(N, C, HH * WW) 228 | y1x1_idx = (W * y1 + x1).view(N, 1, HH * WW).expand(N, C, HH * WW) 229 | 230 | # Actually use gather to pull out the values from feats corresponding 231 | # to our four samples, then reshape them to (BB, C, HH, WW) 232 | feats_flat = feats.view(N, C, H * W) 233 | v1 = feats_flat.gather(2, y0x0_idx.long()).view(N, C, HH, WW) 234 | v2 = feats_flat.gather(2, y1x0_idx.long()).view(N, C, HH, WW) 235 | v3 = feats_flat.gather(2, y0x1_idx.long()).view(N, C, HH, WW) 236 | v4 = feats_flat.gather(2, y1x1_idx.long()).view(N, C, HH, WW) 237 | 238 | # Compute the weights for the four samples 239 | w1 = ((x1 - X) * (y1 - Y)).view(N, 1, HH, WW).expand(N, C, HH, WW) 240 | w2 = ((x1 - X) * (Y - y0)).view(N, 1, HH, WW).expand(N, C, HH, WW) 241 | w3 = ((X - x0) * (y1 - Y)).view(N, 1, HH, WW).expand(N, C, HH, WW) 242 | w4 = ((X - x0) * (Y - y0)).view(N, 1, HH, WW).expand(N, C, HH, WW) 243 | 244 | # Multiply the samples by the weights to give our interpolated results. 245 | out = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4 246 | return out 247 | 248 | 249 | def tensor_linspace(start, end, steps=10): 250 | """ 251 | Vectorized version of torch.linspace. 252 | 253 | Inputs: 254 | - start: Tensor of any shape 255 | - end: Tensor of the same shape as start 256 | - steps: Integer 257 | 258 | Returns: 259 | - out: Tensor of shape start.size() + (steps,), such that 260 | out.select(-1, 0) == start, out.select(-1, -1) == end, 261 | and the other elements of out linearly interpolate between 262 | start and end. 263 | """ 264 | assert start.size() == end.size() 265 | view_size = start.size() + (1,) 266 | w_size = (1,) * start.dim() + (steps,) 267 | out_size = start.size() + (steps,) 268 | 269 | start_w = torch.linspace(1, 0, steps=steps).to(start) 270 | start_w = start_w.view(w_size).expand(out_size) 271 | end_w = torch.linspace(0, 1, steps=steps).to(start) 272 | end_w = end_w.view(w_size).expand(out_size) 273 | 274 | start = start.contiguous().view(view_size).expand(out_size) 275 | end = end.contiguous().view(view_size).expand(out_size) 276 | 277 | out = start_w * start + end_w * end 278 | return out 279 | 280 | 281 | if __name__ == '__main__': 282 | import numpy as np 283 | from scipy.misc import imread, imsave, imresize 284 | 285 | cat = imresize(imread('cat.jpg'), (256, 256)) 286 | dog = imresize(imread('dog.jpg'), (256, 256)) 287 | feats = torch.stack([ 288 | torch.from_numpy(cat.transpose(2, 0, 1).astype(np.float32)), 289 | torch.from_numpy(dog.transpose(2, 0, 1).astype(np.float32))], 290 | dim=0) 291 | 292 | boxes = torch.FloatTensor([ 293 | [0, 0, 1, 1], 294 | [0.25, 0.25, 0.75, 0.75], 295 | [0, 0, 0.5, 0.5], 296 | ]) 297 | 298 | box_to_feats = torch.LongTensor([1, 0, 1]).cuda() 299 | 300 | feats, boxes = feats.cuda(), boxes.cuda() 301 | crops = crop_bbox_batch_cudnn(feats, boxes, box_to_feats, 128) 302 | for i in range(crops.size(0)): 303 | crop_np = crops.data[i].cpu().numpy().transpose(1, 2, 0).astype(np.uint8) 304 | imsave('out%d.png' % i, crop_np) 305 | -------------------------------------------------------------------------------- /sg2im/box_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import torch 18 | 19 | """ 20 | Utilities for dealing with bounding boxes 21 | """ 22 | 23 | 24 | def apply_box_transform(anchors, transforms): 25 | """ 26 | Apply box transforms to a set of anchor boxes. 27 | 28 | Inputs: 29 | - anchors: Anchor boxes of shape (N, 4), where each anchor is specified 30 | in the form [xc, yc, w, h] 31 | - transforms: Box transforms of shape (N, 4) where each transform is 32 | specified as [tx, ty, tw, th] 33 | 34 | Returns: 35 | - boxes: Transformed boxes of shape (N, 4) where each box is in the 36 | format [xc, yc, w, h] 37 | """ 38 | # Unpack anchors 39 | xa, ya = anchors[:, 0], anchors[:, 1] 40 | wa, ha = anchors[:, 2], anchors[:, 3] 41 | 42 | # Unpack transforms 43 | tx, ty = transforms[:, 0], transforms[:, 1] 44 | tw, th = transforms[:, 2], transforms[:, 3] 45 | 46 | x = xa + tx * wa 47 | y = ya + ty * ha 48 | w = wa * tw.exp() 49 | h = ha * th.exp() 50 | 51 | boxes = torch.stack([x, y, w, h], dim=1) 52 | return boxes 53 | 54 | 55 | def invert_box_transform(anchors, boxes): 56 | """ 57 | Compute the box transform that, when applied to anchors, would give boxes. 58 | 59 | Inputs: 60 | - anchors: Box anchors of shape (N, 4) in the format [xc, yc, w, h] 61 | - boxes: Target boxes of shape (N, 4) in the format [xc, yc, w, h] 62 | 63 | Returns: 64 | - transforms: Box transforms of shape (N, 4) in the format [tx, ty, tw, th] 65 | """ 66 | # Unpack anchors 67 | xa, ya = anchors[:, 0], anchors[:, 1] 68 | wa, ha = anchors[:, 2], anchors[:, 3] 69 | 70 | # Unpack boxes 71 | x, y = boxes[:, 0], boxes[:, 1] 72 | w, h = boxes[:, 2], boxes[:, 3] 73 | 74 | tx = (x - xa) / wa 75 | ty = (y - ya) / ha 76 | tw = w.log() - wa.log() 77 | th = h.log() - ha.log() 78 | 79 | transforms = torch.stack([tx, ty, tw, th], dim=1) 80 | return transforms 81 | 82 | 83 | def centers_to_extents(boxes): 84 | """ 85 | Convert boxes from [xc, yc, w, h] format to [x0, y0, x1, y1] format 86 | 87 | Input: 88 | - boxes: Input boxes of shape (N, 4) in [xc, yc, w, h] format 89 | 90 | Returns: 91 | - boxes: Output boxes of shape (N, 4) in [x0, y0, x1, y1] format 92 | """ 93 | xc, yc = boxes[:, 0], boxes[:, 1] 94 | w, h = boxes[:, 2], boxes[:, 3] 95 | 96 | x0 = xc - w / 2 97 | x1 = x0 + w 98 | y0 = yc - h / 2 99 | y1 = y0 + h 100 | 101 | boxes_out = torch.stack([x0, y0, x1, y1], dim=1) 102 | return boxes_out 103 | 104 | 105 | def extents_to_centers(boxes): 106 | """ 107 | Convert boxes from [x0, y0, x1, y1] format to [xc, yc, w, h] format 108 | 109 | Input: 110 | - boxes: Input boxes of shape (N, 4) in [x0, y0, x1, y1] format 111 | 112 | Returns: 113 | - boxes: Output boxes of shape (N, 4) in [xc, yc, w, h] format 114 | """ 115 | x0, y0 = boxes[:, 0], boxes[:, 1] 116 | x1, y1 = boxes[:, 2], boxes[:, 3] 117 | 118 | xc = 0.5 * (x0 + x1) 119 | yc = 0.5 * (y0 + y1) 120 | w = x1 - x0 121 | h = y1 - y0 122 | 123 | boxes_out = torch.stack([xc, yc, w, h], dim=1) 124 | return boxes_out 125 | 126 | -------------------------------------------------------------------------------- /sg2im/crn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | from sg2im.layers import get_normalization_2d 22 | from sg2im.layers import get_activation 23 | from sg2im.utils import timeit, lineno, get_gpu_memory 24 | 25 | 26 | """ 27 | Cascaded refinement network architecture, as described in: 28 | 29 | Qifeng Chen and Vladlen Koltun, 30 | "Photographic Image Synthesis with Cascaded Refinement Networks", 31 | ICCV 2017 32 | """ 33 | 34 | 35 | class RefinementModule(nn.Module): 36 | def __init__(self, layout_dim, input_dim, output_dim, 37 | normalization='instance', activation='leakyrelu'): 38 | super(RefinementModule, self).__init__() 39 | 40 | layers = [] 41 | layers.append(nn.Conv2d(layout_dim + input_dim, output_dim, 42 | kernel_size=3, padding=1)) 43 | layers.append(get_normalization_2d(output_dim, normalization)) 44 | layers.append(get_activation(activation)) 45 | layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1)) 46 | layers.append(get_normalization_2d(output_dim, normalization)) 47 | layers.append(get_activation(activation)) 48 | layers = [layer for layer in layers if layer is not None] 49 | for layer in layers: 50 | if isinstance(layer, nn.Conv2d): 51 | nn.init.kaiming_normal_(layer.weight) 52 | self.net = nn.Sequential(*layers) 53 | 54 | def forward(self, layout, feats): 55 | _, _, HH, WW = layout.size() 56 | _, _, H, W = feats.size() 57 | assert HH >= H 58 | if HH > H: 59 | factor = round(HH // H) 60 | assert HH % factor == 0 61 | assert WW % factor == 0 and WW // factor == W 62 | layout = F.avg_pool2d(layout, kernel_size=factor, stride=factor) 63 | net_input = torch.cat([layout, feats], dim=1) 64 | out = self.net(net_input) 65 | return out 66 | 67 | 68 | class RefinementNetwork(nn.Module): 69 | def __init__(self, dims, normalization='instance', activation='leakyrelu'): 70 | super(RefinementNetwork, self).__init__() 71 | layout_dim = dims[0] 72 | self.refinement_modules = nn.ModuleList() 73 | for i in range(1, len(dims)): 74 | input_dim = 1 if i == 1 else dims[i - 1] 75 | output_dim = dims[i] 76 | mod = RefinementModule(layout_dim, input_dim, output_dim, 77 | normalization=normalization, activation=activation) 78 | self.refinement_modules.append(mod) 79 | output_conv_layers = [ 80 | nn.Conv2d(dims[-1], dims[-1], kernel_size=3, padding=1), 81 | get_activation(activation), 82 | nn.Conv2d(dims[-1], 3, kernel_size=1, padding=0) 83 | ] 84 | nn.init.kaiming_normal_(output_conv_layers[0].weight) 85 | nn.init.kaiming_normal_(output_conv_layers[2].weight) 86 | self.output_conv = nn.Sequential(*output_conv_layers) 87 | 88 | def forward(self, layout): 89 | """ 90 | Output will have same size as layout 91 | """ 92 | # H, W = self.output_size 93 | N, _, H, W = layout.size() 94 | self.layout = layout 95 | 96 | # Figure out size of input 97 | input_H, input_W = H, W 98 | for _ in range(len(self.refinement_modules)): 99 | input_H //= 2 100 | input_W //= 2 101 | 102 | assert input_H != 0 103 | assert input_W != 0 104 | 105 | feats = torch.zeros(N, 1, input_H, input_W).to(layout) 106 | for mod in self.refinement_modules: 107 | feats = F.upsample(feats, scale_factor=2, mode='nearest') 108 | feats = mod(layout, feats) 109 | 110 | out = self.output_conv(feats) 111 | return out 112 | 113 | -------------------------------------------------------------------------------- /sg2im/data/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from .utils import imagenet_preprocess, imagenet_deprocess 18 | from .utils import imagenet_deprocess_batch 19 | -------------------------------------------------------------------------------- /sg2im/data/coco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import json, os, random, math 18 | from collections import defaultdict 19 | 20 | import torch 21 | from torch.utils.data import Dataset 22 | import torchvision.transforms as T 23 | 24 | import numpy as np 25 | import PIL 26 | from skimage.transform import resize as imresize 27 | import pycocotools.mask as mask_utils 28 | 29 | from .utils import imagenet_preprocess, Resize 30 | 31 | 32 | class CocoSceneGraphDataset(Dataset): 33 | def __init__(self, image_dir, instances_json, stuff_json=None, 34 | stuff_only=True, image_size=(64, 64), mask_size=16, 35 | normalize_images=True, max_samples=None, 36 | include_relationships=True, min_object_size=0.02, 37 | min_objects_per_image=3, max_objects_per_image=8, 38 | include_other=False, instance_whitelist=None, stuff_whitelist=None): 39 | """ 40 | A PyTorch Dataset for loading Coco and Coco-Stuff annotations and converting 41 | them to scene graphs on the fly. 42 | 43 | Inputs: 44 | - image_dir: Path to a directory where images are held 45 | - instances_json: Path to a JSON file giving COCO annotations 46 | - stuff_json: (optional) Path to a JSON file giving COCO-Stuff annotations 47 | - stuff_only: (optional, default True) If True then only iterate over 48 | images which appear in stuff_json; if False then iterate over all images 49 | in instances_json. 50 | - image_size: Size (H, W) at which to load images. Default (64, 64). 51 | - mask_size: Size M for object segmentation masks; default 16. 52 | - normalize_image: If True then normalize images by subtracting ImageNet 53 | mean pixel and dividing by ImageNet std pixel. 54 | - max_samples: If None use all images. Other wise only use images in the 55 | range [0, max_samples). Default None. 56 | - include_relationships: If True then include spatial relationships; if 57 | False then only include the trivial __in_image__ relationship. 58 | - min_object_size: Ignore objects whose bounding box takes up less than 59 | this fraction of the image. 60 | - min_objects_per_image: Ignore images which have fewer than this many 61 | object annotations. 62 | - max_objects_per_image: Ignore images which have more than this many 63 | object annotations. 64 | - include_other: If True, include COCO-Stuff annotations which have category 65 | "other". Default is False, because I found that these were really noisy 66 | and pretty much impossible for the system to model. 67 | - instance_whitelist: None means use all instance categories. Otherwise a 68 | list giving a whitelist of instance category names to use. 69 | - stuff_whitelist: None means use all stuff categories. Otherwise a list 70 | giving a whitelist of stuff category names to use. 71 | """ 72 | super(Dataset, self).__init__() 73 | 74 | if stuff_only and stuff_json is None: 75 | print('WARNING: Got stuff_only=True but stuff_json=None.') 76 | print('Falling back to stuff_only=False.') 77 | 78 | self.image_dir = image_dir 79 | self.mask_size = mask_size 80 | self.max_samples = max_samples 81 | self.normalize_images = normalize_images 82 | self.include_relationships = include_relationships 83 | self.set_image_size(image_size) 84 | 85 | with open(instances_json, 'r') as f: 86 | instances_data = json.load(f) 87 | 88 | stuff_data = None 89 | if stuff_json is not None and stuff_json != '': 90 | with open(stuff_json, 'r') as f: 91 | stuff_data = json.load(f) 92 | 93 | self.image_ids = [] 94 | self.image_id_to_filename = {} 95 | self.image_id_to_size = {} 96 | for image_data in instances_data['images']: 97 | image_id = image_data['id'] 98 | filename = image_data['file_name'] 99 | width = image_data['width'] 100 | height = image_data['height'] 101 | self.image_ids.append(image_id) 102 | self.image_id_to_filename[image_id] = filename 103 | self.image_id_to_size[image_id] = (width, height) 104 | 105 | self.vocab = { 106 | 'object_name_to_idx': {}, 107 | 'pred_name_to_idx': {}, 108 | } 109 | object_idx_to_name = {} 110 | all_instance_categories = [] 111 | for category_data in instances_data['categories']: 112 | category_id = category_data['id'] 113 | category_name = category_data['name'] 114 | all_instance_categories.append(category_name) 115 | object_idx_to_name[category_id] = category_name 116 | self.vocab['object_name_to_idx'][category_name] = category_id 117 | all_stuff_categories = [] 118 | if stuff_data: 119 | for category_data in stuff_data['categories']: 120 | category_name = category_data['name'] 121 | category_id = category_data['id'] 122 | all_stuff_categories.append(category_name) 123 | object_idx_to_name[category_id] = category_name 124 | self.vocab['object_name_to_idx'][category_name] = category_id 125 | 126 | if instance_whitelist is None: 127 | instance_whitelist = all_instance_categories 128 | if stuff_whitelist is None: 129 | stuff_whitelist = all_stuff_categories 130 | category_whitelist = set(instance_whitelist) | set(stuff_whitelist) 131 | 132 | # Add object data from instances 133 | self.image_id_to_objects = defaultdict(list) 134 | for object_data in instances_data['annotations']: 135 | image_id = object_data['image_id'] 136 | _, _, w, h = object_data['bbox'] 137 | W, H = self.image_id_to_size[image_id] 138 | box_area = (w * h) / (W * H) 139 | box_ok = box_area > min_object_size 140 | object_name = object_idx_to_name[object_data['category_id']] 141 | category_ok = object_name in category_whitelist 142 | other_ok = object_name != 'other' or include_other 143 | if box_ok and category_ok and other_ok: 144 | self.image_id_to_objects[image_id].append(object_data) 145 | 146 | # Add object data from stuff 147 | if stuff_data: 148 | image_ids_with_stuff = set() 149 | for object_data in stuff_data['annotations']: 150 | image_id = object_data['image_id'] 151 | image_ids_with_stuff.add(image_id) 152 | _, _, w, h = object_data['bbox'] 153 | W, H = self.image_id_to_size[image_id] 154 | box_area = (w * h) / (W * H) 155 | box_ok = box_area > min_object_size 156 | object_name = object_idx_to_name[object_data['category_id']] 157 | category_ok = object_name in category_whitelist 158 | other_ok = object_name != 'other' or include_other 159 | if box_ok and category_ok and other_ok: 160 | self.image_id_to_objects[image_id].append(object_data) 161 | if stuff_only: 162 | new_image_ids = [] 163 | for image_id in self.image_ids: 164 | if image_id in image_ids_with_stuff: 165 | new_image_ids.append(image_id) 166 | self.image_ids = new_image_ids 167 | 168 | all_image_ids = set(self.image_id_to_filename.keys()) 169 | image_ids_to_remove = all_image_ids - image_ids_with_stuff 170 | for image_id in image_ids_to_remove: 171 | self.image_id_to_filename.pop(image_id, None) 172 | self.image_id_to_size.pop(image_id, None) 173 | self.image_id_to_objects.pop(image_id, None) 174 | 175 | # COCO category labels start at 1, so use 0 for __image__ 176 | self.vocab['object_name_to_idx']['__image__'] = 0 177 | 178 | # Build object_idx_to_name 179 | name_to_idx = self.vocab['object_name_to_idx'] 180 | assert len(name_to_idx) == len(set(name_to_idx.values())) 181 | max_object_idx = max(name_to_idx.values()) 182 | idx_to_name = ['NONE'] * (1 + max_object_idx) 183 | for name, idx in self.vocab['object_name_to_idx'].items(): 184 | idx_to_name[idx] = name 185 | self.vocab['object_idx_to_name'] = idx_to_name 186 | 187 | # Prune images that have too few or too many objects 188 | new_image_ids = [] 189 | total_objs = 0 190 | for image_id in self.image_ids: 191 | num_objs = len(self.image_id_to_objects[image_id]) 192 | total_objs += num_objs 193 | if min_objects_per_image <= num_objs <= max_objects_per_image: 194 | new_image_ids.append(image_id) 195 | self.image_ids = new_image_ids 196 | 197 | self.vocab['pred_idx_to_name'] = [ 198 | '__in_image__', 199 | 'left of', 200 | 'right of', 201 | 'above', 202 | 'below', 203 | 'inside', 204 | 'surrounding', 205 | ] 206 | self.vocab['pred_name_to_idx'] = {} 207 | for idx, name in enumerate(self.vocab['pred_idx_to_name']): 208 | self.vocab['pred_name_to_idx'][name] = idx 209 | 210 | def set_image_size(self, image_size): 211 | print('called set_image_size', image_size) 212 | transform = [Resize(image_size), T.ToTensor()] 213 | if self.normalize_images: 214 | transform.append(imagenet_preprocess()) 215 | self.transform = T.Compose(transform) 216 | self.image_size = image_size 217 | 218 | def total_objects(self): 219 | total_objs = 0 220 | for i, image_id in enumerate(self.image_ids): 221 | if self.max_samples and i >= self.max_samples: 222 | break 223 | num_objs = len(self.image_id_to_objects[image_id]) 224 | total_objs += num_objs 225 | return total_objs 226 | 227 | def __len__(self): 228 | if self.max_samples is None: 229 | return len(self.image_ids) 230 | return min(len(self.image_ids), self.max_samples) 231 | 232 | def __getitem__(self, index): 233 | """ 234 | Get the pixels of an image, and a random synthetic scene graph for that 235 | image constructed on-the-fly from its COCO object annotations. We assume 236 | that the image will have height H, width W, C channels; there will be O 237 | object annotations, each of which will have both a bounding box and a 238 | segmentation mask of shape (M, M). There will be T triples in the scene 239 | graph. 240 | 241 | Returns a tuple of: 242 | - image: FloatTensor of shape (C, H, W) 243 | - objs: LongTensor of shape (O,) 244 | - boxes: FloatTensor of shape (O, 4) giving boxes for objects in 245 | (x0, y0, x1, y1) format, in a [0, 1] coordinate system 246 | - masks: LongTensor of shape (O, M, M) giving segmentation masks for 247 | objects, where 0 is background and 1 is object. 248 | - triples: LongTensor of shape (T, 3) where triples[t] = [i, p, j] 249 | means that (objs[i], p, objs[j]) is a triple. 250 | """ 251 | image_id = self.image_ids[index] 252 | 253 | filename = self.image_id_to_filename[image_id] 254 | image_path = os.path.join(self.image_dir, filename) 255 | with open(image_path, 'rb') as f: 256 | with PIL.Image.open(f) as image: 257 | WW, HH = image.size 258 | image = self.transform(image.convert('RGB')) 259 | 260 | H, W = self.image_size 261 | objs, boxes, masks = [], [], [] 262 | for object_data in self.image_id_to_objects[image_id]: 263 | objs.append(object_data['category_id']) 264 | x, y, w, h = object_data['bbox'] 265 | x0 = x / WW 266 | y0 = y / HH 267 | x1 = (x + w) / WW 268 | y1 = (y + h) / HH 269 | boxes.append(torch.FloatTensor([x0, y0, x1, y1])) 270 | 271 | # This will give a numpy array of shape (HH, WW) 272 | mask = seg_to_mask(object_data['segmentation'], WW, HH) 273 | 274 | # Crop the mask according to the bounding box, being careful to 275 | # ensure that we don't crop a zero-area region 276 | mx0, mx1 = int(round(x)), int(round(x + w)) 277 | my0, my1 = int(round(y)), int(round(y + h)) 278 | mx1 = max(mx0 + 1, mx1) 279 | my1 = max(my0 + 1, my1) 280 | mask = mask[my0:my1, mx0:mx1] 281 | mask = imresize(255.0 * mask, (self.mask_size, self.mask_size), 282 | mode='constant') 283 | mask = torch.from_numpy((mask > 128).astype(np.int64)) 284 | masks.append(mask) 285 | 286 | # Add dummy __image__ object 287 | objs.append(self.vocab['object_name_to_idx']['__image__']) 288 | boxes.append(torch.FloatTensor([0, 0, 1, 1])) 289 | masks.append(torch.ones(self.mask_size, self.mask_size).long()) 290 | 291 | objs = torch.LongTensor(objs) 292 | boxes = torch.stack(boxes, dim=0) 293 | masks = torch.stack(masks, dim=0) 294 | 295 | box_areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 296 | 297 | # Compute centers of all objects 298 | obj_centers = [] 299 | _, MH, MW = masks.size() 300 | for i, obj_idx in enumerate(objs): 301 | x0, y0, x1, y1 = boxes[i] 302 | mask = (masks[i] == 1) 303 | xs = torch.linspace(x0, x1, MW).view(1, MW).expand(MH, MW) 304 | ys = torch.linspace(y0, y1, MH).view(MH, 1).expand(MH, MW) 305 | if mask.sum() == 0: 306 | mean_x = 0.5 * (x0 + x1) 307 | mean_y = 0.5 * (y0 + y1) 308 | else: 309 | mean_x = xs[mask].mean() 310 | mean_y = ys[mask].mean() 311 | obj_centers.append([mean_x, mean_y]) 312 | obj_centers = torch.FloatTensor(obj_centers) 313 | 314 | # Add triples 315 | triples = [] 316 | num_objs = objs.size(0) 317 | __image__ = self.vocab['object_name_to_idx']['__image__'] 318 | real_objs = [] 319 | if num_objs > 1: 320 | real_objs = (objs != __image__).nonzero().squeeze(1) 321 | for cur in real_objs: 322 | choices = [obj for obj in real_objs if obj != cur] 323 | if len(choices) == 0 or not self.include_relationships: 324 | break 325 | other = random.choice(choices) 326 | if random.random() > 0.5: 327 | s, o = cur, other 328 | else: 329 | s, o = other, cur 330 | 331 | # Check for inside / surrounding 332 | sx0, sy0, sx1, sy1 = boxes[s] 333 | ox0, oy0, ox1, oy1 = boxes[o] 334 | d = obj_centers[s] - obj_centers[o] 335 | theta = math.atan2(d[1], d[0]) 336 | 337 | if sx0 < ox0 and sx1 > ox1 and sy0 < oy0 and sy1 > oy1: 338 | p = 'surrounding' 339 | elif sx0 > ox0 and sx1 < ox1 and sy0 > oy0 and sy1 < oy1: 340 | p = 'inside' 341 | elif theta >= 3 * math.pi / 4 or theta <= -3 * math.pi / 4: 342 | p = 'left of' 343 | elif -3 * math.pi / 4 <= theta < -math.pi / 4: 344 | p = 'above' 345 | elif -math.pi / 4 <= theta < math.pi / 4: 346 | p = 'right of' 347 | elif math.pi / 4 <= theta < 3 * math.pi / 4: 348 | p = 'below' 349 | p = self.vocab['pred_name_to_idx'][p] 350 | triples.append([s, p, o]) 351 | 352 | # Add __in_image__ triples 353 | O = objs.size(0) 354 | in_image = self.vocab['pred_name_to_idx']['__in_image__'] 355 | for i in range(O - 1): 356 | triples.append([i, in_image, O - 1]) 357 | 358 | triples = torch.LongTensor(triples) 359 | return image, objs, boxes, masks, triples 360 | 361 | 362 | def seg_to_mask(seg, width=1.0, height=1.0): 363 | """ 364 | Tiny utility for decoding segmentation masks using the pycocotools API. 365 | """ 366 | if type(seg) == list: 367 | rles = mask_utils.frPyObjects(seg, height, width) 368 | rle = mask_utils.merge(rles) 369 | elif type(seg['counts']) == list: 370 | rle = mask_utils.frPyObjects(seg, height, width) 371 | else: 372 | rle = seg 373 | return mask_utils.decode(rle) 374 | 375 | 376 | def coco_collate_fn(batch): 377 | """ 378 | Collate function to be used when wrapping CocoSceneGraphDataset in a 379 | DataLoader. Returns a tuple of the following: 380 | 381 | - imgs: FloatTensor of shape (N, C, H, W) 382 | - objs: LongTensor of shape (O,) giving object categories 383 | - boxes: FloatTensor of shape (O, 4) 384 | - masks: FloatTensor of shape (O, M, M) 385 | - triples: LongTensor of shape (T, 3) giving triples 386 | - obj_to_img: LongTensor of shape (O,) mapping objects to images 387 | - triple_to_img: LongTensor of shape (T,) mapping triples to images 388 | """ 389 | all_imgs, all_objs, all_boxes, all_masks, all_triples = [], [], [], [], [] 390 | all_obj_to_img, all_triple_to_img = [], [] 391 | obj_offset = 0 392 | for i, (img, objs, boxes, masks, triples) in enumerate(batch): 393 | all_imgs.append(img[None]) 394 | if objs.dim() == 0 or triples.dim() == 0: 395 | continue 396 | O, T = objs.size(0), triples.size(0) 397 | all_objs.append(objs) 398 | all_boxes.append(boxes) 399 | all_masks.append(masks) 400 | triples = triples.clone() 401 | triples[:, 0] += obj_offset 402 | triples[:, 2] += obj_offset 403 | all_triples.append(triples) 404 | 405 | all_obj_to_img.append(torch.LongTensor(O).fill_(i)) 406 | all_triple_to_img.append(torch.LongTensor(T).fill_(i)) 407 | obj_offset += O 408 | 409 | all_imgs = torch.cat(all_imgs) 410 | all_objs = torch.cat(all_objs) 411 | all_boxes = torch.cat(all_boxes) 412 | all_masks = torch.cat(all_masks) 413 | all_triples = torch.cat(all_triples) 414 | all_obj_to_img = torch.cat(all_obj_to_img) 415 | all_triple_to_img = torch.cat(all_triple_to_img) 416 | 417 | out = (all_imgs, all_objs, all_boxes, all_masks, all_triples, 418 | all_obj_to_img, all_triple_to_img) 419 | return out 420 | 421 | -------------------------------------------------------------------------------- /sg2im/data/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import PIL 18 | import torch 19 | import torchvision.transforms as T 20 | 21 | 22 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 23 | IMAGENET_STD = [0.229, 0.224, 0.225] 24 | 25 | INV_IMAGENET_MEAN = [-m for m in IMAGENET_MEAN] 26 | INV_IMAGENET_STD = [1.0 / s for s in IMAGENET_STD] 27 | 28 | 29 | def imagenet_preprocess(): 30 | return T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) 31 | 32 | 33 | def rescale(x): 34 | lo, hi = x.min(), x.max() 35 | return x.sub(lo).div(hi - lo) 36 | 37 | 38 | def imagenet_deprocess(rescale_image=True): 39 | transforms = [ 40 | T.Normalize(mean=[0, 0, 0], std=INV_IMAGENET_STD), 41 | T.Normalize(mean=INV_IMAGENET_MEAN, std=[1.0, 1.0, 1.0]), 42 | ] 43 | if rescale_image: 44 | transforms.append(rescale) 45 | return T.Compose(transforms) 46 | 47 | 48 | def imagenet_deprocess_batch(imgs, rescale=True): 49 | """ 50 | Input: 51 | - imgs: FloatTensor of shape (N, C, H, W) giving preprocessed images 52 | 53 | Output: 54 | - imgs_de: ByteTensor of shape (N, C, H, W) giving deprocessed images 55 | in the range [0, 255] 56 | """ 57 | if isinstance(imgs, torch.autograd.Variable): 58 | imgs = imgs.data 59 | imgs = imgs.cpu().clone() 60 | deprocess_fn = imagenet_deprocess(rescale_image=rescale) 61 | imgs_de = [] 62 | for i in range(imgs.size(0)): 63 | img_de = deprocess_fn(imgs[i])[None] 64 | img_de = img_de.mul(255).clamp(0, 255).byte() 65 | imgs_de.append(img_de) 66 | imgs_de = torch.cat(imgs_de, dim=0) 67 | return imgs_de 68 | 69 | 70 | class Resize(object): 71 | def __init__(self, size, interp=PIL.Image.BILINEAR): 72 | if isinstance(size, tuple): 73 | H, W = size 74 | self.size = (W, H) 75 | else: 76 | self.size = (size, size) 77 | self.interp = interp 78 | 79 | def __call__(self, img): 80 | return img.resize(self.size, self.interp) 81 | 82 | 83 | def unpack_var(v): 84 | if isinstance(v, torch.autograd.Variable): 85 | return v.data 86 | return v 87 | 88 | 89 | def split_graph_batch(triples, obj_data, obj_to_img, triple_to_img): 90 | triples = unpack_var(triples) 91 | obj_data = [unpack_var(o) for o in obj_data] 92 | obj_to_img = unpack_var(obj_to_img) 93 | triple_to_img = unpack_var(triple_to_img) 94 | 95 | triples_out = [] 96 | obj_data_out = [[] for _ in obj_data] 97 | obj_offset = 0 98 | N = obj_to_img.max() + 1 99 | for i in range(N): 100 | o_idxs = (obj_to_img == i).nonzero().view(-1) 101 | t_idxs = (triple_to_img == i).nonzero().view(-1) 102 | 103 | cur_triples = triples[t_idxs].clone() 104 | cur_triples[:, 0] -= obj_offset 105 | cur_triples[:, 2] -= obj_offset 106 | triples_out.append(cur_triples) 107 | 108 | for j, o_data in enumerate(obj_data): 109 | cur_o_data = None 110 | if o_data is not None: 111 | cur_o_data = o_data[o_idxs] 112 | obj_data_out[j].append(cur_o_data) 113 | 114 | obj_offset += o_idxs.size(0) 115 | 116 | return triples_out, obj_data_out 117 | 118 | -------------------------------------------------------------------------------- /sg2im/data/vg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import os 18 | import random 19 | from collections import defaultdict 20 | 21 | import torch 22 | from torch.utils.data import Dataset 23 | import torchvision.transforms as T 24 | 25 | import numpy as np 26 | import h5py 27 | import PIL 28 | 29 | from .utils import imagenet_preprocess, Resize 30 | 31 | 32 | class VgSceneGraphDataset(Dataset): 33 | def __init__(self, vocab, h5_path, image_dir, image_size=(256, 256), 34 | normalize_images=True, max_objects=10, max_samples=None, 35 | include_relationships=True, use_orphaned_objects=True): 36 | super(VgSceneGraphDataset, self).__init__() 37 | 38 | self.image_dir = image_dir 39 | self.image_size = image_size 40 | self.vocab = vocab 41 | self.num_objects = len(vocab['object_idx_to_name']) 42 | self.use_orphaned_objects = use_orphaned_objects 43 | self.max_objects = max_objects 44 | self.max_samples = max_samples 45 | self.include_relationships = include_relationships 46 | 47 | transform = [Resize(image_size), T.ToTensor()] 48 | if normalize_images: 49 | transform.append(imagenet_preprocess()) 50 | self.transform = T.Compose(transform) 51 | 52 | self.data = {} 53 | with h5py.File(h5_path, 'r') as f: 54 | for k, v in f.items(): 55 | if k == 'image_paths': 56 | self.image_paths = list(v) 57 | else: 58 | self.data[k] = torch.IntTensor(np.asarray(v)) 59 | 60 | def __len__(self): 61 | num = self.data['object_names'].size(0) 62 | if self.max_samples is not None: 63 | return min(self.max_samples, num) 64 | return num 65 | 66 | def __getitem__(self, index): 67 | """ 68 | Returns a tuple of: 69 | - image: FloatTensor of shape (C, H, W) 70 | - objs: LongTensor of shape (O,) 71 | - boxes: FloatTensor of shape (O, 4) giving boxes for objects in 72 | (x0, y0, x1, y1) format, in a [0, 1] coordinate system. 73 | - triples: LongTensor of shape (T, 3) where triples[t] = [i, p, j] 74 | means that (objs[i], p, objs[j]) is a triple. 75 | """ 76 | img_path = os.path.join(self.image_dir, self.image_paths[index]) 77 | 78 | with open(img_path, 'rb') as f: 79 | with PIL.Image.open(f) as image: 80 | WW, HH = image.size 81 | image = self.transform(image.convert('RGB')) 82 | 83 | H, W = self.image_size 84 | 85 | # Figure out which objects appear in relationships and which don't 86 | obj_idxs_with_rels = set() 87 | obj_idxs_without_rels = set(range(self.data['objects_per_image'][index].item())) 88 | for r_idx in range(self.data['relationships_per_image'][index]): 89 | s = self.data['relationship_subjects'][index, r_idx].item() 90 | o = self.data['relationship_objects'][index, r_idx].item() 91 | obj_idxs_with_rels.add(s) 92 | obj_idxs_with_rels.add(o) 93 | obj_idxs_without_rels.discard(s) 94 | obj_idxs_without_rels.discard(o) 95 | 96 | obj_idxs = list(obj_idxs_with_rels) 97 | obj_idxs_without_rels = list(obj_idxs_without_rels) 98 | if len(obj_idxs) > self.max_objects - 1: 99 | obj_idxs = random.sample(obj_idxs, self.max_objects) 100 | if len(obj_idxs) < self.max_objects - 1 and self.use_orphaned_objects: 101 | num_to_add = self.max_objects - 1 - len(obj_idxs) 102 | num_to_add = min(num_to_add, len(obj_idxs_without_rels)) 103 | obj_idxs += random.sample(obj_idxs_without_rels, num_to_add) 104 | O = len(obj_idxs) + 1 105 | 106 | objs = torch.LongTensor(O).fill_(-1) 107 | 108 | boxes = torch.FloatTensor([[0, 0, 1, 1]]).repeat(O, 1) 109 | obj_idx_mapping = {} 110 | for i, obj_idx in enumerate(obj_idxs): 111 | objs[i] = self.data['object_names'][index, obj_idx].item() 112 | x, y, w, h = self.data['object_boxes'][index, obj_idx].tolist() 113 | x0 = float(x) / WW 114 | y0 = float(y) / HH 115 | x1 = float(x + w) / WW 116 | y1 = float(y + h) / HH 117 | boxes[i] = torch.FloatTensor([x0, y0, x1, y1]) 118 | obj_idx_mapping[obj_idx] = i 119 | 120 | # The last object will be the special __image__ object 121 | objs[O - 1] = self.vocab['object_name_to_idx']['__image__'] 122 | 123 | triples = [] 124 | for r_idx in range(self.data['relationships_per_image'][index].item()): 125 | if not self.include_relationships: 126 | break 127 | s = self.data['relationship_subjects'][index, r_idx].item() 128 | p = self.data['relationship_predicates'][index, r_idx].item() 129 | o = self.data['relationship_objects'][index, r_idx].item() 130 | s = obj_idx_mapping.get(s, None) 131 | o = obj_idx_mapping.get(o, None) 132 | if s is not None and o is not None: 133 | triples.append([s, p, o]) 134 | 135 | # Add dummy __in_image__ relationships for all objects 136 | in_image = self.vocab['pred_name_to_idx']['__in_image__'] 137 | for i in range(O - 1): 138 | triples.append([i, in_image, O - 1]) 139 | 140 | triples = torch.LongTensor(triples) 141 | return image, objs, boxes, triples 142 | 143 | 144 | def vg_collate_fn(batch): 145 | """ 146 | Collate function to be used when wrapping a VgSceneGraphDataset in a 147 | DataLoader. Returns a tuple of the following: 148 | 149 | - imgs: FloatTensor of shape (N, C, H, W) 150 | - objs: LongTensor of shape (O,) giving categories for all objects 151 | - boxes: FloatTensor of shape (O, 4) giving boxes for all objects 152 | - triples: FloatTensor of shape (T, 3) giving all triples, where 153 | triples[t] = [i, p, j] means that [objs[i], p, objs[j]] is a triple 154 | - obj_to_img: LongTensor of shape (O,) mapping objects to images; 155 | obj_to_img[i] = n means that objs[i] belongs to imgs[n] 156 | - triple_to_img: LongTensor of shape (T,) mapping triples to images; 157 | triple_to_img[t] = n means that triples[t] belongs to imgs[n]. 158 | """ 159 | # batch is a list, and each element is (image, objs, boxes, triples) 160 | all_imgs, all_objs, all_boxes, all_triples = [], [], [], [] 161 | all_obj_to_img, all_triple_to_img = [], [] 162 | obj_offset = 0 163 | for i, (img, objs, boxes, triples) in enumerate(batch): 164 | all_imgs.append(img[None]) 165 | O, T = objs.size(0), triples.size(0) 166 | all_objs.append(objs) 167 | all_boxes.append(boxes) 168 | triples = triples.clone() 169 | triples[:, 0] += obj_offset 170 | triples[:, 2] += obj_offset 171 | all_triples.append(triples) 172 | 173 | all_obj_to_img.append(torch.LongTensor(O).fill_(i)) 174 | all_triple_to_img.append(torch.LongTensor(T).fill_(i)) 175 | obj_offset += O 176 | 177 | all_imgs = torch.cat(all_imgs) 178 | all_objs = torch.cat(all_objs) 179 | all_boxes = torch.cat(all_boxes) 180 | all_triples = torch.cat(all_triples) 181 | all_obj_to_img = torch.cat(all_obj_to_img) 182 | all_triple_to_img = torch.cat(all_triple_to_img) 183 | 184 | out = (all_imgs, all_objs, all_boxes, all_triples, 185 | all_obj_to_img, all_triple_to_img) 186 | return out 187 | 188 | 189 | def vg_uncollate_fn(batch): 190 | """ 191 | Inverse operation to the above. 192 | """ 193 | imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch 194 | out = [] 195 | obj_offset = 0 196 | for i in range(imgs.size(0)): 197 | cur_img = imgs[i] 198 | o_idxs = (obj_to_img == i).nonzero().view(-1) 199 | t_idxs = (triple_to_img == i).nonzero().view(-1) 200 | cur_objs = objs[o_idxs] 201 | cur_boxes = boxes[o_idxs] 202 | cur_triples = triples[t_idxs].clone() 203 | cur_triples[:, 0] -= obj_offset 204 | cur_triples[:, 2] -= obj_offset 205 | obj_offset += cur_objs.size(0) 206 | out.append((cur_img, cur_objs, cur_boxes, cur_triples)) 207 | return out 208 | 209 | -------------------------------------------------------------------------------- /sg2im/discriminators.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | from sg2im.bilinear import crop_bbox_batch 22 | from sg2im.layers import GlobalAvgPool, Flatten, get_activation, build_cnn 23 | 24 | 25 | class PatchDiscriminator(nn.Module): 26 | def __init__(self, arch, normalization='batch', activation='leakyrelu-0.2', 27 | padding='same', pooling='avg', input_size=(128,128), 28 | layout_dim=0): 29 | super(PatchDiscriminator, self).__init__() 30 | input_dim = 3 + layout_dim 31 | arch = 'I%d,%s' % (input_dim, arch) 32 | cnn_kwargs = { 33 | 'arch': arch, 34 | 'normalization': normalization, 35 | 'activation': activation, 36 | 'pooling': pooling, 37 | 'padding': padding, 38 | } 39 | self.cnn, output_dim = build_cnn(**cnn_kwargs) 40 | self.classifier = nn.Conv2d(output_dim, 1, kernel_size=1, stride=1) 41 | 42 | def forward(self, x, layout=None): 43 | if layout is not None: 44 | x = torch.cat([x, layout], dim=1) 45 | return self.cnn(x) 46 | 47 | 48 | class AcDiscriminator(nn.Module): 49 | def __init__(self, vocab, arch, normalization='none', activation='relu', 50 | padding='same', pooling='avg'): 51 | super(AcDiscriminator, self).__init__() 52 | self.vocab = vocab 53 | 54 | cnn_kwargs = { 55 | 'arch': arch, 56 | 'normalization': normalization, 57 | 'activation': activation, 58 | 'pooling': pooling, 59 | 'padding': padding, 60 | } 61 | cnn, D = build_cnn(**cnn_kwargs) 62 | self.cnn = nn.Sequential(cnn, GlobalAvgPool(), nn.Linear(D, 1024)) 63 | num_objects = len(vocab['object_idx_to_name']) 64 | 65 | self.real_classifier = nn.Linear(1024, 1) 66 | self.obj_classifier = nn.Linear(1024, num_objects) 67 | 68 | def forward(self, x, y): 69 | if x.dim() == 3: 70 | x = x[:, None] 71 | vecs = self.cnn(x) 72 | real_scores = self.real_classifier(vecs) 73 | obj_scores = self.obj_classifier(vecs) 74 | ac_loss = F.cross_entropy(obj_scores, y) 75 | return real_scores, ac_loss 76 | 77 | 78 | class AcCropDiscriminator(nn.Module): 79 | def __init__(self, vocab, arch, normalization='none', activation='relu', 80 | object_size=64, padding='same', pooling='avg'): 81 | super(AcCropDiscriminator, self).__init__() 82 | self.vocab = vocab 83 | self.discriminator = AcDiscriminator(vocab, arch, normalization, 84 | activation, padding, pooling) 85 | self.object_size = object_size 86 | 87 | def forward(self, imgs, objs, boxes, obj_to_img): 88 | crops = crop_bbox_batch(imgs, boxes, obj_to_img, self.object_size) 89 | real_scores, ac_loss = self.discriminator(crops, objs) 90 | return real_scores, ac_loss 91 | -------------------------------------------------------------------------------- /sg2im/graph.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import torch 18 | import torch.nn as nn 19 | from sg2im.layers import build_mlp 20 | 21 | """ 22 | PyTorch modules for dealing with graphs. 23 | """ 24 | 25 | 26 | def _init_weights(module): 27 | if hasattr(module, 'weight'): 28 | if isinstance(module, nn.Linear): 29 | nn.init.kaiming_normal_(module.weight) 30 | 31 | 32 | class GraphTripleConv(nn.Module): 33 | """ 34 | A single layer of scene graph convolution. 35 | """ 36 | def __init__(self, input_dim, output_dim=None, hidden_dim=512, 37 | pooling='avg', mlp_normalization='none'): 38 | super(GraphTripleConv, self).__init__() 39 | if output_dim is None: 40 | output_dim = input_dim 41 | self.input_dim = input_dim 42 | self.output_dim = output_dim 43 | self.hidden_dim = hidden_dim 44 | 45 | assert pooling in ['sum', 'avg'], 'Invalid pooling "%s"' % pooling 46 | self.pooling = pooling 47 | net1_layers = [3 * input_dim, hidden_dim, 2 * hidden_dim + output_dim] 48 | net1_layers = [l for l in net1_layers if l is not None] 49 | self.net1 = build_mlp(net1_layers, batch_norm=mlp_normalization) 50 | self.net1.apply(_init_weights) 51 | 52 | net2_layers = [hidden_dim, hidden_dim, output_dim] 53 | self.net2 = build_mlp(net2_layers, batch_norm=mlp_normalization) 54 | self.net2.apply(_init_weights) 55 | 56 | def forward(self, obj_vecs, pred_vecs, edges): 57 | """ 58 | Inputs: 59 | - obj_vecs: FloatTensor of shape (O, D) giving vectors for all objects 60 | - pred_vecs: FloatTensor of shape (T, D) giving vectors for all predicates 61 | - edges: LongTensor of shape (T, 2) where edges[k] = [i, j] indicates the 62 | presence of a triple [obj_vecs[i], pred_vecs[k], obj_vecs[j]] 63 | 64 | Outputs: 65 | - new_obj_vecs: FloatTensor of shape (O, D) giving new vectors for objects 66 | - new_pred_vecs: FloatTensor of shape (T, D) giving new vectors for predicates 67 | """ 68 | dtype, device = obj_vecs.dtype, obj_vecs.device 69 | O, T = obj_vecs.size(0), pred_vecs.size(0) 70 | Din, H, Dout = self.input_dim, self.hidden_dim, self.output_dim 71 | 72 | # Break apart indices for subjects and objects; these have shape (T,) 73 | s_idx = edges[:, 0].contiguous() 74 | o_idx = edges[:, 1].contiguous() 75 | 76 | # Get current vectors for subjects and objects; these have shape (T, Din) 77 | cur_s_vecs = obj_vecs[s_idx] 78 | cur_o_vecs = obj_vecs[o_idx] 79 | 80 | # Get current vectors for triples; shape is (T, 3 * Din) 81 | # Pass through net1 to get new triple vecs; shape is (T, 2 * H + Dout) 82 | cur_t_vecs = torch.cat([cur_s_vecs, pred_vecs, cur_o_vecs], dim=1) 83 | new_t_vecs = self.net1(cur_t_vecs) 84 | 85 | # Break apart into new s, p, and o vecs; s and o vecs have shape (T, H) and 86 | # p vecs have shape (T, Dout) 87 | new_s_vecs = new_t_vecs[:, :H] 88 | new_p_vecs = new_t_vecs[:, H:(H+Dout)] 89 | new_o_vecs = new_t_vecs[:, (H+Dout):(2 * H + Dout)] 90 | 91 | # Allocate space for pooled object vectors of shape (O, H) 92 | pooled_obj_vecs = torch.zeros(O, H, dtype=dtype, device=device) 93 | 94 | # Use scatter_add to sum vectors for objects that appear in multiple triples; 95 | # we first need to expand the indices to have shape (T, D) 96 | s_idx_exp = s_idx.view(-1, 1).expand_as(new_s_vecs) 97 | o_idx_exp = o_idx.view(-1, 1).expand_as(new_o_vecs) 98 | pooled_obj_vecs = pooled_obj_vecs.scatter_add(0, s_idx_exp, new_s_vecs) 99 | pooled_obj_vecs = pooled_obj_vecs.scatter_add(0, o_idx_exp, new_o_vecs) 100 | 101 | if self.pooling == 'avg': 102 | # Figure out how many times each object has appeared, again using 103 | # some scatter_add trickery. 104 | obj_counts = torch.zeros(O, dtype=dtype, device=device) 105 | ones = torch.ones(T, dtype=dtype, device=device) 106 | obj_counts = obj_counts.scatter_add(0, s_idx, ones) 107 | obj_counts = obj_counts.scatter_add(0, o_idx, ones) 108 | 109 | # Divide the new object vectors by the number of times they 110 | # appeared, but first clamp at 1 to avoid dividing by zero; 111 | # objects that appear in no triples will have output vector 0 112 | # so this will not affect them. 113 | obj_counts = obj_counts.clamp(min=1) 114 | pooled_obj_vecs = pooled_obj_vecs / obj_counts.view(-1, 1) 115 | 116 | # Send pooled object vectors through net2 to get output object vectors, 117 | # of shape (O, Dout) 118 | new_obj_vecs = self.net2(pooled_obj_vecs) 119 | 120 | return new_obj_vecs, new_p_vecs 121 | 122 | 123 | class GraphTripleConvNet(nn.Module): 124 | """ A sequence of scene graph convolution layers """ 125 | def __init__(self, input_dim, num_layers=5, hidden_dim=512, pooling='avg', 126 | mlp_normalization='none'): 127 | super(GraphTripleConvNet, self).__init__() 128 | 129 | self.num_layers = num_layers 130 | self.gconvs = nn.ModuleList() 131 | gconv_kwargs = { 132 | 'input_dim': input_dim, 133 | 'hidden_dim': hidden_dim, 134 | 'pooling': pooling, 135 | 'mlp_normalization': mlp_normalization, 136 | } 137 | for _ in range(self.num_layers): 138 | self.gconvs.append(GraphTripleConv(**gconv_kwargs)) 139 | 140 | def forward(self, obj_vecs, pred_vecs, edges): 141 | for i in range(self.num_layers): 142 | gconv = self.gconvs[i] 143 | obj_vecs, pred_vecs = gconv(obj_vecs, pred_vecs, edges) 144 | return obj_vecs, pred_vecs 145 | 146 | 147 | -------------------------------------------------------------------------------- /sg2im/layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | 22 | def get_normalization_2d(channels, normalization): 23 | if normalization == 'instance': 24 | return nn.InstanceNorm2d(channels) 25 | elif normalization == 'batch': 26 | return nn.BatchNorm2d(channels) 27 | elif normalization == 'none': 28 | return None 29 | else: 30 | raise ValueError('Unrecognized normalization type "%s"' % normalization) 31 | 32 | 33 | def get_activation(name): 34 | kwargs = {} 35 | if name.lower().startswith('leakyrelu'): 36 | if '-' in name: 37 | slope = float(name.split('-')[1]) 38 | kwargs = {'negative_slope': slope} 39 | name = 'leakyrelu' 40 | activations = { 41 | 'relu': nn.ReLU, 42 | 'leakyrelu': nn.LeakyReLU, 43 | } 44 | if name.lower() not in activations: 45 | raise ValueError('Invalid activation "%s"' % name) 46 | return activations[name.lower()](**kwargs) 47 | 48 | 49 | 50 | 51 | def _init_conv(layer, method): 52 | if not isinstance(layer, nn.Conv2d): 53 | return 54 | if method == 'default': 55 | return 56 | elif method == 'kaiming-normal': 57 | nn.init.kaiming_normal(layer.weight) 58 | elif method == 'kaiming-uniform': 59 | nn.init.kaiming_uniform(layer.weight) 60 | 61 | 62 | class Flatten(nn.Module): 63 | def forward(self, x): 64 | return x.view(x.size(0), -1) 65 | 66 | def __repr__(self): 67 | return 'Flatten()' 68 | 69 | 70 | class Unflatten(nn.Module): 71 | def __init__(self, size): 72 | super(Unflatten, self).__init__() 73 | self.size = size 74 | 75 | def forward(self, x): 76 | return x.view(*self.size) 77 | 78 | def __repr__(self): 79 | size_str = ', '.join('%d' % d for d in self.size) 80 | return 'Unflatten(%s)' % size_str 81 | 82 | 83 | class GlobalAvgPool(nn.Module): 84 | def forward(self, x): 85 | N, C = x.size(0), x.size(1) 86 | return x.view(N, C, -1).mean(dim=2) 87 | 88 | 89 | class ResidualBlock(nn.Module): 90 | def __init__(self, channels, normalization='batch', activation='relu', 91 | padding='same', kernel_size=3, init='default'): 92 | super(ResidualBlock, self).__init__() 93 | 94 | K = kernel_size 95 | P = _get_padding(K, padding) 96 | C = channels 97 | self.padding = P 98 | layers = [ 99 | get_normalization_2d(C, normalization), 100 | get_activation(activation), 101 | nn.Conv2d(C, C, kernel_size=K, padding=P), 102 | get_normalization_2d(C, normalization), 103 | get_activation(activation), 104 | nn.Conv2d(C, C, kernel_size=K, padding=P), 105 | ] 106 | layers = [layer for layer in layers if layer is not None] 107 | for layer in layers: 108 | _init_conv(layer, method=init) 109 | self.net = nn.Sequential(*layers) 110 | 111 | def forward(self, x): 112 | P = self.padding 113 | shortcut = x 114 | if P == 0: 115 | shortcut = x[:, :, P:-P, P:-P] 116 | y = self.net(x) 117 | return shortcut + self.net(x) 118 | 119 | 120 | def _get_padding(K, mode): 121 | """ Helper method to compute padding size """ 122 | if mode == 'valid': 123 | return 0 124 | elif mode == 'same': 125 | assert K % 2 == 1, 'Invalid kernel size %d for "same" padding' % K 126 | return (K - 1) // 2 127 | 128 | 129 | def build_cnn(arch, normalization='batch', activation='relu', padding='same', 130 | pooling='max', init='default'): 131 | """ 132 | Build a CNN from an architecture string, which is a list of layer 133 | specification strings. The overall architecture can be given as a list or as 134 | a comma-separated string. 135 | 136 | All convolutions *except for the first* are preceeded by normalization and 137 | nonlinearity. 138 | 139 | All other layers support the following: 140 | - IX: Indicates that the number of input channels to the network is X. 141 | Can only be used at the first layer; if not present then we assume 142 | 3 input channels. 143 | - CK-X: KxK convolution with X output channels 144 | - CK-X-S: KxK convolution with X output channels and stride S 145 | - R: Residual block keeping the same number of channels 146 | - UX: Nearest-neighbor upsampling with factor X 147 | - PX: Spatial pooling with factor X 148 | - FC-X-Y: Flatten followed by fully-connected layer 149 | 150 | Returns a tuple of: 151 | - cnn: An nn.Sequential 152 | - channels: Number of output channels 153 | """ 154 | if isinstance(arch, str): 155 | arch = arch.split(',') 156 | cur_C = 3 157 | if len(arch) > 0 and arch[0][0] == 'I': 158 | cur_C = int(arch[0][1:]) 159 | arch = arch[1:] 160 | 161 | first_conv = True 162 | flat = False 163 | layers = [] 164 | for i, s in enumerate(arch): 165 | if s[0] == 'C': 166 | if not first_conv: 167 | layers.append(get_normalization_2d(cur_C, normalization)) 168 | layers.append(get_activation(activation)) 169 | first_conv = False 170 | vals = [int(i) for i in s[1:].split('-')] 171 | if len(vals) == 2: 172 | K, next_C = vals 173 | stride = 1 174 | elif len(vals) == 3: 175 | K, next_C, stride = vals 176 | # K, next_C = (int(i) for i in s[1:].split('-')) 177 | P = _get_padding(K, padding) 178 | conv = nn.Conv2d(cur_C, next_C, kernel_size=K, padding=P, stride=stride) 179 | layers.append(conv) 180 | _init_conv(layers[-1], init) 181 | cur_C = next_C 182 | elif s[0] == 'R': 183 | norm = 'none' if first_conv else normalization 184 | res = ResidualBlock(cur_C, normalization=norm, activation=activation, 185 | padding=padding, init=init) 186 | layers.append(res) 187 | first_conv = False 188 | elif s[0] == 'U': 189 | factor = int(s[1:]) 190 | layers.append(nn.Upsample(scale_factor=factor, mode='nearest')) 191 | elif s[0] == 'P': 192 | factor = int(s[1:]) 193 | if pooling == 'max': 194 | pool = nn.MaxPool2d(kernel_size=factor, stride=factor) 195 | elif pooling == 'avg': 196 | pool = nn.AvgPool2d(kernel_size=factor, stride=factor) 197 | layers.append(pool) 198 | elif s[:2] == 'FC': 199 | _, Din, Dout = s.split('-') 200 | Din, Dout = int(Din), int(Dout) 201 | if not flat: 202 | layers.append(Flatten()) 203 | flat = True 204 | layers.append(nn.Linear(Din, Dout)) 205 | if i + 1 < len(arch): 206 | layers.append(get_activation(activation)) 207 | cur_C = Dout 208 | else: 209 | raise ValueError('Invalid layer "%s"' % s) 210 | layers = [layer for layer in layers if layer is not None] 211 | for layer in layers: 212 | print(layer) 213 | return nn.Sequential(*layers), cur_C 214 | 215 | 216 | def build_mlp(dim_list, activation='relu', batch_norm='none', 217 | dropout=0, final_nonlinearity=True): 218 | layers = [] 219 | for i in range(len(dim_list) - 1): 220 | dim_in, dim_out = dim_list[i], dim_list[i + 1] 221 | layers.append(nn.Linear(dim_in, dim_out)) 222 | final_layer = (i == len(dim_list) - 2) 223 | if not final_layer or final_nonlinearity: 224 | if batch_norm == 'batch': 225 | layers.append(nn.BatchNorm1d(dim_out)) 226 | if activation == 'relu': 227 | layers.append(nn.ReLU()) 228 | elif activation == 'leakyrelu': 229 | layers.append(nn.LeakyReLU()) 230 | if dropout > 0: 231 | layers.append(nn.Dropout(p=dropout)) 232 | return nn.Sequential(*layers) 233 | 234 | -------------------------------------------------------------------------------- /sg2im/layout.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | from sg2im.utils import timeit, get_gpu_memory, lineno 21 | 22 | 23 | """ 24 | Functions for computing image layouts from object vectors, bounding boxes, 25 | and segmentation masks. These are used to compute course scene layouts which 26 | are then fed as input to the cascaded refinement network. 27 | """ 28 | 29 | 30 | def boxes_to_layout(vecs, boxes, obj_to_img, H, W=None, pooling='sum'): 31 | """ 32 | Inputs: 33 | - vecs: Tensor of shape (O, D) giving vectors 34 | - boxes: Tensor of shape (O, 4) giving bounding boxes in the format 35 | [x0, y0, x1, y1] in the [0, 1] coordinate space 36 | - obj_to_img: LongTensor of shape (O,) mapping each element of vecs to 37 | an image, where each element is in the range [0, N). If obj_to_img[i] = j 38 | then vecs[i] belongs to image j. 39 | - H, W: Size of the output 40 | 41 | Returns: 42 | - out: Tensor of shape (N, D, H, W) 43 | """ 44 | O, D = vecs.size() 45 | if W is None: 46 | W = H 47 | 48 | grid = _boxes_to_grid(boxes, H, W) 49 | 50 | # If we don't add extra spatial dimensions here then out-of-bounds 51 | # elements won't be automatically set to 0 52 | img_in = vecs.view(O, D, 1, 1).expand(O, D, 8, 8) 53 | sampled = F.grid_sample(img_in, grid) # (O, D, H, W) 54 | 55 | # Explicitly masking makes everything quite a bit slower. 56 | # If we rely on implicit masking the interpolated boxes end up 57 | # blurred around the edges, but it should be fine. 58 | # mask = ((X < 0) + (X > 1) + (Y < 0) + (Y > 1)).clamp(max=1) 59 | # sampled[mask[:, None]] = 0 60 | 61 | out = _pool_samples(sampled, obj_to_img, pooling=pooling) 62 | 63 | return out 64 | 65 | 66 | def masks_to_layout(vecs, boxes, masks, obj_to_img, H, W=None, pooling='sum'): 67 | """ 68 | Inputs: 69 | - vecs: Tensor of shape (O, D) giving vectors 70 | - boxes: Tensor of shape (O, 4) giving bounding boxes in the format 71 | [x0, y0, x1, y1] in the [0, 1] coordinate space 72 | - masks: Tensor of shape (O, M, M) giving binary masks for each object 73 | - obj_to_img: LongTensor of shape (O,) mapping objects to images 74 | - H, W: Size of the output image. 75 | 76 | Returns: 77 | - out: Tensor of shape (N, D, H, W) 78 | """ 79 | O, D = vecs.size() 80 | M = masks.size(1) 81 | assert masks.size() == (O, M, M) 82 | if W is None: 83 | W = H 84 | 85 | grid = _boxes_to_grid(boxes, H, W) 86 | 87 | img_in = vecs.view(O, D, 1, 1) * masks.float().view(O, 1, M, M) 88 | sampled = F.grid_sample(img_in, grid) 89 | 90 | out = _pool_samples(sampled, obj_to_img, pooling=pooling) 91 | return out 92 | 93 | 94 | def _boxes_to_grid(boxes, H, W): 95 | """ 96 | Input: 97 | - boxes: FloatTensor of shape (O, 4) giving boxes in the [x0, y0, x1, y1] 98 | format in the [0, 1] coordinate space 99 | - H, W: Scalars giving size of output 100 | 101 | Returns: 102 | - grid: FloatTensor of shape (O, H, W, 2) suitable for passing to grid_sample 103 | """ 104 | O = boxes.size(0) 105 | 106 | boxes = boxes.view(O, 4, 1, 1) 107 | 108 | # All these are (O, 1, 1) 109 | x0, y0 = boxes[:, 0], boxes[:, 1] 110 | x1, y1 = boxes[:, 2], boxes[:, 3] 111 | ww = x1 - x0 112 | hh = y1 - y0 113 | 114 | X = torch.linspace(0, 1, steps=W).view(1, 1, W).to(boxes) 115 | Y = torch.linspace(0, 1, steps=H).view(1, H, 1).to(boxes) 116 | 117 | X = (X - x0) / ww # (O, 1, W) 118 | Y = (Y - y0) / hh # (O, H, 1) 119 | 120 | # Stack does not broadcast its arguments so we need to expand explicitly 121 | X = X.expand(O, H, W) 122 | Y = Y.expand(O, H, W) 123 | grid = torch.stack([X, Y], dim=3) # (O, H, W, 2) 124 | 125 | # Right now grid is in [0, 1] space; transform to [-1, 1] 126 | grid = grid.mul(2).sub(1) 127 | 128 | return grid 129 | 130 | 131 | def _pool_samples(samples, obj_to_img, pooling='sum'): 132 | """ 133 | Input: 134 | - samples: FloatTensor of shape (O, D, H, W) 135 | - obj_to_img: LongTensor of shape (O,) with each element in the range 136 | [0, N) mapping elements of samples to output images 137 | 138 | Output: 139 | - pooled: FloatTensor of shape (N, D, H, W) 140 | """ 141 | dtype, device = samples.dtype, samples.device 142 | O, D, H, W = samples.size() 143 | N = obj_to_img.data.max().item() + 1 144 | 145 | # Use scatter_add to sum the sampled outputs for each image 146 | out = torch.zeros(N, D, H, W, dtype=dtype, device=device) 147 | idx = obj_to_img.view(O, 1, 1, 1).expand(O, D, H, W) 148 | out = out.scatter_add(0, idx, samples) 149 | 150 | if pooling == 'avg': 151 | # Divide each output mask by the number of objects; use scatter_add again 152 | # to count the number of objects per image. 153 | ones = torch.ones(O, dtype=dtype, device=device) 154 | obj_counts = torch.zeros(N, dtype=dtype, device=device) 155 | obj_counts = obj_counts.scatter_add(0, obj_to_img, ones) 156 | print(obj_counts) 157 | obj_counts = obj_counts.clamp(min=1) 158 | out = out / obj_counts.view(N, 1, 1, 1) 159 | elif pooling != 'sum': 160 | raise ValueError('Invalid pooling "%s"' % pooling) 161 | 162 | return out 163 | 164 | 165 | if __name__ == '__main__': 166 | vecs = torch.FloatTensor([ 167 | [1, 0, 0], [0, 1, 0], [0, 0, 1], 168 | [1, 0, 0], [0, 1, 0], [0, 0, 1], 169 | ]) 170 | boxes = torch.FloatTensor([ 171 | [0.25, 0.125, 0.5, 0.875], 172 | [0, 0, 1, 0.25], 173 | [0.6125, 0, 0.875, 1], 174 | [0, 0.8, 1, 1.0], 175 | [0.25, 0.125, 0.5, 0.875], 176 | [0.6125, 0, 0.875, 1], 177 | ]) 178 | obj_to_img = torch.LongTensor([0, 0, 0, 1, 1, 1]) 179 | # vecs = torch.FloatTensor([[[1]]]) 180 | # boxes = torch.FloatTensor([[[0.25, 0.25, 0.75, 0.75]]]) 181 | vecs, boxes = vecs.cuda(), boxes.cuda() 182 | obj_to_img = obj_to_img.cuda() 183 | out = boxes_to_layout(vecs, boxes, obj_to_img, 256, pooling='sum') 184 | 185 | from torchvision.utils import save_image 186 | save_image(out.data, 'out.png') 187 | 188 | 189 | masks = torch.FloatTensor([ 190 | [ 191 | [0, 0, 1, 0, 0], 192 | [0, 1, 1, 1, 0], 193 | [1, 1, 1, 1, 1], 194 | [0, 1, 1, 1, 0], 195 | [0, 0, 1, 0, 0], 196 | ], 197 | [ 198 | [0, 0, 1, 0, 0], 199 | [0, 1, 0, 1, 0], 200 | [1, 0, 0, 0, 1], 201 | [0, 1, 0, 1, 0], 202 | [0, 0, 1, 0, 0], 203 | ], 204 | [ 205 | [0, 0, 1, 0, 0], 206 | [0, 1, 1, 1, 0], 207 | [1, 1, 1, 1, 1], 208 | [0, 1, 1, 1, 0], 209 | [0, 0, 1, 0, 0], 210 | ], 211 | [ 212 | [0, 0, 1, 0, 0], 213 | [0, 1, 1, 1, 0], 214 | [1, 1, 1, 1, 1], 215 | [0, 1, 1, 1, 0], 216 | [0, 0, 1, 0, 0], 217 | ], 218 | [ 219 | [0, 0, 1, 0, 0], 220 | [0, 1, 1, 1, 0], 221 | [1, 1, 1, 1, 1], 222 | [0, 1, 1, 1, 0], 223 | [0, 0, 1, 0, 0], 224 | ], 225 | [ 226 | [0, 0, 1, 0, 0], 227 | [0, 1, 1, 1, 0], 228 | [1, 1, 1, 1, 1], 229 | [0, 1, 1, 1, 0], 230 | [0, 0, 1, 0, 0], 231 | ] 232 | ]) 233 | masks = masks.cuda() 234 | out = masks_to_layout(vecs, boxes, masks, obj_to_img, 256) 235 | save_image(out.data, 'out_masks.png') 236 | -------------------------------------------------------------------------------- /sg2im/losses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | 20 | 21 | def get_gan_losses(gan_type): 22 | """ 23 | Returns the generator and discriminator loss for a particular GAN type. 24 | 25 | The returned functions have the following API: 26 | loss_g = g_loss(scores_fake) 27 | loss_d = d_loss(scores_real, scores_fake) 28 | """ 29 | if gan_type == 'gan': 30 | return gan_g_loss, gan_d_loss 31 | elif gan_type == 'wgan': 32 | return wgan_g_loss, wgan_d_loss 33 | elif gan_type == 'lsgan': 34 | return lsgan_g_loss, lsgan_d_loss 35 | else: 36 | raise ValueError('Unrecognized GAN type "%s"' % gan_type) 37 | 38 | 39 | def bce_loss(input, target): 40 | """ 41 | Numerically stable version of the binary cross-entropy loss function. 42 | 43 | As per https://github.com/pytorch/pytorch/issues/751 44 | See the TensorFlow docs for a derivation of this formula: 45 | https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits 46 | 47 | Inputs: 48 | - input: PyTorch Tensor of shape (N, ) giving scores. 49 | - target: PyTorch Tensor of shape (N,) containing 0 and 1 giving targets. 50 | 51 | Returns: 52 | - A PyTorch Tensor containing the mean BCE loss over the minibatch of 53 | input data. 54 | """ 55 | neg_abs = -input.abs() 56 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() 57 | return loss.mean() 58 | 59 | 60 | def _make_targets(x, y): 61 | """ 62 | Inputs: 63 | - x: PyTorch Tensor 64 | - y: Python scalar 65 | 66 | Outputs: 67 | - out: PyTorch Variable with same shape and dtype as x, but filled with y 68 | """ 69 | return torch.full_like(x, y) 70 | 71 | 72 | def gan_g_loss(scores_fake): 73 | """ 74 | Input: 75 | - scores_fake: Tensor of shape (N,) containing scores for fake samples 76 | 77 | Output: 78 | - loss: Variable of shape (,) giving GAN generator loss 79 | """ 80 | if scores_fake.dim() > 1: 81 | scores_fake = scores_fake.view(-1) 82 | y_fake = _make_targets(scores_fake, 1) 83 | return bce_loss(scores_fake, y_fake) 84 | 85 | 86 | def gan_d_loss(scores_real, scores_fake): 87 | """ 88 | Input: 89 | - scores_real: Tensor of shape (N,) giving scores for real samples 90 | - scores_fake: Tensor of shape (N,) giving scores for fake samples 91 | 92 | Output: 93 | - loss: Tensor of shape (,) giving GAN discriminator loss 94 | """ 95 | assert scores_real.size() == scores_fake.size() 96 | if scores_real.dim() > 1: 97 | scores_real = scores_real.view(-1) 98 | scores_fake = scores_fake.view(-1) 99 | y_real = _make_targets(scores_real, 1) 100 | y_fake = _make_targets(scores_fake, 0) 101 | loss_real = bce_loss(scores_real, y_real) 102 | loss_fake = bce_loss(scores_fake, y_fake) 103 | return loss_real + loss_fake 104 | 105 | 106 | def wgan_g_loss(scores_fake): 107 | """ 108 | Input: 109 | - scores_fake: Tensor of shape (N,) containing scores for fake samples 110 | 111 | Output: 112 | - loss: Tensor of shape (,) giving WGAN generator loss 113 | """ 114 | return -scores_fake.mean() 115 | 116 | 117 | def wgan_d_loss(scores_real, scores_fake): 118 | """ 119 | Input: 120 | - scores_real: Tensor of shape (N,) giving scores for real samples 121 | - scores_fake: Tensor of shape (N,) giving scores for fake samples 122 | 123 | Output: 124 | - loss: Tensor of shape (,) giving WGAN discriminator loss 125 | """ 126 | return scores_fake.mean() - scores_real.mean() 127 | 128 | 129 | def lsgan_g_loss(scores_fake): 130 | if scores_fake.dim() > 1: 131 | scores_fake = scores_fake.view(-1) 132 | y_fake = _make_targets(scores_fake, 1) 133 | return F.mse_loss(scores_fake.sigmoid(), y_fake) 134 | 135 | 136 | def lsgan_d_loss(scores_real, scores_fake): 137 | assert scores_real.size() == scores_fake.size() 138 | if scores_real.dim() > 1: 139 | scores_real = scores_real.view(-1) 140 | scores_fake = scores_fake.view(-1) 141 | y_real = _make_targets(scores_real, 1) 142 | y_fake = _make_targets(scores_fake, 0) 143 | loss_real = F.mse_loss(scores_real.sigmoid(), y_real) 144 | loss_fake = F.mse_loss(scores_fake.sigmoid(), y_fake) 145 | return loss_real + loss_fake 146 | 147 | 148 | def gradient_penalty(x_real, x_fake, f, gamma=1.0): 149 | N = x_real.size(0) 150 | device, dtype = x_real.device, x_real.dtype 151 | eps = torch.randn(N, 1, 1, 1, device=device, dtype=dtype) 152 | x_hat = eps * x_real + (1 - eps) * x_fake 153 | x_hat_score = f(x_hat) 154 | if x_hat_score.dim() > 1: 155 | x_hat_score = x_hat_score.view(x_hat_score.size(0), -1).mean(dim=1) 156 | x_hat_score = x_hat_score.sum() 157 | grad_x_hat, = torch.autograd.grad(x_hat_score, x_hat, create_graph=True) 158 | grad_x_hat_norm = grad_x_hat.contiguous().view(N, -1).norm(p=2, dim=1) 159 | gp_loss = (grad_x_hat_norm - gamma).pow(2).div(gamma * gamma).mean() 160 | return gp_loss 161 | 162 | 163 | -------------------------------------------------------------------------------- /sg2im/metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import torch 18 | 19 | 20 | def intersection(bbox_pred, bbox_gt): 21 | max_xy = torch.min(bbox_pred[:, 2:], bbox_gt[:, 2:]) 22 | min_xy = torch.max(bbox_pred[:, :2], bbox_gt[:, :2]) 23 | inter = torch.clamp((max_xy - min_xy), min=0) 24 | return inter[:, 0] * inter[:, 1] 25 | 26 | 27 | def jaccard(bbox_pred, bbox_gt): 28 | inter = intersection(bbox_pred, bbox_gt) 29 | area_pred = (bbox_pred[:, 2] - bbox_pred[:, 0]) * (bbox_pred[:, 3] - 30 | bbox_pred[:, 1]) 31 | area_gt = (bbox_gt[:, 2] - bbox_gt[:, 0]) * (bbox_gt[:, 3] - 32 | bbox_gt[:, 1]) 33 | union = area_pred + area_gt - inter 34 | iou = torch.div(inter, union) 35 | return torch.sum(iou) 36 | 37 | def get_total_norm(parameters, norm_type=2): 38 | if norm_type == float('inf'): 39 | total_norm = max(p.grad.data.abs().max() for p in parameters) 40 | else: 41 | total_norm = 0 42 | for p in parameters: 43 | try: 44 | param_norm = p.grad.data.norm(norm_type) 45 | total_norm += param_norm ** norm_type 46 | total_norm = total_norm ** (1. / norm_type) 47 | except: 48 | continue 49 | return total_norm 50 | 51 | -------------------------------------------------------------------------------- /sg2im/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import math 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | 22 | import sg2im.box_utils as box_utils 23 | from sg2im.graph import GraphTripleConv, GraphTripleConvNet 24 | from sg2im.crn import RefinementNetwork 25 | from sg2im.layout import boxes_to_layout, masks_to_layout 26 | from sg2im.layers import build_mlp 27 | 28 | 29 | class Sg2ImModel(nn.Module): 30 | def __init__(self, vocab, image_size=(64, 64), embedding_dim=64, 31 | gconv_dim=128, gconv_hidden_dim=512, 32 | gconv_pooling='avg', gconv_num_layers=5, 33 | refinement_dims=(1024, 512, 256, 128, 64), 34 | normalization='batch', activation='leakyrelu-0.2', 35 | mask_size=None, mlp_normalization='none', layout_noise_dim=0, 36 | **kwargs): 37 | super(Sg2ImModel, self).__init__() 38 | 39 | # We used to have some additional arguments: 40 | # vec_noise_dim, gconv_mode, box_anchor, decouple_obj_predictions 41 | if len(kwargs) > 0: 42 | print('WARNING: Model got unexpected kwargs ', kwargs) 43 | 44 | self.vocab = vocab 45 | self.image_size = image_size 46 | self.layout_noise_dim = layout_noise_dim 47 | 48 | num_objs = len(vocab['object_idx_to_name']) 49 | num_preds = len(vocab['pred_idx_to_name']) 50 | self.obj_embeddings = nn.Embedding(num_objs + 1, embedding_dim) 51 | self.pred_embeddings = nn.Embedding(num_preds, embedding_dim) 52 | 53 | if gconv_num_layers == 0: 54 | self.gconv = nn.Linear(embedding_dim, gconv_dim) 55 | elif gconv_num_layers > 0: 56 | gconv_kwargs = { 57 | 'input_dim': embedding_dim, 58 | 'output_dim': gconv_dim, 59 | 'hidden_dim': gconv_hidden_dim, 60 | 'pooling': gconv_pooling, 61 | 'mlp_normalization': mlp_normalization, 62 | } 63 | self.gconv = GraphTripleConv(**gconv_kwargs) 64 | 65 | self.gconv_net = None 66 | if gconv_num_layers > 1: 67 | gconv_kwargs = { 68 | 'input_dim': gconv_dim, 69 | 'hidden_dim': gconv_hidden_dim, 70 | 'pooling': gconv_pooling, 71 | 'num_layers': gconv_num_layers - 1, 72 | 'mlp_normalization': mlp_normalization, 73 | } 74 | self.gconv_net = GraphTripleConvNet(**gconv_kwargs) 75 | 76 | box_net_dim = 4 77 | box_net_layers = [gconv_dim, gconv_hidden_dim, box_net_dim] 78 | self.box_net = build_mlp(box_net_layers, batch_norm=mlp_normalization) 79 | 80 | self.mask_net = None 81 | if mask_size is not None and mask_size > 0: 82 | self.mask_net = self._build_mask_net(num_objs, gconv_dim, mask_size) 83 | 84 | rel_aux_layers = [2 * embedding_dim + 8, gconv_hidden_dim, num_preds] 85 | self.rel_aux_net = build_mlp(rel_aux_layers, batch_norm=mlp_normalization) 86 | 87 | refinement_kwargs = { 88 | 'dims': (gconv_dim + layout_noise_dim,) + refinement_dims, 89 | 'normalization': normalization, 90 | 'activation': activation, 91 | } 92 | self.refinement_net = RefinementNetwork(**refinement_kwargs) 93 | 94 | def _build_mask_net(self, num_objs, dim, mask_size): 95 | output_dim = 1 96 | layers, cur_size = [], 1 97 | while cur_size < mask_size: 98 | layers.append(nn.Upsample(scale_factor=2, mode='nearest')) 99 | layers.append(nn.BatchNorm2d(dim)) 100 | layers.append(nn.Conv2d(dim, dim, kernel_size=3, padding=1)) 101 | layers.append(nn.ReLU()) 102 | cur_size *= 2 103 | if cur_size != mask_size: 104 | raise ValueError('Mask size must be a power of 2') 105 | layers.append(nn.Conv2d(dim, output_dim, kernel_size=1)) 106 | return nn.Sequential(*layers) 107 | 108 | def forward(self, objs, triples, obj_to_img=None, 109 | boxes_gt=None, masks_gt=None): 110 | """ 111 | Required Inputs: 112 | - objs: LongTensor of shape (O,) giving categories for all objects 113 | - triples: LongTensor of shape (T, 3) where triples[t] = [s, p, o] 114 | means that there is a triple (objs[s], p, objs[o]) 115 | 116 | Optional Inputs: 117 | - obj_to_img: LongTensor of shape (O,) where obj_to_img[o] = i 118 | means that objects[o] is an object in image i. If not given then 119 | all objects are assumed to belong to the same image. 120 | - boxes_gt: FloatTensor of shape (O, 4) giving boxes to use for computing 121 | the spatial layout; if not given then use predicted boxes. 122 | """ 123 | O, T = objs.size(0), triples.size(0) 124 | s, p, o = triples.chunk(3, dim=1) # All have shape (T, 1) 125 | s, p, o = [x.squeeze(1) for x in [s, p, o]] # Now have shape (T,) 126 | edges = torch.stack([s, o], dim=1) # Shape is (T, 2) 127 | 128 | if obj_to_img is None: 129 | obj_to_img = torch.zeros(O, dtype=objs.dtype, device=objs.device) 130 | 131 | obj_vecs = self.obj_embeddings(objs) 132 | obj_vecs_orig = obj_vecs 133 | pred_vecs = self.pred_embeddings(p) 134 | 135 | if isinstance(self.gconv, nn.Linear): 136 | obj_vecs = self.gconv(obj_vecs) 137 | else: 138 | obj_vecs, pred_vecs = self.gconv(obj_vecs, pred_vecs, edges) 139 | if self.gconv_net is not None: 140 | obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges) 141 | 142 | boxes_pred = self.box_net(obj_vecs) 143 | 144 | masks_pred = None 145 | if self.mask_net is not None: 146 | mask_scores = self.mask_net(obj_vecs.view(O, -1, 1, 1)) 147 | masks_pred = mask_scores.squeeze(1).sigmoid() 148 | 149 | s_boxes, o_boxes = boxes_pred[s], boxes_pred[o] 150 | s_vecs, o_vecs = obj_vecs_orig[s], obj_vecs_orig[o] 151 | rel_aux_input = torch.cat([s_boxes, o_boxes, s_vecs, o_vecs], dim=1) 152 | rel_scores = self.rel_aux_net(rel_aux_input) 153 | 154 | H, W = self.image_size 155 | layout_boxes = boxes_pred if boxes_gt is None else boxes_gt 156 | 157 | if masks_pred is None: 158 | layout = boxes_to_layout(obj_vecs, layout_boxes, obj_to_img, H, W) 159 | else: 160 | layout_masks = masks_pred if masks_gt is None else masks_gt 161 | layout = masks_to_layout(obj_vecs, layout_boxes, layout_masks, 162 | obj_to_img, H, W) 163 | 164 | if self.layout_noise_dim > 0: 165 | N, C, H, W = layout.size() 166 | noise_shape = (N, self.layout_noise_dim, H, W) 167 | layout_noise = torch.randn(noise_shape, dtype=layout.dtype, 168 | device=layout.device) 169 | layout = torch.cat([layout, layout_noise], dim=1) 170 | img = self.refinement_net(layout) 171 | return img, boxes_pred, masks_pred, rel_scores 172 | 173 | def encode_scene_graphs(self, scene_graphs): 174 | """ 175 | Encode one or more scene graphs using this model's vocabulary. Inputs to 176 | this method are scene graphs represented as dictionaries like the following: 177 | 178 | { 179 | "objects": ["cat", "dog", "sky"], 180 | "relationships": [ 181 | [0, "next to", 1], 182 | [0, "beneath", 2], 183 | [2, "above", 1], 184 | ] 185 | } 186 | 187 | This scene graph has three relationshps: cat next to dog, cat beneath sky, 188 | and sky above dog. 189 | 190 | Inputs: 191 | - scene_graphs: A dictionary giving a single scene graph, or a list of 192 | dictionaries giving a sequence of scene graphs. 193 | 194 | Returns a tuple of LongTensors (objs, triples, obj_to_img) that have the 195 | same semantics as self.forward. The returned LongTensors will be on the 196 | same device as the model parameters. 197 | """ 198 | if isinstance(scene_graphs, dict): 199 | # We just got a single scene graph, so promote it to a list 200 | scene_graphs = [scene_graphs] 201 | 202 | objs, triples, obj_to_img = [], [], [] 203 | obj_offset = 0 204 | for i, sg in enumerate(scene_graphs): 205 | # Insert dummy __image__ object and __in_image__ relationships 206 | sg['objects'].append('__image__') 207 | image_idx = len(sg['objects']) - 1 208 | for j in range(image_idx): 209 | sg['relationships'].append([j, '__in_image__', image_idx]) 210 | 211 | for obj in sg['objects']: 212 | obj_idx = self.vocab['object_name_to_idx'].get(obj, None) 213 | if obj_idx is None: 214 | raise ValueError('Object "%s" not in vocab' % obj) 215 | objs.append(obj_idx) 216 | obj_to_img.append(i) 217 | for s, p, o in sg['relationships']: 218 | pred_idx = self.vocab['pred_name_to_idx'].get(p, None) 219 | if pred_idx is None: 220 | raise ValueError('Relationship "%s" not in vocab' % p) 221 | triples.append([s + obj_offset, pred_idx, o + obj_offset]) 222 | obj_offset += len(sg['objects']) 223 | device = next(self.parameters()).device 224 | objs = torch.tensor(objs, dtype=torch.int64, device=device) 225 | triples = torch.tensor(triples, dtype=torch.int64, device=device) 226 | obj_to_img = torch.tensor(obj_to_img, dtype=torch.int64, device=device) 227 | return objs, triples, obj_to_img 228 | 229 | def forward_json(self, scene_graphs): 230 | """ Convenience method that combines encode_scene_graphs and forward. """ 231 | objs, triples, obj_to_img = self.encode_scene_graphs(scene_graphs) 232 | return self.forward(objs, triples, obj_to_img) 233 | 234 | -------------------------------------------------------------------------------- /sg2im/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import time 18 | import inspect 19 | import subprocess 20 | from contextlib import contextmanager 21 | 22 | import torch 23 | 24 | 25 | def int_tuple(s): 26 | return tuple(int(i) for i in s.split(',')) 27 | 28 | 29 | def float_tuple(s): 30 | return tuple(float(i) for i in s.split(',')) 31 | 32 | 33 | def str_tuple(s): 34 | return tuple(s.split(',')) 35 | 36 | 37 | def bool_flag(s): 38 | if s == '1': 39 | return True 40 | elif s == '0': 41 | return False 42 | msg = 'Invalid value "%s" for bool flag (should be 0 or 1)' 43 | raise ValueError(msg % s) 44 | 45 | 46 | def lineno(): 47 | return inspect.currentframe().f_back.f_lineno 48 | 49 | 50 | def get_gpu_memory(): 51 | torch.cuda.synchronize() 52 | opts = [ 53 | 'nvidia-smi', '-q', '--gpu=' + str(0), '|', 'grep', '"Used GPU Memory"' 54 | ] 55 | cmd = str.join(' ', opts) 56 | ps = subprocess.Popen(cmd,shell=True,stdout=subprocess.PIPE,stderr=subprocess.STDOUT) 57 | output = ps.communicate()[0].decode('utf-8') 58 | output = output.split("\n")[1].split(":") 59 | consumed_mem = int(output[1].strip().split(" ")[0]) 60 | return consumed_mem 61 | 62 | 63 | @contextmanager 64 | def timeit(msg, should_time=True): 65 | if should_time: 66 | torch.cuda.synchronize() 67 | t0 = time.time() 68 | yield 69 | if should_time: 70 | torch.cuda.synchronize() 71 | t1 = time.time() 72 | duration = (t1 - t0) * 1000.0 73 | print('%s: %.2f ms' % (msg, duration)) 74 | 75 | 76 | class LossManager(object): 77 | def __init__(self): 78 | self.total_loss = None 79 | self.all_losses = {} 80 | 81 | def add_loss(self, loss, name, weight=1.0): 82 | cur_loss = loss * weight 83 | if self.total_loss is not None: 84 | self.total_loss += cur_loss 85 | else: 86 | self.total_loss = cur_loss 87 | 88 | self.all_losses[name] = cur_loss.data.cpu().item() 89 | 90 | def items(self): 91 | return self.all_losses.items() 92 | 93 | -------------------------------------------------------------------------------- /sg2im/vis.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import tempfile, os 18 | import torch 19 | import numpy as np 20 | import matplotlib.pyplot as plt 21 | from matplotlib.patches import Rectangle 22 | from imageio import imread 23 | 24 | 25 | """ 26 | Utilities for making visualizations. 27 | """ 28 | 29 | 30 | def draw_layout(vocab, objs, boxes, masks=None, size=256, 31 | show_boxes=False, bgcolor=(0, 0, 0)): 32 | if bgcolor == 'white': 33 | bgcolor = (255, 255, 255) 34 | 35 | cmap = plt.get_cmap('rainbow') 36 | colors = cmap(np.linspace(0, 1, len(objs))) 37 | 38 | with torch.no_grad(): 39 | objs = objs.cpu().clone() 40 | boxes = boxes.cpu().clone() 41 | boxes *= size 42 | 43 | if masks is not None: 44 | masks = masks.cpu().clone() 45 | 46 | bgcolor = np.asarray(bgcolor) 47 | bg = np.ones((size, size, 1)) * bgcolor 48 | plt.imshow(bg.astype(np.uint8)) 49 | 50 | plt.gca().set_xlim(0, size) 51 | plt.gca().set_ylim(size, 0) 52 | plt.gca().set_aspect(1.0, adjustable='box') 53 | 54 | for i, obj in enumerate(objs): 55 | name = vocab['object_idx_to_name'][obj] 56 | if name == '__image__': 57 | continue 58 | box = boxes[i] 59 | 60 | if masks is None: 61 | continue 62 | mask = masks[i].numpy() 63 | mask /= mask.max() 64 | 65 | r, g, b, a = colors[i] 66 | colored_mask = mask[:, :, None] * np.asarray(colors[i]) 67 | 68 | x0, y0, x1, y1 = box 69 | plt.imshow(colored_mask, extent=(x0, x1, y1, y0), 70 | interpolation='bicubic', alpha=1.0) 71 | 72 | if show_boxes: 73 | for i, obj in enumerate(objs): 74 | name = vocab['object_idx_to_name'][obj] 75 | if name == '__image__': 76 | continue 77 | box = boxes[i] 78 | 79 | draw_box(box, colors[i], name) 80 | 81 | 82 | def draw_box(box, color, text=None): 83 | """ 84 | Draw a bounding box using pyplot, optionally with a text box label. 85 | 86 | Inputs: 87 | - box: Tensor or list with 4 elements: [x0, y0, x1, y1] in [0, W] x [0, H] 88 | coordinate system. 89 | - color: pyplot color to use for the box. 90 | - text: (Optional) String; if provided then draw a label for this box. 91 | """ 92 | TEXT_BOX_HEIGHT = 10 93 | if torch.is_tensor(box) and box.dim() == 2: 94 | box = box.view(-1) 95 | assert box.size(0) == 4 96 | x0, y0, x1, y1 = box 97 | assert y1 > y0, box 98 | assert x1 > x0, box 99 | w, h = x1 - x0, y1 - y0 100 | rect = Rectangle((x0, y0), w, h, fc='none', lw=2, ec=color) 101 | plt.gca().add_patch(rect) 102 | if text is not None: 103 | text_rect = Rectangle((x0, y0), w, TEXT_BOX_HEIGHT, fc=color, alpha=0.5) 104 | plt.gca().add_patch(text_rect) 105 | tx = 0.5 * (x0 + x1) 106 | ty = y0 + TEXT_BOX_HEIGHT / 2.0 107 | plt.text(tx, ty, text, va='center', ha='center') 108 | 109 | 110 | def draw_scene_graph(objs, triples, vocab=None, **kwargs): 111 | """ 112 | Use GraphViz to draw a scene graph. If vocab is not passed then we assume 113 | that objs and triples are python lists containing strings for object and 114 | relationship names. 115 | 116 | Using this requires that GraphViz is installed. On Ubuntu 16.04 this is easy: 117 | sudo apt-get install graphviz 118 | """ 119 | output_filename = kwargs.pop('output_filename', 'graph.png') 120 | orientation = kwargs.pop('orientation', 'V') 121 | edge_width = kwargs.pop('edge_width', 6) 122 | arrow_size = kwargs.pop('arrow_size', 1.5) 123 | binary_edge_weight = kwargs.pop('binary_edge_weight', 1.2) 124 | ignore_dummies = kwargs.pop('ignore_dummies', True) 125 | 126 | if orientation not in ['V', 'H']: 127 | raise ValueError('Invalid orientation "%s"' % orientation) 128 | rankdir = {'H': 'LR', 'V': 'TD'}[orientation] 129 | 130 | if vocab is not None: 131 | # Decode object and relationship names 132 | assert torch.is_tensor(objs) 133 | assert torch.is_tensor(triples) 134 | objs_list, triples_list = [], [] 135 | for i in range(objs.size(0)): 136 | objs_list.append(vocab['object_idx_to_name'][objs[i].item()]) 137 | for i in range(triples.size(0)): 138 | s = triples[i, 0].item() 139 | p = vocab['pred_name_to_idx'][triples[i, 1].item()] 140 | o = triples[i, 2].item() 141 | triples_list.append([s, p, o]) 142 | objs, triples = objs_list, triples_list 143 | 144 | # General setup, and style for object nodes 145 | lines = [ 146 | 'digraph{', 147 | 'graph [size="5,3",ratio="compress",dpi="300",bgcolor="transparent"]', 148 | 'rankdir=%s' % rankdir, 149 | 'nodesep="0.5"', 150 | 'ranksep="0.5"', 151 | 'node [shape="box",style="rounded,filled",fontsize="48",color="none"]', 152 | 'node [fillcolor="lightpink1"]', 153 | ] 154 | # Output nodes for objects 155 | for i, obj in enumerate(objs): 156 | if ignore_dummies and obj == '__image__': 157 | continue 158 | lines.append('%d [label="%s"]' % (i, obj)) 159 | 160 | # Output relationships 161 | next_node_id = len(objs) 162 | lines.append('node [fillcolor="lightblue1"]') 163 | for s, p, o in triples: 164 | if ignore_dummies and p == '__in_image__': 165 | continue 166 | lines += [ 167 | '%d [label="%s"]' % (next_node_id, p), 168 | '%d->%d [penwidth=%f,arrowsize=%f,weight=%f]' % ( 169 | s, next_node_id, edge_width, arrow_size, binary_edge_weight), 170 | '%d->%d [penwidth=%f,arrowsize=%f,weight=%f]' % ( 171 | next_node_id, o, edge_width, arrow_size, binary_edge_weight) 172 | ] 173 | next_node_id += 1 174 | lines.append('}') 175 | 176 | # Now it gets slightly hacky. Write the graphviz spec to a temporary 177 | # text file 178 | ff, dot_filename = tempfile.mkstemp() 179 | with open(dot_filename, 'w') as f: 180 | for line in lines: 181 | f.write('%s\n' % line) 182 | os.close(ff) 183 | 184 | # Shell out to invoke graphviz; this will save the resulting image to disk, 185 | # so we read it, delete it, then return it. 186 | output_format = os.path.splitext(output_filename)[1][1:] 187 | os.system('dot -T%s %s > %s' % (output_format, dot_filename, output_filename)) 188 | os.remove(dot_filename) 189 | img = imread(output_filename) 190 | os.remove(output_filename) 191 | 192 | return img 193 | 194 | 195 | if __name__ == '__main__': 196 | o_idx_to_name = ['cat', 'dog', 'hat', 'skateboard'] 197 | p_idx_to_name = ['riding', 'wearing', 'on', 'next to', 'above'] 198 | o_name_to_idx = {s: i for i, s in enumerate(o_idx_to_name)} 199 | p_name_to_idx = {s: i for i, s in enumerate(p_idx_to_name)} 200 | vocab = { 201 | 'object_idx_to_name': o_idx_to_name, 202 | 'object_name_to_idx': o_name_to_idx, 203 | 'pred_idx_to_name': p_idx_to_name, 204 | 'pred_name_to_idx': p_name_to_idx, 205 | } 206 | 207 | objs = [ 208 | 'cat', 209 | 'cat', 210 | 'skateboard', 211 | 'hat', 212 | ] 213 | objs = torch.LongTensor([o_name_to_idx[o] for o in objs]) 214 | triples = [ 215 | [0, 'next to', 1], 216 | [0, 'riding', 2], 217 | [1, 'wearing', 3], 218 | [3, 'above', 2], 219 | ] 220 | triples = [[s, p_name_to_idx[p], o] for s, p, o in triples] 221 | triples = torch.LongTensor(triples) 222 | 223 | draw_scene_graph(objs, triples, vocab, orientation='V') 224 | 225 | --------------------------------------------------------------------------------