├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── data_scripts ├── clevr_with_masks.py ├── download_clevr.sh ├── download_clevrtex.sh ├── preprocess_clevr_with_masks.py ├── preprocess_tetrominoes.py └── tetrominoes.py ├── media ├── gnm_clevr6.png ├── gnm_clevrtex6.png ├── sa_clevr10_with_masks.png ├── sa_clevr6_example.png ├── sa_clevrtex6.png ├── sa_sketchy_with_masks.png └── slate_clevr6.png ├── object_discovery ├── __init__.py ├── data.py ├── gnm │ ├── __init__.py │ ├── config.py │ ├── gnm_model.py │ ├── logging.py │ ├── metrics.py │ ├── module.py │ ├── submodule.py │ └── utils.py ├── method.py ├── params.py ├── segmentation_metrics.py ├── slate_model.py ├── slot_attention_model.py ├── train.py ├── transformer.py └── utils.py ├── poetry.lock ├── predict.py ├── pyproject.toml ├── run.sh └── small_logo.png /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | checkpoint_details.json 3 | .idea/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # Custom 136 | wandb/ 137 | slot-attention-clevr/ 138 | .vscode/ 139 | data/ 140 | debug/ 141 | saved_checkpoints/ 142 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Object Discovery PyTorch 2 | 3 | 4 | 5 | > This is an implementation of several unsupervised object discovery models (Slot Attention, SLATE, GNM) in PyTorch. 6 | 7 | [![GitHub license](https://img.shields.io/github/license/HHousen/object-discovery-pytorch.svg)](https://github.com/HHousen/object-discovery-pytorch/blob/master/LICENSE) [![Github commits](https://img.shields.io/github/last-commit/HHousen/object-discovery-pytorch.svg)](https://github.com/HHousen/object-discovery-pytorch/commits/master) [![made-with-python](https://img.shields.io/badge/Made%20with-Python-1f425f.svg)](https://www.python.org/) [![GitHub issues](https://img.shields.io/github/issues/HHousen/object-discovery-pytorch.svg)](https://GitHub.com/HHousen/object-discovery-pytorch/issues/) [![GitHub pull-requests](https://img.shields.io/github/issues-pr/HHousen/object-discovery-pytorch.svg)](https://GitHub.com/HHousen/object-discovery-pytorch/pull/) 8 | 9 | The initial code for this repo was forked from [untitled-ai/slot_attention](https://github.com/untitled-ai/slot_attention). 10 | 11 | ![Visualization of a slot attention model trained on CLEVR6. This image demonstrates the model's ability to divide objects into slots.](./media/sa_clevr6_example.png) 12 | 13 | ## Setup 14 | 15 | ### Requirements 16 | 17 | - [Poetry](https://python-poetry.org/docs/) 18 | - Python >= 3.9 19 | - CUDA enabled computing device 20 | 21 | ### Getting Started 22 | 23 | 1. Clone the repo: `git clone https://github.com/HHousen/slot-attention-pytorch/ && cd slot-attention-pytorch`. 24 | 2. Install requirements and activate environment: `poetry install` then `poetry shell`. 25 | 3. Download the [CLEVR (with masks)](https://huggingface.co/HHousen/object-discovery-pytorch/blob/main/Datasets/clevr_with_masks.h5) dataset (or the original [CLEVR](https://cs.stanford.edu/people/jcjohns/clevr/) dataset by running `./data_scripts/download_clevr.sh /tmp/CLEVR`). More details about the datasets are below. 26 | 4. Modify the hyperparameters in [object_discovery/params.py](object_discovery/params.py) to fit your needs. Make sure to change `data_root` to the location of your dataset. 27 | 5. Train a model: `python -m slot_attention.train`. 28 | 29 | ## Pre-trained Models 30 | 31 | Code to load these models can be adapted from [predict.py](./predict.py). 32 | 33 | | Model | Dataset | Download | 34 | |---|---|---| 35 | | Slot Attention | CLEVR6 Masks | [Hugging Face](https://huggingface.co/HHousen/object-discovery-pytorch/blob/main/Saved%20Checkpoints/clevr6_masks-epoch=673-step=275666-r4nbi6n7.ckpt) | 36 | | Slot Attention | Sketchy | [Hugging Face](https://huggingface.co/HHousen/object-discovery-pytorch/blob/main/Saved%20Checkpoints/sketchy_sa-epoch=59-step=316440-3nofluv3.ckpt) | 37 | | GNM | CLEVR6 Masks | [Hugging Face](https://huggingface.co/HHousen/object-discovery-pytorch/blob/main/Saved%20Checkpoints/clevr6_gnm_0.4std-epoch=546-step=223723-3fi81x4s.ckpt) | 38 | | Slot Attention | ClevrTex6 | [Hugging Face](https://huggingface.co/HHousen/object-discovery-pytorch/blob/main/Saved%20Checkpoints/clevrtex_sa-epoch=843-step=263328-p0hcfm7j.ckpt) | 39 | | GNM | ClevrTex6 | [Hugging Face](https://huggingface.co/HHousen/object-discovery-pytorch/blob/main/Saved%20Checkpoints/clevrtex_gnm_0.7std-epoch=890-step=277992-35w3i9mg.ckpt) | 40 | | SLATE | CLEVR6 Masks | [Hugging Face](https://huggingface.co/HHousen/object-discovery-pytorch/blob/main/Saved%20Checkpoints/clevr6_slate_no_rescale-1pvy8x3x/epoch=708-step=371516.ckpt) | 41 | 42 | ## Usage 43 | 44 | Train a model by running `python -m slot_attention.train`. 45 | 46 | Hyperparameters can be changed in [object_discovery/params.py](object_discovery/params.py). `training_params` has global parameters that apply to all model types. These parameters can be overridden if the same key is present in `slot_attention_params` or `slate_params`. Change the global parameter `model_type` to `sa` to use Slot Attention (`SlotAttentionModel` in slot_attention_model.py) or `slate` to use SLATE (`SLATE` in slate_model.py). This will determine which model's set of parameters will be merged with `training_params`. 47 | 48 | Perform inference by modifying and running the [predict.py](./predict.py) script. 49 | 50 | ### Models 51 | 52 | Our implementations are based on several open-source repositories. 53 | 54 | 1. Slot Attention (["Object-Centric Learning with Slot Attention"](https://arxiv.org/abs/2006.15055)): [untitled-ai/slot_attention](https://github.com/untitled-ai/slot_attention) & [Official](https://github.com/google-research/google-research/tree/master/slot_attention) 55 | 2. SLATE (["Illiterate DALL-E Learns to Compose"](https://arxiv.org/abs/2110.11405)): [Official](https://github.com/singhgautam/slate) 56 | 3. GNM (["Generative Neurosymbolic Machines"](https://arxiv.org/abs/2010.12152)): [karazijal/clevrtex](https://github.com/karazijal/clevrtex) & [Official](https://github.com/JindongJiang/GNM) 57 | 58 | ### Datasets 59 | 60 | Select a dataset by changing the `dataset` parameter in [object_discovery/params.py](object_discovery/params.py) to the name of the dataset: `clevr`, `shapes3d`, or `ravens`. Then, set the `data_root` parameter to the location of the data. The code for loading supported datasets is in [object_discovery/data.py](object_discovery/data.py). 61 | 62 | 1. [CLEVR](https://cs.stanford.edu/people/jcjohns/clevr/): Download by executing [download_clevr.sh](./data_scripts/download_clevr.sh). 63 | 2. [CLEVR (with masks)](https://github.com/deepmind/multi_object_datasets#clevr-with-masks): [Original TFRecords Download](https://console.cloud.google.com/storage/browser/multi-object-datasets/clevr_with_masks) / [Our HDF5 PyTorch Version](https://huggingface.co/HHousen/object-discovery-pytorch/blob/main/Datasets/clevr_with_masks.h5). 64 | - This dataset is a regenerated version of CLEVR but with ground-truth segmentation masks. This enables the training script to calculate Adjusted Rand Index (ARI) during validation runs. 65 | - The dataset contains 100,000 images with a resolution of 240x320 pixels. The dataloader splits them 70K train, 15K validation, 15k test. Test images are not used by the [object_discovery/train.py](object_discovery/train.py) script. 66 | - We convert the original TFRecords dataset to HDF5 for easy use with PyTorch. This was done using the [data_scripts/preprocess_clevr_with_masks.py](./data_scripts/preprocess_clevr_with_masks.py) script, which takes approximately 2 hours to execute depending on your machine. 67 | 3. [3D Shapes](https://github.com/deepmind/3d-shapes): [Official Google Cloud Bucket](https://console.cloud.google.com/storage/browser/3d-shapes) 68 | 4. RAVENS Robot Data: [Train](https://huggingface.co/HHousen/object-discovery-pytorch/blob/main/Datasets/RAVENS%20Robot%20Dataset/ravens_robot_data_train.h5) & [Test](https://huggingface.co/HHousen/object-discovery-pytorch/blob/main/Datasets/RAVENS%20Robot%20Dataset/ravens_robot_data_test.h5) 69 | - We generated a dataset similar in structure to CLEVR (with masks) but of robotic images using [RAVENS](https://github.com/google-research/ravens). Our modified version of RAVENS used to generate the dataset is [HHousen/ravens](https://github.com/HHousen/ravens). 70 | - The dataset contains 85,002 images split 70,002 train and 15K validation/test. 71 | 5. Sketchy: Download and process by following directions in [applied-ai-lab/genesis](https://github.com/applied-ai-lab/genesis#sketchy) / [Download Our Processed Version](https://huggingface.co/HHousen/object-discovery-pytorch/blob/main/Datasets/genesis_sketchy_processed.tar.gz) 72 | - Dataset details are in the paper [Scaling data-driven robotics with reward sketching and batch reinforcement learning](https://arxiv.org/abs/1909.12200). 73 | 6. ClevrTex: Download by executing [download_clevrtex.sh](./data_scripts/download_clevrtex.sh). Our dataloader needs to index the entire dataset before training can begin. This can take around 2 hours. Thus, it is recommended to download our pre-made index from [this Hugging Face folder](https://huggingface.co/HHousen/object-discovery-pytorch/tree/main/Datasets/ClevrTex%20Cache) and put it in `./data/cache/`. 74 | 7. [Tetrominoes](https://github.com/deepmind/multi_object_datasets#tetrominoes): [Original TFRecords Download](https://console.cloud.google.com/storage/browser/multi-object-datasets/tetrominoes) / [Our HDF5 PyTorch Version](https://huggingface.co/HHousen/object-discovery-pytorch/blob/main/Datasets/tetrominoes.h5). 75 | - There are 1,000,000 samples in the dataset. However, following [the Slot Attention paper](https://arxiv.org/abs/2006.15055), we only use the first 60K samples for training. 76 | - We convert the original TFRecords dataset to HDF5 for easy use with PyTorch. This was done using the [data_scripts/preprocess_tetrominoes.py](./data_scripts/preprocess_tetrominoes.py) script, which takes approximately 2 hours to execute depending on your machine. 77 | 78 | ### Logging 79 | 80 | To log outputs to [wandb](https://wandb.ai/home), run `wandb login YOUR_API_KEY` and set `is_logging_enabled=True` in `SlotAttentionParams`. 81 | 82 | If you use a dataset with ground-truth segmentation masks, then the Adjusted Rand Index (ARI), a clustering similarity score, will be logged for each validation loop. We convert the implementation from [deepmind/multi_object_datasets](https://github.com/deepmind/multi_object_datasets) to PyTorch in [object_discovery/segmentation_metrics.py](object_discovery/segmentation_metrics.py). 83 | 84 | ## More Visualizations 85 | 86 | Slot Attention CLEVR10 | Slot Attention Sketchy 87 | :-----------------------:|:--------------------: 88 | ![](./media/sa_clevr10_with_masks.png) | ![](./media/sa_sketchy_with_masks.png) 89 | 90 | Visualizations (above) for a model trained on CLEVR6 predicting on CLEVR10 (with no increase in number of slots) and a model trained and predicting on Sketchy. The order from left to right of the images is original, reconstruction, raw predicted segmentation mask, processed segmentation mask, and then the slots. 91 | 92 | Slot Attention ClevrTex6 | GNM ClevrTex6 93 | :-----------------------:|:--------------------: 94 | ![](./media/sa_clevrtex6.png) | ![](./media/gnm_clevrtex6.png) 95 | 96 | The Slot Attention visualization image order is the same as in the above visualizations. For GNM, the order is original, reconstruction, ground truth segmentation mask, prediction segmentation mask (repeated 4 times). 97 | 98 | SLATE CLEVR6 | GNM CLEVR6 99 | :-----------------------:|:--------------------: 100 | ![](./media/slate_clevr6.png) | ![](./media/gnm_clevr6.png) 101 | 102 | For SLATE, the image order is original, dVAE reconstruction, autoregressive reconstruction, and then the pixels each slot pays attention to. 103 | 104 | ## References 105 | 106 | 1. [untitled-ai/slot_attention](https://github.com/untitled-ai/slot_attention): An unofficial implementation of Slot Attention from which this repo was forked. 107 | 2. Slot Attention: [Official Code](https://github.com/google-research/google-research/tree/master/slot_attention) / ["Object-Centric Learning with Slot Attention"](https://arxiv.org/abs/2006.15055). 108 | 3. SLATE: [Official Code](https://github.com/singhgautam/slate) / ["Illiterate DALL-E Learns to Compose"](https://arxiv.org/abs/2110.11405). 109 | 4. IODINE: [Official Code](https://github.com/deepmind/deepmind-research/tree/master/iodine) / ["Multi-Object Representation Learning with Iterative Variational Inference"](https://arxiv.org/abs/1903.00450). In the Slot Attention paper, IODINE was frequently used for comparison. The IODINE code was helpful to create this repo. 110 | 5. Multi-Object Datasets: [deepmind/multi_object_datasets](https://github.com/deepmind/multi_object_datasets). This is the original source of the [CLEVR (with masks)](https://github.com/deepmind/multi_object_datasets#clevr-with-masks) dataset. 111 | 6. Implicit Slot Attention: ["Object Representations as Fixed Points: Training Iterative Refinement Algorithms with Implicit Differentiation"](https://arxiv.org/abs/2207.00787). This paper explains a one-line change that improves the optimization of Slot Attention while simultaneously making backpropagation have constant space and time complexity. 112 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHousen/object-discovery-pytorch/db22b8c7f230a79361927df0fa91e9fd19da2160/__init__.py -------------------------------------------------------------------------------- /data_scripts/clevr_with_masks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """CLEVR (with masks) dataset reader.""" 16 | 17 | import tensorflow.compat.v1 as tf 18 | 19 | 20 | COMPRESSION_TYPE = tf.io.TFRecordOptions.get_compression_type_string("GZIP") 21 | IMAGE_SIZE = [240, 320] 22 | # The maximum number of foreground and background entities in the provided 23 | # dataset. This corresponds to the number of segmentation masks returned per 24 | # scene. 25 | MAX_NUM_ENTITIES = 11 26 | BYTE_FEATURES = ["mask", "image", "color", "material", "shape", "size"] 27 | 28 | # Create a dictionary mapping feature names to `tf.Example`-compatible 29 | # shape and data type descriptors. 30 | features = { 31 | "image": tf.FixedLenFeature(IMAGE_SIZE + [3], tf.string), 32 | "mask": tf.FixedLenFeature([MAX_NUM_ENTITIES] + IMAGE_SIZE + [1], tf.string), 33 | "x": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 34 | "y": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 35 | "z": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 36 | "pixel_coords": tf.FixedLenFeature([MAX_NUM_ENTITIES, 3], tf.float32), 37 | "rotation": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 38 | "size": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.string), 39 | "material": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.string), 40 | "shape": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.string), 41 | "color": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.string), 42 | "visibility": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 43 | } 44 | 45 | 46 | def _decode(example_proto): 47 | # Parse the input `tf.Example` proto using the feature description dict above. 48 | single_example = tf.parse_single_example(example_proto, features) 49 | for k in BYTE_FEATURES: 50 | single_example[k] = tf.squeeze( 51 | tf.decode_raw(single_example[k], tf.uint8), axis=-1 52 | ) 53 | return single_example 54 | 55 | 56 | def dataset(tfrecords_path, read_buffer_size=None, map_parallel_calls=None): 57 | """Read, decompress, and parse the TFRecords file. 58 | 59 | Args: 60 | tfrecords_path: str. Path to the dataset file. 61 | read_buffer_size: int. Number of bytes in the read buffer. See documentation 62 | for `tf.data.TFRecordDataset.__init__`. 63 | map_parallel_calls: int. Number of elements decoded asynchronously in 64 | parallel. See documentation for `tf.data.Dataset.map`. 65 | 66 | Returns: 67 | An unbatched `tf.data.TFRecordDataset`. 68 | """ 69 | raw_dataset = tf.data.TFRecordDataset( 70 | tfrecords_path, compression_type=COMPRESSION_TYPE, buffer_size=read_buffer_size 71 | ) 72 | return raw_dataset.map(_decode, num_parallel_calls=map_parallel_calls) 73 | -------------------------------------------------------------------------------- /data_scripts/download_clevr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | DATA_DIR=$1 5 | 6 | if [ ! -d $DATA_DIR ]; then 7 | mkdir $DATA_DIR 8 | fi 9 | 10 | cd $DATA_DIR 11 | 12 | if [ ! -f "$DATA_DIR/CLEVR_v1.0.zip.*" ]; then 13 | wget https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip 14 | echo "CLEVR_v1 downloaded to $DATA_DIR/CLEVR_v1.0.zip" 15 | else 16 | echo "$DATA_DIR/CLEVR_v1.0.zip already exists, skipping download" 17 | fi 18 | 19 | echo "unzipping CLEVR_v1 to $DATA_DIR/CLEVR_v1.0" 20 | rm -rf CLEVR_v1.0 21 | unzip -q CLEVR_v1.0.zip 22 | rm CLEVR_v1.0.zip 23 | -------------------------------------------------------------------------------- /data_scripts/download_clevrtex.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | DATA_DIR=$1 5 | 6 | if [ ! -d $DATA_DIR ]; then 7 | mkdir $DATA_DIR 8 | fi 9 | 10 | cd $DATA_DIR 11 | 12 | wget https://thor.robots.ox.ac.uk/~vgg/data/clevrtex/clevrtex_packaged/clevrtex_full_part1.tar.gz 13 | wget https://thor.robots.ox.ac.uk/~vgg/data/clevrtex/clevrtex_packaged/clevrtex_full_part2.tar.gz 14 | wget https://thor.robots.ox.ac.uk/~vgg/data/clevrtex/clevrtex_packaged/clevrtex_full_part3.tar.gz 15 | wget https://thor.robots.ox.ac.uk/~vgg/data/clevrtex/clevrtex_packaged/clevrtex_full_part4.tar.gz 16 | wget https://thor.robots.ox.ac.uk/~vgg/data/clevrtex/clevrtex_packaged/clevrtex_full_part5.tar.gz 17 | echo "ClevrTex downloaded to $DATA_DIR" 18 | 19 | 20 | echo "Unzipping ClevrTex to $DATA_DIR/ClevrTex" 21 | rm -rf ClevrTex 22 | for file in *.tar.gz; do tar -zxf "$file"; done 23 | for file in *.tar.gz; do rm "$file"; done 24 | -------------------------------------------------------------------------------- /data_scripts/preprocess_clevr_with_masks.py: -------------------------------------------------------------------------------- 1 | # Script to convert CLEVR (with masks) tfrecords file to h5py. 2 | # CLEVR (with masks) information: https://github.com/deepmind/multi_object_datasets#clevr-with-masks 3 | # Download: https://console.cloud.google.com/storage/browser/multi-object-datasets/clevr_with_masks 4 | 5 | import numpy as np 6 | from tqdm import tqdm 7 | import clevr_with_masks 8 | import h5py 9 | 10 | 11 | dataset = clevr_with_masks.dataset( 12 | "clevr_with_masks_train.tfrecords" 13 | ).as_numpy_iterator() 14 | 15 | with h5py.File("clevr_with_masks.h5", "w") as f: 16 | for idx, entry in tqdm(enumerate(dataset), total=100_000): 17 | for key, value in entry.items(): 18 | value = value[np.newaxis, ...] 19 | if idx == 0: 20 | f.create_dataset( 21 | key, 22 | data=value, 23 | dtype=value.dtype, 24 | maxshape=(None, *value.shape[1:]), 25 | compression="gzip", 26 | chunks=True, 27 | ) 28 | else: 29 | f[key].resize((f[key].shape[0] + value.shape[0]), axis=0) 30 | f[key][-value.shape[0] :] = value 31 | -------------------------------------------------------------------------------- /data_scripts/preprocess_tetrominoes.py: -------------------------------------------------------------------------------- 1 | # Script to convert Tetrominoes tfrecords file to h5py. 2 | # Tetrominoes information: https://github.com/deepmind/multi_object_datasets#tetrominoes 3 | # Download: https://console.cloud.google.com/storage/browser/multi-object-datasets/tetrominoes 4 | 5 | import numpy as np 6 | from tqdm import tqdm 7 | import tetrominoes 8 | import h5py 9 | 10 | 11 | dataset = tetrominoes.dataset("tetrominoes_train.tfrecords").as_numpy_iterator() 12 | 13 | with h5py.File("tetrominoes.h5", "w") as f: 14 | for idx, entry in tqdm(enumerate(dataset), total=1_000_000): 15 | for key, value in entry.items(): 16 | value = value[np.newaxis, ...] 17 | if idx == 0: 18 | f.create_dataset( 19 | key, 20 | data=value, 21 | dtype=value.dtype, 22 | maxshape=(None, *value.shape[1:]), 23 | compression="gzip", 24 | chunks=True, 25 | ) 26 | else: 27 | f[key].resize((f[key].shape[0] + value.shape[0]), axis=0) 28 | f[key][-value.shape[0] :] = value 29 | -------------------------------------------------------------------------------- /data_scripts/tetrominoes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tetrominoes dataset reader.""" 16 | 17 | import tensorflow.compat.v1 as tf 18 | 19 | 20 | COMPRESSION_TYPE = tf.io.TFRecordOptions.get_compression_type_string("GZIP") 21 | IMAGE_SIZE = [35, 35] 22 | # The maximum number of foreground and background entities in the provided 23 | # dataset. This corresponds to the number of segmentation masks returned per 24 | # scene. 25 | MAX_NUM_ENTITIES = 4 26 | BYTE_FEATURES = ["mask", "image"] 27 | 28 | # Create a dictionary mapping feature names to `tf.Example`-compatible 29 | # shape and data type descriptors. 30 | features = { 31 | "image": tf.FixedLenFeature(IMAGE_SIZE + [3], tf.string), 32 | "mask": tf.FixedLenFeature([MAX_NUM_ENTITIES] + IMAGE_SIZE + [1], tf.string), 33 | "x": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 34 | "y": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 35 | "shape": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 36 | "color": tf.FixedLenFeature([MAX_NUM_ENTITIES, 3], tf.float32), 37 | "visibility": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 38 | } 39 | 40 | 41 | def _decode(example_proto): 42 | # Parse the input `tf.Example` proto using the feature description dict above. 43 | single_example = tf.parse_single_example(example_proto, features) 44 | for k in BYTE_FEATURES: 45 | single_example[k] = tf.squeeze( 46 | tf.decode_raw(single_example[k], tf.uint8), axis=-1 47 | ) 48 | return single_example 49 | 50 | 51 | def dataset(tfrecords_path, read_buffer_size=None, map_parallel_calls=None): 52 | """Read, decompress, and parse the TFRecords file. 53 | 54 | Args: 55 | tfrecords_path: str. Path to the dataset file. 56 | read_buffer_size: int. Number of bytes in the read buffer. See documentation 57 | for `tf.data.TFRecordDataset.__init__`. 58 | map_parallel_calls: int. Number of elements decoded asynchronously in 59 | parallel. See documentation for `tf.data.Dataset.map`. 60 | 61 | Returns: 62 | An unbatched `tf.data.TFRecordDataset`. 63 | """ 64 | raw_dataset = tf.data.TFRecordDataset( 65 | tfrecords_path, compression_type=COMPRESSION_TYPE, buffer_size=read_buffer_size 66 | ) 67 | return raw_dataset.map(_decode, num_parallel_calls=map_parallel_calls) 68 | -------------------------------------------------------------------------------- /media/gnm_clevr6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHousen/object-discovery-pytorch/db22b8c7f230a79361927df0fa91e9fd19da2160/media/gnm_clevr6.png -------------------------------------------------------------------------------- /media/gnm_clevrtex6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHousen/object-discovery-pytorch/db22b8c7f230a79361927df0fa91e9fd19da2160/media/gnm_clevrtex6.png -------------------------------------------------------------------------------- /media/sa_clevr10_with_masks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHousen/object-discovery-pytorch/db22b8c7f230a79361927df0fa91e9fd19da2160/media/sa_clevr10_with_masks.png -------------------------------------------------------------------------------- /media/sa_clevr6_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHousen/object-discovery-pytorch/db22b8c7f230a79361927df0fa91e9fd19da2160/media/sa_clevr6_example.png -------------------------------------------------------------------------------- /media/sa_clevrtex6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHousen/object-discovery-pytorch/db22b8c7f230a79361927df0fa91e9fd19da2160/media/sa_clevrtex6.png -------------------------------------------------------------------------------- /media/sa_sketchy_with_masks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHousen/object-discovery-pytorch/db22b8c7f230a79361927df0fa91e9fd19da2160/media/sa_sketchy_with_masks.png -------------------------------------------------------------------------------- /media/slate_clevr6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHousen/object-discovery-pytorch/db22b8c7f230a79361927df0fa91e9fd19da2160/media/slate_clevr6.png -------------------------------------------------------------------------------- /object_discovery/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHousen/object-discovery-pytorch/db22b8c7f230a79361927df0fa91e9fd19da2160/object_discovery/__init__.py -------------------------------------------------------------------------------- /object_discovery/gnm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHousen/object-discovery-pytorch/db22b8c7f230a79361927df0fa91e9fd19da2160/object_discovery/gnm/__init__.py -------------------------------------------------------------------------------- /object_discovery/gnm/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import json 3 | import io 4 | from argparse import Namespace 5 | 6 | 7 | def dict_to_ns(d): 8 | return Namespace(**d) 9 | 10 | 11 | CONFIG_YAML = """exp_name: '' 12 | data_dir: '' 13 | summary_dir: '' 14 | model_dir: '' 15 | last_ckpt: '' 16 | data: 17 | img_w: 128 18 | img_h: 128 19 | inp_channel: 3 20 | blender_dir_list_train: [] 21 | blender_dir_list_test: [] 22 | dataset: 'mnist' 23 | z: 24 | z_global_dim: 32 25 | z_what_dim: 64 26 | z_where_scale_dim: 2 27 | z_where_shift_dim: 2 28 | z_where_dim: 4 29 | z_pres_dim: 1 30 | z_depth_dim: 1 31 | z_local_dim: 64 32 | z_bg_dim: 10 33 | arch: 34 | glimpse_size: 64 35 | num_cell: 4 36 | phase_overlap: True 37 | phase_background: True 38 | img_enc_dim: 128 39 | p_global_decoder_type: 'MLP' 40 | draw_step: 4 41 | phase_graph_net_on_global_decoder: False 42 | phase_graph_net_on_global_encoder: False 43 | conv: 44 | img_encoder_filters: [16, 16, 32, 32, 64, 64, 128, 128, 128] 45 | img_encoder_groups: [1, 1, 1, 1, 1, 1, 1, 1, 1] 46 | img_encoder_strides: [2, 1, 2, 1, 2, 1, 2, 1, 2] 47 | img_encoder_kernel_sizes: [4, 3, 4, 3, 4, 3, 4, 3, 4] 48 | p_what_decoder_filters: [128, 64, 32, 16, 8, 4] 49 | p_what_decoder_kernel_sizes: [3, 3, 3, 3, 3, 3] 50 | p_what_decoder_upscales: [2, 2, 2, 2, 2, 2] 51 | p_what_decoder_groups: [1, 1, 1, 1, 1, 1] 52 | p_bg_decoder_filters: [128, 64, 32, 16, 8, 3] 53 | p_bg_decoder_kernel_sizes: [1, 1, 1, 1, 1, 3] 54 | p_bg_decoder_upscales: [4, 2, 4, 2, 2, 1] 55 | p_bg_decoder_groups: [1, 1, 1, 1, 1, 1] 56 | deconv: 57 | p_global_decoder_filters: [128, 128, 128] 58 | p_global_decoder_kernel_sizes: [1, 1, 1] 59 | p_global_decoder_upscales: [2, 1, 2] 60 | p_global_decoder_groups: [1, 1, 1] 61 | mlp: 62 | p_global_decoder_filters: [512, 1024, 2048] 63 | q_global_encoder_filters: [512, 512, 64] 64 | p_global_encoder_filters: [512, 512, 64] 65 | p_bg_generator_filters: [128, 64, 20] 66 | q_bg_encoder_filters: [512, 256, 20] 67 | pwdw: 68 | pwdw_filters: [128, 128] 69 | pwdw_kernel_sizes: [1, 1] 70 | pwdw_strides: [1, 1] 71 | pwdw_groups: [1, 1] 72 | structdraw: 73 | kernel_size: 1 74 | rnn_decoder_hid_dim: 128 75 | rnn_encoder_hid_dim: 128 76 | hid_to_dec_filters: [128] 77 | hid_to_dec_kernel_sizes: [3] 78 | hid_to_dec_strides: [1] 79 | hid_to_dec_groups: [1] 80 | log: 81 | num_summary_img: 15 82 | num_img_per_row: 5 83 | save_epoch_freq: 10 84 | print_step_freq: 2000 85 | num_sample: 50 86 | compute_nll_freq: 20 87 | phase_nll: False 88 | nll_num_sample: 30 89 | phase_log: True 90 | const: 91 | pres_logit_scale: 8.8 92 | scale_mean: -1.5 93 | scale_std: 0.1 94 | ratio_mean: 0 95 | ratio_std: 0.3 96 | shift_std: 1 97 | eps: 0.000000000000001 98 | likelihood_sigma: 0.2 99 | bg_likelihood_sigma: 0.3 100 | train: 101 | start_epoch: 0 102 | epoch: 600 103 | batch_size: 32 104 | lr: 0.0001 105 | cp: 1.0 106 | beta_global_anneal_start_step: 0 107 | beta_global_anneal_end_step: 100000 108 | beta_global_anneal_start_value: 0. 109 | beta_global_anneal_end_value: 1. 110 | beta_pres_anneal_start_step: 0 111 | beta_pres_anneal_end_step: 0 112 | beta_pres_anneal_start_value: 1. 113 | beta_pres_anneal_end_value: 0. 114 | beta_where_anneal_start_step: 0 115 | beta_where_anneal_end_step: 0 116 | beta_where_anneal_start_value: 1. 117 | beta_where_anneal_end_value: 0. 118 | beta_what_anneal_start_step: 0 119 | beta_what_anneal_end_step: 0 120 | beta_what_anneal_start_value: 1. 121 | beta_what_anneal_end_value: 0. 122 | beta_depth_anneal_start_step: 0 123 | beta_depth_anneal_end_step: 0 124 | beta_depth_anneal_start_value: 1. 125 | beta_depth_anneal_end_value: 0. 126 | beta_bg_anneal_start_step: 1000 127 | beta_bg_anneal_end_step: 0 128 | beta_bg_anneal_start_value: 1. 129 | beta_bg_anneal_end_value: 0. 130 | beta_aux_pres_anneal_start_step: 1000 131 | beta_aux_pres_anneal_end_step: 0 132 | beta_aux_pres_anneal_start_value: 1. 133 | beta_aux_pres_anneal_end_value: 0. 134 | beta_aux_where_anneal_start_step: 0 135 | beta_aux_where_anneal_end_step: 500 136 | beta_aux_where_anneal_start_value: 10. 137 | beta_aux_where_anneal_end_value: 1. 138 | beta_aux_what_anneal_start_step: 1000 139 | beta_aux_what_anneal_end_step: 0 140 | beta_aux_what_anneal_start_value: 1. 141 | beta_aux_what_anneal_end_value: 0. 142 | beta_aux_depth_anneal_start_step: 1000 143 | beta_aux_depth_anneal_end_step: 0 144 | beta_aux_depth_anneal_start_value: 1. 145 | beta_aux_depth_anneal_end_value: 0. 146 | beta_aux_global_anneal_start_step: 0 147 | beta_aux_global_anneal_end_step: 100000 148 | beta_aux_global_anneal_start_value: 0. 149 | beta_aux_global_anneal_end_value: 1. 150 | beta_aux_bg_anneal_start_step: 0 151 | beta_aux_bg_anneal_end_step: 50000 152 | beta_aux_bg_anneal_start_value: 50. 153 | beta_aux_bg_anneal_end_value: 1. 154 | tau_pres_anneal_start_step: 1000 155 | tau_pres_anneal_end_step: 20000 156 | tau_pres_anneal_start_value: 1. 157 | tau_pres_anneal_end_value: 0.5 158 | tau_pres: 1. 159 | p_pres_anneal_start_step: 0 160 | p_pres_anneal_end_step: 4000 161 | p_pres_anneal_start_value: 0.1 162 | p_pres_anneal_end_value: 0.001 163 | aux_p_scale_anneal_start_step: 0 164 | aux_p_scale_anneal_end_step: 0 165 | aux_p_scale_anneal_start_value: -1.5 166 | aux_p_scale_anneal_end_value: -1.5 167 | phase_bg_alpha_curriculum: True 168 | bg_alpha_curriculum_period: [0, 500] 169 | bg_alpha_curriculum_value: 0.9 170 | seed: 666 171 | """ 172 | 173 | 174 | def get_arrow_args(): 175 | arrow_args = json.loads( 176 | json.dumps(yaml.safe_load(io.StringIO(CONFIG_YAML))), object_hook=dict_to_ns 177 | ) 178 | return arrow_args 179 | -------------------------------------------------------------------------------- /object_discovery/gnm/gnm_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .module import ( 4 | LocalLatentDecoder, 5 | LocalSampler, 6 | StructDRAW, 7 | BgDecoder, 8 | BgGenerator, 9 | BgEncoder, 10 | ) 11 | from .submodule import StackConvNorm 12 | import torch.distributions as dist 13 | from .utils import ( 14 | linear_schedule_tensor, 15 | spatial_transform, 16 | kl_divergence_bern_bern, 17 | linear_schedule, 18 | visualize, 19 | ) 20 | from typing import List, Tuple 21 | 22 | 23 | class GNM(nn.Module): 24 | shortname = "gnm" 25 | 26 | def __init__(self, args): 27 | super(GNM, self).__init__() 28 | self.args = args 29 | 30 | self.img_encoder = StackConvNorm( 31 | self.args.data.inp_channel, 32 | self.args.arch.conv.img_encoder_filters, 33 | self.args.arch.conv.img_encoder_kernel_sizes, 34 | self.args.arch.conv.img_encoder_strides, 35 | self.args.arch.conv.img_encoder_groups, 36 | norm_act_final=True, 37 | ) 38 | self.global_struct_draw = StructDRAW(self.args) 39 | self.p_z_given_x_or_g_net = LocalLatentDecoder(self.args) 40 | # Share latent decoder for p and q 41 | self.local_latent_sampler = LocalSampler(self.args) 42 | 43 | if self.args.arch.phase_background: 44 | self.p_bg_decoder = BgDecoder(self.args) 45 | self.p_bg_given_g_net = BgGenerator(self.args) 46 | self.q_bg_given_x_net = BgEncoder(self.args) 47 | 48 | self.register_buffer("aux_p_what_mean", torch.zeros(1)) 49 | self.register_buffer("aux_p_what_std", torch.ones(1)) 50 | self.register_buffer("aux_p_bg_mean", torch.zeros(1)) 51 | self.register_buffer("aux_p_bg_std", torch.ones(1)) 52 | self.register_buffer("aux_p_depth_mean", torch.zeros(1)) 53 | self.register_buffer("aux_p_depth_std", torch.ones(1)) 54 | self.register_buffer( 55 | "aux_p_where_mean", 56 | torch.tensor( 57 | [self.args.const.scale_mean, self.args.const.ratio_mean, 0, 0] 58 | )[None, :], 59 | ) 60 | # self.register_buffer('auxiliary_where_std', torch.ones(1)) 61 | self.register_buffer( 62 | "aux_p_where_std", 63 | torch.tensor( 64 | [ 65 | self.args.const.scale_std, 66 | self.args.const.ratio_std, 67 | self.args.const.shift_std, 68 | self.args.const.shift_std, 69 | ] 70 | )[None, :], 71 | ) 72 | self.register_buffer( 73 | "aux_p_pres_probs", torch.tensor(self.args.train.p_pres_anneal_start_value) 74 | ) 75 | self.register_buffer( 76 | "background", 77 | torch.zeros( 78 | 1, 79 | self.args.data.inp_channel, 80 | self.args.data.img_h, 81 | self.args.data.img_w, 82 | ), 83 | ) 84 | 85 | @property 86 | def aux_p_what(self): 87 | return dist.Normal(self.aux_p_what_mean, self.aux_p_what_std) 88 | 89 | @property 90 | def aux_p_bg(self): 91 | return dist.Normal(self.aux_p_bg_mean, self.aux_p_bg_std) 92 | 93 | @property 94 | def aux_p_depth(self): 95 | return dist.Normal(self.aux_p_depth_mean, self.aux_p_depth_std) 96 | 97 | @property 98 | def aux_p_where(self): 99 | return dist.Normal(self.aux_p_where_mean, self.aux_p_where_std) 100 | 101 | def forward(self, x: torch.Tensor, global_step, generate_bbox=False): 102 | self.args = hyperparam_anneal(self.args, global_step) 103 | bs = x.size(0) 104 | 105 | img_enc = self.img_encoder(x) 106 | if self.args.arch.phase_background: 107 | lv_q_bg, ss_q_bg = self.q_bg_given_x_net(img_enc) 108 | q_bg_mean, q_bg_std = ss_q_bg 109 | else: 110 | lv_q_bg = [self.background.new_zeros(1, 1)] 111 | q_bg_mean = self.background.new_zeros(1, 1) 112 | q_bg_std = self.background.new_ones(1, 1) 113 | ss_q_bg = [q_bg_mean, q_bg_std] 114 | 115 | q_bg = dist.Normal(q_bg_mean, q_bg_std) 116 | 117 | pa_g, lv_g, ss_g = self.global_struct_draw(img_enc) 118 | 119 | global_dec = pa_g[0] 120 | 121 | p_global_mean_all, p_global_std_all, q_global_mean_all, q_global_std_all = ss_g 122 | 123 | p_global_all = dist.Normal(p_global_mean_all, p_global_std_all) 124 | 125 | q_global_all = dist.Normal(q_global_mean_all, q_global_std_all) 126 | 127 | ss_p_z = self.p_z_given_x_or_g_net(global_dec) 128 | 129 | # (bs, dim, num_cell, num_cell) 130 | ( 131 | p_pres_logits, 132 | p_where_mean, 133 | p_where_std, 134 | p_depth_mean, 135 | p_depth_std, 136 | p_what_mean, 137 | p_what_std, 138 | ) = ss_p_z 139 | 140 | p_pres_given_g_probs_reshaped = torch.sigmoid( 141 | p_pres_logits.permute(0, 2, 3, 1).reshape( 142 | bs * self.args.arch.num_cell ** 2, -1 143 | ) 144 | ) 145 | 146 | p_where_given_g = dist.Normal( 147 | p_where_mean.permute(0, 2, 3, 1).reshape( 148 | bs * self.args.arch.num_cell ** 2, -1 149 | ), 150 | p_where_std.permute(0, 2, 3, 1).reshape( 151 | bs * self.args.arch.num_cell ** 2, -1 152 | ), 153 | ) 154 | p_depth_given_g = dist.Normal( 155 | p_depth_mean.permute(0, 2, 3, 1).reshape( 156 | bs * self.args.arch.num_cell ** 2, -1 157 | ), 158 | p_depth_std.permute(0, 2, 3, 1).reshape( 159 | bs * self.args.arch.num_cell ** 2, -1 160 | ), 161 | ) 162 | p_what_given_g = dist.Normal( 163 | p_what_mean.permute(0, 2, 3, 1).reshape( 164 | bs * self.args.arch.num_cell ** 2, -1 165 | ), 166 | p_what_std.permute(0, 2, 3, 1).reshape( 167 | bs * self.args.arch.num_cell ** 2, -1 168 | ), 169 | ) 170 | 171 | ss_q_z = self.p_z_given_x_or_g_net(img_enc, ss_p_z=ss_p_z) 172 | 173 | # (bs, dim, num_cell, num_cell) 174 | ( 175 | q_pres_logits, 176 | q_where_mean, 177 | q_where_std, 178 | q_depth_mean, 179 | q_depth_std, 180 | q_what_mean, 181 | q_what_std, 182 | ) = ss_q_z 183 | 184 | q_pres_given_x_and_g_probs_reshaped = torch.sigmoid( 185 | q_pres_logits.permute(0, 2, 3, 1).reshape( 186 | bs * self.args.arch.num_cell ** 2, -1 187 | ) 188 | ) 189 | 190 | q_where_given_x_and_g = dist.Normal( 191 | q_where_mean.permute(0, 2, 3, 1).reshape( 192 | bs * self.args.arch.num_cell ** 2, -1 193 | ), 194 | q_where_std.permute(0, 2, 3, 1).reshape( 195 | bs * self.args.arch.num_cell ** 2, -1 196 | ), 197 | ) 198 | q_depth_given_x_and_g = dist.Normal( 199 | q_depth_mean.permute(0, 2, 3, 1).reshape( 200 | bs * self.args.arch.num_cell ** 2, -1 201 | ), 202 | q_depth_std.permute(0, 2, 3, 1).reshape( 203 | bs * self.args.arch.num_cell ** 2, -1 204 | ), 205 | ) 206 | q_what_given_x_and_g = dist.Normal( 207 | q_what_mean.permute(0, 2, 3, 1).reshape( 208 | bs * self.args.arch.num_cell ** 2, -1 209 | ), 210 | q_what_std.permute(0, 2, 3, 1).reshape( 211 | bs * self.args.arch.num_cell ** 2, -1 212 | ), 213 | ) 214 | 215 | if self.args.arch.phase_background: 216 | # lv_p_bg, ss_p_bg = self.ss_p_bg_given_g(lv_g) 217 | lv_p_bg, ss_p_bg = self.p_bg_given_g_net(lv_g[0], phase_use_mode=False) 218 | p_bg_mean, p_bg_std = ss_p_bg 219 | else: 220 | lv_p_bg = [self.background.new_zeros(1, 1)] 221 | p_bg_mean = self.background.new_zeros(1, 1) 222 | p_bg_std = self.background.new_ones(1, 1) 223 | 224 | p_bg = dist.Normal(p_bg_mean, p_bg_std) 225 | 226 | pa_recon, lv_z = self.lv_p_x_given_z_and_bg(ss_q_z, lv_q_bg, global_step) 227 | *pa_recon, patches, masks = pa_recon 228 | canvas = pa_recon[0] 229 | background = pa_recon[-1] 230 | 231 | z_pres, z_where, z_depth, z_what, z_where_origin = lv_z 232 | 233 | p_dists = [ 234 | p_global_all, 235 | p_pres_given_g_probs_reshaped, 236 | p_where_given_g, 237 | p_depth_given_g, 238 | p_what_given_g, 239 | p_bg, 240 | ] 241 | 242 | q_dists = [ 243 | q_global_all, 244 | q_pres_given_x_and_g_probs_reshaped, 245 | q_where_given_x_and_g, 246 | q_depth_given_x_and_g, 247 | q_what_given_x_and_g, 248 | q_bg, 249 | ] 250 | 251 | log_like, kl, log_imp = self.elbo( 252 | x, p_dists, q_dists, lv_z, lv_g, lv_q_bg, pa_recon, global_step 253 | ) 254 | 255 | self.log = {} 256 | 257 | if self.args.log.phase_log: 258 | pa_recon_from_q_g, _ = self.get_recon_from_q_g( 259 | global_step, global_dec=global_dec, lv_g=lv_g 260 | ) 261 | 262 | z_pres_permute = z_pres.permute(0, 2, 3, 1) 263 | self.log = { 264 | "z_what": z_what.permute(0, 2, 3, 1).reshape( 265 | -1, self.args.z.z_what_dim 266 | ), 267 | "z_where_scale": z_where.permute(0, 2, 3, 1).reshape( 268 | -1, self.args.z.z_where_dim 269 | )[:, : self.args.z.z_where_scale_dim], 270 | "z_where_shift": z_where.permute(0, 2, 3, 1).reshape( 271 | -1, self.args.z.z_where_dim 272 | )[:, self.args.z.z_where_scale_dim :], 273 | "z_where_origin": z_where_origin.permute(0, 2, 3, 1).reshape( 274 | -1, self.args.z.z_where_dim 275 | ), 276 | "z_pres": z_pres_permute, 277 | "p_pres_probs": p_pres_given_g_probs_reshaped, 278 | "p_pres_logits": p_pres_logits, 279 | "p_what_std": p_what_std.permute(0, 2, 3, 1).reshape( 280 | -1, self.args.z.z_what_dim 281 | )[z_pres_permute.view(-1) > 0.05], 282 | "p_what_mean": p_what_mean.permute(0, 2, 3, 1).reshape( 283 | -1, self.args.z.z_what_dim 284 | )[z_pres_permute.view(-1) > 0.05], 285 | "p_where_scale_std": p_where_std.permute(0, 2, 3, 1).reshape( 286 | -1, self.args.z.z_where_dim 287 | )[z_pres_permute.view(-1) > 0.05][:, : self.args.z.z_where_scale_dim], 288 | "p_where_scale_mean": p_where_mean.permute(0, 2, 3, 1).reshape( 289 | -1, self.args.z.z_where_dim 290 | )[z_pres_permute.view(-1) > 0.05][:, : self.args.z.z_where_scale_dim], 291 | "p_where_shift_std": p_where_std.permute(0, 2, 3, 1).reshape( 292 | -1, self.args.z.z_where_dim 293 | )[z_pres_permute.view(-1) > 0.05][:, self.args.z.z_where_scale_dim :], 294 | "p_where_shift_mean": p_where_mean.permute(0, 2, 3, 1).reshape( 295 | -1, self.args.z.z_where_dim 296 | )[z_pres_permute.view(-1) > 0.05][:, self.args.z.z_where_scale_dim :], 297 | "q_pres_probs": q_pres_given_x_and_g_probs_reshaped, 298 | "q_pres_logits": q_pres_logits, 299 | "q_what_std": q_what_std.permute(0, 2, 3, 1).reshape( 300 | -1, self.args.z.z_what_dim 301 | )[z_pres_permute.view(-1) > 0.05], 302 | "q_what_mean": q_what_mean.permute(0, 2, 3, 1).reshape( 303 | -1, self.args.z.z_what_dim 304 | )[z_pres_permute.view(-1) > 0.05], 305 | "q_where_scale_std": q_where_std.permute(0, 2, 3, 1).reshape( 306 | -1, self.args.z.z_where_dim 307 | )[z_pres_permute.view(-1) > 0.05][:, : self.args.z.z_where_scale_dim], 308 | "q_where_scale_mean": q_where_mean.permute(0, 2, 3, 1).reshape( 309 | -1, self.args.z.z_where_dim 310 | )[z_pres_permute.view(-1) > 0.05][:, : self.args.z.z_where_scale_dim], 311 | "q_where_shift_std": q_where_std.permute(0, 2, 3, 1).reshape( 312 | -1, self.args.z.z_where_dim 313 | )[z_pres_permute.view(-1) > 0.05][:, self.args.z.z_where_scale_dim :], 314 | "q_where_shift_mean": q_where_mean.permute(0, 2, 3, 1).reshape( 315 | -1, self.args.z.z_where_dim 316 | )[z_pres_permute.view(-1) > 0.05][:, self.args.z.z_where_scale_dim :], 317 | "z_depth": z_depth.permute(0, 2, 3, 1).reshape( 318 | -1, self.args.z.z_depth_dim 319 | ), 320 | "p_depth_std": p_depth_std.permute(0, 2, 3, 1).reshape( 321 | -1, self.args.z.z_depth_dim 322 | )[z_pres_permute.view(-1) > 0.05], 323 | "p_depth_mean": p_depth_mean.permute(0, 2, 3, 1).reshape( 324 | -1, self.args.z.z_depth_dim 325 | )[z_pres_permute.view(-1) > 0.05], 326 | "q_depth_std": q_depth_std.permute(0, 2, 3, 1).reshape( 327 | -1, self.args.z.z_depth_dim 328 | )[z_pres_permute.view(-1) > 0.05], 329 | "q_depth_mean": q_depth_mean.permute(0, 2, 3, 1).reshape( 330 | -1, self.args.z.z_depth_dim 331 | )[z_pres_permute.view(-1) > 0.05], 332 | "recon": pa_recon[0], 333 | "recon_from_q_g": pa_recon_from_q_g[0], 334 | "log_prob_x_given_g": dist.Normal( 335 | pa_recon_from_q_g[0], self.args.const.likelihood_sigma 336 | ) 337 | .log_prob(x) 338 | .flatten(start_dim=1) 339 | .sum(1), 340 | "global_dec": global_dec, 341 | } 342 | z_global_all = lv_g[0] 343 | for i in range(self.args.arch.draw_step): 344 | self.log[f"z_global_step_{i}"] = z_global_all[:, i] 345 | self.log[f"q_global_mean_step_{i}"] = q_global_mean_all[:, i] 346 | self.log[f"q_global_std_step_{i}"] = q_global_std_all[:, i] 347 | self.log[f"p_global_mean_step_{i}"] = p_global_mean_all[:, i] 348 | self.log[f"p_global_std_step_{i}"] = p_global_std_all[:, i] 349 | if self.args.arch.phase_background: 350 | self.log["z_bg"] = lv_q_bg[0] 351 | self.log["p_bg_mean"] = p_bg_mean 352 | self.log["p_bg_std"] = p_bg_std 353 | self.log["q_bg_mean"] = q_bg_mean 354 | self.log["q_bg_std"] = q_bg_std 355 | self.log["recon_from_q_g_bg"] = pa_recon_from_q_g[-1] 356 | self.log["recon_from_q_g_fg"] = pa_recon_from_q_g[1] 357 | self.log["recon_from_q_g_alpha"] = pa_recon_from_q_g[2] 358 | self.log["recon_bg"] = pa_recon[-1] 359 | self.log["recon_fg"] = pa_recon[1] 360 | self.log["recon_alpha"] = pa_recon[2] 361 | 362 | ss = [ss_q_z, ss_q_bg, ss_g[2:]] 363 | ( 364 | aux_kl_pres, 365 | aux_kl_where, 366 | aux_kl_depth, 367 | aux_kl_what, 368 | aux_kl_bg, 369 | kl_pres, 370 | kl_where, 371 | kl_depth, 372 | kl_what, 373 | kl_global_all, 374 | kl_bg, 375 | ) = kl 376 | 377 | aux_kl_pres_raw = aux_kl_pres.mean(dim=0) 378 | aux_kl_where_raw = aux_kl_where.mean(dim=0) 379 | aux_kl_depth_raw = aux_kl_depth.mean(dim=0) 380 | aux_kl_what_raw = aux_kl_what.mean(dim=0) 381 | aux_kl_bg_raw = aux_kl_bg.mean(dim=0) 382 | kl_pres_raw = kl_pres.mean(dim=0) 383 | kl_where_raw = kl_where.mean(dim=0) 384 | kl_depth_raw = kl_depth.mean(dim=0) 385 | kl_what_raw = kl_what.mean(dim=0) 386 | kl_bg_raw = kl_bg.mean(dim=0) 387 | 388 | log_like = log_like.mean(dim=0) 389 | 390 | aux_kl_pres = aux_kl_pres_raw * self.args.train.beta_aux_pres 391 | aux_kl_where = aux_kl_where_raw * self.args.train.beta_aux_where 392 | aux_kl_depth = aux_kl_depth_raw * self.args.train.beta_aux_depth 393 | aux_kl_what = aux_kl_what_raw * self.args.train.beta_aux_what 394 | aux_kl_bg = aux_kl_bg_raw * self.args.train.beta_aux_bg 395 | kl_pres = kl_pres_raw * self.args.train.beta_pres 396 | kl_where = kl_where_raw * self.args.train.beta_where 397 | kl_depth = kl_depth_raw * self.args.train.beta_depth 398 | kl_what = kl_what_raw * self.args.train.beta_what 399 | kl_bg = kl_bg_raw * self.args.train.beta_bg 400 | 401 | kl_global_raw = kl_global_all.sum(dim=-1).mean(dim=0) 402 | kl_global = kl_global_raw * self.args.train.beta_global 403 | 404 | recon_loss = log_like 405 | kl = ( 406 | kl_pres 407 | + kl_where 408 | + kl_depth 409 | + kl_what 410 | + kl_bg 411 | + kl_global 412 | + aux_kl_pres 413 | + aux_kl_where 414 | + aux_kl_depth 415 | + aux_kl_what 416 | + aux_kl_bg 417 | ) 418 | elbo = recon_loss - kl 419 | loss = -elbo 420 | 421 | bbox = None 422 | if (not self.training) and generate_bbox: 423 | bbox = visualize( 424 | x, 425 | self.log["z_pres"].view(bs, self.args.arch.num_cell ** 2, -1), 426 | self.log["z_where_scale"].view(bs, self.args.arch.num_cell ** 2, -1), 427 | self.log["z_where_shift"].view(bs, self.args.arch.num_cell ** 2, -1), 428 | ) 429 | 430 | bbox = ( 431 | bbox.view(x.shape[0], -1, 3, self.args.data.img_h, self.args.data.img_w) 432 | .sum(1) 433 | .clamp(0.0, 1.0) 434 | ) 435 | # bbox_img = x.cpu().expand(-1, 3, -1, -1).contiguous() 436 | # bbox_img[bbox.sum(dim=1, keepdim=True).expand(-1, 3, -1, -1) > 0.5] = \ 437 | # bbox[bbox.sum(dim=1, keepdim=True).expand(-1, 3, -1, -1) > 0.5] 438 | ret = { 439 | "canvas": canvas, 440 | "canvas_with_bbox": bbox, 441 | "background": background, 442 | "steps": { 443 | "patch": patches, 444 | "mask": masks, 445 | "z_pres": z_pres.view(bs, self.args.arch.num_cell ** 2, -1), 446 | }, 447 | "counts": torch.round(z_pres).flatten(1).sum(-1), 448 | "loss": loss, 449 | "elbo": elbo, 450 | "kl": kl, 451 | "rec_loss": recon_loss, 452 | "kl_pres": kl_pres, 453 | "kl_aux_pres": aux_kl_pres, 454 | "kl_where": kl_where, 455 | "kl_aux_where": aux_kl_where, 456 | "kl_what": kl_what, 457 | "kl_aux_what": aux_kl_what, 458 | "kl_depth": kl_depth, 459 | "kl_aux_depth": aux_kl_depth, 460 | "kl_bg": kl_bg, 461 | "kl_aux_bg": aux_kl_bg, 462 | "kl_global": kl_global, 463 | } 464 | 465 | # return pa_recon, log_like, kl, log_imp, lv_z + lv_g + lv_q_bg, ss, self.log 466 | return ret 467 | 468 | def get_recon_from_q_g( 469 | self, 470 | global_step, 471 | img: torch.Tensor = None, 472 | global_dec: torch.Tensor = None, 473 | lv_g: List = None, 474 | phase_use_mode: bool = False, 475 | ) -> Tuple: 476 | 477 | assert img is not None or ( 478 | global_dec is not None and lv_g is not None 479 | ), "Provide either image or p_l_given_g" 480 | if img is not None: 481 | img_enc = self.img_encoder(img) 482 | pa_g, lv_g, ss_g = self.global_struct_draw(img_enc) 483 | 484 | global_dec = pa_g[0] 485 | 486 | if self.args.arch.phase_background: 487 | lv_p_bg, _ = self.p_bg_given_g_net(lv_g[0], phase_use_mode=phase_use_mode) 488 | else: 489 | lv_p_bg = [self.background.new_zeros(1, 1)] 490 | 491 | ss_z = self.p_z_given_x_or_g_net(global_dec) 492 | 493 | pa, lv = self.lv_p_x_given_z_and_bg( 494 | ss_z, lv_p_bg, global_step, phase_use_mode=phase_use_mode 495 | ) 496 | 497 | lv = lv + lv_p_bg 498 | 499 | return pa, lv 500 | 501 | def elbo( 502 | self, 503 | x: torch.Tensor, 504 | p_dists: List, 505 | q_dists: List, 506 | lv_z: List, 507 | lv_g: List, 508 | lv_bg: List, 509 | pa_recon: List, 510 | global_step, 511 | ) -> Tuple: 512 | 513 | bs = x.size(0) 514 | 515 | ( 516 | p_global_all, 517 | p_pres_given_g_probs_reshaped, 518 | p_where_given_g, 519 | p_depth_given_g, 520 | p_what_given_g, 521 | p_bg, 522 | ) = p_dists 523 | 524 | ( 525 | q_global_all, 526 | q_pres_given_x_and_g_probs_reshaped, 527 | q_where_given_x_and_g, 528 | q_depth_given_x_and_g, 529 | q_what_given_x_and_g, 530 | q_bg, 531 | ) = q_dists 532 | 533 | y, y_nobg, alpha_map, bg = pa_recon 534 | 535 | if self.args.log.phase_nll: 536 | # (bs, dim, num_cell, num_cell) 537 | z_pres, _, z_depth, z_what, z_where_origin = lv_z 538 | # (bs * num_cell * num_cell, dim) 539 | z_pres_reshape = z_pres.permute(0, 2, 3, 1).reshape( 540 | -1, self.args.z.z_pres_dim 541 | ) 542 | z_depth_reshape = z_depth.permute(0, 2, 3, 1).reshape( 543 | -1, self.args.z.z_depth_dim 544 | ) 545 | z_what_reshape = z_what.permute(0, 2, 3, 1).reshape( 546 | -1, self.args.z.z_what_dim 547 | ) 548 | z_where_origin_reshape = z_where_origin.permute(0, 2, 3, 1).reshape( 549 | -1, self.args.z.z_where_dim 550 | ) 551 | # (bs, dim, 1, 1) 552 | z_bg = lv_bg[0] 553 | # (bs, step, dim, 1, 1) 554 | z_g = lv_g[0] 555 | else: 556 | z_pres, _, _, _, z_where_origin = lv_z 557 | 558 | z_pres_reshape = z_pres.permute(0, 2, 3, 1).reshape( 559 | -1, self.args.z.z_pres_dim 560 | ) 561 | 562 | if self.args.train.p_pres_anneal_end_step != 0: 563 | self.aux_p_pres_probs = linear_schedule_tensor( 564 | global_step, 565 | self.args.train.p_pres_anneal_start_step, 566 | self.args.train.p_pres_anneal_end_step, 567 | self.args.train.p_pres_anneal_start_value, 568 | self.args.train.p_pres_anneal_end_value, 569 | self.aux_p_pres_probs.device, 570 | ) 571 | 572 | if self.args.train.aux_p_scale_anneal_end_step != 0: 573 | aux_p_scale_mean = linear_schedule_tensor( 574 | global_step, 575 | self.args.train.aux_p_scale_anneal_start_step, 576 | self.args.train.aux_p_scale_anneal_end_step, 577 | self.args.train.aux_p_scale_anneal_start_value, 578 | self.args.train.aux_p_scale_anneal_end_value, 579 | self.aux_p_where_mean.device, 580 | ) 581 | self.aux_p_where_mean[:, 0] = aux_p_scale_mean 582 | 583 | auxiliary_prior_z_pres_probs = self.aux_p_pres_probs[None][None, :].expand( 584 | bs * self.args.arch.num_cell ** 2, -1 585 | ) 586 | 587 | aux_kl_pres = kl_divergence_bern_bern( 588 | q_pres_given_x_and_g_probs_reshaped, auxiliary_prior_z_pres_probs 589 | ) 590 | aux_kl_where = dist.kl_divergence( 591 | q_where_given_x_and_g, self.aux_p_where 592 | ) * z_pres_reshape.clamp(min=1e-5) 593 | aux_kl_depth = dist.kl_divergence( 594 | q_depth_given_x_and_g, self.aux_p_depth 595 | ) * z_pres_reshape.clamp(min=1e-5) 596 | aux_kl_what = dist.kl_divergence( 597 | q_what_given_x_and_g, self.aux_p_what 598 | ) * z_pres_reshape.clamp(min=1e-5) 599 | 600 | kl_pres = kl_divergence_bern_bern( 601 | q_pres_given_x_and_g_probs_reshaped, p_pres_given_g_probs_reshaped 602 | ) 603 | 604 | kl_where = dist.kl_divergence(q_where_given_x_and_g, p_where_given_g) 605 | kl_depth = dist.kl_divergence(q_depth_given_x_and_g, p_depth_given_g) 606 | kl_what = dist.kl_divergence(q_what_given_x_and_g, p_what_given_g) 607 | 608 | kl_global_all = dist.kl_divergence(q_global_all, p_global_all) 609 | 610 | if self.args.arch.phase_background: 611 | kl_bg = dist.kl_divergence(q_bg, p_bg) 612 | aux_kl_bg = dist.kl_divergence(q_bg, self.aux_p_bg) 613 | else: 614 | kl_bg = self.background.new_zeros(bs, 1) 615 | aux_kl_bg = self.background.new_zeros(bs, 1) 616 | 617 | log_like = dist.Normal(y, self.args.const.likelihood_sigma).log_prob(x) 618 | 619 | log_imp_list = [] 620 | if self.args.log.phase_nll: 621 | log_pres_prior = z_pres_reshape * torch.log( 622 | p_pres_given_g_probs_reshaped + self.args.const.eps 623 | ) + (1 - z_pres_reshape) * torch.log( 624 | 1 - p_pres_given_g_probs_reshaped + self.args.const.eps 625 | ) 626 | log_pres_pos = z_pres_reshape * torch.log( 627 | q_pres_given_x_and_g_probs_reshaped + self.args.const.eps 628 | ) + (1 - z_pres_reshape) * torch.log( 629 | 1 - q_pres_given_x_and_g_probs_reshaped + self.args.const.eps 630 | ) 631 | 632 | log_imp_pres = log_pres_prior - log_pres_pos 633 | 634 | log_imp_depth = p_depth_given_g.log_prob( 635 | z_depth_reshape 636 | ) - q_depth_given_x_and_g.log_prob(z_depth_reshape) 637 | 638 | log_imp_what = p_what_given_g.log_prob( 639 | z_what_reshape 640 | ) - q_what_given_x_and_g.log_prob(z_what_reshape) 641 | 642 | log_imp_where = p_where_given_g.log_prob( 643 | z_where_origin_reshape 644 | ) - q_where_given_x_and_g.log_prob(z_where_origin_reshape) 645 | 646 | if self.args.arch.phase_background: 647 | log_imp_bg = p_bg.log_prob(z_bg) - q_bg.log_prob(z_bg) 648 | else: 649 | log_imp_bg = x.new_zeros(bs, 1) 650 | 651 | log_imp_g = p_global_all.log_prob(z_g) - q_global_all.log_prob(z_g) 652 | 653 | log_imp_list = [ 654 | log_imp_pres.view( 655 | bs, self.args.arch.num_cell, self.args.arch.num_cell, -1 656 | ) 657 | .flatten(start_dim=1) 658 | .sum(1), 659 | log_imp_depth.view( 660 | bs, self.args.arch.num_cell, self.args.arch.num_cell, -1 661 | ) 662 | .flatten(start_dim=1) 663 | .sum(1), 664 | log_imp_what.view( 665 | bs, self.args.arch.num_cell, self.args.arch.num_cell, -1 666 | ) 667 | .flatten(start_dim=1) 668 | .sum(1), 669 | log_imp_where.view( 670 | bs, self.args.arch.num_cell, self.args.arch.num_cell, -1 671 | ) 672 | .flatten(start_dim=1) 673 | .sum(1), 674 | log_imp_bg.flatten(start_dim=1).sum(1), 675 | log_imp_g.flatten(start_dim=1).sum(1), 676 | ] 677 | 678 | return ( 679 | log_like.flatten(start_dim=1).sum(1), 680 | [ 681 | aux_kl_pres.view( 682 | bs, self.args.arch.num_cell, self.args.arch.num_cell, -1 683 | ) 684 | .flatten(start_dim=1) 685 | .sum(-1), 686 | aux_kl_where.view( 687 | bs, self.args.arch.num_cell, self.args.arch.num_cell, -1 688 | ) 689 | .flatten(start_dim=1) 690 | .sum(-1), 691 | aux_kl_depth.view( 692 | bs, self.args.arch.num_cell, self.args.arch.num_cell, -1 693 | ) 694 | .flatten(start_dim=1) 695 | .sum(-1), 696 | aux_kl_what.view( 697 | bs, self.args.arch.num_cell, self.args.arch.num_cell, -1 698 | ) 699 | .flatten(start_dim=1) 700 | .sum(-1), 701 | aux_kl_bg.flatten(start_dim=1).sum(-1), 702 | kl_pres.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1) 703 | .flatten(start_dim=1) 704 | .sum(-1), 705 | kl_where.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1) 706 | .flatten(start_dim=1) 707 | .sum(-1), 708 | kl_depth.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1) 709 | .flatten(start_dim=1) 710 | .sum(-1), 711 | kl_what.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1) 712 | .flatten(start_dim=1) 713 | .sum(-1), 714 | kl_global_all.flatten(start_dim=2).sum(-1), 715 | kl_bg.flatten(start_dim=1).sum(-1), 716 | ], 717 | log_imp_list, 718 | ) 719 | 720 | # def get_img_enc(self, x: torch.Tensor) -> torch.Tensor: 721 | # """ 722 | # :param x: (bs, inp_channel, img_h, img_w) 723 | # :return: img_enc: (bs, dim, num_cell, num_cell) 724 | # """ 725 | # 726 | # img_enc = self.img_encoder(x) 727 | # 728 | # return img_enc 729 | 730 | # def ss_p_z_given_g(self, global_dec: torch.Tensor) -> List: 731 | # """ 732 | # :param x: sample of z_global variable (bs, dim, 1, 1) 733 | # :return: 734 | # """ 735 | # ss_z = self.p_z_given_g_net(global_dec) 736 | # 737 | # return ss_z 738 | 739 | # def ss_q_z_given_x(self, img_enc: torch.Tensor, global_dec: torch.Tensor, ss_p_z: List) -> List: 740 | # """ 741 | # :param x: sample of z_global variable (bs, dim, 1, 1) 742 | # :return: 743 | # """ 744 | # ss_z = self.p_z_given_x_or_g_net(img_enc, ss_p_z=ss_p_z) 745 | # 746 | # return ss_z 747 | 748 | # def ss_q_bg_given_x(self, x: torch.Tensor) -> Tuple: 749 | # """ 750 | # :param x: (bs, dim, img_h, img_w) 751 | # :return: 752 | # """ 753 | # lv_q_bg, ss_q_bg = self.q_bg_given_x_net(x) 754 | # 755 | # return lv_q_bg, ss_q_bg 756 | 757 | # def ss_p_bg_given_g(self, lv_g: List, phase_use_mode: bool = False) -> Tuple: 758 | # """ 759 | # :param x: (bs, dim, img_h, img_w) 760 | # :return: 761 | # """ 762 | # z_global_all = lv_g[0] 763 | # lv_p_bg, ss_p_bg = self.p_bg_given_g_net(z_global_all, phase_use_mode=phase_use_mode) 764 | # 765 | # return lv_p_bg, ss_p_bg 766 | 767 | def lv_p_x_given_z_and_bg( 768 | self, ss: List, lv_bg: List, global_step, phase_use_mode: bool = False 769 | ) -> Tuple: 770 | """ 771 | :param z: (bs, z_what_dim) 772 | :return: 773 | """ 774 | # x: (bs, inp_channel, img_h, img_w) 775 | pa, lv_z = self.local_latent_sampler(ss, phase_use_mode=phase_use_mode) 776 | 777 | o_att, a_att, *_ = pa 778 | z_pres, z_where, z_depth, *_ = lv_z 779 | 780 | if self.args.arch.phase_background: 781 | z_bg = lv_bg[0] 782 | pa_bg = self.p_bg_decoder(z_bg) 783 | y_bg = pa_bg[0] 784 | else: 785 | # pa_bg = [self.background.expand(lv_z[0].size(0), -1, -1, -1)] 786 | y_bg = self.background.expand(lv_z[0].size(0), -1, -1, -1) 787 | 788 | # pa = pa + pa_bg 789 | 790 | y, y_fg, alpha_map, patches, masks = self.render( 791 | o_att, a_att, y_bg, z_pres, z_where, z_depth, global_step 792 | ) 793 | 794 | return [y, y_fg, alpha_map, y_bg, patches, masks], lv_z 795 | 796 | # def pa_bg_given_z_bg(self, lv_bg: List) -> List: 797 | # """ 798 | # :param lv_bg[0]: (bs, z_bg_dim, 1, 1) 799 | # :return: 800 | # """ 801 | # z_bg = lv_bg[0] 802 | # pa = self.p_bg_decoder(z_bg) 803 | # 804 | # return pa 805 | 806 | def render(self, o_att, a_att, bg, z_pres, z_where, z_depth, global_step) -> List: 807 | """ 808 | :param pa: variables with size (bs, dim, num_cell, num_cell) 809 | :param lv_z: o and a with size (bs * num_cell * num_cell, dim) 810 | :return: 811 | """ 812 | 813 | bs = z_pres.size(0) 814 | 815 | z_pres = z_pres.permute(0, 2, 3, 1).reshape( 816 | bs * self.args.arch.num_cell ** 2, -1 817 | ) 818 | z_where = z_where.permute(0, 2, 3, 1).reshape( 819 | bs * self.args.arch.num_cell ** 2, -1 820 | ) 821 | z_depth = z_depth.permute(0, 2, 3, 1).reshape( 822 | bs * self.args.arch.num_cell ** 2, -1 823 | ) 824 | 825 | if self.args.arch.phase_overlap == True: 826 | if ( 827 | self.args.train.phase_bg_alpha_curriculum 828 | and self.args.train.bg_alpha_curriculum_period[0] 829 | < global_step 830 | < self.args.train.bg_alpha_curriculum_period[1] 831 | ): 832 | z_pres = z_pres.clamp(max=0.99) 833 | a_att_hat = a_att * z_pres.view(-1, 1, 1, 1) 834 | y_att = a_att_hat * o_att 835 | 836 | # (bs, self.args.arch.num_cell * self.args.arch.num_cell, 3, img_h, img_w) 837 | y_att_full_res = spatial_transform( 838 | y_att, 839 | z_where, 840 | ( 841 | bs * self.args.arch.num_cell ** 2, 842 | self.args.data.inp_channel, 843 | self.args.data.img_h, 844 | self.args.data.img_w, 845 | ), 846 | inverse=True, 847 | ).view( 848 | -1, 849 | self.args.arch.num_cell * self.args.arch.num_cell, 850 | self.args.data.inp_channel, 851 | self.args.data.img_h, 852 | self.args.data.img_w, 853 | ) 854 | o_att_full_res = spatial_transform( 855 | o_att, 856 | z_where, 857 | ( 858 | bs * self.args.arch.num_cell ** 2, 859 | self.args.data.inp_channel, 860 | self.args.data.img_h, 861 | self.args.data.img_w, 862 | ), 863 | inverse=True, 864 | ).view( 865 | -1, 866 | self.args.arch.num_cell * self.args.arch.num_cell, 867 | self.args.data.inp_channel, 868 | self.args.data.img_h, 869 | self.args.data.img_w, 870 | ) 871 | 872 | # (self.args.arch.num_cell * self.args.arch.num_cell * bs, 1, glimpse_size, glimpse_size) 873 | importance_map = a_att_hat * torch.sigmoid(-z_depth).view(-1, 1, 1, 1) 874 | # (self.args.arch.num_cell * self.args.arch.num_cell * bs, 1, img_h, img_w) 875 | importance_map_full_res = spatial_transform( 876 | importance_map, 877 | z_where, 878 | ( 879 | self.args.arch.num_cell * self.args.arch.num_cell * bs, 880 | 1, 881 | self.args.data.img_h, 882 | self.args.data.img_w, 883 | ), 884 | inverse=True, 885 | ) 886 | # # (bs, self.args.arch.num_cell * self.args.arch.num_cell, 1, img_h, img_w) 887 | importance_map_full_res = importance_map_full_res.view( 888 | -1, 889 | self.args.arch.num_cell * self.args.arch.num_cell, 890 | 1, 891 | self.args.data.img_h, 892 | self.args.data.img_w, 893 | ) 894 | importance_map_full_res_norm = importance_map_full_res / ( 895 | importance_map_full_res.sum(dim=1, keepdim=True) + self.args.const.eps 896 | ) 897 | 898 | # (bs, 3, img_h, img_w) 899 | y_nobg = (y_att_full_res * importance_map_full_res_norm).sum(dim=1) 900 | 901 | # (bs, self.args.arch.num_cell * self.args.arch.num_cell, 1, img_h, img_w) 902 | a_att_hat_full_res = spatial_transform( 903 | a_att_hat, 904 | z_where, 905 | ( 906 | self.args.arch.num_cell * self.args.arch.num_cell * bs, 907 | 1, 908 | self.args.data.img_h, 909 | self.args.data.img_w, 910 | ), 911 | inverse=True, 912 | ).view( 913 | -1, 914 | self.args.arch.num_cell * self.args.arch.num_cell, 915 | 1, 916 | self.args.data.img_h, 917 | self.args.data.img_w, 918 | ) 919 | alpha_map = a_att_hat_full_res.sum(dim=1) 920 | # (bs, 1, img_h, img_w) 921 | alpha_map = ( 922 | alpha_map 923 | + ( 924 | alpha_map.clamp(self.args.const.eps, 1 - self.args.const.eps) 925 | - alpha_map 926 | ).detach() 927 | ) 928 | 929 | if self.args.train.phase_bg_alpha_curriculum: 930 | if ( 931 | self.args.train.bg_alpha_curriculum_period[0] 932 | < global_step 933 | < self.args.train.bg_alpha_curriculum_period[1] 934 | ): 935 | alpha_map = ( 936 | alpha_map.new_ones(alpha_map.size()) 937 | * self.args.train.bg_alpha_curriculum_value 938 | ) 939 | # y_nobg = alpha_map * y_nobg 940 | y = y_nobg + (1.0 - alpha_map) * bg 941 | else: 942 | y_att = a_att * o_att 943 | 944 | o_att_full_res = spatial_transform( 945 | o_att, 946 | z_where, 947 | ( 948 | bs * self.args.arch.num_cell ** 2, 949 | self.args.data.inp_channel, 950 | self.args.data.img_h, 951 | self.args.data.img_w, 952 | ), 953 | inverse=True, 954 | ).view( 955 | -1, 956 | self.args.arch.num_cell * self.args.arch.num_cell, 957 | self.args.data.inp_channel, 958 | self.args.data.img_h, 959 | self.args.data.img_w, 960 | ) 961 | a_att_hat_full_res = spatial_transform( 962 | a_att * z_pres.view(bs * self.args.arch.num_cell ** 2, 1, 1, 1), 963 | z_where, 964 | ( 965 | self.args.arch.num_cell * self.args.arch.num_cell * bs, 966 | 1, 967 | self.args.data.img_h, 968 | self.args.data.img_w, 969 | ), 970 | inverse=True, 971 | ).view( 972 | -1, 973 | self.args.arch.num_cell * self.args.arch.num_cell, 974 | 1, 975 | self.args.data.img_h, 976 | self.args.data.img_w, 977 | ) 978 | 979 | # (self.args.arch.num_cell * self.args.arch.num_cell * bs, 3, img_h, img_w) 980 | y_att_full_res = spatial_transform( 981 | y_att, 982 | z_where, 983 | ( 984 | bs * self.args.arch.num_cell ** 2, 985 | self.args.data.inp_channel, 986 | self.args.data.img_h, 987 | self.args.data.img_w, 988 | ), 989 | inverse=True, 990 | ) 991 | y = ( 992 | ( 993 | y_att_full_res 994 | * z_pres.view(bs * self.args.arch.num_cell ** 2, 1, 1, 1) 995 | ) 996 | .view( 997 | bs, 998 | -1, 999 | self.args.data.inp_channel, 1000 | self.args.data.img_h, 1001 | self.args.data.img_w, 1002 | ) 1003 | .sum(dim=1) 1004 | ) 1005 | y_nobg = y 1006 | alpha_map = y.new_ones(y.size(0), 1, y.size(2), y.size(3)) 1007 | 1008 | return y, y_nobg, alpha_map, o_att_full_res, a_att_hat_full_res 1009 | 1010 | def loss_function(self, x, global_step, generate_bbox=False): 1011 | return self.forward(x, global_step, generate_bbox) 1012 | 1013 | 1014 | def hyperparam_anneal(args, global_step): 1015 | if args.train.beta_aux_pres_anneal_end_step == 0: 1016 | args.train.beta_aux_pres = args.train.beta_aux_pres_anneal_start_value 1017 | else: 1018 | args.train.beta_aux_pres = linear_schedule( 1019 | global_step, 1020 | args.train.beta_aux_pres_anneal_start_step, 1021 | args.train.beta_aux_pres_anneal_end_step, 1022 | args.train.beta_aux_pres_anneal_start_value, 1023 | args.train.beta_aux_pres_anneal_end_value, 1024 | ) 1025 | 1026 | if args.train.beta_aux_where_anneal_end_step == 0: 1027 | args.train.beta_aux_where = args.train.beta_aux_where_anneal_start_value 1028 | else: 1029 | args.train.beta_aux_where = linear_schedule( 1030 | global_step, 1031 | args.train.beta_aux_where_anneal_start_step, 1032 | args.train.beta_aux_where_anneal_end_step, 1033 | args.train.beta_aux_where_anneal_start_value, 1034 | args.train.beta_aux_where_anneal_end_value, 1035 | ) 1036 | 1037 | if args.train.beta_aux_what_anneal_end_step == 0: 1038 | args.train.beta_aux_what = args.train.beta_aux_what_anneal_start_value 1039 | else: 1040 | args.train.beta_aux_what = linear_schedule( 1041 | global_step, 1042 | args.train.beta_aux_what_anneal_start_step, 1043 | args.train.beta_aux_what_anneal_end_step, 1044 | args.train.beta_aux_what_anneal_start_value, 1045 | args.train.beta_aux_what_anneal_end_value, 1046 | ) 1047 | 1048 | if args.train.beta_aux_depth_anneal_end_step == 0: 1049 | args.train.beta_aux_depth = args.train.beta_aux_depth_anneal_start_value 1050 | else: 1051 | args.train.beta_aux_depth = linear_schedule( 1052 | global_step, 1053 | args.train.beta_aux_depth_anneal_start_step, 1054 | args.train.beta_aux_depth_anneal_end_step, 1055 | args.train.beta_aux_depth_anneal_start_value, 1056 | args.train.beta_aux_depth_anneal_end_value, 1057 | ) 1058 | 1059 | if args.train.beta_aux_global_anneal_end_step == 0: 1060 | args.train.beta_aux_global = args.train.beta_aux_global_anneal_start_value 1061 | else: 1062 | args.train.beta_aux_global = linear_schedule( 1063 | global_step, 1064 | args.train.beta_aux_global_anneal_start_step, 1065 | args.train.beta_aux_global_anneal_end_step, 1066 | args.train.beta_aux_global_anneal_start_value, 1067 | args.train.beta_aux_global_anneal_end_value, 1068 | ) 1069 | 1070 | if args.train.beta_aux_bg_anneal_end_step == 0: 1071 | args.train.beta_aux_bg = args.train.beta_aux_bg_anneal_start_value 1072 | else: 1073 | args.train.beta_aux_bg = linear_schedule( 1074 | global_step, 1075 | args.train.beta_aux_bg_anneal_start_step, 1076 | args.train.beta_aux_bg_anneal_end_step, 1077 | args.train.beta_aux_bg_anneal_start_value, 1078 | args.train.beta_aux_bg_anneal_end_value, 1079 | ) 1080 | 1081 | ########################### split here ########################### 1082 | if args.train.beta_pres_anneal_end_step == 0: 1083 | args.train.beta_pres = args.train.beta_pres_anneal_start_value 1084 | else: 1085 | args.train.beta_pres = linear_schedule( 1086 | global_step, 1087 | args.train.beta_pres_anneal_start_step, 1088 | args.train.beta_pres_anneal_end_step, 1089 | args.train.beta_pres_anneal_start_value, 1090 | args.train.beta_pres_anneal_end_value, 1091 | ) 1092 | 1093 | if args.train.beta_where_anneal_end_step == 0: 1094 | args.train.beta_where = args.train.beta_where_anneal_start_value 1095 | else: 1096 | args.train.beta_where = linear_schedule( 1097 | global_step, 1098 | args.train.beta_where_anneal_start_step, 1099 | args.train.beta_where_anneal_end_step, 1100 | args.train.beta_where_anneal_start_value, 1101 | args.train.beta_where_anneal_end_value, 1102 | ) 1103 | 1104 | if args.train.beta_what_anneal_end_step == 0: 1105 | args.train.beta_what = args.train.beta_what_anneal_start_value 1106 | else: 1107 | args.train.beta_what = linear_schedule( 1108 | global_step, 1109 | args.train.beta_what_anneal_start_step, 1110 | args.train.beta_what_anneal_end_step, 1111 | args.train.beta_what_anneal_start_value, 1112 | args.train.beta_what_anneal_end_value, 1113 | ) 1114 | 1115 | if args.train.beta_depth_anneal_end_step == 0: 1116 | args.train.beta_depth = args.train.beta_depth_anneal_start_value 1117 | else: 1118 | args.train.beta_depth = linear_schedule( 1119 | global_step, 1120 | args.train.beta_depth_anneal_start_step, 1121 | args.train.beta_depth_anneal_end_step, 1122 | args.train.beta_depth_anneal_start_value, 1123 | args.train.beta_depth_anneal_end_value, 1124 | ) 1125 | 1126 | if args.train.beta_global_anneal_end_step == 0: 1127 | args.train.beta_global = args.train.beta_global_anneal_start_value 1128 | else: 1129 | args.train.beta_global = linear_schedule( 1130 | global_step, 1131 | args.train.beta_global_anneal_start_step, 1132 | args.train.beta_global_anneal_end_step, 1133 | args.train.beta_global_anneal_start_value, 1134 | args.train.beta_global_anneal_end_value, 1135 | ) 1136 | 1137 | if args.train.tau_pres_anneal_end_step == 0: 1138 | args.train.tau_pres = args.train.tau_pres_anneal_start_value 1139 | else: 1140 | args.train.tau_pres = linear_schedule( 1141 | global_step, 1142 | args.train.tau_pres_anneal_start_step, 1143 | args.train.tau_pres_anneal_end_step, 1144 | args.train.tau_pres_anneal_start_value, 1145 | args.train.tau_pres_anneal_end_value, 1146 | ) 1147 | 1148 | if args.train.beta_bg_anneal_end_step == 0: 1149 | args.train.beta_bg = args.train.beta_bg_anneal_start_value 1150 | else: 1151 | args.train.beta_bg = linear_schedule( 1152 | global_step, 1153 | args.train.beta_bg_anneal_start_step, 1154 | args.train.beta_bg_anneal_end_step, 1155 | args.train.beta_bg_anneal_start_value, 1156 | args.train.beta_bg_anneal_end_value, 1157 | ) 1158 | 1159 | return args 1160 | -------------------------------------------------------------------------------- /object_discovery/gnm/logging.py: -------------------------------------------------------------------------------- 1 | # Partially based on https://github.com/karazijal/clevrtex/blob/fe982ab224689526f5e2f83a2f542ba958d88abd/experiments/framework/vis_mixin.py 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from .metrics import align_masks_iou, dices 6 | import itertools 7 | 8 | import torch 9 | import torchvision as tv 10 | from ..segmentation_metrics import adjusted_rand_index 11 | from ..utils import cmap_tensor 12 | 13 | @torch.no_grad() 14 | def _to_img(img, lim=None, dim=-3): 15 | if lim: 16 | img = img[:lim] 17 | img = (img.clamp(0, 1) * 255).to(torch.uint8) 18 | if img.shape[dim] < 3: 19 | epd_dims = [-1 for _ in range(len(img.shape))] 20 | epd_dims[dim] = 3 21 | img = img.expand(*epd_dims) 22 | return img 23 | 24 | 25 | @torch.no_grad() 26 | def log_recons(input, output, patches, masks, background=None, pres=None): 27 | vis_imgs = [] 28 | img = _to_img(input) 29 | vis_imgs.extend(img) 30 | omg = _to_img(output, lim=len(img)) 31 | vis_imgs.extend(omg) 32 | 33 | if background is not None and not torch.all(background == 0.0): 34 | bg = _to_img(background, lim=len(img)) 35 | vis_imgs.extend(bg) 36 | 37 | masks = masks[: len(img)] 38 | patches = patches[: len(img)] 39 | ms = masks * patches 40 | for sid in range(patches.size(1)): 41 | # p = (patches[:len(img), sid].clamp(0., 1.) * 255.).to(torch.uint8).detach().cpu() 42 | # if p.shape[1] == 3: 43 | # vis_imgs.extend(p) 44 | m = _to_img(ms[:, sid]) 45 | m_hat = [] 46 | if pres is not None: 47 | for i in range(0, len(img)): 48 | if pres[i, sid][0] == 1: 49 | m[i, 0, :2, :2] = 0 50 | m[i, 1, :2, :2] = 255 51 | m[i, 2, :2, :2] = 0 52 | m_hat.append(m[i]) 53 | else: 54 | m_hat.append(m[i]) 55 | else: 56 | m_hat.extend(m) 57 | vis_imgs.extend(m_hat) 58 | grid = tv.utils.make_grid(vis_imgs, pad_value=128, nrow=len(img), padding=1) 59 | return grid.permute([1, 2, 0]).cpu().numpy() 60 | 61 | 62 | @torch.no_grad() 63 | def log_images(input, output): 64 | vis_imgs = [] 65 | img = _to_img(input) 66 | omg = _to_img(output, lim=len(img)) 67 | 68 | for i, (i_img, o_img) in enumerate(zip(img, omg)): 69 | vis_imgs.append(i_img) 70 | vis_imgs.append(o_img) 71 | grid = tv.utils.make_grid(vis_imgs, pad_value=128, nrow=16) 72 | return grid.permute([1, 2, 0]).cpu().numpy() 73 | 74 | 75 | @torch.no_grad() 76 | def log_semantic_images(input, output, true_masks, pred_masks): 77 | assert len(true_masks.shape) == 5 and len(pred_masks.shape) == 5 78 | img = _to_img(input) 79 | omg = _to_img(output, lim=len(img)) 80 | true_masks = true_masks[: len(img)].to(torch.float).argmax(1).squeeze(1) 81 | pred_masks = pred_masks[: len(img)].to(torch.float).argmax(1).squeeze(1) 82 | tms = (cmap_tensor(true_masks) * 255.0).to(torch.uint8) 83 | pms = (cmap_tensor(pred_masks) * 255.0).to(torch.uint8) 84 | vis_imgs = list(itertools.chain.from_iterable(zip(img, omg, tms, pms))) 85 | grid = tv.utils.make_grid(vis_imgs, pad_value=128, nrow=16) 86 | return grid.permute([1, 2, 0]).cpu().numpy() 87 | 88 | 89 | def gnm_log_validation_outputs(batch, batch_idx, output, is_global_zero): 90 | logs = {} 91 | img, masks, vis = batch 92 | masks = masks.transpose(-1, -2).transpose(-2, -3).to(torch.float) 93 | 94 | mse = F.mse_loss(output["canvas"], img, reduction="none").sum((1, 2, 3)) 95 | logs["mse"] = mse 96 | 97 | ali_pmasks = None 98 | ali_tmasks = None 99 | 100 | # Transforms might have changed this. 101 | # cnts = torch.sum(vis, dim=-1) - 1 # -1 for discounting the background from visibility 102 | # estimate from masks 103 | cnts = ( 104 | torch.round(masks.to(torch.float)).flatten(2).any(-1).to(torch.float).sum(-1) 105 | - 1 106 | ) 107 | 108 | if "steps" in output: 109 | pred_masks = output["steps"]["mask"] 110 | pred_vis = output["steps"]["z_pres"].squeeze(-1) 111 | 112 | # `align_masks_iou` adds the background pixels to `ali_pmasks` (aligned 113 | # predicted masks). 114 | ali_pmasks, ali_tmasks, ious, ali_pvis, ali_tvis = align_masks_iou( 115 | pred_masks, masks, pred_mask_vis=pred_vis, true_mask_vis=vis, has_bg=False 116 | ) 117 | 118 | ali_cvis = ali_pvis | ali_tvis 119 | num_paired_slots = ali_cvis.sum(-1) - 1 120 | mses = F.mse_loss( 121 | output["canvas"][:, None] * ali_pmasks, 122 | img[:, None] * ali_tmasks, 123 | reduction="none", 124 | ).sum((-1, -2, -3)) 125 | 126 | bg_mse = mses[:, 0] 127 | logs["bg_mse"] = bg_mse 128 | 129 | slot_mse = mses[:, 1:].sum(-1) / num_paired_slots 130 | logs["slot_mse"] = slot_mse 131 | 132 | mious = ious[:, 1:].sum(-1) / num_paired_slots 133 | # mious = torch.where(zero_mask, 0., mious) 134 | logs["miou"] = mious 135 | 136 | dice = dices(ali_pmasks, ali_tmasks) 137 | mdice = dice[:, 1:].sum(-1) / num_paired_slots 138 | logs["dice"] = mdice 139 | 140 | # aris = ari(ali_pmasks, ali_tmasks) 141 | batch_size, num_entries, channels, height, width = ali_tmasks.shape 142 | ali_tmasks_reshaped = ( 143 | torch.reshape( 144 | ali_tmasks.squeeze(), [batch_size, num_entries, height * width] 145 | ) 146 | .permute([0, 2, 1]) 147 | .to(torch.float) 148 | ) 149 | batch_size, num_entries, channels, height, width = ali_pmasks.shape 150 | ali_pmasks_reshaped = ( 151 | torch.reshape( 152 | ali_pmasks.squeeze(), [batch_size, num_entries, height * width] 153 | ) 154 | .permute([0, 2, 1]) 155 | .to(torch.float) 156 | ) 157 | logs["ari_with_background"] = adjusted_rand_index( 158 | ali_tmasks_reshaped, ali_pmasks_reshaped 159 | ) 160 | 161 | # aris_fg = ari(ali_pmasks, ali_tmasks, True).mean().detach() 162 | # `[..., 1:]` removes the background pixels group from the true mask. 163 | logs["ari"] = adjusted_rand_index( 164 | ali_tmasks_reshaped[..., 1:], ali_pmasks_reshaped 165 | ) 166 | 167 | # Can also calculate ari (same as above line) without using the aligned 168 | # masks by directly adding the background to the predicted masks. The 169 | # background is everything that is not predicted as an object. This is 170 | # done automatically when aligning the masks in `align_masks_iou`. 171 | # pred_masks_with_background = torch.cat([1 - pred_masks.sum(1, keepdim=True), pred_masks], 1) 172 | # batch_size, num_entries, channels, height, width = masks.shape 173 | # masks_reshaped = torch.reshape(masks, [batch_size, num_entries, height * width]).permute([0, 2, 1]).to(torch.float) 174 | # batch_size, num_entries, channels, height, width = pred_masks_with_background.shape 175 | # pred_masks_with_background_reshaped = torch.reshape(pred_masks_with_background, [batch_size, num_entries, height * width]).permute([0, 2, 1]).to(torch.float) 176 | # logs["ari4"] = adjusted_rand_index(masks_reshaped[..., 1:], pred_masks_with_background_reshaped) 177 | 178 | batch_size, num_entries, channels, height, width = masks.shape 179 | masks_reshaped = ( 180 | torch.reshape(masks, [batch_size, num_entries, height * width]) 181 | .permute([0, 2, 1]) 182 | .to(torch.float) 183 | ) 184 | batch_size, num_entries, channels, height, width = pred_masks.shape 185 | pred_masks_reshaped = ( 186 | torch.reshape(pred_masks, [batch_size, num_entries, height * width]) 187 | .permute([0, 2, 1]) 188 | .to(torch.float) 189 | ) 190 | logs["ari_no_background"] = adjusted_rand_index( 191 | masks_reshaped[..., 1:], pred_masks_reshaped 192 | ) 193 | 194 | pred_counts = output["counts"].detach().to(torch.int) 195 | logs["acc"] = (pred_counts == cnts).to(float) 196 | logs["cnt"] = pred_counts.to(float) 197 | 198 | images = {} 199 | if batch_idx == 0 and is_global_zero: 200 | images["output"] = log_images(img, output["canvas_with_bbox"]) 201 | if "steps" in output: 202 | images["recon"] = log_recons( 203 | img[:32], 204 | output["canvas"], 205 | output["steps"]["patch"], 206 | output["steps"]["mask"], 207 | output.get("background", None), 208 | pres=output["steps"]["z_pres"], 209 | ) 210 | 211 | # If masks have been aligned; log semantic map 212 | if ali_pmasks is not None and ali_tmasks is not None: 213 | images["segmentation"] = log_semantic_images( 214 | img[:32], output["canvas"], ali_tmasks, ali_pmasks 215 | ) 216 | 217 | return logs, images 218 | -------------------------------------------------------------------------------- /object_discovery/gnm/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.functional import one_hot 3 | 4 | from scipy.optimize import linear_sum_assignment 5 | 6 | 7 | def binarise_masks(pred_masks, num_classes=None): 8 | assert len(pred_masks.shape) == 5 9 | assert pred_masks.shape[2] == 1 10 | num_classes = num_classes or pred_masks.shape[1] 11 | return ( 12 | one_hot(pred_masks.argmax(axis=1), num_classes=num_classes) 13 | .transpose(-1, -2) 14 | .transpose(-2, -3) 15 | .transpose(-3, -4) 16 | .to(bool) 17 | ) 18 | 19 | 20 | def convert_masks(pred_masks, true_masks, correct_bg=True): 21 | if correct_bg: 22 | pred_masks = torch.cat( 23 | [1.0 - pred_masks.sum(axis=1, keepdims=True), pred_masks], dim=1 24 | ) 25 | num_classes = max(pred_masks.shape[1], true_masks.shape[1]) 26 | pred_masks = binarise_masks(pred_masks, num_classes=num_classes) 27 | true_masks = binarise_masks(true_masks, num_classes=num_classes) 28 | # if torch.any(torch.isnan(pred_masks)) or torch.any(torch.isinf(pred_masks)): import ipdb; ipdb.set_trace() 29 | # if torch.any(torch.isnan(true_masks)) or torch.any(torch.isinf(true_masks)): import ipdb; ipdb.set_trace() 30 | return pred_masks, true_masks 31 | 32 | 33 | def iou_matching(pred_masks, true_masks, threshold=1e-2): 34 | """The order of true_masks is preserved up to potentially missing background dim (in shic """ 35 | assert pred_masks.shape[0] == true_masks.shape[0], "Batchsize mismatch" 36 | # true_masks = true_masks.to(torch.bool) 37 | assert ( 38 | pred_masks.dtype == true_masks.dtype 39 | ), f"Dtype mismatch ({pred_masks.dtype}!={true_masks.dtype})" 40 | 41 | if pred_masks.dtype != torch.bool: 42 | pred_masks, true_masks = convert_masks(pred_masks, true_masks, correct_bg=False) 43 | assert pred_masks.shape[-3:] == true_masks.shape[-3:], "Mask shape mismatch" 44 | 45 | pred_masks = pred_masks.to(float) 46 | true_masks = true_masks.to(float) 47 | 48 | tspec = dict(device=pred_masks.device) 49 | iou_matrix = torch.zeros( 50 | pred_masks.shape[0], pred_masks.shape[1], true_masks.shape[1], **tspec 51 | ) 52 | true_masks_sums = true_masks.sum((-1, -2, -3)) 53 | pred_masks_sums = pred_masks.sum((-1, -2, -3)) 54 | 55 | for pi in range(pred_masks.shape[1]): 56 | pandt = (pred_masks[:, pi : pi + 1] * true_masks).sum((-1, -2, -3)) 57 | port = pred_masks_sums[:, pi : pi + 1] + true_masks_sums 58 | iou_matrix[:, pi] = (pandt + 1e-2) / (port + 1e-2) 59 | iou_matrix[pred_masks_sums[:, pi] == 0.0, pi] = 0.0 60 | 61 | for ti in range(true_masks.shape[1]): 62 | iou_matrix[true_masks_sums[:, ti] == 0.0, :, ti] = 0.0 63 | 64 | cost_matrix = iou_matrix.cpu().detach().numpy() 65 | inds = torch.zeros( 66 | pred_masks.shape[0], 67 | 2, 68 | min(pred_masks.shape[1], true_masks.shape[1]), 69 | dtype=torch.int64, 70 | **tspec, 71 | ) 72 | ious = torch.zeros( 73 | pred_masks.shape[0], 74 | min(pred_masks.shape[1], true_masks.shape[1]), 75 | dtype=float, 76 | **tspec, 77 | ) 78 | for bi in range(cost_matrix.shape[0]): 79 | col_ind, row_ind = linear_sum_assignment(cost_matrix[bi].T, maximize=True) 80 | inds[bi, 0] = torch.tensor(row_ind, **tspec) 81 | inds[bi, 1] = torch.tensor(col_ind, **tspec) 82 | ious[bi] = torch.tensor(cost_matrix[bi, row_ind, col_ind], **tspec) 83 | # if torch.any(torch.isnan(inds)) or torch.any(torch.isinf(inds)): import ipdb; ipdb.set_trace() 84 | # if torch.any(torch.isnan(ious)) or torch.any(torch.isinf(ious)): import ipdb; ipdb.set_trace() 85 | return inds, ious 86 | 87 | 88 | def align_masks_iou( 89 | pred_mask, true_mask, pred_mask_vis=None, true_mask_vis=None, has_bg=False 90 | ): 91 | pred_mask, true_mask = convert_masks(pred_mask, true_mask, correct_bg=not has_bg) 92 | inds, ious = iou_matching(pred_mask, true_mask) 93 | 94 | # Reindex the masks into alighned order 95 | B, S, *s = pred_mask.shape 96 | bias = S * torch.arange(B, device=pred_mask.device)[:, None] 97 | pred_mask = pred_mask.reshape(B * S, *s)[(bias + inds[:, 0]).flatten()].view( 98 | B, S, *s 99 | ) 100 | true_mask = true_mask.reshape(B * S, *s)[(bias + inds[:, 1]).flatten()].view( 101 | B, S, *s 102 | ) 103 | 104 | ret = pred_mask, true_mask, ious 105 | 106 | if pred_mask_vis is not None: 107 | if pred_mask_vis.dtype != torch.bool: 108 | pred_mask_vis = pred_mask_vis > 0.5 109 | 110 | if has_bg: 111 | pred_mask_vis = torch.cat( 112 | [ 113 | pred_mask_vis, 114 | torch.zeros( 115 | B, 116 | S - pred_mask_vis.shape[1], 117 | dtype=pred_mask_vis.dtype, 118 | device=pred_mask_vis.device, 119 | ), 120 | ], 121 | axis=1, 122 | ) 123 | else: 124 | pred_mask_vis = torch.cat( 125 | [ 126 | torch.ones( 127 | B, 1, dtype=pred_mask_vis.dtype, device=pred_mask_vis.device 128 | ), 129 | pred_mask_vis, 130 | torch.zeros( 131 | B, 132 | S - pred_mask_vis.shape[1] - 1, 133 | dtype=pred_mask_vis.dtype, 134 | device=pred_mask_vis.device, 135 | ), 136 | ], 137 | axis=1, 138 | ) 139 | pred_mask_vis = pred_mask_vis.reshape(B * S)[ 140 | (bias + inds[:, 0]).flatten() 141 | ].view(B, S) 142 | ret += (pred_mask_vis,) 143 | 144 | if true_mask_vis is not None: 145 | if true_mask_vis.dtype != torch.bool: 146 | true_mask_vis = true_mask_vis > 0.5 147 | 148 | true_mask_vis = torch.cat( 149 | [ 150 | true_mask_vis, 151 | torch.zeros( 152 | B, 153 | S - true_mask_vis.shape[1], 154 | dtype=true_mask_vis.dtype, 155 | device=true_mask_vis.device, 156 | ), 157 | ], 158 | axis=1, 159 | ) 160 | true_mask_vis = true_mask_vis.reshape(B * S)[ 161 | (bias + inds[:, 1]).flatten() 162 | ].view(B, S) 163 | ret += (true_mask_vis,) 164 | # for i,r in enumerate(ret): 165 | # if torch.any(torch.isnan(r)) or torch.any(torch.isinf(r)): 166 | # print('/tStopcon:', i) 167 | # import ipdb; ipdb.set_trace() 168 | return ret 169 | 170 | 171 | def dices(pred_mask, true_mask): 172 | dice = ( 173 | 2 174 | * (pred_mask * true_mask).sum((-3, -2, -1)) 175 | / (pred_mask.sum((-3, -2, -1)) + true_mask.sum((-3, -2, -1))) 176 | ) 177 | dice = torch.where(torch.isnan(dice) | torch.isinf(dice), 0.0, dice.to(float)) 178 | # if torch.any(torch.isnan(dice)) or torch.any(torch.isinf(dice)): import ipdb; ipdb.set_trace() 179 | return dice 180 | 181 | 182 | # def ari(pred_mask, true_mask, skip_0=False): 183 | # B = pred_mask.shape[0] 184 | # pm = pred_mask.to(int).argmax(axis=1).squeeze().view(B, -1).cpu().detach().numpy() 185 | # tm = true_mask.to(int).argmax(axis=1).squeeze().view(B, -1).cpu().detach().numpy() 186 | # aris = [] 187 | # for bi in range(B): 188 | # t = tm[bi] 189 | # p = pm[bi] 190 | # if skip_0: 191 | # p = p[t > 0] 192 | # t = t[t > 0] 193 | # ari_score = adjusted_rand_score(t, p) 194 | # aris.append(ari_score) 195 | # aris = torch.tensor(np.array(aris), device=pred_mask.device) 196 | # # if torch.any(torch.isnan(aris)) or torch.any(torch.isinf(aris)): import ipdb; ipdb.set_trace() 197 | # return aris 198 | -------------------------------------------------------------------------------- /object_discovery/gnm/module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from typing import Any, List, Tuple 5 | from .submodule import StackConvNorm, StackSubPixelNorm, StackMLP, ConvLSTMCell 6 | from torch.distributions import RelaxedBernoulli, Normal 7 | 8 | 9 | class LocalLatentDecoder(nn.Module): 10 | def __init__(self, args: Any): 11 | super(LocalLatentDecoder, self).__init__() 12 | self.args = args 13 | 14 | pwdw_net_inp_dim = self.args.arch.img_enc_dim 15 | 16 | self.pwdw_net = StackConvNorm( 17 | pwdw_net_inp_dim, 18 | self.args.arch.pwdw.pwdw_filters, 19 | self.args.arch.pwdw.pwdw_kernel_sizes, 20 | self.args.arch.pwdw.pwdw_strides, 21 | self.args.arch.pwdw.pwdw_groups, 22 | norm_act_final=True, 23 | ) 24 | 25 | self.q_depth_net = nn.Conv2d( 26 | self.args.arch.pwdw.pwdw_filters[-1], self.args.z.z_depth_dim * 2, 1 27 | ) 28 | self.q_where_net = nn.Conv2d( 29 | self.args.arch.pwdw.pwdw_filters[-1], self.args.z.z_where_dim * 2, 1 30 | ) 31 | self.q_what_net = nn.Conv2d( 32 | self.args.arch.pwdw.pwdw_filters[-1], self.args.z.z_what_dim * 2, 1 33 | ) 34 | self.q_pres_net = nn.Conv2d( 35 | self.args.arch.pwdw.pwdw_filters[-1], self.args.z.z_pres_dim, 1 36 | ) 37 | 38 | torch.nn.init.uniform_(self.q_where_net.weight.data, -0.01, 0.01) 39 | # scale 40 | torch.nn.init.constant_(self.q_where_net.bias.data[0], -1.0) 41 | # ratio, x, y, std 42 | torch.nn.init.constant_(self.q_where_net.bias.data[1:], 0) 43 | 44 | def forward(self, img_enc: torch.Tensor, ss_p_z: List = None) -> List: 45 | """ 46 | 47 | :param img_enc: (bs, dim, 4, 4) 48 | :param global_dec: (bs, dim, 4, 4) 49 | :return: 50 | """ 51 | 52 | if ss_p_z is not None: 53 | ( 54 | p_pres_logits, 55 | p_where_mean, 56 | p_where_std, 57 | p_depth_mean, 58 | p_depth_std, 59 | p_what_mean, 60 | p_what_std, 61 | ) = ss_p_z 62 | 63 | pwdw_inp = img_enc 64 | 65 | pwdw_ss = self.pwdw_net(pwdw_inp) 66 | 67 | q_pres_logits = ( 68 | self.q_pres_net(pwdw_ss).tanh() * self.args.const.pres_logit_scale 69 | ) 70 | 71 | # q_where_mean, q_where_std: (bs, dim, num_cell, num_cell) 72 | q_where_mean, q_where_std = self.q_where_net(pwdw_ss).chunk(2, 1) 73 | q_where_std = F.softplus(q_where_std) 74 | 75 | # q_depth_mean, q_depth_std: (bs, dim, num_cell, num_cell) 76 | q_depth_mean, q_depth_std = self.q_depth_net(pwdw_ss).chunk(2, 1) 77 | q_depth_std = F.softplus(q_depth_std) 78 | 79 | q_what_mean, q_what_std = self.q_what_net(pwdw_ss).chunk(2, 1) 80 | q_what_std = F.softplus(q_what_std) 81 | 82 | ss = [ 83 | q_pres_logits, 84 | q_where_mean, 85 | q_where_std, 86 | q_depth_mean, 87 | q_depth_std, 88 | q_what_mean, 89 | q_what_std, 90 | ] 91 | 92 | return ss 93 | 94 | 95 | class LocalSampler(nn.Module): 96 | def __init__(self, args: Any): 97 | super(LocalSampler, self).__init__() 98 | self.args = args 99 | 100 | self.z_what_decoder_net = StackSubPixelNorm( 101 | self.args.z.z_what_dim, 102 | self.args.arch.conv.p_what_decoder_filters, 103 | self.args.arch.conv.p_what_decoder_kernel_sizes, 104 | self.args.arch.conv.p_what_decoder_upscales, 105 | self.args.arch.conv.p_what_decoder_groups, 106 | norm_act_final=False, 107 | ) 108 | 109 | self.register_buffer( 110 | "offset", 111 | torch.stack( 112 | torch.meshgrid( 113 | torch.arange(args.arch.num_cell).float(), 114 | torch.arange(args.arch.num_cell).float(), 115 | )[::-1], 116 | dim=0, 117 | ).view(1, 2, args.arch.num_cell, args.arch.num_cell), 118 | ) 119 | 120 | def forward(self, ss: List, phase_use_mode: bool = False) -> Tuple: 121 | 122 | ( 123 | p_pres_logits, 124 | p_where_mean, 125 | p_where_std, 126 | p_depth_mean, 127 | p_depth_std, 128 | p_what_mean, 129 | p_what_std, 130 | ) = ss 131 | 132 | if phase_use_mode: 133 | z_pres = (p_pres_logits > 0).float() 134 | else: 135 | z_pres = RelaxedBernoulli( 136 | logits=p_pres_logits, temperature=self.args.train.tau_pres 137 | ).rsample() 138 | 139 | # z_where_scale, z_where_shift: (bs, dim, num_cell, num_cell) 140 | if phase_use_mode: 141 | z_where_scale, z_where_shift = p_where_mean.chunk(2, 1) 142 | else: 143 | z_where_scale, z_where_shift = ( 144 | Normal(p_where_mean, p_where_std).rsample().chunk(2, 1) 145 | ) 146 | 147 | # z_where_origin: (bs, dim, num_cell, num_cell) 148 | z_where_origin = torch.cat( 149 | [z_where_scale.detach(), z_where_shift.detach()], dim=1 150 | ) 151 | 152 | z_where_shift = (2.0 / self.args.arch.num_cell) * ( 153 | self.offset + 0.5 + torch.tanh(z_where_shift) 154 | ) - 1.0 155 | 156 | scale, ratio = z_where_scale.chunk(2, 1) 157 | scale = scale.sigmoid() 158 | ratio = torch.exp(ratio) 159 | ratio_sqrt = ratio.sqrt() 160 | z_where_scale = torch.cat([scale / ratio_sqrt, scale * ratio_sqrt], dim=1) 161 | # z_where: (bs, dim, num_cell, num_cell) 162 | z_where = torch.cat([z_where_scale, z_where_shift], dim=1) 163 | 164 | if phase_use_mode: 165 | z_depth = p_depth_mean 166 | z_what = p_what_mean 167 | else: 168 | z_depth = Normal(p_depth_mean, p_depth_std).rsample() 169 | z_what = Normal(p_what_mean, p_what_std).rsample() 170 | 171 | z_what_reshape = ( 172 | z_what.permute(0, 2, 3, 1) 173 | .reshape(-1, self.args.z.z_what_dim) 174 | .view(-1, self.args.z.z_what_dim, 1, 1) 175 | ) 176 | 177 | if self.args.data.inp_channel == 1 or not self.args.arch.phase_overlap: 178 | o = self.z_what_decoder_net(z_what_reshape) 179 | o = o.sigmoid() 180 | a = o.new_ones(o.size()) 181 | elif self.args.arch.phase_overlap: 182 | o, a = self.z_what_decoder_net(z_what_reshape).split( 183 | [self.args.data.inp_channel, 1], dim=1 184 | ) 185 | o, a = o.sigmoid(), a.sigmoid() 186 | else: 187 | raise NotImplemented 188 | 189 | lv = [z_pres, z_where, z_depth, z_what, z_where_origin] 190 | pa = [o, a] 191 | 192 | return pa, lv 193 | 194 | 195 | class StructDRAW(nn.Module): 196 | def __init__(self, args): 197 | super(StructDRAW, self).__init__() 198 | self.args = args 199 | 200 | self.p_global_decoder_net = StackMLP( 201 | self.args.z.z_global_dim, 202 | self.args.arch.mlp.p_global_decoder_filters, 203 | norm_act_final=True, 204 | ) 205 | 206 | rnn_enc_inp_dim = ( 207 | self.args.arch.img_enc_dim * 2 208 | + self.args.arch.structdraw.rnn_decoder_hid_dim 209 | ) 210 | 211 | rnn_dec_inp_dim = self.args.arch.mlp.p_global_decoder_filters[-1] // ( 212 | self.args.arch.num_cell ** 2 213 | ) 214 | 215 | rnn_dec_inp_dim += self.args.arch.structdraw.hid_to_dec_filters[-1] 216 | 217 | self.rnn_enc = ConvLSTMCell( 218 | input_dim=rnn_enc_inp_dim, 219 | hidden_dim=self.args.arch.structdraw.rnn_encoder_hid_dim, 220 | kernel_size=self.args.arch.structdraw.kernel_size, 221 | num_cell=self.args.arch.num_cell, 222 | ) 223 | 224 | self.rnn_dec = ConvLSTMCell( 225 | input_dim=rnn_dec_inp_dim, 226 | hidden_dim=self.args.arch.structdraw.rnn_decoder_hid_dim, 227 | kernel_size=self.args.arch.structdraw.kernel_size, 228 | num_cell=self.args.arch.num_cell, 229 | ) 230 | 231 | self.p_global_net = StackMLP( 232 | self.args.arch.num_cell ** 2 233 | * self.args.arch.structdraw.rnn_decoder_hid_dim, 234 | self.args.arch.mlp.p_global_encoder_filters, 235 | norm_act_final=False, 236 | ) 237 | 238 | self.q_global_net = StackMLP( 239 | self.args.arch.num_cell ** 2 240 | * self.args.arch.structdraw.rnn_encoder_hid_dim, 241 | self.args.arch.mlp.q_global_encoder_filters, 242 | norm_act_final=False, 243 | ) 244 | 245 | self.hid_to_dec_net = StackConvNorm( 246 | self.args.arch.structdraw.rnn_decoder_hid_dim, 247 | self.args.arch.structdraw.hid_to_dec_filters, 248 | self.args.arch.structdraw.hid_to_dec_kernel_sizes, 249 | self.args.arch.structdraw.hid_to_dec_strides, 250 | self.args.arch.structdraw.hid_to_dec_groups, 251 | norm_act_final=False, 252 | ) 253 | 254 | self.register_buffer( 255 | "dec_step_0", 256 | torch.zeros( 257 | 1, 258 | self.args.arch.structdraw.hid_to_dec_filters[-1], 259 | self.args.arch.num_cell, 260 | self.args.arch.num_cell, 261 | ), 262 | ) 263 | 264 | def forward( 265 | self, 266 | x: torch.Tensor, 267 | phase_generation: bool = False, 268 | generation_from_step: Any = None, 269 | z_global_predefine: Any = None, 270 | ) -> Tuple: 271 | """ 272 | :param x: (bs, dim, num_cell, num_cell) of (bs, dim, img_h, img_w) 273 | :return: 274 | """ 275 | 276 | bs = x.size(0) 277 | 278 | h_enc, c_enc = self.rnn_enc.init_hidden(bs) 279 | h_dec, c_dec = self.rnn_dec.init_hidden(bs) 280 | 281 | p_global_mean_list = [] 282 | p_global_std_list = [] 283 | q_global_mean_list = [] 284 | q_global_std_list = [] 285 | z_global_list = [] 286 | 287 | dec_step = self.dec_step_0.expand(bs, -1, -1, -1) 288 | 289 | for i in range(self.args.arch.draw_step): 290 | 291 | p_global_mean_step, p_global_std_step = self.p_global_net( 292 | h_dec.permute(0, 2, 3, 1).reshape(bs, -1) 293 | ).chunk(2, -1) 294 | p_global_std_step = F.softplus(p_global_std_step) 295 | 296 | if phase_generation or ( 297 | generation_from_step is not None and i >= generation_from_step 298 | ): 299 | 300 | q_global_mean_step = x.new_empty(bs, self.args.z.z_global_dim) 301 | q_global_std_step = x.new_empty(bs, self.args.z.z_global_dim) 302 | 303 | if z_global_predefine is None or z_global_predefine.size(1) <= i: 304 | z_global_step = Normal( 305 | p_global_mean_step, p_global_std_step 306 | ).rsample() 307 | else: 308 | z_global_step = z_global_predefine.view( 309 | bs, -1, self.args.z.z_global_dim 310 | )[:, i] 311 | 312 | else: 313 | 314 | if i == 0: 315 | rnn_encoder_inp = torch.cat([x, x, h_dec], dim=1) 316 | else: 317 | rnn_encoder_inp = torch.cat([x, x - dec_step, h_dec], dim=1) 318 | 319 | h_enc, c_enc = self.rnn_enc(rnn_encoder_inp, [h_enc, c_enc]) 320 | 321 | q_global_mean_step, q_global_std_step = self.q_global_net( 322 | h_enc.permute(0, 2, 3, 1).reshape(bs, -1) 323 | ).chunk(2, -1) 324 | 325 | q_global_std_step = F.softplus(q_global_std_step) 326 | z_global_step = Normal(q_global_mean_step, q_global_std_step).rsample() 327 | 328 | rnn_decoder_inp = self.p_global_decoder_net(z_global_step).reshape( 329 | bs, -1, self.args.arch.num_cell, self.args.arch.num_cell 330 | ) 331 | 332 | rnn_decoder_inp = torch.cat([rnn_decoder_inp, dec_step], dim=1) 333 | 334 | h_dec, c_dec = self.rnn_dec(rnn_decoder_inp, [h_dec, c_dec]) 335 | 336 | dec_step = dec_step + self.hid_to_dec_net(h_dec) 337 | 338 | # (bs, dim) 339 | p_global_mean_list.append(p_global_mean_step) 340 | p_global_std_list.append(p_global_std_step) 341 | q_global_mean_list.append(q_global_mean_step) 342 | q_global_std_list.append(q_global_std_step) 343 | z_global_list.append(z_global_step) 344 | 345 | global_dec = dec_step 346 | 347 | # (bs, steps, dim, 1, 1) 348 | p_global_mean_all = torch.stack(p_global_mean_list, 1)[:, :, :, None, None] 349 | p_global_std_all = torch.stack(p_global_std_list, 1)[:, :, :, None, None] 350 | q_global_mean_all = torch.stack(q_global_mean_list, 1)[:, :, :, None, None] 351 | q_global_std_all = torch.stack(q_global_std_list, 1)[:, :, :, None, None] 352 | z_global_all = torch.stack(z_global_list, 1)[:, :, :, None, None] 353 | 354 | pa = [global_dec] 355 | lv = [z_global_all] 356 | ss = [p_global_mean_all, p_global_std_all, q_global_mean_all, q_global_std_all] 357 | 358 | return pa, lv, ss 359 | 360 | 361 | class BgEncoder(nn.Module): 362 | def __init__(self, args): 363 | super(BgEncoder, self).__init__() 364 | self.args = args 365 | 366 | self.p_bg_encoder = StackMLP( 367 | self.args.arch.img_enc_dim * self.args.arch.num_cell ** 2, 368 | self.args.arch.mlp.q_bg_encoder_filters, 369 | norm_act_final=False, 370 | ) 371 | 372 | def forward(self, x: torch.Tensor) -> Tuple: 373 | """ 374 | :param x: (bs, dim, img_h, img_w) or (bs, dim, num_cell, num_cell) 375 | :return: 376 | """ 377 | bs = x.size(0) 378 | q_bg_mean, q_bg_std = self.p_bg_encoder(x.view(bs, -1)).chunk(2, 1) 379 | q_bg_mean = q_bg_mean.view(bs, -1, 1, 1) 380 | q_bg_std = q_bg_std.view(bs, -1, 1, 1) 381 | 382 | q_bg_std = F.softplus(q_bg_std) 383 | 384 | z_bg = Normal(q_bg_mean, q_bg_std).rsample() 385 | 386 | lv = [z_bg] 387 | 388 | ss = [q_bg_mean, q_bg_std] 389 | 390 | return lv, ss 391 | 392 | 393 | class BgGenerator(nn.Module): 394 | def __init__(self, args): 395 | super(BgGenerator, self).__init__() 396 | self.args = args 397 | 398 | inp_dim = self.args.z.z_global_dim * self.args.arch.draw_step 399 | 400 | self.p_bg_generator = StackMLP( 401 | inp_dim, self.args.arch.mlp.p_bg_generator_filters, norm_act_final=False 402 | ) 403 | 404 | def forward( 405 | self, z_global_all: torch.Tensor, phase_use_mode: bool = False 406 | ) -> Tuple: 407 | """ 408 | :param x: (bs, step, dim, 1, 1) 409 | :return: 410 | """ 411 | bs = z_global_all.size(0) 412 | 413 | bg_generator_inp = z_global_all 414 | 415 | p_bg_mean, p_bg_std = self.p_bg_generator(bg_generator_inp.view(bs, -1)).chunk( 416 | 2, 1 417 | ) 418 | p_bg_std = F.softplus(p_bg_std) 419 | 420 | p_bg_mean = p_bg_mean.view(bs, -1, 1, 1) 421 | p_bg_std = p_bg_std.view(bs, -1, 1, 1) 422 | 423 | if phase_use_mode: 424 | z_bg = p_bg_mean 425 | else: 426 | z_bg = Normal(p_bg_mean, p_bg_std).rsample() 427 | 428 | lv = [z_bg] 429 | 430 | ss = [p_bg_mean, p_bg_std] 431 | 432 | return lv, ss 433 | 434 | 435 | class BgDecoder(nn.Module): 436 | def __init__(self, args): 437 | super(BgDecoder, self).__init__() 438 | self.args = args 439 | 440 | self.p_bg_decoder = StackSubPixelNorm( 441 | self.args.z.z_bg_dim, 442 | self.args.arch.conv.p_bg_decoder_filters, 443 | self.args.arch.conv.p_bg_decoder_kernel_sizes, 444 | self.args.arch.conv.p_bg_decoder_upscales, 445 | self.args.arch.conv.p_bg_decoder_groups, 446 | norm_act_final=False, 447 | ) 448 | 449 | def forward(self, z_bg: torch.Tensor) -> List: 450 | """ 451 | :param x: (bs, dim, 1, 1) 452 | :return: 453 | """ 454 | bs = z_bg.size(0) 455 | 456 | bg = self.p_bg_decoder(z_bg).sigmoid() 457 | 458 | pa = [bg] 459 | 460 | return pa 461 | -------------------------------------------------------------------------------- /object_discovery/gnm/submodule.py: -------------------------------------------------------------------------------- 1 | from typing import List, Callable 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class StackConvNorm(nn.Module): 7 | def __init__( 8 | self, 9 | dim_inp: int, 10 | filters: List[int], 11 | kernel_sizes: List[int], 12 | strides: List[int], 13 | groupings: List[int], 14 | norm_act_final: bool, 15 | activation: Callable = nn.CELU, 16 | ): 17 | super(StackConvNorm, self).__init__() 18 | 19 | layers = [] 20 | 21 | dim_prev = dim_inp 22 | 23 | for i, (f, k, s) in enumerate(zip(filters, kernel_sizes, strides)): 24 | if s == 0: 25 | layers.append(nn.Conv2d(dim_prev, f, k, 1, 0)) 26 | else: 27 | layers.append(nn.Conv2d(dim_prev, f, k, s, (k - 1) // 2)) 28 | if i == len(filters) - 1 and norm_act_final == False: 29 | break 30 | layers.append(activation()) 31 | layers.append(nn.GroupNorm(groupings[i], f)) 32 | # layers.append(nn.BatchNorm2d(f)) 33 | dim_prev = f 34 | 35 | self.conv = nn.Sequential(*layers) 36 | 37 | def forward(self, x: torch.Tensor) -> torch.Tensor: 38 | x = self.conv(x) 39 | 40 | return x 41 | 42 | 43 | class StackSubPixelNorm(nn.Module): 44 | def __init__( 45 | self, 46 | dim_inp: int, 47 | filters: List[int], 48 | kernel_sizes: List[int], 49 | upscale: List[int], 50 | groupings: List[int], 51 | norm_act_final: bool, 52 | activation: Callable = nn.CELU, 53 | ): 54 | super(StackSubPixelNorm, self).__init__() 55 | 56 | layers = [] 57 | 58 | dim_prev = dim_inp 59 | 60 | for i, (f, k, u) in enumerate(zip(filters, kernel_sizes, upscale)): 61 | if u == 1: 62 | layers.append(nn.Conv2d(dim_prev, f, k, 1, (k - 1) // 2)) 63 | else: 64 | layers.append(nn.Conv2d(dim_prev, f * u ** 2, k, 1, (k - 1) // 2)) 65 | layers.append(nn.PixelShuffle(u)) 66 | if i == len(filters) - 1 and norm_act_final == False: 67 | break 68 | layers.append(activation()) 69 | layers.append(nn.GroupNorm(groupings[i], f)) 70 | dim_prev = f 71 | 72 | self.conv = nn.Sequential(*layers) 73 | 74 | def forward(self, x: torch.Tensor) -> torch.Tensor: 75 | x = self.conv(x) 76 | 77 | return x 78 | 79 | 80 | class StackMLP(nn.Module): 81 | def __init__( 82 | self, 83 | dim_inp: int, 84 | filters: List[int], 85 | norm_act_final: bool, 86 | activation: Callable = nn.CELU, 87 | phase_layer_norm: bool = True, 88 | ): 89 | super(StackMLP, self).__init__() 90 | 91 | layers = [] 92 | 93 | dim_prev = dim_inp 94 | 95 | for i, f in enumerate(filters): 96 | layers.append(nn.Linear(dim_prev, f)) 97 | if i == len(filters) - 1 and norm_act_final == False: 98 | break 99 | layers.append(activation()) 100 | if phase_layer_norm: 101 | layers.append(nn.LayerNorm(f)) 102 | dim_prev = f 103 | 104 | self.mlp = nn.Sequential(*layers) 105 | 106 | def forward(self, x: torch.Tensor) -> torch.Tensor: 107 | x = self.mlp(x) 108 | 109 | return x 110 | 111 | 112 | class ConvLSTMCell(nn.Module): 113 | def __init__(self, input_dim, hidden_dim, kernel_size=3, num_cell=4): 114 | super(ConvLSTMCell, self).__init__() 115 | 116 | self.input_dim = input_dim 117 | self.hidden_dim = hidden_dim 118 | 119 | self.kernel_size = kernel_size 120 | self.padding = (kernel_size - 1) // 2 121 | self.conv = nn.Conv2d( 122 | in_channels=self.input_dim + hidden_dim, 123 | out_channels=4 * self.hidden_dim, 124 | kernel_size=self.kernel_size, 125 | padding=self.padding, 126 | bias=True, 127 | ) 128 | 129 | self.register_parameter( 130 | "h_0", 131 | torch.nn.Parameter( 132 | torch.zeros(1, self.hidden_dim, num_cell, num_cell), requires_grad=True 133 | ), 134 | ) 135 | self.register_parameter( 136 | "c_0", 137 | torch.nn.Parameter( 138 | torch.zeros(1, self.hidden_dim, num_cell, num_cell), requires_grad=True 139 | ), 140 | ) 141 | 142 | def forward(self, x, h_c): 143 | h_cur, c_cur = h_c 144 | 145 | conv_inp = torch.cat([x, h_cur], dim=1) 146 | 147 | i, f, o, c = self.conv(conv_inp).split(self.hidden_dim, dim=1) 148 | 149 | i = torch.sigmoid(i) 150 | f = torch.sigmoid(f) 151 | c = torch.tanh(c) 152 | o = torch.sigmoid(o) 153 | 154 | c_next = f * c_cur + i * c 155 | h_next = o * torch.tanh(c_next) 156 | 157 | return h_next, c_next 158 | 159 | def init_hidden(self, batch_size): 160 | return ( 161 | self.h_0.expand(batch_size, -1, -1, -1), 162 | self.c_0.expand(batch_size, -1, -1, -1), 163 | ) 164 | -------------------------------------------------------------------------------- /object_discovery/gnm/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | border_width = 3 5 | 6 | rbox = torch.zeros(3, 42, 42) 7 | rbox[0, :border_width, :] = 1 8 | rbox[0, -border_width:, :] = 1 9 | rbox[0, :, :border_width] = 1 10 | rbox[0, :, -border_width:] = 1 11 | rbox = rbox.view(1, 3, 42, 42) 12 | 13 | gbox = torch.zeros(3, 42, 42) 14 | gbox[1, :border_width, :] = 1 15 | gbox[1, -border_width:, :] = 1 16 | gbox[1, :, :border_width] = 1 17 | gbox[1, :, -border_width:] = 1 18 | gbox = gbox.view(1, 3, 42, 42) 19 | 20 | 21 | def visualize( 22 | x, z_pres, z_where_scale, z_where_shift, 23 | ): 24 | """ 25 | x: (bs, 3, img_h, img_w) 26 | z_pres: (bs, 4, 4, 1) 27 | z_where_scale: (bs, 4, 4, 2) 28 | z_where_shift: (bs, 4, 4, 2) 29 | """ 30 | global gbox, rbox 31 | bs, _, img_h, img_w = x.size() 32 | z_pres = z_pres.view(-1, 1, 1, 1) 33 | num_obj = z_pres.size(0) // bs 34 | z_scale = z_where_scale.view(-1, 2) 35 | z_shift = z_where_shift.view(-1, 2) 36 | gbox = gbox.to(z_pres.device) 37 | rbox = rbox.to(z_pres.device) 38 | bbox = spatial_transform( 39 | z_pres * gbox + (1 - z_pres) * rbox, 40 | torch.cat((z_scale, z_shift), dim=1), 41 | torch.Size([bs * num_obj, 3, img_h, img_w]), 42 | inverse=True, 43 | ) 44 | 45 | return bbox 46 | 47 | 48 | def linear_schedule_tensor(step, start_step, end_step, start_value, end_value, device): 49 | if start_step < step < end_step: 50 | slope = (end_value - start_value) / (end_step - start_step) 51 | x = torch.tensor(start_value + slope * (step - start_step)).to(device) 52 | elif step >= end_step: 53 | x = torch.tensor(end_value).to(device) 54 | else: 55 | x = torch.tensor(start_value).to(device) 56 | 57 | return x 58 | 59 | 60 | def linear_schedule(step, start_step, end_step, start_value, end_value): 61 | if start_step < step < end_step: 62 | slope = (end_value - start_value) / (end_step - start_step) 63 | x = start_value + slope * (step - start_step) 64 | elif step >= end_step: 65 | x = end_value 66 | else: 67 | x = start_value 68 | 69 | return x 70 | 71 | 72 | def spatial_transform(image, z_where, out_dims, inverse=False): 73 | """ spatial transformer network used to scale and shift input according to z_where in: 74 | 1/ x -> x_att -- shapes (H, W) -> (attn_window, attn_window) -- thus inverse = False 75 | 2/ y_att -> y -- (attn_window, attn_window) -> (H, W) -- thus inverse = True 76 | inverting the affine transform as follows: A_inv ( A * image ) = image 77 | A = [R | T] where R is rotation component of angle alpha, T is [tx, ty] translation component 78 | A_inv rotates by -alpha and translates by [-tx, -ty] 79 | if x' = R * x + T --> x = R_inv * (x' - T) = R_inv * x - R_inv * T 80 | here, z_where is 3-dim [scale, tx, ty] so inverse transform is [1/scale, -tx/scale, -ty/scale] 81 | R = [[s, 0], -> R_inv = [[1/s, 0], 82 | [0, s]] [0, 1/s]] 83 | """ 84 | # 1. construct 2x3 affine matrix for each datapoint in the minibatch 85 | theta = torch.zeros(2, 3).repeat(image.shape[0], 1, 1).to(image.device) 86 | # set scaling 87 | theta[:, 0, 0] = z_where[:, 0] if not inverse else 1 / (z_where[:, 0] + 1e-15) 88 | theta[:, 1, 1] = z_where[:, 1] if not inverse else 1 / (z_where[:, 1] + 1e-15) 89 | 90 | # set translation 91 | theta[:, 0, -1] = ( 92 | z_where[:, 2] if not inverse else -z_where[:, 2] / (z_where[:, 0] + 1e-15) 93 | ) 94 | theta[:, 1, -1] = ( 95 | z_where[:, 3] if not inverse else -z_where[:, 3] / (z_where[:, 1] + 1e-15) 96 | ) 97 | # 2. construct sampling grid 98 | grid = F.affine_grid(theta, torch.Size(out_dims), align_corners=False) 99 | # 3. sample image from grid 100 | return F.grid_sample(image, grid, align_corners=False) 101 | 102 | 103 | def kl_divergence_bern_bern(q_pres_probs, p_pres_prob, eps=1e-15): 104 | """ 105 | Compute kl divergence 106 | :param z_pres_logits: (B, ...) 107 | :param prior_pres_prob: float 108 | :return: kl divergence, (B, ...) 109 | """ 110 | # z_pres_probs = torch.sigmoid(z_pres_logits) 111 | kl = q_pres_probs * ( 112 | torch.log(q_pres_probs + eps) - torch.log(p_pres_prob + eps) 113 | ) + (1 - q_pres_probs) * ( 114 | torch.log(1 - q_pres_probs + eps) - torch.log(1 - p_pres_prob + eps) 115 | ) 116 | 117 | return kl 118 | -------------------------------------------------------------------------------- /object_discovery/method.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from typing import Union 3 | from functools import partial 4 | import pytorch_lightning as pl 5 | import wandb 6 | import torch 7 | from torch import optim 8 | import torch.nn.functional as F 9 | from torchvision import utils as vutils 10 | from torchvision import transforms 11 | 12 | from object_discovery.slot_attention_model import SlotAttentionModel 13 | from object_discovery.slate_model import SLATE 14 | from object_discovery.gnm.gnm_model import GNM 15 | from object_discovery.utils import ( 16 | to_rgb_from_tensor, 17 | warm_and_decay_lr_scheduler, 18 | cosine_anneal, 19 | linear_warmup, 20 | visualize, 21 | compute_ari, 22 | sa_segment, 23 | rescale, 24 | get_largest_objects, 25 | cmap_tensor, 26 | ) 27 | from object_discovery.gnm.logging import gnm_log_validation_outputs 28 | 29 | 30 | class SlotAttentionMethod(pl.LightningModule): 31 | def __init__( 32 | self, 33 | model: Union[SlotAttentionModel, SLATE, GNM], 34 | datamodule: pl.LightningDataModule, 35 | params: Namespace, 36 | ): 37 | if type(params) is dict: 38 | params = Namespace(**params) 39 | super().__init__() 40 | self.model = model 41 | self.datamodule = datamodule 42 | self.params = params 43 | self.save_hyperparameters(params) 44 | 45 | def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor: 46 | return self.model(input, **kwargs) 47 | 48 | def step(self, batch): 49 | mask = None 50 | if self.params.model_type == "slate": 51 | self.tau = cosine_anneal( 52 | self.trainer.global_step, 53 | self.params.tau_start, 54 | self.params.tau_final, 55 | 0, 56 | self.params.tau_steps, 57 | ) 58 | loss = self.model.loss_function(batch, self.tau, self.params.hard) 59 | elif self.params.model_type == "sa": 60 | separation_tau = None 61 | if self.params.use_separation_loss: 62 | if self.params.separation_tau: 63 | separation_tau = self.params.separation_tau 64 | else: 65 | separation_tau = self.params.separation_tau_max_val - cosine_anneal( 66 | self.trainer.global_step, 67 | self.params.separation_tau_max_val, 68 | 0, 69 | self.params.separation_tau_start, 70 | self.params.separation_tau_end, 71 | ) 72 | area_tau = None 73 | if self.params.use_area_loss: 74 | if self.params.area_tau: 75 | area_tau = self.params.area_tau 76 | else: 77 | area_tau = self.params.area_tau_max_val - cosine_anneal( 78 | self.trainer.global_step, 79 | self.params.area_tau_max_val, 80 | 0, 81 | self.params.area_tau_start, 82 | self.params.area_tau_end, 83 | ) 84 | loss, mask = self.model.loss_function( 85 | batch, separation_tau=separation_tau, area_tau=area_tau 86 | ) 87 | elif self.params.model_type == "gnm": 88 | output = self.model.loss_function(batch, self.trainer.global_step) 89 | loss = { 90 | "loss": output["loss"], 91 | "elbo": output["elbo"], 92 | "kl": output["kl"], 93 | "loss_comp_ratio": output["rec_loss"] / output["kl"], 94 | } 95 | for k in output: 96 | if k.startswith("kl_") or k.endswith("_loss"): 97 | val = output[k] 98 | if isinstance(val, torch.Tensor): 99 | val = val.mean() 100 | loss[k] = val 101 | 102 | return loss, mask 103 | 104 | def training_step(self, batch, batch_idx): 105 | loss_dict, _ = self.step(batch) 106 | logs = {"train/" + key: val.item() for key, val in loss_dict.items()} 107 | self.log_dict(logs, sync_dist=True) 108 | return loss_dict["loss"] 109 | 110 | def sample_images(self, stage="validation"): 111 | dl = ( 112 | self.datamodule.val_dataloader() 113 | if stage == "validation" 114 | else self.datamodule.train_dataloader() 115 | ) 116 | perm = torch.randperm(self.params.batch_size) 117 | idx = perm[: self.params.n_samples] 118 | batch = next(iter(dl)) 119 | if type(batch) == list: 120 | batch = batch[0][idx] 121 | else: 122 | batch = batch[idx] 123 | 124 | if self.params.accelerator: 125 | batch = batch.to(self.device) 126 | if self.params.model_type == "sa": 127 | recon_combined, recons, masks, slots = self.model.forward(batch) 128 | # `masks` has shape [batch_size, num_entries, channels, height, width]. 129 | threshold = getattr(self.params, "sa_segmentation_threshold", 0.5) 130 | _, _, cmap_segmentation, cmap_segmentation_thresholded = sa_segment( 131 | masks, threshold 132 | ) 133 | 134 | # combine images in a nice way so we can display all outputs in one grid, output rescaled to be between 0 and 1 135 | out = torch.cat( 136 | [ 137 | to_rgb_from_tensor(batch.unsqueeze(1)), # original images 138 | to_rgb_from_tensor(recon_combined.unsqueeze(1)), # reconstructions 139 | cmap_segmentation.unsqueeze(1), 140 | cmap_segmentation_thresholded.unsqueeze(1), 141 | to_rgb_from_tensor(recons * masks + (1 - masks)), # each slot 142 | ], 143 | dim=1, 144 | ) 145 | 146 | batch_size, num_slots, C, H, W = recons.shape 147 | images = vutils.make_grid( 148 | out.view(batch_size * out.shape[1], C, H, W).cpu(), 149 | normalize=False, 150 | nrow=out.shape[1], 151 | ) 152 | elif self.params.model_type == "slate": 153 | recon, _, _, attns = self.model(batch, self.tau, True) 154 | gen_img = self.model.reconstruct_autoregressive(batch) 155 | vis_recon = visualize(batch, recon, gen_img, attns, N=32) 156 | images = vutils.make_grid( 157 | vis_recon, nrow=self.params.num_slots + 3, pad_value=0.2 158 | )[:, 2:-2, 2:-2] 159 | 160 | return images 161 | 162 | def validation_step(self, batch, batch_idx): 163 | if self.params.model_type == "gnm": 164 | output = self.model.loss_function( 165 | batch[0], self.trainer.global_step, generate_bbox=batch_idx == 0 166 | ) 167 | loss, images = gnm_log_validation_outputs( 168 | batch, batch_idx, output, self.trainer.is_global_zero 169 | ) 170 | for key, image in images.items(): 171 | self.logger.experiment.log( 172 | {f"validation/{key}": [wandb.Image(image)]}, commit=False 173 | ) 174 | elif type(batch) == list and self.model.supports_masks: 175 | loss, predicted_mask = self.step(batch[0]) 176 | predicted_mask = torch.permute(predicted_mask, [0, 3, 4, 2, 1]) 177 | # `predicted_mask` has shape [batch_size, height, width, channels, num_entries] 178 | predicted_mask = torch.squeeze(predicted_mask) 179 | batch_size, height, width, num_entries = predicted_mask.shape 180 | predicted_mask = torch.reshape( 181 | predicted_mask, [batch_size, height * width, num_entries] 182 | ) 183 | # `predicted_mask` has shape [batch_size, height * width, num_entries] 184 | # Scale from [0, 1] to [0, 255] to match the true mask. 185 | predicted_mask = (predicted_mask * 255).type(torch.int) 186 | ari = compute_ari( 187 | predicted_mask, 188 | batch[1], 189 | len(batch[0]), 190 | self.params.resolution[0], 191 | self.params.resolution[1], 192 | self.datamodule.max_num_entries, 193 | ) 194 | loss["ari"] = ari 195 | else: 196 | if type(batch) == list: 197 | batch = batch[0] 198 | loss, _ = self.step(batch) 199 | return loss 200 | 201 | def validation_epoch_end(self, outputs): 202 | logs = { 203 | "validation/" + key: torch.stack([x[key] for x in outputs]).float().mean() 204 | for key in outputs[0].keys() 205 | } 206 | self.log_dict(logs, sync_dist=True) 207 | 208 | def num_training_steps(self) -> int: 209 | """Total training steps inferred from datamodule and devices.""" 210 | # https://github.com/Lightning-AI/lightning/issues/5449#issuecomment-774265729 211 | if self.trainer.max_steps != -1: 212 | return self.trainer.max_steps 213 | 214 | limit_batches = self.trainer.limit_train_batches 215 | batches = len(self.datamodule.train_dataloader()) 216 | batches = ( 217 | min(batches, limit_batches) 218 | if isinstance(limit_batches, int) 219 | else int(limit_batches * batches) 220 | ) 221 | 222 | num_devices = max(1, self.trainer.num_devices) 223 | 224 | effective_accum = self.trainer.accumulate_grad_batches * num_devices 225 | return (batches // effective_accum) * self.trainer.max_epochs 226 | 227 | def configure_optimizers(self): 228 | if self.params.model_type == "slate": 229 | optimizer = optim.Adam( 230 | [ 231 | { 232 | "params": ( 233 | x[1] 234 | for x in self.model.named_parameters() 235 | if "dvae" in x[0] 236 | ), 237 | "lr": self.params.lr_dvae, 238 | }, 239 | { 240 | "params": ( 241 | x[1] 242 | for x in self.model.named_parameters() 243 | if "dvae" not in x[0] 244 | ), 245 | "lr": self.params.lr_main, 246 | }, 247 | ], 248 | weight_decay=self.params.weight_decay, 249 | ) 250 | elif self.params.model_type == "sa": 251 | optimizer = optim.Adam( 252 | self.model.parameters(), 253 | lr=self.params.lr_main, 254 | weight_decay=self.params.weight_decay, 255 | ) 256 | elif self.params.model_type == "gnm": 257 | optimizer = optim.RMSprop( 258 | self.model.parameters(), 259 | lr=self.params.lr_main, 260 | weight_decay=self.params.weight_decay, 261 | ) 262 | 263 | total_steps = self.num_training_steps() 264 | 265 | scheduler_lambda = None 266 | if self.params.scheduler == "warmup_and_decay": 267 | warmup_steps_pct = self.params.warmup_steps_pct 268 | decay_steps_pct = self.params.decay_steps_pct 269 | scheduler_lambda = partial( 270 | warm_and_decay_lr_scheduler, 271 | warmup_steps_pct=warmup_steps_pct, 272 | decay_steps_pct=decay_steps_pct, 273 | total_steps=total_steps, 274 | gamma=self.params.scheduler_gamma, 275 | ) 276 | elif self.params.scheduler == "warmup": 277 | scheduler_lambda = partial( 278 | linear_warmup, 279 | start_value=0.0, 280 | final_value=1.0, 281 | start_step=0, 282 | final_step=self.params.lr_warmup_steps, 283 | ) 284 | 285 | if scheduler_lambda is not None: 286 | if self.params.model_type == "slate": 287 | lr_lambda = [lambda o: 1, scheduler_lambda] 288 | elif self.params.model_type in ["sa", "gnm"]: 289 | lr_lambda = scheduler_lambda 290 | scheduler = optim.lr_scheduler.LambdaLR( 291 | optimizer=optimizer, lr_lambda=lr_lambda 292 | ) 293 | 294 | if self.params.model_type == "slate" and hasattr(self.params, "patience"): 295 | reduce_on_plateau = optim.lr_scheduler.ReduceLROnPlateau( 296 | optimizer=optimizer, 297 | mode="min", 298 | factor=0.5, 299 | patience=self.params.patience, 300 | ) 301 | return ( 302 | [optimizer], 303 | [ 304 | {"scheduler": scheduler, "interval": "step",}, 305 | { 306 | "scheduler": reduce_on_plateau, 307 | "interval": "epoch", 308 | "monitor": "validation/loss", 309 | }, 310 | ], 311 | ) 312 | 313 | return ( 314 | [optimizer], 315 | [{"scheduler": scheduler, "interval": "step",}], 316 | ) 317 | return optimizer 318 | 319 | def predict( 320 | self, 321 | image, 322 | do_transforms=False, 323 | debug=False, 324 | return_pil=False, 325 | return_slots=False, 326 | background_detection="spread_out", 327 | background_metric="area", 328 | ): 329 | """ 330 | `background_detection` options: 331 | - "spread_out" detects the background as pixels that appear in multiple 332 | slot masks. 333 | - "concentrated" assumes the background has been assigned 334 | to one slot. The slot with the largest distance between two points 335 | in its mask is assumed to be the background when using 336 | "concentrated." 337 | - When using "both," pixels detected using "spread_out" 338 | and the largest object detected using "concentrated" will be the 339 | background. The largest object detected using "concentrated" can be 340 | the background detected by "spread_out." 341 | `background_metric` is used when `background_detection` is set to "both" or 342 | "concentrated" to determine which object is largest. 343 | - "area" will find the object with the largest area 344 | - "distance" will find the object with the greatest distance between 345 | two points in that object. 346 | `return_slots` returns only the Slot Attention slots if using Slot 347 | Attention. 348 | 349 | """ 350 | assert background_detection in ["spread_out", "concentrated", "both"] 351 | if return_pil: 352 | assert debug or ( 353 | len(image.shape) == 3 or image.shape[0] == 1 354 | ), "Only one image can be passed when using `return_pil` and `debug=False`." 355 | 356 | if do_transforms: 357 | if getattr(self, "predict_transforms", True): 358 | current_transforms = [] 359 | if type(image) is not torch.Tensor: 360 | current_transforms.append(transforms.ToTensor()) 361 | if self.params.model_type == "sa": 362 | current_transforms.append(transforms.Lambda(rescale)) 363 | current_transforms.append( 364 | transforms.Resize( 365 | self.params.resolution, 366 | interpolation=transforms.InterpolationMode.NEAREST, 367 | ) 368 | ) 369 | self.predict_transforms = transforms.Compose(current_transforms) 370 | image = self.predict_transforms(image) 371 | 372 | if len(image.shape) == 3: 373 | # Add the batch_size dimension (set to 1) if input is a single image. 374 | image = image.unsqueeze(0) 375 | 376 | if self.params.model_type == "sa": 377 | recon_combined, recons, masks, slots = self.forward(image) 378 | if return_slots: 379 | return slots 380 | threshold = getattr(self.params, "sa_segmentation_threshold", 0.5) 381 | ( 382 | segmentation, 383 | segmentation_thresholded, 384 | cmap_segmentation, 385 | cmap_segmentation_thresholded, 386 | ) = sa_segment(masks, threshold) 387 | # `cmap_segmentation` and `cmap_segmentation_thresholded` have shape 388 | # [batch_size, channels=3, height, width]. 389 | if background_detection in ["concentrated", "both"]: 390 | if background_detection == "both": 391 | # `segmentation_thresholded` has pixels that are masked by 392 | # many slots set to 0 already. 393 | objects = F.one_hot(segmentation_thresholded.to(torch.int64)) 394 | else: 395 | objects = F.one_hot(segmentation.to(torch.int64)) 396 | # `objects` has shape [batch_size, height, width, num_objects] 397 | objects = objects.permute([0, 3, 1, 2]) 398 | # `objects` has shape [batch_size, num_objects, height, width] 399 | largest_object_idx = get_largest_objects( 400 | objects, metric=background_metric 401 | ) 402 | # `largest_object_idx` has shape [batch_size] 403 | largest_object = objects[ 404 | range(len(largest_object_idx)), largest_object_idx 405 | ] 406 | # `largest_object` has shape [batch_size, num_objects=1, height, width] 407 | largest_object = largest_object.squeeze(1).to(torch.bool) 408 | 409 | segmentation_background = ( 410 | segmentation_thresholded.clone() 411 | if background_detection == "both" 412 | else segmentation.clone() 413 | ) 414 | # Set the largest object to be index 0, the background. 415 | segmentation_background[largest_object] = 0 416 | # Recompute the colors now that `largest_object` is the background. 417 | cmap_segmentation_background = cmap_tensor(segmentation_background) 418 | elif background_detection == "spread_out": 419 | segmentation_background = segmentation_thresholded 420 | cmap_segmentation_background = cmap_segmentation_thresholded 421 | if debug: 422 | out = torch.cat( 423 | [ 424 | to_rgb_from_tensor(image.unsqueeze(1)), # original images 425 | to_rgb_from_tensor( 426 | recon_combined.unsqueeze(1) 427 | ), # reconstructions 428 | cmap_segmentation.unsqueeze(1), 429 | cmap_segmentation_background.unsqueeze(1), 430 | to_rgb_from_tensor(recons * masks + (1 - masks)), # each slot 431 | ], 432 | dim=1, 433 | ) 434 | batch_size, num_slots, C, H, W = recons.shape 435 | images = vutils.make_grid( 436 | out.view(batch_size * out.shape[1], C, H, W).cpu(), 437 | normalize=False, 438 | nrow=out.shape[1], 439 | ) 440 | to_return = images 441 | if return_pil: 442 | to_return = transforms.functional.to_pil_image(to_return.squeeze()) 443 | return to_return 444 | else: 445 | to_return = segmentation_background 446 | if return_pil: 447 | to_return = transforms.functional.to_pil_image( 448 | cmap_segmentation_background.squeeze() 449 | ) 450 | return to_return 451 | 452 | else: 453 | raise ValueError( 454 | "The predict function is only implemented for " 455 | + 'Slot Attention (params.model_type == "sa").' 456 | ) 457 | -------------------------------------------------------------------------------- /object_discovery/params.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | 4 | training_params = Namespace( 5 | model_type="sa", 6 | dataset="boxworld", 7 | num_slots=10, 8 | num_iterations=3, 9 | accumulate_grad_batches=1, 10 | data_root="data/box_world_dataset.h5", 11 | accelerator="gpu", 12 | devices=-1, 13 | max_steps=-1, 14 | num_sanity_val_steps=1, 15 | num_workers=4, 16 | is_logger_enabled=True, 17 | gradient_clip_val=0.0, 18 | n_samples=16, 19 | clevrtex_dataset_variant="full", 20 | alternative_crop=True, # Alternative crop for RAVENS dataset 21 | ) 22 | 23 | slot_attention_params = Namespace( 24 | lr_main=4e-4, 25 | batch_size=64, 26 | val_batch_size=64, 27 | resolution=(128, 128), 28 | slot_size=64, 29 | max_epochs=1000, 30 | max_steps=500000, 31 | weight_decay=0.0, 32 | mlp_hidden_size=128, 33 | scheduler="warmup_and_decay", 34 | scheduler_gamma=0.5, 35 | warmup_steps_pct=0.02, 36 | decay_steps_pct=0.2, 37 | use_separation_loss="entropy", 38 | separation_tau_start=60_000, 39 | separation_tau_end=65_000, 40 | separation_tau_max_val=0.003, 41 | separation_tau=None, 42 | boxworld_group_objects=True, 43 | use_area_loss=True, 44 | area_tau_start=60_000, 45 | area_tau_end=65_000, 46 | area_tau_max_val=0.006, 47 | area_tau=None, 48 | ) 49 | 50 | slate_params = Namespace( 51 | lr_dvae=3e-4, 52 | lr_main=1e-4, 53 | weight_decay=0.0, 54 | batch_size=50, 55 | val_batch_size=50, 56 | max_epochs=1000, 57 | patience=4, 58 | gradient_clip_val=1.0, 59 | resolution=(128, 128), 60 | num_dec_blocks=8, 61 | vocab_size=4096, 62 | d_model=192, 63 | num_heads=8, 64 | dropout=0.1, 65 | slot_size=192, 66 | mlp_hidden_size=192, 67 | tau_start=1.0, 68 | tau_final=0.1, 69 | tau_steps=30000, 70 | scheduler="warmup", 71 | lr_warmup_steps=30000, 72 | hard=False, 73 | ) 74 | 75 | gnm_params = Namespace( 76 | std=0.2, # 0.4 on CLEVR, 0.7 on ClevrTex, 0.2/0.3 on RAVENS 77 | z_what_dim=64, 78 | z_bg_dim=10, 79 | lr_main=1e-4, 80 | batch_size=64, 81 | val_batch_size=64, 82 | resolution=(128, 128), 83 | gradient_clip_val=1.0, 84 | max_epochs=1000, 85 | max_steps=5_000_000, 86 | weight_decay=0.0, 87 | scheduler=None, 88 | ) 89 | 90 | 91 | def merge_namespaces(one: Namespace, two: Namespace): 92 | return Namespace(**{**vars(one), **vars(two)}) 93 | -------------------------------------------------------------------------------- /object_discovery/segmentation_metrics.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/deepmind/multi_object_datasets/blob/9c40290c5e385d3efbccb34fb776b57d44721cba/segmentation_metrics.py 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | """Implementation of the adjusted Rand index.""" 6 | 7 | 8 | def adjusted_rand_index(true_mask, pred_mask, name="ari_score"): 9 | r"""Computes the adjusted Rand index (ARI), a clustering similarity score. 10 | This implementation ignores points with no cluster label in `true_mask` (i.e. 11 | those points for which `true_mask` is a zero vector). In the context of 12 | segmentation, that means this function can ignore points in an image 13 | corresponding to the background (i.e. not to an object). 14 | Args: 15 | true_mask: `Tensor` of shape [batch_size, n_points, n_true_groups]. 16 | The true cluster assignment encoded as one-hot. 17 | pred_mask: `Tensor` of shape [batch_size, n_points, n_pred_groups]. 18 | The predicted cluster assignment encoded as categorical probabilities. 19 | This function works on the argmax over axis 2. 20 | name: str. Name of this operation (defaults to "ari_score"). 21 | Returns: 22 | ARI scores as a tf.float32 `Tensor` of shape [batch_size]. 23 | Raises: 24 | ValueError: if n_points <= n_true_groups and n_points <= n_pred_groups. 25 | We've chosen not to handle the special cases that can occur when you have 26 | one cluster per datapoint (which would be unusual). 27 | References: 28 | Lawrence Hubert, Phipps Arabie. 1985. "Comparing partitions" 29 | https://link.springer.com/article/10.1007/BF01908075 30 | Wikipedia 31 | https://en.wikipedia.org/wiki/Rand_index 32 | Scikit Learn 33 | http://scikit-learn.org/stable/modules/generated/\ 34 | sklearn.metrics.adjusted_rand_score.html 35 | """ 36 | _, n_points, n_true_groups = true_mask.shape 37 | n_pred_groups = pred_mask.shape[-1] 38 | if n_points <= n_true_groups and n_points <= n_pred_groups: 39 | # This rules out the n_true_groups == n_pred_groups == n_points 40 | # corner case, and also n_true_groups == n_pred_groups == 0, since 41 | # that would imply n_points == 0 too. 42 | # The sklearn implementation has a corner-case branch which does 43 | # handle this. We chose not to support these cases to avoid counting 44 | # distinct clusters just to check if we have one cluster per datapoint. 45 | raise ValueError( 46 | "adjusted_rand_index requires n_groups < n_points. We don't handle " 47 | "the special cases that can occur when you have one cluster " 48 | "per datapoint." 49 | ) 50 | 51 | true_group_ids = torch.argmax(true_mask, -1) 52 | pred_group_ids = torch.argmax(pred_mask, -1) 53 | # We convert true and predicted clusters to one-hot ('oh') representations. 54 | true_mask_oh = true_mask.type(torch.float32) # already one-hot 55 | pred_mask_oh = F.one_hot(pred_group_ids, n_pred_groups).type(torch.float32) 56 | 57 | n_points = torch.sum(true_mask_oh, axis=[1, 2]).type(torch.float32) 58 | 59 | nij = torch.einsum("bji,bjk->bki", pred_mask_oh, true_mask_oh) 60 | a = torch.sum(nij, axis=1) 61 | b = torch.sum(nij, axis=2) 62 | 63 | rindex = torch.sum(nij * (nij - 1), axis=[1, 2]) 64 | aindex = torch.sum(a * (a - 1), axis=1) 65 | bindex = torch.sum(b * (b - 1), axis=1) 66 | expected_rindex = aindex * bindex / (n_points * (n_points - 1)) 67 | max_rindex = (aindex + bindex) / 2 68 | 69 | denominator = max_rindex - expected_rindex 70 | ari = (rindex - expected_rindex) / denominator 71 | # If a divide by 0 occurs, set the ARI value to 1. 72 | zeros_in_denominator = torch.argwhere(denominator == 0).flatten() 73 | if zeros_in_denominator.nelement() > 0: 74 | ari[zeros_in_denominator] = 1 75 | 76 | # The case where n_true_groups == n_pred_groups == 1 needs to be 77 | # special-cased (to return 1) as the above formula gives a divide-by-zero. 78 | # This might not work when true_mask has values that do not sum to one: 79 | both_single_cluster = torch.logical_and( 80 | _all_equal(true_group_ids), _all_equal(pred_group_ids) 81 | ) 82 | return torch.where(both_single_cluster, torch.ones_like(ari), ari) 83 | 84 | 85 | def _all_equal(values): 86 | """Whether values are all equal along the final axis.""" 87 | return torch.all(torch.eq(values, values[..., :1]), axis=-1) 88 | -------------------------------------------------------------------------------- /object_discovery/slate_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from object_discovery.utils import linear 5 | from object_discovery.transformer import PositionalEncoding, TransformerDecoder 6 | 7 | from object_discovery.slot_attention_model import SlotAttention 8 | 9 | 10 | def conv2d( 11 | in_channels, 12 | out_channels, 13 | kernel_size, 14 | stride=1, 15 | padding=0, 16 | dilation=1, 17 | groups=1, 18 | bias=True, 19 | padding_mode="zeros", 20 | weight_init="xavier", 21 | ): 22 | 23 | m = nn.Conv2d( 24 | in_channels, 25 | out_channels, 26 | kernel_size, 27 | stride, 28 | padding, 29 | dilation, 30 | groups, 31 | bias, 32 | padding_mode, 33 | ) 34 | 35 | if weight_init == "kaiming": 36 | nn.init.kaiming_uniform_(m.weight, nonlinearity="relu") 37 | else: 38 | nn.init.xavier_uniform_(m.weight) 39 | 40 | if bias: 41 | nn.init.zeros_(m.bias) 42 | 43 | return m 44 | 45 | 46 | def gumbel_softmax(logits, tau=1.0, hard=False, dim=-1): 47 | 48 | eps = torch.finfo(logits.dtype).tiny 49 | 50 | gumbels = -(torch.empty_like(logits).exponential_() + eps).log() 51 | gumbels = (logits + gumbels) / tau 52 | 53 | y_soft = F.softmax(gumbels, dim) 54 | 55 | if hard: 56 | index = y_soft.argmax(dim, keepdim=True) 57 | y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0) 58 | return y_hard - y_soft.detach() + y_soft 59 | else: 60 | return y_soft 61 | 62 | 63 | class Conv2dBlock(nn.Module): 64 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0): 65 | super().__init__() 66 | 67 | self.m = conv2d( 68 | in_channels, 69 | out_channels, 70 | kernel_size, 71 | stride, 72 | padding, 73 | bias=False, 74 | weight_init="kaiming", 75 | ) 76 | self.weight = nn.Parameter(torch.ones(out_channels)) 77 | self.bias = nn.Parameter(torch.zeros(out_channels)) 78 | 79 | def forward(self, x): 80 | x = self.m(x) 81 | return F.relu(F.group_norm(x, 1, self.weight, self.bias)) 82 | 83 | 84 | class dVAE(nn.Module): 85 | def __init__(self, vocab_size, img_channels=3): 86 | super().__init__() 87 | 88 | self.encoder = nn.Sequential( 89 | Conv2dBlock(img_channels, 64, 4, 4), 90 | Conv2dBlock(64, 64, 1, 1), 91 | Conv2dBlock(64, 64, 1, 1), 92 | Conv2dBlock(64, 64, 1, 1), 93 | Conv2dBlock(64, 64, 1, 1), 94 | Conv2dBlock(64, 64, 1, 1), 95 | Conv2dBlock(64, 64, 1, 1), 96 | conv2d(64, vocab_size, 1), 97 | ) 98 | 99 | self.decoder = nn.Sequential( 100 | Conv2dBlock(vocab_size, 64, 1), 101 | Conv2dBlock(64, 64, 3, 1, 1), 102 | Conv2dBlock(64, 64, 1, 1), 103 | Conv2dBlock(64, 64, 1, 1), 104 | Conv2dBlock(64, 64 * 2 * 2, 1), 105 | nn.PixelShuffle(2), 106 | Conv2dBlock(64, 64, 3, 1, 1), 107 | Conv2dBlock(64, 64, 1, 1), 108 | Conv2dBlock(64, 64, 1, 1), 109 | Conv2dBlock(64, 64 * 2 * 2, 1), 110 | nn.PixelShuffle(2), 111 | conv2d(64, img_channels, 1), 112 | ) 113 | 114 | 115 | class OneHotDictionary(nn.Module): 116 | def __init__(self, vocab_size, emb_size): 117 | super().__init__() 118 | self.dictionary = nn.Embedding(vocab_size, emb_size) 119 | 120 | def forward(self, x): 121 | """ 122 | x: B, N, vocab_size 123 | """ 124 | 125 | tokens = torch.argmax(x, dim=-1) # batch_size x N 126 | token_embs = self.dictionary(tokens) # batch_size x N x emb_size 127 | return token_embs 128 | 129 | 130 | class SLATE(nn.Module): 131 | def __init__( 132 | self, 133 | num_slots, 134 | vocab_size, 135 | d_model, 136 | resolution, 137 | num_iterations, 138 | slot_size, 139 | mlp_hidden_size, 140 | num_heads, 141 | dropout, 142 | num_dec_blocks, 143 | ): 144 | super().__init__() 145 | self.supports_masks = False 146 | 147 | self.num_slots = num_slots 148 | self.vocab_size = vocab_size 149 | self.d_model = d_model 150 | 151 | self.dvae = dVAE(vocab_size) 152 | 153 | self.positional_encoder = PositionalEncoding( 154 | 1 + (resolution[0] // 4) ** 2, d_model, dropout 155 | ) 156 | 157 | self.slot_attn = SlotAttention( 158 | d_model, 159 | num_iterations, 160 | num_slots, 161 | slot_size, 162 | mlp_hidden_size, 163 | do_input_mlp=True, 164 | ) 165 | 166 | self.dictionary = OneHotDictionary(vocab_size + 1, d_model) 167 | self.slot_proj = linear(slot_size, d_model, bias=False) 168 | 169 | self.tf_dec = TransformerDecoder( 170 | num_dec_blocks, (resolution[0] // 4) ** 2, d_model, num_heads, dropout 171 | ).train() 172 | 173 | self.out = linear(d_model, vocab_size, bias=False) 174 | 175 | def forward(self, image, tau, hard): 176 | """ 177 | image: batch_size x img_channels x H x W 178 | """ 179 | 180 | B, C, H, W = image.size() 181 | 182 | # dvae encode 183 | z_logits = F.log_softmax(self.dvae.encoder(image), dim=1) 184 | _, _, H_enc, W_enc = z_logits.size() 185 | z = gumbel_softmax(z_logits, tau, hard, dim=1) 186 | 187 | # dvae recon 188 | recon = self.dvae.decoder(z) 189 | mse = ((image - recon) ** 2).sum() / B 190 | 191 | # hard z 192 | z_hard = gumbel_softmax(z_logits, tau, True, dim=1).detach() 193 | 194 | # target tokens for transformer 195 | z_transformer_target = z_hard.permute(0, 2, 3, 1).flatten( 196 | start_dim=1, end_dim=2 197 | ) 198 | 199 | # add BOS token 200 | z_transformer_input = torch.cat( 201 | [torch.zeros_like(z_transformer_target[..., :1]), z_transformer_target], 202 | dim=-1, 203 | ) 204 | z_transformer_input = torch.cat( 205 | [torch.zeros_like(z_transformer_input[..., :1, :]), z_transformer_input], 206 | dim=-2, 207 | ) 208 | z_transformer_input[:, 0, 0] = 1.0 209 | 210 | # tokens to embeddings 211 | emb_input = self.dictionary(z_transformer_input) 212 | emb_input = self.positional_encoder(emb_input) 213 | 214 | # apply slot attention 215 | slots, attns = self.slot_attn(emb_input[:, 1:], return_attns=True) 216 | attns = attns.transpose(-1, -2) 217 | attns = ( 218 | attns.reshape(B, self.num_slots, 1, H_enc, W_enc) 219 | .repeat_interleave(H // H_enc, dim=-2) 220 | .repeat_interleave(W // W_enc, dim=-1) 221 | ) 222 | attns = image.unsqueeze(1) * attns + 1.0 - attns 223 | # `attns` has shape [batch_size, num_slots, channels, height, width] 224 | 225 | # apply transformer 226 | slots = self.slot_proj(slots) 227 | decoder_output = self.tf_dec(emb_input[:, :-1], slots) 228 | pred = self.out(decoder_output) 229 | cross_entropy = ( 230 | -(z_transformer_target * torch.log_softmax(pred, dim=-1)) 231 | .flatten(start_dim=1) 232 | .sum(-1) 233 | .mean() 234 | ) 235 | 236 | return (recon.clamp(0.0, 1.0), cross_entropy, mse, attns) 237 | 238 | def loss_function(self, input, tau, hard=False): 239 | _, cross_entropy, mse, _ = self.forward(input, tau, hard) 240 | return { 241 | "loss": cross_entropy + mse, 242 | "cross_entropy": cross_entropy, 243 | "mse": mse, 244 | "tau": torch.tensor(tau), 245 | } 246 | 247 | def reconstruct_autoregressive(self, image, eval=False): 248 | """ 249 | image: batch_size x img_channels x H x W 250 | """ 251 | 252 | gen_len = (image.size(-1) // 4) ** 2 253 | 254 | B, C, H, W = image.size() 255 | 256 | # dvae encode 257 | z_logits = F.log_softmax(self.dvae.encoder(image), dim=1) 258 | _, _, H_enc, W_enc = z_logits.size() 259 | 260 | # hard z 261 | z_hard = torch.argmax(z_logits, axis=1) 262 | z_hard = ( 263 | F.one_hot(z_hard, num_classes=self.vocab_size).permute(0, 3, 1, 2).float() 264 | ) 265 | one_hot_tokens = z_hard.permute(0, 2, 3, 1).flatten(start_dim=1, end_dim=2) 266 | 267 | # add BOS token 268 | one_hot_tokens = torch.cat( 269 | [torch.zeros_like(one_hot_tokens[..., :1]), one_hot_tokens], dim=-1 270 | ) 271 | one_hot_tokens = torch.cat( 272 | [torch.zeros_like(one_hot_tokens[..., :1, :]), one_hot_tokens], dim=-2 273 | ) 274 | one_hot_tokens[:, 0, 0] = 1.0 275 | 276 | # tokens to embeddings 277 | emb_input = self.dictionary(one_hot_tokens) 278 | emb_input = self.positional_encoder(emb_input) 279 | 280 | # slot attention 281 | slots, attns = self.slot_attn(emb_input[:, 1:], return_attns=True) 282 | attns = attns.transpose(-1, -2) 283 | attns = ( 284 | attns.reshape(B, self.num_slots, 1, H_enc, W_enc) 285 | .repeat_interleave(H // H_enc, dim=-2) 286 | .repeat_interleave(W // W_enc, dim=-1) 287 | ) 288 | attns = image.unsqueeze(1) * attns + (1.0 - attns) 289 | slots = self.slot_proj(slots) 290 | 291 | # generate image tokens auto-regressively 292 | z_gen = z_hard.new_zeros(0) 293 | z_transformer_input = z_hard.new_zeros(B, 1, self.vocab_size + 1) 294 | z_transformer_input[..., 0] = 1.0 295 | for t in range(gen_len): 296 | decoder_output = self.tf_dec( 297 | self.positional_encoder(self.dictionary(z_transformer_input)), slots 298 | ) 299 | z_next = F.one_hot( 300 | self.out(decoder_output)[:, -1:].argmax(dim=-1), self.vocab_size 301 | ) 302 | z_gen = torch.cat((z_gen, z_next), dim=1) 303 | z_transformer_input = torch.cat( 304 | [ 305 | z_transformer_input, 306 | torch.cat([torch.zeros_like(z_next[:, :, :1]), z_next], dim=-1), 307 | ], 308 | dim=1, 309 | ) 310 | 311 | z_gen = z_gen.transpose(1, 2).float().reshape(B, -1, H_enc, W_enc) 312 | recon_transformer = self.dvae.decoder(z_gen) 313 | 314 | if eval: 315 | return recon_transformer.clamp(0.0, 1.0), attns 316 | 317 | return recon_transformer.clamp(0.0, 1.0) 318 | -------------------------------------------------------------------------------- /object_discovery/slot_attention_model.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from object_discovery.utils import ( 8 | assert_shape, 9 | build_grid, 10 | conv_transpose_out_shape, 11 | linear, 12 | gru_cell, 13 | ) 14 | 15 | 16 | class SlotAttention(nn.Module): 17 | def __init__( 18 | self, 19 | in_features, 20 | num_iterations, 21 | num_slots, 22 | slot_size, 23 | mlp_hidden_size, 24 | epsilon=1e-8, 25 | do_input_mlp=False, 26 | ): 27 | super().__init__() 28 | self.in_features = in_features 29 | self.num_iterations = num_iterations 30 | self.num_slots = num_slots 31 | self.slot_size = slot_size # number of hidden layers in slot dimensions 32 | self.mlp_hidden_size = mlp_hidden_size 33 | self.epsilon = epsilon 34 | self.do_input_mlp = do_input_mlp 35 | 36 | if self.do_input_mlp: 37 | self.input_layer_norm = nn.LayerNorm(self.in_features) 38 | self.input_mlp = nn.Sequential( 39 | linear(self.in_features, self.in_features, weight_init="kaiming"), 40 | nn.ReLU(), 41 | linear(self.in_features, self.in_features), 42 | ) 43 | 44 | self.norm_inputs = nn.LayerNorm(self.in_features) 45 | self.norm_slots = nn.LayerNorm(self.slot_size) 46 | self.norm_mlp = nn.LayerNorm(self.slot_size) 47 | 48 | # Linear maps for the attention module. 49 | self.project_q = linear(self.slot_size, self.slot_size, bias=False) 50 | self.project_k = linear(self.in_features, self.slot_size, bias=False) 51 | self.project_v = linear(self.in_features, self.slot_size, bias=False) 52 | 53 | # Slot update functions. 54 | self.gru = gru_cell(self.slot_size, self.slot_size) 55 | self.mlp = nn.Sequential( 56 | linear(self.slot_size, self.mlp_hidden_size, weight_init="kaiming"), 57 | nn.ReLU(), 58 | linear(self.mlp_hidden_size, self.slot_size), 59 | ) 60 | 61 | # Parameters for Gaussian init (shared by all slots). 62 | self.slot_mu = nn.Parameter(torch.zeros(1, 1, self.slot_size)) 63 | self.slot_log_sigma = nn.Parameter(torch.zeros(1, 1, self.slot_size)) 64 | nn.init.xavier_uniform_(self.slot_mu, gain=nn.init.calculate_gain("linear")) 65 | nn.init.xavier_uniform_( 66 | self.slot_log_sigma, 67 | gain=nn.init.calculate_gain("linear"), 68 | ) 69 | 70 | def step(self, slots, k, v, batch_size, num_inputs): 71 | slots_prev = slots 72 | slots = self.norm_slots(slots) 73 | 74 | # Attention. 75 | q = self.project_q(slots) # Shape: [batch_size, num_slots, slot_size]. 76 | assert_shape(q.size(), (batch_size, self.num_slots, self.slot_size)) 77 | q *= self.slot_size**-0.5 # Normalization 78 | attn_logits = torch.matmul(k, q.transpose(2, 1)) 79 | attn = F.softmax(attn_logits, dim=-1) 80 | # `attn` has shape: [batch_size, num_inputs, num_slots]. 81 | assert_shape(attn.size(), (batch_size, num_inputs, self.num_slots)) 82 | attn_vis = attn.clone() 83 | 84 | # Weighted mean. 85 | attn = attn + self.epsilon 86 | attn = attn / torch.sum(attn, dim=-2, keepdim=True) 87 | updates = torch.matmul(attn.transpose(-1, -2), v) 88 | # `updates` has shape: [batch_size, num_slots, slot_size]. 89 | assert_shape(updates.size(), (batch_size, self.num_slots, self.slot_size)) 90 | 91 | # Slot update. 92 | # GRU is expecting inputs of size (N,H) so flatten batch and slots dimension 93 | slots = self.gru( 94 | updates.view(batch_size * self.num_slots, self.slot_size), 95 | slots_prev.view(batch_size * self.num_slots, self.slot_size), 96 | ) 97 | slots = slots.view(batch_size, self.num_slots, self.slot_size) 98 | assert_shape(slots.size(), (batch_size, self.num_slots, self.slot_size)) 99 | slots = slots + self.mlp(self.norm_mlp(slots)) 100 | assert_shape(slots.size(), (batch_size, self.num_slots, self.slot_size)) 101 | 102 | return slots, attn_vis 103 | 104 | def forward(self, inputs: torch.Tensor, return_attns=False): 105 | # `inputs` has shape [batch_size, num_inputs, inputs_size]. 106 | # `inputs` also has shape [batch_size, enc_height * enc_width, cnn_hidden_size]. 107 | batch_size, num_inputs, inputs_size = inputs.shape 108 | 109 | if self.do_input_mlp: 110 | inputs = self.input_mlp(self.input_layer_norm(inputs)) 111 | 112 | inputs = self.norm_inputs(inputs) # Apply layer norm to the input. 113 | k = self.project_k(inputs) # Shape: [batch_size, num_inputs, slot_size]. 114 | assert_shape(k.size(), (batch_size, num_inputs, self.slot_size)) 115 | v = self.project_v(inputs) # Shape: [batch_size, num_inputs, slot_size]. 116 | assert_shape(v.size(), (batch_size, num_inputs, self.slot_size)) 117 | 118 | # Initialize the slots. Shape: [batch_size, num_slots, slot_size]. 119 | slots_init = inputs.new_empty( 120 | batch_size, self.num_slots, self.slot_size 121 | ).normal_() 122 | slots = self.slot_mu + torch.exp(self.slot_log_sigma) * slots_init 123 | 124 | # Multiple rounds of attention. 125 | for _ in range(self.num_iterations): 126 | slots, attn_vis = self.step(slots, k, v, batch_size, num_inputs) 127 | # Detach slots from the current graph and compute one more step. 128 | # This is implicit slot attention from https://cocosci.princeton.edu/papers/chang2022objfixed.pdf 129 | slots, attn_vis = self.step(slots.detach(), k, v, batch_size, num_inputs) 130 | if return_attns: 131 | return slots, attn_vis 132 | else: 133 | return slots 134 | 135 | 136 | class SlotAttentionModel(nn.Module): 137 | def __init__( 138 | self, 139 | resolution: Tuple[int, int], 140 | num_slots: int, 141 | num_iterations, 142 | in_channels: int = 3, 143 | kernel_size: int = 5, 144 | slot_size: int = 64, 145 | mlp_hidden_size: int = 128, 146 | hidden_dims: Tuple[int, ...] = (64, 64, 64, 64), 147 | decoder_resolution: Tuple[int, int] = (8, 8), 148 | use_separation_loss=False, 149 | use_area_loss=False, 150 | ): 151 | super().__init__() 152 | self.supports_masks = True 153 | 154 | self.resolution = resolution 155 | self.num_slots = num_slots 156 | self.num_iterations = num_iterations 157 | self.in_channels = in_channels 158 | self.kernel_size = kernel_size 159 | self.slot_size = slot_size 160 | self.mlp_hidden_size = mlp_hidden_size 161 | self.hidden_dims = hidden_dims 162 | self.decoder_resolution = decoder_resolution 163 | self.out_features = self.hidden_dims[-1] 164 | self.use_separation_loss = use_separation_loss 165 | self.use_area_loss = use_area_loss 166 | 167 | modules = [] 168 | channels = self.in_channels 169 | # Build Encoder 170 | for h_dim in self.hidden_dims: 171 | modules.append( 172 | nn.Sequential( 173 | nn.Conv2d( 174 | channels, 175 | out_channels=h_dim, 176 | kernel_size=self.kernel_size, 177 | stride=1, 178 | padding=self.kernel_size // 2, 179 | ), 180 | nn.LeakyReLU(), 181 | ) 182 | ) 183 | channels = h_dim 184 | 185 | self.encoder = nn.Sequential(*modules) 186 | self.encoder_pos_embedding = SoftPositionEmbed( 187 | self.in_channels, self.out_features, resolution 188 | ) 189 | self.encoder_out_layer = nn.Sequential( 190 | nn.Linear(self.out_features, self.out_features), 191 | nn.LeakyReLU(), 192 | nn.Linear(self.out_features, self.out_features), 193 | ) 194 | 195 | # Build Decoder 196 | modules = [] 197 | 198 | in_size = decoder_resolution[0] 199 | out_size = in_size 200 | 201 | for i in range(len(self.hidden_dims) - 1, -1, -1): 202 | modules.append( 203 | nn.Sequential( 204 | nn.ConvTranspose2d( 205 | self.hidden_dims[i], 206 | self.hidden_dims[i - 1], 207 | kernel_size=5, 208 | stride=2, 209 | padding=2, 210 | output_padding=1, 211 | ), 212 | nn.LeakyReLU(), 213 | ) 214 | ) 215 | out_size = conv_transpose_out_shape(out_size, 2, 2, 5, 1) 216 | 217 | assert_shape( 218 | resolution, 219 | (out_size, out_size), 220 | message="Output shape of decoder did not match input resolution. Try changing `decoder_resolution`.", 221 | ) 222 | 223 | # same convolutions 224 | modules.append( 225 | nn.Sequential( 226 | nn.ConvTranspose2d( 227 | self.out_features, 228 | self.out_features, 229 | kernel_size=5, 230 | stride=1, 231 | padding=2, 232 | output_padding=0, 233 | ), 234 | nn.LeakyReLU(), 235 | nn.ConvTranspose2d( 236 | self.out_features, 237 | 4, 238 | kernel_size=3, 239 | stride=1, 240 | padding=1, 241 | output_padding=0, 242 | ), 243 | ) 244 | ) 245 | 246 | assert_shape(resolution, (out_size, out_size), message="") 247 | 248 | self.decoder = nn.Sequential(*modules) 249 | self.decoder_pos_embedding = SoftPositionEmbed( 250 | self.in_channels, self.out_features, self.decoder_resolution 251 | ) 252 | 253 | self.slot_attention = SlotAttention( 254 | in_features=self.out_features, 255 | num_iterations=self.num_iterations, 256 | num_slots=self.num_slots, 257 | slot_size=self.slot_size, 258 | mlp_hidden_size=self.mlp_hidden_size, 259 | ) 260 | 261 | def forward(self, x): 262 | batch_size, num_channels, height, width = x.shape 263 | encoder_out = self.encoder(x) 264 | encoder_out = self.encoder_pos_embedding(encoder_out) 265 | # `encoder_out` has shape: [batch_size, filter_size, height, width] 266 | encoder_out = torch.flatten(encoder_out, start_dim=2, end_dim=3) 267 | # `encoder_out` has shape: [batch_size, filter_size, height*width] 268 | encoder_out = encoder_out.permute(0, 2, 1) 269 | encoder_out = self.encoder_out_layer(encoder_out) 270 | # `encoder_out` has shape: [batch_size, height*width, filter_size] 271 | 272 | slots = self.slot_attention(encoder_out) 273 | assert_shape(slots.size(), (batch_size, self.num_slots, self.slot_size)) 274 | # `slots` has shape: [batch_size, num_slots, slot_size]. 275 | batch_size, num_slots, slot_size = slots.shape 276 | 277 | slots = slots.view(batch_size * num_slots, slot_size, 1, 1) 278 | decoder_in = slots.repeat( 279 | 1, 1, self.decoder_resolution[0], self.decoder_resolution[1] 280 | ) 281 | 282 | out = self.decoder_pos_embedding(decoder_in) 283 | out = self.decoder(out) 284 | # `out` has shape: [batch_size*num_slots, num_channels+1, height, width]. 285 | assert_shape( 286 | out.size(), (batch_size * num_slots, num_channels + 1, height, width) 287 | ) 288 | 289 | # Perform researchers' `unstack_and_split` using `torch.view`. 290 | out = out.view(batch_size, num_slots, num_channels + 1, height, width) 291 | recons = out[:, :, :num_channels, :, :] 292 | masks = out[:, :, -1:, :, :] 293 | # Normalize alpha masks over slots. 294 | masks = F.softmax(masks, dim=1) 295 | recon_combined = torch.sum(recons * masks, dim=1) 296 | return recon_combined, recons, masks, slots 297 | 298 | def loss_function( 299 | self, input, separation_tau=None, area_tau=None, global_step=None 300 | ): 301 | recon_combined, recons, masks, slots = self.forward(input) 302 | # `masks` has shape [batch_size, num_entries, channels, height, width] 303 | mse_loss = F.mse_loss(recon_combined, input) 304 | 305 | loss = mse_loss 306 | to_return = {"loss": loss} 307 | if self.use_area_loss: 308 | max_num_objects = self.num_slots - 1 309 | width, height = masks.shape[-1], masks.shape[-2] 310 | one_object_area = 9 * 9 311 | background_size = width * height - max_num_objects * one_object_area 312 | slot_area = torch.sum(masks.squeeze(), dim=(-1, -2)) 313 | 314 | batch_size, num_slots = slot_area.shape[0], slot_area.shape[1] 315 | area_loss = 0 316 | for batch_idx in range(batch_size): 317 | for slot_idx in range(num_slots): 318 | area = slot_area[batch_idx, slot_idx] 319 | area_loss += min( 320 | (area - 2 * one_object_area) ** 2, 321 | max(background_size - area, 0) * (background_size - area), 322 | ) / (2 * one_object_area)**2 323 | 324 | area_loss /= batch_size * num_slots 325 | loss += area_loss * area_tau 326 | to_return["loss"] = loss 327 | to_return["area_loss"] = area_loss 328 | to_return["area_tau"] = torch.tensor(area_tau) 329 | if self.use_separation_loss: 330 | if self.use_separation_loss == "max": 331 | separation_loss = 1 - torch.mean(torch.max(masks, dim=1).values.float()) 332 | elif self.use_separation_loss == "entropy": 333 | entropy = torch.special.entr(masks + 1e-8) 334 | separation_loss = torch.mean(entropy.sum(dim=1)) 335 | 336 | loss += separation_loss * separation_tau 337 | to_return["loss"] = loss 338 | to_return["mse_loss"] = mse_loss 339 | to_return["separation_loss"] = separation_loss 340 | to_return["separation_tau"] = torch.tensor(separation_tau) 341 | return to_return, masks 342 | 343 | 344 | class SoftPositionEmbed(nn.Module): 345 | def __init__( 346 | self, num_channels: int, hidden_size: int, resolution: Tuple[int, int] 347 | ): 348 | super().__init__() 349 | self.dense = nn.Linear(in_features=num_channels + 1, out_features=hidden_size) 350 | self.register_buffer("grid", build_grid(resolution)) 351 | 352 | def forward(self, inputs: torch.Tensor): 353 | # Permute to move num_channels to 1st dimension. PyTorch layers need 354 | # num_channels as 1st dimension, tensorflow needs num_channels last. 355 | emb_proj = self.dense(self.grid).permute(0, 3, 1, 2) 356 | assert_shape(inputs.shape[1:], emb_proj.shape[1:]) 357 | return inputs + emb_proj 358 | -------------------------------------------------------------------------------- /object_discovery/train.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning.loggers as pl_loggers 2 | from pytorch_lightning import Trainer 3 | from pytorch_lightning.callbacks import LearningRateMonitor 4 | 5 | from object_discovery.data import ( 6 | CLEVRDataModule, 7 | Shapes3dDataModule, 8 | RAVENSRobotDataModule, 9 | SketchyDataModule, 10 | ClevrTexDataModule, 11 | BoxWorldDataModule, 12 | TetrominoesDataModule, 13 | ) 14 | from object_discovery.method import SlotAttentionMethod 15 | from object_discovery.slot_attention_model import SlotAttentionModel 16 | from object_discovery.slate_model import SLATE 17 | from object_discovery.params import ( 18 | merge_namespaces, 19 | training_params, 20 | slot_attention_params, 21 | slate_params, 22 | gnm_params, 23 | ) 24 | from object_discovery.utils import ImageLogCallback 25 | from object_discovery.gnm.gnm_model import GNM, hyperparam_anneal 26 | from object_discovery.gnm.config import get_arrow_args 27 | 28 | 29 | def main(params=None): 30 | if params is None: 31 | params = training_params 32 | if params.model_type == "slate": 33 | params = merge_namespaces(params, slate_params) 34 | elif params.model_type == "sa": 35 | params = merge_namespaces(params, slot_attention_params) 36 | elif params.model_type == "gnm": 37 | params = merge_namespaces(params, gnm_params) 38 | 39 | assert params.num_slots > 1, "Must have at least 2 slots." 40 | params.neg_1_to_pos_1_scale = params.model_type == "sa" 41 | 42 | if params.dataset == "clevr": 43 | datamodule = CLEVRDataModule( 44 | data_root=params.data_root, 45 | max_n_objects=params.num_slots - 1, 46 | train_batch_size=params.batch_size, 47 | val_batch_size=params.val_batch_size, 48 | num_workers=params.num_workers, 49 | resolution=params.resolution, 50 | neg_1_to_pos_1_scale=params.neg_1_to_pos_1_scale, 51 | ) 52 | elif params.dataset == "shapes3d": 53 | assert params.resolution == ( 54 | 64, 55 | 64, 56 | ), "shapes3d dataset requires 64x64 resolution" 57 | datamodule = Shapes3dDataModule( 58 | data_root=params.data_root, 59 | train_batch_size=params.batch_size, 60 | val_batch_size=params.val_batch_size, 61 | num_workers=params.num_workers, 62 | neg_1_to_pos_1_scale=params.neg_1_to_pos_1_scale, 63 | ) 64 | elif params.dataset == "ravens": 65 | datamodule = RAVENSRobotDataModule( 66 | data_root=params.data_root, 67 | # `max_n_objects` is the number of objects on the table. It does 68 | # not count the background, table, robot, and robot arm. 69 | max_n_objects=params.num_slots - 1 70 | if params.alternative_crop 71 | else params.num_slots - 4, 72 | train_batch_size=params.batch_size, 73 | val_batch_size=params.val_batch_size, 74 | num_workers=params.num_workers, 75 | resolution=params.resolution, 76 | alternative_crop=params.alternative_crop, 77 | neg_1_to_pos_1_scale=params.neg_1_to_pos_1_scale, 78 | ) 79 | elif params.dataset == "sketchy": 80 | assert params.resolution == ( 81 | 128, 82 | 128, 83 | ), "sketchy dataset requires 128x128 resolution" 84 | datamodule = SketchyDataModule( 85 | data_root=params.data_root, 86 | train_batch_size=params.batch_size, 87 | val_batch_size=params.val_batch_size, 88 | num_workers=params.num_workers, 89 | neg_1_to_pos_1_scale=params.neg_1_to_pos_1_scale, 90 | ) 91 | elif params.dataset == "clevrtex": 92 | datamodule = ClevrTexDataModule( 93 | data_root=params.data_root, 94 | train_batch_size=params.batch_size, 95 | val_batch_size=params.val_batch_size, 96 | num_workers=params.num_workers, 97 | resolution=params.resolution, 98 | neg_1_to_pos_1_scale=params.neg_1_to_pos_1_scale, 99 | dataset_variant=params.clevrtex_dataset_variant, 100 | max_n_objects=params.num_slots - 1, 101 | ) 102 | elif params.dataset == "boxworld": 103 | max_n_objects = params.num_slots - 1 104 | datamodule = BoxWorldDataModule( 105 | data_root=params.data_root, 106 | max_n_objects=2 * max_n_objects 107 | if params.boxworld_group_objects 108 | else max_n_objects, 109 | train_batch_size=params.batch_size, 110 | val_batch_size=params.val_batch_size, 111 | num_workers=params.num_workers, 112 | resolution=params.resolution, 113 | neg_1_to_pos_1_scale=params.neg_1_to_pos_1_scale, 114 | ) 115 | elif params.dataset == "tetrominoes": 116 | datamodule = TetrominoesDataModule( 117 | data_root=params.data_root, 118 | max_n_objects=params.num_slots - 1, 119 | train_batch_size=params.batch_size, 120 | val_batch_size=params.val_batch_size, 121 | num_workers=params.num_workers, 122 | resolution=params.resolution, 123 | neg_1_to_pos_1_scale=params.neg_1_to_pos_1_scale, 124 | ) 125 | 126 | print( 127 | f"Training set size (images must have {params.num_slots - 1} objects):", 128 | len(datamodule.train_dataset), 129 | ) 130 | 131 | if params.model_type == "sa": 132 | model = SlotAttentionModel( 133 | resolution=params.resolution, 134 | num_slots=params.num_slots, 135 | num_iterations=params.num_iterations, 136 | slot_size=params.slot_size, 137 | use_separation_loss=params.use_separation_loss, 138 | use_area_loss=params.use_area_loss, 139 | ) 140 | elif params.model_type == "slate": 141 | model = SLATE( 142 | num_slots=params.num_slots, 143 | vocab_size=params.vocab_size, 144 | d_model=params.d_model, 145 | resolution=params.resolution, 146 | num_iterations=params.num_iterations, 147 | slot_size=params.slot_size, 148 | mlp_hidden_size=params.mlp_hidden_size, 149 | num_heads=params.num_heads, 150 | dropout=params.dropout, 151 | num_dec_blocks=params.num_dec_blocks, 152 | ) 153 | elif params.model_type == "gnm": 154 | model_params = get_arrow_args() 155 | model_params = hyperparam_anneal(model_params, 0) 156 | params = merge_namespaces(params, model_params) 157 | params.const.likelihood_sigma = params.std 158 | params.z.z_what_dim = params.z_what_dim 159 | params.z.z_bg_dim = params.z_bg_dim 160 | model = GNM(params) 161 | 162 | method = SlotAttentionMethod(model=model, datamodule=datamodule, params=params) 163 | 164 | logger = pl_loggers.WandbLogger(project="slot-attention-clevr6") 165 | 166 | callbacks = [LearningRateMonitor("step")] 167 | if params.model_type != "gnm": 168 | callbacks.append(ImageLogCallback()) 169 | 170 | trainer = Trainer( 171 | logger=logger if params.is_logger_enabled else False, 172 | accelerator=params.accelerator, 173 | num_sanity_val_steps=params.num_sanity_val_steps, 174 | devices=params.devices, 175 | max_epochs=params.max_epochs, 176 | max_steps=params.max_steps, 177 | accumulate_grad_batches=params.accumulate_grad_batches, 178 | gradient_clip_val=params.gradient_clip_val, 179 | log_every_n_steps=50, 180 | callbacks=callbacks if params.is_logger_enabled else [], 181 | ) 182 | trainer.fit(method, datamodule=datamodule) 183 | 184 | 185 | if __name__ == "__main__": 186 | main() 187 | -------------------------------------------------------------------------------- /object_discovery/transformer.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/singhgautam/slate/blob/6afe75211a79ef7327071ce198f4427928418bf5/transformer.py 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from object_discovery.utils import linear 6 | 7 | 8 | class MultiHeadAttention(nn.Module): 9 | def __init__(self, d_model, num_heads, dropout=0.0, gain=1.0): 10 | super().__init__() 11 | 12 | assert d_model % num_heads == 0, "d_model must be divisible by num_heads" 13 | self.d_model = d_model 14 | self.num_heads = num_heads 15 | 16 | self.attn_dropout = nn.Dropout(dropout) 17 | self.output_dropout = nn.Dropout(dropout) 18 | 19 | self.proj_q = linear(d_model, d_model, bias=False) 20 | self.proj_k = linear(d_model, d_model, bias=False) 21 | self.proj_v = linear(d_model, d_model, bias=False) 22 | self.proj_o = linear(d_model, d_model, bias=False, gain=gain) 23 | 24 | def forward(self, q, k, v, attn_mask=None): 25 | """ 26 | q: batch_size x target_len x d_model 27 | k: batch_size x source_len x d_model 28 | v: batch_size x source_len x d_model 29 | attn_mask: target_len x source_len 30 | return: batch_size x target_len x d_model 31 | """ 32 | B, T, _ = q.shape 33 | _, S, _ = k.shape 34 | 35 | q = self.proj_q(q).view(B, T, self.num_heads, -1).transpose(1, 2) 36 | k = self.proj_k(k).view(B, S, self.num_heads, -1).transpose(1, 2) 37 | v = self.proj_v(v).view(B, S, self.num_heads, -1).transpose(1, 2) 38 | 39 | q = q * (q.shape[-1] ** (-0.5)) 40 | attn = torch.matmul(q, k.transpose(-1, -2)) 41 | 42 | if attn_mask is not None: 43 | attn = attn.masked_fill(attn_mask, float("-inf")) 44 | 45 | attn = F.softmax(attn, dim=-1) 46 | attn = self.attn_dropout(attn) 47 | 48 | output = torch.matmul(attn, v).transpose(1, 2).reshape(B, T, -1) 49 | output = self.proj_o(output) 50 | output = self.output_dropout(output) 51 | return output 52 | 53 | 54 | class PositionalEncoding(nn.Module): 55 | def __init__(self, max_len, d_model, dropout=0.1): 56 | super().__init__() 57 | self.dropout = nn.Dropout(dropout) 58 | self.pe = nn.Parameter(torch.zeros(1, max_len, d_model), requires_grad=True) 59 | nn.init.trunc_normal_(self.pe) 60 | 61 | def forward(self, input): 62 | """ 63 | input: batch_size x seq_len x d_model 64 | return: batch_size x seq_len x d_model 65 | """ 66 | T = input.shape[1] 67 | return self.dropout(input + self.pe[:, :T]) 68 | 69 | 70 | class TransformerEncoderBlock(nn.Module): 71 | def __init__(self, d_model, num_heads, dropout=0.0, gain=1.0, is_first=False): 72 | super().__init__() 73 | 74 | self.is_first = is_first 75 | 76 | self.attn_layer_norm = nn.LayerNorm(d_model) 77 | self.attn = MultiHeadAttention(d_model, num_heads, dropout, gain) 78 | 79 | self.ffn_layer_norm = nn.LayerNorm(d_model) 80 | self.ffn = nn.Sequential( 81 | linear(d_model, 4 * d_model, weight_init="kaiming"), 82 | nn.ReLU(), 83 | linear(4 * d_model, d_model, gain=gain), 84 | nn.Dropout(dropout), 85 | ) 86 | 87 | def forward(self, input): 88 | """ 89 | input: batch_size x source_len x d_model 90 | return: batch_size x source_len x d_model 91 | """ 92 | if self.is_first: 93 | input = self.attn_layer_norm(input) 94 | x = self.attn(input, input, input) 95 | input = input + x 96 | else: 97 | x = self.attn_layer_norm(input) 98 | x = self.attn(x, x, x) 99 | input = input + x 100 | 101 | x = self.ffn_layer_norm(input) 102 | x = self.ffn(x) 103 | return input + x 104 | 105 | 106 | class TransformerEncoder(nn.Module): 107 | def __init__(self, num_blocks, d_model, num_heads, dropout=0.0): 108 | super().__init__() 109 | 110 | if num_blocks > 0: 111 | gain = (2 * num_blocks) ** (-0.5) 112 | self.blocks = nn.ModuleList( 113 | [ 114 | TransformerEncoderBlock( 115 | d_model, num_heads, dropout, gain, is_first=True 116 | ) 117 | ] 118 | + [ 119 | TransformerEncoderBlock( 120 | d_model, num_heads, dropout, gain, is_first=False 121 | ) 122 | for _ in range(num_blocks - 1) 123 | ] 124 | ) 125 | else: 126 | self.blocks = nn.ModuleList() 127 | 128 | self.layer_norm = nn.LayerNorm(d_model) 129 | 130 | def forward(self, input): 131 | """ 132 | input: batch_size x source_len x d_model 133 | return: batch_size x source_len x d_model 134 | """ 135 | for block in self.blocks: 136 | input = block(input) 137 | 138 | return self.layer_norm(input) 139 | 140 | 141 | class TransformerDecoderBlock(nn.Module): 142 | def __init__( 143 | self, max_len, d_model, num_heads, dropout=0.0, gain=1.0, is_first=False 144 | ): 145 | super().__init__() 146 | 147 | self.is_first = is_first 148 | 149 | self.self_attn_layer_norm = nn.LayerNorm(d_model) 150 | self.self_attn = MultiHeadAttention(d_model, num_heads, dropout, gain) 151 | 152 | mask = torch.triu(torch.ones((max_len, max_len), dtype=torch.bool), diagonal=1) 153 | self.self_attn_mask = nn.Parameter(mask, requires_grad=False) 154 | 155 | self.encoder_decoder_attn_layer_norm = nn.LayerNorm(d_model) 156 | self.encoder_decoder_attn = MultiHeadAttention( 157 | d_model, num_heads, dropout, gain 158 | ) 159 | 160 | self.ffn_layer_norm = nn.LayerNorm(d_model) 161 | self.ffn = nn.Sequential( 162 | linear(d_model, 4 * d_model, weight_init="kaiming"), 163 | nn.ReLU(), 164 | linear(4 * d_model, d_model, gain=gain), 165 | nn.Dropout(dropout), 166 | ) 167 | 168 | def forward(self, input, encoder_output): 169 | """ 170 | input: batch_size x target_len x d_model 171 | encoder_output: batch_size x source_len x d_model 172 | return: batch_size x target_len x d_model 173 | """ 174 | T = input.shape[1] 175 | 176 | if self.is_first: 177 | input = self.self_attn_layer_norm(input) 178 | x = self.self_attn(input, input, input, self.self_attn_mask[:T, :T]) 179 | input = input + x 180 | else: 181 | x = self.self_attn_layer_norm(input) 182 | x = self.self_attn(x, x, x, self.self_attn_mask[:T, :T]) 183 | input = input + x 184 | 185 | x = self.encoder_decoder_attn_layer_norm(input) 186 | x = self.encoder_decoder_attn(x, encoder_output, encoder_output) 187 | input = input + x 188 | 189 | x = self.ffn_layer_norm(input) 190 | x = self.ffn(x) 191 | return input + x 192 | 193 | 194 | class TransformerDecoder(nn.Module): 195 | def __init__(self, num_blocks, max_len, d_model, num_heads, dropout=0.0): 196 | super().__init__() 197 | 198 | if num_blocks > 0: 199 | gain = (3 * num_blocks) ** (-0.5) 200 | self.blocks = nn.ModuleList( 201 | [ 202 | TransformerDecoderBlock( 203 | max_len, d_model, num_heads, dropout, gain, is_first=True 204 | ) 205 | ] 206 | + [ 207 | TransformerDecoderBlock( 208 | max_len, d_model, num_heads, dropout, gain, is_first=False 209 | ) 210 | for _ in range(num_blocks - 1) 211 | ] 212 | ) 213 | else: 214 | self.blocks = nn.ModuleList() 215 | 216 | self.layer_norm = nn.LayerNorm(d_model) 217 | 218 | def forward(self, input, encoder_output): 219 | """ 220 | input: batch_size x target_len x d_model 221 | encoder_output: batch_size x source_len x d_model 222 | return: batch_size x target_len x d_model 223 | """ 224 | for block in self.blocks: 225 | input = block(input, encoder_output) 226 | 227 | return self.layer_norm(input) 228 | -------------------------------------------------------------------------------- /object_discovery/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple, Union 2 | import math 3 | import itertools 4 | 5 | from matplotlib.colors import ListedColormap 6 | from matplotlib.cm import get_cmap 7 | from PIL import ImageFont 8 | from scipy.spatial import ConvexHull 9 | from scipy.spatial._qhull import QhullError 10 | from scipy.spatial.distance import cdist 11 | 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | from pytorch_lightning import Callback 16 | from object_discovery.segmentation_metrics import adjusted_rand_index 17 | 18 | import wandb 19 | 20 | 21 | FNT = ImageFont.truetype("dejavu/DejaVuSansMono.ttf", 7) 22 | CMAPSPEC = ListedColormap( 23 | ["black", "red", "green", "blue", "yellow", "cyan", "magenta"] 24 | + list( 25 | itertools.chain.from_iterable( 26 | itertools.chain.from_iterable( 27 | zip( 28 | [get_cmap("tab20b")(i) for i in range(i, 20, 4)], 29 | [get_cmap("tab20c")(i) for i in range(i, 20, 4)], 30 | ) 31 | ) 32 | for i in range(4) 33 | ) 34 | ) 35 | + [get_cmap("Set3")(i) for i in range(12)] 36 | + ["white"], 37 | name="SemSegMap", 38 | ) 39 | 40 | 41 | @torch.no_grad() 42 | def cmap_tensor(t): 43 | t_hw = t.cpu().numpy() 44 | o_hwc = CMAPSPEC(t_hw)[..., :3] # drop alpha 45 | o = torch.from_numpy(o_hwc).transpose(-1, -2).transpose(-2, -3).to(t.device) 46 | return o 47 | 48 | 49 | def conv_transpose_out_shape( 50 | in_size, stride, padding, kernel_size, out_padding, dilation=1 51 | ): 52 | return ( 53 | (in_size - 1) * stride 54 | - 2 * padding 55 | + dilation * (kernel_size - 1) 56 | + out_padding 57 | + 1 58 | ) 59 | 60 | 61 | def assert_shape( 62 | actual: Union[torch.Size, Tuple[int, ...]], 63 | expected: Tuple[int, ...], 64 | message: str = "", 65 | ): 66 | assert ( 67 | actual == expected 68 | ), f"Expected shape: {expected} but passed shape: {actual}. {message}" 69 | 70 | 71 | def build_grid(resolution): 72 | ranges = [torch.linspace(0.0, 1.0, steps=res) for res in resolution] 73 | grid = torch.meshgrid(*ranges) 74 | grid = torch.stack(grid, dim=-1) 75 | grid = torch.reshape(grid, [resolution[0], resolution[1], -1]) 76 | grid = grid.unsqueeze(0) 77 | return torch.cat([grid, 1.0 - grid], dim=-1) 78 | 79 | 80 | def rescale(x: torch.Tensor) -> torch.Tensor: 81 | return x * 2 - 1 82 | 83 | 84 | def slightly_off_center_crop(image: torch.Tensor) -> torch.Tensor: 85 | crop = ((29, 221), (64, 256)) # Get center crop. (height, width) 86 | # `image` has shape [channels, height, width] 87 | return image[:, crop[0][0] : crop[0][1], crop[1][0] : crop[1][1]] 88 | 89 | 90 | def slightly_off_center_mask_crop(mask: torch.Tensor) -> torch.Tensor: 91 | # `mask` has shape [max_num_entities, height, width, channels] 92 | crop = ((29, 221), (64, 256)) # Get center crop. (height, width) 93 | return mask[:, crop[0][0] : crop[0][1], crop[1][0] : crop[1][1], :] 94 | 95 | 96 | def compact(l: Any) -> Any: 97 | return list(filter(None, l)) 98 | 99 | 100 | def first(x): 101 | return next(iter(x)) 102 | 103 | 104 | def only(x): 105 | materialized_x = list(x) 106 | assert len(materialized_x) == 1 107 | return materialized_x[0] 108 | 109 | 110 | class ImageLogCallback(Callback): 111 | def log_images(self, trainer, pl_module, stage): 112 | if trainer.logger: 113 | with torch.no_grad(): 114 | pl_module.eval() 115 | images = pl_module.sample_images(stage=stage) 116 | trainer.logger.experiment.log( 117 | {stage + "/images": [wandb.Image(images)]}, commit=False 118 | ) 119 | 120 | def on_validation_epoch_end(self, trainer, pl_module) -> None: 121 | self.log_images(trainer, pl_module, stage="validation") 122 | 123 | def on_train_epoch_end(self, trainer, pl_module) -> None: 124 | self.log_images(trainer, pl_module, stage="train") 125 | 126 | 127 | def to_rgb_from_tensor(x: torch.Tensor): 128 | return (x * 0.5 + 0.5).clamp(0, 1) 129 | 130 | 131 | def unstack_and_split(x, batch_size, num_channels=3): 132 | """Unstack batch dimension and split into channels and alpha mask.""" 133 | unstacked = torch.reshape(x, [batch_size, -1] + x.shape.as_list()[1:]) 134 | channels, masks = torch.split(unstacked, [num_channels, 1], dim=-1) 135 | return channels, masks 136 | 137 | 138 | def linear(in_features, out_features, bias=True, weight_init="xavier", gain=1.0): 139 | m = nn.Linear(in_features, out_features, bias) 140 | 141 | if weight_init == "kaiming": 142 | nn.init.kaiming_uniform_(m.weight, nonlinearity="relu") 143 | else: 144 | nn.init.xavier_uniform_(m.weight, gain) 145 | 146 | if bias: 147 | nn.init.zeros_(m.bias) 148 | 149 | return m 150 | 151 | 152 | def gru_cell(input_size, hidden_size, bias=True): 153 | m = nn.GRUCell(input_size, hidden_size, bias) 154 | 155 | nn.init.xavier_uniform_(m.weight_ih) 156 | nn.init.orthogonal_(m.weight_hh) 157 | 158 | if bias: 159 | nn.init.zeros_(m.bias_ih) 160 | nn.init.zeros_(m.bias_hh) 161 | 162 | return m 163 | 164 | 165 | def warm_and_decay_lr_scheduler( 166 | step: int, warmup_steps_pct, decay_steps_pct, total_steps, gamma 167 | ): 168 | warmup_steps = warmup_steps_pct * total_steps 169 | decay_steps = decay_steps_pct * total_steps 170 | assert step < total_steps 171 | if step < warmup_steps: 172 | factor = step / warmup_steps 173 | else: 174 | factor = 1 175 | factor *= gamma ** (step / decay_steps) 176 | return factor 177 | 178 | 179 | def cosine_anneal(step: int, start_value, final_value, start_step, final_step): 180 | assert start_value >= final_value 181 | assert start_step <= final_step 182 | 183 | if step < start_step: 184 | value = start_value 185 | elif step >= final_step: 186 | value = final_value 187 | else: 188 | a = 0.5 * (start_value - final_value) 189 | b = 0.5 * (start_value + final_value) 190 | progress = (step - start_step) / (final_step - start_step) 191 | value = a * math.cos(math.pi * progress) + b 192 | 193 | return value 194 | 195 | 196 | def linear_warmup(step, start_value, final_value, start_step, final_step): 197 | assert start_value <= final_value 198 | assert start_step <= final_step 199 | 200 | if step < start_step: 201 | value = start_value 202 | elif step >= final_step: 203 | value = final_value 204 | else: 205 | a = final_value - start_value 206 | b = start_value 207 | progress = (step + 1 - start_step) / (final_step - start_step) 208 | value = a * progress + b 209 | 210 | return value 211 | 212 | 213 | def visualize(image, recon_orig, gen, attns, N=8): 214 | _, _, H, W = image.shape 215 | image = image[:N].expand(-1, 3, H, W).unsqueeze(dim=1) 216 | recon_orig = recon_orig[:N].expand(-1, 3, H, W).unsqueeze(dim=1) 217 | gen = gen[:N].expand(-1, 3, H, W).unsqueeze(dim=1) 218 | attns = attns[:N].expand(-1, -1, 3, H, W) 219 | 220 | return torch.cat((image, recon_orig, gen, attns), dim=1).view(-1, 3, H, W) 221 | 222 | 223 | def compute_ari(prediction, mask, batch_size, height, width, max_num_entities): 224 | # Ground-truth segmentation masks are always returned in the canonical 225 | # [batch_size, max_num_entities, height, width, channels] format. To use these 226 | # as an input for `segmentation_metrics.adjusted_rand_index`, we need them in 227 | # the [batch_size, n_points, n_true_groups] format, 228 | # where n_true_groups == max_num_entities. We implement this reshape below. 229 | # Note that 'oh' denotes 'one-hot'. 230 | desired_shape = [batch_size, height * width, max_num_entities] 231 | true_groups_oh = torch.permute(mask, [0, 2, 3, 4, 1]) 232 | # `true_groups_oh` has shape [batch_size, height, width, channels, max_num_entries] 233 | true_groups_oh = torch.reshape(true_groups_oh, desired_shape) 234 | 235 | # prediction = tf.random_uniform(desired_shape[:-1], 236 | # minval=0, maxval=max_num_entities, 237 | # dtype=tf.int32) 238 | # prediction_oh = F.one_hot(prediction, max_num_entities) 239 | 240 | ari = adjusted_rand_index(true_groups_oh[..., 1:], prediction) 241 | return ari 242 | 243 | 244 | def flatten_all_but_last(tensor, n_dims=1): 245 | shape = list(tensor.shape) 246 | batch_dims = shape[:-n_dims] 247 | flat_tensor = torch.reshape(tensor, [np.prod(batch_dims)] + shape[-n_dims:]) 248 | 249 | def unflatten(other_tensor): 250 | other_shape = list(other_tensor.shape) 251 | return torch.reshape(other_tensor, batch_dims + other_shape[1:]) 252 | 253 | return flat_tensor, unflatten 254 | 255 | 256 | def sa_segment(masks, threshold, return_cmap=True): 257 | # `masks` has shape [batch_size, num_entries, channels, height, width]. 258 | # Ensure that each pixel belongs to exclusively the slot with the 259 | # greatest mask value for that pixel. Mask values are in the range 260 | # [0, 1] and sum to 1. Add 1 to `segmentation` so the values/colors 261 | # line up with `segmentation_thresholded` 262 | segmentation = masks.to(torch.float).argmax(1).squeeze(1) + 1 263 | segmentation = segmentation.to(torch.uint8) 264 | cmap_segmentation = cmap_tensor(segmentation) 265 | 266 | pred_masks = masks.clone() 267 | # For each pixel, if the mask value for that pixel is less than the 268 | # threshold across all slots, then it belongs to an imaginary slot. This 269 | # imaginary slot is concatenated to the masks as the first mask with any 270 | # pixels that belong to it set to 1. So, the argmax will always select the 271 | # imaginary slot for a pixel if that pixel is in the imaginary slot. This 272 | # has the effect such that any pixel that is masked by multiple slots will 273 | # be grouped together into one mask and removed from its previous mask. 274 | less_than_threshold = pred_masks < threshold 275 | thresholded = torch.all(less_than_threshold, dim=1, keepdim=True).to(torch.float) 276 | pred_masks = torch.cat([thresholded, pred_masks], dim=1) 277 | segmentation_thresholded = pred_masks.to(torch.float).argmax(1).squeeze(1) 278 | cmap_segmentation_thresholded = cmap_tensor( 279 | segmentation_thresholded.to(torch.uint8) 280 | ) 281 | 282 | # `cmap_segmentation` and `cmap_segmentation_thresholded` have shape 283 | # [batch_size, channels=3, height, width]. 284 | if return_cmap: 285 | return ( 286 | segmentation, 287 | segmentation_thresholded, 288 | cmap_segmentation, 289 | cmap_segmentation_thresholded, 290 | ) 291 | return segmentation, segmentation_thresholded 292 | 293 | 294 | def PolyArea(x, y): 295 | """Implementation of Shoelace formula for polygon area""" 296 | # From https://stackoverflow.com/a/30408825 297 | return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1))) 298 | 299 | 300 | def get_largest_objects(objects, metric="area"): 301 | # Implementation partly from https://stackoverflow.com/a/60955825 302 | # `objects` has shape [batch_size, num_objects, height, width] 303 | largest_objects = [] 304 | for batch_idx in range(objects.shape[0]): 305 | distances = [] 306 | for object_idx in range(objects.shape[1]): 307 | ob = objects[batch_idx, object_idx].to(torch.bool) 308 | if not torch.any(ob): 309 | distances.append(0) 310 | continue 311 | 312 | # Find a convex hull in O(N log N) 313 | points = np.indices((128, 128)).transpose((1, 2, 0))[ob] 314 | try: 315 | hull = ConvexHull(points) 316 | except QhullError: 317 | distances.append(0) 318 | continue 319 | # Extract the points forming the hull 320 | hullpoints = points[hull.vertices, :] 321 | if metric == "distance": 322 | # Naive way of finding the best pair in O(H^2) time if H is number 323 | # of points on hull 324 | hdist = cdist(hullpoints, hullpoints, metric="euclidean") 325 | # Get the farthest apart points 326 | bestpair = np.unravel_index(hdist.argmax(), hdist.shape) 327 | 328 | point1 = hullpoints[bestpair[0]] 329 | point2 = hullpoints[bestpair[1]] 330 | score = np.linalg.norm(point2 - point1) 331 | elif metric == "area": 332 | x = hullpoints[:, 0] 333 | y = hullpoints[:, 1] 334 | score = PolyArea(x, y) 335 | distances.append(score) 336 | 337 | largest_objects.append(np.argmax(np.array(distances))) 338 | return np.array(largest_objects) 339 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import torch 3 | import h5py 4 | from PIL import Image 5 | from glob import glob 6 | from torchvision import transforms 7 | from matplotlib import pyplot as plt 8 | from object_discovery.method import SlotAttentionMethod 9 | from object_discovery.slot_attention_model import SlotAttentionModel 10 | from object_discovery.utils import slightly_off_center_crop 11 | 12 | 13 | def load_model(ckpt_path): 14 | ckpt = torch.load(ckpt_path) 15 | params = Namespace(**ckpt["hyper_parameters"]) 16 | sa = SlotAttentionModel( 17 | resolution=params.resolution, 18 | num_slots=params.num_slots, 19 | num_iterations=params.num_iterations, 20 | slot_size=params.slot_size, 21 | ) 22 | model = SlotAttentionMethod.load_from_checkpoint( 23 | ckpt_path, model=sa, datamodule=None 24 | ) 25 | model.eval() 26 | return model 27 | 28 | 29 | print("Loading model...") 30 | ckpt_path = "epoch=209-step=299460.ckpt" 31 | model = load_model(ckpt_path) 32 | 33 | t = transforms.ToTensor() 34 | 35 | print("Loading images...") 36 | with h5py.File("data/box_world_dataset.h5", "r") as f: 37 | images = f["image"][0:8] 38 | 39 | transformed_images = [] 40 | for image in images: 41 | transformed_images.append(t(image)) 42 | images = torch.stack(transformed_images) 43 | 44 | print("Predicting...") 45 | images = model.predict( 46 | images, 47 | do_transforms=True, 48 | debug=True, 49 | return_pil=True, 50 | background_detection="both", 51 | ) 52 | slots = model.predict(images, do_transforms=True, return_slots=True) 53 | slots = slots.squeeze() 54 | # `slots` has shape (num_slots, num_features) 55 | 56 | print("Saving...") 57 | images.save("output.png") 58 | 59 | 60 | # ckpt_path = "sketchy_sa-epoch=59-step=316440-3nofluv3.ckpt" 61 | # model = load_model(ckpt_path) 62 | 63 | # transformed_images = [] 64 | # for image_path in glob("data/sketchy_sample/*.png"): 65 | # image = Image.open(image_path) 66 | # image = image.convert("RGB") 67 | # transformed_images.append(transforms.functional.to_tensor(image)) 68 | 69 | # images = model.predict( 70 | # torch.stack(transformed_images), 71 | # do_transforms=True, 72 | # debug=True, 73 | # return_pil=True, 74 | # background_detection="both", 75 | # ) 76 | 77 | # plt.imshow(images, interpolation="nearest") 78 | # plt.show() 79 | 80 | # ckpt_path = "clevr6_masks-epoch=673-step=275666-r4nbi6n7.ckpt" 81 | # model = load_model(ckpt_path) 82 | 83 | # t = transforms.Compose( 84 | # [ 85 | # transforms.ToTensor(), 86 | # transforms.Lambda(slightly_off_center_crop), 87 | # ] 88 | # ) 89 | 90 | # with h5py.File("/media/Main/Downloads/clevr_with_masks.h5", "r") as f: 91 | # images = f["image"][0:8] 92 | 93 | # transformed_images = [] 94 | # for image in images: 95 | # transformed_images.append(t(image)) 96 | # images = torch.stack(transformed_images) 97 | 98 | # images = model.predict( 99 | # images, 100 | # do_transforms=True, 101 | # debug=True, 102 | # return_pil=True, 103 | # background_detection="both", 104 | # ) 105 | 106 | # plt.imshow(images, interpolation="nearest") 107 | # plt.show() 108 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "slot-attention-pytorch" 3 | version = "0.1.0" 4 | description = "PyTorch implementation of \"Object-Centric Learning with Slot Attention\"" 5 | authors = ["Bryden Fogelman ", "Hayden Housen "] 6 | 7 | [tool.poetry.dependencies] 8 | python = ">=3.9,<3.12" 9 | torch = "^1.13.1" 10 | torchvision = "^0.14.1" 11 | wandb = "^0.13.7" 12 | Pillow = "^9.3.0" 13 | pytorch-lightning = "^1.8.6" 14 | numpy = "^1.24.1" 15 | h5py = "^3.7.0" 16 | scipy = "^1.10.0" 17 | matplotlib = "^3.6.2" 18 | 19 | [tool.poetry.dev-dependencies] 20 | 21 | [build-system] 22 | requires = ["poetry-core>=1.0.0"] 23 | build-backend = "poetry.core.masonry.api" 24 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | set -x 4 | 5 | poetry install 6 | 7 | chmod +x download_clevr.sh 8 | ./download_clevr.sh /tmp/CLEVR 9 | 10 | python -m slot_attention.train 11 | -------------------------------------------------------------------------------- /small_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHousen/object-discovery-pytorch/db22b8c7f230a79361927df0fa91e9fd19da2160/small_logo.png --------------------------------------------------------------------------------