├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── data_scripts
├── clevr_with_masks.py
├── download_clevr.sh
├── download_clevrtex.sh
├── preprocess_clevr_with_masks.py
├── preprocess_tetrominoes.py
└── tetrominoes.py
├── media
├── gnm_clevr6.png
├── gnm_clevrtex6.png
├── sa_clevr10_with_masks.png
├── sa_clevr6_example.png
├── sa_clevrtex6.png
├── sa_sketchy_with_masks.png
└── slate_clevr6.png
├── object_discovery
├── __init__.py
├── data.py
├── gnm
│ ├── __init__.py
│ ├── config.py
│ ├── gnm_model.py
│ ├── logging.py
│ ├── metrics.py
│ ├── module.py
│ ├── submodule.py
│ └── utils.py
├── method.py
├── params.py
├── segmentation_metrics.py
├── slate_model.py
├── slot_attention_model.py
├── train.py
├── transformer.py
└── utils.py
├── poetry.lock
├── predict.py
├── pyproject.toml
├── run.sh
└── small_logo.png
/.gitignore:
--------------------------------------------------------------------------------
1 | checkpoints/
2 | checkpoint_details.json
3 | .idea/
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | pip-wheel-metadata/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
99 | __pypackages__/
100 |
101 | # Celery stuff
102 | celerybeat-schedule
103 | celerybeat.pid
104 |
105 | # SageMath parsed files
106 | *.sage.py
107 |
108 | # Environments
109 | .env
110 | .venv
111 | env/
112 | venv/
113 | ENV/
114 | env.bak/
115 | venv.bak/
116 |
117 | # Spyder project settings
118 | .spyderproject
119 | .spyproject
120 |
121 | # Rope project settings
122 | .ropeproject
123 |
124 | # mkdocs documentation
125 | /site
126 |
127 | # mypy
128 | .mypy_cache/
129 | .dmypy.json
130 | dmypy.json
131 |
132 | # Pyre type checker
133 | .pyre/
134 |
135 | # Custom
136 | wandb/
137 | slot-attention-clevr/
138 | .vscode/
139 | data/
140 | debug/
141 | saved_checkpoints/
142 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Object Discovery PyTorch
2 |
3 |
4 |
5 | > This is an implementation of several unsupervised object discovery models (Slot Attention, SLATE, GNM) in PyTorch.
6 |
7 | [](https://github.com/HHousen/object-discovery-pytorch/blob/master/LICENSE) [](https://github.com/HHousen/object-discovery-pytorch/commits/master) [](https://www.python.org/) [](https://GitHub.com/HHousen/object-discovery-pytorch/issues/) [](https://GitHub.com/HHousen/object-discovery-pytorch/pull/)
8 |
9 | The initial code for this repo was forked from [untitled-ai/slot_attention](https://github.com/untitled-ai/slot_attention).
10 |
11 | 
12 |
13 | ## Setup
14 |
15 | ### Requirements
16 |
17 | - [Poetry](https://python-poetry.org/docs/)
18 | - Python >= 3.9
19 | - CUDA enabled computing device
20 |
21 | ### Getting Started
22 |
23 | 1. Clone the repo: `git clone https://github.com/HHousen/slot-attention-pytorch/ && cd slot-attention-pytorch`.
24 | 2. Install requirements and activate environment: `poetry install` then `poetry shell`.
25 | 3. Download the [CLEVR (with masks)](https://huggingface.co/HHousen/object-discovery-pytorch/blob/main/Datasets/clevr_with_masks.h5) dataset (or the original [CLEVR](https://cs.stanford.edu/people/jcjohns/clevr/) dataset by running `./data_scripts/download_clevr.sh /tmp/CLEVR`). More details about the datasets are below.
26 | 4. Modify the hyperparameters in [object_discovery/params.py](object_discovery/params.py) to fit your needs. Make sure to change `data_root` to the location of your dataset.
27 | 5. Train a model: `python -m slot_attention.train`.
28 |
29 | ## Pre-trained Models
30 |
31 | Code to load these models can be adapted from [predict.py](./predict.py).
32 |
33 | | Model | Dataset | Download |
34 | |---|---|---|
35 | | Slot Attention | CLEVR6 Masks | [Hugging Face](https://huggingface.co/HHousen/object-discovery-pytorch/blob/main/Saved%20Checkpoints/clevr6_masks-epoch=673-step=275666-r4nbi6n7.ckpt) |
36 | | Slot Attention | Sketchy | [Hugging Face](https://huggingface.co/HHousen/object-discovery-pytorch/blob/main/Saved%20Checkpoints/sketchy_sa-epoch=59-step=316440-3nofluv3.ckpt) |
37 | | GNM | CLEVR6 Masks | [Hugging Face](https://huggingface.co/HHousen/object-discovery-pytorch/blob/main/Saved%20Checkpoints/clevr6_gnm_0.4std-epoch=546-step=223723-3fi81x4s.ckpt) |
38 | | Slot Attention | ClevrTex6 | [Hugging Face](https://huggingface.co/HHousen/object-discovery-pytorch/blob/main/Saved%20Checkpoints/clevrtex_sa-epoch=843-step=263328-p0hcfm7j.ckpt) |
39 | | GNM | ClevrTex6 | [Hugging Face](https://huggingface.co/HHousen/object-discovery-pytorch/blob/main/Saved%20Checkpoints/clevrtex_gnm_0.7std-epoch=890-step=277992-35w3i9mg.ckpt) |
40 | | SLATE | CLEVR6 Masks | [Hugging Face](https://huggingface.co/HHousen/object-discovery-pytorch/blob/main/Saved%20Checkpoints/clevr6_slate_no_rescale-1pvy8x3x/epoch=708-step=371516.ckpt) |
41 |
42 | ## Usage
43 |
44 | Train a model by running `python -m slot_attention.train`.
45 |
46 | Hyperparameters can be changed in [object_discovery/params.py](object_discovery/params.py). `training_params` has global parameters that apply to all model types. These parameters can be overridden if the same key is present in `slot_attention_params` or `slate_params`. Change the global parameter `model_type` to `sa` to use Slot Attention (`SlotAttentionModel` in slot_attention_model.py) or `slate` to use SLATE (`SLATE` in slate_model.py). This will determine which model's set of parameters will be merged with `training_params`.
47 |
48 | Perform inference by modifying and running the [predict.py](./predict.py) script.
49 |
50 | ### Models
51 |
52 | Our implementations are based on several open-source repositories.
53 |
54 | 1. Slot Attention (["Object-Centric Learning with Slot Attention"](https://arxiv.org/abs/2006.15055)): [untitled-ai/slot_attention](https://github.com/untitled-ai/slot_attention) & [Official](https://github.com/google-research/google-research/tree/master/slot_attention)
55 | 2. SLATE (["Illiterate DALL-E Learns to Compose"](https://arxiv.org/abs/2110.11405)): [Official](https://github.com/singhgautam/slate)
56 | 3. GNM (["Generative Neurosymbolic Machines"](https://arxiv.org/abs/2010.12152)): [karazijal/clevrtex](https://github.com/karazijal/clevrtex) & [Official](https://github.com/JindongJiang/GNM)
57 |
58 | ### Datasets
59 |
60 | Select a dataset by changing the `dataset` parameter in [object_discovery/params.py](object_discovery/params.py) to the name of the dataset: `clevr`, `shapes3d`, or `ravens`. Then, set the `data_root` parameter to the location of the data. The code for loading supported datasets is in [object_discovery/data.py](object_discovery/data.py).
61 |
62 | 1. [CLEVR](https://cs.stanford.edu/people/jcjohns/clevr/): Download by executing [download_clevr.sh](./data_scripts/download_clevr.sh).
63 | 2. [CLEVR (with masks)](https://github.com/deepmind/multi_object_datasets#clevr-with-masks): [Original TFRecords Download](https://console.cloud.google.com/storage/browser/multi-object-datasets/clevr_with_masks) / [Our HDF5 PyTorch Version](https://huggingface.co/HHousen/object-discovery-pytorch/blob/main/Datasets/clevr_with_masks.h5).
64 | - This dataset is a regenerated version of CLEVR but with ground-truth segmentation masks. This enables the training script to calculate Adjusted Rand Index (ARI) during validation runs.
65 | - The dataset contains 100,000 images with a resolution of 240x320 pixels. The dataloader splits them 70K train, 15K validation, 15k test. Test images are not used by the [object_discovery/train.py](object_discovery/train.py) script.
66 | - We convert the original TFRecords dataset to HDF5 for easy use with PyTorch. This was done using the [data_scripts/preprocess_clevr_with_masks.py](./data_scripts/preprocess_clevr_with_masks.py) script, which takes approximately 2 hours to execute depending on your machine.
67 | 3. [3D Shapes](https://github.com/deepmind/3d-shapes): [Official Google Cloud Bucket](https://console.cloud.google.com/storage/browser/3d-shapes)
68 | 4. RAVENS Robot Data: [Train](https://huggingface.co/HHousen/object-discovery-pytorch/blob/main/Datasets/RAVENS%20Robot%20Dataset/ravens_robot_data_train.h5) & [Test](https://huggingface.co/HHousen/object-discovery-pytorch/blob/main/Datasets/RAVENS%20Robot%20Dataset/ravens_robot_data_test.h5)
69 | - We generated a dataset similar in structure to CLEVR (with masks) but of robotic images using [RAVENS](https://github.com/google-research/ravens). Our modified version of RAVENS used to generate the dataset is [HHousen/ravens](https://github.com/HHousen/ravens).
70 | - The dataset contains 85,002 images split 70,002 train and 15K validation/test.
71 | 5. Sketchy: Download and process by following directions in [applied-ai-lab/genesis](https://github.com/applied-ai-lab/genesis#sketchy) / [Download Our Processed Version](https://huggingface.co/HHousen/object-discovery-pytorch/blob/main/Datasets/genesis_sketchy_processed.tar.gz)
72 | - Dataset details are in the paper [Scaling data-driven robotics with reward sketching and batch reinforcement learning](https://arxiv.org/abs/1909.12200).
73 | 6. ClevrTex: Download by executing [download_clevrtex.sh](./data_scripts/download_clevrtex.sh). Our dataloader needs to index the entire dataset before training can begin. This can take around 2 hours. Thus, it is recommended to download our pre-made index from [this Hugging Face folder](https://huggingface.co/HHousen/object-discovery-pytorch/tree/main/Datasets/ClevrTex%20Cache) and put it in `./data/cache/`.
74 | 7. [Tetrominoes](https://github.com/deepmind/multi_object_datasets#tetrominoes): [Original TFRecords Download](https://console.cloud.google.com/storage/browser/multi-object-datasets/tetrominoes) / [Our HDF5 PyTorch Version](https://huggingface.co/HHousen/object-discovery-pytorch/blob/main/Datasets/tetrominoes.h5).
75 | - There are 1,000,000 samples in the dataset. However, following [the Slot Attention paper](https://arxiv.org/abs/2006.15055), we only use the first 60K samples for training.
76 | - We convert the original TFRecords dataset to HDF5 for easy use with PyTorch. This was done using the [data_scripts/preprocess_tetrominoes.py](./data_scripts/preprocess_tetrominoes.py) script, which takes approximately 2 hours to execute depending on your machine.
77 |
78 | ### Logging
79 |
80 | To log outputs to [wandb](https://wandb.ai/home), run `wandb login YOUR_API_KEY` and set `is_logging_enabled=True` in `SlotAttentionParams`.
81 |
82 | If you use a dataset with ground-truth segmentation masks, then the Adjusted Rand Index (ARI), a clustering similarity score, will be logged for each validation loop. We convert the implementation from [deepmind/multi_object_datasets](https://github.com/deepmind/multi_object_datasets) to PyTorch in [object_discovery/segmentation_metrics.py](object_discovery/segmentation_metrics.py).
83 |
84 | ## More Visualizations
85 |
86 | Slot Attention CLEVR10 | Slot Attention Sketchy
87 | :-----------------------:|:--------------------:
88 |  | 
89 |
90 | Visualizations (above) for a model trained on CLEVR6 predicting on CLEVR10 (with no increase in number of slots) and a model trained and predicting on Sketchy. The order from left to right of the images is original, reconstruction, raw predicted segmentation mask, processed segmentation mask, and then the slots.
91 |
92 | Slot Attention ClevrTex6 | GNM ClevrTex6
93 | :-----------------------:|:--------------------:
94 |  | 
95 |
96 | The Slot Attention visualization image order is the same as in the above visualizations. For GNM, the order is original, reconstruction, ground truth segmentation mask, prediction segmentation mask (repeated 4 times).
97 |
98 | SLATE CLEVR6 | GNM CLEVR6
99 | :-----------------------:|:--------------------:
100 |  | 
101 |
102 | For SLATE, the image order is original, dVAE reconstruction, autoregressive reconstruction, and then the pixels each slot pays attention to.
103 |
104 | ## References
105 |
106 | 1. [untitled-ai/slot_attention](https://github.com/untitled-ai/slot_attention): An unofficial implementation of Slot Attention from which this repo was forked.
107 | 2. Slot Attention: [Official Code](https://github.com/google-research/google-research/tree/master/slot_attention) / ["Object-Centric Learning with Slot Attention"](https://arxiv.org/abs/2006.15055).
108 | 3. SLATE: [Official Code](https://github.com/singhgautam/slate) / ["Illiterate DALL-E Learns to Compose"](https://arxiv.org/abs/2110.11405).
109 | 4. IODINE: [Official Code](https://github.com/deepmind/deepmind-research/tree/master/iodine) / ["Multi-Object Representation Learning with Iterative Variational Inference"](https://arxiv.org/abs/1903.00450). In the Slot Attention paper, IODINE was frequently used for comparison. The IODINE code was helpful to create this repo.
110 | 5. Multi-Object Datasets: [deepmind/multi_object_datasets](https://github.com/deepmind/multi_object_datasets). This is the original source of the [CLEVR (with masks)](https://github.com/deepmind/multi_object_datasets#clevr-with-masks) dataset.
111 | 6. Implicit Slot Attention: ["Object Representations as Fixed Points: Training Iterative Refinement Algorithms with Implicit Differentiation"](https://arxiv.org/abs/2207.00787). This paper explains a one-line change that improves the optimization of Slot Attention while simultaneously making backpropagation have constant space and time complexity.
112 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHousen/object-discovery-pytorch/db22b8c7f230a79361927df0fa91e9fd19da2160/__init__.py
--------------------------------------------------------------------------------
/data_scripts/clevr_with_masks.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """CLEVR (with masks) dataset reader."""
16 |
17 | import tensorflow.compat.v1 as tf
18 |
19 |
20 | COMPRESSION_TYPE = tf.io.TFRecordOptions.get_compression_type_string("GZIP")
21 | IMAGE_SIZE = [240, 320]
22 | # The maximum number of foreground and background entities in the provided
23 | # dataset. This corresponds to the number of segmentation masks returned per
24 | # scene.
25 | MAX_NUM_ENTITIES = 11
26 | BYTE_FEATURES = ["mask", "image", "color", "material", "shape", "size"]
27 |
28 | # Create a dictionary mapping feature names to `tf.Example`-compatible
29 | # shape and data type descriptors.
30 | features = {
31 | "image": tf.FixedLenFeature(IMAGE_SIZE + [3], tf.string),
32 | "mask": tf.FixedLenFeature([MAX_NUM_ENTITIES] + IMAGE_SIZE + [1], tf.string),
33 | "x": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
34 | "y": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
35 | "z": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
36 | "pixel_coords": tf.FixedLenFeature([MAX_NUM_ENTITIES, 3], tf.float32),
37 | "rotation": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
38 | "size": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.string),
39 | "material": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.string),
40 | "shape": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.string),
41 | "color": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.string),
42 | "visibility": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
43 | }
44 |
45 |
46 | def _decode(example_proto):
47 | # Parse the input `tf.Example` proto using the feature description dict above.
48 | single_example = tf.parse_single_example(example_proto, features)
49 | for k in BYTE_FEATURES:
50 | single_example[k] = tf.squeeze(
51 | tf.decode_raw(single_example[k], tf.uint8), axis=-1
52 | )
53 | return single_example
54 |
55 |
56 | def dataset(tfrecords_path, read_buffer_size=None, map_parallel_calls=None):
57 | """Read, decompress, and parse the TFRecords file.
58 |
59 | Args:
60 | tfrecords_path: str. Path to the dataset file.
61 | read_buffer_size: int. Number of bytes in the read buffer. See documentation
62 | for `tf.data.TFRecordDataset.__init__`.
63 | map_parallel_calls: int. Number of elements decoded asynchronously in
64 | parallel. See documentation for `tf.data.Dataset.map`.
65 |
66 | Returns:
67 | An unbatched `tf.data.TFRecordDataset`.
68 | """
69 | raw_dataset = tf.data.TFRecordDataset(
70 | tfrecords_path, compression_type=COMPRESSION_TYPE, buffer_size=read_buffer_size
71 | )
72 | return raw_dataset.map(_decode, num_parallel_calls=map_parallel_calls)
73 |
--------------------------------------------------------------------------------
/data_scripts/download_clevr.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | DATA_DIR=$1
5 |
6 | if [ ! -d $DATA_DIR ]; then
7 | mkdir $DATA_DIR
8 | fi
9 |
10 | cd $DATA_DIR
11 |
12 | if [ ! -f "$DATA_DIR/CLEVR_v1.0.zip.*" ]; then
13 | wget https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip
14 | echo "CLEVR_v1 downloaded to $DATA_DIR/CLEVR_v1.0.zip"
15 | else
16 | echo "$DATA_DIR/CLEVR_v1.0.zip already exists, skipping download"
17 | fi
18 |
19 | echo "unzipping CLEVR_v1 to $DATA_DIR/CLEVR_v1.0"
20 | rm -rf CLEVR_v1.0
21 | unzip -q CLEVR_v1.0.zip
22 | rm CLEVR_v1.0.zip
23 |
--------------------------------------------------------------------------------
/data_scripts/download_clevrtex.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | DATA_DIR=$1
5 |
6 | if [ ! -d $DATA_DIR ]; then
7 | mkdir $DATA_DIR
8 | fi
9 |
10 | cd $DATA_DIR
11 |
12 | wget https://thor.robots.ox.ac.uk/~vgg/data/clevrtex/clevrtex_packaged/clevrtex_full_part1.tar.gz
13 | wget https://thor.robots.ox.ac.uk/~vgg/data/clevrtex/clevrtex_packaged/clevrtex_full_part2.tar.gz
14 | wget https://thor.robots.ox.ac.uk/~vgg/data/clevrtex/clevrtex_packaged/clevrtex_full_part3.tar.gz
15 | wget https://thor.robots.ox.ac.uk/~vgg/data/clevrtex/clevrtex_packaged/clevrtex_full_part4.tar.gz
16 | wget https://thor.robots.ox.ac.uk/~vgg/data/clevrtex/clevrtex_packaged/clevrtex_full_part5.tar.gz
17 | echo "ClevrTex downloaded to $DATA_DIR"
18 |
19 |
20 | echo "Unzipping ClevrTex to $DATA_DIR/ClevrTex"
21 | rm -rf ClevrTex
22 | for file in *.tar.gz; do tar -zxf "$file"; done
23 | for file in *.tar.gz; do rm "$file"; done
24 |
--------------------------------------------------------------------------------
/data_scripts/preprocess_clevr_with_masks.py:
--------------------------------------------------------------------------------
1 | # Script to convert CLEVR (with masks) tfrecords file to h5py.
2 | # CLEVR (with masks) information: https://github.com/deepmind/multi_object_datasets#clevr-with-masks
3 | # Download: https://console.cloud.google.com/storage/browser/multi-object-datasets/clevr_with_masks
4 |
5 | import numpy as np
6 | from tqdm import tqdm
7 | import clevr_with_masks
8 | import h5py
9 |
10 |
11 | dataset = clevr_with_masks.dataset(
12 | "clevr_with_masks_train.tfrecords"
13 | ).as_numpy_iterator()
14 |
15 | with h5py.File("clevr_with_masks.h5", "w") as f:
16 | for idx, entry in tqdm(enumerate(dataset), total=100_000):
17 | for key, value in entry.items():
18 | value = value[np.newaxis, ...]
19 | if idx == 0:
20 | f.create_dataset(
21 | key,
22 | data=value,
23 | dtype=value.dtype,
24 | maxshape=(None, *value.shape[1:]),
25 | compression="gzip",
26 | chunks=True,
27 | )
28 | else:
29 | f[key].resize((f[key].shape[0] + value.shape[0]), axis=0)
30 | f[key][-value.shape[0] :] = value
31 |
--------------------------------------------------------------------------------
/data_scripts/preprocess_tetrominoes.py:
--------------------------------------------------------------------------------
1 | # Script to convert Tetrominoes tfrecords file to h5py.
2 | # Tetrominoes information: https://github.com/deepmind/multi_object_datasets#tetrominoes
3 | # Download: https://console.cloud.google.com/storage/browser/multi-object-datasets/tetrominoes
4 |
5 | import numpy as np
6 | from tqdm import tqdm
7 | import tetrominoes
8 | import h5py
9 |
10 |
11 | dataset = tetrominoes.dataset("tetrominoes_train.tfrecords").as_numpy_iterator()
12 |
13 | with h5py.File("tetrominoes.h5", "w") as f:
14 | for idx, entry in tqdm(enumerate(dataset), total=1_000_000):
15 | for key, value in entry.items():
16 | value = value[np.newaxis, ...]
17 | if idx == 0:
18 | f.create_dataset(
19 | key,
20 | data=value,
21 | dtype=value.dtype,
22 | maxshape=(None, *value.shape[1:]),
23 | compression="gzip",
24 | chunks=True,
25 | )
26 | else:
27 | f[key].resize((f[key].shape[0] + value.shape[0]), axis=0)
28 | f[key][-value.shape[0] :] = value
29 |
--------------------------------------------------------------------------------
/data_scripts/tetrominoes.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Tetrominoes dataset reader."""
16 |
17 | import tensorflow.compat.v1 as tf
18 |
19 |
20 | COMPRESSION_TYPE = tf.io.TFRecordOptions.get_compression_type_string("GZIP")
21 | IMAGE_SIZE = [35, 35]
22 | # The maximum number of foreground and background entities in the provided
23 | # dataset. This corresponds to the number of segmentation masks returned per
24 | # scene.
25 | MAX_NUM_ENTITIES = 4
26 | BYTE_FEATURES = ["mask", "image"]
27 |
28 | # Create a dictionary mapping feature names to `tf.Example`-compatible
29 | # shape and data type descriptors.
30 | features = {
31 | "image": tf.FixedLenFeature(IMAGE_SIZE + [3], tf.string),
32 | "mask": tf.FixedLenFeature([MAX_NUM_ENTITIES] + IMAGE_SIZE + [1], tf.string),
33 | "x": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
34 | "y": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
35 | "shape": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
36 | "color": tf.FixedLenFeature([MAX_NUM_ENTITIES, 3], tf.float32),
37 | "visibility": tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
38 | }
39 |
40 |
41 | def _decode(example_proto):
42 | # Parse the input `tf.Example` proto using the feature description dict above.
43 | single_example = tf.parse_single_example(example_proto, features)
44 | for k in BYTE_FEATURES:
45 | single_example[k] = tf.squeeze(
46 | tf.decode_raw(single_example[k], tf.uint8), axis=-1
47 | )
48 | return single_example
49 |
50 |
51 | def dataset(tfrecords_path, read_buffer_size=None, map_parallel_calls=None):
52 | """Read, decompress, and parse the TFRecords file.
53 |
54 | Args:
55 | tfrecords_path: str. Path to the dataset file.
56 | read_buffer_size: int. Number of bytes in the read buffer. See documentation
57 | for `tf.data.TFRecordDataset.__init__`.
58 | map_parallel_calls: int. Number of elements decoded asynchronously in
59 | parallel. See documentation for `tf.data.Dataset.map`.
60 |
61 | Returns:
62 | An unbatched `tf.data.TFRecordDataset`.
63 | """
64 | raw_dataset = tf.data.TFRecordDataset(
65 | tfrecords_path, compression_type=COMPRESSION_TYPE, buffer_size=read_buffer_size
66 | )
67 | return raw_dataset.map(_decode, num_parallel_calls=map_parallel_calls)
68 |
--------------------------------------------------------------------------------
/media/gnm_clevr6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHousen/object-discovery-pytorch/db22b8c7f230a79361927df0fa91e9fd19da2160/media/gnm_clevr6.png
--------------------------------------------------------------------------------
/media/gnm_clevrtex6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHousen/object-discovery-pytorch/db22b8c7f230a79361927df0fa91e9fd19da2160/media/gnm_clevrtex6.png
--------------------------------------------------------------------------------
/media/sa_clevr10_with_masks.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHousen/object-discovery-pytorch/db22b8c7f230a79361927df0fa91e9fd19da2160/media/sa_clevr10_with_masks.png
--------------------------------------------------------------------------------
/media/sa_clevr6_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHousen/object-discovery-pytorch/db22b8c7f230a79361927df0fa91e9fd19da2160/media/sa_clevr6_example.png
--------------------------------------------------------------------------------
/media/sa_clevrtex6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHousen/object-discovery-pytorch/db22b8c7f230a79361927df0fa91e9fd19da2160/media/sa_clevrtex6.png
--------------------------------------------------------------------------------
/media/sa_sketchy_with_masks.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHousen/object-discovery-pytorch/db22b8c7f230a79361927df0fa91e9fd19da2160/media/sa_sketchy_with_masks.png
--------------------------------------------------------------------------------
/media/slate_clevr6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHousen/object-discovery-pytorch/db22b8c7f230a79361927df0fa91e9fd19da2160/media/slate_clevr6.png
--------------------------------------------------------------------------------
/object_discovery/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHousen/object-discovery-pytorch/db22b8c7f230a79361927df0fa91e9fd19da2160/object_discovery/__init__.py
--------------------------------------------------------------------------------
/object_discovery/gnm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHousen/object-discovery-pytorch/db22b8c7f230a79361927df0fa91e9fd19da2160/object_discovery/gnm/__init__.py
--------------------------------------------------------------------------------
/object_discovery/gnm/config.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import json
3 | import io
4 | from argparse import Namespace
5 |
6 |
7 | def dict_to_ns(d):
8 | return Namespace(**d)
9 |
10 |
11 | CONFIG_YAML = """exp_name: ''
12 | data_dir: ''
13 | summary_dir: ''
14 | model_dir: ''
15 | last_ckpt: ''
16 | data:
17 | img_w: 128
18 | img_h: 128
19 | inp_channel: 3
20 | blender_dir_list_train: []
21 | blender_dir_list_test: []
22 | dataset: 'mnist'
23 | z:
24 | z_global_dim: 32
25 | z_what_dim: 64
26 | z_where_scale_dim: 2
27 | z_where_shift_dim: 2
28 | z_where_dim: 4
29 | z_pres_dim: 1
30 | z_depth_dim: 1
31 | z_local_dim: 64
32 | z_bg_dim: 10
33 | arch:
34 | glimpse_size: 64
35 | num_cell: 4
36 | phase_overlap: True
37 | phase_background: True
38 | img_enc_dim: 128
39 | p_global_decoder_type: 'MLP'
40 | draw_step: 4
41 | phase_graph_net_on_global_decoder: False
42 | phase_graph_net_on_global_encoder: False
43 | conv:
44 | img_encoder_filters: [16, 16, 32, 32, 64, 64, 128, 128, 128]
45 | img_encoder_groups: [1, 1, 1, 1, 1, 1, 1, 1, 1]
46 | img_encoder_strides: [2, 1, 2, 1, 2, 1, 2, 1, 2]
47 | img_encoder_kernel_sizes: [4, 3, 4, 3, 4, 3, 4, 3, 4]
48 | p_what_decoder_filters: [128, 64, 32, 16, 8, 4]
49 | p_what_decoder_kernel_sizes: [3, 3, 3, 3, 3, 3]
50 | p_what_decoder_upscales: [2, 2, 2, 2, 2, 2]
51 | p_what_decoder_groups: [1, 1, 1, 1, 1, 1]
52 | p_bg_decoder_filters: [128, 64, 32, 16, 8, 3]
53 | p_bg_decoder_kernel_sizes: [1, 1, 1, 1, 1, 3]
54 | p_bg_decoder_upscales: [4, 2, 4, 2, 2, 1]
55 | p_bg_decoder_groups: [1, 1, 1, 1, 1, 1]
56 | deconv:
57 | p_global_decoder_filters: [128, 128, 128]
58 | p_global_decoder_kernel_sizes: [1, 1, 1]
59 | p_global_decoder_upscales: [2, 1, 2]
60 | p_global_decoder_groups: [1, 1, 1]
61 | mlp:
62 | p_global_decoder_filters: [512, 1024, 2048]
63 | q_global_encoder_filters: [512, 512, 64]
64 | p_global_encoder_filters: [512, 512, 64]
65 | p_bg_generator_filters: [128, 64, 20]
66 | q_bg_encoder_filters: [512, 256, 20]
67 | pwdw:
68 | pwdw_filters: [128, 128]
69 | pwdw_kernel_sizes: [1, 1]
70 | pwdw_strides: [1, 1]
71 | pwdw_groups: [1, 1]
72 | structdraw:
73 | kernel_size: 1
74 | rnn_decoder_hid_dim: 128
75 | rnn_encoder_hid_dim: 128
76 | hid_to_dec_filters: [128]
77 | hid_to_dec_kernel_sizes: [3]
78 | hid_to_dec_strides: [1]
79 | hid_to_dec_groups: [1]
80 | log:
81 | num_summary_img: 15
82 | num_img_per_row: 5
83 | save_epoch_freq: 10
84 | print_step_freq: 2000
85 | num_sample: 50
86 | compute_nll_freq: 20
87 | phase_nll: False
88 | nll_num_sample: 30
89 | phase_log: True
90 | const:
91 | pres_logit_scale: 8.8
92 | scale_mean: -1.5
93 | scale_std: 0.1
94 | ratio_mean: 0
95 | ratio_std: 0.3
96 | shift_std: 1
97 | eps: 0.000000000000001
98 | likelihood_sigma: 0.2
99 | bg_likelihood_sigma: 0.3
100 | train:
101 | start_epoch: 0
102 | epoch: 600
103 | batch_size: 32
104 | lr: 0.0001
105 | cp: 1.0
106 | beta_global_anneal_start_step: 0
107 | beta_global_anneal_end_step: 100000
108 | beta_global_anneal_start_value: 0.
109 | beta_global_anneal_end_value: 1.
110 | beta_pres_anneal_start_step: 0
111 | beta_pres_anneal_end_step: 0
112 | beta_pres_anneal_start_value: 1.
113 | beta_pres_anneal_end_value: 0.
114 | beta_where_anneal_start_step: 0
115 | beta_where_anneal_end_step: 0
116 | beta_where_anneal_start_value: 1.
117 | beta_where_anneal_end_value: 0.
118 | beta_what_anneal_start_step: 0
119 | beta_what_anneal_end_step: 0
120 | beta_what_anneal_start_value: 1.
121 | beta_what_anneal_end_value: 0.
122 | beta_depth_anneal_start_step: 0
123 | beta_depth_anneal_end_step: 0
124 | beta_depth_anneal_start_value: 1.
125 | beta_depth_anneal_end_value: 0.
126 | beta_bg_anneal_start_step: 1000
127 | beta_bg_anneal_end_step: 0
128 | beta_bg_anneal_start_value: 1.
129 | beta_bg_anneal_end_value: 0.
130 | beta_aux_pres_anneal_start_step: 1000
131 | beta_aux_pres_anneal_end_step: 0
132 | beta_aux_pres_anneal_start_value: 1.
133 | beta_aux_pres_anneal_end_value: 0.
134 | beta_aux_where_anneal_start_step: 0
135 | beta_aux_where_anneal_end_step: 500
136 | beta_aux_where_anneal_start_value: 10.
137 | beta_aux_where_anneal_end_value: 1.
138 | beta_aux_what_anneal_start_step: 1000
139 | beta_aux_what_anneal_end_step: 0
140 | beta_aux_what_anneal_start_value: 1.
141 | beta_aux_what_anneal_end_value: 0.
142 | beta_aux_depth_anneal_start_step: 1000
143 | beta_aux_depth_anneal_end_step: 0
144 | beta_aux_depth_anneal_start_value: 1.
145 | beta_aux_depth_anneal_end_value: 0.
146 | beta_aux_global_anneal_start_step: 0
147 | beta_aux_global_anneal_end_step: 100000
148 | beta_aux_global_anneal_start_value: 0.
149 | beta_aux_global_anneal_end_value: 1.
150 | beta_aux_bg_anneal_start_step: 0
151 | beta_aux_bg_anneal_end_step: 50000
152 | beta_aux_bg_anneal_start_value: 50.
153 | beta_aux_bg_anneal_end_value: 1.
154 | tau_pres_anneal_start_step: 1000
155 | tau_pres_anneal_end_step: 20000
156 | tau_pres_anneal_start_value: 1.
157 | tau_pres_anneal_end_value: 0.5
158 | tau_pres: 1.
159 | p_pres_anneal_start_step: 0
160 | p_pres_anneal_end_step: 4000
161 | p_pres_anneal_start_value: 0.1
162 | p_pres_anneal_end_value: 0.001
163 | aux_p_scale_anneal_start_step: 0
164 | aux_p_scale_anneal_end_step: 0
165 | aux_p_scale_anneal_start_value: -1.5
166 | aux_p_scale_anneal_end_value: -1.5
167 | phase_bg_alpha_curriculum: True
168 | bg_alpha_curriculum_period: [0, 500]
169 | bg_alpha_curriculum_value: 0.9
170 | seed: 666
171 | """
172 |
173 |
174 | def get_arrow_args():
175 | arrow_args = json.loads(
176 | json.dumps(yaml.safe_load(io.StringIO(CONFIG_YAML))), object_hook=dict_to_ns
177 | )
178 | return arrow_args
179 |
--------------------------------------------------------------------------------
/object_discovery/gnm/gnm_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from .module import (
4 | LocalLatentDecoder,
5 | LocalSampler,
6 | StructDRAW,
7 | BgDecoder,
8 | BgGenerator,
9 | BgEncoder,
10 | )
11 | from .submodule import StackConvNorm
12 | import torch.distributions as dist
13 | from .utils import (
14 | linear_schedule_tensor,
15 | spatial_transform,
16 | kl_divergence_bern_bern,
17 | linear_schedule,
18 | visualize,
19 | )
20 | from typing import List, Tuple
21 |
22 |
23 | class GNM(nn.Module):
24 | shortname = "gnm"
25 |
26 | def __init__(self, args):
27 | super(GNM, self).__init__()
28 | self.args = args
29 |
30 | self.img_encoder = StackConvNorm(
31 | self.args.data.inp_channel,
32 | self.args.arch.conv.img_encoder_filters,
33 | self.args.arch.conv.img_encoder_kernel_sizes,
34 | self.args.arch.conv.img_encoder_strides,
35 | self.args.arch.conv.img_encoder_groups,
36 | norm_act_final=True,
37 | )
38 | self.global_struct_draw = StructDRAW(self.args)
39 | self.p_z_given_x_or_g_net = LocalLatentDecoder(self.args)
40 | # Share latent decoder for p and q
41 | self.local_latent_sampler = LocalSampler(self.args)
42 |
43 | if self.args.arch.phase_background:
44 | self.p_bg_decoder = BgDecoder(self.args)
45 | self.p_bg_given_g_net = BgGenerator(self.args)
46 | self.q_bg_given_x_net = BgEncoder(self.args)
47 |
48 | self.register_buffer("aux_p_what_mean", torch.zeros(1))
49 | self.register_buffer("aux_p_what_std", torch.ones(1))
50 | self.register_buffer("aux_p_bg_mean", torch.zeros(1))
51 | self.register_buffer("aux_p_bg_std", torch.ones(1))
52 | self.register_buffer("aux_p_depth_mean", torch.zeros(1))
53 | self.register_buffer("aux_p_depth_std", torch.ones(1))
54 | self.register_buffer(
55 | "aux_p_where_mean",
56 | torch.tensor(
57 | [self.args.const.scale_mean, self.args.const.ratio_mean, 0, 0]
58 | )[None, :],
59 | )
60 | # self.register_buffer('auxiliary_where_std', torch.ones(1))
61 | self.register_buffer(
62 | "aux_p_where_std",
63 | torch.tensor(
64 | [
65 | self.args.const.scale_std,
66 | self.args.const.ratio_std,
67 | self.args.const.shift_std,
68 | self.args.const.shift_std,
69 | ]
70 | )[None, :],
71 | )
72 | self.register_buffer(
73 | "aux_p_pres_probs", torch.tensor(self.args.train.p_pres_anneal_start_value)
74 | )
75 | self.register_buffer(
76 | "background",
77 | torch.zeros(
78 | 1,
79 | self.args.data.inp_channel,
80 | self.args.data.img_h,
81 | self.args.data.img_w,
82 | ),
83 | )
84 |
85 | @property
86 | def aux_p_what(self):
87 | return dist.Normal(self.aux_p_what_mean, self.aux_p_what_std)
88 |
89 | @property
90 | def aux_p_bg(self):
91 | return dist.Normal(self.aux_p_bg_mean, self.aux_p_bg_std)
92 |
93 | @property
94 | def aux_p_depth(self):
95 | return dist.Normal(self.aux_p_depth_mean, self.aux_p_depth_std)
96 |
97 | @property
98 | def aux_p_where(self):
99 | return dist.Normal(self.aux_p_where_mean, self.aux_p_where_std)
100 |
101 | def forward(self, x: torch.Tensor, global_step, generate_bbox=False):
102 | self.args = hyperparam_anneal(self.args, global_step)
103 | bs = x.size(0)
104 |
105 | img_enc = self.img_encoder(x)
106 | if self.args.arch.phase_background:
107 | lv_q_bg, ss_q_bg = self.q_bg_given_x_net(img_enc)
108 | q_bg_mean, q_bg_std = ss_q_bg
109 | else:
110 | lv_q_bg = [self.background.new_zeros(1, 1)]
111 | q_bg_mean = self.background.new_zeros(1, 1)
112 | q_bg_std = self.background.new_ones(1, 1)
113 | ss_q_bg = [q_bg_mean, q_bg_std]
114 |
115 | q_bg = dist.Normal(q_bg_mean, q_bg_std)
116 |
117 | pa_g, lv_g, ss_g = self.global_struct_draw(img_enc)
118 |
119 | global_dec = pa_g[0]
120 |
121 | p_global_mean_all, p_global_std_all, q_global_mean_all, q_global_std_all = ss_g
122 |
123 | p_global_all = dist.Normal(p_global_mean_all, p_global_std_all)
124 |
125 | q_global_all = dist.Normal(q_global_mean_all, q_global_std_all)
126 |
127 | ss_p_z = self.p_z_given_x_or_g_net(global_dec)
128 |
129 | # (bs, dim, num_cell, num_cell)
130 | (
131 | p_pres_logits,
132 | p_where_mean,
133 | p_where_std,
134 | p_depth_mean,
135 | p_depth_std,
136 | p_what_mean,
137 | p_what_std,
138 | ) = ss_p_z
139 |
140 | p_pres_given_g_probs_reshaped = torch.sigmoid(
141 | p_pres_logits.permute(0, 2, 3, 1).reshape(
142 | bs * self.args.arch.num_cell ** 2, -1
143 | )
144 | )
145 |
146 | p_where_given_g = dist.Normal(
147 | p_where_mean.permute(0, 2, 3, 1).reshape(
148 | bs * self.args.arch.num_cell ** 2, -1
149 | ),
150 | p_where_std.permute(0, 2, 3, 1).reshape(
151 | bs * self.args.arch.num_cell ** 2, -1
152 | ),
153 | )
154 | p_depth_given_g = dist.Normal(
155 | p_depth_mean.permute(0, 2, 3, 1).reshape(
156 | bs * self.args.arch.num_cell ** 2, -1
157 | ),
158 | p_depth_std.permute(0, 2, 3, 1).reshape(
159 | bs * self.args.arch.num_cell ** 2, -1
160 | ),
161 | )
162 | p_what_given_g = dist.Normal(
163 | p_what_mean.permute(0, 2, 3, 1).reshape(
164 | bs * self.args.arch.num_cell ** 2, -1
165 | ),
166 | p_what_std.permute(0, 2, 3, 1).reshape(
167 | bs * self.args.arch.num_cell ** 2, -1
168 | ),
169 | )
170 |
171 | ss_q_z = self.p_z_given_x_or_g_net(img_enc, ss_p_z=ss_p_z)
172 |
173 | # (bs, dim, num_cell, num_cell)
174 | (
175 | q_pres_logits,
176 | q_where_mean,
177 | q_where_std,
178 | q_depth_mean,
179 | q_depth_std,
180 | q_what_mean,
181 | q_what_std,
182 | ) = ss_q_z
183 |
184 | q_pres_given_x_and_g_probs_reshaped = torch.sigmoid(
185 | q_pres_logits.permute(0, 2, 3, 1).reshape(
186 | bs * self.args.arch.num_cell ** 2, -1
187 | )
188 | )
189 |
190 | q_where_given_x_and_g = dist.Normal(
191 | q_where_mean.permute(0, 2, 3, 1).reshape(
192 | bs * self.args.arch.num_cell ** 2, -1
193 | ),
194 | q_where_std.permute(0, 2, 3, 1).reshape(
195 | bs * self.args.arch.num_cell ** 2, -1
196 | ),
197 | )
198 | q_depth_given_x_and_g = dist.Normal(
199 | q_depth_mean.permute(0, 2, 3, 1).reshape(
200 | bs * self.args.arch.num_cell ** 2, -1
201 | ),
202 | q_depth_std.permute(0, 2, 3, 1).reshape(
203 | bs * self.args.arch.num_cell ** 2, -1
204 | ),
205 | )
206 | q_what_given_x_and_g = dist.Normal(
207 | q_what_mean.permute(0, 2, 3, 1).reshape(
208 | bs * self.args.arch.num_cell ** 2, -1
209 | ),
210 | q_what_std.permute(0, 2, 3, 1).reshape(
211 | bs * self.args.arch.num_cell ** 2, -1
212 | ),
213 | )
214 |
215 | if self.args.arch.phase_background:
216 | # lv_p_bg, ss_p_bg = self.ss_p_bg_given_g(lv_g)
217 | lv_p_bg, ss_p_bg = self.p_bg_given_g_net(lv_g[0], phase_use_mode=False)
218 | p_bg_mean, p_bg_std = ss_p_bg
219 | else:
220 | lv_p_bg = [self.background.new_zeros(1, 1)]
221 | p_bg_mean = self.background.new_zeros(1, 1)
222 | p_bg_std = self.background.new_ones(1, 1)
223 |
224 | p_bg = dist.Normal(p_bg_mean, p_bg_std)
225 |
226 | pa_recon, lv_z = self.lv_p_x_given_z_and_bg(ss_q_z, lv_q_bg, global_step)
227 | *pa_recon, patches, masks = pa_recon
228 | canvas = pa_recon[0]
229 | background = pa_recon[-1]
230 |
231 | z_pres, z_where, z_depth, z_what, z_where_origin = lv_z
232 |
233 | p_dists = [
234 | p_global_all,
235 | p_pres_given_g_probs_reshaped,
236 | p_where_given_g,
237 | p_depth_given_g,
238 | p_what_given_g,
239 | p_bg,
240 | ]
241 |
242 | q_dists = [
243 | q_global_all,
244 | q_pres_given_x_and_g_probs_reshaped,
245 | q_where_given_x_and_g,
246 | q_depth_given_x_and_g,
247 | q_what_given_x_and_g,
248 | q_bg,
249 | ]
250 |
251 | log_like, kl, log_imp = self.elbo(
252 | x, p_dists, q_dists, lv_z, lv_g, lv_q_bg, pa_recon, global_step
253 | )
254 |
255 | self.log = {}
256 |
257 | if self.args.log.phase_log:
258 | pa_recon_from_q_g, _ = self.get_recon_from_q_g(
259 | global_step, global_dec=global_dec, lv_g=lv_g
260 | )
261 |
262 | z_pres_permute = z_pres.permute(0, 2, 3, 1)
263 | self.log = {
264 | "z_what": z_what.permute(0, 2, 3, 1).reshape(
265 | -1, self.args.z.z_what_dim
266 | ),
267 | "z_where_scale": z_where.permute(0, 2, 3, 1).reshape(
268 | -1, self.args.z.z_where_dim
269 | )[:, : self.args.z.z_where_scale_dim],
270 | "z_where_shift": z_where.permute(0, 2, 3, 1).reshape(
271 | -1, self.args.z.z_where_dim
272 | )[:, self.args.z.z_where_scale_dim :],
273 | "z_where_origin": z_where_origin.permute(0, 2, 3, 1).reshape(
274 | -1, self.args.z.z_where_dim
275 | ),
276 | "z_pres": z_pres_permute,
277 | "p_pres_probs": p_pres_given_g_probs_reshaped,
278 | "p_pres_logits": p_pres_logits,
279 | "p_what_std": p_what_std.permute(0, 2, 3, 1).reshape(
280 | -1, self.args.z.z_what_dim
281 | )[z_pres_permute.view(-1) > 0.05],
282 | "p_what_mean": p_what_mean.permute(0, 2, 3, 1).reshape(
283 | -1, self.args.z.z_what_dim
284 | )[z_pres_permute.view(-1) > 0.05],
285 | "p_where_scale_std": p_where_std.permute(0, 2, 3, 1).reshape(
286 | -1, self.args.z.z_where_dim
287 | )[z_pres_permute.view(-1) > 0.05][:, : self.args.z.z_where_scale_dim],
288 | "p_where_scale_mean": p_where_mean.permute(0, 2, 3, 1).reshape(
289 | -1, self.args.z.z_where_dim
290 | )[z_pres_permute.view(-1) > 0.05][:, : self.args.z.z_where_scale_dim],
291 | "p_where_shift_std": p_where_std.permute(0, 2, 3, 1).reshape(
292 | -1, self.args.z.z_where_dim
293 | )[z_pres_permute.view(-1) > 0.05][:, self.args.z.z_where_scale_dim :],
294 | "p_where_shift_mean": p_where_mean.permute(0, 2, 3, 1).reshape(
295 | -1, self.args.z.z_where_dim
296 | )[z_pres_permute.view(-1) > 0.05][:, self.args.z.z_where_scale_dim :],
297 | "q_pres_probs": q_pres_given_x_and_g_probs_reshaped,
298 | "q_pres_logits": q_pres_logits,
299 | "q_what_std": q_what_std.permute(0, 2, 3, 1).reshape(
300 | -1, self.args.z.z_what_dim
301 | )[z_pres_permute.view(-1) > 0.05],
302 | "q_what_mean": q_what_mean.permute(0, 2, 3, 1).reshape(
303 | -1, self.args.z.z_what_dim
304 | )[z_pres_permute.view(-1) > 0.05],
305 | "q_where_scale_std": q_where_std.permute(0, 2, 3, 1).reshape(
306 | -1, self.args.z.z_where_dim
307 | )[z_pres_permute.view(-1) > 0.05][:, : self.args.z.z_where_scale_dim],
308 | "q_where_scale_mean": q_where_mean.permute(0, 2, 3, 1).reshape(
309 | -1, self.args.z.z_where_dim
310 | )[z_pres_permute.view(-1) > 0.05][:, : self.args.z.z_where_scale_dim],
311 | "q_where_shift_std": q_where_std.permute(0, 2, 3, 1).reshape(
312 | -1, self.args.z.z_where_dim
313 | )[z_pres_permute.view(-1) > 0.05][:, self.args.z.z_where_scale_dim :],
314 | "q_where_shift_mean": q_where_mean.permute(0, 2, 3, 1).reshape(
315 | -1, self.args.z.z_where_dim
316 | )[z_pres_permute.view(-1) > 0.05][:, self.args.z.z_where_scale_dim :],
317 | "z_depth": z_depth.permute(0, 2, 3, 1).reshape(
318 | -1, self.args.z.z_depth_dim
319 | ),
320 | "p_depth_std": p_depth_std.permute(0, 2, 3, 1).reshape(
321 | -1, self.args.z.z_depth_dim
322 | )[z_pres_permute.view(-1) > 0.05],
323 | "p_depth_mean": p_depth_mean.permute(0, 2, 3, 1).reshape(
324 | -1, self.args.z.z_depth_dim
325 | )[z_pres_permute.view(-1) > 0.05],
326 | "q_depth_std": q_depth_std.permute(0, 2, 3, 1).reshape(
327 | -1, self.args.z.z_depth_dim
328 | )[z_pres_permute.view(-1) > 0.05],
329 | "q_depth_mean": q_depth_mean.permute(0, 2, 3, 1).reshape(
330 | -1, self.args.z.z_depth_dim
331 | )[z_pres_permute.view(-1) > 0.05],
332 | "recon": pa_recon[0],
333 | "recon_from_q_g": pa_recon_from_q_g[0],
334 | "log_prob_x_given_g": dist.Normal(
335 | pa_recon_from_q_g[0], self.args.const.likelihood_sigma
336 | )
337 | .log_prob(x)
338 | .flatten(start_dim=1)
339 | .sum(1),
340 | "global_dec": global_dec,
341 | }
342 | z_global_all = lv_g[0]
343 | for i in range(self.args.arch.draw_step):
344 | self.log[f"z_global_step_{i}"] = z_global_all[:, i]
345 | self.log[f"q_global_mean_step_{i}"] = q_global_mean_all[:, i]
346 | self.log[f"q_global_std_step_{i}"] = q_global_std_all[:, i]
347 | self.log[f"p_global_mean_step_{i}"] = p_global_mean_all[:, i]
348 | self.log[f"p_global_std_step_{i}"] = p_global_std_all[:, i]
349 | if self.args.arch.phase_background:
350 | self.log["z_bg"] = lv_q_bg[0]
351 | self.log["p_bg_mean"] = p_bg_mean
352 | self.log["p_bg_std"] = p_bg_std
353 | self.log["q_bg_mean"] = q_bg_mean
354 | self.log["q_bg_std"] = q_bg_std
355 | self.log["recon_from_q_g_bg"] = pa_recon_from_q_g[-1]
356 | self.log["recon_from_q_g_fg"] = pa_recon_from_q_g[1]
357 | self.log["recon_from_q_g_alpha"] = pa_recon_from_q_g[2]
358 | self.log["recon_bg"] = pa_recon[-1]
359 | self.log["recon_fg"] = pa_recon[1]
360 | self.log["recon_alpha"] = pa_recon[2]
361 |
362 | ss = [ss_q_z, ss_q_bg, ss_g[2:]]
363 | (
364 | aux_kl_pres,
365 | aux_kl_where,
366 | aux_kl_depth,
367 | aux_kl_what,
368 | aux_kl_bg,
369 | kl_pres,
370 | kl_where,
371 | kl_depth,
372 | kl_what,
373 | kl_global_all,
374 | kl_bg,
375 | ) = kl
376 |
377 | aux_kl_pres_raw = aux_kl_pres.mean(dim=0)
378 | aux_kl_where_raw = aux_kl_where.mean(dim=0)
379 | aux_kl_depth_raw = aux_kl_depth.mean(dim=0)
380 | aux_kl_what_raw = aux_kl_what.mean(dim=0)
381 | aux_kl_bg_raw = aux_kl_bg.mean(dim=0)
382 | kl_pres_raw = kl_pres.mean(dim=0)
383 | kl_where_raw = kl_where.mean(dim=0)
384 | kl_depth_raw = kl_depth.mean(dim=0)
385 | kl_what_raw = kl_what.mean(dim=0)
386 | kl_bg_raw = kl_bg.mean(dim=0)
387 |
388 | log_like = log_like.mean(dim=0)
389 |
390 | aux_kl_pres = aux_kl_pres_raw * self.args.train.beta_aux_pres
391 | aux_kl_where = aux_kl_where_raw * self.args.train.beta_aux_where
392 | aux_kl_depth = aux_kl_depth_raw * self.args.train.beta_aux_depth
393 | aux_kl_what = aux_kl_what_raw * self.args.train.beta_aux_what
394 | aux_kl_bg = aux_kl_bg_raw * self.args.train.beta_aux_bg
395 | kl_pres = kl_pres_raw * self.args.train.beta_pres
396 | kl_where = kl_where_raw * self.args.train.beta_where
397 | kl_depth = kl_depth_raw * self.args.train.beta_depth
398 | kl_what = kl_what_raw * self.args.train.beta_what
399 | kl_bg = kl_bg_raw * self.args.train.beta_bg
400 |
401 | kl_global_raw = kl_global_all.sum(dim=-1).mean(dim=0)
402 | kl_global = kl_global_raw * self.args.train.beta_global
403 |
404 | recon_loss = log_like
405 | kl = (
406 | kl_pres
407 | + kl_where
408 | + kl_depth
409 | + kl_what
410 | + kl_bg
411 | + kl_global
412 | + aux_kl_pres
413 | + aux_kl_where
414 | + aux_kl_depth
415 | + aux_kl_what
416 | + aux_kl_bg
417 | )
418 | elbo = recon_loss - kl
419 | loss = -elbo
420 |
421 | bbox = None
422 | if (not self.training) and generate_bbox:
423 | bbox = visualize(
424 | x,
425 | self.log["z_pres"].view(bs, self.args.arch.num_cell ** 2, -1),
426 | self.log["z_where_scale"].view(bs, self.args.arch.num_cell ** 2, -1),
427 | self.log["z_where_shift"].view(bs, self.args.arch.num_cell ** 2, -1),
428 | )
429 |
430 | bbox = (
431 | bbox.view(x.shape[0], -1, 3, self.args.data.img_h, self.args.data.img_w)
432 | .sum(1)
433 | .clamp(0.0, 1.0)
434 | )
435 | # bbox_img = x.cpu().expand(-1, 3, -1, -1).contiguous()
436 | # bbox_img[bbox.sum(dim=1, keepdim=True).expand(-1, 3, -1, -1) > 0.5] = \
437 | # bbox[bbox.sum(dim=1, keepdim=True).expand(-1, 3, -1, -1) > 0.5]
438 | ret = {
439 | "canvas": canvas,
440 | "canvas_with_bbox": bbox,
441 | "background": background,
442 | "steps": {
443 | "patch": patches,
444 | "mask": masks,
445 | "z_pres": z_pres.view(bs, self.args.arch.num_cell ** 2, -1),
446 | },
447 | "counts": torch.round(z_pres).flatten(1).sum(-1),
448 | "loss": loss,
449 | "elbo": elbo,
450 | "kl": kl,
451 | "rec_loss": recon_loss,
452 | "kl_pres": kl_pres,
453 | "kl_aux_pres": aux_kl_pres,
454 | "kl_where": kl_where,
455 | "kl_aux_where": aux_kl_where,
456 | "kl_what": kl_what,
457 | "kl_aux_what": aux_kl_what,
458 | "kl_depth": kl_depth,
459 | "kl_aux_depth": aux_kl_depth,
460 | "kl_bg": kl_bg,
461 | "kl_aux_bg": aux_kl_bg,
462 | "kl_global": kl_global,
463 | }
464 |
465 | # return pa_recon, log_like, kl, log_imp, lv_z + lv_g + lv_q_bg, ss, self.log
466 | return ret
467 |
468 | def get_recon_from_q_g(
469 | self,
470 | global_step,
471 | img: torch.Tensor = None,
472 | global_dec: torch.Tensor = None,
473 | lv_g: List = None,
474 | phase_use_mode: bool = False,
475 | ) -> Tuple:
476 |
477 | assert img is not None or (
478 | global_dec is not None and lv_g is not None
479 | ), "Provide either image or p_l_given_g"
480 | if img is not None:
481 | img_enc = self.img_encoder(img)
482 | pa_g, lv_g, ss_g = self.global_struct_draw(img_enc)
483 |
484 | global_dec = pa_g[0]
485 |
486 | if self.args.arch.phase_background:
487 | lv_p_bg, _ = self.p_bg_given_g_net(lv_g[0], phase_use_mode=phase_use_mode)
488 | else:
489 | lv_p_bg = [self.background.new_zeros(1, 1)]
490 |
491 | ss_z = self.p_z_given_x_or_g_net(global_dec)
492 |
493 | pa, lv = self.lv_p_x_given_z_and_bg(
494 | ss_z, lv_p_bg, global_step, phase_use_mode=phase_use_mode
495 | )
496 |
497 | lv = lv + lv_p_bg
498 |
499 | return pa, lv
500 |
501 | def elbo(
502 | self,
503 | x: torch.Tensor,
504 | p_dists: List,
505 | q_dists: List,
506 | lv_z: List,
507 | lv_g: List,
508 | lv_bg: List,
509 | pa_recon: List,
510 | global_step,
511 | ) -> Tuple:
512 |
513 | bs = x.size(0)
514 |
515 | (
516 | p_global_all,
517 | p_pres_given_g_probs_reshaped,
518 | p_where_given_g,
519 | p_depth_given_g,
520 | p_what_given_g,
521 | p_bg,
522 | ) = p_dists
523 |
524 | (
525 | q_global_all,
526 | q_pres_given_x_and_g_probs_reshaped,
527 | q_where_given_x_and_g,
528 | q_depth_given_x_and_g,
529 | q_what_given_x_and_g,
530 | q_bg,
531 | ) = q_dists
532 |
533 | y, y_nobg, alpha_map, bg = pa_recon
534 |
535 | if self.args.log.phase_nll:
536 | # (bs, dim, num_cell, num_cell)
537 | z_pres, _, z_depth, z_what, z_where_origin = lv_z
538 | # (bs * num_cell * num_cell, dim)
539 | z_pres_reshape = z_pres.permute(0, 2, 3, 1).reshape(
540 | -1, self.args.z.z_pres_dim
541 | )
542 | z_depth_reshape = z_depth.permute(0, 2, 3, 1).reshape(
543 | -1, self.args.z.z_depth_dim
544 | )
545 | z_what_reshape = z_what.permute(0, 2, 3, 1).reshape(
546 | -1, self.args.z.z_what_dim
547 | )
548 | z_where_origin_reshape = z_where_origin.permute(0, 2, 3, 1).reshape(
549 | -1, self.args.z.z_where_dim
550 | )
551 | # (bs, dim, 1, 1)
552 | z_bg = lv_bg[0]
553 | # (bs, step, dim, 1, 1)
554 | z_g = lv_g[0]
555 | else:
556 | z_pres, _, _, _, z_where_origin = lv_z
557 |
558 | z_pres_reshape = z_pres.permute(0, 2, 3, 1).reshape(
559 | -1, self.args.z.z_pres_dim
560 | )
561 |
562 | if self.args.train.p_pres_anneal_end_step != 0:
563 | self.aux_p_pres_probs = linear_schedule_tensor(
564 | global_step,
565 | self.args.train.p_pres_anneal_start_step,
566 | self.args.train.p_pres_anneal_end_step,
567 | self.args.train.p_pres_anneal_start_value,
568 | self.args.train.p_pres_anneal_end_value,
569 | self.aux_p_pres_probs.device,
570 | )
571 |
572 | if self.args.train.aux_p_scale_anneal_end_step != 0:
573 | aux_p_scale_mean = linear_schedule_tensor(
574 | global_step,
575 | self.args.train.aux_p_scale_anneal_start_step,
576 | self.args.train.aux_p_scale_anneal_end_step,
577 | self.args.train.aux_p_scale_anneal_start_value,
578 | self.args.train.aux_p_scale_anneal_end_value,
579 | self.aux_p_where_mean.device,
580 | )
581 | self.aux_p_where_mean[:, 0] = aux_p_scale_mean
582 |
583 | auxiliary_prior_z_pres_probs = self.aux_p_pres_probs[None][None, :].expand(
584 | bs * self.args.arch.num_cell ** 2, -1
585 | )
586 |
587 | aux_kl_pres = kl_divergence_bern_bern(
588 | q_pres_given_x_and_g_probs_reshaped, auxiliary_prior_z_pres_probs
589 | )
590 | aux_kl_where = dist.kl_divergence(
591 | q_where_given_x_and_g, self.aux_p_where
592 | ) * z_pres_reshape.clamp(min=1e-5)
593 | aux_kl_depth = dist.kl_divergence(
594 | q_depth_given_x_and_g, self.aux_p_depth
595 | ) * z_pres_reshape.clamp(min=1e-5)
596 | aux_kl_what = dist.kl_divergence(
597 | q_what_given_x_and_g, self.aux_p_what
598 | ) * z_pres_reshape.clamp(min=1e-5)
599 |
600 | kl_pres = kl_divergence_bern_bern(
601 | q_pres_given_x_and_g_probs_reshaped, p_pres_given_g_probs_reshaped
602 | )
603 |
604 | kl_where = dist.kl_divergence(q_where_given_x_and_g, p_where_given_g)
605 | kl_depth = dist.kl_divergence(q_depth_given_x_and_g, p_depth_given_g)
606 | kl_what = dist.kl_divergence(q_what_given_x_and_g, p_what_given_g)
607 |
608 | kl_global_all = dist.kl_divergence(q_global_all, p_global_all)
609 |
610 | if self.args.arch.phase_background:
611 | kl_bg = dist.kl_divergence(q_bg, p_bg)
612 | aux_kl_bg = dist.kl_divergence(q_bg, self.aux_p_bg)
613 | else:
614 | kl_bg = self.background.new_zeros(bs, 1)
615 | aux_kl_bg = self.background.new_zeros(bs, 1)
616 |
617 | log_like = dist.Normal(y, self.args.const.likelihood_sigma).log_prob(x)
618 |
619 | log_imp_list = []
620 | if self.args.log.phase_nll:
621 | log_pres_prior = z_pres_reshape * torch.log(
622 | p_pres_given_g_probs_reshaped + self.args.const.eps
623 | ) + (1 - z_pres_reshape) * torch.log(
624 | 1 - p_pres_given_g_probs_reshaped + self.args.const.eps
625 | )
626 | log_pres_pos = z_pres_reshape * torch.log(
627 | q_pres_given_x_and_g_probs_reshaped + self.args.const.eps
628 | ) + (1 - z_pres_reshape) * torch.log(
629 | 1 - q_pres_given_x_and_g_probs_reshaped + self.args.const.eps
630 | )
631 |
632 | log_imp_pres = log_pres_prior - log_pres_pos
633 |
634 | log_imp_depth = p_depth_given_g.log_prob(
635 | z_depth_reshape
636 | ) - q_depth_given_x_and_g.log_prob(z_depth_reshape)
637 |
638 | log_imp_what = p_what_given_g.log_prob(
639 | z_what_reshape
640 | ) - q_what_given_x_and_g.log_prob(z_what_reshape)
641 |
642 | log_imp_where = p_where_given_g.log_prob(
643 | z_where_origin_reshape
644 | ) - q_where_given_x_and_g.log_prob(z_where_origin_reshape)
645 |
646 | if self.args.arch.phase_background:
647 | log_imp_bg = p_bg.log_prob(z_bg) - q_bg.log_prob(z_bg)
648 | else:
649 | log_imp_bg = x.new_zeros(bs, 1)
650 |
651 | log_imp_g = p_global_all.log_prob(z_g) - q_global_all.log_prob(z_g)
652 |
653 | log_imp_list = [
654 | log_imp_pres.view(
655 | bs, self.args.arch.num_cell, self.args.arch.num_cell, -1
656 | )
657 | .flatten(start_dim=1)
658 | .sum(1),
659 | log_imp_depth.view(
660 | bs, self.args.arch.num_cell, self.args.arch.num_cell, -1
661 | )
662 | .flatten(start_dim=1)
663 | .sum(1),
664 | log_imp_what.view(
665 | bs, self.args.arch.num_cell, self.args.arch.num_cell, -1
666 | )
667 | .flatten(start_dim=1)
668 | .sum(1),
669 | log_imp_where.view(
670 | bs, self.args.arch.num_cell, self.args.arch.num_cell, -1
671 | )
672 | .flatten(start_dim=1)
673 | .sum(1),
674 | log_imp_bg.flatten(start_dim=1).sum(1),
675 | log_imp_g.flatten(start_dim=1).sum(1),
676 | ]
677 |
678 | return (
679 | log_like.flatten(start_dim=1).sum(1),
680 | [
681 | aux_kl_pres.view(
682 | bs, self.args.arch.num_cell, self.args.arch.num_cell, -1
683 | )
684 | .flatten(start_dim=1)
685 | .sum(-1),
686 | aux_kl_where.view(
687 | bs, self.args.arch.num_cell, self.args.arch.num_cell, -1
688 | )
689 | .flatten(start_dim=1)
690 | .sum(-1),
691 | aux_kl_depth.view(
692 | bs, self.args.arch.num_cell, self.args.arch.num_cell, -1
693 | )
694 | .flatten(start_dim=1)
695 | .sum(-1),
696 | aux_kl_what.view(
697 | bs, self.args.arch.num_cell, self.args.arch.num_cell, -1
698 | )
699 | .flatten(start_dim=1)
700 | .sum(-1),
701 | aux_kl_bg.flatten(start_dim=1).sum(-1),
702 | kl_pres.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1)
703 | .flatten(start_dim=1)
704 | .sum(-1),
705 | kl_where.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1)
706 | .flatten(start_dim=1)
707 | .sum(-1),
708 | kl_depth.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1)
709 | .flatten(start_dim=1)
710 | .sum(-1),
711 | kl_what.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1)
712 | .flatten(start_dim=1)
713 | .sum(-1),
714 | kl_global_all.flatten(start_dim=2).sum(-1),
715 | kl_bg.flatten(start_dim=1).sum(-1),
716 | ],
717 | log_imp_list,
718 | )
719 |
720 | # def get_img_enc(self, x: torch.Tensor) -> torch.Tensor:
721 | # """
722 | # :param x: (bs, inp_channel, img_h, img_w)
723 | # :return: img_enc: (bs, dim, num_cell, num_cell)
724 | # """
725 | #
726 | # img_enc = self.img_encoder(x)
727 | #
728 | # return img_enc
729 |
730 | # def ss_p_z_given_g(self, global_dec: torch.Tensor) -> List:
731 | # """
732 | # :param x: sample of z_global variable (bs, dim, 1, 1)
733 | # :return:
734 | # """
735 | # ss_z = self.p_z_given_g_net(global_dec)
736 | #
737 | # return ss_z
738 |
739 | # def ss_q_z_given_x(self, img_enc: torch.Tensor, global_dec: torch.Tensor, ss_p_z: List) -> List:
740 | # """
741 | # :param x: sample of z_global variable (bs, dim, 1, 1)
742 | # :return:
743 | # """
744 | # ss_z = self.p_z_given_x_or_g_net(img_enc, ss_p_z=ss_p_z)
745 | #
746 | # return ss_z
747 |
748 | # def ss_q_bg_given_x(self, x: torch.Tensor) -> Tuple:
749 | # """
750 | # :param x: (bs, dim, img_h, img_w)
751 | # :return:
752 | # """
753 | # lv_q_bg, ss_q_bg = self.q_bg_given_x_net(x)
754 | #
755 | # return lv_q_bg, ss_q_bg
756 |
757 | # def ss_p_bg_given_g(self, lv_g: List, phase_use_mode: bool = False) -> Tuple:
758 | # """
759 | # :param x: (bs, dim, img_h, img_w)
760 | # :return:
761 | # """
762 | # z_global_all = lv_g[0]
763 | # lv_p_bg, ss_p_bg = self.p_bg_given_g_net(z_global_all, phase_use_mode=phase_use_mode)
764 | #
765 | # return lv_p_bg, ss_p_bg
766 |
767 | def lv_p_x_given_z_and_bg(
768 | self, ss: List, lv_bg: List, global_step, phase_use_mode: bool = False
769 | ) -> Tuple:
770 | """
771 | :param z: (bs, z_what_dim)
772 | :return:
773 | """
774 | # x: (bs, inp_channel, img_h, img_w)
775 | pa, lv_z = self.local_latent_sampler(ss, phase_use_mode=phase_use_mode)
776 |
777 | o_att, a_att, *_ = pa
778 | z_pres, z_where, z_depth, *_ = lv_z
779 |
780 | if self.args.arch.phase_background:
781 | z_bg = lv_bg[0]
782 | pa_bg = self.p_bg_decoder(z_bg)
783 | y_bg = pa_bg[0]
784 | else:
785 | # pa_bg = [self.background.expand(lv_z[0].size(0), -1, -1, -1)]
786 | y_bg = self.background.expand(lv_z[0].size(0), -1, -1, -1)
787 |
788 | # pa = pa + pa_bg
789 |
790 | y, y_fg, alpha_map, patches, masks = self.render(
791 | o_att, a_att, y_bg, z_pres, z_where, z_depth, global_step
792 | )
793 |
794 | return [y, y_fg, alpha_map, y_bg, patches, masks], lv_z
795 |
796 | # def pa_bg_given_z_bg(self, lv_bg: List) -> List:
797 | # """
798 | # :param lv_bg[0]: (bs, z_bg_dim, 1, 1)
799 | # :return:
800 | # """
801 | # z_bg = lv_bg[0]
802 | # pa = self.p_bg_decoder(z_bg)
803 | #
804 | # return pa
805 |
806 | def render(self, o_att, a_att, bg, z_pres, z_where, z_depth, global_step) -> List:
807 | """
808 | :param pa: variables with size (bs, dim, num_cell, num_cell)
809 | :param lv_z: o and a with size (bs * num_cell * num_cell, dim)
810 | :return:
811 | """
812 |
813 | bs = z_pres.size(0)
814 |
815 | z_pres = z_pres.permute(0, 2, 3, 1).reshape(
816 | bs * self.args.arch.num_cell ** 2, -1
817 | )
818 | z_where = z_where.permute(0, 2, 3, 1).reshape(
819 | bs * self.args.arch.num_cell ** 2, -1
820 | )
821 | z_depth = z_depth.permute(0, 2, 3, 1).reshape(
822 | bs * self.args.arch.num_cell ** 2, -1
823 | )
824 |
825 | if self.args.arch.phase_overlap == True:
826 | if (
827 | self.args.train.phase_bg_alpha_curriculum
828 | and self.args.train.bg_alpha_curriculum_period[0]
829 | < global_step
830 | < self.args.train.bg_alpha_curriculum_period[1]
831 | ):
832 | z_pres = z_pres.clamp(max=0.99)
833 | a_att_hat = a_att * z_pres.view(-1, 1, 1, 1)
834 | y_att = a_att_hat * o_att
835 |
836 | # (bs, self.args.arch.num_cell * self.args.arch.num_cell, 3, img_h, img_w)
837 | y_att_full_res = spatial_transform(
838 | y_att,
839 | z_where,
840 | (
841 | bs * self.args.arch.num_cell ** 2,
842 | self.args.data.inp_channel,
843 | self.args.data.img_h,
844 | self.args.data.img_w,
845 | ),
846 | inverse=True,
847 | ).view(
848 | -1,
849 | self.args.arch.num_cell * self.args.arch.num_cell,
850 | self.args.data.inp_channel,
851 | self.args.data.img_h,
852 | self.args.data.img_w,
853 | )
854 | o_att_full_res = spatial_transform(
855 | o_att,
856 | z_where,
857 | (
858 | bs * self.args.arch.num_cell ** 2,
859 | self.args.data.inp_channel,
860 | self.args.data.img_h,
861 | self.args.data.img_w,
862 | ),
863 | inverse=True,
864 | ).view(
865 | -1,
866 | self.args.arch.num_cell * self.args.arch.num_cell,
867 | self.args.data.inp_channel,
868 | self.args.data.img_h,
869 | self.args.data.img_w,
870 | )
871 |
872 | # (self.args.arch.num_cell * self.args.arch.num_cell * bs, 1, glimpse_size, glimpse_size)
873 | importance_map = a_att_hat * torch.sigmoid(-z_depth).view(-1, 1, 1, 1)
874 | # (self.args.arch.num_cell * self.args.arch.num_cell * bs, 1, img_h, img_w)
875 | importance_map_full_res = spatial_transform(
876 | importance_map,
877 | z_where,
878 | (
879 | self.args.arch.num_cell * self.args.arch.num_cell * bs,
880 | 1,
881 | self.args.data.img_h,
882 | self.args.data.img_w,
883 | ),
884 | inverse=True,
885 | )
886 | # # (bs, self.args.arch.num_cell * self.args.arch.num_cell, 1, img_h, img_w)
887 | importance_map_full_res = importance_map_full_res.view(
888 | -1,
889 | self.args.arch.num_cell * self.args.arch.num_cell,
890 | 1,
891 | self.args.data.img_h,
892 | self.args.data.img_w,
893 | )
894 | importance_map_full_res_norm = importance_map_full_res / (
895 | importance_map_full_res.sum(dim=1, keepdim=True) + self.args.const.eps
896 | )
897 |
898 | # (bs, 3, img_h, img_w)
899 | y_nobg = (y_att_full_res * importance_map_full_res_norm).sum(dim=1)
900 |
901 | # (bs, self.args.arch.num_cell * self.args.arch.num_cell, 1, img_h, img_w)
902 | a_att_hat_full_res = spatial_transform(
903 | a_att_hat,
904 | z_where,
905 | (
906 | self.args.arch.num_cell * self.args.arch.num_cell * bs,
907 | 1,
908 | self.args.data.img_h,
909 | self.args.data.img_w,
910 | ),
911 | inverse=True,
912 | ).view(
913 | -1,
914 | self.args.arch.num_cell * self.args.arch.num_cell,
915 | 1,
916 | self.args.data.img_h,
917 | self.args.data.img_w,
918 | )
919 | alpha_map = a_att_hat_full_res.sum(dim=1)
920 | # (bs, 1, img_h, img_w)
921 | alpha_map = (
922 | alpha_map
923 | + (
924 | alpha_map.clamp(self.args.const.eps, 1 - self.args.const.eps)
925 | - alpha_map
926 | ).detach()
927 | )
928 |
929 | if self.args.train.phase_bg_alpha_curriculum:
930 | if (
931 | self.args.train.bg_alpha_curriculum_period[0]
932 | < global_step
933 | < self.args.train.bg_alpha_curriculum_period[1]
934 | ):
935 | alpha_map = (
936 | alpha_map.new_ones(alpha_map.size())
937 | * self.args.train.bg_alpha_curriculum_value
938 | )
939 | # y_nobg = alpha_map * y_nobg
940 | y = y_nobg + (1.0 - alpha_map) * bg
941 | else:
942 | y_att = a_att * o_att
943 |
944 | o_att_full_res = spatial_transform(
945 | o_att,
946 | z_where,
947 | (
948 | bs * self.args.arch.num_cell ** 2,
949 | self.args.data.inp_channel,
950 | self.args.data.img_h,
951 | self.args.data.img_w,
952 | ),
953 | inverse=True,
954 | ).view(
955 | -1,
956 | self.args.arch.num_cell * self.args.arch.num_cell,
957 | self.args.data.inp_channel,
958 | self.args.data.img_h,
959 | self.args.data.img_w,
960 | )
961 | a_att_hat_full_res = spatial_transform(
962 | a_att * z_pres.view(bs * self.args.arch.num_cell ** 2, 1, 1, 1),
963 | z_where,
964 | (
965 | self.args.arch.num_cell * self.args.arch.num_cell * bs,
966 | 1,
967 | self.args.data.img_h,
968 | self.args.data.img_w,
969 | ),
970 | inverse=True,
971 | ).view(
972 | -1,
973 | self.args.arch.num_cell * self.args.arch.num_cell,
974 | 1,
975 | self.args.data.img_h,
976 | self.args.data.img_w,
977 | )
978 |
979 | # (self.args.arch.num_cell * self.args.arch.num_cell * bs, 3, img_h, img_w)
980 | y_att_full_res = spatial_transform(
981 | y_att,
982 | z_where,
983 | (
984 | bs * self.args.arch.num_cell ** 2,
985 | self.args.data.inp_channel,
986 | self.args.data.img_h,
987 | self.args.data.img_w,
988 | ),
989 | inverse=True,
990 | )
991 | y = (
992 | (
993 | y_att_full_res
994 | * z_pres.view(bs * self.args.arch.num_cell ** 2, 1, 1, 1)
995 | )
996 | .view(
997 | bs,
998 | -1,
999 | self.args.data.inp_channel,
1000 | self.args.data.img_h,
1001 | self.args.data.img_w,
1002 | )
1003 | .sum(dim=1)
1004 | )
1005 | y_nobg = y
1006 | alpha_map = y.new_ones(y.size(0), 1, y.size(2), y.size(3))
1007 |
1008 | return y, y_nobg, alpha_map, o_att_full_res, a_att_hat_full_res
1009 |
1010 | def loss_function(self, x, global_step, generate_bbox=False):
1011 | return self.forward(x, global_step, generate_bbox)
1012 |
1013 |
1014 | def hyperparam_anneal(args, global_step):
1015 | if args.train.beta_aux_pres_anneal_end_step == 0:
1016 | args.train.beta_aux_pres = args.train.beta_aux_pres_anneal_start_value
1017 | else:
1018 | args.train.beta_aux_pres = linear_schedule(
1019 | global_step,
1020 | args.train.beta_aux_pres_anneal_start_step,
1021 | args.train.beta_aux_pres_anneal_end_step,
1022 | args.train.beta_aux_pres_anneal_start_value,
1023 | args.train.beta_aux_pres_anneal_end_value,
1024 | )
1025 |
1026 | if args.train.beta_aux_where_anneal_end_step == 0:
1027 | args.train.beta_aux_where = args.train.beta_aux_where_anneal_start_value
1028 | else:
1029 | args.train.beta_aux_where = linear_schedule(
1030 | global_step,
1031 | args.train.beta_aux_where_anneal_start_step,
1032 | args.train.beta_aux_where_anneal_end_step,
1033 | args.train.beta_aux_where_anneal_start_value,
1034 | args.train.beta_aux_where_anneal_end_value,
1035 | )
1036 |
1037 | if args.train.beta_aux_what_anneal_end_step == 0:
1038 | args.train.beta_aux_what = args.train.beta_aux_what_anneal_start_value
1039 | else:
1040 | args.train.beta_aux_what = linear_schedule(
1041 | global_step,
1042 | args.train.beta_aux_what_anneal_start_step,
1043 | args.train.beta_aux_what_anneal_end_step,
1044 | args.train.beta_aux_what_anneal_start_value,
1045 | args.train.beta_aux_what_anneal_end_value,
1046 | )
1047 |
1048 | if args.train.beta_aux_depth_anneal_end_step == 0:
1049 | args.train.beta_aux_depth = args.train.beta_aux_depth_anneal_start_value
1050 | else:
1051 | args.train.beta_aux_depth = linear_schedule(
1052 | global_step,
1053 | args.train.beta_aux_depth_anneal_start_step,
1054 | args.train.beta_aux_depth_anneal_end_step,
1055 | args.train.beta_aux_depth_anneal_start_value,
1056 | args.train.beta_aux_depth_anneal_end_value,
1057 | )
1058 |
1059 | if args.train.beta_aux_global_anneal_end_step == 0:
1060 | args.train.beta_aux_global = args.train.beta_aux_global_anneal_start_value
1061 | else:
1062 | args.train.beta_aux_global = linear_schedule(
1063 | global_step,
1064 | args.train.beta_aux_global_anneal_start_step,
1065 | args.train.beta_aux_global_anneal_end_step,
1066 | args.train.beta_aux_global_anneal_start_value,
1067 | args.train.beta_aux_global_anneal_end_value,
1068 | )
1069 |
1070 | if args.train.beta_aux_bg_anneal_end_step == 0:
1071 | args.train.beta_aux_bg = args.train.beta_aux_bg_anneal_start_value
1072 | else:
1073 | args.train.beta_aux_bg = linear_schedule(
1074 | global_step,
1075 | args.train.beta_aux_bg_anneal_start_step,
1076 | args.train.beta_aux_bg_anneal_end_step,
1077 | args.train.beta_aux_bg_anneal_start_value,
1078 | args.train.beta_aux_bg_anneal_end_value,
1079 | )
1080 |
1081 | ########################### split here ###########################
1082 | if args.train.beta_pres_anneal_end_step == 0:
1083 | args.train.beta_pres = args.train.beta_pres_anneal_start_value
1084 | else:
1085 | args.train.beta_pres = linear_schedule(
1086 | global_step,
1087 | args.train.beta_pres_anneal_start_step,
1088 | args.train.beta_pres_anneal_end_step,
1089 | args.train.beta_pres_anneal_start_value,
1090 | args.train.beta_pres_anneal_end_value,
1091 | )
1092 |
1093 | if args.train.beta_where_anneal_end_step == 0:
1094 | args.train.beta_where = args.train.beta_where_anneal_start_value
1095 | else:
1096 | args.train.beta_where = linear_schedule(
1097 | global_step,
1098 | args.train.beta_where_anneal_start_step,
1099 | args.train.beta_where_anneal_end_step,
1100 | args.train.beta_where_anneal_start_value,
1101 | args.train.beta_where_anneal_end_value,
1102 | )
1103 |
1104 | if args.train.beta_what_anneal_end_step == 0:
1105 | args.train.beta_what = args.train.beta_what_anneal_start_value
1106 | else:
1107 | args.train.beta_what = linear_schedule(
1108 | global_step,
1109 | args.train.beta_what_anneal_start_step,
1110 | args.train.beta_what_anneal_end_step,
1111 | args.train.beta_what_anneal_start_value,
1112 | args.train.beta_what_anneal_end_value,
1113 | )
1114 |
1115 | if args.train.beta_depth_anneal_end_step == 0:
1116 | args.train.beta_depth = args.train.beta_depth_anneal_start_value
1117 | else:
1118 | args.train.beta_depth = linear_schedule(
1119 | global_step,
1120 | args.train.beta_depth_anneal_start_step,
1121 | args.train.beta_depth_anneal_end_step,
1122 | args.train.beta_depth_anneal_start_value,
1123 | args.train.beta_depth_anneal_end_value,
1124 | )
1125 |
1126 | if args.train.beta_global_anneal_end_step == 0:
1127 | args.train.beta_global = args.train.beta_global_anneal_start_value
1128 | else:
1129 | args.train.beta_global = linear_schedule(
1130 | global_step,
1131 | args.train.beta_global_anneal_start_step,
1132 | args.train.beta_global_anneal_end_step,
1133 | args.train.beta_global_anneal_start_value,
1134 | args.train.beta_global_anneal_end_value,
1135 | )
1136 |
1137 | if args.train.tau_pres_anneal_end_step == 0:
1138 | args.train.tau_pres = args.train.tau_pres_anneal_start_value
1139 | else:
1140 | args.train.tau_pres = linear_schedule(
1141 | global_step,
1142 | args.train.tau_pres_anneal_start_step,
1143 | args.train.tau_pres_anneal_end_step,
1144 | args.train.tau_pres_anneal_start_value,
1145 | args.train.tau_pres_anneal_end_value,
1146 | )
1147 |
1148 | if args.train.beta_bg_anneal_end_step == 0:
1149 | args.train.beta_bg = args.train.beta_bg_anneal_start_value
1150 | else:
1151 | args.train.beta_bg = linear_schedule(
1152 | global_step,
1153 | args.train.beta_bg_anneal_start_step,
1154 | args.train.beta_bg_anneal_end_step,
1155 | args.train.beta_bg_anneal_start_value,
1156 | args.train.beta_bg_anneal_end_value,
1157 | )
1158 |
1159 | return args
1160 |
--------------------------------------------------------------------------------
/object_discovery/gnm/logging.py:
--------------------------------------------------------------------------------
1 | # Partially based on https://github.com/karazijal/clevrtex/blob/fe982ab224689526f5e2f83a2f542ba958d88abd/experiments/framework/vis_mixin.py
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from .metrics import align_masks_iou, dices
6 | import itertools
7 |
8 | import torch
9 | import torchvision as tv
10 | from ..segmentation_metrics import adjusted_rand_index
11 | from ..utils import cmap_tensor
12 |
13 | @torch.no_grad()
14 | def _to_img(img, lim=None, dim=-3):
15 | if lim:
16 | img = img[:lim]
17 | img = (img.clamp(0, 1) * 255).to(torch.uint8)
18 | if img.shape[dim] < 3:
19 | epd_dims = [-1 for _ in range(len(img.shape))]
20 | epd_dims[dim] = 3
21 | img = img.expand(*epd_dims)
22 | return img
23 |
24 |
25 | @torch.no_grad()
26 | def log_recons(input, output, patches, masks, background=None, pres=None):
27 | vis_imgs = []
28 | img = _to_img(input)
29 | vis_imgs.extend(img)
30 | omg = _to_img(output, lim=len(img))
31 | vis_imgs.extend(omg)
32 |
33 | if background is not None and not torch.all(background == 0.0):
34 | bg = _to_img(background, lim=len(img))
35 | vis_imgs.extend(bg)
36 |
37 | masks = masks[: len(img)]
38 | patches = patches[: len(img)]
39 | ms = masks * patches
40 | for sid in range(patches.size(1)):
41 | # p = (patches[:len(img), sid].clamp(0., 1.) * 255.).to(torch.uint8).detach().cpu()
42 | # if p.shape[1] == 3:
43 | # vis_imgs.extend(p)
44 | m = _to_img(ms[:, sid])
45 | m_hat = []
46 | if pres is not None:
47 | for i in range(0, len(img)):
48 | if pres[i, sid][0] == 1:
49 | m[i, 0, :2, :2] = 0
50 | m[i, 1, :2, :2] = 255
51 | m[i, 2, :2, :2] = 0
52 | m_hat.append(m[i])
53 | else:
54 | m_hat.append(m[i])
55 | else:
56 | m_hat.extend(m)
57 | vis_imgs.extend(m_hat)
58 | grid = tv.utils.make_grid(vis_imgs, pad_value=128, nrow=len(img), padding=1)
59 | return grid.permute([1, 2, 0]).cpu().numpy()
60 |
61 |
62 | @torch.no_grad()
63 | def log_images(input, output):
64 | vis_imgs = []
65 | img = _to_img(input)
66 | omg = _to_img(output, lim=len(img))
67 |
68 | for i, (i_img, o_img) in enumerate(zip(img, omg)):
69 | vis_imgs.append(i_img)
70 | vis_imgs.append(o_img)
71 | grid = tv.utils.make_grid(vis_imgs, pad_value=128, nrow=16)
72 | return grid.permute([1, 2, 0]).cpu().numpy()
73 |
74 |
75 | @torch.no_grad()
76 | def log_semantic_images(input, output, true_masks, pred_masks):
77 | assert len(true_masks.shape) == 5 and len(pred_masks.shape) == 5
78 | img = _to_img(input)
79 | omg = _to_img(output, lim=len(img))
80 | true_masks = true_masks[: len(img)].to(torch.float).argmax(1).squeeze(1)
81 | pred_masks = pred_masks[: len(img)].to(torch.float).argmax(1).squeeze(1)
82 | tms = (cmap_tensor(true_masks) * 255.0).to(torch.uint8)
83 | pms = (cmap_tensor(pred_masks) * 255.0).to(torch.uint8)
84 | vis_imgs = list(itertools.chain.from_iterable(zip(img, omg, tms, pms)))
85 | grid = tv.utils.make_grid(vis_imgs, pad_value=128, nrow=16)
86 | return grid.permute([1, 2, 0]).cpu().numpy()
87 |
88 |
89 | def gnm_log_validation_outputs(batch, batch_idx, output, is_global_zero):
90 | logs = {}
91 | img, masks, vis = batch
92 | masks = masks.transpose(-1, -2).transpose(-2, -3).to(torch.float)
93 |
94 | mse = F.mse_loss(output["canvas"], img, reduction="none").sum((1, 2, 3))
95 | logs["mse"] = mse
96 |
97 | ali_pmasks = None
98 | ali_tmasks = None
99 |
100 | # Transforms might have changed this.
101 | # cnts = torch.sum(vis, dim=-1) - 1 # -1 for discounting the background from visibility
102 | # estimate from masks
103 | cnts = (
104 | torch.round(masks.to(torch.float)).flatten(2).any(-1).to(torch.float).sum(-1)
105 | - 1
106 | )
107 |
108 | if "steps" in output:
109 | pred_masks = output["steps"]["mask"]
110 | pred_vis = output["steps"]["z_pres"].squeeze(-1)
111 |
112 | # `align_masks_iou` adds the background pixels to `ali_pmasks` (aligned
113 | # predicted masks).
114 | ali_pmasks, ali_tmasks, ious, ali_pvis, ali_tvis = align_masks_iou(
115 | pred_masks, masks, pred_mask_vis=pred_vis, true_mask_vis=vis, has_bg=False
116 | )
117 |
118 | ali_cvis = ali_pvis | ali_tvis
119 | num_paired_slots = ali_cvis.sum(-1) - 1
120 | mses = F.mse_loss(
121 | output["canvas"][:, None] * ali_pmasks,
122 | img[:, None] * ali_tmasks,
123 | reduction="none",
124 | ).sum((-1, -2, -3))
125 |
126 | bg_mse = mses[:, 0]
127 | logs["bg_mse"] = bg_mse
128 |
129 | slot_mse = mses[:, 1:].sum(-1) / num_paired_slots
130 | logs["slot_mse"] = slot_mse
131 |
132 | mious = ious[:, 1:].sum(-1) / num_paired_slots
133 | # mious = torch.where(zero_mask, 0., mious)
134 | logs["miou"] = mious
135 |
136 | dice = dices(ali_pmasks, ali_tmasks)
137 | mdice = dice[:, 1:].sum(-1) / num_paired_slots
138 | logs["dice"] = mdice
139 |
140 | # aris = ari(ali_pmasks, ali_tmasks)
141 | batch_size, num_entries, channels, height, width = ali_tmasks.shape
142 | ali_tmasks_reshaped = (
143 | torch.reshape(
144 | ali_tmasks.squeeze(), [batch_size, num_entries, height * width]
145 | )
146 | .permute([0, 2, 1])
147 | .to(torch.float)
148 | )
149 | batch_size, num_entries, channels, height, width = ali_pmasks.shape
150 | ali_pmasks_reshaped = (
151 | torch.reshape(
152 | ali_pmasks.squeeze(), [batch_size, num_entries, height * width]
153 | )
154 | .permute([0, 2, 1])
155 | .to(torch.float)
156 | )
157 | logs["ari_with_background"] = adjusted_rand_index(
158 | ali_tmasks_reshaped, ali_pmasks_reshaped
159 | )
160 |
161 | # aris_fg = ari(ali_pmasks, ali_tmasks, True).mean().detach()
162 | # `[..., 1:]` removes the background pixels group from the true mask.
163 | logs["ari"] = adjusted_rand_index(
164 | ali_tmasks_reshaped[..., 1:], ali_pmasks_reshaped
165 | )
166 |
167 | # Can also calculate ari (same as above line) without using the aligned
168 | # masks by directly adding the background to the predicted masks. The
169 | # background is everything that is not predicted as an object. This is
170 | # done automatically when aligning the masks in `align_masks_iou`.
171 | # pred_masks_with_background = torch.cat([1 - pred_masks.sum(1, keepdim=True), pred_masks], 1)
172 | # batch_size, num_entries, channels, height, width = masks.shape
173 | # masks_reshaped = torch.reshape(masks, [batch_size, num_entries, height * width]).permute([0, 2, 1]).to(torch.float)
174 | # batch_size, num_entries, channels, height, width = pred_masks_with_background.shape
175 | # pred_masks_with_background_reshaped = torch.reshape(pred_masks_with_background, [batch_size, num_entries, height * width]).permute([0, 2, 1]).to(torch.float)
176 | # logs["ari4"] = adjusted_rand_index(masks_reshaped[..., 1:], pred_masks_with_background_reshaped)
177 |
178 | batch_size, num_entries, channels, height, width = masks.shape
179 | masks_reshaped = (
180 | torch.reshape(masks, [batch_size, num_entries, height * width])
181 | .permute([0, 2, 1])
182 | .to(torch.float)
183 | )
184 | batch_size, num_entries, channels, height, width = pred_masks.shape
185 | pred_masks_reshaped = (
186 | torch.reshape(pred_masks, [batch_size, num_entries, height * width])
187 | .permute([0, 2, 1])
188 | .to(torch.float)
189 | )
190 | logs["ari_no_background"] = adjusted_rand_index(
191 | masks_reshaped[..., 1:], pred_masks_reshaped
192 | )
193 |
194 | pred_counts = output["counts"].detach().to(torch.int)
195 | logs["acc"] = (pred_counts == cnts).to(float)
196 | logs["cnt"] = pred_counts.to(float)
197 |
198 | images = {}
199 | if batch_idx == 0 and is_global_zero:
200 | images["output"] = log_images(img, output["canvas_with_bbox"])
201 | if "steps" in output:
202 | images["recon"] = log_recons(
203 | img[:32],
204 | output["canvas"],
205 | output["steps"]["patch"],
206 | output["steps"]["mask"],
207 | output.get("background", None),
208 | pres=output["steps"]["z_pres"],
209 | )
210 |
211 | # If masks have been aligned; log semantic map
212 | if ali_pmasks is not None and ali_tmasks is not None:
213 | images["segmentation"] = log_semantic_images(
214 | img[:32], output["canvas"], ali_tmasks, ali_pmasks
215 | )
216 |
217 | return logs, images
218 |
--------------------------------------------------------------------------------
/object_discovery/gnm/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn.functional import one_hot
3 |
4 | from scipy.optimize import linear_sum_assignment
5 |
6 |
7 | def binarise_masks(pred_masks, num_classes=None):
8 | assert len(pred_masks.shape) == 5
9 | assert pred_masks.shape[2] == 1
10 | num_classes = num_classes or pred_masks.shape[1]
11 | return (
12 | one_hot(pred_masks.argmax(axis=1), num_classes=num_classes)
13 | .transpose(-1, -2)
14 | .transpose(-2, -3)
15 | .transpose(-3, -4)
16 | .to(bool)
17 | )
18 |
19 |
20 | def convert_masks(pred_masks, true_masks, correct_bg=True):
21 | if correct_bg:
22 | pred_masks = torch.cat(
23 | [1.0 - pred_masks.sum(axis=1, keepdims=True), pred_masks], dim=1
24 | )
25 | num_classes = max(pred_masks.shape[1], true_masks.shape[1])
26 | pred_masks = binarise_masks(pred_masks, num_classes=num_classes)
27 | true_masks = binarise_masks(true_masks, num_classes=num_classes)
28 | # if torch.any(torch.isnan(pred_masks)) or torch.any(torch.isinf(pred_masks)): import ipdb; ipdb.set_trace()
29 | # if torch.any(torch.isnan(true_masks)) or torch.any(torch.isinf(true_masks)): import ipdb; ipdb.set_trace()
30 | return pred_masks, true_masks
31 |
32 |
33 | def iou_matching(pred_masks, true_masks, threshold=1e-2):
34 | """The order of true_masks is preserved up to potentially missing background dim (in shic """
35 | assert pred_masks.shape[0] == true_masks.shape[0], "Batchsize mismatch"
36 | # true_masks = true_masks.to(torch.bool)
37 | assert (
38 | pred_masks.dtype == true_masks.dtype
39 | ), f"Dtype mismatch ({pred_masks.dtype}!={true_masks.dtype})"
40 |
41 | if pred_masks.dtype != torch.bool:
42 | pred_masks, true_masks = convert_masks(pred_masks, true_masks, correct_bg=False)
43 | assert pred_masks.shape[-3:] == true_masks.shape[-3:], "Mask shape mismatch"
44 |
45 | pred_masks = pred_masks.to(float)
46 | true_masks = true_masks.to(float)
47 |
48 | tspec = dict(device=pred_masks.device)
49 | iou_matrix = torch.zeros(
50 | pred_masks.shape[0], pred_masks.shape[1], true_masks.shape[1], **tspec
51 | )
52 | true_masks_sums = true_masks.sum((-1, -2, -3))
53 | pred_masks_sums = pred_masks.sum((-1, -2, -3))
54 |
55 | for pi in range(pred_masks.shape[1]):
56 | pandt = (pred_masks[:, pi : pi + 1] * true_masks).sum((-1, -2, -3))
57 | port = pred_masks_sums[:, pi : pi + 1] + true_masks_sums
58 | iou_matrix[:, pi] = (pandt + 1e-2) / (port + 1e-2)
59 | iou_matrix[pred_masks_sums[:, pi] == 0.0, pi] = 0.0
60 |
61 | for ti in range(true_masks.shape[1]):
62 | iou_matrix[true_masks_sums[:, ti] == 0.0, :, ti] = 0.0
63 |
64 | cost_matrix = iou_matrix.cpu().detach().numpy()
65 | inds = torch.zeros(
66 | pred_masks.shape[0],
67 | 2,
68 | min(pred_masks.shape[1], true_masks.shape[1]),
69 | dtype=torch.int64,
70 | **tspec,
71 | )
72 | ious = torch.zeros(
73 | pred_masks.shape[0],
74 | min(pred_masks.shape[1], true_masks.shape[1]),
75 | dtype=float,
76 | **tspec,
77 | )
78 | for bi in range(cost_matrix.shape[0]):
79 | col_ind, row_ind = linear_sum_assignment(cost_matrix[bi].T, maximize=True)
80 | inds[bi, 0] = torch.tensor(row_ind, **tspec)
81 | inds[bi, 1] = torch.tensor(col_ind, **tspec)
82 | ious[bi] = torch.tensor(cost_matrix[bi, row_ind, col_ind], **tspec)
83 | # if torch.any(torch.isnan(inds)) or torch.any(torch.isinf(inds)): import ipdb; ipdb.set_trace()
84 | # if torch.any(torch.isnan(ious)) or torch.any(torch.isinf(ious)): import ipdb; ipdb.set_trace()
85 | return inds, ious
86 |
87 |
88 | def align_masks_iou(
89 | pred_mask, true_mask, pred_mask_vis=None, true_mask_vis=None, has_bg=False
90 | ):
91 | pred_mask, true_mask = convert_masks(pred_mask, true_mask, correct_bg=not has_bg)
92 | inds, ious = iou_matching(pred_mask, true_mask)
93 |
94 | # Reindex the masks into alighned order
95 | B, S, *s = pred_mask.shape
96 | bias = S * torch.arange(B, device=pred_mask.device)[:, None]
97 | pred_mask = pred_mask.reshape(B * S, *s)[(bias + inds[:, 0]).flatten()].view(
98 | B, S, *s
99 | )
100 | true_mask = true_mask.reshape(B * S, *s)[(bias + inds[:, 1]).flatten()].view(
101 | B, S, *s
102 | )
103 |
104 | ret = pred_mask, true_mask, ious
105 |
106 | if pred_mask_vis is not None:
107 | if pred_mask_vis.dtype != torch.bool:
108 | pred_mask_vis = pred_mask_vis > 0.5
109 |
110 | if has_bg:
111 | pred_mask_vis = torch.cat(
112 | [
113 | pred_mask_vis,
114 | torch.zeros(
115 | B,
116 | S - pred_mask_vis.shape[1],
117 | dtype=pred_mask_vis.dtype,
118 | device=pred_mask_vis.device,
119 | ),
120 | ],
121 | axis=1,
122 | )
123 | else:
124 | pred_mask_vis = torch.cat(
125 | [
126 | torch.ones(
127 | B, 1, dtype=pred_mask_vis.dtype, device=pred_mask_vis.device
128 | ),
129 | pred_mask_vis,
130 | torch.zeros(
131 | B,
132 | S - pred_mask_vis.shape[1] - 1,
133 | dtype=pred_mask_vis.dtype,
134 | device=pred_mask_vis.device,
135 | ),
136 | ],
137 | axis=1,
138 | )
139 | pred_mask_vis = pred_mask_vis.reshape(B * S)[
140 | (bias + inds[:, 0]).flatten()
141 | ].view(B, S)
142 | ret += (pred_mask_vis,)
143 |
144 | if true_mask_vis is not None:
145 | if true_mask_vis.dtype != torch.bool:
146 | true_mask_vis = true_mask_vis > 0.5
147 |
148 | true_mask_vis = torch.cat(
149 | [
150 | true_mask_vis,
151 | torch.zeros(
152 | B,
153 | S - true_mask_vis.shape[1],
154 | dtype=true_mask_vis.dtype,
155 | device=true_mask_vis.device,
156 | ),
157 | ],
158 | axis=1,
159 | )
160 | true_mask_vis = true_mask_vis.reshape(B * S)[
161 | (bias + inds[:, 1]).flatten()
162 | ].view(B, S)
163 | ret += (true_mask_vis,)
164 | # for i,r in enumerate(ret):
165 | # if torch.any(torch.isnan(r)) or torch.any(torch.isinf(r)):
166 | # print('/tStopcon:', i)
167 | # import ipdb; ipdb.set_trace()
168 | return ret
169 |
170 |
171 | def dices(pred_mask, true_mask):
172 | dice = (
173 | 2
174 | * (pred_mask * true_mask).sum((-3, -2, -1))
175 | / (pred_mask.sum((-3, -2, -1)) + true_mask.sum((-3, -2, -1)))
176 | )
177 | dice = torch.where(torch.isnan(dice) | torch.isinf(dice), 0.0, dice.to(float))
178 | # if torch.any(torch.isnan(dice)) or torch.any(torch.isinf(dice)): import ipdb; ipdb.set_trace()
179 | return dice
180 |
181 |
182 | # def ari(pred_mask, true_mask, skip_0=False):
183 | # B = pred_mask.shape[0]
184 | # pm = pred_mask.to(int).argmax(axis=1).squeeze().view(B, -1).cpu().detach().numpy()
185 | # tm = true_mask.to(int).argmax(axis=1).squeeze().view(B, -1).cpu().detach().numpy()
186 | # aris = []
187 | # for bi in range(B):
188 | # t = tm[bi]
189 | # p = pm[bi]
190 | # if skip_0:
191 | # p = p[t > 0]
192 | # t = t[t > 0]
193 | # ari_score = adjusted_rand_score(t, p)
194 | # aris.append(ari_score)
195 | # aris = torch.tensor(np.array(aris), device=pred_mask.device)
196 | # # if torch.any(torch.isnan(aris)) or torch.any(torch.isinf(aris)): import ipdb; ipdb.set_trace()
197 | # return aris
198 |
--------------------------------------------------------------------------------
/object_discovery/gnm/module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from typing import Any, List, Tuple
5 | from .submodule import StackConvNorm, StackSubPixelNorm, StackMLP, ConvLSTMCell
6 | from torch.distributions import RelaxedBernoulli, Normal
7 |
8 |
9 | class LocalLatentDecoder(nn.Module):
10 | def __init__(self, args: Any):
11 | super(LocalLatentDecoder, self).__init__()
12 | self.args = args
13 |
14 | pwdw_net_inp_dim = self.args.arch.img_enc_dim
15 |
16 | self.pwdw_net = StackConvNorm(
17 | pwdw_net_inp_dim,
18 | self.args.arch.pwdw.pwdw_filters,
19 | self.args.arch.pwdw.pwdw_kernel_sizes,
20 | self.args.arch.pwdw.pwdw_strides,
21 | self.args.arch.pwdw.pwdw_groups,
22 | norm_act_final=True,
23 | )
24 |
25 | self.q_depth_net = nn.Conv2d(
26 | self.args.arch.pwdw.pwdw_filters[-1], self.args.z.z_depth_dim * 2, 1
27 | )
28 | self.q_where_net = nn.Conv2d(
29 | self.args.arch.pwdw.pwdw_filters[-1], self.args.z.z_where_dim * 2, 1
30 | )
31 | self.q_what_net = nn.Conv2d(
32 | self.args.arch.pwdw.pwdw_filters[-1], self.args.z.z_what_dim * 2, 1
33 | )
34 | self.q_pres_net = nn.Conv2d(
35 | self.args.arch.pwdw.pwdw_filters[-1], self.args.z.z_pres_dim, 1
36 | )
37 |
38 | torch.nn.init.uniform_(self.q_where_net.weight.data, -0.01, 0.01)
39 | # scale
40 | torch.nn.init.constant_(self.q_where_net.bias.data[0], -1.0)
41 | # ratio, x, y, std
42 | torch.nn.init.constant_(self.q_where_net.bias.data[1:], 0)
43 |
44 | def forward(self, img_enc: torch.Tensor, ss_p_z: List = None) -> List:
45 | """
46 |
47 | :param img_enc: (bs, dim, 4, 4)
48 | :param global_dec: (bs, dim, 4, 4)
49 | :return:
50 | """
51 |
52 | if ss_p_z is not None:
53 | (
54 | p_pres_logits,
55 | p_where_mean,
56 | p_where_std,
57 | p_depth_mean,
58 | p_depth_std,
59 | p_what_mean,
60 | p_what_std,
61 | ) = ss_p_z
62 |
63 | pwdw_inp = img_enc
64 |
65 | pwdw_ss = self.pwdw_net(pwdw_inp)
66 |
67 | q_pres_logits = (
68 | self.q_pres_net(pwdw_ss).tanh() * self.args.const.pres_logit_scale
69 | )
70 |
71 | # q_where_mean, q_where_std: (bs, dim, num_cell, num_cell)
72 | q_where_mean, q_where_std = self.q_where_net(pwdw_ss).chunk(2, 1)
73 | q_where_std = F.softplus(q_where_std)
74 |
75 | # q_depth_mean, q_depth_std: (bs, dim, num_cell, num_cell)
76 | q_depth_mean, q_depth_std = self.q_depth_net(pwdw_ss).chunk(2, 1)
77 | q_depth_std = F.softplus(q_depth_std)
78 |
79 | q_what_mean, q_what_std = self.q_what_net(pwdw_ss).chunk(2, 1)
80 | q_what_std = F.softplus(q_what_std)
81 |
82 | ss = [
83 | q_pres_logits,
84 | q_where_mean,
85 | q_where_std,
86 | q_depth_mean,
87 | q_depth_std,
88 | q_what_mean,
89 | q_what_std,
90 | ]
91 |
92 | return ss
93 |
94 |
95 | class LocalSampler(nn.Module):
96 | def __init__(self, args: Any):
97 | super(LocalSampler, self).__init__()
98 | self.args = args
99 |
100 | self.z_what_decoder_net = StackSubPixelNorm(
101 | self.args.z.z_what_dim,
102 | self.args.arch.conv.p_what_decoder_filters,
103 | self.args.arch.conv.p_what_decoder_kernel_sizes,
104 | self.args.arch.conv.p_what_decoder_upscales,
105 | self.args.arch.conv.p_what_decoder_groups,
106 | norm_act_final=False,
107 | )
108 |
109 | self.register_buffer(
110 | "offset",
111 | torch.stack(
112 | torch.meshgrid(
113 | torch.arange(args.arch.num_cell).float(),
114 | torch.arange(args.arch.num_cell).float(),
115 | )[::-1],
116 | dim=0,
117 | ).view(1, 2, args.arch.num_cell, args.arch.num_cell),
118 | )
119 |
120 | def forward(self, ss: List, phase_use_mode: bool = False) -> Tuple:
121 |
122 | (
123 | p_pres_logits,
124 | p_where_mean,
125 | p_where_std,
126 | p_depth_mean,
127 | p_depth_std,
128 | p_what_mean,
129 | p_what_std,
130 | ) = ss
131 |
132 | if phase_use_mode:
133 | z_pres = (p_pres_logits > 0).float()
134 | else:
135 | z_pres = RelaxedBernoulli(
136 | logits=p_pres_logits, temperature=self.args.train.tau_pres
137 | ).rsample()
138 |
139 | # z_where_scale, z_where_shift: (bs, dim, num_cell, num_cell)
140 | if phase_use_mode:
141 | z_where_scale, z_where_shift = p_where_mean.chunk(2, 1)
142 | else:
143 | z_where_scale, z_where_shift = (
144 | Normal(p_where_mean, p_where_std).rsample().chunk(2, 1)
145 | )
146 |
147 | # z_where_origin: (bs, dim, num_cell, num_cell)
148 | z_where_origin = torch.cat(
149 | [z_where_scale.detach(), z_where_shift.detach()], dim=1
150 | )
151 |
152 | z_where_shift = (2.0 / self.args.arch.num_cell) * (
153 | self.offset + 0.5 + torch.tanh(z_where_shift)
154 | ) - 1.0
155 |
156 | scale, ratio = z_where_scale.chunk(2, 1)
157 | scale = scale.sigmoid()
158 | ratio = torch.exp(ratio)
159 | ratio_sqrt = ratio.sqrt()
160 | z_where_scale = torch.cat([scale / ratio_sqrt, scale * ratio_sqrt], dim=1)
161 | # z_where: (bs, dim, num_cell, num_cell)
162 | z_where = torch.cat([z_where_scale, z_where_shift], dim=1)
163 |
164 | if phase_use_mode:
165 | z_depth = p_depth_mean
166 | z_what = p_what_mean
167 | else:
168 | z_depth = Normal(p_depth_mean, p_depth_std).rsample()
169 | z_what = Normal(p_what_mean, p_what_std).rsample()
170 |
171 | z_what_reshape = (
172 | z_what.permute(0, 2, 3, 1)
173 | .reshape(-1, self.args.z.z_what_dim)
174 | .view(-1, self.args.z.z_what_dim, 1, 1)
175 | )
176 |
177 | if self.args.data.inp_channel == 1 or not self.args.arch.phase_overlap:
178 | o = self.z_what_decoder_net(z_what_reshape)
179 | o = o.sigmoid()
180 | a = o.new_ones(o.size())
181 | elif self.args.arch.phase_overlap:
182 | o, a = self.z_what_decoder_net(z_what_reshape).split(
183 | [self.args.data.inp_channel, 1], dim=1
184 | )
185 | o, a = o.sigmoid(), a.sigmoid()
186 | else:
187 | raise NotImplemented
188 |
189 | lv = [z_pres, z_where, z_depth, z_what, z_where_origin]
190 | pa = [o, a]
191 |
192 | return pa, lv
193 |
194 |
195 | class StructDRAW(nn.Module):
196 | def __init__(self, args):
197 | super(StructDRAW, self).__init__()
198 | self.args = args
199 |
200 | self.p_global_decoder_net = StackMLP(
201 | self.args.z.z_global_dim,
202 | self.args.arch.mlp.p_global_decoder_filters,
203 | norm_act_final=True,
204 | )
205 |
206 | rnn_enc_inp_dim = (
207 | self.args.arch.img_enc_dim * 2
208 | + self.args.arch.structdraw.rnn_decoder_hid_dim
209 | )
210 |
211 | rnn_dec_inp_dim = self.args.arch.mlp.p_global_decoder_filters[-1] // (
212 | self.args.arch.num_cell ** 2
213 | )
214 |
215 | rnn_dec_inp_dim += self.args.arch.structdraw.hid_to_dec_filters[-1]
216 |
217 | self.rnn_enc = ConvLSTMCell(
218 | input_dim=rnn_enc_inp_dim,
219 | hidden_dim=self.args.arch.structdraw.rnn_encoder_hid_dim,
220 | kernel_size=self.args.arch.structdraw.kernel_size,
221 | num_cell=self.args.arch.num_cell,
222 | )
223 |
224 | self.rnn_dec = ConvLSTMCell(
225 | input_dim=rnn_dec_inp_dim,
226 | hidden_dim=self.args.arch.structdraw.rnn_decoder_hid_dim,
227 | kernel_size=self.args.arch.structdraw.kernel_size,
228 | num_cell=self.args.arch.num_cell,
229 | )
230 |
231 | self.p_global_net = StackMLP(
232 | self.args.arch.num_cell ** 2
233 | * self.args.arch.structdraw.rnn_decoder_hid_dim,
234 | self.args.arch.mlp.p_global_encoder_filters,
235 | norm_act_final=False,
236 | )
237 |
238 | self.q_global_net = StackMLP(
239 | self.args.arch.num_cell ** 2
240 | * self.args.arch.structdraw.rnn_encoder_hid_dim,
241 | self.args.arch.mlp.q_global_encoder_filters,
242 | norm_act_final=False,
243 | )
244 |
245 | self.hid_to_dec_net = StackConvNorm(
246 | self.args.arch.structdraw.rnn_decoder_hid_dim,
247 | self.args.arch.structdraw.hid_to_dec_filters,
248 | self.args.arch.structdraw.hid_to_dec_kernel_sizes,
249 | self.args.arch.structdraw.hid_to_dec_strides,
250 | self.args.arch.structdraw.hid_to_dec_groups,
251 | norm_act_final=False,
252 | )
253 |
254 | self.register_buffer(
255 | "dec_step_0",
256 | torch.zeros(
257 | 1,
258 | self.args.arch.structdraw.hid_to_dec_filters[-1],
259 | self.args.arch.num_cell,
260 | self.args.arch.num_cell,
261 | ),
262 | )
263 |
264 | def forward(
265 | self,
266 | x: torch.Tensor,
267 | phase_generation: bool = False,
268 | generation_from_step: Any = None,
269 | z_global_predefine: Any = None,
270 | ) -> Tuple:
271 | """
272 | :param x: (bs, dim, num_cell, num_cell) of (bs, dim, img_h, img_w)
273 | :return:
274 | """
275 |
276 | bs = x.size(0)
277 |
278 | h_enc, c_enc = self.rnn_enc.init_hidden(bs)
279 | h_dec, c_dec = self.rnn_dec.init_hidden(bs)
280 |
281 | p_global_mean_list = []
282 | p_global_std_list = []
283 | q_global_mean_list = []
284 | q_global_std_list = []
285 | z_global_list = []
286 |
287 | dec_step = self.dec_step_0.expand(bs, -1, -1, -1)
288 |
289 | for i in range(self.args.arch.draw_step):
290 |
291 | p_global_mean_step, p_global_std_step = self.p_global_net(
292 | h_dec.permute(0, 2, 3, 1).reshape(bs, -1)
293 | ).chunk(2, -1)
294 | p_global_std_step = F.softplus(p_global_std_step)
295 |
296 | if phase_generation or (
297 | generation_from_step is not None and i >= generation_from_step
298 | ):
299 |
300 | q_global_mean_step = x.new_empty(bs, self.args.z.z_global_dim)
301 | q_global_std_step = x.new_empty(bs, self.args.z.z_global_dim)
302 |
303 | if z_global_predefine is None or z_global_predefine.size(1) <= i:
304 | z_global_step = Normal(
305 | p_global_mean_step, p_global_std_step
306 | ).rsample()
307 | else:
308 | z_global_step = z_global_predefine.view(
309 | bs, -1, self.args.z.z_global_dim
310 | )[:, i]
311 |
312 | else:
313 |
314 | if i == 0:
315 | rnn_encoder_inp = torch.cat([x, x, h_dec], dim=1)
316 | else:
317 | rnn_encoder_inp = torch.cat([x, x - dec_step, h_dec], dim=1)
318 |
319 | h_enc, c_enc = self.rnn_enc(rnn_encoder_inp, [h_enc, c_enc])
320 |
321 | q_global_mean_step, q_global_std_step = self.q_global_net(
322 | h_enc.permute(0, 2, 3, 1).reshape(bs, -1)
323 | ).chunk(2, -1)
324 |
325 | q_global_std_step = F.softplus(q_global_std_step)
326 | z_global_step = Normal(q_global_mean_step, q_global_std_step).rsample()
327 |
328 | rnn_decoder_inp = self.p_global_decoder_net(z_global_step).reshape(
329 | bs, -1, self.args.arch.num_cell, self.args.arch.num_cell
330 | )
331 |
332 | rnn_decoder_inp = torch.cat([rnn_decoder_inp, dec_step], dim=1)
333 |
334 | h_dec, c_dec = self.rnn_dec(rnn_decoder_inp, [h_dec, c_dec])
335 |
336 | dec_step = dec_step + self.hid_to_dec_net(h_dec)
337 |
338 | # (bs, dim)
339 | p_global_mean_list.append(p_global_mean_step)
340 | p_global_std_list.append(p_global_std_step)
341 | q_global_mean_list.append(q_global_mean_step)
342 | q_global_std_list.append(q_global_std_step)
343 | z_global_list.append(z_global_step)
344 |
345 | global_dec = dec_step
346 |
347 | # (bs, steps, dim, 1, 1)
348 | p_global_mean_all = torch.stack(p_global_mean_list, 1)[:, :, :, None, None]
349 | p_global_std_all = torch.stack(p_global_std_list, 1)[:, :, :, None, None]
350 | q_global_mean_all = torch.stack(q_global_mean_list, 1)[:, :, :, None, None]
351 | q_global_std_all = torch.stack(q_global_std_list, 1)[:, :, :, None, None]
352 | z_global_all = torch.stack(z_global_list, 1)[:, :, :, None, None]
353 |
354 | pa = [global_dec]
355 | lv = [z_global_all]
356 | ss = [p_global_mean_all, p_global_std_all, q_global_mean_all, q_global_std_all]
357 |
358 | return pa, lv, ss
359 |
360 |
361 | class BgEncoder(nn.Module):
362 | def __init__(self, args):
363 | super(BgEncoder, self).__init__()
364 | self.args = args
365 |
366 | self.p_bg_encoder = StackMLP(
367 | self.args.arch.img_enc_dim * self.args.arch.num_cell ** 2,
368 | self.args.arch.mlp.q_bg_encoder_filters,
369 | norm_act_final=False,
370 | )
371 |
372 | def forward(self, x: torch.Tensor) -> Tuple:
373 | """
374 | :param x: (bs, dim, img_h, img_w) or (bs, dim, num_cell, num_cell)
375 | :return:
376 | """
377 | bs = x.size(0)
378 | q_bg_mean, q_bg_std = self.p_bg_encoder(x.view(bs, -1)).chunk(2, 1)
379 | q_bg_mean = q_bg_mean.view(bs, -1, 1, 1)
380 | q_bg_std = q_bg_std.view(bs, -1, 1, 1)
381 |
382 | q_bg_std = F.softplus(q_bg_std)
383 |
384 | z_bg = Normal(q_bg_mean, q_bg_std).rsample()
385 |
386 | lv = [z_bg]
387 |
388 | ss = [q_bg_mean, q_bg_std]
389 |
390 | return lv, ss
391 |
392 |
393 | class BgGenerator(nn.Module):
394 | def __init__(self, args):
395 | super(BgGenerator, self).__init__()
396 | self.args = args
397 |
398 | inp_dim = self.args.z.z_global_dim * self.args.arch.draw_step
399 |
400 | self.p_bg_generator = StackMLP(
401 | inp_dim, self.args.arch.mlp.p_bg_generator_filters, norm_act_final=False
402 | )
403 |
404 | def forward(
405 | self, z_global_all: torch.Tensor, phase_use_mode: bool = False
406 | ) -> Tuple:
407 | """
408 | :param x: (bs, step, dim, 1, 1)
409 | :return:
410 | """
411 | bs = z_global_all.size(0)
412 |
413 | bg_generator_inp = z_global_all
414 |
415 | p_bg_mean, p_bg_std = self.p_bg_generator(bg_generator_inp.view(bs, -1)).chunk(
416 | 2, 1
417 | )
418 | p_bg_std = F.softplus(p_bg_std)
419 |
420 | p_bg_mean = p_bg_mean.view(bs, -1, 1, 1)
421 | p_bg_std = p_bg_std.view(bs, -1, 1, 1)
422 |
423 | if phase_use_mode:
424 | z_bg = p_bg_mean
425 | else:
426 | z_bg = Normal(p_bg_mean, p_bg_std).rsample()
427 |
428 | lv = [z_bg]
429 |
430 | ss = [p_bg_mean, p_bg_std]
431 |
432 | return lv, ss
433 |
434 |
435 | class BgDecoder(nn.Module):
436 | def __init__(self, args):
437 | super(BgDecoder, self).__init__()
438 | self.args = args
439 |
440 | self.p_bg_decoder = StackSubPixelNorm(
441 | self.args.z.z_bg_dim,
442 | self.args.arch.conv.p_bg_decoder_filters,
443 | self.args.arch.conv.p_bg_decoder_kernel_sizes,
444 | self.args.arch.conv.p_bg_decoder_upscales,
445 | self.args.arch.conv.p_bg_decoder_groups,
446 | norm_act_final=False,
447 | )
448 |
449 | def forward(self, z_bg: torch.Tensor) -> List:
450 | """
451 | :param x: (bs, dim, 1, 1)
452 | :return:
453 | """
454 | bs = z_bg.size(0)
455 |
456 | bg = self.p_bg_decoder(z_bg).sigmoid()
457 |
458 | pa = [bg]
459 |
460 | return pa
461 |
--------------------------------------------------------------------------------
/object_discovery/gnm/submodule.py:
--------------------------------------------------------------------------------
1 | from typing import List, Callable
2 | import torch
3 | from torch import nn
4 |
5 |
6 | class StackConvNorm(nn.Module):
7 | def __init__(
8 | self,
9 | dim_inp: int,
10 | filters: List[int],
11 | kernel_sizes: List[int],
12 | strides: List[int],
13 | groupings: List[int],
14 | norm_act_final: bool,
15 | activation: Callable = nn.CELU,
16 | ):
17 | super(StackConvNorm, self).__init__()
18 |
19 | layers = []
20 |
21 | dim_prev = dim_inp
22 |
23 | for i, (f, k, s) in enumerate(zip(filters, kernel_sizes, strides)):
24 | if s == 0:
25 | layers.append(nn.Conv2d(dim_prev, f, k, 1, 0))
26 | else:
27 | layers.append(nn.Conv2d(dim_prev, f, k, s, (k - 1) // 2))
28 | if i == len(filters) - 1 and norm_act_final == False:
29 | break
30 | layers.append(activation())
31 | layers.append(nn.GroupNorm(groupings[i], f))
32 | # layers.append(nn.BatchNorm2d(f))
33 | dim_prev = f
34 |
35 | self.conv = nn.Sequential(*layers)
36 |
37 | def forward(self, x: torch.Tensor) -> torch.Tensor:
38 | x = self.conv(x)
39 |
40 | return x
41 |
42 |
43 | class StackSubPixelNorm(nn.Module):
44 | def __init__(
45 | self,
46 | dim_inp: int,
47 | filters: List[int],
48 | kernel_sizes: List[int],
49 | upscale: List[int],
50 | groupings: List[int],
51 | norm_act_final: bool,
52 | activation: Callable = nn.CELU,
53 | ):
54 | super(StackSubPixelNorm, self).__init__()
55 |
56 | layers = []
57 |
58 | dim_prev = dim_inp
59 |
60 | for i, (f, k, u) in enumerate(zip(filters, kernel_sizes, upscale)):
61 | if u == 1:
62 | layers.append(nn.Conv2d(dim_prev, f, k, 1, (k - 1) // 2))
63 | else:
64 | layers.append(nn.Conv2d(dim_prev, f * u ** 2, k, 1, (k - 1) // 2))
65 | layers.append(nn.PixelShuffle(u))
66 | if i == len(filters) - 1 and norm_act_final == False:
67 | break
68 | layers.append(activation())
69 | layers.append(nn.GroupNorm(groupings[i], f))
70 | dim_prev = f
71 |
72 | self.conv = nn.Sequential(*layers)
73 |
74 | def forward(self, x: torch.Tensor) -> torch.Tensor:
75 | x = self.conv(x)
76 |
77 | return x
78 |
79 |
80 | class StackMLP(nn.Module):
81 | def __init__(
82 | self,
83 | dim_inp: int,
84 | filters: List[int],
85 | norm_act_final: bool,
86 | activation: Callable = nn.CELU,
87 | phase_layer_norm: bool = True,
88 | ):
89 | super(StackMLP, self).__init__()
90 |
91 | layers = []
92 |
93 | dim_prev = dim_inp
94 |
95 | for i, f in enumerate(filters):
96 | layers.append(nn.Linear(dim_prev, f))
97 | if i == len(filters) - 1 and norm_act_final == False:
98 | break
99 | layers.append(activation())
100 | if phase_layer_norm:
101 | layers.append(nn.LayerNorm(f))
102 | dim_prev = f
103 |
104 | self.mlp = nn.Sequential(*layers)
105 |
106 | def forward(self, x: torch.Tensor) -> torch.Tensor:
107 | x = self.mlp(x)
108 |
109 | return x
110 |
111 |
112 | class ConvLSTMCell(nn.Module):
113 | def __init__(self, input_dim, hidden_dim, kernel_size=3, num_cell=4):
114 | super(ConvLSTMCell, self).__init__()
115 |
116 | self.input_dim = input_dim
117 | self.hidden_dim = hidden_dim
118 |
119 | self.kernel_size = kernel_size
120 | self.padding = (kernel_size - 1) // 2
121 | self.conv = nn.Conv2d(
122 | in_channels=self.input_dim + hidden_dim,
123 | out_channels=4 * self.hidden_dim,
124 | kernel_size=self.kernel_size,
125 | padding=self.padding,
126 | bias=True,
127 | )
128 |
129 | self.register_parameter(
130 | "h_0",
131 | torch.nn.Parameter(
132 | torch.zeros(1, self.hidden_dim, num_cell, num_cell), requires_grad=True
133 | ),
134 | )
135 | self.register_parameter(
136 | "c_0",
137 | torch.nn.Parameter(
138 | torch.zeros(1, self.hidden_dim, num_cell, num_cell), requires_grad=True
139 | ),
140 | )
141 |
142 | def forward(self, x, h_c):
143 | h_cur, c_cur = h_c
144 |
145 | conv_inp = torch.cat([x, h_cur], dim=1)
146 |
147 | i, f, o, c = self.conv(conv_inp).split(self.hidden_dim, dim=1)
148 |
149 | i = torch.sigmoid(i)
150 | f = torch.sigmoid(f)
151 | c = torch.tanh(c)
152 | o = torch.sigmoid(o)
153 |
154 | c_next = f * c_cur + i * c
155 | h_next = o * torch.tanh(c_next)
156 |
157 | return h_next, c_next
158 |
159 | def init_hidden(self, batch_size):
160 | return (
161 | self.h_0.expand(batch_size, -1, -1, -1),
162 | self.c_0.expand(batch_size, -1, -1, -1),
163 | )
164 |
--------------------------------------------------------------------------------
/object_discovery/gnm/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | border_width = 3
5 |
6 | rbox = torch.zeros(3, 42, 42)
7 | rbox[0, :border_width, :] = 1
8 | rbox[0, -border_width:, :] = 1
9 | rbox[0, :, :border_width] = 1
10 | rbox[0, :, -border_width:] = 1
11 | rbox = rbox.view(1, 3, 42, 42)
12 |
13 | gbox = torch.zeros(3, 42, 42)
14 | gbox[1, :border_width, :] = 1
15 | gbox[1, -border_width:, :] = 1
16 | gbox[1, :, :border_width] = 1
17 | gbox[1, :, -border_width:] = 1
18 | gbox = gbox.view(1, 3, 42, 42)
19 |
20 |
21 | def visualize(
22 | x, z_pres, z_where_scale, z_where_shift,
23 | ):
24 | """
25 | x: (bs, 3, img_h, img_w)
26 | z_pres: (bs, 4, 4, 1)
27 | z_where_scale: (bs, 4, 4, 2)
28 | z_where_shift: (bs, 4, 4, 2)
29 | """
30 | global gbox, rbox
31 | bs, _, img_h, img_w = x.size()
32 | z_pres = z_pres.view(-1, 1, 1, 1)
33 | num_obj = z_pres.size(0) // bs
34 | z_scale = z_where_scale.view(-1, 2)
35 | z_shift = z_where_shift.view(-1, 2)
36 | gbox = gbox.to(z_pres.device)
37 | rbox = rbox.to(z_pres.device)
38 | bbox = spatial_transform(
39 | z_pres * gbox + (1 - z_pres) * rbox,
40 | torch.cat((z_scale, z_shift), dim=1),
41 | torch.Size([bs * num_obj, 3, img_h, img_w]),
42 | inverse=True,
43 | )
44 |
45 | return bbox
46 |
47 |
48 | def linear_schedule_tensor(step, start_step, end_step, start_value, end_value, device):
49 | if start_step < step < end_step:
50 | slope = (end_value - start_value) / (end_step - start_step)
51 | x = torch.tensor(start_value + slope * (step - start_step)).to(device)
52 | elif step >= end_step:
53 | x = torch.tensor(end_value).to(device)
54 | else:
55 | x = torch.tensor(start_value).to(device)
56 |
57 | return x
58 |
59 |
60 | def linear_schedule(step, start_step, end_step, start_value, end_value):
61 | if start_step < step < end_step:
62 | slope = (end_value - start_value) / (end_step - start_step)
63 | x = start_value + slope * (step - start_step)
64 | elif step >= end_step:
65 | x = end_value
66 | else:
67 | x = start_value
68 |
69 | return x
70 |
71 |
72 | def spatial_transform(image, z_where, out_dims, inverse=False):
73 | """ spatial transformer network used to scale and shift input according to z_where in:
74 | 1/ x -> x_att -- shapes (H, W) -> (attn_window, attn_window) -- thus inverse = False
75 | 2/ y_att -> y -- (attn_window, attn_window) -> (H, W) -- thus inverse = True
76 | inverting the affine transform as follows: A_inv ( A * image ) = image
77 | A = [R | T] where R is rotation component of angle alpha, T is [tx, ty] translation component
78 | A_inv rotates by -alpha and translates by [-tx, -ty]
79 | if x' = R * x + T --> x = R_inv * (x' - T) = R_inv * x - R_inv * T
80 | here, z_where is 3-dim [scale, tx, ty] so inverse transform is [1/scale, -tx/scale, -ty/scale]
81 | R = [[s, 0], -> R_inv = [[1/s, 0],
82 | [0, s]] [0, 1/s]]
83 | """
84 | # 1. construct 2x3 affine matrix for each datapoint in the minibatch
85 | theta = torch.zeros(2, 3).repeat(image.shape[0], 1, 1).to(image.device)
86 | # set scaling
87 | theta[:, 0, 0] = z_where[:, 0] if not inverse else 1 / (z_where[:, 0] + 1e-15)
88 | theta[:, 1, 1] = z_where[:, 1] if not inverse else 1 / (z_where[:, 1] + 1e-15)
89 |
90 | # set translation
91 | theta[:, 0, -1] = (
92 | z_where[:, 2] if not inverse else -z_where[:, 2] / (z_where[:, 0] + 1e-15)
93 | )
94 | theta[:, 1, -1] = (
95 | z_where[:, 3] if not inverse else -z_where[:, 3] / (z_where[:, 1] + 1e-15)
96 | )
97 | # 2. construct sampling grid
98 | grid = F.affine_grid(theta, torch.Size(out_dims), align_corners=False)
99 | # 3. sample image from grid
100 | return F.grid_sample(image, grid, align_corners=False)
101 |
102 |
103 | def kl_divergence_bern_bern(q_pres_probs, p_pres_prob, eps=1e-15):
104 | """
105 | Compute kl divergence
106 | :param z_pres_logits: (B, ...)
107 | :param prior_pres_prob: float
108 | :return: kl divergence, (B, ...)
109 | """
110 | # z_pres_probs = torch.sigmoid(z_pres_logits)
111 | kl = q_pres_probs * (
112 | torch.log(q_pres_probs + eps) - torch.log(p_pres_prob + eps)
113 | ) + (1 - q_pres_probs) * (
114 | torch.log(1 - q_pres_probs + eps) - torch.log(1 - p_pres_prob + eps)
115 | )
116 |
117 | return kl
118 |
--------------------------------------------------------------------------------
/object_discovery/method.py:
--------------------------------------------------------------------------------
1 | from argparse import Namespace
2 | from typing import Union
3 | from functools import partial
4 | import pytorch_lightning as pl
5 | import wandb
6 | import torch
7 | from torch import optim
8 | import torch.nn.functional as F
9 | from torchvision import utils as vutils
10 | from torchvision import transforms
11 |
12 | from object_discovery.slot_attention_model import SlotAttentionModel
13 | from object_discovery.slate_model import SLATE
14 | from object_discovery.gnm.gnm_model import GNM
15 | from object_discovery.utils import (
16 | to_rgb_from_tensor,
17 | warm_and_decay_lr_scheduler,
18 | cosine_anneal,
19 | linear_warmup,
20 | visualize,
21 | compute_ari,
22 | sa_segment,
23 | rescale,
24 | get_largest_objects,
25 | cmap_tensor,
26 | )
27 | from object_discovery.gnm.logging import gnm_log_validation_outputs
28 |
29 |
30 | class SlotAttentionMethod(pl.LightningModule):
31 | def __init__(
32 | self,
33 | model: Union[SlotAttentionModel, SLATE, GNM],
34 | datamodule: pl.LightningDataModule,
35 | params: Namespace,
36 | ):
37 | if type(params) is dict:
38 | params = Namespace(**params)
39 | super().__init__()
40 | self.model = model
41 | self.datamodule = datamodule
42 | self.params = params
43 | self.save_hyperparameters(params)
44 |
45 | def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
46 | return self.model(input, **kwargs)
47 |
48 | def step(self, batch):
49 | mask = None
50 | if self.params.model_type == "slate":
51 | self.tau = cosine_anneal(
52 | self.trainer.global_step,
53 | self.params.tau_start,
54 | self.params.tau_final,
55 | 0,
56 | self.params.tau_steps,
57 | )
58 | loss = self.model.loss_function(batch, self.tau, self.params.hard)
59 | elif self.params.model_type == "sa":
60 | separation_tau = None
61 | if self.params.use_separation_loss:
62 | if self.params.separation_tau:
63 | separation_tau = self.params.separation_tau
64 | else:
65 | separation_tau = self.params.separation_tau_max_val - cosine_anneal(
66 | self.trainer.global_step,
67 | self.params.separation_tau_max_val,
68 | 0,
69 | self.params.separation_tau_start,
70 | self.params.separation_tau_end,
71 | )
72 | area_tau = None
73 | if self.params.use_area_loss:
74 | if self.params.area_tau:
75 | area_tau = self.params.area_tau
76 | else:
77 | area_tau = self.params.area_tau_max_val - cosine_anneal(
78 | self.trainer.global_step,
79 | self.params.area_tau_max_val,
80 | 0,
81 | self.params.area_tau_start,
82 | self.params.area_tau_end,
83 | )
84 | loss, mask = self.model.loss_function(
85 | batch, separation_tau=separation_tau, area_tau=area_tau
86 | )
87 | elif self.params.model_type == "gnm":
88 | output = self.model.loss_function(batch, self.trainer.global_step)
89 | loss = {
90 | "loss": output["loss"],
91 | "elbo": output["elbo"],
92 | "kl": output["kl"],
93 | "loss_comp_ratio": output["rec_loss"] / output["kl"],
94 | }
95 | for k in output:
96 | if k.startswith("kl_") or k.endswith("_loss"):
97 | val = output[k]
98 | if isinstance(val, torch.Tensor):
99 | val = val.mean()
100 | loss[k] = val
101 |
102 | return loss, mask
103 |
104 | def training_step(self, batch, batch_idx):
105 | loss_dict, _ = self.step(batch)
106 | logs = {"train/" + key: val.item() for key, val in loss_dict.items()}
107 | self.log_dict(logs, sync_dist=True)
108 | return loss_dict["loss"]
109 |
110 | def sample_images(self, stage="validation"):
111 | dl = (
112 | self.datamodule.val_dataloader()
113 | if stage == "validation"
114 | else self.datamodule.train_dataloader()
115 | )
116 | perm = torch.randperm(self.params.batch_size)
117 | idx = perm[: self.params.n_samples]
118 | batch = next(iter(dl))
119 | if type(batch) == list:
120 | batch = batch[0][idx]
121 | else:
122 | batch = batch[idx]
123 |
124 | if self.params.accelerator:
125 | batch = batch.to(self.device)
126 | if self.params.model_type == "sa":
127 | recon_combined, recons, masks, slots = self.model.forward(batch)
128 | # `masks` has shape [batch_size, num_entries, channels, height, width].
129 | threshold = getattr(self.params, "sa_segmentation_threshold", 0.5)
130 | _, _, cmap_segmentation, cmap_segmentation_thresholded = sa_segment(
131 | masks, threshold
132 | )
133 |
134 | # combine images in a nice way so we can display all outputs in one grid, output rescaled to be between 0 and 1
135 | out = torch.cat(
136 | [
137 | to_rgb_from_tensor(batch.unsqueeze(1)), # original images
138 | to_rgb_from_tensor(recon_combined.unsqueeze(1)), # reconstructions
139 | cmap_segmentation.unsqueeze(1),
140 | cmap_segmentation_thresholded.unsqueeze(1),
141 | to_rgb_from_tensor(recons * masks + (1 - masks)), # each slot
142 | ],
143 | dim=1,
144 | )
145 |
146 | batch_size, num_slots, C, H, W = recons.shape
147 | images = vutils.make_grid(
148 | out.view(batch_size * out.shape[1], C, H, W).cpu(),
149 | normalize=False,
150 | nrow=out.shape[1],
151 | )
152 | elif self.params.model_type == "slate":
153 | recon, _, _, attns = self.model(batch, self.tau, True)
154 | gen_img = self.model.reconstruct_autoregressive(batch)
155 | vis_recon = visualize(batch, recon, gen_img, attns, N=32)
156 | images = vutils.make_grid(
157 | vis_recon, nrow=self.params.num_slots + 3, pad_value=0.2
158 | )[:, 2:-2, 2:-2]
159 |
160 | return images
161 |
162 | def validation_step(self, batch, batch_idx):
163 | if self.params.model_type == "gnm":
164 | output = self.model.loss_function(
165 | batch[0], self.trainer.global_step, generate_bbox=batch_idx == 0
166 | )
167 | loss, images = gnm_log_validation_outputs(
168 | batch, batch_idx, output, self.trainer.is_global_zero
169 | )
170 | for key, image in images.items():
171 | self.logger.experiment.log(
172 | {f"validation/{key}": [wandb.Image(image)]}, commit=False
173 | )
174 | elif type(batch) == list and self.model.supports_masks:
175 | loss, predicted_mask = self.step(batch[0])
176 | predicted_mask = torch.permute(predicted_mask, [0, 3, 4, 2, 1])
177 | # `predicted_mask` has shape [batch_size, height, width, channels, num_entries]
178 | predicted_mask = torch.squeeze(predicted_mask)
179 | batch_size, height, width, num_entries = predicted_mask.shape
180 | predicted_mask = torch.reshape(
181 | predicted_mask, [batch_size, height * width, num_entries]
182 | )
183 | # `predicted_mask` has shape [batch_size, height * width, num_entries]
184 | # Scale from [0, 1] to [0, 255] to match the true mask.
185 | predicted_mask = (predicted_mask * 255).type(torch.int)
186 | ari = compute_ari(
187 | predicted_mask,
188 | batch[1],
189 | len(batch[0]),
190 | self.params.resolution[0],
191 | self.params.resolution[1],
192 | self.datamodule.max_num_entries,
193 | )
194 | loss["ari"] = ari
195 | else:
196 | if type(batch) == list:
197 | batch = batch[0]
198 | loss, _ = self.step(batch)
199 | return loss
200 |
201 | def validation_epoch_end(self, outputs):
202 | logs = {
203 | "validation/" + key: torch.stack([x[key] for x in outputs]).float().mean()
204 | for key in outputs[0].keys()
205 | }
206 | self.log_dict(logs, sync_dist=True)
207 |
208 | def num_training_steps(self) -> int:
209 | """Total training steps inferred from datamodule and devices."""
210 | # https://github.com/Lightning-AI/lightning/issues/5449#issuecomment-774265729
211 | if self.trainer.max_steps != -1:
212 | return self.trainer.max_steps
213 |
214 | limit_batches = self.trainer.limit_train_batches
215 | batches = len(self.datamodule.train_dataloader())
216 | batches = (
217 | min(batches, limit_batches)
218 | if isinstance(limit_batches, int)
219 | else int(limit_batches * batches)
220 | )
221 |
222 | num_devices = max(1, self.trainer.num_devices)
223 |
224 | effective_accum = self.trainer.accumulate_grad_batches * num_devices
225 | return (batches // effective_accum) * self.trainer.max_epochs
226 |
227 | def configure_optimizers(self):
228 | if self.params.model_type == "slate":
229 | optimizer = optim.Adam(
230 | [
231 | {
232 | "params": (
233 | x[1]
234 | for x in self.model.named_parameters()
235 | if "dvae" in x[0]
236 | ),
237 | "lr": self.params.lr_dvae,
238 | },
239 | {
240 | "params": (
241 | x[1]
242 | for x in self.model.named_parameters()
243 | if "dvae" not in x[0]
244 | ),
245 | "lr": self.params.lr_main,
246 | },
247 | ],
248 | weight_decay=self.params.weight_decay,
249 | )
250 | elif self.params.model_type == "sa":
251 | optimizer = optim.Adam(
252 | self.model.parameters(),
253 | lr=self.params.lr_main,
254 | weight_decay=self.params.weight_decay,
255 | )
256 | elif self.params.model_type == "gnm":
257 | optimizer = optim.RMSprop(
258 | self.model.parameters(),
259 | lr=self.params.lr_main,
260 | weight_decay=self.params.weight_decay,
261 | )
262 |
263 | total_steps = self.num_training_steps()
264 |
265 | scheduler_lambda = None
266 | if self.params.scheduler == "warmup_and_decay":
267 | warmup_steps_pct = self.params.warmup_steps_pct
268 | decay_steps_pct = self.params.decay_steps_pct
269 | scheduler_lambda = partial(
270 | warm_and_decay_lr_scheduler,
271 | warmup_steps_pct=warmup_steps_pct,
272 | decay_steps_pct=decay_steps_pct,
273 | total_steps=total_steps,
274 | gamma=self.params.scheduler_gamma,
275 | )
276 | elif self.params.scheduler == "warmup":
277 | scheduler_lambda = partial(
278 | linear_warmup,
279 | start_value=0.0,
280 | final_value=1.0,
281 | start_step=0,
282 | final_step=self.params.lr_warmup_steps,
283 | )
284 |
285 | if scheduler_lambda is not None:
286 | if self.params.model_type == "slate":
287 | lr_lambda = [lambda o: 1, scheduler_lambda]
288 | elif self.params.model_type in ["sa", "gnm"]:
289 | lr_lambda = scheduler_lambda
290 | scheduler = optim.lr_scheduler.LambdaLR(
291 | optimizer=optimizer, lr_lambda=lr_lambda
292 | )
293 |
294 | if self.params.model_type == "slate" and hasattr(self.params, "patience"):
295 | reduce_on_plateau = optim.lr_scheduler.ReduceLROnPlateau(
296 | optimizer=optimizer,
297 | mode="min",
298 | factor=0.5,
299 | patience=self.params.patience,
300 | )
301 | return (
302 | [optimizer],
303 | [
304 | {"scheduler": scheduler, "interval": "step",},
305 | {
306 | "scheduler": reduce_on_plateau,
307 | "interval": "epoch",
308 | "monitor": "validation/loss",
309 | },
310 | ],
311 | )
312 |
313 | return (
314 | [optimizer],
315 | [{"scheduler": scheduler, "interval": "step",}],
316 | )
317 | return optimizer
318 |
319 | def predict(
320 | self,
321 | image,
322 | do_transforms=False,
323 | debug=False,
324 | return_pil=False,
325 | return_slots=False,
326 | background_detection="spread_out",
327 | background_metric="area",
328 | ):
329 | """
330 | `background_detection` options:
331 | - "spread_out" detects the background as pixels that appear in multiple
332 | slot masks.
333 | - "concentrated" assumes the background has been assigned
334 | to one slot. The slot with the largest distance between two points
335 | in its mask is assumed to be the background when using
336 | "concentrated."
337 | - When using "both," pixels detected using "spread_out"
338 | and the largest object detected using "concentrated" will be the
339 | background. The largest object detected using "concentrated" can be
340 | the background detected by "spread_out."
341 | `background_metric` is used when `background_detection` is set to "both" or
342 | "concentrated" to determine which object is largest.
343 | - "area" will find the object with the largest area
344 | - "distance" will find the object with the greatest distance between
345 | two points in that object.
346 | `return_slots` returns only the Slot Attention slots if using Slot
347 | Attention.
348 |
349 | """
350 | assert background_detection in ["spread_out", "concentrated", "both"]
351 | if return_pil:
352 | assert debug or (
353 | len(image.shape) == 3 or image.shape[0] == 1
354 | ), "Only one image can be passed when using `return_pil` and `debug=False`."
355 |
356 | if do_transforms:
357 | if getattr(self, "predict_transforms", True):
358 | current_transforms = []
359 | if type(image) is not torch.Tensor:
360 | current_transforms.append(transforms.ToTensor())
361 | if self.params.model_type == "sa":
362 | current_transforms.append(transforms.Lambda(rescale))
363 | current_transforms.append(
364 | transforms.Resize(
365 | self.params.resolution,
366 | interpolation=transforms.InterpolationMode.NEAREST,
367 | )
368 | )
369 | self.predict_transforms = transforms.Compose(current_transforms)
370 | image = self.predict_transforms(image)
371 |
372 | if len(image.shape) == 3:
373 | # Add the batch_size dimension (set to 1) if input is a single image.
374 | image = image.unsqueeze(0)
375 |
376 | if self.params.model_type == "sa":
377 | recon_combined, recons, masks, slots = self.forward(image)
378 | if return_slots:
379 | return slots
380 | threshold = getattr(self.params, "sa_segmentation_threshold", 0.5)
381 | (
382 | segmentation,
383 | segmentation_thresholded,
384 | cmap_segmentation,
385 | cmap_segmentation_thresholded,
386 | ) = sa_segment(masks, threshold)
387 | # `cmap_segmentation` and `cmap_segmentation_thresholded` have shape
388 | # [batch_size, channels=3, height, width].
389 | if background_detection in ["concentrated", "both"]:
390 | if background_detection == "both":
391 | # `segmentation_thresholded` has pixels that are masked by
392 | # many slots set to 0 already.
393 | objects = F.one_hot(segmentation_thresholded.to(torch.int64))
394 | else:
395 | objects = F.one_hot(segmentation.to(torch.int64))
396 | # `objects` has shape [batch_size, height, width, num_objects]
397 | objects = objects.permute([0, 3, 1, 2])
398 | # `objects` has shape [batch_size, num_objects, height, width]
399 | largest_object_idx = get_largest_objects(
400 | objects, metric=background_metric
401 | )
402 | # `largest_object_idx` has shape [batch_size]
403 | largest_object = objects[
404 | range(len(largest_object_idx)), largest_object_idx
405 | ]
406 | # `largest_object` has shape [batch_size, num_objects=1, height, width]
407 | largest_object = largest_object.squeeze(1).to(torch.bool)
408 |
409 | segmentation_background = (
410 | segmentation_thresholded.clone()
411 | if background_detection == "both"
412 | else segmentation.clone()
413 | )
414 | # Set the largest object to be index 0, the background.
415 | segmentation_background[largest_object] = 0
416 | # Recompute the colors now that `largest_object` is the background.
417 | cmap_segmentation_background = cmap_tensor(segmentation_background)
418 | elif background_detection == "spread_out":
419 | segmentation_background = segmentation_thresholded
420 | cmap_segmentation_background = cmap_segmentation_thresholded
421 | if debug:
422 | out = torch.cat(
423 | [
424 | to_rgb_from_tensor(image.unsqueeze(1)), # original images
425 | to_rgb_from_tensor(
426 | recon_combined.unsqueeze(1)
427 | ), # reconstructions
428 | cmap_segmentation.unsqueeze(1),
429 | cmap_segmentation_background.unsqueeze(1),
430 | to_rgb_from_tensor(recons * masks + (1 - masks)), # each slot
431 | ],
432 | dim=1,
433 | )
434 | batch_size, num_slots, C, H, W = recons.shape
435 | images = vutils.make_grid(
436 | out.view(batch_size * out.shape[1], C, H, W).cpu(),
437 | normalize=False,
438 | nrow=out.shape[1],
439 | )
440 | to_return = images
441 | if return_pil:
442 | to_return = transforms.functional.to_pil_image(to_return.squeeze())
443 | return to_return
444 | else:
445 | to_return = segmentation_background
446 | if return_pil:
447 | to_return = transforms.functional.to_pil_image(
448 | cmap_segmentation_background.squeeze()
449 | )
450 | return to_return
451 |
452 | else:
453 | raise ValueError(
454 | "The predict function is only implemented for "
455 | + 'Slot Attention (params.model_type == "sa").'
456 | )
457 |
--------------------------------------------------------------------------------
/object_discovery/params.py:
--------------------------------------------------------------------------------
1 | from argparse import Namespace
2 |
3 |
4 | training_params = Namespace(
5 | model_type="sa",
6 | dataset="boxworld",
7 | num_slots=10,
8 | num_iterations=3,
9 | accumulate_grad_batches=1,
10 | data_root="data/box_world_dataset.h5",
11 | accelerator="gpu",
12 | devices=-1,
13 | max_steps=-1,
14 | num_sanity_val_steps=1,
15 | num_workers=4,
16 | is_logger_enabled=True,
17 | gradient_clip_val=0.0,
18 | n_samples=16,
19 | clevrtex_dataset_variant="full",
20 | alternative_crop=True, # Alternative crop for RAVENS dataset
21 | )
22 |
23 | slot_attention_params = Namespace(
24 | lr_main=4e-4,
25 | batch_size=64,
26 | val_batch_size=64,
27 | resolution=(128, 128),
28 | slot_size=64,
29 | max_epochs=1000,
30 | max_steps=500000,
31 | weight_decay=0.0,
32 | mlp_hidden_size=128,
33 | scheduler="warmup_and_decay",
34 | scheduler_gamma=0.5,
35 | warmup_steps_pct=0.02,
36 | decay_steps_pct=0.2,
37 | use_separation_loss="entropy",
38 | separation_tau_start=60_000,
39 | separation_tau_end=65_000,
40 | separation_tau_max_val=0.003,
41 | separation_tau=None,
42 | boxworld_group_objects=True,
43 | use_area_loss=True,
44 | area_tau_start=60_000,
45 | area_tau_end=65_000,
46 | area_tau_max_val=0.006,
47 | area_tau=None,
48 | )
49 |
50 | slate_params = Namespace(
51 | lr_dvae=3e-4,
52 | lr_main=1e-4,
53 | weight_decay=0.0,
54 | batch_size=50,
55 | val_batch_size=50,
56 | max_epochs=1000,
57 | patience=4,
58 | gradient_clip_val=1.0,
59 | resolution=(128, 128),
60 | num_dec_blocks=8,
61 | vocab_size=4096,
62 | d_model=192,
63 | num_heads=8,
64 | dropout=0.1,
65 | slot_size=192,
66 | mlp_hidden_size=192,
67 | tau_start=1.0,
68 | tau_final=0.1,
69 | tau_steps=30000,
70 | scheduler="warmup",
71 | lr_warmup_steps=30000,
72 | hard=False,
73 | )
74 |
75 | gnm_params = Namespace(
76 | std=0.2, # 0.4 on CLEVR, 0.7 on ClevrTex, 0.2/0.3 on RAVENS
77 | z_what_dim=64,
78 | z_bg_dim=10,
79 | lr_main=1e-4,
80 | batch_size=64,
81 | val_batch_size=64,
82 | resolution=(128, 128),
83 | gradient_clip_val=1.0,
84 | max_epochs=1000,
85 | max_steps=5_000_000,
86 | weight_decay=0.0,
87 | scheduler=None,
88 | )
89 |
90 |
91 | def merge_namespaces(one: Namespace, two: Namespace):
92 | return Namespace(**{**vars(one), **vars(two)})
93 |
--------------------------------------------------------------------------------
/object_discovery/segmentation_metrics.py:
--------------------------------------------------------------------------------
1 | # Based on https://github.com/deepmind/multi_object_datasets/blob/9c40290c5e385d3efbccb34fb776b57d44721cba/segmentation_metrics.py
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | """Implementation of the adjusted Rand index."""
6 |
7 |
8 | def adjusted_rand_index(true_mask, pred_mask, name="ari_score"):
9 | r"""Computes the adjusted Rand index (ARI), a clustering similarity score.
10 | This implementation ignores points with no cluster label in `true_mask` (i.e.
11 | those points for which `true_mask` is a zero vector). In the context of
12 | segmentation, that means this function can ignore points in an image
13 | corresponding to the background (i.e. not to an object).
14 | Args:
15 | true_mask: `Tensor` of shape [batch_size, n_points, n_true_groups].
16 | The true cluster assignment encoded as one-hot.
17 | pred_mask: `Tensor` of shape [batch_size, n_points, n_pred_groups].
18 | The predicted cluster assignment encoded as categorical probabilities.
19 | This function works on the argmax over axis 2.
20 | name: str. Name of this operation (defaults to "ari_score").
21 | Returns:
22 | ARI scores as a tf.float32 `Tensor` of shape [batch_size].
23 | Raises:
24 | ValueError: if n_points <= n_true_groups and n_points <= n_pred_groups.
25 | We've chosen not to handle the special cases that can occur when you have
26 | one cluster per datapoint (which would be unusual).
27 | References:
28 | Lawrence Hubert, Phipps Arabie. 1985. "Comparing partitions"
29 | https://link.springer.com/article/10.1007/BF01908075
30 | Wikipedia
31 | https://en.wikipedia.org/wiki/Rand_index
32 | Scikit Learn
33 | http://scikit-learn.org/stable/modules/generated/\
34 | sklearn.metrics.adjusted_rand_score.html
35 | """
36 | _, n_points, n_true_groups = true_mask.shape
37 | n_pred_groups = pred_mask.shape[-1]
38 | if n_points <= n_true_groups and n_points <= n_pred_groups:
39 | # This rules out the n_true_groups == n_pred_groups == n_points
40 | # corner case, and also n_true_groups == n_pred_groups == 0, since
41 | # that would imply n_points == 0 too.
42 | # The sklearn implementation has a corner-case branch which does
43 | # handle this. We chose not to support these cases to avoid counting
44 | # distinct clusters just to check if we have one cluster per datapoint.
45 | raise ValueError(
46 | "adjusted_rand_index requires n_groups < n_points. We don't handle "
47 | "the special cases that can occur when you have one cluster "
48 | "per datapoint."
49 | )
50 |
51 | true_group_ids = torch.argmax(true_mask, -1)
52 | pred_group_ids = torch.argmax(pred_mask, -1)
53 | # We convert true and predicted clusters to one-hot ('oh') representations.
54 | true_mask_oh = true_mask.type(torch.float32) # already one-hot
55 | pred_mask_oh = F.one_hot(pred_group_ids, n_pred_groups).type(torch.float32)
56 |
57 | n_points = torch.sum(true_mask_oh, axis=[1, 2]).type(torch.float32)
58 |
59 | nij = torch.einsum("bji,bjk->bki", pred_mask_oh, true_mask_oh)
60 | a = torch.sum(nij, axis=1)
61 | b = torch.sum(nij, axis=2)
62 |
63 | rindex = torch.sum(nij * (nij - 1), axis=[1, 2])
64 | aindex = torch.sum(a * (a - 1), axis=1)
65 | bindex = torch.sum(b * (b - 1), axis=1)
66 | expected_rindex = aindex * bindex / (n_points * (n_points - 1))
67 | max_rindex = (aindex + bindex) / 2
68 |
69 | denominator = max_rindex - expected_rindex
70 | ari = (rindex - expected_rindex) / denominator
71 | # If a divide by 0 occurs, set the ARI value to 1.
72 | zeros_in_denominator = torch.argwhere(denominator == 0).flatten()
73 | if zeros_in_denominator.nelement() > 0:
74 | ari[zeros_in_denominator] = 1
75 |
76 | # The case where n_true_groups == n_pred_groups == 1 needs to be
77 | # special-cased (to return 1) as the above formula gives a divide-by-zero.
78 | # This might not work when true_mask has values that do not sum to one:
79 | both_single_cluster = torch.logical_and(
80 | _all_equal(true_group_ids), _all_equal(pred_group_ids)
81 | )
82 | return torch.where(both_single_cluster, torch.ones_like(ari), ari)
83 |
84 |
85 | def _all_equal(values):
86 | """Whether values are all equal along the final axis."""
87 | return torch.all(torch.eq(values, values[..., :1]), axis=-1)
88 |
--------------------------------------------------------------------------------
/object_discovery/slate_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from object_discovery.utils import linear
5 | from object_discovery.transformer import PositionalEncoding, TransformerDecoder
6 |
7 | from object_discovery.slot_attention_model import SlotAttention
8 |
9 |
10 | def conv2d(
11 | in_channels,
12 | out_channels,
13 | kernel_size,
14 | stride=1,
15 | padding=0,
16 | dilation=1,
17 | groups=1,
18 | bias=True,
19 | padding_mode="zeros",
20 | weight_init="xavier",
21 | ):
22 |
23 | m = nn.Conv2d(
24 | in_channels,
25 | out_channels,
26 | kernel_size,
27 | stride,
28 | padding,
29 | dilation,
30 | groups,
31 | bias,
32 | padding_mode,
33 | )
34 |
35 | if weight_init == "kaiming":
36 | nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")
37 | else:
38 | nn.init.xavier_uniform_(m.weight)
39 |
40 | if bias:
41 | nn.init.zeros_(m.bias)
42 |
43 | return m
44 |
45 |
46 | def gumbel_softmax(logits, tau=1.0, hard=False, dim=-1):
47 |
48 | eps = torch.finfo(logits.dtype).tiny
49 |
50 | gumbels = -(torch.empty_like(logits).exponential_() + eps).log()
51 | gumbels = (logits + gumbels) / tau
52 |
53 | y_soft = F.softmax(gumbels, dim)
54 |
55 | if hard:
56 | index = y_soft.argmax(dim, keepdim=True)
57 | y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0)
58 | return y_hard - y_soft.detach() + y_soft
59 | else:
60 | return y_soft
61 |
62 |
63 | class Conv2dBlock(nn.Module):
64 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
65 | super().__init__()
66 |
67 | self.m = conv2d(
68 | in_channels,
69 | out_channels,
70 | kernel_size,
71 | stride,
72 | padding,
73 | bias=False,
74 | weight_init="kaiming",
75 | )
76 | self.weight = nn.Parameter(torch.ones(out_channels))
77 | self.bias = nn.Parameter(torch.zeros(out_channels))
78 |
79 | def forward(self, x):
80 | x = self.m(x)
81 | return F.relu(F.group_norm(x, 1, self.weight, self.bias))
82 |
83 |
84 | class dVAE(nn.Module):
85 | def __init__(self, vocab_size, img_channels=3):
86 | super().__init__()
87 |
88 | self.encoder = nn.Sequential(
89 | Conv2dBlock(img_channels, 64, 4, 4),
90 | Conv2dBlock(64, 64, 1, 1),
91 | Conv2dBlock(64, 64, 1, 1),
92 | Conv2dBlock(64, 64, 1, 1),
93 | Conv2dBlock(64, 64, 1, 1),
94 | Conv2dBlock(64, 64, 1, 1),
95 | Conv2dBlock(64, 64, 1, 1),
96 | conv2d(64, vocab_size, 1),
97 | )
98 |
99 | self.decoder = nn.Sequential(
100 | Conv2dBlock(vocab_size, 64, 1),
101 | Conv2dBlock(64, 64, 3, 1, 1),
102 | Conv2dBlock(64, 64, 1, 1),
103 | Conv2dBlock(64, 64, 1, 1),
104 | Conv2dBlock(64, 64 * 2 * 2, 1),
105 | nn.PixelShuffle(2),
106 | Conv2dBlock(64, 64, 3, 1, 1),
107 | Conv2dBlock(64, 64, 1, 1),
108 | Conv2dBlock(64, 64, 1, 1),
109 | Conv2dBlock(64, 64 * 2 * 2, 1),
110 | nn.PixelShuffle(2),
111 | conv2d(64, img_channels, 1),
112 | )
113 |
114 |
115 | class OneHotDictionary(nn.Module):
116 | def __init__(self, vocab_size, emb_size):
117 | super().__init__()
118 | self.dictionary = nn.Embedding(vocab_size, emb_size)
119 |
120 | def forward(self, x):
121 | """
122 | x: B, N, vocab_size
123 | """
124 |
125 | tokens = torch.argmax(x, dim=-1) # batch_size x N
126 | token_embs = self.dictionary(tokens) # batch_size x N x emb_size
127 | return token_embs
128 |
129 |
130 | class SLATE(nn.Module):
131 | def __init__(
132 | self,
133 | num_slots,
134 | vocab_size,
135 | d_model,
136 | resolution,
137 | num_iterations,
138 | slot_size,
139 | mlp_hidden_size,
140 | num_heads,
141 | dropout,
142 | num_dec_blocks,
143 | ):
144 | super().__init__()
145 | self.supports_masks = False
146 |
147 | self.num_slots = num_slots
148 | self.vocab_size = vocab_size
149 | self.d_model = d_model
150 |
151 | self.dvae = dVAE(vocab_size)
152 |
153 | self.positional_encoder = PositionalEncoding(
154 | 1 + (resolution[0] // 4) ** 2, d_model, dropout
155 | )
156 |
157 | self.slot_attn = SlotAttention(
158 | d_model,
159 | num_iterations,
160 | num_slots,
161 | slot_size,
162 | mlp_hidden_size,
163 | do_input_mlp=True,
164 | )
165 |
166 | self.dictionary = OneHotDictionary(vocab_size + 1, d_model)
167 | self.slot_proj = linear(slot_size, d_model, bias=False)
168 |
169 | self.tf_dec = TransformerDecoder(
170 | num_dec_blocks, (resolution[0] // 4) ** 2, d_model, num_heads, dropout
171 | ).train()
172 |
173 | self.out = linear(d_model, vocab_size, bias=False)
174 |
175 | def forward(self, image, tau, hard):
176 | """
177 | image: batch_size x img_channels x H x W
178 | """
179 |
180 | B, C, H, W = image.size()
181 |
182 | # dvae encode
183 | z_logits = F.log_softmax(self.dvae.encoder(image), dim=1)
184 | _, _, H_enc, W_enc = z_logits.size()
185 | z = gumbel_softmax(z_logits, tau, hard, dim=1)
186 |
187 | # dvae recon
188 | recon = self.dvae.decoder(z)
189 | mse = ((image - recon) ** 2).sum() / B
190 |
191 | # hard z
192 | z_hard = gumbel_softmax(z_logits, tau, True, dim=1).detach()
193 |
194 | # target tokens for transformer
195 | z_transformer_target = z_hard.permute(0, 2, 3, 1).flatten(
196 | start_dim=1, end_dim=2
197 | )
198 |
199 | # add BOS token
200 | z_transformer_input = torch.cat(
201 | [torch.zeros_like(z_transformer_target[..., :1]), z_transformer_target],
202 | dim=-1,
203 | )
204 | z_transformer_input = torch.cat(
205 | [torch.zeros_like(z_transformer_input[..., :1, :]), z_transformer_input],
206 | dim=-2,
207 | )
208 | z_transformer_input[:, 0, 0] = 1.0
209 |
210 | # tokens to embeddings
211 | emb_input = self.dictionary(z_transformer_input)
212 | emb_input = self.positional_encoder(emb_input)
213 |
214 | # apply slot attention
215 | slots, attns = self.slot_attn(emb_input[:, 1:], return_attns=True)
216 | attns = attns.transpose(-1, -2)
217 | attns = (
218 | attns.reshape(B, self.num_slots, 1, H_enc, W_enc)
219 | .repeat_interleave(H // H_enc, dim=-2)
220 | .repeat_interleave(W // W_enc, dim=-1)
221 | )
222 | attns = image.unsqueeze(1) * attns + 1.0 - attns
223 | # `attns` has shape [batch_size, num_slots, channels, height, width]
224 |
225 | # apply transformer
226 | slots = self.slot_proj(slots)
227 | decoder_output = self.tf_dec(emb_input[:, :-1], slots)
228 | pred = self.out(decoder_output)
229 | cross_entropy = (
230 | -(z_transformer_target * torch.log_softmax(pred, dim=-1))
231 | .flatten(start_dim=1)
232 | .sum(-1)
233 | .mean()
234 | )
235 |
236 | return (recon.clamp(0.0, 1.0), cross_entropy, mse, attns)
237 |
238 | def loss_function(self, input, tau, hard=False):
239 | _, cross_entropy, mse, _ = self.forward(input, tau, hard)
240 | return {
241 | "loss": cross_entropy + mse,
242 | "cross_entropy": cross_entropy,
243 | "mse": mse,
244 | "tau": torch.tensor(tau),
245 | }
246 |
247 | def reconstruct_autoregressive(self, image, eval=False):
248 | """
249 | image: batch_size x img_channels x H x W
250 | """
251 |
252 | gen_len = (image.size(-1) // 4) ** 2
253 |
254 | B, C, H, W = image.size()
255 |
256 | # dvae encode
257 | z_logits = F.log_softmax(self.dvae.encoder(image), dim=1)
258 | _, _, H_enc, W_enc = z_logits.size()
259 |
260 | # hard z
261 | z_hard = torch.argmax(z_logits, axis=1)
262 | z_hard = (
263 | F.one_hot(z_hard, num_classes=self.vocab_size).permute(0, 3, 1, 2).float()
264 | )
265 | one_hot_tokens = z_hard.permute(0, 2, 3, 1).flatten(start_dim=1, end_dim=2)
266 |
267 | # add BOS token
268 | one_hot_tokens = torch.cat(
269 | [torch.zeros_like(one_hot_tokens[..., :1]), one_hot_tokens], dim=-1
270 | )
271 | one_hot_tokens = torch.cat(
272 | [torch.zeros_like(one_hot_tokens[..., :1, :]), one_hot_tokens], dim=-2
273 | )
274 | one_hot_tokens[:, 0, 0] = 1.0
275 |
276 | # tokens to embeddings
277 | emb_input = self.dictionary(one_hot_tokens)
278 | emb_input = self.positional_encoder(emb_input)
279 |
280 | # slot attention
281 | slots, attns = self.slot_attn(emb_input[:, 1:], return_attns=True)
282 | attns = attns.transpose(-1, -2)
283 | attns = (
284 | attns.reshape(B, self.num_slots, 1, H_enc, W_enc)
285 | .repeat_interleave(H // H_enc, dim=-2)
286 | .repeat_interleave(W // W_enc, dim=-1)
287 | )
288 | attns = image.unsqueeze(1) * attns + (1.0 - attns)
289 | slots = self.slot_proj(slots)
290 |
291 | # generate image tokens auto-regressively
292 | z_gen = z_hard.new_zeros(0)
293 | z_transformer_input = z_hard.new_zeros(B, 1, self.vocab_size + 1)
294 | z_transformer_input[..., 0] = 1.0
295 | for t in range(gen_len):
296 | decoder_output = self.tf_dec(
297 | self.positional_encoder(self.dictionary(z_transformer_input)), slots
298 | )
299 | z_next = F.one_hot(
300 | self.out(decoder_output)[:, -1:].argmax(dim=-1), self.vocab_size
301 | )
302 | z_gen = torch.cat((z_gen, z_next), dim=1)
303 | z_transformer_input = torch.cat(
304 | [
305 | z_transformer_input,
306 | torch.cat([torch.zeros_like(z_next[:, :, :1]), z_next], dim=-1),
307 | ],
308 | dim=1,
309 | )
310 |
311 | z_gen = z_gen.transpose(1, 2).float().reshape(B, -1, H_enc, W_enc)
312 | recon_transformer = self.dvae.decoder(z_gen)
313 |
314 | if eval:
315 | return recon_transformer.clamp(0.0, 1.0), attns
316 |
317 | return recon_transformer.clamp(0.0, 1.0)
318 |
--------------------------------------------------------------------------------
/object_discovery/slot_attention_model.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 | from object_discovery.utils import (
8 | assert_shape,
9 | build_grid,
10 | conv_transpose_out_shape,
11 | linear,
12 | gru_cell,
13 | )
14 |
15 |
16 | class SlotAttention(nn.Module):
17 | def __init__(
18 | self,
19 | in_features,
20 | num_iterations,
21 | num_slots,
22 | slot_size,
23 | mlp_hidden_size,
24 | epsilon=1e-8,
25 | do_input_mlp=False,
26 | ):
27 | super().__init__()
28 | self.in_features = in_features
29 | self.num_iterations = num_iterations
30 | self.num_slots = num_slots
31 | self.slot_size = slot_size # number of hidden layers in slot dimensions
32 | self.mlp_hidden_size = mlp_hidden_size
33 | self.epsilon = epsilon
34 | self.do_input_mlp = do_input_mlp
35 |
36 | if self.do_input_mlp:
37 | self.input_layer_norm = nn.LayerNorm(self.in_features)
38 | self.input_mlp = nn.Sequential(
39 | linear(self.in_features, self.in_features, weight_init="kaiming"),
40 | nn.ReLU(),
41 | linear(self.in_features, self.in_features),
42 | )
43 |
44 | self.norm_inputs = nn.LayerNorm(self.in_features)
45 | self.norm_slots = nn.LayerNorm(self.slot_size)
46 | self.norm_mlp = nn.LayerNorm(self.slot_size)
47 |
48 | # Linear maps for the attention module.
49 | self.project_q = linear(self.slot_size, self.slot_size, bias=False)
50 | self.project_k = linear(self.in_features, self.slot_size, bias=False)
51 | self.project_v = linear(self.in_features, self.slot_size, bias=False)
52 |
53 | # Slot update functions.
54 | self.gru = gru_cell(self.slot_size, self.slot_size)
55 | self.mlp = nn.Sequential(
56 | linear(self.slot_size, self.mlp_hidden_size, weight_init="kaiming"),
57 | nn.ReLU(),
58 | linear(self.mlp_hidden_size, self.slot_size),
59 | )
60 |
61 | # Parameters for Gaussian init (shared by all slots).
62 | self.slot_mu = nn.Parameter(torch.zeros(1, 1, self.slot_size))
63 | self.slot_log_sigma = nn.Parameter(torch.zeros(1, 1, self.slot_size))
64 | nn.init.xavier_uniform_(self.slot_mu, gain=nn.init.calculate_gain("linear"))
65 | nn.init.xavier_uniform_(
66 | self.slot_log_sigma,
67 | gain=nn.init.calculate_gain("linear"),
68 | )
69 |
70 | def step(self, slots, k, v, batch_size, num_inputs):
71 | slots_prev = slots
72 | slots = self.norm_slots(slots)
73 |
74 | # Attention.
75 | q = self.project_q(slots) # Shape: [batch_size, num_slots, slot_size].
76 | assert_shape(q.size(), (batch_size, self.num_slots, self.slot_size))
77 | q *= self.slot_size**-0.5 # Normalization
78 | attn_logits = torch.matmul(k, q.transpose(2, 1))
79 | attn = F.softmax(attn_logits, dim=-1)
80 | # `attn` has shape: [batch_size, num_inputs, num_slots].
81 | assert_shape(attn.size(), (batch_size, num_inputs, self.num_slots))
82 | attn_vis = attn.clone()
83 |
84 | # Weighted mean.
85 | attn = attn + self.epsilon
86 | attn = attn / torch.sum(attn, dim=-2, keepdim=True)
87 | updates = torch.matmul(attn.transpose(-1, -2), v)
88 | # `updates` has shape: [batch_size, num_slots, slot_size].
89 | assert_shape(updates.size(), (batch_size, self.num_slots, self.slot_size))
90 |
91 | # Slot update.
92 | # GRU is expecting inputs of size (N,H) so flatten batch and slots dimension
93 | slots = self.gru(
94 | updates.view(batch_size * self.num_slots, self.slot_size),
95 | slots_prev.view(batch_size * self.num_slots, self.slot_size),
96 | )
97 | slots = slots.view(batch_size, self.num_slots, self.slot_size)
98 | assert_shape(slots.size(), (batch_size, self.num_slots, self.slot_size))
99 | slots = slots + self.mlp(self.norm_mlp(slots))
100 | assert_shape(slots.size(), (batch_size, self.num_slots, self.slot_size))
101 |
102 | return slots, attn_vis
103 |
104 | def forward(self, inputs: torch.Tensor, return_attns=False):
105 | # `inputs` has shape [batch_size, num_inputs, inputs_size].
106 | # `inputs` also has shape [batch_size, enc_height * enc_width, cnn_hidden_size].
107 | batch_size, num_inputs, inputs_size = inputs.shape
108 |
109 | if self.do_input_mlp:
110 | inputs = self.input_mlp(self.input_layer_norm(inputs))
111 |
112 | inputs = self.norm_inputs(inputs) # Apply layer norm to the input.
113 | k = self.project_k(inputs) # Shape: [batch_size, num_inputs, slot_size].
114 | assert_shape(k.size(), (batch_size, num_inputs, self.slot_size))
115 | v = self.project_v(inputs) # Shape: [batch_size, num_inputs, slot_size].
116 | assert_shape(v.size(), (batch_size, num_inputs, self.slot_size))
117 |
118 | # Initialize the slots. Shape: [batch_size, num_slots, slot_size].
119 | slots_init = inputs.new_empty(
120 | batch_size, self.num_slots, self.slot_size
121 | ).normal_()
122 | slots = self.slot_mu + torch.exp(self.slot_log_sigma) * slots_init
123 |
124 | # Multiple rounds of attention.
125 | for _ in range(self.num_iterations):
126 | slots, attn_vis = self.step(slots, k, v, batch_size, num_inputs)
127 | # Detach slots from the current graph and compute one more step.
128 | # This is implicit slot attention from https://cocosci.princeton.edu/papers/chang2022objfixed.pdf
129 | slots, attn_vis = self.step(slots.detach(), k, v, batch_size, num_inputs)
130 | if return_attns:
131 | return slots, attn_vis
132 | else:
133 | return slots
134 |
135 |
136 | class SlotAttentionModel(nn.Module):
137 | def __init__(
138 | self,
139 | resolution: Tuple[int, int],
140 | num_slots: int,
141 | num_iterations,
142 | in_channels: int = 3,
143 | kernel_size: int = 5,
144 | slot_size: int = 64,
145 | mlp_hidden_size: int = 128,
146 | hidden_dims: Tuple[int, ...] = (64, 64, 64, 64),
147 | decoder_resolution: Tuple[int, int] = (8, 8),
148 | use_separation_loss=False,
149 | use_area_loss=False,
150 | ):
151 | super().__init__()
152 | self.supports_masks = True
153 |
154 | self.resolution = resolution
155 | self.num_slots = num_slots
156 | self.num_iterations = num_iterations
157 | self.in_channels = in_channels
158 | self.kernel_size = kernel_size
159 | self.slot_size = slot_size
160 | self.mlp_hidden_size = mlp_hidden_size
161 | self.hidden_dims = hidden_dims
162 | self.decoder_resolution = decoder_resolution
163 | self.out_features = self.hidden_dims[-1]
164 | self.use_separation_loss = use_separation_loss
165 | self.use_area_loss = use_area_loss
166 |
167 | modules = []
168 | channels = self.in_channels
169 | # Build Encoder
170 | for h_dim in self.hidden_dims:
171 | modules.append(
172 | nn.Sequential(
173 | nn.Conv2d(
174 | channels,
175 | out_channels=h_dim,
176 | kernel_size=self.kernel_size,
177 | stride=1,
178 | padding=self.kernel_size // 2,
179 | ),
180 | nn.LeakyReLU(),
181 | )
182 | )
183 | channels = h_dim
184 |
185 | self.encoder = nn.Sequential(*modules)
186 | self.encoder_pos_embedding = SoftPositionEmbed(
187 | self.in_channels, self.out_features, resolution
188 | )
189 | self.encoder_out_layer = nn.Sequential(
190 | nn.Linear(self.out_features, self.out_features),
191 | nn.LeakyReLU(),
192 | nn.Linear(self.out_features, self.out_features),
193 | )
194 |
195 | # Build Decoder
196 | modules = []
197 |
198 | in_size = decoder_resolution[0]
199 | out_size = in_size
200 |
201 | for i in range(len(self.hidden_dims) - 1, -1, -1):
202 | modules.append(
203 | nn.Sequential(
204 | nn.ConvTranspose2d(
205 | self.hidden_dims[i],
206 | self.hidden_dims[i - 1],
207 | kernel_size=5,
208 | stride=2,
209 | padding=2,
210 | output_padding=1,
211 | ),
212 | nn.LeakyReLU(),
213 | )
214 | )
215 | out_size = conv_transpose_out_shape(out_size, 2, 2, 5, 1)
216 |
217 | assert_shape(
218 | resolution,
219 | (out_size, out_size),
220 | message="Output shape of decoder did not match input resolution. Try changing `decoder_resolution`.",
221 | )
222 |
223 | # same convolutions
224 | modules.append(
225 | nn.Sequential(
226 | nn.ConvTranspose2d(
227 | self.out_features,
228 | self.out_features,
229 | kernel_size=5,
230 | stride=1,
231 | padding=2,
232 | output_padding=0,
233 | ),
234 | nn.LeakyReLU(),
235 | nn.ConvTranspose2d(
236 | self.out_features,
237 | 4,
238 | kernel_size=3,
239 | stride=1,
240 | padding=1,
241 | output_padding=0,
242 | ),
243 | )
244 | )
245 |
246 | assert_shape(resolution, (out_size, out_size), message="")
247 |
248 | self.decoder = nn.Sequential(*modules)
249 | self.decoder_pos_embedding = SoftPositionEmbed(
250 | self.in_channels, self.out_features, self.decoder_resolution
251 | )
252 |
253 | self.slot_attention = SlotAttention(
254 | in_features=self.out_features,
255 | num_iterations=self.num_iterations,
256 | num_slots=self.num_slots,
257 | slot_size=self.slot_size,
258 | mlp_hidden_size=self.mlp_hidden_size,
259 | )
260 |
261 | def forward(self, x):
262 | batch_size, num_channels, height, width = x.shape
263 | encoder_out = self.encoder(x)
264 | encoder_out = self.encoder_pos_embedding(encoder_out)
265 | # `encoder_out` has shape: [batch_size, filter_size, height, width]
266 | encoder_out = torch.flatten(encoder_out, start_dim=2, end_dim=3)
267 | # `encoder_out` has shape: [batch_size, filter_size, height*width]
268 | encoder_out = encoder_out.permute(0, 2, 1)
269 | encoder_out = self.encoder_out_layer(encoder_out)
270 | # `encoder_out` has shape: [batch_size, height*width, filter_size]
271 |
272 | slots = self.slot_attention(encoder_out)
273 | assert_shape(slots.size(), (batch_size, self.num_slots, self.slot_size))
274 | # `slots` has shape: [batch_size, num_slots, slot_size].
275 | batch_size, num_slots, slot_size = slots.shape
276 |
277 | slots = slots.view(batch_size * num_slots, slot_size, 1, 1)
278 | decoder_in = slots.repeat(
279 | 1, 1, self.decoder_resolution[0], self.decoder_resolution[1]
280 | )
281 |
282 | out = self.decoder_pos_embedding(decoder_in)
283 | out = self.decoder(out)
284 | # `out` has shape: [batch_size*num_slots, num_channels+1, height, width].
285 | assert_shape(
286 | out.size(), (batch_size * num_slots, num_channels + 1, height, width)
287 | )
288 |
289 | # Perform researchers' `unstack_and_split` using `torch.view`.
290 | out = out.view(batch_size, num_slots, num_channels + 1, height, width)
291 | recons = out[:, :, :num_channels, :, :]
292 | masks = out[:, :, -1:, :, :]
293 | # Normalize alpha masks over slots.
294 | masks = F.softmax(masks, dim=1)
295 | recon_combined = torch.sum(recons * masks, dim=1)
296 | return recon_combined, recons, masks, slots
297 |
298 | def loss_function(
299 | self, input, separation_tau=None, area_tau=None, global_step=None
300 | ):
301 | recon_combined, recons, masks, slots = self.forward(input)
302 | # `masks` has shape [batch_size, num_entries, channels, height, width]
303 | mse_loss = F.mse_loss(recon_combined, input)
304 |
305 | loss = mse_loss
306 | to_return = {"loss": loss}
307 | if self.use_area_loss:
308 | max_num_objects = self.num_slots - 1
309 | width, height = masks.shape[-1], masks.shape[-2]
310 | one_object_area = 9 * 9
311 | background_size = width * height - max_num_objects * one_object_area
312 | slot_area = torch.sum(masks.squeeze(), dim=(-1, -2))
313 |
314 | batch_size, num_slots = slot_area.shape[0], slot_area.shape[1]
315 | area_loss = 0
316 | for batch_idx in range(batch_size):
317 | for slot_idx in range(num_slots):
318 | area = slot_area[batch_idx, slot_idx]
319 | area_loss += min(
320 | (area - 2 * one_object_area) ** 2,
321 | max(background_size - area, 0) * (background_size - area),
322 | ) / (2 * one_object_area)**2
323 |
324 | area_loss /= batch_size * num_slots
325 | loss += area_loss * area_tau
326 | to_return["loss"] = loss
327 | to_return["area_loss"] = area_loss
328 | to_return["area_tau"] = torch.tensor(area_tau)
329 | if self.use_separation_loss:
330 | if self.use_separation_loss == "max":
331 | separation_loss = 1 - torch.mean(torch.max(masks, dim=1).values.float())
332 | elif self.use_separation_loss == "entropy":
333 | entropy = torch.special.entr(masks + 1e-8)
334 | separation_loss = torch.mean(entropy.sum(dim=1))
335 |
336 | loss += separation_loss * separation_tau
337 | to_return["loss"] = loss
338 | to_return["mse_loss"] = mse_loss
339 | to_return["separation_loss"] = separation_loss
340 | to_return["separation_tau"] = torch.tensor(separation_tau)
341 | return to_return, masks
342 |
343 |
344 | class SoftPositionEmbed(nn.Module):
345 | def __init__(
346 | self, num_channels: int, hidden_size: int, resolution: Tuple[int, int]
347 | ):
348 | super().__init__()
349 | self.dense = nn.Linear(in_features=num_channels + 1, out_features=hidden_size)
350 | self.register_buffer("grid", build_grid(resolution))
351 |
352 | def forward(self, inputs: torch.Tensor):
353 | # Permute to move num_channels to 1st dimension. PyTorch layers need
354 | # num_channels as 1st dimension, tensorflow needs num_channels last.
355 | emb_proj = self.dense(self.grid).permute(0, 3, 1, 2)
356 | assert_shape(inputs.shape[1:], emb_proj.shape[1:])
357 | return inputs + emb_proj
358 |
--------------------------------------------------------------------------------
/object_discovery/train.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning.loggers as pl_loggers
2 | from pytorch_lightning import Trainer
3 | from pytorch_lightning.callbacks import LearningRateMonitor
4 |
5 | from object_discovery.data import (
6 | CLEVRDataModule,
7 | Shapes3dDataModule,
8 | RAVENSRobotDataModule,
9 | SketchyDataModule,
10 | ClevrTexDataModule,
11 | BoxWorldDataModule,
12 | TetrominoesDataModule,
13 | )
14 | from object_discovery.method import SlotAttentionMethod
15 | from object_discovery.slot_attention_model import SlotAttentionModel
16 | from object_discovery.slate_model import SLATE
17 | from object_discovery.params import (
18 | merge_namespaces,
19 | training_params,
20 | slot_attention_params,
21 | slate_params,
22 | gnm_params,
23 | )
24 | from object_discovery.utils import ImageLogCallback
25 | from object_discovery.gnm.gnm_model import GNM, hyperparam_anneal
26 | from object_discovery.gnm.config import get_arrow_args
27 |
28 |
29 | def main(params=None):
30 | if params is None:
31 | params = training_params
32 | if params.model_type == "slate":
33 | params = merge_namespaces(params, slate_params)
34 | elif params.model_type == "sa":
35 | params = merge_namespaces(params, slot_attention_params)
36 | elif params.model_type == "gnm":
37 | params = merge_namespaces(params, gnm_params)
38 |
39 | assert params.num_slots > 1, "Must have at least 2 slots."
40 | params.neg_1_to_pos_1_scale = params.model_type == "sa"
41 |
42 | if params.dataset == "clevr":
43 | datamodule = CLEVRDataModule(
44 | data_root=params.data_root,
45 | max_n_objects=params.num_slots - 1,
46 | train_batch_size=params.batch_size,
47 | val_batch_size=params.val_batch_size,
48 | num_workers=params.num_workers,
49 | resolution=params.resolution,
50 | neg_1_to_pos_1_scale=params.neg_1_to_pos_1_scale,
51 | )
52 | elif params.dataset == "shapes3d":
53 | assert params.resolution == (
54 | 64,
55 | 64,
56 | ), "shapes3d dataset requires 64x64 resolution"
57 | datamodule = Shapes3dDataModule(
58 | data_root=params.data_root,
59 | train_batch_size=params.batch_size,
60 | val_batch_size=params.val_batch_size,
61 | num_workers=params.num_workers,
62 | neg_1_to_pos_1_scale=params.neg_1_to_pos_1_scale,
63 | )
64 | elif params.dataset == "ravens":
65 | datamodule = RAVENSRobotDataModule(
66 | data_root=params.data_root,
67 | # `max_n_objects` is the number of objects on the table. It does
68 | # not count the background, table, robot, and robot arm.
69 | max_n_objects=params.num_slots - 1
70 | if params.alternative_crop
71 | else params.num_slots - 4,
72 | train_batch_size=params.batch_size,
73 | val_batch_size=params.val_batch_size,
74 | num_workers=params.num_workers,
75 | resolution=params.resolution,
76 | alternative_crop=params.alternative_crop,
77 | neg_1_to_pos_1_scale=params.neg_1_to_pos_1_scale,
78 | )
79 | elif params.dataset == "sketchy":
80 | assert params.resolution == (
81 | 128,
82 | 128,
83 | ), "sketchy dataset requires 128x128 resolution"
84 | datamodule = SketchyDataModule(
85 | data_root=params.data_root,
86 | train_batch_size=params.batch_size,
87 | val_batch_size=params.val_batch_size,
88 | num_workers=params.num_workers,
89 | neg_1_to_pos_1_scale=params.neg_1_to_pos_1_scale,
90 | )
91 | elif params.dataset == "clevrtex":
92 | datamodule = ClevrTexDataModule(
93 | data_root=params.data_root,
94 | train_batch_size=params.batch_size,
95 | val_batch_size=params.val_batch_size,
96 | num_workers=params.num_workers,
97 | resolution=params.resolution,
98 | neg_1_to_pos_1_scale=params.neg_1_to_pos_1_scale,
99 | dataset_variant=params.clevrtex_dataset_variant,
100 | max_n_objects=params.num_slots - 1,
101 | )
102 | elif params.dataset == "boxworld":
103 | max_n_objects = params.num_slots - 1
104 | datamodule = BoxWorldDataModule(
105 | data_root=params.data_root,
106 | max_n_objects=2 * max_n_objects
107 | if params.boxworld_group_objects
108 | else max_n_objects,
109 | train_batch_size=params.batch_size,
110 | val_batch_size=params.val_batch_size,
111 | num_workers=params.num_workers,
112 | resolution=params.resolution,
113 | neg_1_to_pos_1_scale=params.neg_1_to_pos_1_scale,
114 | )
115 | elif params.dataset == "tetrominoes":
116 | datamodule = TetrominoesDataModule(
117 | data_root=params.data_root,
118 | max_n_objects=params.num_slots - 1,
119 | train_batch_size=params.batch_size,
120 | val_batch_size=params.val_batch_size,
121 | num_workers=params.num_workers,
122 | resolution=params.resolution,
123 | neg_1_to_pos_1_scale=params.neg_1_to_pos_1_scale,
124 | )
125 |
126 | print(
127 | f"Training set size (images must have {params.num_slots - 1} objects):",
128 | len(datamodule.train_dataset),
129 | )
130 |
131 | if params.model_type == "sa":
132 | model = SlotAttentionModel(
133 | resolution=params.resolution,
134 | num_slots=params.num_slots,
135 | num_iterations=params.num_iterations,
136 | slot_size=params.slot_size,
137 | use_separation_loss=params.use_separation_loss,
138 | use_area_loss=params.use_area_loss,
139 | )
140 | elif params.model_type == "slate":
141 | model = SLATE(
142 | num_slots=params.num_slots,
143 | vocab_size=params.vocab_size,
144 | d_model=params.d_model,
145 | resolution=params.resolution,
146 | num_iterations=params.num_iterations,
147 | slot_size=params.slot_size,
148 | mlp_hidden_size=params.mlp_hidden_size,
149 | num_heads=params.num_heads,
150 | dropout=params.dropout,
151 | num_dec_blocks=params.num_dec_blocks,
152 | )
153 | elif params.model_type == "gnm":
154 | model_params = get_arrow_args()
155 | model_params = hyperparam_anneal(model_params, 0)
156 | params = merge_namespaces(params, model_params)
157 | params.const.likelihood_sigma = params.std
158 | params.z.z_what_dim = params.z_what_dim
159 | params.z.z_bg_dim = params.z_bg_dim
160 | model = GNM(params)
161 |
162 | method = SlotAttentionMethod(model=model, datamodule=datamodule, params=params)
163 |
164 | logger = pl_loggers.WandbLogger(project="slot-attention-clevr6")
165 |
166 | callbacks = [LearningRateMonitor("step")]
167 | if params.model_type != "gnm":
168 | callbacks.append(ImageLogCallback())
169 |
170 | trainer = Trainer(
171 | logger=logger if params.is_logger_enabled else False,
172 | accelerator=params.accelerator,
173 | num_sanity_val_steps=params.num_sanity_val_steps,
174 | devices=params.devices,
175 | max_epochs=params.max_epochs,
176 | max_steps=params.max_steps,
177 | accumulate_grad_batches=params.accumulate_grad_batches,
178 | gradient_clip_val=params.gradient_clip_val,
179 | log_every_n_steps=50,
180 | callbacks=callbacks if params.is_logger_enabled else [],
181 | )
182 | trainer.fit(method, datamodule=datamodule)
183 |
184 |
185 | if __name__ == "__main__":
186 | main()
187 |
--------------------------------------------------------------------------------
/object_discovery/transformer.py:
--------------------------------------------------------------------------------
1 | # From https://github.com/singhgautam/slate/blob/6afe75211a79ef7327071ce198f4427928418bf5/transformer.py
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from object_discovery.utils import linear
6 |
7 |
8 | class MultiHeadAttention(nn.Module):
9 | def __init__(self, d_model, num_heads, dropout=0.0, gain=1.0):
10 | super().__init__()
11 |
12 | assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
13 | self.d_model = d_model
14 | self.num_heads = num_heads
15 |
16 | self.attn_dropout = nn.Dropout(dropout)
17 | self.output_dropout = nn.Dropout(dropout)
18 |
19 | self.proj_q = linear(d_model, d_model, bias=False)
20 | self.proj_k = linear(d_model, d_model, bias=False)
21 | self.proj_v = linear(d_model, d_model, bias=False)
22 | self.proj_o = linear(d_model, d_model, bias=False, gain=gain)
23 |
24 | def forward(self, q, k, v, attn_mask=None):
25 | """
26 | q: batch_size x target_len x d_model
27 | k: batch_size x source_len x d_model
28 | v: batch_size x source_len x d_model
29 | attn_mask: target_len x source_len
30 | return: batch_size x target_len x d_model
31 | """
32 | B, T, _ = q.shape
33 | _, S, _ = k.shape
34 |
35 | q = self.proj_q(q).view(B, T, self.num_heads, -1).transpose(1, 2)
36 | k = self.proj_k(k).view(B, S, self.num_heads, -1).transpose(1, 2)
37 | v = self.proj_v(v).view(B, S, self.num_heads, -1).transpose(1, 2)
38 |
39 | q = q * (q.shape[-1] ** (-0.5))
40 | attn = torch.matmul(q, k.transpose(-1, -2))
41 |
42 | if attn_mask is not None:
43 | attn = attn.masked_fill(attn_mask, float("-inf"))
44 |
45 | attn = F.softmax(attn, dim=-1)
46 | attn = self.attn_dropout(attn)
47 |
48 | output = torch.matmul(attn, v).transpose(1, 2).reshape(B, T, -1)
49 | output = self.proj_o(output)
50 | output = self.output_dropout(output)
51 | return output
52 |
53 |
54 | class PositionalEncoding(nn.Module):
55 | def __init__(self, max_len, d_model, dropout=0.1):
56 | super().__init__()
57 | self.dropout = nn.Dropout(dropout)
58 | self.pe = nn.Parameter(torch.zeros(1, max_len, d_model), requires_grad=True)
59 | nn.init.trunc_normal_(self.pe)
60 |
61 | def forward(self, input):
62 | """
63 | input: batch_size x seq_len x d_model
64 | return: batch_size x seq_len x d_model
65 | """
66 | T = input.shape[1]
67 | return self.dropout(input + self.pe[:, :T])
68 |
69 |
70 | class TransformerEncoderBlock(nn.Module):
71 | def __init__(self, d_model, num_heads, dropout=0.0, gain=1.0, is_first=False):
72 | super().__init__()
73 |
74 | self.is_first = is_first
75 |
76 | self.attn_layer_norm = nn.LayerNorm(d_model)
77 | self.attn = MultiHeadAttention(d_model, num_heads, dropout, gain)
78 |
79 | self.ffn_layer_norm = nn.LayerNorm(d_model)
80 | self.ffn = nn.Sequential(
81 | linear(d_model, 4 * d_model, weight_init="kaiming"),
82 | nn.ReLU(),
83 | linear(4 * d_model, d_model, gain=gain),
84 | nn.Dropout(dropout),
85 | )
86 |
87 | def forward(self, input):
88 | """
89 | input: batch_size x source_len x d_model
90 | return: batch_size x source_len x d_model
91 | """
92 | if self.is_first:
93 | input = self.attn_layer_norm(input)
94 | x = self.attn(input, input, input)
95 | input = input + x
96 | else:
97 | x = self.attn_layer_norm(input)
98 | x = self.attn(x, x, x)
99 | input = input + x
100 |
101 | x = self.ffn_layer_norm(input)
102 | x = self.ffn(x)
103 | return input + x
104 |
105 |
106 | class TransformerEncoder(nn.Module):
107 | def __init__(self, num_blocks, d_model, num_heads, dropout=0.0):
108 | super().__init__()
109 |
110 | if num_blocks > 0:
111 | gain = (2 * num_blocks) ** (-0.5)
112 | self.blocks = nn.ModuleList(
113 | [
114 | TransformerEncoderBlock(
115 | d_model, num_heads, dropout, gain, is_first=True
116 | )
117 | ]
118 | + [
119 | TransformerEncoderBlock(
120 | d_model, num_heads, dropout, gain, is_first=False
121 | )
122 | for _ in range(num_blocks - 1)
123 | ]
124 | )
125 | else:
126 | self.blocks = nn.ModuleList()
127 |
128 | self.layer_norm = nn.LayerNorm(d_model)
129 |
130 | def forward(self, input):
131 | """
132 | input: batch_size x source_len x d_model
133 | return: batch_size x source_len x d_model
134 | """
135 | for block in self.blocks:
136 | input = block(input)
137 |
138 | return self.layer_norm(input)
139 |
140 |
141 | class TransformerDecoderBlock(nn.Module):
142 | def __init__(
143 | self, max_len, d_model, num_heads, dropout=0.0, gain=1.0, is_first=False
144 | ):
145 | super().__init__()
146 |
147 | self.is_first = is_first
148 |
149 | self.self_attn_layer_norm = nn.LayerNorm(d_model)
150 | self.self_attn = MultiHeadAttention(d_model, num_heads, dropout, gain)
151 |
152 | mask = torch.triu(torch.ones((max_len, max_len), dtype=torch.bool), diagonal=1)
153 | self.self_attn_mask = nn.Parameter(mask, requires_grad=False)
154 |
155 | self.encoder_decoder_attn_layer_norm = nn.LayerNorm(d_model)
156 | self.encoder_decoder_attn = MultiHeadAttention(
157 | d_model, num_heads, dropout, gain
158 | )
159 |
160 | self.ffn_layer_norm = nn.LayerNorm(d_model)
161 | self.ffn = nn.Sequential(
162 | linear(d_model, 4 * d_model, weight_init="kaiming"),
163 | nn.ReLU(),
164 | linear(4 * d_model, d_model, gain=gain),
165 | nn.Dropout(dropout),
166 | )
167 |
168 | def forward(self, input, encoder_output):
169 | """
170 | input: batch_size x target_len x d_model
171 | encoder_output: batch_size x source_len x d_model
172 | return: batch_size x target_len x d_model
173 | """
174 | T = input.shape[1]
175 |
176 | if self.is_first:
177 | input = self.self_attn_layer_norm(input)
178 | x = self.self_attn(input, input, input, self.self_attn_mask[:T, :T])
179 | input = input + x
180 | else:
181 | x = self.self_attn_layer_norm(input)
182 | x = self.self_attn(x, x, x, self.self_attn_mask[:T, :T])
183 | input = input + x
184 |
185 | x = self.encoder_decoder_attn_layer_norm(input)
186 | x = self.encoder_decoder_attn(x, encoder_output, encoder_output)
187 | input = input + x
188 |
189 | x = self.ffn_layer_norm(input)
190 | x = self.ffn(x)
191 | return input + x
192 |
193 |
194 | class TransformerDecoder(nn.Module):
195 | def __init__(self, num_blocks, max_len, d_model, num_heads, dropout=0.0):
196 | super().__init__()
197 |
198 | if num_blocks > 0:
199 | gain = (3 * num_blocks) ** (-0.5)
200 | self.blocks = nn.ModuleList(
201 | [
202 | TransformerDecoderBlock(
203 | max_len, d_model, num_heads, dropout, gain, is_first=True
204 | )
205 | ]
206 | + [
207 | TransformerDecoderBlock(
208 | max_len, d_model, num_heads, dropout, gain, is_first=False
209 | )
210 | for _ in range(num_blocks - 1)
211 | ]
212 | )
213 | else:
214 | self.blocks = nn.ModuleList()
215 |
216 | self.layer_norm = nn.LayerNorm(d_model)
217 |
218 | def forward(self, input, encoder_output):
219 | """
220 | input: batch_size x target_len x d_model
221 | encoder_output: batch_size x source_len x d_model
222 | return: batch_size x target_len x d_model
223 | """
224 | for block in self.blocks:
225 | input = block(input, encoder_output)
226 |
227 | return self.layer_norm(input)
228 |
--------------------------------------------------------------------------------
/object_discovery/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Tuple, Union
2 | import math
3 | import itertools
4 |
5 | from matplotlib.colors import ListedColormap
6 | from matplotlib.cm import get_cmap
7 | from PIL import ImageFont
8 | from scipy.spatial import ConvexHull
9 | from scipy.spatial._qhull import QhullError
10 | from scipy.spatial.distance import cdist
11 |
12 | import torch
13 | import torch.nn as nn
14 | import numpy as np
15 | from pytorch_lightning import Callback
16 | from object_discovery.segmentation_metrics import adjusted_rand_index
17 |
18 | import wandb
19 |
20 |
21 | FNT = ImageFont.truetype("dejavu/DejaVuSansMono.ttf", 7)
22 | CMAPSPEC = ListedColormap(
23 | ["black", "red", "green", "blue", "yellow", "cyan", "magenta"]
24 | + list(
25 | itertools.chain.from_iterable(
26 | itertools.chain.from_iterable(
27 | zip(
28 | [get_cmap("tab20b")(i) for i in range(i, 20, 4)],
29 | [get_cmap("tab20c")(i) for i in range(i, 20, 4)],
30 | )
31 | )
32 | for i in range(4)
33 | )
34 | )
35 | + [get_cmap("Set3")(i) for i in range(12)]
36 | + ["white"],
37 | name="SemSegMap",
38 | )
39 |
40 |
41 | @torch.no_grad()
42 | def cmap_tensor(t):
43 | t_hw = t.cpu().numpy()
44 | o_hwc = CMAPSPEC(t_hw)[..., :3] # drop alpha
45 | o = torch.from_numpy(o_hwc).transpose(-1, -2).transpose(-2, -3).to(t.device)
46 | return o
47 |
48 |
49 | def conv_transpose_out_shape(
50 | in_size, stride, padding, kernel_size, out_padding, dilation=1
51 | ):
52 | return (
53 | (in_size - 1) * stride
54 | - 2 * padding
55 | + dilation * (kernel_size - 1)
56 | + out_padding
57 | + 1
58 | )
59 |
60 |
61 | def assert_shape(
62 | actual: Union[torch.Size, Tuple[int, ...]],
63 | expected: Tuple[int, ...],
64 | message: str = "",
65 | ):
66 | assert (
67 | actual == expected
68 | ), f"Expected shape: {expected} but passed shape: {actual}. {message}"
69 |
70 |
71 | def build_grid(resolution):
72 | ranges = [torch.linspace(0.0, 1.0, steps=res) for res in resolution]
73 | grid = torch.meshgrid(*ranges)
74 | grid = torch.stack(grid, dim=-1)
75 | grid = torch.reshape(grid, [resolution[0], resolution[1], -1])
76 | grid = grid.unsqueeze(0)
77 | return torch.cat([grid, 1.0 - grid], dim=-1)
78 |
79 |
80 | def rescale(x: torch.Tensor) -> torch.Tensor:
81 | return x * 2 - 1
82 |
83 |
84 | def slightly_off_center_crop(image: torch.Tensor) -> torch.Tensor:
85 | crop = ((29, 221), (64, 256)) # Get center crop. (height, width)
86 | # `image` has shape [channels, height, width]
87 | return image[:, crop[0][0] : crop[0][1], crop[1][0] : crop[1][1]]
88 |
89 |
90 | def slightly_off_center_mask_crop(mask: torch.Tensor) -> torch.Tensor:
91 | # `mask` has shape [max_num_entities, height, width, channels]
92 | crop = ((29, 221), (64, 256)) # Get center crop. (height, width)
93 | return mask[:, crop[0][0] : crop[0][1], crop[1][0] : crop[1][1], :]
94 |
95 |
96 | def compact(l: Any) -> Any:
97 | return list(filter(None, l))
98 |
99 |
100 | def first(x):
101 | return next(iter(x))
102 |
103 |
104 | def only(x):
105 | materialized_x = list(x)
106 | assert len(materialized_x) == 1
107 | return materialized_x[0]
108 |
109 |
110 | class ImageLogCallback(Callback):
111 | def log_images(self, trainer, pl_module, stage):
112 | if trainer.logger:
113 | with torch.no_grad():
114 | pl_module.eval()
115 | images = pl_module.sample_images(stage=stage)
116 | trainer.logger.experiment.log(
117 | {stage + "/images": [wandb.Image(images)]}, commit=False
118 | )
119 |
120 | def on_validation_epoch_end(self, trainer, pl_module) -> None:
121 | self.log_images(trainer, pl_module, stage="validation")
122 |
123 | def on_train_epoch_end(self, trainer, pl_module) -> None:
124 | self.log_images(trainer, pl_module, stage="train")
125 |
126 |
127 | def to_rgb_from_tensor(x: torch.Tensor):
128 | return (x * 0.5 + 0.5).clamp(0, 1)
129 |
130 |
131 | def unstack_and_split(x, batch_size, num_channels=3):
132 | """Unstack batch dimension and split into channels and alpha mask."""
133 | unstacked = torch.reshape(x, [batch_size, -1] + x.shape.as_list()[1:])
134 | channels, masks = torch.split(unstacked, [num_channels, 1], dim=-1)
135 | return channels, masks
136 |
137 |
138 | def linear(in_features, out_features, bias=True, weight_init="xavier", gain=1.0):
139 | m = nn.Linear(in_features, out_features, bias)
140 |
141 | if weight_init == "kaiming":
142 | nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")
143 | else:
144 | nn.init.xavier_uniform_(m.weight, gain)
145 |
146 | if bias:
147 | nn.init.zeros_(m.bias)
148 |
149 | return m
150 |
151 |
152 | def gru_cell(input_size, hidden_size, bias=True):
153 | m = nn.GRUCell(input_size, hidden_size, bias)
154 |
155 | nn.init.xavier_uniform_(m.weight_ih)
156 | nn.init.orthogonal_(m.weight_hh)
157 |
158 | if bias:
159 | nn.init.zeros_(m.bias_ih)
160 | nn.init.zeros_(m.bias_hh)
161 |
162 | return m
163 |
164 |
165 | def warm_and_decay_lr_scheduler(
166 | step: int, warmup_steps_pct, decay_steps_pct, total_steps, gamma
167 | ):
168 | warmup_steps = warmup_steps_pct * total_steps
169 | decay_steps = decay_steps_pct * total_steps
170 | assert step < total_steps
171 | if step < warmup_steps:
172 | factor = step / warmup_steps
173 | else:
174 | factor = 1
175 | factor *= gamma ** (step / decay_steps)
176 | return factor
177 |
178 |
179 | def cosine_anneal(step: int, start_value, final_value, start_step, final_step):
180 | assert start_value >= final_value
181 | assert start_step <= final_step
182 |
183 | if step < start_step:
184 | value = start_value
185 | elif step >= final_step:
186 | value = final_value
187 | else:
188 | a = 0.5 * (start_value - final_value)
189 | b = 0.5 * (start_value + final_value)
190 | progress = (step - start_step) / (final_step - start_step)
191 | value = a * math.cos(math.pi * progress) + b
192 |
193 | return value
194 |
195 |
196 | def linear_warmup(step, start_value, final_value, start_step, final_step):
197 | assert start_value <= final_value
198 | assert start_step <= final_step
199 |
200 | if step < start_step:
201 | value = start_value
202 | elif step >= final_step:
203 | value = final_value
204 | else:
205 | a = final_value - start_value
206 | b = start_value
207 | progress = (step + 1 - start_step) / (final_step - start_step)
208 | value = a * progress + b
209 |
210 | return value
211 |
212 |
213 | def visualize(image, recon_orig, gen, attns, N=8):
214 | _, _, H, W = image.shape
215 | image = image[:N].expand(-1, 3, H, W).unsqueeze(dim=1)
216 | recon_orig = recon_orig[:N].expand(-1, 3, H, W).unsqueeze(dim=1)
217 | gen = gen[:N].expand(-1, 3, H, W).unsqueeze(dim=1)
218 | attns = attns[:N].expand(-1, -1, 3, H, W)
219 |
220 | return torch.cat((image, recon_orig, gen, attns), dim=1).view(-1, 3, H, W)
221 |
222 |
223 | def compute_ari(prediction, mask, batch_size, height, width, max_num_entities):
224 | # Ground-truth segmentation masks are always returned in the canonical
225 | # [batch_size, max_num_entities, height, width, channels] format. To use these
226 | # as an input for `segmentation_metrics.adjusted_rand_index`, we need them in
227 | # the [batch_size, n_points, n_true_groups] format,
228 | # where n_true_groups == max_num_entities. We implement this reshape below.
229 | # Note that 'oh' denotes 'one-hot'.
230 | desired_shape = [batch_size, height * width, max_num_entities]
231 | true_groups_oh = torch.permute(mask, [0, 2, 3, 4, 1])
232 | # `true_groups_oh` has shape [batch_size, height, width, channels, max_num_entries]
233 | true_groups_oh = torch.reshape(true_groups_oh, desired_shape)
234 |
235 | # prediction = tf.random_uniform(desired_shape[:-1],
236 | # minval=0, maxval=max_num_entities,
237 | # dtype=tf.int32)
238 | # prediction_oh = F.one_hot(prediction, max_num_entities)
239 |
240 | ari = adjusted_rand_index(true_groups_oh[..., 1:], prediction)
241 | return ari
242 |
243 |
244 | def flatten_all_but_last(tensor, n_dims=1):
245 | shape = list(tensor.shape)
246 | batch_dims = shape[:-n_dims]
247 | flat_tensor = torch.reshape(tensor, [np.prod(batch_dims)] + shape[-n_dims:])
248 |
249 | def unflatten(other_tensor):
250 | other_shape = list(other_tensor.shape)
251 | return torch.reshape(other_tensor, batch_dims + other_shape[1:])
252 |
253 | return flat_tensor, unflatten
254 |
255 |
256 | def sa_segment(masks, threshold, return_cmap=True):
257 | # `masks` has shape [batch_size, num_entries, channels, height, width].
258 | # Ensure that each pixel belongs to exclusively the slot with the
259 | # greatest mask value for that pixel. Mask values are in the range
260 | # [0, 1] and sum to 1. Add 1 to `segmentation` so the values/colors
261 | # line up with `segmentation_thresholded`
262 | segmentation = masks.to(torch.float).argmax(1).squeeze(1) + 1
263 | segmentation = segmentation.to(torch.uint8)
264 | cmap_segmentation = cmap_tensor(segmentation)
265 |
266 | pred_masks = masks.clone()
267 | # For each pixel, if the mask value for that pixel is less than the
268 | # threshold across all slots, then it belongs to an imaginary slot. This
269 | # imaginary slot is concatenated to the masks as the first mask with any
270 | # pixels that belong to it set to 1. So, the argmax will always select the
271 | # imaginary slot for a pixel if that pixel is in the imaginary slot. This
272 | # has the effect such that any pixel that is masked by multiple slots will
273 | # be grouped together into one mask and removed from its previous mask.
274 | less_than_threshold = pred_masks < threshold
275 | thresholded = torch.all(less_than_threshold, dim=1, keepdim=True).to(torch.float)
276 | pred_masks = torch.cat([thresholded, pred_masks], dim=1)
277 | segmentation_thresholded = pred_masks.to(torch.float).argmax(1).squeeze(1)
278 | cmap_segmentation_thresholded = cmap_tensor(
279 | segmentation_thresholded.to(torch.uint8)
280 | )
281 |
282 | # `cmap_segmentation` and `cmap_segmentation_thresholded` have shape
283 | # [batch_size, channels=3, height, width].
284 | if return_cmap:
285 | return (
286 | segmentation,
287 | segmentation_thresholded,
288 | cmap_segmentation,
289 | cmap_segmentation_thresholded,
290 | )
291 | return segmentation, segmentation_thresholded
292 |
293 |
294 | def PolyArea(x, y):
295 | """Implementation of Shoelace formula for polygon area"""
296 | # From https://stackoverflow.com/a/30408825
297 | return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))
298 |
299 |
300 | def get_largest_objects(objects, metric="area"):
301 | # Implementation partly from https://stackoverflow.com/a/60955825
302 | # `objects` has shape [batch_size, num_objects, height, width]
303 | largest_objects = []
304 | for batch_idx in range(objects.shape[0]):
305 | distances = []
306 | for object_idx in range(objects.shape[1]):
307 | ob = objects[batch_idx, object_idx].to(torch.bool)
308 | if not torch.any(ob):
309 | distances.append(0)
310 | continue
311 |
312 | # Find a convex hull in O(N log N)
313 | points = np.indices((128, 128)).transpose((1, 2, 0))[ob]
314 | try:
315 | hull = ConvexHull(points)
316 | except QhullError:
317 | distances.append(0)
318 | continue
319 | # Extract the points forming the hull
320 | hullpoints = points[hull.vertices, :]
321 | if metric == "distance":
322 | # Naive way of finding the best pair in O(H^2) time if H is number
323 | # of points on hull
324 | hdist = cdist(hullpoints, hullpoints, metric="euclidean")
325 | # Get the farthest apart points
326 | bestpair = np.unravel_index(hdist.argmax(), hdist.shape)
327 |
328 | point1 = hullpoints[bestpair[0]]
329 | point2 = hullpoints[bestpair[1]]
330 | score = np.linalg.norm(point2 - point1)
331 | elif metric == "area":
332 | x = hullpoints[:, 0]
333 | y = hullpoints[:, 1]
334 | score = PolyArea(x, y)
335 | distances.append(score)
336 |
337 | largest_objects.append(np.argmax(np.array(distances)))
338 | return np.array(largest_objects)
339 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | from argparse import Namespace
2 | import torch
3 | import h5py
4 | from PIL import Image
5 | from glob import glob
6 | from torchvision import transforms
7 | from matplotlib import pyplot as plt
8 | from object_discovery.method import SlotAttentionMethod
9 | from object_discovery.slot_attention_model import SlotAttentionModel
10 | from object_discovery.utils import slightly_off_center_crop
11 |
12 |
13 | def load_model(ckpt_path):
14 | ckpt = torch.load(ckpt_path)
15 | params = Namespace(**ckpt["hyper_parameters"])
16 | sa = SlotAttentionModel(
17 | resolution=params.resolution,
18 | num_slots=params.num_slots,
19 | num_iterations=params.num_iterations,
20 | slot_size=params.slot_size,
21 | )
22 | model = SlotAttentionMethod.load_from_checkpoint(
23 | ckpt_path, model=sa, datamodule=None
24 | )
25 | model.eval()
26 | return model
27 |
28 |
29 | print("Loading model...")
30 | ckpt_path = "epoch=209-step=299460.ckpt"
31 | model = load_model(ckpt_path)
32 |
33 | t = transforms.ToTensor()
34 |
35 | print("Loading images...")
36 | with h5py.File("data/box_world_dataset.h5", "r") as f:
37 | images = f["image"][0:8]
38 |
39 | transformed_images = []
40 | for image in images:
41 | transformed_images.append(t(image))
42 | images = torch.stack(transformed_images)
43 |
44 | print("Predicting...")
45 | images = model.predict(
46 | images,
47 | do_transforms=True,
48 | debug=True,
49 | return_pil=True,
50 | background_detection="both",
51 | )
52 | slots = model.predict(images, do_transforms=True, return_slots=True)
53 | slots = slots.squeeze()
54 | # `slots` has shape (num_slots, num_features)
55 |
56 | print("Saving...")
57 | images.save("output.png")
58 |
59 |
60 | # ckpt_path = "sketchy_sa-epoch=59-step=316440-3nofluv3.ckpt"
61 | # model = load_model(ckpt_path)
62 |
63 | # transformed_images = []
64 | # for image_path in glob("data/sketchy_sample/*.png"):
65 | # image = Image.open(image_path)
66 | # image = image.convert("RGB")
67 | # transformed_images.append(transforms.functional.to_tensor(image))
68 |
69 | # images = model.predict(
70 | # torch.stack(transformed_images),
71 | # do_transforms=True,
72 | # debug=True,
73 | # return_pil=True,
74 | # background_detection="both",
75 | # )
76 |
77 | # plt.imshow(images, interpolation="nearest")
78 | # plt.show()
79 |
80 | # ckpt_path = "clevr6_masks-epoch=673-step=275666-r4nbi6n7.ckpt"
81 | # model = load_model(ckpt_path)
82 |
83 | # t = transforms.Compose(
84 | # [
85 | # transforms.ToTensor(),
86 | # transforms.Lambda(slightly_off_center_crop),
87 | # ]
88 | # )
89 |
90 | # with h5py.File("/media/Main/Downloads/clevr_with_masks.h5", "r") as f:
91 | # images = f["image"][0:8]
92 |
93 | # transformed_images = []
94 | # for image in images:
95 | # transformed_images.append(t(image))
96 | # images = torch.stack(transformed_images)
97 |
98 | # images = model.predict(
99 | # images,
100 | # do_transforms=True,
101 | # debug=True,
102 | # return_pil=True,
103 | # background_detection="both",
104 | # )
105 |
106 | # plt.imshow(images, interpolation="nearest")
107 | # plt.show()
108 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "slot-attention-pytorch"
3 | version = "0.1.0"
4 | description = "PyTorch implementation of \"Object-Centric Learning with Slot Attention\""
5 | authors = ["Bryden Fogelman ", "Hayden Housen "]
6 |
7 | [tool.poetry.dependencies]
8 | python = ">=3.9,<3.12"
9 | torch = "^1.13.1"
10 | torchvision = "^0.14.1"
11 | wandb = "^0.13.7"
12 | Pillow = "^9.3.0"
13 | pytorch-lightning = "^1.8.6"
14 | numpy = "^1.24.1"
15 | h5py = "^3.7.0"
16 | scipy = "^1.10.0"
17 | matplotlib = "^3.6.2"
18 |
19 | [tool.poetry.dev-dependencies]
20 |
21 | [build-system]
22 | requires = ["poetry-core>=1.0.0"]
23 | build-backend = "poetry.core.masonry.api"
24 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 | set -x
4 |
5 | poetry install
6 |
7 | chmod +x download_clevr.sh
8 | ./download_clevr.sh /tmp/CLEVR
9 |
10 | python -m slot_attention.train
11 |
--------------------------------------------------------------------------------
/small_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHousen/object-discovery-pytorch/db22b8c7f230a79361927df0fa91e9fd19da2160/small_logo.png
--------------------------------------------------------------------------------