├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── args_64_crn_clevr.yaml ├── args_64_crn_vg.yaml ├── args_64_spade_clevr.yaml ├── args_64_spade_vg.yaml ├── requirements.txt ├── scripts ├── SIMSG_gui.py ├── download_vg.sh ├── eval_utils.py ├── evaluate_changes_clevr.py ├── evaluate_changes_vg.py ├── evaluate_reconstruction.py ├── preprocess_vg.py ├── print_args.py ├── run_train.py ├── train.py └── train_utils.py └── simsg ├── SPADE ├── __init__.py ├── architectures.py ├── base_network.py └── normalization.py ├── __init__.py ├── bilinear.py ├── data ├── __init__.py ├── clevr.py ├── clevr_gen │ ├── 1_gen_data.sh │ ├── 2_arrange_data.sh │ ├── 3_clevrToVG.py │ ├── README_CLEVR.md │ ├── collect_scenes.py │ ├── data │ │ ├── CoGenT_A.json │ │ ├── CoGenT_B.json │ │ ├── base_scene.blend │ │ ├── materials │ │ │ ├── MyMetal.blend │ │ │ └── Rubber.blend │ │ ├── properties.json │ │ └── shapes │ │ │ ├── SmoothCube_v2.blend │ │ │ ├── SmoothCylinder.blend │ │ │ └── Sphere.blend │ ├── render_clevr.py │ └── utils.py ├── utils.py ├── vg.py └── vg_splits.json ├── decoder.py ├── discriminators.py ├── feats_statistics.py ├── graph.py ├── layers.py ├── layout.py ├── loader_utils.py ├── losses.py ├── metrics.py ├── model.py ├── utils.py └── vis.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "PerceptualSimilarity"] 2 | path = PerceptualSimilarity 3 | url = https://github.com/richzhang/PerceptualSimilarity 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SIMSG 2 | 3 | This is the code accompanying the paper 4 | 5 | **Semantic Image Manipulation Using Scene Graphs | arxiv**
6 | Helisa Dhamo*, Azade Farshad*, Iro Laina, Nassir Navab, Gregory D. Hager, Federico Tombari, Christian Rupprecht
7 | **CVPR 2020** 8 | 9 | The work for this paper was done at the Technical University of Munich. 10 | 11 | In our work, we address the novel problem of image manipulation from scene graphs, in which a user can edit images by 12 | merely applying changes in the nodes or edges of a semantic graph that is generated from the image. 13 | 14 | We introduce a spatio-semantic scene graph network that does not require direct supervision for constellation changes or 15 | image edits. This makes it possible to train the system from existing real-world datasets with no additional annotation 16 | effort. 17 | 18 | If you find this code useful in your research, please cite 19 | ``` 20 | @inproceedings{dhamo2020_SIMSG, 21 | title={Semantic Image Manipulation Using Scene Graphs}, 22 | author={Dhamo, Helisa and Farshad, Azade, and Laina, Iro and Navab, Nassir and 23 | Hager, Gregory D., and Tombari, Federico and Rupprecht, Christian}, 24 | booktitle={CVPR}, 25 | year={2020} 26 | ``` 27 | 28 | **Note:** The project page has been updated with new links for model checkpoints and data. 29 | 30 | 31 | ## Setup 32 | 33 | We have tested it on Ubuntu 16.04 with Python 3.7 and PyTorch 1.2. 34 | 35 | ### Setup code 36 | You can setup a conda environment to run the code like this: 37 | 38 | ```bash 39 | # clone this repository and move there 40 | git clone --recurse-submodules https://github.com/he-dhamo/simsg.git 41 | cd simsg 42 | # create a conda environment and install the requirments 43 | conda create --name simsg_env python=3.7 --file requirements.txt 44 | conda activate simsg_env # activate virtual environment 45 | # install pytorch and cuda version as tested in our work 46 | conda install pytorch==1.2.0 torchvision==0.4.0 cudatoolkit=10.0 -c pytorch 47 | # more installations 48 | pip install opencv-python tensorboardx grave addict 49 | # to add current directory to python path run 50 | echo $PWD > /lib/python3.7/site-packages/simsg.pth 51 | ``` 52 | `path_to_env` can be found using `which python` (the resulting path minus `/bin/python`) while the conda env is active. 53 | 54 | ### Setup Visual Genome data 55 | (instructions from the sg2im repository) 56 | 57 | Run the following script to download and unpack the relevant parts of the Visual Genome dataset: 58 | 59 | ```bash 60 | bash scripts/download_vg.sh 61 | ``` 62 | 63 | This will create the directory `datasets/vg` and will download about 15 GB of data to this directory; after unpacking it 64 | will take about 30 GB of disk space. 65 | 66 | After downloading the Visual Genome dataset, we need to preprocess it. This will split the data into train / val / test 67 | splits, consolidate all scene graphs into HDF5 files, and apply several heuristics to clean the data. In particular we 68 | ignore images that are too small, and only consider object and attribute categories that appear some number of times in 69 | the training set; we also ignore objects that are too small, and set minimum and maximum values on the number of objects 70 | and relationships that appear per image. 71 | 72 | ```bash 73 | python scripts/preprocess_vg.py 74 | ``` 75 | 76 | This will create files `train.h5`, `val.h5`, `test.h5`, and `vocab.json` in the directory `datasets/vg`. 77 | 78 | 79 | ## Training 80 | 81 | To train the model, you can either set the right options in a `args.yaml` file and then run the `run_train.py` script 82 | (**highly recommended**): 83 | ``` 84 | python scripts/run_train.py args.yaml 85 | ``` 86 | Or, alternatively run the `train.py` script with the respective arguments. For example: 87 | ``` 88 | python scripts/train.py --checkpoint_name myckpt --dataset vg --spade_blocks True 89 | ``` 90 | 91 | We provide the configuration files for the experiments presented in the paper, represented in the format 92 | `args_{resolution}_{decoder}_{dataset}.yaml`. 93 | 94 | Please set the dataset path `DATA_DIR` in `train.py` before running. 95 | 96 | Most relevant arguments: 97 | 98 | `--checkpoint_name`: Base filename for saved checkpoints; default is 'checkpoint', so the filename for the checkpoint 99 | with model parameters will be 'checkpoint_model.pt'
100 | `--selective_discr_obj`: If True, only apply the object discriminator in the reconstructed RoIs (masked image areas). 101 |
102 | `--feats_in_gcn`: If True, feed the RoI visual features in the scene graph network (SGN).
103 | `--feats_out_gcn`: If True, concatenate RoI visual features with the node feature (result of SGN).
104 | `--is_baseline`: If True, run cond-sg2im baseline.
105 | `--is_supervised`: If True, run fully-supervised baseline for CLEVR.
106 | `--spade_gen_blocks`: If True, use SPADE blocks in the decoder architecture. Otherwise, use CRN blocks. 107 | 108 | ## Evaluate reconstruction 109 | 110 | To evaluate the model in reconstruction mode, you need to run the script ```evaluate_reconstruction.py```. The MAE, SSIM 111 | and LPIPS will be printed on the terminal with frequency ```--print_every``` and saved on pickle files with frequency 112 | ```--save_every```. When ```--save_images``` is set to ```True```, the code also saves one reconstructed sample per 113 | image, which can be used for the computation of FID and 114 | Inception score. 115 | 116 | Other relevant arguments are: 117 | 118 | `--exp_dir`: path to folder where all experiments are saved
119 | `--experiment`: name of experiment, e.g. spade_vg_64
120 | `--checkpoint`: checkpoint path
121 | `--with_feats`: (bool) using RoI visual features
122 | `--generative`: (bool) fully-generative mode, i.e. the whole input image is masked out 123 | 124 | If `--checkpoint` is not specified, the checkpoint path is automatically set to ```/_model.pt```. 125 | 126 | To reproduce the results from Table 2, please run: 127 | 1. for the fully generative task: 128 | ``` 129 | python scripts/evaluate_reconstruction.py --exp_dir /path/to/experiment/dir --experiment spade --with_feats True 130 | --generative True 131 | ``` 132 | 2. for the RoI reconstruction with visual features: 133 | ``` 134 | python scripts/evaluate_reconstruction.py --exp_dir /path/to/experiment/dir --experiment spade --with_feats True 135 | --generative False 136 | ``` 137 | 3. for the RoI reconstruction without visual features: 138 | ``` 139 | python scripts/evaluate_reconstruction.py --exp_dir /path/to/experiment/dir --experiment spade --with_feats False 140 | --generative False 141 | ``` 142 | 143 | To evaluate the model with ground truth (GT) and predicted (PRED) scene graphs, please set `--predgraphs` namely to 144 | ```False``` or ```True```. 145 | Before running with predicted graphs (PRED), make sure you have downloaded the respective predicted graphs 146 | `test_predgraphs.h5` and placed it at the same directory as ```test.h5```. 147 | 148 | Please set the dataset path `DATA_DIR` before running. 149 | 150 | For the LPIPS computation, we use PerceptualSimilarity cloned from the 151 | official repo as a git submodule. If you did not clone our 152 | repository recursively, run: 153 | ``` 154 | git submodule update --init --recursive 155 | ``` 156 | 157 | ## Evaluate changes on Visual Genome 158 | 159 | To evaluate the model in semantic editing mode on Visual Genome, you need to run the script 160 | ```evaluate_changes_vg.py``` . The script automatically generates edited images, without a user interface. 161 | 162 | The most relevant arguments are: 163 | 164 | `--exp_dir`: path to folder where all experiments are saved
165 | `--experiment`: name of experiment, e.g. spade_vg_64
166 | `--checkpoint`: checkpoint path
167 | `--mode`: choose from options ['replace', 'reposition', 'remove', 'auto_withfeats', 'auto_nofeats'], for namely 168 | object replacement, relationship change, object removal, RoI reconstruction with visual features, RoI reconstruction 169 | without features.
170 | `--with_query_image`: (bool) in case you want to use visual features from another image (query image). used in 171 | combination with `mode='auto_nofeats'`.
172 | 173 | If `--checkpoint` is not specified, the checkpoint path is automatically set to ```/_model.pt```. 174 | 175 | Example run of object replacement changes using the spade model: 176 | ``` 177 | python scripts/evaluate_changes_vg.py --exp_dir /path/to/experiment/dir --experiment spade --mode replace 178 | ``` 179 | 180 | Please set the dataset path `VG_DIR` before running. 181 | 182 | ## Evaluate changes on CLEVR 183 | 184 | To evaluate the model in semantic editing mode on CLEVR, you need to run the script ```evaluate_changes_clevr.py```. 185 | The script automatically generates edited images, without a user interface. 186 | 187 | The most relevant arguments are: 188 | 189 | `--exp_dir`: path to folder where all experiments are saved
190 | `--experiment`: name of experiment, e.g. spade_clevr_64
191 | `--checkpoint`: checkpoint path
192 | `--image_size`: size of the input image, can be (64,64) or (128,128) based on the size used in training
193 | 194 | If `--checkpoint` is not specified, the checkpoint path is automatically set to ```/_model.pt```. 195 | 196 | Example run of object replacement changes using the spade model: 197 | ``` 198 | python scripts/evaluate_changes_clevr.py --exp_dir /path/to/experiment/dir --experiment spade 199 | ``` 200 | 201 | Please set the dataset path `CLEVR_DIR` before running. 202 | 203 | ## User Interface 204 | 205 | To run a simple user interface that supports different manipulation types, such as object addition, removal, replacement 206 | and relationship change run: 207 | ``` 208 | python scripts/SIMSG_gui.py 209 | ``` 210 | 211 | Relevant options: 212 | 213 | `--checkpoint`: path to checkpoint file
214 | `--dataset`: visual genome or clevr
215 | `--predgraphs`:(bool) specifies loading either ground truth or predicted graphs. So far, predicted graphs are only 216 | available for visual genome.
217 | `--data_h5`: path of h5 file used to load a certain data split. If not excplicitly set, it uses `test.h5` for GT graphs 218 | and `test_predgraphs.h5` for predicted graphs.
219 | `--update_input`: (bool) used to control a sequence of changes in the same image. If `True`, it sets the input as the 220 | output of the previous generation. Otherwise, all consecutive changes are applied on the original image. 221 | Please set the value of DATA_DIR in SIMSG_gui.py to point to your dataset. 222 | 223 | Once you click on "Load image" and "Get graph" a new image and corresponding scene graph will appear. The current 224 | implementation loads GT or predicted graph. Then you can apply one of the following manipulations: 225 | 226 | - For **object replacement** or **relationship change**: 227 | 1. Select the node you want to change 228 | 2. Choose a new category from namely the "Replace object" and "Change Relationship" menu. 229 | - For **object addition**: 230 | 1. Choose a node where you want to connect the new object to (from "Connect to object") 231 | 2. Select the category of the new object and relationship ("Add new node", "Add relationship"). 232 | 3. Specify the direction of the connection. Click on "Add as subject" for a `new_node -> predicate -> existing_node` 233 | direction, and click "Add as object" for `existing_node -> predicate -> new_node`. 234 | - For **object removal**: 235 | 1. Select the node you want to remove 236 | 2. Click on "Remove node"
237 | Note that the current implementation only supports object removal (cannot remove relationships); though the model 238 | supports this and the GUI implementation can be extended accordingly. 239 | 240 | After you have completed the change, click on "Generate image". Alternatively you can save the image. 241 | 242 | ## Download 243 | 244 | Visit the project page to download model checkpoints, 245 | predicted scene graphs and the CLEVR data with edit pairs. 246 | 247 | We also provide the code used to generate the CLEVR data with change pairs. Please follow the instructions from 248 | [here](simsg/data/clevr_gen/README_CLEVR.md). 249 | 250 | ## Acknoledgement 251 | 252 | This code is based on the sg2im repository . 253 | 254 | The following directory is taken from the SPADE repository: 255 | - simsg/SPADE/ 256 | -------------------------------------------------------------------------------- /args_64_crn_clevr.yaml: -------------------------------------------------------------------------------- 1 | # arguments for easier parsing 2 | seed: 1 3 | gpu: 0 4 | 5 | # DATA OPTIONS 6 | dataset: clevr 7 | vg_image_dir: ./datasets/clevr/target 8 | output_dir: experiments/clevr 9 | checkpoint_name: crn_64_clevr 10 | log_dir: experiments/clevr/logs/crn_64_clevr 11 | 12 | # ARCHITECTURE OPTIONS 13 | image_size: !!python/tuple [64, 64] 14 | crop_size: 32 15 | batch_size: 32 16 | mask_size: 16 17 | d_obj_arch: C4-64-2,C4-128-2,C4-256-2 18 | d_img_arch: C4-64-2,C4-128-2,C4-256-2 19 | decoder_network_dims: !!python/tuple [1024,512,256,128,64] 20 | layout_pooling: sum 21 | 22 | # CRN weights 23 | percept_weight: 0 24 | weight_gan_feat: 0 25 | discriminator_loss_weight: 0.01 26 | d_obj_weight: 1 27 | ac_loss_weight: 0.1 28 | d_img_weight: 1 29 | l1_pixel_loss_weight: 1 30 | bbox_pred_loss_weight: 10 31 | 32 | # EXTRA OPTIONS 33 | feats_in_gcn: True 34 | feats_out_gcn: True 35 | is_baseline: False 36 | is_supervised: False 37 | num_iterations: 50000 38 | 39 | # LOGGING OPTIONS 40 | print_every: 500 41 | checkpoint_every: 2000 42 | max_num_imgs: 32 43 | -------------------------------------------------------------------------------- /args_64_crn_vg.yaml: -------------------------------------------------------------------------------- 1 | # arguments for easier parsing 2 | seed: 1 3 | gpu: 0 4 | 5 | # DATA OPTIONS 6 | dataset: vg 7 | vg_image_dir: ./datasets/vg/images 8 | output_dir: experiments/vg 9 | checkpoint_name: crn_64_vg 10 | log_dir: experiments/vg/logs/crn_64_vg 11 | 12 | # ARCHITECTURE OPTIONS 13 | image_size: !!python/tuple [64, 64] 14 | crop_size: 32 15 | batch_size: 32 16 | mask_size: 16 17 | d_obj_arch: C4-64-2,C4-128-2,C4-256-2 18 | d_img_arch: C4-64-2,C4-128-2,C4-256-2 19 | decoder_network_dims: !!python/tuple [1024,512,256,128,64] 20 | layout_pooling: sum 21 | 22 | # CRN weights 23 | percept_weight: 0 24 | weight_gan_feat: 0 25 | discriminator_loss_weight: 0.01 26 | d_obj_weight: 1 27 | ac_loss_weight: 0.1 28 | d_img_weight: 1 29 | l1_pixel_loss_weight: 1 30 | bbox_pred_loss_weight: 10 31 | 32 | # EXTRA OPTIONS 33 | feats_in_gcn: True 34 | feats_out_gcn: True 35 | is_baseline: False 36 | is_supervised: False 37 | 38 | # LOGGING OPTIONS 39 | print_every: 500 40 | checkpoint_every: 2000 41 | max_num_imgs: 32 42 | -------------------------------------------------------------------------------- /args_64_spade_clevr.yaml: -------------------------------------------------------------------------------- 1 | # arguments for easier parsing 2 | seed: 1 3 | gpu: 0 4 | 5 | # DATA OPTIONS 6 | dataset: clevr 7 | vg_image_dir: ./datasets/clevr/target 8 | output_dir: experiments/clevr 9 | checkpoint_name: spade_64_clevr 10 | log_dir: experiments/clevr/logs/spade_64_clevr 11 | 12 | # ARCHITECTURE OPTIONS 13 | image_size: !!python/tuple [64, 64] 14 | crop_size: 32 15 | batch_size: 32 16 | mask_size: 16 17 | d_obj_arch: C4-64-2,C4-128-2,C4-256-2 18 | d_img_arch: C4-64-2,C4-128-2,C4-256-2 19 | decoder_network_dims: !!python/tuple [1024,512,256,128,64] 20 | layout_pooling: sum 21 | 22 | # spade weights + options 23 | percept_weight: 5 24 | weight_gan_feat: 5 25 | discriminator_loss_weight: 1 26 | d_obj_weight: 0.1 27 | ac_loss_weight: 0.1 28 | d_img_weight: 1 29 | l1_pixel_loss_weight: 1 30 | bbox_pred_loss_weight: 50 31 | multi_discriminator: True 32 | spade_gen_blocks: True 33 | 34 | # EXTRA OPTIONS 35 | feats_in_gcn: True 36 | feats_out_gcn: True 37 | is_baseline: False 38 | is_supervised: True 39 | num_iterations: 50000 40 | 41 | # LOGGING OPTIONS 42 | print_every: 500 43 | checkpoint_every: 2000 44 | max_num_imgs: 32 45 | -------------------------------------------------------------------------------- /args_64_spade_vg.yaml: -------------------------------------------------------------------------------- 1 | # arguments for easier parsing 2 | seed: 1 3 | gpu: 0 4 | 5 | # DATA OPTIONS 6 | dataset: vg 7 | vg_image_dir: ./datasets/vg/images 8 | output_dir: experiments/vg 9 | checkpoint_name: spade_64_vg 10 | log_dir: experiments/vg/logs/spade_64_vg 11 | 12 | # ARCHITECTURE OPTIONS 13 | image_size: !!python/tuple [64, 64] 14 | crop_size: 32 15 | batch_size: 32 16 | mask_size: 16 17 | d_obj_arch: C4-64-2,C4-128-2,C4-256-2 18 | d_img_arch: C4-64-2,C4-128-2,C4-256-2 19 | decoder_network_dims: !!python/tuple [1024,512,256,128,64] 20 | layout_pooling: sum 21 | 22 | # spade weights + options 23 | percept_weight: 5 24 | weight_gan_feat: 5 25 | discriminator_loss_weight: 1 26 | d_obj_weight: 0.1 27 | ac_loss_weight: 0.1 28 | d_img_weight: 1 29 | l1_pixel_loss_weight: 1 30 | bbox_pred_loss_weight: 50 31 | multi_discriminator: True 32 | spade_gen_blocks: True 33 | 34 | # EXTRA OPTIONS 35 | feats_in_gcn: True 36 | feats_out_gcn: True 37 | is_baseline: False 38 | is_supervised: False 39 | 40 | # LOGGING OPTIONS 41 | print_every: 500 42 | checkpoint_every: 2000 43 | max_num_imgs: 32 44 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | cloudpickle=0.5.3 4 | cycler=0.10.0 5 | cython=0.28.3 6 | decorator=4.3.0 7 | h5py=2.10.0 8 | imageio=2.3.0 9 | kiwisolver=1.1.0 10 | matplotlib=3.1.1 11 | networkx=2.1 12 | ninja=1.9.0 13 | numpy=1.16 14 | pillow>=6.2.2 15 | pyqt=5.9.2 16 | pyyaml=5.1.2 17 | pywavelets=0.5.2 18 | qt=5.9.7 19 | scikit-image=0.15.0 20 | scikit-learn=0.21.3 21 | scipy=1.3.1 22 | six=1.12.0 23 | sqlite=3.29.0=h7b6447c_0 24 | toolz=0.9.0 25 | tqdm=4.31.1 26 | wheel=0.33.6 27 | yaml=0.1.7 28 | -------------------------------------------------------------------------------- /scripts/download_vg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | VG_DIR=datasets/vg 18 | mkdir -p $VG_DIR 19 | 20 | wget https://visualgenome.org/static/data/dataset/objects.json.zip -O $VG_DIR/objects.json.zip 21 | wget https://visualgenome.org/static/data/dataset/attributes.json.zip -O $VG_DIR/attributes.json.zip 22 | wget https://visualgenome.org/static/data/dataset/relationships.json.zip -O $VG_DIR/relationships.json.zip 23 | wget https://visualgenome.org/static/data/dataset/object_alias.txt -O $VG_DIR/object_alias.txt 24 | wget https://visualgenome.org/static/data/dataset/relationship_alias.txt -O $VG_DIR/relationship_alias.txt 25 | wget https://visualgenome.org/static/data/dataset/image_data.json.zip -O $VG_DIR/image_data.json.zip 26 | wget https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip -O $VG_DIR/images.zip 27 | wget https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip -O $VG_DIR/images2.zip 28 | 29 | unzip $VG_DIR/objects.json.zip -d $VG_DIR 30 | unzip $VG_DIR/attributes.json.zip -d $VG_DIR 31 | unzip $VG_DIR/relationships.json.zip -d $VG_DIR 32 | unzip $VG_DIR/image_data.json.zip -d $VG_DIR 33 | unzip $VG_DIR/images.zip -d $VG_DIR/images 34 | unzip $VG_DIR/images2.zip -d $VG_DIR/images 35 | -------------------------------------------------------------------------------- /scripts/eval_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2020 Helisa Dhamo 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import os 18 | import json 19 | import numpy as np 20 | from matplotlib import pyplot as plt 21 | import matplotlib.patches as patches 22 | import cv2 23 | import torch 24 | from simsg.vis import draw_scene_graph 25 | from simsg.data import imagenet_deprocess_batch 26 | from imageio import imsave 27 | 28 | from simsg.model import mask_image_in_bbox 29 | 30 | 31 | def remove_node(objs, triples, boxes, imgs, idx, obj_to_img, triple_to_img): 32 | ''' 33 | removes nodes and all related edges in case of object removal 34 | image is also masked in the respective area 35 | idx: list of object ids to be removed 36 | Returns: 37 | updated objs, triples, boxes, imgs, obj_to_img, triple_to_img 38 | ''' 39 | 40 | # object nodes 41 | idlist = list(range(objs.shape[0])) 42 | keeps = [i for i in idlist if i not in idx] 43 | objs_reduced = objs[keeps] 44 | boxes_reduced = boxes[keeps] 45 | 46 | offset = torch.zeros_like(objs) 47 | for i in range(objs.shape[0]): 48 | for j in idx: 49 | if j < i: 50 | offset[i] += 1 51 | 52 | # edges connected to removed object 53 | keeps_t = [] 54 | triples_reduced = [] 55 | for i in range(triples.shape[0]): 56 | if not(triples[i,0] in idx or triples[i, 2] in idx): 57 | keeps_t.append(i) 58 | triples_reduced.append(torch.tensor([triples[i,0] - offset[triples[i,0]], triples[i,1], 59 | triples[i,2] - offset[triples[i,2]]], device=triples.device)) 60 | triples_reduced = torch.stack(triples_reduced, dim=0) 61 | 62 | # update indexing arrays 63 | obj_to_img_reduced = obj_to_img[keeps] 64 | triple_to_img_reduced = triple_to_img[keeps_t] 65 | 66 | # mask RoI of removed objects from image 67 | for i in idx: 68 | imgs = mask_image_in_bbox(imgs, boxes, i, obj_to_img, mode='removal') 69 | 70 | return objs_reduced, triples_reduced, boxes_reduced, imgs, obj_to_img_reduced, triple_to_img_reduced 71 | 72 | 73 | def bbox_coordinates_with_margin(bbox, margin, img): 74 | # extract bounding box with a margin 75 | 76 | left = max(0, bbox[0] * img.shape[3] - margin) 77 | top = max(0, bbox[1] * img.shape[2] - margin) 78 | right = min(img.shape[3], bbox[2] * img.shape[3] + margin) 79 | bottom = min(img.shape[2], bbox[3] * img.shape[2] + margin) 80 | 81 | return int(left), int(right), int(top), int(bottom) 82 | 83 | 84 | def save_image_from_tensor(img, img_dir, filename): 85 | 86 | img = imagenet_deprocess_batch(img) 87 | img_np = img[0].numpy().transpose(1, 2, 0) 88 | img_path = os.path.join(img_dir, filename) 89 | imsave(img_path, img_np) 90 | 91 | 92 | def save_image_with_label(img_pred, img_gt, img_dir, filename, txt_str): 93 | # saves gt and generated image, concatenated 94 | # together with text label describing the change 95 | # used for easier visualization of results 96 | 97 | img_pred = imagenet_deprocess_batch(img_pred) 98 | img_gt = imagenet_deprocess_batch(img_gt) 99 | 100 | img_pred_np = img_pred[0].numpy().transpose(1, 2, 0) 101 | img_gt_np = img_gt[0].numpy().transpose(1, 2, 0) 102 | 103 | img_pred_np = cv2.resize(img_pred_np, (128, 128)) 104 | img_gt_np = cv2.resize(img_gt_np, (128, 128)) 105 | 106 | wspace = np.zeros([img_pred_np.shape[0], 10, 3]) 107 | text = np.zeros([30, img_pred_np.shape[1] * 2 + 10, 3]) 108 | text = cv2.putText(text, txt_str, (0,20), cv2.FONT_HERSHEY_SIMPLEX, 109 | 0.5, (255, 255, 255), lineType=cv2.LINE_AA) 110 | 111 | img_pred_gt = np.concatenate([img_gt_np, wspace, img_pred_np], axis=1).astype('uint8') 112 | img_pred_gt = np.concatenate([text, img_pred_gt], axis=0).astype('uint8') 113 | img_path = os.path.join(img_dir, filename) 114 | imsave(img_path, img_pred_gt) 115 | 116 | 117 | def makedir(base, name, flag=True): 118 | dir_name = None 119 | if flag: 120 | dir_name = os.path.join(base, name) 121 | if not os.path.isdir(dir_name): 122 | os.makedirs(dir_name) 123 | return dir_name 124 | 125 | 126 | def save_graph_json(objs, triples, boxes, beforeafter, dir, idx): 127 | # save scene graph in json form 128 | 129 | data = {} 130 | objs = objs.cpu().numpy() 131 | triples = triples.cpu().numpy() 132 | data['objs'] = objs.tolist() 133 | data['triples'] = triples.tolist() 134 | data['boxes'] = boxes.tolist() 135 | with open(dir + '/' + beforeafter + '_' + str(idx) + '.json', 'w') as outfile: 136 | json.dump(data, outfile) 137 | 138 | 139 | def query_image_by_semantic_id(obj_id, curr_img_id, loader, num_samples=7): 140 | # used to replace objects with an object of the same category and different appearance 141 | # return list of images and bboxes, that contain object of category obj_id 142 | 143 | query_imgs, query_boxes = [], [] 144 | loader_id = 0 145 | counter = 0 146 | 147 | for l in loader: 148 | # load images 149 | imgs, objs, boxes, _, _, _, _ = [x.cuda() for x in l] 150 | if loader_id == curr_img_id: 151 | loader_id += 1 152 | continue 153 | 154 | for i, ob in enumerate(objs): 155 | if obj_id[0] == ob: 156 | query_imgs.append(imgs) 157 | query_boxes.append(boxes[i]) 158 | counter += 1 159 | if counter == num_samples: 160 | return query_imgs, query_boxes 161 | loader_id += 1 162 | 163 | return 0, 0 164 | 165 | 166 | def draw_image_box(img, box): 167 | 168 | left, right, top, bottom = int(round(box[0] * img.shape[1])), int(round(box[2] * img.shape[1])), \ 169 | int(round(box[1] * img.shape[0])), int(round(box[3] * img.shape[0])) 170 | 171 | cv2.rectangle(img, (left, top), (right, bottom), (255,0,0), 1) 172 | return img 173 | 174 | 175 | def draw_image_edge(img, box1, box2): 176 | # draw arrow that connects two objects centroids 177 | left1, right1, top1, bottom1 = int(round(box1[0] * img.shape[1])), int(round(box1[2] * img.shape[1])), \ 178 | int(round(box1[1] * img.shape[0])), int(round(box1[3] * img.shape[0])) 179 | left2, right2, top2, bottom2 = int(round(box2[0] * img.shape[1])), int(round(box2[2] * img.shape[1])), \ 180 | int(round(box2[1] * img.shape[0])), int(round(box2[3] * img.shape[0])) 181 | 182 | cv2.arrowedLine(img, (int((left1+right1)/2), int((top1+bottom1)/2)), 183 | (int((left2+right2)/2), int((top2+bottom2)/2)), (255,0,0), 1) 184 | 185 | return img 186 | 187 | 188 | def visualize_imgs_boxes(imgs, imgs_pred, boxes, boxes_pred): 189 | 190 | nrows = imgs.size(0) 191 | imgs = imgs.detach().cpu().numpy() 192 | imgs_pred = imgs_pred.detach().cpu().numpy() 193 | boxes = boxes.detach().cpu().numpy() 194 | boxes_pred = boxes_pred.detach().cpu().numpy() 195 | plt.figure() 196 | 197 | for i in range(0, nrows): 198 | # i = j//2 199 | ax1 = plt.subplot(2, nrows, i+1) 200 | img = np.transpose(imgs[i, :, :, :], (1, 2, 0)) / 255. 201 | plt.imshow(img) 202 | 203 | left, right, top, bottom = bbox_coordinates_with_margin(boxes[i, :], 0, imgs[i:i+1, :, :, :]) 204 | bbox_gt = patches.Rectangle((left, top), 205 | width=right-left, 206 | height=bottom-top, 207 | linewidth=1, edgecolor='r', facecolor='none') 208 | # Add the patch to the Axes 209 | ax1.add_patch(bbox_gt) 210 | plt.axis('off') 211 | 212 | ax2 = plt.subplot(2, nrows, i+nrows+1) 213 | pred = np.transpose(imgs_pred[i, :, :, :], (1, 2, 0)) / 255. 214 | plt.imshow(pred) 215 | 216 | left, right, top, bottom = bbox_coordinates_with_margin(boxes_pred[i, :], 0, imgs[i:i+1, :, :, :]) 217 | bbox_pr = patches.Rectangle((left, top), 218 | width=right-left, 219 | height=bottom-top, 220 | linewidth=1, edgecolor='r', facecolor='none') 221 | # ax2.add_patch(bbox_gt) 222 | ax2.add_patch(bbox_pr) 223 | plt.axis('off') 224 | 225 | plt.show() 226 | 227 | 228 | def visualize_scene_graphs(obj_to_img, objs, triples, vocab, device): 229 | offset = 0 230 | for i in range(1):#imgs_in.size(0)): 231 | curr_obj_idx = (obj_to_img == i).nonzero() 232 | 233 | objs_vis = objs[curr_obj_idx] 234 | triples_vis = [] 235 | for j in range(triples.size(0)): 236 | if triples[j, 0] in curr_obj_idx or triples[j, 2] in curr_obj_idx: 237 | triples_vis.append(triples[j].to(device) - torch.tensor([offset, 0, offset]).to(device)) 238 | offset += curr_obj_idx.size(0) 239 | triples_vis = torch.stack(triples_vis, 0) 240 | 241 | print(objs_vis, triples_vis) 242 | graph_img = draw_scene_graph(objs_vis, triples_vis, vocab) 243 | 244 | cv2.imshow('graph' + str(i), graph_img) 245 | cv2.waitKey(10000) 246 | 247 | 248 | def remove_duplicates(triples, triple_to_img, indexes): 249 | # removes duplicates in relationship triples 250 | 251 | triples_new = [] 252 | triple_to_img_new = [] 253 | 254 | for i in range(triples.size(0)): 255 | if i not in indexes: 256 | triples_new.append(triples[i]) 257 | triple_to_img_new.append(triple_to_img[i]) 258 | 259 | triples_new = torch.stack(triples_new, 0) 260 | triple_to_img_new = torch.stack(triple_to_img_new, 0) 261 | 262 | return triples_new, triple_to_img_new 263 | 264 | 265 | def parse_bool(pred_graphs, generative, use_gt_boxes, use_feats): 266 | # returns name of output directory depending on arguments 267 | 268 | if pred_graphs: 269 | name = "pred/" 270 | else: 271 | name = "" 272 | if generative: # fully generative mode 273 | return name + "generative" 274 | else: 275 | if use_gt_boxes: 276 | b = "withbox" 277 | else: 278 | b = "nobox" 279 | if use_feats: 280 | f = "withfeats" 281 | else: 282 | f = "nofeats" 283 | 284 | return name + b + "_" + f 285 | 286 | 287 | def is_background(label_id): 288 | 289 | if label_id in [169, 60, 61, 49, 141, 8, 11, 52, 66]: 290 | return True 291 | else: 292 | return False 293 | 294 | 295 | def get_selected_objects(): 296 | 297 | objs = ["", "apple", "ball", "banana", "beach", "bike", "bird", "bus", "bush", "cat", "car", "chair", "cloud", "dog", 298 | "elephant", "field", "giraffe", "man", "motorcycle", "ocean", "person", "plane", "sheep", "tree", "zebra"] 299 | 300 | return objs 301 | -------------------------------------------------------------------------------- /scripts/print_args.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import argparse 18 | import torch 19 | 20 | 21 | """ 22 | Tiny utility to print the command-line args used for a checkpoint 23 | """ 24 | 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('checkpoint') 28 | 29 | 30 | def main(args): 31 | checkpoint = torch.load(args.checkpoint, map_location='cpu') 32 | for k, v in checkpoint['args'].items(): 33 | print(k, v) 34 | 35 | 36 | if __name__ == '__main__': 37 | args = parser.parse_args() 38 | main(args) 39 | 40 | -------------------------------------------------------------------------------- /scripts/run_train.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from addict import Dict 3 | from scripts.train import argument_parser 4 | from scripts.train import main as train_main 5 | import torch 6 | import shutil 7 | import os 8 | import random 9 | import sys 10 | 11 | 12 | # needed to solve tuple reading issue 13 | class PrettySafeLoader(yaml.SafeLoader): 14 | def construct_python_tuple(self, node): 15 | return tuple(self.construct_sequence(node)) 16 | 17 | PrettySafeLoader.add_constructor( 18 | u'tag:yaml.org,2002:python/tuple', 19 | PrettySafeLoader.construct_python_tuple) 20 | 21 | 22 | def main(): 23 | # argument: name of yaml config file 24 | try: 25 | filename = sys.argv[1] 26 | except IndexError as error: 27 | print("provide yaml file as command-line argument!") 28 | exit() 29 | 30 | config = Dict(yaml.load(open(filename), Loader=PrettySafeLoader)) 31 | os.makedirs(config.log_dir, exist_ok=True) 32 | # save a copy in the experiment dir 33 | shutil.copyfile(filename, os.path.join(config.log_dir, 'args.yaml')) 34 | 35 | torch.cuda.set_device(config.gpu) 36 | torch.manual_seed(config.seed) 37 | random.seed(config.seed) 38 | 39 | args = yaml_to_parser(config) 40 | train_main(args) 41 | 42 | 43 | def yaml_to_parser(config): 44 | parser = argument_parser() 45 | args, unknown = parser.parse_known_args() 46 | 47 | args_dict = vars(args) 48 | for key, value in config.items(): 49 | try: 50 | args_dict[key] = value 51 | except KeyError: 52 | print(key, ' was not found in arguments') 53 | return args 54 | 55 | 56 | if __name__ == '__main__': 57 | main() 58 | -------------------------------------------------------------------------------- /scripts/train_utils.py: -------------------------------------------------------------------------------- 1 | from tensorboardX import SummaryWriter 2 | import torch 3 | from torch.functional import F 4 | from collections import defaultdict 5 | 6 | 7 | def check_args(args): 8 | H, W = args.image_size 9 | for _ in args.decoder_network_dims[1:]: 10 | H = H // 2 11 | 12 | if H == 0: 13 | raise ValueError("Too many layers in decoder network") 14 | 15 | 16 | def add_loss(total_loss, curr_loss, loss_dict, loss_name, weight=1): 17 | curr_loss = curr_loss * weight 18 | loss_dict[loss_name] = curr_loss.item() 19 | if total_loss is not None: 20 | total_loss += curr_loss 21 | else: 22 | total_loss = curr_loss 23 | return total_loss 24 | 25 | 26 | def calculate_model_losses(args, skip_pixel_loss, img, img_pred, bbox, bbox_pred): 27 | 28 | total_loss = torch.zeros(1).to(img) 29 | losses = {} 30 | 31 | l1_pixel_weight = args.l1_pixel_loss_weight 32 | 33 | if skip_pixel_loss: 34 | l1_pixel_weight = 0 35 | 36 | l1_pixel_loss = F.l1_loss(img_pred, img) 37 | 38 | total_loss = add_loss(total_loss, l1_pixel_loss, losses, 'L1_pixel_loss', 39 | l1_pixel_weight) 40 | 41 | loss_bbox = F.mse_loss(bbox_pred, bbox) 42 | total_loss = add_loss(total_loss, loss_bbox, losses, 'bbox_pred', 43 | args.bbox_pred_loss_weight) 44 | 45 | return total_loss, losses 46 | 47 | 48 | def init_checkpoint_dict(args, vocab, model_kwargs, d_obj_kwargs, d_img_kwargs): 49 | 50 | ckpt = { 51 | 'args': args.__dict__, 'vocab': vocab, 'model_kwargs': model_kwargs, 52 | 'd_obj_kwargs': d_obj_kwargs, 'd_img_kwargs': d_img_kwargs, 53 | 'losses_ts': [], 'losses': defaultdict(list), 'd_losses': defaultdict(list), 54 | 'checkpoint_ts': [], 'train_iou': [], 'val_losses': defaultdict(list), 55 | 'val_iou': [], 'counters': {'t': None, 'epoch': None}, 56 | 'model_state': None, 'model_best_state': None, 'optim_state': None, 57 | 'd_obj_state': None, 'd_obj_best_state': None, 'd_obj_optim_state': None, 58 | 'd_img_state': None, 'd_img_best_state': None, 'd_img_optim_state': None, 59 | 'best_t': [], 60 | } 61 | return ckpt 62 | 63 | 64 | def print_G_state(args, t, losses, writer, checkpoint): 65 | # print generator losses on terminal and save on tensorboard 66 | 67 | print('t = %d / %d' % (t, args.num_iterations)) 68 | for name, val in losses.items(): 69 | print('G [%s]: %.4f' % (name, val)) 70 | writer.add_scalar('G {}'.format(name), val, global_step=t) 71 | checkpoint['losses'][name].append(val) 72 | checkpoint['losses_ts'].append(t) 73 | 74 | 75 | def print_D_obj_state(args, t, writer, checkpoint, d_obj_losses): 76 | # print D_obj losses on terminal and save on tensorboard 77 | 78 | for name, val in d_obj_losses.items(): 79 | print('D_obj [%s]: %.4f' % (name, val)) 80 | writer.add_scalar('D_obj {}'.format(name), val, global_step=t) 81 | checkpoint['d_losses'][name].append(val) 82 | 83 | 84 | def print_D_img_state(args, t, writer, checkpoint, d_img_losses): 85 | # print D_img losses on terminal and save on tensorboard 86 | 87 | for name, val in d_img_losses.items(): 88 | print('D_img [%s]: %.4f' % (name, val)) 89 | writer.add_scalar('D_img {}'.format(name), val, global_step=t) 90 | checkpoint['d_losses'][name].append(val) 91 | -------------------------------------------------------------------------------- /simsg/SPADE/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/he-dhamo/simsg/a1decd989a53a329c82eacf8124c597f8f2bc308/simsg/SPADE/__init__.py -------------------------------------------------------------------------------- /simsg/SPADE/architectures.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision 10 | #import torch.nn.utils.spectral_norm as spectral_norm 11 | #from models.networks.normalization import SPADE 12 | 13 | 14 | # VGG architecter, used for the perceptual loss using a pretrained VGG network 15 | class VGG19(torch.nn.Module): 16 | def __init__(self, requires_grad=False): 17 | super().__init__() 18 | vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features 19 | self.slice1 = torch.nn.Sequential() 20 | self.slice2 = torch.nn.Sequential() 21 | self.slice3 = torch.nn.Sequential() 22 | self.slice4 = torch.nn.Sequential() 23 | self.slice5 = torch.nn.Sequential() 24 | for x in range(2): 25 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 26 | for x in range(2, 7): 27 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 28 | for x in range(7, 12): 29 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 30 | for x in range(12, 21): 31 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 32 | for x in range(21, 30): 33 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 34 | if not requires_grad: 35 | for param in self.parameters(): 36 | param.requires_grad = False 37 | 38 | def forward(self, X): 39 | h_relu1 = self.slice1(X) 40 | h_relu2 = self.slice2(h_relu1) 41 | h_relu3 = self.slice3(h_relu2) 42 | h_relu4 = self.slice4(h_relu3) 43 | h_relu5 = self.slice5(h_relu4) 44 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 45 | return out 46 | -------------------------------------------------------------------------------- /simsg/SPADE/base_network.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch.nn as nn 7 | from torch.nn import init 8 | 9 | 10 | class BaseNetwork(nn.Module): 11 | def __init__(self): 12 | super(BaseNetwork, self).__init__() 13 | 14 | @staticmethod 15 | def modify_commandline_options(parser, is_train): 16 | return parser 17 | 18 | def print_network(self): 19 | if isinstance(self, list): 20 | self = self[0] 21 | num_params = 0 22 | for param in self.parameters(): 23 | num_params += param.numel() 24 | print('Network [%s] was created. Total number of parameters: %.1f million. ' 25 | 'To see the architecture, do print(network).' 26 | % (type(self).__name__, num_params / 1000000)) 27 | 28 | def init_weights(self, init_type='normal', gain=0.02): 29 | def init_func(m): 30 | classname = m.__class__.__name__ 31 | if classname.find('BatchNorm2d') != -1: 32 | if hasattr(m, 'weight') and m.weight is not None: 33 | init.normal_(m.weight.data, 1.0, gain) 34 | if hasattr(m, 'bias') and m.bias is not None: 35 | init.constant_(m.bias.data, 0.0) 36 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 37 | if init_type == 'normal': 38 | init.normal_(m.weight.data, 0.0, gain) 39 | elif init_type == 'xavier': 40 | init.xavier_normal_(m.weight.data, gain=gain) 41 | elif init_type == 'xavier_uniform': 42 | init.xavier_uniform_(m.weight.data, gain=1.0) 43 | elif init_type == 'kaiming': 44 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 45 | elif init_type == 'orthogonal': 46 | init.orthogonal_(m.weight.data, gain=gain) 47 | elif init_type == 'none': # uses pytorch's default init method 48 | m.reset_parameters() 49 | else: 50 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 51 | if hasattr(m, 'bias') and m.bias is not None: 52 | init.constant_(m.bias.data, 0.0) 53 | 54 | self.apply(init_func) 55 | 56 | # propagate to children 57 | for m in self.children(): 58 | if hasattr(m, 'init_weights'): 59 | m.init_weights(init_type, gain) 60 | -------------------------------------------------------------------------------- /simsg/SPADE/normalization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import re 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | #from models.networks.sync_batchnorm import SynchronizedBatchNorm2d 11 | import torch.nn.utils.spectral_norm as spectral_norm 12 | 13 | 14 | # Returns a function that creates a normalization function 15 | # that does not condition on semantic map 16 | def get_nonspade_norm_layer(opt, norm_type='instance'): 17 | # helper function to get # output channels of the previous layer 18 | def get_out_channel(layer): 19 | if hasattr(layer, 'out_channels'): 20 | return getattr(layer, 'out_channels') 21 | return layer.weight.size(0) 22 | 23 | # this function will be returned 24 | def add_norm_layer(layer): 25 | nonlocal norm_type 26 | if norm_type.startswith('spectral'): 27 | layer = spectral_norm(layer) 28 | subnorm_type = norm_type[len('spectral'):] 29 | 30 | if subnorm_type == 'none' or len(subnorm_type) == 0: 31 | return layer 32 | 33 | # remove bias in the previous layer, which is meaningless 34 | # since it has no effect after normalization 35 | if getattr(layer, 'bias', None) is not None: 36 | delattr(layer, 'bias') 37 | layer.register_parameter('bias', None) 38 | 39 | if subnorm_type == 'batch': 40 | norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) 41 | #elif subnorm_type == 'sync_batch': 42 | # norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True) 43 | elif subnorm_type == 'instance': 44 | norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) 45 | else: 46 | raise ValueError('normalization layer %s is not recognized' % subnorm_type) 47 | 48 | return nn.Sequential(layer, norm_layer) 49 | 50 | return add_norm_layer 51 | 52 | 53 | # Creates SPADE normalization layer based on the given configuration 54 | # SPADE consists of two steps. First, it normalizes the activations using 55 | # your favorite normalization method, such as Batch Norm or Instance Norm. 56 | # Second, it applies scale and bias to the normalized output, conditioned on 57 | # the segmentation map. 58 | # The format of |config_text| is spade(norm)(ks), where 59 | # (norm) specifies the type of parameter-free normalization. 60 | # (e.g. syncbatch, batch, instance) 61 | # (ks) specifies the size of kernel in the SPADE module (e.g. 3x3) 62 | # Example |config_text| will be spadesyncbatch3x3, or spadeinstance5x5. 63 | # Also, the other arguments are 64 | # |norm_nc|: the #channels of the normalized activations, hence the output dim of SPADE 65 | # |label_nc|: the #channels of the input semantic map, hence the input dim of SPADE 66 | class SPADE(nn.Module): 67 | def __init__(self, config_text, norm_nc, label_nc): 68 | super().__init__() 69 | 70 | assert config_text.startswith('spade') 71 | parsed = re.search('spade(\D+)(\d)x\d', config_text) 72 | param_free_norm_type = str(parsed.group(1)) 73 | ks = int(parsed.group(2)) 74 | 75 | if param_free_norm_type == 'instance': 76 | self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) 77 | elif param_free_norm_type == 'syncbatch': 78 | self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False) 79 | elif param_free_norm_type == 'batch': 80 | self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) 81 | else: 82 | raise ValueError('%s is not a recognized param-free norm type in SPADE' 83 | % param_free_norm_type) 84 | 85 | # The dimension of the intermediate embedding space. Yes, hardcoded. 86 | nhidden = 128 87 | 88 | pw = ks // 2 89 | self.mlp_shared = nn.Sequential( 90 | nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), 91 | nn.ReLU() 92 | ) 93 | self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) 94 | self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) 95 | 96 | def forward(self, x, segmap): 97 | 98 | # Part 1. generate parameter-free normalized activations 99 | normalized = self.param_free_norm(x) 100 | 101 | # Part 2. produce scaling and bias conditioned on semantic map 102 | #segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') 103 | actv = self.mlp_shared(segmap) 104 | gamma = self.mlp_gamma(actv) 105 | beta = self.mlp_beta(actv) 106 | #print(normalized.shape, x.shape, segmap.shape, gamma.shape, beta.shape) 107 | # apply scale and bias 108 | out = normalized * (1 + gamma) + beta 109 | 110 | return out 111 | -------------------------------------------------------------------------------- /simsg/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | -------------------------------------------------------------------------------- /simsg/bilinear.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | from simsg.utils import timeit 20 | 21 | 22 | """ 23 | Functions for performing differentiable bilinear cropping of images, for use in 24 | the object discriminator 25 | """ 26 | 27 | 28 | def crop_bbox_batch(feats, bbox, bbox_to_feats, HH, WW=None, backend='cudnn'): 29 | """ 30 | Inputs: 31 | - feats: FloatTensor of shape (N, C, H, W) 32 | - bbox: FloatTensor of shape (B, 4) giving bounding box coordinates 33 | - bbox_to_feats: LongTensor of shape (B,) mapping boxes to feature maps; 34 | each element is in the range [0, N) and bbox_to_feats[b] = i means that 35 | bbox[b] will be cropped from feats[i]. 36 | - HH, WW: Size of the output crops 37 | 38 | Returns: 39 | - crops: FloatTensor of shape (B, C, HH, WW) where crops[i] uses bbox[i] to 40 | crop from feats[bbox_to_feats[i]]. 41 | """ 42 | if backend == 'cudnn': 43 | return crop_bbox_batch_cudnn(feats, bbox, bbox_to_feats, HH, WW) 44 | #print("here ========================================" ,feats.size()) 45 | N, C, H, W = feats.size() 46 | B = bbox.size(0) 47 | if WW is None: WW = HH 48 | dtype, device = feats.dtype, feats.device 49 | crops = torch.zeros(B, C, HH, WW, dtype=dtype, device=device) 50 | for i in range(N): 51 | idx = (bbox_to_feats.data == i).nonzero() 52 | if idx.dim() == 0: 53 | continue 54 | idx = idx.view(-1) 55 | n = idx.size(0) 56 | cur_feats = feats[i].view(1, C, H, W).expand(n, C, H, W).contiguous() 57 | cur_bbox = bbox[idx] 58 | cur_crops = crop_bbox(cur_feats, cur_bbox, HH, WW) 59 | crops[idx] = cur_crops 60 | return crops 61 | 62 | 63 | def _invperm(p): 64 | N = p.size(0) 65 | eye = torch.arange(0, N).type_as(p) 66 | pp = (eye[:, None] == p).nonzero()[:, 1] 67 | return pp 68 | 69 | 70 | def crop_bbox_batch_cudnn(feats, bbox, bbox_to_feats, HH, WW=None): 71 | #print("here ========================================" ,feats.size()) 72 | N, C, H, W = feats.size() 73 | B = bbox.size(0) 74 | if WW is None: WW = HH 75 | dtype = feats.data.type() 76 | 77 | feats_flat, bbox_flat, all_idx = [], [], [] 78 | for i in range(N): 79 | idx = (bbox_to_feats.data == i).nonzero() 80 | if idx.dim() == 0: 81 | continue 82 | idx = idx.view(-1) 83 | n = idx.size(0) 84 | cur_feats = feats[i].view(1, C, H, W).expand(n, C, H, W).contiguous() 85 | cur_bbox = bbox[idx] 86 | 87 | feats_flat.append(cur_feats) 88 | bbox_flat.append(cur_bbox) 89 | all_idx.append(idx) 90 | 91 | feats_flat = torch.cat(feats_flat, dim=0) 92 | bbox_flat = torch.cat(bbox_flat, dim=0) 93 | crops = crop_bbox(feats_flat, bbox_flat, HH, WW, backend='cudnn') 94 | 95 | # If the crops were sequential (all_idx is identity permutation) then we can 96 | # simply return them; otherwise we need to permute crops by the inverse 97 | # permutation from all_idx. 98 | all_idx = torch.cat(all_idx, dim=0) 99 | eye = torch.arange(0, B).type_as(all_idx) 100 | if (all_idx == eye).all(): 101 | return crops 102 | return crops[_invperm(all_idx)] 103 | 104 | 105 | def crop_bbox(feats, bbox, HH, WW=None, backend='cudnn'): 106 | """ 107 | Take differentiable crops of feats specified by bbox. 108 | 109 | Inputs: 110 | - feats: Tensor of shape (N, C, H, W) 111 | - bbox: Bounding box coordinates of shape (N, 4) in the format 112 | [x0, y0, x1, y1] in the [0, 1] coordinate space. 113 | - HH, WW: Size of the output crops. 114 | 115 | Returns: 116 | - crops: Tensor of shape (N, C, HH, WW) where crops[i] is the portion of 117 | feats[i] specified by bbox[i], reshaped to (HH, WW) using bilinear sampling. 118 | """ 119 | N = feats.size(0) 120 | assert bbox.size(0) == N 121 | #print(bbox.shape) 122 | assert bbox.size(1) == 4 123 | if WW is None: WW = HH 124 | if backend == 'cudnn': 125 | # Change box from [0, 1] to [-1, 1] coordinate system 126 | bbox = 2 * bbox - 1 127 | x0, y0 = bbox[:, 0], bbox[:, 1] 128 | x1, y1 = bbox[:, 2], bbox[:, 3] 129 | X = tensor_linspace(x0, x1, steps=WW).view(N, 1, WW).expand(N, HH, WW) 130 | Y = tensor_linspace(y0, y1, steps=HH).view(N, HH, 1).expand(N, HH, WW) 131 | if backend == 'jj': 132 | return bilinear_sample(feats, X, Y) 133 | elif backend == 'cudnn': 134 | grid = torch.stack([X, Y], dim=3) 135 | return F.grid_sample(feats, grid) 136 | 137 | 138 | 139 | def uncrop_bbox(feats, bbox, H, W=None, fill_value=0): 140 | """ 141 | Inverse operation to crop_bbox; construct output images where the feature maps 142 | from feats have been reshaped and placed into the positions specified by bbox. 143 | 144 | Inputs: 145 | - feats: Tensor of shape (N, C, HH, WW) 146 | - bbox: Bounding box coordinates of shape (N, 4) in the format 147 | [x0, y0, x1, y1] in the [0, 1] coordinate space. 148 | - H, W: Size of output. 149 | - fill_value: Portions of the output image that are outside the bounding box 150 | will be filled with this value. 151 | 152 | Returns: 153 | - out: Tensor of shape (N, C, H, W) where the portion of out[i] given by 154 | bbox[i] contains feats[i], reshaped using bilinear sampling. 155 | """ 156 | N, C = feats.size(0), feats.size(1) 157 | assert bbox.size(0) == N 158 | assert bbox.size(1) == 4 159 | if W is None: H = W 160 | 161 | x0, y0 = bbox[:, 0], bbox[:, 1] 162 | x1, y1 = bbox[:, 2], bbox[:, 3] 163 | ww = x1 - x0 164 | hh = y1 - y0 165 | 166 | x0 = x0.contiguous().view(N, 1).expand(N, H) 167 | x1 = x1.contiguous().view(N, 1).expand(N, H) 168 | ww = ww.view(N, 1).expand(N, H) 169 | 170 | y0 = y0.contiguous().view(N, 1).expand(N, W) 171 | y1 = y1.contiguous().view(N, 1).expand(N, W) 172 | hh = hh.view(N, 1).expand(N, W) 173 | 174 | X = torch.linspace(0, 1, steps=W).view(1, W).expand(N, W).to(feats) 175 | Y = torch.linspace(0, 1, steps=H).view(1, H).expand(N, H).to(feats) 176 | 177 | X = (X - x0) / ww 178 | Y = (Y - y0) / hh 179 | 180 | # For ByteTensors, (x + y).clamp(max=1) gives logical_or 181 | X_out_mask = ((X < 0) + (X > 1)).view(N, 1, W).expand(N, H, W) 182 | Y_out_mask = ((Y < 0) + (Y > 1)).view(N, H, 1).expand(N, H, W) 183 | out_mask = (X_out_mask + Y_out_mask).clamp(max=1) 184 | out_mask = out_mask.view(N, 1, H, W).expand(N, C, H, W) 185 | 186 | X = X.view(N, 1, W).expand(N, H, W) 187 | Y = Y.view(N, H, 1).expand(N, H, W) 188 | 189 | out = bilinear_sample(feats, X, Y) 190 | out[out_mask] = fill_value 191 | return out 192 | 193 | 194 | def bilinear_sample(feats, X, Y): 195 | """ 196 | Perform bilinear sampling on the features in feats using the sampling grid 197 | given by X and Y. 198 | 199 | Inputs: 200 | - feats: Tensor holding input feature map, of shape (N, C, H, W) 201 | - X, Y: Tensors holding x and y coordinates of the sampling 202 | grids; both have shape shape (N, HH, WW) and have elements in the range [0, 1]. 203 | Returns: 204 | - out: Tensor of shape (B, C, HH, WW) where out[i] is computed 205 | by sampling from feats[idx[i]] using the sampling grid (X[i], Y[i]). 206 | """ 207 | N, C, H, W = feats.size() 208 | assert X.size() == Y.size() 209 | assert X.size(0) == N 210 | _, HH, WW = X.size() 211 | 212 | X = X.mul(W) 213 | Y = Y.mul(H) 214 | 215 | # Get the x and y coordinates for the four samples 216 | x0 = X.floor().clamp(min=0, max=W-1) 217 | x1 = (x0 + 1).clamp(min=0, max=W-1) 218 | y0 = Y.floor().clamp(min=0, max=H-1) 219 | y1 = (y0 + 1).clamp(min=0, max=H-1) 220 | 221 | # In numpy we could do something like feats[i, :, y0, x0] to pull out 222 | # the elements of feats at coordinates y0 and x0, but PyTorch doesn't 223 | # yet support this style of indexing. Instead we have to use the gather 224 | # method, which only allows us to index along one dimension at a time; 225 | # therefore we will collapse the features (BB, C, H, W) into (BB, C, H * W) 226 | # and index along the last dimension. Below we generate linear indices into 227 | # the collapsed last dimension for each of the four combinations we need. 228 | y0x0_idx = (W * y0 + x0).view(N, 1, HH * WW).expand(N, C, HH * WW) 229 | y1x0_idx = (W * y1 + x0).view(N, 1, HH * WW).expand(N, C, HH * WW) 230 | y0x1_idx = (W * y0 + x1).view(N, 1, HH * WW).expand(N, C, HH * WW) 231 | y1x1_idx = (W * y1 + x1).view(N, 1, HH * WW).expand(N, C, HH * WW) 232 | 233 | # Actually use gather to pull out the values from feats corresponding 234 | # to our four samples, then reshape them to (BB, C, HH, WW) 235 | feats_flat = feats.view(N, C, H * W) 236 | v1 = feats_flat.gather(2, y0x0_idx.long()).view(N, C, HH, WW) 237 | v2 = feats_flat.gather(2, y1x0_idx.long()).view(N, C, HH, WW) 238 | v3 = feats_flat.gather(2, y0x1_idx.long()).view(N, C, HH, WW) 239 | v4 = feats_flat.gather(2, y1x1_idx.long()).view(N, C, HH, WW) 240 | 241 | # Compute the weights for the four samples 242 | w1 = ((x1 - X) * (y1 - Y)).view(N, 1, HH, WW).expand(N, C, HH, WW) 243 | w2 = ((x1 - X) * (Y - y0)).view(N, 1, HH, WW).expand(N, C, HH, WW) 244 | w3 = ((X - x0) * (y1 - Y)).view(N, 1, HH, WW).expand(N, C, HH, WW) 245 | w4 = ((X - x0) * (Y - y0)).view(N, 1, HH, WW).expand(N, C, HH, WW) 246 | 247 | # Multiply the samples by the weights to give our interpolated results. 248 | out = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4 249 | return out 250 | 251 | 252 | def tensor_linspace(start, end, steps=10): 253 | """ 254 | Vectorized version of torch.linspace. 255 | 256 | Inputs: 257 | - start: Tensor of any shape 258 | - end: Tensor of the same shape as start 259 | - steps: Integer 260 | 261 | Returns: 262 | - out: Tensor of shape start.size() + (steps,), such that 263 | out.select(-1, 0) == start, out.select(-1, -1) == end, 264 | and the other elements of out linearly interpolate between 265 | start and end. 266 | """ 267 | assert start.size() == end.size() 268 | view_size = start.size() + (1,) 269 | w_size = (1,) * start.dim() + (steps,) 270 | out_size = start.size() + (steps,) 271 | 272 | start_w = torch.linspace(1, 0, steps=steps).to(start) 273 | start_w = start_w.view(w_size).expand(out_size) 274 | end_w = torch.linspace(0, 1, steps=steps).to(start) 275 | end_w = end_w.view(w_size).expand(out_size) 276 | 277 | start = start.contiguous().view(view_size).expand(out_size) 278 | end = end.contiguous().view(view_size).expand(out_size) 279 | 280 | out = start_w * start + end_w * end 281 | return out 282 | 283 | 284 | if __name__ == '__main__': 285 | import numpy as np 286 | from scipy.misc import imread, imsave, imresize 287 | 288 | cat = imresize(imread('cat.jpg'), (256, 256)) 289 | dog = imresize(imread('dog.jpg'), (256, 256)) 290 | feats = torch.stack([ 291 | torch.from_numpy(cat.transpose(2, 0, 1).astype(np.float32)), 292 | torch.from_numpy(dog.transpose(2, 0, 1).astype(np.float32))], 293 | dim=0) 294 | 295 | boxes = torch.FloatTensor([ 296 | [0, 0, 1, 1], 297 | [0.25, 0.25, 0.75, 0.75], 298 | [0, 0, 0.5, 0.5], 299 | ]) 300 | 301 | box_to_feats = torch.LongTensor([1, 0, 1]).cuda() 302 | 303 | feats, boxes = feats.cuda(), boxes.cuda() 304 | crops = crop_bbox_batch_cudnn(feats, boxes, box_to_feats, 128) 305 | for i in range(crops.size(0)): 306 | crop_np = crops.data[i].cpu().numpy().transpose(1, 2, 0).astype(np.uint8) 307 | imsave('out%d.png' % i, crop_np) 308 | -------------------------------------------------------------------------------- /simsg/data/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from .utils import imagenet_preprocess, imagenet_deprocess 18 | from .utils import imagenet_deprocess_batch 19 | -------------------------------------------------------------------------------- /simsg/data/clevr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2020 Azade Farshad 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import os 18 | 19 | import torch 20 | from torch.utils.data import Dataset 21 | 22 | import torchvision.transforms as T 23 | 24 | import numpy as np 25 | import h5py, json 26 | import PIL 27 | 28 | from .utils import imagenet_preprocess, Resize 29 | sg_task = True 30 | 31 | def conv_src_to_target(voc_s, voc_t): 32 | dic = {} 33 | for k, val in voc_s['object_name_to_idx'].items(): 34 | dic[val] = voc_t['object_name_to_idx'][k] 35 | return dic 36 | 37 | 38 | class SceneGraphWithPairsDataset(Dataset): 39 | def __init__(self, vocab, h5_path, image_dir, image_size=(256, 256), 40 | normalize_images=True, max_objects=10, max_samples=None, 41 | include_relationships=True, use_orphaned_objects=True, 42 | mode='train', clean_repeats=True): 43 | super(SceneGraphWithPairsDataset, self).__init__() 44 | 45 | assert mode in ["train", "eval", "auto_withfeats", "auto_nofeats", "reposition", "remove", "replace"] 46 | 47 | CLEVR_target_dir = os.path.split(h5_path)[0] 48 | CLEVR_SRC_DIR = os.path.join(os.path.split(CLEVR_target_dir)[0], 'source') 49 | 50 | vocab_json_s = os.path.join(CLEVR_SRC_DIR, "vocab.json") 51 | vocab_json_t = os.path.join(CLEVR_target_dir, "vocab.json") 52 | 53 | with open(vocab_json_s, 'r') as f: 54 | vocab_src = json.load(f) 55 | 56 | with open(vocab_json_t, 'r') as f: 57 | vocab_t = json.load(f) 58 | 59 | self.mode = mode 60 | 61 | self.image_dir = image_dir 62 | self.image_source_dir = os.path.join(os.path.split(image_dir)[0], 'source') #Azade 63 | 64 | src_h5_path = os.path.join(self.image_source_dir, os.path.split(h5_path)[-1]) 65 | print(self.image_dir, src_h5_path) 66 | 67 | self.image_size = image_size 68 | self.vocab = vocab 69 | self.vocab_src = vocab_src 70 | self.vocab_t = vocab_t 71 | self.num_objects = len(vocab['object_idx_to_name']) 72 | self.use_orphaned_objects = use_orphaned_objects 73 | self.max_objects = max_objects 74 | self.max_samples = max_samples 75 | self.include_relationships = include_relationships 76 | 77 | self.evaluating = mode != 'train' 78 | 79 | self.clean_repeats = clean_repeats 80 | 81 | transform = [Resize(image_size), T.ToTensor()] 82 | if normalize_images: 83 | transform.append(imagenet_preprocess()) 84 | self.transform = T.Compose(transform) 85 | 86 | self.data = {} 87 | with h5py.File(h5_path, 'r') as f: 88 | for k, v in f.items(): 89 | if k == 'image_paths': 90 | self.image_paths = list(v) 91 | else: 92 | self.data[k] = torch.IntTensor(np.asarray(v)) 93 | 94 | self.data_src = {} 95 | with h5py.File(src_h5_path, 'r') as f: 96 | for k, v in f.items(): 97 | if k == 'image_paths': 98 | self.image_paths_src = list(v) 99 | else: 100 | self.data_src[k] = torch.IntTensor(np.asarray(v)) 101 | 102 | def __len__(self): 103 | num = self.data['object_names'].size(0) 104 | if self.max_samples is not None: 105 | return min(self.max_samples, num) 106 | return num 107 | 108 | def __getitem__(self, index): 109 | """ 110 | Returns a tuple of: 111 | - image: FloatTensor of shape (C, H, W) 112 | - objs: LongTensor of shape (num_objs,) 113 | - boxes: FloatTensor of shape (num_objs, 4) giving boxes for objects in 114 | (x0, y0, x1, y1) format, in a [0, 1] coordinate system. 115 | - triples: LongTensor of shape (num_triples, 3) where triples[t] = [i, p, j] 116 | means that (objs[i], p, objs[j]) is a triple. 117 | """ 118 | img_path = os.path.join(self.image_dir, self.image_paths[index]) 119 | img_source_path = os.path.join(self.image_source_dir, self.image_paths[index]) 120 | 121 | src_to_target_obj = conv_src_to_target(self.vocab_src, self.vocab_t) 122 | 123 | with open(img_path, 'rb') as f: 124 | with PIL.Image.open(f) as image: 125 | WW, HH = image.size 126 | image = self.transform(image.convert('RGB')) 127 | 128 | with open(img_source_path, 'rb') as f: 129 | with PIL.Image.open(f) as image_src: 130 | #WW, HH = image.size 131 | image_src = self.transform(image_src.convert('RGB')) 132 | 133 | H, W = self.image_size 134 | 135 | # Figure out which objects appear in relationships and which don't 136 | obj_idxs_with_rels = set() 137 | obj_idxs_without_rels = set(range(self.data['objects_per_image'][index].item())) 138 | for r_idx in range(self.data['relationships_per_image'][index]): 139 | s = self.data['relationship_subjects'][index, r_idx].item() 140 | o = self.data['relationship_objects'][index, r_idx].item() 141 | obj_idxs_with_rels.add(s) 142 | obj_idxs_with_rels.add(o) 143 | obj_idxs_without_rels.discard(s) 144 | obj_idxs_without_rels.discard(o) 145 | 146 | obj_idxs = list(obj_idxs_with_rels) 147 | obj_idxs_without_rels = list(obj_idxs_without_rels) 148 | if len(obj_idxs) > self.max_objects - 1: 149 | obj_idxs = obj_idxs[:self.max_objects] 150 | if len(obj_idxs) < self.max_objects - 1 and self.use_orphaned_objects: 151 | num_to_add = self.max_objects - 1 - len(obj_idxs) 152 | num_to_add = min(num_to_add, len(obj_idxs_without_rels)) 153 | obj_idxs += obj_idxs_without_rels[:num_to_add] 154 | 155 | num_objs = len(obj_idxs) + 1 156 | 157 | objs = torch.LongTensor(num_objs).fill_(-1) 158 | 159 | boxes = torch.FloatTensor([[0, 0, 1, 1]]).repeat(num_objs, 1) 160 | obj_idx_mapping = {} 161 | for i, obj_idx in enumerate(obj_idxs): 162 | objs[i] = self.data['object_names'][index, obj_idx].item() 163 | x, y, w, h = self.data['object_boxes'][index, obj_idx].tolist() 164 | x0 = float(x) / WW 165 | y0 = float(y) / HH 166 | x1 = float(x + w) / WW 167 | y1 = float(y + h) / HH 168 | boxes[i] = torch.FloatTensor([x0, y0, x1, y1]) 169 | obj_idx_mapping[obj_idx] = i 170 | 171 | # The last object will be the special __image__ object 172 | objs[num_objs - 1] = self.vocab['object_name_to_idx']['__image__'] 173 | 174 | triples = [] 175 | for r_idx in range(self.data['relationships_per_image'][index].item()): 176 | if not self.include_relationships: 177 | break 178 | s = self.data['relationship_subjects'][index, r_idx].item() 179 | p = self.data['relationship_predicates'][index, r_idx].item() 180 | o = self.data['relationship_objects'][index, r_idx].item() 181 | s = obj_idx_mapping.get(s, None) 182 | o = obj_idx_mapping.get(o, None) 183 | if s is not None and o is not None: 184 | if self.clean_repeats and [s, p, o] in triples: 185 | continue 186 | triples.append([s, p, o]) 187 | 188 | # Add dummy __in_image__ relationships for all objects 189 | in_image = self.vocab['pred_name_to_idx']['__in_image__'] 190 | for i in range(num_objs - 1): 191 | triples.append([i, in_image, num_objs - 1]) 192 | 193 | triples = torch.LongTensor(triples) 194 | 195 | #Source image 196 | 197 | # Figure out which objects appear in relationships and which don't 198 | obj_idxs_with_rels_src = set() 199 | obj_idxs_without_rels_src = set(range(self.data_src['objects_per_image'][index].item())) 200 | for r_idx in range(self.data_src['relationships_per_image'][index]): 201 | s = self.data_src['relationship_subjects'][index, r_idx].item() 202 | o = self.data_src['relationship_objects'][index, r_idx].item() 203 | obj_idxs_with_rels_src.add(s) 204 | obj_idxs_with_rels_src.add(o) 205 | obj_idxs_without_rels_src.discard(s) 206 | obj_idxs_without_rels_src.discard(o) 207 | 208 | obj_idxs_src = list(obj_idxs_with_rels_src) 209 | obj_idxs_without_rels_src = list(obj_idxs_without_rels_src) 210 | if len(obj_idxs_src) > self.max_objects - 1: 211 | obj_idxs_src = obj_idxs_src[:self.max_objects] 212 | if len(obj_idxs_src) < self.max_objects - 1 and self.use_orphaned_objects: 213 | num_to_add = self.max_objects - 1 - len(obj_idxs_src) 214 | num_to_add = min(num_to_add, len(obj_idxs_without_rels_src)) 215 | obj_idxs_src += obj_idxs_without_rels_src[:num_to_add] 216 | 217 | num_objs_src = len(obj_idxs_src) + 1 218 | 219 | objs_src = torch.LongTensor(num_objs_src).fill_(-1) 220 | 221 | boxes_src = torch.FloatTensor([[0, 0, 1, 1]]).repeat(num_objs_src, 1) 222 | obj_idx_mapping_src = {} 223 | for i, obj_idx in enumerate(obj_idxs_src): 224 | objs_src[i] = src_to_target_obj[self.data_src['object_names'][index, obj_idx].item()] 225 | x, y, w, h = self.data_src['object_boxes'][index, obj_idx].tolist() 226 | x0 = float(x) / WW 227 | y0 = float(y) / HH 228 | x1 = float(x + w) / WW 229 | y1 = float(y + h) / HH 230 | boxes_src[i] = torch.FloatTensor([x0, y0, x1, y1]) 231 | obj_idx_mapping_src[obj_idx] = i 232 | 233 | # The last object will be the special __image__ object 234 | objs_src[num_objs_src - 1] = self.vocab_src['object_name_to_idx']['__image__'] 235 | 236 | triples_src = [] 237 | for r_idx in range(self.data_src['relationships_per_image'][index].item()): 238 | if not self.include_relationships: 239 | break 240 | s = self.data_src['relationship_subjects'][index, r_idx].item() 241 | p = self.data_src['relationship_predicates'][index, r_idx].item() 242 | o = self.data_src['relationship_objects'][index, r_idx].item() 243 | s = obj_idx_mapping_src.get(s, None) 244 | o = obj_idx_mapping_src.get(o, None) 245 | if s is not None and o is not None: 246 | if self.clean_repeats and [s, p, o] in triples_src: 247 | continue 248 | triples_src.append([s, p, o]) 249 | 250 | # Add dummy __in_image__ relationships for all objects 251 | in_image = self.vocab_src['pred_name_to_idx']['__in_image__'] 252 | for i in range(num_objs_src - 1): 253 | triples_src.append([i, in_image, num_objs_src - 1]) 254 | 255 | triples_src = torch.LongTensor(triples_src) 256 | 257 | return image, image_src, objs, objs_src, boxes, boxes_src, triples, triples_src 258 | 259 | 260 | def collate_fn_withpairs(batch): 261 | """ 262 | Collate function to be used when wrapping a SceneGraphWithPairsDataset in a 263 | DataLoader. Returns a tuple of the following: 264 | 265 | - imgs, imgs_src: target and source FloatTensors of shape (N, C, H, W) 266 | - objs, objs_src: target and source LongTensors of shape (num_objs,) giving categories for all objects 267 | - boxes, boxes_src: target and source FloatTensors of shape (num_objs, 4) giving boxes for all objects 268 | - triples, triples_src: target and source FloatTensors of shape (num_triples, 3) giving all triples, where 269 | triples[t] = [i, p, j] means that [objs[i], p, objs[j]] is a triple 270 | - obj_to_img: LongTensor of shape (num_objs,) mapping objects to images; 271 | obj_to_img[i] = n means that objs[i] belongs to imgs[n] 272 | - triple_to_img: LongTensor of shape (num_triples,) mapping triples to images; 273 | triple_to_img[t] = n means that triples[t] belongs to imgs[n] 274 | - imgs_masked: FloatTensor of shape (N, 4, H, W) 275 | """ 276 | # batch is a list, and each element is (image, objs, boxes, triples) 277 | all_imgs, all_imgs_src, all_objs, all_objs_src, all_boxes, all_boxes_src, all_triples, all_triples_src = [], [], [], [], [], [], [], [] 278 | all_obj_to_img, all_triple_to_img = [], [] 279 | all_imgs_masked = [] 280 | 281 | obj_offset = 0 282 | 283 | for i, (img, image_src, objs, objs_src, boxes, boxes_src, triples, triples_src) in enumerate(batch): 284 | 285 | all_imgs.append(img[None]) 286 | all_imgs_src.append(image_src[None]) 287 | num_objs, num_triples = objs.size(0), triples.size(0) 288 | all_objs.append(objs) 289 | all_objs_src.append(objs_src) 290 | all_boxes.append(boxes) 291 | all_boxes_src.append(boxes_src) 292 | triples = triples.clone() 293 | triples_src = triples_src.clone() 294 | 295 | triples[:, 0] += obj_offset 296 | triples[:, 2] += obj_offset 297 | all_triples.append(triples) 298 | 299 | triples_src[:, 0] += obj_offset 300 | triples_src[:, 2] += obj_offset 301 | all_triples_src.append(triples_src) 302 | 303 | all_obj_to_img.append(torch.LongTensor(num_objs).fill_(i)) 304 | all_triple_to_img.append(torch.LongTensor(num_triples).fill_(i)) 305 | 306 | # prepare input 4-channel image 307 | # initialize mask channel with zeros 308 | masked_img = image_src.clone() 309 | mask = torch.zeros_like(masked_img) 310 | mask = mask[0:1,:,:] 311 | masked_img = torch.cat([masked_img, mask], 0) 312 | all_imgs_masked.append(masked_img[None]) 313 | 314 | obj_offset += num_objs 315 | 316 | all_imgs_masked = torch.cat(all_imgs_masked) 317 | 318 | all_imgs = torch.cat(all_imgs) 319 | all_imgs_src = torch.cat(all_imgs_src) 320 | all_objs = torch.cat(all_objs) 321 | all_objs_src = torch.cat(all_objs_src) 322 | all_boxes = torch.cat(all_boxes) 323 | all_boxes_src = torch.cat(all_boxes_src) 324 | all_triples = torch.cat(all_triples) 325 | all_triples_src = torch.cat(all_triples_src) 326 | all_obj_to_img = torch.cat(all_obj_to_img) 327 | all_triple_to_img = torch.cat(all_triple_to_img) 328 | 329 | out = (all_imgs, all_imgs_src, all_objs, all_objs_src, all_boxes, all_boxes_src, all_triples, all_triples_src, 330 | all_obj_to_img, all_triple_to_img, all_imgs_masked) 331 | 332 | return out 333 | -------------------------------------------------------------------------------- /simsg/data/clevr_gen/1_gen_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2020 Azade Farshad 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | mode="$1" #removal, replacement, addition, rel_change 18 | blender_path="$2" #"/home/azadef/Downloads/blender-2.79-cuda10-x86_64" 19 | merge="$3" 20 | if [ $merge ] 21 | then 22 | out_folder="./output"; 23 | else 24 | out_folder="./output_$mode"; 25 | fi 26 | for i in {1..6} 27 | do 28 | if [ $merge ] 29 | then 30 | start=`expr $(ls -1 "$out_folder/images/" | wc -l)`; 31 | else 32 | start=`expr $(ls -1 "$out_folder/../output_rel_change/images/" | wc -l) + $(ls -1 "$out_folder/../output_remove/images/" | wc -l) + $(ls -1 "$out_folder/../output_replacement/images/" | wc -l) + $(ls -1 "$out_folder/../output_addition/images/" | wc -l)`; 33 | start=$((start/2)); 34 | fi 35 | echo $start 36 | "$blender_path"/blender --background --python render_clevr.py -- --num_images 800 --output_image_dir "$out_folder/images/" --output_scene_dir "$out_folder/scenes/" --output_scene_file "$out_folder/CLEVR_scenes.json" --start_idx $start --use_gpu 1 --mode "$mode" 37 | done 38 | -------------------------------------------------------------------------------- /simsg/data/clevr_gen/2_arrange_data.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Azade Farshad 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | num=0; 16 | res_dir='./MyClevr'; 17 | for j in */; do 18 | find $(pwd)/$j"images" -iname "*_target.png" | sort -u | while read p; do 19 | cp $p $res_dir"/target/images/$num.png"; 20 | q=$(pwd)/$j"scenes/${p##*/}" 21 | q="${q%%.*}.json" 22 | cp $q $res_dir"/target/scenes/$num.json"; 23 | ((num++)); 24 | echo $p; 25 | echo $q; 26 | echo $num; 27 | done 28 | num=`expr $(ls -1 "$res_dir/target/images/" | wc -l)`; 29 | done 30 | -------------------------------------------------------------------------------- /simsg/data/clevr_gen/3_clevrToVG.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2020 Azade Farshad 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """ 18 | Stores the generated CLEVR data in the same format as Visual Genome, for easy loading 19 | """ 20 | 21 | import math, sys, random, argparse, json, os 22 | import numpy as np 23 | #import cv2 24 | import itertools 25 | 26 | 27 | def get_rel(argument): 28 | switcher = { 29 | 3: "left of", 30 | 2: "behind", 31 | 1: "right of", 32 | 0: "front of" 33 | } 34 | return switcher.get(argument, "Invalid label") 35 | 36 | def get_xy_coords(obj): 37 | obj_x, obj_y, obj_w, obj_h = obj['bbox'] 38 | 39 | return obj_x, obj_y, obj_w, obj_h 40 | 41 | def getLabel(obj, subj): 42 | diff = np.subtract(obj[:-1], subj[:-1]) 43 | label = int(np.argmax(np.abs(diff))) 44 | if diff[label] > 0: 45 | label += 2 46 | return label 47 | 48 | 49 | np.random.seed(seed=0) 50 | parser = argparse.ArgumentParser() 51 | 52 | # Input data 53 | parser.add_argument('--data_path', default='/media/azadef/MyHDD/Data/MyClevr_postcvpr/source') 54 | 55 | args = parser.parse_args() 56 | 57 | data_path = args.data_path 58 | im_names = sorted(os.listdir(os.path.join(data_path,'images'))) 59 | json_names = sorted(os.listdir(os.path.join(data_path,'scenes'))) 60 | 61 | labels_file_path = os.path.join(data_path, 'image_data.json') 62 | objs_file_path = os.path.join(data_path, 'objects.json') 63 | rels_file_path = os.path.join(data_path, 'relationships.json') 64 | attrs_file_path = os.path.join(data_path, 'attributes.json') 65 | split_file_path = os.path.join(data_path, 'vg_splits.json') 66 | 67 | rel_alias_path = os.path.join(data_path, 'relationship_alias.txt') 68 | obj_alias_path = os.path.join(data_path, 'object_alias.txt') 69 | 70 | im_id = 0 71 | 72 | num_objs = 0 73 | num_rels = 0 74 | obj_data = [] 75 | rel_data = [] 76 | scenedata = [] 77 | attr_data = [] 78 | 79 | obj_data_file = open(objs_file_path, 'w') 80 | rel_data_file = open(rels_file_path, 'w') 81 | attr_data_file = open(attrs_file_path, 'w') 82 | split_data_file = open(split_file_path, 'w') 83 | 84 | with open(rel_alias_path, 'w'): 85 | print("Created relationship alias file!") 86 | 87 | with open(obj_alias_path, 'w'): 88 | print("Created object alias file!") 89 | 90 | num_ims = len(json_names) 91 | 92 | split_dic = {} 93 | rand_perm = np.random.choice(num_ims, num_ims, replace=False) 94 | train_split = int(num_ims * 0.8) 95 | val_split = int(num_ims * 0.9) 96 | 97 | train_list = [int(num) for num in rand_perm[0:train_split]] 98 | val_list = [int(num) for num in rand_perm[train_split+1:val_split]] 99 | test_list = [int(num) for num in rand_perm[val_split+1:num_ims]] 100 | 101 | split_dic['train'] = train_list 102 | split_dic['val'] = val_list 103 | split_dic['test'] = test_list 104 | 105 | json.dump(split_dic, split_data_file, indent=2) 106 | split_data_file.close() 107 | 108 | with open(labels_file_path, 'w') as labels_file: 109 | for json_file in json_names: 110 | file_name = os.path.join(data_path, 'scenes',json_file) 111 | new_obj = {} 112 | new_rel = {} 113 | new_scene = {} 114 | new_attr = {} 115 | 116 | with open(file_name, 'r') as f: 117 | properties = json.load(f) 118 | 119 | new_obj['image_id'] = im_id 120 | new_rel['image_id'] = im_id 121 | new_scene['image_id'] = im_id 122 | new_attr['image_id'] = im_id 123 | 124 | rels = [] 125 | objs = properties["objects"] 126 | attr_objs = [] 127 | 128 | for j, obj in enumerate(objs): 129 | obj['object_id'] = num_objs + j 130 | obj['x'], obj['y'], obj['w'], obj['h'] = get_xy_coords(obj) 131 | obj.pop('bbox') 132 | 133 | obj['names'] = [obj["color"] + " " + obj['shape']] 134 | obj.pop('shape') 135 | obj.pop('rotation') 136 | attr_obj = obj 137 | attr_obj["attributes"] = [obj["color"]] 138 | obj.pop('size') 139 | attr_objs.append(attr_obj) 140 | 141 | 142 | new_obj['objects'] = objs 143 | new_scene['objects'] = objs 144 | new_attr['attributes'] = attr_objs 145 | 146 | pairs = list(itertools.combinations(objs, 2)) 147 | indices = list((i,j) for ((i,_),(j,_)) in itertools.combinations(enumerate(objs), 2)) 148 | for ii, (obj, subj) in enumerate(pairs): 149 | label = getLabel(obj['3d_coords'], subj['3d_coords']) 150 | predicate = get_rel(label) 151 | if predicate == 'Invalid label': 152 | print("Invalid label!") 153 | rel = {} 154 | rel['predicate'] = predicate 155 | rel['name'] = predicate 156 | rel['object'] = obj 157 | rel['subject'] = subj 158 | rel['relationship_id'] = num_rels + ii 159 | rel['object_id'] = obj['object_id'] 160 | rel['subject_id'] = subj['object_id'] 161 | rels.append(rel) 162 | 163 | new_rel['relationships'] = rels 164 | new_scene['relationships'] = rels 165 | im_path = os.path.join(data_path,'images',json_file.strip('.json') + '.png') 166 | 167 | 168 | new_obj['url'] = im_path 169 | 170 | new_scene['url'] = im_path 171 | new_scene['width'] = 320 172 | new_scene['height'] = 240 173 | im_id += 1 174 | num_objs += len(objs) 175 | num_rels += len(rels) 176 | 177 | scenedata.append(new_scene) 178 | obj_data.append(new_obj) 179 | rel_data.append(new_rel) 180 | attr_data.append(new_attr) 181 | 182 | json.dump(scenedata, labels_file, indent=2) 183 | 184 | json.dump(obj_data, obj_data_file, indent=2) 185 | json.dump(rel_data, rel_data_file, indent=2) 186 | json.dump(attr_data, attr_data_file, indent=2) 187 | 188 | obj_data_file.close() 189 | rel_data_file.close() 190 | attr_data_file.close() 191 | print("Done processing!") 192 | 193 | -------------------------------------------------------------------------------- /simsg/data/clevr_gen/README_CLEVR.md: -------------------------------------------------------------------------------- 1 | ## Setup 2 | You will need to download and install [Blender](https://www.blender.org/); code has been developed and tested using Blender version 2.78c but other versions may work as well. 3 | 4 | ## Data Generation 5 | Run 1_gen_data.sh to generate source/target pairs for each of the available modes. The first argument is the change mode, the second argument is the path to blender directory and the third argument is the option to merge all of the modes into one folder. The manipulation modes are: removal, replacement, addition, rel_change 6 | Sample command: 7 | ```bash 8 | sh 1_gen_data.sh removal /home/user/blender2.79-cuda10-x86_64/ 1 9 | ``` 10 | Run this file for all of the modes to generate all variations of change. 11 | 12 | ## Merging all the folders 13 | If you set the merge argument to 1, you can skip this set. 14 | 15 | Otherwise, move all of the generated images for each mode to a single folder so that you have all images in the "images" folder and all the scenes in "scenes" folder. 16 | 17 | ## Arranging the generated image into separate folder 18 | To arrange the the data, run 2_arrange_data.sh. The script will move the previously generated data to MyClevr directory in the required format. 19 | 20 | ```bash 21 | sh 2_arrange_data.sh 22 | ``` 23 | 24 | ## Converting clevr format to VG format 25 | This step converts the generated data to the format required by SIMSG similar to the Visual Genome dataset. The final scene graphs and the required files are generated in this step. 26 | 27 | ```bash 28 | python 3_clevrToVG.py 29 | ``` 30 | 31 | ## Preprocessing 32 | 33 | Run scripts/preprocess_vg.py on the generated CLEVR data for both source and target folders to preprocess the data for SIMSG. Set VG_DIR inside the file to point to CLEVR source or target directory. 34 | 35 | To make sure no image is removed by the preprocess filtering step so that we have corresponding source/target pairs, set the following arguments: 36 | 37 | ```bash 38 | python scripts/preprocess_vg.py --min_object_instances 0 --min_attribute_instances 0 --min_object_size 0 --min_objects_per_image 1 --min_relationship_instances 1 --max_relationships_per_image 50 39 | ``` 40 | ## Acknowledgment 41 | 42 | Our CLEVR generation code is based on this repo. 43 | -------------------------------------------------------------------------------- /simsg/data/clevr_gen/collect_scenes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. An additional grant 6 | # of patent rights can be found in the PATENTS file in the same directory. 7 | 8 | import argparse, json, os 9 | 10 | """ 11 | During rendering, each CLEVR scene file is dumped to disk as a separate JSON 12 | file; this is convenient for distributing rendering across multiple machines. 13 | This script collects all CLEVR scene files stored in a directory and combines 14 | them into a single JSON file. This script also adds the version number, date, 15 | and license to the output file. 16 | """ 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--input_dir', default='output/scenes') 20 | parser.add_argument('--output_file', default='output/CLEVR_misc_scenes.json') 21 | parser.add_argument('--version', default='1.0') 22 | parser.add_argument('--date', default='7/8/2017') 23 | parser.add_argument('--license', 24 | default='Creative Commons Attribution (CC-BY 4.0') 25 | 26 | 27 | def main(args): 28 | input_files = os.listdir(args.input_dir) 29 | scenes = [] 30 | split = None 31 | for filename in os.listdir(args.input_dir): 32 | if not filename.endswith('.json'): 33 | continue 34 | path = os.path.join(args.input_dir, filename) 35 | with open(path, 'r') as f: 36 | scene = json.load(f) 37 | scenes.append(scene) 38 | if split is not None: 39 | msg = 'Input directory contains scenes from multiple splits' 40 | assert scene['split'] == split, msg 41 | else: 42 | split = scene['split'] 43 | scenes.sort(key=lambda s: s['image_index']) 44 | for s in scenes: 45 | print(s['image_filename']) 46 | output = { 47 | 'info': { 48 | 'date': args.date, 49 | 'version': args.version, 50 | 'split': split, 51 | 'license': args.license, 52 | }, 53 | 'scenes': scenes 54 | } 55 | with open(args.output_file, 'w') as f: 56 | json.dump(output, f) 57 | 58 | 59 | if __name__ == '__main__': 60 | args = parser.parse_args() 61 | main(args) 62 | 63 | -------------------------------------------------------------------------------- /simsg/data/clevr_gen/data/CoGenT_A.json: -------------------------------------------------------------------------------- 1 | { 2 | "cube": [ 3 | "gray", "blue", "brown", "yellow" 4 | ], 5 | "cylinder": [ 6 | "red", "green", "purple", "cyan" 7 | ], 8 | "sphere": [ 9 | "gray", "red", "blue", "green", "brown", "purple", "cyan", "yellow" 10 | ] 11 | } 12 | -------------------------------------------------------------------------------- /simsg/data/clevr_gen/data/CoGenT_B.json: -------------------------------------------------------------------------------- 1 | { 2 | "cube": [ 3 | "red", "green", "purple", "cyan" 4 | ], 5 | "cylinder": [ 6 | "gray", "blue", "brown", "yellow" 7 | ], 8 | "sphere": [ 9 | "gray", "red", "blue", "green", "brown", "purple", "cyan", "yellow" 10 | ] 11 | } 12 | -------------------------------------------------------------------------------- /simsg/data/clevr_gen/data/base_scene.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/he-dhamo/simsg/a1decd989a53a329c82eacf8124c597f8f2bc308/simsg/data/clevr_gen/data/base_scene.blend -------------------------------------------------------------------------------- /simsg/data/clevr_gen/data/materials/MyMetal.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/he-dhamo/simsg/a1decd989a53a329c82eacf8124c597f8f2bc308/simsg/data/clevr_gen/data/materials/MyMetal.blend -------------------------------------------------------------------------------- /simsg/data/clevr_gen/data/materials/Rubber.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/he-dhamo/simsg/a1decd989a53a329c82eacf8124c597f8f2bc308/simsg/data/clevr_gen/data/materials/Rubber.blend -------------------------------------------------------------------------------- /simsg/data/clevr_gen/data/properties.json: -------------------------------------------------------------------------------- 1 | { 2 | "shapes": { 3 | "cube": "SmoothCube_v2", 4 | "sphere": "Sphere", 5 | "cylinder": "SmoothCylinder" 6 | }, 7 | "colors": { 8 | "gray": [87, 87, 87], 9 | "red": [173, 35, 35], 10 | "blue": [42, 75, 215], 11 | "green": [29, 105, 20], 12 | "brown": [129, 74, 25], 13 | "purple": [129, 38, 192], 14 | "cyan": [41, 208, 208], 15 | "yellow": [255, 238, 51] 16 | }, 17 | "materials": { 18 | "rubber": "Rubber", 19 | "metal": "MyMetal" 20 | }, 21 | "sizes": { 22 | "large": 0.7, 23 | "small": 0.35 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /simsg/data/clevr_gen/data/shapes/SmoothCube_v2.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/he-dhamo/simsg/a1decd989a53a329c82eacf8124c597f8f2bc308/simsg/data/clevr_gen/data/shapes/SmoothCube_v2.blend -------------------------------------------------------------------------------- /simsg/data/clevr_gen/data/shapes/SmoothCylinder.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/he-dhamo/simsg/a1decd989a53a329c82eacf8124c597f8f2bc308/simsg/data/clevr_gen/data/shapes/SmoothCylinder.blend -------------------------------------------------------------------------------- /simsg/data/clevr_gen/data/shapes/Sphere.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/he-dhamo/simsg/a1decd989a53a329c82eacf8124c597f8f2bc308/simsg/data/clevr_gen/data/shapes/Sphere.blend -------------------------------------------------------------------------------- /simsg/data/clevr_gen/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. An additional grant 6 | # of patent rights can be found in the PATENTS file in the same directory. 7 | 8 | import sys, random, os 9 | import bpy, bpy_extras 10 | 11 | 12 | """ 13 | Some utility functions for interacting with Blender 14 | """ 15 | 16 | 17 | def extract_args(input_argv=None): 18 | """ 19 | Pull out command-line arguments after "--". Blender ignores command-line flags 20 | after --, so this lets us forward command line arguments from the blender 21 | invocation to our own script. 22 | """ 23 | if input_argv is None: 24 | input_argv = sys.argv 25 | output_argv = [] 26 | if '--' in input_argv: 27 | idx = input_argv.index('--') 28 | output_argv = input_argv[(idx + 1):] 29 | return output_argv 30 | 31 | 32 | def parse_args(parser, argv=None): 33 | return parser.parse_args(extract_args(argv)) 34 | 35 | 36 | # I wonder if there's a better way to do this? 37 | def delete_object(obj): 38 | """ Delete a specified blender object """ 39 | for o in bpy.data.objects: 40 | o.select = False 41 | obj.select = True 42 | bpy.ops.object.delete() 43 | 44 | 45 | def get_camera_coords(cam, pos): 46 | """ 47 | For a specified point, get both the 3D coordinates and 2D pixel-space 48 | coordinates of the point from the perspective of the camera. 49 | 50 | Inputs: 51 | - cam: Camera object 52 | - pos: Vector giving 3D world-space position 53 | 54 | Returns a tuple of: 55 | - (px, py, pz): px and py give 2D image-space coordinates; pz gives depth 56 | in the range [-1, 1] 57 | """ 58 | scene = bpy.context.scene 59 | x, y, z = bpy_extras.object_utils.world_to_camera_view(scene, cam, pos) 60 | scale = scene.render.resolution_percentage / 100.0 61 | w = int(scale * scene.render.resolution_x) 62 | h = int(scale * scene.render.resolution_y) 63 | px = int(round(x * w)) 64 | py = int(round(h - y * h)) 65 | return (px, py, z) 66 | 67 | 68 | def set_layer(obj, layer_idx): 69 | """ Move an object to a particular layer """ 70 | # Set the target layer to True first because an object must always be on 71 | # at least one layer. 72 | obj.layers[layer_idx] = True 73 | for i in range(len(obj.layers)): 74 | obj.layers[i] = (i == layer_idx) 75 | 76 | 77 | def add_object(object_dir, name, scale, loc, theta=0): 78 | """ 79 | Load an object from a file. We assume that in the directory object_dir, there 80 | is a file named "$name.blend" which contains a single object named "$name" 81 | that has unit size and is centered at the origin. 82 | 83 | - scale: scalar giving the size that the object should be in the scene 84 | - loc: tuple (x, y) giving the coordinates on the ground plane where the 85 | object should be placed. 86 | """ 87 | # First figure out how many of this object are already in the scene so we can 88 | # give the new object a unique name 89 | count = 0 90 | for obj in bpy.data.objects: 91 | if obj.name.startswith(name): 92 | count += 1 93 | 94 | filename = os.path.join(object_dir, '%s.blend' % name, 'Object', name) 95 | bpy.ops.wm.append(filename=filename) 96 | 97 | # Give it a new name to avoid conflicts 98 | new_name = '%s_%d' % (name, count) 99 | bpy.data.objects[name].name = new_name 100 | 101 | # Set the new object as active, then rotate, scale, and translate it 102 | x, y = loc 103 | bpy.context.scene.objects.active = bpy.data.objects[new_name] 104 | bpy.context.object.rotation_euler[2] = theta 105 | bpy.ops.transform.resize(value=(scale, scale, scale)) 106 | bpy.ops.transform.translate(value=(x, y, scale)) 107 | 108 | 109 | def load_materials(material_dir): 110 | """ 111 | Load materials from a directory. We assume that the directory contains .blend 112 | files with one material each. The file X.blend has a single NodeTree item named 113 | X; this NodeTree item must have a "Color" input that accepts an RGBA value. 114 | """ 115 | for fn in os.listdir(material_dir): 116 | if not fn.endswith('.blend'): continue 117 | name = os.path.splitext(fn)[0] 118 | filepath = os.path.join(material_dir, fn, 'NodeTree', name) 119 | bpy.ops.wm.append(filename=filepath) 120 | 121 | 122 | def add_material(name, relchange=False, **properties): 123 | """ 124 | Create a new material and assign it to the active object. "name" should be the 125 | name of a material that has been previously loaded using load_materials. 126 | """ 127 | # Figure out how many materials are already in the scene 128 | mat_count = len(bpy.data.materials) 129 | 130 | # Create a new material; it is not attached to anything and 131 | # it will be called "Material" 132 | bpy.ops.material.new() 133 | 134 | # Get a reference to the material we just created and rename it; 135 | # then the next time we make a new material it will still be called 136 | # "Material" and we will still be able to look it up by name 137 | mat = bpy.data.materials['Material'] 138 | mat.name = 'Material_%d' % mat_count 139 | 140 | # Attach the new material to the active object 141 | # Make sure it doesn't already have materials 142 | obj = bpy.context.active_object 143 | if not relchange: 144 | assert len(obj.data.materials) == 0 145 | obj.data.materials.append(mat) 146 | 147 | # Find the output node of the new material 148 | output_node = None 149 | for n in mat.node_tree.nodes: 150 | if n.name == 'Material Output': 151 | output_node = n 152 | break 153 | 154 | # Add a new GroupNode to the node tree of the active material, 155 | # and copy the node tree from the preloaded node group to the 156 | # new group node. This copying seems to happen by-value, so 157 | # we can create multiple materials of the same type without them 158 | # clobbering each other 159 | group_node = mat.node_tree.nodes.new('ShaderNodeGroup') 160 | group_node.node_tree = bpy.data.node_groups[name] 161 | 162 | # Find and set the "Color" input of the new group node 163 | for inp in group_node.inputs: 164 | if inp.name in properties: 165 | inp.default_value = properties[inp.name] 166 | 167 | # Wire the output of the new group node to the input of 168 | # the MaterialOutput node 169 | mat.node_tree.links.new( 170 | group_node.outputs['Shader'], 171 | output_node.inputs['Surface'], 172 | ) 173 | 174 | -------------------------------------------------------------------------------- /simsg/data/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import PIL 18 | import torch 19 | import torchvision.transforms as T 20 | 21 | 22 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 23 | IMAGENET_STD = [0.229, 0.224, 0.225] 24 | 25 | INV_IMAGENET_MEAN = [-m for m in IMAGENET_MEAN] 26 | INV_IMAGENET_STD = [1.0 / s for s in IMAGENET_STD] 27 | 28 | 29 | def imagenet_preprocess(): 30 | return T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) 31 | 32 | 33 | def rescale(x): 34 | lo, hi = x.min(), x.max() 35 | return x.sub(lo).div(hi - lo) 36 | 37 | 38 | def imagenet_deprocess(rescale_image=True): 39 | transforms = [ 40 | T.Normalize(mean=[0, 0, 0], std=INV_IMAGENET_STD), 41 | T.Normalize(mean=INV_IMAGENET_MEAN, std=[1.0, 1.0, 1.0]), 42 | ] 43 | if rescale_image: 44 | transforms.append(rescale) 45 | return T.Compose(transforms) 46 | 47 | 48 | def imagenet_deprocess_batch(imgs, rescale=True): 49 | """ 50 | Input: 51 | - imgs: FloatTensor of shape (N, C, H, W) giving preprocessed images 52 | 53 | Output: 54 | - imgs_de: ByteTensor of shape (N, C, H, W) giving deprocessed images 55 | in the range [0, 255] 56 | """ 57 | if isinstance(imgs, torch.autograd.Variable): 58 | imgs = imgs.data 59 | imgs = imgs.cpu().clone() 60 | deprocess_fn = imagenet_deprocess(rescale_image=rescale) 61 | imgs_de = [] 62 | for i in range(imgs.size(0)): 63 | img_de = deprocess_fn(imgs[i])[None] 64 | img_de = img_de.mul(255).clamp(0, 255).byte() 65 | imgs_de.append(img_de) 66 | imgs_de = torch.cat(imgs_de, dim=0) 67 | return imgs_de 68 | 69 | 70 | class Resize(object): 71 | def __init__(self, size, interp=PIL.Image.BILINEAR): 72 | if isinstance(size, tuple): 73 | H, W = size 74 | self.size = (W, H) 75 | else: 76 | self.size = (size, size) 77 | self.interp = interp 78 | 79 | def __call__(self, img): 80 | return img.resize(self.size, self.interp) 81 | 82 | 83 | def unpack_var(v): 84 | if isinstance(v, torch.autograd.Variable): 85 | return v.data 86 | return v 87 | 88 | 89 | def split_graph_batch(triples, obj_data, obj_to_img, triple_to_img): 90 | triples = unpack_var(triples) 91 | obj_data = [unpack_var(o) for o in obj_data] 92 | obj_to_img = unpack_var(obj_to_img) 93 | triple_to_img = unpack_var(triple_to_img) 94 | 95 | triples_out = [] 96 | obj_data_out = [[] for _ in obj_data] 97 | obj_offset = 0 98 | N = obj_to_img.max() + 1 99 | for i in range(N): 100 | o_idxs = (obj_to_img == i).nonzero().view(-1) 101 | t_idxs = (triple_to_img == i).nonzero().view(-1) 102 | 103 | cur_triples = triples[t_idxs].clone() 104 | cur_triples[:, 0] -= obj_offset 105 | cur_triples[:, 2] -= obj_offset 106 | triples_out.append(cur_triples) 107 | 108 | for j, o_data in enumerate(obj_data): 109 | cur_o_data = None 110 | if o_data is not None: 111 | cur_o_data = o_data[o_idxs] 112 | obj_data_out[j].append(cur_o_data) 113 | 114 | obj_offset += o_idxs.size(0) 115 | 116 | return triples_out, obj_data_out 117 | 118 | -------------------------------------------------------------------------------- /simsg/data/vg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # Modification copyright 2020 Helisa Dhamo 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import os 19 | import random 20 | from collections import defaultdict 21 | 22 | import torch 23 | from torch.utils.data import Dataset 24 | 25 | import torchvision.transforms as T 26 | 27 | import numpy as np 28 | import h5py 29 | import PIL 30 | 31 | from .utils import imagenet_preprocess, Resize 32 | 33 | 34 | class SceneGraphNoPairsDataset(Dataset): 35 | def __init__(self, vocab, h5_path, image_dir, image_size=(256, 256), 36 | normalize_images=True, max_objects=10, max_samples=None, 37 | include_relationships=True, use_orphaned_objects=True, 38 | mode='train', clean_repeats=True, predgraphs=False): 39 | super(SceneGraphNoPairsDataset, self).__init__() 40 | 41 | assert mode in ["train", "eval", "auto_withfeats", "auto_nofeats", "reposition", "remove", "replace"] 42 | 43 | self.mode = mode 44 | 45 | self.image_dir = image_dir 46 | self.image_size = image_size 47 | self.vocab = vocab 48 | self.num_objects = len(vocab['object_idx_to_name']) 49 | self.use_orphaned_objects = use_orphaned_objects 50 | self.max_objects = max_objects 51 | self.max_samples = max_samples 52 | self.include_relationships = include_relationships 53 | 54 | self.evaluating = mode != 'train' 55 | self.predgraphs = predgraphs 56 | 57 | if self.mode == 'reposition': 58 | self.use_orphaned_objects = False 59 | 60 | self.clean_repeats = clean_repeats 61 | 62 | transform = [Resize(image_size), T.ToTensor()] 63 | if normalize_images: 64 | transform.append(imagenet_preprocess()) 65 | self.transform = T.Compose(transform) 66 | 67 | self.data = {} 68 | with h5py.File(h5_path, 'r') as f: 69 | for k, v in f.items(): 70 | if k == 'image_paths': 71 | self.image_paths = list(v) 72 | else: 73 | self.data[k] = torch.IntTensor(np.asarray(v)) 74 | 75 | def __len__(self): 76 | num = self.data['object_names'].size(0) 77 | if self.max_samples is not None: 78 | return min(self.max_samples, num) 79 | return num 80 | 81 | def __getitem__(self, index): 82 | """ 83 | Returns a tuple of: 84 | - image: FloatTensor of shape (C, H, W) 85 | - objs: LongTensor of shape (num_objs,) 86 | - boxes: FloatTensor of shape (num_objs, 4) giving boxes for objects in 87 | (x0, y0, x1, y1) format, in a [0, 1] coordinate system. 88 | - triples: LongTensor of shape (num_triples, 3) where triples[t] = [i, p, j] 89 | means that (objs[i], p, objs[j]) is a triple. 90 | """ 91 | img_path = os.path.join(self.image_dir, self.image_paths[index]) 92 | 93 | # use for the mix strings and bytes error 94 | #img_path = os.path.join(self.image_dir, self.image_paths[index].decode("utf-8")) 95 | 96 | with open(img_path, 'rb') as f: 97 | with PIL.Image.open(f) as image: 98 | WW, HH = image.size 99 | #print(WW, HH) 100 | image = self.transform(image.convert('RGB')) 101 | 102 | H, W = self.image_size 103 | 104 | # Figure out which objects appear in relationships and which don't 105 | obj_idxs_with_rels = set() 106 | obj_idxs_without_rels = set(range(self.data['objects_per_image'][index].item())) 107 | for r_idx in range(self.data['relationships_per_image'][index]): 108 | s = self.data['relationship_subjects'][index, r_idx].item() 109 | o = self.data['relationship_objects'][index, r_idx].item() 110 | obj_idxs_with_rels.add(s) 111 | obj_idxs_with_rels.add(o) 112 | obj_idxs_without_rels.discard(s) 113 | obj_idxs_without_rels.discard(o) 114 | 115 | obj_idxs = list(obj_idxs_with_rels) 116 | obj_idxs_without_rels = list(obj_idxs_without_rels) 117 | if len(obj_idxs) > self.max_objects - 1: 118 | if self.evaluating: 119 | obj_idxs = obj_idxs[:self.max_objects] 120 | else: 121 | obj_idxs = random.sample(obj_idxs, self.max_objects) 122 | if len(obj_idxs) < self.max_objects - 1 and self.use_orphaned_objects: 123 | num_to_add = self.max_objects - 1 - len(obj_idxs) 124 | num_to_add = min(num_to_add, len(obj_idxs_without_rels)) 125 | if self.evaluating: 126 | obj_idxs += obj_idxs_without_rels[:num_to_add] 127 | else: 128 | obj_idxs += random.sample(obj_idxs_without_rels, num_to_add) 129 | if len(obj_idxs) == 0 and not self.use_orphaned_objects: 130 | # avoid empty list of objects 131 | obj_idxs += obj_idxs_without_rels[:1] 132 | map_overlapping_obj = {} 133 | 134 | objs = [] 135 | boxes = [] 136 | 137 | obj_idx_mapping = {} 138 | counter = 0 139 | for i, obj_idx in enumerate(obj_idxs): 140 | 141 | curr_obj = self.data['object_names'][index, obj_idx].item() 142 | x, y, w, h = self.data['object_boxes'][index, obj_idx].tolist() 143 | 144 | x0 = float(x) / WW 145 | y0 = float(y) / HH 146 | if self.predgraphs: 147 | x1 = float(w) / WW 148 | y1 = float(h) / HH 149 | else: 150 | x1 = float(x + w) / WW 151 | y1 = float(y + h) / HH 152 | 153 | curr_box = torch.FloatTensor([x0, y0, x1, y1]) 154 | 155 | found_overlap = False 156 | if self.predgraphs: 157 | for prev_idx in range(counter): 158 | if overlapping_nodes(objs[prev_idx], curr_obj, boxes[prev_idx], curr_box): 159 | map_overlapping_obj[i] = prev_idx 160 | found_overlap = True 161 | break 162 | if not found_overlap: 163 | 164 | objs.append(curr_obj) 165 | boxes.append(curr_box) 166 | map_overlapping_obj[i] = counter 167 | counter += 1 168 | 169 | obj_idx_mapping[obj_idx] = map_overlapping_obj[i] 170 | 171 | # The last object will be the special __image__ object 172 | objs.append(self.vocab['object_name_to_idx']['__image__']) 173 | boxes.append(torch.FloatTensor([0, 0, 1, 1])) 174 | 175 | boxes = torch.stack(boxes) 176 | objs = torch.LongTensor(objs) 177 | num_objs = counter + 1 178 | 179 | triples = [] 180 | for r_idx in range(self.data['relationships_per_image'][index].item()): 181 | if not self.include_relationships: 182 | break 183 | s = self.data['relationship_subjects'][index, r_idx].item() 184 | p = self.data['relationship_predicates'][index, r_idx].item() 185 | o = self.data['relationship_objects'][index, r_idx].item() 186 | s = obj_idx_mapping.get(s, None) 187 | o = obj_idx_mapping.get(o, None) 188 | if s is not None and o is not None: 189 | if self.clean_repeats and [s, p, o] in triples: 190 | continue 191 | if self.predgraphs and s == o: 192 | continue 193 | triples.append([s, p, o]) 194 | 195 | # Add dummy __in_image__ relationships for all objects 196 | in_image = self.vocab['pred_name_to_idx']['__in_image__'] 197 | for i in range(num_objs - 1): 198 | triples.append([i, in_image, num_objs - 1]) 199 | 200 | triples = torch.LongTensor(triples) 201 | return image, objs, boxes, triples 202 | 203 | 204 | def collate_fn_nopairs(batch): 205 | """ 206 | Collate function to be used when wrapping a SceneGraphNoPairsDataset in a 207 | DataLoader. Returns a tuple of the following: 208 | 209 | - imgs: FloatTensor of shape (N, 3, H, W) 210 | - objs: LongTensor of shape (num_objs,) giving categories for all objects 211 | - boxes: FloatTensor of shape (num_objs, 4) giving boxes for all objects 212 | - triples: FloatTensor of shape (num_triples, 3) giving all triples, where 213 | triples[t] = [i, p, j] means that [objs[i], p, objs[j]] is a triple 214 | - obj_to_img: LongTensor of shape (num_objs,) mapping objects to images; 215 | obj_to_img[i] = n means that objs[i] belongs to imgs[n] 216 | - triple_to_img: LongTensor of shape (num_triples,) mapping triples to images; 217 | triple_to_img[t] = n means that triples[t] belongs to imgs[n] 218 | - imgs_masked: FloatTensor of shape (N, 4, H, W) 219 | """ 220 | # batch is a list, and each element is (image, objs, boxes, triples) 221 | all_imgs, all_objs, all_boxes, all_triples = [], [], [], [] 222 | all_obj_to_img, all_triple_to_img = [], [] 223 | 224 | all_imgs_masked = [] 225 | 226 | obj_offset = 0 227 | 228 | for i, (img, objs, boxes, triples) in enumerate(batch): 229 | 230 | all_imgs.append(img[None]) 231 | num_objs, num_triples = objs.size(0), triples.size(0) 232 | 233 | all_objs.append(objs) 234 | all_boxes.append(boxes) 235 | triples = triples.clone() 236 | 237 | triples[:, 0] += obj_offset 238 | triples[:, 2] += obj_offset 239 | 240 | all_triples.append(triples) 241 | 242 | all_obj_to_img.append(torch.LongTensor(num_objs).fill_(i)) 243 | all_triple_to_img.append(torch.LongTensor(num_triples).fill_(i)) 244 | 245 | # prepare input 4-channel image 246 | # initialize mask channel with zeros 247 | masked_img = img.clone() 248 | mask = torch.zeros_like(masked_img) 249 | mask = mask[0:1,:,:] 250 | masked_img = torch.cat([masked_img, mask], 0) 251 | all_imgs_masked.append(masked_img[None]) 252 | 253 | obj_offset += num_objs 254 | 255 | all_imgs_masked = torch.cat(all_imgs_masked) 256 | 257 | all_imgs = torch.cat(all_imgs) 258 | all_objs = torch.cat(all_objs) 259 | all_boxes = torch.cat(all_boxes) 260 | all_triples = torch.cat(all_triples) 261 | all_obj_to_img = torch.cat(all_obj_to_img) 262 | all_triple_to_img = torch.cat(all_triple_to_img) 263 | 264 | return all_imgs, all_objs, all_boxes, all_triples, \ 265 | all_obj_to_img, all_triple_to_img, all_imgs_masked 266 | 267 | 268 | from simsg.model import get_left_right_top_bottom 269 | 270 | 271 | def overlapping_nodes(obj1, obj2, box1, box2, criteria=0.7): 272 | # used to clean predicted graphs - merge nodes with overlapping boxes 273 | # are these two objects overplapping? 274 | # boxes given as [left, top, right, bottom] 275 | res = 100 # used to project box representation in 2D for iou computation 276 | epsilon = 0.001 277 | if obj1 == obj2: 278 | spatial_box1 = np.zeros([res, res]) 279 | left, right, top, bottom = get_left_right_top_bottom(box1, res, res) 280 | spatial_box1[top:bottom, left:right] = 1 281 | spatial_box2 = np.zeros([res, res]) 282 | left, right, top, bottom = get_left_right_top_bottom(box2, res, res) 283 | spatial_box2[top:bottom, left:right] = 1 284 | iou = np.sum(spatial_box1 * spatial_box2) / \ 285 | (np.sum((spatial_box1 + spatial_box2 > 0).astype(np.float32)) + epsilon) 286 | return iou >= criteria 287 | else: 288 | return False 289 | -------------------------------------------------------------------------------- /simsg/decoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # Modification copyright 2020 Helisa Dhamo 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | 22 | from simsg.layers import get_normalization_2d 23 | from simsg.layers import get_activation 24 | 25 | from simsg.SPADE.normalization import SPADE 26 | import torch.nn.utils.spectral_norm as spectral_norm 27 | 28 | 29 | class DecoderNetwork(nn.Module): 30 | """ 31 | Decoder Network that generates a target image from a pair of masked source image and layout 32 | Implemented in two options: with a CRN block or a SPADE block 33 | """ 34 | 35 | def __init__(self, dims, normalization='instance', activation='leakyrelu', spade_blocks=False, source_image_dims=32): 36 | super(DecoderNetwork, self).__init__() 37 | 38 | self.spade_block = spade_blocks 39 | self.source_image_dims = source_image_dims 40 | 41 | layout_dim = dims[0] 42 | self.decoder_modules = nn.ModuleList() 43 | for i in range(1, len(dims)): 44 | input_dim = 1 if i == 1 else dims[i - 1] 45 | output_dim = dims[i] 46 | 47 | if self.spade_block: 48 | # Resnet SPADE block 49 | mod = SPADEResnetBlock(input_dim, output_dim, layout_dim-self.source_image_dims, self.source_image_dims) 50 | 51 | else: 52 | # CRN block 53 | mod = CRNBlock(layout_dim, input_dim, output_dim, 54 | normalization=normalization, activation=activation) 55 | 56 | self.decoder_modules.append(mod) 57 | 58 | output_conv_layers = [ 59 | nn.Conv2d(dims[-1], dims[-1], kernel_size=3, padding=1), 60 | get_activation(activation), 61 | nn.Conv2d(dims[-1], 3, kernel_size=1, padding=0) 62 | ] 63 | nn.init.kaiming_normal_(output_conv_layers[0].weight) 64 | nn.init.kaiming_normal_(output_conv_layers[2].weight) 65 | self.output_conv = nn.Sequential(*output_conv_layers) 66 | 67 | def forward(self, layout): 68 | """ 69 | Output will have same size as layout 70 | """ 71 | # H, W = self.output_size 72 | N, _, H, W = layout.size() 73 | self.layout = layout 74 | 75 | # Figure out size of input 76 | input_H, input_W = H, W 77 | for _ in range(len(self.decoder_modules)): 78 | input_H //= 2 79 | input_W //= 2 80 | 81 | assert input_H != 0 82 | assert input_W != 0 83 | 84 | feats = torch.zeros(N, 1, input_H, input_W).to(layout) 85 | for mod in self.decoder_modules: 86 | feats = F.upsample(feats, scale_factor=2, mode='nearest') 87 | #print(layout.shape) 88 | feats = mod(layout, feats) 89 | 90 | out = self.output_conv(feats) 91 | 92 | return out 93 | 94 | 95 | class CRNBlock(nn.Module): 96 | """ 97 | Cascaded refinement network (CRN) block, as described in: 98 | Qifeng Chen and Vladlen Koltun, 99 | "Photographic Image Synthesis with Cascaded Refinement Networks", 100 | ICCV 2017 101 | """ 102 | 103 | def __init__(self, layout_dim, input_dim, output_dim, 104 | normalization='instance', activation='leakyrelu'): 105 | super(CRNBlock, self).__init__() 106 | 107 | layers = [] 108 | 109 | layers.append(nn.Conv2d(layout_dim + input_dim, output_dim, 110 | kernel_size=3, padding=1)) 111 | layers.append(get_normalization_2d(output_dim, normalization)) 112 | layers.append(get_activation(activation)) 113 | layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1)) 114 | layers.append(get_normalization_2d(output_dim, normalization)) 115 | layers.append(get_activation(activation)) 116 | layers = [layer for layer in layers if layer is not None] 117 | for layer in layers: 118 | if isinstance(layer, nn.Conv2d): 119 | nn.init.kaiming_normal_(layer.weight) 120 | self.net = nn.Sequential(*layers) 121 | 122 | def forward(self, layout, feats): 123 | _, CC, HH, WW = layout.size() 124 | _, _, H, W = feats.size() 125 | assert HH >= H 126 | if HH > H: 127 | factor = round(HH // H) 128 | assert HH % factor == 0 129 | assert WW % factor == 0 and WW // factor == W 130 | layout = F.avg_pool2d(layout, kernel_size=factor, stride=factor) 131 | 132 | net_input = torch.cat([layout, feats], dim=1) 133 | out = self.net(net_input) 134 | return out 135 | 136 | 137 | class SPADEResnetBlock(nn.Module): 138 | """ 139 | ResNet block used in SPADE. 140 | It differs from the ResNet block of pix2pixHD in that 141 | it takes in the segmentation map as input, learns the skip connection if necessary, 142 | and applies normalization first and then convolution. 143 | This architecture seemed like a standard architecture for unconditional or 144 | class-conditional GAN architecture using residual block. 145 | The code was inspired from https://github.com/LMescheder/GAN_stability. 146 | """ 147 | 148 | def __init__(self, fin, fout, seg_nc, src_nc, spade_config_str='spadebatch3x3', spectral=True): 149 | super().__init__() 150 | # Attributes 151 | self.learned_shortcut = (fin != fout) 152 | fmiddle = min(fin, fout) 153 | self.src_nc = src_nc 154 | 155 | # create conv layers 156 | self.conv_0 = nn.Conv2d(fin+self.src_nc, fmiddle, kernel_size=3, padding=1) 157 | self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) 158 | if self.learned_shortcut: 159 | self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) 160 | 161 | # apply spectral norm if specified 162 | if spectral: 163 | self.conv_0 = spectral_norm(self.conv_0) 164 | self.conv_1 = spectral_norm(self.conv_1) 165 | if self.learned_shortcut: 166 | self.conv_s = spectral_norm(self.conv_s) 167 | 168 | # define normalization layers 169 | self.norm_0 = SPADE(spade_config_str, fin, seg_nc) 170 | self.norm_1 = SPADE(spade_config_str, fmiddle, seg_nc) 171 | if self.learned_shortcut: 172 | self.norm_s = SPADE(spade_config_str, fin, seg_nc) 173 | 174 | # note the resnet block with SPADE also takes in |seg|, 175 | # the semantic segmentation map as input 176 | def forward(self, seg_, x): 177 | 178 | seg_ = F.interpolate(seg_, size=x.size()[2:], mode='nearest') 179 | 180 | # only use the layout map as input to SPADE norm (not the source image channels) 181 | layout_only_dim = seg_.size(1) - self.src_nc 182 | in_img = seg_[:, layout_only_dim:, :, :] 183 | seg = seg_[:, :layout_only_dim, :, :] 184 | 185 | x_s = self.shortcut(x, seg) 186 | dx = torch.cat([self.norm_0(x, seg), in_img],1) 187 | dx = self.conv_0(self.actvn(dx)) 188 | dx = self.conv_1(self.actvn(self.norm_1(dx, seg))) 189 | 190 | out = x_s + dx 191 | 192 | return out 193 | 194 | def shortcut(self, x, seg): 195 | if self.learned_shortcut: 196 | x_s = self.conv_s(self.norm_s(x, seg)) 197 | else: 198 | x_s = x 199 | return x_s 200 | 201 | def actvn(self, x): 202 | return F.leaky_relu(x, 2e-1) 203 | -------------------------------------------------------------------------------- /simsg/discriminators.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # Modification copyright 2020 Helisa Dhamo 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | 22 | from simsg.bilinear import crop_bbox_batch 23 | from simsg.layers import GlobalAvgPool, Flatten, get_activation, build_cnn 24 | import numpy as np 25 | 26 | ''' 27 | visualizations for debugging 28 | import cv2 29 | import numpy as np 30 | from simsg.data import imagenet_deprocess_batch 31 | ''' 32 | 33 | class PatchDiscriminator(nn.Module): 34 | def __init__(self, arch, normalization='batch', activation='leakyrelu-0.2', 35 | padding='same', pooling='avg', input_size=(128,128), 36 | layout_dim=0): 37 | super(PatchDiscriminator, self).__init__() 38 | input_dim = 3 + layout_dim 39 | arch = 'I%d,%s' % (input_dim, arch) 40 | cnn_kwargs = { 41 | 'arch': arch, 42 | 'normalization': normalization, 43 | 'activation': activation, 44 | 'pooling': pooling, 45 | 'padding': padding, 46 | } 47 | self.cnn, output_dim = build_cnn(**cnn_kwargs) 48 | self.classifier = nn.Conv2d(output_dim, 1, kernel_size=1, stride=1) 49 | 50 | def forward(self, x, layout=None): 51 | if layout is not None: 52 | x = torch.cat([x, layout], dim=1) 53 | 54 | out = self.cnn(x) 55 | 56 | discr_layers = [] 57 | i = 0 58 | for l in self.cnn.children(): 59 | x = l(x) 60 | 61 | if i % 3 == 0: 62 | discr_layers.append(x) 63 | #print("shape: ", x.shape, l) 64 | i += 1 65 | 66 | return out, discr_layers 67 | 68 | 69 | class AcDiscriminator(nn.Module): 70 | def __init__(self, vocab, arch, normalization='none', activation='relu', 71 | padding='same', pooling='avg'): 72 | super(AcDiscriminator, self).__init__() 73 | self.vocab = vocab 74 | 75 | cnn_kwargs = { 76 | 'arch': arch, 77 | 'normalization': normalization, 78 | 'activation': activation, 79 | 'pooling': pooling, 80 | 'padding': padding, 81 | } 82 | self.cnn_body, D = build_cnn(**cnn_kwargs) 83 | #print(D) 84 | self.cnn = nn.Sequential(self.cnn_body, GlobalAvgPool(), nn.Linear(D, 1024)) 85 | num_objects = len(vocab['object_idx_to_name']) 86 | 87 | self.real_classifier = nn.Linear(1024, 1) 88 | self.obj_classifier = nn.Linear(1024, num_objects) 89 | 90 | def forward(self, x, y): 91 | if x.dim() == 3: 92 | x = x[:, None] 93 | vecs = self.cnn(x) 94 | real_scores = self.real_classifier(vecs) 95 | obj_scores = self.obj_classifier(vecs) 96 | ac_loss = F.cross_entropy(obj_scores, y) 97 | 98 | discr_layers = [] 99 | i = 0 100 | for l in self.cnn_body.children(): 101 | x = l(x) 102 | 103 | if i % 3 == 0: 104 | discr_layers.append(x) 105 | #print("shape: ", x.shape, l) 106 | i += 1 107 | #print(len(discr_layers)) 108 | return real_scores, ac_loss, discr_layers 109 | 110 | 111 | class AcCropDiscriminator(nn.Module): 112 | def __init__(self, vocab, arch, normalization='none', activation='relu', 113 | object_size=64, padding='same', pooling='avg'): 114 | super(AcCropDiscriminator, self).__init__() 115 | self.vocab = vocab 116 | self.discriminator = AcDiscriminator(vocab, arch, normalization, 117 | activation, padding, pooling) 118 | self.object_size = object_size 119 | 120 | def forward(self, imgs, objs, boxes, obj_to_img): 121 | crops = crop_bbox_batch(imgs, boxes, obj_to_img, self.object_size) 122 | real_scores, ac_loss, discr_layers = self.discriminator(crops, objs) 123 | return real_scores, ac_loss, discr_layers 124 | 125 | 126 | # Multi-scale discriminator as in pix2pixHD or SPADE 127 | class MultiscaleDiscriminator(nn.Module): 128 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, 129 | use_sigmoid=False, num_D=3): 130 | super(MultiscaleDiscriminator, self).__init__() 131 | self.num_D = num_D 132 | self.n_layers = n_layers 133 | 134 | for i in range(num_D): 135 | netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid) 136 | for j in range(n_layers + 2): 137 | setattr(self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j))) 138 | 139 | self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) 140 | 141 | def singleD_forward(self, model, input): 142 | result = [input] 143 | for i in range(len(model)): 144 | result.append(model[i](result[-1])) 145 | return result[1:] 146 | 147 | def forward(self, input): 148 | num_D = self.num_D 149 | result = [] 150 | input_downsampled = input 151 | for i in range(num_D): 152 | model = [getattr(self, 'scale' + str(num_D - 1 - i) + '_layer' + str(j)) for j in 153 | range(self.n_layers + 2)] 154 | result.append(self.singleD_forward(model, input_downsampled)) 155 | if i != (num_D - 1): 156 | input_downsampled = self.downsample(input_downsampled) 157 | return result 158 | 159 | # Defines the PatchGAN discriminator with the specified arguments. 160 | class NLayerDiscriminator(nn.Module): 161 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False): 162 | super(NLayerDiscriminator, self).__init__() 163 | self.n_layers = n_layers 164 | 165 | kw = 4 166 | padw = int(np.ceil((kw - 1.0) / 2)) 167 | sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] 168 | 169 | nf = ndf 170 | for n in range(1, n_layers): 171 | nf_prev = nf 172 | nf = min(nf * 2, 512) 173 | sequence += [[ 174 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), 175 | norm_layer(nf), nn.LeakyReLU(0.2, True) 176 | ]] 177 | 178 | nf_prev = nf 179 | nf = min(nf * 2, 512) 180 | sequence += [[ 181 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), 182 | norm_layer(nf), 183 | nn.LeakyReLU(0.2, True) 184 | ]] 185 | 186 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] 187 | 188 | if use_sigmoid: 189 | sequence += [[nn.Sigmoid()]] 190 | 191 | for n in range(len(sequence)): 192 | setattr(self, 'model' + str(n), nn.Sequential(*sequence[n])) 193 | 194 | def forward(self, input): 195 | res = [input] 196 | for n in range(self.n_layers + 2): 197 | model = getattr(self, 'model' + str(n)) 198 | res.append(model(res[-1])) 199 | return res[1:] 200 | 201 | 202 | def divide_pred(pred): 203 | # the prediction contains the intermediate outputs of multiscale GAN, 204 | # so it's usually a list 205 | if type(pred) == list: 206 | fake = [] 207 | real = [] 208 | for p in pred: 209 | fake.append([tensor[:tensor.size(0) // 2] for tensor in p]) 210 | real.append([tensor[tensor.size(0) // 2:] for tensor in p]) 211 | else: 212 | fake = pred[:pred.size(0) // 2] 213 | real = pred[pred.size(0) // 2:] 214 | 215 | return fake, real 216 | -------------------------------------------------------------------------------- /simsg/feats_statistics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def get_mean(spade=True): 4 | if spade: 5 | return torch.tensor([-1.6150e-02, -2.0104e-01, -1.8967e-02, 8.6773e-01, 1.1788e-01, 6 | -3.7955e-02, -2.3504e-01, -7.6040e-02, 8.8244e-01, 7.9869e-02, 7 | 4.8099e-01, -5.5241e-02, -9.3631e-02, 6.2165e-01, 3.5968e-01, 8 | -2.2090e-01, 1.5325e-02, -2.5600e-01, 6.4510e-01, -3.2229e-01, 9 | -9.3363e-03, -2.3452e-01, -2.9875e-01, -5.9847e-02, 5.0070e-01, 10 | -1.3385e-02, -4.1650e-02, 4.7214e-02, 1.5526e+00, -1.1106e+00, 11 | 4.0959e-01, -2.2214e-01, 5.8195e-02, 3.5044e-01, -4.1516e-01, 12 | -2.6164e-01, 3.7791e-01, 1.6697e-01, 2.4359e-01, 3.7218e-01, 13 | 1.3322e-01, 2.2361e-01, -4.0649e-01, 1.4112e+00, 6.2109e-01, 14 | -8.9414e-01, 2.4960e-01, -2.2291e-02, 2.5344e-01, -1.1063e-01, 15 | 2.0111e-01, 3.0083e-01, -4.5993e-01, 1.1597e+00, 6.1391e-02, 16 | 3.8579e-01, -1.8961e-02, 4.3253e-01, 3.1550e-01, 5.1039e-02, 17 | -2.0387e-02, 3.7300e-01, 1.3172e-01, 3.1559e-01, 7.0767e-02, 18 | 4.1030e-01, 2.6682e-01, 3.7454e-01, 1.5960e-01, 2.3767e-01, 19 | 2.1569e-01, 4.0779e-01, 1.8256e-01, 2.3073e-01, 3.6593e-01, 20 | -1.2173e-02, 3.2893e-01, 1.8276e-01, -2.5898e-01, 6.1171e-01, 21 | 5.7514e-01, -3.9560e-02, 1.3200e-01, 4.2561e-01, 1.0145e-01, 22 | 6.3877e-01, 1.5947e-01, -4.3896e-01, -5.0542e-01, 4.7463e-01, 23 | -1.2649e-01, 2.4283e-01, 2.0448e-02, 1.4343e-01, 3.9534e-04, 24 | -5.2192e-02, 8.2538e-01, 2.8621e-01, 1.7567e-01, 9.0932e-02, 25 | -5.3764e-01, 3.6047e-01, 2.3840e-01, 4.0529e-01, -2.9391e-02, 26 | -3.1827e-01, 1.7440e-01, 8.4958e-01, -7.6857e-02, 1.1439e-01, 27 | 3.8213e-01, -1.9179e-01, 6.5607e-01, 1.1372e+00, 2.9348e-01, 28 | 2.1763e-02, 4.9819e-01, -2.3046e-02, -1.5268e-01, -2.3093e-01, 29 | 1.5162e-02, 7.3744e-02, 1.4607e-01, 1.4677e-01, 1.0958e-02, 30 | 1.6972e-01, 2.8153e-01, -4.3338e+00]) 31 | else: 32 | return torch.tensor([ 7.7237e-02, 1.4599e-01, -3.2086e-02, 1.2272e-02, 1.3585e-01, 33 | 2.2647e-01, 1.5775e-01, 8.7155e-02, 2.8425e-01, 9.3309e-01, 34 | -2.8251e+00, 4.8397e-01, 1.6812e-01, -4.3176e-01, 3.5545e-01, 35 | 1.8081e-01, 1.2170e+00, 3.7712e-01, 3.6601e-01, 2.4548e-02, 36 | -8.9017e-02, 2.1669e-01, 5.5919e-01, 5.3623e-01, 1.8667e-01, 37 | 7.7680e-01, 1.1192e-01, 9.0081e-02, 1.8276e-01, 5.6268e-01, 38 | 3.5646e-01, -1.8525e-02, 4.7276e-01, 1.3664e-01, -2.4958e-01, 39 | 1.3064e-01, 1.0058e-01, 2.3417e-01, 2.4552e-01, 1.3849e-01, 40 | 1.0150e+00, 2.3932e-01, -7.0545e-01, 3.6638e-01, 1.8530e-02, 41 | 5.1259e-01, -1.0165e-01, 1.1289e-01, -1.1383e-02, 7.0450e-02, 42 | 1.5806e-01, -1.4862e-01, 4.2865e-02, 2.7360e-01, 2.3326e-01, 43 | 2.2674e-01, -1.0260e-02, 1.8777e-01, 2.2641e-01, -2.4887e-01, 44 | 2.9636e-01, 4.9020e-01, -3.6557e-01, 2.2225e-01, 6.5004e-02, 45 | -4.7235e-02, 1.6490e-02, 3.0804e-02, 2.5871e-01, 3.1599e-05, 46 | 2.9948e-01, -1.2801e+00, 2.6655e-03, 3.6675e-01, 2.4675e-01, 47 | 2.0338e-01, 4.9576e-01, 1.9530e-01, -5.4172e-02, 1.8044e-01, 48 | 2.9627e-01, 8.3459e-02, 1.0668e-01, 1.0449e-01, 1.8324e-01, 49 | -1.3553e-01, -7.5838e-02, -3.9355e-01, 5.1458e-02, 4.3815e-01, 50 | 5.8406e-02, 2.2365e-01, -2.3227e-02, 2.6278e-01, -1.5512e-01, 51 | 3.9810e-01, 3.7780e-01, 4.3622e-01, 2.1492e-01, -3.2092e-02, 52 | 1.4565e-01, 5.7963e-02, 2.4176e-01, -8.1265e-02, 2.5032e-01, 53 | -6.4307e-04, 5.2379e-02, 3.4459e-01, 1.3226e-01, 1.3777e-01, 54 | 3.3404e-02, 2.8588e-01, -1.7163e-01, -3.7610e-02, 8.4848e-02, 55 | 5.9351e-01, -4.0149e-02, 1.2884e-01, 3.8397e-02, -2.2867e-01, 56 | 3.7894e-02, 2.0033e-01, -6.8478e-02, -1.3748e-01, -2.1313e-02, 57 | 1.4798e-01, -7.1153e-02, 2.5109e-01]) 58 | 59 | def get_std(spade=True): 60 | 61 | if spade: 62 | return torch.tensor([0.7377, 0.8760, 0.5966, 0.9396, 0.5330, 0.6746, 0.5873, 0.8002, 0.8464, 63 | 0.6066, 0.7404, 0.6880, 0.8290, 0.8517, 1.0168, 0.6587, 0.6910, 0.7041, 64 | 0.7083, 0.8597, 0.6041, 0.7032, 0.4664, 0.5939, 0.9299, 0.6339, 0.6201, 65 | 0.6464, 1.3569, 0.9664, 0.8113, 0.7645, 0.7036, 0.6485, 0.8178, 0.5965, 66 | 0.5853, 0.9413, 0.5311, 0.6869, 0.7178, 0.4459, 0.6768, 1.0432, 0.5735, 67 | 1.4332, 0.7651, 0.5793, 0.5602, 0.5846, 0.6134, 1.0111, 1.0382, 0.8546, 68 | 0.8659, 0.6131, 0.5885, 0.5515, 0.6286, 0.6191, 0.7734, 0.6184, 0.6307, 69 | 0.6496, 0.8436, 0.6631, 0.5839, 0.6096, 0.8167, 0.6743, 0.5774, 0.5412, 70 | 0.5770, 0.6273, 0.5946, 0.5786, 0.6149, 0.7487, 0.7289, 0.5467, 0.9170, 71 | 0.7468, 0.7206, 0.6468, 0.6711, 0.6553, 0.5945, 0.7166, 0.6544, 0.7168, 72 | 0.6667, 0.6720, 0.8059, 0.6146, 0.6778, 0.7143, 0.9330, 0.5723, 0.6583, 73 | 0.5033, 0.8432, 0.6200, 0.5862, 0.6239, 0.6245, 0.6621, 0.7231, 0.8397, 74 | 0.6194, 0.5876, 0.6404, 0.5165, 0.8618, 1.1337, 0.7873, 0.6999, 0.7077, 75 | 0.6124, 0.7833, 0.6646, 0.6600, 0.6348, 0.5936, 0.5906, 0.5752, 0.7991, 76 | 0.7337, 3.8573]) 77 | 78 | else: 79 | return torch.tensor([0.5562, 0.4463, 0.6239, 0.5942, 0.5989, 0.5937, 0.5813, 0.6217, 0.6885, 80 | 0.8446, 2.4218, 1.5264, 0.6207, 0.5945, 0.6178, 0.5321, 0.9981, 0.6412, 81 | 0.5612, 0.6799, 0.5956, 0.6331, 0.6156, 0.7085, 0.5387, 0.7024, 0.5395, 82 | 0.6156, 0.5126, 0.5616, 0.5330, 0.6228, 0.5880, 0.6474, 0.5386, 0.5761, 83 | 0.5934, 0.5818, 0.5054, 0.6030, 0.8313, 0.5675, 0.5970, 0.5890, 0.5287, 84 | 0.5437, 0.5739, 0.5770, 0.5374, 0.5381, 0.6655, 0.5938, 0.6502, 0.5945, 85 | 0.6889, 0.5691, 0.6278, 0.5376, 0.5919, 0.6407, 0.6559, 0.5942, 0.5599, 86 | 0.6238, 0.5494, 0.5596, 0.5861, 0.5741, 0.6071, 0.5206, 0.5780, 0.9384, 87 | 0.4894, 0.6108, 0.6441, 0.6012, 0.6952, 0.6326, 0.4971, 0.5562, 0.5494, 88 | 0.5879, 0.5013, 0.5992, 0.5527, 0.6322, 0.5842, 0.5900, 0.6353, 0.5606, 89 | 0.6369, 0.5970, 0.5347, 0.6015, 0.5133, 0.6209, 0.6077, 0.7290, 0.5833, 90 | 0.5555, 0.5780, 0.6566, 0.5696, 0.5394, 0.5386, 0.5731, 0.5225, 0.5397, 91 | 0.5517, 0.6082, 0.6007, 0.5728, 0.6639, 0.5972, 0.6115, 0.6083, 0.5304, 92 | 0.5828, 0.6301, 0.5566, 0.6096, 0.6493, 0.5196, 0.6479, 0.6541, 0.5875, 93 | 0.5701, 0.6249]) 94 | -------------------------------------------------------------------------------- /simsg/graph.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import torch 18 | import torch.nn as nn 19 | from simsg.layers import build_mlp 20 | 21 | """ 22 | PyTorch modules for dealing with graphs. 23 | """ 24 | 25 | def _init_weights(module): 26 | if hasattr(module, 'weight'): 27 | if isinstance(module, nn.Linear): 28 | nn.init.kaiming_normal_(module.weight) 29 | 30 | class GraphTripleConv(nn.Module): 31 | """ 32 | A single layer of scene graph convolution. 33 | """ 34 | def __init__(self, input_dim_obj, input_dim_pred, output_dim=None, hidden_dim=512, 35 | pooling='avg', mlp_normalization='none'): 36 | super(GraphTripleConv, self).__init__() 37 | if output_dim is None: 38 | output_dim = input_dim_obj 39 | self.input_dim_obj = input_dim_obj 40 | self.input_dim_pred = input_dim_pred 41 | self.output_dim = output_dim 42 | self.hidden_dim = hidden_dim 43 | 44 | assert pooling in ['sum', 'avg'], 'Invalid pooling "%s"' % pooling 45 | 46 | self.pooling = pooling 47 | net1_layers = [2 * input_dim_obj + input_dim_pred, hidden_dim, 2 * hidden_dim + output_dim] 48 | net1_layers = [l for l in net1_layers if l is not None] 49 | self.net1 = build_mlp(net1_layers, batch_norm=mlp_normalization) 50 | self.net1.apply(_init_weights) 51 | 52 | net2_layers = [hidden_dim, hidden_dim, output_dim] 53 | self.net2 = build_mlp(net2_layers, batch_norm=mlp_normalization) 54 | self.net2.apply(_init_weights) 55 | 56 | 57 | def forward(self, obj_vecs, pred_vecs, edges): 58 | """ 59 | Inputs: 60 | - obj_vecs: FloatTensor of shape (num_objs, D) giving vectors for all objects 61 | - pred_vecs: FloatTensor of shape (num_triples, D) giving vectors for all predicates 62 | - edges: LongTensor of shape (num_triples, 2) where edges[k] = [i, j] indicates the 63 | presence of a triple [obj_vecs[i], pred_vecs[k], obj_vecs[j]] 64 | 65 | Outputs: 66 | - new_obj_vecs: FloatTensor of shape (num_objs, D) giving new vectors for objects 67 | - new_pred_vecs: FloatTensor of shape (num_triples, D) giving new vectors for predicates 68 | """ 69 | dtype, device = obj_vecs.dtype, obj_vecs.device 70 | num_objs, num_triples = obj_vecs.size(0), pred_vecs.size(0) 71 | Din_obj, Din_pred, H, Dout = self.input_dim_obj, self.input_dim_pred, self.hidden_dim, self.output_dim 72 | 73 | # Break apart indices for subjects and objects; these have shape (num_triples,) 74 | s_idx = edges[:, 0].contiguous() 75 | o_idx = edges[:, 1].contiguous() 76 | 77 | # Get current vectors for subjects and objects; these have shape (num_triples, Din) 78 | cur_s_vecs = obj_vecs[s_idx] 79 | cur_o_vecs = obj_vecs[o_idx] 80 | 81 | # Get current vectors for triples; shape is (num_triples, 3 * Din) 82 | # Pass through net1 to get new triple vecs; shape is (num_triples, 2 * H + Dout) 83 | cur_t_vecs = torch.cat([cur_s_vecs, pred_vecs, cur_o_vecs], dim=1) 84 | new_t_vecs = self.net1(cur_t_vecs) 85 | 86 | # Break apart into new s, p, and o vecs; s and o vecs have shape (num_triples, H) and 87 | # p vecs have shape (num_triples, Dout) 88 | new_s_vecs = new_t_vecs[:, :H] 89 | new_p_vecs = new_t_vecs[:, H:(H+Dout)] 90 | new_o_vecs = new_t_vecs[:, (H+Dout):(2 * H + Dout)] 91 | 92 | # Allocate space for pooled object vectors of shape (num_objs, H) 93 | pooled_obj_vecs = torch.zeros(num_objs, H, dtype=dtype, device=device) 94 | 95 | # Use scatter_add to sum vectors for objects that appear in multiple triples; 96 | # we first need to expand the indices to have shape (num_triples, D) 97 | s_idx_exp = s_idx.view(-1, 1).expand_as(new_s_vecs) 98 | o_idx_exp = o_idx.view(-1, 1).expand_as(new_o_vecs) 99 | # print(pooled_obj_vecs.shape, o_idx_exp.shape) 100 | pooled_obj_vecs = pooled_obj_vecs.scatter_add(0, s_idx_exp, new_s_vecs) 101 | pooled_obj_vecs = pooled_obj_vecs.scatter_add(0, o_idx_exp, new_o_vecs) 102 | 103 | if self.pooling == 'avg': 104 | #print("here i am, would you send me an angel") 105 | # Figure out how many times each object has appeared, again using 106 | # some scatter_add trickery. 107 | obj_counts = torch.zeros(num_objs, dtype=dtype, device=device) 108 | ones = torch.ones(num_triples, dtype=dtype, device=device) 109 | obj_counts = obj_counts.scatter_add(0, s_idx, ones) 110 | obj_counts = obj_counts.scatter_add(0, o_idx, ones) 111 | 112 | # Divide the new object vectors by the number of times they 113 | # appeared, but first clamp at 1 to avoid dividing by zero; 114 | # objects that appear in no triples will have output vector 0 115 | # so this will not affect them. 116 | obj_counts = obj_counts.clamp(min=1) 117 | pooled_obj_vecs = pooled_obj_vecs / obj_counts.view(-1, 1) 118 | 119 | # Send pooled object vectors through net2 to get output object vectors, 120 | # of shape (num_objs, Dout) 121 | new_obj_vecs = self.net2(pooled_obj_vecs) 122 | 123 | return new_obj_vecs, new_p_vecs 124 | 125 | 126 | class GraphTripleConvNet(nn.Module): 127 | """ A sequence of scene graph convolution layers """ 128 | def __init__(self, input_dim_obj, input_dim_pred, num_layers=5, hidden_dim=512, pooling='avg', 129 | mlp_normalization='none'): 130 | super(GraphTripleConvNet, self).__init__() 131 | 132 | self.num_layers = num_layers 133 | self.gconvs = nn.ModuleList() 134 | gconv_kwargs = { 135 | 'input_dim_obj': input_dim_obj, 136 | 'input_dim_pred': input_dim_pred, 137 | 'hidden_dim': hidden_dim, 138 | 'pooling': pooling, 139 | 'mlp_normalization': mlp_normalization, 140 | } 141 | for _ in range(self.num_layers): 142 | self.gconvs.append(GraphTripleConv(**gconv_kwargs)) 143 | 144 | def forward(self, obj_vecs, pred_vecs, edges): 145 | for i in range(self.num_layers): 146 | gconv = self.gconvs[i] 147 | obj_vecs, pred_vecs = gconv(obj_vecs, pred_vecs, edges) 148 | return obj_vecs, pred_vecs 149 | 150 | 151 | -------------------------------------------------------------------------------- /simsg/layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | 22 | def get_normalization_2d(channels, normalization): 23 | if normalization == 'instance': 24 | return nn.InstanceNorm2d(channels) 25 | elif normalization == 'batch': 26 | return nn.BatchNorm2d(channels) 27 | elif normalization == 'none': 28 | return None 29 | else: 30 | raise ValueError('Unrecognized normalization type "%s"' % normalization) 31 | 32 | 33 | def get_activation(name): 34 | kwargs = {} 35 | if name.lower().startswith('leakyrelu'): 36 | if '-' in name: 37 | slope = float(name.split('-')[1]) 38 | kwargs = {'negative_slope': slope} 39 | name = 'leakyrelu' 40 | activations = { 41 | 'relu': nn.ReLU, 42 | 'leakyrelu': nn.LeakyReLU, 43 | } 44 | if name.lower() not in activations: 45 | raise ValueError('Invalid activation "%s"' % name) 46 | return activations[name.lower()](**kwargs) 47 | 48 | 49 | def _init_conv(layer, method): 50 | if not isinstance(layer, nn.Conv2d): 51 | return 52 | if method == 'default': 53 | return 54 | elif method == 'kaiming-normal': 55 | nn.init.kaiming_normal(layer.weight) 56 | elif method == 'kaiming-uniform': 57 | nn.init.kaiming_uniform(layer.weight) 58 | 59 | 60 | class Flatten(nn.Module): 61 | def forward(self, x): 62 | return x.view(x.size(0), -1) 63 | 64 | def __repr__(self): 65 | return 'Flatten()' 66 | 67 | 68 | class Unflatten(nn.Module): 69 | def __init__(self, size): 70 | super(Unflatten, self).__init__() 71 | self.size = size 72 | 73 | def forward(self, x): 74 | return x.view(*self.size) 75 | 76 | def __repr__(self): 77 | size_str = ', '.join('%d' % d for d in self.size) 78 | return 'Unflatten(%s)' % size_str 79 | 80 | 81 | class GlobalAvgPool(nn.Module): 82 | def forward(self, x): 83 | N, C = x.size(0), x.size(1) 84 | return x.view(N, C, -1).mean(dim=2) 85 | 86 | 87 | class ResidualBlock(nn.Module): 88 | def __init__(self, channels, normalization='batch', activation='relu', 89 | padding='same', kernel_size=3, init='default'): 90 | super(ResidualBlock, self).__init__() 91 | 92 | K = kernel_size 93 | P = _get_padding(K, padding) 94 | C = channels 95 | self.padding = P 96 | layers = [ 97 | get_normalization_2d(C, normalization), 98 | get_activation(activation), 99 | nn.Conv2d(C, C, kernel_size=K, padding=P), 100 | get_normalization_2d(C, normalization), 101 | get_activation(activation), 102 | nn.Conv2d(C, C, kernel_size=K, padding=P), 103 | ] 104 | layers = [layer for layer in layers if layer is not None] 105 | for layer in layers: 106 | _init_conv(layer, method=init) 107 | self.net = nn.Sequential(*layers) 108 | 109 | def forward(self, x): 110 | P = self.padding 111 | shortcut = x 112 | if P == 0: 113 | shortcut = x[:, :, P:-P, P:-P] 114 | y = self.net(x) 115 | return shortcut + self.net(x) 116 | 117 | 118 | def _get_padding(K, mode): 119 | """ Helper method to compute padding size """ 120 | if mode == 'valid': 121 | return 0 122 | elif mode == 'same': 123 | assert K % 2 == 1, 'Invalid kernel size %d for "same" padding' % K 124 | return (K - 1) // 2 125 | 126 | 127 | def build_cnn(arch, normalization='batch', activation='relu', padding='same', 128 | pooling='max', init='default'): 129 | """ 130 | Build a CNN from an architecture string, which is a list of layer 131 | specification strings. The overall architecture can be given as a list or as 132 | a comma-separated string. 133 | 134 | All convolutions *except for the first* are preceeded by normalization and 135 | nonlinearity. 136 | 137 | All other layers support the following: 138 | - IX: Indicates that the number of input channels to the network is X. 139 | Can only be used at the first layer; if not present then we assume 140 | 3 input channels. 141 | - CK-X: KxK convolution with X output channels 142 | - CK-X-S: KxK convolution with X output channels and stride S 143 | - R: Residual block keeping the same number of channels 144 | - UX: Nearest-neighbor upsampling with factor X 145 | - PX: Spatial pooling with factor X 146 | - FC-X-Y: Flatten followed by fully-connected layer 147 | 148 | Returns a tuple of: 149 | - cnn: An nn.Sequential 150 | - channels: Number of output channels 151 | """ 152 | #print(arch) 153 | if isinstance(arch, str): 154 | arch = arch.split(',') 155 | cur_C = 3 156 | if len(arch) > 0 and arch[0][0] == 'I': 157 | cur_C = int(arch[0][1:]) 158 | arch = arch[1:] 159 | #print(arch) 160 | #if len(arch) > 0 and arch[0][0] == 'I': 161 | # arch = arch[1:] 162 | first_conv = True 163 | flat = False 164 | layers = [] 165 | for i, s in enumerate(arch): 166 | #if s[0] == 'I': 167 | # continue 168 | if s[0] == 'C': 169 | if not first_conv: 170 | layers.append(get_normalization_2d(cur_C, normalization)) 171 | layers.append(get_activation(activation)) 172 | first_conv = False 173 | vals = [int(i) for i in s[1:].split('-')] 174 | if len(vals) == 2: 175 | K, next_C = vals 176 | stride = 1 177 | elif len(vals) == 3: 178 | K, next_C, stride = vals 179 | # K, next_C = (int(i) for i in s[1:].split('-')) 180 | P = _get_padding(K, padding) 181 | conv = nn.Conv2d(cur_C, next_C, kernel_size=K, padding=P, stride=stride) 182 | layers.append(conv) 183 | _init_conv(layers[-1], init) 184 | cur_C = next_C 185 | elif s[0] == 'R': 186 | norm = 'none' if first_conv else normalization 187 | res = ResidualBlock(cur_C, normalization=norm, activation=activation, 188 | padding=padding, init=init) 189 | layers.append(res) 190 | first_conv = False 191 | elif s[0] == 'U': 192 | factor = int(s[1:]) 193 | layers.append(nn.Upsample(scale_factor=factor, mode='nearest')) 194 | elif s[0] == 'P': 195 | factor = int(s[1:]) 196 | if pooling == 'max': 197 | pool = nn.MaxPool2d(kernel_size=factor, stride=factor) 198 | elif pooling == 'avg': 199 | pool = nn.AvgPool2d(kernel_size=factor, stride=factor) 200 | layers.append(pool) 201 | elif s[:2] == 'FC': 202 | _, Din, Dout = s.split('-') 203 | Din, Dout = int(Din), int(Dout) 204 | if not flat: 205 | layers.append(Flatten()) 206 | flat = True 207 | layers.append(nn.Linear(Din, Dout)) 208 | if i + 1 < len(arch): 209 | layers.append(get_activation(activation)) 210 | cur_C = Dout 211 | else: 212 | raise ValueError('Invalid layer "%s"' % s) 213 | layers = [layer for layer in layers if layer is not None] 214 | for layer in layers: 215 | print(layer) 216 | return nn.Sequential(*layers), cur_C 217 | 218 | 219 | def build_mlp(dim_list, activation='relu', batch_norm='none', 220 | dropout=0, final_nonlinearity=True): 221 | layers = [] 222 | for i in range(len(dim_list) - 1): 223 | dim_in, dim_out = dim_list[i], dim_list[i + 1] 224 | layers.append(nn.Linear(dim_in, dim_out)) 225 | final_layer = (i == len(dim_list) - 2) 226 | if not final_layer or final_nonlinearity: 227 | if batch_norm == 'batch': 228 | layers.append(nn.BatchNorm1d(dim_out)) 229 | if activation == 'relu': 230 | layers.append(nn.ReLU()) 231 | elif activation == 'leakyrelu': 232 | layers.append(nn.LeakyReLU()) 233 | if dropout > 0: 234 | layers.append(nn.Dropout(p=dropout)) 235 | return nn.Sequential(*layers) 236 | 237 | -------------------------------------------------------------------------------- /simsg/layout.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | 20 | """ 21 | Functions for computing image layouts from object vectors, bounding boxes, 22 | and segmentation masks. These are used to compute course scene layouts which 23 | are then fed as input to the cascaded refinement network. 24 | """ 25 | 26 | 27 | def boxes_to_layout(vecs, boxes, obj_to_img, H, W=None, pooling='sum'): 28 | """ 29 | Inputs: 30 | - vecs: Tensor of shape (O, D) giving vectors 31 | - boxes: Tensor of shape (O, 4) giving bounding boxes in the format 32 | [x0, y0, x1, y1] in the [0, 1] coordinate space 33 | - obj_to_img: LongTensor of shape (O,) mapping each element of vecs to 34 | an image, where each element is in the range [0, N). If obj_to_img[i] = j 35 | then vecs[i] belongs to image j. 36 | - H, W: Size of the output 37 | 38 | Returns: 39 | - out: Tensor of shape (N, D, H, W) 40 | """ 41 | O, D = vecs.size() 42 | if W is None: 43 | W = H 44 | 45 | grid = _boxes_to_grid(boxes, H, W) 46 | 47 | # If we don't add extra spatial dimensions here then out-of-bounds 48 | # elements won't be automatically set to 0 49 | img_in = vecs.view(O, D, 1, 1).expand(O, D, 8, 8) 50 | sampled = F.grid_sample(img_in, grid) # (O, D, H, W) 51 | 52 | # Explicitly masking makes everything quite a bit slower. 53 | # If we rely on implicit masking the interpolated boxes end up 54 | # blurred around the edges, but it should be fine. 55 | # mask = ((X < 0) + (X > 1) + (Y < 0) + (Y > 1)).clamp(max=1) 56 | # sampled[mask[:, None]] = 0 57 | 58 | out = _pool_samples(sampled, obj_to_img, pooling=pooling) 59 | 60 | return out 61 | 62 | 63 | def masks_to_layout(vecs, boxes, masks, obj_to_img, H, W=None, pooling='sum', front_idx=None): 64 | """ 65 | Inputs: 66 | - vecs: Tensor of shape (O, D) giving vectors 67 | - boxes: Tensor of shape (O, 4) giving bounding boxes in the format 68 | [x0, y0, x1, y1] in the [0, 1] coordinate space 69 | - masks: Tensor of shape (O, M, M) giving binary masks for each object 70 | - obj_to_img: LongTensor of shape (O,) mapping objects to images 71 | - H, W: Size of the output image. 72 | 73 | Returns: 74 | - out: Tensor of shape (N, D, H, W) 75 | """ 76 | O, D = vecs.size() 77 | M = masks.size(1) 78 | assert masks.size() == (O, M, M) 79 | if W is None: 80 | W = H 81 | 82 | grid = _boxes_to_grid(boxes, H, W) 83 | 84 | img_in = vecs.view(O, D, 1, 1) * masks.float().view(O, 1, M, M) 85 | 86 | if pooling == 'max': 87 | out = [] 88 | 89 | # group by image, extract feature with maximum mask confidence 90 | for i in range(obj_to_img.max()+1): 91 | 92 | curr_projected_mask = F.grid_sample(masks.view(O, 1, M, M)[(obj_to_img==i).nonzero().view(-1)], 93 | grid[(obj_to_img==i).nonzero().view(-1)]) 94 | curr_sampled = F.grid_sample(img_in[(obj_to_img==i).nonzero().view(-1)], 95 | grid[(obj_to_img==i).nonzero().view(-1)]) 96 | 97 | argmax_mask = torch.argmax(curr_projected_mask, dim=0) 98 | argmax_mask = argmax_mask.repeat(1, D, 1, 1) 99 | out.append(torch.gather(curr_sampled, 0, argmax_mask)) 100 | 101 | out = torch.cat(out, dim=0) 102 | 103 | else: 104 | sampled = F.grid_sample(img_in, grid) 105 | out = _pool_samples(sampled, obj_to_img, pooling=pooling) 106 | 107 | #print(out.shape) 108 | return out 109 | 110 | 111 | def _boxes_to_grid(boxes, H, W): 112 | """ 113 | Input: 114 | - boxes: FloatTensor of shape (O, 4) giving boxes in the [x0, y0, x1, y1] 115 | format in the [0, 1] coordinate space 116 | - H, W: Scalars giving size of output 117 | 118 | Returns: 119 | - grid: FloatTensor of shape (O, H, W, 2) suitable for passing to grid_sample 120 | """ 121 | O = boxes.size(0) 122 | 123 | boxes = boxes.view(O, 4, 1, 1) 124 | 125 | # All these are (O, 1, 1) 126 | x0, y0 = boxes[:, 0], boxes[:, 1] 127 | x1, y1 = boxes[:, 2], boxes[:, 3] 128 | ww = x1 - x0 129 | hh = y1 - y0 130 | 131 | X = torch.linspace(0, 1, steps=W).view(1, 1, W).to(boxes) 132 | Y = torch.linspace(0, 1, steps=H).view(1, H, 1).to(boxes) 133 | 134 | X = (X - x0) / ww # (O, 1, W) 135 | Y = (Y - y0) / hh # (O, H, 1) 136 | 137 | # Stack does not broadcast its arguments so we need to expand explicitly 138 | X = X.expand(O, H, W) 139 | Y = Y.expand(O, H, W) 140 | grid = torch.stack([X, Y], dim=3) # (O, H, W, 2) 141 | 142 | # Right now grid is in [0, 1] space; transform to [-1, 1] 143 | grid = grid.mul(2).sub(1) 144 | 145 | return grid 146 | 147 | 148 | def _pool_samples(samples, obj_to_img, pooling='sum'): 149 | """ 150 | Input: 151 | - samples: FloatTensor of shape (O, D, H, W) 152 | - obj_to_img: LongTensor of shape (O,) with each element in the range 153 | [0, N) mapping elements of samples to output images 154 | 155 | Output: 156 | - pooled: FloatTensor of shape (N, D, H, W) 157 | """ 158 | dtype, device = samples.dtype, samples.device 159 | O, D, H, W = samples.size() 160 | N = obj_to_img.data.max().item() + 1 161 | 162 | # Use scatter_add to sum the sampled outputs for each image 163 | out = torch.zeros(N, D, H, W, dtype=dtype, device=device) 164 | idx = obj_to_img.view(O, 1, 1, 1).expand(O, D, H, W) 165 | out = out.scatter_add(0, idx, samples) 166 | 167 | if pooling == 'avg': 168 | # Divide each output mask by the number of objects; use scatter_add again 169 | # to count the number of objects per image. 170 | ones = torch.ones(O, dtype=dtype, device=device) 171 | obj_counts = torch.zeros(N, dtype=dtype, device=device) 172 | obj_counts = obj_counts.scatter_add(0, obj_to_img, ones) 173 | print(obj_counts) 174 | obj_counts = obj_counts.clamp(min=1) 175 | out = out / obj_counts.view(N, 1, 1, 1) 176 | 177 | elif pooling != 'sum': 178 | raise ValueError('Invalid pooling "%s"' % pooling) 179 | 180 | return out 181 | 182 | 183 | if __name__ == '__main__': 184 | vecs = torch.FloatTensor([ 185 | [1, 0, 0], [0, 1, 0], [0, 0, 1], 186 | [1, 0, 0], [0, 1, 0], [0, 0, 1], 187 | ]) 188 | boxes = torch.FloatTensor([ 189 | [0.25, 0.125, 0.5, 0.875], 190 | [0, 0, 1, 0.25], 191 | [0.6125, 0, 0.875, 1], 192 | [0, 0.8, 1, 1.0], 193 | [0.25, 0.125, 0.5, 0.875], 194 | [0.6125, 0, 0.875, 1], 195 | ]) 196 | obj_to_img = torch.LongTensor([0, 0, 0, 1, 1, 1]) 197 | # vecs = torch.FloatTensor([[[1]]]) 198 | # boxes = torch.FloatTensor([[[0.25, 0.25, 0.75, 0.75]]]) 199 | vecs, boxes = vecs.cuda(), boxes.cuda() 200 | obj_to_img = obj_to_img.cuda() 201 | out = boxes_to_layout(vecs, boxes, obj_to_img, 256, pooling='sum') 202 | 203 | from torchvision.utils import save_image 204 | save_image(out.data, 'out.png') 205 | 206 | 207 | masks = torch.FloatTensor([ 208 | [ 209 | [0, 0, 1, 0, 0], 210 | [0, 1, 1, 1, 0], 211 | [1, 1, 1, 1, 1], 212 | [0, 1, 1, 1, 0], 213 | [0, 0, 1, 0, 0], 214 | ], 215 | [ 216 | [0, 0, 1, 0, 0], 217 | [0, 1, 0, 1, 0], 218 | [1, 0, 0, 0, 1], 219 | [0, 1, 0, 1, 0], 220 | [0, 0, 1, 0, 0], 221 | ], 222 | [ 223 | [0, 0, 1, 0, 0], 224 | [0, 1, 1, 1, 0], 225 | [1, 1, 1, 1, 1], 226 | [0, 1, 1, 1, 0], 227 | [0, 0, 1, 0, 0], 228 | ], 229 | [ 230 | [0, 0, 1, 0, 0], 231 | [0, 1, 1, 1, 0], 232 | [1, 1, 1, 1, 1], 233 | [0, 1, 1, 1, 0], 234 | [0, 0, 1, 0, 0], 235 | ], 236 | [ 237 | [0, 0, 1, 0, 0], 238 | [0, 1, 1, 1, 0], 239 | [1, 1, 1, 1, 1], 240 | [0, 1, 1, 1, 0], 241 | [0, 0, 1, 0, 0], 242 | ], 243 | [ 244 | [0, 0, 1, 0, 0], 245 | [0, 1, 1, 1, 0], 246 | [1, 1, 1, 1, 1], 247 | [0, 1, 1, 1, 0], 248 | [0, 0, 1, 0, 0], 249 | ] 250 | ]) 251 | masks = masks.cuda() 252 | out = masks_to_layout(vecs, boxes, masks, obj_to_img, 256, pooling="max") 253 | save_image(out.data, 'out_masks.png') 254 | -------------------------------------------------------------------------------- /simsg/loader_utils.py: -------------------------------------------------------------------------------- 1 | from simsg.data.vg import SceneGraphNoPairsDataset, collate_fn_nopairs 2 | from simsg.data.clevr import SceneGraphWithPairsDataset, collate_fn_withpairs 3 | 4 | import json 5 | from torch.utils.data import DataLoader 6 | 7 | 8 | def build_clevr_supervised_train_dsets(args): 9 | print("building fully supervised %s dataset" % args.dataset) 10 | with open(args.vocab_json, 'r') as f: 11 | vocab = json.load(f) 12 | dset_kwargs = { 13 | 'vocab': vocab, 14 | 'h5_path': args.train_h5, 15 | 'image_dir': args.vg_image_dir, 16 | 'image_size': args.image_size, 17 | 'max_samples': args.num_train_samples, 18 | 'max_objects': args.max_objects_per_image, 19 | 'use_orphaned_objects': args.vg_use_orphaned_objects, 20 | 'include_relationships': args.include_relationships, 21 | } 22 | train_dset = SceneGraphWithPairsDataset(**dset_kwargs) 23 | iter_per_epoch = len(train_dset) // args.batch_size 24 | print('There are %d iterations per epoch' % iter_per_epoch) 25 | 26 | dset_kwargs['h5_path'] = args.val_h5 27 | del dset_kwargs['max_samples'] 28 | val_dset = SceneGraphWithPairsDataset(**dset_kwargs) 29 | 30 | dset_kwargs['h5_path'] = args.test_h5 31 | test_dset = SceneGraphWithPairsDataset(**dset_kwargs) 32 | 33 | return vocab, train_dset, val_dset, test_dset 34 | 35 | 36 | def build_dset_nopairs(args, checkpoint): 37 | 38 | vocab = checkpoint['model_kwargs']['vocab'] 39 | dset_kwargs = { 40 | 'vocab': vocab, 41 | 'h5_path': args.data_h5, 42 | 'image_dir': args.data_image_dir, 43 | 'image_size': args.image_size, 44 | 'max_objects': checkpoint['args']['max_objects_per_image'], 45 | 'use_orphaned_objects': checkpoint['args']['vg_use_orphaned_objects'], 46 | 'mode': args.mode, 47 | 'predgraphs': args.predgraphs 48 | } 49 | dset = SceneGraphNoPairsDataset(**dset_kwargs) 50 | 51 | return dset 52 | 53 | 54 | def build_dset_withpairs(args, checkpoint, vocab_t): 55 | 56 | vocab = vocab_t 57 | dset_kwargs = { 58 | 'vocab': vocab, 59 | 'h5_path': args.data_h5, 60 | 'image_dir': args.data_image_dir, 61 | 'image_size': args.image_size, 62 | 'max_objects': checkpoint['args']['max_objects_per_image'], 63 | 'use_orphaned_objects': checkpoint['args']['vg_use_orphaned_objects'], 64 | 'mode': args.mode 65 | } 66 | dset = SceneGraphWithPairsDataset(**dset_kwargs) 67 | 68 | return dset 69 | 70 | 71 | def build_eval_loader(args, checkpoint, vocab_t=None, no_gt=False): 72 | 73 | if args.dataset == 'vg' or (no_gt and args.dataset == 'clevr'): 74 | dset = build_dset_nopairs(args, checkpoint) 75 | collate_fn = collate_fn_nopairs 76 | elif args.dataset == 'clevr': 77 | dset = build_dset_withpairs(args, checkpoint, vocab_t) 78 | collate_fn = collate_fn_withpairs 79 | 80 | loader_kwargs = { 81 | 'batch_size': 1, 82 | 'num_workers': args.loader_num_workers, 83 | 'shuffle': args.shuffle, 84 | 'collate_fn': collate_fn, 85 | } 86 | loader = DataLoader(dset, **loader_kwargs) 87 | 88 | return loader 89 | 90 | 91 | def build_train_dsets(args): 92 | print("building unpaired %s dataset" % args.dataset) 93 | with open(args.vocab_json, 'r') as f: 94 | vocab = json.load(f) 95 | dset_kwargs = { 96 | 'vocab': vocab, 97 | 'h5_path': args.train_h5, 98 | 'image_dir': args.vg_image_dir, 99 | 'image_size': args.image_size, 100 | 'max_samples': args.num_train_samples, 101 | 'max_objects': args.max_objects_per_image, 102 | 'use_orphaned_objects': args.vg_use_orphaned_objects, 103 | 'include_relationships': args.include_relationships, 104 | } 105 | train_dset = SceneGraphNoPairsDataset(**dset_kwargs) 106 | iter_per_epoch = len(train_dset) // args.batch_size 107 | print('There are %d iterations per epoch' % iter_per_epoch) 108 | 109 | dset_kwargs['h5_path'] = args.val_h5 110 | del dset_kwargs['max_samples'] 111 | val_dset = SceneGraphNoPairsDataset(**dset_kwargs) 112 | 113 | return vocab, train_dset, val_dset 114 | 115 | 116 | def build_train_loaders(args): 117 | 118 | print(args.dataset) 119 | if args.dataset == 'vg' or (args.dataset == "clevr" and not args.is_supervised): 120 | vocab, train_dset, val_dset = build_train_dsets(args) 121 | collate_fn = collate_fn_nopairs 122 | elif args.dataset == 'clevr': 123 | vocab, train_dset, val_dset, test_dset = build_clevr_supervised_train_dsets(args) 124 | collate_fn = collate_fn_withpairs 125 | 126 | loader_kwargs = { 127 | 'batch_size': args.batch_size, 128 | 'num_workers': args.loader_num_workers, 129 | 'shuffle': True, 130 | 'collate_fn': collate_fn, 131 | } 132 | train_loader = DataLoader(train_dset, **loader_kwargs) 133 | 134 | loader_kwargs['shuffle'] = args.shuffle_val 135 | val_loader = DataLoader(val_dset, **loader_kwargs) 136 | 137 | return vocab, train_loader, val_loader 138 | -------------------------------------------------------------------------------- /simsg/losses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | 20 | 21 | def get_gan_losses(gan_type): 22 | """ 23 | Returns the generator and discriminator loss for a particular GAN type. 24 | 25 | The returned functions have the following API: 26 | loss_g = g_loss(scores_fake) 27 | loss_d = d_loss(scores_real, scores_fake) 28 | """ 29 | if gan_type == 'gan': 30 | return gan_g_loss, gan_d_loss 31 | elif gan_type == 'wgan': 32 | return wgan_g_loss, wgan_d_loss 33 | elif gan_type == 'lsgan': 34 | return lsgan_g_loss, lsgan_d_loss 35 | else: 36 | raise ValueError('Unrecognized GAN type "%s"' % gan_type) 37 | 38 | def gan_percept_loss(real, fake): 39 | 40 | ''' 41 | Inputs: 42 | - real: discriminator feat maps for every layer, when x=real image 43 | - fake: discriminator feat maps for every layer, when x=pred image 44 | Returns: 45 | perceptual loss in all discriminator layers 46 | ''' 47 | 48 | loss = 0 49 | 50 | for i in range(len(real)): 51 | loss += (real[i] - fake[i]).abs().mean() 52 | 53 | return loss / len(real) 54 | 55 | 56 | def bce_loss(input, target): 57 | """ 58 | Numerically stable version of the binary cross-entropy loss function. 59 | 60 | As per https://github.com/pytorch/pytorch/issues/751 61 | See the TensorFlow docs for a derivation of this formula: 62 | https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits 63 | 64 | Inputs: 65 | - input: PyTorch Tensor of shape (N, ) giving scores. 66 | - target: PyTorch Tensor of shape (N,) containing 0 and 1 giving targets. 67 | 68 | Returns: 69 | - A PyTorch Tensor containing the mean BCE loss over the minibatch of 70 | input data. 71 | """ 72 | neg_abs = -input.abs() 73 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() 74 | return loss.mean() 75 | 76 | 77 | def _make_targets(x, y): 78 | """ 79 | Inputs: 80 | - x: PyTorch Tensor 81 | - y: Python scalar 82 | 83 | Outputs: 84 | - out: PyTorch Variable with same shape and dtype as x, but filled with y 85 | """ 86 | return torch.full_like(x, y) 87 | 88 | 89 | def gan_g_loss(scores_fake): 90 | """ 91 | Input: 92 | - scores_fake: Tensor of shape (N,) containing scores for fake samples 93 | 94 | Output: 95 | - loss: Variable of shape (,) giving GAN generator loss 96 | """ 97 | if scores_fake.dim() > 1: 98 | scores_fake = scores_fake.view(-1) 99 | y_fake = _make_targets(scores_fake, 1) 100 | return bce_loss(scores_fake, y_fake) 101 | 102 | 103 | def gan_d_loss(scores_real, scores_fake): 104 | """ 105 | Input: 106 | - scores_real: Tensor of shape (N,) giving scores for real samples 107 | - scores_fake: Tensor of shape (N,) giving scores for fake samples 108 | 109 | Output: 110 | - loss: Tensor of shape (,) giving GAN discriminator loss 111 | """ 112 | assert scores_real.size() == scores_fake.size() 113 | if scores_real.dim() > 1: 114 | scores_real = scores_real.view(-1) 115 | scores_fake = scores_fake.view(-1) 116 | y_real = _make_targets(scores_real, 1) 117 | y_fake = _make_targets(scores_fake, 0) 118 | loss_real = bce_loss(scores_real, y_real) 119 | loss_fake = bce_loss(scores_fake, y_fake) 120 | return loss_real + loss_fake 121 | 122 | 123 | def wgan_g_loss(scores_fake): 124 | """ 125 | Input: 126 | - scores_fake: Tensor of shape (N,) containing scores for fake samples 127 | 128 | Output: 129 | - loss: Tensor of shape (,) giving WGAN generator loss 130 | """ 131 | return -scores_fake.mean() 132 | 133 | 134 | def wgan_d_loss(scores_real, scores_fake): 135 | """ 136 | Input: 137 | - scores_real: Tensor of shape (N,) giving scores for real samples 138 | - scores_fake: Tensor of shape (N,) giving scores for fake samples 139 | 140 | Output: 141 | - loss: Tensor of shape (,) giving WGAN discriminator loss 142 | """ 143 | return scores_fake.mean() - scores_real.mean() 144 | 145 | 146 | def lsgan_g_loss(scores_fake): 147 | if scores_fake.dim() > 1: 148 | scores_fake = scores_fake.view(-1) 149 | y_fake = _make_targets(scores_fake, 1) 150 | return F.mse_loss(scores_fake.sigmoid(), y_fake) 151 | 152 | 153 | def lsgan_d_loss(scores_real, scores_fake): 154 | assert scores_real.size() == scores_fake.size() 155 | if scores_real.dim() > 1: 156 | scores_real = scores_real.view(-1) 157 | scores_fake = scores_fake.view(-1) 158 | y_real = _make_targets(scores_real, 1) 159 | y_fake = _make_targets(scores_fake, 0) 160 | loss_real = F.mse_loss(scores_real.sigmoid(), y_real) 161 | loss_fake = F.mse_loss(scores_fake.sigmoid(), y_fake) 162 | return loss_real + loss_fake 163 | 164 | 165 | def gradient_penalty(x_real, x_fake, f, gamma=1.0): 166 | N = x_real.size(0) 167 | device, dtype = x_real.device, x_real.dtype 168 | eps = torch.randn(N, 1, 1, 1, device=device, dtype=dtype) 169 | x_hat = eps * x_real + (1 - eps) * x_fake 170 | x_hat_score = f(x_hat) 171 | if x_hat_score.dim() > 1: 172 | x_hat_score = x_hat_score.view(x_hat_score.size(0), -1).mean(dim=1) 173 | x_hat_score = x_hat_score.sum() 174 | grad_x_hat, = torch.autograd.grad(x_hat_score, x_hat, create_graph=True) 175 | grad_x_hat_norm = grad_x_hat.contiguous().view(N, -1).norm(p=2, dim=1) 176 | gp_loss = (grad_x_hat_norm - gamma).pow(2).div(gamma * gamma).mean() 177 | return gp_loss 178 | 179 | """ 180 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 181 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 182 | """ 183 | 184 | import torch 185 | import torch.nn as nn 186 | #import torch.nn.functional as F 187 | from simsg.SPADE.architectures import VGG19 188 | 189 | 190 | # SPADE losses! # 191 | # Defines the GAN loss which uses either LSGAN or the regular GAN. 192 | # When LSGAN is used, it is basically same as MSELoss, 193 | # but it abstracts away the need to create the target label tensor 194 | # that has the same size as the input 195 | class GANLoss(nn.Module): 196 | def __init__(self, gan_mode='hinge', target_real_label=1.0, target_fake_label=0.0, 197 | tensor=torch.cuda.FloatTensor, opt=None): 198 | super(GANLoss, self).__init__() 199 | self.real_label = target_real_label 200 | self.fake_label = target_fake_label 201 | self.real_label_tensor = None 202 | self.fake_label_tensor = None 203 | self.zero_tensor = None 204 | self.Tensor = tensor 205 | self.gan_mode = gan_mode 206 | self.opt = opt 207 | if gan_mode == 'ls': 208 | pass 209 | elif gan_mode == 'original': 210 | pass 211 | elif gan_mode == 'w': 212 | pass 213 | elif gan_mode == 'hinge': 214 | pass 215 | else: 216 | raise ValueError('Unexpected gan_mode {}'.format(gan_mode)) 217 | 218 | def get_target_tensor(self, input, target_is_real): 219 | if target_is_real: 220 | if self.real_label_tensor is None: 221 | self.real_label_tensor = self.Tensor(1).fill_(self.real_label) 222 | self.real_label_tensor.requires_grad_(False) 223 | return self.real_label_tensor.expand_as(input) 224 | else: 225 | if self.fake_label_tensor is None: 226 | self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label) 227 | self.fake_label_tensor.requires_grad_(False) 228 | return self.fake_label_tensor.expand_as(input) 229 | 230 | def get_zero_tensor(self, input): 231 | if self.zero_tensor is None: 232 | self.zero_tensor = self.Tensor(1).fill_(0) 233 | self.zero_tensor.requires_grad_(False) 234 | return self.zero_tensor.expand_as(input) 235 | 236 | def loss(self, input, target_is_real, for_discriminator=True): 237 | if self.gan_mode == 'original': # cross entropy loss 238 | target_tensor = self.get_target_tensor(input, target_is_real) 239 | loss = F.binary_cross_entropy_with_logits(input, target_tensor) 240 | return loss 241 | elif self.gan_mode == 'ls': 242 | target_tensor = self.get_target_tensor(input, target_is_real) 243 | return F.mse_loss(input, target_tensor) 244 | elif self.gan_mode == 'hinge': 245 | if for_discriminator: 246 | if target_is_real: 247 | minval = torch.min(input - 1, self.get_zero_tensor(input)) 248 | loss = -torch.mean(minval) 249 | else: 250 | minval = torch.min(-input - 1, self.get_zero_tensor(input)) 251 | loss = -torch.mean(minval) 252 | else: 253 | assert target_is_real, "The generator's hinge loss must be aiming for real" 254 | loss = -torch.mean(input) 255 | return loss 256 | else: 257 | # wgan 258 | if target_is_real: 259 | return -input.mean() 260 | else: 261 | return input.mean() 262 | 263 | def __call__(self, input, target_is_real, for_discriminator=True): 264 | # computing loss is a bit complicated because |input| may not be 265 | # a tensor, but list of tensors in case of multiscale discriminator 266 | if isinstance(input, list): 267 | loss = 0 268 | for pred_i in input: 269 | if isinstance(pred_i, list): 270 | pred_i = pred_i[-1] 271 | loss_tensor = self.loss(pred_i, target_is_real, for_discriminator) 272 | bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) 273 | new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) 274 | loss += new_loss 275 | return loss / len(input) 276 | else: 277 | return self.loss(input, target_is_real, for_discriminator) 278 | 279 | 280 | # Perceptual loss that uses a pretrained VGG network 281 | class VGGLoss(nn.Module): 282 | def __init__(self): 283 | super(VGGLoss, self).__init__() 284 | self.vgg = VGG19().cuda() 285 | self.criterion = nn.L1Loss() 286 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] 287 | 288 | def forward(self, x, y): 289 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 290 | loss = 0 291 | for i in range(len(x_vgg)): 292 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 293 | return loss 294 | 295 | 296 | # KL Divergence loss used in VAE with an image encoder 297 | class KLDLoss(nn.Module): 298 | def forward(self, mu, logvar): 299 | return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 300 | 301 | 302 | -------------------------------------------------------------------------------- /simsg/metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import torch 18 | import numpy as np 19 | from scipy import signal 20 | from scipy.ndimage.filters import convolve 21 | from PIL import Image 22 | 23 | 24 | def intersection(bbox_pred, bbox_gt): 25 | max_xy = torch.min(bbox_pred[:, 2:], bbox_gt[:, 2:]) 26 | min_xy = torch.max(bbox_pred[:, :2], bbox_gt[:, :2]) 27 | inter = torch.clamp((max_xy - min_xy), min=0) 28 | return inter[:, 0] * inter[:, 1] 29 | 30 | 31 | def jaccard(bbox_pred, bbox_gt): 32 | inter = intersection(bbox_pred, bbox_gt) 33 | area_pred = (bbox_pred[:, 2] - bbox_pred[:, 0]) * (bbox_pred[:, 3] - 34 | bbox_pred[:, 1]) 35 | area_gt = (bbox_gt[:, 2] - bbox_gt[:, 0]) * (bbox_gt[:, 3] - 36 | bbox_gt[:, 1]) 37 | union = area_pred + area_gt - inter 38 | iou = torch.div(inter, union) 39 | return torch.sum(iou) 40 | 41 | def get_total_norm(parameters, norm_type=2): 42 | if norm_type == float('inf'): 43 | total_norm = max(p.grad.data.abs().max() for p in parameters) 44 | else: 45 | total_norm = 0 46 | for p in parameters: 47 | try: 48 | param_norm = p.grad.data.norm(norm_type) 49 | total_norm += param_norm ** norm_type 50 | total_norm = total_norm ** (1. / norm_type) 51 | except: 52 | continue 53 | return total_norm 54 | 55 | 56 | def _FSpecialGauss(size, sigma): 57 | """Function to mimic the 'fspecial' gaussian MATLAB function.""" 58 | radius = size // 2 59 | offset = 0.0 60 | start, stop = -radius, radius + 1 61 | if size % 2 == 0: 62 | offset = 0.5 63 | stop -= 1 64 | x, y = np.mgrid[offset + start:stop, offset + start:stop] 65 | assert len(x) == size 66 | g = np.exp(-((x**2 + y**2) / (2.0 * sigma**2))) 67 | return g / g.sum() 68 | 69 | 70 | def _SSIMForMultiScale(img1, 71 | img2, 72 | max_val=255, 73 | filter_size=11, 74 | filter_sigma=1.5, 75 | k1=0.01, 76 | k2=0.03): 77 | """Return the Structural Similarity Map between `img1` and `img2`. 78 | This function attempts to match the functionality of ssim_index_new.m by 79 | Zhou Wang: http://www.cns.nyu.edu/~lcv/ssim/msssim.zip 80 | Arguments: 81 | img1: Numpy array holding the first RGB image batch. 82 | img2: Numpy array holding the second RGB image batch. 83 | max_val: the dynamic range of the images (i.e., the difference between the 84 | maximum the and minimum allowed values). 85 | filter_size: Size of blur kernel to use (will be reduced for small images). 86 | filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced 87 | for small images). 88 | k1: Constant used to maintain stability in the SSIM calculation (0.01 in 89 | the original paper). 90 | k2: Constant used to maintain stability in the SSIM calculation (0.03 in 91 | the original paper). 92 | Returns: 93 | Pair containing the mean SSIM and contrast sensitivity between `img1` and 94 | `img2`. 95 | Raises: 96 | RuntimeError: If input images don't have the same shape or don't have four 97 | dimensions: [batch_size, height, width, depth]. 98 | """ 99 | if img1.shape != img2.shape: 100 | raise RuntimeError( 101 | 'Input images must have the same shape (%s vs. %s).', img1.shape, 102 | img2.shape) 103 | if img1.ndim != 4: 104 | raise RuntimeError('Input images must have four dimensions, not %d', 105 | img1.ndim) 106 | 107 | img1 = img1.astype(np.float64) 108 | img2 = img2.astype(np.float64) 109 | _, height, width, _ = img1.shape 110 | 111 | # Filter size can't be larger than height or width of images. 112 | size = min(filter_size, height, width) 113 | 114 | # Scale down sigma if a smaller filter size is used. 115 | sigma = size * filter_sigma / filter_size if filter_size else 0 116 | 117 | if filter_size: 118 | window = np.reshape(_FSpecialGauss(size, sigma), (1, size, size, 1)) 119 | mu1 = signal.fftconvolve(img1, window, mode='valid') 120 | mu2 = signal.fftconvolve(img2, window, mode='valid') 121 | sigma11 = signal.fftconvolve(img1 * img1, window, mode='valid') 122 | sigma22 = signal.fftconvolve(img2 * img2, window, mode='valid') 123 | sigma12 = signal.fftconvolve(img1 * img2, window, mode='valid') 124 | else: 125 | # Empty blur kernel so no need to convolve. 126 | mu1, mu2 = img1, img2 127 | sigma11 = img1 * img1 128 | sigma22 = img2 * img2 129 | sigma12 = img1 * img2 130 | 131 | mu11 = mu1 * mu1 132 | mu22 = mu2 * mu2 133 | mu12 = mu1 * mu2 134 | sigma11 -= mu11 135 | sigma22 -= mu22 136 | sigma12 -= mu12 137 | 138 | # Calculate intermediate values used by both ssim and cs_map. 139 | c1 = (k1 * max_val)**2 140 | c2 = (k2 * max_val)**2 141 | v1 = 2.0 * sigma12 + c2 142 | v2 = sigma11 + sigma22 + c2 143 | ssim = np.mean((((2.0 * mu12 + c1) * v1) / ((mu11 + mu22 + c1) * v2))) 144 | cs = np.mean(v1 / v2) 145 | return ssim, cs 146 | 147 | 148 | def MultiScaleSSIM(img1, 149 | img2, 150 | max_val=255, 151 | filter_size=11, 152 | filter_sigma=1.5, 153 | k1=0.01, 154 | k2=0.03, 155 | weights=None): 156 | """Return the MS-SSIM score between `img1` and `img2`. 157 | This function implements Multi-Scale Structural Similarity (MS-SSIM) Image 158 | Quality Assessment according to Zhou Wang's paper, "Multi-scale structural 159 | similarity for image quality assessment" (2003). 160 | Link: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf 161 | Author's MATLAB implementation: 162 | http://www.cns.nyu.edu/~lcv/ssim/msssim.zip 163 | Arguments: 164 | img1: Numpy array holding the first RGB image batch. 165 | img2: Numpy array holding the second RGB image batch. 166 | max_val: the dynamic range of the images (i.e., the difference between the 167 | maximum the and minimum allowed values). 168 | filter_size: Size of blur kernel to use (will be reduced for small images). 169 | filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced 170 | for small images). 171 | k1: Constant used to maintain stability in the SSIM calculation (0.01 in 172 | the original paper). 173 | k2: Constant used to maintain stability in the SSIM calculation (0.03 in 174 | the original paper). 175 | weights: List of weights for each level; if none, use five levels and the 176 | weights from the original paper. 177 | Returns: 178 | MS-SSIM score between `img1` and `img2`. 179 | Raises: 180 | RuntimeError: If input images don't have the same shape or don't have four 181 | dimensions: [batch_size, height, width, depth]. 182 | """ 183 | if img1.shape != img2.shape: 184 | raise RuntimeError( 185 | 'Input images must have the same shape (%s vs. %s).', img1.shape, 186 | img2.shape) 187 | if img1.ndim != 4: 188 | raise RuntimeError('Input images must have four dimensions, not %d', 189 | img1.ndim) 190 | 191 | # Note: default weights don't sum to 1.0 but do match the paper / matlab code. 192 | weights = np.array(weights if weights else 193 | [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) 194 | levels = weights.size 195 | downsample_filter = np.ones((1, 2, 2, 1)) / 4.0 196 | im1, im2 = [x.astype(np.float64) for x in [img1, img2]] 197 | mssim = np.array([]) 198 | mcs = np.array([]) 199 | for _ in range(levels): 200 | ssim, cs = _SSIMForMultiScale( 201 | im1, 202 | im2, 203 | max_val=max_val, 204 | filter_size=filter_size, 205 | filter_sigma=filter_sigma, 206 | k1=k1, 207 | k2=k2) 208 | mssim = np.append(mssim, ssim) 209 | mcs = np.append(mcs, cs) 210 | filtered = [ 211 | convolve(im, downsample_filter, mode='reflect') 212 | for im in [im1, im2] 213 | ] 214 | im1, im2 = [x[:, ::2, ::2, :] for x in filtered] 215 | return (np.prod(mcs[0:levels - 1]**weights[0:levels - 1]) * 216 | (mssim[levels - 1]**weights[levels - 1])) 217 | 218 | 219 | def msssim(original, compared): 220 | if isinstance(original, str): 221 | original = np.array(Image.open(original).convert('RGB'), dtype=np.float32) 222 | if isinstance(compared, str): 223 | compared = np.array(Image.open(compared).convert('RGB'), dtype=np.float32) 224 | 225 | original = original[None, ...] if original.ndim == 3 else original 226 | compared = compared[None, ...] if compared.ndim == 3 else compared 227 | 228 | return MultiScaleSSIM(original, compared, max_val=255) 229 | 230 | 231 | def psnr(original, compared): 232 | if isinstance(original, str): 233 | original = np.array(Image.open(original).convert('RGB'), dtype=np.float32) 234 | if isinstance(compared, str): 235 | compared = np.array(Image.open(compared).convert('RGB'), dtype=np.float32) 236 | 237 | mse = np.mean(np.square(original - compared)) 238 | psnr = np.clip( 239 | np.multiply(np.log10(255. * 255. / mse[mse > 0.]), 10.), 0., 99.99)[0] 240 | return psnr 241 | 242 | 243 | def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx): 244 | """Match each prior box with the ground truth box of the highest jaccard 245 | overlap, encode the bounding boxes, then return the matched indices 246 | corresponding to both confidence and location preds. 247 | Args: 248 | threshold: (float) The overlap threshold used when mathing boxes. 249 | truths: (tensor) Ground truth boxes, Shape: [num_obj, num_priors]. 250 | priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. 251 | variances: (tensor) Variances corresponding to each prior coord, 252 | Shape: [num_priors, 4]. 253 | labels: (tensor) All the class labels for the image, Shape: [num_obj]. 254 | loc_t: (tensor) Tensor to be filled w/ endcoded location targets. 255 | conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds. 256 | idx: (int) current batch index 257 | Return: 258 | The matched indices corresponding to 1)location and 2)confidence preds. 259 | """ 260 | # jaccard index 261 | overlaps = jaccard( 262 | truths, 263 | point_form(priors) 264 | ) 265 | # (Bipartite Matching) 266 | # [1,num_objects] best prior for each ground truth 267 | best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) 268 | # [1,num_priors] best ground truth for each prior 269 | best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) 270 | best_truth_idx.squeeze_(0) 271 | best_truth_overlap.squeeze_(0) 272 | best_prior_idx.squeeze_(1) 273 | best_prior_overlap.squeeze_(1) 274 | best_truth_overlap.index_fill_(0, best_prior_idx, 2) # ensure best prior 275 | # TODO refactor: index best_prior_idx with long tensor 276 | # ensure every gt matches with its prior of max overlap 277 | for j in range(best_prior_idx.size(0)): 278 | best_truth_idx[best_prior_idx[j]] = j 279 | matches = truths[best_truth_idx] # Shape: [num_priors,4] 280 | conf = labels[best_truth_idx] + 1 # Shape: [num_priors] 281 | conf[best_truth_overlap < threshold] = 0 # label as background 282 | loc = encode(matches, priors, variances) 283 | loc_t[idx] = loc # [num_priors,4] encoded offsets to learn 284 | conf_t[idx] = conf # [num_priors] top class label for each prior 285 | 286 | def point_form(boxes): 287 | """ Convert prior_boxes to (xmin, ymin, xmax, ymax) 288 | representation for comparison to point form ground truth data. 289 | Args: 290 | boxes: (tensor) center-size default boxes from priorbox layers. 291 | Return: 292 | boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. 293 | """ 294 | return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin 295 | boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax 296 | 297 | def encode(matched, priors, variances): 298 | """Encode the variances from the priorbox layers into the ground truth boxes 299 | we have matched (based on jaccard overlap) with the prior boxes. 300 | Args: 301 | matched: (tensor) Coords of ground truth for each prior in point-form 302 | Shape: [num_priors, 4]. 303 | priors: (tensor) Prior boxes in center-offset form 304 | Shape: [num_priors,4]. 305 | variances: (list[float]) Variances of priorboxes 306 | Return: 307 | encoded boxes (tensor), Shape: [num_priors, 4] 308 | """ 309 | 310 | # dist b/t match center and prior's center 311 | g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2] 312 | # encode variance 313 | g_cxcy /= (variances[0] * priors[:, 2:]) 314 | # match wh / prior wh 315 | g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] 316 | g_wh = torch.log(g_wh) / variances[1] 317 | # return target for smooth_l1_loss 318 | return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] 319 | 320 | def find_intersection(set_1, set_2): 321 | """ 322 | Find the intersection of every box combination between two sets of boxes that are in boundary coordinates. 323 | :param set_1: set 1, a tensor of dimensions (n1, 4) 324 | :param set_2: set 2, a tensor of dimensions (n2, 4) 325 | :return: intersection of each of the boxes in set 1 with respect to each of the boxes in set 2, a tensor of dimensions (n1, n2) 326 | """ 327 | 328 | # PyTorch auto-broadcasts singleton dimensions 329 | lower_bounds = torch.max(set_1[:, :2].unsqueeze(1), set_2[:, :2].unsqueeze(0)) # (n1, n2, 2) 330 | upper_bounds = torch.min(set_1[:, 2:].unsqueeze(1), set_2[:, 2:].unsqueeze(0)) # (n1, n2, 2) 331 | intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0) # (n1, n2, 2) 332 | return intersection_dims[:, :, 0] * intersection_dims[:, :, 1] # (n1, n2) 333 | -------------------------------------------------------------------------------- /simsg/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import time 18 | import inspect 19 | import subprocess 20 | from contextlib import contextmanager 21 | 22 | import torch 23 | 24 | 25 | def int_tuple(s): 26 | return tuple(int(i) for i in s.split(',')) 27 | 28 | 29 | def float_tuple(s): 30 | return tuple(float(i) for i in s.split(',')) 31 | 32 | 33 | def str_tuple(s): 34 | return tuple(s.split(',')) 35 | 36 | 37 | def bool_flag(s): 38 | if s == '1' or s == 'True' or s == 'true': 39 | return True 40 | elif s == '0' or s == 'False' or s == 'false': 41 | return False 42 | msg = 'Invalid value "%s" for bool flag (should be 0/1 or True/False or true/false)' 43 | raise ValueError(msg % s) 44 | 45 | 46 | def lineno(): 47 | return inspect.currentframe().f_back.f_lineno 48 | 49 | 50 | def get_gpu_memory(): 51 | torch.cuda.synchronize() 52 | opts = [ 53 | 'nvidia-smi', '-q', '--gpu=' + str(0), '|', 'grep', '"Used GPU Memory"' 54 | ] 55 | cmd = str.join(' ', opts) 56 | ps = subprocess.Popen(cmd,shell=True,stdout=subprocess.PIPE,stderr=subprocess.STDOUT) 57 | output = ps.communicate()[0].decode('utf-8') 58 | output = output.split("\n")[1].split(":") 59 | consumed_mem = int(output[1].strip().split(" ")[0]) 60 | return consumed_mem 61 | 62 | 63 | @contextmanager 64 | def timeit(msg, should_time=True): 65 | if should_time: 66 | torch.cuda.synchronize() 67 | t0 = time.time() 68 | yield 69 | if should_time: 70 | torch.cuda.synchronize() 71 | t1 = time.time() 72 | duration = (t1 - t0) * 1000.0 73 | print('%s: %.2f ms' % (msg, duration)) 74 | 75 | 76 | class LossManager(object): 77 | def __init__(self): 78 | self.total_loss = None 79 | self.all_losses = {} 80 | 81 | def add_loss(self, loss, name, weight=1.0): 82 | cur_loss = loss * weight 83 | if self.total_loss is not None: 84 | self.total_loss += cur_loss 85 | else: 86 | self.total_loss = cur_loss 87 | 88 | self.all_losses[name] = cur_loss.data.cpu().item() 89 | 90 | def items(self): 91 | return self.all_losses.items() 92 | 93 | -------------------------------------------------------------------------------- /simsg/vis.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import tempfile, os 18 | import torch 19 | import numpy as np 20 | import matplotlib.pyplot as plt 21 | from matplotlib.patches import Rectangle 22 | from imageio import imread 23 | 24 | 25 | """ 26 | Utilities for making visualizations. 27 | """ 28 | 29 | 30 | def draw_layout(vocab, objs, boxes, masks=None, size=256, 31 | show_boxes=False, bgcolor=(0, 0, 0)): 32 | if bgcolor == 'white': 33 | bgcolor = (255, 255, 255) 34 | 35 | cmap = plt.get_cmap('rainbow') 36 | colors = cmap(np.linspace(0, 1, len(objs))) 37 | 38 | with torch.no_grad(): 39 | objs = objs.cpu().clone() 40 | boxes = boxes.cpu().clone() 41 | boxes *= size 42 | 43 | if masks is not None: 44 | masks = masks.cpu().clone() 45 | 46 | bgcolor = np.asarray(bgcolor) 47 | bg = np.ones((size, size, 1)) * bgcolor 48 | plt.imshow(bg.astype(np.uint8)) 49 | 50 | plt.gca().set_xlim(0, size) 51 | plt.gca().set_ylim(size, 0) 52 | plt.gca().set_aspect(1.0, adjustable='box') 53 | 54 | for i, obj in enumerate(objs): 55 | name = vocab['object_idx_to_name'][obj] 56 | if name == '__image__': 57 | continue 58 | box = boxes[i] 59 | 60 | if masks is None: 61 | continue 62 | mask = masks[i].numpy() 63 | mask /= mask.max() 64 | 65 | r, g, b, a = colors[i] 66 | colored_mask = mask[:, :, None] * np.asarray(colors[i]) 67 | 68 | x0, y0, x1, y1 = box 69 | plt.imshow(colored_mask, extent=(x0, x1, y1, y0), 70 | interpolation='bicubic', alpha=1.0) 71 | 72 | if show_boxes: 73 | for i, obj in enumerate(objs): 74 | name = vocab['object_idx_to_name'][obj] 75 | if name == '__image__': 76 | continue 77 | box = boxes[i] 78 | 79 | draw_box(box, colors[i], name) 80 | 81 | 82 | def draw_box(box, color, text=None): 83 | """ 84 | Draw a bounding box using pyplot, optionally with a text box label. 85 | 86 | Inputs: 87 | - box: Tensor or list with 4 elements: [x0, y0, x1, y1] in [0, W] x [0, H] 88 | coordinate system. 89 | - color: pyplot color to use for the box. 90 | - text: (Optional) String; if provided then draw a label for this box. 91 | """ 92 | TEXT_BOX_HEIGHT = 10 93 | if torch.is_tensor(box) and box.dim() == 2: 94 | box = box.view(-1) 95 | assert box.size(0) == 4 96 | x0, y0, x1, y1 = box 97 | assert y1 > y0, box 98 | assert x1 > x0, box 99 | w, h = x1 - x0, y1 - y0 100 | rect = Rectangle((x0, y0), w, h, fc='none', lw=2, ec=color) 101 | plt.gca().add_patch(rect) 102 | if text is not None: 103 | text_rect = Rectangle((x0, y0), w, TEXT_BOX_HEIGHT, fc=color, alpha=0.5) 104 | plt.gca().add_patch(text_rect) 105 | tx = 0.5 * (x0 + x1) 106 | ty = y0 + TEXT_BOX_HEIGHT / 2.0 107 | plt.text(tx, ty, text, va='center', ha='center') 108 | 109 | 110 | def draw_scene_graph(objs, triples, vocab=None, **kwargs): 111 | """ 112 | Use GraphViz to draw a scene graph. If vocab is not passed then we assume 113 | that objs and triples are python lists containing strings for object and 114 | relationship names. 115 | 116 | Using this requires that GraphViz is installed. On Ubuntu 16.04 this is easy: 117 | sudo apt-get install graphviz 118 | """ 119 | output_filename = kwargs.pop('output_filename', 'graph.png') 120 | orientation = kwargs.pop('orientation', 'LR') 121 | edge_width = kwargs.pop('edge_width', 6) 122 | arrow_size = kwargs.pop('arrow_size', 1.5) 123 | binary_edge_weight = kwargs.pop('binary_edge_weight', 1.2) 124 | ignore_dummies = kwargs.pop('ignore_dummies', True) 125 | 126 | #if orientation not in ['V', 'H']: 127 | # raise ValueError('Invalid orientation "%s"' % orientation) 128 | rankdir = orientation # {'H': 'LR', 'V': 'TD'}[orientation] 129 | 130 | if vocab is not None: 131 | # Decode object and relationship names 132 | assert torch.is_tensor(objs) 133 | assert torch.is_tensor(triples) 134 | objs_list, triples_list = [], [] 135 | for i in range(objs.size(0)): 136 | objs_list.append(vocab['object_idx_to_name'][objs[i].item()]) 137 | for i in range(triples.size(0)): 138 | s = triples[i, 0].item() 139 | #print(vocab['pred_idx_to_name'], triples[i, 1].item()) 140 | p = vocab['pred_idx_to_name'][triples[i, 1].item()] 141 | o = triples[i, 2].item() 142 | triples_list.append([s, p, o]) 143 | objs, triples = objs_list, triples_list 144 | 145 | # General setup, and style for object nodes 146 | lines = [ 147 | 'digraph{', 148 | 'graph [size="5,3",ratio="compress",dpi="300",bgcolor="white"]', 149 | 'rankdir=%s' % rankdir, 150 | 'nodesep="0.5"', 151 | 'ranksep="0.5"', 152 | 'node [shape="box",style="rounded,filled",fontsize="48",color="none"]', 153 | 'node [fillcolor="lightpink1"]', 154 | ] 155 | # Output nodes for objects 156 | for i, obj in enumerate(objs): 157 | if ignore_dummies and obj == '__image__': 158 | continue 159 | lines.append('%d [label="%s"]' % (i, obj)) 160 | 161 | # Output relationships 162 | next_node_id = len(objs) 163 | lines.append('node [fillcolor="lightblue1"]') 164 | for s, p, o in triples: 165 | if ignore_dummies and p == '__in_image__': 166 | continue 167 | lines += [ 168 | '%d [label="%s"]' % (next_node_id, p), 169 | '%d->%d [penwidth=%f,arrowsize=%f,weight=%f]' % ( 170 | s, next_node_id, edge_width, arrow_size, binary_edge_weight), 171 | '%d->%d [penwidth=%f,arrowsize=%f,weight=%f]' % ( 172 | next_node_id, o, edge_width, arrow_size, binary_edge_weight) 173 | ] 174 | next_node_id += 1 175 | lines.append('}') 176 | 177 | # Now it gets slightly hacky. Write the graphviz spec to a temporary 178 | # text file 179 | ff, dot_filename = tempfile.mkstemp() 180 | #dot_filename = "graph.dot" 181 | with open(dot_filename, 'w') as f: 182 | for line in lines: 183 | f.write('%s\n' % line) 184 | os.close(ff) 185 | 186 | # Shell out to invoke graphviz; this will save the resulting image to disk, 187 | # so we read it, delete it, then return it. 188 | output_format = os.path.splitext(output_filename)[1][1:] 189 | #os.system('dot -T%s %s > %s' % (output_format, dot_filename, output_filename)) 190 | os.system('ccomps -x %s | dot | gvpack -array3 | neato -Tpng -n2 -o %s' % (dot_filename, output_filename)) 191 | os.remove(dot_filename) 192 | img = imread(output_filename) 193 | os.remove(output_filename) 194 | 195 | return img 196 | 197 | 198 | if __name__ == '__main__': 199 | o_idx_to_name = ['cat', 'dog', 'hat', 'skateboard'] 200 | p_idx_to_name = ['riding', 'wearing', 'on', 'next to', 'above'] 201 | o_name_to_idx = {s: i for i, s in enumerate(o_idx_to_name)} 202 | p_name_to_idx = {s: i for i, s in enumerate(p_idx_to_name)} 203 | vocab = { 204 | 'object_idx_to_name': o_idx_to_name, 205 | 'object_name_to_idx': o_name_to_idx, 206 | 'pred_idx_to_name': p_idx_to_name, 207 | 'pred_name_to_idx': p_name_to_idx, 208 | } 209 | 210 | objs = [ 211 | 'cat', 212 | 'cat', 213 | 'skateboard', 214 | 'hat', 215 | ] 216 | objs = torch.LongTensor([o_name_to_idx[o] for o in objs]) 217 | triples = [ 218 | [0, 'next to', 1], 219 | [0, 'riding', 2], 220 | [1, 'wearing', 3], 221 | [3, 'above', 2], 222 | ] 223 | triples = [[s, p_name_to_idx[p], o] for s, p, o in triples] 224 | triples = torch.LongTensor(triples) 225 | 226 | draw_scene_graph(objs, triples, vocab, orientation='V') 227 | 228 | --------------------------------------------------------------------------------