├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── args_64_crn_clevr.yaml
├── args_64_crn_vg.yaml
├── args_64_spade_clevr.yaml
├── args_64_spade_vg.yaml
├── requirements.txt
├── scripts
    ├── SIMSG_gui.py
    ├── download_vg.sh
    ├── eval_utils.py
    ├── evaluate_changes_clevr.py
    ├── evaluate_changes_vg.py
    ├── evaluate_reconstruction.py
    ├── preprocess_vg.py
    ├── print_args.py
    ├── run_train.py
    ├── train.py
    └── train_utils.py
└── simsg
    ├── SPADE
        ├── __init__.py
        ├── architectures.py
        ├── base_network.py
        └── normalization.py
    ├── __init__.py
    ├── bilinear.py
    ├── data
        ├── __init__.py
        ├── clevr.py
        ├── clevr_gen
        │   ├── 1_gen_data.sh
        │   ├── 2_arrange_data.sh
        │   ├── 3_clevrToVG.py
        │   ├── README_CLEVR.md
        │   ├── collect_scenes.py
        │   ├── data
        │   │   ├── CoGenT_A.json
        │   │   ├── CoGenT_B.json
        │   │   ├── base_scene.blend
        │   │   ├── materials
        │   │   │   ├── MyMetal.blend
        │   │   │   └── Rubber.blend
        │   │   ├── properties.json
        │   │   └── shapes
        │   │   │   ├── SmoothCube_v2.blend
        │   │   │   ├── SmoothCylinder.blend
        │   │   │   └── Sphere.blend
        │   ├── render_clevr.py
        │   └── utils.py
        ├── utils.py
        ├── vg.py
        └── vg_splits.json
    ├── decoder.py
    ├── discriminators.py
    ├── feats_statistics.py
    ├── graph.py
    ├── layers.py
    ├── layout.py
    ├── loader_utils.py
    ├── losses.py
    ├── metrics.py
    ├── model.py
    ├── utils.py
    └── vis.py
/.gitignore:
--------------------------------------------------------------------------------
  1 | # Byte-compiled / optimized / DLL files
  2 | __pycache__/
  3 | *.py[cod]
  4 | *$py.class
  5 | 
  6 | # C extensions
  7 | *.so
  8 | 
  9 | # Distribution / packaging
 10 | .Python
 11 | build/
 12 | develop-eggs/
 13 | dist/
 14 | downloads/
 15 | eggs/
 16 | .eggs/
 17 | lib/
 18 | lib64/
 19 | parts/
 20 | sdist/
 21 | var/
 22 | wheels/
 23 | pip-wheel-metadata/
 24 | share/python-wheels/
 25 | *.egg-info/
 26 | .installed.cfg
 27 | *.egg
 28 | MANIFEST
 29 | 
 30 | # PyInstaller
 31 | #  Usually these files are written by a python script from a template
 32 | #  before PyInstaller builds the exe, so as to inject date/other infos into it.
 33 | *.manifest
 34 | *.spec
 35 | 
 36 | # Installer logs
 37 | pip-log.txt
 38 | pip-delete-this-directory.txt
 39 | 
 40 | # Unit test / coverage reports
 41 | htmlcov/
 42 | .tox/
 43 | .nox/
 44 | .coverage
 45 | .coverage.*
 46 | .cache
 47 | nosetests.xml
 48 | coverage.xml
 49 | *.cover
 50 | *.py,cover
 51 | .hypothesis/
 52 | .pytest_cache/
 53 | 
 54 | # Translations
 55 | *.mo
 56 | *.pot
 57 | 
 58 | # Django stuff:
 59 | *.log
 60 | local_settings.py
 61 | db.sqlite3
 62 | db.sqlite3-journal
 63 | 
 64 | # Flask stuff:
 65 | instance/
 66 | .webassets-cache
 67 | 
 68 | # Scrapy stuff:
 69 | .scrapy
 70 | 
 71 | # Sphinx documentation
 72 | docs/_build/
 73 | 
 74 | # PyBuilder
 75 | target/
 76 | 
 77 | # Jupyter Notebook
 78 | .ipynb_checkpoints
 79 | 
 80 | # IPython
 81 | profile_default/
 82 | ipython_config.py
 83 | 
 84 | # pyenv
 85 | .python-version
 86 | 
 87 | # pipenv
 88 | #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
 89 | #   However, in case of collaboration, if having platform-specific dependencies or dependencies
 90 | #   having no cross-platform support, pipenv may install dependencies that don't work, or not
 91 | #   install all needed dependencies.
 92 | #Pipfile.lock
 93 | 
 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
 95 | __pypackages__/
 96 | 
 97 | # Celery stuff
 98 | celerybeat-schedule
 99 | celerybeat.pid
100 | 
101 | # SageMath parsed files
102 | *.sage.py
103 | 
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 | 
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 | 
117 | # Rope project settings
118 | .ropeproject
119 | 
120 | # mkdocs documentation
121 | /site
122 | 
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 | 
128 | # Pyre type checker
129 | .pyre/
130 | 
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "PerceptualSimilarity"]
2 | 	path = PerceptualSimilarity
3 | 	url = https://github.com/richzhang/PerceptualSimilarity
4 | 
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
  1 |                                  Apache License
  2 |                            Version 2.0, January 2004
  3 |                         http://www.apache.org/licenses/
  4 | 
  5 |    TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
  6 | 
  7 |    1. Definitions.
  8 | 
  9 |       "License" shall mean the terms and conditions for use, reproduction,
 10 |       and distribution as defined by Sections 1 through 9 of this document.
 11 | 
 12 |       "Licensor" shall mean the copyright owner or entity authorized by
 13 |       the copyright owner that is granting the License.
 14 | 
 15 |       "Legal Entity" shall mean the union of the acting entity and all
 16 |       other entities that control, are controlled by, or are under common
 17 |       control with that entity. For the purposes of this definition,
 18 |       "control" means (i) the power, direct or indirect, to cause the
 19 |       direction or management of such entity, whether by contract or
 20 |       otherwise, or (ii) ownership of fifty percent (50%) or more of the
 21 |       outstanding shares, or (iii) beneficial ownership of such entity.
 22 | 
 23 |       "You" (or "Your") shall mean an individual or Legal Entity
 24 |       exercising permissions granted by this License.
 25 | 
 26 |       "Source" form shall mean the preferred form for making modifications,
 27 |       including but not limited to software source code, documentation
 28 |       source, and configuration files.
 29 | 
 30 |       "Object" form shall mean any form resulting from mechanical
 31 |       transformation or translation of a Source form, including but
 32 |       not limited to compiled object code, generated documentation,
 33 |       and conversions to other media types.
 34 | 
 35 |       "Work" shall mean the work of authorship, whether in Source or
 36 |       Object form, made available under the License, as indicated by a
 37 |       copyright notice that is included in or attached to the work
 38 |       (an example is provided in the Appendix below).
 39 | 
 40 |       "Derivative Works" shall mean any work, whether in Source or Object
 41 |       form, that is based on (or derived from) the Work and for which the
 42 |       editorial revisions, annotations, elaborations, or other modifications
 43 |       represent, as a whole, an original work of authorship. For the purposes
 44 |       of this License, Derivative Works shall not include works that remain
 45 |       separable from, or merely link (or bind by name) to the interfaces of,
 46 |       the Work and Derivative Works thereof.
 47 | 
 48 |       "Contribution" shall mean any work of authorship, including
 49 |       the original version of the Work and any modifications or additions
 50 |       to that Work or Derivative Works thereof, that is intentionally
 51 |       submitted to Licensor for inclusion in the Work by the copyright owner
 52 |       or by an individual or Legal Entity authorized to submit on behalf of
 53 |       the copyright owner. For the purposes of this definition, "submitted"
 54 |       means any form of electronic, verbal, or written communication sent
 55 |       to the Licensor or its representatives, including but not limited to
 56 |       communication on electronic mailing lists, source code control systems,
 57 |       and issue tracking systems that are managed by, or on behalf of, the
 58 |       Licensor for the purpose of discussing and improving the Work, but
 59 |       excluding communication that is conspicuously marked or otherwise
 60 |       designated in writing by the copyright owner as "Not a Contribution."
 61 | 
 62 |       "Contributor" shall mean Licensor and any individual or Legal Entity
 63 |       on behalf of whom a Contribution has been received by Licensor and
 64 |       subsequently incorporated within the Work.
 65 | 
 66 |    2. Grant of Copyright License. Subject to the terms and conditions of
 67 |       this License, each Contributor hereby grants to You a perpetual,
 68 |       worldwide, non-exclusive, no-charge, royalty-free, irrevocable
 69 |       copyright license to reproduce, prepare Derivative Works of,
 70 |       publicly display, publicly perform, sublicense, and distribute the
 71 |       Work and such Derivative Works in Source or Object form.
 72 | 
 73 |    3. Grant of Patent License. Subject to the terms and conditions of
 74 |       this License, each Contributor hereby grants to You a perpetual,
 75 |       worldwide, non-exclusive, no-charge, royalty-free, irrevocable
 76 |       (except as stated in this section) patent license to make, have made,
 77 |       use, offer to sell, sell, import, and otherwise transfer the Work,
 78 |       where such license applies only to those patent claims licensable
 79 |       by such Contributor that are necessarily infringed by their
 80 |       Contribution(s) alone or by combination of their Contribution(s)
 81 |       with the Work to which such Contribution(s) was submitted. If You
 82 |       institute patent litigation against any entity (including a
 83 |       cross-claim or counterclaim in a lawsuit) alleging that the Work
 84 |       or a Contribution incorporated within the Work constitutes direct
 85 |       or contributory patent infringement, then any patent licenses
 86 |       granted to You under this License for that Work shall terminate
 87 |       as of the date such litigation is filed.
 88 | 
 89 |    4. Redistribution. You may reproduce and distribute copies of the
 90 |       Work or Derivative Works thereof in any medium, with or without
 91 |       modifications, and in Source or Object form, provided that You
 92 |       meet the following conditions:
 93 | 
 94 |       (a) You must give any other recipients of the Work or
 95 |           Derivative Works a copy of this License; and
 96 | 
 97 |       (b) You must cause any modified files to carry prominent notices
 98 |           stating that You changed the files; and
 99 | 
100 |       (c) You must retain, in the Source form of any Derivative Works
101 |           that You distribute, all copyright, patent, trademark, and
102 |           attribution notices from the Source form of the Work,
103 |           excluding those notices that do not pertain to any part of
104 |           the Derivative Works; and
105 | 
106 |       (d) If the Work includes a "NOTICE" text file as part of its
107 |           distribution, then any Derivative Works that You distribute must
108 |           include a readable copy of the attribution notices contained
109 |           within such NOTICE file, excluding those notices that do not
110 |           pertain to any part of the Derivative Works, in at least one
111 |           of the following places: within a NOTICE text file distributed
112 |           as part of the Derivative Works; within the Source form or
113 |           documentation, if provided along with the Derivative Works; or,
114 |           within a display generated by the Derivative Works, if and
115 |           wherever such third-party notices normally appear. The contents
116 |           of the NOTICE file are for informational purposes only and
117 |           do not modify the License. You may add Your own attribution
118 |           notices within Derivative Works that You distribute, alongside
119 |           or as an addendum to the NOTICE text from the Work, provided
120 |           that such additional attribution notices cannot be construed
121 |           as modifying the License.
122 | 
123 |       You may add Your own copyright statement to Your modifications and
124 |       may provide additional or different license terms and conditions
125 |       for use, reproduction, or distribution of Your modifications, or
126 |       for any such Derivative Works as a whole, provided Your use,
127 |       reproduction, and distribution of the Work otherwise complies with
128 |       the conditions stated in this License.
129 | 
130 |    5. Submission of Contributions. Unless You explicitly state otherwise,
131 |       any Contribution intentionally submitted for inclusion in the Work
132 |       by You to the Licensor shall be under the terms and conditions of
133 |       this License, without any additional terms or conditions.
134 |       Notwithstanding the above, nothing herein shall supersede or modify
135 |       the terms of any separate license agreement you may have executed
136 |       with Licensor regarding such Contributions.
137 | 
138 |    6. Trademarks. This License does not grant permission to use the trade
139 |       names, trademarks, service marks, or product names of the Licensor,
140 |       except as required for reasonable and customary use in describing the
141 |       origin of the Work and reproducing the content of the NOTICE file.
142 | 
143 |    7. Disclaimer of Warranty. Unless required by applicable law or
144 |       agreed to in writing, Licensor provides the Work (and each
145 |       Contributor provides its Contributions) on an "AS IS" BASIS,
146 |       WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 |       implied, including, without limitation, any warranties or conditions
148 |       of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 |       PARTICULAR PURPOSE. You are solely responsible for determining the
150 |       appropriateness of using or redistributing the Work and assume any
151 |       risks associated with Your exercise of permissions under this License.
152 | 
153 |    8. Limitation of Liability. In no event and under no legal theory,
154 |       whether in tort (including negligence), contract, or otherwise,
155 |       unless required by applicable law (such as deliberate and grossly
156 |       negligent acts) or agreed to in writing, shall any Contributor be
157 |       liable to You for damages, including any direct, indirect, special,
158 |       incidental, or consequential damages of any character arising as a
159 |       result of this License or out of the use or inability to use the
160 |       Work (including but not limited to damages for loss of goodwill,
161 |       work stoppage, computer failure or malfunction, or any and all
162 |       other commercial damages or losses), even if such Contributor
163 |       has been advised of the possibility of such damages.
164 | 
165 |    9. Accepting Warranty or Additional Liability. While redistributing
166 |       the Work or Derivative Works thereof, You may choose to offer,
167 |       and charge a fee for, acceptance of support, warranty, indemnity,
168 |       or other liability obligations and/or rights consistent with this
169 |       License. However, in accepting such obligations, You may act only
170 |       on Your own behalf and on Your sole responsibility, not on behalf
171 |       of any other Contributor, and only if You agree to indemnify,
172 |       defend, and hold each Contributor harmless for any liability
173 |       incurred by, or claims asserted against, such Contributor by reason
174 |       of your accepting any such warranty or additional liability.
175 | 
176 |    END OF TERMS AND CONDITIONS
177 | 
178 |    APPENDIX: How to apply the Apache License to your work.
179 | 
180 |       To apply the Apache License to your work, attach the following
181 |       boilerplate notice, with the fields enclosed by brackets "[]"
182 |       replaced with your own identifying information. (Don't include
183 |       the brackets!)  The text should be enclosed in the appropriate
184 |       comment syntax for the file format. We also recommend that a
185 |       file or class name and description of purpose be included on the
186 |       same "printed page" as the copyright notice for easier
187 |       identification within third-party archives.
188 | 
189 |    Copyright [yyyy] [name of copyright owner]
190 | 
191 |    Licensed under the Apache License, Version 2.0 (the "License");
192 |    you may not use this file except in compliance with the License.
193 |    You may obtain a copy of the License at
194 | 
195 |        http://www.apache.org/licenses/LICENSE-2.0
196 | 
197 |    Unless required by applicable law or agreed to in writing, software
198 |    distributed under the License is distributed on an "AS IS" BASIS,
199 |    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 |    See the License for the specific language governing permissions and
201 |    limitations under the License.
202 | 
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
  1 | # SIMSG
  2 | 
  3 | This is the code accompanying the paper
  4 | 
  5 | **Semantic Image Manipulation Using Scene Graphs | arxiv** 
  6 | Helisa Dhamo*, Azade Farshad*, Iro Laina, Nassir Navab, Gregory D. Hager, Federico Tombari, Christian Rupprecht 
  7 | **CVPR 2020**
  8 | 
  9 | The work for this paper was done at the Technical University of Munich.
 10 | 
 11 | In our work, we address the novel problem of image manipulation from scene graphs, in which a user can edit images by 
 12 | merely applying changes in the nodes or edges of a semantic graph that is generated from the image.
 13 | 
 14 | We introduce a spatio-semantic scene graph network that does not require direct supervision for constellation changes or 
 15 | image edits. This makes it possible to train the system from existing real-world datasets with no additional annotation 
 16 | effort.
 17 | 
 18 | If you find this code useful in your research, please cite
 19 | ```
 20 | @inproceedings{dhamo2020_SIMSG,
 21 |   title={Semantic Image Manipulation Using Scene Graphs},
 22 |   author={Dhamo, Helisa and Farshad, Azade, and Laina, Iro and Navab, Nassir and
 23 |           Hager, Gregory D., and Tombari, Federico and Rupprecht, Christian},
 24 |   booktitle={CVPR},
 25 |   year={2020}
 26 | ```
 27 | 
 28 | **Note:** The project page has been updated with new links for model checkpoints and data.
 29 | 
 30 | 
 31 | ## Setup
 32 | 
 33 | We have tested it on Ubuntu 16.04 with Python 3.7 and PyTorch 1.2.
 34 | 
 35 | ### Setup code
 36 | You can setup a conda environment to run the code like this:
 37 | 
 38 | ```bash
 39 | # clone this repository and move there
 40 | git clone --recurse-submodules https://github.com/he-dhamo/simsg.git
 41 | cd simsg
 42 | # create a conda environment and install the requirments
 43 | conda create --name simsg_env python=3.7 --file requirements.txt 
 44 | conda activate simsg_env          # activate virtual environment
 45 | # install pytorch and cuda version as tested in our work
 46 | conda install pytorch==1.2.0 torchvision==0.4.0 cudatoolkit=10.0 -c pytorch
 47 | # more installations
 48 | pip install opencv-python tensorboardx grave addict
 49 | # to add current directory to python path run
 50 | echo $PWD > /lib/python3.7/site-packages/simsg.pth 
 51 | ```
 52 | `path_to_env` can be found using `which python` (the resulting path minus `/bin/python`) while the conda env is active.
 53 | 
 54 | ### Setup Visual Genome data
 55 | (instructions from the  sg2im  repository)
 56 | 
 57 | Run the following script to download and unpack the relevant parts of the Visual Genome dataset:
 58 | 
 59 | ```bash
 60 | bash scripts/download_vg.sh
 61 | ```
 62 | 
 63 | This will create the directory `datasets/vg` and will download about 15 GB of data to this directory; after unpacking it 
 64 | will take about 30 GB of disk space.
 65 | 
 66 | After downloading the Visual Genome dataset, we need to preprocess it. This will split the data into train / val / test 
 67 | splits, consolidate all scene graphs into HDF5 files, and apply several heuristics to clean the data. In particular we 
 68 | ignore images that are too small, and only consider object and attribute categories that appear some number of times in 
 69 | the training set; we also ignore objects that are too small, and set minimum and maximum values on the number of objects 
 70 | and relationships that appear per image.
 71 | 
 72 | ```bash
 73 | python scripts/preprocess_vg.py
 74 | ```
 75 | 
 76 | This will create files `train.h5`, `val.h5`, `test.h5`, and `vocab.json` in the directory `datasets/vg`.
 77 | 
 78 | 
 79 | ## Training
 80 | 
 81 | To train the model, you can either set the right options in a `args.yaml` file and then run the `run_train.py` script 
 82 | (**highly recommended**):
 83 | ```
 84 | python scripts/run_train.py args.yaml
 85 | ``` 
 86 | Or, alternatively run the `train.py` script with the respective arguments. For example:
 87 | ```
 88 | python scripts/train.py --checkpoint_name myckpt --dataset vg --spade_blocks True
 89 | ```
 90 | 
 91 | We provide the configuration files for the experiments presented in the paper, represented in the format 
 92 | `args_{resolution}_{decoder}_{dataset}.yaml`.
 93 | 
 94 | Please set the dataset path `DATA_DIR` in `train.py` before running.
 95 | 
 96 | Most relevant arguments:
 97 | 
 98 | `--checkpoint_name`: Base filename for saved checkpoints; default is 'checkpoint', so the filename for the checkpoint 
 99 | with model parameters will be 'checkpoint_model.pt' 
100 | `--selective_discr_obj`: If True, only apply the object discriminator in the reconstructed RoIs (masked image areas). 
101 | 
102 | `--feats_in_gcn`: If True, feed the RoI visual features in the scene graph network (SGN). 
 
103 | `--feats_out_gcn`: If True, concatenate RoI visual features with the node feature (result of SGN). 
104 | `--is_baseline`: If True, run cond-sg2im baseline. 
105 | `--is_supervised`: If True, run fully-supervised baseline for CLEVR. 
106 | `--spade_gen_blocks`: If True, use SPADE blocks in the decoder architecture. Otherwise, use CRN blocks.
107 | 
108 | ## Evaluate reconstruction
109 | 
110 | To evaluate the model in reconstruction mode, you need to run the script ```evaluate_reconstruction.py```. The MAE, SSIM 
111 | and LPIPS will be printed on the terminal with frequency ```--print_every``` and saved on pickle files with frequency 
112 | ```--save_every```. When ```--save_images``` is set to ```True```, the code also saves one reconstructed sample per 
113 | image, which can be used for the computation of FID and 
114 | Inception score.
115 | 
116 | Other relevant arguments are:
117 | 
118 | `--exp_dir`: path to folder where all experiments are saved 
119 | `--experiment`: name of experiment, e.g. spade_vg_64 
120 | `--checkpoint`: checkpoint path 
121 | `--with_feats`: (bool) using RoI visual features 
122 | `--generative`: (bool) fully-generative mode, i.e. the whole input image is masked out
123 | 
124 | If `--checkpoint` is not specified, the checkpoint path is automatically set to ```/_model.pt```.
125 | 
126 | To reproduce the results from Table 2, please run:
127 | 1. for the fully generative task:
128 |  ```
129 |  python scripts/evaluate_reconstruction.py --exp_dir /path/to/experiment/dir --experiment spade --with_feats True 
130 | --generative True
131 | ```
132 |  2. for the RoI reconstruction with visual features:
133 |  ```
134 |  python scripts/evaluate_reconstruction.py --exp_dir /path/to/experiment/dir --experiment spade --with_feats True 
135 | --generative False
136 | ```
137 |  3. for the RoI reconstruction without visual features:
138 |   ```
139 |   python scripts/evaluate_reconstruction.py --exp_dir /path/to/experiment/dir --experiment spade --with_feats False 
140 | --generative False
141 |   ```
142 | 
143 | To evaluate the model with ground truth (GT) and predicted (PRED) scene graphs, please set `--predgraphs`  namely to 
144 | ```False``` or ```True```. 
145 | Before running with predicted graphs (PRED), make sure you have downloaded the respective predicted graphs 
146 | `test_predgraphs.h5` and placed it at the same directory as ```test.h5```.
147 | 
148 | Please set the dataset path `DATA_DIR` before running.
149 | 
150 | For the LPIPS computation, we use PerceptualSimilarity cloned from the 
151 | official repo as a git submodule. If you did not clone our
152 | repository recursively, run:
153 | ```
154 | git submodule update --init --recursive
155 | ```
156 | 
157 | ## Evaluate changes on Visual Genome
158 | 
159 | To evaluate the model in semantic editing mode on Visual Genome, you need to run the script 
160 | ```evaluate_changes_vg.py``` . The script automatically generates edited images, without a user interface.
161 | 
162 | The most relevant arguments are:
163 | 
164 | `--exp_dir`: path to folder where all experiments are saved 
165 | `--experiment`: name of experiment, e.g. spade_vg_64 
166 | `--checkpoint`: checkpoint path 
167 | `--mode`: choose from options ['replace',  'reposition',  'remove',  'auto_withfeats', 'auto_nofeats'], for namely 
168 | object replacement, relationship change, object removal, RoI reconstruction with visual features, RoI reconstruction 
169 | without features. 
170 | `--with_query_image`: (bool) in case you want to use visual features from another image (query image). used in 
171 | combination with `mode='auto_nofeats'`. 
172 | 
173 | If `--checkpoint` is not specified, the checkpoint path is automatically set to ```/_model.pt```.
174 | 
175 | Example run of object replacement changes using the spade model:
176 | ```
177 | python scripts/evaluate_changes_vg.py --exp_dir /path/to/experiment/dir --experiment spade --mode replace
178 | ```
179 | 
180 | Please set the dataset path `VG_DIR` before running.
181 | 
182 | ## Evaluate changes on CLEVR
183 | 
184 | To evaluate the model in semantic editing mode on CLEVR, you need to run the script ```evaluate_changes_clevr.py```. 
185 | The script automatically generates edited images, without a user interface.
186 | 
187 | The most relevant arguments are:
188 | 
189 | `--exp_dir`: path to folder where all experiments are saved 
190 | `--experiment`: name of experiment, e.g. spade_clevr_64 
191 | `--checkpoint`: checkpoint path 
192 | `--image_size`: size of the input image, can be (64,64) or (128,128) based on the size used in training 
193 | 
194 | If `--checkpoint` is not specified, the checkpoint path is automatically set to ```/_model.pt```.
195 | 
196 | Example run of object replacement changes using the spade model:
197 | ```
198 | python scripts/evaluate_changes_clevr.py --exp_dir /path/to/experiment/dir --experiment spade
199 | ```
200 | 
201 | Please set the dataset path `CLEVR_DIR` before running.
202 | 
203 | ## User Interface
204 | 
205 | To run a simple user interface that supports different manipulation types, such as object addition, removal, replacement 
206 | and relationship change run:
207 | ```
208 | python scripts/SIMSG_gui.py 
209 | ```
210 | 
211 | Relevant options:
212 | 
213 | `--checkpoint`: path to checkpoint file 
214 | `--dataset`: visual genome or clevr 
215 | `--predgraphs`:(bool) specifies loading either ground truth or predicted graphs. So far, predicted graphs are only 
216 | available for visual genome. 
217 | `--data_h5`: path of h5 file used to load a certain data split. If not excplicitly set, it uses `test.h5` for GT graphs 
218 | and `test_predgraphs.h5` for predicted graphs. 
219 | `--update_input`: (bool) used to control a sequence of changes in the same image. If `True`, it sets the input as the 
220 | output of the previous generation. Otherwise, all consecutive changes are applied on the original image.
221 | Please set the value of DATA_DIR in SIMSG_gui.py to point to your dataset.
222 | 
223 | Once you click on "Load image" and "Get graph" a new image and corresponding scene graph will appear. The current 
224 | implementation loads GT or predicted graph. Then you can apply one of the following manipulations:
225 | 
226 | - For **object replacement** or **relationship change**:
227 |   1. Select the node you want to change 
228 |   2. Choose a new category from namely the "Replace object" and "Change Relationship" menu.
229 | - For **object addition**: 
230 |   1. Choose a node where you want to connect the new object to (from "Connect to object") 
231 |   2. Select the category of the new object and relationship ("Add new node", "Add relationship"). 
232 |   3. Specify the direction of the connection. Click on "Add as subject" for a `new_node -> predicate -> existing_node` 
233 |   direction, and click "Add as object" for `existing_node -> predicate -> new_node`.
234 | - For **object removal**: 
235 |   1. Select the node you want to remove
236 |   2. Click on "Remove node" 
237 |   Note that the current implementation only supports object removal (cannot remove relationships); though the model 
238 |   supports this and the GUI implementation can be extended accordingly.
239 | 
240 | After you have completed the change, click on "Generate image". Alternatively you can save the image. 
241 | 
242 | ## Download
243 | 
244 | Visit the project page to download model checkpoints, 
245 | predicted scene graphs and the CLEVR data with edit pairs.
246 | 
247 | We also provide the code used to generate the CLEVR data with change pairs. Please follow the instructions from 
248 | [here](simsg/data/clevr_gen/README_CLEVR.md).
249 | 
250 | ## Acknoledgement
251 | 
252 | This code is based on the  sg2im repository . 
253 | 
254 | The following directory is taken from the  SPADE  repository:
255 | - simsg/SPADE/
256 | 
--------------------------------------------------------------------------------
/args_64_crn_clevr.yaml:
--------------------------------------------------------------------------------
 1 | # arguments for easier parsing
 2 | seed: 1
 3 | gpu: 0
 4 | 
 5 | # DATA OPTIONS
 6 | dataset: clevr
 7 | vg_image_dir: ./datasets/clevr/target
 8 | output_dir: experiments/clevr
 9 | checkpoint_name: crn_64_clevr
10 | log_dir: experiments/clevr/logs/crn_64_clevr
11 | 
12 | # ARCHITECTURE OPTIONS
13 | image_size: !!python/tuple [64, 64]
14 | crop_size: 32
15 | batch_size: 32
16 | mask_size: 16
17 | d_obj_arch: C4-64-2,C4-128-2,C4-256-2
18 | d_img_arch: C4-64-2,C4-128-2,C4-256-2
19 | decoder_network_dims: !!python/tuple [1024,512,256,128,64]
20 | layout_pooling: sum
21 | 
22 | # CRN weights
23 | percept_weight: 0
24 | weight_gan_feat: 0
25 | discriminator_loss_weight: 0.01
26 | d_obj_weight: 1
27 | ac_loss_weight: 0.1
28 | d_img_weight: 1
29 | l1_pixel_loss_weight: 1
30 | bbox_pred_loss_weight: 10
31 | 
32 | # EXTRA OPTIONS
33 | feats_in_gcn: True
34 | feats_out_gcn: True
35 | is_baseline: False
36 | is_supervised: False
37 | num_iterations: 50000
38 | 
39 | # LOGGING OPTIONS
40 | print_every: 500
41 | checkpoint_every: 2000
42 | max_num_imgs: 32
43 | 
--------------------------------------------------------------------------------
/args_64_crn_vg.yaml:
--------------------------------------------------------------------------------
 1 | # arguments for easier parsing
 2 | seed: 1
 3 | gpu: 0
 4 | 
 5 | # DATA OPTIONS
 6 | dataset: vg
 7 | vg_image_dir: ./datasets/vg/images
 8 | output_dir: experiments/vg
 9 | checkpoint_name: crn_64_vg
10 | log_dir: experiments/vg/logs/crn_64_vg
11 | 
12 | # ARCHITECTURE OPTIONS
13 | image_size: !!python/tuple [64, 64]
14 | crop_size: 32
15 | batch_size: 32
16 | mask_size: 16
17 | d_obj_arch: C4-64-2,C4-128-2,C4-256-2
18 | d_img_arch: C4-64-2,C4-128-2,C4-256-2
19 | decoder_network_dims: !!python/tuple [1024,512,256,128,64]
20 | layout_pooling: sum
21 | 
22 | # CRN weights
23 | percept_weight: 0
24 | weight_gan_feat: 0
25 | discriminator_loss_weight: 0.01
26 | d_obj_weight: 1
27 | ac_loss_weight: 0.1
28 | d_img_weight: 1
29 | l1_pixel_loss_weight: 1
30 | bbox_pred_loss_weight: 10
31 | 
32 | # EXTRA OPTIONS
33 | feats_in_gcn: True
34 | feats_out_gcn: True
35 | is_baseline: False
36 | is_supervised: False
37 | 
38 | # LOGGING OPTIONS
39 | print_every: 500
40 | checkpoint_every: 2000
41 | max_num_imgs: 32
42 | 
--------------------------------------------------------------------------------
/args_64_spade_clevr.yaml:
--------------------------------------------------------------------------------
 1 | # arguments for easier parsing
 2 | seed: 1
 3 | gpu: 0
 4 | 
 5 | # DATA OPTIONS
 6 | dataset: clevr
 7 | vg_image_dir: ./datasets/clevr/target
 8 | output_dir: experiments/clevr
 9 | checkpoint_name: spade_64_clevr
10 | log_dir: experiments/clevr/logs/spade_64_clevr
11 | 
12 | # ARCHITECTURE OPTIONS
13 | image_size: !!python/tuple [64, 64]
14 | crop_size: 32
15 | batch_size: 32
16 | mask_size: 16
17 | d_obj_arch: C4-64-2,C4-128-2,C4-256-2
18 | d_img_arch: C4-64-2,C4-128-2,C4-256-2
19 | decoder_network_dims: !!python/tuple [1024,512,256,128,64]
20 | layout_pooling: sum
21 | 
22 | # spade weights + options
23 | percept_weight: 5
24 | weight_gan_feat: 5
25 | discriminator_loss_weight: 1
26 | d_obj_weight: 0.1
27 | ac_loss_weight: 0.1
28 | d_img_weight: 1
29 | l1_pixel_loss_weight: 1
30 | bbox_pred_loss_weight: 50
31 | multi_discriminator: True
32 | spade_gen_blocks: True
33 | 
34 | # EXTRA OPTIONS
35 | feats_in_gcn: True
36 | feats_out_gcn: True
37 | is_baseline: False
38 | is_supervised: True
39 | num_iterations: 50000
40 | 
41 | # LOGGING OPTIONS
42 | print_every: 500
43 | checkpoint_every: 2000
44 | max_num_imgs: 32
45 | 
--------------------------------------------------------------------------------
/args_64_spade_vg.yaml:
--------------------------------------------------------------------------------
 1 | # arguments for easier parsing
 2 | seed: 1
 3 | gpu: 0
 4 | 
 5 | # DATA OPTIONS
 6 | dataset: vg
 7 | vg_image_dir: ./datasets/vg/images
 8 | output_dir: experiments/vg
 9 | checkpoint_name: spade_64_vg
10 | log_dir: experiments/vg/logs/spade_64_vg
11 | 
12 | # ARCHITECTURE OPTIONS
13 | image_size: !!python/tuple [64, 64]
14 | crop_size: 32
15 | batch_size: 32
16 | mask_size: 16
17 | d_obj_arch: C4-64-2,C4-128-2,C4-256-2
18 | d_img_arch: C4-64-2,C4-128-2,C4-256-2
19 | decoder_network_dims: !!python/tuple [1024,512,256,128,64]
20 | layout_pooling: sum
21 | 
22 | # spade weights + options
23 | percept_weight: 5
24 | weight_gan_feat: 5
25 | discriminator_loss_weight: 1
26 | d_obj_weight: 0.1
27 | ac_loss_weight: 0.1
28 | d_img_weight: 1
29 | l1_pixel_loss_weight: 1
30 | bbox_pred_loss_weight: 50
31 | multi_discriminator: True
32 | spade_gen_blocks: True
33 | 
34 | # EXTRA OPTIONS
35 | feats_in_gcn: True
36 | feats_out_gcn: True
37 | is_baseline: False
38 | is_supervised: False
39 | 
40 | # LOGGING OPTIONS
41 | print_every: 500
42 | checkpoint_every: 2000
43 | max_num_imgs: 32
44 | 
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
 1 | # This file may be used to create an environment using:
 2 | # $ conda create --name  --file 
 3 | cloudpickle=0.5.3
 4 | cycler=0.10.0
 5 | cython=0.28.3
 6 | decorator=4.3.0
 7 | h5py=2.10.0
 8 | imageio=2.3.0
 9 | kiwisolver=1.1.0
10 | matplotlib=3.1.1
11 | networkx=2.1
12 | ninja=1.9.0
13 | numpy=1.16
14 | pillow>=6.2.2
15 | pyqt=5.9.2
16 | pyyaml=5.1.2
17 | pywavelets=0.5.2
18 | qt=5.9.7
19 | scikit-image=0.15.0
20 | scikit-learn=0.21.3
21 | scipy=1.3.1
22 | six=1.12.0
23 | sqlite=3.29.0=h7b6447c_0
24 | toolz=0.9.0
25 | tqdm=4.31.1
26 | wheel=0.33.6
27 | yaml=0.1.7
28 | 
--------------------------------------------------------------------------------
/scripts/download_vg.sh:
--------------------------------------------------------------------------------
 1 | #!/bin/bash -eu
 2 | #
 3 | # Copyright 2018 Google LLC
 4 | #
 5 | # Licensed under the Apache License, Version 2.0 (the "License");
 6 | # you may not use this file except in compliance with the License.
 7 | # You may obtain a copy of the License at
 8 | #
 9 | #      http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | 
17 | VG_DIR=datasets/vg
18 | mkdir -p $VG_DIR
19 | 
20 | wget https://visualgenome.org/static/data/dataset/objects.json.zip -O $VG_DIR/objects.json.zip
21 | wget https://visualgenome.org/static/data/dataset/attributes.json.zip -O $VG_DIR/attributes.json.zip
22 | wget https://visualgenome.org/static/data/dataset/relationships.json.zip -O $VG_DIR/relationships.json.zip
23 | wget https://visualgenome.org/static/data/dataset/object_alias.txt -O $VG_DIR/object_alias.txt
24 | wget https://visualgenome.org/static/data/dataset/relationship_alias.txt -O $VG_DIR/relationship_alias.txt
25 | wget https://visualgenome.org/static/data/dataset/image_data.json.zip -O $VG_DIR/image_data.json.zip
26 | wget https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip -O $VG_DIR/images.zip
27 | wget https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip -O $VG_DIR/images2.zip
28 | 
29 | unzip $VG_DIR/objects.json.zip -d $VG_DIR
30 | unzip $VG_DIR/attributes.json.zip -d $VG_DIR
31 | unzip $VG_DIR/relationships.json.zip -d $VG_DIR
32 | unzip $VG_DIR/image_data.json.zip -d $VG_DIR
33 | unzip $VG_DIR/images.zip -d $VG_DIR/images
34 | unzip $VG_DIR/images2.zip -d $VG_DIR/images
35 | 
--------------------------------------------------------------------------------
/scripts/eval_utils.py:
--------------------------------------------------------------------------------
  1 | #!/usr/bin/python
  2 | #
  3 | # Copyright 2020 Helisa Dhamo
  4 | #
  5 | # Licensed under the Apache License, Version 2.0 (the "License");
  6 | # you may not use this file except in compliance with the License.
  7 | # You may obtain a copy of the License at
  8 | #
  9 | #      http://www.apache.org/licenses/LICENSE-2.0
 10 | #
 11 | # Unless required by applicable law or agreed to in writing, software
 12 | # distributed under the License is distributed on an "AS IS" BASIS,
 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14 | # See the License for the specific language governing permissions and
 15 | # limitations under the License.
 16 | 
 17 | import os
 18 | import json
 19 | import numpy as np
 20 | from matplotlib import pyplot as plt
 21 | import matplotlib.patches as patches
 22 | import cv2
 23 | import torch
 24 | from simsg.vis import draw_scene_graph
 25 | from simsg.data import imagenet_deprocess_batch
 26 | from imageio import imsave
 27 | 
 28 | from simsg.model import mask_image_in_bbox
 29 | 
 30 | 
 31 | def remove_node(objs, triples, boxes, imgs, idx, obj_to_img, triple_to_img):
 32 |   '''
 33 |   removes nodes and all related edges in case of object removal
 34 |   image is also masked in the respective area
 35 |   idx: list of object ids to be removed
 36 |   Returns:
 37 |     updated objs, triples, boxes, imgs, obj_to_img, triple_to_img
 38 |   '''
 39 | 
 40 |   # object nodes
 41 |   idlist = list(range(objs.shape[0]))
 42 |   keeps = [i for i in idlist if i not in idx]
 43 |   objs_reduced = objs[keeps]
 44 |   boxes_reduced = boxes[keeps]
 45 | 
 46 |   offset = torch.zeros_like(objs)
 47 |   for i in range(objs.shape[0]):
 48 |     for j in idx:
 49 |       if j < i:
 50 |         offset[i] += 1
 51 | 
 52 |   # edges connected to removed object
 53 |   keeps_t = []
 54 |   triples_reduced = []
 55 |   for i in range(triples.shape[0]):
 56 |     if not(triples[i,0] in idx or triples[i, 2] in idx):
 57 |       keeps_t.append(i)
 58 |       triples_reduced.append(torch.tensor([triples[i,0] - offset[triples[i,0]], triples[i,1],
 59 |                                            triples[i,2] - offset[triples[i,2]]], device=triples.device))
 60 |   triples_reduced = torch.stack(triples_reduced, dim=0)
 61 | 
 62 |   # update indexing arrays
 63 |   obj_to_img_reduced = obj_to_img[keeps]
 64 |   triple_to_img_reduced = triple_to_img[keeps_t]
 65 | 
 66 |   # mask RoI of removed objects from image
 67 |   for i in idx:
 68 |     imgs = mask_image_in_bbox(imgs, boxes, i, obj_to_img, mode='removal')
 69 | 
 70 |   return objs_reduced, triples_reduced, boxes_reduced, imgs, obj_to_img_reduced, triple_to_img_reduced
 71 | 
 72 | 
 73 | def bbox_coordinates_with_margin(bbox, margin, img):
 74 |     # extract bounding box with a margin
 75 | 
 76 |     left = max(0, bbox[0] * img.shape[3] - margin)
 77 |     top = max(0, bbox[1] * img.shape[2] - margin)
 78 |     right = min(img.shape[3], bbox[2] * img.shape[3] + margin)
 79 |     bottom = min(img.shape[2], bbox[3] * img.shape[2] + margin)
 80 | 
 81 |     return int(left), int(right), int(top), int(bottom)
 82 | 
 83 | 
 84 | def save_image_from_tensor(img, img_dir, filename):
 85 | 
 86 |     img = imagenet_deprocess_batch(img)
 87 |     img_np = img[0].numpy().transpose(1, 2, 0)
 88 |     img_path = os.path.join(img_dir, filename)
 89 |     imsave(img_path, img_np)
 90 | 
 91 | 
 92 | def save_image_with_label(img_pred, img_gt, img_dir, filename, txt_str):
 93 |     # saves gt and generated image, concatenated
 94 |     # together with text label describing the change
 95 |     # used for easier visualization of results
 96 | 
 97 |     img_pred = imagenet_deprocess_batch(img_pred)
 98 |     img_gt = imagenet_deprocess_batch(img_gt)
 99 | 
100 |     img_pred_np = img_pred[0].numpy().transpose(1, 2, 0)
101 |     img_gt_np = img_gt[0].numpy().transpose(1, 2, 0)
102 | 
103 |     img_pred_np = cv2.resize(img_pred_np, (128, 128))
104 |     img_gt_np = cv2.resize(img_gt_np, (128, 128))
105 | 
106 |     wspace = np.zeros([img_pred_np.shape[0], 10, 3])
107 |     text = np.zeros([30, img_pred_np.shape[1] * 2 + 10, 3])
108 |     text = cv2.putText(text, txt_str, (0,20), cv2.FONT_HERSHEY_SIMPLEX,
109 |                      0.5, (255, 255, 255), lineType=cv2.LINE_AA)
110 | 
111 |     img_pred_gt = np.concatenate([img_gt_np, wspace, img_pred_np], axis=1).astype('uint8')
112 |     img_pred_gt = np.concatenate([text, img_pred_gt], axis=0).astype('uint8')
113 |     img_path = os.path.join(img_dir, filename)
114 |     imsave(img_path, img_pred_gt)
115 | 
116 | 
117 | def makedir(base, name, flag=True):
118 |     dir_name = None
119 |     if flag:
120 |         dir_name = os.path.join(base, name)
121 |         if not os.path.isdir(dir_name):
122 |             os.makedirs(dir_name)
123 |     return dir_name
124 | 
125 | 
126 | def save_graph_json(objs, triples, boxes, beforeafter, dir, idx):
127 |     # save scene graph in json form
128 | 
129 |     data = {}
130 |     objs = objs.cpu().numpy()
131 |     triples = triples.cpu().numpy()
132 |     data['objs'] = objs.tolist()
133 |     data['triples'] = triples.tolist()
134 |     data['boxes'] = boxes.tolist()
135 |     with open(dir + '/' + beforeafter + '_' + str(idx) + '.json', 'w') as outfile:
136 |         json.dump(data, outfile)
137 | 
138 | 
139 | def query_image_by_semantic_id(obj_id, curr_img_id, loader, num_samples=7):
140 |     # used to replace objects with an object of the same category and different appearance
141 |     # return list of images and bboxes, that contain object of category obj_id
142 | 
143 |     query_imgs, query_boxes = [], []
144 |     loader_id = 0
145 |     counter = 0
146 | 
147 |     for l in loader:
148 |         # load images
149 |         imgs, objs, boxes, _, _, _, _ = [x.cuda() for x in l]
150 |         if loader_id == curr_img_id:
151 |             loader_id += 1
152 |             continue
153 | 
154 |         for i, ob in enumerate(objs):
155 |             if obj_id[0] == ob:
156 |                 query_imgs.append(imgs)
157 |                 query_boxes.append(boxes[i])
158 |                 counter += 1
159 |             if counter == num_samples:
160 |                 return query_imgs, query_boxes
161 |         loader_id += 1
162 | 
163 |     return 0, 0
164 | 
165 | 
166 | def draw_image_box(img, box):
167 | 
168 |     left, right, top, bottom = int(round(box[0] * img.shape[1])), int(round(box[2] * img.shape[1])), \
169 |                                int(round(box[1] * img.shape[0])), int(round(box[3] * img.shape[0]))
170 | 
171 |     cv2.rectangle(img, (left, top), (right, bottom), (255,0,0), 1)
172 |     return img
173 | 
174 | 
175 | def draw_image_edge(img, box1, box2):
176 |     # draw arrow that connects two objects centroids
177 |     left1, right1, top1, bottom1 = int(round(box1[0] * img.shape[1])), int(round(box1[2] * img.shape[1])), \
178 |                                int(round(box1[1] * img.shape[0])), int(round(box1[3] * img.shape[0]))
179 |     left2, right2, top2, bottom2 = int(round(box2[0] * img.shape[1])), int(round(box2[2] * img.shape[1])), \
180 |                                int(round(box2[1] * img.shape[0])), int(round(box2[3] * img.shape[0]))
181 | 
182 |     cv2.arrowedLine(img, (int((left1+right1)/2), int((top1+bottom1)/2)),
183 |              (int((left2+right2)/2), int((top2+bottom2)/2)), (255,0,0), 1)
184 | 
185 |     return img
186 | 
187 | 
188 | def visualize_imgs_boxes(imgs, imgs_pred, boxes, boxes_pred):
189 | 
190 |     nrows = imgs.size(0)
191 |     imgs = imgs.detach().cpu().numpy()
192 |     imgs_pred = imgs_pred.detach().cpu().numpy()
193 |     boxes = boxes.detach().cpu().numpy()
194 |     boxes_pred = boxes_pred.detach().cpu().numpy()
195 |     plt.figure()
196 | 
197 |     for i in range(0, nrows):
198 |         # i = j//2
199 |         ax1 = plt.subplot(2, nrows, i+1)
200 |         img = np.transpose(imgs[i, :, :, :], (1, 2, 0)) / 255.
201 |         plt.imshow(img)
202 | 
203 |         left, right, top, bottom = bbox_coordinates_with_margin(boxes[i, :], 0, imgs[i:i+1, :, :, :])
204 |         bbox_gt = patches.Rectangle((left, top),
205 |                                     width=right-left,
206 |                                     height=bottom-top,
207 |                                     linewidth=1, edgecolor='r', facecolor='none')
208 |         # Add the patch to the Axes
209 |         ax1.add_patch(bbox_gt)
210 |         plt.axis('off')
211 | 
212 |         ax2 = plt.subplot(2, nrows, i+nrows+1)
213 |         pred = np.transpose(imgs_pred[i, :, :, :], (1, 2, 0)) / 255.
214 |         plt.imshow(pred)
215 | 
216 |         left, right, top, bottom = bbox_coordinates_with_margin(boxes_pred[i, :], 0, imgs[i:i+1, :, :, :])
217 |         bbox_pr = patches.Rectangle((left, top),
218 |                                     width=right-left,
219 |                                     height=bottom-top,
220 |                                     linewidth=1, edgecolor='r', facecolor='none')
221 |         # ax2.add_patch(bbox_gt)
222 |         ax2.add_patch(bbox_pr)
223 |         plt.axis('off')
224 | 
225 |     plt.show()
226 | 
227 | 
228 | def visualize_scene_graphs(obj_to_img, objs, triples, vocab, device):
229 |     offset = 0
230 |     for i in range(1):#imgs_in.size(0)):
231 |         curr_obj_idx = (obj_to_img == i).nonzero()
232 | 
233 |         objs_vis = objs[curr_obj_idx]
234 |         triples_vis = []
235 |         for j in range(triples.size(0)):
236 |             if triples[j, 0] in curr_obj_idx or triples[j, 2] in curr_obj_idx:
237 |                 triples_vis.append(triples[j].to(device) - torch.tensor([offset, 0, offset]).to(device))
238 |         offset += curr_obj_idx.size(0)
239 |         triples_vis = torch.stack(triples_vis, 0)
240 | 
241 |         print(objs_vis, triples_vis)
242 |         graph_img = draw_scene_graph(objs_vis, triples_vis, vocab)
243 | 
244 |         cv2.imshow('graph' + str(i), graph_img)
245 |     cv2.waitKey(10000)
246 | 
247 | 
248 | def remove_duplicates(triples, triple_to_img, indexes):
249 |     # removes duplicates in relationship triples
250 | 
251 |     triples_new = []
252 |     triple_to_img_new = []
253 | 
254 |     for i in range(triples.size(0)):
255 |         if i not in indexes:
256 |             triples_new.append(triples[i])
257 |             triple_to_img_new.append(triple_to_img[i])
258 | 
259 |     triples_new = torch.stack(triples_new, 0)
260 |     triple_to_img_new = torch.stack(triple_to_img_new, 0)
261 | 
262 |     return triples_new, triple_to_img_new
263 | 
264 | 
265 | def parse_bool(pred_graphs, generative, use_gt_boxes, use_feats):
266 |     # returns name of output directory depending on arguments
267 | 
268 |     if pred_graphs:
269 |         name = "pred/"
270 |     else:
271 |         name = ""
272 |     if generative: # fully generative mode
273 |         return name + "generative"
274 |     else:
275 |         if use_gt_boxes:
276 |             b = "withbox"
277 |         else:
278 |             b = "nobox"
279 |         if use_feats:
280 |             f = "withfeats"
281 |         else:
282 |             f = "nofeats"
283 | 
284 |         return name + b + "_" + f
285 | 
286 | 
287 | def is_background(label_id):
288 | 
289 |     if label_id in [169, 60, 61, 49, 141, 8, 11, 52, 66]:
290 |         return True
291 |     else:
292 |         return False
293 | 
294 | 
295 | def get_selected_objects():
296 | 
297 |   objs = ["", "apple", "ball", "banana", "beach", "bike", "bird", "bus", "bush", "cat", "car", "chair", "cloud", "dog",
298 |           "elephant", "field", "giraffe", "man", "motorcycle", "ocean", "person", "plane", "sheep", "tree", "zebra"]
299 | 
300 |   return objs
301 | 
--------------------------------------------------------------------------------
/scripts/print_args.py:
--------------------------------------------------------------------------------
 1 | #!/usr/bin/python
 2 | #
 3 | # Copyright 2018 Google LLC
 4 | #
 5 | # Licensed under the Apache License, Version 2.0 (the "License");
 6 | # you may not use this file except in compliance with the License.
 7 | # You may obtain a copy of the License at
 8 | #
 9 | #      http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | 
17 | import argparse
18 | import torch
19 | 
20 | 
21 | """
22 | Tiny utility to print the command-line args used for a checkpoint
23 | """
24 | 
25 | 
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument('checkpoint')
28 | 
29 | 
30 | def main(args):
31 |   checkpoint = torch.load(args.checkpoint, map_location='cpu')
32 |   for k, v in checkpoint['args'].items():
33 |     print(k, v)
34 | 
35 | 
36 | if __name__ == '__main__':
37 |   args = parser.parse_args()
38 |   main(args)
39 | 
40 | 
--------------------------------------------------------------------------------
/scripts/run_train.py:
--------------------------------------------------------------------------------
 1 | import yaml
 2 | from addict import Dict
 3 | from scripts.train import argument_parser
 4 | from scripts.train import main as train_main
 5 | import torch
 6 | import shutil
 7 | import os
 8 | import random
 9 | import sys
10 | 
11 | 
12 | # needed to solve tuple reading issue
13 | class PrettySafeLoader(yaml.SafeLoader):
14 |     def construct_python_tuple(self, node):
15 |         return tuple(self.construct_sequence(node))
16 | 
17 | PrettySafeLoader.add_constructor(
18 |     u'tag:yaml.org,2002:python/tuple',
19 |     PrettySafeLoader.construct_python_tuple)
20 | 
21 | 
22 | def main():
23 |     # argument: name of yaml config file
24 |     try:
25 |         filename = sys.argv[1]
26 |     except IndexError as error:
27 |         print("provide yaml file as command-line argument!")
28 |         exit()
29 | 
30 |     config = Dict(yaml.load(open(filename), Loader=PrettySafeLoader))
31 |     os.makedirs(config.log_dir, exist_ok=True)
32 |     # save a copy in the experiment dir
33 |     shutil.copyfile(filename, os.path.join(config.log_dir, 'args.yaml'))
34 | 
35 |     torch.cuda.set_device(config.gpu)
36 |     torch.manual_seed(config.seed)
37 |     random.seed(config.seed)
38 | 
39 |     args = yaml_to_parser(config)
40 |     train_main(args)
41 | 
42 | 
43 | def yaml_to_parser(config):
44 |     parser = argument_parser()
45 |     args, unknown = parser.parse_known_args()
46 | 
47 |     args_dict = vars(args)
48 |     for key, value in config.items():
49 |         try:
50 |             args_dict[key] = value
51 |         except KeyError:
52 |             print(key, ' was not found in arguments')
53 |     return args
54 | 
55 | 
56 | if __name__ == '__main__':
57 |     main()
58 | 
--------------------------------------------------------------------------------
/scripts/train_utils.py:
--------------------------------------------------------------------------------
 1 | from tensorboardX import SummaryWriter
 2 | import torch
 3 | from torch.functional import F
 4 | from collections import defaultdict
 5 | 
 6 | 
 7 | def check_args(args):
 8 |   H, W = args.image_size
 9 |   for _ in args.decoder_network_dims[1:]:
10 |     H = H // 2
11 | 
12 |   if H == 0:
13 |     raise ValueError("Too many layers in decoder network")
14 | 
15 | 
16 | def add_loss(total_loss, curr_loss, loss_dict, loss_name, weight=1):
17 |   curr_loss = curr_loss * weight
18 |   loss_dict[loss_name] = curr_loss.item()
19 |   if total_loss is not None:
20 |     total_loss += curr_loss
21 |   else:
22 |     total_loss = curr_loss
23 |   return total_loss
24 | 
25 | 
26 | def calculate_model_losses(args, skip_pixel_loss, img, img_pred, bbox, bbox_pred):
27 | 
28 |   total_loss = torch.zeros(1).to(img)
29 |   losses = {}
30 | 
31 |   l1_pixel_weight = args.l1_pixel_loss_weight
32 | 
33 |   if skip_pixel_loss:
34 |     l1_pixel_weight = 0
35 | 
36 |   l1_pixel_loss = F.l1_loss(img_pred, img)
37 | 
38 |   total_loss = add_loss(total_loss, l1_pixel_loss, losses, 'L1_pixel_loss',
39 |                         l1_pixel_weight)
40 | 
41 |   loss_bbox = F.mse_loss(bbox_pred, bbox)
42 |   total_loss = add_loss(total_loss, loss_bbox, losses, 'bbox_pred',
43 |                         args.bbox_pred_loss_weight)
44 | 
45 |   return total_loss, losses
46 | 
47 | 
48 | def init_checkpoint_dict(args, vocab, model_kwargs, d_obj_kwargs, d_img_kwargs):
49 | 
50 |   ckpt = {
51 |         'args': args.__dict__, 'vocab': vocab, 'model_kwargs': model_kwargs,
52 |         'd_obj_kwargs': d_obj_kwargs, 'd_img_kwargs': d_img_kwargs,
53 |         'losses_ts': [], 'losses': defaultdict(list), 'd_losses': defaultdict(list),
54 |         'checkpoint_ts': [], 'train_iou': [], 'val_losses': defaultdict(list),
55 |         'val_iou': [], 'counters': {'t': None, 'epoch': None},
56 |         'model_state': None, 'model_best_state': None, 'optim_state': None,
57 |         'd_obj_state': None, 'd_obj_best_state': None, 'd_obj_optim_state': None,
58 |         'd_img_state': None, 'd_img_best_state': None, 'd_img_optim_state': None,
59 |         'best_t': [],
60 |       }
61 |   return ckpt
62 | 
63 | 
64 | def print_G_state(args, t, losses, writer, checkpoint):
65 |   # print generator losses on terminal and save on tensorboard
66 | 
67 |   print('t = %d / %d' % (t, args.num_iterations))
68 |   for name, val in losses.items():
69 |     print('G [%s]: %.4f' % (name, val))
70 |     writer.add_scalar('G {}'.format(name), val, global_step=t)
71 |     checkpoint['losses'][name].append(val)
72 |   checkpoint['losses_ts'].append(t)
73 | 
74 | 
75 | def print_D_obj_state(args, t, writer, checkpoint, d_obj_losses):
76 |   # print D_obj losses on terminal and save on tensorboard
77 | 
78 |   for name, val in d_obj_losses.items():
79 |     print('D_obj [%s]: %.4f' % (name, val))
80 |     writer.add_scalar('D_obj {}'.format(name), val, global_step=t)
81 |     checkpoint['d_losses'][name].append(val)
82 | 
83 | 
84 | def print_D_img_state(args, t, writer, checkpoint, d_img_losses):
85 |   # print D_img losses on terminal and save on tensorboard
86 | 
87 |   for name, val in d_img_losses.items():
88 |     print('D_img [%s]: %.4f' % (name, val))
89 |     writer.add_scalar('D_img {}'.format(name), val, global_step=t)
90 |     checkpoint['d_losses'][name].append(val)
91 | 
--------------------------------------------------------------------------------
/simsg/SPADE/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/he-dhamo/simsg/a1decd989a53a329c82eacf8124c597f8f2bc308/simsg/SPADE/__init__.py
--------------------------------------------------------------------------------
/simsg/SPADE/architectures.py:
--------------------------------------------------------------------------------
 1 | """
 2 | Copyright (C) 2019 NVIDIA Corporation.  All rights reserved.
 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
 4 | """
 5 | 
 6 | import torch
 7 | import torch.nn as nn
 8 | import torch.nn.functional as F
 9 | import torchvision
10 | #import torch.nn.utils.spectral_norm as spectral_norm
11 | #from models.networks.normalization import SPADE
12 | 
13 | 
14 | # VGG architecter, used for the perceptual loss using a pretrained VGG network
15 | class VGG19(torch.nn.Module):
16 |     def __init__(self, requires_grad=False):
17 |         super().__init__()
18 |         vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
19 |         self.slice1 = torch.nn.Sequential()
20 |         self.slice2 = torch.nn.Sequential()
21 |         self.slice3 = torch.nn.Sequential()
22 |         self.slice4 = torch.nn.Sequential()
23 |         self.slice5 = torch.nn.Sequential()
24 |         for x in range(2):
25 |             self.slice1.add_module(str(x), vgg_pretrained_features[x])
26 |         for x in range(2, 7):
27 |             self.slice2.add_module(str(x), vgg_pretrained_features[x])
28 |         for x in range(7, 12):
29 |             self.slice3.add_module(str(x), vgg_pretrained_features[x])
30 |         for x in range(12, 21):
31 |             self.slice4.add_module(str(x), vgg_pretrained_features[x])
32 |         for x in range(21, 30):
33 |             self.slice5.add_module(str(x), vgg_pretrained_features[x])
34 |         if not requires_grad:
35 |             for param in self.parameters():
36 |                 param.requires_grad = False
37 | 
38 |     def forward(self, X):
39 |         h_relu1 = self.slice1(X)
40 |         h_relu2 = self.slice2(h_relu1)
41 |         h_relu3 = self.slice3(h_relu2)
42 |         h_relu4 = self.slice4(h_relu3)
43 |         h_relu5 = self.slice5(h_relu4)
44 |         out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
45 |         return out
46 | 
--------------------------------------------------------------------------------
/simsg/SPADE/base_network.py:
--------------------------------------------------------------------------------
 1 | """
 2 | Copyright (C) 2019 NVIDIA Corporation.  All rights reserved.
 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
 4 | """
 5 | 
 6 | import torch.nn as nn
 7 | from torch.nn import init
 8 | 
 9 | 
10 | class BaseNetwork(nn.Module):
11 |     def __init__(self):
12 |         super(BaseNetwork, self).__init__()
13 | 
14 |     @staticmethod
15 |     def modify_commandline_options(parser, is_train):
16 |         return parser
17 | 
18 |     def print_network(self):
19 |         if isinstance(self, list):
20 |             self = self[0]
21 |         num_params = 0
22 |         for param in self.parameters():
23 |             num_params += param.numel()
24 |         print('Network [%s] was created. Total number of parameters: %.1f million. '
25 |               'To see the architecture, do print(network).'
26 |               % (type(self).__name__, num_params / 1000000))
27 | 
28 |     def init_weights(self, init_type='normal', gain=0.02):
29 |         def init_func(m):
30 |             classname = m.__class__.__name__
31 |             if classname.find('BatchNorm2d') != -1:
32 |                 if hasattr(m, 'weight') and m.weight is not None:
33 |                     init.normal_(m.weight.data, 1.0, gain)
34 |                 if hasattr(m, 'bias') and m.bias is not None:
35 |                     init.constant_(m.bias.data, 0.0)
36 |             elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
37 |                 if init_type == 'normal':
38 |                     init.normal_(m.weight.data, 0.0, gain)
39 |                 elif init_type == 'xavier':
40 |                     init.xavier_normal_(m.weight.data, gain=gain)
41 |                 elif init_type == 'xavier_uniform':
42 |                     init.xavier_uniform_(m.weight.data, gain=1.0)
43 |                 elif init_type == 'kaiming':
44 |                     init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
45 |                 elif init_type == 'orthogonal':
46 |                     init.orthogonal_(m.weight.data, gain=gain)
47 |                 elif init_type == 'none':  # uses pytorch's default init method
48 |                     m.reset_parameters()
49 |                 else:
50 |                     raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
51 |                 if hasattr(m, 'bias') and m.bias is not None:
52 |                     init.constant_(m.bias.data, 0.0)
53 | 
54 |         self.apply(init_func)
55 | 
56 |         # propagate to children
57 |         for m in self.children():
58 |             if hasattr(m, 'init_weights'):
59 |                 m.init_weights(init_type, gain)
60 | 
--------------------------------------------------------------------------------
/simsg/SPADE/normalization.py:
--------------------------------------------------------------------------------
  1 | """
  2 | Copyright (C) 2019 NVIDIA Corporation.  All rights reserved.
  3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
  4 | """
  5 | 
  6 | import re
  7 | import torch
  8 | import torch.nn as nn
  9 | import torch.nn.functional as F
 10 | #from models.networks.sync_batchnorm import SynchronizedBatchNorm2d
 11 | import torch.nn.utils.spectral_norm as spectral_norm
 12 | 
 13 | 
 14 | # Returns a function that creates a normalization function
 15 | # that does not condition on semantic map
 16 | def get_nonspade_norm_layer(opt, norm_type='instance'):
 17 |     # helper function to get # output channels of the previous layer
 18 |     def get_out_channel(layer):
 19 |         if hasattr(layer, 'out_channels'):
 20 |             return getattr(layer, 'out_channels')
 21 |         return layer.weight.size(0)
 22 | 
 23 |     # this function will be returned
 24 |     def add_norm_layer(layer):
 25 |         nonlocal norm_type
 26 |         if norm_type.startswith('spectral'):
 27 |             layer = spectral_norm(layer)
 28 |             subnorm_type = norm_type[len('spectral'):]
 29 | 
 30 |         if subnorm_type == 'none' or len(subnorm_type) == 0:
 31 |             return layer
 32 | 
 33 |         # remove bias in the previous layer, which is meaningless
 34 |         # since it has no effect after normalization
 35 |         if getattr(layer, 'bias', None) is not None:
 36 |             delattr(layer, 'bias')
 37 |             layer.register_parameter('bias', None)
 38 | 
 39 |         if subnorm_type == 'batch':
 40 |             norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
 41 |         #elif subnorm_type == 'sync_batch':
 42 |         #    norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True)
 43 |         elif subnorm_type == 'instance':
 44 |             norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
 45 |         else:
 46 |             raise ValueError('normalization layer %s is not recognized' % subnorm_type)
 47 | 
 48 |         return nn.Sequential(layer, norm_layer)
 49 | 
 50 |     return add_norm_layer
 51 | 
 52 | 
 53 | # Creates SPADE normalization layer based on the given configuration
 54 | # SPADE consists of two steps. First, it normalizes the activations using
 55 | # your favorite normalization method, such as Batch Norm or Instance Norm.
 56 | # Second, it applies scale and bias to the normalized output, conditioned on
 57 | # the segmentation map.
 58 | # The format of |config_text| is spade(norm)(ks), where
 59 | # (norm) specifies the type of parameter-free normalization.
 60 | #       (e.g. syncbatch, batch, instance)
 61 | # (ks) specifies the size of kernel in the SPADE module (e.g. 3x3)
 62 | # Example |config_text| will be spadesyncbatch3x3, or spadeinstance5x5.
 63 | # Also, the other arguments are
 64 | # |norm_nc|: the #channels of the normalized activations, hence the output dim of SPADE
 65 | # |label_nc|: the #channels of the input semantic map, hence the input dim of SPADE
 66 | class SPADE(nn.Module):
 67 |     def __init__(self, config_text, norm_nc, label_nc):
 68 |         super().__init__()
 69 | 
 70 |         assert config_text.startswith('spade')
 71 |         parsed = re.search('spade(\D+)(\d)x\d', config_text)
 72 |         param_free_norm_type = str(parsed.group(1))
 73 |         ks = int(parsed.group(2))
 74 | 
 75 |         if param_free_norm_type == 'instance':
 76 |             self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
 77 |         elif param_free_norm_type == 'syncbatch':
 78 |             self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)
 79 |         elif param_free_norm_type == 'batch':
 80 |             self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
 81 |         else:
 82 |             raise ValueError('%s is not a recognized param-free norm type in SPADE'
 83 |                              % param_free_norm_type)
 84 | 
 85 |         # The dimension of the intermediate embedding space. Yes, hardcoded.
 86 |         nhidden = 128
 87 | 
 88 |         pw = ks // 2
 89 |         self.mlp_shared = nn.Sequential(
 90 |             nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
 91 |             nn.ReLU()
 92 |         )
 93 |         self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
 94 |         self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
 95 | 
 96 |     def forward(self, x, segmap):
 97 | 
 98 |         # Part 1. generate parameter-free normalized activations
 99 |         normalized = self.param_free_norm(x)
100 | 
101 |         # Part 2. produce scaling and bias conditioned on semantic map
102 |         #segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
103 |         actv = self.mlp_shared(segmap)
104 |         gamma = self.mlp_gamma(actv)
105 |         beta = self.mlp_beta(actv)
106 |         #print(normalized.shape, x.shape, segmap.shape, gamma.shape, beta.shape)
107 |         # apply scale and bias
108 |         out = normalized * (1 + gamma) + beta
109 | 
110 |         return out
111 | 
--------------------------------------------------------------------------------
/simsg/__init__.py:
--------------------------------------------------------------------------------
 1 | #!/usr/bin/python
 2 | #
 3 | # Copyright 2018 Google LLC
 4 | #
 5 | # Licensed under the Apache License, Version 2.0 (the "License");
 6 | # you may not use this file except in compliance with the License.
 7 | # You may obtain a copy of the License at
 8 | #
 9 | #      http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | 
--------------------------------------------------------------------------------
/simsg/bilinear.py:
--------------------------------------------------------------------------------
  1 | #!/usr/bin/python
  2 | #
  3 | # Copyright 2018 Google LLC
  4 | #
  5 | # Licensed under the Apache License, Version 2.0 (the "License");
  6 | # you may not use this file except in compliance with the License.
  7 | # You may obtain a copy of the License at
  8 | #
  9 | #      http://www.apache.org/licenses/LICENSE-2.0
 10 | #
 11 | # Unless required by applicable law or agreed to in writing, software
 12 | # distributed under the License is distributed on an "AS IS" BASIS,
 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14 | # See the License for the specific language governing permissions and
 15 | # limitations under the License.
 16 | 
 17 | import torch
 18 | import torch.nn.functional as F
 19 | from simsg.utils import timeit
 20 | 
 21 | 
 22 | """
 23 | Functions for performing differentiable bilinear cropping of images, for use in
 24 | the object discriminator
 25 | """
 26 | 
 27 | 
 28 | def crop_bbox_batch(feats, bbox, bbox_to_feats, HH, WW=None, backend='cudnn'):
 29 |   """
 30 |   Inputs:
 31 |   - feats: FloatTensor of shape (N, C, H, W)
 32 |   - bbox: FloatTensor of shape (B, 4) giving bounding box coordinates
 33 |   - bbox_to_feats: LongTensor of shape (B,) mapping boxes to feature maps;
 34 |     each element is in the range [0, N) and bbox_to_feats[b] = i means that
 35 |     bbox[b] will be cropped from feats[i].
 36 |   - HH, WW: Size of the output crops
 37 | 
 38 |   Returns:
 39 |   - crops: FloatTensor of shape (B, C, HH, WW) where crops[i] uses bbox[i] to
 40 |     crop from feats[bbox_to_feats[i]].
 41 |   """
 42 |   if backend == 'cudnn':
 43 |     return crop_bbox_batch_cudnn(feats, bbox, bbox_to_feats, HH, WW)
 44 |   #print("here ========================================" ,feats.size())
 45 |   N, C, H, W = feats.size()
 46 |   B = bbox.size(0)
 47 |   if WW is None: WW = HH
 48 |   dtype, device = feats.dtype, feats.device
 49 |   crops = torch.zeros(B, C, HH, WW, dtype=dtype, device=device)
 50 |   for i in range(N):
 51 |     idx = (bbox_to_feats.data == i).nonzero()
 52 |     if idx.dim() == 0:
 53 |       continue
 54 |     idx = idx.view(-1)
 55 |     n = idx.size(0)
 56 |     cur_feats = feats[i].view(1, C, H, W).expand(n, C, H, W).contiguous()
 57 |     cur_bbox = bbox[idx]
 58 |     cur_crops = crop_bbox(cur_feats, cur_bbox, HH, WW)
 59 |     crops[idx] = cur_crops
 60 |   return crops
 61 | 
 62 | 
 63 | def _invperm(p):
 64 |   N = p.size(0)
 65 |   eye = torch.arange(0, N).type_as(p)
 66 |   pp = (eye[:, None] == p).nonzero()[:, 1]
 67 |   return pp
 68 | 
 69 | 
 70 | def crop_bbox_batch_cudnn(feats, bbox, bbox_to_feats, HH, WW=None):
 71 |   #print("here ========================================" ,feats.size())
 72 |   N, C, H, W = feats.size()
 73 |   B = bbox.size(0)
 74 |   if WW is None: WW = HH
 75 |   dtype = feats.data.type()
 76 | 
 77 |   feats_flat, bbox_flat, all_idx = [], [], []
 78 |   for i in range(N):
 79 |     idx = (bbox_to_feats.data == i).nonzero()
 80 |     if idx.dim() == 0:
 81 |       continue
 82 |     idx = idx.view(-1)
 83 |     n = idx.size(0)
 84 |     cur_feats = feats[i].view(1, C, H, W).expand(n, C, H, W).contiguous()
 85 |     cur_bbox = bbox[idx]
 86 | 
 87 |     feats_flat.append(cur_feats)
 88 |     bbox_flat.append(cur_bbox)
 89 |     all_idx.append(idx)
 90 | 
 91 |   feats_flat = torch.cat(feats_flat, dim=0)
 92 |   bbox_flat = torch.cat(bbox_flat, dim=0)
 93 |   crops = crop_bbox(feats_flat, bbox_flat, HH, WW, backend='cudnn')
 94 | 
 95 |   # If the crops were sequential (all_idx is identity permutation) then we can
 96 |   # simply return them; otherwise we need to permute crops by the inverse
 97 |   # permutation from all_idx.
 98 |   all_idx = torch.cat(all_idx, dim=0)
 99 |   eye = torch.arange(0, B).type_as(all_idx)
100 |   if (all_idx == eye).all():
101 |     return crops
102 |   return crops[_invperm(all_idx)]
103 | 
104 | 
105 | def crop_bbox(feats, bbox, HH, WW=None, backend='cudnn'):
106 |   """
107 |   Take differentiable crops of feats specified by bbox.
108 | 
109 |   Inputs:
110 |   - feats: Tensor of shape (N, C, H, W)
111 |   - bbox: Bounding box coordinates of shape (N, 4) in the format
112 |     [x0, y0, x1, y1] in the [0, 1] coordinate space.
113 |   - HH, WW: Size of the output crops.
114 | 
115 |   Returns:
116 |   - crops: Tensor of shape (N, C, HH, WW) where crops[i] is the portion of
117 |     feats[i] specified by bbox[i], reshaped to (HH, WW) using bilinear sampling.
118 |   """
119 |   N = feats.size(0)
120 |   assert bbox.size(0) == N
121 |   #print(bbox.shape)
122 |   assert bbox.size(1) == 4
123 |   if WW is None: WW = HH
124 |   if backend == 'cudnn':
125 |     # Change box from [0, 1] to [-1, 1] coordinate system
126 |     bbox = 2 * bbox - 1
127 |   x0, y0 = bbox[:, 0], bbox[:, 1]
128 |   x1, y1 = bbox[:, 2], bbox[:, 3]
129 |   X = tensor_linspace(x0, x1, steps=WW).view(N, 1, WW).expand(N, HH, WW)
130 |   Y = tensor_linspace(y0, y1, steps=HH).view(N, HH, 1).expand(N, HH, WW)
131 |   if backend == 'jj':
132 |     return bilinear_sample(feats, X, Y)
133 |   elif backend == 'cudnn':
134 |     grid = torch.stack([X, Y], dim=3)
135 |     return F.grid_sample(feats, grid)
136 | 
137 | 
138 | 
139 | def uncrop_bbox(feats, bbox, H, W=None, fill_value=0):
140 |   """
141 |   Inverse operation to crop_bbox; construct output images where the feature maps
142 |   from feats have been reshaped and placed into the positions specified by bbox.
143 | 
144 |   Inputs:
145 |   - feats: Tensor of shape (N, C, HH, WW)
146 |   - bbox: Bounding box coordinates of shape (N, 4) in the format
147 |     [x0, y0, x1, y1] in the [0, 1] coordinate space.
148 |   - H, W: Size of output.
149 |   - fill_value: Portions of the output image that are outside the bounding box
150 |     will be filled with this value.
151 | 
152 |   Returns:
153 |   - out: Tensor of shape (N, C, H, W) where the portion of out[i] given by
154 |     bbox[i] contains feats[i], reshaped using bilinear sampling.
155 |   """
156 |   N, C = feats.size(0), feats.size(1)
157 |   assert bbox.size(0) == N
158 |   assert bbox.size(1) == 4
159 |   if W is None: H = W
160 | 
161 |   x0, y0 = bbox[:, 0], bbox[:, 1]
162 |   x1, y1 = bbox[:, 2], bbox[:, 3]
163 |   ww = x1 - x0
164 |   hh = y1 - y0
165 | 
166 |   x0 = x0.contiguous().view(N, 1).expand(N, H)
167 |   x1 = x1.contiguous().view(N, 1).expand(N, H)
168 |   ww = ww.view(N, 1).expand(N, H)
169 | 
170 |   y0 = y0.contiguous().view(N, 1).expand(N, W)
171 |   y1 = y1.contiguous().view(N, 1).expand(N, W)
172 |   hh = hh.view(N, 1).expand(N, W)
173 |   
174 |   X = torch.linspace(0, 1, steps=W).view(1, W).expand(N, W).to(feats)
175 |   Y = torch.linspace(0, 1, steps=H).view(1, H).expand(N, H).to(feats)
176 | 
177 |   X = (X - x0) / ww
178 |   Y = (Y - y0) / hh
179 | 
180 |   # For ByteTensors, (x + y).clamp(max=1) gives logical_or
181 |   X_out_mask = ((X < 0) + (X > 1)).view(N, 1, W).expand(N, H, W)
182 |   Y_out_mask = ((Y < 0) + (Y > 1)).view(N, H, 1).expand(N, H, W)
183 |   out_mask = (X_out_mask + Y_out_mask).clamp(max=1)
184 |   out_mask = out_mask.view(N, 1, H, W).expand(N, C, H, W)
185 | 
186 |   X = X.view(N, 1, W).expand(N, H, W)
187 |   Y = Y.view(N, H, 1).expand(N, H, W)
188 | 
189 |   out = bilinear_sample(feats, X, Y)
190 |   out[out_mask] = fill_value
191 |   return out
192 | 
193 | 
194 | def bilinear_sample(feats, X, Y):
195 |   """
196 |   Perform bilinear sampling on the features in feats using the sampling grid
197 |   given by X and Y.
198 | 
199 |   Inputs:
200 |   - feats: Tensor holding input feature map, of shape (N, C, H, W)
201 |   - X, Y: Tensors holding x and y coordinates of the sampling
202 |     grids; both have shape shape (N, HH, WW) and have elements in the range [0, 1].
203 |   Returns:
204 |   - out: Tensor of shape (B, C, HH, WW) where out[i] is computed
205 |     by sampling from feats[idx[i]] using the sampling grid (X[i], Y[i]).
206 |   """
207 |   N, C, H, W = feats.size()
208 |   assert X.size() == Y.size()
209 |   assert X.size(0) == N
210 |   _, HH, WW = X.size()
211 | 
212 |   X = X.mul(W)
213 |   Y = Y.mul(H)
214 | 
215 |   # Get the x and y coordinates for the four samples
216 |   x0 = X.floor().clamp(min=0, max=W-1)
217 |   x1 = (x0 + 1).clamp(min=0, max=W-1)
218 |   y0 = Y.floor().clamp(min=0, max=H-1)
219 |   y1 = (y0 + 1).clamp(min=0, max=H-1)
220 | 
221 |   # In numpy we could do something like feats[i, :, y0, x0] to pull out
222 |   # the elements of feats at coordinates y0 and x0, but PyTorch doesn't
223 |   # yet support this style of indexing. Instead we have to use the gather
224 |   # method, which only allows us to index along one dimension at a time;
225 |   # therefore we will collapse the features (BB, C, H, W) into (BB, C, H * W)
226 |   # and index along the last dimension. Below we generate linear indices into
227 |   # the collapsed last dimension for each of the four combinations we need.
228 |   y0x0_idx = (W * y0 + x0).view(N, 1, HH * WW).expand(N, C, HH * WW)
229 |   y1x0_idx = (W * y1 + x0).view(N, 1, HH * WW).expand(N, C, HH * WW)
230 |   y0x1_idx = (W * y0 + x1).view(N, 1, HH * WW).expand(N, C, HH * WW)
231 |   y1x1_idx = (W * y1 + x1).view(N, 1, HH * WW).expand(N, C, HH * WW)
232 | 
233 |   # Actually use gather to pull out the values from feats corresponding
234 |   # to our four samples, then reshape them to (BB, C, HH, WW)
235 |   feats_flat = feats.view(N, C, H * W)
236 |   v1 = feats_flat.gather(2, y0x0_idx.long()).view(N, C, HH, WW)
237 |   v2 = feats_flat.gather(2, y1x0_idx.long()).view(N, C, HH, WW)
238 |   v3 = feats_flat.gather(2, y0x1_idx.long()).view(N, C, HH, WW)
239 |   v4 = feats_flat.gather(2, y1x1_idx.long()).view(N, C, HH, WW)
240 | 
241 |   # Compute the weights for the four samples
242 |   w1 = ((x1 - X) * (y1 - Y)).view(N, 1, HH, WW).expand(N, C, HH, WW)
243 |   w2 = ((x1 - X) * (Y - y0)).view(N, 1, HH, WW).expand(N, C, HH, WW)
244 |   w3 = ((X - x0) * (y1 - Y)).view(N, 1, HH, WW).expand(N, C, HH, WW)
245 |   w4 = ((X - x0) * (Y - y0)).view(N, 1, HH, WW).expand(N, C, HH, WW)
246 | 
247 |   # Multiply the samples by the weights to give our interpolated results.
248 |   out = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
249 |   return out
250 | 
251 | 
252 | def tensor_linspace(start, end, steps=10):
253 |   """
254 |   Vectorized version of torch.linspace.
255 | 
256 |   Inputs:
257 |   - start: Tensor of any shape
258 |   - end: Tensor of the same shape as start
259 |   - steps: Integer
260 | 
261 |   Returns:
262 |   - out: Tensor of shape start.size() + (steps,), such that
263 |     out.select(-1, 0) == start, out.select(-1, -1) == end,
264 |     and the other elements of out linearly interpolate between
265 |     start and end.
266 |   """
267 |   assert start.size() == end.size()
268 |   view_size = start.size() + (1,)
269 |   w_size = (1,) * start.dim() + (steps,)
270 |   out_size = start.size() + (steps,)
271 | 
272 |   start_w = torch.linspace(1, 0, steps=steps).to(start)
273 |   start_w = start_w.view(w_size).expand(out_size)
274 |   end_w = torch.linspace(0, 1, steps=steps).to(start)
275 |   end_w = end_w.view(w_size).expand(out_size)
276 | 
277 |   start = start.contiguous().view(view_size).expand(out_size)
278 |   end = end.contiguous().view(view_size).expand(out_size)
279 | 
280 |   out = start_w * start + end_w * end
281 |   return out
282 | 
283 | 
284 | if __name__ == '__main__':
285 |   import numpy as np
286 |   from scipy.misc import imread, imsave, imresize
287 | 
288 |   cat = imresize(imread('cat.jpg'), (256, 256))
289 |   dog = imresize(imread('dog.jpg'), (256, 256))
290 |   feats = torch.stack([
291 |       torch.from_numpy(cat.transpose(2, 0, 1).astype(np.float32)),
292 |       torch.from_numpy(dog.transpose(2, 0, 1).astype(np.float32))],
293 |             dim=0)
294 | 
295 |   boxes = torch.FloatTensor([
296 |             [0, 0, 1, 1],
297 |             [0.25, 0.25, 0.75, 0.75],
298 |             [0, 0, 0.5, 0.5],
299 |           ])
300 | 
301 |   box_to_feats = torch.LongTensor([1, 0, 1]).cuda()
302 | 
303 |   feats, boxes = feats.cuda(), boxes.cuda()
304 |   crops = crop_bbox_batch_cudnn(feats, boxes, box_to_feats, 128)
305 |   for i in range(crops.size(0)):
306 |     crop_np = crops.data[i].cpu().numpy().transpose(1, 2, 0).astype(np.uint8)
307 |     imsave('out%d.png' % i, crop_np)
308 | 
--------------------------------------------------------------------------------
/simsg/data/__init__.py:
--------------------------------------------------------------------------------
 1 | #!/usr/bin/python
 2 | #
 3 | # Copyright 2018 Google LLC
 4 | #
 5 | # Licensed under the Apache License, Version 2.0 (the "License");
 6 | # you may not use this file except in compliance with the License.
 7 | # You may obtain a copy of the License at
 8 | #
 9 | #      http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | 
17 | from .utils import imagenet_preprocess, imagenet_deprocess
18 | from .utils import imagenet_deprocess_batch
19 | 
--------------------------------------------------------------------------------
/simsg/data/clevr.py:
--------------------------------------------------------------------------------
  1 | #!/usr/bin/python
  2 | #
  3 | # Copyright 2020 Azade Farshad
  4 | #
  5 | # Licensed under the Apache License, Version 2.0 (the "License");
  6 | # you may not use this file except in compliance with the License.
  7 | # You may obtain a copy of the License at
  8 | #
  9 | #      http://www.apache.org/licenses/LICENSE-2.0
 10 | #
 11 | # Unless required by applicable law or agreed to in writing, software
 12 | # distributed under the License is distributed on an "AS IS" BASIS,
 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14 | # See the License for the specific language governing permissions and
 15 | # limitations under the License.
 16 | 
 17 | import os
 18 | 
 19 | import torch
 20 | from torch.utils.data import Dataset
 21 | 
 22 | import torchvision.transforms as T
 23 | 
 24 | import numpy as np
 25 | import h5py, json
 26 | import PIL
 27 | 
 28 | from .utils import imagenet_preprocess, Resize
 29 | sg_task = True
 30 | 
 31 | def conv_src_to_target(voc_s, voc_t):
 32 |   dic = {}
 33 |   for k, val in voc_s['object_name_to_idx'].items():
 34 |     dic[val] = voc_t['object_name_to_idx'][k]
 35 |   return dic
 36 | 
 37 | 
 38 | class SceneGraphWithPairsDataset(Dataset):
 39 |   def __init__(self, vocab, h5_path, image_dir, image_size=(256, 256),
 40 |                normalize_images=True, max_objects=10, max_samples=None,
 41 |                include_relationships=True, use_orphaned_objects=True,
 42 |                mode='train', clean_repeats=True):
 43 |     super(SceneGraphWithPairsDataset, self).__init__()
 44 | 
 45 |     assert mode in ["train", "eval", "auto_withfeats", "auto_nofeats", "reposition", "remove", "replace"]
 46 | 
 47 |     CLEVR_target_dir = os.path.split(h5_path)[0]
 48 |     CLEVR_SRC_DIR = os.path.join(os.path.split(CLEVR_target_dir)[0], 'source')
 49 | 
 50 |     vocab_json_s = os.path.join(CLEVR_SRC_DIR, "vocab.json")
 51 |     vocab_json_t = os.path.join(CLEVR_target_dir, "vocab.json")
 52 | 
 53 |     with open(vocab_json_s, 'r') as f:
 54 |       vocab_src = json.load(f)
 55 | 
 56 |     with open(vocab_json_t, 'r') as f:
 57 |       vocab_t = json.load(f)
 58 | 
 59 |     self.mode = mode
 60 | 
 61 |     self.image_dir = image_dir
 62 |     self.image_source_dir = os.path.join(os.path.split(image_dir)[0], 'source') #Azade
 63 | 
 64 |     src_h5_path = os.path.join(self.image_source_dir, os.path.split(h5_path)[-1])
 65 |     print(self.image_dir, src_h5_path)
 66 | 
 67 |     self.image_size = image_size
 68 |     self.vocab = vocab
 69 |     self.vocab_src = vocab_src
 70 |     self.vocab_t = vocab_t
 71 |     self.num_objects = len(vocab['object_idx_to_name'])
 72 |     self.use_orphaned_objects = use_orphaned_objects
 73 |     self.max_objects = max_objects
 74 |     self.max_samples = max_samples
 75 |     self.include_relationships = include_relationships
 76 | 
 77 |     self.evaluating = mode != 'train'
 78 | 
 79 |     self.clean_repeats = clean_repeats
 80 | 
 81 |     transform = [Resize(image_size), T.ToTensor()]
 82 |     if normalize_images:
 83 |       transform.append(imagenet_preprocess())
 84 |     self.transform = T.Compose(transform)
 85 | 
 86 |     self.data = {}
 87 |     with h5py.File(h5_path, 'r') as f:
 88 |       for k, v in f.items():
 89 |         if k == 'image_paths':
 90 |           self.image_paths = list(v)
 91 |         else:
 92 |           self.data[k] = torch.IntTensor(np.asarray(v))
 93 | 
 94 |     self.data_src = {}
 95 |     with h5py.File(src_h5_path, 'r') as f:
 96 |       for k, v in f.items():
 97 |         if k == 'image_paths':
 98 |           self.image_paths_src = list(v)
 99 |         else:
100 |           self.data_src[k] = torch.IntTensor(np.asarray(v))
101 | 
102 |   def __len__(self):
103 |     num = self.data['object_names'].size(0)
104 |     if self.max_samples is not None:
105 |       return min(self.max_samples, num)
106 |     return num
107 | 
108 |   def __getitem__(self, index):
109 |     """
110 |     Returns a tuple of:
111 |     - image: FloatTensor of shape (C, H, W)
112 |     - objs: LongTensor of shape (num_objs,)
113 |     - boxes: FloatTensor of shape (num_objs, 4) giving boxes for objects in
114 |       (x0, y0, x1, y1) format, in a [0, 1] coordinate system.
115 |     - triples: LongTensor of shape (num_triples, 3) where triples[t] = [i, p, j]
116 |       means that (objs[i], p, objs[j]) is a triple.
117 |     """
118 |     img_path = os.path.join(self.image_dir, self.image_paths[index])
119 |     img_source_path = os.path.join(self.image_source_dir, self.image_paths[index])
120 | 
121 |     src_to_target_obj = conv_src_to_target(self.vocab_src, self.vocab_t)
122 | 
123 |     with open(img_path, 'rb') as f:
124 |       with PIL.Image.open(f) as image:
125 |         WW, HH = image.size
126 |         image = self.transform(image.convert('RGB'))
127 | 
128 |     with open(img_source_path, 'rb') as f:
129 |       with PIL.Image.open(f) as image_src:
130 |         #WW, HH = image.size
131 |         image_src = self.transform(image_src.convert('RGB'))
132 | 
133 |     H, W = self.image_size
134 | 
135 |     # Figure out which objects appear in relationships and which don't
136 |     obj_idxs_with_rels = set()
137 |     obj_idxs_without_rels = set(range(self.data['objects_per_image'][index].item()))
138 |     for r_idx in range(self.data['relationships_per_image'][index]):
139 |       s = self.data['relationship_subjects'][index, r_idx].item()
140 |       o = self.data['relationship_objects'][index, r_idx].item()
141 |       obj_idxs_with_rels.add(s)
142 |       obj_idxs_with_rels.add(o)
143 |       obj_idxs_without_rels.discard(s)
144 |       obj_idxs_without_rels.discard(o)
145 | 
146 |     obj_idxs = list(obj_idxs_with_rels)
147 |     obj_idxs_without_rels = list(obj_idxs_without_rels)
148 |     if len(obj_idxs) > self.max_objects - 1:
149 |       obj_idxs = obj_idxs[:self.max_objects]
150 |     if len(obj_idxs) < self.max_objects - 1 and self.use_orphaned_objects:
151 |       num_to_add = self.max_objects - 1 - len(obj_idxs)
152 |       num_to_add = min(num_to_add, len(obj_idxs_without_rels))
153 |       obj_idxs += obj_idxs_without_rels[:num_to_add]
154 | 
155 |     num_objs = len(obj_idxs) + 1
156 | 
157 |     objs = torch.LongTensor(num_objs).fill_(-1)
158 | 
159 |     boxes = torch.FloatTensor([[0, 0, 1, 1]]).repeat(num_objs, 1)
160 |     obj_idx_mapping = {}
161 |     for i, obj_idx in enumerate(obj_idxs):
162 |       objs[i] = self.data['object_names'][index, obj_idx].item()
163 |       x, y, w, h = self.data['object_boxes'][index, obj_idx].tolist()
164 |       x0 = float(x) / WW
165 |       y0 = float(y) / HH
166 |       x1 = float(x + w) / WW
167 |       y1 = float(y + h) / HH
168 |       boxes[i] = torch.FloatTensor([x0, y0, x1, y1])
169 |       obj_idx_mapping[obj_idx] = i
170 | 
171 |     # The last object will be the special __image__ object
172 |     objs[num_objs - 1] = self.vocab['object_name_to_idx']['__image__']
173 | 
174 |     triples = []
175 |     for r_idx in range(self.data['relationships_per_image'][index].item()):
176 |       if not self.include_relationships:
177 |         break
178 |       s = self.data['relationship_subjects'][index, r_idx].item()
179 |       p = self.data['relationship_predicates'][index, r_idx].item()
180 |       o = self.data['relationship_objects'][index, r_idx].item()
181 |       s = obj_idx_mapping.get(s, None)
182 |       o = obj_idx_mapping.get(o, None)
183 |       if s is not None and o is not None:
184 |         if self.clean_repeats and [s, p, o] in triples:
185 |           continue
186 |         triples.append([s, p, o])
187 | 
188 |     # Add dummy __in_image__ relationships for all objects
189 |     in_image = self.vocab['pred_name_to_idx']['__in_image__']
190 |     for i in range(num_objs - 1):
191 |       triples.append([i, in_image, num_objs - 1])
192 | 
193 |     triples = torch.LongTensor(triples)
194 | 
195 |     #Source image
196 | 
197 |     # Figure out which objects appear in relationships and which don't
198 |     obj_idxs_with_rels_src = set()
199 |     obj_idxs_without_rels_src = set(range(self.data_src['objects_per_image'][index].item()))
200 |     for r_idx in range(self.data_src['relationships_per_image'][index]):
201 |       s = self.data_src['relationship_subjects'][index, r_idx].item()
202 |       o = self.data_src['relationship_objects'][index, r_idx].item()
203 |       obj_idxs_with_rels_src.add(s)
204 |       obj_idxs_with_rels_src.add(o)
205 |       obj_idxs_without_rels_src.discard(s)
206 |       obj_idxs_without_rels_src.discard(o)
207 | 
208 |     obj_idxs_src = list(obj_idxs_with_rels_src)
209 |     obj_idxs_without_rels_src = list(obj_idxs_without_rels_src)
210 |     if len(obj_idxs_src) > self.max_objects - 1:
211 |       obj_idxs_src = obj_idxs_src[:self.max_objects]
212 |     if len(obj_idxs_src) < self.max_objects - 1 and self.use_orphaned_objects:
213 |       num_to_add = self.max_objects - 1 - len(obj_idxs_src)
214 |       num_to_add = min(num_to_add, len(obj_idxs_without_rels_src))
215 |       obj_idxs_src += obj_idxs_without_rels_src[:num_to_add]
216 | 
217 |     num_objs_src = len(obj_idxs_src) + 1
218 | 
219 |     objs_src = torch.LongTensor(num_objs_src).fill_(-1)
220 | 
221 |     boxes_src = torch.FloatTensor([[0, 0, 1, 1]]).repeat(num_objs_src, 1)
222 |     obj_idx_mapping_src = {}
223 |     for i, obj_idx in enumerate(obj_idxs_src):
224 |       objs_src[i] = src_to_target_obj[self.data_src['object_names'][index, obj_idx].item()]
225 |       x, y, w, h = self.data_src['object_boxes'][index, obj_idx].tolist()
226 |       x0 = float(x) / WW
227 |       y0 = float(y) / HH
228 |       x1 = float(x + w) / WW
229 |       y1 = float(y + h) / HH
230 |       boxes_src[i] = torch.FloatTensor([x0, y0, x1, y1])
231 |       obj_idx_mapping_src[obj_idx] = i
232 | 
233 |     # The last object will be the special __image__ object
234 |     objs_src[num_objs_src - 1] = self.vocab_src['object_name_to_idx']['__image__']
235 | 
236 |     triples_src = []
237 |     for r_idx in range(self.data_src['relationships_per_image'][index].item()):
238 |       if not self.include_relationships:
239 |         break
240 |       s = self.data_src['relationship_subjects'][index, r_idx].item()
241 |       p = self.data_src['relationship_predicates'][index, r_idx].item()
242 |       o = self.data_src['relationship_objects'][index, r_idx].item()
243 |       s = obj_idx_mapping_src.get(s, None)
244 |       o = obj_idx_mapping_src.get(o, None)
245 |       if s is not None and o is not None:
246 |         if self.clean_repeats and [s, p, o] in triples_src:
247 |           continue
248 |         triples_src.append([s, p, o])
249 | 
250 |     # Add dummy __in_image__ relationships for all objects
251 |     in_image = self.vocab_src['pred_name_to_idx']['__in_image__']
252 |     for i in range(num_objs_src - 1):
253 |       triples_src.append([i, in_image, num_objs_src - 1])
254 | 
255 |     triples_src = torch.LongTensor(triples_src)
256 | 
257 |     return image, image_src, objs, objs_src, boxes, boxes_src, triples, triples_src
258 | 
259 | 
260 | def collate_fn_withpairs(batch):
261 |   """
262 |   Collate function to be used when wrapping a SceneGraphWithPairsDataset in a
263 |   DataLoader. Returns a tuple of the following:
264 | 
265 |   - imgs, imgs_src: target and source FloatTensors of shape (N, C, H, W)
266 |   - objs, objs_src: target and source LongTensors of shape (num_objs,) giving categories for all objects
267 |   - boxes, boxes_src: target and source FloatTensors of shape (num_objs, 4) giving boxes for all objects
268 |   - triples, triples_src: target and source FloatTensors of shape (num_triples, 3) giving all triples, where
269 |     triples[t] = [i, p, j] means that [objs[i], p, objs[j]] is a triple
270 |   - obj_to_img: LongTensor of shape (num_objs,) mapping objects to images;
271 |     obj_to_img[i] = n means that objs[i] belongs to imgs[n]
272 |   - triple_to_img: LongTensor of shape (num_triples,) mapping triples to images;
273 |     triple_to_img[t] = n means that triples[t] belongs to imgs[n]
274 |   - imgs_masked: FloatTensor of shape (N, 4, H, W)
275 |   """
276 |   # batch is a list, and each element is (image, objs, boxes, triples)
277 |   all_imgs, all_imgs_src, all_objs, all_objs_src, all_boxes, all_boxes_src, all_triples, all_triples_src = [], [], [], [], [], [], [], []
278 |   all_obj_to_img, all_triple_to_img = [], []
279 |   all_imgs_masked = []
280 | 
281 |   obj_offset = 0
282 | 
283 |   for i, (img, image_src, objs, objs_src, boxes, boxes_src, triples, triples_src) in enumerate(batch):
284 | 
285 |     all_imgs.append(img[None])
286 |     all_imgs_src.append(image_src[None])
287 |     num_objs, num_triples = objs.size(0), triples.size(0)
288 |     all_objs.append(objs)
289 |     all_objs_src.append(objs_src)
290 |     all_boxes.append(boxes)
291 |     all_boxes_src.append(boxes_src)
292 |     triples = triples.clone()
293 |     triples_src = triples_src.clone()
294 | 
295 |     triples[:, 0] += obj_offset
296 |     triples[:, 2] += obj_offset
297 |     all_triples.append(triples)
298 | 
299 |     triples_src[:, 0] += obj_offset
300 |     triples_src[:, 2] += obj_offset
301 |     all_triples_src.append(triples_src)
302 | 
303 |     all_obj_to_img.append(torch.LongTensor(num_objs).fill_(i))
304 |     all_triple_to_img.append(torch.LongTensor(num_triples).fill_(i))
305 | 
306 |     # prepare input 4-channel image
307 |     # initialize mask channel with zeros
308 |     masked_img = image_src.clone()
309 |     mask = torch.zeros_like(masked_img)
310 |     mask = mask[0:1,:,:]
311 |     masked_img = torch.cat([masked_img, mask], 0)
312 |     all_imgs_masked.append(masked_img[None])
313 | 
314 |     obj_offset += num_objs
315 | 
316 |   all_imgs_masked = torch.cat(all_imgs_masked)
317 | 
318 |   all_imgs = torch.cat(all_imgs)
319 |   all_imgs_src = torch.cat(all_imgs_src)
320 |   all_objs = torch.cat(all_objs)
321 |   all_objs_src = torch.cat(all_objs_src)
322 |   all_boxes = torch.cat(all_boxes)
323 |   all_boxes_src = torch.cat(all_boxes_src)
324 |   all_triples = torch.cat(all_triples)
325 |   all_triples_src = torch.cat(all_triples_src)
326 |   all_obj_to_img = torch.cat(all_obj_to_img)
327 |   all_triple_to_img = torch.cat(all_triple_to_img)
328 | 
329 |   out = (all_imgs, all_imgs_src, all_objs, all_objs_src, all_boxes, all_boxes_src, all_triples, all_triples_src,
330 |          all_obj_to_img, all_triple_to_img, all_imgs_masked)
331 | 
332 |   return out
333 | 
--------------------------------------------------------------------------------
/simsg/data/clevr_gen/1_gen_data.sh:
--------------------------------------------------------------------------------
 1 | #!/bin/bash
 2 | #
 3 | # Copyright 2020 Azade Farshad
 4 | #
 5 | # Licensed under the Apache License, Version 2.0 (the "License");
 6 | # you may not use this file except in compliance with the License.
 7 | # You may obtain a copy of the License at
 8 | #
 9 | #      http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | 
17 | mode="$1" #removal, replacement, addition, rel_change
18 | blender_path="$2" #"/home/azadef/Downloads/blender-2.79-cuda10-x86_64"
19 | merge="$3"
20 | if [ $merge ]
21 | then 
22 | out_folder="./output";
23 | else
24 | out_folder="./output_$mode";
25 | fi
26 | for i in {1..6}
27 | do 
28 | if [ $merge ]
29 | then 
30 | start=`expr $(ls -1 "$out_folder/images/" | wc -l)`;
31 | else
32 | start=`expr $(ls -1 "$out_folder/../output_rel_change/images/" | wc -l) + $(ls -1 "$out_folder/../output_remove/images/" | wc -l) + $(ls -1 "$out_folder/../output_replacement/images/" | wc -l) + $(ls -1 "$out_folder/../output_addition/images/" | wc -l)`;
33 | start=$((start/2));
34 | fi
35 | echo $start
36 | "$blender_path"/blender --background --python render_clevr.py -- --num_images 800 --output_image_dir "$out_folder/images/" --output_scene_dir "$out_folder/scenes/" --output_scene_file "$out_folder/CLEVR_scenes.json" --start_idx $start --use_gpu 1 --mode "$mode"
37 | done
38 | 
--------------------------------------------------------------------------------
/simsg/data/clevr_gen/2_arrange_data.sh:
--------------------------------------------------------------------------------
 1 | # Copyright 2020 Azade Farshad
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #      http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | 
15 | num=0;
16 | res_dir='./MyClevr';
17 | for j in */; do 
18 | find $(pwd)/$j"images" -iname "*_target.png" | sort -u | while read p; do
19 |   cp $p $res_dir"/target/images/$num.png";
20 |   q=$(pwd)/$j"scenes/${p##*/}"
21 |   q="${q%%.*}.json"
22 |   cp $q $res_dir"/target/scenes/$num.json";
23 |   ((num++));
24 |   echo $p;
25 |   echo $q;
26 |   echo $num;
27 | done
28 | num=`expr $(ls -1 "$res_dir/target/images/" | wc -l)`;
29 | done
30 | 
--------------------------------------------------------------------------------
/simsg/data/clevr_gen/3_clevrToVG.py:
--------------------------------------------------------------------------------
  1 | #!/usr/bin/python
  2 | #
  3 | # Copyright 2020 Azade Farshad
  4 | #
  5 | # Licensed under the Apache License, Version 2.0 (the "License");
  6 | # you may not use this file except in compliance with the License.
  7 | # You may obtain a copy of the License at
  8 | #
  9 | #      http://www.apache.org/licenses/LICENSE-2.0
 10 | #
 11 | # Unless required by applicable law or agreed to in writing, software
 12 | # distributed under the License is distributed on an "AS IS" BASIS,
 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14 | # See the License for the specific language governing permissions and
 15 | # limitations under the License.
 16 | 
 17 | """
 18 | Stores the generated CLEVR data in the same format as Visual Genome, for easy loading
 19 | """
 20 | 
 21 | import math, sys, random, argparse, json, os
 22 | import numpy as np
 23 | #import cv2
 24 | import itertools
 25 | 
 26 | 
 27 | def get_rel(argument):
 28 |     switcher = {
 29 |         3: "left of",
 30 |         2: "behind",
 31 |         1: "right of",
 32 |         0: "front of"
 33 |     }
 34 |     return switcher.get(argument, "Invalid label")
 35 | 
 36 | def get_xy_coords(obj):
 37 |     obj_x, obj_y, obj_w, obj_h = obj['bbox']
 38 | 
 39 |     return obj_x, obj_y, obj_w, obj_h
 40 | 
 41 | def getLabel(obj, subj):
 42 |     diff = np.subtract(obj[:-1], subj[:-1])
 43 |     label = int(np.argmax(np.abs(diff)))
 44 |     if diff[label] > 0:
 45 |         label += 2
 46 |     return label
 47 | 
 48 | 
 49 | np.random.seed(seed=0)
 50 | parser = argparse.ArgumentParser()
 51 | 
 52 | # Input data
 53 | parser.add_argument('--data_path', default='/media/azadef/MyHDD/Data/MyClevr_postcvpr/source')
 54 | 
 55 | args = parser.parse_args()
 56 | 
 57 | data_path = args.data_path
 58 | im_names = sorted(os.listdir(os.path.join(data_path,'images')))
 59 | json_names = sorted(os.listdir(os.path.join(data_path,'scenes')))
 60 | 
 61 | labels_file_path = os.path.join(data_path, 'image_data.json')
 62 | objs_file_path = os.path.join(data_path, 'objects.json')
 63 | rels_file_path = os.path.join(data_path, 'relationships.json')
 64 | attrs_file_path = os.path.join(data_path, 'attributes.json')
 65 | split_file_path = os.path.join(data_path, 'vg_splits.json')
 66 | 
 67 | rel_alias_path = os.path.join(data_path, 'relationship_alias.txt')
 68 | obj_alias_path = os.path.join(data_path, 'object_alias.txt')
 69 | 
 70 | im_id = 0
 71 | 
 72 | num_objs = 0
 73 | num_rels = 0
 74 | obj_data = []
 75 | rel_data = []
 76 | scenedata = []
 77 | attr_data = []
 78 | 
 79 | obj_data_file = open(objs_file_path, 'w')
 80 | rel_data_file = open(rels_file_path, 'w')
 81 | attr_data_file = open(attrs_file_path, 'w')
 82 | split_data_file = open(split_file_path, 'w')
 83 | 
 84 | with open(rel_alias_path, 'w'):
 85 |   print("Created relationship alias file!")
 86 | 
 87 | with open(obj_alias_path, 'w'):
 88 |   print("Created object alias file!")
 89 | 
 90 | num_ims = len(json_names)
 91 | 
 92 | split_dic = {}
 93 | rand_perm = np.random.choice(num_ims, num_ims, replace=False)
 94 | train_split = int(num_ims * 0.8)
 95 | val_split = int(num_ims * 0.9)
 96 | 
 97 | train_list = [int(num) for num in rand_perm[0:train_split]]
 98 | val_list = [int(num) for num in rand_perm[train_split+1:val_split]]
 99 | test_list = [int(num) for num in rand_perm[val_split+1:num_ims]]
100 | 
101 | split_dic['train'] = train_list
102 | split_dic['val'] = val_list
103 | split_dic['test'] = test_list
104 | 
105 | json.dump(split_dic, split_data_file, indent=2)
106 | split_data_file.close()
107 | 
108 | with open(labels_file_path, 'w') as labels_file:
109 |     for json_file in json_names:
110 |         file_name = os.path.join(data_path, 'scenes',json_file)
111 |         new_obj = {}
112 |         new_rel = {}
113 |         new_scene = {}
114 |         new_attr = {}
115 |         
116 |         with open(file_name, 'r') as f:
117 |             properties = json.load(f)
118 | 
119 |             new_obj['image_id'] = im_id
120 |             new_rel['image_id'] = im_id
121 |             new_scene['image_id'] = im_id
122 |             new_attr['image_id'] = im_id
123 | 
124 |             rels = []
125 |             objs = properties["objects"]
126 |             attr_objs = []
127 | 
128 |             for j, obj in enumerate(objs):
129 |                 obj['object_id'] = num_objs + j
130 |                 obj['x'], obj['y'], obj['w'], obj['h'] = get_xy_coords(obj)
131 |                 obj.pop('bbox')
132 |                 
133 |                 obj['names'] = [obj["color"] + " " + obj['shape']]
134 |                 obj.pop('shape')
135 |                 obj.pop('rotation')
136 |                 attr_obj = obj
137 |                 attr_obj["attributes"] = [obj["color"]]
138 |                 obj.pop('size')
139 |                 attr_objs.append(attr_obj)
140 | 
141 | 
142 |             new_obj['objects'] = objs
143 |             new_scene['objects'] = objs
144 |             new_attr['attributes'] = attr_objs
145 | 
146 |             pairs = list(itertools.combinations(objs, 2))
147 |             indices = list((i,j) for ((i,_),(j,_)) in itertools.combinations(enumerate(objs), 2))
148 |             for ii, (obj, subj) in enumerate(pairs):
149 |                 label = getLabel(obj['3d_coords'], subj['3d_coords'])
150 |                 predicate = get_rel(label)
151 |                 if predicate == 'Invalid label':
152 |                     print("Invalid label!")
153 |                 rel = {}
154 |                 rel['predicate'] = predicate
155 |                 rel['name'] = predicate
156 |                 rel['object'] = obj
157 |                 rel['subject'] = subj
158 |                 rel['relationship_id'] = num_rels + ii
159 |                 rel['object_id'] = obj['object_id']
160 |                 rel['subject_id'] = subj['object_id']
161 |                 rels.append(rel)
162 | 
163 |             new_rel['relationships'] = rels
164 |             new_scene['relationships'] = rels
165 |             im_path = os.path.join(data_path,'images',json_file.strip('.json') + '.png')
166 | 
167 | 
168 |             new_obj['url'] = im_path
169 | 
170 |             new_scene['url'] = im_path
171 |             new_scene['width'] = 320
172 |             new_scene['height'] = 240
173 |             im_id += 1
174 |             num_objs += len(objs)
175 |             num_rels += len(rels)
176 |             
177 |         scenedata.append(new_scene)
178 |         obj_data.append(new_obj)
179 |         rel_data.append(new_rel)
180 |         attr_data.append(new_attr)
181 | 
182 |     json.dump(scenedata, labels_file, indent=2)
183 | 
184 | json.dump(obj_data, obj_data_file, indent=2)
185 | json.dump(rel_data, rel_data_file, indent=2)
186 | json.dump(attr_data, attr_data_file, indent=2)
187 | 
188 | obj_data_file.close()
189 | rel_data_file.close()
190 | attr_data_file.close()
191 | print("Done processing!")
192 | 
193 | 
--------------------------------------------------------------------------------
/simsg/data/clevr_gen/README_CLEVR.md:
--------------------------------------------------------------------------------
 1 | ## Setup
 2 | You will need to download and install [Blender](https://www.blender.org/); code has been developed and tested using Blender version 2.78c but other versions may work as well.
 3 | 
 4 | ## Data Generation
 5 | Run 1_gen_data.sh to generate source/target pairs for each of the available modes. The first argument is the change mode, the second argument is the path to blender directory and the third argument is the option to merge all of the modes into one folder. The manipulation modes are: removal, replacement, addition, rel_change
 6 | Sample command:
 7 | ```bash
 8 | sh 1_gen_data.sh removal /home/user/blender2.79-cuda10-x86_64/ 1
 9 | ```
10 | Run this file for all of the modes to generate all variations of change.
11 | 
12 | ## Merging all the folders
13 | If you set the merge argument to 1, you can skip this set.
14 | 
15 | Otherwise, move all of the generated images for each mode to a single folder so that you have all images in the "images" folder and all the scenes in "scenes" folder.
16 | 
17 | ## Arranging the generated image into separate folder
18 | To arrange the the data, run 2_arrange_data.sh. The script will move the previously generated data to MyClevr directory in the required format.
19 | 
20 | ```bash
21 | sh 2_arrange_data.sh
22 | ```
23 | 
24 | ## Converting clevr format to VG format
25 | This step converts the generated data to the format required by SIMSG similar to the Visual Genome dataset. The final scene graphs and the required files are generated in this step.
26 | 
27 | ```bash
28 | python 3_clevrToVG.py
29 | ```
30 | 
31 | ## Preprocessing
32 | 
33 | Run scripts/preprocess_vg.py on the generated CLEVR data for both source and target folders to preprocess the data for SIMSG. Set VG_DIR inside the file to point to CLEVR source or target directory.
34 | 
35 | To make sure no image is removed by the preprocess filtering step so that we have corresponding source/target pairs, set the following arguments:
36 | 
37 | ```bash
38 | python scripts/preprocess_vg.py --min_object_instances 0 --min_attribute_instances 0 --min_object_size 0 --min_objects_per_image 1 --min_relationship_instances 1 --max_relationships_per_image 50
39 | ```
40 | ## Acknowledgment
41 | 
42 | Our CLEVR generation code is based on  this repo.
43 | 
--------------------------------------------------------------------------------
/simsg/data/clevr_gen/collect_scenes.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2017-present, Facebook, Inc.
 2 | # All rights reserved.
 3 | #
 4 | # This source code is licensed under the BSD-style license found in the
 5 | # LICENSE file in the root directory of this source tree. An additional grant
 6 | # of patent rights can be found in the PATENTS file in the same directory.
 7 | 
 8 | import argparse, json, os
 9 | 
10 | """
11 | During rendering, each CLEVR scene file is dumped to disk as a separate JSON
12 | file; this is convenient for distributing rendering across multiple machines.
13 | This script collects all CLEVR scene files stored in a directory and combines
14 | them into a single JSON file. This script also adds the version number, date,
15 | and license to the output file.
16 | """
17 | 
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--input_dir', default='output/scenes')
20 | parser.add_argument('--output_file', default='output/CLEVR_misc_scenes.json')
21 | parser.add_argument('--version', default='1.0')
22 | parser.add_argument('--date', default='7/8/2017')
23 | parser.add_argument('--license',
24 |            default='Creative Commons Attribution (CC-BY 4.0')
25 | 
26 | 
27 | def main(args):
28 |   input_files = os.listdir(args.input_dir)
29 |   scenes = []
30 |   split = None
31 |   for filename in os.listdir(args.input_dir):
32 |     if not filename.endswith('.json'):
33 |       continue
34 |     path = os.path.join(args.input_dir, filename)
35 |     with open(path, 'r') as f:
36 |       scene = json.load(f)
37 |     scenes.append(scene)
38 |     if split is not None:
39 |       msg = 'Input directory contains scenes from multiple splits'
40 |       assert scene['split'] == split, msg
41 |     else:
42 |       split = scene['split']
43 |   scenes.sort(key=lambda s: s['image_index'])
44 |   for s in scenes:
45 |     print(s['image_filename'])
46 |   output = {
47 |     'info': {
48 |       'date': args.date,
49 |       'version': args.version,
50 |       'split': split,
51 |       'license': args.license,
52 |     },
53 |     'scenes': scenes
54 |   }
55 |   with open(args.output_file, 'w') as f:
56 |     json.dump(output, f)
57 | 
58 | 
59 | if __name__ == '__main__':
60 |   args = parser.parse_args()
61 |   main(args)
62 | 
63 | 
--------------------------------------------------------------------------------
/simsg/data/clevr_gen/data/CoGenT_A.json:
--------------------------------------------------------------------------------
 1 | {
 2 |   "cube": [
 3 |     "gray", "blue", "brown", "yellow"
 4 |   ],
 5 |   "cylinder": [
 6 |     "red", "green", "purple", "cyan"
 7 |   ],
 8 |   "sphere": [
 9 |     "gray", "red", "blue", "green", "brown", "purple", "cyan", "yellow"
10 |   ]
11 | }
12 | 
--------------------------------------------------------------------------------
/simsg/data/clevr_gen/data/CoGenT_B.json:
--------------------------------------------------------------------------------
 1 | {
 2 |   "cube": [
 3 |     "red", "green", "purple", "cyan"
 4 |   ],
 5 |   "cylinder": [
 6 |     "gray", "blue", "brown", "yellow"
 7 |   ],
 8 |   "sphere": [
 9 |     "gray", "red", "blue", "green", "brown", "purple", "cyan", "yellow"  
10 |   ]
11 | }
12 | 
--------------------------------------------------------------------------------
/simsg/data/clevr_gen/data/base_scene.blend:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/he-dhamo/simsg/a1decd989a53a329c82eacf8124c597f8f2bc308/simsg/data/clevr_gen/data/base_scene.blend
--------------------------------------------------------------------------------
/simsg/data/clevr_gen/data/materials/MyMetal.blend:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/he-dhamo/simsg/a1decd989a53a329c82eacf8124c597f8f2bc308/simsg/data/clevr_gen/data/materials/MyMetal.blend
--------------------------------------------------------------------------------
/simsg/data/clevr_gen/data/materials/Rubber.blend:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/he-dhamo/simsg/a1decd989a53a329c82eacf8124c597f8f2bc308/simsg/data/clevr_gen/data/materials/Rubber.blend
--------------------------------------------------------------------------------
/simsg/data/clevr_gen/data/properties.json:
--------------------------------------------------------------------------------
 1 | {
 2 |   "shapes": {
 3 |     "cube": "SmoothCube_v2",
 4 |     "sphere": "Sphere",
 5 |     "cylinder": "SmoothCylinder"
 6 |   },
 7 |   "colors": {
 8 |     "gray": [87, 87, 87],
 9 |     "red": [173, 35, 35],
10 |     "blue": [42, 75, 215],
11 |     "green": [29, 105, 20],
12 |     "brown": [129, 74, 25],
13 |     "purple": [129, 38, 192],
14 |     "cyan": [41, 208, 208],
15 |     "yellow": [255, 238, 51]
16 |   },
17 |   "materials": {
18 |     "rubber": "Rubber",
19 |     "metal": "MyMetal"
20 |   },
21 |   "sizes": {
22 |     "large": 0.7,
23 |     "small": 0.35
24 |   }
25 | }
26 | 
--------------------------------------------------------------------------------
/simsg/data/clevr_gen/data/shapes/SmoothCube_v2.blend:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/he-dhamo/simsg/a1decd989a53a329c82eacf8124c597f8f2bc308/simsg/data/clevr_gen/data/shapes/SmoothCube_v2.blend
--------------------------------------------------------------------------------
/simsg/data/clevr_gen/data/shapes/SmoothCylinder.blend:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/he-dhamo/simsg/a1decd989a53a329c82eacf8124c597f8f2bc308/simsg/data/clevr_gen/data/shapes/SmoothCylinder.blend
--------------------------------------------------------------------------------
/simsg/data/clevr_gen/data/shapes/Sphere.blend:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/he-dhamo/simsg/a1decd989a53a329c82eacf8124c597f8f2bc308/simsg/data/clevr_gen/data/shapes/Sphere.blend
--------------------------------------------------------------------------------
/simsg/data/clevr_gen/utils.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2017-present, Facebook, Inc.
  2 | # All rights reserved.
  3 | #
  4 | # This source code is licensed under the BSD-style license found in the
  5 | # LICENSE file in the root directory of this source tree. An additional grant
  6 | # of patent rights can be found in the PATENTS file in the same directory.
  7 | 
  8 | import sys, random, os
  9 | import bpy, bpy_extras
 10 | 
 11 | 
 12 | """
 13 | Some utility functions for interacting with Blender
 14 | """
 15 | 
 16 | 
 17 | def extract_args(input_argv=None):
 18 |   """
 19 |   Pull out command-line arguments after "--". Blender ignores command-line flags
 20 |   after --, so this lets us forward command line arguments from the blender
 21 |   invocation to our own script.
 22 |   """
 23 |   if input_argv is None:
 24 |     input_argv = sys.argv
 25 |   output_argv = []
 26 |   if '--' in input_argv:
 27 |     idx = input_argv.index('--')
 28 |     output_argv = input_argv[(idx + 1):]
 29 |   return output_argv
 30 | 
 31 | 
 32 | def parse_args(parser, argv=None):
 33 |   return parser.parse_args(extract_args(argv))
 34 | 
 35 | 
 36 | # I wonder if there's a better way to do this?
 37 | def delete_object(obj):
 38 |   """ Delete a specified blender object """
 39 |   for o in bpy.data.objects:
 40 |     o.select = False
 41 |   obj.select = True
 42 |   bpy.ops.object.delete()
 43 | 
 44 | 
 45 | def get_camera_coords(cam, pos):
 46 |   """
 47 |   For a specified point, get both the 3D coordinates and 2D pixel-space
 48 |   coordinates of the point from the perspective of the camera.
 49 | 
 50 |   Inputs:
 51 |   - cam: Camera object
 52 |   - pos: Vector giving 3D world-space position
 53 | 
 54 |   Returns a tuple of:
 55 |   - (px, py, pz): px and py give 2D image-space coordinates; pz gives depth
 56 |     in the range [-1, 1]
 57 |   """
 58 |   scene = bpy.context.scene
 59 |   x, y, z = bpy_extras.object_utils.world_to_camera_view(scene, cam, pos)
 60 |   scale = scene.render.resolution_percentage / 100.0
 61 |   w = int(scale * scene.render.resolution_x)
 62 |   h = int(scale * scene.render.resolution_y)
 63 |   px = int(round(x * w))
 64 |   py = int(round(h - y * h))
 65 |   return (px, py, z)
 66 | 
 67 | 
 68 | def set_layer(obj, layer_idx):
 69 |   """ Move an object to a particular layer """
 70 |   # Set the target layer to True first because an object must always be on
 71 |   # at least one layer.
 72 |   obj.layers[layer_idx] = True
 73 |   for i in range(len(obj.layers)):
 74 |     obj.layers[i] = (i == layer_idx)
 75 | 
 76 | 
 77 | def add_object(object_dir, name, scale, loc, theta=0):
 78 |   """
 79 |   Load an object from a file. We assume that in the directory object_dir, there
 80 |   is a file named "$name.blend" which contains a single object named "$name"
 81 |   that has unit size and is centered at the origin.
 82 | 
 83 |   - scale: scalar giving the size that the object should be in the scene
 84 |   - loc: tuple (x, y) giving the coordinates on the ground plane where the
 85 |     object should be placed.
 86 |   """
 87 |   # First figure out how many of this object are already in the scene so we can
 88 |   # give the new object a unique name
 89 |   count = 0
 90 |   for obj in bpy.data.objects:
 91 |     if obj.name.startswith(name):
 92 |       count += 1
 93 | 
 94 |   filename = os.path.join(object_dir, '%s.blend' % name, 'Object', name)
 95 |   bpy.ops.wm.append(filename=filename)
 96 | 
 97 |   # Give it a new name to avoid conflicts
 98 |   new_name = '%s_%d' % (name, count)
 99 |   bpy.data.objects[name].name = new_name
100 | 
101 |   # Set the new object as active, then rotate, scale, and translate it
102 |   x, y = loc
103 |   bpy.context.scene.objects.active = bpy.data.objects[new_name]
104 |   bpy.context.object.rotation_euler[2] = theta
105 |   bpy.ops.transform.resize(value=(scale, scale, scale))
106 |   bpy.ops.transform.translate(value=(x, y, scale))
107 | 
108 | 
109 | def load_materials(material_dir):
110 |   """
111 |   Load materials from a directory. We assume that the directory contains .blend
112 |   files with one material each. The file X.blend has a single NodeTree item named
113 |   X; this NodeTree item must have a "Color" input that accepts an RGBA value.
114 |   """
115 |   for fn in os.listdir(material_dir):
116 |     if not fn.endswith('.blend'): continue
117 |     name = os.path.splitext(fn)[0]
118 |     filepath = os.path.join(material_dir, fn, 'NodeTree', name)
119 |     bpy.ops.wm.append(filename=filepath)
120 | 
121 | 
122 | def add_material(name, relchange=False, **properties):
123 |   """
124 |   Create a new material and assign it to the active object. "name" should be the
125 |   name of a material that has been previously loaded using load_materials.
126 |   """
127 |   # Figure out how many materials are already in the scene
128 |   mat_count = len(bpy.data.materials)
129 | 
130 |   # Create a new material; it is not attached to anything and
131 |   # it will be called "Material"
132 |   bpy.ops.material.new()
133 | 
134 |   # Get a reference to the material we just created and rename it;
135 |   # then the next time we make a new material it will still be called
136 |   # "Material" and we will still be able to look it up by name
137 |   mat = bpy.data.materials['Material']
138 |   mat.name = 'Material_%d' % mat_count
139 | 
140 |   # Attach the new material to the active object
141 |   # Make sure it doesn't already have materials
142 |   obj = bpy.context.active_object
143 |   if not relchange:
144 |     assert len(obj.data.materials) == 0
145 |   obj.data.materials.append(mat)
146 | 
147 |   # Find the output node of the new material
148 |   output_node = None
149 |   for n in mat.node_tree.nodes:
150 |     if n.name == 'Material Output':
151 |       output_node = n
152 |       break
153 | 
154 |   # Add a new GroupNode to the node tree of the active material,
155 |   # and copy the node tree from the preloaded node group to the
156 |   # new group node. This copying seems to happen by-value, so
157 |   # we can create multiple materials of the same type without them
158 |   # clobbering each other
159 |   group_node = mat.node_tree.nodes.new('ShaderNodeGroup')
160 |   group_node.node_tree = bpy.data.node_groups[name]
161 | 
162 |   # Find and set the "Color" input of the new group node
163 |   for inp in group_node.inputs:
164 |     if inp.name in properties:
165 |       inp.default_value = properties[inp.name]
166 | 
167 |   # Wire the output of the new group node to the input of
168 |   # the MaterialOutput node
169 |   mat.node_tree.links.new(
170 |       group_node.outputs['Shader'],
171 |       output_node.inputs['Surface'],
172 |   )
173 | 
174 | 
--------------------------------------------------------------------------------
/simsg/data/utils.py:
--------------------------------------------------------------------------------
  1 | #!/usr/bin/python
  2 | #
  3 | # Copyright 2018 Google LLC
  4 | #
  5 | # Licensed under the Apache License, Version 2.0 (the "License");
  6 | # you may not use this file except in compliance with the License.
  7 | # You may obtain a copy of the License at
  8 | #
  9 | #      http://www.apache.org/licenses/LICENSE-2.0
 10 | #
 11 | # Unless required by applicable law or agreed to in writing, software
 12 | # distributed under the License is distributed on an "AS IS" BASIS,
 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14 | # See the License for the specific language governing permissions and
 15 | # limitations under the License.
 16 | 
 17 | import PIL
 18 | import torch
 19 | import torchvision.transforms as T
 20 | 
 21 | 
 22 | IMAGENET_MEAN = [0.485, 0.456, 0.406]
 23 | IMAGENET_STD = [0.229, 0.224, 0.225]
 24 | 
 25 | INV_IMAGENET_MEAN = [-m for m in IMAGENET_MEAN]
 26 | INV_IMAGENET_STD = [1.0 / s for s in IMAGENET_STD]
 27 | 
 28 | 
 29 | def imagenet_preprocess():
 30 |   return T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
 31 | 
 32 | 
 33 | def rescale(x):
 34 |   lo, hi = x.min(), x.max()
 35 |   return x.sub(lo).div(hi - lo)
 36 | 
 37 | 
 38 | def imagenet_deprocess(rescale_image=True):
 39 |   transforms = [
 40 |     T.Normalize(mean=[0, 0, 0], std=INV_IMAGENET_STD),
 41 |     T.Normalize(mean=INV_IMAGENET_MEAN, std=[1.0, 1.0, 1.0]),
 42 |   ]
 43 |   if rescale_image:
 44 |     transforms.append(rescale)
 45 |   return T.Compose(transforms)
 46 | 
 47 | 
 48 | def imagenet_deprocess_batch(imgs, rescale=True):
 49 |   """
 50 |   Input:
 51 |   - imgs: FloatTensor of shape (N, C, H, W) giving preprocessed images
 52 | 
 53 |   Output:
 54 |   - imgs_de: ByteTensor of shape (N, C, H, W) giving deprocessed images
 55 |     in the range [0, 255]
 56 |   """
 57 |   if isinstance(imgs, torch.autograd.Variable):
 58 |     imgs = imgs.data
 59 |   imgs = imgs.cpu().clone()
 60 |   deprocess_fn = imagenet_deprocess(rescale_image=rescale)
 61 |   imgs_de = []
 62 |   for i in range(imgs.size(0)):
 63 |     img_de = deprocess_fn(imgs[i])[None]
 64 |     img_de = img_de.mul(255).clamp(0, 255).byte()
 65 |     imgs_de.append(img_de)
 66 |   imgs_de = torch.cat(imgs_de, dim=0)
 67 |   return imgs_de
 68 | 
 69 | 
 70 | class Resize(object):
 71 |   def __init__(self, size, interp=PIL.Image.BILINEAR):
 72 |     if isinstance(size, tuple):
 73 |       H, W = size
 74 |       self.size = (W, H)
 75 |     else:
 76 |       self.size = (size, size)
 77 |     self.interp = interp
 78 | 
 79 |   def __call__(self, img):
 80 |     return img.resize(self.size, self.interp)
 81 | 
 82 | 
 83 | def unpack_var(v):
 84 |   if isinstance(v, torch.autograd.Variable):
 85 |     return v.data
 86 |   return v
 87 | 
 88 | 
 89 | def split_graph_batch(triples, obj_data, obj_to_img, triple_to_img):
 90 |   triples = unpack_var(triples)
 91 |   obj_data = [unpack_var(o) for o in obj_data]
 92 |   obj_to_img = unpack_var(obj_to_img)
 93 |   triple_to_img = unpack_var(triple_to_img)
 94 | 
 95 |   triples_out = []
 96 |   obj_data_out = [[] for _ in obj_data]
 97 |   obj_offset = 0
 98 |   N = obj_to_img.max() + 1
 99 |   for i in range(N):
100 |     o_idxs = (obj_to_img == i).nonzero().view(-1)
101 |     t_idxs = (triple_to_img == i).nonzero().view(-1)
102 | 
103 |     cur_triples = triples[t_idxs].clone()
104 |     cur_triples[:, 0] -= obj_offset
105 |     cur_triples[:, 2] -= obj_offset
106 |     triples_out.append(cur_triples)
107 | 
108 |     for j, o_data in enumerate(obj_data):
109 |       cur_o_data = None
110 |       if o_data is not None:
111 |         cur_o_data = o_data[o_idxs]
112 |       obj_data_out[j].append(cur_o_data)
113 | 
114 |     obj_offset += o_idxs.size(0)
115 | 
116 |   return triples_out, obj_data_out
117 | 
118 | 
--------------------------------------------------------------------------------
/simsg/data/vg.py:
--------------------------------------------------------------------------------
  1 | #!/usr/bin/python
  2 | #
  3 | # Copyright 2018 Google LLC
  4 | # Modification copyright 2020 Helisa Dhamo
  5 | #
  6 | # Licensed under the Apache License, Version 2.0 (the "License");
  7 | # you may not use this file except in compliance with the License.
  8 | # You may obtain a copy of the License at
  9 | #
 10 | #      http://www.apache.org/licenses/LICENSE-2.0
 11 | #
 12 | # Unless required by applicable law or agreed to in writing, software
 13 | # distributed under the License is distributed on an "AS IS" BASIS,
 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 15 | # See the License for the specific language governing permissions and
 16 | # limitations under the License.
 17 | 
 18 | import os
 19 | import random
 20 | from collections import defaultdict
 21 | 
 22 | import torch
 23 | from torch.utils.data import Dataset
 24 | 
 25 | import torchvision.transforms as T
 26 | 
 27 | import numpy as np
 28 | import h5py
 29 | import PIL
 30 | 
 31 | from .utils import imagenet_preprocess, Resize
 32 | 
 33 | 
 34 | class SceneGraphNoPairsDataset(Dataset):
 35 |   def __init__(self, vocab, h5_path, image_dir, image_size=(256, 256),
 36 |                normalize_images=True, max_objects=10, max_samples=None,
 37 |                include_relationships=True, use_orphaned_objects=True,
 38 |                mode='train', clean_repeats=True, predgraphs=False):
 39 |     super(SceneGraphNoPairsDataset, self).__init__()
 40 | 
 41 |     assert mode in ["train", "eval", "auto_withfeats", "auto_nofeats", "reposition", "remove", "replace"]
 42 | 
 43 |     self.mode = mode
 44 | 
 45 |     self.image_dir = image_dir
 46 |     self.image_size = image_size
 47 |     self.vocab = vocab
 48 |     self.num_objects = len(vocab['object_idx_to_name'])
 49 |     self.use_orphaned_objects = use_orphaned_objects
 50 |     self.max_objects = max_objects
 51 |     self.max_samples = max_samples
 52 |     self.include_relationships = include_relationships
 53 | 
 54 |     self.evaluating = mode != 'train'
 55 |     self.predgraphs = predgraphs
 56 | 
 57 |     if self.mode == 'reposition':
 58 |       self.use_orphaned_objects = False
 59 | 
 60 |     self.clean_repeats = clean_repeats
 61 | 
 62 |     transform = [Resize(image_size), T.ToTensor()]
 63 |     if normalize_images:
 64 |       transform.append(imagenet_preprocess())
 65 |     self.transform = T.Compose(transform)
 66 | 
 67 |     self.data = {}
 68 |     with h5py.File(h5_path, 'r') as f:
 69 |       for k, v in f.items():
 70 |         if k == 'image_paths':
 71 |           self.image_paths = list(v)
 72 |         else:
 73 |           self.data[k] = torch.IntTensor(np.asarray(v))
 74 | 
 75 |   def __len__(self):
 76 |     num = self.data['object_names'].size(0)
 77 |     if self.max_samples is not None:
 78 |       return min(self.max_samples, num)
 79 |     return num
 80 | 
 81 |   def __getitem__(self, index):
 82 |     """
 83 |     Returns a tuple of:
 84 |     - image: FloatTensor of shape (C, H, W)
 85 |     - objs: LongTensor of shape (num_objs,)
 86 |     - boxes: FloatTensor of shape (num_objs, 4) giving boxes for objects in
 87 |       (x0, y0, x1, y1) format, in a [0, 1] coordinate system.
 88 |     - triples: LongTensor of shape (num_triples, 3) where triples[t] = [i, p, j]
 89 |       means that (objs[i], p, objs[j]) is a triple.
 90 |     """
 91 |     img_path = os.path.join(self.image_dir, self.image_paths[index])
 92 | 
 93 |     # use for the mix strings and bytes error
 94 |     #img_path = os.path.join(self.image_dir, self.image_paths[index].decode("utf-8"))
 95 | 
 96 |     with open(img_path, 'rb') as f:
 97 |       with PIL.Image.open(f) as image:
 98 |         WW, HH = image.size
 99 |         #print(WW, HH)
100 |         image = self.transform(image.convert('RGB'))
101 | 
102 |     H, W = self.image_size
103 | 
104 |     # Figure out which objects appear in relationships and which don't
105 |     obj_idxs_with_rels = set()
106 |     obj_idxs_without_rels = set(range(self.data['objects_per_image'][index].item()))
107 |     for r_idx in range(self.data['relationships_per_image'][index]):
108 |       s = self.data['relationship_subjects'][index, r_idx].item()
109 |       o = self.data['relationship_objects'][index, r_idx].item()
110 |       obj_idxs_with_rels.add(s)
111 |       obj_idxs_with_rels.add(o)
112 |       obj_idxs_without_rels.discard(s)
113 |       obj_idxs_without_rels.discard(o)
114 | 
115 |     obj_idxs = list(obj_idxs_with_rels)
116 |     obj_idxs_without_rels = list(obj_idxs_without_rels)
117 |     if len(obj_idxs) > self.max_objects - 1:
118 |       if self.evaluating:
119 |         obj_idxs = obj_idxs[:self.max_objects]
120 |       else:
121 |         obj_idxs = random.sample(obj_idxs, self.max_objects)
122 |     if len(obj_idxs) < self.max_objects - 1 and self.use_orphaned_objects:
123 |       num_to_add = self.max_objects - 1 - len(obj_idxs)
124 |       num_to_add = min(num_to_add, len(obj_idxs_without_rels))
125 |       if self.evaluating:
126 |         obj_idxs += obj_idxs_without_rels[:num_to_add]
127 |       else:
128 |         obj_idxs += random.sample(obj_idxs_without_rels, num_to_add)
129 |     if len(obj_idxs) == 0 and not self.use_orphaned_objects:
130 |       # avoid empty list of objects
131 |       obj_idxs += obj_idxs_without_rels[:1]
132 |     map_overlapping_obj = {}
133 | 
134 |     objs = []
135 |     boxes = []
136 | 
137 |     obj_idx_mapping = {}
138 |     counter = 0
139 |     for i, obj_idx in enumerate(obj_idxs):
140 | 
141 |       curr_obj = self.data['object_names'][index, obj_idx].item()
142 |       x, y, w, h = self.data['object_boxes'][index, obj_idx].tolist()
143 | 
144 |       x0 = float(x) / WW
145 |       y0 = float(y) / HH
146 |       if self.predgraphs:
147 |         x1 = float(w) / WW
148 |         y1 = float(h) / HH
149 |       else:
150 |         x1 = float(x + w) / WW
151 |         y1 = float(y + h) / HH
152 | 
153 |       curr_box = torch.FloatTensor([x0, y0, x1, y1])
154 | 
155 |       found_overlap = False
156 |       if self.predgraphs:
157 |         for prev_idx in range(counter):
158 |           if overlapping_nodes(objs[prev_idx], curr_obj, boxes[prev_idx], curr_box):
159 |             map_overlapping_obj[i] = prev_idx
160 |             found_overlap = True
161 |             break
162 |       if not found_overlap:
163 | 
164 |         objs.append(curr_obj)
165 |         boxes.append(curr_box)
166 |         map_overlapping_obj[i] = counter
167 |         counter += 1
168 | 
169 |       obj_idx_mapping[obj_idx] = map_overlapping_obj[i]
170 | 
171 |     # The last object will be the special __image__ object
172 |     objs.append(self.vocab['object_name_to_idx']['__image__'])
173 |     boxes.append(torch.FloatTensor([0, 0, 1, 1]))
174 | 
175 |     boxes = torch.stack(boxes)
176 |     objs = torch.LongTensor(objs)
177 |     num_objs = counter + 1
178 | 
179 |     triples = []
180 |     for r_idx in range(self.data['relationships_per_image'][index].item()):
181 |       if not self.include_relationships:
182 |         break
183 |       s = self.data['relationship_subjects'][index, r_idx].item()
184 |       p = self.data['relationship_predicates'][index, r_idx].item()
185 |       o = self.data['relationship_objects'][index, r_idx].item()
186 |       s = obj_idx_mapping.get(s, None)
187 |       o = obj_idx_mapping.get(o, None)
188 |       if s is not None and o is not None:
189 |         if self.clean_repeats and [s, p, o] in triples:
190 |           continue
191 |         if self.predgraphs and s == o:
192 |           continue
193 |         triples.append([s, p, o])
194 | 
195 |     # Add dummy __in_image__ relationships for all objects
196 |     in_image = self.vocab['pred_name_to_idx']['__in_image__']
197 |     for i in range(num_objs - 1):
198 |       triples.append([i, in_image, num_objs - 1])
199 | 
200 |     triples = torch.LongTensor(triples)
201 |     return image, objs, boxes, triples
202 | 
203 | 
204 | def collate_fn_nopairs(batch):
205 |   """
206 |   Collate function to be used when wrapping a SceneGraphNoPairsDataset in a
207 |   DataLoader. Returns a tuple of the following:
208 | 
209 |   - imgs: FloatTensor of shape (N, 3, H, W)
210 |   - objs: LongTensor of shape (num_objs,) giving categories for all objects
211 |   - boxes: FloatTensor of shape (num_objs, 4) giving boxes for all objects
212 |   - triples: FloatTensor of shape (num_triples, 3) giving all triples, where
213 |     triples[t] = [i, p, j] means that [objs[i], p, objs[j]] is a triple
214 |   - obj_to_img: LongTensor of shape (num_objs,) mapping objects to images;
215 |     obj_to_img[i] = n means that objs[i] belongs to imgs[n]
216 |   - triple_to_img: LongTensor of shape (num_triples,) mapping triples to images;
217 |     triple_to_img[t] = n means that triples[t] belongs to imgs[n]
218 |   - imgs_masked: FloatTensor of shape (N, 4, H, W)
219 |   """
220 |   # batch is a list, and each element is (image, objs, boxes, triples)
221 |   all_imgs, all_objs, all_boxes, all_triples = [], [], [], []
222 |   all_obj_to_img, all_triple_to_img = [], []
223 | 
224 |   all_imgs_masked = []
225 | 
226 |   obj_offset = 0
227 | 
228 |   for i, (img, objs, boxes, triples) in enumerate(batch):
229 | 
230 |     all_imgs.append(img[None])
231 |     num_objs, num_triples = objs.size(0), triples.size(0)
232 | 
233 |     all_objs.append(objs)
234 |     all_boxes.append(boxes)
235 |     triples = triples.clone()
236 | 
237 |     triples[:, 0] += obj_offset
238 |     triples[:, 2] += obj_offset
239 | 
240 |     all_triples.append(triples)
241 | 
242 |     all_obj_to_img.append(torch.LongTensor(num_objs).fill_(i))
243 |     all_triple_to_img.append(torch.LongTensor(num_triples).fill_(i))
244 | 
245 |     # prepare input 4-channel image
246 |     # initialize mask channel with zeros
247 |     masked_img = img.clone()
248 |     mask = torch.zeros_like(masked_img)
249 |     mask = mask[0:1,:,:]
250 |     masked_img = torch.cat([masked_img, mask], 0)
251 |     all_imgs_masked.append(masked_img[None])
252 | 
253 |     obj_offset += num_objs
254 | 
255 |   all_imgs_masked = torch.cat(all_imgs_masked)
256 | 
257 |   all_imgs = torch.cat(all_imgs)
258 |   all_objs = torch.cat(all_objs)
259 |   all_boxes = torch.cat(all_boxes)
260 |   all_triples = torch.cat(all_triples)
261 |   all_obj_to_img = torch.cat(all_obj_to_img)
262 |   all_triple_to_img = torch.cat(all_triple_to_img)
263 | 
264 |   return all_imgs, all_objs, all_boxes, all_triples, \
265 |          all_obj_to_img, all_triple_to_img, all_imgs_masked
266 | 
267 | 
268 | from simsg.model import get_left_right_top_bottom
269 | 
270 | 
271 | def overlapping_nodes(obj1, obj2, box1, box2, criteria=0.7):
272 |   # used to clean predicted graphs - merge nodes with overlapping boxes
273 |   # are these two objects overplapping?
274 |   # boxes given as [left, top, right, bottom]
275 |   res = 100 # used to project box representation in 2D for iou computation
276 |   epsilon = 0.001
277 |   if obj1 == obj2:
278 |     spatial_box1 = np.zeros([res, res])
279 |     left, right, top, bottom = get_left_right_top_bottom(box1, res, res)
280 |     spatial_box1[top:bottom, left:right] = 1
281 |     spatial_box2 = np.zeros([res, res])
282 |     left, right, top, bottom = get_left_right_top_bottom(box2, res, res)
283 |     spatial_box2[top:bottom, left:right] = 1
284 |     iou = np.sum(spatial_box1 * spatial_box2) / \
285 |           (np.sum((spatial_box1 + spatial_box2 > 0).astype(np.float32)) + epsilon)
286 |     return iou >= criteria
287 |   else:
288 |     return False
289 | 
--------------------------------------------------------------------------------
/simsg/decoder.py:
--------------------------------------------------------------------------------
  1 | #!/usr/bin/python
  2 | #
  3 | # Copyright 2018 Google LLC
  4 | # Modification copyright 2020 Helisa Dhamo
  5 | #
  6 | # Licensed under the Apache License, Version 2.0 (the "License");
  7 | # you may not use this file except in compliance with the License.
  8 | # You may obtain a copy of the License at
  9 | #
 10 | #      http://www.apache.org/licenses/LICENSE-2.0
 11 | #
 12 | # Unless required by applicable law or agreed to in writing, software
 13 | # distributed under the License is distributed on an "AS IS" BASIS,
 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 15 | # See the License for the specific language governing permissions and
 16 | # limitations under the License.
 17 | 
 18 | import torch
 19 | import torch.nn as nn
 20 | import torch.nn.functional as F
 21 | 
 22 | from simsg.layers import get_normalization_2d
 23 | from simsg.layers import get_activation
 24 | 
 25 | from simsg.SPADE.normalization import SPADE
 26 | import torch.nn.utils.spectral_norm as spectral_norm
 27 | 
 28 | 
 29 | class DecoderNetwork(nn.Module):
 30 |   """
 31 |   Decoder Network that generates a target image from a pair of masked source image and layout
 32 |   Implemented in two options: with a CRN block or a SPADE block
 33 |   """
 34 | 
 35 |   def __init__(self, dims, normalization='instance', activation='leakyrelu', spade_blocks=False, source_image_dims=32):
 36 |     super(DecoderNetwork, self).__init__()
 37 | 
 38 |     self.spade_block = spade_blocks
 39 |     self.source_image_dims = source_image_dims
 40 | 
 41 |     layout_dim = dims[0]
 42 |     self.decoder_modules = nn.ModuleList()
 43 |     for i in range(1, len(dims)):
 44 |       input_dim = 1 if i == 1 else dims[i - 1]
 45 |       output_dim = dims[i]
 46 | 
 47 |       if self.spade_block:
 48 |         # Resnet SPADE block
 49 |         mod = SPADEResnetBlock(input_dim, output_dim, layout_dim-self.source_image_dims, self.source_image_dims)
 50 | 
 51 |       else:
 52 |         # CRN block
 53 |         mod = CRNBlock(layout_dim, input_dim, output_dim,
 54 |                        normalization=normalization, activation=activation)
 55 | 
 56 |       self.decoder_modules.append(mod)
 57 | 
 58 |     output_conv_layers = [
 59 |       nn.Conv2d(dims[-1], dims[-1], kernel_size=3, padding=1),
 60 |       get_activation(activation),
 61 |       nn.Conv2d(dims[-1], 3, kernel_size=1, padding=0)
 62 |     ]
 63 |     nn.init.kaiming_normal_(output_conv_layers[0].weight)
 64 |     nn.init.kaiming_normal_(output_conv_layers[2].weight)
 65 |     self.output_conv = nn.Sequential(*output_conv_layers)
 66 | 
 67 |   def forward(self, layout):
 68 |     """
 69 |     Output will have same size as layout
 70 |     """
 71 |     # H, W = self.output_size
 72 |     N, _, H, W = layout.size()
 73 |     self.layout = layout
 74 | 
 75 |     # Figure out size of input
 76 |     input_H, input_W = H, W
 77 |     for _ in range(len(self.decoder_modules)):
 78 |       input_H //= 2
 79 |       input_W //= 2
 80 | 
 81 |     assert input_H != 0
 82 |     assert input_W != 0
 83 | 
 84 |     feats = torch.zeros(N, 1, input_H, input_W).to(layout)
 85 |     for mod in self.decoder_modules:
 86 |       feats = F.upsample(feats, scale_factor=2, mode='nearest')
 87 |       #print(layout.shape)
 88 |       feats = mod(layout, feats)
 89 | 
 90 |     out = self.output_conv(feats)
 91 | 
 92 |     return out
 93 | 
 94 | 
 95 | class CRNBlock(nn.Module):
 96 |   """
 97 |   Cascaded refinement network (CRN) block, as described in:
 98 |   Qifeng Chen and Vladlen Koltun,
 99 |   "Photographic Image Synthesis with Cascaded Refinement Networks",
100 |   ICCV 2017
101 |   """
102 | 
103 |   def __init__(self, layout_dim, input_dim, output_dim,
104 |                normalization='instance', activation='leakyrelu'):
105 |     super(CRNBlock, self).__init__()
106 | 
107 |     layers = []
108 | 
109 |     layers.append(nn.Conv2d(layout_dim + input_dim, output_dim,
110 |                             kernel_size=3, padding=1))
111 |     layers.append(get_normalization_2d(output_dim, normalization))
112 |     layers.append(get_activation(activation))
113 |     layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1))
114 |     layers.append(get_normalization_2d(output_dim, normalization))
115 |     layers.append(get_activation(activation))
116 |     layers = [layer for layer in layers if layer is not None]
117 |     for layer in layers:
118 |       if isinstance(layer, nn.Conv2d):
119 |         nn.init.kaiming_normal_(layer.weight)
120 |     self.net = nn.Sequential(*layers)
121 | 
122 |   def forward(self, layout, feats):
123 |     _, CC, HH, WW = layout.size()
124 |     _, _, H, W = feats.size()
125 |     assert HH >= H
126 |     if HH > H:
127 |       factor = round(HH // H)
128 |       assert HH % factor == 0
129 |       assert WW % factor == 0 and WW // factor == W
130 |       layout = F.avg_pool2d(layout, kernel_size=factor, stride=factor)
131 | 
132 |     net_input = torch.cat([layout, feats], dim=1)
133 |     out = self.net(net_input)
134 |     return out
135 | 
136 | 
137 | class SPADEResnetBlock(nn.Module):
138 |     """
139 |     ResNet block used in SPADE.
140 |     It differs from the ResNet block of pix2pixHD in that
141 |     it takes in the segmentation map as input, learns the skip connection if necessary,
142 |     and applies normalization first and then convolution.
143 |     This architecture seemed like a standard architecture for unconditional or
144 |     class-conditional GAN architecture using residual block.
145 |     The code was inspired from https://github.com/LMescheder/GAN_stability.
146 |     """
147 | 
148 |     def __init__(self, fin, fout, seg_nc, src_nc, spade_config_str='spadebatch3x3', spectral=True):
149 |         super().__init__()
150 |         # Attributes
151 |         self.learned_shortcut = (fin != fout)
152 |         fmiddle = min(fin, fout)
153 |         self.src_nc = src_nc
154 | 
155 |         # create conv layers
156 |         self.conv_0 = nn.Conv2d(fin+self.src_nc, fmiddle, kernel_size=3, padding=1)
157 |         self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
158 |         if self.learned_shortcut:
159 |             self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
160 | 
161 |         # apply spectral norm if specified
162 |         if spectral:
163 |             self.conv_0 = spectral_norm(self.conv_0)
164 |             self.conv_1 = spectral_norm(self.conv_1)
165 |             if self.learned_shortcut:
166 |                 self.conv_s = spectral_norm(self.conv_s)
167 | 
168 |         # define normalization layers
169 |         self.norm_0 = SPADE(spade_config_str, fin, seg_nc)
170 |         self.norm_1 = SPADE(spade_config_str, fmiddle, seg_nc)
171 |         if self.learned_shortcut:
172 |             self.norm_s = SPADE(spade_config_str, fin, seg_nc)
173 | 
174 |     # note the resnet block with SPADE also takes in |seg|,
175 |     # the semantic segmentation map as input
176 |     def forward(self, seg_, x):
177 | 
178 |         seg_ = F.interpolate(seg_, size=x.size()[2:], mode='nearest')
179 | 
180 |         # only use the layout map as input to SPADE norm (not the source image channels)
181 |         layout_only_dim = seg_.size(1) - self.src_nc
182 |         in_img = seg_[:, layout_only_dim:, :, :]
183 |         seg = seg_[:, :layout_only_dim, :, :]
184 | 
185 |         x_s = self.shortcut(x, seg)
186 |         dx = torch.cat([self.norm_0(x, seg), in_img],1)
187 |         dx = self.conv_0(self.actvn(dx))
188 |         dx = self.conv_1(self.actvn(self.norm_1(dx, seg)))
189 | 
190 |         out = x_s + dx
191 | 
192 |         return out
193 | 
194 |     def shortcut(self, x, seg):
195 |         if self.learned_shortcut:
196 |             x_s = self.conv_s(self.norm_s(x, seg))
197 |         else:
198 |             x_s = x
199 |         return x_s
200 | 
201 |     def actvn(self, x):
202 |         return F.leaky_relu(x, 2e-1)
203 | 
--------------------------------------------------------------------------------
/simsg/discriminators.py:
--------------------------------------------------------------------------------
  1 | #!/usr/bin/python
  2 | #
  3 | # Copyright 2018 Google LLC
  4 | # Modification copyright 2020 Helisa Dhamo
  5 | #
  6 | # Licensed under the Apache License, Version 2.0 (the "License");
  7 | # you may not use this file except in compliance with the License.
  8 | # You may obtain a copy of the License at
  9 | #
 10 | #      http://www.apache.org/licenses/LICENSE-2.0
 11 | #
 12 | # Unless required by applicable law or agreed to in writing, software
 13 | # distributed under the License is distributed on an "AS IS" BASIS,
 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 15 | # See the License for the specific language governing permissions and
 16 | # limitations under the License.
 17 | 
 18 | import torch
 19 | import torch.nn as nn
 20 | import torch.nn.functional as F
 21 | 
 22 | from simsg.bilinear import crop_bbox_batch
 23 | from simsg.layers import GlobalAvgPool, Flatten, get_activation, build_cnn
 24 | import numpy as np
 25 | 
 26 | '''
 27 | visualizations for debugging
 28 | import cv2
 29 | import numpy as np
 30 | from simsg.data import imagenet_deprocess_batch
 31 | '''
 32 | 
 33 | class PatchDiscriminator(nn.Module):
 34 |   def __init__(self, arch, normalization='batch', activation='leakyrelu-0.2',
 35 |                padding='same', pooling='avg', input_size=(128,128),
 36 |                layout_dim=0):
 37 |     super(PatchDiscriminator, self).__init__()
 38 |     input_dim = 3 + layout_dim
 39 |     arch = 'I%d,%s' % (input_dim, arch)
 40 |     cnn_kwargs = {
 41 |       'arch': arch,
 42 |       'normalization': normalization,
 43 |       'activation': activation,
 44 |       'pooling': pooling,
 45 |       'padding': padding,
 46 |     }
 47 |     self.cnn, output_dim = build_cnn(**cnn_kwargs)
 48 |     self.classifier = nn.Conv2d(output_dim, 1, kernel_size=1, stride=1)
 49 | 
 50 |   def forward(self, x, layout=None):
 51 |     if layout is not None:
 52 |       x = torch.cat([x, layout], dim=1)
 53 | 
 54 |     out = self.cnn(x)
 55 | 
 56 |     discr_layers = []
 57 |     i = 0
 58 |     for l in self.cnn.children():
 59 |       x = l(x)
 60 | 
 61 |       if i % 3 == 0:
 62 |         discr_layers.append(x)
 63 |         #print("shape: ", x.shape, l)
 64 |       i += 1
 65 | 
 66 |     return out, discr_layers
 67 | 
 68 | 
 69 | class AcDiscriminator(nn.Module):
 70 |   def __init__(self, vocab, arch, normalization='none', activation='relu',
 71 |                padding='same', pooling='avg'):
 72 |     super(AcDiscriminator, self).__init__()
 73 |     self.vocab = vocab
 74 | 
 75 |     cnn_kwargs = {
 76 |       'arch': arch,
 77 |       'normalization': normalization,
 78 |       'activation': activation,
 79 |       'pooling': pooling, 
 80 |       'padding': padding,
 81 |     }
 82 |     self.cnn_body, D = build_cnn(**cnn_kwargs)
 83 |     #print(D)
 84 |     self.cnn = nn.Sequential(self.cnn_body, GlobalAvgPool(), nn.Linear(D, 1024))
 85 |     num_objects = len(vocab['object_idx_to_name'])
 86 | 
 87 |     self.real_classifier = nn.Linear(1024, 1)
 88 |     self.obj_classifier = nn.Linear(1024, num_objects)
 89 | 
 90 |   def forward(self, x, y):
 91 |     if x.dim() == 3:
 92 |       x = x[:, None]
 93 |     vecs = self.cnn(x)
 94 |     real_scores = self.real_classifier(vecs)
 95 |     obj_scores = self.obj_classifier(vecs)
 96 |     ac_loss = F.cross_entropy(obj_scores, y)
 97 | 
 98 |     discr_layers = []
 99 |     i = 0
100 |     for l in self.cnn_body.children():
101 |       x = l(x)
102 | 
103 |       if i % 3 == 0:
104 |         discr_layers.append(x)
105 |         #print("shape: ", x.shape, l)
106 |       i += 1
107 |     #print(len(discr_layers))
108 |     return real_scores, ac_loss, discr_layers
109 | 
110 | 
111 | class AcCropDiscriminator(nn.Module):
112 |   def __init__(self, vocab, arch, normalization='none', activation='relu',
113 |                object_size=64, padding='same', pooling='avg'):
114 |     super(AcCropDiscriminator, self).__init__()
115 |     self.vocab = vocab
116 |     self.discriminator = AcDiscriminator(vocab, arch, normalization,
117 |                                          activation, padding, pooling)
118 |     self.object_size = object_size
119 | 
120 |   def forward(self, imgs, objs, boxes, obj_to_img):
121 |     crops = crop_bbox_batch(imgs, boxes, obj_to_img, self.object_size)
122 |     real_scores, ac_loss, discr_layers = self.discriminator(crops, objs)
123 |     return real_scores, ac_loss, discr_layers
124 | 
125 | 
126 | # Multi-scale discriminator as in pix2pixHD or SPADE
127 | class MultiscaleDiscriminator(nn.Module):
128 |     def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
129 |                  use_sigmoid=False, num_D=3):
130 |         super(MultiscaleDiscriminator, self).__init__()
131 |         self.num_D = num_D
132 |         self.n_layers = n_layers
133 | 
134 |         for i in range(num_D):
135 |             netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid)
136 |             for j in range(n_layers + 2):
137 |                 setattr(self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j)))
138 | 
139 |         self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
140 | 
141 |     def singleD_forward(self, model, input):
142 |         result = [input]
143 |         for i in range(len(model)):
144 |             result.append(model[i](result[-1]))
145 |         return result[1:]
146 | 
147 |     def forward(self, input):
148 |         num_D = self.num_D
149 |         result = []
150 |         input_downsampled = input
151 |         for i in range(num_D):
152 |             model = [getattr(self, 'scale' + str(num_D - 1 - i) + '_layer' + str(j)) for j in
153 |                      range(self.n_layers + 2)]
154 |             result.append(self.singleD_forward(model, input_downsampled))
155 |             if i != (num_D - 1):
156 |                 input_downsampled = self.downsample(input_downsampled)
157 |         return result
158 | 
159 | # Defines the PatchGAN discriminator with the specified arguments.
160 | class NLayerDiscriminator(nn.Module):
161 |     def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
162 |         super(NLayerDiscriminator, self).__init__()
163 |         self.n_layers = n_layers
164 | 
165 |         kw = 4
166 |         padw = int(np.ceil((kw - 1.0) / 2))
167 |         sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
168 | 
169 |         nf = ndf
170 |         for n in range(1, n_layers):
171 |             nf_prev = nf
172 |             nf = min(nf * 2, 512)
173 |             sequence += [[
174 |                 nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
175 |                 norm_layer(nf), nn.LeakyReLU(0.2, True)
176 |             ]]
177 | 
178 |         nf_prev = nf
179 |         nf = min(nf * 2, 512)
180 |         sequence += [[
181 |             nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
182 |             norm_layer(nf),
183 |             nn.LeakyReLU(0.2, True)
184 |         ]]
185 | 
186 |         sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
187 | 
188 |         if use_sigmoid:
189 |             sequence += [[nn.Sigmoid()]]
190 | 
191 |         for n in range(len(sequence)):
192 |             setattr(self, 'model' + str(n), nn.Sequential(*sequence[n]))
193 | 
194 |     def forward(self, input):
195 |         res = [input]
196 |         for n in range(self.n_layers + 2):
197 |             model = getattr(self, 'model' + str(n))
198 |             res.append(model(res[-1]))
199 |         return res[1:]
200 | 
201 | 
202 | def divide_pred(pred):
203 |   # the prediction contains the intermediate outputs of multiscale GAN,
204 |   # so it's usually a list
205 |   if type(pred) == list:
206 |       fake = []
207 |       real = []
208 |       for p in pred:
209 |           fake.append([tensor[:tensor.size(0) // 2] for tensor in p])
210 |           real.append([tensor[tensor.size(0) // 2:] for tensor in p])
211 |   else:
212 |       fake = pred[:pred.size(0) // 2]
213 |       real = pred[pred.size(0) // 2:]
214 | 
215 |   return fake, real
216 | 
--------------------------------------------------------------------------------
/simsg/feats_statistics.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | 
 3 | def get_mean(spade=True):
 4 |   if spade:
 5 |     return torch.tensor([-1.6150e-02, -2.0104e-01, -1.8967e-02,  8.6773e-01,  1.1788e-01,
 6 |         -3.7955e-02, -2.3504e-01, -7.6040e-02,  8.8244e-01,  7.9869e-02,
 7 |          4.8099e-01, -5.5241e-02, -9.3631e-02,  6.2165e-01,  3.5968e-01,
 8 |         -2.2090e-01,  1.5325e-02, -2.5600e-01,  6.4510e-01, -3.2229e-01,
 9 |         -9.3363e-03, -2.3452e-01, -2.9875e-01, -5.9847e-02,  5.0070e-01,
10 |         -1.3385e-02, -4.1650e-02,  4.7214e-02,  1.5526e+00, -1.1106e+00,
11 |          4.0959e-01, -2.2214e-01,  5.8195e-02,  3.5044e-01, -4.1516e-01,
12 |         -2.6164e-01,  3.7791e-01,  1.6697e-01,  2.4359e-01,  3.7218e-01,
13 |          1.3322e-01,  2.2361e-01, -4.0649e-01,  1.4112e+00,  6.2109e-01,
14 |         -8.9414e-01,  2.4960e-01, -2.2291e-02,  2.5344e-01, -1.1063e-01,
15 |          2.0111e-01,  3.0083e-01, -4.5993e-01,  1.1597e+00,  6.1391e-02,
16 |          3.8579e-01, -1.8961e-02,  4.3253e-01,  3.1550e-01,  5.1039e-02,
17 |         -2.0387e-02,  3.7300e-01,  1.3172e-01,  3.1559e-01,  7.0767e-02,
18 |          4.1030e-01,  2.6682e-01,  3.7454e-01,  1.5960e-01,  2.3767e-01,
19 |          2.1569e-01,  4.0779e-01,  1.8256e-01,  2.3073e-01,  3.6593e-01,
20 |         -1.2173e-02,  3.2893e-01,  1.8276e-01, -2.5898e-01,  6.1171e-01,
21 |          5.7514e-01, -3.9560e-02,  1.3200e-01,  4.2561e-01,  1.0145e-01,
22 |          6.3877e-01,  1.5947e-01, -4.3896e-01, -5.0542e-01,  4.7463e-01,
23 |         -1.2649e-01,  2.4283e-01,  2.0448e-02,  1.4343e-01,  3.9534e-04,
24 |         -5.2192e-02,  8.2538e-01,  2.8621e-01,  1.7567e-01,  9.0932e-02,
25 |         -5.3764e-01,  3.6047e-01,  2.3840e-01,  4.0529e-01, -2.9391e-02,
26 |         -3.1827e-01,  1.7440e-01,  8.4958e-01, -7.6857e-02,  1.1439e-01,
27 |          3.8213e-01, -1.9179e-01,  6.5607e-01,  1.1372e+00,  2.9348e-01,
28 |          2.1763e-02,  4.9819e-01, -2.3046e-02, -1.5268e-01, -2.3093e-01,
29 |          1.5162e-02,  7.3744e-02,  1.4607e-01,  1.4677e-01,  1.0958e-02,
30 |          1.6972e-01,  2.8153e-01, -4.3338e+00])
31 |   else:
32 |     return torch.tensor([ 7.7237e-02,  1.4599e-01, -3.2086e-02,  1.2272e-02,  1.3585e-01,
33 |          2.2647e-01,  1.5775e-01,  8.7155e-02,  2.8425e-01,  9.3309e-01,
34 |         -2.8251e+00,  4.8397e-01,  1.6812e-01, -4.3176e-01,  3.5545e-01,
35 |          1.8081e-01,  1.2170e+00,  3.7712e-01,  3.6601e-01,  2.4548e-02,
36 |         -8.9017e-02,  2.1669e-01,  5.5919e-01,  5.3623e-01,  1.8667e-01,
37 |          7.7680e-01,  1.1192e-01,  9.0081e-02,  1.8276e-01,  5.6268e-01,
38 |          3.5646e-01, -1.8525e-02,  4.7276e-01,  1.3664e-01, -2.4958e-01,
39 |          1.3064e-01,  1.0058e-01,  2.3417e-01,  2.4552e-01,  1.3849e-01,
40 |          1.0150e+00,  2.3932e-01, -7.0545e-01,  3.6638e-01,  1.8530e-02,
41 |          5.1259e-01, -1.0165e-01,  1.1289e-01, -1.1383e-02,  7.0450e-02,
42 |          1.5806e-01, -1.4862e-01,  4.2865e-02,  2.7360e-01,  2.3326e-01,
43 |          2.2674e-01, -1.0260e-02,  1.8777e-01,  2.2641e-01, -2.4887e-01,
44 |          2.9636e-01,  4.9020e-01, -3.6557e-01,  2.2225e-01,  6.5004e-02,
45 |         -4.7235e-02,  1.6490e-02,  3.0804e-02,  2.5871e-01,  3.1599e-05,
46 |          2.9948e-01, -1.2801e+00,  2.6655e-03,  3.6675e-01,  2.4675e-01,
47 |          2.0338e-01,  4.9576e-01,  1.9530e-01, -5.4172e-02,  1.8044e-01,
48 |          2.9627e-01,  8.3459e-02,  1.0668e-01,  1.0449e-01,  1.8324e-01,
49 |         -1.3553e-01, -7.5838e-02, -3.9355e-01,  5.1458e-02,  4.3815e-01,
50 |          5.8406e-02,  2.2365e-01, -2.3227e-02,  2.6278e-01, -1.5512e-01,
51 |          3.9810e-01,  3.7780e-01,  4.3622e-01,  2.1492e-01, -3.2092e-02,
52 |          1.4565e-01,  5.7963e-02,  2.4176e-01, -8.1265e-02,  2.5032e-01,
53 |         -6.4307e-04,  5.2379e-02,  3.4459e-01,  1.3226e-01,  1.3777e-01,
54 |          3.3404e-02,  2.8588e-01, -1.7163e-01, -3.7610e-02,  8.4848e-02,
55 |          5.9351e-01, -4.0149e-02,  1.2884e-01,  3.8397e-02, -2.2867e-01,
56 |          3.7894e-02,  2.0033e-01, -6.8478e-02, -1.3748e-01, -2.1313e-02,
57 |          1.4798e-01, -7.1153e-02,  2.5109e-01])
58 | 
59 | def get_std(spade=True):
60 | 
61 |   if spade:
62 |     return torch.tensor([0.7377, 0.8760, 0.5966, 0.9396, 0.5330, 0.6746, 0.5873, 0.8002, 0.8464,
63 |         0.6066, 0.7404, 0.6880, 0.8290, 0.8517, 1.0168, 0.6587, 0.6910, 0.7041,
64 |         0.7083, 0.8597, 0.6041, 0.7032, 0.4664, 0.5939, 0.9299, 0.6339, 0.6201,
65 |         0.6464, 1.3569, 0.9664, 0.8113, 0.7645, 0.7036, 0.6485, 0.8178, 0.5965,
66 |         0.5853, 0.9413, 0.5311, 0.6869, 0.7178, 0.4459, 0.6768, 1.0432, 0.5735,
67 |         1.4332, 0.7651, 0.5793, 0.5602, 0.5846, 0.6134, 1.0111, 1.0382, 0.8546,
68 |         0.8659, 0.6131, 0.5885, 0.5515, 0.6286, 0.6191, 0.7734, 0.6184, 0.6307,
69 |         0.6496, 0.8436, 0.6631, 0.5839, 0.6096, 0.8167, 0.6743, 0.5774, 0.5412,
70 |         0.5770, 0.6273, 0.5946, 0.5786, 0.6149, 0.7487, 0.7289, 0.5467, 0.9170,
71 |         0.7468, 0.7206, 0.6468, 0.6711, 0.6553, 0.5945, 0.7166, 0.6544, 0.7168,
72 |         0.6667, 0.6720, 0.8059, 0.6146, 0.6778, 0.7143, 0.9330, 0.5723, 0.6583,
73 |         0.5033, 0.8432, 0.6200, 0.5862, 0.6239, 0.6245, 0.6621, 0.7231, 0.8397,
74 |         0.6194, 0.5876, 0.6404, 0.5165, 0.8618, 1.1337, 0.7873, 0.6999, 0.7077,
75 |         0.6124, 0.7833, 0.6646, 0.6600, 0.6348, 0.5936, 0.5906, 0.5752, 0.7991,
76 |         0.7337, 3.8573])
77 | 
78 |   else:
79 |     return torch.tensor([0.5562, 0.4463, 0.6239, 0.5942, 0.5989, 0.5937, 0.5813, 0.6217, 0.6885,
80 |         0.8446, 2.4218, 1.5264, 0.6207, 0.5945, 0.6178, 0.5321, 0.9981, 0.6412,
81 |         0.5612, 0.6799, 0.5956, 0.6331, 0.6156, 0.7085, 0.5387, 0.7024, 0.5395,
82 |         0.6156, 0.5126, 0.5616, 0.5330, 0.6228, 0.5880, 0.6474, 0.5386, 0.5761,
83 |         0.5934, 0.5818, 0.5054, 0.6030, 0.8313, 0.5675, 0.5970, 0.5890, 0.5287,
84 |         0.5437, 0.5739, 0.5770, 0.5374, 0.5381, 0.6655, 0.5938, 0.6502, 0.5945,
85 |         0.6889, 0.5691, 0.6278, 0.5376, 0.5919, 0.6407, 0.6559, 0.5942, 0.5599,
86 |         0.6238, 0.5494, 0.5596, 0.5861, 0.5741, 0.6071, 0.5206, 0.5780, 0.9384,
87 |         0.4894, 0.6108, 0.6441, 0.6012, 0.6952, 0.6326, 0.4971, 0.5562, 0.5494,
88 |         0.5879, 0.5013, 0.5992, 0.5527, 0.6322, 0.5842, 0.5900, 0.6353, 0.5606,
89 |         0.6369, 0.5970, 0.5347, 0.6015, 0.5133, 0.6209, 0.6077, 0.7290, 0.5833,
90 |         0.5555, 0.5780, 0.6566, 0.5696, 0.5394, 0.5386, 0.5731, 0.5225, 0.5397,
91 |         0.5517, 0.6082, 0.6007, 0.5728, 0.6639, 0.5972, 0.6115, 0.6083, 0.5304,
92 |         0.5828, 0.6301, 0.5566, 0.6096, 0.6493, 0.5196, 0.6479, 0.6541, 0.5875,
93 |         0.5701, 0.6249])
94 | 
--------------------------------------------------------------------------------
/simsg/graph.py:
--------------------------------------------------------------------------------
  1 | #!/usr/bin/python
  2 | #
  3 | # Copyright 2018 Google LLC
  4 | #
  5 | # Licensed under the Apache License, Version 2.0 (the "License");
  6 | # you may not use this file except in compliance with the License.
  7 | # You may obtain a copy of the License at
  8 | #
  9 | #      http://www.apache.org/licenses/LICENSE-2.0
 10 | #
 11 | # Unless required by applicable law or agreed to in writing, software
 12 | # distributed under the License is distributed on an "AS IS" BASIS,
 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14 | # See the License for the specific language governing permissions and
 15 | # limitations under the License.
 16 | 
 17 | import torch
 18 | import torch.nn as nn
 19 | from simsg.layers import build_mlp
 20 | 
 21 | """
 22 | PyTorch modules for dealing with graphs.
 23 | """
 24 | 
 25 | def _init_weights(module):
 26 |   if hasattr(module, 'weight'):
 27 |     if isinstance(module, nn.Linear):
 28 |       nn.init.kaiming_normal_(module.weight)
 29 | 
 30 | class GraphTripleConv(nn.Module):
 31 |   """
 32 |   A single layer of scene graph convolution.
 33 |   """
 34 |   def __init__(self, input_dim_obj, input_dim_pred, output_dim=None, hidden_dim=512,
 35 |                pooling='avg', mlp_normalization='none'):
 36 |     super(GraphTripleConv, self).__init__()
 37 |     if output_dim is None:
 38 |       output_dim = input_dim_obj
 39 |     self.input_dim_obj = input_dim_obj
 40 |     self.input_dim_pred = input_dim_pred
 41 |     self.output_dim = output_dim
 42 |     self.hidden_dim = hidden_dim
 43 | 
 44 |     assert pooling in ['sum', 'avg'], 'Invalid pooling "%s"' % pooling
 45 | 
 46 |     self.pooling = pooling
 47 |     net1_layers = [2 * input_dim_obj + input_dim_pred, hidden_dim, 2 * hidden_dim + output_dim]
 48 |     net1_layers = [l for l in net1_layers if l is not None]
 49 |     self.net1 = build_mlp(net1_layers, batch_norm=mlp_normalization)
 50 |     self.net1.apply(_init_weights)
 51 |     
 52 |     net2_layers = [hidden_dim, hidden_dim, output_dim]
 53 |     self.net2 = build_mlp(net2_layers, batch_norm=mlp_normalization)
 54 |     self.net2.apply(_init_weights)
 55 | 
 56 | 
 57 |   def forward(self, obj_vecs, pred_vecs, edges):
 58 |     """
 59 |     Inputs:
 60 |     - obj_vecs: FloatTensor of shape (num_objs, D) giving vectors for all objects
 61 |     - pred_vecs: FloatTensor of shape (num_triples, D) giving vectors for all predicates
 62 |     - edges: LongTensor of shape (num_triples, 2) where edges[k] = [i, j] indicates the
 63 |       presence of a triple [obj_vecs[i], pred_vecs[k], obj_vecs[j]]
 64 |     
 65 |     Outputs:
 66 |     - new_obj_vecs: FloatTensor of shape (num_objs, D) giving new vectors for objects
 67 |     - new_pred_vecs: FloatTensor of shape (num_triples, D) giving new vectors for predicates
 68 |     """
 69 |     dtype, device = obj_vecs.dtype, obj_vecs.device
 70 |     num_objs, num_triples = obj_vecs.size(0), pred_vecs.size(0)
 71 |     Din_obj, Din_pred, H, Dout = self.input_dim_obj, self.input_dim_pred, self.hidden_dim, self.output_dim
 72 |     
 73 |     # Break apart indices for subjects and objects; these have shape (num_triples,)
 74 |     s_idx = edges[:, 0].contiguous()
 75 |     o_idx = edges[:, 1].contiguous()
 76 |     
 77 |     # Get current vectors for subjects and objects; these have shape (num_triples, Din)
 78 |     cur_s_vecs = obj_vecs[s_idx]
 79 |     cur_o_vecs = obj_vecs[o_idx]
 80 |     
 81 |     # Get current vectors for triples; shape is (num_triples, 3 * Din)
 82 |     # Pass through net1 to get new triple vecs; shape is (num_triples, 2 * H + Dout)
 83 |     cur_t_vecs = torch.cat([cur_s_vecs, pred_vecs, cur_o_vecs], dim=1)
 84 |     new_t_vecs = self.net1(cur_t_vecs)
 85 | 
 86 |     # Break apart into new s, p, and o vecs; s and o vecs have shape (num_triples, H) and
 87 |     # p vecs have shape (num_triples, Dout)
 88 |     new_s_vecs = new_t_vecs[:, :H]
 89 |     new_p_vecs = new_t_vecs[:, H:(H+Dout)]
 90 |     new_o_vecs = new_t_vecs[:, (H+Dout):(2 * H + Dout)]
 91 |  
 92 |     # Allocate space for pooled object vectors of shape (num_objs, H)
 93 |     pooled_obj_vecs = torch.zeros(num_objs, H, dtype=dtype, device=device)
 94 | 
 95 |     # Use scatter_add to sum vectors for objects that appear in multiple triples;
 96 |     # we first need to expand the indices to have shape (num_triples, D)
 97 |     s_idx_exp = s_idx.view(-1, 1).expand_as(new_s_vecs)
 98 |     o_idx_exp = o_idx.view(-1, 1).expand_as(new_o_vecs)
 99 |     # print(pooled_obj_vecs.shape, o_idx_exp.shape)
100 |     pooled_obj_vecs = pooled_obj_vecs.scatter_add(0, s_idx_exp, new_s_vecs)
101 |     pooled_obj_vecs = pooled_obj_vecs.scatter_add(0, o_idx_exp, new_o_vecs)
102 | 
103 |     if self.pooling == 'avg':
104 |       #print("here i am, would you send me an angel")
105 |       # Figure out how many times each object has appeared, again using
106 |       # some scatter_add trickery.
107 |       obj_counts = torch.zeros(num_objs, dtype=dtype, device=device)
108 |       ones = torch.ones(num_triples, dtype=dtype, device=device)
109 |       obj_counts = obj_counts.scatter_add(0, s_idx, ones)
110 |       obj_counts = obj_counts.scatter_add(0, o_idx, ones)
111 |   
112 |       # Divide the new object vectors by the number of times they
113 |       # appeared, but first clamp at 1 to avoid dividing by zero;
114 |       # objects that appear in no triples will have output vector 0
115 |       # so this will not affect them.
116 |       obj_counts = obj_counts.clamp(min=1)
117 |       pooled_obj_vecs = pooled_obj_vecs / obj_counts.view(-1, 1)
118 | 
119 |     # Send pooled object vectors through net2 to get output object vectors,
120 |     # of shape (num_objs, Dout)
121 |     new_obj_vecs = self.net2(pooled_obj_vecs)
122 | 
123 |     return new_obj_vecs, new_p_vecs
124 | 
125 | 
126 | class GraphTripleConvNet(nn.Module):
127 |   """ A sequence of scene graph convolution layers  """
128 |   def __init__(self, input_dim_obj, input_dim_pred, num_layers=5, hidden_dim=512, pooling='avg',
129 |                mlp_normalization='none'):
130 |     super(GraphTripleConvNet, self).__init__()
131 | 
132 |     self.num_layers = num_layers
133 |     self.gconvs = nn.ModuleList()
134 |     gconv_kwargs = {
135 |       'input_dim_obj': input_dim_obj,
136 |       'input_dim_pred': input_dim_pred,
137 |       'hidden_dim': hidden_dim,
138 |       'pooling': pooling,
139 |       'mlp_normalization': mlp_normalization,
140 |     }
141 |     for _ in range(self.num_layers):
142 |       self.gconvs.append(GraphTripleConv(**gconv_kwargs))
143 | 
144 |   def forward(self, obj_vecs, pred_vecs, edges):
145 |     for i in range(self.num_layers):
146 |       gconv = self.gconvs[i]
147 |       obj_vecs, pred_vecs = gconv(obj_vecs, pred_vecs, edges)
148 |     return obj_vecs, pred_vecs
149 | 
150 | 
151 | 
--------------------------------------------------------------------------------
/simsg/layers.py:
--------------------------------------------------------------------------------
  1 | #!/usr/bin/python
  2 | #
  3 | # Copyright 2018 Google LLC
  4 | #
  5 | # Licensed under the Apache License, Version 2.0 (the "License");
  6 | # you may not use this file except in compliance with the License.
  7 | # You may obtain a copy of the License at
  8 | #
  9 | #      http://www.apache.org/licenses/LICENSE-2.0
 10 | #
 11 | # Unless required by applicable law or agreed to in writing, software
 12 | # distributed under the License is distributed on an "AS IS" BASIS,
 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14 | # See the License for the specific language governing permissions and
 15 | # limitations under the License.
 16 | 
 17 | import torch
 18 | import torch.nn as nn
 19 | import torch.nn.functional as F
 20 | 
 21 | 
 22 | def get_normalization_2d(channels, normalization):
 23 |   if normalization == 'instance':
 24 |     return nn.InstanceNorm2d(channels)
 25 |   elif normalization == 'batch':
 26 |     return nn.BatchNorm2d(channels)
 27 |   elif normalization == 'none':
 28 |     return None
 29 |   else:
 30 |     raise ValueError('Unrecognized normalization type "%s"' % normalization)
 31 | 
 32 | 
 33 | def get_activation(name):
 34 |   kwargs = {}
 35 |   if name.lower().startswith('leakyrelu'):
 36 |     if '-' in name:
 37 |       slope = float(name.split('-')[1])
 38 |       kwargs = {'negative_slope': slope}
 39 |   name = 'leakyrelu'
 40 |   activations = {
 41 |     'relu': nn.ReLU,
 42 |     'leakyrelu': nn.LeakyReLU,
 43 |   }
 44 |   if name.lower() not in activations:
 45 |     raise ValueError('Invalid activation "%s"' % name)
 46 |   return activations[name.lower()](**kwargs)
 47 | 
 48 | 
 49 | def _init_conv(layer, method):
 50 |   if not isinstance(layer, nn.Conv2d):
 51 |     return
 52 |   if method == 'default':
 53 |     return
 54 |   elif method == 'kaiming-normal':
 55 |     nn.init.kaiming_normal(layer.weight)
 56 |   elif method == 'kaiming-uniform':
 57 |     nn.init.kaiming_uniform(layer.weight)
 58 | 
 59 | 
 60 | class Flatten(nn.Module):
 61 |   def forward(self, x):
 62 |     return x.view(x.size(0), -1)
 63 | 
 64 |   def __repr__(self):
 65 |     return 'Flatten()'
 66 | 
 67 | 
 68 | class Unflatten(nn.Module):
 69 |   def __init__(self, size):
 70 |     super(Unflatten, self).__init__()
 71 |     self.size = size
 72 | 
 73 |   def forward(self, x):
 74 |     return x.view(*self.size)
 75 | 
 76 |   def __repr__(self):
 77 |     size_str = ', '.join('%d' % d for d in self.size)
 78 |     return 'Unflatten(%s)' % size_str
 79 | 
 80 | 
 81 | class GlobalAvgPool(nn.Module):
 82 |   def forward(self, x):
 83 |     N, C = x.size(0), x.size(1)
 84 |     return x.view(N, C, -1).mean(dim=2)
 85 | 
 86 | 
 87 | class ResidualBlock(nn.Module):
 88 |   def __init__(self, channels, normalization='batch', activation='relu',
 89 |                padding='same', kernel_size=3, init='default'):
 90 |     super(ResidualBlock, self).__init__()
 91 | 
 92 |     K = kernel_size
 93 |     P = _get_padding(K, padding)
 94 |     C = channels
 95 |     self.padding = P
 96 |     layers = [
 97 |       get_normalization_2d(C, normalization),
 98 |       get_activation(activation),
 99 |       nn.Conv2d(C, C, kernel_size=K, padding=P),
100 |       get_normalization_2d(C, normalization),
101 |       get_activation(activation),
102 |       nn.Conv2d(C, C, kernel_size=K, padding=P),
103 |     ]
104 |     layers = [layer for layer in layers if layer is not None]
105 |     for layer in layers:
106 |       _init_conv(layer, method=init)
107 |     self.net = nn.Sequential(*layers)
108 | 
109 |   def forward(self, x):
110 |     P = self.padding
111 |     shortcut = x
112 |     if P == 0:
113 |       shortcut = x[:, :, P:-P, P:-P]
114 |     y = self.net(x)
115 |     return shortcut + self.net(x)
116 | 
117 | 
118 | def _get_padding(K, mode):
119 |   """ Helper method to compute padding size """
120 |   if mode == 'valid':
121 |     return 0
122 |   elif mode == 'same':
123 |     assert K % 2 == 1, 'Invalid kernel size %d for "same" padding' % K
124 |     return (K - 1) // 2
125 | 
126 | 
127 | def build_cnn(arch, normalization='batch', activation='relu', padding='same',
128 |               pooling='max', init='default'):
129 |   """
130 |   Build a CNN from an architecture string, which is a list of layer
131 |   specification strings. The overall architecture can be given as a list or as
132 |   a comma-separated string.
133 | 
134 |   All convolutions *except for the first* are preceeded by normalization and
135 |   nonlinearity.
136 | 
137 |   All other layers support the following:
138 |   - IX: Indicates that the number of input channels to the network is X.
139 |         Can only be used at the first layer; if not present then we assume
140 |         3 input channels.
141 |   - CK-X: KxK convolution with X output channels
142 |   - CK-X-S: KxK convolution with X output channels and stride S
143 |   - R: Residual block keeping the same number of channels
144 |   - UX: Nearest-neighbor upsampling with factor X
145 |   - PX: Spatial pooling with factor X
146 |   - FC-X-Y: Flatten followed by fully-connected layer
147 | 
148 |   Returns a tuple of:
149 |   - cnn: An nn.Sequential
150 |   - channels: Number of output channels
151 |   """
152 |   #print(arch)
153 |   if isinstance(arch, str):
154 |     arch = arch.split(',')
155 |   cur_C = 3
156 |   if len(arch) > 0 and arch[0][0] == 'I':
157 |     cur_C = int(arch[0][1:])
158 |     arch = arch[1:]
159 |   #print(arch)
160 |   #if len(arch) > 0 and arch[0][0] == 'I':
161 |   #  arch = arch[1:]
162 |   first_conv = True
163 |   flat = False
164 |   layers = []
165 |   for i, s in enumerate(arch):
166 |     #if s[0] == 'I':
167 |     #  continue
168 |     if s[0] == 'C':
169 |       if not first_conv:
170 |         layers.append(get_normalization_2d(cur_C, normalization))
171 |         layers.append(get_activation(activation))
172 |       first_conv = False
173 |       vals = [int(i) for i in s[1:].split('-')]
174 |       if len(vals) == 2:
175 |         K, next_C = vals
176 |         stride = 1
177 |       elif len(vals) == 3:
178 |         K, next_C, stride = vals
179 |       # K, next_C = (int(i) for i in s[1:].split('-'))
180 |       P = _get_padding(K, padding)
181 |       conv = nn.Conv2d(cur_C, next_C, kernel_size=K, padding=P, stride=stride)
182 |       layers.append(conv)
183 |       _init_conv(layers[-1], init)
184 |       cur_C = next_C
185 |     elif s[0] == 'R':
186 |       norm = 'none' if first_conv else normalization
187 |       res = ResidualBlock(cur_C, normalization=norm, activation=activation,
188 |                           padding=padding, init=init)
189 |       layers.append(res)
190 |       first_conv = False
191 |     elif s[0] == 'U':
192 |       factor = int(s[1:])
193 |       layers.append(nn.Upsample(scale_factor=factor, mode='nearest'))
194 |     elif s[0] == 'P':
195 |       factor = int(s[1:])
196 |       if pooling == 'max':
197 |         pool = nn.MaxPool2d(kernel_size=factor, stride=factor)
198 |       elif pooling == 'avg':
199 |         pool = nn.AvgPool2d(kernel_size=factor, stride=factor)
200 |       layers.append(pool)
201 |     elif s[:2] == 'FC':
202 |       _, Din, Dout = s.split('-')
203 |       Din, Dout = int(Din), int(Dout)
204 |       if not flat:
205 |         layers.append(Flatten())
206 |       flat = True
207 |       layers.append(nn.Linear(Din, Dout))
208 |       if i + 1 < len(arch):
209 |         layers.append(get_activation(activation))
210 |       cur_C = Dout
211 |     else:
212 |       raise ValueError('Invalid layer "%s"' % s)
213 |   layers = [layer for layer in layers if layer is not None]
214 |   for layer in layers:
215 |     print(layer)
216 |   return nn.Sequential(*layers), cur_C
217 | 
218 | 
219 | def build_mlp(dim_list, activation='relu', batch_norm='none',
220 |               dropout=0, final_nonlinearity=True):
221 |   layers = []
222 |   for i in range(len(dim_list) - 1):
223 |     dim_in, dim_out = dim_list[i], dim_list[i + 1]
224 |     layers.append(nn.Linear(dim_in, dim_out))
225 |     final_layer = (i == len(dim_list) - 2)
226 |     if not final_layer or final_nonlinearity:
227 |       if batch_norm == 'batch':
228 |         layers.append(nn.BatchNorm1d(dim_out))
229 |       if activation == 'relu':
230 |         layers.append(nn.ReLU())
231 |       elif activation == 'leakyrelu':
232 |         layers.append(nn.LeakyReLU())
233 |     if dropout > 0:
234 |       layers.append(nn.Dropout(p=dropout))
235 |   return nn.Sequential(*layers)
236 | 
237 | 
--------------------------------------------------------------------------------
/simsg/layout.py:
--------------------------------------------------------------------------------
  1 | #!/usr/bin/python
  2 | #
  3 | # Copyright 2018 Google LLC
  4 | #
  5 | # Licensed under the Apache License, Version 2.0 (the "License");
  6 | # you may not use this file except in compliance with the License.
  7 | # You may obtain a copy of the License at
  8 | #
  9 | #      http://www.apache.org/licenses/LICENSE-2.0
 10 | #
 11 | # Unless required by applicable law or agreed to in writing, software
 12 | # distributed under the License is distributed on an "AS IS" BASIS,
 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14 | # See the License for the specific language governing permissions and
 15 | # limitations under the License.
 16 | 
 17 | import torch
 18 | import torch.nn.functional as F
 19 | 
 20 | """
 21 | Functions for computing image layouts from object vectors, bounding boxes,
 22 | and segmentation masks. These are used to compute course scene layouts which
 23 | are then fed as input to the cascaded refinement network.
 24 | """
 25 | 
 26 | 
 27 | def boxes_to_layout(vecs, boxes, obj_to_img, H, W=None, pooling='sum'):
 28 |   """
 29 |   Inputs:
 30 |   - vecs: Tensor of shape (O, D) giving vectors
 31 |   - boxes: Tensor of shape (O, 4) giving bounding boxes in the format
 32 |     [x0, y0, x1, y1] in the [0, 1] coordinate space
 33 |   - obj_to_img: LongTensor of shape (O,) mapping each element of vecs to
 34 |     an image, where each element is in the range [0, N). If obj_to_img[i] = j
 35 |     then vecs[i] belongs to image j.
 36 |   - H, W: Size of the output
 37 | 
 38 |   Returns:
 39 |   - out: Tensor of shape (N, D, H, W)
 40 |   """
 41 |   O, D = vecs.size()
 42 |   if W is None:
 43 |     W = H
 44 | 
 45 |   grid = _boxes_to_grid(boxes, H, W)
 46 | 
 47 |   # If we don't add extra spatial dimensions here then out-of-bounds
 48 |   # elements won't be automatically set to 0
 49 |   img_in = vecs.view(O, D, 1, 1).expand(O, D, 8, 8)
 50 |   sampled = F.grid_sample(img_in, grid)   # (O, D, H, W)
 51 | 
 52 |   # Explicitly masking makes everything quite a bit slower.
 53 |   # If we rely on implicit masking the interpolated boxes end up
 54 |   # blurred around the edges, but it should be fine.
 55 |   # mask = ((X < 0) + (X > 1) + (Y < 0) + (Y > 1)).clamp(max=1)
 56 |   # sampled[mask[:, None]] = 0
 57 | 
 58 |   out = _pool_samples(sampled, obj_to_img, pooling=pooling)
 59 | 
 60 |   return out
 61 | 
 62 | 
 63 | def masks_to_layout(vecs, boxes, masks, obj_to_img, H, W=None, pooling='sum', front_idx=None):
 64 |   """
 65 |   Inputs:
 66 |   - vecs: Tensor of shape (O, D) giving vectors
 67 |   - boxes: Tensor of shape (O, 4) giving bounding boxes in the format
 68 |     [x0, y0, x1, y1] in the [0, 1] coordinate space
 69 |   - masks: Tensor of shape (O, M, M) giving binary masks for each object
 70 |   - obj_to_img: LongTensor of shape (O,) mapping objects to images
 71 |   - H, W: Size of the output image.
 72 | 
 73 |   Returns:
 74 |   - out: Tensor of shape (N, D, H, W)
 75 |   """
 76 |   O, D = vecs.size()
 77 |   M = masks.size(1)
 78 |   assert masks.size() == (O, M, M)
 79 |   if W is None:
 80 |     W = H
 81 | 
 82 |   grid = _boxes_to_grid(boxes, H, W)
 83 | 
 84 |   img_in = vecs.view(O, D, 1, 1) * masks.float().view(O, 1, M, M)
 85 | 
 86 |   if pooling == 'max':
 87 |     out = []
 88 | 
 89 |     # group by image, extract feature with maximum mask confidence
 90 |     for i in range(obj_to_img.max()+1):
 91 | 
 92 |       curr_projected_mask = F.grid_sample(masks.view(O, 1, M, M)[(obj_to_img==i).nonzero().view(-1)],
 93 |                                           grid[(obj_to_img==i).nonzero().view(-1)])
 94 |       curr_sampled = F.grid_sample(img_in[(obj_to_img==i).nonzero().view(-1)],
 95 |                                    grid[(obj_to_img==i).nonzero().view(-1)])
 96 | 
 97 |       argmax_mask = torch.argmax(curr_projected_mask, dim=0)
 98 |       argmax_mask = argmax_mask.repeat(1, D, 1, 1)
 99 |       out.append(torch.gather(curr_sampled, 0, argmax_mask))
100 | 
101 |     out = torch.cat(out, dim=0)
102 | 
103 |   else:
104 |     sampled = F.grid_sample(img_in, grid)
105 |     out = _pool_samples(sampled, obj_to_img, pooling=pooling)
106 | 
107 |   #print(out.shape)
108 |   return out
109 | 
110 | 
111 | def _boxes_to_grid(boxes, H, W):
112 |   """
113 |   Input:
114 |   - boxes: FloatTensor of shape (O, 4) giving boxes in the [x0, y0, x1, y1]
115 |     format in the [0, 1] coordinate space
116 |   - H, W: Scalars giving size of output
117 | 
118 |   Returns:
119 |   - grid: FloatTensor of shape (O, H, W, 2) suitable for passing to grid_sample
120 |   """
121 |   O = boxes.size(0)
122 | 
123 |   boxes = boxes.view(O, 4, 1, 1)
124 | 
125 |   # All these are (O, 1, 1)
126 |   x0, y0 = boxes[:, 0], boxes[:, 1]
127 |   x1, y1 = boxes[:, 2], boxes[:, 3]
128 |   ww = x1 - x0
129 |   hh = y1 - y0
130 | 
131 |   X = torch.linspace(0, 1, steps=W).view(1, 1, W).to(boxes)
132 |   Y = torch.linspace(0, 1, steps=H).view(1, H, 1).to(boxes)
133 |   
134 |   X = (X - x0) / ww   # (O, 1, W)
135 |   Y = (Y - y0) / hh   # (O, H, 1)
136 |   
137 |   # Stack does not broadcast its arguments so we need to expand explicitly
138 |   X = X.expand(O, H, W)
139 |   Y = Y.expand(O, H, W)
140 |   grid = torch.stack([X, Y], dim=3)  # (O, H, W, 2)
141 | 
142 |   # Right now grid is in [0, 1] space; transform to [-1, 1]
143 |   grid = grid.mul(2).sub(1)
144 | 
145 |   return grid
146 | 
147 | 
148 | def _pool_samples(samples, obj_to_img, pooling='sum'):
149 |   """
150 |   Input:
151 |   - samples: FloatTensor of shape (O, D, H, W)
152 |   - obj_to_img: LongTensor of shape (O,) with each element in the range
153 |     [0, N) mapping elements of samples to output images
154 | 
155 |   Output:
156 |   - pooled: FloatTensor of shape (N, D, H, W)
157 |   """
158 |   dtype, device = samples.dtype, samples.device
159 |   O, D, H, W = samples.size()
160 |   N = obj_to_img.data.max().item() + 1
161 | 
162 |   # Use scatter_add to sum the sampled outputs for each image
163 |   out = torch.zeros(N, D, H, W, dtype=dtype, device=device)
164 |   idx = obj_to_img.view(O, 1, 1, 1).expand(O, D, H, W)
165 |   out = out.scatter_add(0, idx, samples)
166 | 
167 |   if pooling == 'avg':
168 |     # Divide each output mask by the number of objects; use scatter_add again
169 |     # to count the number of objects per image.
170 |     ones = torch.ones(O, dtype=dtype, device=device)
171 |     obj_counts = torch.zeros(N, dtype=dtype, device=device)
172 |     obj_counts = obj_counts.scatter_add(0, obj_to_img, ones)
173 |     print(obj_counts)
174 |     obj_counts = obj_counts.clamp(min=1)
175 |     out = out / obj_counts.view(N, 1, 1, 1)
176 | 
177 |   elif pooling != 'sum':
178 |     raise ValueError('Invalid pooling "%s"' % pooling)
179 | 
180 |   return out
181 | 
182 | 
183 | if __name__ == '__main__':
184 |   vecs = torch.FloatTensor([
185 |             [1, 0, 0], [0, 1, 0], [0, 0, 1],
186 |             [1, 0, 0], [0, 1, 0], [0, 0, 1],
187 |          ])
188 |   boxes = torch.FloatTensor([
189 |             [0.25, 0.125, 0.5, 0.875],
190 |             [0, 0, 1, 0.25],
191 |             [0.6125, 0, 0.875, 1],
192 |             [0, 0.8, 1, 1.0],
193 |             [0.25, 0.125, 0.5, 0.875],
194 |             [0.6125, 0, 0.875, 1],
195 |           ])
196 |   obj_to_img = torch.LongTensor([0, 0, 0, 1, 1, 1])
197 |   # vecs = torch.FloatTensor([[[1]]])
198 |   # boxes = torch.FloatTensor([[[0.25, 0.25, 0.75, 0.75]]])
199 |   vecs, boxes = vecs.cuda(), boxes.cuda()
200 |   obj_to_img = obj_to_img.cuda()
201 |   out = boxes_to_layout(vecs, boxes, obj_to_img, 256, pooling='sum')
202 |   
203 |   from torchvision.utils import save_image
204 |   save_image(out.data, 'out.png')
205 | 
206 | 
207 |   masks = torch.FloatTensor([
208 |             [
209 |               [0, 0, 1, 0, 0],
210 |               [0, 1, 1, 1, 0],
211 |               [1, 1, 1, 1, 1],
212 |               [0, 1, 1, 1, 0],
213 |               [0, 0, 1, 0, 0],
214 |             ],
215 |             [
216 |               [0, 0, 1, 0, 0],
217 |               [0, 1, 0, 1, 0],
218 |               [1, 0, 0, 0, 1],
219 |               [0, 1, 0, 1, 0],
220 |               [0, 0, 1, 0, 0],
221 |             ],
222 |             [
223 |               [0, 0, 1, 0, 0],
224 |               [0, 1, 1, 1, 0],
225 |               [1, 1, 1, 1, 1],
226 |               [0, 1, 1, 1, 0],
227 |               [0, 0, 1, 0, 0],
228 |             ],
229 |             [
230 |               [0, 0, 1, 0, 0],
231 |               [0, 1, 1, 1, 0],
232 |               [1, 1, 1, 1, 1],
233 |               [0, 1, 1, 1, 0],
234 |               [0, 0, 1, 0, 0],
235 |             ],
236 |             [
237 |               [0, 0, 1, 0, 0],
238 |               [0, 1, 1, 1, 0],
239 |               [1, 1, 1, 1, 1],
240 |               [0, 1, 1, 1, 0],
241 |               [0, 0, 1, 0, 0],
242 |             ],
243 |             [
244 |               [0, 0, 1, 0, 0],
245 |               [0, 1, 1, 1, 0],
246 |               [1, 1, 1, 1, 1],
247 |               [0, 1, 1, 1, 0],
248 |               [0, 0, 1, 0, 0],
249 |             ]
250 |           ])
251 |   masks = masks.cuda()
252 |   out = masks_to_layout(vecs, boxes, masks, obj_to_img, 256, pooling="max")
253 |   save_image(out.data, 'out_masks.png')
254 | 
--------------------------------------------------------------------------------
/simsg/loader_utils.py:
--------------------------------------------------------------------------------
  1 | from simsg.data.vg import SceneGraphNoPairsDataset, collate_fn_nopairs
  2 | from simsg.data.clevr import SceneGraphWithPairsDataset, collate_fn_withpairs
  3 | 
  4 | import json
  5 | from torch.utils.data import DataLoader
  6 | 
  7 | 
  8 | def build_clevr_supervised_train_dsets(args):
  9 |   print("building fully supervised %s dataset" % args.dataset)
 10 |   with open(args.vocab_json, 'r') as f:
 11 |     vocab = json.load(f)
 12 |   dset_kwargs = {
 13 |     'vocab': vocab,
 14 |     'h5_path': args.train_h5,
 15 |     'image_dir': args.vg_image_dir,
 16 |     'image_size': args.image_size,
 17 |     'max_samples': args.num_train_samples,
 18 |     'max_objects': args.max_objects_per_image,
 19 |     'use_orphaned_objects': args.vg_use_orphaned_objects,
 20 |     'include_relationships': args.include_relationships,
 21 |   }
 22 |   train_dset = SceneGraphWithPairsDataset(**dset_kwargs)
 23 |   iter_per_epoch = len(train_dset) // args.batch_size
 24 |   print('There are %d iterations per epoch' % iter_per_epoch)
 25 | 
 26 |   dset_kwargs['h5_path'] = args.val_h5
 27 |   del dset_kwargs['max_samples']
 28 |   val_dset = SceneGraphWithPairsDataset(**dset_kwargs)
 29 | 
 30 |   dset_kwargs['h5_path'] = args.test_h5
 31 |   test_dset = SceneGraphWithPairsDataset(**dset_kwargs)
 32 | 
 33 |   return vocab, train_dset, val_dset, test_dset
 34 | 
 35 | 
 36 | def build_dset_nopairs(args, checkpoint):
 37 | 
 38 |   vocab = checkpoint['model_kwargs']['vocab']
 39 |   dset_kwargs = {
 40 |     'vocab': vocab,
 41 |     'h5_path': args.data_h5,
 42 |     'image_dir': args.data_image_dir,
 43 |     'image_size': args.image_size,
 44 |     'max_objects': checkpoint['args']['max_objects_per_image'],
 45 |     'use_orphaned_objects': checkpoint['args']['vg_use_orphaned_objects'],
 46 |     'mode': args.mode,
 47 |     'predgraphs': args.predgraphs
 48 |   }
 49 |   dset = SceneGraphNoPairsDataset(**dset_kwargs)
 50 | 
 51 |   return dset
 52 | 
 53 | 
 54 | def build_dset_withpairs(args, checkpoint, vocab_t):
 55 | 
 56 |   vocab = vocab_t
 57 |   dset_kwargs = {
 58 |     'vocab': vocab,
 59 |     'h5_path': args.data_h5,
 60 |     'image_dir': args.data_image_dir,
 61 |     'image_size': args.image_size,
 62 |     'max_objects': checkpoint['args']['max_objects_per_image'],
 63 |     'use_orphaned_objects': checkpoint['args']['vg_use_orphaned_objects'],
 64 |     'mode': args.mode
 65 |   }
 66 |   dset = SceneGraphWithPairsDataset(**dset_kwargs)
 67 | 
 68 |   return dset
 69 | 
 70 | 
 71 | def build_eval_loader(args, checkpoint, vocab_t=None, no_gt=False):
 72 | 
 73 |   if args.dataset == 'vg' or (no_gt and args.dataset == 'clevr'):
 74 |     dset = build_dset_nopairs(args, checkpoint)
 75 |     collate_fn = collate_fn_nopairs
 76 |   elif args.dataset == 'clevr':
 77 |     dset = build_dset_withpairs(args, checkpoint, vocab_t)
 78 |     collate_fn = collate_fn_withpairs
 79 | 
 80 |   loader_kwargs = {
 81 |     'batch_size': 1,
 82 |     'num_workers': args.loader_num_workers,
 83 |     'shuffle': args.shuffle,
 84 |     'collate_fn': collate_fn,
 85 |   }
 86 |   loader = DataLoader(dset, **loader_kwargs)
 87 | 
 88 |   return loader
 89 | 
 90 | 
 91 | def build_train_dsets(args):
 92 |   print("building unpaired %s dataset" % args.dataset)
 93 |   with open(args.vocab_json, 'r') as f:
 94 |     vocab = json.load(f)
 95 |   dset_kwargs = {
 96 |     'vocab': vocab,
 97 |     'h5_path': args.train_h5,
 98 |     'image_dir': args.vg_image_dir,
 99 |     'image_size': args.image_size,
100 |     'max_samples': args.num_train_samples,
101 |     'max_objects': args.max_objects_per_image,
102 |     'use_orphaned_objects': args.vg_use_orphaned_objects,
103 |     'include_relationships': args.include_relationships,
104 |   }
105 |   train_dset = SceneGraphNoPairsDataset(**dset_kwargs)
106 |   iter_per_epoch = len(train_dset) // args.batch_size
107 |   print('There are %d iterations per epoch' % iter_per_epoch)
108 | 
109 |   dset_kwargs['h5_path'] = args.val_h5
110 |   del dset_kwargs['max_samples']
111 |   val_dset = SceneGraphNoPairsDataset(**dset_kwargs)
112 | 
113 |   return vocab, train_dset, val_dset
114 | 
115 | 
116 | def build_train_loaders(args):
117 | 
118 |   print(args.dataset)
119 |   if args.dataset == 'vg' or (args.dataset == "clevr" and not args.is_supervised):
120 |     vocab, train_dset, val_dset = build_train_dsets(args)
121 |     collate_fn = collate_fn_nopairs
122 |   elif args.dataset == 'clevr':
123 |     vocab, train_dset, val_dset, test_dset = build_clevr_supervised_train_dsets(args)
124 |     collate_fn = collate_fn_withpairs
125 | 
126 |   loader_kwargs = {
127 |     'batch_size': args.batch_size,
128 |     'num_workers': args.loader_num_workers,
129 |     'shuffle': True,
130 |     'collate_fn': collate_fn,
131 |   }
132 |   train_loader = DataLoader(train_dset, **loader_kwargs)
133 | 
134 |   loader_kwargs['shuffle'] = args.shuffle_val
135 |   val_loader = DataLoader(val_dset, **loader_kwargs)
136 | 
137 |   return vocab, train_loader, val_loader
138 | 
--------------------------------------------------------------------------------
/simsg/losses.py:
--------------------------------------------------------------------------------
  1 | #!/usr/bin/python
  2 | #
  3 | # Copyright 2018 Google LLC
  4 | #
  5 | # Licensed under the Apache License, Version 2.0 (the "License");
  6 | # you may not use this file except in compliance with the License.
  7 | # You may obtain a copy of the License at
  8 | #
  9 | #      http://www.apache.org/licenses/LICENSE-2.0
 10 | #
 11 | # Unless required by applicable law or agreed to in writing, software
 12 | # distributed under the License is distributed on an "AS IS" BASIS,
 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14 | # See the License for the specific language governing permissions and
 15 | # limitations under the License.
 16 | 
 17 | import torch
 18 | import torch.nn.functional as F
 19 | 
 20 | 
 21 | def get_gan_losses(gan_type):
 22 |   """
 23 |   Returns the generator and discriminator loss for a particular GAN type.
 24 | 
 25 |   The returned functions have the following API:
 26 |   loss_g = g_loss(scores_fake)
 27 |   loss_d = d_loss(scores_real, scores_fake)
 28 |   """
 29 |   if gan_type == 'gan':
 30 |     return gan_g_loss, gan_d_loss
 31 |   elif gan_type == 'wgan':
 32 |     return wgan_g_loss, wgan_d_loss
 33 |   elif gan_type == 'lsgan':
 34 |     return lsgan_g_loss, lsgan_d_loss
 35 |   else:
 36 |     raise ValueError('Unrecognized GAN type "%s"' % gan_type)
 37 | 
 38 | def gan_percept_loss(real, fake):
 39 | 
 40 |   '''
 41 |   Inputs:
 42 |   - real: discriminator feat maps for every layer, when x=real image
 43 |   - fake: discriminator feat maps for every layer, when x=pred image
 44 |   Returns:
 45 |     perceptual loss in all discriminator layers
 46 |   '''
 47 | 
 48 |   loss = 0
 49 | 
 50 |   for i in range(len(real)):
 51 |     loss += (real[i] - fake[i]).abs().mean()
 52 | 
 53 |   return loss / len(real)
 54 | 
 55 | 
 56 | def bce_loss(input, target):
 57 |     """
 58 |     Numerically stable version of the binary cross-entropy loss function.
 59 | 
 60 |     As per https://github.com/pytorch/pytorch/issues/751
 61 |     See the TensorFlow docs for a derivation of this formula:
 62 |     https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
 63 | 
 64 |     Inputs:
 65 |     - input: PyTorch Tensor of shape (N, ) giving scores.
 66 |     - target: PyTorch Tensor of shape (N,) containing 0 and 1 giving targets.
 67 | 
 68 |     Returns:
 69 |     - A PyTorch Tensor containing the mean BCE loss over the minibatch of
 70 |       input data.
 71 |     """
 72 |     neg_abs = -input.abs()
 73 |     loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
 74 |     return loss.mean()
 75 | 
 76 | 
 77 | def _make_targets(x, y):
 78 |   """
 79 |   Inputs:
 80 |   - x: PyTorch Tensor
 81 |   - y: Python scalar
 82 | 
 83 |   Outputs:
 84 |   - out: PyTorch Variable with same shape and dtype as x, but filled with y
 85 |   """
 86 |   return torch.full_like(x, y)
 87 | 
 88 | 
 89 | def gan_g_loss(scores_fake):
 90 |   """
 91 |   Input:
 92 |   - scores_fake: Tensor of shape (N,) containing scores for fake samples
 93 | 
 94 |   Output:
 95 |   - loss: Variable of shape (,) giving GAN generator loss
 96 |   """
 97 |   if scores_fake.dim() > 1:
 98 |     scores_fake = scores_fake.view(-1)
 99 |   y_fake = _make_targets(scores_fake, 1)
100 |   return bce_loss(scores_fake, y_fake)
101 | 
102 | 
103 | def gan_d_loss(scores_real, scores_fake):
104 |   """
105 |   Input:
106 |   - scores_real: Tensor of shape (N,) giving scores for real samples
107 |   - scores_fake: Tensor of shape (N,) giving scores for fake samples
108 | 
109 |   Output:
110 |   - loss: Tensor of shape (,) giving GAN discriminator loss
111 |   """
112 |   assert scores_real.size() == scores_fake.size()
113 |   if scores_real.dim() > 1:
114 |     scores_real = scores_real.view(-1)
115 |     scores_fake = scores_fake.view(-1)
116 |   y_real = _make_targets(scores_real, 1)
117 |   y_fake = _make_targets(scores_fake, 0)
118 |   loss_real = bce_loss(scores_real, y_real)
119 |   loss_fake = bce_loss(scores_fake, y_fake)
120 |   return loss_real + loss_fake
121 | 
122 | 
123 | def wgan_g_loss(scores_fake):
124 |   """
125 |   Input:
126 |   - scores_fake: Tensor of shape (N,) containing scores for fake samples
127 | 
128 |   Output:
129 |   - loss: Tensor of shape (,) giving WGAN generator loss
130 |   """
131 |   return -scores_fake.mean()
132 | 
133 | 
134 | def wgan_d_loss(scores_real, scores_fake):
135 |   """
136 |   Input:
137 |   - scores_real: Tensor of shape (N,) giving scores for real samples
138 |   - scores_fake: Tensor of shape (N,) giving scores for fake samples
139 | 
140 |   Output:
141 |   - loss: Tensor of shape (,) giving WGAN discriminator loss
142 |   """
143 |   return scores_fake.mean() - scores_real.mean()
144 | 
145 | 
146 | def lsgan_g_loss(scores_fake):
147 |   if scores_fake.dim() > 1:
148 |     scores_fake = scores_fake.view(-1)
149 |   y_fake = _make_targets(scores_fake, 1)
150 |   return F.mse_loss(scores_fake.sigmoid(), y_fake)
151 | 
152 | 
153 | def lsgan_d_loss(scores_real, scores_fake):
154 |   assert scores_real.size() == scores_fake.size()
155 |   if scores_real.dim() > 1:
156 |     scores_real = scores_real.view(-1)
157 |     scores_fake = scores_fake.view(-1)
158 |   y_real = _make_targets(scores_real, 1)
159 |   y_fake = _make_targets(scores_fake, 0)
160 |   loss_real = F.mse_loss(scores_real.sigmoid(), y_real)
161 |   loss_fake = F.mse_loss(scores_fake.sigmoid(), y_fake)
162 |   return loss_real + loss_fake
163 | 
164 | 
165 | def gradient_penalty(x_real, x_fake, f, gamma=1.0):
166 |   N = x_real.size(0)
167 |   device, dtype = x_real.device, x_real.dtype
168 |   eps = torch.randn(N, 1, 1, 1, device=device, dtype=dtype)
169 |   x_hat = eps * x_real + (1 - eps) * x_fake
170 |   x_hat_score = f(x_hat)
171 |   if x_hat_score.dim() > 1:
172 |     x_hat_score = x_hat_score.view(x_hat_score.size(0), -1).mean(dim=1)
173 |   x_hat_score = x_hat_score.sum()
174 |   grad_x_hat, = torch.autograd.grad(x_hat_score, x_hat, create_graph=True)
175 |   grad_x_hat_norm = grad_x_hat.contiguous().view(N, -1).norm(p=2, dim=1)
176 |   gp_loss = (grad_x_hat_norm - gamma).pow(2).div(gamma * gamma).mean()
177 |   return gp_loss
178 | 
179 | """
180 | Copyright (C) 2019 NVIDIA Corporation.  All rights reserved.
181 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
182 | """
183 | 
184 | import torch
185 | import torch.nn as nn
186 | #import torch.nn.functional as F
187 | from simsg.SPADE.architectures import VGG19
188 | 
189 | 
190 | #                            SPADE losses!                      #
191 | # Defines the GAN loss which uses either LSGAN or the regular GAN.
192 | # When LSGAN is used, it is basically same as MSELoss,
193 | # but it abstracts away the need to create the target label tensor
194 | # that has the same size as the input
195 | class GANLoss(nn.Module):
196 |     def __init__(self, gan_mode='hinge', target_real_label=1.0, target_fake_label=0.0,
197 |                  tensor=torch.cuda.FloatTensor, opt=None):
198 |         super(GANLoss, self).__init__()
199 |         self.real_label = target_real_label
200 |         self.fake_label = target_fake_label
201 |         self.real_label_tensor = None
202 |         self.fake_label_tensor = None
203 |         self.zero_tensor = None
204 |         self.Tensor = tensor
205 |         self.gan_mode = gan_mode
206 |         self.opt = opt
207 |         if gan_mode == 'ls':
208 |             pass
209 |         elif gan_mode == 'original':
210 |             pass
211 |         elif gan_mode == 'w':
212 |             pass
213 |         elif gan_mode == 'hinge':
214 |             pass
215 |         else:
216 |             raise ValueError('Unexpected gan_mode {}'.format(gan_mode))
217 | 
218 |     def get_target_tensor(self, input, target_is_real):
219 |         if target_is_real:
220 |             if self.real_label_tensor is None:
221 |                 self.real_label_tensor = self.Tensor(1).fill_(self.real_label)
222 |                 self.real_label_tensor.requires_grad_(False)
223 |             return self.real_label_tensor.expand_as(input)
224 |         else:
225 |             if self.fake_label_tensor is None:
226 |                 self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label)
227 |                 self.fake_label_tensor.requires_grad_(False)
228 |             return self.fake_label_tensor.expand_as(input)
229 | 
230 |     def get_zero_tensor(self, input):
231 |         if self.zero_tensor is None:
232 |             self.zero_tensor = self.Tensor(1).fill_(0)
233 |             self.zero_tensor.requires_grad_(False)
234 |         return self.zero_tensor.expand_as(input)
235 | 
236 |     def loss(self, input, target_is_real, for_discriminator=True):
237 |         if self.gan_mode == 'original':  # cross entropy loss
238 |             target_tensor = self.get_target_tensor(input, target_is_real)
239 |             loss = F.binary_cross_entropy_with_logits(input, target_tensor)
240 |             return loss
241 |         elif self.gan_mode == 'ls':
242 |             target_tensor = self.get_target_tensor(input, target_is_real)
243 |             return F.mse_loss(input, target_tensor)
244 |         elif self.gan_mode == 'hinge':
245 |             if for_discriminator:
246 |                 if target_is_real:
247 |                     minval = torch.min(input - 1, self.get_zero_tensor(input))
248 |                     loss = -torch.mean(minval)
249 |                 else:
250 |                     minval = torch.min(-input - 1, self.get_zero_tensor(input))
251 |                     loss = -torch.mean(minval)
252 |             else:
253 |                 assert target_is_real, "The generator's hinge loss must be aiming for real"
254 |                 loss = -torch.mean(input)
255 |             return loss
256 |         else:
257 |             # wgan
258 |             if target_is_real:
259 |                 return -input.mean()
260 |             else:
261 |                 return input.mean()
262 | 
263 |     def __call__(self, input, target_is_real, for_discriminator=True):
264 |         # computing loss is a bit complicated because |input| may not be
265 |         # a tensor, but list of tensors in case of multiscale discriminator
266 |         if isinstance(input, list):
267 |             loss = 0
268 |             for pred_i in input:
269 |                 if isinstance(pred_i, list):
270 |                     pred_i = pred_i[-1]
271 |                 loss_tensor = self.loss(pred_i, target_is_real, for_discriminator)
272 |                 bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
273 |                 new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
274 |                 loss += new_loss
275 |             return loss / len(input)
276 |         else:
277 |             return self.loss(input, target_is_real, for_discriminator)
278 | 
279 | 
280 | # Perceptual loss that uses a pretrained VGG network
281 | class VGGLoss(nn.Module):
282 |     def __init__(self):
283 |         super(VGGLoss, self).__init__()
284 |         self.vgg = VGG19().cuda()
285 |         self.criterion = nn.L1Loss()
286 |         self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
287 | 
288 |     def forward(self, x, y):
289 |         x_vgg, y_vgg = self.vgg(x), self.vgg(y)
290 |         loss = 0
291 |         for i in range(len(x_vgg)):
292 |             loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
293 |         return loss
294 | 
295 | 
296 | # KL Divergence loss used in VAE with an image encoder
297 | class KLDLoss(nn.Module):
298 |     def forward(self, mu, logvar):
299 |         return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
300 | 
301 | 
302 | 
--------------------------------------------------------------------------------
/simsg/metrics.py:
--------------------------------------------------------------------------------
  1 | #!/usr/bin/python
  2 | #
  3 | # Copyright 2018 Google LLC
  4 | #
  5 | # Licensed under the Apache License, Version 2.0 (the "License");
  6 | # you may not use this file except in compliance with the License.
  7 | # You may obtain a copy of the License at
  8 | #
  9 | #      http://www.apache.org/licenses/LICENSE-2.0
 10 | #
 11 | # Unless required by applicable law or agreed to in writing, software
 12 | # distributed under the License is distributed on an "AS IS" BASIS,
 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14 | # See the License for the specific language governing permissions and
 15 | # limitations under the License.
 16 | 
 17 | import torch
 18 | import numpy as np
 19 | from scipy import signal
 20 | from scipy.ndimage.filters import convolve
 21 | from PIL import Image
 22 | 
 23 | 
 24 | def intersection(bbox_pred, bbox_gt):
 25 |   max_xy = torch.min(bbox_pred[:, 2:], bbox_gt[:, 2:])
 26 |   min_xy = torch.max(bbox_pred[:, :2], bbox_gt[:, :2])
 27 |   inter = torch.clamp((max_xy - min_xy), min=0)
 28 |   return inter[:, 0] * inter[:, 1]
 29 | 
 30 | 
 31 | def jaccard(bbox_pred, bbox_gt):
 32 |   inter = intersection(bbox_pred, bbox_gt)
 33 |   area_pred = (bbox_pred[:, 2] - bbox_pred[:, 0]) * (bbox_pred[:, 3] -
 34 |       bbox_pred[:, 1])
 35 |   area_gt = (bbox_gt[:, 2] - bbox_gt[:, 0]) * (bbox_gt[:, 3] -
 36 |       bbox_gt[:, 1])
 37 |   union = area_pred + area_gt - inter
 38 |   iou = torch.div(inter, union)
 39 |   return torch.sum(iou)
 40 | 
 41 | def get_total_norm(parameters, norm_type=2):
 42 |   if norm_type == float('inf'):
 43 |     total_norm = max(p.grad.data.abs().max() for p in parameters)
 44 |   else:
 45 |     total_norm = 0
 46 |     for p in parameters:
 47 |       try:
 48 |         param_norm = p.grad.data.norm(norm_type)
 49 |         total_norm += param_norm ** norm_type
 50 |         total_norm = total_norm ** (1. / norm_type)
 51 |       except:
 52 |         continue
 53 |   return total_norm
 54 | 
 55 | 
 56 | def _FSpecialGauss(size, sigma):
 57 |     """Function to mimic the 'fspecial' gaussian MATLAB function."""
 58 |     radius = size // 2
 59 |     offset = 0.0
 60 |     start, stop = -radius, radius + 1
 61 |     if size % 2 == 0:
 62 |         offset = 0.5
 63 |         stop -= 1
 64 |     x, y = np.mgrid[offset + start:stop, offset + start:stop]
 65 |     assert len(x) == size
 66 |     g = np.exp(-((x**2 + y**2) / (2.0 * sigma**2)))
 67 |     return g / g.sum()
 68 | 
 69 | 
 70 | def _SSIMForMultiScale(img1,
 71 |                        img2,
 72 |                        max_val=255,
 73 |                        filter_size=11,
 74 |                        filter_sigma=1.5,
 75 |                        k1=0.01,
 76 |                        k2=0.03):
 77 |     """Return the Structural Similarity Map between `img1` and `img2`.
 78 |   This function attempts to match the functionality of ssim_index_new.m by
 79 |   Zhou Wang: http://www.cns.nyu.edu/~lcv/ssim/msssim.zip
 80 |   Arguments:
 81 |     img1: Numpy array holding the first RGB image batch.
 82 |     img2: Numpy array holding the second RGB image batch.
 83 |     max_val: the dynamic range of the images (i.e., the difference between the
 84 |       maximum the and minimum allowed values).
 85 |     filter_size: Size of blur kernel to use (will be reduced for small images).
 86 |     filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced
 87 |       for small images).
 88 |     k1: Constant used to maintain stability in the SSIM calculation (0.01 in
 89 |       the original paper).
 90 |     k2: Constant used to maintain stability in the SSIM calculation (0.03 in
 91 |       the original paper).
 92 |   Returns:
 93 |     Pair containing the mean SSIM and contrast sensitivity between `img1` and
 94 |     `img2`.
 95 |   Raises:
 96 |     RuntimeError: If input images don't have the same shape or don't have four
 97 |       dimensions: [batch_size, height, width, depth].
 98 |   """
 99 |     if img1.shape != img2.shape:
100 |         raise RuntimeError(
101 |             'Input images must have the same shape (%s vs. %s).', img1.shape,
102 |             img2.shape)
103 |     if img1.ndim != 4:
104 |         raise RuntimeError('Input images must have four dimensions, not %d',
105 |                            img1.ndim)
106 | 
107 |     img1 = img1.astype(np.float64)
108 |     img2 = img2.astype(np.float64)
109 |     _, height, width, _ = img1.shape
110 | 
111 |     # Filter size can't be larger than height or width of images.
112 |     size = min(filter_size, height, width)
113 | 
114 |     # Scale down sigma if a smaller filter size is used.
115 |     sigma = size * filter_sigma / filter_size if filter_size else 0
116 | 
117 |     if filter_size:
118 |         window = np.reshape(_FSpecialGauss(size, sigma), (1, size, size, 1))
119 |         mu1 = signal.fftconvolve(img1, window, mode='valid')
120 |         mu2 = signal.fftconvolve(img2, window, mode='valid')
121 |         sigma11 = signal.fftconvolve(img1 * img1, window, mode='valid')
122 |         sigma22 = signal.fftconvolve(img2 * img2, window, mode='valid')
123 |         sigma12 = signal.fftconvolve(img1 * img2, window, mode='valid')
124 |     else:
125 |         # Empty blur kernel so no need to convolve.
126 |         mu1, mu2 = img1, img2
127 |         sigma11 = img1 * img1
128 |         sigma22 = img2 * img2
129 |         sigma12 = img1 * img2
130 | 
131 |     mu11 = mu1 * mu1
132 |     mu22 = mu2 * mu2
133 |     mu12 = mu1 * mu2
134 |     sigma11 -= mu11
135 |     sigma22 -= mu22
136 |     sigma12 -= mu12
137 | 
138 |     # Calculate intermediate values used by both ssim and cs_map.
139 |     c1 = (k1 * max_val)**2
140 |     c2 = (k2 * max_val)**2
141 |     v1 = 2.0 * sigma12 + c2
142 |     v2 = sigma11 + sigma22 + c2
143 |     ssim = np.mean((((2.0 * mu12 + c1) * v1) / ((mu11 + mu22 + c1) * v2)))
144 |     cs = np.mean(v1 / v2)
145 |     return ssim, cs
146 | 
147 | 
148 | def MultiScaleSSIM(img1,
149 |                    img2,
150 |                    max_val=255,
151 |                    filter_size=11,
152 |                    filter_sigma=1.5,
153 |                    k1=0.01,
154 |                    k2=0.03,
155 |                    weights=None):
156 |     """Return the MS-SSIM score between `img1` and `img2`.
157 |   This function implements Multi-Scale Structural Similarity (MS-SSIM) Image
158 |   Quality Assessment according to Zhou Wang's paper, "Multi-scale structural
159 |   similarity for image quality assessment" (2003).
160 |   Link: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf
161 |   Author's MATLAB implementation:
162 |   http://www.cns.nyu.edu/~lcv/ssim/msssim.zip
163 |   Arguments:
164 |     img1: Numpy array holding the first RGB image batch.
165 |     img2: Numpy array holding the second RGB image batch.
166 |     max_val: the dynamic range of the images (i.e., the difference between the
167 |       maximum the and minimum allowed values).
168 |     filter_size: Size of blur kernel to use (will be reduced for small images).
169 |     filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced
170 |       for small images).
171 |     k1: Constant used to maintain stability in the SSIM calculation (0.01 in
172 |       the original paper).
173 |     k2: Constant used to maintain stability in the SSIM calculation (0.03 in
174 |       the original paper).
175 |     weights: List of weights for each level; if none, use five levels and the
176 |       weights from the original paper.
177 |   Returns:
178 |     MS-SSIM score between `img1` and `img2`.
179 |   Raises:
180 |     RuntimeError: If input images don't have the same shape or don't have four
181 |       dimensions: [batch_size, height, width, depth].
182 |   """
183 |     if img1.shape != img2.shape:
184 |         raise RuntimeError(
185 |             'Input images must have the same shape (%s vs. %s).', img1.shape,
186 |             img2.shape)
187 |     if img1.ndim != 4:
188 |         raise RuntimeError('Input images must have four dimensions, not %d',
189 |                            img1.ndim)
190 | 
191 |     # Note: default weights don't sum to 1.0 but do match the paper / matlab code.
192 |     weights = np.array(weights if weights else
193 |                        [0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
194 |     levels = weights.size
195 |     downsample_filter = np.ones((1, 2, 2, 1)) / 4.0
196 |     im1, im2 = [x.astype(np.float64) for x in [img1, img2]]
197 |     mssim = np.array([])
198 |     mcs = np.array([])
199 |     for _ in range(levels):
200 |         ssim, cs = _SSIMForMultiScale(
201 |             im1,
202 |             im2,
203 |             max_val=max_val,
204 |             filter_size=filter_size,
205 |             filter_sigma=filter_sigma,
206 |             k1=k1,
207 |             k2=k2)
208 |         mssim = np.append(mssim, ssim)
209 |         mcs = np.append(mcs, cs)
210 |         filtered = [
211 |             convolve(im, downsample_filter, mode='reflect')
212 |             for im in [im1, im2]
213 |         ]
214 |         im1, im2 = [x[:, ::2, ::2, :] for x in filtered]
215 |     return (np.prod(mcs[0:levels - 1]**weights[0:levels - 1]) *
216 |             (mssim[levels - 1]**weights[levels - 1]))
217 | 
218 | 
219 | def msssim(original, compared):
220 |     if isinstance(original, str):
221 |         original = np.array(Image.open(original).convert('RGB'), dtype=np.float32)
222 |     if isinstance(compared, str):
223 |         compared = np.array(Image.open(compared).convert('RGB'), dtype=np.float32)
224 | 
225 |     original = original[None, ...] if original.ndim == 3 else original
226 |     compared = compared[None, ...] if compared.ndim == 3 else compared
227 | 
228 |     return MultiScaleSSIM(original, compared, max_val=255)
229 | 
230 | 
231 | def psnr(original, compared):
232 |     if isinstance(original, str):
233 |         original = np.array(Image.open(original).convert('RGB'), dtype=np.float32)
234 |     if isinstance(compared, str):
235 |         compared = np.array(Image.open(compared).convert('RGB'), dtype=np.float32)
236 | 
237 |     mse = np.mean(np.square(original - compared))
238 |     psnr = np.clip(
239 |         np.multiply(np.log10(255. * 255. / mse[mse > 0.]), 10.), 0., 99.99)[0]
240 |     return psnr
241 | 
242 | 
243 | def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx):
244 |     """Match each prior box with the ground truth box of the highest jaccard
245 |     overlap, encode the bounding boxes, then return the matched indices
246 |     corresponding to both confidence and location preds.
247 |     Args:
248 |         threshold: (float) The overlap threshold used when mathing boxes.
249 |         truths: (tensor) Ground truth boxes, Shape: [num_obj, num_priors].
250 |         priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
251 |         variances: (tensor) Variances corresponding to each prior coord,
252 |             Shape: [num_priors, 4].
253 |         labels: (tensor) All the class labels for the image, Shape: [num_obj].
254 |         loc_t: (tensor) Tensor to be filled w/ endcoded location targets.
255 |         conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
256 |         idx: (int) current batch index
257 |     Return:
258 |         The matched indices corresponding to 1)location and 2)confidence preds.
259 |     """
260 |     # jaccard index
261 |     overlaps = jaccard(
262 |         truths,
263 |         point_form(priors)
264 |     )
265 |     # (Bipartite Matching)
266 |     # [1,num_objects] best prior for each ground truth
267 |     best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
268 |     # [1,num_priors] best ground truth for each prior
269 |     best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
270 |     best_truth_idx.squeeze_(0)
271 |     best_truth_overlap.squeeze_(0)
272 |     best_prior_idx.squeeze_(1)
273 |     best_prior_overlap.squeeze_(1)
274 |     best_truth_overlap.index_fill_(0, best_prior_idx, 2)  # ensure best prior
275 |     # TODO refactor: index  best_prior_idx with long tensor
276 |     # ensure every gt matches with its prior of max overlap
277 |     for j in range(best_prior_idx.size(0)):
278 |         best_truth_idx[best_prior_idx[j]] = j
279 |     matches = truths[best_truth_idx]          # Shape: [num_priors,4]
280 |     conf = labels[best_truth_idx] + 1         # Shape: [num_priors]
281 |     conf[best_truth_overlap < threshold] = 0  # label as background
282 |     loc = encode(matches, priors, variances)
283 |     loc_t[idx] = loc    # [num_priors,4] encoded offsets to learn
284 |     conf_t[idx] = conf  # [num_priors] top class label for each prior
285 | 
286 | def point_form(boxes):
287 |     """ Convert prior_boxes to (xmin, ymin, xmax, ymax)
288 |     representation for comparison to point form ground truth data.
289 |     Args:
290 |         boxes: (tensor) center-size default boxes from priorbox layers.
291 |     Return:
292 |         boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
293 |     """
294 |     return torch.cat((boxes[:, :2] - boxes[:, 2:]/2,     # xmin, ymin
295 |                      boxes[:, :2] + boxes[:, 2:]/2), 1)  # xmax, ymax
296 | 
297 | def encode(matched, priors, variances):
298 |     """Encode the variances from the priorbox layers into the ground truth boxes
299 |     we have matched (based on jaccard overlap) with the prior boxes.
300 |     Args:
301 |         matched: (tensor) Coords of ground truth for each prior in point-form
302 |             Shape: [num_priors, 4].
303 |         priors: (tensor) Prior boxes in center-offset form
304 |             Shape: [num_priors,4].
305 |         variances: (list[float]) Variances of priorboxes
306 |     Return:
307 |         encoded boxes (tensor), Shape: [num_priors, 4]
308 |     """
309 | 
310 |     # dist b/t match center and prior's center
311 |     g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2]
312 |     # encode variance
313 |     g_cxcy /= (variances[0] * priors[:, 2:])
314 |     # match wh / prior wh
315 |     g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
316 |     g_wh = torch.log(g_wh) / variances[1]
317 |     # return target for smooth_l1_loss
318 |     return torch.cat([g_cxcy, g_wh], 1)  # [num_priors,4]
319 | 
320 | def find_intersection(set_1, set_2):
321 |     """
322 |     Find the intersection of every box combination between two sets of boxes that are in boundary coordinates.
323 |     :param set_1: set 1, a tensor of dimensions (n1, 4)
324 |     :param set_2: set 2, a tensor of dimensions (n2, 4)
325 |     :return: intersection of each of the boxes in set 1 with respect to each of the boxes in set 2, a tensor of dimensions (n1, n2)
326 |     """
327 | 
328 |     # PyTorch auto-broadcasts singleton dimensions
329 |     lower_bounds = torch.max(set_1[:, :2].unsqueeze(1), set_2[:, :2].unsqueeze(0))  # (n1, n2, 2)
330 |     upper_bounds = torch.min(set_1[:, 2:].unsqueeze(1), set_2[:, 2:].unsqueeze(0))  # (n1, n2, 2)
331 |     intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0)  # (n1, n2, 2)
332 |     return intersection_dims[:, :, 0] * intersection_dims[:, :, 1]  # (n1, n2)
333 | 
--------------------------------------------------------------------------------
/simsg/utils.py:
--------------------------------------------------------------------------------
 1 | #!/usr/bin/python
 2 | #
 3 | # Copyright 2018 Google LLC
 4 | #
 5 | # Licensed under the Apache License, Version 2.0 (the "License");
 6 | # you may not use this file except in compliance with the License.
 7 | # You may obtain a copy of the License at
 8 | #
 9 | #      http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | 
17 | import time
18 | import inspect
19 | import subprocess
20 | from contextlib import contextmanager
21 | 
22 | import torch
23 | 
24 | 
25 | def int_tuple(s):
26 |   return tuple(int(i) for i in s.split(','))
27 | 
28 | 
29 | def float_tuple(s):
30 |   return tuple(float(i) for i in s.split(','))
31 | 
32 | 
33 | def str_tuple(s):
34 |   return tuple(s.split(','))
35 | 
36 | 
37 | def bool_flag(s):
38 |   if s == '1' or s == 'True' or s == 'true':
39 |     return True
40 |   elif s == '0' or s == 'False' or s == 'false':
41 |     return False
42 |   msg = 'Invalid value "%s" for bool flag (should be 0/1 or True/False or true/false)'
43 |   raise ValueError(msg % s)
44 | 
45 | 
46 | def lineno():
47 |   return inspect.currentframe().f_back.f_lineno
48 | 
49 | 
50 | def get_gpu_memory():
51 |   torch.cuda.synchronize()
52 |   opts = [
53 |       'nvidia-smi', '-q', '--gpu=' + str(0), '|', 'grep', '"Used GPU Memory"'
54 |   ]
55 |   cmd = str.join(' ', opts)
56 |   ps = subprocess.Popen(cmd,shell=True,stdout=subprocess.PIPE,stderr=subprocess.STDOUT)
57 |   output = ps.communicate()[0].decode('utf-8')
58 |   output = output.split("\n")[1].split(":")
59 |   consumed_mem = int(output[1].strip().split(" ")[0])
60 |   return consumed_mem
61 | 
62 | 
63 | @contextmanager
64 | def timeit(msg, should_time=True):
65 |   if should_time:
66 |     torch.cuda.synchronize()
67 |     t0 = time.time()
68 |   yield
69 |   if should_time:
70 |     torch.cuda.synchronize()
71 |     t1 = time.time()
72 |     duration = (t1 - t0) * 1000.0
73 |     print('%s: %.2f ms' % (msg, duration))
74 | 
75 | 
76 | class LossManager(object):
77 |   def __init__(self):
78 |     self.total_loss = None
79 |     self.all_losses = {}
80 | 
81 |   def add_loss(self, loss, name, weight=1.0):
82 |     cur_loss = loss * weight
83 |     if self.total_loss is not None:
84 |       self.total_loss += cur_loss
85 |     else:
86 |       self.total_loss = cur_loss
87 | 
88 |     self.all_losses[name] = cur_loss.data.cpu().item()
89 | 
90 |   def items(self):
91 |     return self.all_losses.items()
92 | 
93 | 
--------------------------------------------------------------------------------
/simsg/vis.py:
--------------------------------------------------------------------------------
  1 | #!/usr/bin/python
  2 | #
  3 | # Copyright 2018 Google LLC
  4 | #
  5 | # Licensed under the Apache License, Version 2.0 (the "License");
  6 | # you may not use this file except in compliance with the License.
  7 | # You may obtain a copy of the License at
  8 | #
  9 | #      http://www.apache.org/licenses/LICENSE-2.0
 10 | #
 11 | # Unless required by applicable law or agreed to in writing, software
 12 | # distributed under the License is distributed on an "AS IS" BASIS,
 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14 | # See the License for the specific language governing permissions and
 15 | # limitations under the License.
 16 | 
 17 | import tempfile, os
 18 | import torch
 19 | import numpy as np
 20 | import matplotlib.pyplot as plt
 21 | from matplotlib.patches import Rectangle
 22 | from imageio import imread
 23 | 
 24 | 
 25 | """
 26 | Utilities for making visualizations.
 27 | """
 28 | 
 29 | 
 30 | def draw_layout(vocab, objs, boxes, masks=None, size=256,
 31 |                 show_boxes=False, bgcolor=(0, 0, 0)):
 32 |   if bgcolor == 'white':
 33 |     bgcolor = (255, 255, 255)
 34 | 
 35 |   cmap = plt.get_cmap('rainbow')
 36 |   colors = cmap(np.linspace(0, 1, len(objs)))
 37 | 
 38 |   with torch.no_grad():
 39 |     objs = objs.cpu().clone()
 40 |     boxes = boxes.cpu().clone()
 41 |     boxes *= size
 42 |     
 43 |     if masks is not None:
 44 |       masks = masks.cpu().clone()
 45 |     
 46 |     bgcolor = np.asarray(bgcolor)
 47 |     bg = np.ones((size, size, 1)) * bgcolor
 48 |     plt.imshow(bg.astype(np.uint8))
 49 | 
 50 |     plt.gca().set_xlim(0, size)
 51 |     plt.gca().set_ylim(size, 0)
 52 |     plt.gca().set_aspect(1.0, adjustable='box')
 53 |     
 54 |     for i, obj in enumerate(objs):
 55 |       name = vocab['object_idx_to_name'][obj]
 56 |       if name == '__image__':
 57 |         continue
 58 |       box = boxes[i]
 59 |       
 60 |       if masks is None:
 61 |         continue
 62 |       mask = masks[i].numpy()
 63 |       mask /= mask.max()
 64 | 
 65 |       r, g, b, a = colors[i]
 66 |       colored_mask = mask[:, :, None] * np.asarray(colors[i])
 67 |       
 68 |       x0, y0, x1, y1 = box
 69 |       plt.imshow(colored_mask, extent=(x0, x1, y1, y0),
 70 |                  interpolation='bicubic', alpha=1.0)
 71 | 
 72 |     if show_boxes:
 73 |       for i, obj in enumerate(objs):
 74 |         name = vocab['object_idx_to_name'][obj]
 75 |         if name == '__image__':
 76 |           continue
 77 |         box = boxes[i]
 78 | 
 79 |         draw_box(box, colors[i], name)
 80 | 
 81 | 
 82 | def draw_box(box, color, text=None):
 83 |   """
 84 |   Draw a bounding box using pyplot, optionally with a text box label.
 85 | 
 86 |   Inputs:
 87 |   - box: Tensor or list with 4 elements: [x0, y0, x1, y1] in [0, W] x [0, H]
 88 |          coordinate system.
 89 |   - color: pyplot color to use for the box.
 90 |   - text: (Optional) String; if provided then draw a label for this box.
 91 |   """
 92 |   TEXT_BOX_HEIGHT = 10
 93 |   if torch.is_tensor(box) and box.dim() == 2:
 94 |     box = box.view(-1)
 95 |     assert box.size(0) == 4
 96 |   x0, y0, x1, y1 = box
 97 |   assert y1 > y0, box
 98 |   assert x1 > x0, box
 99 |   w, h = x1 - x0, y1 - y0
100 |   rect = Rectangle((x0, y0), w, h, fc='none', lw=2, ec=color)
101 |   plt.gca().add_patch(rect)
102 |   if text is not None:
103 |     text_rect = Rectangle((x0, y0), w, TEXT_BOX_HEIGHT, fc=color, alpha=0.5)
104 |     plt.gca().add_patch(text_rect)
105 |     tx = 0.5 * (x0 + x1)
106 |     ty = y0 + TEXT_BOX_HEIGHT / 2.0
107 |     plt.text(tx, ty, text, va='center', ha='center')
108 | 
109 | 
110 | def draw_scene_graph(objs, triples, vocab=None, **kwargs):
111 |   """
112 |   Use GraphViz to draw a scene graph. If vocab is not passed then we assume
113 |   that objs and triples are python lists containing strings for object and
114 |   relationship names.
115 | 
116 |   Using this requires that GraphViz is installed. On Ubuntu 16.04 this is easy:
117 |   sudo apt-get install graphviz
118 |   """
119 |   output_filename = kwargs.pop('output_filename', 'graph.png')
120 |   orientation = kwargs.pop('orientation', 'LR')
121 |   edge_width = kwargs.pop('edge_width', 6)
122 |   arrow_size = kwargs.pop('arrow_size', 1.5)
123 |   binary_edge_weight = kwargs.pop('binary_edge_weight', 1.2)
124 |   ignore_dummies = kwargs.pop('ignore_dummies', True)
125 |   
126 |   #if orientation not in ['V', 'H']:
127 |   #  raise ValueError('Invalid orientation "%s"' % orientation)
128 |   rankdir = orientation # {'H': 'LR', 'V': 'TD'}[orientation]
129 | 
130 |   if vocab is not None:
131 |     # Decode object and relationship names
132 |     assert torch.is_tensor(objs)
133 |     assert torch.is_tensor(triples)
134 |     objs_list, triples_list = [], []
135 |     for i in range(objs.size(0)):
136 |       objs_list.append(vocab['object_idx_to_name'][objs[i].item()])
137 |     for i in range(triples.size(0)):
138 |       s = triples[i, 0].item()
139 |       #print(vocab['pred_idx_to_name'], triples[i, 1].item())
140 |       p = vocab['pred_idx_to_name'][triples[i, 1].item()]
141 |       o = triples[i, 2].item()
142 |       triples_list.append([s, p, o])
143 |     objs, triples = objs_list, triples_list
144 | 
145 |   # General setup, and style for object nodes
146 |   lines = [
147 |     'digraph{',
148 |     'graph [size="5,3",ratio="compress",dpi="300",bgcolor="white"]',
149 |     'rankdir=%s' % rankdir,
150 |     'nodesep="0.5"',
151 |     'ranksep="0.5"',
152 |     'node [shape="box",style="rounded,filled",fontsize="48",color="none"]',
153 |     'node [fillcolor="lightpink1"]',
154 |   ]
155 |   # Output nodes for objects
156 |   for i, obj in enumerate(objs):
157 |     if ignore_dummies and obj == '__image__':
158 |       continue
159 |     lines.append('%d [label="%s"]' % (i, obj))
160 | 
161 |   # Output relationships
162 |   next_node_id = len(objs)
163 |   lines.append('node [fillcolor="lightblue1"]')
164 |   for s, p, o in triples:
165 |     if ignore_dummies and p == '__in_image__':
166 |       continue
167 |     lines += [
168 |       '%d [label="%s"]' % (next_node_id, p),
169 |       '%d->%d [penwidth=%f,arrowsize=%f,weight=%f]' % (
170 |         s, next_node_id, edge_width, arrow_size, binary_edge_weight),
171 |       '%d->%d [penwidth=%f,arrowsize=%f,weight=%f]' % (
172 |         next_node_id, o, edge_width, arrow_size, binary_edge_weight)
173 |     ]
174 |     next_node_id += 1
175 |   lines.append('}')
176 | 
177 |   # Now it gets slightly hacky. Write the graphviz spec to a temporary
178 |   # text file
179 |   ff, dot_filename = tempfile.mkstemp()
180 |   #dot_filename = "graph.dot"
181 |   with open(dot_filename, 'w') as f:
182 |     for line in lines:
183 |       f.write('%s\n' % line)
184 |   os.close(ff)
185 | 
186 |   # Shell out to invoke graphviz; this will save the resulting image to disk,
187 |   # so we read it, delete it, then return it.
188 |   output_format = os.path.splitext(output_filename)[1][1:]
189 |   #os.system('dot -T%s %s > %s' % (output_format, dot_filename, output_filename))
190 |   os.system('ccomps -x %s | dot | gvpack -array3 | neato -Tpng -n2 -o %s' % (dot_filename, output_filename))
191 |   os.remove(dot_filename)
192 |   img = imread(output_filename)
193 |   os.remove(output_filename)
194 | 
195 |   return img
196 | 
197 | 
198 | if __name__ == '__main__':
199 |   o_idx_to_name = ['cat', 'dog', 'hat', 'skateboard']
200 |   p_idx_to_name = ['riding', 'wearing', 'on', 'next to', 'above']
201 |   o_name_to_idx = {s: i for i, s in enumerate(o_idx_to_name)}
202 |   p_name_to_idx = {s: i for i, s in enumerate(p_idx_to_name)}
203 |   vocab = {
204 |     'object_idx_to_name': o_idx_to_name,
205 |     'object_name_to_idx': o_name_to_idx,
206 |     'pred_idx_to_name': p_idx_to_name,
207 |     'pred_name_to_idx': p_name_to_idx,
208 |   }
209 | 
210 |   objs = [
211 |     'cat',
212 |     'cat',
213 |     'skateboard',
214 |     'hat',
215 |   ]
216 |   objs = torch.LongTensor([o_name_to_idx[o] for o in objs])
217 |   triples = [
218 |     [0, 'next to', 1],
219 |     [0, 'riding', 2],
220 |     [1, 'wearing', 3],
221 |     [3, 'above', 2],
222 |   ]
223 |   triples = [[s, p_name_to_idx[p], o] for s, p, o in triples]
224 |   triples = torch.LongTensor(triples)
225 | 
226 |   draw_scene_graph(objs, triples, vocab, orientation='V')
227 | 
228 | 
--------------------------------------------------------------------------------