├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── TRAINING.md
├── images
├── sheep
│ ├── img000000.png
│ ├── img000001.png
│ ├── img000002.png
│ ├── img000003.png
│ ├── img000004.png
│ ├── img000005.png
│ ├── img000006.png
│ ├── sg000000.png
│ ├── sg000001.png
│ ├── sg000002.png
│ ├── sg000003.png
│ ├── sg000004.png
│ ├── sg000005.png
│ └── sg000006.png
└── system.png
├── requirements.txt
├── scene_graphs
├── figure_5_coco.json
├── figure_5_vg.json
├── figure_6_sheep.json
└── figure_6_street.json
├── scripts
├── download_ablated_models.sh
├── download_coco.sh
├── download_full_models.sh
├── download_models.sh
├── download_vg.sh
├── preprocess_vg.py
├── print_args.py
├── run_model.py
├── sample_images.py
├── strip_checkpoint.py
├── strip_old_args.py
└── train.py
└── sg2im
├── __init__.py
├── bilinear.py
├── box_utils.py
├── crn.py
├── data
├── __init__.py
├── coco.py
├── utils.py
├── vg.py
└── vg_splits.json
├── discriminators.py
├── graph.py
├── layers.py
├── layout.py
├── losses.py
├── metrics.py
├── model.py
├── utils.py
└── vis.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 | *.pyc
3 | *.DS_Store
4 | outputs/
5 | sg2im-models/
6 | env/
7 | datasets/
8 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows [Google's Open Source Community
28 | Guidelines](https://opensource.google.com/conduct/).
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # sg2im
2 |
3 | This is the code for the paper
4 |
5 | **Image Generation from Scene Graphs**
6 |
7 | Justin Johnson,
8 | Agrim Gupta,
9 | Li Fei-Fei
10 |
11 | Presented at [CVPR 2018](http://cvpr2018.thecvf.com/)
12 |
13 | Please note that this is not an officially supported Google product.
14 |
15 | A **scene graph** is a structured representation of a visual scene where nodes represent *objects* in the scene and edges represent *relationships* between objects. In this paper we present and end-to-end neural network model that inputs a scene graph and outputs an image.
16 |
17 | Below we show some example scene graphs along with images generated from those scene graphs using our model. By modifying the input scene graph we can exercise fine-grained control over the objects in the generated image.
18 |
19 |
20 |

21 |

22 |

23 |

24 |

25 |

26 |

27 |
28 |

29 |

30 |

31 |

32 |

33 |

34 |

35 |
36 |
37 | If you find this code useful in your research then please cite
38 | ```
39 | @inproceedings{johnson2018image,
40 | title={Image Generation from Scene Graphs},
41 | author={Johnson, Justin and Gupta, Agrim and Fei-Fei, Li},
42 | booktitle={CVPR},
43 | year={2018}
44 | }
45 | ```
46 |
47 | ## Model
48 | The input scene graph is processed with a *graph convolution network* which passes information along edges to compute embedding vectors for all objects. These vectors are used to predict bounding boxes and segmentation masks for all objects, which are combined to form a coarse *scene layout*. The layout is passed to a *cascaded refinement network* (Chen an Koltun, ICCV 2017) which generates an output image at increasing spatial scales. The model is trained adversarially against a pair of *discriminator networks* which ensure that output images look realistic.
49 |
50 |
51 |

52 |
53 |
54 | ## Setup
55 | All code was developed and tested on Ubuntu 16.04 with Python 3.5 and PyTorch 0.4.
56 |
57 | You can setup a virtual environment to run the code like this:
58 |
59 | ```bash
60 | python3 -m venv env # Create a virtual environment
61 | source env/bin/activate # Activate virtual environment
62 | pip install -r requirements.txt # Install dependencies
63 | echo $PWD > env/lib/python3.5/site-packages/sg2im.pth # Add current directory to python path
64 | # Work for a while ...
65 | deactivate # Exit virtual environment
66 | ```
67 |
68 | ## Pretrained Models
69 | You can download pretrained models by running the script `bash scripts/download_models.sh`. This will download the following models, and will require about 355 MB of disk space:
70 |
71 | - `sg2im-models/coco64.pt`: Trained to generate 64 x 64 images on the COCO-Stuff dataset. This model was used to generate the COCO images in Figure 5 from the paper.
72 | - `sg2im-models/vg64.pt`: Trained to generate 64 x 64 images on the Visual Genome dataset. This model was used to generate the Visual Genome images in Figure 5 from the paper.
73 | - `sg2im-models/vg128.pt`: Trained to generate 128 x 128 images on the Visual Genome dataset. This model was used to generate the images in Figure 6 from the paper.
74 |
75 | Table 1 in the paper presents an ablation study where we disable various components of the full model. You can download the additional models used in this ablation study by running the script `bash scripts/download_ablated_models.sh`. This will download 12 additional models, requiring and additional 1.25 GB of disk space.
76 |
77 | ## Running Models
78 | You can use the script `scripts/run_model.py` to easily run any of the pretrained models on new scene graphs using a simple human-readable JSON format. For example you can replicate the sheep images above like this:
79 |
80 | ```bash
81 | python scripts/run_model.py \
82 | --checkpoint sg2im-models/vg128.pt \
83 | --scene_graphs scene_graphs/figure_6_sheep.json \
84 | --output_dir outputs
85 | ```
86 |
87 | The generated images will be saved to the directory specified by the `--output_dir` flag. You can control whether the model runs on CPU or GPU using py passing the flag `--device cpu` or `--device gpu`.
88 |
89 | We provide JSON files and pretrained models allowing you to recreate all images from Figures 5 and 6 from the paper.
90 |
91 | #### (Optional): GraphViz
92 | This script can also draw images for the scene graphs themselves using [GraphViz](http://www.graphviz.org/); to enable this option just add the flag `--draw_scene_graphs 1` and the scene graph images will also be saved in the output directory. For this option to work you must install GraphViz; on Ubuntu 16.04 you can simply run `sudo apt-get install graphviz`.
93 |
94 | ## Training new models
95 | Instructions for training new models can be [found here](TRAINING.md).
96 |
--------------------------------------------------------------------------------
/TRAINING.md:
--------------------------------------------------------------------------------
1 | You can train your own model by following these instructions:
2 |
3 | ## Step 1: Install COCO API
4 | To train new models you will need to install the [COCO Python API](https://github.com/cocodataset/cocoapi). Unfortunately installing this package via pip often leads to build errors, but you can install it from source like this:
5 |
6 | ```bash
7 | cd ~
8 | git clone https://github.com/cocodataset/cocoapi.git
9 | cd cocoapi/PythonAPI/
10 | python setup.py install
11 | ```
12 |
13 | ## Step 2: Preparing the data
14 |
15 | ### Visual Genome
16 | Run the following script to download and unpack the relevant parts of the Visual Genome dataset:
17 |
18 | ```bash
19 | bash scripts/download_vg.sh
20 | ```
21 |
22 | This will create the directory `datasets/vg` and will download about 15 GB of data to this directory; after unpacking it will take about 30 GB of disk space.
23 |
24 | After downloading the Visual Genome dataset, we need to preprocess it. This will split the data into train / val / test splits, consolidate all scene graphs into HDF5 files, and apply several heuristics to clean the data. In particular we ignore images that are too small, and only consider object and attribute categories that appear some number of times in the training set; we also igmore objects that are too small, and set minimum and maximum values on the number of objects and relationships that appear per image.
25 |
26 | ```bash
27 | python scripts/preprocess_vg.py
28 | ```
29 |
30 | This will create files `train.h5`, `val.h5`, `test.h5`, and `vocab.json` in the directory `datasets/vg`.
31 |
32 | ### COCO
33 | Run the following script to download and unpack the relevant parts of the COCO dataset:
34 |
35 | ```bash
36 | bash scripts/download_coco.sh
37 | ```
38 |
39 | This will create the directory `datasets/coco` and will download about 21 GB of data to this directory; after unpacking it will take about 60 GB of disk space.
40 |
41 | ## Step 3: Train a model
42 |
43 | Now you can train a new model by running the script:
44 |
45 | ```bash
46 | python scripts/train.py
47 | ```
48 |
49 | By default this will train a model on COCO, periodically saving checkpoint files `checkpoint_with_model.pt` and `checkpoint_no_model.pt` to the current working directory. The training script has a number of command-line flags that you can use to configure the model architecture, hyperparameters, and input / output settings:
50 |
51 | ### Optimization
52 |
53 | - `--batch_size`: How many pairs of (scene graph, image) to use in each minibatch during training. Default is 32.
54 | - `--num_iterations`: Number of training iterations. Default is 1,000,000.
55 | - `--learning_rate`: Learning rate to use in Adam optimizer for the generator and discriminators; default is 1e-4.
56 | - `--eval_mode_after`: The generator is trained in "train" mode for this many iterations, after which training continues in "eval" mode. We found that if the model is trained exclusively in "train" mode then generated images can have severe artifacts if test batches have a different size or composition than those used during training.
57 |
58 | ### Dataset options
59 |
60 | - `--dataset`: The dataset to use for training; must be either `coco` or `vg`. Default is `coco`.
61 | - `--image_size`: The size of images to generate, as a tuple of integers. Default is `64,64`. This is also the resolution at which scene layouts are predicted.
62 | - `--num_train_samples`: The number of images from the training set to use. Default is None, which means the entire training set will be used.
63 | - `--num_val_samples`: The number of images from the validation set to use. Default is 1024. This is particularly important for the COCO dataset, since we partition the COCO validation images into our own validation and test sets; this flag thus controls the number of COCO validation images which we will use as our own validation set, and the remaining images will serve as our test set.
64 | - `--shuffle_val`: Whether to shuffle the samples from the validation set. Default is True.
65 | - `--loader_num_workers`: The number of background threads to use for data loading. Default is 4.
66 | - `--include_relationships`: Whether to include relationships in the scene graphs; default is 1 which means use relationships, 0 means omit them. This is used to train the "no relationships" ablated model.
67 |
68 | **Visual Genome options**:
69 | These flags only take effect if `--dataset` is set to `vg`:
70 |
71 | - `--vg_image_dir`: Directory from which to load Visual Genome images. Default is `datasets/vg/images`.
72 | - `--train_h5`: Path to HDF5 file containing data for the training split, created by `scripts/preprocess_vg.py`. Default is `datasets/vg/train.h5`.
73 | - `--val_h5`: Path to HDF5 file containing data for the validation split, created by `scripts/preprocess_vg.py`. Default is `datasets/vg/val.h5`.
74 | - `--vocab_json`: Path to JSON file containing Visual Genome vocabulary, created by `scripts/preprocess_vg.py`. Default is `datasets/vg/vocab.json`.
75 | - `--max_objects_per_image`: The maximum number of objects to use per scene graph during training; default is 10. Note that `scripts/preprocess_vg.py` also defines a maximum number of objects per image, but the two settings are different. The max value in the preprocessing script causes images with more than the max number of objects to be skipped entirely; in contrast during training if we encounter images with more than the max number of objects then they are randomly subsampled to the max value as a form of data augmentation.
76 | - `--vg_use_orphaned_objects`: Whether to include objects which do not participate in any relationships; 1 means use them (default), 0 means skip them.
77 |
78 | **COCO options**:
79 | These flags only take effect if `--dataset` is set to `coco`:
80 |
81 | - `--coco_train_image_dir`: Directory from which to load COCO training images; default is `datasets/coco/images/train2017`.
82 | - `--coco_val_image_dir`: Directory from which to load COCO validation images; default is `datasets/coco/images/val2017`.
83 | - `--coco_train_instances_json`: Path to JSON file containing object annotations for the COCO training images; default is `datasets/coco/annotations/instances_train2017.json`.
84 | - `--coco_train_stuff_json`: Path to JSON file containing stuff annotations for the COCO training images; default is `datasets/coco/annotations/stuff_train2017.json`.
85 | - `--coco_val_instances_json`: Path to JSON file containing object annotations for COCO validation images; default is `datasets/coco/instances_val2017.json`.
86 | - `--coco_train_instances_json`: Path to JSON file containing stuff annotations for COCO validation images; default is `datasets/coco/stuff_val2017.json`.
87 | - `--instance_whitelist`: The default behavior is to train the model to generate all object categories; however by passing a comma-separated list to this flag we can train the model to generate only a subset of object categories.
88 | - `--stuff_whitelist`: The default behavior is to train the model to generate all stuff categories (except other, see below); however by passing a comma-separated list to this flag we can train the model to generate only a subset of stuff categories.
89 | - `--coco_include_other`: The COCO-Stuff annotations include an "other" category for objects which do not fall into any of the other object categories. Due to the variety in this category I found that the model was unable to learn it, so setting this flag to 0 (default) causes COCO-Stuff annotations with the "other" category to be ignored. Setting it to 1 will include these "other" annotations.
90 | - `--coco_stuff_only`: The 2017 COCO training split contains 115K images. Object annotations are provided for all of these images, but Stuff annotations are only provided for 40K of these images. Setting this flag to 1 (default) will only train using images for which Stuff annotations are available; setting this flag to 0 will use all 115K images for training, including Stuff annotations only for the images for which they are available.
91 |
92 | ### Generator options
93 | These flags control the architecture and loss hyperparameters for the generator, which inputs scene graphs and outputs images.
94 |
95 | - `--mask_size`: Integer giving the resolution at which instance segmentation masks are predicted for objects. Default is 16. Setting this to 0 causes the model to omit the mask prediction subnetwork, instead using the entire object bounding box as the mask. Since mask prediction is differentiable the model can predict masks even when the training dataset does not provide masks; in particular Visual Genome does not provide masks, but all VG models were trained with `--mask_size 16`.
96 | - `--embedding_dim`: Integer giving the dimension for the embedding layer for objects and relationships prior to the first graph convolution layer. Default is 128.
97 | - `--gconv_dim`: Integer giving the dimension for the vectors in graph convolution layers. Default is 128.
98 | - `--gconv_hidden_dim`: Integer giving the dimension for the hidden dimension inside each graph convolution layer; this is the dimension of the candidate vectors V^s_i and V^s_o from Equations 1 and 2 in the paper. Default is 512.
99 | - `--gconv_num_layers`: The number of graph convolution layers to use. Default is 5.
100 | - `--mlp_normalization`: The type of normalization (if any) to use for linear layers in MLPs inside graph convolution layers and the box prediction subnetwork. Choices are 'none' (default), which means to use no normalization, or 'batch' which means to use batch normalization.
101 | - `--refinement_network_dims`: Comma-separated list of integers specifying the architecture of the cascaded refinement network used to generate images; default is `1024,512,256,128,64` which means to use five refinement modules, the first with 1024 feature maps, the second with 512 feature maps, etc. Spatial resolution of the feature maps doubles between each successive refinement modules.
102 | - `--normalization`: The type of normalization layer to use in the cascaded refinement network. Options are 'batch' (default) for batch normalization, 'instance' for instance normalization, or 'none' for no normalization.
103 | - `--activation`: Activation function to use in the cascaded refinement network; default is `leakyrelu-0.2` which is a Leaky ReLU with a negative slope of 0.2. Can also be `relu`.
104 | - `--layout_noise_dim`: The number of channels of random noise that is concatenated with the scene layout before feeding to the cascaded refinement network. Default is 32.
105 |
106 | **Generator Losses**: These flags control the non-adversarial losses used to to train the generator:
107 |
108 | - `--l1_pixel_loss_weight`: Float giving the weight to give L1 difference between generated and ground-truth image. Default is 1.0.
109 | - `--mask_loss_weight`: Float giving the weight to give mask prediction in the overall model loss. Setting this to 0 (default) means that masks are weakly supervised, which is required when training on Visual Genome. For COCO I found that setting `--mask_loss_weight` to 0.1 works well.
110 | - `--bbox_pred_loss_weight`: Float giving the weight to assign to regressing the bounding boxes for objects. Default is 10.
111 |
112 | ### Discriminator options
113 | The generator is trained adversarially against two discriminators: an patch-based image discriminator ensuring that patches of the generated image look realistic, and an object discriminator that ensures that generated objects are realistic. These flags apply to both discriminators:
114 |
115 | - `--discriminator_loss_weight`: The weight to assign to discriminator losses when training the generator. Default is 0.01.
116 | - `--gan_loss_type`: The GAN loss function to use. Default is 'gan' which is the original GAN loss function; can also be 'lsgan' for least-squares GAN or 'wgan' for Wasserstein GAN loss.
117 | - `--d_clip`: Value at which to clip discriminator weights, for WGAN. Default is no clipping.
118 | - `--d_normalization`: The type of normalization to use in discriminators. Default is 'batch' for batch normalization, but like CRN normalization this can also be 'none' or 'instance'.
119 | - `--d_padding`: The type of padding to use for convolutions in the discriminators, either 'valid' (default) or 'same'.
120 | - `--d_activation`: Activation function to use in the discriminators. Like CRN the default is `leakyrelu-0.2`.
121 |
122 | **Object Discriminator**: These flags only apply to the object discriminator:
123 |
124 | - `--d_obj_arch`: String giving the architecture of the object discriminator; the semantics for architecture strings [is described here](https://github.com/jcjohnson/sg2im-release/blob/master/sg2im/layers.py#L116).
125 | - `--crop_size`: The object discriminator crops out each object in images; this gives the spatial size to which these crops are (differentiably) resized. Default is 32.
126 | - `--d_obj_weight`: Weight for real / fake classification in the object discriminator. During training the weight given to fooling the object discriminator is `--discriminator_loss_weight * --d_obj_weight`. Default is 1.0
127 | - `--ac_loss_weight`: Weight for the auxiliary classifier in the object discriminator that attempts to predict the object category of objects; the weight assigned to this loss is `--discriminator_loss_weight * --ac_loss_weight`. Default is 0.1.
128 |
129 | **Image Discriminator**: These flags only apply to the image discriminator:
130 |
131 | - `--d_img_arch`: String giving the architecture of the image discriminator; the semantics for architecture strings [is described here](https://github.com/jcjohnson/sg2im-release/blob/master/sg2im/layers.py#L116).
132 | - `--d_img_weight`: The weight assigned to fooling the image discriminator is `--discriminator_loss_weight * --d_img_weight`. Default is 1.0.
133 |
134 | ### Output Options
135 | These flags control outputs from the training script:
136 |
137 | - `--print_every`: Training losses are printed and recorded every `--print_every` iterations. Default is 10.
138 | - `--timing`: If this flag is set to 1 then measure and print the time that each model component takes to execute.
139 | - `--checkpoint_every`: Checkpoints are saved to disk every `--checkpoint_every` iterations. Default is 10000. Each checkpoint contains a history of training losses, a history of images generated from training-set and val-set scene graphs, the current state of the generator, discriminators, and optimizers, as well as all other state information needed to resume training in case it is interrupted. We actually save two checkpoints: one with all information, and one without model parameters; the latter is much smaller, and is convenient for exploring the results of a large hyperparameter sweep without actually loading model parameters.
140 | - `--output_dir`: Directory to which checkpoints will be saved. Default is current directory.
141 | - `--checkpoint_name`: Base filename for saved checkpoints; default is 'checkpoint', so the filename for the checkpoint with model parameters will be 'checkpoint_with_model.pt' and the filename for the checkpoint without model parameters will be 'checkpoint_no_model.pt'.
142 | - `--restore_from_checkpoint`: Default behavior is to start training from scratch, and overwrite the output checkpoint path if it already exists. If this flag is set to 1 then instead resume training from the output checkpoint file if it already exists. This is useful when running in an environment where jobs can be preempted.
143 | - `--checkpoint_start_from`: Default behavior is to start training from scratch; if this flag is given then instead resume training from the specified checkpoint. This takes precedence over `--restore_from_checkpoint` if both are given.
144 |
145 | ## (Optional) Step 4: Strip Checkpoint
146 | Checkpoints saved by train.py contain not only model parameters but also optimizer states, losses, a history of generated images, and other statistics. This information is very useful for development and debugging models, but makes the saved checkpoints very large. You can use the script `scripts/strip_checkpoint.py` to remove all this extra information from a saved checkpoint, and save only the trained models:
147 |
148 | ```bash
149 | python scripts/strip_checkpoint.py \
150 | --input_checkpoint checkpoint_with_model.pt \
151 | --output_checkpoint checkpoint_stripped.pt
152 | ```
153 |
154 | ## Step 5: Run the model
155 | You can use the script `scripts/run_model.py` to run the model on arbitrary scene graphs specified in a simple JSON format. For example to generate images for the scene graphs used in Figure 6 of the paper you can run:
156 |
157 | ```bash
158 | python scripts/run_model.py \
159 | --checkpoint checkpoint_with_model.pt \
160 | --scene_graphs_json scene_graphs/figure_6_sheep.json
161 | ```
162 |
--------------------------------------------------------------------------------
/images/sheep/img000000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/img000000.png
--------------------------------------------------------------------------------
/images/sheep/img000001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/img000001.png
--------------------------------------------------------------------------------
/images/sheep/img000002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/img000002.png
--------------------------------------------------------------------------------
/images/sheep/img000003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/img000003.png
--------------------------------------------------------------------------------
/images/sheep/img000004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/img000004.png
--------------------------------------------------------------------------------
/images/sheep/img000005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/img000005.png
--------------------------------------------------------------------------------
/images/sheep/img000006.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/img000006.png
--------------------------------------------------------------------------------
/images/sheep/sg000000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/sg000000.png
--------------------------------------------------------------------------------
/images/sheep/sg000001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/sg000001.png
--------------------------------------------------------------------------------
/images/sheep/sg000002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/sg000002.png
--------------------------------------------------------------------------------
/images/sheep/sg000003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/sg000003.png
--------------------------------------------------------------------------------
/images/sheep/sg000004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/sg000004.png
--------------------------------------------------------------------------------
/images/sheep/sg000005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/sg000005.png
--------------------------------------------------------------------------------
/images/sheep/sg000006.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/sheep/sg000006.png
--------------------------------------------------------------------------------
/images/system.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/sg2im/2c1bf4a150f8a70c0977200a6719519213367b5c/images/system.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cloudpickle==0.5.3
2 | cycler==0.10.0
3 | Cython==0.28.3
4 | dask==0.17.5
5 | decorator==4.3.0
6 | h5py==2.8.0
7 | imageio==2.3.0
8 | kiwisolver==1.0.1
9 | matplotlib==2.2.2
10 | networkx==2.1
11 | numpy==1.14.4
12 | Pillow==5.1.0
13 | pyparsing==2.2.0
14 | python-dateutil==2.7.3
15 | pytz==2018.4
16 | PyWavelets==0.5.2
17 | scikit-image==0.14.0
18 | scipy==1.1.0
19 | six==1.11.0
20 | toolz==0.9.0
21 | torch==0.4.0
22 | torchvision==0.2.1
23 |
--------------------------------------------------------------------------------
/scene_graphs/figure_5_coco.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "objects": [
4 | "car", "car", "cage", "person", "grass", "tree", "playingfield", "person"
5 | ],
6 | "relationships": [
7 | [0, "left of", 1],
8 | [0, "above", 6],
9 | [1, "above", 4],
10 | [3, "left of", 7],
11 | [5, "above", 7],
12 | [6, "below", 2],
13 | [6, "below", 3],
14 | [7, "left of", 4]
15 | ]
16 | },
17 | {
18 | "objects": [ "broccoli", "broccoli", "vegetable", "carrot"],
19 | "relationships": [
20 | [0, "left of", 1],
21 | [0, "left of", 1],
22 | [1, "inside", 2],
23 | [3, "below", 1]
24 | ]
25 | },
26 | {
27 | "objects": ["person", "person", "person", "fence"],
28 | "relationships": [
29 | [0, "left of", 1],
30 | [0, "left of", 1],
31 | [0, "inside", 3],
32 | [2, "inside", 3]
33 | ]
34 | },
35 | {
36 | "objects": ["sky-other", "person", "skateboard", "tree"],
37 | "relationships": [
38 | [0, "surrounding", 2],
39 | [1, "inside", 0],
40 | [1, "above", 3],
41 | [3, "below", 0]
42 | ]
43 | },
44 | {
45 | "objects": ["person", "tie", "wall-panel", "clothes"],
46 | "relationships": [
47 | [0, "surrounding", 1],
48 | [1, "inside", 0],
49 | [1, "above", 3],
50 | [2, "surrounding", 0]
51 | ]
52 | },
53 | {
54 | "objects": ["clouds", "tree", "person", "horse", "grass"],
55 | "relationships": [
56 | [0, "above", 4],
57 | [0, "above", 4],
58 | [1, "right of", 2],
59 | [2, "left of", 3],
60 | [3, "above", 4]
61 | ]
62 | },
63 | {
64 | "objects": ["elephant", "elephant", "tree", "grass"],
65 | "relationships": [
66 | [0, "inside", 2],
67 | [0, "above", 3],
68 | [2, "surrounding", 1]
69 | ]
70 | },
71 | {
72 | "objects": ["clouds", "building-other", "boat", "tree", "river"],
73 | "relationships": [
74 | [0, "above", 1],
75 | [0, "above", 1],
76 | [1, "above", 4],
77 | [3, "left of", 4],
78 | [4, "below", 0]
79 | ]
80 | }
81 | ]
82 |
--------------------------------------------------------------------------------
/scene_graphs/figure_5_vg.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "objects": [
4 | "sky", "cloud", "mountain", "sheep", "grass", "tree",
5 | "stone", "rock", "sheep"
6 | ],
7 | "relationships": [
8 | [0, "has", 1],
9 | [2, "behind", 5],
10 | [3, "eating", 4],
11 | [3, "eating", 4],
12 | [5, "in front of", 2]
13 | ]
14 | },
15 | {
16 | "objects": [
17 | "sky", "water", "person", "wave", "board",
18 | "cloud", "background", "edge"
19 | ],
20 | "relationships": [
21 | [0, "above", 1],
22 | [2, "by", 1],
23 | [2, "riding", 3],
24 | [2, "riding", 4]
25 | ]
26 | },
27 | {
28 | "objects": ["boy", "grass", "sky", "kite", "field", "mountain", "brick"],
29 | "relationships": [
30 | [0, "on top of", 1],
31 | [0, "standing on", 1],
32 | [0, "looking at", 2],
33 | [0, "looking at", 3],
34 | [4, "under", 5]
35 | ]
36 | },
37 | {
38 | "objects": [
39 | "building", "bus", "tree", "bus", "windshield", "windshield",
40 | "sky", "line", "sign"
41 | ],
42 | "relationships": [
43 | [0, "next to", 0],
44 | [0, "next to", 0],
45 | [0, "next to", 0],
46 | [0, "next to", 0],
47 | [1, "has", 4],
48 | [1, "behind", 3],
49 | [2, "behind", 3],
50 | [3, "has", 5]
51 | ]
52 | },
53 | {
54 | "objects": [
55 | "car", "tree", "house", "street", "window", "roof", "house", "bush", "car"
56 | ],
57 | "relationships": [
58 | [0, "parked on", 3],
59 | [1, "along", 3],
60 | [2, "has", 5],
61 | [4, "in front of", 6]
62 | ]
63 | },
64 | {
65 | "objects": [
66 | "sky", "man", "leg", "horse", "tail", "leg",
67 | "short", "hill", "hill"
68 | ],
69 | "relationships": [
70 | [0, "above", 1],
71 | [1, "has", 2],
72 | [1, "riding", 3],
73 | [3, "has", 4],
74 | [3, "has", 4],
75 | [3, "has", 5]
76 | ]
77 | },
78 | {
79 | "objects": ["boat", "water", "sky", "rock", "bird"],
80 | "relationships": [[0, "on top of", 1]]
81 | },
82 | {
83 | "objects": ["food", "glass", "glass", "plate", "plate"],
84 | "relationships": [
85 | [0, "on top of", 3],
86 | [1, "by", 3],
87 | [2, "on top of", 4]
88 | ]
89 | }
90 | ]
91 |
--------------------------------------------------------------------------------
/scene_graphs/figure_6_sheep.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "objects": ["sky", "grass", "zebra"],
4 | "relationships": [
5 | [0, "above", 1],
6 | [2, "standing on", 1]
7 | ]
8 | },
9 | {
10 | "objects": ["sky", "grass", "sheep"],
11 | "relationships": [
12 | [0, "above", 1],
13 | [2, "standing on", 1]
14 | ]
15 | },
16 | {
17 | "objects": ["sky", "grass", "sheep", "sheep"],
18 | "relationships": [
19 | [0, "above", 1],
20 | [2, "standing on", 1],
21 | [3, "by", 2]
22 | ]
23 | },
24 | {
25 | "objects": ["sky", "grass", "sheep", "sheep", "tree"],
26 | "relationships": [
27 | [0, "above", 1],
28 | [2, "standing on", 1],
29 | [3, "by", 2],
30 | [4, "behind", 2]
31 | ]
32 | },
33 | {
34 | "objects": ["sky", "grass", "sheep", "sheep", "tree", "ocean"],
35 | "relationships": [
36 | [0, "above", 1],
37 | [2, "standing on", 1],
38 | [3, "by", 2],
39 | [4, "behind", 2],
40 | [5, "by", 4]
41 | ]
42 | },
43 | {
44 | "objects": ["sky", "grass", "sheep", "sheep", "tree", "ocean", "boat"],
45 | "relationships": [
46 | [0, "above", 1],
47 | [2, "standing on", 1],
48 | [3, "by", 2],
49 | [4, "behind", 2],
50 | [5, "by", 4],
51 | [6, "in", 5]
52 | ]
53 | },
54 | {
55 | "objects": ["sky", "grass", "sheep", "sheep", "tree", "ocean", "boat"],
56 | "relationships": [
57 | [0, "above", 1],
58 | [2, "standing on", 1],
59 | [3, "by", 2],
60 | [4, "behind", 2],
61 | [5, "by", 4],
62 | [6, "on", 1]
63 | ]
64 | }
65 | ]
66 |
--------------------------------------------------------------------------------
/scene_graphs/figure_6_street.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "objects": ["car", "street", "line", "sky"],
4 | "relationships": [
5 | [0, "on", 1],
6 | [2, "on", 1],
7 | [3, "above", 1]
8 | ]
9 | },
10 | {
11 | "objects": ["bus", "street", "line", "sky"],
12 | "relationships": [
13 | [0, "on", 1],
14 | [2, "on", 1],
15 | [3, "above", 1]
16 | ]
17 | },
18 | {
19 | "objects": ["bus", "street", "line", "sky", "car"],
20 | "relationships": [
21 | [0, "on", 1],
22 | [2, "on", 1],
23 | [3, "above", 1],
24 | [4, "on", 1]
25 | ]
26 | },
27 | {
28 | "objects": ["bus", "street", "line", "sky", "car", "kite"],
29 | "relationships": [
30 | [0, "on", 1],
31 | [2, "on", 1],
32 | [3, "above", 1],
33 | [4, "on", 1],
34 | [5, "in", 3]
35 | ]
36 | },
37 | {
38 | "objects": ["bus", "street", "line", "sky", "car", "kite"],
39 | "relationships": [
40 | [0, "on", 1],
41 | [2, "on", 1],
42 | [3, "above", 1],
43 | [4, "on", 1],
44 | [5, "in", 3],
45 | [4, "below", 5]
46 | ]
47 | },
48 | {
49 | "objects": ["bus", "street", "line", "sky", "car", "building"],
50 | "relationships": [
51 | [0, "on", 1],
52 | [2, "on", 1],
53 | [3, "above", 1],
54 | [4, "on", 1],
55 | [5, "behind", 1]
56 | ]
57 | },
58 | {
59 | "objects": ["bus", "street", "line", "sky", "car", "building", "window"],
60 | "relationships": [
61 | [0, "on", 1],
62 | [2, "on", 1],
63 | [3, "above", 1],
64 | [4, "on", 1],
65 | [5, "behind", 1],
66 | [6, "on", 5]
67 | ]
68 | }
69 | ]
70 |
--------------------------------------------------------------------------------
/scripts/download_ablated_models.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -eu
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # Download all models from the ablation study
18 | mkdir -p sg2im-models
19 |
20 | # COCO models
21 | wget https://storage.googleapis.com/sg2im-data/small/coco64_no_gconv.pt -O sg2im-models/coco64_no_gconv.pt
22 | wget https://storage.googleapis.com/sg2im-data/small/coco64_no_relations.pt -O sg2im-models/coco64_no_relations.pt
23 | wget https://storage.googleapis.com/sg2im-data/small/coco64_no_discriminators.pt -O sg2im-models/coco64_no_discriminators.pt
24 | wget https://storage.googleapis.com/sg2im-data/small/coco64_no_img_discriminator.pt -O sg2im-models/coco64_no_img_discriminator.pt
25 | wget https://storage.googleapis.com/sg2im-data/small/coco64_no_obj_discriminator.pt -O sg2im-models/coco64_no_obj_discriminator.pt
26 | wget https://storage.googleapis.com/sg2im-data/small/coco64_gt_layout.pt -O sg2im-models/coco64_gt_layout.pt
27 | wget https://storage.googleapis.com/sg2im-data/small/coco64_gt_layout_no_gconv.pt -O sg2im-models/coco64_gt_layout_no_gconv.pt
28 |
29 | # VG models
30 | wget https://storage.googleapis.com/sg2im-data/small/vg64_no_relations.pt -O sg2im-models/vg64_no_relations.pt
31 | wget https://storage.googleapis.com/sg2im-data/small/vg64_no_gconv.pt -O sg2im-models/vg64_no_gconv.pt
32 | wget https://storage.googleapis.com/sg2im-data/small/vg64_no_discriminators.pt -O sg2im-models/vg64_no_discriminators.pt
33 | wget https://storage.googleapis.com/sg2im-data/small/vg64_no_img_discriminator.pt -O sg2im-models/vg64_no_img_discriminator.pt
34 | wget https://storage.googleapis.com/sg2im-data/small/vg64_no_obj_discriminator.pt -O sg2im-models/vg64_no_obj_discriminator.pt
35 |
--------------------------------------------------------------------------------
/scripts/download_coco.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -eu
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | COCO_DIR=datasets/coco
18 | mkdir -p $COCO_DIR
19 |
20 | wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip -O $COCO_DIR/annotations_trainval2017.zip
21 | wget http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip -O $COCO_DIR/stuff_annotations_trainval2017.zip
22 | wget http://images.cocodataset.org/zips/train2017.zip -O $COCO_DIR/train2017.zip
23 | wget http://images.cocodataset.org/zips/val2017.zip -O $COCO_DIR/val2017.zip
24 |
25 | unzip $COCO_DIR/annotations_trainval2017.zip -d $COCO_DIR
26 | unzip $COCO_DIR/stuff_annotations_trainval2017.zip -d $COCO_DIR
27 | unzip $COCO_DIR/train2017.zip -d $COCO_DIR/images
28 | unzip $COCO_DIR/val2017.zip -d $COCO_DIR/images
29 |
--------------------------------------------------------------------------------
/scripts/download_full_models.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -eu
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # Download full versions of all models which include training history
18 | mkdir -p sg2im-models/full
19 |
20 | # COCO models
21 | wget https://storage.googleapis.com/sg2im-data/full/coco64.pt -O sg2im-models/full/coco64.pt
22 | wget https://storage.googleapis.com/sg2im-data/full/coco64_no_gconv.pt -O sg2im-models/full/coco64_no_gconv.pt
23 | wget https://storage.googleapis.com/sg2im-data/full/coco64_no_relations.pt -O sg2im-models/full/coco64_no_relations.pt
24 | wget https://storage.googleapis.com/sg2im-data/full/coco64_no_discriminators.pt -O sg2im-models/full/coco64_no_discriminators.pt
25 | wget https://storage.googleapis.com/sg2im-data/full/coco64_no_img_discriminator.pt -O sg2im-models/full/coco64_no_img_discriminator.pt
26 | wget https://storage.googleapis.com/sg2im-data/full/coco64_no_obj_discriminator.pt -O sg2im-models/full/coco64_no_obj_discriminator.pt
27 | wget https://storage.googleapis.com/sg2im-data/full/coco64_gt_layout.pt -O sg2im-models/full/coco64_gt_layout.pt
28 | wget https://storage.googleapis.com/sg2im-data/full/coco64_gt_layout_no_gconv.pt -O sg2im-models/full/coco64_gt_layout_no_gconv.pt
29 |
30 | # VG models
31 | wget https://storage.googleapis.com/sg2im-data/full/vg64.pt -O sg2im-models/full/vg64.pt
32 | wget https://storage.googleapis.com/sg2im-data/full/vg128.pt -O sg2im-models/full/vg128.pt
33 | wget https://storage.googleapis.com/sg2im-data/full/vg64_no_relations.pt -O sg2im-models/full/vg64_no_relations.pt
34 | wget https://storage.googleapis.com/sg2im-data/full/vg64_no_gconv.pt -O sg2im-models/full/vg64_no_gconv.pt
35 | wget https://storage.googleapis.com/sg2im-data/full/vg64_no_discriminators.pt -O sg2im-models/full/vg64_no_discriminators.pt
36 | wget https://storage.googleapis.com/sg2im-data/full/vg64_no_img_discriminator.pt -O sg2im-models/full/vg64_no_img_discriminator.pt
37 | wget https://storage.googleapis.com/sg2im-data/full/vg64_no_obj_discriminator.pt -O sg2im-models/full/vg64_no_obj_discriminator.pt
38 |
--------------------------------------------------------------------------------
/scripts/download_models.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -eu
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # Download the main models: 64 x 64 for coco and vg, and 128 x 128 for vg
18 | mkdir -p sg2im-models
19 | wget https://storage.googleapis.com/sg2im-data/small/coco64.pt -O sg2im-models/coco64.pt
20 | wget https://storage.googleapis.com/sg2im-data/small/vg64.pt -O sg2im-models/vg64.pt
21 | wget https://storage.googleapis.com/sg2im-data/small/vg128.pt -O sg2im-models/vg128.pt
22 |
--------------------------------------------------------------------------------
/scripts/download_vg.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -eu
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | VG_DIR=datasets/vg
18 | mkdir -p $VG_DIR
19 |
20 | wget https://visualgenome.org/static/data/dataset/objects.json.zip -O $VG_DIR/objects.json.zip
21 | wget https://visualgenome.org/static/data/dataset/attributes.json.zip -O $VG_DIR/attributes.json.zip
22 | wget https://visualgenome.org/static/data/dataset/relationships.json.zip -O $VG_DIR/relationships.json.zip
23 | wget https://visualgenome.org/static/data/dataset/object_alias.txt -O $VG_DIR/object_alias.txt
24 | wget https://visualgenome.org/static/data/dataset/relationship_alias.txt -O $VG_DIR/relationship_alias.txt
25 | wget https://visualgenome.org/static/data/dataset/image_data.json.zip -O $VG_DIR/image_data.json.zip
26 | wget https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip -O $VG_DIR/images.zip
27 | wget https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip -O $VG_DIR/images2.zip
28 |
29 | unzip $VG_DIR/objects.json.zip -d $VG_DIR
30 | unzip $VG_DIR/attributes.json.zip -d $VG_DIR
31 | unzip $VG_DIR/relationships.json.zip -d $VG_DIR
32 | unzip $VG_DIR/image_data.json.zip -d $VG_DIR
33 | unzip $VG_DIR/images.zip -d $VG_DIR/images
34 | unzip $VG_DIR/images2.zip -d $VG_DIR/images
35 |
--------------------------------------------------------------------------------
/scripts/preprocess_vg.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import argparse, json, os
18 | from collections import Counter, defaultdict
19 |
20 | import numpy as np
21 | import h5py
22 | from scipy.misc import imread, imresize
23 |
24 |
25 | """
26 | vocab for objects contains a special entry "__image__" intended to be used for
27 | dummy nodes encompassing the entire image; vocab for predicates includes a
28 | special entry "__in_image__" to be used for dummy relationships making the graph
29 | fully-connected.
30 | """
31 |
32 |
33 | VG_DIR = 'datasets/vg'
34 |
35 | parser = argparse.ArgumentParser()
36 |
37 | # Input data
38 | parser.add_argument('--splits_json', default='sg2im/data/vg_splits.json')
39 | parser.add_argument('--images_json',
40 | default=os.path.join(VG_DIR, 'image_data.json'))
41 | parser.add_argument('--objects_json',
42 | default=os.path.join(VG_DIR, 'objects.json'))
43 | parser.add_argument('--attributes_json',
44 | default=os.path.join(VG_DIR, 'attributes.json'))
45 | parser.add_argument('--object_aliases',
46 | default=os.path.join(VG_DIR, 'object_alias.txt'))
47 | parser.add_argument('--relationship_aliases',
48 | default=os.path.join(VG_DIR, 'relationship_alias.txt'))
49 | parser.add_argument('--relationships_json',
50 | default=os.path.join(VG_DIR, 'relationships.json'))
51 |
52 | # Arguments for images
53 | parser.add_argument('--min_image_size', default=200, type=int)
54 | parser.add_argument('--train_split', default='train')
55 |
56 | # Arguments for objects
57 | parser.add_argument('--min_object_instances', default=2000, type=int)
58 | parser.add_argument('--min_attribute_instances', default=2000, type=int)
59 | parser.add_argument('--min_object_size', default=32, type=int)
60 | parser.add_argument('--min_objects_per_image', default=3, type=int)
61 | parser.add_argument('--max_objects_per_image', default=30, type=int)
62 | parser.add_argument('--max_attributes_per_image', default=30, type=int)
63 |
64 | # Arguments for relationships
65 | parser.add_argument('--min_relationship_instances', default=500, type=int)
66 | parser.add_argument('--min_relationships_per_image', default=1, type=int)
67 | parser.add_argument('--max_relationships_per_image', default=30, type=int)
68 |
69 | # Output
70 | parser.add_argument('--output_vocab_json',
71 | default=os.path.join(VG_DIR, 'vocab.json'))
72 | parser.add_argument('--output_h5_dir', default=VG_DIR)
73 |
74 |
75 | def main(args):
76 | print('Loading image info from "%s"' % args.images_json)
77 | with open(args.images_json, 'r') as f:
78 | images = json.load(f)
79 | image_id_to_image = {i['image_id']: i for i in images}
80 |
81 | with open(args.splits_json, 'r') as f:
82 | splits = json.load(f)
83 |
84 | # Filter images for being too small
85 | splits = remove_small_images(args, image_id_to_image, splits)
86 |
87 | obj_aliases = load_aliases(args.object_aliases)
88 | rel_aliases = load_aliases(args.relationship_aliases)
89 |
90 | print('Loading objects from "%s"' % args.objects_json)
91 | with open(args.objects_json, 'r') as f:
92 | objects = json.load(f)
93 |
94 | # Vocab for objects and relationships
95 | vocab = {}
96 | train_ids = splits[args.train_split]
97 | create_object_vocab(args, train_ids, objects, obj_aliases, vocab)
98 |
99 | print('Loading attributes from "%s"' % args.attributes_json)
100 | with open(args.attributes_json, 'r') as f:
101 | attributes = json.load(f)
102 |
103 | # Vocab for attributes
104 | create_attribute_vocab(args, train_ids, attributes, vocab)
105 |
106 | object_id_to_obj = filter_objects(args, objects, obj_aliases, vocab, splits)
107 | print('After filtering there are %d object instances'
108 | % len(object_id_to_obj))
109 |
110 | print('Loading relationshps from "%s"' % args.relationships_json)
111 | with open(args.relationships_json, 'r') as f:
112 | relationships = json.load(f)
113 |
114 | create_rel_vocab(args, train_ids, relationships, object_id_to_obj,
115 | rel_aliases, vocab)
116 |
117 | print('Encoding objects and relationships ...')
118 | numpy_arrays = encode_graphs(args, splits, objects, relationships, vocab,
119 | object_id_to_obj, attributes)
120 |
121 | print('Writing HDF5 output files')
122 | for split_name, split_arrays in numpy_arrays.items():
123 | image_ids = list(split_arrays['image_ids'].astype(int))
124 | h5_path = os.path.join(args.output_h5_dir, '%s.h5' % split_name)
125 | print('Writing file "%s"' % h5_path)
126 | with h5py.File(h5_path, 'w') as h5_file:
127 | for name, ary in split_arrays.items():
128 | print('Creating datset: ', name, ary.shape, ary.dtype)
129 | h5_file.create_dataset(name, data=ary)
130 | print('Writing image paths')
131 | image_paths = get_image_paths(image_id_to_image, image_ids)
132 | path_dtype = h5py.special_dtype(vlen=str)
133 | path_shape = (len(image_paths),)
134 | path_dset = h5_file.create_dataset('image_paths', path_shape,
135 | dtype=path_dtype)
136 | for i, p in enumerate(image_paths):
137 | path_dset[i] = p
138 | print()
139 |
140 | print('Writing vocab to "%s"' % args.output_vocab_json)
141 | with open(args.output_vocab_json, 'w') as f:
142 | json.dump(vocab, f)
143 |
144 | def remove_small_images(args, image_id_to_image, splits):
145 | new_splits = {}
146 | for split_name, image_ids in splits.items():
147 | new_image_ids = []
148 | num_skipped = 0
149 | for image_id in image_ids:
150 | image = image_id_to_image[image_id]
151 | height, width = image['height'], image['width']
152 | if min(height, width) < args.min_image_size:
153 | num_skipped += 1
154 | continue
155 | new_image_ids.append(image_id)
156 | new_splits[split_name] = new_image_ids
157 | print('Removed %d images from split "%s" for being too small' %
158 | (num_skipped, split_name))
159 |
160 | return new_splits
161 |
162 |
163 | def get_image_paths(image_id_to_image, image_ids):
164 | paths = []
165 | for image_id in image_ids:
166 | image = image_id_to_image[image_id]
167 | base, filename = os.path.split(image['url'])
168 | path = os.path.join(os.path.basename(base), filename)
169 | paths.append(path)
170 | return paths
171 |
172 |
173 | def handle_images(args, image_ids, h5_file):
174 | with open(args.images_json, 'r') as f:
175 | images = json.load(f)
176 | if image_ids:
177 | image_ids = set(image_ids)
178 |
179 | image_heights, image_widths = [], []
180 | image_ids_out, image_paths = [], []
181 | for image in images:
182 | image_id = image['image_id']
183 | if image_ids and image_id not in image_ids:
184 | continue
185 | height, width = image['height'], image['width']
186 |
187 | base, filename = os.path.split(image['url'])
188 | path = os.path.join(os.path.basename(base), filename)
189 | image_paths.append(path)
190 | image_heights.append(height)
191 | image_widths.append(width)
192 | image_ids_out.append(image_id)
193 |
194 | image_ids_np = np.asarray(image_ids_out, dtype=int)
195 | h5_file.create_dataset('image_ids', data=image_ids_np)
196 |
197 | image_heights = np.asarray(image_heights, dtype=int)
198 | h5_file.create_dataset('image_heights', data=image_heights)
199 |
200 | image_widths = np.asarray(image_widths, dtype=int)
201 | h5_file.create_dataset('image_widths', data=image_widths)
202 |
203 | return image_paths
204 |
205 |
206 | def load_aliases(alias_path):
207 | aliases = {}
208 | print('Loading aliases from "%s"' % alias_path)
209 | with open(alias_path, 'r') as f:
210 | for line in f:
211 | line = [s.strip() for s in line.split(',')]
212 | for s in line:
213 | aliases[s] = line[0]
214 | return aliases
215 |
216 |
217 | def create_object_vocab(args, image_ids, objects, aliases, vocab):
218 | image_ids = set(image_ids)
219 |
220 | print('Making object vocab from %d training images' % len(image_ids))
221 | object_name_counter = Counter()
222 | for image in objects:
223 | if image['image_id'] not in image_ids:
224 | continue
225 | for obj in image['objects']:
226 | names = set()
227 | for name in obj['names']:
228 | names.add(aliases.get(name, name))
229 | object_name_counter.update(names)
230 |
231 | object_names = ['__image__']
232 | for name, count in object_name_counter.most_common():
233 | if count >= args.min_object_instances:
234 | object_names.append(name)
235 | print('Found %d object categories with >= %d training instances' %
236 | (len(object_names), args.min_object_instances))
237 |
238 | object_name_to_idx = {}
239 | object_idx_to_name = []
240 | for idx, name in enumerate(object_names):
241 | object_name_to_idx[name] = idx
242 | object_idx_to_name.append(name)
243 |
244 | vocab['object_name_to_idx'] = object_name_to_idx
245 | vocab['object_idx_to_name'] = object_idx_to_name
246 |
247 | def create_attribute_vocab(args, image_ids, attributes, vocab):
248 | image_ids = set(image_ids)
249 | print('Making attribute vocab from %d training images' % len(image_ids))
250 | attribute_name_counter = Counter()
251 | for image in attributes:
252 | if image['image_id'] not in image_ids:
253 | continue
254 | for attribute in image['attributes']:
255 | names = set()
256 | try:
257 | for name in attribute['attributes']:
258 | names.add(name)
259 | attribute_name_counter.update(names)
260 | except KeyError:
261 | pass
262 | attribute_names = []
263 | for name, count in attribute_name_counter.most_common():
264 | if count >= args.min_attribute_instances:
265 | attribute_names.append(name)
266 | print('Found %d attribute categories with >= %d training instances' %
267 | (len(attribute_names), args.min_attribute_instances))
268 |
269 | attribute_name_to_idx = {}
270 | attribute_idx_to_name = []
271 | for idx, name in enumerate(attribute_names):
272 | attribute_name_to_idx[name] = idx
273 | attribute_idx_to_name.append(name)
274 | vocab['attribute_name_to_idx'] = attribute_name_to_idx
275 | vocab['attribute_idx_to_name'] = attribute_idx_to_name
276 |
277 | def filter_objects(args, objects, aliases, vocab, splits):
278 | object_id_to_objects = {}
279 | all_image_ids = set()
280 | for image_ids in splits.values():
281 | all_image_ids |= set(image_ids)
282 |
283 | object_name_to_idx = vocab['object_name_to_idx']
284 | object_id_to_obj = {}
285 |
286 | num_too_small = 0
287 | for image in objects:
288 | image_id = image['image_id']
289 | if image_id not in all_image_ids:
290 | continue
291 | for obj in image['objects']:
292 | object_id = obj['object_id']
293 | final_name = None
294 | final_name_idx = None
295 | for name in obj['names']:
296 | name = aliases.get(name, name)
297 | if name in object_name_to_idx:
298 | final_name = name
299 | final_name_idx = object_name_to_idx[final_name]
300 | break
301 | w, h = obj['w'], obj['h']
302 | too_small = (w < args.min_object_size) or (h < args.min_object_size)
303 | if too_small:
304 | num_too_small += 1
305 | if final_name is not None and not too_small:
306 | object_id_to_obj[object_id] = {
307 | 'name': final_name,
308 | 'name_idx': final_name_idx,
309 | 'box': [obj['x'], obj['y'], obj['w'], obj['h']],
310 | }
311 | print('Skipped %d objects with size < %d' % (num_too_small, args.min_object_size))
312 | return object_id_to_obj
313 |
314 |
315 | def create_rel_vocab(args, image_ids, relationships, object_id_to_obj,
316 | rel_aliases, vocab):
317 | pred_counter = defaultdict(int)
318 | image_ids_set = set(image_ids)
319 | for image in relationships:
320 | image_id = image['image_id']
321 | if image_id not in image_ids_set:
322 | continue
323 | for rel in image['relationships']:
324 | sid = rel['subject']['object_id']
325 | oid = rel['object']['object_id']
326 | found_subject = sid in object_id_to_obj
327 | found_object = oid in object_id_to_obj
328 | if not found_subject or not found_object:
329 | continue
330 | pred = rel['predicate'].lower().strip()
331 | pred = rel_aliases.get(pred, pred)
332 | rel['predicate'] = pred
333 | pred_counter[pred] += 1
334 |
335 | pred_names = ['__in_image__']
336 | for pred, count in pred_counter.items():
337 | if count >= args.min_relationship_instances:
338 | pred_names.append(pred)
339 | print('Found %d relationship types with >= %d training instances'
340 | % (len(pred_names), args.min_relationship_instances))
341 |
342 | pred_name_to_idx = {}
343 | pred_idx_to_name = []
344 | for idx, name in enumerate(pred_names):
345 | pred_name_to_idx[name] = idx
346 | pred_idx_to_name.append(name)
347 |
348 | vocab['pred_name_to_idx'] = pred_name_to_idx
349 | vocab['pred_idx_to_name'] = pred_idx_to_name
350 |
351 |
352 | def encode_graphs(args, splits, objects, relationships, vocab,
353 | object_id_to_obj, attributes):
354 |
355 | image_id_to_objects = {}
356 | for image in objects:
357 | image_id = image['image_id']
358 | image_id_to_objects[image_id] = image['objects']
359 | image_id_to_relationships = {}
360 | for image in relationships:
361 | image_id = image['image_id']
362 | image_id_to_relationships[image_id] = image['relationships']
363 | image_id_to_attributes = {}
364 | for image in attributes:
365 | image_id = image['image_id']
366 | image_id_to_attributes[image_id] = image['attributes']
367 |
368 | numpy_arrays = {}
369 | for split, image_ids in splits.items():
370 | skip_stats = defaultdict(int)
371 | # We need to filter *again* based on number of objects and relationships
372 | final_image_ids = []
373 | object_ids = []
374 | object_names = []
375 | object_boxes = []
376 | objects_per_image = []
377 | relationship_ids = []
378 | relationship_subjects = []
379 | relationship_predicates = []
380 | relationship_objects = []
381 | relationships_per_image = []
382 | attribute_ids = []
383 | attributes_per_object = []
384 | object_attributes = []
385 | for image_id in image_ids:
386 | image_object_ids = []
387 | image_object_names = []
388 | image_object_boxes = []
389 | object_id_to_idx = {}
390 | for obj in image_id_to_objects[image_id]:
391 | object_id = obj['object_id']
392 | if object_id not in object_id_to_obj:
393 | continue
394 | obj = object_id_to_obj[object_id]
395 | object_id_to_idx[object_id] = len(image_object_ids)
396 | image_object_ids.append(object_id)
397 | image_object_names.append(obj['name_idx'])
398 | image_object_boxes.append(obj['box'])
399 | num_objects = len(image_object_ids)
400 | too_few = num_objects < args.min_objects_per_image
401 | too_many = num_objects > args.max_objects_per_image
402 | if too_few:
403 | skip_stats['too_few_objects'] += 1
404 | continue
405 | if too_many:
406 | skip_stats['too_many_objects'] += 1
407 | continue
408 | image_rel_ids = []
409 | image_rel_subs = []
410 | image_rel_preds = []
411 | image_rel_objs = []
412 | for rel in image_id_to_relationships[image_id]:
413 | relationship_id = rel['relationship_id']
414 | pred = rel['predicate']
415 | pred_idx = vocab['pred_name_to_idx'].get(pred, None)
416 | if pred_idx is None:
417 | continue
418 | sid = rel['subject']['object_id']
419 | sidx = object_id_to_idx.get(sid, None)
420 | oid = rel['object']['object_id']
421 | oidx = object_id_to_idx.get(oid, None)
422 | if sidx is None or oidx is None:
423 | continue
424 | image_rel_ids.append(relationship_id)
425 | image_rel_subs.append(sidx)
426 | image_rel_preds.append(pred_idx)
427 | image_rel_objs.append(oidx)
428 | num_relationships = len(image_rel_ids)
429 | too_few = num_relationships < args.min_relationships_per_image
430 | too_many = num_relationships > args.max_relationships_per_image
431 | if too_few:
432 | skip_stats['too_few_relationships'] += 1
433 | continue
434 | if too_many:
435 | skip_stats['too_many_relationships'] += 1
436 | continue
437 |
438 | obj_id_to_attributes = {}
439 | num_attributes = []
440 | for obj_attribute in image_id_to_attributes[image_id]:
441 | obj_id_to_attributes[obj_attribute['object_id']] = obj_attribute.get('attributes', None)
442 | for object_id in image_object_ids:
443 | attributes = obj_id_to_attributes.get(object_id, None)
444 | if attributes is None:
445 | object_attributes.append([-1] * args.max_attributes_per_image)
446 | num_attributes.append(0)
447 | else:
448 | attribute_ids = []
449 | for attribute in attributes:
450 | if attribute in vocab['attribute_name_to_idx']:
451 | attribute_ids.append(vocab['attribute_name_to_idx'][attribute])
452 | if len(attribute_ids) >= args.max_attributes_per_image:
453 | break
454 | num_attributes.append(len(attribute_ids))
455 | pad_len = args.max_attributes_per_image - len(attribute_ids)
456 | attribute_ids = attribute_ids + [-1] * pad_len
457 | object_attributes.append(attribute_ids)
458 |
459 | # Pad object info out to max_objects_per_image
460 | while len(image_object_ids) < args.max_objects_per_image:
461 | image_object_ids.append(-1)
462 | image_object_names.append(-1)
463 | image_object_boxes.append([-1, -1, -1, -1])
464 | num_attributes.append(-1)
465 |
466 | # Pad relationship info out to max_relationships_per_image
467 | while len(image_rel_ids) < args.max_relationships_per_image:
468 | image_rel_ids.append(-1)
469 | image_rel_subs.append(-1)
470 | image_rel_preds.append(-1)
471 | image_rel_objs.append(-1)
472 |
473 | final_image_ids.append(image_id)
474 | object_ids.append(image_object_ids)
475 | object_names.append(image_object_names)
476 | object_boxes.append(image_object_boxes)
477 | objects_per_image.append(num_objects)
478 | relationship_ids.append(image_rel_ids)
479 | relationship_subjects.append(image_rel_subs)
480 | relationship_predicates.append(image_rel_preds)
481 | relationship_objects.append(image_rel_objs)
482 | relationships_per_image.append(num_relationships)
483 | attributes_per_object.append(num_attributes)
484 |
485 | print('Skip stats for split "%s"' % split)
486 | for stat, count in skip_stats.items():
487 | print(stat, count)
488 | print()
489 | numpy_arrays[split] = {
490 | 'image_ids': np.asarray(final_image_ids),
491 | 'object_ids': np.asarray(object_ids),
492 | 'object_names': np.asarray(object_names),
493 | 'object_boxes': np.asarray(object_boxes),
494 | 'objects_per_image': np.asarray(objects_per_image),
495 | 'relationship_ids': np.asarray(relationship_ids),
496 | 'relationship_subjects': np.asarray(relationship_subjects),
497 | 'relationship_predicates': np.asarray(relationship_predicates),
498 | 'relationship_objects': np.asarray(relationship_objects),
499 | 'relationships_per_image': np.asarray(relationships_per_image),
500 | 'attributes_per_object': np.asarray(attributes_per_object),
501 | 'object_attributes': np.asarray(object_attributes),
502 | }
503 | for k, v in numpy_arrays[split].items():
504 | if v.dtype == np.int64:
505 | numpy_arrays[split][k] = v.astype(np.int32)
506 | return numpy_arrays
507 |
508 |
509 | if __name__ == '__main__':
510 | args = parser.parse_args()
511 | main(args)
512 |
--------------------------------------------------------------------------------
/scripts/print_args.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import argparse
18 | import torch
19 |
20 |
21 | """
22 | Tiny utility to print the command-line args used for a checkpoint
23 | """
24 |
25 |
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument('checkpoint')
28 |
29 |
30 | def main(args):
31 | checkpoint = torch.load(args.checkpoint, map_location='cpu')
32 | for k, v in checkpoint['args'].items():
33 | print(k, v)
34 |
35 |
36 | if __name__ == '__main__':
37 | args = parser.parse_args()
38 | main(args)
39 |
40 |
--------------------------------------------------------------------------------
/scripts/run_model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import argparse, json, os
18 |
19 | from imageio import imwrite
20 | import torch
21 |
22 | from sg2im.model import Sg2ImModel
23 | from sg2im.data.utils import imagenet_deprocess_batch
24 | import sg2im.vis as vis
25 |
26 |
27 | parser = argparse.ArgumentParser()
28 | parser.add_argument('--checkpoint', default='sg2im-models/vg128.pt')
29 | parser.add_argument('--scene_graphs_json', default='scene_graphs/figure_6_sheep.json')
30 | parser.add_argument('--output_dir', default='outputs')
31 | parser.add_argument('--draw_scene_graphs', type=int, default=0)
32 | parser.add_argument('--device', default='gpu', choices=['cpu', 'gpu'])
33 |
34 |
35 | def main(args):
36 | if not os.path.isfile(args.checkpoint):
37 | print('ERROR: Checkpoint file "%s" not found' % args.checkpoint)
38 | print('Maybe you forgot to download pretraind models? Try running:')
39 | print('bash scripts/download_models.sh')
40 | return
41 |
42 | if not os.path.isdir(args.output_dir):
43 | print('Output directory "%s" does not exist; creating it' % args.output_dir)
44 | os.makedirs(args.output_dir)
45 |
46 | if args.device == 'cpu':
47 | device = torch.device('cpu')
48 | elif args.device == 'gpu':
49 | device = torch.device('cuda:0')
50 | if not torch.cuda.is_available():
51 | print('WARNING: CUDA not available; falling back to CPU')
52 | device = torch.device('cpu')
53 |
54 | # Load the model, with a bit of care in case there are no GPUs
55 | map_location = 'cpu' if device == torch.device('cpu') else None
56 | checkpoint = torch.load(args.checkpoint, map_location=map_location)
57 | model = Sg2ImModel(**checkpoint['model_kwargs'])
58 | model.load_state_dict(checkpoint['model_state'])
59 | model.eval()
60 | model.to(device)
61 |
62 | # Load the scene graphs
63 | with open(args.scene_graphs_json, 'r') as f:
64 | scene_graphs = json.load(f)
65 |
66 | # Run the model forward
67 | with torch.no_grad():
68 | imgs, boxes_pred, masks_pred, _ = model.forward_json(scene_graphs)
69 | imgs = imagenet_deprocess_batch(imgs)
70 |
71 | # Save the generated images
72 | for i in range(imgs.shape[0]):
73 | img_np = imgs[i].numpy().transpose(1, 2, 0)
74 | img_path = os.path.join(args.output_dir, 'img%06d.png' % i)
75 | imwrite(img_path, img_np)
76 |
77 | # Draw the scene graphs
78 | if args.draw_scene_graphs == 1:
79 | for i, sg in enumerate(scene_graphs):
80 | sg_img = vis.draw_scene_graph(sg['objects'], sg['relationships'])
81 | sg_img_path = os.path.join(args.output_dir, 'sg%06d.png' % i)
82 | imwrite(sg_img_path, sg_img)
83 |
84 |
85 | if __name__ == '__main__':
86 | args = parser.parse_args()
87 | main(args)
88 |
89 |
--------------------------------------------------------------------------------
/scripts/sample_images.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """
18 | This script can be used to sample many images from a model for evaluation.
19 | """
20 |
21 |
22 | import argparse, json
23 | import os
24 |
25 | import torch
26 | from torch.autograd import Variable
27 | from torch.utils.data import DataLoader
28 |
29 | from scipy.misc import imsave, imresize
30 |
31 | from sg2im.data import imagenet_deprocess_batch
32 | from sg2im.data.coco import CocoSceneGraphDataset, coco_collate_fn
33 | from sg2im.data.vg import VgSceneGraphDataset, vg_collate_fn
34 | from sg2im.data.utils import split_graph_batch
35 | from sg2im.model import Sg2ImModel
36 | from sg2im.utils import int_tuple, bool_flag
37 | from sg2im.vis import draw_scene_graph
38 |
39 |
40 | parser = argparse.ArgumentParser()
41 | parser.add_argument('--checkpoint', default='sg2im-models/vg64.pt')
42 | parser.add_argument('--checkpoint_list', default=None)
43 | parser.add_argument('--model_mode', default='eval', choices=['train', 'eval'])
44 |
45 | # Shared dataset options
46 | parser.add_argument('--dataset', default='vg', choices=['coco', 'vg'])
47 | parser.add_argument('--image_size', default=(64, 64), type=int_tuple)
48 | parser.add_argument('--batch_size', default=24, type=int)
49 | parser.add_argument('--shuffle', default=False, type=bool_flag)
50 | parser.add_argument('--loader_num_workers', default=4, type=int)
51 | parser.add_argument('--num_samples', default=10000, type=int)
52 | parser.add_argument('--save_gt_imgs', default=False, type=bool_flag)
53 | parser.add_argument('--save_graphs', default=False, type=bool_flag)
54 | parser.add_argument('--use_gt_boxes', default=False, type=bool_flag)
55 | parser.add_argument('--use_gt_masks', default=False, type=bool_flag)
56 | parser.add_argument('--save_layout', default=True, type=bool_flag)
57 |
58 | parser.add_argument('--output_dir', default='output')
59 |
60 | # For VG
61 | VG_DIR = os.path.expanduser('datasets/vg')
62 | parser.add_argument('--vg_h5', default=os.path.join(VG_DIR, 'val.h5'))
63 | parser.add_argument('--vg_image_dir',
64 | default=os.path.join(VG_DIR, 'images'))
65 |
66 | # For COCO
67 | COCO_DIR = os.path.expanduser('~/datasets/coco/2017')
68 | parser.add_argument('--coco_image_dir',
69 | default=os.path.join(COCO_DIR, 'images/val2017'))
70 | parser.add_argument('--instances_json',
71 | default=os.path.join(COCO_DIR, 'annotations/instances_val2017.json'))
72 | parser.add_argument('--stuff_json',
73 | default=os.path.join(COCO_DIR, 'annotations/stuff_val2017.json'))
74 |
75 |
76 | def build_coco_dset(args, checkpoint):
77 | checkpoint_args = checkpoint['args']
78 | print('include other: ', checkpoint_args.get('coco_include_other'))
79 | dset_kwargs = {
80 | 'image_dir': args.coco_image_dir,
81 | 'instances_json': args.instances_json,
82 | 'stuff_json': args.stuff_json,
83 | 'stuff_only': checkpoint_args['coco_stuff_only'],
84 | 'image_size': args.image_size,
85 | 'mask_size': checkpoint_args['mask_size'],
86 | 'max_samples': args.num_samples,
87 | 'min_object_size': checkpoint_args['min_object_size'],
88 | 'min_objects_per_image': checkpoint_args['min_objects_per_image'],
89 | 'instance_whitelist': checkpoint_args['instance_whitelist'],
90 | 'stuff_whitelist': checkpoint_args['stuff_whitelist'],
91 | 'include_other': checkpoint_args.get('coco_include_other', True),
92 | }
93 | dset = CocoSceneGraphDataset(**dset_kwargs)
94 | return dset
95 |
96 |
97 | def build_vg_dset(args, checkpoint):
98 | vocab = checkpoint['model_kwargs']['vocab']
99 | dset_kwargs = {
100 | 'vocab': vocab,
101 | 'h5_path': args.vg_h5,
102 | 'image_dir': args.vg_image_dir,
103 | 'image_size': args.image_size,
104 | 'max_samples': args.num_samples,
105 | 'max_objects': checkpoint['args']['max_objects_per_image'],
106 | 'use_orphaned_objects': checkpoint['args']['vg_use_orphaned_objects'],
107 | }
108 | dset = VgSceneGraphDataset(**dset_kwargs)
109 | return dset
110 |
111 |
112 | def build_loader(args, checkpoint):
113 | if args.dataset == 'coco':
114 | dset = build_coco_dset(args, checkpoint)
115 | collate_fn = coco_collate_fn
116 | elif args.dataset == 'vg':
117 | dset = build_vg_dset(args, checkpoint)
118 | collate_fn = vg_collate_fn
119 |
120 | loader_kwargs = {
121 | 'batch_size': args.batch_size,
122 | 'num_workers': args.loader_num_workers,
123 | 'shuffle': args.shuffle,
124 | 'collate_fn': collate_fn,
125 | }
126 | loader = DataLoader(dset, **loader_kwargs)
127 | return loader
128 |
129 |
130 | def build_model(args, checkpoint):
131 | kwargs = checkpoint['model_kwargs']
132 | model = Sg2ImModel(**checkpoint['model_kwargs'])
133 | model.load_state_dict(checkpoint['model_state'])
134 | if args.model_mode == 'eval':
135 | model.eval()
136 | elif args.model_mode == 'train':
137 | model.train()
138 | model.image_size = args.image_size
139 | model.cuda()
140 | return model
141 |
142 |
143 | def makedir(base, name, flag=True):
144 | dir_name = None
145 | if flag:
146 | dir_name = os.path.join(base, name)
147 | if not os.path.isdir(dir_name):
148 | os.makedirs(dir_name)
149 | return dir_name
150 |
151 |
152 | def run_model(args, checkpoint, output_dir, loader=None):
153 | vocab = checkpoint['model_kwargs']['vocab']
154 | model = build_model(args, checkpoint)
155 | if loader is None:
156 | loader = build_loader(args, checkpoint)
157 |
158 | img_dir = makedir(output_dir, 'images')
159 | graph_dir = makedir(output_dir, 'graphs', args.save_graphs)
160 | gt_img_dir = makedir(output_dir, 'images_gt', args.save_gt_imgs)
161 | data_path = os.path.join(output_dir, 'data.pt')
162 |
163 | data = {
164 | 'vocab': vocab,
165 | 'objs': [],
166 | 'masks_pred': [],
167 | 'boxes_pred': [],
168 | 'masks_gt': [],
169 | 'boxes_gt': [],
170 | 'filenames': [],
171 | }
172 |
173 | img_idx = 0
174 | for batch in loader:
175 | masks = None
176 | if len(batch) == 6:
177 | imgs, objs, boxes, triples, obj_to_img, triple_to_img = [x.cuda() for x in batch]
178 | elif len(batch) == 7:
179 | imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = [x.cuda() for x in batch]
180 |
181 | imgs_gt = imagenet_deprocess_batch(imgs)
182 | boxes_gt = None
183 | masks_gt = None
184 | if args.use_gt_boxes:
185 | boxes_gt = boxes
186 | if args.use_gt_masks:
187 | masks_gt = masks
188 |
189 | # Run the model with predicted masks
190 | model_out = model(objs, triples, obj_to_img,
191 | boxes_gt=boxes_gt, masks_gt=masks_gt)
192 | imgs_pred, boxes_pred, masks_pred, _ = model_out
193 | imgs_pred = imagenet_deprocess_batch(imgs_pred)
194 |
195 | obj_data = [objs, boxes_pred, masks_pred]
196 | _, obj_data = split_graph_batch(triples, obj_data, obj_to_img,
197 | triple_to_img)
198 | objs, boxes_pred, masks_pred = obj_data
199 |
200 | obj_data_gt = [boxes.data]
201 | if masks is not None:
202 | obj_data_gt.append(masks.data)
203 | triples, obj_data_gt = split_graph_batch(triples, obj_data_gt,
204 | obj_to_img, triple_to_img)
205 | boxes_gt, masks_gt = obj_data_gt[0], None
206 | if masks is not None:
207 | masks_gt = obj_data_gt[1]
208 |
209 | for i in range(imgs_pred.size(0)):
210 | img_filename = '%04d.png' % img_idx
211 | if args.save_gt_imgs:
212 | img_gt = imgs_gt[i].numpy().transpose(1, 2, 0)
213 | img_gt_path = os.path.join(gt_img_dir, img_filename)
214 | imsave(img_gt_path, img_gt)
215 |
216 | img_pred = imgs_pred[i]
217 | img_pred_np = imgs_pred[i].numpy().transpose(1, 2, 0)
218 | img_path = os.path.join(img_dir, img_filename)
219 | imsave(img_path, img_pred_np)
220 |
221 | data['objs'].append(objs[i].cpu().clone())
222 | data['masks_pred'].append(masks_pred[i].cpu().clone())
223 | data['boxes_pred'].append(boxes_pred[i].cpu().clone())
224 | data['boxes_gt'].append(boxes_gt[i].cpu().clone())
225 | data['filenames'].append(img_filename)
226 |
227 | cur_masks_gt = None
228 | if masks_gt is not None:
229 | cur_masks_gt = masks_gt[i].cpu().clone()
230 | data['masks_gt'].append(cur_masks_gt)
231 |
232 | if args.save_graphs:
233 | graph_img = draw_scene_graph(vocab, objs[i], triples[i])
234 | graph_path = os.path.join(graph_dir, img_filename)
235 | imsave(graph_path, graph_img)
236 |
237 | img_idx += 1
238 |
239 | torch.save(data, data_path)
240 | print('Saved %d images' % img_idx)
241 |
242 |
243 | def main(args):
244 | got_checkpoint = args.checkpoint is not None
245 | got_checkpoint_list = args.checkpoint_list is not None
246 | if got_checkpoint == got_checkpoint_list:
247 | raise ValueError('Must specify exactly one of --checkpoint and --checkpoint_list')
248 |
249 | if got_checkpoint:
250 | checkpoint = torch.load(args.checkpoint)
251 | print('Loading model from ', args.checkpoint)
252 | run_model(args, checkpoint, args.output_dir)
253 | elif got_checkpoint_list:
254 | # For efficiency, use the same loader for all checkpoints
255 | loader = None
256 | with open(args.checkpoint_list, 'r') as f:
257 | checkpoint_list = [line.strip() for line in f]
258 | for i, path in enumerate(checkpoint_list):
259 | if os.path.isfile(path):
260 | print('Loading model from ', path)
261 | checkpoint = torch.load(path)
262 | if loader is None:
263 | loader = build_loader(args, checkpoint)
264 | output_dir = os.path.join(args.output_dir, 'result%03d' % (i + 1))
265 | run_model(args, checkpoint, output_dir, loader)
266 | elif os.path.isdir(path):
267 | # Look for snapshots in this dir
268 | for fn in sorted(os.listdir(path)):
269 | if 'snapshot' not in fn:
270 | continue
271 | checkpoint_path = os.path.join(path, fn)
272 | print('Loading model from ', checkpoint_path)
273 | checkpoint = torch.load(checkpoint_path)
274 | if loader is None:
275 | loader = build_loader(args, checkpoint)
276 |
277 | # Snapshots have names like "snapshot_00100K.pt'; we want to
278 | # extract the "00100K" part
279 | snapshot_name = os.path.splitext(fn)[0].split('_')[1]
280 | output_dir = 'result%03d_%s' % (i, snapshot_name)
281 | output_dir = os.path.join(args.output_dir, output_dir)
282 |
283 | run_model(args, checkpoint, output_dir, loader)
284 |
285 |
286 | if __name__ == '__main__':
287 | args = parser.parse_args()
288 | main(args)
289 |
290 |
291 |
--------------------------------------------------------------------------------
/scripts/strip_checkpoint.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import argparse, os
18 | import torch
19 |
20 |
21 | """
22 | Checkpoints saved by train.py contain not only model parameters but also
23 | optimizer states, losses, a history of generated images, and other statistics.
24 | This information is very useful for development and debugging models, but makes
25 | the saved checkpoints very large. This utility script strips away all extra
26 | information from saved checkpoints, keeping only the saved models.
27 | """
28 |
29 |
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument('--input_checkpoint', default=None)
32 | parser.add_argument('--output_checkpoint', default=None)
33 | parser.add_argument('--input_dir', default=None)
34 | parser.add_argument('--output_dir', default=None)
35 | parser.add_argument('--keep_discriminators', type=int, default=1)
36 |
37 |
38 | def main(args):
39 | if args.input_checkpoint is not None:
40 | handle_checkpoint(args, args.input_checkpoint, args.output_checkpoint)
41 | if args.input_dir is not None:
42 | handle_dir(args, args.input_dir, args.output_dir)
43 |
44 |
45 | def handle_dir(args, input_dir, output_dir):
46 | for fn in os.listdir(input_dir):
47 | if not fn.endswith('.pt'):
48 | continue
49 | input_path = os.path.join(input_dir, fn)
50 | output_path = os.path.join(output_dir, fn)
51 | handle_checkpoint(args, input_path, output_path)
52 |
53 |
54 | def handle_checkpoint(args, input_path, output_path):
55 | input_checkpoint = torch.load(input_path)
56 | keep = ['args', 'model_state', 'model_kwargs']
57 | if args.keep_discriminators == 1:
58 | keep += ['d_img_state', 'd_img_kwargs', 'd_obj_state', 'd_obj_kwargs']
59 | output_checkpoint = {}
60 | for k, v in input_checkpoint.items():
61 | if k in keep:
62 | output_checkpoint[k] = v
63 | torch.save(output_checkpoint, output_path)
64 |
65 |
66 | if __name__ == '__main__':
67 | args = parser.parse_args()
68 | main(args)
69 |
70 |
--------------------------------------------------------------------------------
/scripts/strip_old_args.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import argparse, os
18 | import torch
19 |
20 |
21 | """
22 | This utility script removes deprecated kwargs in checkpoints.
23 | """
24 |
25 |
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument('--input_checkpoint', default=None)
28 | parser.add_argument('--input_dir', default=None)
29 |
30 |
31 | DEPRECATED_KWARGS = {
32 | 'model_kwargs': [
33 | 'vec_noise_dim', 'gconv_mode', 'box_anchor', 'decouple_obj_predictions',
34 | ],
35 | }
36 |
37 |
38 | def main(args):
39 | got_checkpoint = (args.input_checkpoint is not None)
40 | got_dir = (args.input_dir is not None)
41 | assert got_checkpoint != got_dir, "Must give exactly one of checkpoint or dir"
42 | if got_checkpoint:
43 | handle_checkpoint(args.input_checkpoint)
44 | elif got_dir:
45 | handle_dir(args.input_dir)
46 |
47 |
48 |
49 | def handle_dir(dir_path):
50 | for fn in os.listdir(dir_path):
51 | if not fn.endswith('.pt'):
52 | continue
53 | checkpoint_path = os.path.join(dir_path, fn)
54 | handle_checkpoint(checkpoint_path)
55 |
56 |
57 | def handle_checkpoint(checkpoint_path):
58 | print('Stripping old args from checkpoint "%s"' % checkpoint_path)
59 | checkpoint = torch.load(checkpoint_path)
60 | for group, deprecated in DEPRECATED_KWARGS.items():
61 | assert group in checkpoint
62 | for k in deprecated:
63 | if k in checkpoint[group]:
64 | print('Removing key "%s" from "%s"' % (k, group))
65 | del checkpoint[group][k]
66 | torch.save(checkpoint, checkpoint_path)
67 |
68 |
69 | if __name__ == '__main__':
70 | args = parser.parse_args()
71 | main(args)
72 |
73 |
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import argparse
18 | import functools
19 | import os
20 | import json
21 | import math
22 | from collections import defaultdict
23 | import random
24 |
25 | import numpy as np
26 | import torch
27 | import torch.optim as optim
28 | import torch.nn as nn
29 | import torch.nn.functional as F
30 | from torch.utils.data import DataLoader
31 |
32 | from sg2im.data import imagenet_deprocess_batch
33 | from sg2im.data.coco import CocoSceneGraphDataset, coco_collate_fn
34 | from sg2im.data.vg import VgSceneGraphDataset, vg_collate_fn
35 | from sg2im.discriminators import PatchDiscriminator, AcCropDiscriminator
36 | from sg2im.losses import get_gan_losses
37 | from sg2im.metrics import jaccard
38 | from sg2im.model import Sg2ImModel
39 | from sg2im.utils import int_tuple, float_tuple, str_tuple
40 | from sg2im.utils import timeit, bool_flag, LossManager
41 |
42 | torch.backends.cudnn.benchmark = True
43 |
44 | VG_DIR = os.path.expanduser('datasets/vg')
45 | COCO_DIR = os.path.expanduser('datasets/coco')
46 |
47 | parser = argparse.ArgumentParser()
48 | parser.add_argument('--dataset', default='coco', choices=['vg', 'coco'])
49 |
50 | # Optimization hyperparameters
51 | parser.add_argument('--batch_size', default=32, type=int)
52 | parser.add_argument('--num_iterations', default=1000000, type=int)
53 | parser.add_argument('--learning_rate', default=1e-4, type=float)
54 |
55 | # Switch the generator to eval mode after this many iterations
56 | parser.add_argument('--eval_mode_after', default=100000, type=int)
57 |
58 | # Dataset options common to both VG and COCO
59 | parser.add_argument('--image_size', default='64,64', type=int_tuple)
60 | parser.add_argument('--num_train_samples', default=None, type=int)
61 | parser.add_argument('--num_val_samples', default=1024, type=int)
62 | parser.add_argument('--shuffle_val', default=True, type=bool_flag)
63 | parser.add_argument('--loader_num_workers', default=4, type=int)
64 | parser.add_argument('--include_relationships', default=True, type=bool_flag)
65 |
66 | # VG-specific options
67 | parser.add_argument('--vg_image_dir', default=os.path.join(VG_DIR, 'images'))
68 | parser.add_argument('--train_h5', default=os.path.join(VG_DIR, 'train.h5'))
69 | parser.add_argument('--val_h5', default=os.path.join(VG_DIR, 'val.h5'))
70 | parser.add_argument('--vocab_json', default=os.path.join(VG_DIR, 'vocab.json'))
71 | parser.add_argument('--max_objects_per_image', default=10, type=int)
72 | parser.add_argument('--vg_use_orphaned_objects', default=True, type=bool_flag)
73 |
74 | # COCO-specific options
75 | parser.add_argument('--coco_train_image_dir',
76 | default=os.path.join(COCO_DIR, 'images/train2017'))
77 | parser.add_argument('--coco_val_image_dir',
78 | default=os.path.join(COCO_DIR, 'images/val2017'))
79 | parser.add_argument('--coco_train_instances_json',
80 | default=os.path.join(COCO_DIR, 'annotations/instances_train2017.json'))
81 | parser.add_argument('--coco_train_stuff_json',
82 | default=os.path.join(COCO_DIR, 'annotations/stuff_train2017.json'))
83 | parser.add_argument('--coco_val_instances_json',
84 | default=os.path.join(COCO_DIR, 'annotations/instances_val2017.json'))
85 | parser.add_argument('--coco_val_stuff_json',
86 | default=os.path.join(COCO_DIR, 'annotations/stuff_val2017.json'))
87 | parser.add_argument('--instance_whitelist', default=None, type=str_tuple)
88 | parser.add_argument('--stuff_whitelist', default=None, type=str_tuple)
89 | parser.add_argument('--coco_include_other', default=False, type=bool_flag)
90 | parser.add_argument('--min_object_size', default=0.02, type=float)
91 | parser.add_argument('--min_objects_per_image', default=3, type=int)
92 | parser.add_argument('--coco_stuff_only', default=True, type=bool_flag)
93 |
94 | # Generator options
95 | parser.add_argument('--mask_size', default=16, type=int) # Set this to 0 to use no masks
96 | parser.add_argument('--embedding_dim', default=128, type=int)
97 | parser.add_argument('--gconv_dim', default=128, type=int)
98 | parser.add_argument('--gconv_hidden_dim', default=512, type=int)
99 | parser.add_argument('--gconv_num_layers', default=5, type=int)
100 | parser.add_argument('--mlp_normalization', default='none', type=str)
101 | parser.add_argument('--refinement_network_dims', default='1024,512,256,128,64', type=int_tuple)
102 | parser.add_argument('--normalization', default='batch')
103 | parser.add_argument('--activation', default='leakyrelu-0.2')
104 | parser.add_argument('--layout_noise_dim', default=32, type=int)
105 | parser.add_argument('--use_boxes_pred_after', default=-1, type=int)
106 |
107 | # Generator losses
108 | parser.add_argument('--mask_loss_weight', default=0, type=float)
109 | parser.add_argument('--l1_pixel_loss_weight', default=1.0, type=float)
110 | parser.add_argument('--bbox_pred_loss_weight', default=10, type=float)
111 | parser.add_argument('--predicate_pred_loss_weight', default=0, type=float) # DEPRECATED
112 |
113 | # Generic discriminator options
114 | parser.add_argument('--discriminator_loss_weight', default=0.01, type=float)
115 | parser.add_argument('--gan_loss_type', default='gan')
116 | parser.add_argument('--d_clip', default=None, type=float)
117 | parser.add_argument('--d_normalization', default='batch')
118 | parser.add_argument('--d_padding', default='valid')
119 | parser.add_argument('--d_activation', default='leakyrelu-0.2')
120 |
121 | # Object discriminator
122 | parser.add_argument('--d_obj_arch',
123 | default='C4-64-2,C4-128-2,C4-256-2')
124 | parser.add_argument('--crop_size', default=32, type=int)
125 | parser.add_argument('--d_obj_weight', default=1.0, type=float) # multiplied by d_loss_weight
126 | parser.add_argument('--ac_loss_weight', default=0.1, type=float)
127 |
128 | # Image discriminator
129 | parser.add_argument('--d_img_arch',
130 | default='C4-64-2,C4-128-2,C4-256-2')
131 | parser.add_argument('--d_img_weight', default=1.0, type=float) # multiplied by d_loss_weight
132 |
133 | # Output options
134 | parser.add_argument('--print_every', default=10, type=int)
135 | parser.add_argument('--timing', default=False, type=bool_flag)
136 | parser.add_argument('--checkpoint_every', default=10000, type=int)
137 | parser.add_argument('--output_dir', default=os.getcwd())
138 | parser.add_argument('--checkpoint_name', default='checkpoint')
139 | parser.add_argument('--checkpoint_start_from', default=None)
140 | parser.add_argument('--restore_from_checkpoint', default=False, type=bool_flag)
141 |
142 |
143 | def add_loss(total_loss, curr_loss, loss_dict, loss_name, weight=1):
144 | curr_loss = curr_loss * weight
145 | loss_dict[loss_name] = curr_loss.item()
146 | if total_loss is not None:
147 | total_loss += curr_loss
148 | else:
149 | total_loss = curr_loss
150 | return total_loss
151 |
152 |
153 | def check_args(args):
154 | H, W = args.image_size
155 | for _ in args.refinement_network_dims[1:]:
156 | H = H // 2
157 | if H == 0:
158 | raise ValueError("Too many layers in refinement network")
159 |
160 |
161 | def build_model(args, vocab):
162 | if args.checkpoint_start_from is not None:
163 | checkpoint = torch.load(args.checkpoint_start_from)
164 | kwargs = checkpoint['model_kwargs']
165 | model = Sg2ImModel(**kwargs)
166 | raw_state_dict = checkpoint['model_state']
167 | state_dict = {}
168 | for k, v in raw_state_dict.items():
169 | if k.startswith('module.'):
170 | k = k[7:]
171 | state_dict[k] = v
172 | model.load_state_dict(state_dict)
173 | else:
174 | kwargs = {
175 | 'vocab': vocab,
176 | 'image_size': args.image_size,
177 | 'embedding_dim': args.embedding_dim,
178 | 'gconv_dim': args.gconv_dim,
179 | 'gconv_hidden_dim': args.gconv_hidden_dim,
180 | 'gconv_num_layers': args.gconv_num_layers,
181 | 'mlp_normalization': args.mlp_normalization,
182 | 'refinement_dims': args.refinement_network_dims,
183 | 'normalization': args.normalization,
184 | 'activation': args.activation,
185 | 'mask_size': args.mask_size,
186 | 'layout_noise_dim': args.layout_noise_dim,
187 | }
188 | model = Sg2ImModel(**kwargs)
189 | return model, kwargs
190 |
191 |
192 | def build_obj_discriminator(args, vocab):
193 | discriminator = None
194 | d_kwargs = {}
195 | d_weight = args.discriminator_loss_weight
196 | d_obj_weight = args.d_obj_weight
197 | if d_weight == 0 or d_obj_weight == 0:
198 | return discriminator, d_kwargs
199 |
200 | d_kwargs = {
201 | 'vocab': vocab,
202 | 'arch': args.d_obj_arch,
203 | 'normalization': args.d_normalization,
204 | 'activation': args.d_activation,
205 | 'padding': args.d_padding,
206 | 'object_size': args.crop_size,
207 | }
208 | discriminator = AcCropDiscriminator(**d_kwargs)
209 | return discriminator, d_kwargs
210 |
211 |
212 | def build_img_discriminator(args, vocab):
213 | discriminator = None
214 | d_kwargs = {}
215 | d_weight = args.discriminator_loss_weight
216 | d_img_weight = args.d_img_weight
217 | if d_weight == 0 or d_img_weight == 0:
218 | return discriminator, d_kwargs
219 |
220 | d_kwargs = {
221 | 'arch': args.d_img_arch,
222 | 'normalization': args.d_normalization,
223 | 'activation': args.d_activation,
224 | 'padding': args.d_padding,
225 | }
226 | discriminator = PatchDiscriminator(**d_kwargs)
227 | return discriminator, d_kwargs
228 |
229 |
230 | def build_coco_dsets(args):
231 | dset_kwargs = {
232 | 'image_dir': args.coco_train_image_dir,
233 | 'instances_json': args.coco_train_instances_json,
234 | 'stuff_json': args.coco_train_stuff_json,
235 | 'stuff_only': args.coco_stuff_only,
236 | 'image_size': args.image_size,
237 | 'mask_size': args.mask_size,
238 | 'max_samples': args.num_train_samples,
239 | 'min_object_size': args.min_object_size,
240 | 'min_objects_per_image': args.min_objects_per_image,
241 | 'instance_whitelist': args.instance_whitelist,
242 | 'stuff_whitelist': args.stuff_whitelist,
243 | 'include_other': args.coco_include_other,
244 | 'include_relationships': args.include_relationships,
245 | }
246 | train_dset = CocoSceneGraphDataset(**dset_kwargs)
247 | num_objs = train_dset.total_objects()
248 | num_imgs = len(train_dset)
249 | print('Training dataset has %d images and %d objects' % (num_imgs, num_objs))
250 | print('(%.2f objects per image)' % (float(num_objs) / num_imgs))
251 |
252 | dset_kwargs['image_dir'] = args.coco_val_image_dir
253 | dset_kwargs['instances_json'] = args.coco_val_instances_json
254 | dset_kwargs['stuff_json'] = args.coco_val_stuff_json
255 | dset_kwargs['max_samples'] = args.num_val_samples
256 | val_dset = CocoSceneGraphDataset(**dset_kwargs)
257 |
258 | assert train_dset.vocab == val_dset.vocab
259 | vocab = json.loads(json.dumps(train_dset.vocab))
260 |
261 | return vocab, train_dset, val_dset
262 |
263 |
264 | def build_vg_dsets(args):
265 | with open(args.vocab_json, 'r') as f:
266 | vocab = json.load(f)
267 | dset_kwargs = {
268 | 'vocab': vocab,
269 | 'h5_path': args.train_h5,
270 | 'image_dir': args.vg_image_dir,
271 | 'image_size': args.image_size,
272 | 'max_samples': args.num_train_samples,
273 | 'max_objects': args.max_objects_per_image,
274 | 'use_orphaned_objects': args.vg_use_orphaned_objects,
275 | 'include_relationships': args.include_relationships,
276 | }
277 | train_dset = VgSceneGraphDataset(**dset_kwargs)
278 | iter_per_epoch = len(train_dset) // args.batch_size
279 | print('There are %d iterations per epoch' % iter_per_epoch)
280 |
281 | dset_kwargs['h5_path'] = args.val_h5
282 | del dset_kwargs['max_samples']
283 | val_dset = VgSceneGraphDataset(**dset_kwargs)
284 |
285 | return vocab, train_dset, val_dset
286 |
287 |
288 | def build_loaders(args):
289 | if args.dataset == 'vg':
290 | vocab, train_dset, val_dset = build_vg_dsets(args)
291 | collate_fn = vg_collate_fn
292 | elif args.dataset == 'coco':
293 | vocab, train_dset, val_dset = build_coco_dsets(args)
294 | collate_fn = coco_collate_fn
295 |
296 | loader_kwargs = {
297 | 'batch_size': args.batch_size,
298 | 'num_workers': args.loader_num_workers,
299 | 'shuffle': True,
300 | 'collate_fn': collate_fn,
301 | }
302 | train_loader = DataLoader(train_dset, **loader_kwargs)
303 |
304 | loader_kwargs['shuffle'] = args.shuffle_val
305 | val_loader = DataLoader(val_dset, **loader_kwargs)
306 | return vocab, train_loader, val_loader
307 |
308 |
309 | def check_model(args, t, loader, model):
310 | float_dtype = torch.cuda.FloatTensor
311 | long_dtype = torch.cuda.LongTensor
312 | num_samples = 0
313 | all_losses = defaultdict(list)
314 | total_iou = 0
315 | total_boxes = 0
316 | with torch.no_grad():
317 | for batch in loader:
318 | batch = [tensor.cuda() for tensor in batch]
319 | masks = None
320 | if len(batch) == 6:
321 | imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch
322 | elif len(batch) == 7:
323 | imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = batch
324 | predicates = triples[:, 1]
325 |
326 | # Run the model as it has been run during training
327 | model_masks = masks
328 | model_out = model(objs, triples, obj_to_img, boxes_gt=boxes, masks_gt=model_masks)
329 | imgs_pred, boxes_pred, masks_pred, predicate_scores = model_out
330 |
331 | skip_pixel_loss = False
332 | total_loss, losses = calculate_model_losses(
333 | args, skip_pixel_loss, model, imgs, imgs_pred,
334 | boxes, boxes_pred, masks, masks_pred,
335 | predicates, predicate_scores)
336 |
337 | total_iou += jaccard(boxes_pred, boxes)
338 | total_boxes += boxes_pred.size(0)
339 |
340 | for loss_name, loss_val in losses.items():
341 | all_losses[loss_name].append(loss_val)
342 | num_samples += imgs.size(0)
343 | if num_samples >= args.num_val_samples:
344 | break
345 |
346 | samples = {}
347 | samples['gt_img'] = imgs
348 |
349 | model_out = model(objs, triples, obj_to_img, boxes_gt=boxes, masks_gt=masks)
350 | samples['gt_box_gt_mask'] = model_out[0]
351 |
352 | model_out = model(objs, triples, obj_to_img, boxes_gt=boxes)
353 | samples['gt_box_pred_mask'] = model_out[0]
354 |
355 | model_out = model(objs, triples, obj_to_img)
356 | samples['pred_box_pred_mask'] = model_out[0]
357 |
358 | for k, v in samples.items():
359 | samples[k] = imagenet_deprocess_batch(v)
360 |
361 | mean_losses = {k: np.mean(v) for k, v in all_losses.items()}
362 | avg_iou = total_iou / total_boxes
363 |
364 | masks_to_store = masks
365 | if masks_to_store is not None:
366 | masks_to_store = masks_to_store.data.cpu().clone()
367 |
368 | masks_pred_to_store = masks_pred
369 | if masks_pred_to_store is not None:
370 | masks_pred_to_store = masks_pred_to_store.data.cpu().clone()
371 |
372 | batch_data = {
373 | 'objs': objs.detach().cpu().clone(),
374 | 'boxes_gt': boxes.detach().cpu().clone(),
375 | 'masks_gt': masks_to_store,
376 | 'triples': triples.detach().cpu().clone(),
377 | 'obj_to_img': obj_to_img.detach().cpu().clone(),
378 | 'triple_to_img': triple_to_img.detach().cpu().clone(),
379 | 'boxes_pred': boxes_pred.detach().cpu().clone(),
380 | 'masks_pred': masks_pred_to_store
381 | }
382 | out = [mean_losses, samples, batch_data, avg_iou]
383 |
384 | return tuple(out)
385 |
386 |
387 | def calculate_model_losses(args, skip_pixel_loss, model, img, img_pred,
388 | bbox, bbox_pred, masks, masks_pred,
389 | predicates, predicate_scores):
390 | total_loss = torch.zeros(1).to(img)
391 | losses = {}
392 |
393 | l1_pixel_weight = args.l1_pixel_loss_weight
394 | if skip_pixel_loss:
395 | l1_pixel_weight = 0
396 | l1_pixel_loss = F.l1_loss(img_pred, img)
397 | total_loss = add_loss(total_loss, l1_pixel_loss, losses, 'L1_pixel_loss',
398 | l1_pixel_weight)
399 | loss_bbox = F.mse_loss(bbox_pred, bbox)
400 | total_loss = add_loss(total_loss, loss_bbox, losses, 'bbox_pred',
401 | args.bbox_pred_loss_weight)
402 |
403 | if args.predicate_pred_loss_weight > 0:
404 | loss_predicate = F.cross_entropy(predicate_scores, predicates)
405 | total_loss = add_loss(total_loss, loss_predicate, losses, 'predicate_pred',
406 | args.predicate_pred_loss_weight)
407 |
408 | if args.mask_loss_weight > 0 and masks is not None and masks_pred is not None:
409 | mask_loss = F.binary_cross_entropy(masks_pred, masks.float())
410 | total_loss = add_loss(total_loss, mask_loss, losses, 'mask_loss',
411 | args.mask_loss_weight)
412 | return total_loss, losses
413 |
414 |
415 | def main(args):
416 | print(args)
417 | check_args(args)
418 | float_dtype = torch.cuda.FloatTensor
419 | long_dtype = torch.cuda.LongTensor
420 |
421 | vocab, train_loader, val_loader = build_loaders(args)
422 | model, model_kwargs = build_model(args, vocab)
423 | model.type(float_dtype)
424 | print(model)
425 |
426 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
427 |
428 | obj_discriminator, d_obj_kwargs = build_obj_discriminator(args, vocab)
429 | img_discriminator, d_img_kwargs = build_img_discriminator(args, vocab)
430 | gan_g_loss, gan_d_loss = get_gan_losses(args.gan_loss_type)
431 |
432 | if obj_discriminator is not None:
433 | obj_discriminator.type(float_dtype)
434 | obj_discriminator.train()
435 | print(obj_discriminator)
436 | optimizer_d_obj = torch.optim.Adam(obj_discriminator.parameters(),
437 | lr=args.learning_rate)
438 |
439 | if img_discriminator is not None:
440 | img_discriminator.type(float_dtype)
441 | img_discriminator.train()
442 | print(img_discriminator)
443 | optimizer_d_img = torch.optim.Adam(img_discriminator.parameters(),
444 | lr=args.learning_rate)
445 |
446 | restore_path = None
447 | if args.restore_from_checkpoint:
448 | restore_path = '%s_with_model.pt' % args.checkpoint_name
449 | restore_path = os.path.join(args.output_dir, restore_path)
450 | if restore_path is not None and os.path.isfile(restore_path):
451 | print('Restoring from checkpoint:')
452 | print(restore_path)
453 | checkpoint = torch.load(restore_path)
454 | model.load_state_dict(checkpoint['model_state'])
455 | optimizer.load_state_dict(checkpoint['optim_state'])
456 |
457 | if obj_discriminator is not None:
458 | obj_discriminator.load_state_dict(checkpoint['d_obj_state'])
459 | optimizer_d_obj.load_state_dict(checkpoint['d_obj_optim_state'])
460 |
461 | if img_discriminator is not None:
462 | img_discriminator.load_state_dict(checkpoint['d_img_state'])
463 | optimizer_d_img.load_state_dict(checkpoint['d_img_optim_state'])
464 |
465 | t = checkpoint['counters']['t']
466 | if 0 <= args.eval_mode_after <= t:
467 | model.eval()
468 | else:
469 | model.train()
470 | epoch = checkpoint['counters']['epoch']
471 | else:
472 | t, epoch = 0, 0
473 | checkpoint = {
474 | 'args': args.__dict__,
475 | 'vocab': vocab,
476 | 'model_kwargs': model_kwargs,
477 | 'd_obj_kwargs': d_obj_kwargs,
478 | 'd_img_kwargs': d_img_kwargs,
479 | 'losses_ts': [],
480 | 'losses': defaultdict(list),
481 | 'd_losses': defaultdict(list),
482 | 'checkpoint_ts': [],
483 | 'train_batch_data': [],
484 | 'train_samples': [],
485 | 'train_iou': [],
486 | 'val_batch_data': [],
487 | 'val_samples': [],
488 | 'val_losses': defaultdict(list),
489 | 'val_iou': [],
490 | 'norm_d': [],
491 | 'norm_g': [],
492 | 'counters': {
493 | 't': None,
494 | 'epoch': None,
495 | },
496 | 'model_state': None, 'model_best_state': None, 'optim_state': None,
497 | 'd_obj_state': None, 'd_obj_best_state': None, 'd_obj_optim_state': None,
498 | 'd_img_state': None, 'd_img_best_state': None, 'd_img_optim_state': None,
499 | 'best_t': [],
500 | }
501 |
502 | while True:
503 | if t >= args.num_iterations:
504 | break
505 | epoch += 1
506 | print('Starting epoch %d' % epoch)
507 |
508 | for batch in train_loader:
509 | if t == args.eval_mode_after:
510 | print('switching to eval mode')
511 | model.eval()
512 | optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
513 | t += 1
514 | batch = [tensor.cuda() for tensor in batch]
515 | masks = None
516 | if len(batch) == 6:
517 | imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch
518 | elif len(batch) == 7:
519 | imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = batch
520 | else:
521 | assert False
522 | predicates = triples[:, 1]
523 |
524 | with timeit('forward', args.timing):
525 | model_boxes = boxes
526 | model_masks = masks
527 | model_out = model(objs, triples, obj_to_img,
528 | boxes_gt=model_boxes, masks_gt=model_masks)
529 | imgs_pred, boxes_pred, masks_pred, predicate_scores = model_out
530 | with timeit('loss', args.timing):
531 | # Skip the pixel loss if using GT boxes
532 | skip_pixel_loss = (model_boxes is None)
533 | total_loss, losses = calculate_model_losses(
534 | args, skip_pixel_loss, model, imgs, imgs_pred,
535 | boxes, boxes_pred, masks, masks_pred,
536 | predicates, predicate_scores)
537 |
538 | if obj_discriminator is not None:
539 | scores_fake, ac_loss = obj_discriminator(imgs_pred, objs, boxes, obj_to_img)
540 | total_loss = add_loss(total_loss, ac_loss, losses, 'ac_loss',
541 | args.ac_loss_weight)
542 | weight = args.discriminator_loss_weight * args.d_obj_weight
543 | total_loss = add_loss(total_loss, gan_g_loss(scores_fake), losses,
544 | 'g_gan_obj_loss', weight)
545 |
546 | if img_discriminator is not None:
547 | scores_fake = img_discriminator(imgs_pred)
548 | weight = args.discriminator_loss_weight * args.d_img_weight
549 | total_loss = add_loss(total_loss, gan_g_loss(scores_fake), losses,
550 | 'g_gan_img_loss', weight)
551 |
552 | losses['total_loss'] = total_loss.item()
553 | if not math.isfinite(losses['total_loss']):
554 | print('WARNING: Got loss = NaN, not backpropping')
555 | continue
556 |
557 | optimizer.zero_grad()
558 | with timeit('backward', args.timing):
559 | total_loss.backward()
560 | optimizer.step()
561 | total_loss_d = None
562 | ac_loss_real = None
563 | ac_loss_fake = None
564 | d_losses = {}
565 |
566 | if obj_discriminator is not None:
567 | d_obj_losses = LossManager()
568 | imgs_fake = imgs_pred.detach()
569 | scores_fake, ac_loss_fake = obj_discriminator(imgs_fake, objs, boxes, obj_to_img)
570 | scores_real, ac_loss_real = obj_discriminator(imgs, objs, boxes, obj_to_img)
571 |
572 | d_obj_gan_loss = gan_d_loss(scores_real, scores_fake)
573 | d_obj_losses.add_loss(d_obj_gan_loss, 'd_obj_gan_loss')
574 | d_obj_losses.add_loss(ac_loss_real, 'd_ac_loss_real')
575 | d_obj_losses.add_loss(ac_loss_fake, 'd_ac_loss_fake')
576 |
577 | optimizer_d_obj.zero_grad()
578 | d_obj_losses.total_loss.backward()
579 | optimizer_d_obj.step()
580 |
581 | if img_discriminator is not None:
582 | d_img_losses = LossManager()
583 | imgs_fake = imgs_pred.detach()
584 | scores_fake = img_discriminator(imgs_fake)
585 | scores_real = img_discriminator(imgs)
586 |
587 | d_img_gan_loss = gan_d_loss(scores_real, scores_fake)
588 | d_img_losses.add_loss(d_img_gan_loss, 'd_img_gan_loss')
589 |
590 | optimizer_d_img.zero_grad()
591 | d_img_losses.total_loss.backward()
592 | optimizer_d_img.step()
593 |
594 | if t % args.print_every == 0:
595 | print('t = %d / %d' % (t, args.num_iterations))
596 | for name, val in losses.items():
597 | print(' G [%s]: %.4f' % (name, val))
598 | checkpoint['losses'][name].append(val)
599 | checkpoint['losses_ts'].append(t)
600 |
601 | if obj_discriminator is not None:
602 | for name, val in d_obj_losses.items():
603 | print(' D_obj [%s]: %.4f' % (name, val))
604 | checkpoint['d_losses'][name].append(val)
605 |
606 | if img_discriminator is not None:
607 | for name, val in d_img_losses.items():
608 | print(' D_img [%s]: %.4f' % (name, val))
609 | checkpoint['d_losses'][name].append(val)
610 |
611 | if t % args.checkpoint_every == 0:
612 | print('checking on train')
613 | train_results = check_model(args, t, train_loader, model)
614 | t_losses, t_samples, t_batch_data, t_avg_iou = train_results
615 |
616 | checkpoint['train_batch_data'].append(t_batch_data)
617 | checkpoint['train_samples'].append(t_samples)
618 | checkpoint['checkpoint_ts'].append(t)
619 | checkpoint['train_iou'].append(t_avg_iou)
620 |
621 | print('checking on val')
622 | val_results = check_model(args, t, val_loader, model)
623 | val_losses, val_samples, val_batch_data, val_avg_iou = val_results
624 | checkpoint['val_samples'].append(val_samples)
625 | checkpoint['val_batch_data'].append(val_batch_data)
626 | checkpoint['val_iou'].append(val_avg_iou)
627 |
628 | print('train iou: ', t_avg_iou)
629 | print('val iou: ', val_avg_iou)
630 |
631 | for k, v in val_losses.items():
632 | checkpoint['val_losses'][k].append(v)
633 | checkpoint['model_state'] = model.state_dict()
634 |
635 | if obj_discriminator is not None:
636 | checkpoint['d_obj_state'] = obj_discriminator.state_dict()
637 | checkpoint['d_obj_optim_state'] = optimizer_d_obj.state_dict()
638 |
639 | if img_discriminator is not None:
640 | checkpoint['d_img_state'] = img_discriminator.state_dict()
641 | checkpoint['d_img_optim_state'] = optimizer_d_img.state_dict()
642 |
643 | checkpoint['optim_state'] = optimizer.state_dict()
644 | checkpoint['counters']['t'] = t
645 | checkpoint['counters']['epoch'] = epoch
646 | checkpoint_path = os.path.join(args.output_dir,
647 | '%s_with_model.pt' % args.checkpoint_name)
648 | print('Saving checkpoint to ', checkpoint_path)
649 | torch.save(checkpoint, checkpoint_path)
650 |
651 | # Save another checkpoint without any model or optim state
652 | checkpoint_path = os.path.join(args.output_dir,
653 | '%s_no_model.pt' % args.checkpoint_name)
654 | key_blacklist = ['model_state', 'optim_state', 'model_best_state',
655 | 'd_obj_state', 'd_obj_optim_state', 'd_obj_best_state',
656 | 'd_img_state', 'd_img_optim_state', 'd_img_best_state']
657 | small_checkpoint = {}
658 | for k, v in checkpoint.items():
659 | if k not in key_blacklist:
660 | small_checkpoint[k] = v
661 | torch.save(small_checkpoint, checkpoint_path)
662 |
663 |
664 | if __name__ == '__main__':
665 | args = parser.parse_args()
666 | main(args)
667 |
668 |
--------------------------------------------------------------------------------
/sg2im/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 |
--------------------------------------------------------------------------------
/sg2im/bilinear.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import torch
18 | import torch.nn.functional as F
19 | from sg2im.utils import timeit
20 |
21 |
22 | """
23 | Functions for performing differentiable bilinear cropping of images, for use in
24 | the object discriminator
25 | """
26 |
27 |
28 | def crop_bbox_batch(feats, bbox, bbox_to_feats, HH, WW=None, backend='cudnn'):
29 | """
30 | Inputs:
31 | - feats: FloatTensor of shape (N, C, H, W)
32 | - bbox: FloatTensor of shape (B, 4) giving bounding box coordinates
33 | - bbox_to_feats: LongTensor of shape (B,) mapping boxes to feature maps;
34 | each element is in the range [0, N) and bbox_to_feats[b] = i means that
35 | bbox[b] will be cropped from feats[i].
36 | - HH, WW: Size of the output crops
37 |
38 | Returns:
39 | - crops: FloatTensor of shape (B, C, HH, WW) where crops[i] uses bbox[i] to
40 | crop from feats[bbox_to_feats[i]].
41 | """
42 | if backend == 'cudnn':
43 | return crop_bbox_batch_cudnn(feats, bbox, bbox_to_feats, HH, WW)
44 | N, C, H, W = feats.size()
45 | B = bbox.size(0)
46 | if WW is None: WW = HH
47 | dtype, device = feats.dtype, feats.device
48 | crops = torch.zeros(B, C, HH, WW, dtype=dtype, device=device)
49 | for i in range(N):
50 | idx = (bbox_to_feats.data == i).nonzero()
51 | if idx.dim() == 0:
52 | continue
53 | idx = idx.view(-1)
54 | n = idx.size(0)
55 | cur_feats = feats[i].view(1, C, H, W).expand(n, C, H, W).contiguous()
56 | cur_bbox = bbox[idx]
57 | cur_crops = crop_bbox(cur_feats, cur_bbox, HH, WW)
58 | crops[idx] = cur_crops
59 | return crops
60 |
61 |
62 | def _invperm(p):
63 | N = p.size(0)
64 | eye = torch.arange(0, N).type_as(p)
65 | pp = (eye[:, None] == p).nonzero()[:, 1]
66 | return pp
67 |
68 |
69 | def crop_bbox_batch_cudnn(feats, bbox, bbox_to_feats, HH, WW=None):
70 | N, C, H, W = feats.size()
71 | B = bbox.size(0)
72 | if WW is None: WW = HH
73 | dtype = feats.data.type()
74 |
75 | feats_flat, bbox_flat, all_idx = [], [], []
76 | for i in range(N):
77 | idx = (bbox_to_feats.data == i).nonzero()
78 | if idx.dim() == 0:
79 | continue
80 | idx = idx.view(-1)
81 | n = idx.size(0)
82 | cur_feats = feats[i].view(1, C, H, W).expand(n, C, H, W).contiguous()
83 | cur_bbox = bbox[idx]
84 |
85 | feats_flat.append(cur_feats)
86 | bbox_flat.append(cur_bbox)
87 | all_idx.append(idx)
88 |
89 | feats_flat = torch.cat(feats_flat, dim=0)
90 | bbox_flat = torch.cat(bbox_flat, dim=0)
91 | crops = crop_bbox(feats_flat, bbox_flat, HH, WW, backend='cudnn')
92 |
93 | # If the crops were sequential (all_idx is identity permutation) then we can
94 | # simply return them; otherwise we need to permute crops by the inverse
95 | # permutation from all_idx.
96 | all_idx = torch.cat(all_idx, dim=0)
97 | eye = torch.arange(0, B).type_as(all_idx)
98 | if (all_idx == eye).all():
99 | return crops
100 | return crops[_invperm(all_idx)]
101 |
102 |
103 | def crop_bbox(feats, bbox, HH, WW=None, backend='cudnn'):
104 | """
105 | Take differentiable crops of feats specified by bbox.
106 |
107 | Inputs:
108 | - feats: Tensor of shape (N, C, H, W)
109 | - bbox: Bounding box coordinates of shape (N, 4) in the format
110 | [x0, y0, x1, y1] in the [0, 1] coordinate space.
111 | - HH, WW: Size of the output crops.
112 |
113 | Returns:
114 | - crops: Tensor of shape (N, C, HH, WW) where crops[i] is the portion of
115 | feats[i] specified by bbox[i], reshaped to (HH, WW) using bilinear sampling.
116 | """
117 | N = feats.size(0)
118 | assert bbox.size(0) == N
119 | assert bbox.size(1) == 4
120 | if WW is None: WW = HH
121 | if backend == 'cudnn':
122 | # Change box from [0, 1] to [-1, 1] coordinate system
123 | bbox = 2 * bbox - 1
124 | x0, y0 = bbox[:, 0], bbox[:, 1]
125 | x1, y1 = bbox[:, 2], bbox[:, 3]
126 | X = tensor_linspace(x0, x1, steps=WW).view(N, 1, WW).expand(N, HH, WW)
127 | Y = tensor_linspace(y0, y1, steps=HH).view(N, HH, 1).expand(N, HH, WW)
128 | if backend == 'jj':
129 | return bilinear_sample(feats, X, Y)
130 | elif backend == 'cudnn':
131 | grid = torch.stack([X, Y], dim=3)
132 | return F.grid_sample(feats, grid)
133 |
134 |
135 |
136 | def uncrop_bbox(feats, bbox, H, W=None, fill_value=0):
137 | """
138 | Inverse operation to crop_bbox; construct output images where the feature maps
139 | from feats have been reshaped and placed into the positions specified by bbox.
140 |
141 | Inputs:
142 | - feats: Tensor of shape (N, C, HH, WW)
143 | - bbox: Bounding box coordinates of shape (N, 4) in the format
144 | [x0, y0, x1, y1] in the [0, 1] coordinate space.
145 | - H, W: Size of output.
146 | - fill_value: Portions of the output image that are outside the bounding box
147 | will be filled with this value.
148 |
149 | Returns:
150 | - out: Tensor of shape (N, C, H, W) where the portion of out[i] given by
151 | bbox[i] contains feats[i], reshaped using bilinear sampling.
152 | """
153 | N, C = feats.size(0), feats.size(1)
154 | assert bbox.size(0) == N
155 | assert bbox.size(1) == 4
156 | if W is None: H = W
157 |
158 | x0, y0 = bbox[:, 0], bbox[:, 1]
159 | x1, y1 = bbox[:, 2], bbox[:, 3]
160 | ww = x1 - x0
161 | hh = y1 - y0
162 |
163 | x0 = x0.contiguous().view(N, 1).expand(N, H)
164 | x1 = x1.contiguous().view(N, 1).expand(N, H)
165 | ww = ww.view(N, 1).expand(N, H)
166 |
167 | y0 = y0.contiguous().view(N, 1).expand(N, W)
168 | y1 = y1.contiguous().view(N, 1).expand(N, W)
169 | hh = hh.view(N, 1).expand(N, W)
170 |
171 | X = torch.linspace(0, 1, steps=W).view(1, W).expand(N, W).to(feats)
172 | Y = torch.linspace(0, 1, steps=H).view(1, H).expand(N, H).to(feats)
173 |
174 | X = (X - x0) / ww
175 | Y = (Y - y0) / hh
176 |
177 | # For ByteTensors, (x + y).clamp(max=1) gives logical_or
178 | X_out_mask = ((X < 0) + (X > 1)).view(N, 1, W).expand(N, H, W)
179 | Y_out_mask = ((Y < 0) + (Y > 1)).view(N, H, 1).expand(N, H, W)
180 | out_mask = (X_out_mask + Y_out_mask).clamp(max=1)
181 | out_mask = out_mask.view(N, 1, H, W).expand(N, C, H, W)
182 |
183 | X = X.view(N, 1, W).expand(N, H, W)
184 | Y = Y.view(N, H, 1).expand(N, H, W)
185 |
186 | out = bilinear_sample(feats, X, Y)
187 | out[out_mask] = fill_value
188 | return out
189 |
190 |
191 | def bilinear_sample(feats, X, Y):
192 | """
193 | Perform bilinear sampling on the features in feats using the sampling grid
194 | given by X and Y.
195 |
196 | Inputs:
197 | - feats: Tensor holding input feature map, of shape (N, C, H, W)
198 | - X, Y: Tensors holding x and y coordinates of the sampling
199 | grids; both have shape shape (N, HH, WW) and have elements in the range [0, 1].
200 | Returns:
201 | - out: Tensor of shape (B, C, HH, WW) where out[i] is computed
202 | by sampling from feats[idx[i]] using the sampling grid (X[i], Y[i]).
203 | """
204 | N, C, H, W = feats.size()
205 | assert X.size() == Y.size()
206 | assert X.size(0) == N
207 | _, HH, WW = X.size()
208 |
209 | X = X.mul(W)
210 | Y = Y.mul(H)
211 |
212 | # Get the x and y coordinates for the four samples
213 | x0 = X.floor().clamp(min=0, max=W-1)
214 | x1 = (x0 + 1).clamp(min=0, max=W-1)
215 | y0 = Y.floor().clamp(min=0, max=H-1)
216 | y1 = (y0 + 1).clamp(min=0, max=H-1)
217 |
218 | # In numpy we could do something like feats[i, :, y0, x0] to pull out
219 | # the elements of feats at coordinates y0 and x0, but PyTorch doesn't
220 | # yet support this style of indexing. Instead we have to use the gather
221 | # method, which only allows us to index along one dimension at a time;
222 | # therefore we will collapse the features (BB, C, H, W) into (BB, C, H * W)
223 | # and index along the last dimension. Below we generate linear indices into
224 | # the collapsed last dimension for each of the four combinations we need.
225 | y0x0_idx = (W * y0 + x0).view(N, 1, HH * WW).expand(N, C, HH * WW)
226 | y1x0_idx = (W * y1 + x0).view(N, 1, HH * WW).expand(N, C, HH * WW)
227 | y0x1_idx = (W * y0 + x1).view(N, 1, HH * WW).expand(N, C, HH * WW)
228 | y1x1_idx = (W * y1 + x1).view(N, 1, HH * WW).expand(N, C, HH * WW)
229 |
230 | # Actually use gather to pull out the values from feats corresponding
231 | # to our four samples, then reshape them to (BB, C, HH, WW)
232 | feats_flat = feats.view(N, C, H * W)
233 | v1 = feats_flat.gather(2, y0x0_idx.long()).view(N, C, HH, WW)
234 | v2 = feats_flat.gather(2, y1x0_idx.long()).view(N, C, HH, WW)
235 | v3 = feats_flat.gather(2, y0x1_idx.long()).view(N, C, HH, WW)
236 | v4 = feats_flat.gather(2, y1x1_idx.long()).view(N, C, HH, WW)
237 |
238 | # Compute the weights for the four samples
239 | w1 = ((x1 - X) * (y1 - Y)).view(N, 1, HH, WW).expand(N, C, HH, WW)
240 | w2 = ((x1 - X) * (Y - y0)).view(N, 1, HH, WW).expand(N, C, HH, WW)
241 | w3 = ((X - x0) * (y1 - Y)).view(N, 1, HH, WW).expand(N, C, HH, WW)
242 | w4 = ((X - x0) * (Y - y0)).view(N, 1, HH, WW).expand(N, C, HH, WW)
243 |
244 | # Multiply the samples by the weights to give our interpolated results.
245 | out = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
246 | return out
247 |
248 |
249 | def tensor_linspace(start, end, steps=10):
250 | """
251 | Vectorized version of torch.linspace.
252 |
253 | Inputs:
254 | - start: Tensor of any shape
255 | - end: Tensor of the same shape as start
256 | - steps: Integer
257 |
258 | Returns:
259 | - out: Tensor of shape start.size() + (steps,), such that
260 | out.select(-1, 0) == start, out.select(-1, -1) == end,
261 | and the other elements of out linearly interpolate between
262 | start and end.
263 | """
264 | assert start.size() == end.size()
265 | view_size = start.size() + (1,)
266 | w_size = (1,) * start.dim() + (steps,)
267 | out_size = start.size() + (steps,)
268 |
269 | start_w = torch.linspace(1, 0, steps=steps).to(start)
270 | start_w = start_w.view(w_size).expand(out_size)
271 | end_w = torch.linspace(0, 1, steps=steps).to(start)
272 | end_w = end_w.view(w_size).expand(out_size)
273 |
274 | start = start.contiguous().view(view_size).expand(out_size)
275 | end = end.contiguous().view(view_size).expand(out_size)
276 |
277 | out = start_w * start + end_w * end
278 | return out
279 |
280 |
281 | if __name__ == '__main__':
282 | import numpy as np
283 | from scipy.misc import imread, imsave, imresize
284 |
285 | cat = imresize(imread('cat.jpg'), (256, 256))
286 | dog = imresize(imread('dog.jpg'), (256, 256))
287 | feats = torch.stack([
288 | torch.from_numpy(cat.transpose(2, 0, 1).astype(np.float32)),
289 | torch.from_numpy(dog.transpose(2, 0, 1).astype(np.float32))],
290 | dim=0)
291 |
292 | boxes = torch.FloatTensor([
293 | [0, 0, 1, 1],
294 | [0.25, 0.25, 0.75, 0.75],
295 | [0, 0, 0.5, 0.5],
296 | ])
297 |
298 | box_to_feats = torch.LongTensor([1, 0, 1]).cuda()
299 |
300 | feats, boxes = feats.cuda(), boxes.cuda()
301 | crops = crop_bbox_batch_cudnn(feats, boxes, box_to_feats, 128)
302 | for i in range(crops.size(0)):
303 | crop_np = crops.data[i].cpu().numpy().transpose(1, 2, 0).astype(np.uint8)
304 | imsave('out%d.png' % i, crop_np)
305 |
--------------------------------------------------------------------------------
/sg2im/box_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import torch
18 |
19 | """
20 | Utilities for dealing with bounding boxes
21 | """
22 |
23 |
24 | def apply_box_transform(anchors, transforms):
25 | """
26 | Apply box transforms to a set of anchor boxes.
27 |
28 | Inputs:
29 | - anchors: Anchor boxes of shape (N, 4), where each anchor is specified
30 | in the form [xc, yc, w, h]
31 | - transforms: Box transforms of shape (N, 4) where each transform is
32 | specified as [tx, ty, tw, th]
33 |
34 | Returns:
35 | - boxes: Transformed boxes of shape (N, 4) where each box is in the
36 | format [xc, yc, w, h]
37 | """
38 | # Unpack anchors
39 | xa, ya = anchors[:, 0], anchors[:, 1]
40 | wa, ha = anchors[:, 2], anchors[:, 3]
41 |
42 | # Unpack transforms
43 | tx, ty = transforms[:, 0], transforms[:, 1]
44 | tw, th = transforms[:, 2], transforms[:, 3]
45 |
46 | x = xa + tx * wa
47 | y = ya + ty * ha
48 | w = wa * tw.exp()
49 | h = ha * th.exp()
50 |
51 | boxes = torch.stack([x, y, w, h], dim=1)
52 | return boxes
53 |
54 |
55 | def invert_box_transform(anchors, boxes):
56 | """
57 | Compute the box transform that, when applied to anchors, would give boxes.
58 |
59 | Inputs:
60 | - anchors: Box anchors of shape (N, 4) in the format [xc, yc, w, h]
61 | - boxes: Target boxes of shape (N, 4) in the format [xc, yc, w, h]
62 |
63 | Returns:
64 | - transforms: Box transforms of shape (N, 4) in the format [tx, ty, tw, th]
65 | """
66 | # Unpack anchors
67 | xa, ya = anchors[:, 0], anchors[:, 1]
68 | wa, ha = anchors[:, 2], anchors[:, 3]
69 |
70 | # Unpack boxes
71 | x, y = boxes[:, 0], boxes[:, 1]
72 | w, h = boxes[:, 2], boxes[:, 3]
73 |
74 | tx = (x - xa) / wa
75 | ty = (y - ya) / ha
76 | tw = w.log() - wa.log()
77 | th = h.log() - ha.log()
78 |
79 | transforms = torch.stack([tx, ty, tw, th], dim=1)
80 | return transforms
81 |
82 |
83 | def centers_to_extents(boxes):
84 | """
85 | Convert boxes from [xc, yc, w, h] format to [x0, y0, x1, y1] format
86 |
87 | Input:
88 | - boxes: Input boxes of shape (N, 4) in [xc, yc, w, h] format
89 |
90 | Returns:
91 | - boxes: Output boxes of shape (N, 4) in [x0, y0, x1, y1] format
92 | """
93 | xc, yc = boxes[:, 0], boxes[:, 1]
94 | w, h = boxes[:, 2], boxes[:, 3]
95 |
96 | x0 = xc - w / 2
97 | x1 = x0 + w
98 | y0 = yc - h / 2
99 | y1 = y0 + h
100 |
101 | boxes_out = torch.stack([x0, y0, x1, y1], dim=1)
102 | return boxes_out
103 |
104 |
105 | def extents_to_centers(boxes):
106 | """
107 | Convert boxes from [x0, y0, x1, y1] format to [xc, yc, w, h] format
108 |
109 | Input:
110 | - boxes: Input boxes of shape (N, 4) in [x0, y0, x1, y1] format
111 |
112 | Returns:
113 | - boxes: Output boxes of shape (N, 4) in [xc, yc, w, h] format
114 | """
115 | x0, y0 = boxes[:, 0], boxes[:, 1]
116 | x1, y1 = boxes[:, 2], boxes[:, 3]
117 |
118 | xc = 0.5 * (x0 + x1)
119 | yc = 0.5 * (y0 + y1)
120 | w = x1 - x0
121 | h = y1 - y0
122 |
123 | boxes_out = torch.stack([xc, yc, w, h], dim=1)
124 | return boxes_out
125 |
126 |
--------------------------------------------------------------------------------
/sg2im/crn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 |
21 | from sg2im.layers import get_normalization_2d
22 | from sg2im.layers import get_activation
23 | from sg2im.utils import timeit, lineno, get_gpu_memory
24 |
25 |
26 | """
27 | Cascaded refinement network architecture, as described in:
28 |
29 | Qifeng Chen and Vladlen Koltun,
30 | "Photographic Image Synthesis with Cascaded Refinement Networks",
31 | ICCV 2017
32 | """
33 |
34 |
35 | class RefinementModule(nn.Module):
36 | def __init__(self, layout_dim, input_dim, output_dim,
37 | normalization='instance', activation='leakyrelu'):
38 | super(RefinementModule, self).__init__()
39 |
40 | layers = []
41 | layers.append(nn.Conv2d(layout_dim + input_dim, output_dim,
42 | kernel_size=3, padding=1))
43 | layers.append(get_normalization_2d(output_dim, normalization))
44 | layers.append(get_activation(activation))
45 | layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1))
46 | layers.append(get_normalization_2d(output_dim, normalization))
47 | layers.append(get_activation(activation))
48 | layers = [layer for layer in layers if layer is not None]
49 | for layer in layers:
50 | if isinstance(layer, nn.Conv2d):
51 | nn.init.kaiming_normal_(layer.weight)
52 | self.net = nn.Sequential(*layers)
53 |
54 | def forward(self, layout, feats):
55 | _, _, HH, WW = layout.size()
56 | _, _, H, W = feats.size()
57 | assert HH >= H
58 | if HH > H:
59 | factor = round(HH // H)
60 | assert HH % factor == 0
61 | assert WW % factor == 0 and WW // factor == W
62 | layout = F.avg_pool2d(layout, kernel_size=factor, stride=factor)
63 | net_input = torch.cat([layout, feats], dim=1)
64 | out = self.net(net_input)
65 | return out
66 |
67 |
68 | class RefinementNetwork(nn.Module):
69 | def __init__(self, dims, normalization='instance', activation='leakyrelu'):
70 | super(RefinementNetwork, self).__init__()
71 | layout_dim = dims[0]
72 | self.refinement_modules = nn.ModuleList()
73 | for i in range(1, len(dims)):
74 | input_dim = 1 if i == 1 else dims[i - 1]
75 | output_dim = dims[i]
76 | mod = RefinementModule(layout_dim, input_dim, output_dim,
77 | normalization=normalization, activation=activation)
78 | self.refinement_modules.append(mod)
79 | output_conv_layers = [
80 | nn.Conv2d(dims[-1], dims[-1], kernel_size=3, padding=1),
81 | get_activation(activation),
82 | nn.Conv2d(dims[-1], 3, kernel_size=1, padding=0)
83 | ]
84 | nn.init.kaiming_normal_(output_conv_layers[0].weight)
85 | nn.init.kaiming_normal_(output_conv_layers[2].weight)
86 | self.output_conv = nn.Sequential(*output_conv_layers)
87 |
88 | def forward(self, layout):
89 | """
90 | Output will have same size as layout
91 | """
92 | # H, W = self.output_size
93 | N, _, H, W = layout.size()
94 | self.layout = layout
95 |
96 | # Figure out size of input
97 | input_H, input_W = H, W
98 | for _ in range(len(self.refinement_modules)):
99 | input_H //= 2
100 | input_W //= 2
101 |
102 | assert input_H != 0
103 | assert input_W != 0
104 |
105 | feats = torch.zeros(N, 1, input_H, input_W).to(layout)
106 | for mod in self.refinement_modules:
107 | feats = F.upsample(feats, scale_factor=2, mode='nearest')
108 | feats = mod(layout, feats)
109 |
110 | out = self.output_conv(feats)
111 | return out
112 |
113 |
--------------------------------------------------------------------------------
/sg2im/data/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from .utils import imagenet_preprocess, imagenet_deprocess
18 | from .utils import imagenet_deprocess_batch
19 |
--------------------------------------------------------------------------------
/sg2im/data/coco.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import json, os, random, math
18 | from collections import defaultdict
19 |
20 | import torch
21 | from torch.utils.data import Dataset
22 | import torchvision.transforms as T
23 |
24 | import numpy as np
25 | import PIL
26 | from skimage.transform import resize as imresize
27 | import pycocotools.mask as mask_utils
28 |
29 | from .utils import imagenet_preprocess, Resize
30 |
31 |
32 | class CocoSceneGraphDataset(Dataset):
33 | def __init__(self, image_dir, instances_json, stuff_json=None,
34 | stuff_only=True, image_size=(64, 64), mask_size=16,
35 | normalize_images=True, max_samples=None,
36 | include_relationships=True, min_object_size=0.02,
37 | min_objects_per_image=3, max_objects_per_image=8,
38 | include_other=False, instance_whitelist=None, stuff_whitelist=None):
39 | """
40 | A PyTorch Dataset for loading Coco and Coco-Stuff annotations and converting
41 | them to scene graphs on the fly.
42 |
43 | Inputs:
44 | - image_dir: Path to a directory where images are held
45 | - instances_json: Path to a JSON file giving COCO annotations
46 | - stuff_json: (optional) Path to a JSON file giving COCO-Stuff annotations
47 | - stuff_only: (optional, default True) If True then only iterate over
48 | images which appear in stuff_json; if False then iterate over all images
49 | in instances_json.
50 | - image_size: Size (H, W) at which to load images. Default (64, 64).
51 | - mask_size: Size M for object segmentation masks; default 16.
52 | - normalize_image: If True then normalize images by subtracting ImageNet
53 | mean pixel and dividing by ImageNet std pixel.
54 | - max_samples: If None use all images. Other wise only use images in the
55 | range [0, max_samples). Default None.
56 | - include_relationships: If True then include spatial relationships; if
57 | False then only include the trivial __in_image__ relationship.
58 | - min_object_size: Ignore objects whose bounding box takes up less than
59 | this fraction of the image.
60 | - min_objects_per_image: Ignore images which have fewer than this many
61 | object annotations.
62 | - max_objects_per_image: Ignore images which have more than this many
63 | object annotations.
64 | - include_other: If True, include COCO-Stuff annotations which have category
65 | "other". Default is False, because I found that these were really noisy
66 | and pretty much impossible for the system to model.
67 | - instance_whitelist: None means use all instance categories. Otherwise a
68 | list giving a whitelist of instance category names to use.
69 | - stuff_whitelist: None means use all stuff categories. Otherwise a list
70 | giving a whitelist of stuff category names to use.
71 | """
72 | super(Dataset, self).__init__()
73 |
74 | if stuff_only and stuff_json is None:
75 | print('WARNING: Got stuff_only=True but stuff_json=None.')
76 | print('Falling back to stuff_only=False.')
77 |
78 | self.image_dir = image_dir
79 | self.mask_size = mask_size
80 | self.max_samples = max_samples
81 | self.normalize_images = normalize_images
82 | self.include_relationships = include_relationships
83 | self.set_image_size(image_size)
84 |
85 | with open(instances_json, 'r') as f:
86 | instances_data = json.load(f)
87 |
88 | stuff_data = None
89 | if stuff_json is not None and stuff_json != '':
90 | with open(stuff_json, 'r') as f:
91 | stuff_data = json.load(f)
92 |
93 | self.image_ids = []
94 | self.image_id_to_filename = {}
95 | self.image_id_to_size = {}
96 | for image_data in instances_data['images']:
97 | image_id = image_data['id']
98 | filename = image_data['file_name']
99 | width = image_data['width']
100 | height = image_data['height']
101 | self.image_ids.append(image_id)
102 | self.image_id_to_filename[image_id] = filename
103 | self.image_id_to_size[image_id] = (width, height)
104 |
105 | self.vocab = {
106 | 'object_name_to_idx': {},
107 | 'pred_name_to_idx': {},
108 | }
109 | object_idx_to_name = {}
110 | all_instance_categories = []
111 | for category_data in instances_data['categories']:
112 | category_id = category_data['id']
113 | category_name = category_data['name']
114 | all_instance_categories.append(category_name)
115 | object_idx_to_name[category_id] = category_name
116 | self.vocab['object_name_to_idx'][category_name] = category_id
117 | all_stuff_categories = []
118 | if stuff_data:
119 | for category_data in stuff_data['categories']:
120 | category_name = category_data['name']
121 | category_id = category_data['id']
122 | all_stuff_categories.append(category_name)
123 | object_idx_to_name[category_id] = category_name
124 | self.vocab['object_name_to_idx'][category_name] = category_id
125 |
126 | if instance_whitelist is None:
127 | instance_whitelist = all_instance_categories
128 | if stuff_whitelist is None:
129 | stuff_whitelist = all_stuff_categories
130 | category_whitelist = set(instance_whitelist) | set(stuff_whitelist)
131 |
132 | # Add object data from instances
133 | self.image_id_to_objects = defaultdict(list)
134 | for object_data in instances_data['annotations']:
135 | image_id = object_data['image_id']
136 | _, _, w, h = object_data['bbox']
137 | W, H = self.image_id_to_size[image_id]
138 | box_area = (w * h) / (W * H)
139 | box_ok = box_area > min_object_size
140 | object_name = object_idx_to_name[object_data['category_id']]
141 | category_ok = object_name in category_whitelist
142 | other_ok = object_name != 'other' or include_other
143 | if box_ok and category_ok and other_ok:
144 | self.image_id_to_objects[image_id].append(object_data)
145 |
146 | # Add object data from stuff
147 | if stuff_data:
148 | image_ids_with_stuff = set()
149 | for object_data in stuff_data['annotations']:
150 | image_id = object_data['image_id']
151 | image_ids_with_stuff.add(image_id)
152 | _, _, w, h = object_data['bbox']
153 | W, H = self.image_id_to_size[image_id]
154 | box_area = (w * h) / (W * H)
155 | box_ok = box_area > min_object_size
156 | object_name = object_idx_to_name[object_data['category_id']]
157 | category_ok = object_name in category_whitelist
158 | other_ok = object_name != 'other' or include_other
159 | if box_ok and category_ok and other_ok:
160 | self.image_id_to_objects[image_id].append(object_data)
161 | if stuff_only:
162 | new_image_ids = []
163 | for image_id in self.image_ids:
164 | if image_id in image_ids_with_stuff:
165 | new_image_ids.append(image_id)
166 | self.image_ids = new_image_ids
167 |
168 | all_image_ids = set(self.image_id_to_filename.keys())
169 | image_ids_to_remove = all_image_ids - image_ids_with_stuff
170 | for image_id in image_ids_to_remove:
171 | self.image_id_to_filename.pop(image_id, None)
172 | self.image_id_to_size.pop(image_id, None)
173 | self.image_id_to_objects.pop(image_id, None)
174 |
175 | # COCO category labels start at 1, so use 0 for __image__
176 | self.vocab['object_name_to_idx']['__image__'] = 0
177 |
178 | # Build object_idx_to_name
179 | name_to_idx = self.vocab['object_name_to_idx']
180 | assert len(name_to_idx) == len(set(name_to_idx.values()))
181 | max_object_idx = max(name_to_idx.values())
182 | idx_to_name = ['NONE'] * (1 + max_object_idx)
183 | for name, idx in self.vocab['object_name_to_idx'].items():
184 | idx_to_name[idx] = name
185 | self.vocab['object_idx_to_name'] = idx_to_name
186 |
187 | # Prune images that have too few or too many objects
188 | new_image_ids = []
189 | total_objs = 0
190 | for image_id in self.image_ids:
191 | num_objs = len(self.image_id_to_objects[image_id])
192 | total_objs += num_objs
193 | if min_objects_per_image <= num_objs <= max_objects_per_image:
194 | new_image_ids.append(image_id)
195 | self.image_ids = new_image_ids
196 |
197 | self.vocab['pred_idx_to_name'] = [
198 | '__in_image__',
199 | 'left of',
200 | 'right of',
201 | 'above',
202 | 'below',
203 | 'inside',
204 | 'surrounding',
205 | ]
206 | self.vocab['pred_name_to_idx'] = {}
207 | for idx, name in enumerate(self.vocab['pred_idx_to_name']):
208 | self.vocab['pred_name_to_idx'][name] = idx
209 |
210 | def set_image_size(self, image_size):
211 | print('called set_image_size', image_size)
212 | transform = [Resize(image_size), T.ToTensor()]
213 | if self.normalize_images:
214 | transform.append(imagenet_preprocess())
215 | self.transform = T.Compose(transform)
216 | self.image_size = image_size
217 |
218 | def total_objects(self):
219 | total_objs = 0
220 | for i, image_id in enumerate(self.image_ids):
221 | if self.max_samples and i >= self.max_samples:
222 | break
223 | num_objs = len(self.image_id_to_objects[image_id])
224 | total_objs += num_objs
225 | return total_objs
226 |
227 | def __len__(self):
228 | if self.max_samples is None:
229 | return len(self.image_ids)
230 | return min(len(self.image_ids), self.max_samples)
231 |
232 | def __getitem__(self, index):
233 | """
234 | Get the pixels of an image, and a random synthetic scene graph for that
235 | image constructed on-the-fly from its COCO object annotations. We assume
236 | that the image will have height H, width W, C channels; there will be O
237 | object annotations, each of which will have both a bounding box and a
238 | segmentation mask of shape (M, M). There will be T triples in the scene
239 | graph.
240 |
241 | Returns a tuple of:
242 | - image: FloatTensor of shape (C, H, W)
243 | - objs: LongTensor of shape (O,)
244 | - boxes: FloatTensor of shape (O, 4) giving boxes for objects in
245 | (x0, y0, x1, y1) format, in a [0, 1] coordinate system
246 | - masks: LongTensor of shape (O, M, M) giving segmentation masks for
247 | objects, where 0 is background and 1 is object.
248 | - triples: LongTensor of shape (T, 3) where triples[t] = [i, p, j]
249 | means that (objs[i], p, objs[j]) is a triple.
250 | """
251 | image_id = self.image_ids[index]
252 |
253 | filename = self.image_id_to_filename[image_id]
254 | image_path = os.path.join(self.image_dir, filename)
255 | with open(image_path, 'rb') as f:
256 | with PIL.Image.open(f) as image:
257 | WW, HH = image.size
258 | image = self.transform(image.convert('RGB'))
259 |
260 | H, W = self.image_size
261 | objs, boxes, masks = [], [], []
262 | for object_data in self.image_id_to_objects[image_id]:
263 | objs.append(object_data['category_id'])
264 | x, y, w, h = object_data['bbox']
265 | x0 = x / WW
266 | y0 = y / HH
267 | x1 = (x + w) / WW
268 | y1 = (y + h) / HH
269 | boxes.append(torch.FloatTensor([x0, y0, x1, y1]))
270 |
271 | # This will give a numpy array of shape (HH, WW)
272 | mask = seg_to_mask(object_data['segmentation'], WW, HH)
273 |
274 | # Crop the mask according to the bounding box, being careful to
275 | # ensure that we don't crop a zero-area region
276 | mx0, mx1 = int(round(x)), int(round(x + w))
277 | my0, my1 = int(round(y)), int(round(y + h))
278 | mx1 = max(mx0 + 1, mx1)
279 | my1 = max(my0 + 1, my1)
280 | mask = mask[my0:my1, mx0:mx1]
281 | mask = imresize(255.0 * mask, (self.mask_size, self.mask_size),
282 | mode='constant')
283 | mask = torch.from_numpy((mask > 128).astype(np.int64))
284 | masks.append(mask)
285 |
286 | # Add dummy __image__ object
287 | objs.append(self.vocab['object_name_to_idx']['__image__'])
288 | boxes.append(torch.FloatTensor([0, 0, 1, 1]))
289 | masks.append(torch.ones(self.mask_size, self.mask_size).long())
290 |
291 | objs = torch.LongTensor(objs)
292 | boxes = torch.stack(boxes, dim=0)
293 | masks = torch.stack(masks, dim=0)
294 |
295 | box_areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
296 |
297 | # Compute centers of all objects
298 | obj_centers = []
299 | _, MH, MW = masks.size()
300 | for i, obj_idx in enumerate(objs):
301 | x0, y0, x1, y1 = boxes[i]
302 | mask = (masks[i] == 1)
303 | xs = torch.linspace(x0, x1, MW).view(1, MW).expand(MH, MW)
304 | ys = torch.linspace(y0, y1, MH).view(MH, 1).expand(MH, MW)
305 | if mask.sum() == 0:
306 | mean_x = 0.5 * (x0 + x1)
307 | mean_y = 0.5 * (y0 + y1)
308 | else:
309 | mean_x = xs[mask].mean()
310 | mean_y = ys[mask].mean()
311 | obj_centers.append([mean_x, mean_y])
312 | obj_centers = torch.FloatTensor(obj_centers)
313 |
314 | # Add triples
315 | triples = []
316 | num_objs = objs.size(0)
317 | __image__ = self.vocab['object_name_to_idx']['__image__']
318 | real_objs = []
319 | if num_objs > 1:
320 | real_objs = (objs != __image__).nonzero().squeeze(1)
321 | for cur in real_objs:
322 | choices = [obj for obj in real_objs if obj != cur]
323 | if len(choices) == 0 or not self.include_relationships:
324 | break
325 | other = random.choice(choices)
326 | if random.random() > 0.5:
327 | s, o = cur, other
328 | else:
329 | s, o = other, cur
330 |
331 | # Check for inside / surrounding
332 | sx0, sy0, sx1, sy1 = boxes[s]
333 | ox0, oy0, ox1, oy1 = boxes[o]
334 | d = obj_centers[s] - obj_centers[o]
335 | theta = math.atan2(d[1], d[0])
336 |
337 | if sx0 < ox0 and sx1 > ox1 and sy0 < oy0 and sy1 > oy1:
338 | p = 'surrounding'
339 | elif sx0 > ox0 and sx1 < ox1 and sy0 > oy0 and sy1 < oy1:
340 | p = 'inside'
341 | elif theta >= 3 * math.pi / 4 or theta <= -3 * math.pi / 4:
342 | p = 'left of'
343 | elif -3 * math.pi / 4 <= theta < -math.pi / 4:
344 | p = 'above'
345 | elif -math.pi / 4 <= theta < math.pi / 4:
346 | p = 'right of'
347 | elif math.pi / 4 <= theta < 3 * math.pi / 4:
348 | p = 'below'
349 | p = self.vocab['pred_name_to_idx'][p]
350 | triples.append([s, p, o])
351 |
352 | # Add __in_image__ triples
353 | O = objs.size(0)
354 | in_image = self.vocab['pred_name_to_idx']['__in_image__']
355 | for i in range(O - 1):
356 | triples.append([i, in_image, O - 1])
357 |
358 | triples = torch.LongTensor(triples)
359 | return image, objs, boxes, masks, triples
360 |
361 |
362 | def seg_to_mask(seg, width=1.0, height=1.0):
363 | """
364 | Tiny utility for decoding segmentation masks using the pycocotools API.
365 | """
366 | if type(seg) == list:
367 | rles = mask_utils.frPyObjects(seg, height, width)
368 | rle = mask_utils.merge(rles)
369 | elif type(seg['counts']) == list:
370 | rle = mask_utils.frPyObjects(seg, height, width)
371 | else:
372 | rle = seg
373 | return mask_utils.decode(rle)
374 |
375 |
376 | def coco_collate_fn(batch):
377 | """
378 | Collate function to be used when wrapping CocoSceneGraphDataset in a
379 | DataLoader. Returns a tuple of the following:
380 |
381 | - imgs: FloatTensor of shape (N, C, H, W)
382 | - objs: LongTensor of shape (O,) giving object categories
383 | - boxes: FloatTensor of shape (O, 4)
384 | - masks: FloatTensor of shape (O, M, M)
385 | - triples: LongTensor of shape (T, 3) giving triples
386 | - obj_to_img: LongTensor of shape (O,) mapping objects to images
387 | - triple_to_img: LongTensor of shape (T,) mapping triples to images
388 | """
389 | all_imgs, all_objs, all_boxes, all_masks, all_triples = [], [], [], [], []
390 | all_obj_to_img, all_triple_to_img = [], []
391 | obj_offset = 0
392 | for i, (img, objs, boxes, masks, triples) in enumerate(batch):
393 | all_imgs.append(img[None])
394 | if objs.dim() == 0 or triples.dim() == 0:
395 | continue
396 | O, T = objs.size(0), triples.size(0)
397 | all_objs.append(objs)
398 | all_boxes.append(boxes)
399 | all_masks.append(masks)
400 | triples = triples.clone()
401 | triples[:, 0] += obj_offset
402 | triples[:, 2] += obj_offset
403 | all_triples.append(triples)
404 |
405 | all_obj_to_img.append(torch.LongTensor(O).fill_(i))
406 | all_triple_to_img.append(torch.LongTensor(T).fill_(i))
407 | obj_offset += O
408 |
409 | all_imgs = torch.cat(all_imgs)
410 | all_objs = torch.cat(all_objs)
411 | all_boxes = torch.cat(all_boxes)
412 | all_masks = torch.cat(all_masks)
413 | all_triples = torch.cat(all_triples)
414 | all_obj_to_img = torch.cat(all_obj_to_img)
415 | all_triple_to_img = torch.cat(all_triple_to_img)
416 |
417 | out = (all_imgs, all_objs, all_boxes, all_masks, all_triples,
418 | all_obj_to_img, all_triple_to_img)
419 | return out
420 |
421 |
--------------------------------------------------------------------------------
/sg2im/data/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import PIL
18 | import torch
19 | import torchvision.transforms as T
20 |
21 |
22 | IMAGENET_MEAN = [0.485, 0.456, 0.406]
23 | IMAGENET_STD = [0.229, 0.224, 0.225]
24 |
25 | INV_IMAGENET_MEAN = [-m for m in IMAGENET_MEAN]
26 | INV_IMAGENET_STD = [1.0 / s for s in IMAGENET_STD]
27 |
28 |
29 | def imagenet_preprocess():
30 | return T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
31 |
32 |
33 | def rescale(x):
34 | lo, hi = x.min(), x.max()
35 | return x.sub(lo).div(hi - lo)
36 |
37 |
38 | def imagenet_deprocess(rescale_image=True):
39 | transforms = [
40 | T.Normalize(mean=[0, 0, 0], std=INV_IMAGENET_STD),
41 | T.Normalize(mean=INV_IMAGENET_MEAN, std=[1.0, 1.0, 1.0]),
42 | ]
43 | if rescale_image:
44 | transforms.append(rescale)
45 | return T.Compose(transforms)
46 |
47 |
48 | def imagenet_deprocess_batch(imgs, rescale=True):
49 | """
50 | Input:
51 | - imgs: FloatTensor of shape (N, C, H, W) giving preprocessed images
52 |
53 | Output:
54 | - imgs_de: ByteTensor of shape (N, C, H, W) giving deprocessed images
55 | in the range [0, 255]
56 | """
57 | if isinstance(imgs, torch.autograd.Variable):
58 | imgs = imgs.data
59 | imgs = imgs.cpu().clone()
60 | deprocess_fn = imagenet_deprocess(rescale_image=rescale)
61 | imgs_de = []
62 | for i in range(imgs.size(0)):
63 | img_de = deprocess_fn(imgs[i])[None]
64 | img_de = img_de.mul(255).clamp(0, 255).byte()
65 | imgs_de.append(img_de)
66 | imgs_de = torch.cat(imgs_de, dim=0)
67 | return imgs_de
68 |
69 |
70 | class Resize(object):
71 | def __init__(self, size, interp=PIL.Image.BILINEAR):
72 | if isinstance(size, tuple):
73 | H, W = size
74 | self.size = (W, H)
75 | else:
76 | self.size = (size, size)
77 | self.interp = interp
78 |
79 | def __call__(self, img):
80 | return img.resize(self.size, self.interp)
81 |
82 |
83 | def unpack_var(v):
84 | if isinstance(v, torch.autograd.Variable):
85 | return v.data
86 | return v
87 |
88 |
89 | def split_graph_batch(triples, obj_data, obj_to_img, triple_to_img):
90 | triples = unpack_var(triples)
91 | obj_data = [unpack_var(o) for o in obj_data]
92 | obj_to_img = unpack_var(obj_to_img)
93 | triple_to_img = unpack_var(triple_to_img)
94 |
95 | triples_out = []
96 | obj_data_out = [[] for _ in obj_data]
97 | obj_offset = 0
98 | N = obj_to_img.max() + 1
99 | for i in range(N):
100 | o_idxs = (obj_to_img == i).nonzero().view(-1)
101 | t_idxs = (triple_to_img == i).nonzero().view(-1)
102 |
103 | cur_triples = triples[t_idxs].clone()
104 | cur_triples[:, 0] -= obj_offset
105 | cur_triples[:, 2] -= obj_offset
106 | triples_out.append(cur_triples)
107 |
108 | for j, o_data in enumerate(obj_data):
109 | cur_o_data = None
110 | if o_data is not None:
111 | cur_o_data = o_data[o_idxs]
112 | obj_data_out[j].append(cur_o_data)
113 |
114 | obj_offset += o_idxs.size(0)
115 |
116 | return triples_out, obj_data_out
117 |
118 |
--------------------------------------------------------------------------------
/sg2im/data/vg.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import os
18 | import random
19 | from collections import defaultdict
20 |
21 | import torch
22 | from torch.utils.data import Dataset
23 | import torchvision.transforms as T
24 |
25 | import numpy as np
26 | import h5py
27 | import PIL
28 |
29 | from .utils import imagenet_preprocess, Resize
30 |
31 |
32 | class VgSceneGraphDataset(Dataset):
33 | def __init__(self, vocab, h5_path, image_dir, image_size=(256, 256),
34 | normalize_images=True, max_objects=10, max_samples=None,
35 | include_relationships=True, use_orphaned_objects=True):
36 | super(VgSceneGraphDataset, self).__init__()
37 |
38 | self.image_dir = image_dir
39 | self.image_size = image_size
40 | self.vocab = vocab
41 | self.num_objects = len(vocab['object_idx_to_name'])
42 | self.use_orphaned_objects = use_orphaned_objects
43 | self.max_objects = max_objects
44 | self.max_samples = max_samples
45 | self.include_relationships = include_relationships
46 |
47 | transform = [Resize(image_size), T.ToTensor()]
48 | if normalize_images:
49 | transform.append(imagenet_preprocess())
50 | self.transform = T.Compose(transform)
51 |
52 | self.data = {}
53 | with h5py.File(h5_path, 'r') as f:
54 | for k, v in f.items():
55 | if k == 'image_paths':
56 | self.image_paths = list(v)
57 | else:
58 | self.data[k] = torch.IntTensor(np.asarray(v))
59 |
60 | def __len__(self):
61 | num = self.data['object_names'].size(0)
62 | if self.max_samples is not None:
63 | return min(self.max_samples, num)
64 | return num
65 |
66 | def __getitem__(self, index):
67 | """
68 | Returns a tuple of:
69 | - image: FloatTensor of shape (C, H, W)
70 | - objs: LongTensor of shape (O,)
71 | - boxes: FloatTensor of shape (O, 4) giving boxes for objects in
72 | (x0, y0, x1, y1) format, in a [0, 1] coordinate system.
73 | - triples: LongTensor of shape (T, 3) where triples[t] = [i, p, j]
74 | means that (objs[i], p, objs[j]) is a triple.
75 | """
76 | img_path = os.path.join(self.image_dir, self.image_paths[index])
77 |
78 | with open(img_path, 'rb') as f:
79 | with PIL.Image.open(f) as image:
80 | WW, HH = image.size
81 | image = self.transform(image.convert('RGB'))
82 |
83 | H, W = self.image_size
84 |
85 | # Figure out which objects appear in relationships and which don't
86 | obj_idxs_with_rels = set()
87 | obj_idxs_without_rels = set(range(self.data['objects_per_image'][index].item()))
88 | for r_idx in range(self.data['relationships_per_image'][index]):
89 | s = self.data['relationship_subjects'][index, r_idx].item()
90 | o = self.data['relationship_objects'][index, r_idx].item()
91 | obj_idxs_with_rels.add(s)
92 | obj_idxs_with_rels.add(o)
93 | obj_idxs_without_rels.discard(s)
94 | obj_idxs_without_rels.discard(o)
95 |
96 | obj_idxs = list(obj_idxs_with_rels)
97 | obj_idxs_without_rels = list(obj_idxs_without_rels)
98 | if len(obj_idxs) > self.max_objects - 1:
99 | obj_idxs = random.sample(obj_idxs, self.max_objects)
100 | if len(obj_idxs) < self.max_objects - 1 and self.use_orphaned_objects:
101 | num_to_add = self.max_objects - 1 - len(obj_idxs)
102 | num_to_add = min(num_to_add, len(obj_idxs_without_rels))
103 | obj_idxs += random.sample(obj_idxs_without_rels, num_to_add)
104 | O = len(obj_idxs) + 1
105 |
106 | objs = torch.LongTensor(O).fill_(-1)
107 |
108 | boxes = torch.FloatTensor([[0, 0, 1, 1]]).repeat(O, 1)
109 | obj_idx_mapping = {}
110 | for i, obj_idx in enumerate(obj_idxs):
111 | objs[i] = self.data['object_names'][index, obj_idx].item()
112 | x, y, w, h = self.data['object_boxes'][index, obj_idx].tolist()
113 | x0 = float(x) / WW
114 | y0 = float(y) / HH
115 | x1 = float(x + w) / WW
116 | y1 = float(y + h) / HH
117 | boxes[i] = torch.FloatTensor([x0, y0, x1, y1])
118 | obj_idx_mapping[obj_idx] = i
119 |
120 | # The last object will be the special __image__ object
121 | objs[O - 1] = self.vocab['object_name_to_idx']['__image__']
122 |
123 | triples = []
124 | for r_idx in range(self.data['relationships_per_image'][index].item()):
125 | if not self.include_relationships:
126 | break
127 | s = self.data['relationship_subjects'][index, r_idx].item()
128 | p = self.data['relationship_predicates'][index, r_idx].item()
129 | o = self.data['relationship_objects'][index, r_idx].item()
130 | s = obj_idx_mapping.get(s, None)
131 | o = obj_idx_mapping.get(o, None)
132 | if s is not None and o is not None:
133 | triples.append([s, p, o])
134 |
135 | # Add dummy __in_image__ relationships for all objects
136 | in_image = self.vocab['pred_name_to_idx']['__in_image__']
137 | for i in range(O - 1):
138 | triples.append([i, in_image, O - 1])
139 |
140 | triples = torch.LongTensor(triples)
141 | return image, objs, boxes, triples
142 |
143 |
144 | def vg_collate_fn(batch):
145 | """
146 | Collate function to be used when wrapping a VgSceneGraphDataset in a
147 | DataLoader. Returns a tuple of the following:
148 |
149 | - imgs: FloatTensor of shape (N, C, H, W)
150 | - objs: LongTensor of shape (O,) giving categories for all objects
151 | - boxes: FloatTensor of shape (O, 4) giving boxes for all objects
152 | - triples: FloatTensor of shape (T, 3) giving all triples, where
153 | triples[t] = [i, p, j] means that [objs[i], p, objs[j]] is a triple
154 | - obj_to_img: LongTensor of shape (O,) mapping objects to images;
155 | obj_to_img[i] = n means that objs[i] belongs to imgs[n]
156 | - triple_to_img: LongTensor of shape (T,) mapping triples to images;
157 | triple_to_img[t] = n means that triples[t] belongs to imgs[n].
158 | """
159 | # batch is a list, and each element is (image, objs, boxes, triples)
160 | all_imgs, all_objs, all_boxes, all_triples = [], [], [], []
161 | all_obj_to_img, all_triple_to_img = [], []
162 | obj_offset = 0
163 | for i, (img, objs, boxes, triples) in enumerate(batch):
164 | all_imgs.append(img[None])
165 | O, T = objs.size(0), triples.size(0)
166 | all_objs.append(objs)
167 | all_boxes.append(boxes)
168 | triples = triples.clone()
169 | triples[:, 0] += obj_offset
170 | triples[:, 2] += obj_offset
171 | all_triples.append(triples)
172 |
173 | all_obj_to_img.append(torch.LongTensor(O).fill_(i))
174 | all_triple_to_img.append(torch.LongTensor(T).fill_(i))
175 | obj_offset += O
176 |
177 | all_imgs = torch.cat(all_imgs)
178 | all_objs = torch.cat(all_objs)
179 | all_boxes = torch.cat(all_boxes)
180 | all_triples = torch.cat(all_triples)
181 | all_obj_to_img = torch.cat(all_obj_to_img)
182 | all_triple_to_img = torch.cat(all_triple_to_img)
183 |
184 | out = (all_imgs, all_objs, all_boxes, all_triples,
185 | all_obj_to_img, all_triple_to_img)
186 | return out
187 |
188 |
189 | def vg_uncollate_fn(batch):
190 | """
191 | Inverse operation to the above.
192 | """
193 | imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch
194 | out = []
195 | obj_offset = 0
196 | for i in range(imgs.size(0)):
197 | cur_img = imgs[i]
198 | o_idxs = (obj_to_img == i).nonzero().view(-1)
199 | t_idxs = (triple_to_img == i).nonzero().view(-1)
200 | cur_objs = objs[o_idxs]
201 | cur_boxes = boxes[o_idxs]
202 | cur_triples = triples[t_idxs].clone()
203 | cur_triples[:, 0] -= obj_offset
204 | cur_triples[:, 2] -= obj_offset
205 | obj_offset += cur_objs.size(0)
206 | out.append((cur_img, cur_objs, cur_boxes, cur_triples))
207 | return out
208 |
209 |
--------------------------------------------------------------------------------
/sg2im/discriminators.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 |
21 | from sg2im.bilinear import crop_bbox_batch
22 | from sg2im.layers import GlobalAvgPool, Flatten, get_activation, build_cnn
23 |
24 |
25 | class PatchDiscriminator(nn.Module):
26 | def __init__(self, arch, normalization='batch', activation='leakyrelu-0.2',
27 | padding='same', pooling='avg', input_size=(128,128),
28 | layout_dim=0):
29 | super(PatchDiscriminator, self).__init__()
30 | input_dim = 3 + layout_dim
31 | arch = 'I%d,%s' % (input_dim, arch)
32 | cnn_kwargs = {
33 | 'arch': arch,
34 | 'normalization': normalization,
35 | 'activation': activation,
36 | 'pooling': pooling,
37 | 'padding': padding,
38 | }
39 | self.cnn, output_dim = build_cnn(**cnn_kwargs)
40 | self.classifier = nn.Conv2d(output_dim, 1, kernel_size=1, stride=1)
41 |
42 | def forward(self, x, layout=None):
43 | if layout is not None:
44 | x = torch.cat([x, layout], dim=1)
45 | return self.cnn(x)
46 |
47 |
48 | class AcDiscriminator(nn.Module):
49 | def __init__(self, vocab, arch, normalization='none', activation='relu',
50 | padding='same', pooling='avg'):
51 | super(AcDiscriminator, self).__init__()
52 | self.vocab = vocab
53 |
54 | cnn_kwargs = {
55 | 'arch': arch,
56 | 'normalization': normalization,
57 | 'activation': activation,
58 | 'pooling': pooling,
59 | 'padding': padding,
60 | }
61 | cnn, D = build_cnn(**cnn_kwargs)
62 | self.cnn = nn.Sequential(cnn, GlobalAvgPool(), nn.Linear(D, 1024))
63 | num_objects = len(vocab['object_idx_to_name'])
64 |
65 | self.real_classifier = nn.Linear(1024, 1)
66 | self.obj_classifier = nn.Linear(1024, num_objects)
67 |
68 | def forward(self, x, y):
69 | if x.dim() == 3:
70 | x = x[:, None]
71 | vecs = self.cnn(x)
72 | real_scores = self.real_classifier(vecs)
73 | obj_scores = self.obj_classifier(vecs)
74 | ac_loss = F.cross_entropy(obj_scores, y)
75 | return real_scores, ac_loss
76 |
77 |
78 | class AcCropDiscriminator(nn.Module):
79 | def __init__(self, vocab, arch, normalization='none', activation='relu',
80 | object_size=64, padding='same', pooling='avg'):
81 | super(AcCropDiscriminator, self).__init__()
82 | self.vocab = vocab
83 | self.discriminator = AcDiscriminator(vocab, arch, normalization,
84 | activation, padding, pooling)
85 | self.object_size = object_size
86 |
87 | def forward(self, imgs, objs, boxes, obj_to_img):
88 | crops = crop_bbox_batch(imgs, boxes, obj_to_img, self.object_size)
89 | real_scores, ac_loss = self.discriminator(crops, objs)
90 | return real_scores, ac_loss
91 |
--------------------------------------------------------------------------------
/sg2im/graph.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import torch
18 | import torch.nn as nn
19 | from sg2im.layers import build_mlp
20 |
21 | """
22 | PyTorch modules for dealing with graphs.
23 | """
24 |
25 |
26 | def _init_weights(module):
27 | if hasattr(module, 'weight'):
28 | if isinstance(module, nn.Linear):
29 | nn.init.kaiming_normal_(module.weight)
30 |
31 |
32 | class GraphTripleConv(nn.Module):
33 | """
34 | A single layer of scene graph convolution.
35 | """
36 | def __init__(self, input_dim, output_dim=None, hidden_dim=512,
37 | pooling='avg', mlp_normalization='none'):
38 | super(GraphTripleConv, self).__init__()
39 | if output_dim is None:
40 | output_dim = input_dim
41 | self.input_dim = input_dim
42 | self.output_dim = output_dim
43 | self.hidden_dim = hidden_dim
44 |
45 | assert pooling in ['sum', 'avg'], 'Invalid pooling "%s"' % pooling
46 | self.pooling = pooling
47 | net1_layers = [3 * input_dim, hidden_dim, 2 * hidden_dim + output_dim]
48 | net1_layers = [l for l in net1_layers if l is not None]
49 | self.net1 = build_mlp(net1_layers, batch_norm=mlp_normalization)
50 | self.net1.apply(_init_weights)
51 |
52 | net2_layers = [hidden_dim, hidden_dim, output_dim]
53 | self.net2 = build_mlp(net2_layers, batch_norm=mlp_normalization)
54 | self.net2.apply(_init_weights)
55 |
56 | def forward(self, obj_vecs, pred_vecs, edges):
57 | """
58 | Inputs:
59 | - obj_vecs: FloatTensor of shape (O, D) giving vectors for all objects
60 | - pred_vecs: FloatTensor of shape (T, D) giving vectors for all predicates
61 | - edges: LongTensor of shape (T, 2) where edges[k] = [i, j] indicates the
62 | presence of a triple [obj_vecs[i], pred_vecs[k], obj_vecs[j]]
63 |
64 | Outputs:
65 | - new_obj_vecs: FloatTensor of shape (O, D) giving new vectors for objects
66 | - new_pred_vecs: FloatTensor of shape (T, D) giving new vectors for predicates
67 | """
68 | dtype, device = obj_vecs.dtype, obj_vecs.device
69 | O, T = obj_vecs.size(0), pred_vecs.size(0)
70 | Din, H, Dout = self.input_dim, self.hidden_dim, self.output_dim
71 |
72 | # Break apart indices for subjects and objects; these have shape (T,)
73 | s_idx = edges[:, 0].contiguous()
74 | o_idx = edges[:, 1].contiguous()
75 |
76 | # Get current vectors for subjects and objects; these have shape (T, Din)
77 | cur_s_vecs = obj_vecs[s_idx]
78 | cur_o_vecs = obj_vecs[o_idx]
79 |
80 | # Get current vectors for triples; shape is (T, 3 * Din)
81 | # Pass through net1 to get new triple vecs; shape is (T, 2 * H + Dout)
82 | cur_t_vecs = torch.cat([cur_s_vecs, pred_vecs, cur_o_vecs], dim=1)
83 | new_t_vecs = self.net1(cur_t_vecs)
84 |
85 | # Break apart into new s, p, and o vecs; s and o vecs have shape (T, H) and
86 | # p vecs have shape (T, Dout)
87 | new_s_vecs = new_t_vecs[:, :H]
88 | new_p_vecs = new_t_vecs[:, H:(H+Dout)]
89 | new_o_vecs = new_t_vecs[:, (H+Dout):(2 * H + Dout)]
90 |
91 | # Allocate space for pooled object vectors of shape (O, H)
92 | pooled_obj_vecs = torch.zeros(O, H, dtype=dtype, device=device)
93 |
94 | # Use scatter_add to sum vectors for objects that appear in multiple triples;
95 | # we first need to expand the indices to have shape (T, D)
96 | s_idx_exp = s_idx.view(-1, 1).expand_as(new_s_vecs)
97 | o_idx_exp = o_idx.view(-1, 1).expand_as(new_o_vecs)
98 | pooled_obj_vecs = pooled_obj_vecs.scatter_add(0, s_idx_exp, new_s_vecs)
99 | pooled_obj_vecs = pooled_obj_vecs.scatter_add(0, o_idx_exp, new_o_vecs)
100 |
101 | if self.pooling == 'avg':
102 | # Figure out how many times each object has appeared, again using
103 | # some scatter_add trickery.
104 | obj_counts = torch.zeros(O, dtype=dtype, device=device)
105 | ones = torch.ones(T, dtype=dtype, device=device)
106 | obj_counts = obj_counts.scatter_add(0, s_idx, ones)
107 | obj_counts = obj_counts.scatter_add(0, o_idx, ones)
108 |
109 | # Divide the new object vectors by the number of times they
110 | # appeared, but first clamp at 1 to avoid dividing by zero;
111 | # objects that appear in no triples will have output vector 0
112 | # so this will not affect them.
113 | obj_counts = obj_counts.clamp(min=1)
114 | pooled_obj_vecs = pooled_obj_vecs / obj_counts.view(-1, 1)
115 |
116 | # Send pooled object vectors through net2 to get output object vectors,
117 | # of shape (O, Dout)
118 | new_obj_vecs = self.net2(pooled_obj_vecs)
119 |
120 | return new_obj_vecs, new_p_vecs
121 |
122 |
123 | class GraphTripleConvNet(nn.Module):
124 | """ A sequence of scene graph convolution layers """
125 | def __init__(self, input_dim, num_layers=5, hidden_dim=512, pooling='avg',
126 | mlp_normalization='none'):
127 | super(GraphTripleConvNet, self).__init__()
128 |
129 | self.num_layers = num_layers
130 | self.gconvs = nn.ModuleList()
131 | gconv_kwargs = {
132 | 'input_dim': input_dim,
133 | 'hidden_dim': hidden_dim,
134 | 'pooling': pooling,
135 | 'mlp_normalization': mlp_normalization,
136 | }
137 | for _ in range(self.num_layers):
138 | self.gconvs.append(GraphTripleConv(**gconv_kwargs))
139 |
140 | def forward(self, obj_vecs, pred_vecs, edges):
141 | for i in range(self.num_layers):
142 | gconv = self.gconvs[i]
143 | obj_vecs, pred_vecs = gconv(obj_vecs, pred_vecs, edges)
144 | return obj_vecs, pred_vecs
145 |
146 |
147 |
--------------------------------------------------------------------------------
/sg2im/layers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 |
21 |
22 | def get_normalization_2d(channels, normalization):
23 | if normalization == 'instance':
24 | return nn.InstanceNorm2d(channels)
25 | elif normalization == 'batch':
26 | return nn.BatchNorm2d(channels)
27 | elif normalization == 'none':
28 | return None
29 | else:
30 | raise ValueError('Unrecognized normalization type "%s"' % normalization)
31 |
32 |
33 | def get_activation(name):
34 | kwargs = {}
35 | if name.lower().startswith('leakyrelu'):
36 | if '-' in name:
37 | slope = float(name.split('-')[1])
38 | kwargs = {'negative_slope': slope}
39 | name = 'leakyrelu'
40 | activations = {
41 | 'relu': nn.ReLU,
42 | 'leakyrelu': nn.LeakyReLU,
43 | }
44 | if name.lower() not in activations:
45 | raise ValueError('Invalid activation "%s"' % name)
46 | return activations[name.lower()](**kwargs)
47 |
48 |
49 |
50 |
51 | def _init_conv(layer, method):
52 | if not isinstance(layer, nn.Conv2d):
53 | return
54 | if method == 'default':
55 | return
56 | elif method == 'kaiming-normal':
57 | nn.init.kaiming_normal(layer.weight)
58 | elif method == 'kaiming-uniform':
59 | nn.init.kaiming_uniform(layer.weight)
60 |
61 |
62 | class Flatten(nn.Module):
63 | def forward(self, x):
64 | return x.view(x.size(0), -1)
65 |
66 | def __repr__(self):
67 | return 'Flatten()'
68 |
69 |
70 | class Unflatten(nn.Module):
71 | def __init__(self, size):
72 | super(Unflatten, self).__init__()
73 | self.size = size
74 |
75 | def forward(self, x):
76 | return x.view(*self.size)
77 |
78 | def __repr__(self):
79 | size_str = ', '.join('%d' % d for d in self.size)
80 | return 'Unflatten(%s)' % size_str
81 |
82 |
83 | class GlobalAvgPool(nn.Module):
84 | def forward(self, x):
85 | N, C = x.size(0), x.size(1)
86 | return x.view(N, C, -1).mean(dim=2)
87 |
88 |
89 | class ResidualBlock(nn.Module):
90 | def __init__(self, channels, normalization='batch', activation='relu',
91 | padding='same', kernel_size=3, init='default'):
92 | super(ResidualBlock, self).__init__()
93 |
94 | K = kernel_size
95 | P = _get_padding(K, padding)
96 | C = channels
97 | self.padding = P
98 | layers = [
99 | get_normalization_2d(C, normalization),
100 | get_activation(activation),
101 | nn.Conv2d(C, C, kernel_size=K, padding=P),
102 | get_normalization_2d(C, normalization),
103 | get_activation(activation),
104 | nn.Conv2d(C, C, kernel_size=K, padding=P),
105 | ]
106 | layers = [layer for layer in layers if layer is not None]
107 | for layer in layers:
108 | _init_conv(layer, method=init)
109 | self.net = nn.Sequential(*layers)
110 |
111 | def forward(self, x):
112 | P = self.padding
113 | shortcut = x
114 | if P == 0:
115 | shortcut = x[:, :, P:-P, P:-P]
116 | y = self.net(x)
117 | return shortcut + self.net(x)
118 |
119 |
120 | def _get_padding(K, mode):
121 | """ Helper method to compute padding size """
122 | if mode == 'valid':
123 | return 0
124 | elif mode == 'same':
125 | assert K % 2 == 1, 'Invalid kernel size %d for "same" padding' % K
126 | return (K - 1) // 2
127 |
128 |
129 | def build_cnn(arch, normalization='batch', activation='relu', padding='same',
130 | pooling='max', init='default'):
131 | """
132 | Build a CNN from an architecture string, which is a list of layer
133 | specification strings. The overall architecture can be given as a list or as
134 | a comma-separated string.
135 |
136 | All convolutions *except for the first* are preceeded by normalization and
137 | nonlinearity.
138 |
139 | All other layers support the following:
140 | - IX: Indicates that the number of input channels to the network is X.
141 | Can only be used at the first layer; if not present then we assume
142 | 3 input channels.
143 | - CK-X: KxK convolution with X output channels
144 | - CK-X-S: KxK convolution with X output channels and stride S
145 | - R: Residual block keeping the same number of channels
146 | - UX: Nearest-neighbor upsampling with factor X
147 | - PX: Spatial pooling with factor X
148 | - FC-X-Y: Flatten followed by fully-connected layer
149 |
150 | Returns a tuple of:
151 | - cnn: An nn.Sequential
152 | - channels: Number of output channels
153 | """
154 | if isinstance(arch, str):
155 | arch = arch.split(',')
156 | cur_C = 3
157 | if len(arch) > 0 and arch[0][0] == 'I':
158 | cur_C = int(arch[0][1:])
159 | arch = arch[1:]
160 |
161 | first_conv = True
162 | flat = False
163 | layers = []
164 | for i, s in enumerate(arch):
165 | if s[0] == 'C':
166 | if not first_conv:
167 | layers.append(get_normalization_2d(cur_C, normalization))
168 | layers.append(get_activation(activation))
169 | first_conv = False
170 | vals = [int(i) for i in s[1:].split('-')]
171 | if len(vals) == 2:
172 | K, next_C = vals
173 | stride = 1
174 | elif len(vals) == 3:
175 | K, next_C, stride = vals
176 | # K, next_C = (int(i) for i in s[1:].split('-'))
177 | P = _get_padding(K, padding)
178 | conv = nn.Conv2d(cur_C, next_C, kernel_size=K, padding=P, stride=stride)
179 | layers.append(conv)
180 | _init_conv(layers[-1], init)
181 | cur_C = next_C
182 | elif s[0] == 'R':
183 | norm = 'none' if first_conv else normalization
184 | res = ResidualBlock(cur_C, normalization=norm, activation=activation,
185 | padding=padding, init=init)
186 | layers.append(res)
187 | first_conv = False
188 | elif s[0] == 'U':
189 | factor = int(s[1:])
190 | layers.append(nn.Upsample(scale_factor=factor, mode='nearest'))
191 | elif s[0] == 'P':
192 | factor = int(s[1:])
193 | if pooling == 'max':
194 | pool = nn.MaxPool2d(kernel_size=factor, stride=factor)
195 | elif pooling == 'avg':
196 | pool = nn.AvgPool2d(kernel_size=factor, stride=factor)
197 | layers.append(pool)
198 | elif s[:2] == 'FC':
199 | _, Din, Dout = s.split('-')
200 | Din, Dout = int(Din), int(Dout)
201 | if not flat:
202 | layers.append(Flatten())
203 | flat = True
204 | layers.append(nn.Linear(Din, Dout))
205 | if i + 1 < len(arch):
206 | layers.append(get_activation(activation))
207 | cur_C = Dout
208 | else:
209 | raise ValueError('Invalid layer "%s"' % s)
210 | layers = [layer for layer in layers if layer is not None]
211 | for layer in layers:
212 | print(layer)
213 | return nn.Sequential(*layers), cur_C
214 |
215 |
216 | def build_mlp(dim_list, activation='relu', batch_norm='none',
217 | dropout=0, final_nonlinearity=True):
218 | layers = []
219 | for i in range(len(dim_list) - 1):
220 | dim_in, dim_out = dim_list[i], dim_list[i + 1]
221 | layers.append(nn.Linear(dim_in, dim_out))
222 | final_layer = (i == len(dim_list) - 2)
223 | if not final_layer or final_nonlinearity:
224 | if batch_norm == 'batch':
225 | layers.append(nn.BatchNorm1d(dim_out))
226 | if activation == 'relu':
227 | layers.append(nn.ReLU())
228 | elif activation == 'leakyrelu':
229 | layers.append(nn.LeakyReLU())
230 | if dropout > 0:
231 | layers.append(nn.Dropout(p=dropout))
232 | return nn.Sequential(*layers)
233 |
234 |
--------------------------------------------------------------------------------
/sg2im/layout.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 | from sg2im.utils import timeit, get_gpu_memory, lineno
21 |
22 |
23 | """
24 | Functions for computing image layouts from object vectors, bounding boxes,
25 | and segmentation masks. These are used to compute course scene layouts which
26 | are then fed as input to the cascaded refinement network.
27 | """
28 |
29 |
30 | def boxes_to_layout(vecs, boxes, obj_to_img, H, W=None, pooling='sum'):
31 | """
32 | Inputs:
33 | - vecs: Tensor of shape (O, D) giving vectors
34 | - boxes: Tensor of shape (O, 4) giving bounding boxes in the format
35 | [x0, y0, x1, y1] in the [0, 1] coordinate space
36 | - obj_to_img: LongTensor of shape (O,) mapping each element of vecs to
37 | an image, where each element is in the range [0, N). If obj_to_img[i] = j
38 | then vecs[i] belongs to image j.
39 | - H, W: Size of the output
40 |
41 | Returns:
42 | - out: Tensor of shape (N, D, H, W)
43 | """
44 | O, D = vecs.size()
45 | if W is None:
46 | W = H
47 |
48 | grid = _boxes_to_grid(boxes, H, W)
49 |
50 | # If we don't add extra spatial dimensions here then out-of-bounds
51 | # elements won't be automatically set to 0
52 | img_in = vecs.view(O, D, 1, 1).expand(O, D, 8, 8)
53 | sampled = F.grid_sample(img_in, grid) # (O, D, H, W)
54 |
55 | # Explicitly masking makes everything quite a bit slower.
56 | # If we rely on implicit masking the interpolated boxes end up
57 | # blurred around the edges, but it should be fine.
58 | # mask = ((X < 0) + (X > 1) + (Y < 0) + (Y > 1)).clamp(max=1)
59 | # sampled[mask[:, None]] = 0
60 |
61 | out = _pool_samples(sampled, obj_to_img, pooling=pooling)
62 |
63 | return out
64 |
65 |
66 | def masks_to_layout(vecs, boxes, masks, obj_to_img, H, W=None, pooling='sum'):
67 | """
68 | Inputs:
69 | - vecs: Tensor of shape (O, D) giving vectors
70 | - boxes: Tensor of shape (O, 4) giving bounding boxes in the format
71 | [x0, y0, x1, y1] in the [0, 1] coordinate space
72 | - masks: Tensor of shape (O, M, M) giving binary masks for each object
73 | - obj_to_img: LongTensor of shape (O,) mapping objects to images
74 | - H, W: Size of the output image.
75 |
76 | Returns:
77 | - out: Tensor of shape (N, D, H, W)
78 | """
79 | O, D = vecs.size()
80 | M = masks.size(1)
81 | assert masks.size() == (O, M, M)
82 | if W is None:
83 | W = H
84 |
85 | grid = _boxes_to_grid(boxes, H, W)
86 |
87 | img_in = vecs.view(O, D, 1, 1) * masks.float().view(O, 1, M, M)
88 | sampled = F.grid_sample(img_in, grid)
89 |
90 | out = _pool_samples(sampled, obj_to_img, pooling=pooling)
91 | return out
92 |
93 |
94 | def _boxes_to_grid(boxes, H, W):
95 | """
96 | Input:
97 | - boxes: FloatTensor of shape (O, 4) giving boxes in the [x0, y0, x1, y1]
98 | format in the [0, 1] coordinate space
99 | - H, W: Scalars giving size of output
100 |
101 | Returns:
102 | - grid: FloatTensor of shape (O, H, W, 2) suitable for passing to grid_sample
103 | """
104 | O = boxes.size(0)
105 |
106 | boxes = boxes.view(O, 4, 1, 1)
107 |
108 | # All these are (O, 1, 1)
109 | x0, y0 = boxes[:, 0], boxes[:, 1]
110 | x1, y1 = boxes[:, 2], boxes[:, 3]
111 | ww = x1 - x0
112 | hh = y1 - y0
113 |
114 | X = torch.linspace(0, 1, steps=W).view(1, 1, W).to(boxes)
115 | Y = torch.linspace(0, 1, steps=H).view(1, H, 1).to(boxes)
116 |
117 | X = (X - x0) / ww # (O, 1, W)
118 | Y = (Y - y0) / hh # (O, H, 1)
119 |
120 | # Stack does not broadcast its arguments so we need to expand explicitly
121 | X = X.expand(O, H, W)
122 | Y = Y.expand(O, H, W)
123 | grid = torch.stack([X, Y], dim=3) # (O, H, W, 2)
124 |
125 | # Right now grid is in [0, 1] space; transform to [-1, 1]
126 | grid = grid.mul(2).sub(1)
127 |
128 | return grid
129 |
130 |
131 | def _pool_samples(samples, obj_to_img, pooling='sum'):
132 | """
133 | Input:
134 | - samples: FloatTensor of shape (O, D, H, W)
135 | - obj_to_img: LongTensor of shape (O,) with each element in the range
136 | [0, N) mapping elements of samples to output images
137 |
138 | Output:
139 | - pooled: FloatTensor of shape (N, D, H, W)
140 | """
141 | dtype, device = samples.dtype, samples.device
142 | O, D, H, W = samples.size()
143 | N = obj_to_img.data.max().item() + 1
144 |
145 | # Use scatter_add to sum the sampled outputs for each image
146 | out = torch.zeros(N, D, H, W, dtype=dtype, device=device)
147 | idx = obj_to_img.view(O, 1, 1, 1).expand(O, D, H, W)
148 | out = out.scatter_add(0, idx, samples)
149 |
150 | if pooling == 'avg':
151 | # Divide each output mask by the number of objects; use scatter_add again
152 | # to count the number of objects per image.
153 | ones = torch.ones(O, dtype=dtype, device=device)
154 | obj_counts = torch.zeros(N, dtype=dtype, device=device)
155 | obj_counts = obj_counts.scatter_add(0, obj_to_img, ones)
156 | print(obj_counts)
157 | obj_counts = obj_counts.clamp(min=1)
158 | out = out / obj_counts.view(N, 1, 1, 1)
159 | elif pooling != 'sum':
160 | raise ValueError('Invalid pooling "%s"' % pooling)
161 |
162 | return out
163 |
164 |
165 | if __name__ == '__main__':
166 | vecs = torch.FloatTensor([
167 | [1, 0, 0], [0, 1, 0], [0, 0, 1],
168 | [1, 0, 0], [0, 1, 0], [0, 0, 1],
169 | ])
170 | boxes = torch.FloatTensor([
171 | [0.25, 0.125, 0.5, 0.875],
172 | [0, 0, 1, 0.25],
173 | [0.6125, 0, 0.875, 1],
174 | [0, 0.8, 1, 1.0],
175 | [0.25, 0.125, 0.5, 0.875],
176 | [0.6125, 0, 0.875, 1],
177 | ])
178 | obj_to_img = torch.LongTensor([0, 0, 0, 1, 1, 1])
179 | # vecs = torch.FloatTensor([[[1]]])
180 | # boxes = torch.FloatTensor([[[0.25, 0.25, 0.75, 0.75]]])
181 | vecs, boxes = vecs.cuda(), boxes.cuda()
182 | obj_to_img = obj_to_img.cuda()
183 | out = boxes_to_layout(vecs, boxes, obj_to_img, 256, pooling='sum')
184 |
185 | from torchvision.utils import save_image
186 | save_image(out.data, 'out.png')
187 |
188 |
189 | masks = torch.FloatTensor([
190 | [
191 | [0, 0, 1, 0, 0],
192 | [0, 1, 1, 1, 0],
193 | [1, 1, 1, 1, 1],
194 | [0, 1, 1, 1, 0],
195 | [0, 0, 1, 0, 0],
196 | ],
197 | [
198 | [0, 0, 1, 0, 0],
199 | [0, 1, 0, 1, 0],
200 | [1, 0, 0, 0, 1],
201 | [0, 1, 0, 1, 0],
202 | [0, 0, 1, 0, 0],
203 | ],
204 | [
205 | [0, 0, 1, 0, 0],
206 | [0, 1, 1, 1, 0],
207 | [1, 1, 1, 1, 1],
208 | [0, 1, 1, 1, 0],
209 | [0, 0, 1, 0, 0],
210 | ],
211 | [
212 | [0, 0, 1, 0, 0],
213 | [0, 1, 1, 1, 0],
214 | [1, 1, 1, 1, 1],
215 | [0, 1, 1, 1, 0],
216 | [0, 0, 1, 0, 0],
217 | ],
218 | [
219 | [0, 0, 1, 0, 0],
220 | [0, 1, 1, 1, 0],
221 | [1, 1, 1, 1, 1],
222 | [0, 1, 1, 1, 0],
223 | [0, 0, 1, 0, 0],
224 | ],
225 | [
226 | [0, 0, 1, 0, 0],
227 | [0, 1, 1, 1, 0],
228 | [1, 1, 1, 1, 1],
229 | [0, 1, 1, 1, 0],
230 | [0, 0, 1, 0, 0],
231 | ]
232 | ])
233 | masks = masks.cuda()
234 | out = masks_to_layout(vecs, boxes, masks, obj_to_img, 256)
235 | save_image(out.data, 'out_masks.png')
236 |
--------------------------------------------------------------------------------
/sg2im/losses.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import torch
18 | import torch.nn.functional as F
19 |
20 |
21 | def get_gan_losses(gan_type):
22 | """
23 | Returns the generator and discriminator loss for a particular GAN type.
24 |
25 | The returned functions have the following API:
26 | loss_g = g_loss(scores_fake)
27 | loss_d = d_loss(scores_real, scores_fake)
28 | """
29 | if gan_type == 'gan':
30 | return gan_g_loss, gan_d_loss
31 | elif gan_type == 'wgan':
32 | return wgan_g_loss, wgan_d_loss
33 | elif gan_type == 'lsgan':
34 | return lsgan_g_loss, lsgan_d_loss
35 | else:
36 | raise ValueError('Unrecognized GAN type "%s"' % gan_type)
37 |
38 |
39 | def bce_loss(input, target):
40 | """
41 | Numerically stable version of the binary cross-entropy loss function.
42 |
43 | As per https://github.com/pytorch/pytorch/issues/751
44 | See the TensorFlow docs for a derivation of this formula:
45 | https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
46 |
47 | Inputs:
48 | - input: PyTorch Tensor of shape (N, ) giving scores.
49 | - target: PyTorch Tensor of shape (N,) containing 0 and 1 giving targets.
50 |
51 | Returns:
52 | - A PyTorch Tensor containing the mean BCE loss over the minibatch of
53 | input data.
54 | """
55 | neg_abs = -input.abs()
56 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
57 | return loss.mean()
58 |
59 |
60 | def _make_targets(x, y):
61 | """
62 | Inputs:
63 | - x: PyTorch Tensor
64 | - y: Python scalar
65 |
66 | Outputs:
67 | - out: PyTorch Variable with same shape and dtype as x, but filled with y
68 | """
69 | return torch.full_like(x, y)
70 |
71 |
72 | def gan_g_loss(scores_fake):
73 | """
74 | Input:
75 | - scores_fake: Tensor of shape (N,) containing scores for fake samples
76 |
77 | Output:
78 | - loss: Variable of shape (,) giving GAN generator loss
79 | """
80 | if scores_fake.dim() > 1:
81 | scores_fake = scores_fake.view(-1)
82 | y_fake = _make_targets(scores_fake, 1)
83 | return bce_loss(scores_fake, y_fake)
84 |
85 |
86 | def gan_d_loss(scores_real, scores_fake):
87 | """
88 | Input:
89 | - scores_real: Tensor of shape (N,) giving scores for real samples
90 | - scores_fake: Tensor of shape (N,) giving scores for fake samples
91 |
92 | Output:
93 | - loss: Tensor of shape (,) giving GAN discriminator loss
94 | """
95 | assert scores_real.size() == scores_fake.size()
96 | if scores_real.dim() > 1:
97 | scores_real = scores_real.view(-1)
98 | scores_fake = scores_fake.view(-1)
99 | y_real = _make_targets(scores_real, 1)
100 | y_fake = _make_targets(scores_fake, 0)
101 | loss_real = bce_loss(scores_real, y_real)
102 | loss_fake = bce_loss(scores_fake, y_fake)
103 | return loss_real + loss_fake
104 |
105 |
106 | def wgan_g_loss(scores_fake):
107 | """
108 | Input:
109 | - scores_fake: Tensor of shape (N,) containing scores for fake samples
110 |
111 | Output:
112 | - loss: Tensor of shape (,) giving WGAN generator loss
113 | """
114 | return -scores_fake.mean()
115 |
116 |
117 | def wgan_d_loss(scores_real, scores_fake):
118 | """
119 | Input:
120 | - scores_real: Tensor of shape (N,) giving scores for real samples
121 | - scores_fake: Tensor of shape (N,) giving scores for fake samples
122 |
123 | Output:
124 | - loss: Tensor of shape (,) giving WGAN discriminator loss
125 | """
126 | return scores_fake.mean() - scores_real.mean()
127 |
128 |
129 | def lsgan_g_loss(scores_fake):
130 | if scores_fake.dim() > 1:
131 | scores_fake = scores_fake.view(-1)
132 | y_fake = _make_targets(scores_fake, 1)
133 | return F.mse_loss(scores_fake.sigmoid(), y_fake)
134 |
135 |
136 | def lsgan_d_loss(scores_real, scores_fake):
137 | assert scores_real.size() == scores_fake.size()
138 | if scores_real.dim() > 1:
139 | scores_real = scores_real.view(-1)
140 | scores_fake = scores_fake.view(-1)
141 | y_real = _make_targets(scores_real, 1)
142 | y_fake = _make_targets(scores_fake, 0)
143 | loss_real = F.mse_loss(scores_real.sigmoid(), y_real)
144 | loss_fake = F.mse_loss(scores_fake.sigmoid(), y_fake)
145 | return loss_real + loss_fake
146 |
147 |
148 | def gradient_penalty(x_real, x_fake, f, gamma=1.0):
149 | N = x_real.size(0)
150 | device, dtype = x_real.device, x_real.dtype
151 | eps = torch.randn(N, 1, 1, 1, device=device, dtype=dtype)
152 | x_hat = eps * x_real + (1 - eps) * x_fake
153 | x_hat_score = f(x_hat)
154 | if x_hat_score.dim() > 1:
155 | x_hat_score = x_hat_score.view(x_hat_score.size(0), -1).mean(dim=1)
156 | x_hat_score = x_hat_score.sum()
157 | grad_x_hat, = torch.autograd.grad(x_hat_score, x_hat, create_graph=True)
158 | grad_x_hat_norm = grad_x_hat.contiguous().view(N, -1).norm(p=2, dim=1)
159 | gp_loss = (grad_x_hat_norm - gamma).pow(2).div(gamma * gamma).mean()
160 | return gp_loss
161 |
162 |
163 |
--------------------------------------------------------------------------------
/sg2im/metrics.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import torch
18 |
19 |
20 | def intersection(bbox_pred, bbox_gt):
21 | max_xy = torch.min(bbox_pred[:, 2:], bbox_gt[:, 2:])
22 | min_xy = torch.max(bbox_pred[:, :2], bbox_gt[:, :2])
23 | inter = torch.clamp((max_xy - min_xy), min=0)
24 | return inter[:, 0] * inter[:, 1]
25 |
26 |
27 | def jaccard(bbox_pred, bbox_gt):
28 | inter = intersection(bbox_pred, bbox_gt)
29 | area_pred = (bbox_pred[:, 2] - bbox_pred[:, 0]) * (bbox_pred[:, 3] -
30 | bbox_pred[:, 1])
31 | area_gt = (bbox_gt[:, 2] - bbox_gt[:, 0]) * (bbox_gt[:, 3] -
32 | bbox_gt[:, 1])
33 | union = area_pred + area_gt - inter
34 | iou = torch.div(inter, union)
35 | return torch.sum(iou)
36 |
37 | def get_total_norm(parameters, norm_type=2):
38 | if norm_type == float('inf'):
39 | total_norm = max(p.grad.data.abs().max() for p in parameters)
40 | else:
41 | total_norm = 0
42 | for p in parameters:
43 | try:
44 | param_norm = p.grad.data.norm(norm_type)
45 | total_norm += param_norm ** norm_type
46 | total_norm = total_norm ** (1. / norm_type)
47 | except:
48 | continue
49 | return total_norm
50 |
51 |
--------------------------------------------------------------------------------
/sg2im/model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import math
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 |
22 | import sg2im.box_utils as box_utils
23 | from sg2im.graph import GraphTripleConv, GraphTripleConvNet
24 | from sg2im.crn import RefinementNetwork
25 | from sg2im.layout import boxes_to_layout, masks_to_layout
26 | from sg2im.layers import build_mlp
27 |
28 |
29 | class Sg2ImModel(nn.Module):
30 | def __init__(self, vocab, image_size=(64, 64), embedding_dim=64,
31 | gconv_dim=128, gconv_hidden_dim=512,
32 | gconv_pooling='avg', gconv_num_layers=5,
33 | refinement_dims=(1024, 512, 256, 128, 64),
34 | normalization='batch', activation='leakyrelu-0.2',
35 | mask_size=None, mlp_normalization='none', layout_noise_dim=0,
36 | **kwargs):
37 | super(Sg2ImModel, self).__init__()
38 |
39 | # We used to have some additional arguments:
40 | # vec_noise_dim, gconv_mode, box_anchor, decouple_obj_predictions
41 | if len(kwargs) > 0:
42 | print('WARNING: Model got unexpected kwargs ', kwargs)
43 |
44 | self.vocab = vocab
45 | self.image_size = image_size
46 | self.layout_noise_dim = layout_noise_dim
47 |
48 | num_objs = len(vocab['object_idx_to_name'])
49 | num_preds = len(vocab['pred_idx_to_name'])
50 | self.obj_embeddings = nn.Embedding(num_objs + 1, embedding_dim)
51 | self.pred_embeddings = nn.Embedding(num_preds, embedding_dim)
52 |
53 | if gconv_num_layers == 0:
54 | self.gconv = nn.Linear(embedding_dim, gconv_dim)
55 | elif gconv_num_layers > 0:
56 | gconv_kwargs = {
57 | 'input_dim': embedding_dim,
58 | 'output_dim': gconv_dim,
59 | 'hidden_dim': gconv_hidden_dim,
60 | 'pooling': gconv_pooling,
61 | 'mlp_normalization': mlp_normalization,
62 | }
63 | self.gconv = GraphTripleConv(**gconv_kwargs)
64 |
65 | self.gconv_net = None
66 | if gconv_num_layers > 1:
67 | gconv_kwargs = {
68 | 'input_dim': gconv_dim,
69 | 'hidden_dim': gconv_hidden_dim,
70 | 'pooling': gconv_pooling,
71 | 'num_layers': gconv_num_layers - 1,
72 | 'mlp_normalization': mlp_normalization,
73 | }
74 | self.gconv_net = GraphTripleConvNet(**gconv_kwargs)
75 |
76 | box_net_dim = 4
77 | box_net_layers = [gconv_dim, gconv_hidden_dim, box_net_dim]
78 | self.box_net = build_mlp(box_net_layers, batch_norm=mlp_normalization)
79 |
80 | self.mask_net = None
81 | if mask_size is not None and mask_size > 0:
82 | self.mask_net = self._build_mask_net(num_objs, gconv_dim, mask_size)
83 |
84 | rel_aux_layers = [2 * embedding_dim + 8, gconv_hidden_dim, num_preds]
85 | self.rel_aux_net = build_mlp(rel_aux_layers, batch_norm=mlp_normalization)
86 |
87 | refinement_kwargs = {
88 | 'dims': (gconv_dim + layout_noise_dim,) + refinement_dims,
89 | 'normalization': normalization,
90 | 'activation': activation,
91 | }
92 | self.refinement_net = RefinementNetwork(**refinement_kwargs)
93 |
94 | def _build_mask_net(self, num_objs, dim, mask_size):
95 | output_dim = 1
96 | layers, cur_size = [], 1
97 | while cur_size < mask_size:
98 | layers.append(nn.Upsample(scale_factor=2, mode='nearest'))
99 | layers.append(nn.BatchNorm2d(dim))
100 | layers.append(nn.Conv2d(dim, dim, kernel_size=3, padding=1))
101 | layers.append(nn.ReLU())
102 | cur_size *= 2
103 | if cur_size != mask_size:
104 | raise ValueError('Mask size must be a power of 2')
105 | layers.append(nn.Conv2d(dim, output_dim, kernel_size=1))
106 | return nn.Sequential(*layers)
107 |
108 | def forward(self, objs, triples, obj_to_img=None,
109 | boxes_gt=None, masks_gt=None):
110 | """
111 | Required Inputs:
112 | - objs: LongTensor of shape (O,) giving categories for all objects
113 | - triples: LongTensor of shape (T, 3) where triples[t] = [s, p, o]
114 | means that there is a triple (objs[s], p, objs[o])
115 |
116 | Optional Inputs:
117 | - obj_to_img: LongTensor of shape (O,) where obj_to_img[o] = i
118 | means that objects[o] is an object in image i. If not given then
119 | all objects are assumed to belong to the same image.
120 | - boxes_gt: FloatTensor of shape (O, 4) giving boxes to use for computing
121 | the spatial layout; if not given then use predicted boxes.
122 | """
123 | O, T = objs.size(0), triples.size(0)
124 | s, p, o = triples.chunk(3, dim=1) # All have shape (T, 1)
125 | s, p, o = [x.squeeze(1) for x in [s, p, o]] # Now have shape (T,)
126 | edges = torch.stack([s, o], dim=1) # Shape is (T, 2)
127 |
128 | if obj_to_img is None:
129 | obj_to_img = torch.zeros(O, dtype=objs.dtype, device=objs.device)
130 |
131 | obj_vecs = self.obj_embeddings(objs)
132 | obj_vecs_orig = obj_vecs
133 | pred_vecs = self.pred_embeddings(p)
134 |
135 | if isinstance(self.gconv, nn.Linear):
136 | obj_vecs = self.gconv(obj_vecs)
137 | else:
138 | obj_vecs, pred_vecs = self.gconv(obj_vecs, pred_vecs, edges)
139 | if self.gconv_net is not None:
140 | obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges)
141 |
142 | boxes_pred = self.box_net(obj_vecs)
143 |
144 | masks_pred = None
145 | if self.mask_net is not None:
146 | mask_scores = self.mask_net(obj_vecs.view(O, -1, 1, 1))
147 | masks_pred = mask_scores.squeeze(1).sigmoid()
148 |
149 | s_boxes, o_boxes = boxes_pred[s], boxes_pred[o]
150 | s_vecs, o_vecs = obj_vecs_orig[s], obj_vecs_orig[o]
151 | rel_aux_input = torch.cat([s_boxes, o_boxes, s_vecs, o_vecs], dim=1)
152 | rel_scores = self.rel_aux_net(rel_aux_input)
153 |
154 | H, W = self.image_size
155 | layout_boxes = boxes_pred if boxes_gt is None else boxes_gt
156 |
157 | if masks_pred is None:
158 | layout = boxes_to_layout(obj_vecs, layout_boxes, obj_to_img, H, W)
159 | else:
160 | layout_masks = masks_pred if masks_gt is None else masks_gt
161 | layout = masks_to_layout(obj_vecs, layout_boxes, layout_masks,
162 | obj_to_img, H, W)
163 |
164 | if self.layout_noise_dim > 0:
165 | N, C, H, W = layout.size()
166 | noise_shape = (N, self.layout_noise_dim, H, W)
167 | layout_noise = torch.randn(noise_shape, dtype=layout.dtype,
168 | device=layout.device)
169 | layout = torch.cat([layout, layout_noise], dim=1)
170 | img = self.refinement_net(layout)
171 | return img, boxes_pred, masks_pred, rel_scores
172 |
173 | def encode_scene_graphs(self, scene_graphs):
174 | """
175 | Encode one or more scene graphs using this model's vocabulary. Inputs to
176 | this method are scene graphs represented as dictionaries like the following:
177 |
178 | {
179 | "objects": ["cat", "dog", "sky"],
180 | "relationships": [
181 | [0, "next to", 1],
182 | [0, "beneath", 2],
183 | [2, "above", 1],
184 | ]
185 | }
186 |
187 | This scene graph has three relationshps: cat next to dog, cat beneath sky,
188 | and sky above dog.
189 |
190 | Inputs:
191 | - scene_graphs: A dictionary giving a single scene graph, or a list of
192 | dictionaries giving a sequence of scene graphs.
193 |
194 | Returns a tuple of LongTensors (objs, triples, obj_to_img) that have the
195 | same semantics as self.forward. The returned LongTensors will be on the
196 | same device as the model parameters.
197 | """
198 | if isinstance(scene_graphs, dict):
199 | # We just got a single scene graph, so promote it to a list
200 | scene_graphs = [scene_graphs]
201 |
202 | objs, triples, obj_to_img = [], [], []
203 | obj_offset = 0
204 | for i, sg in enumerate(scene_graphs):
205 | # Insert dummy __image__ object and __in_image__ relationships
206 | sg['objects'].append('__image__')
207 | image_idx = len(sg['objects']) - 1
208 | for j in range(image_idx):
209 | sg['relationships'].append([j, '__in_image__', image_idx])
210 |
211 | for obj in sg['objects']:
212 | obj_idx = self.vocab['object_name_to_idx'].get(obj, None)
213 | if obj_idx is None:
214 | raise ValueError('Object "%s" not in vocab' % obj)
215 | objs.append(obj_idx)
216 | obj_to_img.append(i)
217 | for s, p, o in sg['relationships']:
218 | pred_idx = self.vocab['pred_name_to_idx'].get(p, None)
219 | if pred_idx is None:
220 | raise ValueError('Relationship "%s" not in vocab' % p)
221 | triples.append([s + obj_offset, pred_idx, o + obj_offset])
222 | obj_offset += len(sg['objects'])
223 | device = next(self.parameters()).device
224 | objs = torch.tensor(objs, dtype=torch.int64, device=device)
225 | triples = torch.tensor(triples, dtype=torch.int64, device=device)
226 | obj_to_img = torch.tensor(obj_to_img, dtype=torch.int64, device=device)
227 | return objs, triples, obj_to_img
228 |
229 | def forward_json(self, scene_graphs):
230 | """ Convenience method that combines encode_scene_graphs and forward. """
231 | objs, triples, obj_to_img = self.encode_scene_graphs(scene_graphs)
232 | return self.forward(objs, triples, obj_to_img)
233 |
234 |
--------------------------------------------------------------------------------
/sg2im/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import time
18 | import inspect
19 | import subprocess
20 | from contextlib import contextmanager
21 |
22 | import torch
23 |
24 |
25 | def int_tuple(s):
26 | return tuple(int(i) for i in s.split(','))
27 |
28 |
29 | def float_tuple(s):
30 | return tuple(float(i) for i in s.split(','))
31 |
32 |
33 | def str_tuple(s):
34 | return tuple(s.split(','))
35 |
36 |
37 | def bool_flag(s):
38 | if s == '1':
39 | return True
40 | elif s == '0':
41 | return False
42 | msg = 'Invalid value "%s" for bool flag (should be 0 or 1)'
43 | raise ValueError(msg % s)
44 |
45 |
46 | def lineno():
47 | return inspect.currentframe().f_back.f_lineno
48 |
49 |
50 | def get_gpu_memory():
51 | torch.cuda.synchronize()
52 | opts = [
53 | 'nvidia-smi', '-q', '--gpu=' + str(0), '|', 'grep', '"Used GPU Memory"'
54 | ]
55 | cmd = str.join(' ', opts)
56 | ps = subprocess.Popen(cmd,shell=True,stdout=subprocess.PIPE,stderr=subprocess.STDOUT)
57 | output = ps.communicate()[0].decode('utf-8')
58 | output = output.split("\n")[1].split(":")
59 | consumed_mem = int(output[1].strip().split(" ")[0])
60 | return consumed_mem
61 |
62 |
63 | @contextmanager
64 | def timeit(msg, should_time=True):
65 | if should_time:
66 | torch.cuda.synchronize()
67 | t0 = time.time()
68 | yield
69 | if should_time:
70 | torch.cuda.synchronize()
71 | t1 = time.time()
72 | duration = (t1 - t0) * 1000.0
73 | print('%s: %.2f ms' % (msg, duration))
74 |
75 |
76 | class LossManager(object):
77 | def __init__(self):
78 | self.total_loss = None
79 | self.all_losses = {}
80 |
81 | def add_loss(self, loss, name, weight=1.0):
82 | cur_loss = loss * weight
83 | if self.total_loss is not None:
84 | self.total_loss += cur_loss
85 | else:
86 | self.total_loss = cur_loss
87 |
88 | self.all_losses[name] = cur_loss.data.cpu().item()
89 |
90 | def items(self):
91 | return self.all_losses.items()
92 |
93 |
--------------------------------------------------------------------------------
/sg2im/vis.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import tempfile, os
18 | import torch
19 | import numpy as np
20 | import matplotlib.pyplot as plt
21 | from matplotlib.patches import Rectangle
22 | from imageio import imread
23 |
24 |
25 | """
26 | Utilities for making visualizations.
27 | """
28 |
29 |
30 | def draw_layout(vocab, objs, boxes, masks=None, size=256,
31 | show_boxes=False, bgcolor=(0, 0, 0)):
32 | if bgcolor == 'white':
33 | bgcolor = (255, 255, 255)
34 |
35 | cmap = plt.get_cmap('rainbow')
36 | colors = cmap(np.linspace(0, 1, len(objs)))
37 |
38 | with torch.no_grad():
39 | objs = objs.cpu().clone()
40 | boxes = boxes.cpu().clone()
41 | boxes *= size
42 |
43 | if masks is not None:
44 | masks = masks.cpu().clone()
45 |
46 | bgcolor = np.asarray(bgcolor)
47 | bg = np.ones((size, size, 1)) * bgcolor
48 | plt.imshow(bg.astype(np.uint8))
49 |
50 | plt.gca().set_xlim(0, size)
51 | plt.gca().set_ylim(size, 0)
52 | plt.gca().set_aspect(1.0, adjustable='box')
53 |
54 | for i, obj in enumerate(objs):
55 | name = vocab['object_idx_to_name'][obj]
56 | if name == '__image__':
57 | continue
58 | box = boxes[i]
59 |
60 | if masks is None:
61 | continue
62 | mask = masks[i].numpy()
63 | mask /= mask.max()
64 |
65 | r, g, b, a = colors[i]
66 | colored_mask = mask[:, :, None] * np.asarray(colors[i])
67 |
68 | x0, y0, x1, y1 = box
69 | plt.imshow(colored_mask, extent=(x0, x1, y1, y0),
70 | interpolation='bicubic', alpha=1.0)
71 |
72 | if show_boxes:
73 | for i, obj in enumerate(objs):
74 | name = vocab['object_idx_to_name'][obj]
75 | if name == '__image__':
76 | continue
77 | box = boxes[i]
78 |
79 | draw_box(box, colors[i], name)
80 |
81 |
82 | def draw_box(box, color, text=None):
83 | """
84 | Draw a bounding box using pyplot, optionally with a text box label.
85 |
86 | Inputs:
87 | - box: Tensor or list with 4 elements: [x0, y0, x1, y1] in [0, W] x [0, H]
88 | coordinate system.
89 | - color: pyplot color to use for the box.
90 | - text: (Optional) String; if provided then draw a label for this box.
91 | """
92 | TEXT_BOX_HEIGHT = 10
93 | if torch.is_tensor(box) and box.dim() == 2:
94 | box = box.view(-1)
95 | assert box.size(0) == 4
96 | x0, y0, x1, y1 = box
97 | assert y1 > y0, box
98 | assert x1 > x0, box
99 | w, h = x1 - x0, y1 - y0
100 | rect = Rectangle((x0, y0), w, h, fc='none', lw=2, ec=color)
101 | plt.gca().add_patch(rect)
102 | if text is not None:
103 | text_rect = Rectangle((x0, y0), w, TEXT_BOX_HEIGHT, fc=color, alpha=0.5)
104 | plt.gca().add_patch(text_rect)
105 | tx = 0.5 * (x0 + x1)
106 | ty = y0 + TEXT_BOX_HEIGHT / 2.0
107 | plt.text(tx, ty, text, va='center', ha='center')
108 |
109 |
110 | def draw_scene_graph(objs, triples, vocab=None, **kwargs):
111 | """
112 | Use GraphViz to draw a scene graph. If vocab is not passed then we assume
113 | that objs and triples are python lists containing strings for object and
114 | relationship names.
115 |
116 | Using this requires that GraphViz is installed. On Ubuntu 16.04 this is easy:
117 | sudo apt-get install graphviz
118 | """
119 | output_filename = kwargs.pop('output_filename', 'graph.png')
120 | orientation = kwargs.pop('orientation', 'V')
121 | edge_width = kwargs.pop('edge_width', 6)
122 | arrow_size = kwargs.pop('arrow_size', 1.5)
123 | binary_edge_weight = kwargs.pop('binary_edge_weight', 1.2)
124 | ignore_dummies = kwargs.pop('ignore_dummies', True)
125 |
126 | if orientation not in ['V', 'H']:
127 | raise ValueError('Invalid orientation "%s"' % orientation)
128 | rankdir = {'H': 'LR', 'V': 'TD'}[orientation]
129 |
130 | if vocab is not None:
131 | # Decode object and relationship names
132 | assert torch.is_tensor(objs)
133 | assert torch.is_tensor(triples)
134 | objs_list, triples_list = [], []
135 | for i in range(objs.size(0)):
136 | objs_list.append(vocab['object_idx_to_name'][objs[i].item()])
137 | for i in range(triples.size(0)):
138 | s = triples[i, 0].item()
139 | p = vocab['pred_name_to_idx'][triples[i, 1].item()]
140 | o = triples[i, 2].item()
141 | triples_list.append([s, p, o])
142 | objs, triples = objs_list, triples_list
143 |
144 | # General setup, and style for object nodes
145 | lines = [
146 | 'digraph{',
147 | 'graph [size="5,3",ratio="compress",dpi="300",bgcolor="transparent"]',
148 | 'rankdir=%s' % rankdir,
149 | 'nodesep="0.5"',
150 | 'ranksep="0.5"',
151 | 'node [shape="box",style="rounded,filled",fontsize="48",color="none"]',
152 | 'node [fillcolor="lightpink1"]',
153 | ]
154 | # Output nodes for objects
155 | for i, obj in enumerate(objs):
156 | if ignore_dummies and obj == '__image__':
157 | continue
158 | lines.append('%d [label="%s"]' % (i, obj))
159 |
160 | # Output relationships
161 | next_node_id = len(objs)
162 | lines.append('node [fillcolor="lightblue1"]')
163 | for s, p, o in triples:
164 | if ignore_dummies and p == '__in_image__':
165 | continue
166 | lines += [
167 | '%d [label="%s"]' % (next_node_id, p),
168 | '%d->%d [penwidth=%f,arrowsize=%f,weight=%f]' % (
169 | s, next_node_id, edge_width, arrow_size, binary_edge_weight),
170 | '%d->%d [penwidth=%f,arrowsize=%f,weight=%f]' % (
171 | next_node_id, o, edge_width, arrow_size, binary_edge_weight)
172 | ]
173 | next_node_id += 1
174 | lines.append('}')
175 |
176 | # Now it gets slightly hacky. Write the graphviz spec to a temporary
177 | # text file
178 | ff, dot_filename = tempfile.mkstemp()
179 | with open(dot_filename, 'w') as f:
180 | for line in lines:
181 | f.write('%s\n' % line)
182 | os.close(ff)
183 |
184 | # Shell out to invoke graphviz; this will save the resulting image to disk,
185 | # so we read it, delete it, then return it.
186 | output_format = os.path.splitext(output_filename)[1][1:]
187 | os.system('dot -T%s %s > %s' % (output_format, dot_filename, output_filename))
188 | os.remove(dot_filename)
189 | img = imread(output_filename)
190 | os.remove(output_filename)
191 |
192 | return img
193 |
194 |
195 | if __name__ == '__main__':
196 | o_idx_to_name = ['cat', 'dog', 'hat', 'skateboard']
197 | p_idx_to_name = ['riding', 'wearing', 'on', 'next to', 'above']
198 | o_name_to_idx = {s: i for i, s in enumerate(o_idx_to_name)}
199 | p_name_to_idx = {s: i for i, s in enumerate(p_idx_to_name)}
200 | vocab = {
201 | 'object_idx_to_name': o_idx_to_name,
202 | 'object_name_to_idx': o_name_to_idx,
203 | 'pred_idx_to_name': p_idx_to_name,
204 | 'pred_name_to_idx': p_name_to_idx,
205 | }
206 |
207 | objs = [
208 | 'cat',
209 | 'cat',
210 | 'skateboard',
211 | 'hat',
212 | ]
213 | objs = torch.LongTensor([o_name_to_idx[o] for o in objs])
214 | triples = [
215 | [0, 'next to', 1],
216 | [0, 'riding', 2],
217 | [1, 'wearing', 3],
218 | [3, 'above', 2],
219 | ]
220 | triples = [[s, p_name_to_idx[p], o] for s, p, o in triples]
221 | triples = torch.LongTensor(triples)
222 |
223 | draw_scene_graph(objs, triples, vocab, orientation='V')
224 |
225 |
--------------------------------------------------------------------------------