├── LICENSE
├── README.md
├── accel_conf.json
├── assets
├── celeb_manip_2.gif
├── clevrer_manip_1.gif
└── tarffic_manip.gif
├── checkpoints
├── PLACE PRE-TRAINED CHECKPOINTS HERE.txt
└── sample_images
│ ├── celeb
│ ├── 1.jpg
│ ├── 2.jpg
│ ├── 3.jpg
│ ├── 4.jpg
│ ├── 5.jpg
│ ├── 6.jpg
│ ├── 7.jpg
│ └── 8.jpg
│ ├── clevrer
│ ├── 1.png
│ └── 2.png
│ └── traffic
│ └── 1.png
├── datasets
├── celeba_dataset.py
├── clevrer_ds.py
├── shapes_ds.py
└── traffic_ds.py
├── dlp_tutorial.ipynb
├── environment17.yml
├── environment19.yml
├── eval
└── eval_model.py
├── eval_celeb.py
├── interactive_demo_dlp.py
├── models.py
├── modules
└── modules.py
├── requirements17.txt
├── requirements19.txt
├── train_dlp.py
├── train_dlp_accelerate.py
└── utils
├── loss_functions.py
├── tps.py
└── util_func.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # deep-latent-particles-pytorch
2 |
3 | [ICML 2022] Official PyTorch implementation of the paper "Unsupervised Image Representation Learning with Deep Latent Particles"
4 |
5 |
DLPv2 and DDLP (DLP for video generation) have been released: DDLP: Unsupervised Object-Centric Video Prediction with Deep Dynamic Latent Particles
6 |
7 |
8 |
9 | [ICML 2022] Unsupervised Image Representation Learning with Deep Latent Particles
10 |
11 |
12 |
13 | Tal Daniel •
14 | Aviv Tamar
15 |
16 |
17 | Official repository of the paper
18 |
19 | ICML 2022
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 | # Deep Latent Particles
38 |
39 | > **Unsupervised Image Representation Learning with Deep Latent Particles**
40 | > Tal Daniel, Aviv Tamar
41 | >
42 | > **Abstract:** *We propose a new representation of visual data that disentangles object position from appearance.
43 | > Our method, termed Deep Latent Particles (DLP), decomposes the visual input into low-dimensional latent ``particles'',
44 | > where each particle is described by its spatial location and features of its surrounding region.
45 | > To drive learning of such representations, we follow a VAE-based approach and introduce a prior for particle positions
46 | > based on a spatial-softmax architecture, and a modification of the evidence lower bound loss
47 | > inspired by the Chamfer distance between particles. We demonstrate that our DLP representations are useful for
48 | > downstream tasks such as unsupervised keypoint (KP) detection, image manipulation, and video prediction for scenes
49 | > composed of multiple dynamic objects. In addition, we show that our probabilistic interpretation of the problem
50 | > naturally provides uncertainty estimates for particle locations, which can be used for model selection,
51 | > among other tasks.*
52 |
53 | ## Citation
54 |
55 | Daniel, Tal, and Aviv Tamar. "Unsupervised Image Representation Learning with Deep Latent Particles." Proceedings of the 39th International Conference on Machine Learning (ICML) 2022.
56 | >
57 | @InProceedings{pmlr-v162-daniel22a,
58 | title = {Unsupervised Image Representation Learning with Deep Latent Particles},
59 | author = {Daniel, Tal and Tamar, Aviv},
60 | booktitle = {Proceedings of the 39th International Conference on Machine Learning},
61 | pages = {4644--4665},
62 | year = {2022},
63 | volume = {162},
64 | series = {Proceedings of Machine Learning Research},
65 | month = {17--23 Jul},
66 | publisher = {PMLR}
67 |
68 |
69 |
70 |
71 |
72 | - [deep-latent-particles-pytorch](#deep-latent-particles-pytorch)
73 | - [Deep Latent Particles](#deep-latent-particles)
74 | * [Citation](#citation)
75 | * [Prerequisites](#prerequisites)
76 | * [Pretrained Models](#pretrained-models)
77 | * [Interactive Demo](#interactive-demo)
78 | * [Datasets](#datasets)
79 | * [Training](#training)
80 | * [Evaluation of Unsupervised Keypoint Regression on CelebA](#evaluation-of-unsupervised-keypoint-regression-on-celeba)
81 | * [Recommended Hyper-parameters](#recommended-hyper-parameters)
82 | * [Repository Organization](#repository-organization)
83 | * [Credits](#credits)
84 |
85 | ## Prerequisites
86 |
87 | * For your convenience, we provide an `environemnt.yml` file which installs the required packages in a `conda`
88 | environment named `torch`. Alternatively, you can use `pip` to install `requirements.txt`.
89 | * Use the terminal or an Anaconda Prompt and run the following command `conda env create -f environment.yml`.
90 | * For PyTorch 1.7 + CUDA 10.2: `environment17.yml`, `requirements17.txt`
91 | * For PyTorch 1.9 + CUDA 11.1: `environment19.yml`, `requirements19.txt`
92 |
93 | | Library | Version |
94 | |-------------------|------------------|
95 | | `Python` | `3.7 (Anaconda)` |
96 | | `torch` | > = `1.7.1` |
97 | | `torch_geometric` | > = `1.7.1` |
98 | | `torchvision` | > = `0.4` |
99 | | `matplotlib` | > = `2.2.2` |
100 | | `numpy` | > = `1.17` |
101 | | `py-opencv` | > = `3.4.2` |
102 | | `tqdm` | > = `4.36.1` |
103 | | `scipy` | > = `1.3.1` |
104 | | `scikit-image` | > = `0.18.1` |
105 | | `accelerate` | > = `0.3.0` |
106 |
107 | ## Pretrained Models
108 |
109 | * We provide pre-trained checkpoints for the 3 datasets we used in the paper.
110 | * All model checkpoints should be placed inside the `/checkpoints` directory.
111 | * The interactive demo will use these checkpoints.
112 |
113 | | Dataset | Filename | Link |
114 | |-------------------|------------------------------------|--------------------------------------------------------------------------------------|
115 | | CelebA (128x128) | `dlp_celeba_gauss_pointnetpp_feat.pth` | [MEGA.co.nz](https://mega.nz/file/ZAkiDSIQ#ndtlzAPwG42TEGZmuADAzR1Wo0AZx2k__qyWUfcOkQc)|
116 | | Traffic (128x128) | `dlp_traffic_gauss_pointnetpp.pth` | [MEGA.co.nz](https://mega.nz/file/9cN0HAYQ#K9AvKsWemA5hvk9WleleautIdQu2Euezf8UOI7aKUtE)|
117 | | CLEVRER (128x128) | `dlp_clevrer_gauss_pointnetpp.pth` | [MEGA.co.nz](https://mega.nz/file/VINjRZCL#rJ25UPXlYJUxWPaP7gDEbxjVZaayey5JB6x9P5Z__CU)|
118 |
119 | ## Interactive Demo
120 |
121 | * We designed a simple `matplotlib` interactive GUI to plot and control the particles.
122 | * The demo is a **standalone and does not require to download the original datasets**.
123 | * We provide sample images inside `/checkpoints/sample_images/` which will be used.
124 |
125 | To run the demo (after downloading the checkpoints): `python interactive_demo_dlp.py --help`
126 |
127 | * `-d`: dataset to use: [`celeba`, `traffic`, `clevrer`]
128 | * `-i`: index of the image to use inside `/checkpoints/sample_images/`
129 |
130 | Examples:
131 |
132 | * `python interactive_demo_dlp.py -d celeba -i 2`
133 | * `python interactive_demo_dlp.py -d traffic -i 0`
134 | * `python interactive_demo_dlp.py -d clevrer -i 0`
135 |
136 | You can modify `interactive_demo_dlp.py` to add additional datasets.
137 |
138 | ## Datasets
139 |
140 | * **CelebA**: we follow [DVE](https://github.com/jamt9000/DVE):
141 | * [Download](https://github.com/jamt9000/DVE/blob/master/misc/datasets/celeba/README.md) the dataset from
142 | this [link](http:/www.robots.ox.ac.uk/~vgg/research/DVE/data/datasets/celeba.tar.gz).
143 | * The pre-processing is described in `datasets/celeba_dataset.py`.
144 | * **CLEVRER**: download the training and validation videos from [here](http://clevrer.csail.mit.edu/):
145 | * [Training Videos](http://data.csail.mit.edu/clevrer/videos/train/video_train.zip)
146 | , [Validation Videos](http://data.csail.mit.edu/clevrer/videos/validation/video_validation.zip)
147 | * Follow the pre-processing
148 | in `datasets/clevrer_ds.py` (`prepare_numpy_file(path_to_img, image_size=128, frameskip=3, start_frame=26`)
149 | * **Traffic**: this is a self-collected dataset, please contact us if you wish to use it.
150 | * **Shapes**: this dataset is generated automatically in each run for simplicity, see `generate_shape_dataset_torch()`
151 | in `datasets/shapes_ds.py`.
152 |
153 | ## Training
154 |
155 | You can train the model on single-GPU machines and multi-GPU machines. For multi-GPU training We use
156 | [HuggingFace Accelerate](https://huggingface.co/docs/accelerate/index): `pip install accelerate`.
157 |
158 | 1. Set visible GPUs under: `os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"` (`NUM_GPUS=4`)
159 | 2. Set "num_processes": NUM_GPUS in `accel_conf.json` (e.g. `"num_processes":4`
160 | if `os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"`).
161 |
162 |
163 | * Single-GPU machines: `python train_dlp.py --help`
164 | * Multi-GPU machines: `accelerate --config_file ./accel_conf.json train_dlp_accelerate.py --help`
165 |
166 | You should run the `train_dlp.py` or `train_dlp_accelerate.py` files with the following arguments:
167 |
168 | | Argument | Description | Legal Values |
169 | |-------------------------|----------------------------------------------------------------------------------------------------------|----------------------------------------------|
170 | | -h, --help | shows arguments description | |
171 | | -d, --dataset | dataset to train on | str: 'celeba', traffic', 'clevrer', 'shapes' |
172 | | -o, --override | if specified, the code will override the default hyper-parameters with the ones specified with `argparse` (command line) | bool: default=False |
173 | | -l, --lr | learning rate | float: default=2e-4 |
174 | | -b, --batch_size | batch size | int: default=32 |
175 | | -n, --num_epochs | total number of epochs to run | int: default=100 |
176 | | -e, --eval_freq | evaluation epoch frequency | int: defalut=2 |
177 | | -s, --sigma | the prior std of the keypoints | float: default=0.1 |
178 | | -p, --prefix | string prefix for logging | str: default="" |
179 | | -r, --beta_rec | beta coefficient for the reconstruction loss | float: default=1.0 |
180 | | -k, --beta_kl | beta coefficient for the kl divergence | float: default=1.0 |
181 | | -c, --kl_balance | coefficient for the balance between the ChamferKL (for the KP) and the standard KL | float: default=0.001 |
182 | | -v, --rec_loss_function | type of reconstruction loss: 'mse', 'vgg' | str: default="mse" |
183 | | --n_kp_enc | number of posterior kp to be learned | int: default=30 |
184 | | --n_kp_enc_prior | number of kp to filter from the set of prior kp | int: default=50 |
185 | | --dec_bone | decoder backbone:'gauss_pointnetpp_feat': Masked Model, 'gauss_pointnetpp': Object Model" | str: default="gauss_pointnetpp" |
186 | | --patch_size | patch size for the prior KP proposals network (not to be confused with the glimpse size) | int: default=8 |
187 | | --learned_feature_dim | the latent visual features dimensions extracted from glimpses | int: default=10 |
188 | | --use_object_enc | set True to use a separate encoder to encode visual features of glimpses | bool: default=False |
189 | | --use_object_dec | set True to use a separate decoder to decode glimpses (Object Model) | bool: default=False |
190 | | --warmup_epoch | number of epochs where only the object decoder is trained | int: default=2 |
191 | | --anchor_s | defines the glimpse size as a ratio of image_size | float: default=0.25 |
192 | | --exclusive_patches | set True to enable non-overlapping object patches | bool: default=False |
193 |
194 | Examples:
195 |
196 | * Single-GPU:
197 |
198 | `python train_dlp.py --dataset shapes`
199 |
200 | `python train_dlp.py --dataset celeba`
201 |
202 | `python train_dlp.py --dataset clevrer -o --use_object_enc --use_object_dec --warmup_epoch 1 --beta_kl 40.0 --rec_loss_function vgg --learned_feature_dim 6`
203 |
204 | * Multi-GPU:
205 |
206 | `accelerate --config_file ./accel_conf.json train_dlp_accelerate.py --dataset celeba`
207 |
208 | `accelerate --config_file ./accel_conf.json train_dlp_accelerate.py --dataset clevrer -o --use_object_enc --use_object_dec --warmup_epoch 1 --beta_kl 40.0 --rec_loss_function vgg --learned_feature_dim 6`
209 |
210 | * Note: if you want multiple multi-GPU runs, each run should have a different accelerate config file (
211 | e.g., `accel_conf.json`, `accel_conf_2.json`, etc..). The only difference between the files should be
212 | the `main_process_port` field (e.g., for the second config file, set `main_process_port: 81231`).
213 |
214 | ## Evaluation of Unsupervised Keypoint Regression on CelebA
215 |
216 | Linear regression of supervised keypoints on the MAFL dataset it performed during training on the CelebA dataset.
217 |
218 | To evaluate a saved checkpoint of the model: modify the hyper-parameters and paths in `eval_celeb.py`,
219 | and then use `python eval_celeb.py` to calculate and print the normalized error with respect to intra-occular distance.
220 |
221 | ## Recommended Hyper-parameters
222 |
223 | | Dataset | `dec_bone` (model type) | `n_kp_enc` | `n_kp_prior`|`rec_loss_func`|`beta_kl`| `kl_balance` | `patch_size` | `anchor_s` | `learned_feature_dim` |
224 | |---------------------|-------------------------|--------------|---|----|---|-----|-----|-----------|-----|
225 | | CelebA (`celeba`) | `gauss_pointnetpp_feat` | 30 |50|`vgg`|40| 0.001 | 8 |0.125| 10 |
226 | | Traffic (`traffic`) | `gauss_pointnetpp` | 15 |20|`vgg`|30| 0.001 | 16 | 0.25 | 20 |
227 | | CLEVRER (`clevrer`) | `gauss_pointnetpp` | 10 |20|`vgg`|40| 0.001 | 16 |0.25 | 5 |
228 | | Shapes (`shapes`) | `gauss_pointnetpp` | 10 |15|`mse`|0.1| 0.001 | 8 | 0.25 | 6 |
229 |
230 |
231 | ## Repository Organization
232 |
233 | | File name | Content |
234 | |----------------------------|---------------------------------------------------------------------------------------------------------------------------------------------|
235 | | `/checkpoints` | directory for pre-trained checkpoints and sample images for the interactive demo |
236 | | `/datasets` | directory containing data loading classes for the various datasets |
237 | | `/eval/eval_model.py` | evaluation functions such as evaluating the ELBO |
238 | | `/modules/modules.py` | basic neural network blocks used to implement the DLP model |
239 | | `/utils/tps.py` | implementation of the TPS augmentation used for training on CelebA |
240 | | `/utils/loss_functions.py` | loss functions used to optimize the model such as Chamfer-KL and perceptual (VGG) loss |
241 | | `/utils/util_func.py` | utility functions such as logging and plotting functions |
242 | | `eval_celeb.py` | functions to evaluate the normalized error of keypoint linear regression with respect to intra-occular distance for the MAFL/CelebA dataset |
243 | | `models.py` | implementation of the DLP model |
244 | | `train_dlp.py` | training function of DLP for single-GPU machines |
245 | | `train_dlp_accelerate.py` | training function of DLP for multi-GPU machines |
246 | | `dlp_tutorial.ipynb` | Jupyter Notebook tutorial for explaining and training DLP on the random shapes dataset |
247 | | `interactive_demo_dlp.py` | `matplotlib`-based interactive demo to plot and interact with learned particles |
248 | | `environment17/19.yml` | Anaconda environment file to install the required dependencies |
249 | | `requirements17/19.txt` | requirements file for `pip` |
250 | | `accel_conf.json` | configuration file for `accelerate` to run training on multiple GPUs |
251 |
252 | ## Credits
253 |
254 | * CelebA pre-processing is performed as [DVE](https://github.com/jamt9000/DVE).
255 | * Normalized intra-occular distance: [KeyNet (Jakab et al.)](https://github.com/tomasjakab/imm).
256 |
--------------------------------------------------------------------------------
/accel_conf.json:
--------------------------------------------------------------------------------
1 | {
2 | "compute_environment": "LOCAL_MACHINE",
3 | "distributed_type": "MULTI_GPU",
4 | "fp16": false,
5 | "machine_rank": 0,
6 | "main_process_ip": null,
7 | "main_process_port": null,
8 | "main_training_function": "main",
9 | "num_machines": 1,
10 | "num_processes": 4
11 | }
12 |
--------------------------------------------------------------------------------
/assets/celeb_manip_2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/assets/celeb_manip_2.gif
--------------------------------------------------------------------------------
/assets/clevrer_manip_1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/assets/clevrer_manip_1.gif
--------------------------------------------------------------------------------
/assets/tarffic_manip.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/assets/tarffic_manip.gif
--------------------------------------------------------------------------------
/checkpoints/PLACE PRE-TRAINED CHECKPOINTS HERE.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/checkpoints/PLACE PRE-TRAINED CHECKPOINTS HERE.txt
--------------------------------------------------------------------------------
/checkpoints/sample_images/celeb/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/checkpoints/sample_images/celeb/1.jpg
--------------------------------------------------------------------------------
/checkpoints/sample_images/celeb/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/checkpoints/sample_images/celeb/2.jpg
--------------------------------------------------------------------------------
/checkpoints/sample_images/celeb/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/checkpoints/sample_images/celeb/3.jpg
--------------------------------------------------------------------------------
/checkpoints/sample_images/celeb/4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/checkpoints/sample_images/celeb/4.jpg
--------------------------------------------------------------------------------
/checkpoints/sample_images/celeb/5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/checkpoints/sample_images/celeb/5.jpg
--------------------------------------------------------------------------------
/checkpoints/sample_images/celeb/6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/checkpoints/sample_images/celeb/6.jpg
--------------------------------------------------------------------------------
/checkpoints/sample_images/celeb/7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/checkpoints/sample_images/celeb/7.jpg
--------------------------------------------------------------------------------
/checkpoints/sample_images/celeb/8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/checkpoints/sample_images/celeb/8.jpg
--------------------------------------------------------------------------------
/checkpoints/sample_images/clevrer/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/checkpoints/sample_images/clevrer/1.png
--------------------------------------------------------------------------------
/checkpoints/sample_images/clevrer/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/checkpoints/sample_images/clevrer/2.png
--------------------------------------------------------------------------------
/checkpoints/sample_images/traffic/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/checkpoints/sample_images/traffic/1.png
--------------------------------------------------------------------------------
/datasets/clevrer_ds.py:
--------------------------------------------------------------------------------
1 | """
2 | functions and classes to process the CLEVRER dataset
3 | """
4 |
5 | import os
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | import cv2
9 | import utils.tps as tps
10 |
11 | import torch
12 | from PIL import Image
13 | from tqdm import tqdm
14 | from torch.utils.data import Dataset, DataLoader
15 | import torchvision.transforms as transforms
16 |
17 |
18 | def list_images_in_dir(path):
19 | valid_images = [".jpg", ".gif", ".png"]
20 | img_list = []
21 | for f in os.listdir(path):
22 | ext = os.path.splitext(f)[1]
23 | if ext.lower() not in valid_images:
24 | continue
25 | img_list.append(os.path.join(path, f))
26 | return img_list
27 |
28 |
29 | def prepare_numpy_file(path_to_image_dir, image_size=128, frameskip=1, start_frame=1):
30 | # path_to_image_dir = '/media/newhd/data/traffic_data/rimon_frames/'
31 | img_list = list_images_in_dir(path_to_image_dir)
32 | img_list = sorted(img_list, key=lambda x: int(x.split('/')[-1].split('_')[-1].split('.')[0]))
33 | img_list = [img_list[i] for i in range(len(img_list)) if
34 | abs(int(img_list[i].split('/')[-1].split('_')[-1].split('.')[0])) % 1000 > start_frame]
35 | print(f'img_list: {len(img_list)}, 0: {img_list[0]}, -1: {img_list[-1]}')
36 | img_np_list = []
37 | for i in tqdm(range(len(img_list))):
38 | if i % frameskip != 0:
39 | continue
40 | img = Image.open(img_list[i])
41 | img = img.convert('RGB')
42 | # img = img.crop((60, 0, 480, 420))
43 | img = img.resize((image_size, image_size), Image.BICUBIC)
44 | img_np = np.asarray(img)
45 | img_np_list.append(img_np)
46 | img_np_array = np.stack(img_np_list, axis=0)
47 | print(f'img_np_array: {img_np_array.shape}')
48 | save_path = os.path.join(path_to_image_dir, f'clevrer_img{image_size}np_fs{frameskip}.npy')
49 | np.save(save_path, img_np_array)
50 | print(f'file save at @ {save_path}')
51 |
52 |
53 | class CLEVRERDataset(Dataset):
54 | def __init__(self, path_to_npy, image_size=128, transform=None, mode='single', train=True, horizon=3,
55 | frames_per_video=34, video_as_index=False):
56 | super(CLEVRERDataset, self).__init__()
57 | assert mode in ['single', 'frames', 'tps', 'horizon']
58 | self.mode = mode
59 | self.frames_per_video = frames_per_video
60 | self.horizon = horizon if (horizon > 0 and self.mode == 'horizon') else self.frames_per_video
61 | self.train_mode = train
62 | if train:
63 | print(f'clevrer dataset mode: {self.mode}')
64 | if self.mode == 'horizon':
65 | print(f'time steps horizon: {self.horizon}')
66 | if self.mode == 'tps':
67 | self.warper = tps.Warper(H=image_size, W=image_size, warpsd_all=0.00001,
68 | warpsd_subset=0.001, transsd=0.1, scalesd=0.1,
69 | rotsd=2, im1_multiplier=0.1, im1_multiplier_aff=0.1)
70 | else:
71 | self.warper = None
72 | data = np.load(path_to_npy)
73 | # train_size = int(0.9 * data.shape[0])
74 | # valid_size = data.shape[0] - train_size
75 | if train:
76 | self.data = data
77 | # self.data = data[:self.frames_per_video * 200]
78 | print(f'loaded data with shape: {self.data.shape}, size: {self.data.shape[0]}')
79 | else:
80 | self.data = data[:5000]
81 | self.image_size = image_size
82 | self.num_videos = self.data.shape[0] // self.frames_per_video
83 | self.video_as_index = video_as_index
84 | if transform is None:
85 | self.input_transform = transforms.Compose([
86 | transforms.ToPILImage(),
87 | transforms.Resize(image_size),
88 | transforms.ToTensor()
89 | ])
90 | else:
91 | self.input_transform = transform
92 |
93 | def __getitem__(self, index):
94 | if not self.video_as_index:
95 | video_num = int(index / self.frames_per_video)
96 | video_start_idx = video_num * self.frames_per_video
97 | curr_idx = index % self.frames_per_video
98 | max_idx = min(video_start_idx + self.frames_per_video - 1, self.data.shape[0] - 1)
99 | global_idx = video_start_idx + curr_idx
100 | if self.mode == 'single':
101 | return self.input_transform(self.data[index])
102 | elif self.mode == 'frames':
103 | min_idx = min(video_start_idx, index - 1)
104 | if min_idx == video_start_idx:
105 | im1 = self.input_transform(self.data[min_idx + 1])
106 | im2 = self.input_transform(self.data[min_idx])
107 | else:
108 | im1 = self.input_transform(self.data[min_idx])
109 | im2 = self.input_transform(self.data[min_idx - 1])
110 | return im1, im2
111 | elif self.mode == 'horizon':
112 | images = []
113 | length = max_idx
114 | if (index + self.horizon) >= length:
115 | slack = index + self.horizon - length
116 | index = index - slack
117 | for i in range(self.horizon):
118 | t = index + i
119 | images.append(self.input_transform(self.data[t]))
120 | images = torch.stack(images, dim=0)
121 | return images
122 | elif self.mode == 'tps':
123 | im = self.input_transform(self.data[index])
124 | im = im * 255
125 | im2, im1, _, _, _, _ = self.warper(im)
126 | return im1 / 255, im2 / 255
127 | else:
128 | raise NotImplementedError
129 | else:
130 | video_num = index
131 | video_start_idx = video_num * self.frames_per_video
132 | max_idx = video_start_idx + self.frames_per_video - 1
133 | images = []
134 | length = max_idx
135 | frame_idx = video_start_idx
136 | actual_horizon = self.frames_per_video if ((frame_idx + self.horizon) >= length) else self.horizon
137 | for i in range(actual_horizon):
138 | t = frame_idx + i
139 | images.append(self.input_transform(self.data[t]))
140 | images = torch.stack(images, dim=0)
141 | return images
142 |
143 | def __len__(self):
144 | if not self.video_as_index:
145 | return self.data.shape[0]
146 | else:
147 | return self.num_videos
148 |
149 |
150 | if __name__ == '__main__':
151 | path_to_img = '/media/newhd/data/clevrer/train/frames/'
152 | # prepare_numpy_file(path_to_img, image_size=128, frameskip=3, start_frame=26)
153 | test_epochs = False
154 | # load data
155 | path_to_npy = '/media/newhd/data/clevrer/valid/clevrer_img128np_fs3_valid.npy'
156 | mode = 'frames'
157 | horizon = 4
158 | train = True
159 | clevrer_ds = CLEVRERDataset(path_to_npy, mode=mode, train=train, horizon=horizon)
160 | clevrer_dl = DataLoader(clevrer_ds, shuffle=True, pin_memory=True, batch_size=5)
161 | batch = next(iter(clevrer_dl))
162 | if mode == 'single':
163 | im1 = batch[0]
164 | elif mode == 'frames' or mode == 'tps':
165 | im1 = batch[0][0]
166 | im2 = batch[1][0]
167 |
168 | if mode == 'single':
169 | print(im1.shape)
170 | img_np = im1.permute(1, 2, 0).data.cpu().numpy()
171 | fig = plt.figure(figsize=(5, 5))
172 | ax = fig.add_subplot(111)
173 | ax.imshow(img_np)
174 | elif mode == 'horizon':
175 | print(f'batch shape: {batch.shape}')
176 | images = batch[0]
177 | print(f'images shape: {images.shape}')
178 | fig = plt.figure(figsize=(8, 8))
179 | for i in range(images.shape[0]):
180 | ax = fig.add_subplot(1, horizon, i + 1)
181 | im = images[i]
182 | im_np = im.permute(1, 2, 0).data.cpu().numpy()
183 | ax.imshow(im_np)
184 | ax.set_title(f'im {i + 1}')
185 | else:
186 | print(f'im1: {im1.shape}, im2: {im2.shape}')
187 | im1_np = im1.permute(1, 2, 0).data.cpu().numpy()
188 | im2_np = im2.permute(1, 2, 0).data.cpu().numpy()
189 | fig = plt.figure(figsize=(8, 8))
190 | ax = fig.add_subplot(1, 2, 1)
191 | ax.imshow(im1_np)
192 | ax.set_title('im1')
193 |
194 | ax = fig.add_subplot(1, 2, 2)
195 | ax.imshow(im2_np)
196 | ax.set_title('im2 [t-1] or [tps]')
197 | plt.show()
198 | if test_epochs:
199 | from tqdm import tqdm
200 |
201 | pbar = tqdm(iterable=clevrer_dl)
202 | for batch in pbar:
203 | pass
204 | pbar.close()
205 |
--------------------------------------------------------------------------------
/datasets/shapes_ds.py:
--------------------------------------------------------------------------------
1 | """
2 | Simple Random Colored Shapes Dataset
3 | """
4 | # imports
5 | import numpy as np
6 | from skimage.draw import random_shapes
7 | from tqdm.auto import tqdm
8 | import torch
9 |
10 |
11 | def generate_shape_dataset(img_size=64, min_shapes=2, max_shapes=5, min_size=10, max_size=12, allow_overlap=False,
12 | num_images=10_000):
13 | images = []
14 | for i in tqdm(range(num_images)):
15 | img, _ = random_shapes((img_size, img_size), min_shapes=min_shapes, max_shapes=max_shapes,
16 | intensity_range=((0, 200),), min_size=min_size, max_size=max_size,
17 | allow_overlap=allow_overlap, num_trials=100)
18 | img[:, :, 0][img[:, :, 0] == 255] = 0
19 | img[:, :, 1][img[:, :, 1] == 255] = 255
20 | img[:, :, 2][img[:, :, 2] == 255] = 255
21 | img = img / 255.0
22 | images.append(img)
23 | images = np.stack(images, axis=0) # [num_mages, H, W, 3]
24 | return images
25 |
26 |
27 | def generate_shape_dataset_torch(img_size=64, min_shapes=2, max_shapes=5, min_size=11, max_size=13, allow_overlap=False,
28 | num_images=10_000):
29 | images = generate_shape_dataset(img_size=img_size, min_shapes=min_shapes, max_shapes=max_shapes, min_size=min_size,
30 | max_size=max_size,
31 | allow_overlap=allow_overlap, num_images=num_images)
32 | # create torch dataset
33 | img_data_torch = images.transpose(0, 3, 1, 2) # [num_images, 3, H, W]
34 | img_ds = torch.utils.data.TensorDataset(torch.tensor(img_data_torch, dtype=torch.float))
35 | return img_ds
36 |
--------------------------------------------------------------------------------
/datasets/traffic_ds.py:
--------------------------------------------------------------------------------
1 | """
2 | functions and classes to process the Traffic dataset
3 | """
4 |
5 | import os
6 | import numpy as np
7 | import cv2
8 | import torch
9 | from torch.utils.data import Dataset, DataLoader
10 | import torchvision.transforms as transforms
11 | import torchvision.utils as vutils
12 | import matplotlib.pyplot as plt
13 | from PIL import Image
14 | # from tqdm.auto import tqdm
15 | import utils.tps as tps
16 |
17 |
18 | def list_images_in_dir(path):
19 | valid_images = [".jpg", ".gif", ".png"]
20 | img_list = []
21 | for f in os.listdir(path):
22 | ext = os.path.splitext(f)[1]
23 | if ext.lower() not in valid_images:
24 | continue
25 | img_list.append(os.path.join(path, f))
26 | return img_list
27 |
28 |
29 | def prepare_numpy_file(path_to_image_dir, image_size=128, frameskip=1):
30 | # path_to_image_dir = '/media/newhd/data/traffic_data/rimon_frames/'
31 | img_list = list_images_in_dir(path_to_image_dir)
32 | img_list = sorted(img_list, key=lambda x: int(x.split('/')[-1].split('.')[0]))
33 | print(f'img_list: {len(img_list)}, 0: {img_list[0]}, -1: {img_list[-1]}')
34 | img_np_list = []
35 | for i in tqdm(range(len(img_list))):
36 | if i % frameskip != 0:
37 | continue
38 | img = Image.open(img_list[i])
39 | img = img.convert('RGB')
40 | img = img.crop((60, 0, 480, 420))
41 | img = img.resize((image_size, image_size), Image.BICUBIC)
42 | img_np = np.asarray(img)
43 | img_np_list.append(img_np)
44 | img_np_array = np.stack(img_np_list, axis=0)
45 | print(f'img_np_array: {img_np_array.shape}')
46 | save_path = os.path.join(path_to_image_dir, f'img{image_size}np_fs{frameskip}.npy')
47 | np.save(save_path, img_np_array)
48 | print(f'file save at @ {save_path}')
49 |
50 |
51 | class TrafficDataset(Dataset):
52 | def __init__(self, path_to_npy, image_size=128, transform=None, mode='single', train=True, horizon=3):
53 | super(TrafficDataset, self).__init__()
54 | assert mode in ['single', 'frames', 'tps', 'horizon']
55 | self.mode = mode
56 | self.horizon = horizon
57 | if train:
58 | print(f'traffic dataset mode: {self.mode}')
59 | if self.mode == 'horizon':
60 | print(f'time steps horizon: {self.horizon}')
61 | if self.mode == 'tps':
62 | self.warper = tps.Warper(H=image_size, W=image_size, warpsd_all=0.00001,
63 | warpsd_subset=0.001, transsd=0.1, scalesd=0.1,
64 | rotsd=2, im1_multiplier=0.1, im1_multiplier_aff=0.1)
65 | else:
66 | self.warper = None
67 | data = np.load(path_to_npy)
68 | train_size = int(0.9 * data.shape[0])
69 | valid_size = data.shape[0] - train_size
70 | if train:
71 | print(f'loaded data with shape: {data.shape}, train_size: {train_size}, valid_size: {valid_size}')
72 | self.data = data[:train_size]
73 | else:
74 | self.data = data[train_size:]
75 | self.image_size = image_size
76 | if transform is None:
77 | self.input_transform = transforms.Compose([
78 | transforms.ToPILImage(),
79 | transforms.Resize(image_size),
80 | transforms.ToTensor()
81 | ])
82 | else:
83 | self.input_transform = transform
84 |
85 | def __getitem__(self, index):
86 | if self.mode == 'single':
87 | return self.input_transform(self.data[index])
88 | elif self.mode == 'frames':
89 | if index == 0:
90 | im1 = self.input_transform(self.data[index + 1])
91 | im2 = self.input_transform(self.data[index])
92 | else:
93 | im1 = self.input_transform(self.data[index])
94 | im2 = self.input_transform(self.data[index - 1])
95 | return im1, im2
96 | elif self.mode == 'horizon':
97 | images = []
98 | length = self.data.shape[0]
99 | if (index + self.horizon) >= length:
100 | slack = index + self.horizon - length
101 | index = index - slack
102 | for i in range(self.horizon):
103 | t = index + i
104 | images.append(self.input_transform(self.data[t]))
105 | images = torch.stack(images, dim=0)
106 | return images
107 | elif self.mode == 'tps':
108 | im = self.input_transform(self.data[index])
109 | im = im * 255
110 | im2, im1, _, _, _, _ = self.warper(im)
111 | return im1 / 255, im2 / 255
112 | else:
113 | raise NotImplementedError
114 |
115 | def __len__(self):
116 | return self.data.shape[0]
117 |
118 |
119 | if __name__ == '__main__':
120 | # prepare data
121 | path_to_image_dir = '/media/newhd/data/traffic_data/rimon_frames/'
122 | frameskip = 3
123 | image_size = 128
124 | # prepare_numpy_file(path_to_image_dir, image_size=128, frameskip=1)
125 | test_epochs = True
126 | # load data
127 | path_to_npy = '/media/newhd/data/traffic_data/img128np_fs3.npy'
128 | mode = 'horizon'
129 | horizon = 4
130 | train = True
131 | traffic_ds = TrafficDataset(path_to_npy, mode=mode, train=train, horizon=horizon)
132 | traffic_dl = DataLoader(traffic_ds, shuffle=True, pin_memory=True, batch_size=5)
133 | batch = next(iter(traffic_dl))
134 | if mode == 'single':
135 | im1 = batch[0]
136 | elif mode == 'frames' or mode == 'tps':
137 | im1 = batch[0][0]
138 | im2 = batch[1][0]
139 |
140 | if mode == 'single':
141 | print(im1.shape)
142 | img_np = im1.permute(1, 2, 0).data.cpu().numpy()
143 | fig = plt.figure(figsize=(5, 5))
144 | ax = fig.add_subplot(111)
145 | ax.imshow(img_np)
146 | elif mode == 'horizon':
147 | print(f'batch shape: {batch.shape}')
148 | images = batch[0]
149 | print(f'images shape: {images.shape}')
150 | fig = plt.figure(figsize=(8, 8))
151 | for i in range(images.shape[0]):
152 | ax = fig.add_subplot(1, horizon, i + 1)
153 | im = images[i]
154 | im_np = im.permute(1, 2, 0).data.cpu().numpy()
155 | ax.imshow(im_np)
156 | ax.set_title(f'im {i + 1}')
157 | else:
158 | print(f'im1: {im1.shape}, im2: {im2.shape}')
159 | im1_np = im1.permute(1, 2, 0).data.cpu().numpy()
160 | im2_np = im2.permute(1, 2, 0).data.cpu().numpy()
161 | fig = plt.figure(figsize=(8, 8))
162 | ax = fig.add_subplot(1, 2, 1)
163 | ax.imshow(im1_np)
164 | ax.set_title('im1')
165 |
166 | ax = fig.add_subplot(1, 2, 2)
167 | ax.imshow(im2_np)
168 | ax.set_title('im2 [t-1] or [tps]')
169 | plt.show()
170 | if test_epochs:
171 | from tqdm import tqdm
172 | pbar = tqdm(iterable=traffic_dl)
173 | for batch in pbar:
174 | pass
175 | pbar.close()
176 |
--------------------------------------------------------------------------------
/environment17.yml:
--------------------------------------------------------------------------------
1 | name: torch17
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - cudatoolkit=10.2.89
8 | - ffmpeg=4.0
9 | - hdf5=1.10.2
10 | - imageio=2.6.1
11 | - matplotlib=3.4.2
12 | - matplotlib-base=3.4.2
13 | - opencv=3.4.2
14 | - pillow=8.3.1
15 | - pip=21.1.3
16 | - py-opencv=3.4.2
17 | - python=3.7.10
18 | - python-dateutil=2.8.2
19 | - pytorch=1.7.1
20 | - scikit-image=0.18.1
21 | - torchaudio=0.7.2
22 | - torchvision=0.8.2
23 | - pip:
24 | - absl-py==0.15.0
25 | - accelerate==0.3.0
26 | - astunparse==1.6.3
27 | - cachetools==4.2.4
28 | - charset-normalizer==2.0.3
29 | - cloudpickle==2.0.0
30 | - dm-tree==0.1.6
31 | - flatbuffers==1.12
32 | - gast==0.3.3
33 | - google-auth==2.3.3
34 | - google-auth-oauthlib==0.4.6
35 | - google-pasta==0.2.0
36 | - googledrivedownloader==0.4
37 | - grpcio==1.32.0
38 | - h5py==2.10.0
39 | - idna==3.2
40 | - importlib-metadata==4.10.0
41 | - isodate==0.6.0
42 | - jinja2==3.0.1
43 | - joblib==1.0.1
44 | - markdown==3.3.6
45 | - markupsafe==2.0.1
46 | - networkx==2.6.2
47 | - numpy==1.19.5
48 | - oauthlib==3.1.1
49 | - opt-einsum==3.3.0
50 | - packaging==21.0
51 | - pandas==1.3.1
52 | - protobuf==3.19.3
53 | - pyaml==20.4.0
54 | - pyasn1==0.4.8
55 | - pyasn1-modules==0.2.8
56 | - python-louvain==0.15
57 | - pytz==2021.1
58 | - pyyaml==5.4.1
59 | - rdflib==6.0.0
60 | - requests-oauthlib==1.3.0
61 | - rsa==4.8
62 | - scikit-learn==0.24.2
63 | - scipy==1.7.0
64 | - six==1.15.0
65 | - termcolor==1.1.0
66 | - threadpoolctl==2.2.0
67 | - torch-cluster==1.5.9
68 | - torch-geometric==1.7.2
69 | - torch-scatter==2.0.7
70 | - torch-sparse==0.6.9
71 | - torch-spline-conv==1.2.1
72 | - tqdm==4.61.2
73 | - typing-extensions==3.7.4.3
74 | - urllib3==1.26.6
75 | - werkzeug==2.0.2
76 | - wrapt==1.12.1
77 | - zipp==3.7.0
78 |
--------------------------------------------------------------------------------
/environment19.yml:
--------------------------------------------------------------------------------
1 | name: torch19
2 | channels:
3 | - pyg
4 | - pytorch
5 | - nvidia
6 | - conda-forge
7 | - defaults
8 | dependencies:
9 | - cudatoolkit=11.1.74
10 | - ffmpeg=4.3
11 | - h5py=2.10.0
12 | - hdf5=1.10.5
13 | - imageio=2.6.1
14 | - ipykernel=6.4.1
15 | - ipython=7.27.0
16 | - ipython_genutils=0.2.0
17 | - ipywidgets=7.6.4
18 | - jupyter=1.0.0
19 | - jupyter_client=7.0.2
20 | - jupyter_console=6.4.0
21 | - jupyter_core=4.7.1
22 | - jupyterlab_pygments=0.1.2
23 | - jupyterlab_widgets=1.0.1
24 | - matplotlib=3.4.2
25 | - matplotlib-base=3.4.2
26 | - matplotlib-inline=0.1.3
27 | - notebook=6.4.3
28 | - numpy=1.20.3
29 | - numpy-base=1.20.3
30 | - pandas=1.3.0
31 | - pillow=8.3.1
32 | - pip=21.0.1
33 | - python=3.7.11
34 | - pytorch=1.9.0
35 | - pytorch-cluster=1.5.9
36 | - pytorch-scatter=2.0.8
37 | - pytorch-sparse=0.6.12
38 | - pytorch-spline-conv=1.2.1
39 | - pyyaml=5.4.1
40 | - scikit-image=0.18.1
41 | - scikit-learn=0.24.2
42 | - scipy=1.6.3
43 | - torchaudio=0.9.0
44 | - torchfile=0.1.0
45 | - torchvision=0.10.0
46 | - tqdm=4.62.2
47 | - yacs=0.1.6
48 | - yaml=0.2.5
49 | - pip:
50 | - accelerate==0.5.1
51 | - opencv-python==3.4.2.17
52 |
--------------------------------------------------------------------------------
/eval/eval_model.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluation of the ELBO on the validation set
3 | """
4 | # imports
5 | import numpy as np
6 | # torch
7 | import torch
8 | import torch.nn.functional as F
9 | from utils.loss_functions import ChamferLossKL, calc_kl, calc_reconstruction_loss, VGGDistance
10 | from torch.utils.data import DataLoader
11 | import torchvision.utils as vutils
12 | # datasets
13 | from datasets.traffic_ds import TrafficDataset
14 | from datasets.clevrer_ds import CLEVRERDataset
15 | # util functions
16 | from utils.util_func import plot_keypoints_on_image_batch
17 |
18 |
19 | def evaluate_validation_elbo(model, ds, epoch, batch_size=100, recon_loss_type="vgg", device=torch.device('cpu'),
20 | save_image=False, fig_dir='./', topk=5, recon_loss_func=None, beta_rec=1.0, beta_kl=1.0,
21 | kl_balance=1.0, accelerator=None):
22 | model.eval()
23 | kp_range = model.kp_range
24 | # load data
25 | if ds == "traffic":
26 | image_size = 128
27 | root = '/mnt/data/tal/traffic_dataset/img128np_fs3.npy'
28 | mode = 'single'
29 | dataset = TrafficDataset(path_to_npy=root, image_size=image_size, mode=mode, train=False)
30 | elif ds == 'clevrer':
31 | image_size = 128
32 | root = '/mnt/data/tal/clevrer/clevrer_img128np_fs3_valid.npy'
33 | # root = '/media/newhd/data/clevrer/valid/clevrer_img128np_fs3_valid.npy'
34 | mode = 'single'
35 | dataset = CLEVRERDataset(path_to_npy=root, image_size=image_size, mode=mode, train=False)
36 | else:
37 | raise NotImplementedError
38 |
39 | dataloader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=2, drop_last=False)
40 | kl_loss_func = ChamferLossKL(use_reverse_kl=False)
41 | if recon_loss_func is None:
42 | if recon_loss_type == "vgg":
43 | recon_loss_func = VGGDistance(device=device)
44 | else:
45 | recon_loss_func = calc_reconstruction_loss
46 |
47 | elbos = []
48 | for batch in dataloader:
49 | if ds == 'traffic':
50 | if mode == 'single':
51 | x = batch.to(device)
52 | x_prior = x
53 | else:
54 | x = batch[0].to(device)
55 | x_prior = batch[1].to(device)
56 | elif ds == 'clevrer':
57 | if mode == 'single':
58 | x = batch.to(device)
59 | x_prior = x
60 | else:
61 | x = batch[0].to(device)
62 | x_prior = batch[1].to(device)
63 | else:
64 | x = batch
65 | x_prior = x
66 | batch_size = x.shape[0]
67 | # forward pass
68 | with torch.no_grad():
69 | model_output = model(x, x_prior=x_prior)
70 | mu_p = model_output['kp_p']
71 | gmap = model_output['gmap']
72 | mu = model_output['mu']
73 | logvar = model_output['logvar']
74 | rec_x = model_output['rec']
75 | mu_features = model_output['mu_features']
76 | logvar_features = model_output['logvar_features']
77 | # object stuff
78 | dec_objects_original = model_output['dec_objects_original']
79 | cropped_objects_original = model_output['cropped_objects_original']
80 |
81 | # reconstruction error
82 | if recon_loss_type == "vgg":
83 | loss_rec = recon_loss_func(x, rec_x, reduction="mean")
84 | else:
85 | loss_rec = calc_reconstruction_loss(x, rec_x, loss_type='mse', reduction='mean')
86 |
87 | # kl-divergence
88 | logvar_p = torch.log(torch.tensor(model.sigma ** 2)).to(mu.device) # logvar of the constant std -> for the kl
89 | logvar_kp = logvar_p.expand_as(mu_p)
90 |
91 | mu_post = mu
92 | logvar_post = logvar
93 | mu_prior = mu_p
94 | logvar_prior = logvar_kp
95 |
96 | loss_kl_kp = kl_loss_func(mu_preds=mu_post, logvar_preds=logvar_post, mu_gts=mu_prior,
97 | logvar_gts=logvar_prior).mean()
98 | if model.learned_feature_dim > 0:
99 | loss_kl_feat = calc_kl(logvar_features.view(-1, logvar_features.shape[-1]),
100 | mu_features.view(-1, mu_features.shape[-1]), reduce='none')
101 | loss_kl_feat = loss_kl_feat.view(batch_size, model.n_kp_enc + 1).sum(1).mean()
102 | else:
103 | loss_kl_feat = torch.tensor(0.0, device=mu.device)
104 | loss_kl = loss_kl_kp + kl_balance * loss_kl_feat
105 | elbo = beta_rec * loss_rec + beta_kl * loss_kl
106 | elbos.append(elbo.data.cpu().numpy())
107 | if save_image:
108 | max_imgs = 8
109 | img_with_kp = plot_keypoints_on_image_batch(mu.clamp(min=model.kp_range[0], max=model.kp_range[1]), x, radius=3,
110 | thickness=1, max_imgs=max_imgs, kp_range=model.kp_range)
111 | img_with_kp_p = plot_keypoints_on_image_batch(mu_p, x_prior, radius=3, thickness=1, max_imgs=max_imgs,
112 | kp_range=model.kp_range)
113 | # top-k
114 | with torch.no_grad():
115 | logvar_sum = logvar.sum(-1)
116 | logvar_topk = torch.topk(logvar_sum, k=topk, dim=-1, largest=False)
117 | indices = logvar_topk[1] # [batch_size, topk]
118 | batch_indices = torch.arange(mu.shape[0]).view(-1, 1).to(mu.device)
119 | topk_kp = mu[batch_indices, indices]
120 | img_with_kp_topk = plot_keypoints_on_image_batch(topk_kp.clamp(min=kp_range[0], max=kp_range[1]), x,
121 | radius=3, thickness=1, max_imgs=max_imgs,
122 | kp_range=kp_range)
123 | if model.use_object_dec and dec_objects_original is not None:
124 | dec_objects = model_output['dec_objects']
125 | if accelerator is not None:
126 | if accelerator.is_main_process:
127 | vutils.save_image(torch.cat([x[:max_imgs, -3:], img_with_kp[:max_imgs, -3:].to(mu.device),
128 | rec_x[:max_imgs, -3:], img_with_kp_p[:max_imgs, -3:].to(mu.device),
129 | img_with_kp_topk[:max_imgs, -3:].to(mu.device),
130 | dec_objects[:max_imgs, -3:]],
131 | dim=0).data.cpu(), '{}/image_valid_{}.jpg'.format(fig_dir, epoch),
132 | nrow=8, pad_value=1)
133 | else:
134 | vutils.save_image(torch.cat([x[:max_imgs, -3:], img_with_kp[:max_imgs, -3:].to(mu.device),
135 | rec_x[:max_imgs, -3:], img_with_kp_p[:max_imgs, -3:].to(mu.device),
136 | img_with_kp_topk[:max_imgs, -3:].to(mu.device),
137 | dec_objects[:max_imgs, -3:]],
138 | dim=0).data.cpu(), '{}/image_valid_{}.jpg'.format(fig_dir, epoch),
139 | nrow=8, pad_value=1)
140 | with torch.no_grad():
141 | _, dec_objects_rgb = torch.split(dec_objects_original, [1, 3], dim=2)
142 | dec_objects_rgb = dec_objects_rgb.reshape(-1, *dec_objects_rgb.shape[2:])
143 | cropped_objects_original = cropped_objects_original.clone().reshape(-1, 3,
144 | cropped_objects_original.shape[
145 | -1],
146 | cropped_objects_original.shape[
147 | -1])
148 | if cropped_objects_original.shape[-1] != dec_objects_rgb.shape[-1]:
149 | cropped_objects_original = F.interpolate(cropped_objects_original,
150 | size=dec_objects_rgb.shape[-1],
151 | align_corners=False, mode='bilinear')
152 | if accelerator is not None:
153 | if accelerator.is_main_process:
154 | vutils.save_image(
155 | torch.cat([cropped_objects_original[:max_imgs * 2, -3:], dec_objects_rgb[:max_imgs * 2, -3:]],
156 | dim=0).data.cpu(), '{}/image_obj_valid_{}.jpg'.format(fig_dir, epoch),
157 | nrow=8, pad_value=1)
158 | else:
159 | vutils.save_image(
160 | torch.cat([cropped_objects_original[:max_imgs * 2, -3:], dec_objects_rgb[:max_imgs * 2, -3:]],
161 | dim=0).data.cpu(), '{}/image_obj_valid_{}.jpg'.format(fig_dir, epoch),
162 | nrow=8, pad_value=1)
163 | else:
164 | if accelerator is not None:
165 | if accelerator.is_main_process:
166 | vutils.save_image(torch.cat([x[:max_imgs, -3:], img_with_kp[:max_imgs, -3:].to(mu.device),
167 | rec_x[:max_imgs, -3:], img_with_kp_p[:max_imgs, -3:].to(mu.device),
168 | img_with_kp_topk[:max_imgs, -3:].to(mu.device)],
169 | dim=0).data.cpu(), '{}/image_valid_{}.jpg'.format(fig_dir, epoch),
170 | nrow=8, pad_value=1)
171 | else:
172 | vutils.save_image(torch.cat([x[:max_imgs, -3:], img_with_kp[:max_imgs, -3:].to(mu.device),
173 | rec_x[:max_imgs, -3:], img_with_kp_p[:max_imgs, -3:].to(mu.device),
174 | img_with_kp_topk[:max_imgs, -3:].to(mu.device)],
175 | dim=0).data.cpu(), '{}/image_valid_{}.jpg'.format(fig_dir, epoch),
176 | nrow=8, pad_value=1)
177 | return np.mean(elbos)
178 |
--------------------------------------------------------------------------------
/eval_celeb.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluate supervised regression on CelebA-HQ 128x128
3 | """
4 | # imports
5 | import torch
6 | from datasets.celeba_dataset import evaluate_lin_reg_on_mafl_topk
7 | from models import KeyPointVAE
8 |
9 | if __name__ == '__main__':
10 | image_size = 128
11 | imwidth = 160
12 | crop = 16
13 | ch = 3
14 | enc_channels = [32, 64, 128, 256]
15 | prior_channels = (16, 32, 64)
16 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
17 | use_logsoftmax = False
18 | pad_mode = 'replicate'
19 | sigma = 0.1 # default sigma for the gaussian maps
20 | n_kp = 1 # num kp per patch
21 | n_kp_enc = 30 # total kp to output from the encoder / filter from prior
22 | n_kp_prior = 50
23 | mask_threshold = 0.2 # mask threshold for the features from the encoder
24 | patch_size = 8 # 8 for playground, 16 for celeb
25 | learned_feature_dim = 10 # additional features than x,y for each kp
26 | kp_range = (-1, 1)
27 | dec_bone = "gauss_pointnetpp_feat"
28 | topk = 10
29 | kp_activation = "tanh"
30 | use_object_enc = True # separate object encoder
31 | use_object_dec = False # separate object decoder
32 | anchor_s = 0.125
33 | learn_order = False
34 | kl_balance = 0.001
35 | dropout = 0.0
36 | root = '/mnt/data/tal/celeba'
37 | path_to_model_ckpt = './checkpoints/dlp_celeba_gauss_pointnetpp_feat.pth'
38 |
39 | model = KeyPointVAE(cdim=ch, enc_channels=enc_channels, prior_channels=prior_channels,
40 | image_size=image_size, n_kp=n_kp, learned_feature_dim=learned_feature_dim,
41 | use_logsoftmax=use_logsoftmax, pad_mode=pad_mode, sigma=sigma,
42 | dropout=dropout, dec_bone=dec_bone, patch_size=patch_size, n_kp_enc=n_kp_enc,
43 | n_kp_prior=n_kp_prior, kp_range=kp_range, kp_activation=kp_activation,
44 | mask_threshold=mask_threshold, use_object_enc=use_object_enc,
45 | use_object_dec=use_object_dec, anchor_s=anchor_s, learn_order=learn_order).to(device)
46 | model.load_state_dict(
47 | torch.load(path_to_model_ckpt, map_location=device))
48 | print("loaded model from checkpoint")
49 | print('evaluating linear regression')
50 | result = evaluate_lin_reg_on_mafl_topk(model, root=root, batch_size=100, device=device, topk=topk)
51 | print(result)
52 | print(f"all kp (mu) err: {result['mu_err'] * 100}%")
53 | print(f"top-10 confident kp err: {result['mu_confident_err'] * 100}%")
54 | print(f"top-10 uncertain kp err: {result['mu_uncertainty_err'] * 100}%")
55 | print(f"all kp with logvar err: {result['logvar_err'] * 100}%")
56 | print(f"all kp with logvar and features err: {result['feat_err'] * 100}%")
57 |
--------------------------------------------------------------------------------
/interactive_demo_dlp.py:
--------------------------------------------------------------------------------
1 | """
2 | Interactive demo to observe and explore the learned particles.
3 | """
4 | import torch
5 | from PIL import Image
6 | import torchvision.transforms as transforms
7 | import torchvision.transforms.functional as TF
8 |
9 | from models import KeyPointVAE
10 | from utils.util_func import reparameterize
11 |
12 | import matplotlib
13 | from matplotlib.widgets import Slider, Button
14 | import matplotlib as mpl
15 | from matplotlib import pyplot as plt
16 | import numpy as np
17 | import argparse
18 |
19 | matplotlib.use('Qt5Agg')
20 |
21 |
22 | def update_from_slider(val):
23 | for i in np.arange(N - 1):
24 | yvals[i] = sliders_y[i].val
25 | # xvals[i] = sliders_x[i].val
26 | if learned_feature_dim > 0:
27 | feature_1_vals[i] = sliders_features[i].val
28 | update(val)
29 |
30 |
31 | def update(val):
32 | global yvals
33 | global xvals
34 | global dec_bone
35 | global learned_feature_dim
36 | if learned_feature_dim > 0:
37 | global feature_1_vals
38 | # update curve
39 | for i in np.arange(N - 1):
40 | if learned_feature_dim > 0:
41 | # print(f'{i}: {feature_1_vals[i]}')
42 | feature_1_vals[i] = sliders_features[i].val
43 | # print(f'{i}: {feature_1_vals[i]}')
44 | l.set_offsets(np.c_[xvals, yvals])
45 | # convert to tensors
46 | new_mu = torch.from_numpy(np.stack([yvals, xvals], axis=-1)).unsqueeze(0).to(device) / (image_size - 1) # [0, 1]
47 | new_mu = new_mu * (kp_range[1] - kp_range[0]) + kp_range[0] # [kp_range[0], kp_range[1]]
48 | new_mu = torch.cat([new_mu, original_mu[:, -1].unsqueeze(1)], dim=1)
49 | delta_mu = new_mu - original_mu
50 | # print(f'delta_mu: {delta_mu}')
51 | if learned_feature_dim > 0:
52 | new_features = torch.from_numpy(feature_1_vals[None, :, None]).to(device)
53 | new_features = torch.cat([mu_features[:, :, :-1], new_features], dim=-1)
54 | else:
55 | new_features = None
56 | with torch.no_grad():
57 | rec_new, _, _ = model.decode_all(new_mu, new_features, kp_heatmap, obj_on, deterministic=deterministic,
58 | order_weights=order_weights)
59 | rec_new = rec_new.clamp(0, 1)
60 |
61 | image_rec_new = rec_new[0].permute(1, 2, 0).data.cpu().numpy()
62 | m.set_data(image_rec_new)
63 | # redraw canvas while idle
64 | fig.canvas.draw_idle()
65 |
66 |
67 | def reset(event):
68 | global yvals
69 | global xvals
70 | global learned_feature_dim
71 | if learned_feature_dim > 0:
72 | global feature_1_vals
73 | # reset the values
74 | xvals = mu[0, :-1, 1].data.cpu().numpy() * (image_size - 1)
75 | yvals = mu[0, :-1, 0].data.cpu().numpy() * (image_size - 1)
76 | if learned_feature_dim > 0:
77 | # a slider for the last feature dimension
78 | # feature_1_vals = mu_features[0, :, 0].data.cpu().numpy()
79 | feature_1_vals = mu_features[0, :, -1].data.cpu().numpy()
80 | for i in np.arange(N - 1):
81 | sliders_y[i].reset()
82 | if learned_feature_dim > 0:
83 | sliders_features[i].reset()
84 | l.set_offsets(np.c_[xvals, yvals])
85 | m.set_data(image_rec)
86 | # redraw canvas while idle
87 | fig.canvas.draw_idle()
88 |
89 |
90 | def button_press_callback(event):
91 | 'whenever a mouse button is pressed'
92 | global pind
93 | if event.inaxes is None:
94 | return
95 | if event.button != 1:
96 | return
97 | pind = get_ind_under_point(event)
98 |
99 |
100 | def button_release_callback(event):
101 | 'whenever a mouse button is released'
102 | global pind
103 | if event.button != 1:
104 | return
105 | pind = None
106 |
107 |
108 | def get_ind_under_point(event):
109 | 'get the index of the vertex under point if within epsilon tolerance'
110 | t = ax1.transData.inverted()
111 | tinv = ax1.transData
112 | xy = t.transform([event.x, event.y])
113 | xr = np.reshape(xvals, (np.shape(xvals)[0], 1))
114 | yr = np.reshape(yvals, (np.shape(yvals)[0], 1))
115 | xy_vals = np.append(xr, yr, 1)
116 | xyt = tinv.transform(xy_vals)
117 | xt, yt = xyt[:, 0], xyt[:, 1]
118 | d = np.hypot(xt - event.x, yt - event.y)
119 | indseq, = np.nonzero(d == d.min())
120 | ind = indseq[0]
121 |
122 | if d[ind] >= epsilon:
123 | ind = None
124 | return ind
125 |
126 |
127 | def motion_notify_callback(event):
128 | 'on mouse movement'
129 | global xvals
130 | global yvals
131 | if pind is None:
132 | return
133 | if event.inaxes is None:
134 | return
135 | if event.button != 1:
136 | return
137 |
138 | # update yvals
139 | # print('motion x: {0}; y: {1}'.format(event.xdata,event.ydata))
140 | # print(f'delta: x: {event.xdata - xvals[pind]}. y: {event.ydata - yvals[pind]}')
141 | delta_x = event.xdata - xvals[pind]
142 | delta_y = event.ydata - yvals[pind]
143 | yvals[pind] = event.ydata
144 | xvals[pind] = event.xdata
145 |
146 | # yvals[pind + 1] = yvals[pind + 1] + delta_y
147 | # xvals[pind + 1] = xvals[pind + 1] + delta_x
148 |
149 | update(None)
150 |
151 | # update curve via sliders and draw
152 | sliders_y[pind].set_val(yvals[pind])
153 | # sliders_y[pind + 1].set_val(yvals[pind + 1])
154 |
155 | # sliders_x[pind].set_val(xvals[pind])
156 | fig.canvas.draw_idle()
157 |
158 |
159 | def bn_eval(model):
160 | """
161 | https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/67
162 | for batch_size = 1 don't use the running stats
163 | """
164 | for m in model.modules():
165 | for child in m.children():
166 | if type(child) == torch.nn.BatchNorm2d or type(child) == torch.nn.BatchNorm1d:
167 | child.track_running_stats = False
168 | child.running_mean = None
169 | child.running_var = None
170 |
171 |
172 | if __name__ == '__main__':
173 | parser = argparse.ArgumentParser(description="DLP Interactive Demo")
174 | parser.add_argument("-d", "--dataset", type=str, default='celeba',
175 | help="dataset of pretrained model: ['celeba', 'traffic', 'clevrer']")
176 | parser.add_argument("-i", "--index", type=int,
177 | help="index of image in ./checkpoints/sample_images/dataset/", default=0)
178 | args = parser.parse_args()
179 | # hyper-parameters for model
180 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
181 | use_logsoftmax = False
182 | pad_mode = 'replicate'
183 | sigma = 0.1 # default sigma for the gaussian maps
184 | dropout = 0.0
185 | n_kp = 1 # num kp per patch
186 | kp_range = (-1, 1)
187 | kp_activation = "tanh"
188 | mask_threshold = 0.2
189 | learn_order = False
190 |
191 | ds = args.dataset
192 |
193 | image_idx = args.index
194 | image_idx = max(0, image_idx)
195 |
196 | if ds == 'celeba':
197 | path_to_model_ckpt = './checkpoints/dlp_celeba_gauss_pointnetpp_feat.pth'
198 | image_size = 128
199 | ch = 3
200 | enc_channels = [32, 64, 128, 256]
201 | prior_channels = (16, 32, 64)
202 | imwidth = 160
203 | crop = 16
204 | n_kp_enc = 30 # total kp to output from the encoder / filter from prior
205 | n_kp_prior = 50 # total kp to filter from prior
206 | use_object_enc = True
207 | use_object_dec = False
208 | learned_feature_dim = 10
209 | patch_size = 8
210 | anchor_s = 0.125
211 | dec_bone = "gauss_pointnetpp_feat"
212 | exclusive_patches = False
213 | elif ds == 'traffic':
214 | path_to_model_ckpt = './checkpoints/dlp_traffic_gauss_pointnetpp.pth'
215 | image_size = 128
216 | ch = 3
217 | enc_channels = [32, 64, 128, 256]
218 | prior_channels = (16, 32, 64)
219 | imwidth = 160
220 | crop = 16
221 | n_kp_enc = 15 # total kp to output from the encoder / filter from prior
222 | n_kp_prior = 20 # total kp to filter from prior
223 | use_object_enc = True
224 | use_object_dec = True
225 | learned_feature_dim = 20
226 | patch_size = 16
227 | anchor_s = 0.25
228 | dec_bone = "gauss_pointnetpp"
229 | exclusive_patches = False
230 | elif ds == 'clevrer':
231 | path_to_model_ckpt = './checkpoints/dlp_clevrer_gauss_pointnetpp_orig.pth'
232 | image_size = 128
233 | ch = 3
234 | enc_channels = [32, 64, 128, 256]
235 | prior_channels = (16, 32, 64)
236 | imwidth = 160
237 | crop = 16
238 | n_kp_enc = 10 # total kp to output from the encoder / filter from prior
239 | n_kp_prior = 20 # total kp to filter from prior
240 | use_object_enc = True
241 | use_object_dec = True
242 | learned_feature_dim = 5
243 | # learned_feature_dim = 8
244 | patch_size = 16
245 | anchor_s = 0.25
246 | dec_bone = "gauss_pointnetpp"
247 | exclusive_patches = False
248 | else:
249 | raise NotImplementedError
250 |
251 | model = KeyPointVAE(cdim=ch, enc_channels=enc_channels, prior_channels=prior_channels,
252 | image_size=image_size, n_kp=n_kp, learned_feature_dim=learned_feature_dim,
253 | use_logsoftmax=use_logsoftmax, pad_mode=pad_mode, sigma=sigma,
254 | dropout=dropout, dec_bone=dec_bone, patch_size=patch_size, n_kp_enc=n_kp_enc,
255 | n_kp_prior=n_kp_prior, kp_range=kp_range, kp_activation=kp_activation,
256 | mask_threshold=mask_threshold, use_object_enc=use_object_enc,
257 | exclusive_patches=exclusive_patches, use_object_dec=use_object_dec, anchor_s=anchor_s,
258 | learn_order=learn_order).to(device)
259 | model.load_state_dict(torch.load(path_to_model_ckpt, map_location=device), strict=False)
260 | model.eval()
261 | print("loaded model from checkpoint")
262 | logvar_threshold = 0.0 # threshold to filter particles
263 | if ds == 'celeba':
264 | # load image
265 | path_to_images = ['./checkpoints/sample_images/celeb/1.jpg', './checkpoints/sample_images/celeb/2.jpg',
266 | './checkpoints/sample_images/celeb/3.jpg', './checkpoints/sample_images/celeb/4.jpg',
267 | './checkpoints/sample_images/celeb/5.jpg', './checkpoints/sample_images/celeb/6.jpg',
268 | './checkpoints/sample_images/celeb/7.jpg', './checkpoints/sample_images/celeb/8.jpg']
269 | image_idx = min(image_idx, len(path_to_images) - 1)
270 | path_to_image = path_to_images[image_idx]
271 | im = Image.open(path_to_image)
272 | # move head up a bit
273 | vertical_shift = 30
274 | initial_crop = lambda im: transforms.functional.crop(im, 30, 0, 178, 178)
275 | initial_transforms = transforms.Compose([initial_crop, transforms.Resize(imwidth)])
276 | trans = transforms.ToTensor()
277 | data = trans(initial_transforms(im.convert("RGB")))
278 | if crop != 0:
279 | data = data[:, crop:-crop, crop:-crop]
280 | data = data.unsqueeze(0).to(device)
281 | elif ds == 'traffic':
282 | path_to_images = ['./checkpoints/sample_images/traffic/1.png', ]
283 | image_idx = min(image_idx, len(path_to_images) - 1)
284 | path_to_image = path_to_images[image_idx]
285 | im = Image.open(path_to_image)
286 | im = im.convert('RGB')
287 | im = im.crop((60, 0, 480, 420))
288 | im = im.resize((image_size, image_size), Image.BICUBIC)
289 | trans = transforms.ToTensor()
290 | data = trans(im)
291 | data = data.unsqueeze(0).to(device)
292 | x = data
293 | logvar_threshold = 14.0 # threshold to filter particles
294 | elif ds == 'clevrer':
295 | path_to_images = ['./checkpoints/sample_images/clevrer/1.png',
296 | './checkpoints/sample_images/clevrer/2.png']
297 | image_idx = min(image_idx, len(path_to_images) - 1)
298 | path_to_image = path_to_images[image_idx]
299 | im = Image.open(path_to_image)
300 | im = im.convert('RGB')
301 | im = im.resize((image_size, image_size), Image.BICUBIC)
302 | trans = transforms.ToTensor()
303 | data = trans(im)
304 | data = data.unsqueeze(0).to(device)
305 | x = data
306 | logvar_threshold = 13.0 # threshold to filter particles
307 | else:
308 | raise NotImplementedError
309 |
310 | with torch.no_grad():
311 | deterministic = True
312 | enc_out = model.encode_all(data, return_heatmap=True, deterministic=deterministic)
313 | mu, logvar, kp_heatmap, mu_features, logvar_features, obj_on, order_weights = enc_out
314 | if deterministic:
315 | z = mu
316 | z_features = mu_features
317 | else:
318 | z = reparameterize(mu, logvar)
319 | z_features = reparameterize(mu_features, logvar_features)
320 |
321 | # top-k
322 | logvar_sum = logvar.sum(-1)
323 | logvar_topk = torch.topk(logvar_sum, k=5, dim=-1, largest=False)
324 | indices = logvar_topk[1] # [batch_size, topk]
325 | batch_indices = torch.arange(mu.shape[0]).view(-1, 1).to(mu.device)
326 | topk_kp = mu[batch_indices, indices]
327 | print(f'logvar: {logvar_sum[0].data.cpu()}')
328 |
329 | if learn_order:
330 | order_of_kp = [torch.argmax(order_weights[0][i]).item() for i in range(order_weights.shape[-1])]
331 | print(f'order of kp: {order_of_kp}')
332 | if obj_on is not None:
333 | obj_on = torch.where((torch.abs(logvar_sum[:, :-1]) > logvar_threshold), obj_on,
334 | torch.tensor(0.0, dtype=torch.float, device=obj_on.device))
335 | print(f'obj_on: {obj_on[0].data.cpu()}')
336 |
337 | rec, _, _ = model.decode_all(z, z_features, kp_heatmap, obj_on, deterministic=deterministic,
338 | order_weights=order_weights)
339 | rec = rec.clamp(0, 1)
340 |
341 | N = mu.shape[1]
342 | xmin = 0
343 | xmax = image_size
344 |
345 | x = np.linspace(xmin, xmax, N)
346 |
347 | mu = mu.clamp(kp_range[0], kp_range[1])
348 | original_mu = mu.clone()
349 | mu = (mu - kp_range[0]) / (kp_range[1] - kp_range[0])
350 | xvals = mu[0, :-1, 1].data.cpu().numpy() * (image_size - 1)
351 | yvals = mu[0, :-1, 0].data.cpu().numpy() * (image_size - 1)
352 | if learned_feature_dim > 0:
353 | # feature_1_vals = mu_features[0, :, 0].data.cpu().numpy()
354 | feature_1_vals = mu_features[0, :, -1].data.cpu().numpy()
355 |
356 | # set up a plot for topk
357 | topk_kp = topk_kp.clamp(kp_range[0], kp_range[1])
358 | topk_kp = (topk_kp - kp_range[0]) / (kp_range[1] - kp_range[0])
359 | xvals_topk = topk_kp[0, :, 1].data.cpu().numpy() * (image_size - 1)
360 | yvals_topk = topk_kp[0, :, 0].data.cpu().numpy() * (image_size - 1)
361 |
362 | fig = plt.figure(figsize=(10, 10))
363 | ax1 = fig.add_subplot(111)
364 | image = data[0].permute(1, 2, 0).data.cpu().numpy()
365 | ax1.imshow(image)
366 | ax1.scatter(xvals, yvals, label='original', s=70)
367 | ax1.set_axis_off()
368 | ax1.set_title('all particles')
369 | fig = plt.figure(figsize=(10, 10))
370 | ax2 = fig.add_subplot(111)
371 | ax2.imshow(image)
372 | ax2.scatter(xvals_topk, yvals_topk, label='topk', s=70, color='red')
373 | ax2.set_axis_off()
374 | ax2.set_title('top-5 lowest variance particles')
375 |
376 | # figure.subplot.right
377 | mpl.rcParams['figure.subplot.right'] = 0.8
378 |
379 | # set up a plot
380 | fig, axes = plt.subplots(1, 2, figsize=(15.0, 15.0), sharex=True)
381 | ax1, ax2 = axes
382 |
383 | image = data[0].permute(1, 2, 0).data.cpu().numpy()
384 | ax1.imshow(image)
385 | ax1.set_axis_off()
386 | ax2.set_axis_off()
387 | image_rec = rec[0].permute(1, 2, 0).data.cpu().numpy()
388 | m = ax2.imshow(image_rec)
389 |
390 | pind = None # active point
391 | epsilon = 10 # max pixel distance
392 |
393 | ax1.scatter(xvals, yvals, label='original', s=70)
394 | l = ax1.scatter(xvals, yvals, color='red', marker='*', s=70)
395 |
396 | ax1.set_xlabel('x')
397 | ax1.set_ylabel('y')
398 |
399 | sliders_y = []
400 | sliders_x = []
401 | if learned_feature_dim > 0:
402 | sliders_features = []
403 |
404 | for i in np.arange(N - 1):
405 | slider_width = 0.04 if learned_feature_dim > 0 else 0.12
406 | axamp = plt.axes([0.84, 0.85 - (i * 0.025), slider_width, 0.01])
407 | # Slider y
408 | s_y = Slider(axamp, 'p_y{0}'.format(i), 0, image_size, valinit=yvals[i])
409 | sliders_y.append(s_y)
410 | if learned_feature_dim > 0:
411 | axamp_f = plt.axes([0.93, 0.85 - (i * 0.025), slider_width, 0.01])
412 | s_feat = Slider(axamp_f, f'f_1', -5, 5, valinit=feature_1_vals[i])
413 | sliders_features.append(s_feat)
414 |
415 | for i in np.arange(N - 1):
416 | sliders_y[i].on_changed(update_from_slider)
417 | if learned_feature_dim > 0:
418 | sliders_features[i].on_changed(update_from_slider)
419 |
420 | axres = plt.axes([0.84, 0.85 - ((N) * 0.025), 0.12, 0.01])
421 | bres = Button(axres, 'Reset')
422 | bres.on_clicked(reset)
423 |
424 | fig.canvas.mpl_connect('button_press_event', button_press_callback)
425 | fig.canvas.mpl_connect('button_release_event', button_release_callback)
426 | fig.canvas.mpl_connect('motion_notify_event', motion_notify_callback)
427 |
428 | plt.show()
429 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | """
2 | Main DLP model neural network.
3 | """
4 | # imports
5 | import numpy as np
6 | # torch
7 | import torch
8 | import torch.nn.functional as F
9 | import torch.nn as nn
10 | # modules
11 | from modules.modules import KeyPointCNNOriginal, VariationalKeyPointPatchEncoder, SpatialSoftmaxKP, SpatialLogSoftmaxKP, \
12 | ToGaussianMapHW, CNNDecoder, ObjectEncoder, ObjectDecoderCNN, PointNetPPToCNN
13 | # util functions
14 | from utils.util_func import reparameterize, get_kp_mask_from_gmap, create_masks_fast
15 |
16 |
17 | class KeyPointVAE(nn.Module):
18 | def __init__(self, cdim=3, enc_channels=(16, 16, 32), prior_channels=(16, 16, 32), image_size=64, n_kp=1,
19 | use_logsoftmax=False, pad_mode='replicate', sigma=0.1, dropout=0.0, dec_bone="gauss_pointnetpp",
20 | patch_size=16, n_kp_enc=20, n_kp_prior=20, learned_feature_dim=16,
21 | kp_range=(-1, 1), kp_activation="tanh", mask_threshold=0.2, anchor_s=0.25,
22 | use_object_enc=False, use_object_dec=False, learn_order=False, exclusive_patches=False):
23 | super(KeyPointVAE, self).__init__()
24 | """
25 | cdim: channels of the input image (3...)
26 | enc_channels: channels for the posterior CNN (takes in the whole image)
27 | prior_channels: channels for prior CNN (takes in patches)
28 | n_kp: number of kp to extract from each (!) patch
29 | n_kp_prior: number of kp to filter from the set of prior kp (of size n_kp x num_patches)
30 | n_kp_enc: number of posterior kp to be learned (this is the actual number of kp that will be learnt)
31 | use_logsoftmax: for spatial-softmax, set True to use log-softmax for numerical stability
32 | pad_mode: padding for the CNNs, 'zeros' or 'replicate' (default)
33 | sigma: the prior std of the KP
34 | dropout: dropout for the CNNs. We don't use it though...
35 | dec_bone: decoder backbone -- "gauss_pointnetpp_feat": Masked Model, "gauss_pointnetpp": "Object Model
36 | patch_size: patch size for the prior KP proposals network (not to be confused with the glimpse size)
37 | kp_range: the range of keypoints, can be [-1, 1] (default) or [0,1]
38 | learned_feature_dim: the latent visual features dimensions extracted from glimpses.
39 | kp_activation: the type of activation to apply on the keypoints: "tanh" for kp_range [-1, 1], "sigmoid" for [0, 1]
40 | mask_threshold: activation threshold (>thresh -> 1, else 0) for the binary mask created from the Gaussian-maps.
41 | anchor_s: defines the glimpse size as a ratio of image_size (e.g., 0.25 for image_size=128 -> glimpse_size=32)
42 | learn_order: experimental feature to learn the order of keypoints - but it doesn't work yet.
43 | use_object_enc: set True to use a separate encoder to encode visual features of glimpses.
44 | use_object_dec: set True to use a separate decoder to decode glimpses (Object Model).
45 | exclusive_patches: (mostly) enforce one particle pre object by masking up regions that were already encoded.
46 | """
47 | if dec_bone not in ["gauss_pointnetpp", "gauss_pointnetpp_feat"]:
48 | raise SystemError(f'unrecognized decoder backbone: {dec_bone}')
49 | print(f'decoder backbone: {dec_bone}')
50 | self.dec_bone = dec_bone
51 | self.image_size = image_size
52 | self.use_logsoftmax = use_logsoftmax
53 | self.sigma = sigma
54 | print(f'prior std: {self.sigma}')
55 | self.dropout = dropout
56 | self.kp_range = kp_range
57 | print(f'keypoints range: {self.kp_range}')
58 | self.num_patches = int((image_size // patch_size) ** 2)
59 | self.n_kp = n_kp
60 | self.n_kp_total = self.n_kp * self.num_patches
61 | self.n_kp_prior = min(self.n_kp_total, n_kp_prior)
62 | print(f'total number of kp: {self.n_kp_total} -> prior kp: {self.n_kp_prior}')
63 | self.n_kp_enc = n_kp_enc
64 | print(f'number of kp from encoder: {self.n_kp_enc}')
65 | self.kp_activation = kp_activation
66 | print(f'kp_activation: {self.kp_activation}')
67 | self.patch_size = patch_size
68 | self.features_dim = int(image_size // (2 ** (len(enc_channels) - 1)))
69 | self.learned_feature_dim = learned_feature_dim
70 | print(f'learnable feature dim: {learned_feature_dim}')
71 | self.mask_threshold = mask_threshold
72 | print(f'mask threshold: {self.mask_threshold}')
73 | self.anchor_s = anchor_s
74 | self.obj_patch_size = np.round(anchor_s * (image_size - 1)).astype(int)
75 | print(f'object patch size: {self.obj_patch_size}')
76 | self.use_object_enc = True if use_object_dec else use_object_enc
77 | self.use_object_dec = use_object_dec
78 | print(f'object encoder: {self.use_object_enc}, object decoder: {self.use_object_dec}')
79 | self.learn_order = learn_order
80 | print(f'learn particles order: {self.learn_order}')
81 | self.exclusive_patches = exclusive_patches
82 |
83 | # encoder
84 | self.enc = KeyPointCNNOriginal(cdim=cdim, channels=enc_channels, image_size=image_size, n_kp=self.n_kp_enc,
85 | pad_mode=pad_mode, use_resblock=False)
86 | enc_output_dim = 2 * 2
87 | # flatten feature maps and extract statistics
88 | self.to_normal_stats = nn.Sequential(nn.Linear(self.n_kp_enc * self.features_dim ** 2, 256),
89 | nn.ReLU(True),
90 | nn.Linear(256, 128),
91 | nn.ReLU(True),
92 | nn.Linear(128, self.n_kp_enc * enc_output_dim))
93 | if self.use_object_dec:
94 | if self.learn_order:
95 | enc_aux_output_dim = 1 + self.n_kp_enc # obj_on, ordering weights
96 | else:
97 | enc_aux_output_dim = 1 # obj_on
98 | self.aux_enc = nn.Sequential(nn.Linear(self.n_kp_enc * self.features_dim ** 2, 256),
99 | nn.ReLU(True),
100 | nn.Linear(256, 128),
101 | nn.ReLU(True),
102 | nn.Linear(128, self.n_kp_enc * enc_aux_output_dim))
103 | else:
104 | self.aux_enc = None
105 | # object encoder
106 | object_enc_output_dim = self.learned_feature_dim * 2 # [mu_features, sigma_features]
107 | self.object_enc = nn.Sequential(nn.Linear(self.n_kp_enc * self.features_dim ** 2, 256),
108 | nn.ReLU(True),
109 | nn.Linear(256, 128),
110 | nn.ReLU(True),
111 | nn.Linear(128, object_enc_output_dim))
112 | if self.use_object_enc:
113 | if self.use_object_dec:
114 | self.object_enc_sep = ObjectEncoder(z_dim=learned_feature_dim, anchor_size=anchor_s,
115 | image_size=image_size, ch=cdim, margin=0, cnn=True)
116 | else:
117 | self.object_enc_sep = ObjectEncoder(z_dim=learned_feature_dim, anchor_size=anchor_s,
118 | image_size=self.features_dim, ch=self.n_kp_enc,
119 | margin=0, cnn=False, encode_location=True)
120 | else:
121 | self.object_enc_sep = None
122 | self.prior = VariationalKeyPointPatchEncoder(cdim=cdim, channels=prior_channels, image_size=image_size,
123 | n_kp=n_kp, kp_range=self.kp_range,
124 | patch_size=patch_size, use_logsoftmax=use_logsoftmax,
125 | pad_mode=pad_mode, sigma=sigma, dropout=dropout,
126 | learnable_logvar=False, learned_feature_dim=0)
127 | self.ssm = SpatialLogSoftmaxKP(kp_range=kp_range) if use_logsoftmax else SpatialSoftmaxKP(kp_range=kp_range)
128 |
129 | # decoder
130 | decoder_n_kp = 3 * self.n_kp_enc if self.dec_bone == "gauss_pointnetpp_feat" else 2 * self.n_kp_enc
131 | self.to_gauss_map = ToGaussianMapHW(sigma_w=sigma, sigma_h=sigma, kp_range=kp_range)
132 | self.pointnet = PointNetPPToCNN(axis_dim=2, target_hw=self.features_dim,
133 | n_kp=self.n_kp_enc, features_dim=self.learned_feature_dim,
134 | pad_mode=pad_mode)
135 | self.dec = CNNDecoder(cdim=cdim, channels=enc_channels, image_size=image_size, in_ch=decoder_n_kp,
136 | n_kp=self.n_kp_enc + 1, pad_mode=pad_mode)
137 | # object decoder
138 | if self.use_object_dec:
139 | self.object_dec = ObjectDecoderCNN(patch_size=(self.obj_patch_size, self.obj_patch_size), num_chans=4,
140 | bottleneck_size=learned_feature_dim)
141 | else:
142 | self.object_dec = None
143 | self.init_weights()
144 |
145 | def get_parameters(self, prior=True, encoder=True, decoder=True):
146 | parameters = []
147 | if prior:
148 | parameters.extend(list(self.prior.parameters()))
149 | if encoder:
150 | parameters.extend(list(self.enc.parameters()))
151 | parameters.extend(list(self.to_normal_stats.parameters()))
152 | parameters.extend(list(self.object_enc.parameters()))
153 | if self.use_object_enc:
154 | parameters.extend(list(self.object_enc_sep.parameters()))
155 | if self.use_object_dec:
156 | parameters.extend(list(self.aux_enc.parameters()))
157 | if decoder:
158 | parameters.extend(list(self.dec.parameters()))
159 | parameters.extend(list(self.pointnet.parameters()))
160 | if self.use_object_dec:
161 | parameters.extend(list(self.object_dec.parameters()))
162 | return parameters
163 |
164 | def set_require_grad(self, prior_value=True, enc_value=True, dec_value=True):
165 | for param in self.prior.parameters():
166 | param.requires_grad = prior_value
167 | for param in self.enc.parameters():
168 | param.requires_grad = enc_value
169 | for param in self.to_normal_stats.parameters():
170 | param.requires_grad = enc_value
171 | for param in self.object_enc.parameters():
172 | param.requires_grad = enc_value
173 | if self.use_object_enc:
174 | for param in self.object_enc_sep.parameters():
175 | param.requires_grad = enc_value
176 | if self.use_object_dec:
177 | for param in self.aux_enc.parameters():
178 | param.requires_grad = enc_value
179 | for param in self.dec.parameters():
180 | param.requires_grad = dec_value
181 | for param in self.pointnet.parameters():
182 | param.requires_grad = dec_value
183 | if self.use_object_dec:
184 | for param in self.object_dec.parameters():
185 | param.requires_grad = dec_value
186 |
187 | def init_weights(self):
188 | for m in self.modules():
189 | if isinstance(m, nn.Conv2d):
190 | nn.init.normal_(m.weight, 0, 0.01)
191 | if m.bias is not None:
192 | nn.init.constant_(m.bias, 0)
193 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
194 | nn.init.constant_(m.weight, 1)
195 | nn.init.constant_(m.bias, 0)
196 | elif isinstance(m, nn.Linear):
197 | # use pytorch's default
198 | pass
199 |
200 | def encode(self, x, return_heatmap=False, mask=None):
201 | _, z_kp = self.enc(x) # [batch_size, n_kp, features_dim, features_dim]
202 | if mask is None:
203 | masked_hm = z_kp
204 | else:
205 | masked_hm = mask * z_kp
206 | z_kp_v = masked_hm.view(masked_hm.shape[0], -1) # [batch_size, n_kp * features_dim * features_dim]
207 | stats = self.to_normal_stats(z_kp_v) # [batch_size, n_kp * 4]
208 | stats = stats.view(stats.shape[0], self.n_kp_enc, 2 * 2)
209 | # [batch_size, n_kp, 4 + learned_feature_dim * 2]
210 | mu_enc, logvar_enc = torch.chunk(stats, chunks=2, dim=-1) # [batch_size, n_kp, 2 + learned_feature_dim]
211 | mu, logvar = mu_enc[:, :, :2], logvar_enc[:, :, :2] # [x, y]
212 |
213 | logvar_p = torch.log(torch.tensor(self.sigma ** 2, device=logvar.device))
214 | if self.use_object_dec:
215 | stats_aux = self.aux_enc(z_kp_v.detach())
216 | if self.learn_order:
217 | stats_aux = stats_aux.view(stats_aux.shape[0], self.n_kp_enc, 1 + self.n_kp_enc)
218 | order_weights = stats_aux[:, :, 1:]
219 | else:
220 | stats_aux = stats_aux.view(stats_aux.shape[0], self.n_kp_enc, 1)
221 | order_weights = None
222 | mu_obj_weight = stats_aux[:, :, 0]
223 | mu_obj_weight = torch.sigmoid(mu_obj_weight)
224 | else:
225 | mu_obj_weight = None
226 | order_weights = None
227 | mu_features, logvar_features = None, None
228 | if self.kp_activation == "tanh":
229 | mu = torch.tanh(mu)
230 | elif self.kp_activation == "sigmoid":
231 | mu = torch.sigmoid(mu)
232 |
233 | mu = torch.cat([mu, torch.zeros_like(mu[:, 0]).unsqueeze(1)], dim=1)
234 | logvar = torch.cat([logvar, logvar_p * torch.ones_like(logvar[:, 0]).unsqueeze(1)], dim=1)
235 |
236 | if return_heatmap:
237 | return mu, logvar, z_kp, mu_features, logvar_features, mu_obj_weight, order_weights
238 | else:
239 | return mu, logvar, mu_features, logvar_features, mu_obj_weight, order_weights
240 |
241 | def encode_object_features(self, features_map, masks):
242 | # features_map [bs, n_kp, feature_dim, feature_dim]
243 | # masks: [bs, n_kp + 1, feature_dim, feature_dim]
244 | y = masks.unsqueeze(2) * features_map.unsqueeze(1) # [bs, n_kp + 1, n_kp, feature_dim, feature_dim]
245 | y = y.view(y.shape[0], y.shape[1], -1) # [bs, n_kp + 1, n_kp * feature_dim ** 2]
246 | enc_out = self.object_enc(y) # [bs, n_kp + 1, learned_feature_dim * 2]
247 | mu_features, logvar_features = torch.chunk(enc_out, chunks=2, dim=-1) # [bs, n_kp + 1, learned_feature_dim]
248 | return mu_features, logvar_features
249 |
250 | def encode_object_features_sep(self, x, kp, features_map, masks, exclusive_patches=False, obj_on=None):
251 | # x: [bs, ch, image_size, image_size]
252 | # kp :[bs, n_kp, 2]
253 | # features_map: [bs, n_kp, features_dim, features_dim]
254 | # masks: [bs, n_kp, features_dim, features_dim]
255 |
256 | batch_size, n_kp, features_dim, _ = masks.shape
257 |
258 | # object features
259 | obj_enc_out = self.object_enc_sep(x, kp.detach(), exclusive_patches=exclusive_patches, obj_on=obj_on)
260 | mu_obj, logvar_obj, cropped_objects = obj_enc_out[0], obj_enc_out[1], obj_enc_out[2]
261 | if len(obj_enc_out) > 3:
262 | cropped_objects_masks = obj_enc_out[3]
263 | else:
264 | cropped_objects_masks = None
265 |
266 | # bg beatures
267 | if self.use_object_dec:
268 | obj_fmap_masks = create_masks_fast(kp.detach(), anchor_s=self.anchor_s, feature_dim=self.features_dim)
269 | bg_mask = 1 - obj_fmap_masks.squeeze(2).sum(1, keepdim=True).clamp(0, 1)
270 | # [bs, 1, features_dim, features_dim]
271 | else:
272 | bg_mask = masks[:, -1].unsqueeze(1) # [bs, 1, features_dim, features_dim]
273 | masked_features = bg_mask.unsqueeze(2) * features_map.unsqueeze(1) # [bs, 1, n_kp, f_dim, f_dim]
274 | masked_features = masked_features.view(batch_size, masked_features.shape[1], -1) # flatten
275 | object_enc_out = self.object_enc(masked_features) # [bs, 1, 2 * learned_features_dim]
276 | mu_bg, logvar_bg = object_enc_out.chunk(2, dim=-1)
277 |
278 | mu_features = torch.cat([mu_obj, mu_bg], dim=1)
279 | logvar_features = torch.cat([logvar_obj, logvar_bg], dim=1)
280 |
281 | return mu_features, logvar_features, cropped_objects, cropped_objects_masks
282 |
283 | def encode_all(self, x, return_heatmap=False, mask=None, deterministic=False):
284 | # posterior
285 | enc_out = self.encode(x, return_heatmap=True, mask=mask)
286 | mu, logvar, kp_heatmap, mu_features, logvar_features, obj_on, order_weights = enc_out
287 | if deterministic:
288 | z = mu
289 | else:
290 | z = reparameterize(mu, logvar)
291 | gmap_1_fg = self.to_gauss_map(z[:, :-1], self.features_dim, self.features_dim)
292 | fg_masks_sep = get_kp_mask_from_gmap(gmap_1_fg, threshold=self.mask_threshold, binary=True,
293 | elementwise=True).detach()
294 | fg_masks = fg_masks_sep.sum(1, keepdim=True).clamp(0, 1)
295 | bg_masks = 1 - fg_masks
296 | masks_sep = torch.cat([fg_masks_sep, bg_masks], dim=1)
297 |
298 | if self.learned_feature_dim > 0:
299 | if self.use_object_enc:
300 | feat_source = x if self.use_object_dec else kp_heatmap.detach()
301 | obj_enc_out = self.encode_object_features_sep(feat_source, mu[:, :-1], kp_heatmap.detach(),
302 | masks_sep.detach())
303 | mu_features, logvar_features, cropped_objects = obj_enc_out[0], obj_enc_out[1], obj_enc_out[2]
304 | else:
305 | mu_features, logvar_features = self.encode_object_features(kp_heatmap.detach(), masks_sep)
306 | if return_heatmap:
307 | return mu, logvar, kp_heatmap, mu_features, logvar_features, obj_on, order_weights
308 | else:
309 | return mu, logvar, mu_features, logvar_features, obj_on, order_weights
310 |
311 | def encode_prior(self, x):
312 | return self.prior(x)
313 |
314 | def decode(self, z):
315 | return self.dec(z)
316 |
317 | def get_prior_kp(self, x, probs=False):
318 | _, z = self.encode_prior(x)
319 | return self.ssm(z, probs)
320 |
321 | def translate_patches(self, kp_batch, patches_batch, scale=None, translation=None):
322 | """
323 | translate patches to be centered around given keypoints
324 | kp_batch: [bs, n_kp, 2] in [-1, 1]
325 | patches: [bs, n_kp, ch_patches, patch_size, patch_size]
326 | scale: None or [bs, n_kp, 2] or [bs, n_kp, 1]
327 | translation: None or [bs, n_kp, 2] or [bs, n_kp, 1] (delta from kp)
328 | :return: translated_padded_pathces [bs, n_kp, ch, img_size, img_size]
329 | """
330 | batch_size, n_kp, ch_patch, patch_size, _ = patches_batch.shape
331 | img_size = self.image_size
332 | pad_size = (img_size - patch_size) // 2
333 | padded_patches_batch = F.pad(patches_batch, pad=[pad_size] * 4)
334 | delta_t_batch = 0.0 - kp_batch
335 | delta_t_batch = delta_t_batch.reshape(-1, delta_t_batch.shape[-1]) # [bs * n_kp, 2]
336 | padded_patches_batch = padded_patches_batch.reshape(-1, *padded_patches_batch.shape[2:])
337 | # [bs * n_kp, 3, patch_size, patch_size]
338 | zeros = torch.zeros([delta_t_batch.shape[0], 1], device=delta_t_batch.device).float()
339 | ones = torch.ones([delta_t_batch.shape[0], 1], device=delta_t_batch.device).float()
340 |
341 | if scale is None:
342 | scale_w = ones
343 | scale_h = ones
344 | elif scale.shape[-1] == 1:
345 | scale_w = scale[:, :-1].reshape(-1, scale.shape[-1]) # no need for bg kp
346 | scale_h = scale[:, :-1].reshape(-1, scale.shape[-1]) # no need for bg kp
347 | else:
348 | scale_h, scale_w = torch.split(scale[:, :-1], [1, 1], dim=-1)
349 | scale_w = scale_w.reshape(-1, scale_w.shape[-1])
350 | scale_h = scale_h.reshape(-1, scale_h.shape[-1])
351 | if translation is None:
352 | trans_w = zeros
353 | trans_h = zeros
354 | elif translation.shape[-1] == 1:
355 | trans_w = translation[:, :-1].reshape(-1, translation.shape[-1]) # no need for bg kp
356 | trans_h = translation[:, :-1].reshape(-1, translation.shape[-1]) # no need for bg kp
357 | else:
358 | trans_h, trans_w = torch.split(translation[:, :-1], [1, 1], dim=-1)
359 | trans_w = trans_w.reshape(-1, trans_w.shape[-1])
360 | trans_h = trans_h.reshape(-1, trans_h.shape[-1])
361 |
362 | theta = torch.cat([scale_h, zeros, delta_t_batch[:, 1].unsqueeze(-1) + trans_h,
363 | zeros, scale_w, delta_t_batch[:, 0].unsqueeze(-1) + trans_w], dim=-1)
364 |
365 | theta = theta.view(-1, 2, 3) # [batch_size * n_kp, 2, 3]
366 | align_corners = False
367 | padding_mode = 'zeros'
368 | # mode = "nearest"
369 | mode = 'bilinear'
370 |
371 | grid = F.affine_grid(theta, padded_patches_batch.size(), align_corners=align_corners)
372 | trans_padded_patches_batch = F.grid_sample(padded_patches_batch, grid, align_corners=align_corners,
373 | mode=mode, padding_mode=padding_mode)
374 |
375 | trans_padded_patches_batch = trans_padded_patches_batch.view(batch_size, n_kp, *padded_patches_batch.shape[1:])
376 | # [bs, n_kp, ch, img_size, img_size]
377 | return trans_padded_patches_batch
378 |
379 | def get_objects_alpha_rgb(self, z_kp, z_features, scale=None, translation=None, deterministic=False,
380 | order_weights=None):
381 | dec_objects = self.object_dec(z_features[:, :-1]) # [bs * n_kp, 4, patch_size, patch_size]
382 | dec_objects = dec_objects.view(-1, self.n_kp_enc,
383 | *dec_objects.shape[1:]) # [bs, n_kp, 4, patch_size, patch_size]
384 | # translate patches
385 | dec_objects_trans = self.translate_patches(z_kp[:, :-1], dec_objects, scale, translation)
386 | dec_objects_trans = dec_objects_trans.clamp(0, 1) # STN can change values to be < 0
387 | # dec_objects_trans: [bs, n_kp, 3, im_size, im_size]
388 | if order_weights is not None:
389 | # for each particle, we get a one-hot vector of its place in the order
390 | # we then move all of its maps [4, h, w] to its new place via 1x1 grouped-convolution (group_size=batch_size)
391 | bs, n_kp, n_ch, h, w = dec_objects_trans.shape
392 | order_weights = order_weights.view(order_weights.shape[0], self.n_kp_enc, self.n_kp_enc, 1, 1)
393 | order_weights = F.gumbel_softmax(order_weights, hard=True, dim=1) # straight-through gradients (hard=True)
394 | # order weights: [bs, n_kp, n_kp, 1, 1] - for each kp, its location in the order, in one-hot form
395 | # i.e., if kp 1 is 6 in the order of 8 kp, then its vector: [0 0 0 0 0 0 1 0]
396 | order_weights = order_weights.view(order_weights.shape[0] * self.n_kp_enc, self.n_kp_enc, 1, 1)
397 | reordered_objects = dec_objects_trans.reshape(1, -1, h * n_ch, w) # [1, bs * n_kp, h * n_ch, w]
398 | ordered_objects = F.conv2d(reordered_objects, order_weights, bias=None, stride=1, groups=bs)
399 | ordered_objects = ordered_objects.view(bs, n_kp, n_ch, h, w)
400 | dec_objects_trans = ordered_objects
401 |
402 | # multiply by alpha channel
403 | a_obj, rgb_obj = torch.split(dec_objects_trans, [1, 3], dim=2)
404 |
405 | if not deterministic:
406 | attn_mask = torch.where(a_obj > 0.1, 1.0, 0.0)
407 | # attn_mask = self.to_gauss_map(z_kp[:, :-1], a_obj.shape[-1], a_obj.shape[-1]).unsqueeze(
408 | # 2).detach()
409 | a_obj = a_obj + self.sigma * torch.randn_like(a_obj) * attn_mask
410 | return dec_objects, a_obj, rgb_obj
411 |
412 | def stitch_objects(self, a_obj, rgb_obj, obj_on, bg, stitch_method='c'):
413 | # turn off inactive kp
414 | # obj_on: [bs, n_kp, 1]
415 | a_obj = obj_on[:, :, None, None, None] * a_obj # [bs, n_kp, 4, im_size, im_size]
416 | if stitch_method == 'a':
417 | # layer-wise stitching, each particle is a layer
418 | # x_0 = bg
419 | # x_i = (1-a_i) * x_(i-1) + a_i * rgb_i
420 | rec = bg
421 | curr_mask = a_obj[:, 0]
422 | comp_masks = [curr_mask] # to calculate the effective mask, only for plotting
423 | for i in range(a_obj.shape[1]):
424 | rec = (1 - a_obj[:, i]) * rec + a_obj[:, i] * rgb_obj[:, i]
425 | # rec = (1 - a_obj[:, i].detach()) * rec + a_obj[:, i] * rgb_obj[:, i]
426 | # what is the effect of this? bad, masks are not learned properly
427 | if i > 0:
428 | available_space = 1.0 - curr_mask.detach()
429 | curr_mask_tmp = torch.min(available_space, a_obj[:, i])
430 | comp_masks.append(curr_mask_tmp)
431 | curr_mask = curr_mask + curr_mask_tmp
432 | comp_masks = torch.stack(comp_masks, dim=1)
433 | dec_objects_trans = comp_masks * rgb_obj
434 | dec_objects_trans = dec_objects_trans.sum(1)
435 | elif stitch_method == 'b':
436 | # same formula as method 'a', but with detach and opening the recursive formula
437 | # x_n = bg * \prod_{i=1}^n (1-a_i) + a_n * rgb_n + a_(n-1) * rgb_(n-1) * (1-a_n) + ...
438 | # + a_1 * rgb_1 * \prod_{i=1}^{n-1} (1-a_i)
439 | bg_comp = torch.prod(1 - a_obj, dim=1) * bg
440 | obj = a_obj * rgb_obj
441 | # stitch
442 | rec = obj[:, -1]
443 | for i in reversed(range(a_obj.shape[1] - 1)):
444 | rec = rec + obj[:, i] * torch.prod((1 - a_obj[:, i + 1:].detach()), dim=1)
445 | dec_objects_trans = rec.detach()
446 | rec = rec + bg_comp
447 | else:
448 | # alpha-based stitching: we first calculate the effective masks, assuming the previous
449 | # masks already occupy some space that cannot be taken and finally we multiply the effective masks
450 | # by the rgb channel, the bg mask is the space left from the sum of all effective masks.
451 | curr_mask = a_obj[:, 0]
452 | comp_masks = [curr_mask]
453 | for i in range(1, a_obj.shape[1]):
454 | available_space = 1.0 - curr_mask.detach()
455 | curr_mask_tmp = torch.min(available_space, a_obj[:, i])
456 | comp_masks.append(curr_mask_tmp)
457 | curr_mask = curr_mask + curr_mask_tmp
458 | comp_masks = torch.stack(comp_masks, dim=1)
459 | comp_masks_sum = comp_masks.sum(1).clamp(0, 1)
460 | alpha_mask = 1.0 - comp_masks_sum
461 | dec_objects_trans = comp_masks * rgb_obj
462 | dec_objects_trans = dec_objects_trans.sum(1) # [bs, 3, im_size, im_size]
463 | rec = alpha_mask * bg + dec_objects_trans
464 | return rec, dec_objects_trans
465 |
466 | def decode_objects(self, z_kp, z_features, obj_on, scale=None, translation=None, deterministic=False,
467 | order_weights=None, bg=None):
468 | dec_objects, a_obj, rgb_obj = self.get_objects_alpha_rgb(z_kp, z_features, scale=scale, translation=translation,
469 | deterministic=deterministic,
470 | order_weights=order_weights)
471 | if bg is None:
472 | bg = torch.zeros_like(rgb_obj[:, 0])
473 | # stitching
474 | rec, dec_objects_trans = self.stitch_objects(a_obj, rgb_obj, obj_on=obj_on, bg=bg)
475 | return dec_objects, dec_objects_trans, rec
476 |
477 | def decode_all(self, z, z_features, kp_heatmap, obj_on, deterministic=False, order_weights=None):
478 | gmap_1_fg = self.to_gauss_map(z[:, :-1], self.features_dim, self.features_dim)
479 | gmap_1_bg = 1 - gmap_1_fg.sum(1, keepdim=True).clamp(0, 1).detach()
480 | gmap_1 = torch.cat([gmap_1_fg, gmap_1_bg], dim=1)
481 | fg_masks_sep = get_kp_mask_from_gmap(gmap_1_fg, threshold=self.mask_threshold, binary=True,
482 | elementwise=True).detach()
483 | fg_masks = fg_masks_sep.sum(1, keepdim=True).clamp(0, 1)
484 | bg_masks = 1 - fg_masks
485 | masks = torch.cat([fg_masks.expand_as(gmap_1_fg), bg_masks], dim=1)
486 | # decode object and translate them to the positions of the keypoints
487 | # decode
488 | z_features_in = z_features
489 | if self.dec_bone == "gauss_pointnetpp":
490 | if self.learned_feature_dim > 0:
491 | gmap_2 = self.pointnet(position=z.detach(),
492 | features=torch.cat([z.detach(), z_features_in], dim=-1))
493 | else:
494 | gmap_2 = self.pointnet(position=z.detach(), features=z.detach())
495 | gmap = torch.cat([gmap_1[:, :-1], gmap_2], dim=1)
496 | elif self.dec_bone == "gauss_pointnetpp_feat":
497 | if self.learned_feature_dim > 0:
498 | gmap_2 = self.pointnet(position=z.detach(),
499 | features=torch.cat([z.detach(), z_features_in], dim=-1))
500 | else:
501 | gmap_2 = self.pointnet(position=z.detach(), features=z.detach())
502 |
503 | fg_masks = masks[:, :-1]
504 | bg_masks = masks[:, -1].unsqueeze(1)
505 | gmap_2 = fg_masks * gmap_2
506 | gmap_3 = bg_masks * kp_heatmap.detach()
507 | gmap = torch.cat([gmap_1[:, :-1], gmap_2, gmap_3], dim=1)
508 | else:
509 | raise NotImplementedError('grow a dec bone')
510 | rec = self.dec(gmap)
511 |
512 | if z_features is not None and self.use_object_dec:
513 | object_dec_out = self.decode_objects(z, z_features, obj_on, deterministic=deterministic,
514 | order_weights=order_weights, bg=rec)
515 | dec_objects, dec_objects_trans, rec = object_dec_out
516 | else:
517 | dec_objects_trans = None
518 | dec_objects = None
519 | return rec, dec_objects, dec_objects_trans
520 |
521 | def forward(self, x, deterministic=False, detach_decoder=False, x_prior=None, warmup=False, stg=False,
522 | noisy_masks=False):
523 | # stg: straight-through-gradients. not used.
524 | # first, extract prior KP proposals
525 | # prior
526 | if x_prior is None:
527 | x_prior = x
528 | kp_p = self.prior(x_prior, global_kp=True)
529 | kp_p = kp_p.view(x_prior.shape[0], -1, 2) # [batch_size, n_kp_total, 2]
530 | # filter proposals by distance to the patches' center
531 | dist_from_center = self.prior.get_distance_from_patch_centers(kp_p, global_kp=True)
532 | _, indices = torch.topk(dist_from_center, k=self.n_kp_prior, dim=-1, largest=True)
533 | batch_indices = torch.arange(kp_p.shape[0]).view(-1, 1).to(kp_p.device)
534 | kp_p = kp_p[batch_indices, indices]
535 | # alternatively, just sample random kp
536 | # kp_p = kp_p[:, torch.randperm(kp_p.shape[1])[:self.n_kp_prior]]
537 |
538 | # encode posterior KP
539 | mu, logvar, kp_heatmap, mu_features, logvar_features, obj_on, order_weights = self.encode(x,
540 | return_heatmap=True)
541 | if deterministic:
542 | z = mu
543 | else:
544 | z = reparameterize(mu, logvar)
545 |
546 | # create gaussian maps (and masks) from the posterior keypoints
547 | gmap_1_fg = self.to_gauss_map(z[:, :-1], self.features_dim, self.features_dim)
548 | gmap_1_bg = 1 - gmap_1_fg.sum(1, keepdim=True).clamp(0, 1).detach()
549 | gmap_1 = torch.cat([gmap_1_fg, gmap_1_bg], dim=1)
550 | fg_masks_sep = get_kp_mask_from_gmap(gmap_1_fg, threshold=self.mask_threshold, binary=True,
551 | elementwise=True).detach()
552 | fg_masks = fg_masks_sep.sum(1, keepdim=True).clamp(0, 1)
553 | bg_masks = 1 - fg_masks
554 | masks = torch.cat([fg_masks.expand_as(gmap_1_fg), bg_masks], dim=1)
555 | masks_sep = torch.cat([fg_masks_sep, bg_masks], dim=1)
556 |
557 | # encode visual features
558 | if self.learned_feature_dim > 0:
559 | if self.use_object_enc:
560 | feat_source = x if self.use_object_dec else kp_heatmap.detach()
561 | obj_on_in = obj_on if not noisy_masks else 0.0 * obj_on + torch.rand_like(obj_on)
562 | obj_enc_out = self.encode_object_features_sep(feat_source, mu[:, :-1], kp_heatmap.detach(),
563 | masks_sep.detach(),
564 | exclusive_patches=self.exclusive_patches,
565 | obj_on=obj_on_in)
566 | mu_features, logvar_features, cropped_objects = obj_enc_out[0], obj_enc_out[1], obj_enc_out[2]
567 | if len(obj_enc_out) > 3:
568 | cropped_objects_masks = obj_enc_out[3]
569 | else:
570 | cropped_objects_masks = None
571 | else:
572 | mu_features, logvar_features = self.encode_object_features(kp_heatmap.detach(), masks_sep)
573 | cropped_objects = None
574 | cropped_objects_masks = None
575 |
576 | if deterministic:
577 | z_features = mu_features
578 | else:
579 | z_features = reparameterize(mu_features, logvar_features)
580 | else:
581 | z_features = None
582 | cropped_objects = None
583 | cropped_objects_masks = None
584 |
585 | # decode
586 | if not warmup or not self.use_object_dec:
587 | z_features_fg, z_features_bg = torch.split(z_features, [self.n_kp_enc, 1], dim=1)
588 | z_features_in = torch.cat([z_features_fg.detach(), z_features_bg],
589 | dim=1) if self.use_object_dec else z_features
590 | if self.dec_bone == "gauss_pointnetpp":
591 | if self.learned_feature_dim > 0:
592 | gmap_2 = self.pointnet(position=z.detach(),
593 | features=torch.cat([z.detach(), z_features_in], dim=-1))
594 | else:
595 | gmap_2 = self.pointnet(position=z.detach(), features=z.detach())
596 | gmap = torch.cat([gmap_1[:, :-1], gmap_2], dim=1)
597 | # gmap = torch.cat([gmap_2.detach(), gmap_2], dim=1)
598 | elif self.dec_bone == "gauss_pointnetpp_feat":
599 | if self.learned_feature_dim > 0:
600 | gmap_2 = self.pointnet(position=z.detach(),
601 | features=torch.cat([z.detach(), z_features_in], dim=-1))
602 | else:
603 | gmap_2 = self.pointnet(position=z.detach(), features=z.detach())
604 |
605 | fg_masks = masks[:, :-1]
606 | bg_masks = masks[:, -1].unsqueeze(1)
607 | gmap_2 = fg_masks * gmap_2
608 | gmap_3 = bg_masks * kp_heatmap.detach()
609 | gmap = torch.cat([gmap_1[:, :-1], gmap_2, gmap_3], dim=1)
610 | else:
611 | raise NotImplementedError('grow a dec bone')
612 | if detach_decoder:
613 | rec = self.dec(gmap.detach())
614 | else:
615 | rec = self.dec(gmap)
616 | else:
617 | rec = torch.zeros_like(x)
618 | gmap = None
619 |
620 | # decode object and translate them to the positions of the keypoints
621 | if z_features is not None and self.use_object_dec:
622 | obj_on_in = obj_on if not noisy_masks else 0.0 * obj_on + torch.rand_like(obj_on)
623 | object_dec_out = self.decode_objects(z, z_features, obj_on_in, deterministic=not noisy_masks,
624 | order_weights=order_weights, bg=rec)
625 | dec_objects, dec_objects_trans, rec = object_dec_out
626 | else:
627 | dec_objects_trans = None
628 | dec_objects = None
629 | gmap = None
630 |
631 | output_dict = {}
632 | output_dict['kp_p'] = kp_p
633 | output_dict['gmap'] = gmap
634 | output_dict['rec'] = rec
635 | output_dict['mu'] = mu
636 | output_dict['logvar'] = logvar
637 | output_dict['mu_features'] = mu_features
638 | output_dict['logvar_features'] = logvar_features
639 | # object stuff
640 | output_dict['cropped_objects_original'] = cropped_objects
641 | output_dict['cropped_objects_masks'] = cropped_objects_masks
642 | output_dict['obj_on'] = obj_on
643 | output_dict['dec_objects_original'] = dec_objects
644 | output_dict['dec_objects'] = dec_objects_trans
645 | output_dict['order_weights'] = order_weights
646 |
647 | return output_dict
648 |
649 | def lerp(self, other, betta):
650 | # weight interpolation for ema - not used in the paper
651 | if hasattr(other, 'module'):
652 | other = other.module
653 | with torch.no_grad():
654 | params = self.parameters()
655 | other_param = other.parameters()
656 | for p, p_other in zip(params, other_param):
657 | p.data.lerp_(p_other.data, 1.0 - betta)
658 |
--------------------------------------------------------------------------------
/modules/modules.py:
--------------------------------------------------------------------------------
1 | """
2 | Basic modules and layers.
3 | """
4 |
5 | # imports
6 | import numpy as np
7 | # torch
8 | import torch
9 | import torch.nn.functional as F
10 | import torch.nn as nn
11 | from utils.util_func import create_masks_fast
12 | # torch geometric
13 | from torch_geometric.nn import MessagePassing, global_max_pool
14 | from torch_cluster import knn_graph
15 | from torch_cluster import fps
16 |
17 |
18 | class ResidualBlock(nn.Module):
19 | def __init__(self, inc=64, outc=64, groups=1, scale=1.0, padding="zeros"):
20 | super(ResidualBlock, self).__init__()
21 |
22 | midc = int(outc * scale)
23 |
24 | if inc is not outc:
25 | self.conv_expand = nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=1, stride=1, padding=0,
26 | groups=1, bias=False)
27 | else:
28 | self.conv_expand = None
29 | if padding == "zeros":
30 | self.conv1 = nn.Conv2d(in_channels=inc, out_channels=midc, kernel_size=3, stride=1, padding=1,
31 | groups=groups,
32 | bias=False)
33 | else:
34 | self.conv1 = nn.Sequential(nn.ReplicationPad2d(1),
35 | nn.Conv2d(in_channels=inc, out_channels=midc, kernel_size=3, stride=1,
36 | padding=0, groups=groups, bias=False))
37 | self.bn1 = nn.BatchNorm2d(midc)
38 | self.relu1 = nn.LeakyReLU(0.2, inplace=True)
39 | if padding == "zeros":
40 | self.conv2 = nn.Conv2d(in_channels=midc, out_channels=outc, kernel_size=3, stride=1, padding=1,
41 | groups=groups, bias=False)
42 | else:
43 | self.conv2 = nn.Sequential(nn.ReplicationPad2d(1),
44 | nn.Conv2d(in_channels=midc, out_channels=outc, kernel_size=3, stride=1,
45 | padding=0, groups=groups, bias=False))
46 | self.bn2 = nn.BatchNorm2d(outc)
47 | self.relu2 = nn.LeakyReLU(0.01, inplace=True)
48 |
49 | def forward(self, x):
50 | if self.conv_expand is not None:
51 | identity_data = self.conv_expand(x)
52 | else:
53 | identity_data = x
54 |
55 | output = self.relu1(self.bn1(self.conv1(x)))
56 | output = self.conv2(output)
57 | output = self.bn2(output)
58 | output = self.relu2(torch.add(output, identity_data))
59 | return output
60 |
61 |
62 | class ConvBlock(nn.Module):
63 | """
64 | Basic convolutional nn block.
65 | """
66 |
67 | def __init__(self, c_in, c_out, kernel_size, stride=1, pad=0, pool=False, upsample=False, bias=False,
68 | activation=True, batchnorm=True, relu_type='leaky', pad_mode='zeros', use_resblock=False):
69 | super(ConvBlock, self).__init__()
70 | self.main = nn.Sequential()
71 | if use_resblock:
72 | self.main.add_module(f'conv_{c_in}_to_{c_out}', ResidualBlock(c_in, c_out, padding=pad_mode))
73 | else:
74 | if pad_mode != 'zeros':
75 | self.main.add_module('replicate_pad', nn.ReplicationPad2d(pad))
76 | pad = 0
77 | self.main.add_module(f'conv_{c_in}_to_{c_out}', nn.Conv2d(c_in, c_out, kernel_size,
78 | stride=stride, padding=pad, bias=bias))
79 | if batchnorm:
80 | # note: for better performance on small datasets/batches, it is better to replace with nn.GroupNorm
81 | self.main.add_module(f'bathcnorm_{c_out}', nn.BatchNorm2d(c_out))
82 | if activation:
83 | if relu_type == 'leaky':
84 | self.main.add_module(f'relu', nn.LeakyReLU(0.01))
85 | else:
86 | self.main.add_module(f'relu', nn.ReLU())
87 | if pool:
88 | self.main.add_module(f'max_pool2', nn.MaxPool2d(kernel_size=2, stride=2))
89 | if upsample:
90 | # note: literature recommends using 'nearest' interpolation instead of`bilinear`.
91 | self.main.add_module(f'upsample_bilinear_2',
92 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False))
93 |
94 | def forward(self, x):
95 | y = self.main(x)
96 | return y
97 |
98 |
99 | class KeyPointCNNOriginal(nn.Module):
100 | """
101 | CNN to extract heatmaps, inspired by KeyNet (Jakab et.al)
102 | """
103 |
104 | def __init__(self, cdim=3, channels=(32, 64, 128, 256), image_size=64, n_kp=8, pad_mode='replicate',
105 | use_resblock=False, first_conv_kernel_size=7):
106 | super(KeyPointCNNOriginal, self).__init__()
107 | self.cdim = cdim
108 | self.image_size = image_size
109 | self.n_kp = n_kp
110 | cc = channels[0]
111 | ch = cc
112 | first_conv_pad = first_conv_kernel_size // 2
113 | self.main = nn.Sequential()
114 | self.main.add_module(f'in_block_1',
115 | ConvBlock(cdim, cc, kernel_size=first_conv_kernel_size, stride=1,
116 | pad=first_conv_pad, pool=False, pad_mode=pad_mode,
117 | use_resblock=use_resblock, relu_type='relu'))
118 | self.main.add_module(f'in_block_2',
119 | ConvBlock(cc, cc, kernel_size=3, stride=1, pad=1, pool=False, pad_mode=pad_mode,
120 | use_resblock=use_resblock, relu_type='relu'))
121 |
122 | sz = image_size
123 | for ch in channels[1:]:
124 | self.main.add_module('conv_in_{}_0'.format(sz), ConvBlock(cc, ch, kernel_size=3, stride=2, pad=1,
125 | pool=False, pad_mode=pad_mode,
126 | use_resblock=use_resblock, relu_type='relu'))
127 | self.main.add_module('conv_in_{}_1'.format(ch), ConvBlock(ch, ch, kernel_size=3, stride=1, pad=1,
128 | pool=False, pad_mode=pad_mode,
129 | use_resblock=use_resblock, relu_type='relu'))
130 | cc, sz = ch, sz // 2
131 |
132 | self.keymap = nn.Conv2d(channels[-1], n_kp, kernel_size=1)
133 | self.conv_output_size = self.calc_conv_output_size()
134 | num_fc_features = torch.zeros(self.conv_output_size).view(-1).shape[0]
135 | print("conv shape: ", self.conv_output_size)
136 | # print("num fc features: ", num_fc_features)
137 | # self.fc = nn.Linear(num_fc_features, self.fc_output)
138 |
139 | def calc_conv_output_size(self):
140 | dummy_input = torch.zeros(1, self.cdim, self.image_size, self.image_size)
141 | dummy_input = self.main(dummy_input)
142 | return dummy_input[0].shape
143 |
144 | def forward(self, x):
145 | y = self.main(x)
146 | # heatmap
147 | hm = self.keymap(y)
148 | return y, hm
149 |
150 |
151 | # coverting to probabilities
152 | class SpatialSoftmaxKP(torch.nn.Module):
153 | """
154 | This module performs spatial-softmax (ssm) by performing marginalization over heatmaps.
155 | """
156 |
157 | def __init__(self, kp_range=(0, 1)):
158 | super().__init__()
159 | self.kp_range = kp_range
160 |
161 | def forward(self, heatmap, probs=False):
162 | batch_size, n_kp, height, width = heatmap.shape
163 | # marginalize over height (y)
164 | s_h = torch.mean(heatmap, dim=3) # [batch_size, n_kp, features_dim_height]
165 | sm_h = torch.softmax(s_h, dim=-1) # [batch_size, n_kp, features_dim_height]
166 | # marginalize over width (x)
167 | s_w = torch.mean(heatmap, dim=2) # [batch_size, n_kp, features_dim_width]
168 | # probability per spatial coordinate
169 | sm_w = torch.softmax(s_w, dim=-1) # [batch_size, n_kp, features_dim_width]
170 | # each coordinate [0, 1] is assigned a probability
171 | y_axis = torch.linspace(self.kp_range[0], self.kp_range[1], height).type_as(sm_h).expand(1, 1, -1).to(
172 | sm_h.device) # [1, 1, features_dim_height]
173 | # expected value: proability per coordinate * coordinate
174 | kp_h = torch.sum(sm_h * y_axis, dim=-1, keepdim=True) # [batch_size, n_kp, 1]
175 | kp_h = kp_h.squeeze(-1) # [batch_size, n_kp], y coordinate of each kp
176 |
177 | x_axis = torch.linspace(self.kp_range[0], self.kp_range[1], width).type_as(sm_w).expand(1, 1, -1).to(
178 | sm_w.device) # [1, 1, features_dim_width]
179 | kp_w = torch.sum(sm_w * x_axis, dim=-1, keepdim=True).squeeze(-1) # [batch_size, n_kp], x coordinate of each kp
180 |
181 | # stack keypoints
182 | kp = torch.stack([kp_h, kp_w], dim=-1) # [batch_size, n_kp, 2], x, y coordinates of each kp
183 |
184 | if probs:
185 | return kp, sm_h, sm_w
186 | else:
187 | return kp
188 |
189 |
190 | class SpatialLogSoftmaxKP(torch.nn.Module):
191 | """
192 | This module performs spatial-softmax (ssm) by performing marginalization over heatmaps. Uses log-softmax
193 | for numerical stability. In practice, we do not use it, but feel free to explore it.
194 | """
195 |
196 | def __init__(self, kp_range=(0, 1)):
197 | super().__init__()
198 | self.kp_range = kp_range
199 |
200 | def forward(self, heatmap, probs=False):
201 | batch_size, n_kp, height, width = heatmap.shape
202 | # marginalize over height (y)
203 | s_h = torch.mean(heatmap, dim=3) # [batch_size, n_kp, features_dim_height]
204 | sm_h = torch.log_softmax(s_h, dim=-1) # [batch_size, n_kp, features_dim_height]
205 | # marginalize over width (x)
206 | s_w = torch.mean(heatmap, dim=2) # [batch_size, n_kp, features_dim_width]
207 | sm_w = torch.log_softmax(s_w, dim=-1) # [batch_size, n_kp, features_dim_width]
208 | # each coordinate [0, 1] is assigned a probability
209 | y_axis = torch.log(torch.linspace(self.kp_range[0], self.kp_range[1], height)).type_as(sm_h).expand(1, 1,
210 | -1).to(
211 | sm_h.device)
212 | # [1, 1, features_dim_height]
213 | # expected value: proability per coordinate * coordinate
214 | kp_h = torch.sum(torch.exp(sm_h + y_axis), dim=-1, keepdim=True) # [batch_size, n_kp, 1]
215 | kp_h = kp_h.squeeze(-1) # [batch_size, n_kp], y coordinate of each kp
216 |
217 | x_axis = torch.log(torch.linspace(self.kp_range[0], self.kp_range[1], width)).type_as(sm_w).expand(1, 1, -1).to(
218 | sm_w.device)
219 | # [1, 1, features_dim_width]
220 | kp_w = torch.sum(torch.exp(sm_w + x_axis), dim=-1, keepdim=True).squeeze(-1)
221 | # [batch_size, n_kp], x coordinate of each kp
222 |
223 | # stack keypoints
224 | kp = torch.stack([kp_h, kp_w], dim=-1) # [batch_size, n_kp, 2], x, y coordinates of each kp
225 |
226 | if probs:
227 | return kp, sm_h, sm_w
228 | else:
229 | return kp
230 |
231 |
232 | class ToGaussianMapHW(nn.Module):
233 | """
234 | This module converts KP to a gaussian map centered at the coordinates of the keypoint
235 | """
236 |
237 | def __init__(self, sigma_w=0.1, sigma_h=0.1, kp_range=(0, 1)):
238 | super().__init__()
239 | self.sigma_w = sigma_w
240 | self.sigma_h = sigma_h
241 | self.kp_range = kp_range
242 |
243 | def forward(self, kp, height, width, logvar_h=None, logvar_w=None):
244 | batch_size, n_kp, _ = kp.shape
245 | # get means
246 | h_mean, w_mean = kp[:, :, 0], kp[:, :, 1]
247 | # create a coordinate map for each axis
248 | h_map = torch.linspace(self.kp_range[0], self.kp_range[1], height, device=h_mean.device).type_as(
249 | h_mean) # [height]
250 | w_map = torch.linspace(self.kp_range[0], self.kp_range[1], width, device=w_mean.device).type_as(
251 | w_mean) # [width]
252 | # duplicate for all keypoints in the batch
253 | h_map = h_map.expand(batch_size, n_kp, height) # [batch_size, n_kp, height]
254 | w_map = w_map.expand(batch_size, n_kp, width) # [batch_size, n_kp, width]
255 | # repeat the mean to match dimensions
256 | h_mean_m = h_mean.expand(height, -1, -1) # [height, batch_size, n_kp]
257 | h_mean_m = h_mean_m.permute(1, 2, 0) # [batch_size, n_kp, height]
258 | w_mean_m = w_mean.expand(width, -1, -1) # [width, batch_size, n_kp]
259 | w_mean_m = w_mean_m.permute(1, 2, 0) # [batch_size, n_kp, width]
260 | # for each pixel in the map, calculate the squared distance from the mean
261 | h_sdiff = (h_map - h_mean_m) ** 2 # [batch_size, n_kp, height]
262 | w_sdiff = (w_map - w_mean_m) ** 2 # [batch_size, n_kp, width]
263 | # compute gaussian
264 | # duplicate for the other dimension
265 | hm = h_sdiff.expand(width, -1, -1, -1).permute(1, 2, 3, 0) # [batch_size, n_kp, height, width]
266 | wm = w_sdiff.expand(height, -1, -1, -1).permute(1, 2, 0, 3) # [batch_size, n_kp, height, width]
267 | if logvar_h is not None:
268 | sigma_h = torch.exp(0.5 * logvar_h)
269 | sigma_h = sigma_h.expand(height, -1, -1) # [height, batch_size, n_kp]
270 | sigma_h = sigma_h.permute(1, 2, 0) # [batch_size, n_kp, height]
271 | sigma_h = sigma_h.expand(width, -1, -1, -1).permute(1, 2, 3, 0) # [batch_size, n_kp, height, width]
272 | else:
273 | sigma_h = self.sigma_h
274 | if logvar_w is not None:
275 | sigma_w = torch.exp(0.5 * logvar_w)
276 | sigma_w = sigma_w.expand(width, -1, -1) # [width, batch_size, n_kp]
277 | sigma_w = sigma_w.permute(1, 2, 0) # [batch_size, n_kp, width]
278 | sigma_w = sigma_w.expand(height, -1, -1, -1).permute(1, 2, 0, 3) # [batch_size, n_kp, height, width]
279 | else:
280 | sigma_w = self.sigma_w
281 | # print(hm.shape, sigma_h.shape, wm.shape, sigma_w.shape)
282 | gm = -0.5 * (hm / (sigma_h ** 2) + wm / (sigma_w ** 2))
283 | gm = torch.exp(gm) # [batch_size, n_kp, height, width]
284 | return gm
285 |
286 |
287 | class CNNDecoder(nn.Module):
288 | """
289 | This module upsamples activation maps to to the original image size.
290 | """
291 |
292 | def __init__(self, cdim=3, channels=(64, 128, 256, 512, 512, 512), image_size=64, in_ch=16, n_kp=8,
293 | pad_mode='zeros', use_resblock=False):
294 | super(CNNDecoder, self).__init__()
295 | self.cdim = cdim
296 | self.image_size = image_size
297 | cc = channels[-1]
298 | self.in_ch = in_ch
299 | self.n_kp = n_kp
300 |
301 | sz = 4
302 |
303 | self.main = nn.Sequential()
304 | self.main.add_module('depth_up',
305 | ConvBlock(self.in_ch, cc, kernel_size=3, pad=1, upsample=True, pad_mode=pad_mode,
306 | use_resblock=use_resblock))
307 | for ch in reversed(channels[1:-1]):
308 | self.main.add_module('conv_to_{}'.format(sz * 2), ConvBlock(cc, ch, kernel_size=3, pad=1, upsample=True,
309 | pad_mode=pad_mode, use_resblock=use_resblock))
310 | cc, sz = ch, sz * 2
311 |
312 | self.main.add_module('conv_to_{}'.format(sz * 2),
313 | ConvBlock(cc, self.n_kp * (channels[0] // self.n_kp + 1), kernel_size=3, pad=1,
314 | upsample=False,
315 | pad_mode=pad_mode, use_resblock=use_resblock))
316 | self.final_conv = ConvBlock(self.n_kp * (channels[0] // self.n_kp + 1), cdim, kernel_size=1, bias=True,
317 | activation=False, batchnorm=False)
318 |
319 | def forward(self, z, masks=None):
320 | y = self.main(z)
321 | if masks is not None:
322 | # this is not used in the paper
323 | # masks: [bs, n_kp, feat_dim, feat_dim]
324 | bs, n_kp, fs, _ = masks.shape
325 | # y: [bs, n_kp * ch[0], feat_dim, feat_dim]
326 | y = y.view(bs, n_kp, -1, fs, fs)
327 | y = masks.unsqueeze(2) * y
328 | y = y.view(bs, -1, fs, fs)
329 | y = self.final_conv(y)
330 | return y
331 |
332 |
333 | class ImagePatcher(nn.Module):
334 | """
335 | This module take an image of size B x cdim x H x W and return a patchified tesnor
336 | B x cdim x num_patches x patch_size x patch_size. It also gives you the global location of the patch
337 | w.r.t the original image. We use this module to extract prior KP from patches, and we need to know their
338 | global coordinates for the Chamfer-KL.
339 | """
340 |
341 | def __init__(self, cdim=3, image_size=64, patch_size=16):
342 | super(ImagePatcher, self).__init__()
343 | self.cdim = cdim
344 | self.image_size = image_size
345 | self.patch_size = patch_size
346 | self.kh, self.kw = self.patch_size, self.patch_size # kernel size
347 | self.dh, self.dw = self.patch_size, patch_size # stride
348 | self.unfold_shape = self.get_unfold_shape()
349 | self.patch_location_idx = self.get_patch_location_idx()
350 | # print(f'unfold shape: {self.unfold_shape}')
351 | # print(f'patch locations: {self.patch_location_idx}')
352 |
353 | def get_patch_location_idx(self):
354 | h = np.arange(0, self.image_size)[::self.patch_size]
355 | w = np.arange(0, self.image_size)[::self.patch_size]
356 | ww, hh = np.meshgrid(h, w)
357 | hw = np.stack((hh, ww), axis=-1)
358 | hw = hw.reshape(-1, 2)
359 | return torch.from_numpy(hw).int()
360 |
361 | def get_patch_centers(self):
362 | mid = self.patch_size // 2
363 | patch_locations_idx = self.get_patch_location_idx()
364 | patch_locations_idx += mid
365 | return patch_locations_idx
366 |
367 | def get_unfold_shape(self):
368 | dummy_input = torch.zeros(1, self.cdim, self.image_size, self.image_size)
369 | patches = dummy_input.unfold(2, self.kh, self.dh).unfold(3, self.kw, self.dw)
370 | unfold_shape = patches.shape[1:]
371 | return unfold_shape
372 |
373 | def img_to_patches(self, x):
374 | patches = x.unfold(2, self.kh, self.dh).unfold(3, self.kw, self.dw)
375 | patches = patches.contiguous().view(patches.shape[0], patches.shape[1], -1, self.kh, self.kw)
376 | return patches
377 |
378 | def patches_to_img(self, x):
379 | patches_orig = x.view(x.shape[0], *self.unfold_shape)
380 | output_h = self.unfold_shape[1] * self.unfold_shape[2]
381 | output_w = self.unfold_shape[2] * self.unfold_shape[4]
382 | patches_orig = patches_orig.permute(0, 1, 2, 4, 3, 5).contiguous()
383 | patches_orig = patches_orig.view(-1, self.cdim, output_h, output_w)
384 | return patches_orig
385 |
386 | def forward(self, x, patches=True):
387 | # x [batch_size, 3, image_size, image_size] or [batch_size, 3, num_patches, image_size, image_size]
388 | if patches:
389 | return self.img_to_patches(x)
390 | else:
391 | return self.patches_to_img(x)
392 |
393 |
394 | class VariationalKeyPointPatchEncoder(nn.Module):
395 | """
396 | This module encodes patches to KP via SSM. Additionally, we implement a variational version that encodes
397 | log-variance in addition to the mean, but we don't use it in practice, as constant prior std works better.
398 | We also experimented with extracting features directly from the patches for a prior for the visual features,
399 | but we didn't find it better than a constant prior (~N(0,1)). However, you can explore it by setting
400 | `learned_feature_dim`>0.
401 | """
402 |
403 | def __init__(self, cdim=3, channels=(16, 16, 32), image_size=64, n_kp=4, patch_size=16, kp_range=(0, 1),
404 | use_logsoftmax=False, pad_mode='replicate', sigma=0.1, dropout=0.0, learnable_logvar=False,
405 | learned_feature_dim=0):
406 | super(VariationalKeyPointPatchEncoder, self).__init__()
407 | self.use_logsoftmax = use_logsoftmax
408 | self.image_size = image_size
409 | self.dropout = dropout
410 | self.kp_range = kp_range
411 | self.n_kp = n_kp # kp per patch
412 | self.patcher = ImagePatcher(cdim=cdim, image_size=image_size, patch_size=patch_size)
413 | self.features_dim = int(patch_size // (2 ** (len(channels) - 1)))
414 | self.enc = KeyPointCNNOriginal(cdim=cdim, channels=channels, image_size=patch_size, n_kp=n_kp,
415 | pad_mode=pad_mode, use_resblock=False, first_conv_kernel_size=3)
416 | self.ssm = SpatialLogSoftmaxKP(kp_range=kp_range) if use_logsoftmax else SpatialSoftmaxKP(kp_range=kp_range)
417 | self.sigma = sigma
418 | self.learnable_logvar = learnable_logvar
419 | self.learned_feature_dim = learned_feature_dim
420 | if self.learnable_logvar:
421 | self.to_logvar = nn.Sequential(nn.Linear(self.n_kp * (self.features_dim ** 2), 512),
422 | nn.ReLU(True),
423 | nn.Linear(512, 256),
424 | nn.ReLU(True),
425 | nn.Linear(256, self.n_kp * 2)) # logvar_x, logvar_y
426 | if self.learned_feature_dim > 0:
427 | self.to_features = nn.Sequential(nn.Linear(self.n_kp * (self.features_dim ** 2), 512),
428 | nn.ReLU(True),
429 | nn.Linear(512, 256),
430 | nn.ReLU(True),
431 | nn.Linear(256, self.n_kp * self.learned_feature_dim)) # logvar_x, logvar_y
432 |
433 | def img_to_patches(self, x):
434 | return self.patcher.img_to_patches(x)
435 |
436 | def patches_to_img(self, x):
437 | return self.patcher.patches_to_img(x)
438 |
439 | def get_global_kp(self, local_kp):
440 | # local_kp: [batch_size, num_patches, n_kp, 2]
441 | # returns the global coordinates of a KP within the original image.
442 | batch_size, num_patches, n_kp, _ = local_kp.shape
443 | global_coor = self.patcher.get_patch_location_idx().to(local_kp.device) # [num_patches, 2]
444 | global_coor = global_coor[:, None, :].repeat(1, n_kp, 1)
445 | global_coor = (((local_kp - self.kp_range[0]) / (self.kp_range[1] - self.kp_range[0])) * (
446 | self.patcher.patch_size - 1) + global_coor) / (self.image_size - 1)
447 | global_coor = global_coor * (self.kp_range[1] - self.kp_range[0]) + self.kp_range[0]
448 | return global_coor
449 |
450 | def get_distance_from_patch_centers(self, kp, global_kp=False):
451 | # calculates the distance of a KP from the center of its parent patch. This is useful to understand (and filter)
452 | # if SSM detected something, otherwise, the KP will probably land in the center of the patch
453 | # (e.g., a solid-color patch will have the same activation in all pixels).
454 | if not global_kp:
455 | global_coor = self.get_global_kp(kp).view(kp.shape[0], -1, 2)
456 | else:
457 | global_coor = kp
458 | centers = 0.5 * (self.kp_range[1] + self.kp_range[0]) * torch.ones_like(kp).to(kp.device)
459 | global_centers = self.get_global_kp(centers.view(kp.shape[0], -1, self.n_kp, 2)).view(kp.shape[0], -1, 2)
460 | return ((global_coor - global_centers) ** 2).sum(-1)
461 |
462 | def encode(self, x, global_kp=False):
463 | # x: [batch_size, cdim, image_size, image_size]
464 | # global_kp: set True to get the global coordinates within the image (instead of local KP inside the patch)
465 | batch_size, cdim, image_size, image_size = x.shape
466 | x_patches = self.img_to_patches(x) # [batch_size, cdim, num_patches, patch_size, patch_size]
467 | x_patches = x_patches.permute(0, 2, 1, 3, 4) # [batch_size, num_patches, cdim, patch_size, patch_size]
468 | x_patches = x_patches.contiguous().view(-1, cdim, self.patcher.patch_size, self.patcher.patch_size)
469 | _, z = self.enc(x_patches) # [batch_size*num_patches, n_kp, features_dim, features_dim]
470 | mu_kp = self.ssm(z, probs=False) # [batch_size * num_patches, n_kp, 2]
471 | mu_kp = mu_kp.view(batch_size, -1, self.n_kp, 2) # [batch_size, num_patches, n_kp, 2]
472 | if global_kp:
473 | mu_kp = self.get_global_kp(mu_kp)
474 | if self.learned_feature_dim > 0:
475 | mu_features = self.to_features(z.view(z.shape[0], -1))
476 | mu_features = mu_features.view(batch_size, -1, self.n_kp, self.learned_feature_dim)
477 | # [batch_size, num_patches, n_kp, learned_feature_dim]
478 | if self.learnable_logvar:
479 | logvar_kp = self.to_logvar(z.view(z.shape[0], -1))
480 | logvar_kp = logvar_kp.view(batch_size, -1, self.n_kp, 2) # [batch_size, num_patches, n_kp, 2]
481 | if self.learned_feature_dim > 0:
482 | return mu_kp, logvar_kp, mu_features
483 | else:
484 | return mu_kp, logvar_kp
485 | elif self.learned_feature_dim > 0:
486 | return mu_kp, mu_features
487 | else:
488 | return mu_kp
489 |
490 | def forward(self, x, global_kp=False):
491 | return self.encode(x, global_kp)
492 |
493 |
494 | class ObjectEncoder(nn.Module):
495 | """
496 | Glimpse-encoder: encodes patches visual features in a variational fashion (mu, log-variance).
497 | Useful for object-based scenes.
498 | """
499 |
500 | def __init__(self, z_dim, anchor_size, image_size, cnn_channels=(16, 16, 32), margin=0, ch=3, cnn=False,
501 | encode_location=False):
502 | super().__init__()
503 |
504 | self.anchor_size = anchor_size
505 | self.channels = cnn_channels
506 | self.z_dim = z_dim
507 | self.image_size = image_size
508 | self.patch_size = np.round(anchor_size * (image_size - 1)).astype(int)
509 | self.margin = margin
510 | self.crop_size = self.patch_size + 2 * margin
511 | self.ch = ch
512 | self.encode_location = encode_location
513 |
514 | if cnn:
515 | self.cnn = KeyPointCNNOriginal(cdim=ch, channels=cnn_channels, image_size=self.crop_size, n_kp=32,
516 | pad_mode='replicate', use_resblock=False, first_conv_kernel_size=3)
517 | else:
518 | self.cnn = None
519 | fc_in_dim = 32 * ((self.crop_size // 4) ** 2) if cnn else self.ch * (self.crop_size ** 2)
520 | fc_in_dim = fc_in_dim + 2 if self.encode_location else fc_in_dim
521 | self.fc = nn.Sequential(nn.Linear(fc_in_dim, 256),
522 | nn.ReLU(True),
523 | nn.Linear(256, 128),
524 | nn.ReLU(True),
525 | nn.Linear(128, self.z_dim * 2))
526 |
527 | def center_objects(self, kp, padded_objects):
528 | # after masking everything outside of the glimpse, move the unmasked area to the center -- makes it easier
529 | # to crop the patches.
530 | batch_size, n_kp, _ = kp.shape
531 | delta_tr_batch = kp
532 | delta_tr_batch = delta_tr_batch.reshape(-1, delta_tr_batch.shape[-1]) # [bs * n_kp, 2]
533 | source_batch = padded_objects.reshape(-1, *padded_objects.shape[2:])
534 |
535 | zeros = torch.zeros([delta_tr_batch.shape[0], 1], device=delta_tr_batch.device).float()
536 | ones = torch.ones([delta_tr_batch.shape[0], 1], device=delta_tr_batch.device).float()
537 |
538 | theta = torch.cat([ones, zeros, delta_tr_batch[:, 1].unsqueeze(-1),
539 | zeros, ones, delta_tr_batch[:, 0].unsqueeze(-1)], dim=-1)
540 | theta = theta.view(-1, 2, 3) # [batch_size * n_kp, 2, 3]
541 |
542 | align_corners = False
543 | padding_mode = 'zeros'
544 | mode = 'bilinear'
545 |
546 | grid = F.affine_grid(theta, source_batch.size(), align_corners=align_corners)
547 | trans_source_batch = F.grid_sample(source_batch, grid, align_corners=align_corners,
548 | mode=mode, padding_mode=padding_mode)
549 | trans_source_batch = trans_source_batch.view(batch_size, n_kp, *source_batch.shape[1:])
550 | return trans_source_batch
551 |
552 | def get_cropped_objects(self, centered_objects):
553 | # extracts the unmasked area of the image after the glimpses have been centered in the original image.
554 | center_idx = self.image_size // 2
555 | margin = self.margin
556 | w_start = center_idx - self.patch_size // 2 - margin
557 | w_end = center_idx + self.patch_size // 2 + margin
558 | h_start = center_idx - self.patch_size // 2 - margin
559 | h_end = center_idx + self.patch_size // 2 + margin
560 | cropped_objects = centered_objects[:, :, :, w_start:w_end, h_start:h_end]
561 | return cropped_objects
562 |
563 | def forward(self, x, kp, exclusive_patches=False, obj_on=None):
564 | # x: [bs, ch, image_size, image_size]
565 | # kp: [bs, n_kp, 2] in [-1, 1]
566 | # exclusive_objects: create cumulative masks to avoid overlapping objects, THIS WAS NOT USED IN THE PAPER
567 | # this will (mostly) enforce one particle pre object and
568 | # won't allow several layers per object
569 | batch_size, _, _, img_size = x.shape
570 | _, n_kp, _ = kp.shape
571 | # create masks from kp
572 | masks = create_masks_fast(kp.detach(), self.anchor_size, feature_dim=self.image_size)
573 | # [batch_size, n_kp, 1, feature_dim, feature_dim]
574 | if exclusive_patches:
575 | if obj_on is not None:
576 | masks = obj_on[:, :, None, None, None] * masks
577 | masks = masks.clamp(0, 1) # STN can cause values to be outside [0, 1]
578 | # create cumulative masks to avoid overlapping objects
579 | curr_mask = masks[:, 0]
580 | comp_masks = [curr_mask]
581 | for i in range(1, masks.shape[1]):
582 | available_space = 1.0 - curr_mask.detach()
583 | curr_mask_tmp = torch.min(available_space, masks[:, i])
584 | comp_masks.append(curr_mask_tmp)
585 | curr_mask = curr_mask + curr_mask_tmp
586 | masks = torch.stack(comp_masks, dim=1)
587 | # extract objects
588 | padded_objects = masks * x.unsqueeze(1) # [batch_size, n_kp, ch, image_size, image_size]
589 | # center objects
590 | centered_objects = self.center_objects(kp, padded_objects) # [batch_size, n_kp, ch, image_size, image_size]
591 | # get crop
592 | cropped_objects = self.get_cropped_objects(centered_objects)
593 | # [batch_size, n_kp, ch, patch_size + margin * 2, image_size + margin * 2]
594 |
595 | # encode objects - fc
596 | if self.cnn is not None:
597 | _, cropped_objects_cnn = self.cnn(cropped_objects.view(-1, *cropped_objects.shape[2:]))
598 | else:
599 | cropped_objects_cnn = cropped_objects
600 | cropped_objects_flat = cropped_objects_cnn.reshape(batch_size, n_kp, -1) # flatten
601 | if self.encode_location:
602 | cropped_objects_flat = torch.cat([cropped_objects_flat, kp], dim=-1) # [batch_size, n_kp, .. + 2]
603 | enc_out = self.fc(cropped_objects_flat)
604 | mu, logvar = enc_out.chunk(2, dim=-1) # [batch_size, n_kp, z_dim]
605 | return mu, logvar, cropped_objects
606 |
607 |
608 | class ObjectDecoderCNN(nn.Module):
609 | """
610 | Glimpse Decoder: this module takes in a latent object tensor (vector) and upsamples it to an RGBA patch
611 | (RGBA = RGB + alpha channel).
612 | """
613 |
614 | def __init__(self, patch_size, num_chans=4, bottleneck_size=128, pad_mode='replicate'):
615 | super().__init__()
616 |
617 | if isinstance(patch_size, int):
618 | patch_size = (patch_size, patch_size)
619 | self.patch_size = patch_size
620 | self.num_chans = num_chans
621 |
622 | self.in_ch = 32
623 |
624 | fc_out_dim = self.in_ch * 8 * 8
625 | self.fc = nn.Sequential(nn.Linear(bottleneck_size, 256, bias=True),
626 | nn.ReLU(True),
627 | nn.Linear(256, fc_out_dim),
628 | nn.ReLU(True))
629 |
630 | # num_upsample = int(np.log(patch_size[0]) // np.log(2)) - 3
631 | num_upsample = int(np.log2(patch_size[0])) - 3 # thanks @MoritzLange
632 | # print(f'ObjDecCNN: fc to cnn num upsample: {num_upsample}')
633 | self.channels = [32]
634 | for i in range(num_upsample):
635 | self.channels.append(64)
636 | cc = self.channels[-1]
637 |
638 | sz = 8
639 |
640 | self.main = nn.Sequential()
641 | self.main.add_module('depth_up',
642 | ConvBlock(self.in_ch, cc, kernel_size=3, pad=1, upsample=True, pad_mode=pad_mode,
643 | use_resblock=False))
644 | for ch in reversed(self.channels[1:-1]):
645 | self.main.add_module('conv_to_{}'.format(sz * 2), ConvBlock(cc, ch, kernel_size=3, pad=1, upsample=True,
646 | pad_mode=pad_mode, use_resblock=False))
647 | cc, sz = ch, sz * 2
648 |
649 | self.main.add_module('conv_to_{}'.format(sz * 2),
650 | ConvBlock(cc, self.channels[0], kernel_size=3, pad=1,
651 | upsample=False, pad_mode=pad_mode, use_resblock=False))
652 | self.main.add_module('final_conv', ConvBlock(self.channels[0], num_chans, kernel_size=1, bias=True,
653 | activation=False, batchnorm=False))
654 | self.decode = self.main
655 |
656 | def forward(self, x):
657 | if x.dim() == 3:
658 | x = x.reshape(-1, x.shape[-1])
659 | conv_in = self.fc(x)
660 | conv_in = conv_in.view(-1, 32, 8, 8)
661 | out = self.decode(conv_in).view(-1, self.num_chans, *self.patch_size)
662 | out = torch.sigmoid(out)
663 | return out
664 |
665 |
666 | # torch geometric modules to process graphs/KP
667 | class PointNetPPLayer(MessagePassing):
668 | def __init__(self, in_channels, out_channels, axis_dim=2):
669 | # Message passing with "max" aggregation.
670 | super(PointNetPPLayer, self).__init__('max')
671 | self.axis_dim = axis_dim
672 | # Initialization of the MLP:
673 | # Here, the number of input features correspond to the hidden node
674 | # dimensionality plus point dimensionality (=3).
675 | self.mlp = nn.Sequential(nn.Linear(in_channels + axis_dim, out_channels),
676 | nn.BatchNorm1d(out_channels),
677 | nn.ReLU(),
678 | nn.Linear(out_channels, out_channels),
679 | nn.BatchNorm1d(out_channels))
680 |
681 | def forward(self, h, pos, edge_index):
682 | # Start propagating messages.
683 | return self.propagate(edge_index, h=h, pos=pos)
684 |
685 | def message(self, h_j, pos_j, pos_i):
686 | # h_j defines the features of neighboring nodes as shape [num_edges, in_channels]
687 | # pos_j defines the position of neighboring nodes as shape [num_edges, 3]
688 | # pos_i defines the position of central nodes as shape [num_edges, 3]
689 |
690 | input_pos = pos_j - pos_i # Compute spatial relation.
691 |
692 | if h_j is not None:
693 | # In the first layer, we may not have any hidden node features,
694 | # so we only combine them in case they are present.
695 | input_pos = torch.cat([h_j, input_pos], dim=-1)
696 |
697 | return self.mlp(input_pos) # Apply our final MLP.
698 |
699 |
700 | class PointNetPPToCNN(nn.Module):
701 | def __init__(self, axis_dim=2, target_hw=16, n_kp=8, pad_mode='replicate', with_fps=False, features_dim=2):
702 | super(PointNetPPToCNN, self).__init__()
703 | # features_dim : 2 [logvar] + additional features
704 | self.with_fps = with_fps
705 | self.axis_dim = axis_dim # mu
706 | self.features_dim = features_dim # logvar, features
707 | self.conv1 = PointNetPPLayer(self.axis_dim + self.features_dim, 64, axis_dim=axis_dim)
708 | self.conv2 = PointNetPPLayer(64, 128, axis_dim=axis_dim)
709 | self.conv3 = PointNetPPLayer(128, 256, axis_dim=axis_dim)
710 | self.conv4 = PointNetPPLayer(256, 512, axis_dim=axis_dim)
711 |
712 | self.n_kp = n_kp
713 |
714 | fc_out_dim = self.n_kp * 8 * 8
715 |
716 | self.fc = nn.Sequential(nn.Linear(512, fc_out_dim, bias=True),
717 | nn.ReLU(True))
718 |
719 | # num_upsample = int(np.log(target_hw) // np.log(2)) - 3
720 | num_upsample = int(np.log2(patch_size[0])) - 3 # thanks @MoritzLange
721 | # print(f'pointnet to cnn num upsample: {num_upsample}')
722 | self.cnn = nn.Sequential()
723 | for i in range(num_upsample):
724 | self.cnn.add_module(f'depth_up_{i}', ConvBlock(n_kp, n_kp, kernel_size=3, pad=1,
725 | upsample=True, pad_mode=pad_mode))
726 |
727 | def forward(self, position, features):
728 | # position [batch_size, n_kp, 2]
729 | # features [batch_size, n_kp, features_dim]
730 | pos = position
731 | batch = torch.arange(pos.shape[0]).view(-1, 1).repeat(1, pos.shape[1]).view(-1).to(pos.device)
732 | pos = pos.view(-1, pos.shape[-1]) # [batch_size * n_kp, 2]
733 | features = features.view(-1, features.shape[-1]) # [batch_size * n_kp, features]
734 | # x [batch_size, n_kp, 2 or features_dim]
735 | # Compute the kNN graph:
736 | # Here, we need to pass the batch vector to the function call in order
737 | # to prevent creating edges between points of different examples.
738 | # We also add `loop=True` which will add self-loops to the graph in
739 | # order to preserve central point information.
740 | edge_index = knn_graph(pos, k=10, batch=batch, loop=True)
741 |
742 | # 3. Start bipartite message passing.
743 | h = self.conv1(h=features, pos=pos, edge_index=edge_index)
744 | h = h.relu()
745 | # print(f'conv1 h: {h.shape}')
746 | if self.with_fps:
747 | index = fps(pos, batch=batch, ratio=0.5)
748 | pos = pos[index]
749 | h = h[index]
750 | batch = batch[index]
751 | edge_index = knn_graph(pos, k=5, batch=batch, loop=True)
752 | h = self.conv2(h=h, pos=pos, edge_index=edge_index)
753 | h = h.relu()
754 | # print(f'conv2 h: {h.shape}')
755 | if self.with_fps:
756 | index = fps(pos, batch=batch, ratio=0.5)
757 | pos = pos[index]
758 | h = h[index]
759 | batch = batch[index]
760 | edge_index = knn_graph(pos, k=3, batch=batch, loop=True)
761 | h = self.conv3(h=h, pos=pos, edge_index=edge_index)
762 | h = h.relu()
763 | h = self.conv4(h=h, pos=pos, edge_index=edge_index)
764 | h = h.relu()
765 | # 4. Global Pooling.
766 | h = global_max_pool(h, batch) # [num_examples, hidden_channels]
767 | # 5. FC
768 | h = self.fc(h)
769 | h = h.view(-1, self.n_kp, 8, 8) # [batch_size, n_kp, 4, 4]
770 | cnn_out = self.cnn(h) # [batch_size, n_kp, target_hw, target_hw]
771 | return cnn_out
772 |
--------------------------------------------------------------------------------
/requirements17.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.3.0
2 | h5py==2.10.0
3 | imageio==2.6.1
4 | matplotlib==3.4.2
5 | numpy==1.19.5
6 | pandas==1.3.1
7 | Pillow==8.3.1
8 | pip==21.1.3
9 | scikit-image==0.18.1
10 | scikit-learn==0.24.2
11 | scipy==1.7.0
12 | torch==1.7.1
13 | torch-cluster==1.5.9
14 | torch-geometric==1.7.2
15 | torch-scatter==2.0.7
16 | torch-sparse==0.6.9
17 | torch-spline-conv==1.2.1
18 | torchvision==0.8.2
19 | tqdm==4.62.3
20 |
--------------------------------------------------------------------------------
/requirements19.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.5.1
2 | h5py==2.10.0
3 | imageio==2.6.1
4 | matplotlib==3.4.2
5 | matplotlib-inline==0.1.3
6 | numpy==1.20.3
7 | opencv-python==3.4.2.17
8 | pandas==1.3.0
9 | Pillow==8.3.1
10 | pip==21.0.1
11 | scikit-image==0.18.1
12 | scikit-learn==0.24.2
13 | scipy==1.6.3
14 | torch==1.9.0
15 | torch-cluster==1.5.9
16 | torch-geometric==2.0.0
17 | torch-scatter==2.0.8
18 | torch-sparse==0.6.12
19 | torch-spline-conv==1.2.1
20 | torchfile==0.1.0
21 | torchvision==0.10.0
22 | tqdm==4.62.2
23 | yacs==0.1.6
24 |
--------------------------------------------------------------------------------
/utils/loss_functions.py:
--------------------------------------------------------------------------------
1 | """
2 | Loss functions implementations used in the optimization of DLP.
3 | """
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torchvision import models
8 |
9 |
10 | # functions
11 | def batch_pairwise_kl(mu_x, logvar_x, mu_y, logvar_y, reverse_kl=False):
12 | """
13 | Calculate batch-wise KL-divergence
14 | mu_x, logvar_x: [batch_size, n_x, points_dim]
15 | mu_y, logvar_y: [batch_size, n_y, points_dim]
16 | kl = -0.5 * Σ_points_dim (1 + logvar_x - logvar_y - exp(logvar_x)/exp(logvar_y)
17 | - ((mu_x - mu_y) ** 2)/exp(logvar_y))
18 | """
19 | if reverse_kl:
20 | mu_a, logvar_a = mu_y, logvar_y
21 | mu_b, logvar_b = mu_x, logvar_x
22 | else:
23 | mu_a, logvar_a = mu_x, logvar_x
24 | mu_b, logvar_b = mu_y, logvar_y
25 | bs, n_a, points_dim = mu_a.size()
26 | _, n_b, _ = mu_b.size()
27 | logvar_aa = logvar_a.unsqueeze(2).expand(-1, -1, n_b, -1) # [batch_size, n_a, n_b, points_dim]
28 | logvar_bb = logvar_b.unsqueeze(1).expand(-1, n_a, -1, -1) # [batch_size, n_a, n_b, points_dim]
29 | mu_aa = mu_a.unsqueeze(2).expand(-1, -1, n_b, -1) # [batch_size, n_a, n_b, points_dim]
30 | mu_bb = mu_b.unsqueeze(1).expand(-1, n_a, -1, -1) # [batch_size, n_a, n_b, points_dim]
31 | p_kl = -0.5 * (1 + logvar_aa - logvar_bb - logvar_aa.exp() / logvar_bb.exp()
32 | - ((mu_aa - mu_bb) ** 2) / logvar_bb.exp()).sum(-1) # [batch_size, n_x, n_y]
33 | return p_kl
34 |
35 |
36 | def calc_reconstruction_loss(x, recon_x, loss_type='mse', reduction='sum'):
37 | """
38 |
39 | :param x: original inputs
40 | :param recon_x: reconstruction of the VAE's input
41 | :param loss_type: "mse", "l1", "bce"
42 | :param reduction: "sum", "mean", "none"
43 | :return: recon_loss
44 | """
45 | if reduction not in ['sum', 'mean', 'none']:
46 | raise NotImplementedError
47 | recon_x = recon_x.view(recon_x.size(0), -1)
48 | x = x.view(x.size(0), -1)
49 | if loss_type == 'mse':
50 | recon_error = F.mse_loss(recon_x, x, reduction='none')
51 | recon_error = recon_error.sum(1)
52 | if reduction == 'sum':
53 | recon_error = recon_error.sum()
54 | elif reduction == 'mean':
55 | recon_error = recon_error.mean()
56 | elif loss_type == 'l1':
57 | recon_error = F.l1_loss(recon_x, x, reduction=reduction)
58 | elif loss_type == 'bce':
59 | recon_error = F.binary_cross_entropy(recon_x, x, reduction=reduction)
60 | else:
61 | raise NotImplementedError
62 | return recon_error
63 |
64 |
65 | def calc_kl(logvar, mu, mu_o=0.0, logvar_o=0.0, reduce='sum'):
66 | """
67 | Calculate kl-divergence
68 | :param logvar: log-variance from the encoder
69 | :param mu: mean from the encoder
70 | :param mu_o: negative mean for outliers (hyper-parameter)
71 | :param logvar_o: negative log-variance for outliers (hyper-parameter)
72 | :param reduce: type of reduce: 'sum', 'none'
73 | :return: kld
74 | """
75 | if not isinstance(mu_o, torch.Tensor):
76 | mu_o = torch.tensor(mu_o).to(mu.device)
77 | if not isinstance(logvar_o, torch.Tensor):
78 | logvar_o = torch.tensor(logvar_o).to(mu.device)
79 | kl = -0.5 * (1 + logvar - logvar_o - logvar.exp() / torch.exp(logvar_o) - (mu - mu_o).pow(2) / torch.exp(
80 | logvar_o)).sum(1)
81 | if reduce == 'sum':
82 | kl = torch.sum(kl)
83 | elif reduce == 'mean':
84 | kl = torch.mean(kl)
85 | return kl
86 |
87 |
88 | # classes
89 | class ChamferLossKL(nn.Module):
90 | """
91 | Calculates the KL-divergence between two sets of (R.V.) particle coordinates.
92 | """
93 | def __init__(self, use_reverse_kl=False):
94 | super(ChamferLossKL, self).__init__()
95 | self.use_reverse_kl = use_reverse_kl
96 |
97 | def forward(self, mu_preds, logvar_preds, mu_gts, logvar_gts):
98 | p_kl = batch_pairwise_kl(mu_preds, logvar_preds, mu_gts, logvar_gts, reverse_kl=False)
99 | if self.use_reverse_kl:
100 | p_rkl = batch_pairwise_kl(mu_preds, logvar_preds, mu_gts, logvar_gts, reverse_kl=True)
101 | p_kl = 0.5 * (p_kl + p_rkl.transpose(2, 1))
102 | mins, _ = torch.min(p_kl, 1)
103 | loss_1 = torch.sum(mins, 1)
104 | mins, _ = torch.min(p_kl, 2)
105 | loss_2 = torch.sum(mins, 1)
106 | return loss_1 + loss_2
107 |
108 |
109 | class NetVGGFeatures(nn.Module):
110 |
111 | def __init__(self, layer_ids):
112 | super().__init__()
113 |
114 | self.vggnet = models.vgg16(pretrained=True)
115 | self.layer_ids = layer_ids
116 |
117 | def forward(self, x):
118 | output = []
119 | for i in range(self.layer_ids[-1] + 1):
120 | x = self.vggnet.features[i](x)
121 |
122 | if i in self.layer_ids:
123 | output.append(x)
124 |
125 | return output
126 |
127 |
128 | class VGGDistance(nn.Module):
129 |
130 | def __init__(self, layer_ids=(2, 7, 12, 21, 30), accumulate_mode='sum', device=torch.device("cpu")):
131 | super().__init__()
132 |
133 | self.vgg = NetVGGFeatures(layer_ids).to(device)
134 | self.layer_ids = layer_ids
135 | self.accumulate_mode = accumulate_mode
136 | self.device = device
137 |
138 | def forward(self, I1, I2, reduction='sum', only_image=False):
139 | b_sz = I1.size(0)
140 | num_ch = I1.size(1)
141 |
142 | if self.accumulate_mode == 'sum':
143 | loss = ((I1 - I2) ** 2).view(b_sz, -1).sum(1)
144 | else:
145 | loss = ((I1 - I2) ** 2).view(b_sz, -1).mean(1)
146 |
147 | if num_ch == 1:
148 | I1 = I1.repeat(1, 3, 1, 1)
149 | I2 = I2.repeat(1, 3, 1, 1)
150 | f1 = self.vgg(I1)
151 | f2 = self.vgg(I2)
152 |
153 | if not only_image:
154 | for i in range(len(self.layer_ids)):
155 | if self.accumulate_mode == 'sum':
156 | layer_loss = ((f1[i] - f2[i]) ** 2).view(b_sz, -1).sum(1)
157 | else:
158 | layer_loss = ((f1[i] - f2[i]) ** 2).view(b_sz, -1).mean(1)
159 | loss = loss + layer_loss
160 |
161 | if reduction == 'mean':
162 | return loss.mean()
163 | elif reduction == 'sum':
164 | return loss.sum()
165 | else:
166 | return loss
167 |
168 | def get_dimensions(self, device=torch.device("cpu")):
169 | dims = []
170 | dummy_input = torch.zeros(1, 3, 128, 128).to(device)
171 | dims.append(dummy_input.view(1, -1).size(1))
172 | f = self.vgg(dummy_input)
173 | for i in range(len(self.layer_ids)):
174 | dims.append(f[i].view(1, -1).size(1))
175 | return dims
176 |
177 |
178 | class ChamferLoss(nn.Module):
179 |
180 | def __init__(self):
181 | super(ChamferLoss, self).__init__()
182 | self.use_cuda = torch.cuda.is_available()
183 |
184 | def forward(self, preds, gts):
185 | P = self.batch_pairwise_dist(gts, preds)
186 | mins, _ = torch.min(P, 1)
187 | loss_1 = torch.sum(mins, 1)
188 | mins, _ = torch.min(P, 2)
189 | loss_2 = torch.sum(mins, 1)
190 | return loss_1 + loss_2
191 |
192 | def batch_pairwise_dist(self, x, y):
193 | bs, num_points_x, points_dim = x.size()
194 | _, num_points_y, _ = y.size()
195 | xx = torch.bmm(x, x.transpose(2, 1))
196 | yy = torch.bmm(y, y.transpose(2, 1))
197 | zz = torch.bmm(x, y.transpose(2, 1))
198 | if self.use_cuda:
199 | dtype = torch.cuda.LongTensor
200 | else:
201 | dtype = torch.LongTensor
202 | diag_ind_x = torch.arange(0, num_points_x).type(dtype)
203 | diag_ind_y = torch.arange(0, num_points_y).type(dtype)
204 | rx = xx[:, diag_ind_x, diag_ind_x].unsqueeze(1).expand_as(
205 | zz.transpose(2, 1))
206 | ry = yy[:, diag_ind_y, diag_ind_y].unsqueeze(1).expand_as(zz)
207 | P = rx.transpose(2, 1) + ry - 2 * zz
208 | return P
209 |
--------------------------------------------------------------------------------
/utils/tps.py:
--------------------------------------------------------------------------------
1 | """
2 | Credits:
3 | https://github.com/jamt9000/DVE
4 | """
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | import math
9 |
10 |
11 | def tps_grid(H, W):
12 | xi = torch.linspace(-1, 1, W)
13 | yi = torch.linspace(-1, 1, H)
14 |
15 | yy, xx = torch.meshgrid(yi, xi)
16 | grid = torch.stack((xx.reshape(-1), yy.reshape(-1)), 1)
17 | return grid
18 |
19 |
20 | def spatial_grid_unnormalized(H, W):
21 | xi = torch.linspace(0, W - 1, W)
22 | yi = torch.linspace(0, H - 1, H)
23 |
24 | yy, xx = torch.meshgrid(yi, xi)
25 | grid = torch.stack((xx.reshape(-1), yy.reshape(-1)), 1)
26 | return grid.reshape(H, W, 2)
27 |
28 |
29 | def tps_U(grid1, grid2):
30 | D = grid1.reshape(-1, 1, 2) - grid2.reshape(1, -1, 2)
31 | D = torch.sum(D ** 2., 2)
32 | U = D * torch.log(D + 1e-5)
33 | return U
34 |
35 |
36 | def grid_unnormalize(grid, H, W):
37 | x = grid.reshape(-1, H, W, 2)
38 | constants = torch.tensor([W - 1., H - 1.], dtype=x.dtype)
39 | constants = constants.reshape(1, 1, 1, 2).to(x.device)
40 | x = (x + 1.) / 2. * constants
41 | return x.reshape(grid.shape)
42 |
43 |
44 | def grid_normalize(grid, H, W):
45 | x = grid.reshape(-1, H, W, 2)
46 | x = 2. * x / torch.Tensor([W - 1., H - 1.]).reshape(1, 1, 1, 2).to(x.device) - 1
47 | return x.reshape(grid.shape)
48 |
49 |
50 | def random_tps_weights(nctrlpts, warpsd_all, warpsd_subset, transsd, scalesd, rotsd):
51 | W = torch.randn(nctrlpts, 2) * warpsd_all
52 | subset = torch.rand(W.shape) > 0.5
53 | W[subset] = torch.randn(subset.sum()) * warpsd_subset
54 | rot = torch.randn([]) * rotsd * math.pi / 180
55 | sc = 1. + torch.randn([]) * scalesd
56 | tx = torch.randn([]) * transsd
57 | ty = torch.randn([]) * transsd
58 |
59 | aff = torch.Tensor([[tx, ty],
60 | [sc * torch.cos(rot), sc * -torch.sin(rot)],
61 | [sc * torch.sin(rot), sc * torch.cos(rot)]])
62 |
63 | Wa = torch.cat((W, aff), 0)
64 | return Wa
65 |
66 |
67 | class Warper(object):
68 | returns_pairs = True
69 |
70 | def __init__(self, H, W, warpsd_all=0.001, warpsd_subset=0.01, transsd=0.1,
71 | scalesd=0.1, rotsd=5, im1_multiplier=0.5, im1_multiplier_aff=1.):
72 | self.H = H
73 | self.W = W
74 | self.warpsd_all = warpsd_all
75 | self.warpsd_subset = warpsd_subset
76 | self.transsd = transsd
77 | self.scalesd = scalesd
78 | self.rotsd = rotsd
79 | self.im1_multiplier = im1_multiplier
80 | self.im1_multiplier_aff = im1_multiplier_aff
81 |
82 | self.npixels = H * W
83 | self.nc = 10
84 | self.nctrlpts = self.nc * self.nc
85 |
86 | self.grid_pixels = tps_grid(H, W)
87 | self.grid_pixels_unnormalized = grid_unnormalize(self.grid_pixels.reshape(1, H, W, 2), self.H, self.W)
88 | self.grid_ctrlpts = tps_grid(self.nc, self.nc)
89 | self.U_ctrlpts = tps_U(self.grid_ctrlpts, self.grid_ctrlpts)
90 | self.U_pixels_ctrlpts = tps_U(self.grid_pixels, self.grid_ctrlpts)
91 | self.F = torch.cat((self.U_pixels_ctrlpts, torch.ones(self.npixels, 1), self.grid_pixels), 1)
92 |
93 | def __call__(self, im1, im2=None, keypts=None, crop=0):
94 | Hc = self.H - crop - crop
95 | Wc = self.W - crop - crop
96 |
97 | # im2 should be a copy of im1 with different colour jitter
98 | if im2 is None:
99 | im2 = im1
100 |
101 | kp1 = kp2 = 0
102 |
103 | unsqueezed = False
104 | if len(im1.shape) == 3:
105 | im1 = im1.unsqueeze(0)
106 | im2 = im2.unsqueeze(0)
107 | unsqueezed = True
108 |
109 | assert im1.shape[0] == 1 and im2.shape[0] == 1
110 |
111 | a = self.im1_multiplier
112 | b = self.im1_multiplier_aff
113 | weights1 = random_tps_weights(self.nctrlpts, a * self.warpsd_all, a * self.warpsd_subset, b * self.transsd,
114 | b * self.scalesd, b * self.rotsd)
115 |
116 | grid1 = torch.matmul(self.F, weights1).reshape(1, self.H, self.W, 2)
117 | grid1_unnormalized = grid_unnormalize(grid1, self.H, self.W)
118 | if keypts is not None:
119 | kp1 = self.warp_keypoints(keypts, grid1_unnormalized)
120 |
121 | im1 = F.grid_sample(im1, grid1, align_corners=True)
122 | im2 = F.grid_sample(im2, grid1, align_corners=True)
123 |
124 | weights2 = random_tps_weights(self.nctrlpts, self.warpsd_all, self.warpsd_subset, self.transsd,
125 | self.scalesd, self.rotsd)
126 | grid2 = torch.matmul(self.F, weights2).reshape(1, self.H, self.W, 2)
127 | im2 = F.grid_sample(im2, grid2, align_corners=True)
128 |
129 | if crop != 0:
130 | im1 = im1[:, :, crop:-crop, crop:-crop]
131 | im2 = im2[:, :, crop:-crop, crop:-crop]
132 |
133 | if unsqueezed:
134 | im1 = im1.squeeze(0)
135 | im2 = im2.squeeze(0)
136 |
137 | grid = grid2
138 | grid_unnormalized = grid_unnormalize(grid, self.H, self.W)
139 |
140 | if keypts is not None:
141 | kp2 = self.warp_keypoints(kp1, grid_unnormalized)
142 |
143 | flow = grid_unnormalized - self.grid_pixels_unnormalized
144 |
145 | if crop != 0:
146 | flow = flow[:, crop:-crop, crop:-crop, :]
147 |
148 | grid_cropped = grid_unnormalized[:, crop:-crop, crop:-crop, :] - crop
149 | grid = grid_normalize(grid_cropped, Hc, Wc)
150 |
151 | # hc = flow.shape[1]
152 | # wc = flow.shape[2]
153 | # gridc = flow + grid_unnormalize(tps_grid(hc, wc).reshape(1, hc, wc, 2), hc, wc)
154 | # grid = grid_normalize(gridc, hc, wc)
155 |
156 | if keypts is not None:
157 | kp1 -= crop
158 | kp2 -= crop
159 |
160 | # Reverse the order because due to inverse warping the "flow" is in direction im2->im1
161 | # and we want to be consistent with optical flow from videos
162 | return im2, im1, flow, grid, kp2, kp1
163 |
164 | def warp_keypoints(self, keypoints, grid_unnormalized):
165 | from scipy.spatial.kdtree import KDTree
166 | warp_grid = grid_unnormalized.reshape(-1, 2)
167 | regular_grid = self.grid_pixels_unnormalized.reshape(-1, 2)
168 | kd = KDTree(warp_grid)
169 | dists, idxs = kd.query(keypoints)
170 | new_keypoints = regular_grid[idxs]
171 | return new_keypoints
172 |
173 |
174 | class WarperSingle(object):
175 | returns_pairs = False
176 |
177 | def __init__(self, H, W, warpsd_all=0.0005, warpsd_subset=0.0, transsd=0.02,
178 | scalesd=0.02, rotsd=2):
179 | self.H = H
180 | self.W = W
181 | self.warpsd_all = warpsd_all
182 | self.warpsd_subset = warpsd_subset
183 | self.transsd = transsd
184 | self.scalesd = scalesd
185 | self.rotsd = rotsd
186 |
187 | self.npixels = H * W
188 | self.nc = 10
189 | self.nctrlpts = self.nc * self.nc
190 |
191 | self.grid_pixels = tps_grid(H, W)
192 | self.grid_pixels_unnormalized = grid_unnormalize(self.grid_pixels.reshape(1, H, W, 2), self.H, self.W)
193 | self.grid_ctrlpts = tps_grid(self.nc, self.nc)
194 | self.U_ctrlpts = tps_U(self.grid_ctrlpts, self.grid_ctrlpts)
195 | self.U_pixels_ctrlpts = tps_U(self.grid_pixels, self.grid_ctrlpts)
196 | self.F = torch.cat((self.U_pixels_ctrlpts, torch.ones(self.npixels, 1), self.grid_pixels), 1)
197 |
198 | def __call__(self, im1, keypts=None, crop=0):
199 | kp1 = 0
200 |
201 | unsqueezed = False
202 | if len(im1.shape) == 3:
203 | im1 = im1.unsqueeze(0)
204 | unsqueezed = True
205 |
206 | assert im1.shape[0] == 1
207 |
208 | a = 1
209 | weights1 = random_tps_weights(self.nctrlpts, a * self.warpsd_all, a * self.warpsd_subset, a * self.transsd,
210 | a * self.scalesd, a * self.rotsd)
211 |
212 | grid1 = torch.matmul(self.F, weights1).reshape(1, self.H, self.W, 2)
213 | grid1_unnormalized = grid_unnormalize(grid1, self.H, self.W)
214 | if keypts is not None:
215 | kp1 = self.warp_keypoints(keypts, grid1_unnormalized)
216 |
217 | im1 = F.grid_sample(im1, grid1, align_corners=True)
218 |
219 | if crop != 0:
220 | im1 = im1[:, :, crop:-crop, crop:-crop]
221 |
222 | if unsqueezed:
223 | im1 = im1.squeeze(0)
224 |
225 | if crop != 0 and keypts is not None:
226 | kp1 -= crop
227 |
228 | # Reverse the order because due to inverse warping the "flow" is in direction im2->im1
229 | # and we want to be consistent with optical flow from videos
230 | return im1, kp1
231 |
232 | def warp_keypoints(self, keypoints, grid_unnormalized):
233 | from scipy.spatial.kdtree import KDTree
234 | warp_grid = grid_unnormalized.reshape(-1, 2)
235 | regular_grid = self.grid_pixels_unnormalized.reshape(-1, 2)
236 | kd = KDTree(warp_grid)
237 | dists, idxs = kd.query(keypoints)
238 | new_keypoints = regular_grid[idxs]
239 | return new_keypoints
--------------------------------------------------------------------------------
/utils/util_func.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility logging, plotting and animation functions.
3 | """
4 |
5 | # imports
6 | import numpy as np
7 | # import matplotlib.pyplot as plt
8 | import cv2
9 | # from PIL import Image
10 | import datetime
11 | import os
12 | import json
13 | # import imageio
14 | # torch
15 | import torch
16 | import torch.nn.functional as F
17 | import torchvision.transforms as transforms
18 | import torchvision.ops as ops
19 |
20 |
21 | def color_map(num=100):
22 | """
23 | Color maps for the keypoints
24 | """
25 | colormap = ["FF355E",
26 | "8ffe09",
27 | "1d5dec",
28 | "FF9933",
29 | "FFFF66",
30 | "CCFF00",
31 | "AAF0D1",
32 | "FF6EFF",
33 | "FF00CC",
34 | "299617",
35 | "AF6E4D"] * num
36 | s = ''
37 | for color in colormap:
38 | s += color
39 | b = bytes.fromhex(s)
40 | cm = np.frombuffer(b, np.uint8)
41 | cm = cm.reshape(len(colormap), 3)
42 | return cm
43 |
44 |
45 | def plot_keypoints_on_image(k, image_tensor, radius=1, thickness=1, kp_range=(0, 1)):
46 | # https://github.com/DuaneNielsen/keypoints
47 | height, width = image_tensor.size(1), image_tensor.size(2)
48 | num_keypoints = k.size(0)
49 |
50 | if len(k.shape) != 2:
51 | raise Exception('Individual images and keypoints, not batches')
52 |
53 | k = k.clone()
54 | k[:, 0] = ((k[:, 0] - kp_range[0]) / (kp_range[1] - kp_range[0])) * (height - 1)
55 | k[:, 1] = ((k[:, 1] - kp_range[0]) / (kp_range[1] - kp_range[0])) * (width - 1)
56 | # k.floor_()
57 | k.round_()
58 | k = k.detach().cpu().numpy()
59 | # print(k)
60 |
61 | img = transforms.ToPILImage()(image_tensor.cpu())
62 |
63 | img = np.array(img)
64 | cmap = color_map()
65 | cm = cmap[:num_keypoints].astype(int)
66 | count = 0
67 | for co_ord, color in zip(k, cm):
68 | c = color.item(0), color.item(1), color.item(2)
69 | co_ord = co_ord.squeeze()
70 | cv2.circle(img, (int(co_ord[1]), int(co_ord[0])), radius, c, thickness)
71 | count += 1
72 |
73 | return img
74 |
75 |
76 | def plot_keypoints_on_image_batch(kp_batch_tensor, img_batch_tensor, radius=1, thickness=1, max_imgs=8,
77 | kp_range=(-1, 1)):
78 | num_plot = min(max_imgs, img_batch_tensor.shape[0])
79 | img_with_kp = []
80 | for i in range(num_plot):
81 | img_np = plot_keypoints_on_image(kp_batch_tensor[i], img_batch_tensor[i], radius=radius, thickness=thickness,
82 | kp_range=kp_range)
83 | img_tensor = torch.tensor(img_np).float() / 255.0
84 | img_with_kp.append(img_tensor.permute(2, 0, 1))
85 | img_with_kp = torch.stack(img_with_kp, dim=0)
86 | return img_with_kp
87 |
88 |
89 | def get_kp_mask_from_gmap(gmaps, threshold=0.2, binary=True, elementwise=False):
90 | """
91 | Transforms the Gaussian-maps created from the KP to (binary) masks.
92 | gmaps: [B, K, H, W]
93 | threshold: above it pixels are one and below zero
94 | """
95 | if elementwise:
96 | mask = gmaps
97 | else:
98 | mask = gmaps.sum(1, keepdim=True)
99 | if binary:
100 | mask = torch.where(mask > threshold, 1.0, 0.0)
101 | else:
102 | mask = mask.clamp(0, 1)
103 | return mask
104 |
105 |
106 | def reparameterize(mu, logvar):
107 | """
108 | This function applies the reparameterization trick:
109 | z = mu(X) + sigma(X)^0.5 * epsilon, where epsilon ~ N(0,I)
110 | :param mu: mean of x
111 | :param logvar: log variaance of x
112 | :return z: the sampled latent variable
113 | """
114 | device = mu.device
115 | std = torch.exp(0.5 * logvar)
116 | eps = torch.randn_like(mu).to(device)
117 | return mu + eps * std
118 |
119 |
120 | def create_masks_fast(center, anchor_s, feature_dim=16, patch_size=None):
121 | """
122 | Creates binary masks where only a box of size round(anchor_s * (feature_dim - 1)) centered arond `center`
123 | is 1, rest is 0.
124 | center: [batch_size, n_kp, 2] in kp_range
125 | anchor_h, anchor_w: size of anchor in [0, 1]
126 | """
127 | batch_size, n_kp = center.shape[0], center.shape[1]
128 | if patch_size is None:
129 | patch_size = np.round(anchor_s * (feature_dim - 1)).astype(int)
130 | # create white rectangles
131 | masks = torch.ones(batch_size * n_kp, 1, patch_size, patch_size, device=center.device).float()
132 | # pad the masks to image size
133 | pad_size = (feature_dim - patch_size) // 2
134 | padded_patches_batch = F.pad(masks, pad=[pad_size] * 4)
135 | # move the masks to be centered around the kp
136 | delta_t_batch = 0.0 - center
137 | delta_t_batch = delta_t_batch.reshape(-1, delta_t_batch.shape[-1]) # [bs * n_kp, 2]
138 | zeros = torch.zeros([delta_t_batch.shape[0], 1], device=delta_t_batch.device).float()
139 | ones = torch.ones([delta_t_batch.shape[0], 1], device=delta_t_batch.device).float()
140 | theta = torch.cat([ones, zeros, delta_t_batch[:, 1].unsqueeze(-1),
141 | zeros, ones, delta_t_batch[:, 0].unsqueeze(-1)], dim=-1)
142 | theta = theta.view(-1, 2, 3) # [batch_size * n_kp, 2, 3]
143 | align_corners = False
144 | padding_mode = 'zeros'
145 | mode = "nearest"
146 | # mode = 'bilinear' # makes it differentiable, but we don't care about it here
147 | grid = F.affine_grid(theta, padded_patches_batch.size(), align_corners=align_corners)
148 | trans_padded_patches_batch = F.grid_sample(padded_patches_batch, grid, align_corners=align_corners,
149 | mode=mode, padding_mode=padding_mode)
150 | trans_padded_patches_batch = trans_padded_patches_batch.view(batch_size, n_kp, *padded_patches_batch.shape[1:])
151 | # [bs, n_kp, 1, feature_dim, feature_dim]
152 | return trans_padded_patches_batch
153 |
154 |
155 | def get_bb_from_masks(masks, width, height):
156 | # extracts bounding boxes (bb) from masks.
157 | # masks: [n_masks, 1, feature_dim, feature_dim]
158 | n_masks = masks.shape[0]
159 | mask_h, mask_w = masks.shape[2], masks.shape[3]
160 | coor = torch.zeros(size=(n_masks, 4), dtype=torch.int, device=masks.device)
161 | for i in range(n_masks):
162 | mask = masks[i].int().squeeze() # [feature_dim, feature_dim]
163 | indices = (mask == 1).nonzero(as_tuple=False)
164 | if indices.shape[0] > 0:
165 | ws = (indices[0][1] * (width / mask_w)).clamp(0, width).int()
166 | wt = (indices[-1][1] * (width / mask_w)).clamp(0, width).int()
167 | hs = (indices[0][0] * (height / mask_h)).clamp(0, height).int()
168 | ht = (indices[-1][0] * (height / mask_h)).clamp(0, height).int()
169 | coor[i, 0] = ws
170 | coor[i, 1] = hs
171 | coor[i, 2] = wt
172 | coor[i, 3] = ht
173 | return coor
174 |
175 |
176 | def get_bb_from_masks_batch(masks, width, height):
177 | # extracts bounding boxes (bb) from a batch of masks.
178 | # masks: [batch_size, n_masks, 1, feature_dim, feature_dim]
179 | coor = torch.zeros(size=(masks.shape[0], masks.shape[1], 4), dtype=torch.int, device=masks.device)
180 | for i in range(masks.shape[0]):
181 | coor[i, :, :] = get_bb_from_masks(masks[i], width, height)
182 | return coor
183 |
184 |
185 | def nms_single(boxes, scores, iou_thresh=0.5, return_scores=False, remove_ind=None):
186 | # non-maximal suppression on bb and scores from one image.
187 | # boxes: [n_bb, 4], scores: [n_boxes]
188 | nms_indices = ops.nms(boxes.float(), scores, iou_thresh)
189 | # remove low scoring indices from nms output
190 | if remove_ind is not None:
191 | # final_indices = [ind for ind in nms_indices if ind not in remove_ind]
192 | final_indices = list(set(nms_indices.data.cpu().numpy()) - set(remove_ind))
193 | # print(f'removed indices: {remove_ind}')
194 | else:
195 | final_indices = nms_indices
196 | nms_boxes = boxes[final_indices] # [n_bb_nms, 4]
197 | if return_scores:
198 | return nms_boxes, final_indices, scores[final_indices]
199 | else:
200 | return nms_boxes, final_indices
201 |
202 |
203 | def remove_low_score_bb_single(boxes, scores, return_scores=False, mode='mean', thresh=0.4, hard_thresh=None):
204 | # filters out low-scoring bounding boxes. The score is usually the variance of the particle.
205 | # boxes: [n_bb, 4], scores: [n_boxes]
206 | if hard_thresh is None:
207 | if mode == 'mean':
208 | mean_score = scores.mean()
209 | # indices = (scores > mean_score)
210 | indices = torch.nonzero(scores > thresh, as_tuple=True)[0].data.cpu().numpy()
211 | else:
212 | normalzied_scores = (scores - scores.min()) / (scores.max() - scores.min())
213 | # indices = (normalzied_scores > thresh)
214 | indices = torch.nonzero(normalzied_scores > thresh, as_tuple=True)[0].data.cpu().numpy()
215 | else:
216 | # indices = (scores > hard_thresh)
217 | indices = torch.nonzero(scores > hard_thresh, as_tuple=True)[0].data.cpu().numpy()
218 | boxes_t = boxes[indices]
219 | scores_t = scores[indices]
220 | if return_scores:
221 | return indices, boxes_t, scores_t
222 | else:
223 | return indices, boxes_t
224 |
225 |
226 | def get_low_score_bb_single(scores, mode='mean', thresh=0.4, hard_thresh=None):
227 | # get indices of low-scoring bounding boxes.
228 | # boxes: [n_bb, 4], scores: [n_boxes]
229 | if hard_thresh is None:
230 | if mode == 'mean':
231 | mean_score = scores.mean()
232 | # indices = (scores > mean_score)
233 | indices = torch.nonzero(scores < thresh, as_tuple=True)[0].data.cpu().numpy()
234 | else:
235 | normalzied_scores = (scores - scores.min()) / (scores.max() - scores.min())
236 | # indices = (normalzied_scores > thresh)
237 | indices = torch.nonzero(normalzied_scores < thresh, as_tuple=True)[0].data.cpu().numpy()
238 | else:
239 | # indices = (scores > hard_thresh)
240 | indices = torch.nonzero(scores < hard_thresh, as_tuple=True)[0].data.cpu().numpy()
241 | return indices
242 |
243 |
244 | def plot_bb_on_image_from_masks_nms(masks, image_tensor, scores, iou_thresh=0.5, thickness=1, hard_thresh=None):
245 | # plot bounding boxes on a single image, use non-maximal suppression to filter low-scoring bbs.
246 | # masks: [n_masks, 1, feature_dim, feature_dim]
247 | n_masks = masks.shape[0]
248 | mask_h, mask_w = masks.shape[2], masks.shape[3]
249 | height, width = image_tensor.size(1), image_tensor.size(2)
250 | img = transforms.ToPILImage()(image_tensor.cpu())
251 | img = np.array(img)
252 | cmap = color_map()
253 | cm = cmap[:n_masks].astype(int)
254 | count = 0
255 | # get bb coor
256 | coors = get_bb_from_masks(masks, width, height) # [n_masks, 4]
257 | # remove low-score bb
258 | low_score_ind = get_low_score_bb_single(scores, mode='mean', hard_thresh=hard_thresh)
259 | # nms
260 | coors_nms, nms_indices, scores_nms = nms_single(coors, scores, iou_thresh, return_scores=True,
261 | remove_ind=low_score_ind)
262 | # [n_masks_nms, 4]
263 | for coor, color in zip(coors_nms, cm):
264 | c = color.item(0), color.item(1), color.item(2)
265 | ws = (coor[0] - thickness).clamp(0, width)
266 | hs = (coor[1] - thickness).clamp(0, height)
267 | wt = (coor[2] + thickness).clamp(0, width)
268 | ht = (coor[3] + thickness).clamp(0, height)
269 | bb_s = (int(ws), int(hs))
270 | bb_t = (int(wt), int(ht))
271 | cv2.rectangle(img, bb_s, bb_t, c, thickness, 1)
272 | score_text = f'{scores_nms[count]:.2f}'
273 | font = cv2.FONT_HERSHEY_SIMPLEX
274 | fontScale = 0.3
275 | thickness = 1
276 | box_w = bb_t[0] - bb_s[0]
277 | box_h = bb_t[1] - bb_s[1]
278 | org = (int(bb_s[0] + box_w / 2), int(bb_s[1] + box_h / 2))
279 | cv2.putText(img, score_text, org, font, fontScale, thickness=thickness, color=c, lineType=cv2.LINE_AA)
280 | count += 1
281 |
282 | return img, nms_indices
283 |
284 |
285 | def plot_bb_on_image_batch_from_masks_nms(mask_batch_tensor, img_batch_tensor, scores, iou_thresh=0.5, thickness=1,
286 | max_imgs=8, hard_thresh=None):
287 | # plot bounding boxes on a batch of images, use non-maximal suppression to filter low-scoring bbs.
288 | # mask_batch_tensor: [batch_size, n_kp, 1, feature_dim, feature_dim]
289 | num_plot = min(max_imgs, img_batch_tensor.shape[0])
290 | img_with_bb = []
291 | indices = []
292 | for i in range(num_plot):
293 | img_np, nms_indices = plot_bb_on_image_from_masks_nms(mask_batch_tensor[i], img_batch_tensor[i], scores[i],
294 | iou_thresh, thickness=thickness, hard_thresh=hard_thresh)
295 | img_tensor = torch.tensor(img_np).float() / 255.0
296 | img_with_bb.append(img_tensor.permute(2, 0, 1))
297 | indices.append(nms_indices)
298 | img_with_bb = torch.stack(img_with_bb, dim=0)
299 | return img_with_bb, indices
300 |
301 |
302 | def plot_bb_on_image_from_masks(masks, image_tensor, thickness=1):
303 | # vanilla plotting of bbs from masks.
304 | # masks: [n_masks, 1, feature_dim, feature_dim]
305 | n_masks = masks.shape[0]
306 | mask_h, mask_w = masks.shape[2], masks.shape[3]
307 | height, width = image_tensor.size(1), image_tensor.size(2)
308 |
309 | img = transforms.ToPILImage()(image_tensor.cpu())
310 |
311 | img = np.array(img)
312 | cmap = color_map()
313 | cm = cmap[:n_masks].astype(int)
314 | count = 0
315 | for mask, color in zip(masks, cm):
316 | c = color.item(0), color.item(1), color.item(2)
317 | mask = mask.int().squeeze() # [feature_dim, feature_dim]
318 | # print(mask.shape)
319 | indices = (mask == 1).nonzero(as_tuple=False)
320 | # print(indices.shape)
321 | if indices.shape[0] > 0:
322 | ws = (indices[0][1] * (width / mask_w) - thickness).clamp(0, width).int()
323 | wt = (indices[-1][1] * (width / mask_w) + thickness).clamp(0, width).int()
324 | hs = (indices[0][0] * (height / mask_h) - thickness).clamp(0, height).int()
325 | ht = (indices[-1][0] * (height / mask_h) + thickness).clamp(0, height).int()
326 | bb_s = (int(ws), int(hs))
327 | bb_t = (int(wt), int(ht))
328 | cv2.rectangle(img, bb_s, bb_t, c, thickness, 1)
329 | count += 1
330 | return img
331 |
332 |
333 | def plot_bb_on_image_batch_from_masks(mask_batch_tensor, img_batch_tensor, thickness=1, max_imgs=8):
334 | # vanilla plotting of bbs from a batch of masks.
335 | # mask_batch_tensor: [batch_size, n_kp, 1, feature_dim, feature_dim]
336 | num_plot = min(max_imgs, img_batch_tensor.shape[0])
337 | img_with_bb = []
338 | for i in range(num_plot):
339 | img_np = plot_bb_on_image_from_masks(mask_batch_tensor[i], img_batch_tensor[i], thickness=thickness)
340 | img_tensor = torch.tensor(img_np).float() / 255.0
341 | img_with_bb.append(img_tensor.permute(2, 0, 1))
342 | img_with_bb = torch.stack(img_with_bb, dim=0)
343 | return img_with_bb
344 |
345 |
346 | def prepare_logdir(runname, src_dir='./'):
347 | """
348 | Prepare the log directory in which checkpoints, plots and stats will be saved.
349 | """
350 | td_prefix = datetime.datetime.now().strftime("%d%m%y_%H%M%S")
351 | dir_name = f'{td_prefix}_{runname}'
352 | path_to_dir = os.path.join(src_dir, dir_name)
353 | os.makedirs(path_to_dir, exist_ok=True)
354 | path_to_fig_dir = os.path.join(path_to_dir, 'figures')
355 | os.makedirs(path_to_fig_dir, exist_ok=True)
356 | path_to_save_dir = os.path.join(path_to_dir, 'saves')
357 | os.makedirs(path_to_save_dir, exist_ok=True)
358 | return path_to_dir
359 |
360 |
361 | def save_config(src_dir, hparams):
362 | # saves the hyperparameters of a single run.
363 | path_to_conf = os.path.join(src_dir, 'hparams.json')
364 | with open(path_to_conf, "w") as outfile:
365 | json.dump(hparams, outfile, indent=2)
366 |
367 |
368 | def log_line(src_dir, line):
369 | log_file = os.path.join(src_dir, 'log.txt')
370 | with open(log_file, 'a') as fp:
371 | fp.writelines(line)
372 |
--------------------------------------------------------------------------------